okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 387c4740e7
commit 27f3326e22
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,12 @@
# -*- test-case-name: twisted -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted: The Framework Of Your Internet.
"""
from twisted._version import __version__ as version
__version__ = version.short()

View File

@@ -0,0 +1,14 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# Make the twisted module executable with the default behaviour of
# running twist.
# This is not a docstring to avoid changing the string output of twist.
import sys
if __name__ == "__main__":
from twisted.application.twist._twist import Twist
sys.exit(Twist.main())

View File

@@ -0,0 +1,24 @@
# -*- test-case-name: twisted.test.test_paths -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted integration with operating system threads.
"""
from ._ithreads import AlreadyQuit, IWorker
from ._memory import createMemoryWorker
from ._pool import pool
from ._team import Team
from ._threadworker import LockWorker, ThreadWorker
__all__ = [
"ThreadWorker",
"LockWorker",
"IWorker",
"AlreadyQuit",
"Team",
"createMemoryWorker",
"pool",
]

View File

@@ -0,0 +1,43 @@
# -*- test-case-name: twisted._threads.test.test_convenience -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functionality used within the implementation of various workers.
"""
from ._ithreads import AlreadyQuit
class Quit:
"""
A flag representing whether a worker has been quit.
@ivar isSet: Whether this flag is set.
@type isSet: L{bool}
"""
def __init__(self) -> None:
"""
Create a L{Quit} un-set.
"""
self.isSet = False
def set(self) -> None:
"""
Set the flag if it has not been set.
@raise AlreadyQuit: If it has been set.
"""
self.check()
self.isSet = True
def check(self) -> None:
"""
Check if the flag has been set.
@raise AlreadyQuit: If it has been set.
"""
if self.isSet:
raise AlreadyQuit()

View File

@@ -0,0 +1,61 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Interfaces related to threads.
"""
from typing import Callable
from zope.interface import Interface
class AlreadyQuit(Exception):
"""
This worker worker is dead and cannot execute more instructions.
"""
class IWorker(Interface):
"""
A worker that can perform some work concurrently.
All methods on this interface must be thread-safe.
"""
def do(task: Callable[[], None]) -> None:
"""
Perform the given task.
As an interface, this method makes no specific claims about concurrent
execution. An L{IWorker}'s C{do} implementation may defer execution
for later on the same thread, immediately on a different thread, or
some combination of the two. It is valid for a C{do} method to
schedule C{task} in such a way that it may never be executed.
It is important for some implementations to provide specific properties
with respect to where C{task} is executed, of course, and client code
may rely on a more specific implementation of C{do} than L{IWorker}.
@param task: a task to call in a thread or other concurrent context.
@type task: 0-argument callable
@raise AlreadyQuit: if C{quit} has been called.
"""
def quit() -> None:
"""
Free any resources associated with this L{IWorker} and cause it to
reject all future work.
@raise AlreadyQuit: if this method has already been called.
"""
class IExclusiveWorker(IWorker):
"""
Like L{IWorker}, but with the additional guarantee that the callables
passed to C{do} will not be called exclusively with each other.
"""

View File

@@ -0,0 +1,83 @@
# -*- test-case-name: twisted._threads.test.test_memory -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of an in-memory worker that defers execution.
"""
from __future__ import annotations
from enum import Enum, auto
from typing import Callable, Literal
from zope.interface import implementer
from ._convenience import Quit
from ._ithreads import IExclusiveWorker
class NoMore(Enum):
Work = auto()
NoMoreWork = NoMore.Work
@implementer(IExclusiveWorker)
class MemoryWorker:
"""
An L{IWorker} that queues work for later performance.
@ivar _quit: a flag indicating
@type _quit: L{Quit}
"""
def __init__(
self,
pending: Callable[[], list[Callable[[], object] | Literal[NoMore.Work]]] = list,
) -> None:
"""
Create a L{MemoryWorker}.
"""
self._quit = Quit()
self._pending = pending()
def do(self, work: Callable[[], object]) -> None:
"""
Queue some work for to perform later; see L{createMemoryWorker}.
@param work: The work to perform.
"""
self._quit.check()
self._pending.append(work)
def quit(self) -> None:
"""
Quit this worker.
"""
self._quit.set()
self._pending.append(NoMoreWork)
def createMemoryWorker() -> tuple[MemoryWorker, Callable[[], bool]]:
"""
Create an L{IWorker} that does nothing but defer work, to be performed
later.
@return: a worker that will enqueue work to perform later, and a callable
that will perform one element of that work.
"""
def perform() -> bool:
if not worker._pending:
return False
peek = worker._pending[0]
if peek is NoMoreWork:
return False
worker._pending.pop(0)
peek()
return True
worker = MemoryWorker()
return (worker, perform)

View File

@@ -0,0 +1,73 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Top level thread pool interface, used to implement
L{twisted.python.threadpool}.
"""
from queue import Queue
from threading import Lock, Thread, local as LocalStorage
from typing import Callable, Optional
from typing_extensions import Protocol
from twisted.python.log import err
from ._ithreads import IWorker
from ._team import Team
from ._threadworker import LockWorker, ThreadWorker
class _ThreadFactory(Protocol):
def __call__(self, *, target: Callable[..., object]) -> Thread:
...
def pool(
currentLimit: Callable[[], int], threadFactory: _ThreadFactory = Thread
) -> Team:
"""
Construct a L{Team} that spawns threads as a thread pool, with the given
limiting function.
@note: Future maintainers: while the public API for the eventual move to
twisted.threads should look I{something} like this, and while this
function is necessary to implement the API described by
L{twisted.python.threadpool}, I am starting to think the idea of a hard
upper limit on threadpool size is just bad (turning memory performance
issues into correctness issues well before we run into memory
pressure), and instead we should build something with reactor
integration for slowly releasing idle threads when they're not needed
and I{rate} limiting the creation of new threads rather than just
hard-capping it.
@param currentLimit: a callable that returns the current limit on the
number of workers that the returned L{Team} should create; if it
already has more workers than that value, no new workers will be
created.
@type currentLimit: 0-argument callable returning L{int}
@param threadFactory: Factory that, when given a C{target} keyword argument,
returns a L{threading.Thread} that will run that target.
@type threadFactory: callable returning a L{threading.Thread}
@return: a new L{Team}.
"""
def startThread(target: Callable[..., object]) -> None:
return threadFactory(target=target).start()
def limitedWorkerCreator() -> Optional[IWorker]:
stats = team.statistics()
if stats.busyWorkerCount + stats.idleWorkerCount >= currentLimit():
return None
return ThreadWorker(startThread, Queue())
team = Team(
coordinator=LockWorker(Lock(), LocalStorage()),
createWorker=limitedWorkerCreator,
logException=err,
)
return team

View File

@@ -0,0 +1,232 @@
# -*- test-case-name: twisted._threads.test.test_team -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of a L{Team} of workers; a thread-pool that can allocate work to
workers.
"""
from __future__ import annotations
from collections import deque
from typing import Callable, Optional, Set
from zope.interface import implementer
from . import IWorker
from ._convenience import Quit
from ._ithreads import IExclusiveWorker
class Statistics:
"""
Statistics about a L{Team}'s current activity.
@ivar idleWorkerCount: The number of idle workers.
@type idleWorkerCount: L{int}
@ivar busyWorkerCount: The number of busy workers.
@type busyWorkerCount: L{int}
@ivar backloggedWorkCount: The number of work items passed to L{Team.do}
which have not yet been sent to a worker to be performed because not
enough workers are available.
@type backloggedWorkCount: L{int}
"""
def __init__(
self, idleWorkerCount: int, busyWorkerCount: int, backloggedWorkCount: int
) -> None:
self.idleWorkerCount = idleWorkerCount
self.busyWorkerCount = busyWorkerCount
self.backloggedWorkCount = backloggedWorkCount
@implementer(IWorker)
class Team:
"""
A composite L{IWorker} implementation.
@ivar _quit: A L{Quit} flag indicating whether this L{Team} has been quit
yet. This may be set by an arbitrary thread since L{Team.quit} may be
called from anywhere.
@ivar _coordinator: the L{IExclusiveWorker} coordinating access to this
L{Team}'s internal resources.
@ivar _createWorker: a callable that will create new workers.
@ivar _logException: a 0-argument callable called in an exception context
when there is an unhandled error from a task passed to L{Team.do}
@ivar _idle: a L{set} of idle workers.
@ivar _busyCount: the number of workers currently busy.
@ivar _pending: a C{deque} of tasks - that is, 0-argument callables passed
to L{Team.do} - that are outstanding.
@ivar _shouldQuitCoordinator: A flag indicating that the coordinator should
be quit at the next available opportunity. Unlike L{Team._quit}, this
flag is only set by the coordinator.
@ivar _toShrink: the number of workers to shrink this L{Team} by at the
next available opportunity; set in the coordinator.
"""
def __init__(
self,
coordinator: IExclusiveWorker,
createWorker: Callable[[], Optional[IWorker]],
logException: Callable[[], None],
):
"""
@param coordinator: an L{IExclusiveWorker} which will coordinate access
to resources on this L{Team}; that is to say, an
L{IExclusiveWorker} whose C{do} method ensures that its given work
will be executed in a mutually exclusive context, not in parallel
with other work enqueued by C{do} (although possibly in parallel
with the caller).
@param createWorker: A 0-argument callable that will create an
L{IWorker} to perform work.
@param logException: A 0-argument callable called in an exception
context when the work passed to C{do} raises an exception.
"""
self._quit = Quit()
self._coordinator = coordinator
self._createWorker = createWorker
self._logException = logException
# Don't touch these except from the coordinator.
self._idle: Set[IWorker] = set()
self._busyCount = 0
self._pending: "deque[Callable[..., object]]" = deque()
self._shouldQuitCoordinator = False
self._toShrink = 0
def statistics(self) -> Statistics:
"""
Gather information on the current status of this L{Team}.
@return: a L{Statistics} describing the current state of this L{Team}.
"""
return Statistics(len(self._idle), self._busyCount, len(self._pending))
def grow(self, n: int) -> None:
"""
Increase the the number of idle workers by C{n}.
@param n: The number of new idle workers to create.
@type n: L{int}
"""
self._quit.check()
@self._coordinator.do
def createOneWorker() -> None:
for x in range(n):
worker = self._createWorker()
if worker is None:
return
self._recycleWorker(worker)
def shrink(self, n: Optional[int] = None) -> None:
"""
Decrease the number of idle workers by C{n}.
@param n: The number of idle workers to shut down, or L{None} (or
unspecified) to shut down all workers.
@type n: L{int} or L{None}
"""
self._quit.check()
self._coordinator.do(lambda: self._quitIdlers(n))
def _quitIdlers(self, n: Optional[int] = None) -> None:
"""
The implmentation of C{shrink}, performed by the coordinator worker.
@param n: see L{Team.shrink}
"""
if n is None:
n = len(self._idle) + self._busyCount
for x in range(n):
if self._idle:
self._idle.pop().quit()
else:
self._toShrink += 1
if self._shouldQuitCoordinator and self._busyCount == 0:
self._coordinator.quit()
def do(self, task: Callable[[], object]) -> None:
"""
Perform some work in a worker created by C{createWorker}.
@param task: the callable to run
"""
self._quit.check()
self._coordinator.do(lambda: self._coordinateThisTask(task))
def _coordinateThisTask(self, task: Callable[..., object]) -> None:
"""
Select a worker to dispatch to, either an idle one or a new one, and
perform it.
This method should run on the coordinator worker.
@param task: the task to dispatch
@type task: 0-argument callable
"""
worker = self._idle.pop() if self._idle else self._createWorker()
if worker is None:
# The createWorker method may return None if we're out of resources
# to create workers.
self._pending.append(task)
return
not_none_worker = worker
self._busyCount += 1
@worker.do
def doWork() -> None:
try:
task()
except BaseException:
self._logException()
@self._coordinator.do
def idleAndPending() -> None:
self._busyCount -= 1
self._recycleWorker(not_none_worker)
def _recycleWorker(self, worker: IWorker) -> None:
"""
Called only from coordinator.
Recycle the given worker into the idle pool.
@param worker: a worker created by C{createWorker} and now idle.
@type worker: L{IWorker}
"""
self._idle.add(worker)
if self._pending:
# Re-try the first enqueued thing.
# (Explicitly do _not_ honor _quit.)
self._coordinateThisTask(self._pending.popleft())
elif self._shouldQuitCoordinator:
self._quitIdlers()
elif self._toShrink > 0:
self._toShrink -= 1
self._idle.remove(worker)
worker.quit()
def quit(self) -> None:
"""
Stop doing work and shut down all idle workers.
"""
self._quit.set()
# In case all the workers are idle when we do this.
@self._coordinator.do
def startFinishing() -> None:
self._shouldQuitCoordinator = True
self._quitIdlers()

View File

@@ -0,0 +1,156 @@
# -*- test-case-name: twisted._threads.test.test_threadworker -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of an L{IWorker} based on native threads and queues.
"""
from __future__ import annotations
from enum import Enum, auto
from typing import TYPE_CHECKING, Callable, Iterator, Literal, Protocol, TypeVar
if TYPE_CHECKING:
import threading
from zope.interface import implementer
from ._convenience import Quit
from ._ithreads import IExclusiveWorker
class Stop(Enum):
Thread = auto()
StopThread = Stop.Thread
T = TypeVar("T")
U = TypeVar("U")
class SimpleQueue(Protocol[T]):
def put(self, item: T) -> None:
...
def get(self) -> T:
...
# when the sentinel value is a literal in a union, this is how iter works
smartiter: Callable[[Callable[[], T | U], U], Iterator[T]]
smartiter = iter # type:ignore[assignment]
@implementer(IExclusiveWorker)
class ThreadWorker:
"""
An L{IExclusiveWorker} implemented based on a single thread and a queue.
This worker ensures exclusivity (i.e. it is an L{IExclusiveWorker} and not
an L{IWorker}) by performing all of the work passed to C{do} on the I{same}
thread.
"""
def __init__(
self,
startThread: Callable[[Callable[[], object]], object],
queue: SimpleQueue[Callable[[], object] | Literal[Stop.Thread]],
):
"""
Create a L{ThreadWorker} with a function to start a thread and a queue
to use to communicate with that thread.
@param startThread: a callable that takes a callable to run in another
thread.
@param queue: A L{Queue} to use to give tasks to the thread created by
C{startThread}.
"""
self._q = queue
self._hasQuit = Quit()
def work() -> None:
for task in smartiter(queue.get, StopThread):
task()
startThread(work)
def do(self, task: Callable[[], None]) -> None:
"""
Perform the given task on the thread owned by this L{ThreadWorker}.
@param task: the function to call on a thread.
"""
self._hasQuit.check()
self._q.put(task)
def quit(self) -> None:
"""
Reject all future work and stop the thread started by C{__init__}.
"""
# Reject all future work. Set this _before_ enqueueing _stop, so
# that no work is ever enqueued _after_ _stop.
self._hasQuit.set()
self._q.put(StopThread)
class SimpleLock(Protocol):
def acquire(self) -> bool:
...
def release(self) -> None:
...
@implementer(IExclusiveWorker)
class LockWorker:
"""
An L{IWorker} implemented based on a mutual-exclusion lock.
"""
def __init__(self, lock: SimpleLock, local: threading.local):
"""
@param lock: A mutual-exclusion lock, with C{acquire} and C{release}
methods.
@type lock: L{threading.Lock}
@param local: Local storage.
@type local: L{threading.local}
"""
self._quit = Quit()
self._lock: SimpleLock | None = lock
self._local = local
def do(self, work: Callable[[], None]) -> None:
"""
Do the given work on this thread, with the mutex acquired. If this is
called re-entrantly, return and wait for the outer invocation to do the
work.
@param work: the work to do with the lock held.
"""
lock = self._lock
local = self._local
self._quit.check()
working = getattr(local, "working", None)
if working is None:
assert lock is not None, "LockWorker used after quit()"
working = local.working = []
working.append(work)
lock.acquire()
try:
while working:
working.pop(0)()
finally:
lock.release()
local.working = None
else:
working.append(work)
def quit(self) -> None:
"""
Quit this L{LockWorker}.
"""
self._quit.set()
self._lock = None

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads}.
"""

View File

@@ -0,0 +1,56 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for convenience functionality in L{twisted._threads._convenience}.
"""
from twisted.trial.unittest import SynchronousTestCase
from .._convenience import Quit
from .._ithreads import AlreadyQuit
class QuitTests(SynchronousTestCase):
"""
Tests for L{Quit}
"""
def test_isInitiallySet(self) -> None:
"""
L{Quit.isSet} starts as L{False}.
"""
quit = Quit()
self.assertEqual(quit.isSet, False)
def test_setSetsSet(self) -> None:
"""
L{Quit.set} sets L{Quit.isSet} to L{True}.
"""
quit = Quit()
quit.set()
self.assertEqual(quit.isSet, True)
def test_checkDoesNothing(self) -> None:
"""
L{Quit.check} initially does nothing and returns L{None}.
"""
quit = Quit()
checked = quit.check() # type:ignore[func-returns-value]
self.assertIs(checked, None)
def test_checkAfterSetRaises(self) -> None:
"""
L{Quit.check} raises L{AlreadyQuit} if L{Quit.set} has been called.
"""
quit = Quit()
quit.set()
self.assertRaises(AlreadyQuit, quit.check)
def test_setTwiceRaises(self) -> None:
"""
L{Quit.set} raises L{AlreadyQuit} if it has been called previously.
"""
quit = Quit()
quit.set()
self.assertRaises(AlreadyQuit, quit.set)

View File

@@ -0,0 +1,64 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._memory}.
"""
from zope.interface.verify import verifyObject
from twisted.trial.unittest import SynchronousTestCase
from .. import AlreadyQuit, IWorker, createMemoryWorker
class MemoryWorkerTests(SynchronousTestCase):
"""
Tests for L{MemoryWorker}.
"""
def test_createWorkerAndPerform(self) -> None:
"""
L{createMemoryWorker} creates an L{IWorker} and a callable that can
perform work on it. The performer returns C{True} if it accomplished
useful work.
"""
worker, performer = createMemoryWorker()
verifyObject(IWorker, worker)
done = []
worker.do(lambda: done.append(3))
worker.do(lambda: done.append(4))
self.assertEqual(done, [])
self.assertEqual(performer(), True)
self.assertEqual(done, [3])
self.assertEqual(performer(), True)
self.assertEqual(done, [3, 4])
def test_quitQuits(self) -> None:
"""
Calling C{quit} on the worker returned by L{createMemoryWorker} causes
its C{do} and C{quit} methods to raise L{AlreadyQuit}; its C{perform}
callable will start raising L{AlreadyQuit} when the work already
provided to C{do} has been exhausted.
"""
worker, performer = createMemoryWorker()
done = []
def moreWork() -> None:
done.append(7)
worker.do(moreWork)
worker.quit()
self.assertRaises(AlreadyQuit, worker.do, moreWork)
self.assertRaises(AlreadyQuit, worker.quit)
performer()
self.assertEqual(done, [7])
self.assertEqual(performer(), False)
def test_performWhenNothingToDoYet(self) -> None:
"""
The C{perform} callable returned by L{createMemoryWorker} will return
no result when there's no work to do yet. Since there is no work to
do, the performer returns C{False}.
"""
worker, performer = createMemoryWorker()
self.assertEqual(performer(), False)

View File

@@ -0,0 +1,286 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._team}.
"""
from __future__ import annotations
from typing import Callable
from twisted.python.components import proxyForInterface
from twisted.python.context import call, get
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase
from .. import AlreadyQuit, IWorker, Team, createMemoryWorker
class ContextualWorker(proxyForInterface(IWorker, "_realWorker")): # type: ignore[misc]
"""
A worker implementation that supplies a context.
"""
def __init__(self, realWorker: IWorker, **ctx: object) -> None:
"""
Create with a real worker and a context.
"""
self._realWorker = realWorker
self._context = ctx
def do(self, work: Callable[[], object]) -> None:
"""
Perform the given work with the context given to __init__.
@param work: the work to pass on to the real worker.
"""
super().do(lambda: call(self._context, work))
class TeamTests(SynchronousTestCase):
"""
Tests for L{Team}
"""
def setUp(self) -> None:
"""
Set up a L{Team} with inspectable, synchronous workers that can be
single-stepped.
"""
coordinator, self.coordinateOnce = createMemoryWorker()
self.coordinator = ContextualWorker(coordinator, worker="coordinator")
self.workerPerformers: list[Callable[[], object]] = []
self.allWorkersEver: list[ContextualWorker] = []
self.allUnquitWorkers: list[ContextualWorker] = []
self.activePerformers: list[Callable[[], object]] = []
self.noMoreWorkers = lambda: False
def createWorker() -> ContextualWorker | None:
if self.noMoreWorkers():
return None
worker, performer = createMemoryWorker()
self.workerPerformers.append(performer)
self.activePerformers.append(performer)
cw = ContextualWorker(worker, worker=len(self.workerPerformers))
self.allWorkersEver.append(cw)
self.allUnquitWorkers.append(cw)
realQuit = cw.quit
def quitAndRemove() -> None:
realQuit()
self.allUnquitWorkers.remove(cw)
self.activePerformers.remove(performer)
cw.quit = quitAndRemove
return cw
self.failures: list[Failure] = []
def logException() -> None:
self.failures.append(Failure())
self.team = Team(coordinator, createWorker, logException)
def coordinate(self) -> bool:
"""
Perform all work currently scheduled in the coordinator.
@return: whether any coordination work was performed; if the
coordinator was idle when this was called, return L{False}
(otherwise L{True}).
"""
did = False
while self.coordinateOnce():
did = True
return did
def performAllOutstandingWork(self) -> None:
"""
Perform all work on the coordinator and worker performers that needs to
be done.
"""
continuing = True
while continuing:
continuing = self.coordinate()
for performer in self.workerPerformers:
if performer in self.activePerformers:
performer()
continuing = continuing or self.coordinate()
def test_doDoesWorkInWorker(self) -> None:
"""
L{Team.do} does the work in a worker created by the createWorker
callable.
"""
who = None
def something() -> None:
nonlocal who
who = get("worker")
self.team.do(something)
self.coordinate()
self.assertEqual(self.team.statistics().busyWorkerCount, 1)
self.performAllOutstandingWork()
self.assertEqual(who, 1)
self.assertEqual(self.team.statistics().busyWorkerCount, 0)
def test_initialStatistics(self) -> None:
"""
L{Team.statistics} returns an object with idleWorkerCount,
busyWorkerCount, and backloggedWorkCount integer attributes.
"""
stats = self.team.statistics()
self.assertEqual(stats.idleWorkerCount, 0)
self.assertEqual(stats.busyWorkerCount, 0)
self.assertEqual(stats.backloggedWorkCount, 0)
def test_growCreatesIdleWorkers(self) -> None:
"""
L{Team.grow} increases the number of available idle workers.
"""
self.team.grow(5)
self.performAllOutstandingWork()
self.assertEqual(len(self.workerPerformers), 5)
def test_growCreateLimit(self) -> None:
"""
L{Team.grow} increases the number of available idle workers until the
C{createWorker} callable starts returning None.
"""
self.noMoreWorkers = lambda: len(self.allWorkersEver) >= 3
self.team.grow(5)
self.performAllOutstandingWork()
self.assertEqual(len(self.allWorkersEver), 3)
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
def test_shrinkQuitsWorkers(self) -> None:
"""
L{Team.shrink} will quit the given number of workers.
"""
self.team.grow(5)
self.performAllOutstandingWork()
self.team.shrink(3)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 2)
def test_shrinkToZero(self) -> None:
"""
L{Team.shrink} with no arguments will stop all outstanding workers.
"""
self.team.grow(10)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 10)
self.team.shrink()
self.assertEqual(len(self.allUnquitWorkers), 10)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
def test_moreWorkWhenNoWorkersAvailable(self) -> None:
"""
When no additional workers are available, the given work is backlogged,
and then performed later when the work was.
"""
self.team.grow(3)
self.coordinate()
times = 0
def something() -> None:
nonlocal times
times += 1
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
for i in range(3):
self.team.do(something)
# Make progress on the coordinator but do _not_ actually complete the
# work, yet.
self.coordinate()
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
self.noMoreWorkers = lambda: True
self.team.do(something)
self.coordinate()
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
self.assertEqual(self.team.statistics().backloggedWorkCount, 1)
self.performAllOutstandingWork()
self.assertEqual(self.team.statistics().backloggedWorkCount, 0)
self.assertEqual(times, 4)
def test_exceptionInTask(self) -> None:
"""
When an exception is raised in a task passed to L{Team.do}, the
C{logException} given to the L{Team} at construction is invoked in the
exception context.
"""
self.team.do(lambda: 1 / 0)
self.performAllOutstandingWork()
self.assertEqual(len(self.failures), 1)
self.assertEqual(self.failures[0].type, ZeroDivisionError)
def test_quit(self) -> None:
"""
L{Team.quit} causes future invocations of L{Team.do} and L{Team.quit}
to raise L{AlreadyQuit}.
"""
self.team.quit()
self.assertRaises(AlreadyQuit, self.team.quit)
self.assertRaises(AlreadyQuit, self.team.do, list)
def test_quitQuits(self) -> None:
"""
L{Team.quit} causes all idle workers, as well as the coordinator
worker, to quit.
"""
for x in range(10):
self.team.do(list)
self.performAllOutstandingWork()
self.team.quit()
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
self.assertRaises(AlreadyQuit, self.coordinator.quit)
def test_quitQuitsLaterWhenBusy(self) -> None:
"""
L{Team.quit} causes all busy workers to be quit once they've finished
the work they've been given.
"""
self.team.grow(10)
for x in range(5):
self.team.do(list)
self.coordinate()
self.team.quit()
self.coordinate()
self.assertEqual(len(self.allUnquitWorkers), 5)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
self.assertRaises(AlreadyQuit, self.coordinator.quit)
def test_quitConcurrentWithWorkHappening(self) -> None:
"""
If work happens after L{Team.quit} sets its C{Quit} flag, but before
any other work takes place, the L{Team} should still exit gracefully.
"""
self.team.do(list)
originalSet = self.team._quit.set
def performWorkConcurrently() -> None:
originalSet()
self.performAllOutstandingWork()
self.team._quit.set = performWorkConcurrently # type:ignore[method-assign]
self.team.quit()
self.assertRaises(AlreadyQuit, self.team.quit)
self.assertRaises(AlreadyQuit, self.team.do, list)
def test_shrinkWhenBusy(self) -> None:
"""
L{Team.shrink} will wait for busy workers to finish being busy and then
quit them.
"""
for x in range(10):
self.team.do(list)
self.coordinate()
self.assertEqual(len(self.allUnquitWorkers), 10)
# There should be 10 busy workers at this point.
self.team.shrink(7)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 3)

View File

@@ -0,0 +1,314 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._threadworker}.
"""
from __future__ import annotations
import gc
import weakref
from threading import ThreadError, local
from typing import Callable, Generic, TypeVar
from twisted.trial.unittest import SynchronousTestCase
from .. import AlreadyQuit, LockWorker, ThreadWorker
from .._threadworker import SimpleLock
T = TypeVar("T")
class FakeQueueEmpty(Exception):
"""
L{FakeQueue}'s C{get} has exhausted the queue.
"""
class WouldDeadlock(Exception):
"""
If this were a real lock, you'd be deadlocked because the lock would be
double-acquired.
"""
class FakeThread:
"""
A fake L{threading.Thread}.
@ivar target: A target function to run.
@ivar started: Has this thread been started?
@type started: L{bool}
"""
def __init__(self, target: Callable[[], object]) -> None:
"""
Create a L{FakeThread} with a target.
"""
self.target = target
self.started = False
def start(self) -> None:
"""
Set the "started" flag.
"""
self.started = True
class FakeQueue(Generic[T]):
"""
A fake L{Queue} implementing C{put} and C{get}.
@ivar items: A lit of items placed by C{put} but not yet retrieved by
C{get}.
@type items: L{list}
"""
def __init__(self) -> None:
"""
Create a L{FakeQueue}.
"""
self.items: list[T] = []
def put(self, item: T) -> None:
"""
Put an item into the queue for later retrieval by L{FakeQueue.get}.
@param item: any object
"""
self.items.append(item)
def get(self) -> T:
"""
Get an item.
@return: an item previously put by C{put}.
"""
if not self.items:
raise FakeQueueEmpty()
return self.items.pop(0)
class FakeLock:
"""
A stand-in for L{threading.Lock}.
@ivar acquired: Whether this lock is presently acquired.
"""
def __init__(self) -> None:
"""
Create a lock in the un-acquired state.
"""
self.acquired = False
def acquire(self) -> bool:
"""
Acquire the lock. Raise an exception if the lock is already acquired.
"""
if self.acquired:
raise WouldDeadlock()
self.acquired = True
return True
def release(self) -> None:
"""
Release the lock. Raise an exception if the lock is not presently
acquired.
"""
if not self.acquired:
raise ThreadError()
self.acquired = False
class ThreadWorkerTests(SynchronousTestCase):
"""
Tests for L{ThreadWorker}.
"""
def setUp(self) -> None:
"""
Create a worker with fake threads.
"""
self.fakeThreads: list[FakeThread] = []
def startThread(target: Callable[[], object]) -> FakeThread:
newThread = FakeThread(target=target)
newThread.start()
self.fakeThreads.append(newThread)
return newThread
self.worker = ThreadWorker(startThread, FakeQueue())
def test_startsThreadAndPerformsWork(self) -> None:
"""
L{ThreadWorker} calls its C{createThread} callable to create a thread,
its C{createQueue} callable to create a queue, and then the thread's
target pulls work from that queue.
"""
self.assertEqual(len(self.fakeThreads), 1)
self.assertEqual(self.fakeThreads[0].started, True)
done = False
def doIt() -> None:
nonlocal done
done = True
self.worker.do(doIt)
self.assertEqual(done, False)
self.assertRaises(FakeQueueEmpty, self.fakeThreads[0].target)
self.assertEqual(done, True)
def test_quitPreventsFutureCalls(self) -> None:
"""
L{ThreadWorker.quit} causes future calls to L{ThreadWorker.do} and
L{ThreadWorker.quit} to raise L{AlreadyQuit}.
"""
self.worker.quit()
self.assertRaises(AlreadyQuit, self.worker.quit)
self.assertRaises(AlreadyQuit, self.worker.do, list)
class LockWorkerTests(SynchronousTestCase):
"""
Tests for L{LockWorker}.
"""
def test_fakeDeadlock(self) -> None:
"""
The L{FakeLock} test fixture will alert us if there's a potential
deadlock.
"""
lock = FakeLock()
lock.acquire()
self.assertRaises(WouldDeadlock, lock.acquire)
def test_fakeDoubleRelease(self) -> None:
"""
The L{FakeLock} test fixture will alert us if there's a potential
double-release.
"""
lock = FakeLock()
self.assertRaises(ThreadError, lock.release)
lock.acquire()
noResult = lock.release() # type:ignore[func-returns-value]
self.assertIs(None, noResult)
self.assertRaises(ThreadError, lock.release)
def test_doExecutesImmediatelyWithLock(self) -> None:
"""
L{LockWorker.do} immediately performs the work it's given, while the
lock is acquired.
"""
storage = local()
lock = FakeLock()
worker = LockWorker(lock, storage)
done = False
acquired = False
def work() -> None:
nonlocal done, acquired
done = True
acquired = lock.acquired
worker.do(work)
self.assertEqual(done, True)
self.assertEqual(acquired, True)
self.assertEqual(lock.acquired, False)
def test_doUnwindsReentrancy(self) -> None:
"""
If L{LockWorker.do} is called recursively, it postpones the inner call
until the outer one is complete.
"""
lock = FakeLock()
worker = LockWorker(lock, local())
levels = []
acquired = []
level = 0
def work() -> None:
nonlocal level
level += 1
levels.append(level)
acquired.append(lock.acquired)
if len(levels) < 2:
worker.do(work)
level -= 1
worker.do(work)
self.assertEqual(levels, [1, 1])
self.assertEqual(acquired, [True, True])
def test_quit(self) -> None:
"""
L{LockWorker.quit} frees the resources associated with its lock and
causes further calls to C{do} and C{quit} to fail.
"""
lock = FakeLock()
ref = weakref.ref(lock)
worker = LockWorker(lock, local())
del lock
self.assertIsNot(ref(), None)
worker.quit()
gc.collect()
self.assertIs(ref(), None)
self.assertRaises(AlreadyQuit, worker.quit)
self.assertRaises(AlreadyQuit, worker.do, list)
def test_quitWhileWorking(self) -> None:
"""
If L{LockWorker.quit} is invoked during a call to L{LockWorker.do}, all
recursive work scheduled with L{LockWorker.do} will be completed and
the lock will be released.
"""
lock = FakeLock()
ref = weakref.ref(lock)
worker = LockWorker(lock, local())
phase1complete = False
phase2complete = False
phase2acquired = None
def phase1() -> None:
nonlocal phase1complete
worker.do(phase2)
worker.quit()
self.assertRaises(AlreadyQuit, worker.do, list)
phase1complete = True
def phase2() -> None:
nonlocal phase2complete, phase2acquired, lock
phase2complete = True
phase2acquired = lock.acquired
worker.do(phase1)
self.assertEqual(phase1complete, True)
self.assertEqual(phase2complete, True)
self.assertEqual(lock.acquired, False)
del lock
gc.collect()
self.assertIs(ref(), None)
def test_quitWhileGettingLock(self) -> None:
"""
If L{LockWorker.do} is called concurrently with L{LockWorker.quit}, and
C{quit} wins the race before C{do} gets the lock attribute, then
L{AlreadyQuit} will be raised.
"""
class RacyLockWorker(LockWorker):
@property
def _lock(self) -> SimpleLock | None:
self.quit()
it: SimpleLock = self.__dict__["_lock"]
return it
@_lock.setter
def _lock(self, value: SimpleLock | None) -> None:
self.__dict__["_lock"] = value
worker = RacyLockWorker(FakeLock(), local())
self.assertRaises(AlreadyQuit, worker.do, list)

View File

@@ -0,0 +1,11 @@
"""
Provides Twisted version information.
"""
# This file is auto-generated! Do not edit!
# Use `python -m incremental.update Twisted` to change this file.
from incremental import Version
__version__ = Version("Twisted", 24, 10, 0)
__all__ = ["__version__"]

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Configuration objects for Twisted Applications.
"""

View File

@@ -0,0 +1,596 @@
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
"""
Implementation of L{twisted.application.internet.ClientService}, particularly
its U{automat <https://automat.readthedocs.org/>} state machine.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from random import random as _goodEnoughRandom
from typing import Callable, Optional, Protocol as TypingProtocol, TypeVar, Union
from zope.interface import implementer
from automat import TypeMachineBuilder, pep614
from twisted.application.service import Service
from twisted.internet.defer import (
CancelledError,
Deferred,
fail,
maybeDeferred,
succeed,
)
from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IProtocol,
IProtocolFactory,
IReactorTime,
IStreamClientEndpoint,
ITransport,
)
from twisted.logger import Logger
from twisted.python.failure import Failure
T = TypeVar("T")
def _maybeGlobalReactor(maybeReactor: Optional[T]) -> T:
"""
@return: the argument, or the global reactor if the argument is L{None}.
"""
if maybeReactor is None:
from twisted.internet import reactor
return reactor # type:ignore[return-value]
else:
return maybeReactor
class _Client(TypingProtocol):
def start(self) -> None:
"""
Start this L{ClientService}, initiating the connection retry loop.
"""
def stop(self) -> Deferred[None]:
"""
Stop trying to connect and disconnect any current connection.
@return: a L{Deferred} that fires when all outstanding connections are
closed and all in-progress connection attempts halted.
"""
def _connectionMade(self, protocol: _ReconnectingProtocolProxy) -> None:
"""
A connection has been made.
@param protocol: The protocol of the connection.
"""
def _connectionFailed(self, failure: Failure) -> None:
"""
Deliver connection failures to any L{ClientService.whenConnected}
L{Deferred}s that have met their failAfterFailures threshold.
@param failure: the Failure to fire the L{Deferred}s with.
"""
def _reconnect(self, failure: Optional[Failure] = None) -> None:
"""
The wait between connection attempts is done.
"""
def _clientDisconnected(self, failure: Optional[Failure] = None) -> None:
"""
The current connection has been disconnected.
"""
def whenConnected(
self, /, failAfterFailures: Optional[int] = None
) -> Deferred[IProtocol]:
"""
Retrieve the currently-connected L{Protocol}, or the next one to
connect.
@param failAfterFailures: number of connection failures after which the
Deferred will deliver a Failure (None means the Deferred will only
fail if/when the service is stopped). Set this to 1 to make the
very first connection failure signal an error. Use 2 to allow one
failure but signal an error if the subsequent retry then fails.
@return: a Deferred that fires with a protocol produced by the factory
passed to C{__init__}. It may:
- fire with L{IProtocol}
- fail with L{CancelledError} when the service is stopped
- fail with e.g.
L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
when the number of consecutive failed connection attempts
equals the value of "failAfterFailures"
"""
@implementer(IProtocol)
class _ReconnectingProtocolProxy:
"""
A proxy for a Protocol to provide connectionLost notification to a client
connection service, in support of reconnecting when connections are lost.
"""
def __init__(
self, protocol: IProtocol, lostNotification: Callable[[Failure], None]
) -> None:
"""
Create a L{_ReconnectingProtocolProxy}.
@param protocol: the application-provided L{interfaces.IProtocol}
provider.
@type protocol: provider of L{interfaces.IProtocol} which may
additionally provide L{interfaces.IHalfCloseableProtocol} and
L{interfaces.IFileDescriptorReceiver}.
@param lostNotification: a 1-argument callable to invoke with the
C{reason} when the connection is lost.
"""
self._protocol = protocol
self._lostNotification = lostNotification
def makeConnection(self, transport: ITransport) -> None:
self._transport = transport
self._protocol.makeConnection(transport)
def connectionLost(self, reason: Failure) -> None:
"""
The connection was lost. Relay this information.
@param reason: The reason the connection was lost.
@return: the underlying protocol's result
"""
try:
return self._protocol.connectionLost(reason)
finally:
self._lostNotification(reason)
def __getattr__(self, item: str) -> object:
return getattr(self._protocol, item)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} wrapping {self._protocol!r}>"
@implementer(IProtocolFactory)
class _DisconnectFactory:
"""
A L{_DisconnectFactory} is a proxy for L{IProtocolFactory} that catches
C{connectionLost} notifications and relays them.
"""
def __init__(
self,
protocolFactory: IProtocolFactory,
protocolDisconnected: Callable[[Failure], None],
) -> None:
self._protocolFactory = protocolFactory
self._protocolDisconnected = protocolDisconnected
def buildProtocol(self, addr: IAddress) -> Optional[IProtocol]:
"""
Create a L{_ReconnectingProtocolProxy} with the disconnect-notification
callback we were called with.
@param addr: The address the connection is coming from.
@return: a L{_ReconnectingProtocolProxy} for a protocol produced by
C{self._protocolFactory}
"""
built = self._protocolFactory.buildProtocol(addr)
if built is None:
return None
return _ReconnectingProtocolProxy(built, self._protocolDisconnected)
def __getattr__(self, item: str) -> object:
return getattr(self._protocolFactory, item)
def __repr__(self) -> str:
return "<{} wrapping {!r}>".format(
self.__class__.__name__, self._protocolFactory
)
def _deinterface(o: object) -> None:
"""
Remove the special runtime attributes set by L{implementer} so that a class
can proxy through those attributes with C{__getattr__} and thereby forward
optionally-provided interfaces by the delegated class.
"""
for zopeSpecial in ["__providedBy__", "__provides__", "__implemented__"]:
delattr(o, zopeSpecial)
_deinterface(_DisconnectFactory)
_deinterface(_ReconnectingProtocolProxy)
@dataclass
class _Core:
"""
Shared core for ClientService state machine.
"""
# required parameters
endpoint: IStreamClientEndpoint
factory: IProtocolFactory
timeoutForAttempt: Callable[[int], float]
clock: IReactorTime
prepareConnection: Optional[Callable[[IProtocol], object]]
# internal state
stopWaiters: list[Deferred[None]] = field(default_factory=list)
awaitingConnected: list[tuple[Deferred[IProtocol], Optional[int]]] = field(
default_factory=list
)
failedAttempts: int = 0
log: Logger = Logger()
def waitForStop(self) -> Deferred[None]:
self.stopWaiters.append(Deferred())
return self.stopWaiters[-1]
def unawait(self, value: Union[IProtocol, Failure]) -> None:
self.awaitingConnected, waiting = [], self.awaitingConnected
for w, remaining in waiting:
w.callback(value)
def cancelConnectWaiters(self) -> None:
self.unawait(Failure(CancelledError()))
def finishStopping(self) -> None:
self.stopWaiters, waiting = [], self.stopWaiters
for w in waiting:
w.callback(None)
def makeMachine() -> Callable[[_Core], _Client]:
machine = TypeMachineBuilder(_Client, _Core)
def waitForRetry(
c: _Client, s: _Core, failure: Optional[Failure] = None
) -> IDelayedCall:
s.failedAttempts += 1
delay = s.timeoutForAttempt(s.failedAttempts)
s.log.info(
"Scheduling retry {attempt} to connect {endpoint} in {delay} seconds.",
attempt=s.failedAttempts,
endpoint=s.endpoint,
delay=delay,
)
return s.clock.callLater(delay, c._reconnect)
def rememberConnection(
c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
) -> _ReconnectingProtocolProxy:
s.failedAttempts = 0
s.unawait(protocol._protocol)
return protocol
def attemptConnection(
c: _Client, s: _Core, failure: Optional[Failure] = None
) -> Deferred[_ReconnectingProtocolProxy]:
factoryProxy = _DisconnectFactory(s.factory, c._clientDisconnected)
connecting: Deferred[IProtocol] = s.endpoint.connect(factoryProxy)
def prepare(
protocol: _ReconnectingProtocolProxy,
) -> Deferred[_ReconnectingProtocolProxy]:
if s.prepareConnection is not None:
return maybeDeferred(s.prepareConnection, protocol).addCallback(
lambda _: protocol
)
return succeed(protocol)
# endpoint.connect() is actually generic on the type of the protocol,
# but this is not expressible via zope.interface, so we have to cast
# https://github.com/Shoobx/mypy-zope/issues/95
connectingProxy: Deferred[_ReconnectingProtocolProxy]
connectingProxy = connecting # type:ignore[assignment]
(
connectingProxy.addCallback(prepare)
.addCallback(c._connectionMade)
.addErrback(c._connectionFailed)
)
return connectingProxy
# States:
Init = machine.state("Init")
Connecting = machine.state("Connecting", attemptConnection)
Stopped = machine.state("Stopped")
Waiting = machine.state("Waiting", waitForRetry)
Connected = machine.state("Connected", rememberConnection)
Disconnecting = machine.state("Disconnecting")
Restarting = machine.state("Restarting")
Stopped = machine.state("Stopped")
# Behavior-less state transitions:
Init.upon(_Client.start).to(Connecting).returns(None)
Connecting.upon(_Client.start).loop().returns(None)
Connecting.upon(_Client._connectionMade).to(Connected).returns(None)
Waiting.upon(_Client.start).loop().returns(None)
Waiting.upon(_Client._reconnect).to(Connecting).returns(None)
Connected.upon(_Client._connectionFailed).to(Waiting).returns(None)
Connected.upon(_Client.start).loop().returns(None)
Connected.upon(_Client._clientDisconnected).to(Waiting).returns(None)
Disconnecting.upon(_Client.start).to(Restarting).returns(None)
Restarting.upon(_Client.start).to(Restarting).returns(None)
Stopped.upon(_Client.start).to(Connecting).returns(None)
# Behavior-full state transitions:
@pep614(Init.upon(_Client.stop).to(Stopped))
@pep614(Stopped.upon(_Client.stop).to(Stopped))
def immediateStop(c: _Client, s: _Core) -> Deferred[None]:
return succeed(None)
@pep614(Connecting.upon(_Client.stop).to(Disconnecting))
def connectingStop(
c: _Client, s: _Core, attempt: Deferred[_ReconnectingProtocolProxy]
) -> Deferred[None]:
waited = s.waitForStop()
attempt.cancel()
return waited
@pep614(Connecting.upon(_Client._connectionFailed, nodata=True).to(Waiting))
def failedWhenConnecting(c: _Client, s: _Core, failure: Failure) -> None:
ready = []
notReady: list[tuple[Deferred[IProtocol], Optional[int]]] = []
for w, remaining in s.awaitingConnected:
if remaining is None:
notReady.append((w, remaining))
elif remaining <= 1:
ready.append(w)
else:
notReady.append((w, remaining - 1))
s.awaitingConnected = notReady
for w in ready:
w.callback(failure)
@pep614(Waiting.upon(_Client.stop).to(Stopped))
def stop(c: _Client, s: _Core, futureRetry: IDelayedCall) -> Deferred[None]:
waited = s.waitForStop()
s.cancelConnectWaiters()
futureRetry.cancel()
s.finishStopping()
return waited
@pep614(Connected.upon(_Client.stop).to(Disconnecting))
def stopWhileConnected(
c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
) -> Deferred[None]:
waited = s.waitForStop()
protocol._transport.loseConnection()
return waited
@pep614(Connected.upon(_Client.whenConnected).loop())
def whenConnectedWhenConnected(
c: _Client,
s: _Core,
protocol: _ReconnectingProtocolProxy,
failAfterFailures: Optional[int] = None,
) -> Deferred[IProtocol]:
return succeed(protocol._protocol)
@pep614(Disconnecting.upon(_Client.stop).loop())
@pep614(Restarting.upon(_Client.stop).to(Disconnecting))
def discoStop(c: _Client, s: _Core) -> Deferred[None]:
return s.waitForStop()
@pep614(Disconnecting.upon(_Client._connectionFailed).to(Stopped))
@pep614(Disconnecting.upon(_Client._clientDisconnected).to(Stopped))
def disconnectingFinished(
c: _Client, s: _Core, failure: Optional[Failure] = None
) -> None:
s.cancelConnectWaiters()
s.finishStopping()
@pep614(Connecting.upon(_Client.whenConnected, nodata=True).loop())
@pep614(Waiting.upon(_Client.whenConnected, nodata=True).loop())
@pep614(Init.upon(_Client.whenConnected).to(Init))
@pep614(Restarting.upon(_Client.whenConnected).to(Restarting))
@pep614(Disconnecting.upon(_Client.whenConnected).to(Disconnecting))
def awaitingConnection(
c: _Client, s: _Core, failAfterFailures: Optional[int] = None
) -> Deferred[IProtocol]:
result: Deferred[IProtocol] = Deferred()
s.awaitingConnected.append((result, failAfterFailures))
return result
@pep614(Restarting.upon(_Client._clientDisconnected).to(Connecting))
def restartDone(c: _Client, s: _Core, failure: Optional[Failure] = None) -> None:
s.finishStopping()
@pep614(Stopped.upon(_Client.whenConnected).to(Stopped))
def notGoingToConnect(
c: _Client, s: _Core, failAfterFailures: Optional[int] = None
) -> Deferred[IProtocol]:
return fail(CancelledError())
return machine.build()
def backoffPolicy(
initialDelay: float = 1.0,
maxDelay: float = 60.0,
factor: float = 1.5,
jitter: Callable[[], float] = _goodEnoughRandom,
) -> Callable[[int], float]:
"""
A timeout policy for L{ClientService} which computes an exponential backoff
interval with configurable parameters.
@since: 16.1.0
@param initialDelay: Delay for the first reconnection attempt (default
1.0s).
@type initialDelay: L{float}
@param maxDelay: Maximum number of seconds between connection attempts
(default 60 seconds, or one minute). Note that this value is before
jitter is applied, so the actual maximum possible delay is this value
plus the maximum possible result of C{jitter()}.
@type maxDelay: L{float}
@param factor: A multiplicative factor by which the delay grows on each
failed reattempt. Default: 1.5.
@type factor: L{float}
@param jitter: A 0-argument callable that introduces noise into the delay.
By default, C{random.random}, i.e. a pseudorandom floating-point value
between zero and one.
@type jitter: 0-argument callable returning L{float}
@return: a 1-argument callable that, given an attempt count, returns a
floating point number; the number of seconds to delay.
@rtype: see L{ClientService.__init__}'s C{retryPolicy} argument.
"""
def policy(attempt: int) -> float:
try:
delay = min(initialDelay * (factor ** min(100, attempt)), maxDelay)
except OverflowError:
delay = maxDelay
return delay + jitter()
return policy
_defaultPolicy = backoffPolicy()
ClientMachine = makeMachine()
class ClientService(Service):
"""
A L{ClientService} maintains a single outgoing connection to a client
endpoint, reconnecting after a configurable timeout when a connection
fails, either before or after connecting.
@since: 16.1.0
"""
_log = Logger()
def __init__(
self,
endpoint: IStreamClientEndpoint,
factory: IProtocolFactory,
retryPolicy: Optional[Callable[[int], float]] = None,
clock: Optional[IReactorTime] = None,
prepareConnection: Optional[Callable[[IProtocol], object]] = None,
):
"""
@param endpoint: A L{stream client endpoint
<interfaces.IStreamClientEndpoint>} provider which will be used to
connect when the service starts.
@param factory: A L{protocol factory <interfaces.IProtocolFactory>}
which will be used to create clients for the endpoint.
@param retryPolicy: A policy configuring how long L{ClientService} will
wait between attempts to connect to C{endpoint}; a callable taking
(the number of failed connection attempts made in a row (L{int}))
and returning the number of seconds to wait before making another
attempt.
@param clock: The clock used to schedule reconnection. It's mainly
useful to be parametrized in tests. If the factory is serialized,
this attribute will not be serialized, and the default value (the
reactor) will be restored when deserialized.
@param prepareConnection: A single argument L{callable} that may return
a L{Deferred}. It will be called once with the L{protocol
<interfaces.IProtocol>} each time a new connection is made. It may
call methods on the protocol to prepare it for use (e.g.
authenticate) or validate it (check its health).
The C{prepareConnection} callable may raise an exception or return
a L{Deferred} which fails to reject the connection. A rejected
connection is not used to fire an L{Deferred} returned by
L{whenConnected}. Instead, L{ClientService} handles the failure
and continues as if the connection attempt were a failure
(incrementing the counter passed to C{retryPolicy}).
L{Deferred}s returned by L{whenConnected} will not fire until any
L{Deferred} returned by the C{prepareConnection} callable fire.
Otherwise its successful return value is consumed, but ignored.
Present Since Twisted 18.7.0
"""
clock = _maybeGlobalReactor(clock)
retryPolicy = _defaultPolicy if retryPolicy is None else retryPolicy
self._machine: _Client = ClientMachine(
_Core(
endpoint,
factory,
retryPolicy,
clock,
prepareConnection=prepareConnection,
log=self._log,
)
)
def whenConnected(
self, failAfterFailures: Optional[int] = None
) -> Deferred[IProtocol]:
"""
Retrieve the currently-connected L{Protocol}, or the next one to
connect.
@param failAfterFailures: number of connection failures after which
the Deferred will deliver a Failure (None means the Deferred will
only fail if/when the service is stopped). Set this to 1 to make
the very first connection failure signal an error. Use 2 to
allow one failure but signal an error if the subsequent retry
then fails.
@type failAfterFailures: L{int} or None
@return: a Deferred that fires with a protocol produced by the
factory passed to C{__init__}
@rtype: L{Deferred} that may:
- fire with L{IProtocol}
- fail with L{CancelledError} when the service is stopped
- fail with e.g.
L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
when the number of consecutive failed connection attempts
equals the value of "failAfterFailures"
"""
return self._machine.whenConnected(failAfterFailures)
def startService(self) -> None:
"""
Start this L{ClientService}, initiating the connection retry loop.
"""
if self.running:
self._log.warn("Duplicate ClientService.startService {log_source}")
return
super().startService()
self._machine.start()
def stopService(self) -> Deferred[None]:
"""
Stop attempting to reconnect and close any existing connections.
@return: a L{Deferred} that fires when all outstanding connections are
closed and all in-progress connection attempts halted.
"""
super().stopService()
return self._machine.stop()

View File

@@ -0,0 +1,706 @@
# -*- test-case-name: twisted.test.test_application,twisted.test.test_twistd -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import getpass
import os
import pdb
import signal
import sys
import traceback
import warnings
from operator import attrgetter
from twisted import copyright, logger, plugin
from twisted.application import reactors, service
# Expose the new implementation of installReactor at the old location.
from twisted.application.reactors import NoSuchReactor, installReactor
from twisted.internet import defer
from twisted.internet.interfaces import _ISupportsExitSignalCapturing
from twisted.persisted import sob
from twisted.python import failure, log, logfile, runtime, usage, util
from twisted.python.reflect import namedAny, namedModule, qual
class _BasicProfiler:
"""
@ivar saveStats: if C{True}, save the stats information instead of the
human readable format
@type saveStats: C{bool}
@ivar profileOutput: the name of the file use to print profile data.
@type profileOutput: C{str}
"""
def __init__(self, profileOutput, saveStats):
self.profileOutput = profileOutput
self.saveStats = saveStats
def _reportImportError(self, module, e):
"""
Helper method to report an import error with a profile module. This
has to be explicit because some of these modules are removed by
distributions due to them being non-free.
"""
s = f"Failed to import module {module}: {e}"
s += """
This is most likely caused by your operating system not including
the module due to it being non-free. Either do not use the option
--profile, or install the module; your operating system vendor
may provide it in a separate package.
"""
raise SystemExit(s)
class ProfileRunner(_BasicProfiler):
"""
Runner for the standard profile module.
"""
def run(self, reactor):
"""
Run reactor under the standard profiler.
"""
try:
import profile
except ImportError as e:
self._reportImportError("profile", e)
p = profile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
tmp, sys.stdout = sys.stdout, open(self.profileOutput, "a")
try:
p.print_stats()
finally:
sys.stdout, tmp = tmp, sys.stdout
tmp.close()
class CProfileRunner(_BasicProfiler):
"""
Runner for the cProfile module.
"""
def run(self, reactor):
"""
Run reactor under the cProfile profiler.
"""
try:
import cProfile
import pstats
except ImportError as e:
self._reportImportError("cProfile", e)
p = cProfile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
with open(self.profileOutput, "w") as stream:
s = pstats.Stats(p, stream=stream)
s.strip_dirs()
s.sort_stats(-1)
s.print_stats()
class AppProfiler:
"""
Class which selects a specific profile runner based on configuration
options.
@ivar profiler: the name of the selected profiler.
@type profiler: C{str}
"""
profilers = {"profile": ProfileRunner, "cprofile": CProfileRunner}
def __init__(self, options):
saveStats = options.get("savestats", False)
profileOutput = options.get("profile", None)
self.profiler = options.get("profiler", "cprofile").lower()
if self.profiler in self.profilers:
profiler = self.profilers[self.profiler](profileOutput, saveStats)
self.run = profiler.run
else:
raise SystemExit(f"Unsupported profiler name: {self.profiler}")
class AppLogger:
"""
An L{AppLogger} attaches the configured log observer specified on the
commandline to a L{ServerOptions} object, a custom L{logger.ILogObserver},
or a legacy custom {log.ILogObserver}.
@ivar _logfilename: The name of the file to which to log, if other than the
default.
@type _logfilename: C{str}
@ivar _observerFactory: Callable object that will create a log observer, or
None.
@ivar _observer: log observer added at C{start} and removed at C{stop}.
@type _observer: a callable that implements L{logger.ILogObserver} or
L{log.ILogObserver}.
"""
_observer = None
def __init__(self, options):
"""
Initialize an L{AppLogger} with a L{ServerOptions}.
"""
self._logfilename = options.get("logfile", "")
self._observerFactory = options.get("logger") or None
def start(self, application):
"""
Initialize the global logging system for the given application.
If a custom logger was specified on the command line it will be used.
If not, and an L{logger.ILogObserver} or legacy L{log.ILogObserver}
component has been set on C{application}, then it will be used as the
log observer. Otherwise a log observer will be created based on the
command line options for built-in loggers (e.g. C{--logfile}).
@param application: The application on which to check for an
L{logger.ILogObserver} or legacy L{log.ILogObserver}.
@type application: L{twisted.python.components.Componentized}
"""
if self._observerFactory is not None:
observer = self._observerFactory()
else:
observer = application.getComponent(logger.ILogObserver, None)
if observer is None:
# If there's no new ILogObserver, try the legacy one
observer = application.getComponent(log.ILogObserver, None)
if observer is None:
observer = self._getLogObserver()
self._observer = observer
if logger.ILogObserver.providedBy(self._observer):
observers = [self._observer]
elif log.ILogObserver.providedBy(self._observer):
observers = [logger.LegacyLogObserverWrapper(self._observer)]
else:
warnings.warn(
(
"Passing a logger factory which makes log observers which do "
"not implement twisted.logger.ILogObserver or "
"twisted.python.log.ILogObserver to "
"twisted.application.app.AppLogger was deprecated in "
"Twisted 16.2. Please use a factory that produces "
"twisted.logger.ILogObserver (or the legacy "
"twisted.python.log.ILogObserver) implementing objects "
"instead."
),
DeprecationWarning,
stacklevel=2,
)
observers = [logger.LegacyLogObserverWrapper(self._observer)]
logger.globalLogBeginner.beginLoggingTo(observers)
self._initialLog()
def _initialLog(self):
"""
Print twistd start log message.
"""
from twisted.internet import reactor
logger._loggerFor(self).info(
"twistd {version} ({exe} {pyVersion}) starting up.",
version=copyright.version,
exe=sys.executable,
pyVersion=runtime.shortPythonVersion(),
)
logger._loggerFor(self).info(
"reactor class: {reactor}.", reactor=qual(reactor.__class__)
)
def _getLogObserver(self):
"""
Create a log observer to be added to the logging system before running
this application.
"""
if self._logfilename == "-" or not self._logfilename:
logFile = sys.stdout
else:
logFile = logfile.LogFile.fromFullPath(self._logfilename)
return logger.textFileLogObserver(logFile)
def stop(self):
"""
Remove all log observers previously set up by L{AppLogger.start}.
"""
logger._loggerFor(self).info("Server Shut Down.")
if self._observer is not None:
logger.globalLogPublisher.removeObserver(self._observer)
self._observer = None
def fixPdb():
def do_stop(self, arg):
self.clear_all_breaks()
self.set_continue()
from twisted.internet import reactor
reactor.callLater(0, reactor.stop)
return 1
def help_stop(self):
print(
"stop - Continue execution, then cleanly shutdown the twisted " "reactor."
)
def set_quit(self):
os._exit(0)
pdb.Pdb.set_quit = set_quit
pdb.Pdb.do_stop = do_stop
pdb.Pdb.help_stop = help_stop
def runReactorWithLogging(config, oldstdout, oldstderr, profiler=None, reactor=None):
"""
Start the reactor, using profiling if specified by the configuration, and
log any error happening in the process.
@param config: configuration of the twistd application.
@type config: L{ServerOptions}
@param oldstdout: initial value of C{sys.stdout}.
@type oldstdout: C{file}
@param oldstderr: initial value of C{sys.stderr}.
@type oldstderr: C{file}
@param profiler: object used to run the reactor with profiling.
@type profiler: L{AppProfiler}
@param reactor: The reactor to use. If L{None}, the global reactor will
be used.
"""
if reactor is None:
from twisted.internet import reactor
try:
if config["profile"]:
if profiler is not None:
profiler.run(reactor)
elif config["debug"]:
sys.stdout = oldstdout
sys.stderr = oldstderr
if runtime.platformType == "posix":
signal.signal(signal.SIGUSR2, lambda *args: pdb.set_trace())
signal.signal(signal.SIGINT, lambda *args: pdb.set_trace())
fixPdb()
pdb.runcall(reactor.run)
else:
reactor.run()
except BaseException:
close = False
if config["nodaemon"]:
file = oldstdout
else:
file = open("TWISTD-CRASH.log", "a")
close = True
try:
traceback.print_exc(file=file)
file.flush()
finally:
if close:
file.close()
def getPassphrase(needed):
if needed:
return getpass.getpass("Passphrase: ")
else:
return None
def getSavePassphrase(needed):
if needed:
return util.getPassword("Encryption passphrase: ")
else:
return None
class ApplicationRunner:
"""
An object which helps running an application based on a config object.
Subclass me and implement preApplication and postApplication
methods. postApplication generally will want to run the reactor
after starting the application.
@ivar config: The config object, which provides a dict-like interface.
@ivar application: Available in postApplication, but not
preApplication. This is the application object.
@ivar profilerFactory: Factory for creating a profiler object, able to
profile the application if options are set accordingly.
@ivar profiler: Instance provided by C{profilerFactory}.
@ivar loggerFactory: Factory for creating object responsible for logging.
@ivar logger: Instance provided by C{loggerFactory}.
"""
profilerFactory = AppProfiler
loggerFactory = AppLogger
def __init__(self, config):
self.config = config
self.profiler = self.profilerFactory(config)
self.logger = self.loggerFactory(config)
def run(self):
"""
Run the application.
"""
self.preApplication()
self.application = self.createOrGetApplication()
self.logger.start(self.application)
self.postApplication()
self.logger.stop()
def startReactor(self, reactor, oldstdout, oldstderr):
"""
Run the reactor with the given configuration. Subclasses should
probably call this from C{postApplication}.
@see: L{runReactorWithLogging}
"""
if reactor is None:
from twisted.internet import reactor
runReactorWithLogging(self.config, oldstdout, oldstderr, self.profiler, reactor)
if _ISupportsExitSignalCapturing.providedBy(reactor):
self._exitSignal = reactor._exitSignal
else:
self._exitSignal = None
def preApplication(self):
"""
Override in subclass.
This should set up any state necessary before loading and
running the Application.
"""
raise NotImplementedError()
def postApplication(self):
"""
Override in subclass.
This will be called after the application has been loaded (so
the C{application} attribute will be set). Generally this
should start the application and run the reactor.
"""
raise NotImplementedError()
def createOrGetApplication(self):
"""
Create or load an Application based on the parameters found in the
given L{ServerOptions} instance.
If a subcommand was used, the L{service.IServiceMaker} that it
represents will be used to construct a service to be added to
a newly-created Application.
Otherwise, an application will be loaded based on parameters in
the config.
"""
if self.config.subCommand:
# If a subcommand was given, it's our responsibility to create
# the application, instead of load it from a file.
# loadedPlugins is set up by the ServerOptions.subCommands
# property, which is iterated somewhere in the bowels of
# usage.Options.
plg = self.config.loadedPlugins[self.config.subCommand]
ser = plg.makeService(self.config.subOptions)
application = service.Application(plg.tapname)
ser.setServiceParent(application)
else:
passphrase = getPassphrase(self.config["encrypted"])
application = getApplication(self.config, passphrase)
return application
def getApplication(config, passphrase):
s = [(config[t], t) for t in ["python", "source", "file"] if config[t]][0]
filename, style = s[0], {"file": "pickle"}.get(s[1], s[1])
try:
log.msg("Loading %s..." % filename)
application = service.loadApplication(filename, style, passphrase)
log.msg("Loaded.")
except Exception as e:
s = "Failed to load application: %s" % e
if isinstance(e, KeyError) and e.args[0] == "application":
s += """
Could not find 'application' in the file. To use 'twistd -y', your .tac
file must create a suitable object (e.g., by calling service.Application())
and store it in a variable named 'application'. twistd loads your .tac file
and scans the global variables for one of this name.
Please read the 'Using Application' HOWTO for details.
"""
traceback.print_exc(file=log.logfile)
log.msg(s)
log.deferr()
sys.exit("\n" + s + "\n")
return application
def _reactorAction():
return usage.CompleteList([r.shortName for r in reactors.getReactorTypes()])
class ReactorSelectionMixin:
"""
Provides options for selecting a reactor to install.
If a reactor is installed, the short name which was used to locate it is
saved as the value for the C{"reactor"} key.
"""
compData = usage.Completions(optActions={"reactor": _reactorAction})
messageOutput = sys.stdout
_getReactorTypes = staticmethod(reactors.getReactorTypes)
def opt_help_reactors(self):
"""
Display a list of possibly available reactor names.
"""
rcts = sorted(self._getReactorTypes(), key=attrgetter("shortName"))
notWorkingReactors = ""
for r in rcts:
try:
namedModule(r.moduleName)
self.messageOutput.write(f" {r.shortName:<4}\t{r.description}\n")
except ImportError as e:
notWorkingReactors += " !{:<4}\t{} ({})\n".format(
r.shortName,
r.description,
e.args[0],
)
if notWorkingReactors:
self.messageOutput.write("\n")
self.messageOutput.write(
" reactors not available " "on this platform:\n\n"
)
self.messageOutput.write(notWorkingReactors)
raise SystemExit(0)
def opt_reactor(self, shortName):
"""
Which reactor to use (see --help-reactors for a list of possibilities)
"""
# Actually actually actually install the reactor right at this very
# moment, before any other code (for example, a sub-command plugin)
# runs and accidentally imports and installs the default reactor.
#
# This could probably be improved somehow.
try:
installReactor(shortName)
except NoSuchReactor:
msg = (
"The specified reactor does not exist: '%s'.\n"
"See the list of available reactors with "
"--help-reactors" % (shortName,)
)
raise usage.UsageError(msg)
except Exception as e:
msg = (
"The specified reactor cannot be used, failed with error: "
"%s.\nSee the list of available reactors with "
"--help-reactors" % (e,)
)
raise usage.UsageError(msg)
else:
self["reactor"] = shortName
opt_r = opt_reactor
class ServerOptions(usage.Options, ReactorSelectionMixin):
longdesc = (
"twistd reads a twisted.application.service.Application out "
"of a file and runs it."
)
optFlags = [
[
"savestats",
None,
"save the Stats object rather than the text output of " "the profiler.",
],
["no_save", "o", "do not save state on shutdown"],
["encrypted", "e", "The specified tap/aos file is encrypted."],
]
optParameters = [
["logfile", "l", None, "log to a specified file, - for stdout"],
[
"logger",
None,
None,
"A fully-qualified name to a log observer factory to "
"use for the initial log observer. Takes precedence "
"over --logfile and --syslog (when available).",
],
[
"profile",
"p",
None,
"Run in profile mode, dumping results to specified " "file.",
],
[
"profiler",
None,
"cprofile",
"Name of the profiler to use (%s)." % ", ".join(AppProfiler.profilers),
],
["file", "f", "twistd.tap", "read the given .tap file"],
[
"python",
"y",
None,
"read an application from within a Python file " "(implies -o)",
],
["source", "s", None, "Read an application from a .tas file (AOT format)."],
["rundir", "d", ".", "Change to a supplied directory before running"],
]
compData = usage.Completions(
mutuallyExclusive=[("file", "python", "source")],
optActions={
"file": usage.CompleteFiles("*.tap"),
"python": usage.CompleteFiles("*.(tac|py)"),
"source": usage.CompleteFiles("*.tas"),
"rundir": usage.CompleteDirs(),
},
)
_getPlugins = staticmethod(plugin.getPlugins)
def __init__(self, *a, **kw):
self["debug"] = False
if "stdout" in kw:
self.stdout = kw["stdout"]
else:
self.stdout = sys.stdout
usage.Options.__init__(self)
def opt_debug(self):
"""
Run the application in the Python Debugger (implies nodaemon),
sending SIGUSR2 will drop into debugger
"""
defer.setDebugging(True)
failure.startDebugMode()
self["debug"] = True
opt_b = opt_debug
def opt_spew(self):
"""
Print an insanely verbose log of everything that happens.
Useful when debugging freezes or locks in complex code.
"""
sys.settrace(util.spewer)
try:
import threading
except ImportError:
return
threading.settrace(util.spewer)
def parseOptions(self, options=None):
if options is None:
options = sys.argv[1:] or ["--help"]
usage.Options.parseOptions(self, options)
def postOptions(self):
if self.subCommand or self["python"]:
self["no_save"] = True
if self["logger"] is not None:
try:
self["logger"] = namedAny(self["logger"])
except Exception as e:
raise usage.UsageError(
"Logger '{}' could not be imported: {}".format(self["logger"], e)
)
@property
def subCommands(self):
plugins = self._getPlugins(service.IServiceMaker)
self.loadedPlugins = {}
for plug in sorted(plugins, key=attrgetter("tapname")):
self.loadedPlugins[plug.tapname] = plug
yield (
plug.tapname,
None,
# Avoid resolving the options attribute right away, in case
# it's a property with a non-trivial getter (eg, one which
# imports modules).
lambda plug=plug: plug.options(),
plug.description,
)
def run(runApp, ServerOptions):
config = ServerOptions()
try:
config.parseOptions()
except usage.error as ue:
commstr = " ".join(sys.argv[0:2])
print(config)
print(f"{commstr}: {ue}")
else:
runApp(config)
def convertStyle(filein, typein, passphrase, fileout, typeout, encrypt):
application = service.loadApplication(filein, typein, passphrase)
sob.IPersistable(application).setStyle(typeout)
passphrase = getSavePassphrase(encrypt)
if passphrase:
fileout = None
sob.IPersistable(application).save(filename=fileout, passphrase=passphrase)
def startApplication(application, save):
from twisted.internet import reactor
service.IService(application).startService()
if save:
p = sob.IPersistable(application)
reactor.addSystemEventTrigger("after", "shutdown", p.save, "shutdown")
reactor.addSystemEventTrigger(
"before", "shutdown", service.IService(application).stopService
)
def _exitWithSignal(sig):
"""
Force the application to terminate with the specified signal by replacing
the signal handler with the default and sending the signal to ourselves.
@param sig: Signal to use to terminate the process with C{os.kill}.
@type sig: C{int}
"""
signal.signal(sig, signal.SIG_DFL)
os.kill(os.getpid(), sig)

View File

@@ -0,0 +1,427 @@
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Reactor-based Services
Here are services to run clients, servers and periodic services using
the reactor.
If you want to run a server service, L{StreamServerEndpointService} defines a
service that can wrap an arbitrary L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>}
as an L{IService}. See also L{twisted.application.strports.service} for
constructing one of these directly from a descriptive string.
Additionally, this module (dynamically) defines various Service subclasses that
let you represent clients and servers in a Service hierarchy. Endpoints APIs
should be preferred for stream server services, but since those APIs do not yet
exist for clients or datagram services, many of these are still useful.
They are as follows::
TCPServer, TCPClient,
UNIXServer, UNIXClient,
SSLServer, SSLClient,
UDPServer,
UNIXDatagramServer, UNIXDatagramClient,
MulticastServer
These classes take arbitrary arguments in their constructors and pass
them straight on to their respective reactor.listenXXX or
reactor.connectXXX calls.
For example, the following service starts a web server on port 8080:
C{TCPServer(8080, server.Site(r))}. See the documentation for the
reactor.listen/connect* methods for more information.
"""
from typing import List
from twisted.application import service
from twisted.internet import task
from twisted.internet.defer import CancelledError
from twisted.python import log
from ._client_service import ClientService, _maybeGlobalReactor, backoffPolicy
class _VolatileDataService(service.Service):
volatile: List[str] = []
def __getstate__(self):
d = service.Service.__getstate__(self)
for attr in self.volatile:
if attr in d:
del d[attr]
return d
class _AbstractServer(_VolatileDataService):
"""
@cvar volatile: list of attribute to remove from pickling.
@type volatile: C{list}
@ivar method: the type of method to call on the reactor, one of B{TCP},
B{UDP}, B{SSL} or B{UNIX}.
@type method: C{str}
@ivar reactor: the current running reactor.
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
C{IReactorSSL} or C{IReactorUnix}.
@ivar _port: instance of port set when the service is started.
@type _port: a provider of L{twisted.internet.interfaces.IListeningPort}.
"""
volatile = ["_port"]
method: str = ""
reactor = None
_port = None
def __init__(self, *args, **kwargs):
self.args = args
if "reactor" in kwargs:
self.reactor = kwargs.pop("reactor")
self.kwargs = kwargs
def privilegedStartService(self):
service.Service.privilegedStartService(self)
self._port = self._getPort()
def startService(self):
service.Service.startService(self)
if self._port is None:
self._port = self._getPort()
def stopService(self):
service.Service.stopService(self)
# TODO: if startup failed, should shutdown skip stopListening?
# _port won't exist
if self._port is not None:
d = self._port.stopListening()
del self._port
return d
def _getPort(self):
"""
Wrapper around the appropriate listen method of the reactor.
@return: the port object returned by the listen method.
@rtype: an object providing
L{twisted.internet.interfaces.IListeningPort}.
"""
return getattr(
_maybeGlobalReactor(self.reactor),
"listen{}".format(
self.method,
),
)(*self.args, **self.kwargs)
class _AbstractClient(_VolatileDataService):
"""
@cvar volatile: list of attribute to remove from pickling.
@type volatile: C{list}
@ivar method: the type of method to call on the reactor, one of B{TCP},
B{UDP}, B{SSL} or B{UNIX}.
@type method: C{str}
@ivar reactor: the current running reactor.
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
C{IReactorSSL} or C{IReactorUnix}.
@ivar _connection: instance of connection set when the service is started.
@type _connection: a provider of L{twisted.internet.interfaces.IConnector}.
"""
volatile = ["_connection"]
method: str = ""
reactor = None
_connection = None
def __init__(self, *args, **kwargs):
self.args = args
if "reactor" in kwargs:
self.reactor = kwargs.pop("reactor")
self.kwargs = kwargs
def startService(self):
service.Service.startService(self)
self._connection = self._getConnection()
def stopService(self):
service.Service.stopService(self)
if self._connection is not None:
self._connection.disconnect()
del self._connection
def _getConnection(self):
"""
Wrapper around the appropriate connect method of the reactor.
@return: the port object returned by the connect method.
@rtype: an object providing L{twisted.internet.interfaces.IConnector}.
"""
return getattr(_maybeGlobalReactor(self.reactor), f"connect{self.method}")(
*self.args, **self.kwargs
)
_clientDoc = """Connect to {tran}
Call reactor.connect{tran} when the service starts, with the
arguments given to the constructor.
"""
_serverDoc = """Serve {tran} clients
Call reactor.listen{tran} when the service starts, with the
arguments given to the constructor. When the service stops,
stop listening. See twisted.internet.interfaces for documentation
on arguments to the reactor method.
"""
class TCPServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="TCP")
method = "TCP"
class TCPClient(_AbstractClient):
__doc__ = _clientDoc.format(tran="TCP")
method = "TCP"
class UNIXServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="UNIX")
method = "UNIX"
class UNIXClient(_AbstractClient):
__doc__ = _clientDoc.format(tran="UNIX")
method = "UNIX"
class SSLServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="SSL")
method = "SSL"
class SSLClient(_AbstractClient):
__doc__ = _clientDoc.format(tran="SSL")
method = "SSL"
class UDPServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="UDP")
method = "UDP"
class UNIXDatagramServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="UNIXDatagram")
method = "UNIXDatagram"
class UNIXDatagramClient(_AbstractClient):
__doc__ = _clientDoc.format(tran="UNIXDatagram")
method = "UNIXDatagram"
class MulticastServer(_AbstractServer):
__doc__ = _serverDoc.format(tran="Multicast")
method = "Multicast"
class TimerService(_VolatileDataService):
"""
Service to periodically call a function
Every C{step} seconds call the given function with the given arguments.
The service starts the calls when it starts, and cancels them
when it stops.
@ivar clock: Source of time. This defaults to L{None} which is
causes L{twisted.internet.reactor} to be used.
Feel free to set this to something else, but it probably ought to be
set *before* calling L{startService}.
@type clock: L{IReactorTime<twisted.internet.interfaces.IReactorTime>}
@ivar call: Function and arguments to call periodically.
@type call: L{tuple} of C{(callable, args, kwargs)}
"""
volatile = ["_loop", "_loopFinished"]
def __init__(self, step, callable, *args, **kwargs):
"""
@param step: The number of seconds between calls.
@type step: L{float}
@param callable: Function to call
@type callable: L{callable}
@param args: Positional arguments to pass to function
@param kwargs: Keyword arguments to pass to function
"""
self.step = step
self.call = (callable, args, kwargs)
self.clock = None
def startService(self):
service.Service.startService(self)
callable, args, kwargs = self.call
# we have to make a new LoopingCall each time we're started, because
# an active LoopingCall remains active when serialized. If
# LoopingCall were a _VolatileDataService, we wouldn't need to do
# this.
self._loop = task.LoopingCall(callable, *args, **kwargs)
self._loop.clock = _maybeGlobalReactor(self.clock)
self._loopFinished = self._loop.start(self.step, now=True)
self._loopFinished.addErrback(self._failed)
def _failed(self, why):
# make a note that the LoopingCall is no longer looping, so we don't
# try to shut it down a second time in stopService. I think this
# should be in LoopingCall. -warner
self._loop.running = False
log.err(why)
def stopService(self):
"""
Stop the service.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is fired when the
currently running call (if any) is finished.
"""
if self._loop.running:
self._loop.stop()
self._loopFinished.addCallback(lambda _: service.Service.stopService(self))
return self._loopFinished
class CooperatorService(service.Service):
"""
Simple L{service.IService} which starts and stops a L{twisted.internet.task.Cooperator}.
"""
def __init__(self):
self.coop = task.Cooperator(started=False)
def coiterate(self, iterator):
return self.coop.coiterate(iterator)
def startService(self):
self.coop.start()
def stopService(self):
self.coop.stop()
class StreamServerEndpointService(service.Service):
"""
A L{StreamServerEndpointService} is an L{IService} which runs a server on a
listening port described by an L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>}.
@ivar factory: A server factory which will be used to listen on the
endpoint.
@ivar endpoint: An L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>} provider
which will be used to listen when the service starts.
@ivar _waitingForPort: a Deferred, if C{listen} has yet been invoked on the
endpoint, otherwise None.
@ivar _raiseSynchronously: Defines error-handling behavior for the case
where C{listen(...)} raises an exception before C{startService} or
C{privilegedStartService} have completed.
@type _raiseSynchronously: C{bool}
@since: 10.2
"""
_raiseSynchronously = False
def __init__(self, endpoint, factory):
self.endpoint = endpoint
self.factory = factory
self._waitingForPort = None
def privilegedStartService(self):
"""
Start listening on the endpoint.
"""
service.Service.privilegedStartService(self)
self._waitingForPort = self.endpoint.listen(self.factory)
raisedNow = []
def handleIt(err):
if self._raiseSynchronously:
raisedNow.append(err)
elif not err.check(CancelledError):
log.err(err)
self._waitingForPort.addErrback(handleIt)
if raisedNow:
raisedNow[0].raiseException()
self._raiseSynchronously = False
def startService(self):
"""
Start listening on the endpoint, unless L{privilegedStartService} got
around to it already.
"""
service.Service.startService(self)
if self._waitingForPort is None:
self.privilegedStartService()
def stopService(self):
"""
Stop listening on the port if it is already listening, otherwise,
cancel the attempt to listen.
@return: a L{Deferred<twisted.internet.defer.Deferred>} which fires
with L{None} when the port has stopped listening.
"""
self._waitingForPort.cancel()
def stopIt(port):
if port is not None:
return port.stopListening()
d = self._waitingForPort.addCallback(stopIt)
def stop(passthrough):
self.running = False
return passthrough
d.addBoth(stop)
return d
__all__ = [
"TimerService",
"CooperatorService",
"MulticastServer",
"StreamServerEndpointService",
"UDPServer",
"ClientService",
"TCPServer",
"TCPClient",
"UNIXServer",
"UNIXClient",
"SSLServer",
"SSLClient",
"UNIXDatagramServer",
"UNIXDatagramClient",
"ClientService",
"backoffPolicy",
]

View File

@@ -0,0 +1,87 @@
# -*- test-case-name: twisted.test.test_application -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Plugin-based system for enumerating available reactors and installing one of
them.
"""
from typing import Iterable, cast
from zope.interface import Attribute, Interface, implementer
from twisted.internet.interfaces import IReactorCore
from twisted.plugin import IPlugin, getPlugins
from twisted.python.reflect import namedAny
class IReactorInstaller(Interface):
"""
Definition of a reactor which can probably be installed.
"""
shortName = Attribute(
"""
A brief string giving the user-facing name of this reactor.
"""
)
description = Attribute(
"""
A longer string giving a user-facing description of this reactor.
"""
)
def install() -> None:
"""
Install this reactor.
"""
# TODO - A method which provides a best-guess as to whether this reactor
# can actually be used in the execution environment.
class NoSuchReactor(KeyError):
"""
Raised when an attempt is made to install a reactor which cannot be found.
"""
@implementer(IPlugin, IReactorInstaller)
class Reactor:
"""
@ivar moduleName: The fully-qualified Python name of the module of which
the install callable is an attribute.
"""
def __init__(self, shortName: str, moduleName: str, description: str):
self.shortName = shortName
self.moduleName = moduleName
self.description = description
def install(self) -> None:
namedAny(self.moduleName).install()
def getReactorTypes() -> Iterable[IReactorInstaller]:
"""
Return an iterator of L{IReactorInstaller} plugins.
"""
return getPlugins(IReactorInstaller)
def installReactor(shortName: str) -> IReactorCore:
"""
Install the reactor with the given C{shortName} attribute.
@raise NoSuchReactor: If no reactor is found with a matching C{shortName}.
@raise Exception: Anything that the specified reactor can raise when installed.
"""
for installer in getReactorTypes():
if installer.shortName == shortName:
installer.install()
from twisted.internet import reactor
return cast(IReactorCore, reactor)
raise NoSuchReactor(shortName)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.runner.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Facilities for running a Twisted application.
"""

View File

@@ -0,0 +1,99 @@
# -*- test-case-name: twisted.application.runner.test.test_exit -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
System exit support.
"""
import typing
from enum import IntEnum
from sys import exit as sysexit, stderr, stdout
from typing import Union
try:
import posix as Status
except ImportError:
class Status: # type: ignore[no-redef]
"""
Object to hang C{EX_*} values off of as a substitute for L{posix}.
"""
EX__BASE = 64
EX_OK = 0
EX_USAGE = EX__BASE
EX_DATAERR = EX__BASE + 1
EX_NOINPUT = EX__BASE + 2
EX_NOUSER = EX__BASE + 3
EX_NOHOST = EX__BASE + 4
EX_UNAVAILABLE = EX__BASE + 5
EX_SOFTWARE = EX__BASE + 6
EX_OSERR = EX__BASE + 7
EX_OSFILE = EX__BASE + 8
EX_CANTCREAT = EX__BASE + 9
EX_IOERR = EX__BASE + 10
EX_TEMPFAIL = EX__BASE + 11
EX_PROTOCOL = EX__BASE + 12
EX_NOPERM = EX__BASE + 13
EX_CONFIG = EX__BASE + 14
class ExitStatus(IntEnum):
"""
Standard exit status codes for system programs.
@cvar EX_OK: Successful termination.
@cvar EX_USAGE: Command line usage error.
@cvar EX_DATAERR: Data format error.
@cvar EX_NOINPUT: Cannot open input.
@cvar EX_NOUSER: Addressee unknown.
@cvar EX_NOHOST: Host name unknown.
@cvar EX_UNAVAILABLE: Service unavailable.
@cvar EX_SOFTWARE: Internal software error.
@cvar EX_OSERR: System error (e.g., can't fork).
@cvar EX_OSFILE: Critical OS file missing.
@cvar EX_CANTCREAT: Can't create (user) output file.
@cvar EX_IOERR: Input/output error.
@cvar EX_TEMPFAIL: Temporary failure; the user is invited to retry.
@cvar EX_PROTOCOL: Remote error in protocol.
@cvar EX_NOPERM: Permission denied.
@cvar EX_CONFIG: Configuration error.
"""
EX_OK = Status.EX_OK
EX_USAGE = Status.EX_USAGE
EX_DATAERR = Status.EX_DATAERR
EX_NOINPUT = Status.EX_NOINPUT
EX_NOUSER = Status.EX_NOUSER
EX_NOHOST = Status.EX_NOHOST
EX_UNAVAILABLE = Status.EX_UNAVAILABLE
EX_SOFTWARE = Status.EX_SOFTWARE
EX_OSERR = Status.EX_OSERR
EX_OSFILE = Status.EX_OSFILE
EX_CANTCREAT = Status.EX_CANTCREAT
EX_IOERR = Status.EX_IOERR
EX_TEMPFAIL = Status.EX_TEMPFAIL
EX_PROTOCOL = Status.EX_PROTOCOL
EX_NOPERM = Status.EX_NOPERM
EX_CONFIG = Status.EX_CONFIG
def exit(status: Union[int, ExitStatus], message: str = "") -> "typing.NoReturn":
"""
Exit the python interpreter with the given status and an optional message.
@param status: An exit status. An appropriate value from L{ExitStatus} is
recommended.
@param message: An optional message to print.
"""
if message:
if status == ExitStatus.EX_OK:
out = stdout
else:
out = stderr
out.write(message)
out.write("\n")
sysexit(status)

View File

@@ -0,0 +1,282 @@
# -*- test-case-name: twisted.application.runner.test.test_pidfile -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
PID file.
"""
from __future__ import annotations
import errno
from os import getpid, kill, name as SYSTEM_NAME
from types import TracebackType
from typing import Any, Optional, Type
from zope.interface import Interface, implementer
from twisted.logger import Logger
from twisted.python.filepath import FilePath
class IPIDFile(Interface):
"""
Manages a file that remembers a process ID.
"""
def read() -> int:
"""
Read the process ID stored in this PID file.
@return: The contained process ID.
@raise NoPIDFound: If this PID file does not exist.
@raise EnvironmentError: If this PID file cannot be read.
@raise ValueError: If this PID file's content is invalid.
"""
def writeRunningPID() -> None:
"""
Store the PID of the current process in this PID file.
@raise EnvironmentError: If this PID file cannot be written.
"""
def remove() -> None:
"""
Remove this PID file.
@raise EnvironmentError: If this PID file cannot be removed.
"""
def isRunning() -> bool:
"""
Determine whether there is a running process corresponding to the PID
in this PID file.
@return: True if this PID file contains a PID and a process with that
PID is currently running; false otherwise.
@raise EnvironmentError: If this PID file cannot be read.
@raise InvalidPIDFileError: If this PID file's content is invalid.
@raise StalePIDFileError: If this PID file's content refers to a PID
for which there is no corresponding running process.
"""
def __enter__() -> "IPIDFile":
"""
Enter a context using this PIDFile.
Writes the PID file with the PID of the running process.
@raise AlreadyRunningError: A process corresponding to the PID in this
PID file is already running.
"""
def __exit__(
excType: Optional[Type[BaseException]],
excValue: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
"""
Exit a context using this PIDFile.
Removes the PID file.
"""
@implementer(IPIDFile)
class PIDFile:
"""
Concrete implementation of L{IPIDFile}.
This implementation is presently not supported on non-POSIX platforms.
Specifically, calling L{PIDFile.isRunning} will raise
L{NotImplementedError}.
"""
_log = Logger()
@staticmethod
def _format(pid: int) -> bytes:
"""
Format a PID file's content.
@param pid: A process ID.
@return: Formatted PID file contents.
"""
return f"{int(pid)}\n".encode()
def __init__(self, filePath: FilePath[Any]) -> None:
"""
@param filePath: The path to the PID file on disk.
"""
self.filePath = filePath
def read(self) -> int:
pidString = b""
try:
with self.filePath.open() as fh:
for pidString in fh:
break
except OSError as e:
if e.errno == errno.ENOENT: # No such file
raise NoPIDFound("PID file does not exist")
raise
try:
return int(pidString)
except ValueError:
raise InvalidPIDFileError(
f"non-integer PID value in PID file: {pidString!r}"
)
def _write(self, pid: int) -> None:
"""
Store a PID in this PID file.
@param pid: A PID to store.
@raise EnvironmentError: If this PID file cannot be written.
"""
self.filePath.setContent(self._format(pid=pid))
def writeRunningPID(self) -> None:
self._write(getpid())
def remove(self) -> None:
self.filePath.remove()
def isRunning(self) -> bool:
try:
pid = self.read()
except NoPIDFound:
return False
if SYSTEM_NAME == "posix":
return self._pidIsRunningPOSIX(pid)
else:
raise NotImplementedError(f"isRunning is not implemented on {SYSTEM_NAME}")
@staticmethod
def _pidIsRunningPOSIX(pid: int) -> bool:
"""
POSIX implementation for running process check.
Determine whether there is a running process corresponding to the given
PID.
@param pid: The PID to check.
@return: True if the given PID is currently running; false otherwise.
@raise EnvironmentError: If this PID file cannot be read.
@raise InvalidPIDFileError: If this PID file's content is invalid.
@raise StalePIDFileError: If this PID file's content refers to a PID
for which there is no corresponding running process.
"""
try:
kill(pid, 0)
except OSError as e:
if e.errno == errno.ESRCH: # No such process
raise StalePIDFileError("PID file refers to non-existing process")
elif e.errno == errno.EPERM: # Not permitted to kill
return True
else:
raise
else:
return True
def __enter__(self) -> "PIDFile":
try:
if self.isRunning():
raise AlreadyRunningError()
except StalePIDFileError:
self._log.info("Replacing stale PID file: {log_source}")
self.writeRunningPID()
return self
def __exit__(
self,
excType: Optional[Type[BaseException]],
excValue: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.remove()
return None
@implementer(IPIDFile)
class NonePIDFile:
"""
PID file implementation that does nothing.
This is meant to be used as a "active None" object in place of a PID file
when no PID file is desired.
"""
def __init__(self) -> None:
pass
def read(self) -> int:
raise NoPIDFound("PID file does not exist")
def _write(self, pid: int) -> None:
"""
Store a PID in this PID file.
@param pid: A PID to store.
@raise EnvironmentError: If this PID file cannot be written.
@note: This implementation always raises an L{EnvironmentError}.
"""
raise OSError(errno.EPERM, "Operation not permitted")
def writeRunningPID(self) -> None:
self._write(0)
def remove(self) -> None:
raise OSError(errno.ENOENT, "No such file or directory")
def isRunning(self) -> bool:
return False
def __enter__(self) -> "NonePIDFile":
return self
def __exit__(
self,
excType: Optional[Type[BaseException]],
excValue: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
return None
nonePIDFile: IPIDFile = NonePIDFile()
class AlreadyRunningError(Exception):
"""
Process is already running.
"""
class InvalidPIDFileError(Exception):
"""
PID file contents are invalid.
"""
class StalePIDFileError(Exception):
"""
PID file contents are valid, but there is no process with the referenced
PID.
"""
class NoPIDFound(Exception):
"""
No PID found in PID file.
"""

View File

@@ -0,0 +1,166 @@
# -*- test-case-name: twisted.application.runner.test.test_runner -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted application runner.
"""
from os import kill
from signal import SIGTERM
from sys import stderr
from typing import Any, Callable, Mapping, TextIO
from attr import Factory, attrib, attrs
from constantly import NamedConstant
from twisted.internet.interfaces import IReactorCore
from twisted.logger import (
FileLogObserver,
FilteringLogObserver,
Logger,
LogLevel,
LogLevelFilterPredicate,
globalLogBeginner,
textFileLogObserver,
)
from ._exit import ExitStatus, exit
from ._pidfile import AlreadyRunningError, InvalidPIDFileError, IPIDFile, nonePIDFile
@attrs(frozen=True)
class Runner:
"""
Twisted application runner.
@cvar _log: The logger attached to this class.
@ivar _reactor: The reactor to start and run the application in.
@ivar _pidFile: The file to store the running process ID in.
@ivar _kill: Whether this runner should kill an existing running
instance of the application.
@ivar _defaultLogLevel: The default log level to start the logging
system with.
@ivar _logFile: A file stream to write logging output to.
@ivar _fileLogObserverFactory: A factory for the file log observer to
use when starting the logging system.
@ivar _whenRunning: Hook to call after the reactor is running;
this is where the application code that relies on the reactor gets
called.
@ivar _whenRunningArguments: Keyword arguments to pass to
C{whenRunning} when it is called.
@ivar _reactorExited: Hook to call after the reactor exits.
@ivar _reactorExitedArguments: Keyword arguments to pass to
C{reactorExited} when it is called.
"""
_log = Logger()
_reactor = attrib(type=IReactorCore)
_pidFile = attrib(type=IPIDFile, default=nonePIDFile)
_kill = attrib(type=bool, default=False)
_defaultLogLevel = attrib(type=NamedConstant, default=LogLevel.info)
_logFile = attrib(type=TextIO, default=stderr)
_fileLogObserverFactory = attrib(
type=Callable[[TextIO], FileLogObserver], default=textFileLogObserver
)
_whenRunning = attrib(type=Callable[..., None], default=lambda **_: None)
_whenRunningArguments = attrib(type=Mapping[str, Any], default=Factory(dict))
_reactorExited = attrib(type=Callable[..., None], default=lambda **_: None)
_reactorExitedArguments = attrib(type=Mapping[str, Any], default=Factory(dict))
def run(self) -> None:
"""
Run this command.
"""
pidFile = self._pidFile
self.killIfRequested()
try:
with pidFile:
self.startLogging()
self.startReactor()
self.reactorExited()
except AlreadyRunningError:
exit(ExitStatus.EX_CONFIG, "Already running.")
# When testing, patched exit doesn't exit
return # type: ignore[unreachable]
def killIfRequested(self) -> None:
"""
If C{self._kill} is true, attempt to kill a running instance of the
application.
"""
pidFile = self._pidFile
if self._kill:
if pidFile is nonePIDFile:
exit(ExitStatus.EX_USAGE, "No PID file specified.")
# When testing, patched exit doesn't exit
return # type: ignore[unreachable]
try:
pid = pidFile.read()
except OSError:
exit(ExitStatus.EX_IOERR, "Unable to read PID file.")
# When testing, patched exit doesn't exit
return # type: ignore[unreachable]
except InvalidPIDFileError:
exit(ExitStatus.EX_DATAERR, "Invalid PID file.")
# When testing, patched exit doesn't exit
return # type: ignore[unreachable]
self.startLogging()
self._log.info("Terminating process: {pid}", pid=pid)
kill(pid, SIGTERM)
exit(ExitStatus.EX_OK)
# When testing, patched exit doesn't exit
return # type: ignore[unreachable]
def startLogging(self) -> None:
"""
Start the L{twisted.logger} logging system.
"""
logFile = self._logFile
fileLogObserverFactory = self._fileLogObserverFactory
fileLogObserver = fileLogObserverFactory(logFile)
logLevelPredicate = LogLevelFilterPredicate(
defaultLogLevel=self._defaultLogLevel
)
filteringObserver = FilteringLogObserver(fileLogObserver, [logLevelPredicate])
globalLogBeginner.beginLoggingTo([filteringObserver])
def startReactor(self) -> None:
"""
Register C{self._whenRunning} with the reactor so that it is called
once the reactor is running, then start the reactor.
"""
self._reactor.callWhenRunning(self.whenRunning)
self._log.info("Starting reactor...")
self._reactor.run()
def whenRunning(self) -> None:
"""
Call C{self._whenRunning} with C{self._whenRunningArguments}.
@note: This method is called after the reactor starts running.
"""
self._whenRunning(**self._whenRunningArguments)
def reactorExited(self) -> None:
"""
Call C{self._reactorExited} with C{self._reactorExitedArguments}.
@note: This method is called after the reactor exits.
"""
self._reactorExited(**self._reactorExitedArguments)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.runner.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner}.
"""

View File

@@ -0,0 +1,82 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._exit}.
"""
from io import StringIO
from typing import Optional, Union
import twisted.trial.unittest
from ...runner import _exit
from .._exit import ExitStatus, exit
class ExitTests(twisted.trial.unittest.TestCase):
"""
Tests for L{exit}.
"""
def setUp(self) -> None:
self.exit = DummyExit()
self.patch(_exit, "sysexit", self.exit)
def test_exitStatusInt(self) -> None:
"""
L{exit} given an L{int} status code will pass it to L{sys.exit}.
"""
status = 1234
exit(status)
self.assertEqual(self.exit.arg, status) # type: ignore[unreachable]
def test_exitConstant(self) -> None:
"""
L{exit} given a L{ValueConstant} status code passes the corresponding
value to L{sys.exit}.
"""
status = ExitStatus.EX_CONFIG
exit(status)
self.assertEqual(self.exit.arg, status.value) # type: ignore[unreachable]
def test_exitMessageZero(self) -> None:
"""
L{exit} given a status code of zero (C{0}) writes the given message to
standard output.
"""
out = StringIO()
self.patch(_exit, "stdout", out)
message = "Hello, world."
exit(0, message)
self.assertEqual(out.getvalue(), message + "\n") # type: ignore[unreachable]
def test_exitMessageNonZero(self) -> None:
"""
L{exit} given a non-zero status code writes the given message to
standard error.
"""
out = StringIO()
self.patch(_exit, "stderr", out)
message = "Hello, world."
exit(64, message)
self.assertEqual(out.getvalue(), message + "\n") # type: ignore[unreachable]
class DummyExit:
"""
Stub for L{sys.exit} that remembers whether it's been called and, if it
has, what argument it was given.
"""
def __init__(self) -> None:
self.exited = False
def __call__(self, arg: Optional[Union[int, str]] = None) -> None:
assert not self.exited
self.arg = arg
self.exited = True

View File

@@ -0,0 +1,419 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._pidfile}.
"""
import errno
from functools import wraps
from os import getpid, name as SYSTEM_NAME
from typing import Any, Callable, Optional
from zope.interface.verify import verifyObject
from typing_extensions import NoReturn
import twisted.trial.unittest
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform
from twisted.trial.unittest import SkipTest
from ...runner import _pidfile
from .._pidfile import (
AlreadyRunningError,
InvalidPIDFileError,
IPIDFile,
NonePIDFile,
NoPIDFound,
PIDFile,
StalePIDFileError,
)
def ifPlatformSupported(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator for tests that are not expected to work on all platforms.
Calling L{PIDFile.isRunning} currently raises L{NotImplementedError} on
non-POSIX platforms.
On an unsupported platform, we expect to see any test that calls
L{PIDFile.isRunning} to raise either L{NotImplementedError}, L{SkipTest},
or C{self.failureException}.
(C{self.failureException} may occur in a test that checks for a specific
exception but it gets NotImplementedError instead.)
@param f: The test method to decorate.
@return: The wrapped callable.
"""
@wraps(f)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
supported = platform.getType() == "posix"
if supported:
return f(self, *args, **kwargs)
else:
e = self.assertRaises(
(NotImplementedError, SkipTest, self.failureException),
f,
self,
*args,
**kwargs,
)
if isinstance(e, NotImplementedError):
self.assertTrue(str(e).startswith("isRunning is not implemented on "))
return wrapper
class PIDFileTests(twisted.trial.unittest.TestCase):
"""
Tests for L{PIDFile}.
"""
def filePath(self, content: Optional[bytes] = None) -> FilePath[str]:
filePath = FilePath(self.mktemp())
if content is not None:
filePath.setContent(content)
return filePath
def test_interface(self) -> None:
"""
L{PIDFile} conforms to L{IPIDFile}.
"""
pidFile = PIDFile(self.filePath())
verifyObject(IPIDFile, pidFile)
def test_formatWithPID(self) -> None:
"""
L{PIDFile._format} returns the expected format when given a PID.
"""
self.assertEqual(PIDFile._format(pid=1337), b"1337\n")
def test_readWithPID(self) -> None:
"""
L{PIDFile.read} returns the PID from the given file path.
"""
pid = 1337
pidFile = PIDFile(self.filePath(PIDFile._format(pid=pid)))
self.assertEqual(pid, pidFile.read())
def test_readEmptyPID(self) -> None:
"""
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
path.
"""
pidValue = b""
pidFile = PIDFile(self.filePath(b""))
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
self.assertEqual(str(e), f"non-integer PID value in PID file: {pidValue!r}")
def test_readWithBogusPID(self) -> None:
"""
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
path.
"""
pidValue = b"$foo!"
pidFile = PIDFile(self.filePath(pidValue))
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
self.assertEqual(str(e), f"non-integer PID value in PID file: {pidValue!r}")
def test_readDoesntExist(self) -> None:
"""
L{PIDFile.read} raises L{NoPIDFound} when given a non-existing file
path.
"""
pidFile = PIDFile(self.filePath())
e = self.assertRaises(NoPIDFound, pidFile.read)
self.assertEqual(str(e), "PID file does not exist")
def test_readOpenRaisesOSErrorNotENOENT(self) -> None:
"""
L{PIDFile.read} re-raises L{OSError} if the associated C{errno} is
anything other than L{errno.ENOENT}.
"""
def oops(mode: str = "r") -> NoReturn:
raise OSError(errno.EIO, "I/O error")
self.patch(FilePath, "open", oops)
pidFile = PIDFile(self.filePath())
error = self.assertRaises(OSError, pidFile.read)
self.assertEqual(error.errno, errno.EIO)
def test_writePID(self) -> None:
"""
L{PIDFile._write} stores the given PID.
"""
pid = 1995
pidFile = PIDFile(self.filePath())
pidFile._write(pid)
self.assertEqual(pidFile.read(), pid)
def test_writePIDInvalid(self) -> None:
"""
L{PIDFile._write} raises L{ValueError} when given an invalid PID.
"""
pidFile = PIDFile(self.filePath())
self.assertRaises(ValueError, pidFile._write, "burp")
def test_writeRunningPID(self) -> None:
"""
L{PIDFile.writeRunningPID} stores the PID for the current process.
"""
pidFile = PIDFile(self.filePath())
pidFile.writeRunningPID()
self.assertEqual(pidFile.read(), getpid())
def test_remove(self) -> None:
"""
L{PIDFile.remove} removes the PID file.
"""
pidFile = PIDFile(self.filePath(b""))
self.assertTrue(pidFile.filePath.exists())
pidFile.remove()
self.assertFalse(pidFile.filePath.exists())
@ifPlatformSupported
def test_isRunningDoesExist(self) -> None:
"""
L{PIDFile.isRunning} returns true for a process that does exist.
"""
pidFile = PIDFile(self.filePath())
pidFile._write(1337)
def kill(pid: int, signal: int) -> None:
return # Don't actually kill anything
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningThis(self) -> None:
"""
L{PIDFile.isRunning} returns true for this process (which is running).
@note: This differs from L{PIDFileTests.test_isRunningDoesExist} in
that it actually invokes the C{kill} system call, which is useful for
testing of our chosen method for probing the existence of a process.
"""
pidFile = PIDFile(self.filePath())
pidFile.writeRunningPID()
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningDoesNotExist(self) -> None:
"""
L{PIDFile.isRunning} raises L{StalePIDFileError} for a process that
does not exist (errno=ESRCH).
"""
pidFile = PIDFile(self.filePath())
pidFile._write(1337)
def kill(pid: int, signal: int) -> None:
raise OSError(errno.ESRCH, "No such process")
self.patch(_pidfile, "kill", kill)
self.assertRaises(StalePIDFileError, pidFile.isRunning)
@ifPlatformSupported
def test_isRunningNotAllowed(self) -> None:
"""
L{PIDFile.isRunning} returns true for a process that we are not allowed
to kill (errno=EPERM).
"""
pidFile = PIDFile(self.filePath())
pidFile._write(1337)
def kill(pid: int, signal: int) -> None:
raise OSError(errno.EPERM, "Operation not permitted")
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningInit(self) -> None:
"""
L{PIDFile.isRunning} returns true for a process that we are not allowed
to kill (errno=EPERM).
@note: This differs from L{PIDFileTests.test_isRunningNotAllowed} in
that it actually invokes the C{kill} system call, which is useful for
testing of our chosen method for probing the existence of a process
that we are not allowed to kill.
@note: In this case, we try killing C{init}, which is process #1 on
POSIX systems, so this test is not portable. C{init} should always be
running and should not be killable by non-root users.
"""
if SYSTEM_NAME != "posix":
raise SkipTest("This test assumes POSIX")
pidFile = PIDFile(self.filePath())
pidFile._write(1) # PID 1 is init on POSIX systems
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningUnknownErrno(self) -> None:
"""
L{PIDFile.isRunning} re-raises L{OSError} if the attached C{errno}
value from L{os.kill} is not an expected one.
"""
pidFile = PIDFile(self.filePath())
pidFile.writeRunningPID()
def kill(pid: int, signal: int) -> None:
raise OSError(errno.EEXIST, "File exists")
self.patch(_pidfile, "kill", kill)
self.assertRaises(OSError, pidFile.isRunning)
def test_isRunningNoPIDFile(self) -> None:
"""
L{PIDFile.isRunning} returns false if the PID file doesn't exist.
"""
pidFile = PIDFile(self.filePath())
self.assertFalse(pidFile.isRunning())
def test_contextManager(self) -> None:
"""
When used as a context manager, a L{PIDFile} will store the current pid
on entry, then removes the PID file on exit.
"""
pidFile = PIDFile(self.filePath())
self.assertFalse(pidFile.filePath.exists())
with pidFile:
self.assertTrue(pidFile.filePath.exists())
self.assertEqual(pidFile.read(), getpid())
self.assertFalse(pidFile.filePath.exists())
@ifPlatformSupported
def test_contextManagerDoesntExist(self) -> None:
"""
When used as a context manager, a L{PIDFile} will replace the
underlying PIDFile rather than raising L{AlreadyRunningError} if the
contained PID file exists but refers to a non-running PID.
"""
pidFile = PIDFile(self.filePath())
pidFile._write(1337)
def kill(pid: int, signal: int) -> None:
raise OSError(errno.ESRCH, "No such process")
self.patch(_pidfile, "kill", kill)
e = self.assertRaises(StalePIDFileError, pidFile.isRunning)
self.assertEqual(str(e), "PID file refers to non-existing process")
with pidFile:
self.assertEqual(pidFile.read(), getpid())
@ifPlatformSupported
def test_contextManagerAlreadyRunning(self) -> None:
"""
When used as a context manager, a L{PIDFile} will raise
L{AlreadyRunningError} if the there is already a running process with
the contained PID.
"""
pidFile = PIDFile(self.filePath())
pidFile._write(1337)
def kill(pid: int, signal: int) -> None:
return # Don't actually kill anything
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
self.assertRaises(AlreadyRunningError, pidFile.__enter__)
class NonePIDFileTests(twisted.trial.unittest.TestCase):
"""
Tests for L{NonePIDFile}.
"""
def test_interface(self) -> None:
"""
L{NonePIDFile} conforms to L{IPIDFile}.
"""
pidFile = NonePIDFile()
verifyObject(IPIDFile, pidFile)
def test_read(self) -> None:
"""
L{NonePIDFile.read} raises L{NoPIDFound}.
"""
pidFile = NonePIDFile()
e = self.assertRaises(NoPIDFound, pidFile.read)
self.assertEqual(str(e), "PID file does not exist")
def test_write(self) -> None:
"""
L{NonePIDFile._write} raises L{OSError} with an errno of L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile._write, 0)
self.assertEqual(error.errno, errno.EPERM)
def test_writeRunningPID(self) -> None:
"""
L{NonePIDFile.writeRunningPID} raises L{OSError} with an errno of
L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile.writeRunningPID)
self.assertEqual(error.errno, errno.EPERM)
def test_remove(self) -> None:
"""
L{NonePIDFile.remove} raises L{OSError} with an errno of L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile.remove)
self.assertEqual(error.errno, errno.ENOENT)
def test_isRunning(self) -> None:
"""
L{NonePIDFile.isRunning} returns L{False}.
"""
pidFile = NonePIDFile()
self.assertEqual(pidFile.isRunning(), False)
def test_contextManager(self) -> None:
"""
When used as a context manager, a L{NonePIDFile} doesn't raise, despite
not existing.
"""
pidFile = NonePIDFile()
with pidFile:
pass

View File

@@ -0,0 +1,454 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._runner}.
"""
import errno
from io import StringIO
from signal import SIGTERM
from types import TracebackType
from typing import Any, Iterable, List, Optional, TextIO, Tuple, Type, Union, cast
from attr import Factory, attrib, attrs
import twisted.trial.unittest
from twisted.internet.testing import MemoryReactor
from twisted.logger import (
FileLogObserver,
FilteringLogObserver,
ILogObserver,
LogBeginner,
LogLevel,
LogLevelFilterPredicate,
LogPublisher,
)
from twisted.python.filepath import FilePath
from ...runner import _runner
from .._exit import ExitStatus
from .._pidfile import NonePIDFile, PIDFile
from .._runner import Runner
class RunnerTests(twisted.trial.unittest.TestCase):
"""
Tests for L{Runner}.
"""
def filePath(self, content: Optional[bytes] = None) -> FilePath[str]:
filePath = FilePath(self.mktemp())
if content is not None:
filePath.setContent(content)
return filePath
def setUp(self) -> None:
# Patch exit and kill so we can capture usage and prevent actual exits
# and kills.
self.exit = DummyExit()
self.kill = DummyKill()
self.patch(_runner, "exit", self.exit)
self.patch(_runner, "kill", self.kill)
# Patch getpid so we get a known result
self.pid = 1337
self.pidFileContent = f"{self.pid}\n".encode()
# Patch globalLogBeginner so that we aren't trying to install multiple
# global log observers.
self.stdout = StringIO()
self.stderr = StringIO()
self.stdio = DummyStandardIO(self.stdout, self.stderr)
self.warnings = DummyWarningsModule()
self.globalLogPublisher = LogPublisher()
self.globalLogBeginner = LogBeginner(
self.globalLogPublisher,
self.stdio.stderr,
self.stdio,
self.warnings,
)
self.patch(_runner, "stderr", self.stderr)
self.patch(_runner, "globalLogBeginner", self.globalLogBeginner)
def test_runInOrder(self) -> None:
"""
L{Runner.run} calls the expected methods in order.
"""
runner = DummyRunner(reactor=MemoryReactor())
runner.run()
self.assertEqual(
runner.calledMethods,
[
"killIfRequested",
"startLogging",
"startReactor",
"reactorExited",
],
)
def test_runUsesPIDFile(self) -> None:
"""
L{Runner.run} uses the provided PID file.
"""
pidFile = DummyPIDFile()
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
self.assertFalse(pidFile.entered)
self.assertFalse(pidFile.exited)
runner.run()
self.assertTrue(pidFile.entered)
self.assertTrue(pidFile.exited)
def test_runAlreadyRunning(self) -> None:
"""
L{Runner.run} exits with L{ExitStatus.EX_USAGE} and the expected
message if a process is already running that corresponds to the given
PID file.
"""
pidFile = PIDFile(self.filePath(self.pidFileContent))
pidFile.isRunning = lambda: True # type: ignore[method-assign]
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
runner.run()
self.assertEqual(self.exit.status, ExitStatus.EX_CONFIG)
self.assertEqual(self.exit.message, "Already running.")
def test_killNotRequested(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is false doesn't exit and
doesn't indiscriminately murder anyone.
"""
runner = Runner(reactor=MemoryReactor())
runner.killIfRequested()
self.assertEqual(self.kill.calls, [])
self.assertFalse(self.exit.exited)
def test_killRequestedWithoutPIDFile(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is true but C{pidFile} is
L{nonePIDFile} exits with L{ExitStatus.EX_USAGE} and the expected
message; and also doesn't indiscriminately murder anyone.
"""
runner = Runner(reactor=MemoryReactor(), kill=True)
runner.killIfRequested()
self.assertEqual(self.kill.calls, [])
self.assertEqual(self.exit.status, ExitStatus.EX_USAGE)
self.assertEqual(self.exit.message, "No PID file specified.")
def test_killRequestedWithPIDFile(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
performs a targeted killing of the appropriate process.
"""
pidFile = PIDFile(self.filePath(self.pidFileContent))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.kill.calls, [(self.pid, SIGTERM)])
self.assertEqual(self.exit.status, ExitStatus.EX_OK)
self.assertIdentical(self.exit.message, None)
def test_killRequestedWithPIDFileCantRead(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
that it can't read exits with L{ExitStatus.EX_IOERR}.
"""
pidFile = PIDFile(self.filePath(None))
def read() -> int:
raise OSError(errno.EACCES, "Permission denied")
pidFile.read = read # type: ignore[method-assign]
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_IOERR)
self.assertEqual(self.exit.message, "Unable to read PID file.")
def test_killRequestedWithPIDFileEmpty(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
containing no value exits with L{ExitStatus.EX_DATAERR}.
"""
pidFile = PIDFile(self.filePath(b""))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
self.assertEqual(self.exit.message, "Invalid PID file.")
def test_killRequestedWithPIDFileNotAnInt(self) -> None:
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
containing a non-integer value exits with L{ExitStatus.EX_DATAERR}.
"""
pidFile = PIDFile(self.filePath(b"** totally not a number, dude **"))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
self.assertEqual(self.exit.message, "Invalid PID file.")
def test_startLogging(self) -> None:
"""
L{Runner.startLogging} sets up a filtering observer with a log level
predicate set to the given log level that contains a file observer of
the given type which writes to the given file.
"""
logFile = StringIO()
# Patch the log beginner so that we don't try to start the already
# running (started by trial) logging system.
class LogBeginner:
observers: List[ILogObserver] = []
def beginLoggingTo(self, observers: Iterable[ILogObserver]) -> None:
LogBeginner.observers = list(observers)
self.patch(_runner, "globalLogBeginner", LogBeginner())
# Patch FilteringLogObserver so we can capture its arguments
class MockFilteringLogObserver(FilteringLogObserver):
observer: Optional[ILogObserver] = None
predicates: List[LogLevelFilterPredicate] = []
def __init__(
self,
observer: ILogObserver,
predicates: Iterable[LogLevelFilterPredicate],
negativeObserver: ILogObserver = cast(ILogObserver, lambda event: None),
) -> None:
MockFilteringLogObserver.observer = observer
MockFilteringLogObserver.predicates = list(predicates)
FilteringLogObserver.__init__(
self, observer, predicates, negativeObserver
)
self.patch(_runner, "FilteringLogObserver", MockFilteringLogObserver)
# Patch FileLogObserver so we can capture its arguments
class MockFileLogObserver(FileLogObserver):
outFile: Optional[TextIO] = None
def __init__(self, outFile: TextIO) -> None:
MockFileLogObserver.outFile = outFile
FileLogObserver.__init__(self, outFile, str)
# Start logging
runner = Runner(
reactor=MemoryReactor(),
defaultLogLevel=LogLevel.critical,
logFile=logFile,
fileLogObserverFactory=MockFileLogObserver,
)
runner.startLogging()
# Check for a filtering observer
self.assertEqual(len(LogBeginner.observers), 1)
self.assertIsInstance(LogBeginner.observers[0], FilteringLogObserver)
# Check log level predicate with the correct default log level
self.assertEqual(len(MockFilteringLogObserver.predicates), 1)
self.assertIsInstance(
MockFilteringLogObserver.predicates[0], LogLevelFilterPredicate
)
self.assertIdentical(
MockFilteringLogObserver.predicates[0].defaultLogLevel, LogLevel.critical
)
# Check for a file observer attached to the filtering observer
observer = cast(MockFileLogObserver, MockFilteringLogObserver.observer)
self.assertIsInstance(observer, MockFileLogObserver)
# Check for the file we gave it
self.assertIdentical(observer.outFile, logFile)
def test_startReactorWithReactor(self) -> None:
"""
L{Runner.startReactor} with the C{reactor} argument runs the given
reactor.
"""
reactor = MemoryReactor()
runner = Runner(reactor=reactor)
runner.startReactor()
self.assertTrue(reactor.hasRun)
def test_startReactorWhenRunning(self) -> None:
"""
L{Runner.startReactor} ensures that C{whenRunning} is called with
C{whenRunningArguments} when the reactor is running.
"""
self._testHook("whenRunning", "startReactor")
def test_whenRunningWithArguments(self) -> None:
"""
L{Runner.whenRunning} calls C{whenRunning} with
C{whenRunningArguments}.
"""
self._testHook("whenRunning")
def test_reactorExitedWithArguments(self) -> None:
"""
L{Runner.whenRunning} calls C{reactorExited} with
C{reactorExitedArguments}.
"""
self._testHook("reactorExited")
def _testHook(self, methodName: str, callerName: Optional[str] = None) -> None:
"""
Verify that the named hook is run with the expected arguments as
specified by the arguments used to create the L{Runner}, when the
specified caller is invoked.
@param methodName: The name of the hook to verify.
@param callerName: The name of the method that is expected to cause the
hook to be called.
If C{None}, use the L{Runner} method with the same name as the
hook.
"""
if callerName is None:
callerName = methodName
arguments = dict(a=object(), b=object(), c=object())
argumentsSeen = []
def hook(**arguments: object) -> None:
argumentsSeen.append(arguments)
runnerArguments = {
methodName: hook,
f"{methodName}Arguments": arguments.copy(),
}
runner = Runner(
reactor=MemoryReactor(), **runnerArguments # type: ignore[arg-type]
)
hookCaller = getattr(runner, callerName)
hookCaller()
self.assertEqual(len(argumentsSeen), 1)
self.assertEqual(argumentsSeen[0], arguments)
@attrs(frozen=True)
class DummyRunner(Runner):
"""
Stub for L{Runner}.
Keep track of calls to some methods without actually doing anything.
"""
calledMethods = attrib(type=List[str], default=Factory(list))
def killIfRequested(self) -> None:
self.calledMethods.append("killIfRequested")
def startLogging(self) -> None:
self.calledMethods.append("startLogging")
def startReactor(self) -> None:
self.calledMethods.append("startReactor")
def reactorExited(self) -> None:
self.calledMethods.append("reactorExited")
class DummyPIDFile(NonePIDFile):
"""
Stub for L{PIDFile}.
Tracks context manager entry/exit without doing anything.
"""
def __init__(self) -> None:
NonePIDFile.__init__(self)
self.entered = False
self.exited = False
def __enter__(self) -> "DummyPIDFile":
self.entered = True
return self
def __exit__(
self,
excType: Optional[Type[BaseException]],
excValue: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.exited = True
class DummyExit:
"""
Stub for L{_exit.exit} that remembers whether it's been called and, if it has,
what arguments it was given.
"""
def __init__(self) -> None:
self.exited = False
def __call__(
self, status: Union[int, ExitStatus], message: Optional[str] = None
) -> None:
assert not self.exited
self.status = status
self.message = message
self.exited = True
class DummyKill:
"""
Stub for L{os.kill} that remembers whether it's been called and, if it has,
what arguments it was given.
"""
def __init__(self) -> None:
self.calls: List[Tuple[int, int]] = []
def __call__(self, pid: int, sig: int) -> None:
self.calls.append((pid, sig))
class DummyStandardIO:
"""
Stub for L{sys} which provides L{StringIO} streams as stdout and stderr.
"""
def __init__(self, stdout: TextIO, stderr: TextIO) -> None:
self.stdout = stdout
self.stderr = stderr
class DummyWarningsModule:
"""
Stub for L{warnings} which provides a C{showwarning} method that is a no-op.
"""
def showwarning(*args: Any, **kwargs: Any) -> None:
"""
Do nothing.
@param args: ignored.
@param kwargs: ignored.
"""

View File

@@ -0,0 +1,420 @@
# -*- test-case-name: twisted.application.test.test_service -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Service architecture for Twisted.
Services are arranged in a hierarchy. At the leafs of the hierarchy,
the services which actually interact with the outside world are started.
Services can be named or anonymous -- usually, they will be named if
there is need to access them through the hierarchy (from a parent or
a sibling).
Maintainer: Moshe Zadka
"""
from zope.interface import Attribute, Interface, implementer
from twisted.internet import defer
from twisted.persisted import sob
from twisted.plugin import IPlugin
from twisted.python import components
from twisted.python.reflect import namedAny
class IServiceMaker(Interface):
"""
An object which can be used to construct services in a flexible
way.
This interface should most often be implemented along with
L{twisted.plugin.IPlugin}, and will most often be used by the
'twistd' command.
"""
tapname = Attribute(
"A short string naming this Twisted plugin, for example 'web' or "
"'pencil'. This name will be used as the subcommand of 'twistd'."
)
description = Attribute(
"A brief summary of the features provided by this "
"Twisted application plugin."
)
options = Attribute(
"A C{twisted.python.usage.Options} subclass defining the "
"configuration options for this application."
)
def makeService(options):
"""
Create and return an object providing
L{twisted.application.service.IService}.
@param options: A mapping (typically a C{dict} or
L{twisted.python.usage.Options} instance) of configuration
options to desired configuration values.
"""
@implementer(IPlugin, IServiceMaker)
class ServiceMaker:
"""
Utility class to simplify the definition of L{IServiceMaker} plugins.
"""
def __init__(self, name, module, description, tapname):
self.name = name
self.module = module
self.description = description
self.tapname = tapname
@property
def options(self):
return namedAny(self.module).Options
@property
def makeService(self):
return namedAny(self.module).makeService
class IService(Interface):
"""
A service.
Run start-up and shut-down code at the appropriate times.
"""
name = Attribute("A C{str} which is the name of the service or C{None}.")
running = Attribute("A C{boolean} which indicates whether the service is running.")
parent = Attribute("An C{IServiceCollection} which is the parent or C{None}.")
def setName(name):
"""
Set the name of the service.
@type name: C{str}
@raise RuntimeError: Raised if the service already has a parent.
"""
def setServiceParent(parent):
"""
Set the parent of the service. This method is responsible for setting
the C{parent} attribute on this service (the child service).
@type parent: L{IServiceCollection}
@raise RuntimeError: Raised if the service already has a parent
or if the service has a name and the parent already has a child
by that name.
"""
def disownServiceParent():
"""
Use this API to remove an L{IService} from an L{IServiceCollection}.
This method is used symmetrically with L{setServiceParent} in that it
sets the C{parent} attribute on the child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
def startService():
"""
Start the service.
"""
def stopService():
"""
Stop the service.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
def privilegedStartService():
"""
Do preparation work for starting the service.
Here things which should be done before changing directory,
root or shedding privileges are done.
"""
@implementer(IService)
class Service:
"""
Base class for services.
Most services should inherit from this class. It handles the
book-keeping responsibilities of starting and stopping, as well
as not serializing this book-keeping information.
"""
running = 0
name = None
parent = None
def __getstate__(self):
dict = self.__dict__.copy()
if "running" in dict:
del dict["running"]
return dict
def setName(self, name):
if self.parent is not None:
raise RuntimeError("cannot change name when parent exists")
self.name = name
def setServiceParent(self, parent):
if self.parent is not None:
self.disownServiceParent()
parent = IServiceCollection(parent, parent)
self.parent = parent
self.parent.addService(self)
def disownServiceParent(self):
d = self.parent.removeService(self)
self.parent = None
return d
def privilegedStartService(self):
pass
def startService(self):
self.running = 1
def stopService(self):
self.running = 0
class IServiceCollection(Interface):
"""
Collection of services.
Contain several services, and manage their start-up/shut-down.
Services can be accessed by name if they have a name, and it
is always possible to iterate over them.
"""
def getServiceNamed(name):
"""
Get the child service with a given name.
@type name: C{str}
@rtype: L{IService}
@raise KeyError: Raised if the service has no child with the
given name.
"""
def __iter__():
"""
Get an iterator over all child services.
"""
def addService(service):
"""
Add a child service.
Only implementations of L{IService.setServiceParent} should use this
method.
@type service: L{IService}
@raise RuntimeError: Raised if the service has a child with
the given name.
"""
def removeService(service):
"""
Remove a child service.
Only implementations of L{IService.disownServiceParent} should
use this method.
@type service: L{IService}
@raise ValueError: Raised if the given service is not a child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
@implementer(IServiceCollection)
class MultiService(Service):
"""
Straightforward Service Container.
Hold a collection of services, and manage them in a simplistic
way. No service will wait for another, but this object itself
will not finish shutting down until all of its child services
will finish.
"""
def __init__(self):
self.services = []
self.namedServices = {}
self.parent = None
def privilegedStartService(self):
Service.privilegedStartService(self)
for service in self:
service.privilegedStartService()
def startService(self):
Service.startService(self)
for service in self:
service.startService()
def stopService(self):
Service.stopService(self)
l = []
services = list(self)
services.reverse()
for service in services:
l.append(defer.maybeDeferred(service.stopService))
return defer.DeferredList(l)
def getServiceNamed(self, name):
return self.namedServices[name]
def __iter__(self):
return iter(self.services)
def addService(self, service):
if service.name is not None:
if service.name in self.namedServices:
raise RuntimeError(
"cannot have two services with same name" " '%s'" % service.name
)
self.namedServices[service.name] = service
self.services.append(service)
if self.running:
# It may be too late for that, but we will do our best
service.privilegedStartService()
service.startService()
def removeService(self, service):
if service.name:
del self.namedServices[service.name]
self.services.remove(service)
if self.running:
# Returning this so as not to lose information from the
# MultiService.stopService deferred.
return service.stopService()
else:
return None
class IProcess(Interface):
"""
Process running parameters.
Represents parameters for how processes should be run.
"""
processName = Attribute(
"""
A C{str} giving the name the process should have in ps (or L{None}
to leave the name alone).
"""
)
uid = Attribute(
"""
An C{int} giving the user id as which the process should run (or
L{None} to leave the UID alone).
"""
)
gid = Attribute(
"""
An C{int} giving the group id as which the process should run (or
L{None} to leave the GID alone).
"""
)
@implementer(IProcess)
class Process:
"""
Process running parameters.
Sets up uid/gid in the constructor, and has a default
of L{None} as C{processName}.
"""
processName = None
def __init__(self, uid=None, gid=None):
"""
Set uid and gid.
@param uid: The user ID as whom to execute the process. If
this is L{None}, no attempt will be made to change the UID.
@param gid: The group ID as whom to execute the process. If
this is L{None}, no attempt will be made to change the GID.
"""
self.uid = uid
self.gid = gid
def Application(name, uid=None, gid=None):
"""
Return a compound class.
Return an object supporting the L{IService}, L{IServiceCollection},
L{IProcess} and L{sob.IPersistable} interfaces, with the given
parameters. Always access the return value by explicit casting to
one of the interfaces.
"""
ret = components.Componentized()
availableComponents = [MultiService(), Process(uid, gid), sob.Persistent(ret, name)]
for comp in availableComponents:
ret.addComponent(comp, ignoreClass=1)
IService(ret).setName(name)
return ret
def loadApplication(filename, kind, passphrase=None):
"""
Load Application from a given file.
The serialization format it was saved in should be given as
C{kind}, and is one of C{pickle}, C{source}, C{xml} or C{python}. If
C{passphrase} is given, the application was encrypted with the
given passphrase.
@type filename: C{str}
@type kind: C{str}
@type passphrase: C{str}
"""
if kind == "python":
application = sob.loadValueFromFile(filename, "application")
else:
application = sob.load(filename, kind)
return application
__all__ = [
"IServiceMaker",
"IService",
"Service",
"IServiceCollection",
"MultiService",
"IProcess",
"Process",
"Application",
"loadApplication",
]

View File

@@ -0,0 +1,83 @@
# -*- test-case-name: twisted.test.test_strports -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Construct listening port services from a simple string description.
@see: L{twisted.internet.endpoints.serverFromString}
@see: L{twisted.internet.endpoints.clientFromString}
"""
from typing import Optional, cast
from twisted.application.internet import StreamServerEndpointService
from twisted.internet import endpoints, interfaces
def _getReactor() -> interfaces.IReactorCore:
from twisted.internet import reactor
return cast(interfaces.IReactorCore, reactor)
def service(
description: str,
factory: interfaces.IProtocolFactory,
reactor: Optional[interfaces.IReactorCore] = None,
) -> StreamServerEndpointService:
"""
Return the service corresponding to a description.
@param description: The description of the listening port, in the syntax
described by L{twisted.internet.endpoints.serverFromString}.
@type description: C{str}
@param factory: The protocol factory which will build protocols for
connections to this service.
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@rtype: C{twisted.application.service.IService}
@return: the service corresponding to a description of a reliable stream
server.
@see: L{twisted.internet.endpoints.serverFromString}
"""
if reactor is None:
reactor = _getReactor()
svc = StreamServerEndpointService(
endpoints.serverFromString(reactor, description), factory
)
svc._raiseSynchronously = True
return svc
def listen(
description: str, factory: interfaces.IProtocolFactory
) -> interfaces.IListeningPort:
"""
Listen on a port corresponding to a description.
@param description: The description of the connecting port, in the syntax
described by L{twisted.internet.endpoints.serverFromString}.
@type description: L{str}
@param factory: The protocol factory which will build protocols on
connection.
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@rtype: L{twisted.internet.interfaces.IListeningPort}
@return: the port corresponding to a description of a reliable virtual
circuit server.
@see: L{twisted.internet.endpoints.serverFromString}
"""
from twisted.internet import reactor
name, args, kw = endpoints._parseServer(description, factory)
return cast(
interfaces.IListeningPort, getattr(reactor, "listen" + name)(*args, **kw)
)
__all__ = ["service", "listen"]

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.application}.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,175 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.service}.
"""
from zope.interface import implementer
from zope.interface.exceptions import BrokenImplementation
from zope.interface.verify import verifyObject
from twisted.application.service import (
Application,
IProcess,
IService,
IServiceCollection,
Service,
)
from twisted.persisted.sob import IPersistable
from twisted.trial.unittest import TestCase
@implementer(IService)
class AlmostService:
"""
Implement IService in a way that can fail.
In general, classes should maintain invariants that adhere
to the interfaces that they claim to implement --
otherwise, it is a bug.
This is a buggy class -- the IService implementation is fragile,
and several methods will break it. These bugs are intentional,
as the tests trigger them -- and then check that the class,
indeed, no longer complies with the interface (IService)
that it claims to comply with.
Since the verification will, by definition, only fail on buggy classes --
in other words, those which do not actually support the interface they
claim to support, we have to write a buggy class to properly verify
the interface.
"""
def __init__(self, name: str, parent: IServiceCollection, running: bool) -> None:
self.name = name
self.parent = parent
self.running = running
def makeInvalidByDeletingName(self) -> None:
"""
Probably not a wise method to call.
This method removes the :code:`name` attribute,
which has to exist in IService classes.
"""
del self.name
def makeInvalidByDeletingParent(self) -> None:
"""
Probably not a wise method to call.
This method removes the :code:`parent` attribute,
which has to exist in IService classes.
"""
del self.parent
def makeInvalidByDeletingRunning(self) -> None:
"""
Probably not a wise method to call.
This method removes the :code:`running` attribute,
which has to exist in IService classes.
"""
del self.running
def setName(self, name: object) -> None:
"""
See L{twisted.application.service.IService}.
@param name: ignored
"""
def setServiceParent(self, parent: object) -> None:
"""
See L{twisted.application.service.IService}.
@param parent: ignored
"""
def disownServiceParent(self) -> None:
"""
See L{twisted.application.service.IService}.
"""
def privilegedStartService(self) -> None:
"""
See L{twisted.application.service.IService}.
"""
def startService(self) -> None:
"""
See L{twisted.application.service.IService}.
"""
def stopService(self) -> None:
"""
See L{twisted.application.service.IService}.
"""
class ServiceInterfaceTests(TestCase):
"""
Tests for L{twisted.application.service.IService} implementation.
"""
def setUp(self) -> None:
"""
Build something that implements IService.
"""
self.almostService = AlmostService(parent=None, running=False, name=None) # type: ignore[arg-type]
def test_realService(self) -> None:
"""
Service implements IService.
"""
myService = Service()
verifyObject(IService, myService)
def test_hasAll(self) -> None:
"""
AlmostService implements IService.
"""
verifyObject(IService, self.almostService)
def test_noName(self) -> None:
"""
AlmostService with no name does not implement IService.
"""
self.almostService.makeInvalidByDeletingName()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
def test_noParent(self) -> None:
"""
AlmostService with no parent does not implement IService.
"""
self.almostService.makeInvalidByDeletingParent()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
def test_noRunning(self) -> None:
"""
AlmostService with no running does not implement IService.
"""
self.almostService.makeInvalidByDeletingRunning()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
class ApplicationTests(TestCase):
"""
Tests for L{twisted.application.service.Application}.
"""
def test_applicationComponents(self) -> None:
"""
Check L{twisted.application.service.Application} instantiation.
"""
app = Application("app-name")
self.assertTrue(verifyObject(IService, IService(app)))
self.assertTrue(verifyObject(IServiceCollection, IServiceCollection(app)))
self.assertTrue(verifyObject(IProcess, IProcess(app)))
self.assertTrue(verifyObject(IPersistable, IPersistable(app)))

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.twist.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
C{twist} command line tool.
"""

View File

@@ -0,0 +1,207 @@
# -*- test-case-name: twisted.application.twist.test.test_options -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Command line options for C{twist}.
"""
import typing
from sys import stderr, stdout
from textwrap import dedent
from typing import Callable, Iterable, Mapping, Optional, Sequence, Tuple, cast
from twisted.copyright import version
from twisted.internet.interfaces import IReactorCore
from twisted.logger import (
InvalidLogLevelError,
LogLevel,
jsonFileLogObserver,
textFileLogObserver,
)
from twisted.plugin import getPlugins
from twisted.python.usage import Options, UsageError
from ..reactors import NoSuchReactor, getReactorTypes, installReactor
from ..runner._exit import ExitStatus, exit
from ..service import IServiceMaker
openFile = open
def _update_doc(opt: Callable[["TwistOptions", str], None], **kwargs: str) -> None:
"""
Update the docstring of a method that implements an option.
The string is dedented and the given keyword arguments are substituted.
"""
opt.__doc__ = dedent(opt.__doc__ or "").format(**kwargs)
class TwistOptions(Options):
"""
Command line options for C{twist}.
"""
defaultReactorName = "default"
defaultLogLevel = LogLevel.info
def __init__(self) -> None:
Options.__init__(self)
self["reactorName"] = self.defaultReactorName
self["logLevel"] = self.defaultLogLevel
self["logFile"] = stdout
# An empty long description is explicitly set here as otherwise
# when executing from distributed trial twisted.python.usage will
# pull the description from `__main__` which is another entry point.
self.longdesc = ""
def getSynopsis(self) -> str:
return f"{Options.getSynopsis(self)} plugin [plugin_options]"
def opt_version(self) -> "typing.NoReturn":
"""
Print version and exit.
"""
exit(ExitStatus.EX_OK, f"{version}")
def opt_reactor(self, name: str) -> None:
"""
The name of the reactor to use.
(options: {options})
"""
# Actually actually actually install the reactor right at this very
# moment, before any other code (for example, a sub-command plugin)
# runs and accidentally imports and installs the default reactor.
try:
self["reactor"] = self.installReactor(name)
except NoSuchReactor:
raise UsageError(f"Unknown reactor: {name}")
else:
self["reactorName"] = name
_update_doc(
opt_reactor,
options=", ".join(f'"{rt.shortName}"' for rt in getReactorTypes()),
)
def installReactor(self, name: str) -> IReactorCore:
"""
Install the reactor.
"""
if name == self.defaultReactorName:
from twisted.internet import reactor
return cast(IReactorCore, reactor)
else:
return installReactor(name)
def opt_log_level(self, levelName: str) -> None:
"""
Set default log level.
(options: {options}; default: "{default}")
"""
try:
self["logLevel"] = LogLevel.levelWithName(levelName)
except InvalidLogLevelError:
raise UsageError(f"Invalid log level: {levelName}")
_update_doc(
opt_log_level,
options=", ".join(
f'"{constant.name}"' for constant in LogLevel.iterconstants()
),
default=defaultLogLevel.name,
)
def opt_log_file(self, fileName: str) -> None:
"""
Log to file. ("-" for stdout, "+" for stderr; default: "-")
"""
if fileName == "-":
self["logFile"] = stdout
return
if fileName == "+":
self["logFile"] = stderr
return
try:
self["logFile"] = openFile(fileName, "a")
except OSError as e:
exit(
ExitStatus.EX_IOERR,
f"Unable to open log file {fileName!r}: {e}",
)
def opt_log_format(self, format: str) -> None:
"""
Log file format.
(options: "text", "json"; default: "text" if the log file is a tty,
otherwise "json")
"""
format = format.lower()
if format == "text":
self["fileLogObserverFactory"] = textFileLogObserver
elif format == "json":
self["fileLogObserverFactory"] = jsonFileLogObserver
else:
raise UsageError(f"Invalid log format: {format}")
self["logFormat"] = format
_update_doc(opt_log_format)
def selectDefaultLogObserver(self) -> None:
"""
Set C{fileLogObserverFactory} to the default appropriate for the
chosen C{logFile}.
"""
if "fileLogObserverFactory" not in self:
logFile = self["logFile"]
if hasattr(logFile, "isatty") and logFile.isatty():
self["fileLogObserverFactory"] = textFileLogObserver
self["logFormat"] = "text"
else:
self["fileLogObserverFactory"] = jsonFileLogObserver
self["logFormat"] = "json"
def parseOptions(self, options: Optional[Sequence[str]] = None) -> None:
self.selectDefaultLogObserver()
Options.parseOptions(self, options=options)
if "reactor" not in self:
self["reactor"] = self.installReactor(self["reactorName"])
@property
def plugins(self) -> Mapping[str, IServiceMaker]:
if "plugins" not in self:
plugins = {}
for plugin in getPlugins(IServiceMaker):
plugins[plugin.tapname] = plugin
self["plugins"] = plugins
return cast(Mapping[str, IServiceMaker], self["plugins"])
@property
def subCommands(
self,
) -> Iterable[Tuple[str, None, Callable[[IServiceMaker], Options], str]]:
plugins = self.plugins
for name in sorted(plugins):
plugin = plugins[name]
# Don't pass plugin.options along in order to avoid resolving the
# options attribute right away, in case it's a property with a
# non-trivial getter (eg, one which imports modules).
def options(plugin: IServiceMaker = plugin) -> Options:
return cast(Options, plugin.options())
yield (plugin.tapname, None, options, plugin.description)
def postOptions(self) -> None:
Options.postOptions(self)
if self.subCommand is None:
raise UsageError("No plugin specified.")

View File

@@ -0,0 +1,114 @@
# -*- test-case-name: twisted.application.twist.test.test_twist -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Run a Twisted application.
"""
import sys
from typing import Sequence
from twisted.application.app import _exitWithSignal
from twisted.internet.interfaces import IReactorCore, _ISupportsExitSignalCapturing
from twisted.python.usage import Options, UsageError
from ..runner._exit import ExitStatus, exit
from ..runner._runner import Runner
from ..service import Application, IService, IServiceMaker
from ._options import TwistOptions
class Twist:
"""
Run a Twisted application.
"""
@staticmethod
def options(argv: Sequence[str]) -> TwistOptions:
"""
Parse command line options.
@param argv: Command line arguments.
@return: The parsed options.
"""
options = TwistOptions()
try:
options.parseOptions(argv[1:])
except UsageError as e:
exit(ExitStatus.EX_USAGE, f"Error: {e}\n\n{options}")
return options
@staticmethod
def service(plugin: IServiceMaker, options: Options) -> IService:
"""
Create the application service.
@param plugin: The name of the plugin that implements the service
application to run.
@param options: Options to pass to the application.
@return: The created application service.
"""
service = plugin.makeService(options)
application = Application(plugin.tapname)
service.setServiceParent(application)
return IService(application)
@staticmethod
def startService(reactor: IReactorCore, service: IService) -> None:
"""
Start the application service.
@param reactor: The reactor to run the service with.
@param service: The application service to run.
"""
service.startService()
# Ask the reactor to stop the service before shutting down
reactor.addSystemEventTrigger("before", "shutdown", service.stopService)
@staticmethod
def run(twistOptions: TwistOptions) -> None:
"""
Run the application service.
@param twistOptions: Command line options to convert to runner
arguments.
"""
runner = Runner(
reactor=twistOptions["reactor"],
defaultLogLevel=twistOptions["logLevel"],
logFile=twistOptions["logFile"],
fileLogObserverFactory=twistOptions["fileLogObserverFactory"],
)
runner.run()
reactor = twistOptions["reactor"]
if _ISupportsExitSignalCapturing.providedBy(reactor):
if reactor._exitSignal is not None:
_exitWithSignal(reactor._exitSignal)
@classmethod
def main(cls, argv: Sequence[str] = sys.argv) -> None:
"""
Executable entry point for L{Twist}.
Processes options and run a twisted reactor with a service.
@param argv: Command line arguments.
@type argv: L{list}
"""
options = cls.options(argv)
reactor = options["reactor"]
# If subCommand is None, TwistOptions.parseOptions() raises UsageError
# and Twist.options() will exit the runner, so we'll never get here.
subCommand = options.subCommand
assert subCommand is not None
service = cls.service(
plugin=options.plugins[subCommand],
options=options.subOptions,
)
cls.startService(reactor, service)
cls.run(options)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.twist.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist}.
"""

View File

@@ -0,0 +1,355 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist._options}.
"""
from sys import stderr, stdout
from typing import Callable, Dict, List, Optional, TextIO, Tuple
import twisted.trial.unittest
from twisted.copyright import version
from twisted.internet import reactor
from twisted.internet.interfaces import IReactorCore
from twisted.internet.testing import MemoryReactor
from twisted.logger import (
FileLogObserver,
LogLevel,
jsonFileLogObserver,
textFileLogObserver,
)
from twisted.python.usage import UsageError
from ...reactors import NoSuchReactor
from ...runner._exit import ExitStatus
from ...runner.test.test_runner import DummyExit
from ...service import ServiceMaker
from ...twist import _options
from .._options import TwistOptions
class OptionsTests(twisted.trial.unittest.TestCase):
"""
Tests for L{TwistOptions}.
"""
def patchExit(self) -> None:
"""
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
"""
self.exit = DummyExit()
self.patch(_options, "exit", self.exit)
def patchOpen(self) -> None:
"""
Patch L{_options.open} so we can capture usage and prevent actual opens.
"""
self.opened: List[Tuple[str, Optional[str]]] = []
def fakeOpen(name: str, mode: Optional[str] = None) -> TextIO:
if name == "nocanopen":
raise OSError(None, None, name)
self.opened.append((name, mode))
return NotImplemented
self.patch(_options, "openFile", fakeOpen)
def patchInstallReactor(self) -> None:
"""
Patch C{_options.installReactor} so we can capture usage and prevent
actual installs.
"""
self.installedReactors: Dict[str, IReactorCore] = {}
def installReactor(name: str) -> IReactorCore:
if name != "fusion":
raise NoSuchReactor()
reactor = MemoryReactor()
self.installedReactors[name] = reactor
return reactor
self.patch(_options, "installReactor", installReactor)
def test_synopsis(self) -> None:
"""
L{TwistOptions.getSynopsis} appends arguments.
"""
options = TwistOptions()
self.assertTrue(options.getSynopsis().endswith(" plugin [plugin_options]"))
def test_version(self) -> None:
"""
L{TwistOptions.opt_version} exits with L{ExitStatus.EX_OK} and prints
the version.
"""
self.patchExit()
options = TwistOptions()
options.opt_version()
self.assertEquals(self.exit.status, ExitStatus.EX_OK) # type: ignore[unreachable]
self.assertEquals(self.exit.message, version)
def test_reactor(self) -> None:
"""
L{TwistOptions.installReactor} installs the chosen reactor and sets
the reactor name.
"""
self.patchInstallReactor()
options = TwistOptions()
options.opt_reactor("fusion")
self.assertEqual(set(self.installedReactors), {"fusion"})
self.assertEquals(options["reactorName"], "fusion")
def test_installCorrectReactor(self) -> None:
"""
L{TwistOptions.installReactor} installs the chosen reactor after the
command line options have been parsed.
"""
self.patchInstallReactor()
options = TwistOptions()
options.subCommand = "test-subcommand"
options.parseOptions(["--reactor=fusion"])
self.assertEqual(set(self.installedReactors), {"fusion"})
def test_installReactorBogus(self) -> None:
"""
L{TwistOptions.installReactor} raises UsageError if an unknown reactor
is specified.
"""
self.patchInstallReactor()
options = TwistOptions()
self.assertRaises(UsageError, options.opt_reactor, "coal")
def test_installReactorDefault(self) -> None:
"""
L{TwistOptions.installReactor} returns the currently installed reactor
when the default reactor name is specified.
"""
options = TwistOptions()
self.assertIdentical(reactor, options.installReactor("default"))
def test_logLevelValid(self) -> None:
"""
L{TwistOptions.opt_log_level} sets the corresponding log level.
"""
options = TwistOptions()
options.opt_log_level("warn")
self.assertIdentical(options["logLevel"], LogLevel.warn)
def test_logLevelInvalid(self) -> None:
"""
L{TwistOptions.opt_log_level} with an invalid log level name raises
UsageError.
"""
options = TwistOptions()
self.assertRaises(UsageError, options.opt_log_level, "cheese")
def _testLogFile(self, name: str, expectedStream: TextIO) -> None:
"""
Set log file name and check the selected output stream.
@param name: The name of the file.
@param expectedStream: The expected stream.
"""
options = TwistOptions()
options.opt_log_file(name)
self.assertIdentical(options["logFile"], expectedStream)
def test_logFileStdout(self) -> None:
"""
L{TwistOptions.opt_log_file} given C{"-"} as a file name uses stdout.
"""
self._testLogFile("-", stdout)
def test_logFileStderr(self) -> None:
"""
L{TwistOptions.opt_log_file} given C{"+"} as a file name uses stderr.
"""
self._testLogFile("+", stderr)
def test_logFileNamed(self) -> None:
"""
L{TwistOptions.opt_log_file} opens the given file name in append mode.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_file("mylog")
self.assertEqual([("mylog", "a")], self.opened)
def test_logFileCantOpen(self) -> None:
"""
L{TwistOptions.opt_log_file} exits with L{ExitStatus.EX_IOERR} if
unable to open the log file due to an L{EnvironmentError}.
"""
self.patchExit()
self.patchOpen()
options = TwistOptions()
options.opt_log_file("nocanopen")
self.assertEquals(self.exit.status, ExitStatus.EX_IOERR)
self.assertIsNotNone(self.exit.message)
self.assertTrue(
self.exit.message.startswith( # type: ignore[union-attr]
"Unable to open log file 'nocanopen': "
)
)
def _testLogFormat(
self, format: str, expectedObserverFactory: Callable[[TextIO], FileLogObserver]
) -> None:
"""
Set log file format and check the selected observer factory.
@param format: The format of the file.
@param expectedObserverFactory: The expected observer factory.
"""
options = TwistOptions()
options.opt_log_format(format)
self.assertIdentical(options["fileLogObserverFactory"], expectedObserverFactory)
self.assertEqual(options["logFormat"], format)
def test_logFormatText(self) -> None:
"""
L{TwistOptions.opt_log_format} given C{"text"} uses a
L{textFileLogObserver}.
"""
self._testLogFormat("text", textFileLogObserver)
def test_logFormatJSON(self) -> None:
"""
L{TwistOptions.opt_log_format} given C{"text"} uses a
L{textFileLogObserver}.
"""
self._testLogFormat("json", jsonFileLogObserver)
def test_logFormatInvalid(self) -> None:
"""
L{TwistOptions.opt_log_format} given an invalid format name raises
L{UsageError}.
"""
options = TwistOptions()
self.assertRaises(UsageError, options.opt_log_format, "frommage")
def test_selectDefaultLogObserverNoOverride(self) -> None:
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_format("text") # Ask for text
options.opt_log_file("queso") # File, not a tty
options.selectDefaultLogObserver()
# Because we didn't select a file that is a tty, the default is JSON,
# but since we asked for text, we should get text.
self.assertIdentical(options["fileLogObserverFactory"], textFileLogObserver)
self.assertEqual(options["logFormat"], "text")
def test_selectDefaultLogObserverDefaultWithTTY(self) -> None:
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
class TTYFile:
def isatty(self) -> bool:
return True
# stdout may not be a tty, so let's make sure it thinks it is
self.patch(_options, "stdout", TTYFile())
options = TwistOptions()
options.opt_log_file("-") # stdout, a tty
options.selectDefaultLogObserver()
self.assertIdentical(options["fileLogObserverFactory"], textFileLogObserver)
self.assertEqual(options["logFormat"], "text")
def test_selectDefaultLogObserverDefaultWithoutTTY(self) -> None:
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_file("queso") # File, not a tty
options.selectDefaultLogObserver()
self.assertIdentical(options["fileLogObserverFactory"], jsonFileLogObserver)
self.assertEqual(options["logFormat"], "json")
def test_pluginsType(self) -> None:
"""
L{TwistOptions.plugins} is a mapping of available plug-ins.
"""
options = TwistOptions()
plugins = options.plugins
for name in plugins:
self.assertIsInstance(name, str)
self.assertIsInstance(plugins[name], ServiceMaker)
def test_pluginsIncludeWeb(self) -> None:
"""
L{TwistOptions.plugins} includes a C{"web"} plug-in.
This is an attempt to verify that something we expect to be in the list
is in there without enumerating all of the built-in plug-ins.
"""
options = TwistOptions()
self.assertIn("web", options.plugins)
def test_subCommandsType(self) -> None:
"""
L{TwistOptions.subCommands} is an iterable of tuples as expected by
L{twisted.python.usage.Options}.
"""
options = TwistOptions()
for name, shortcut, parser, doc in options.subCommands:
self.assertIsInstance(name, str)
self.assertIdentical(shortcut, None)
self.assertTrue(callable(parser))
self.assertIsInstance(doc, str)
def test_subCommandsIncludeWeb(self) -> None:
"""
L{TwistOptions.subCommands} includes a sub-command for every plug-in.
"""
options = TwistOptions()
plugins = set(options.plugins)
subCommands = {name for name, shortcut, parser, doc in options.subCommands}
self.assertEqual(subCommands, plugins)
def test_postOptionsNoSubCommand(self) -> None:
"""
L{TwistOptions.postOptions} raises L{UsageError} is it has no
sub-command.
"""
self.patchInstallReactor()
options = TwistOptions()
self.assertRaises(UsageError, options.postOptions)

View File

@@ -0,0 +1,256 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist._twist}.
"""
from sys import stdout
from typing import Any, Dict, List
import twisted.trial.unittest
from twisted.internet.interfaces import IReactorCore
from twisted.internet.testing import MemoryReactor
from twisted.logger import LogLevel, jsonFileLogObserver
from twisted.test.test_twistd import SignalCapturingMemoryReactor
from ...runner._exit import ExitStatus
from ...runner._runner import Runner
from ...runner.test.test_runner import DummyExit
from ...service import IService, MultiService
from ...twist import _twist
from .._options import TwistOptions
from .._twist import Twist
class TwistTests(twisted.trial.unittest.TestCase):
"""
Tests for L{Twist}.
"""
def setUp(self) -> None:
self.patchInstallReactor()
def patchExit(self) -> None:
"""
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
"""
self.exit = DummyExit()
self.patch(_twist, "exit", self.exit)
def patchInstallReactor(self) -> None:
"""
Patch C{_options.installReactor} so we can capture usage and prevent
actual installs.
"""
self.installedReactors: Dict[str, IReactorCore] = {}
def installReactor(_: TwistOptions, name: str) -> IReactorCore:
reactor = MemoryReactor()
self.installedReactors[name] = reactor
return reactor
self.patch(TwistOptions, "installReactor", installReactor)
def patchStartService(self) -> None:
"""
Patch L{MultiService.startService} so we can capture usage and prevent
actual starts.
"""
self.serviceStarts: List[IService] = []
def startService(service: IService) -> None:
self.serviceStarts.append(service)
self.patch(MultiService, "startService", startService)
def test_optionsValidArguments(self) -> None:
"""
L{Twist.options} given valid arguments returns options.
"""
options = Twist.options(["twist", "web"])
self.assertIsInstance(options, TwistOptions)
def test_optionsInvalidArguments(self) -> None:
"""
L{Twist.options} given invalid arguments exits with
L{ExitStatus.EX_USAGE} and an error/usage message.
"""
self.patchExit()
Twist.options(["twist", "--bogus-bagels"])
self.assertIdentical(self.exit.status, ExitStatus.EX_USAGE)
self.assertIsNotNone(self.exit.message)
self.assertTrue(
self.exit.message.startswith("Error: ") # type: ignore[union-attr]
)
self.assertTrue(
self.exit.message.endswith( # type: ignore[union-attr]
f"\n\n{TwistOptions()}"
)
)
def test_service(self) -> None:
"""
L{Twist.service} returns an L{IService}.
"""
options = Twist.options(["twist", "web"]) # web should exist
service = Twist.service(options.plugins["web"], options.subOptions)
self.assertTrue(IService.providedBy(service))
def test_startService(self) -> None:
"""
L{Twist.startService} starts the service and registers a trigger to
stop the service when the reactor shuts down.
"""
options = Twist.options(["twist", "web"])
reactor = options["reactor"]
subCommand = options.subCommand
assert subCommand is not None
service = Twist.service(
plugin=options.plugins[subCommand],
options=options.subOptions,
)
self.patchStartService()
Twist.startService(reactor, service)
self.assertEqual(self.serviceStarts, [service])
self.assertEqual(
reactor.triggers["before"]["shutdown"], [(service.stopService, (), {})]
)
def test_run(self) -> None:
"""
L{Twist.run} runs the runner with arguments corresponding to the given
options.
"""
argsSeen = []
self.patch(Runner, "__init__", lambda self, **args: argsSeen.append(args))
self.patch(Runner, "run", lambda self: None)
twistOptions = Twist.options(
["twist", "--reactor=default", "--log-format=json", "web"]
)
Twist.run(twistOptions)
self.assertEqual(len(argsSeen), 1)
self.assertEqual(
argsSeen[0],
dict(
reactor=self.installedReactors["default"],
defaultLogLevel=LogLevel.info,
logFile=stdout,
fileLogObserverFactory=jsonFileLogObserver,
),
)
def test_main(self) -> None:
"""
L{Twist.main} runs the runner with arguments corresponding to the given
command line arguments.
"""
self.patchStartService()
runners = []
class Runner:
def __init__(self, **kwargs: Any) -> None:
self.args = kwargs
self.runs = 0
runners.append(self)
def run(self) -> None:
self.runs += 1
self.patch(_twist, "Runner", Runner)
Twist.main(["twist", "--reactor=default", "--log-format=json", "web"])
self.assertEqual(len(self.serviceStarts), 1)
self.assertEqual(len(runners), 1)
self.assertEqual(
runners[0].args,
dict(
reactor=self.installedReactors["default"],
defaultLogLevel=LogLevel.info,
logFile=stdout,
fileLogObserverFactory=jsonFileLogObserver,
),
)
self.assertEqual(runners[0].runs, 1)
class TwistExitTests(twisted.trial.unittest.TestCase):
"""
Tests to verify that the Twist script takes the expected actions related
to signals and the reactor.
"""
def setUp(self) -> None:
self.exitWithSignalCalled = False
def fakeExitWithSignal(sig: int) -> None:
"""
Fake to capture whether L{twisted.application._exitWithSignal
was called.
@param sig: Signal value
@type sig: C{int}
"""
self.exitWithSignalCalled = True
self.patch(_twist, "_exitWithSignal", fakeExitWithSignal)
def startLogging(_: Runner) -> None:
"""
Prevent Runner from adding new log observers or other
tests outside this module will fail.
@param _: Unused self param
"""
self.patch(Runner, "startLogging", startLogging)
def test_twistReactorDoesntExitWithSignal(self) -> None:
"""
_exitWithSignal is not called if the reactor's _exitSignal attribute
is zero.
"""
reactor = SignalCapturingMemoryReactor()
reactor._exitSignal = None
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertFalse(self.exitWithSignalCalled)
def test_twistReactorHasNoExitSignalAttr(self) -> None:
"""
_exitWithSignal is not called if the runner's reactor does not
implement L{twisted.internet.interfaces._ISupportsExitSignalCapturing}
"""
reactor = MemoryReactor()
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertFalse(self.exitWithSignalCalled)
def test_twistReactorExitsWithSignal(self) -> None:
"""
_exitWithSignal is called if the runner's reactor exits due
to a signal.
"""
reactor = SignalCapturingMemoryReactor()
reactor._exitSignal = 2
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertTrue(self.exitWithSignalCalled)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.conch.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
"""

View File

@@ -0,0 +1,56 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
from zope.interface import implementer
from twisted.conch.error import ConchError
from twisted.conch.interfaces import IConchUser
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
from twisted.logger import Logger
from twisted.python.compat import nativeString
@implementer(IConchUser)
class ConchUser:
_log = Logger()
def __init__(self):
self.channelLookup = {}
self.subsystemLookup = {}
@property
def conn(self):
return self._conn
@conn.setter
def conn(self, value):
self._conn = value
def lookupChannel(self, channelType, windowSize, maxPacket, data):
klass = self.channelLookup.get(channelType, None)
if not klass:
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
else:
return klass(
remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data,
avatar=self,
)
def lookupSubsystem(self, subsystem, data):
self._log.debug(
"Subsystem lookup: {subsystem!r}", subsystem=self.subsystemLookup
)
klass = self.subsystemLookup.get(subsystem, None)
if not klass:
return False
return klass(data, avatar=self)
def gotGlobalRequest(self, requestType, data):
# XXX should this use method dispatch?
requestType = nativeString(requestType.replace(b"-", b"_"))
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)

View File

@@ -0,0 +1,640 @@
# -*- test-case-name: twisted.conch.test.test_checkers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
"""
import binascii
import errno
import sys
from base64 import decodebytes
from typing import IO, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, cast
from zope.interface import Interface, implementer, providedBy
from incremental import Version
from typing_extensions import Literal, Protocol
from twisted.conch import error
from twisted.conch.ssh import keys
from twisted.cred.checkers import ICredentialsChecker
from twisted.cred.credentials import ISSHPrivateKey, IUsernamePassword
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
from twisted.internet import defer
from twisted.logger import Logger
from twisted.plugins.cred_unix import verifyCryptedPassword
from twisted.python import failure, reflect
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.filepath import FilePath
from twisted.python.util import runAsEffectiveUser
_log = Logger()
class UserRecord(Tuple[str, str, int, int, str, str, str]):
"""
A record in a UNIX-style password database. See L{pwd} for field details.
This corresponds to the undocumented type L{pwd.struct_passwd}, but lacks named
field accessors.
"""
@property
def pw_dir(self) -> str: # type: ignore[empty-body]
...
class UserDB(Protocol):
"""
A database of users by name, like the stdlib L{pwd} module.
See L{twisted.python.fakepwd} for an in-memory implementation.
"""
def getpwnam(self, username: str) -> UserRecord:
"""
Lookup a user record by name.
@raises KeyError: when no such user exists
"""
pwd: Optional[UserDB]
try:
import pwd as _pwd
except ImportError:
pwd = None
else:
pwd = cast(UserDB, _pwd)
try:
import spwd as _spwd
except ImportError:
spwd = None
else:
spwd = _spwd
class CryptedPasswordRecord(Protocol):
"""
A sequence where the item at index 1 may be a crypted password.
Both L{pwd.struct_passwd} and L{spwd.struct_spwd} conform to this protocol.
"""
def __getitem__(self, index: Literal[1]) -> str:
"""
Get the crypted password.
"""
def _lookupUser(userdb: UserDB, username: bytes) -> UserRecord:
"""
Lookup a user by name in a L{pwd}-style database.
@param userdb: The user database.
@param username: Identifying name in bytes. This will be decoded according
to the filesystem encoding, as the L{pwd} module does internally.
@raises KeyError: when the user doesn't exist
"""
return userdb.getpwnam(username.decode(sys.getfilesystemencoding()))
def _pwdGetByName(username: str) -> Optional[CryptedPasswordRecord]:
"""
Look up a user in the /etc/passwd database using the pwd module. If the
pwd module is not available, return None.
@param username: the username of the user to return the passwd database
information for.
@returns: A L{pwd.struct_passwd}, where field 1 may contain a crypted
password, or L{None} when the L{pwd} database is unavailable.
@raises KeyError: when no such user exists
"""
if pwd is None:
return None
return cast(CryptedPasswordRecord, pwd.getpwnam(username))
def _shadowGetByName(username: str) -> Optional[CryptedPasswordRecord]:
"""
Look up a user in the /etc/shadow database using the spwd module. If it is
not available, return L{None}.
@param username: the username of the user to return the shadow database
information for.
@type username: L{str}
@returns: A L{spwd.struct_spwd}, where field 1 may contain a crypted
password, or L{None} when the L{spwd} database is unavailable.
@raises KeyError: when no such user exists
"""
if spwd is not None:
f = spwd.getspnam
else:
return None
return cast(CryptedPasswordRecord, runAsEffectiveUser(0, 0, f, username))
@implementer(ICredentialsChecker)
class UNIXPasswordDatabase:
"""
A checker which validates users out of the UNIX password databases, or
databases of a compatible format.
@ivar _getByNameFunctions: a C{list} of functions which are called in order
to validate a user. The default value is such that the C{/etc/passwd}
database will be tried first, followed by the C{/etc/shadow} database.
"""
credentialInterfaces = (IUsernamePassword,)
def __init__(self, getByNameFunctions=None):
if getByNameFunctions is None:
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
self._getByNameFunctions = getByNameFunctions
def requestAvatarId(self, credentials):
# We get bytes, but the Py3 pwd module uses str. So attempt to decode
# it using the same method that CPython does for the file on disk.
username = credentials.username.decode(sys.getfilesystemencoding())
password = credentials.password.decode(sys.getfilesystemencoding())
for func in self._getByNameFunctions:
try:
pwnam = func(username)
except KeyError:
return defer.fail(UnauthorizedLogin("invalid username"))
else:
if pwnam is not None:
crypted = pwnam[1]
if crypted == "":
continue
if verifyCryptedPassword(crypted, password):
return defer.succeed(credentials.username)
# fallback
return defer.fail(UnauthorizedLogin("unable to verify password"))
@implementer(ICredentialsChecker)
class SSHPublicKeyDatabase:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
"""
credentialInterfaces = (ISSHPrivateKey,)
_userdb: UserDB = cast(UserDB, pwd)
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self.checkKey, credentials)
d.addCallback(self._cbRequestAvatarId, credentials)
d.addErrback(self._ebRequestAvatarId)
return d
def _cbRequestAvatarId(self, validKey, credentials):
"""
Check whether the credentials themselves are valid, now that we know
if the key matches the user.
@param validKey: A boolean indicating whether or not the public key
matches a key in the user's authorized_keys file.
@param credentials: The credentials offered by the user.
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: (as a failure) if the key does not match the
user in C{credentials}. Also raised if the user provides an invalid
signature.
@raise ValidPublicKey: (as a failure) if the key matches the user but
the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@return: The user's username, if authentication was successful.
"""
if not validKey:
return failure.Failure(UnauthorizedLogin("invalid key"))
if not credentials.signature:
return failure.Failure(error.ValidPublicKey())
else:
try:
pubKey = keys.Key.fromString(credentials.blob)
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except Exception: # any error should be treated as a failed login
_log.failure("Error while verifying key")
return failure.Failure(UnauthorizedLogin("error while verifying key"))
return failure.Failure(UnauthorizedLogin("unable to verify key"))
def getAuthorizedKeysFiles(self, credentials):
"""
Return a list of L{FilePath} instances for I{authorized_keys} files
which might contain information about authorized keys for the given
credentials.
On OpenSSH servers, the default location of the file containing the
list of authorized public keys is
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
U{deprecated by OpenSSH since
2001<http://marc.info/?m=100508718416162>}.
@return: A list of L{FilePath} instances to files with the authorized keys.
"""
pwent = _lookupUser(self._userdb, credentials.username)
root = FilePath(pwent.pw_dir).child(".ssh")
files = ["authorized_keys", "authorized_keys2"]
return [root.child(f) for f in files]
def checkKey(self, credentials):
"""
Retrieve files containing authorized keys and check against user
credentials.
"""
ouid, ogid = _lookupUser(self._userdb, credentials.username)[2:4]
for filepath in self.getAuthorizedKeysFiles(credentials):
if not filepath.exists():
continue
try:
lines = filepath.open()
except OSError as e:
if e.errno == errno.EACCES:
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
else:
raise
with lines:
for l in lines:
l2 = l.split()
if len(l2) < 2:
continue
try:
if decodebytes(l2[1]) == credentials.blob:
return True
except binascii.Error:
continue
return False
def _ebRequestAvatarId(self, f):
if not f.check(UnauthorizedLogin):
_log.error(
"Unauthorized login due to internal error: {error}", error=f.value
)
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
return f
@implementer(ICredentialsChecker)
class SSHProtocolChecker:
"""
SSHProtocolChecker is a checker that requires multiple authentications
to succeed. To add a checker, call my registerChecker method with
the checker and the interface.
After each successful authenticate, I call my areDone method with the
avatar id. To get a list of the successful credentials for an avatar id,
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
returns True, the authentication has succeeded.
"""
def __init__(self):
self.checkers = {}
self.successfulCredentials = {}
@property
def credentialInterfaces(self):
return list(self.checkers.keys())
def registerChecker(self, checker, *credentialInterfaces):
if not credentialInterfaces:
credentialInterfaces = checker.credentialInterfaces
for credentialInterface in credentialInterfaces:
self.checkers[credentialInterface] = checker
def requestAvatarId(self, credentials):
"""
Part of the L{ICredentialsChecker} interface. Called by a portal with
some credentials to check if they'll authenticate a user. We check the
interfaces that the credentials provide against our list of acceptable
checkers. If one of them matches, we ask that checker to verify the
credentials. If they're valid, we call our L{_cbGoodAuthentication}
method to continue.
@param credentials: the credentials the L{Portal} wants us to verify
"""
ifac = providedBy(credentials)
for i in ifac:
c = self.checkers.get(i)
if c is not None:
d = defer.maybeDeferred(c.requestAvatarId, credentials)
return d.addCallback(self._cbGoodAuthentication, credentials)
return defer.fail(
UnhandledCredentials(
"No checker for %s" % ", ".join(map(reflect.qual, ifac))
)
)
def _cbGoodAuthentication(self, avatarId, credentials):
"""
Called if a checker has verified the credentials. We call our
L{areDone} method to see if the whole of the successful authentications
are enough. If they are, we return the avatar ID returned by the first
checker.
"""
if avatarId not in self.successfulCredentials:
self.successfulCredentials[avatarId] = []
self.successfulCredentials[avatarId].append(credentials)
if self.areDone(avatarId):
del self.successfulCredentials[avatarId]
return avatarId
else:
raise error.NotEnoughAuthentication()
def areDone(self, avatarId):
"""
Override to determine if the authentication is finished for a given
avatarId.
@param avatarId: the avatar returned by the first checker. For
this checker to function correctly, all the checkers must
return the same avatar ID.
"""
return True
deprecatedModuleAttribute(
Version("Twisted", 15, 0, 0),
(
"Please use twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."
),
__name__,
"SSHPublicKeyDatabase",
)
class IAuthorizedKeysDB(Interface):
"""
An object that provides valid authorized ssh keys mapped to usernames.
@since: 15.0
"""
def getAuthorizedKeys(avatarId):
"""
Gets an iterable of authorized keys that are valid for the given
C{avatarId}.
@param avatarId: the ID of the avatar
@type avatarId: valid return value of
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
"""
def readAuthorizedKeyFile(
fileobj: IO[bytes], parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString
) -> Iterator[keys.Key]:
"""
Reads keys from an authorized keys file. Any non-comment line that cannot
be parsed as a key will be ignored, although that particular line will
be logged.
@param fileobj: something from which to read lines which can be parsed
as keys
@param parseKey: a callable that takes bytes and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@return: an iterable of L{twisted.conch.ssh.keys.Key}
@since: 15.0
"""
for line in fileobj:
line = line.strip()
if line and not line.startswith(b"#"): # for comments
try:
yield parseKey(line)
except keys.BadKeyError as e:
_log.error(
"Unable to parse line {line!r} as a key: {error!s}",
line=line,
error=e,
)
def _keysFromFilepaths(
filepaths: Iterable[FilePath[Any]], parseKey: Callable[[bytes], keys.Key]
) -> Iterable[keys.Key]:
"""
Helper function that turns an iterable of filepaths into a generator of
keys. If any file cannot be read, a message is logged but it is
otherwise ignored.
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
@type filepaths: iterable
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}
@type parseKey: L{callable}
@return: generator of L{twisted.conch.ssh.keys.Key}
@since: 15.0
"""
for fp in filepaths:
if fp.exists():
try:
with fp.open() as f:
yield from readAuthorizedKeyFile(f, parseKey)
except OSError as e:
_log.error("Unable to read {path!r}: {error!s}", path=fp.path, error=e)
@implementer(IAuthorizedKeysDB)
class InMemorySSHKeyDB:
"""
Object that provides SSH public keys based on a dictionary of usernames
mapped to L{twisted.conch.ssh.keys.Key}s.
@since: 15.0
"""
def __init__(self, mapping: Mapping[bytes, Iterable[keys.Key]]) -> None:
"""
Initializes a new L{InMemorySSHKeyDB}.
@param mapping: mapping of usernames to iterables of
L{twisted.conch.ssh.keys.Key}s
"""
self._mapping = mapping
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
"""
Look up the authorized keys for a user.
@param username: Name of the user
"""
return self._mapping.get(username, [])
@implementer(IAuthorizedKeysDB)
class UNIXAuthorizedKeysFiles:
"""
Object that provides SSH public keys based on public keys listed in
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
If any of the files cannot be read, a message is logged but that file is
otherwise ignored.
@since: 15.0
"""
_userdb: UserDB
def __init__(
self,
userdb: Optional[UserDB] = None,
parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString,
):
"""
Initializes a new L{UNIXAuthorizedKeysFiles}.
@param userdb: access to the Unix user account and password database
(default is the Python module L{pwd}, if available)
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
"""
if userdb is not None:
self._userdb = userdb
elif pwd is not None:
self._userdb = pwd
else:
raise ValueError("No pwd module found, and no userdb argument passed.")
self._parseKey = parseKey
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
try:
passwd = _lookupUser(self._userdb, username)
except KeyError:
return ()
root = FilePath(passwd.pw_dir).child(".ssh")
files = ["authorized_keys", "authorized_keys2"]
return _keysFromFilepaths((root.child(f) for f in files), self._parseKey)
@implementer(ICredentialsChecker)
class SSHPublicKeyChecker:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
@since: 15.0
"""
credentialInterfaces = (ISSHPrivateKey,)
def __init__(self, keydb: IAuthorizedKeysDB) -> None:
"""
Initializes a L{SSHPublicKeyChecker}.
@param keydb: a provider of L{IAuthorizedKeysDB}
"""
self._keydb = keydb
def requestAvatarId(self, credentials):
d = defer.execute(self._sanityCheckKey, credentials)
d.addCallback(self._checkKey, credentials)
d.addCallback(self._verifyKey, credentials)
return d
def _sanityCheckKey(self, credentials):
"""
Checks whether the provided credentials are a valid SSH key with a
signature (does not actually verify the signature).
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise ValidPublicKey: the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@raise BadKeyError: The key included with the credentials is not
recognized as a key.
@return: the key in the credentials
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if not credentials.signature:
raise error.ValidPublicKey()
return keys.Key.fromString(credentials.blob)
def _checkKey(self, pubKey, credentials):
"""
Checks the public key against all authorized keys (if any) for the
user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey:
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key is not authorized, or if there
was any error obtaining a list of authorized keys for the user.
@return: C{pubKey} if the key is authorized
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if any(
key == pubKey for key in self._keydb.getAuthorizedKeys(credentials.username)
):
return pubKey
raise UnauthorizedLogin("Key not authorized")
def _verifyKey(self, pubKey, credentials):
"""
Checks whether the credentials themselves are valid, now that we know
if the key matches the user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey: L{twisted.conch.ssh.keys.Key}
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key signature is invalid or there
was any error verifying the signature.
@return: The user's username, if authentication was successful
@rtype: L{bytes}
"""
try:
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except Exception as e: # Any error should be treated as a failed login
raise UnauthorizedLogin("Error while verifying key") from e
raise UnauthorizedLogin("Key signature invalid.")

View File

@@ -0,0 +1,9 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Client support code for Conch.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,65 @@
# -*- test-case-name: twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Accesses the key agent for user authentication.
Maintainer: Paul Swartz
"""
import os
from twisted.conch.ssh import agent, channel, keys
from twisted.internet import protocol, reactor
from twisted.logger import Logger
class SSHAgentClient(agent.SSHAgentClient):
_log = Logger()
def __init__(self):
agent.SSHAgentClient.__init__(self)
self.blobs = []
def getPublicKeys(self):
return self.requestIdentities().addCallback(self._cbPublicKeys)
def _cbPublicKeys(self, blobcomm):
self._log.debug("got {num_keys} public keys", num_keys=len(blobcomm))
self.blobs = [x[0] for x in blobcomm]
def getPublicKey(self):
"""
Return a L{Key} from the first blob in C{self.blobs}, if any, or
return L{None}.
"""
if self.blobs:
return keys.Key.fromString(self.blobs.pop(0))
return None
class SSHAgentForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
d.addCallback(self._cbGotLocal)
d.addErrback(lambda x: self.loseConnection())
self.buf = ""
def _cbGotLocal(self, local):
self.local = local
self.dataReceived = self.local.transport.write
self.local.dataReceived = self.write
def dataReceived(self, data):
self.buf += data
def closed(self):
if self.local:
self.local.loseConnection()
self.local = None
class SSHAgentForwardingLocal(protocol.Protocol):
pass

View File

@@ -0,0 +1,24 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.conch.client import direct
connectTypes = {"direct": direct.connect}
def connect(host, port, options, verifyHostKey, userAuthObject):
useConnects = ["direct"]
return _ebConnect(
None, useConnects, host, port, options, verifyHostKey, userAuthObject
)
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
if not useConnects:
return f
connectType = useConnects.pop(0)
f = connectTypes[connectType]
d = f(host, port, options, vhk, uao)
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
return d

View File

@@ -0,0 +1,331 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Various classes and functions for implementing user-interaction in the
command-line conch client.
You probably shouldn't use anything in this module directly, since it assumes
you are sitting at an interactive terminal. For example, to programmatically
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
"""
import contextlib
import getpass
import io
import os
import sys
from base64 import decodebytes
from twisted.conch.client import agent
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.error import ConchError
from twisted.conch.ssh import common, keys, userauth
from twisted.internet import defer, protocol, reactor
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
# The default location of the known hosts file (probably should be parsed out
# of an ssh config file someday).
_KNOWN_HOSTS = "~/.ssh/known_hosts"
# This name is bound so that the unit tests can use 'patch' to override it.
_open = open
_input = input
def verifyHostKey(transport, host, pubKey, fingerprint):
"""
Verify a host's key.
This function is a gross vestige of some bad factoring in the client
internals. The actual implementation, and a better signature of this logic
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
because the callers have not yet been rehabilitated, but they should
eventually be changed to call that method instead.
However, this function does perform two functions not implemented by
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
known_hosts file based on the options (which should really be the options
object's job), and it provides an opener to L{ConsoleUI} which opens
'/dev/tty' so that the user will be prompted on the tty of the process even
if the input and output of the process has been redirected. This latter
part is, somewhat obviously, not portable, but I don't know of a portable
equivalent that could be used.
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
always the dotted-quad IP address of the host being connected to.
@type host: L{str}
@param transport: the client transport which is attempting to connect to
the given host.
@type transport: L{SSHClientTransport}
@param fingerprint: the fingerprint of the given public key, in
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
from the key itself.
@type fingerprint: L{str}
@param pubKey: The public key of the server being connected to.
@type pubKey: L{str}
@return: a L{Deferred} which fires with C{1} if the key was successfully
verified, or fails if the key could not be successfully verified. Failure
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
L{KeyboardInterrupt}.
"""
actualHost = transport.factory.options["host"]
actualKey = keys.Key.fromString(pubKey)
kh = KnownHostsFile.fromPath(
FilePath(
transport.factory.options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS)
)
)
ui = ConsoleUI(lambda: _open("/dev/tty", "r+b", buffering=0))
return kh.verifyHostKey(ui, actualHost, host, actualKey)
def isInKnownHosts(host, pubKey, options):
"""
Checks to see if host is in the known_hosts file for the user.
@return: 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
@rtype: L{int}
"""
keyType = common.getNS(pubKey)[0]
retVal = 0
if not options["known-hosts"] and not os.path.exists(os.path.expanduser("~/.ssh/")):
print("Creating ~/.ssh directory...")
os.mkdir(os.path.expanduser("~/.ssh"))
kh_file = options["known-hosts"] or _KNOWN_HOSTS
try:
known_hosts = open(os.path.expanduser(kh_file), "rb")
except OSError:
return 0
with known_hosts:
for line in known_hosts.readlines():
split = line.split()
if len(split) < 3:
continue
hosts, hostKeyType, encodedKey = split[:3]
if host not in hosts.split(b","): # incorrect host
continue
if hostKeyType != keyType: # incorrect type of key
continue
try:
decodedKey = decodebytes(encodedKey)
except BaseException:
continue
if decodedKey == pubKey:
return 1
else:
retVal = 2
return retVal
def getHostKeyAlgorithms(host, options):
"""
Look in known_hosts for a key corresponding to C{host}.
This can be used to change the order of supported key types
in the KEXINIT packet.
@type host: L{str}
@param host: the host to check in known_hosts
@type options: L{twisted.conch.client.options.ConchOptions}
@param options: options passed to client
@return: L{list} of L{str} representing key types or L{None}.
"""
knownHosts = KnownHostsFile.fromPath(
FilePath(options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS))
)
keyTypes = []
for entry in knownHosts.iterentries():
if entry.matchesHost(host):
if entry.keyType not in keyTypes:
keyTypes.append(entry.keyType)
return keyTypes or None
class SSHUserAuthClient(userauth.SSHUserAuthClient):
def __init__(self, user, options, *args):
userauth.SSHUserAuthClient.__init__(self, user, *args)
self.keyAgent = None
self.options = options
self.usedFiles = []
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
def serviceStarted(self):
if "SSH_AUTH_SOCK" in os.environ and not self.options["noagent"]:
self._log.debug(
"using SSH agent {authSock!r}", authSock=os.environ["SSH_AUTH_SOCK"]
)
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
d.addCallback(self._setAgent)
d.addErrback(self._ebSetAgent)
else:
userauth.SSHUserAuthClient.serviceStarted(self)
def serviceStopped(self):
if self.keyAgent:
self.keyAgent.transport.loseConnection()
self.keyAgent = None
def _setAgent(self, a):
self.keyAgent = a
d = self.keyAgent.getPublicKeys()
d.addBoth(self._ebSetAgent)
return d
def _ebSetAgent(self, f):
userauth.SSHUserAuthClient.serviceStarted(self)
def _getPassword(self, prompt):
"""
Prompt for a password using L{getpass.getpass}.
@param prompt: Written on tty to ask for the input.
@type prompt: L{str}
@return: The input.
@rtype: L{str}
"""
with self._replaceStdoutStdin():
try:
p = getpass.getpass(prompt)
return p
except (KeyboardInterrupt, OSError):
print()
raise ConchError("PEBKAC")
def getPassword(self, prompt=None):
if prompt:
prompt = nativeString(prompt)
else:
prompt = "{}@{}'s password: ".format(
nativeString(self.user),
self.transport.transport.getPeer().host,
)
try:
# We don't know the encoding the other side is using,
# signaling that is not part of the SSH protocol. But
# using our defaultencoding is better than just going for
# ASCII.
p = self._getPassword(prompt).encode(sys.getdefaultencoding())
return defer.succeed(p)
except ConchError:
return defer.fail()
def getPublicKey(self):
"""
Get a public key from the key agent if possible, otherwise look in
the next configured identity file for one.
"""
if self.keyAgent:
key = self.keyAgent.getPublicKey()
if key is not None:
return key
files = [x for x in self.options.identitys if x not in self.usedFiles]
self._log.debug(
"public key identities: {identities}\n{files}",
identities=self.options.identitys,
files=files,
)
if not files:
return None
file = files[0]
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += ".pub"
if not os.path.exists(file):
return self.getPublicKey() # try again
try:
return keys.Key.fromFile(file)
except keys.BadKeyError:
return self.getPublicKey() # try again
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{bytes}
"""
if not self.usedFiles: # agent key
return self.keyAgent.signData(publicKey.blob(), signData)
else:
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Try to load the private key from the last used file identified by
C{getPublicKey}, potentially asking for the passphrase if the key is
encrypted.
"""
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file))
except keys.EncryptedKeyError:
for i in range(3):
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
try:
p = self._getPassword(prompt).encode(sys.getfilesystemencoding())
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
except (keys.BadKeyError, ConchError):
pass
return defer.fail(ConchError("bad password"))
raise
except KeyboardInterrupt:
print()
reactor.stop()
def getGenericAnswers(self, name, instruction, prompts):
responses = []
with self._replaceStdoutStdin():
if name:
print(name.decode("utf-8"))
if instruction:
print(instruction.decode("utf-8"))
for prompt, echo in prompts:
prompt = prompt.decode("utf-8")
if echo:
responses.append(_input(prompt))
else:
responses.append(getpass.getpass(prompt))
return defer.succeed(responses)
@classmethod
def _openTty(cls):
"""
Open /dev/tty as two streams one in read, one in write mode,
and return them.
@return: File objects for reading and writing to /dev/tty,
corresponding to standard input and standard output.
@rtype: A L{tuple} of L{io.TextIOWrapper} on Python 3.
"""
stdin = io.TextIOWrapper(open("/dev/tty", "rb"))
stdout = io.TextIOWrapper(open("/dev/tty", "wb"))
return stdin, stdout
@classmethod
@contextlib.contextmanager
def _replaceStdoutStdin(cls):
"""
Contextmanager that replaces stdout and stdin with /dev/tty
and resets them when it is done.
"""
oldout, oldin = sys.stdout, sys.stdin
sys.stdin, sys.stdout = cls._openTty()
try:
yield
finally:
sys.stdout.close()
sys.stdin.close()
sys.stdout, sys.stdin = oldout, oldin

View File

@@ -0,0 +1,98 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.conch import error
from twisted.conch.ssh import transport
from twisted.internet import defer, protocol, reactor
class SSHClientFactory(protocol.ClientFactory):
def __init__(self, d, options, verifyHostKey, userAuthObject):
self.d = d
self.options = options
self.verifyHostKey = verifyHostKey
self.userAuthObject = userAuthObject
def clientConnectionLost(self, connector, reason):
if self.options["reconnect"]:
connector.connect()
def clientConnectionFailed(self, connector, reason):
if self.d is None:
return
d, self.d = self.d, None
d.errback(reason)
def buildProtocol(self, addr):
trans = SSHClientTransport(self)
if self.options["ciphers"]:
trans.supportedCiphers = self.options["ciphers"]
if self.options["macs"]:
trans.supportedMACs = self.options["macs"]
if self.options["compress"]:
trans.supportedCompressions[0:1] = ["zlib"]
if self.options["host-key-algorithms"]:
trans.supportedPublicKeys = self.options["host-key-algorithms"]
return trans
class SSHClientTransport(transport.SSHClientTransport):
def __init__(self, factory):
self.factory = factory
self.unixServer = None
def connectionLost(self, reason):
if self.unixServer:
d = self.unixServer.stopListening()
self.unixServer = None
else:
d = defer.succeed(None)
d.addCallback(
lambda x: transport.SSHClientTransport.connectionLost(self, reason)
)
def receiveError(self, code, desc):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
d.errback(error.ConchError(desc, code))
def sendDisconnect(self, code, reason):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
transport.SSHClientTransport.sendDisconnect(self, code, reason)
d.errback(error.ConchError(reason, code))
def receiveDebug(self, alwaysDisplay, message, lang):
self._log.debug(
"Received Debug Message: {message}",
message=message,
alwaysDisplay=alwaysDisplay,
lang=lang,
)
if alwaysDisplay: # XXX what should happen here?
print(message)
def verifyHostKey(self, pubKey, fingerprint):
return self.factory.verifyHostKey(
self, self.transport.getPeer().host, pubKey, fingerprint
)
def setService(self, service):
self._log.info("setting client server to {service}", service=service)
transport.SSHClientTransport.setService(self, service)
if service.name != "ssh-userauth" and self.factory.d is not None:
d, self.factory.d = self.factory.d, None
d.callback(None)
def connectionSecure(self):
self.requestService(self.factory.userAuthObject)
def connect(host, port, options, verifyHostKey, userAuthObject):
d = defer.Deferred()
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
reactor.connectTCP(host, port, factory)
return d

View File

@@ -0,0 +1,622 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of the OpenSSH known_hosts database.
@since: 8.2
"""
from __future__ import annotations
import hmac
import sys
from binascii import Error as DecodeError, a2b_base64, b2a_base64
from contextlib import closing
from hashlib import sha1
from typing import IO, Callable, Literal
from zope.interface import implementer
from twisted.conch.error import HostKeyChanged, InvalidEntry, UserRejectedKey
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.ssh.keys import BadKeyError, FingerprintFormats, Key
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.logger import Logger
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.randbytes import secureRandom
from twisted.python.util import FancyEqMixin
log = Logger()
def _b64encode(s):
"""
Encode a binary string as base64 with no trailing newline.
@param s: The string to encode.
@type s: L{bytes}
@return: The base64-encoded string.
@rtype: L{bytes}
"""
return b2a_base64(s).strip()
def _extractCommon(string):
"""
Extract common elements of base64 keys from an entry in a hosts file.
@param string: A known hosts file entry (a single line).
@type string: L{bytes}
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
simply the beginning of the line up to the first occurrence of
whitespace.
@rtype: L{tuple}
"""
elements = string.split(None, 2)
if len(elements) != 3:
raise InvalidEntry()
hostnames, keyType, keyAndComment = elements
splitkey = keyAndComment.split(None, 1)
if len(splitkey) == 2:
keyString, comment = splitkey
comment = comment.rstrip(b"\n")
else:
keyString = splitkey[0]
comment = None
key = Key.fromString(a2b_base64(keyString))
return hostnames, keyType, key, comment
class _BaseEntry:
"""
Abstract base of both hashed and non-hashed entry objects, since they
represent keys and key types the same way.
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
@type keyType: L{bytes}
@ivar publicKey: The server public key indicated by this line.
@type publicKey: L{twisted.conch.ssh.keys.Key}
@ivar comment: Trailing garbage after the key line.
@type comment: L{bytes}
"""
def __init__(self, keyType, publicKey, comment):
self.keyType = keyType
self.publicKey = publicKey
self.comment = comment
def matchesKey(self, keyObject):
"""
Check to see if this entry matches a given key object.
@param keyObject: A public key object to check.
@type keyObject: L{Key}
@return: C{True} if this entry's key matches C{keyObject}, C{False}
otherwise.
@rtype: L{bool}
"""
return self.publicKey == keyObject
@implementer(IKnownHostEntry)
class PlainEntry(_BaseEntry):
"""
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
file.
@ivar _hostnames: the list of all host-names associated with this entry.
"""
def __init__(
self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes
):
self._hostnames: list[bytes] = hostnames
super().__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string: bytes) -> PlainEntry:
"""
Parse a plain-text entry in a known_hosts file, and return a
corresponding L{PlainEntry}.
@param string: a space-separated string formatted like "hostname
key-type base64-key-data comment".
@raise DecodeError: if the key is not valid encoded as valid base64.
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: an IKnownHostEntry representing the hostname and key in the
input line.
@rtype: L{PlainEntry}
"""
hostnames, keyType, key, comment = _extractCommon(string)
self = cls(hostnames.split(b","), keyType, key, comment)
return self
def matchesHost(self, hostname: bytes | str) -> bool:
"""
Check to see if this entry matches a given hostname.
@param hostname: A hostname or IP address literal to check against this
entry.
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
"""
if isinstance(hostname, str):
hostname = hostname.encode("utf-8")
return hostname in self._hostnames
def toString(self) -> bytes:
"""
Implement L{IKnownHostEntry.toString} by recording the comma-separated
hostnames, key type, and base-64 encoded key.
@return: The string representation of this entry, with unhashed hostname
information.
"""
fields = [
b",".join(self._hostnames),
self.keyType,
_b64encode(self.publicKey.blob()),
]
if self.comment is not None:
fields.append(self.comment)
return b" ".join(fields)
@implementer(IKnownHostEntry)
class UnparsedEntry:
"""
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
parsed; therefore it matches no keys and no hosts.
"""
def __init__(self, string):
"""
Create an unparsed entry from a line in a known_hosts file which cannot
otherwise be parsed.
"""
self._string = string
def matchesHost(self, hostname):
"""
Always returns False.
"""
return False
def matchesKey(self, key):
"""
Always returns False.
"""
return False
def toString(self):
"""
Returns the input line, without its newline if one was given.
@return: The string representation of this entry, almost exactly as was
used to initialize this entry but without a trailing newline.
@rtype: L{bytes}
"""
return self._string.rstrip(b"\n")
def _hmacedString(key, string):
"""
Return the SHA-1 HMAC hash of the given key and string.
@param key: The HMAC key.
@type key: L{bytes}
@param string: The string to be hashed.
@type string: L{bytes}
@return: The keyed hash value.
@rtype: L{bytes}
"""
hash = hmac.HMAC(key, digestmod=sha1)
if isinstance(string, str):
string = string.encode("utf-8")
hash.update(string)
return hash.digest()
@implementer(IKnownHostEntry)
class HashedEntry(_BaseEntry, FancyEqMixin):
"""
A L{HashedEntry} is a representation of an entry in a known_hosts file
where the hostname has been hashed and salted.
@ivar _hostSalt: the salt to combine with a hostname for hashing.
@ivar _hostHash: the hashed representation of the hostname.
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
known_hosts file as opposed to a plaintext one.
"""
MAGIC = b"|1|"
compareAttributes = ("_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
def __init__(
self,
hostSalt: bytes,
hostHash: bytes,
keyType: bytes,
publicKey: Key,
comment: bytes | None,
) -> None:
self._hostSalt = hostSalt
self._hostHash = hostHash
super().__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string: bytes) -> HashedEntry:
"""
Load a hashed entry from a string representing a line in a known_hosts
file.
@param string: A complete single line from a I{known_hosts} file,
formatted as defined by OpenSSH.
@raise DecodeError: if the key, the hostname, or the is not valid
encoded as valid base64
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid, or the host/hash portion
contains more items than just the host and hash.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: The newly created L{HashedEntry} instance, initialized with
the information from C{string}.
"""
stuff, keyType, key, comment = _extractCommon(string)
saltAndHash = stuff[len(cls.MAGIC) :].split(b"|")
if len(saltAndHash) != 2:
raise InvalidEntry()
hostSalt, hostHash = saltAndHash
self = cls(a2b_base64(hostSalt), a2b_base64(hostHash), keyType, key, comment)
return self
def matchesHost(self, hostname):
"""
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
input to the stored hash.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
return hmac.compare_digest(
_hmacedString(self._hostSalt, hostname), self._hostHash
)
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
hash, and key.
@return: The string representation of this entry, with the hostname part
hashed.
@rtype: L{bytes}
"""
fields = [
self.MAGIC
+ b"|".join([_b64encode(self._hostSalt), _b64encode(self._hostHash)]),
self.keyType,
_b64encode(self.publicKey.blob()),
]
if self.comment is not None:
fields.append(self.comment)
return b" ".join(fields)
class KnownHostsFile:
"""
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
to this instance in memory but not yet saved.
@ivar _clobber: A flag indicating whether the current contents of the save
path will be disregarded and potentially overwritten or not. If
C{True}, this will be done. If C{False}, entries in the save path will
be read and new entries will be saved by appending rather than
overwriting.
@type _clobber: L{bool}
@ivar _savePath: See C{savePath} parameter of L{__init__}.
"""
def __init__(self, savePath: FilePath[str]) -> None:
"""
Create a new, empty KnownHostsFile.
Unless you want to erase the current contents of C{savePath}, you want
to use L{KnownHostsFile.fromPath} instead.
@param savePath: The L{FilePath} to which to save new entries.
@type savePath: L{FilePath}
"""
self._added: list[IKnownHostEntry] = []
self._savePath = savePath
self._clobber = True
@property
def savePath(self) -> FilePath[str]:
"""
@see: C{savePath} parameter of L{__init__}
"""
return self._savePath
def iterentries(self):
"""
Iterate over the host entries in this file.
@return: An iterable the elements of which provide L{IKnownHostEntry}.
There is an element for each entry in the file as well as an element
for each added but not yet saved entry.
@rtype: iterable of L{IKnownHostEntry} providers
"""
for entry in self._added:
yield entry
if self._clobber:
return
try:
fp = self._savePath.open()
except OSError:
return
with fp:
for line in fp:
try:
if line.startswith(HashedEntry.MAGIC):
entry = HashedEntry.fromString(line)
else:
entry = PlainEntry.fromString(line)
except (DecodeError, InvalidEntry, BadKeyError):
entry = UnparsedEntry(line)
yield entry
def hasHostKey(self, hostname, key):
"""
Check for an entry with matching hostname and key.
@param hostname: A hostname or IP address literal to check for.
@type hostname: L{bytes}
@param key: The public key to check for.
@type key: L{Key}
@return: C{True} if the given hostname and key are present in this file,
C{False} if they are not.
@rtype: L{bool}
@raise HostKeyChanged: if the host key found for the given hostname
does not match the given key.
"""
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
if entry.matchesHost(hostname) and entry.keyType == key.sshType():
if entry.matchesKey(key):
return True
else:
# Notice that lineidx is 0-based but HostKeyChanged.lineno
# is 1-based.
if lineidx < 0:
line = None
path = None
else:
line = lineidx + 1
path = self._savePath
raise HostKeyChanged(entry, path, line)
return False
def verifyHostKey(
self, ui: ConsoleUI, hostname: bytes, ip: bytes, key: Key
) -> Deferred[bool]:
"""
Verify the given host key for the given IP and host, asking for
confirmation from, and notifying, the given UI about changes to this
file.
@param ui: The user interface to request an IP address from.
@param hostname: The hostname that the user requested to connect to.
@param ip: The string representation of the IP address that is actually
being connected to.
@param key: The public key of the server.
@return: a L{Deferred} that fires with True when the key has been
verified, or fires with an errback when the key either cannot be
verified or has changed.
@rtype: L{Deferred}
"""
hhk = defer.execute(self.hasHostKey, hostname, key)
def gotHasKey(result: bool) -> bool | Deferred[bool]:
if result:
if not self.hasHostKey(ip, key):
addMessage = (
f"Warning: Permanently added the {key.type()} host key"
f" for IP address '{ip.decode()}' to the list of known"
" hosts.\n"
)
ui.warn(addMessage.encode("utf-8"))
self.addHostKey(ip, key)
self.save()
return result
else:
def promptResponse(response: bool) -> bool:
if response:
self.addHostKey(hostname, key)
self.addHostKey(ip, key)
self.save()
return response
else:
raise UserRejectedKey()
keytype: str = key.type()
if keytype == "EC":
keytype = "ECDSA"
prompt = (
"The authenticity of host '%s (%s)' "
"can't be established.\n"
"%s key fingerprint is SHA256:%s.\n"
"Are you sure you want to continue connecting (yes/no)? "
% (
nativeString(hostname),
nativeString(ip),
keytype,
key.fingerprint(format=FingerprintFormats.SHA256_BASE64),
)
)
proceed = ui.prompt(prompt.encode(sys.getdefaultencoding()))
return proceed.addCallback(promptResponse)
return hhk.addCallback(gotHasKey)
def addHostKey(self, hostname: bytes, key: Key) -> HashedEntry:
"""
Add a new L{HashedEntry} to the key database.
Note that you still need to call L{KnownHostsFile.save} if you wish
these changes to be persisted.
@param hostname: A hostname or IP address literal to associate with the
new entry.
@type hostname: L{bytes}
@param key: The public key to associate with the new entry.
@type key: L{Key}
@return: The L{HashedEntry} that was added.
@rtype: L{HashedEntry}
"""
salt = secureRandom(20)
keyType = key.sshType()
entry = HashedEntry(salt, _hmacedString(salt, hostname), keyType, key, None)
self._added.append(entry)
return entry
def save(self) -> None:
"""
Save this L{KnownHostsFile} to the path it was loaded from.
"""
p = self._savePath.parent()
if not p.isdir():
p.makedirs()
mode: Literal["a", "w"] = "w" if self._clobber else "a"
with self._savePath.open(mode) as hostsFileObj:
if self._added:
hostsFileObj.write(
b"\n".join([entry.toString() for entry in self._added]) + b"\n"
)
self._added = []
self._clobber = False
@classmethod
def fromPath(cls, path: FilePath[str]) -> KnownHostsFile:
"""
Create a new L{KnownHostsFile}, potentially reading existing known
hosts information from the given file.
@param path: A path object to use for both reading contents from and
later saving to. If no file exists at this path, it is not an
error; a L{KnownHostsFile} with no entries is returned.
@return: A L{KnownHostsFile} initialized with entries from C{path}.
"""
knownHosts = cls(path)
knownHosts._clobber = False
return knownHosts
class ConsoleUI:
"""
A UI object that can ask true/false questions and post notifications on the
console, to be used during key verification.
"""
def __init__(self, opener: Callable[[], IO[bytes]]):
"""
@param opener: A no-argument callable which should open a console
binary-mode file-like object to be used for reading and writing.
This initializes the C{opener} attribute.
@type opener: callable taking no arguments and returning a read/write
file-like object
"""
self.opener = opener
def prompt(self, text: bytes) -> Deferred[bool]:
"""
Write the given text as a prompt to the console output, then read a
result from the console input.
@param text: Something to present to a user to solicit a yes or no
response.
@type text: L{bytes}
@return: a L{Deferred} which fires with L{True} when the user answers
'yes' and L{False} when the user answers 'no'. It may errback if
there were any I/O errors.
"""
d = defer.succeed(None)
def body(ignored):
with closing(self.opener()) as f:
f.write(text)
while True:
answer = f.readline().strip().lower()
if answer == b"yes":
return True
elif answer in {b"no", b""}:
return False
else:
f.write(b"Please type 'yes' or 'no': ")
return d.addCallback(body)
def warn(self, text: bytes) -> None:
"""
Notify the user (non-interactively) of the provided text, by writing it
to the console.
@param text: Some information the user is to be made aware of.
"""
try:
with closing(self.opener()) as f:
f.write(text)
except Exception:
log.failure("Failed to write to console")

View File

@@ -0,0 +1,109 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys
from typing import List, Optional, Union
#
from twisted.conch.ssh.transport import SSHCiphers, SSHClientTransport
from twisted.python import usage
class ConchOptions(usage.Options):
optParameters: List[List[Optional[Union[str, int]]]] = [
["user", "l", None, "Log in using this user name."],
["identity", "i", None],
["ciphers", "c", None],
["macs", "m", None],
["port", "p", None, "Connect to this port. Server must be on the same port."],
["option", "o", None, "Ignored OpenSSH options"],
["host-key-algorithms", "", None],
["known-hosts", "", None, "File to check for host keys"],
["user-authentications", "", None, "Types of user authentications to use."],
["logfile", "", None, "File to log to, or - for stdout"],
]
optFlags = [
["version", "V", "Display version number only."],
["compress", "C", "Enable compression."],
["log", "v", "Enable logging (defaults to stderr)"],
["nox11", "x", "Disable X11 connection forwarding (default)"],
["agent", "A", "Enable authentication agent forwarding"],
["noagent", "a", "Disable authentication agent forwarding (default)"],
["reconnect", "r", "Reconnect to the server if the connection is lost."],
]
compData = usage.Completions(
mutuallyExclusive=[("agent", "noagent")],
optActions={
"user": usage.CompleteUsernames(),
"ciphers": usage.CompleteMultiList(
[v.decode() for v in SSHCiphers.cipherMap.keys()],
descr="ciphers to choose from",
),
"macs": usage.CompleteMultiList(
[v.decode() for v in SSHCiphers.macMap.keys()],
descr="macs to choose from",
),
"host-key-algorithms": usage.CompleteMultiList(
[v.decode() for v in SSHClientTransport.supportedPublicKeys],
descr="host key algorithms to choose from",
),
# "user-authentications": usage.CompleteMultiList(?
# descr='user authentication types' ),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
def __init__(self, *args, **kw):
usage.Options.__init__(self, *args, **kw)
self.identitys = []
self.conns = None
def opt_identity(self, i):
"""Identity for public-key authentication"""
self.identitys.append(i)
def opt_ciphers(self, ciphers):
"Select encryption algorithms"
ciphers = ciphers.split(",")
for cipher in ciphers:
if cipher not in SSHCiphers.cipherMap:
sys.exit("Unknown cipher type '%s'" % cipher)
self["ciphers"] = ciphers
def opt_macs(self, macs):
"Specify MAC algorithms"
if isinstance(macs, str):
macs = macs.encode("utf-8")
macs = macs.split(b",")
for mac in macs:
if mac not in SSHCiphers.macMap:
sys.exit("Unknown mac type '%r'" % mac)
self["macs"] = macs
def opt_host_key_algorithms(self, hkas):
"Select host key algorithms"
if isinstance(hkas, str):
hkas = hkas.encode("utf-8")
hkas = hkas.split(b",")
for hka in hkas:
if hka not in SSHClientTransport.supportedPublicKeys:
sys.exit("Unknown host key type '%r'" % hka)
self["host-key-algorithms"] = hkas
def opt_user_authentications(self, uas):
"Choose how to authenticate to the remote server"
if isinstance(uas, str):
uas = uas.encode("utf-8")
self["user-authentications"] = uas.split(b",")
# def opt_compress(self):
# "Enable compression"
# self.enableCompression = 1
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']

View File

@@ -0,0 +1,845 @@
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Endpoint implementations of various SSH interactions.
"""
from __future__ import annotations
__all__ = [
"AuthenticationFailed",
"SSHCommandAddress",
"SSHCommandClientEndpoint",
]
import signal
from io import BytesIO
from os.path import expanduser
from struct import unpack
from typing import IO, Any
from zope.interface import Interface, implementer
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import _KNOWN_HOSTS
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.ssh.common import NS, getNS
from twisted.conch.ssh.connection import SSHConnection
from twisted.conch.ssh.keys import Key
from twisted.conch.ssh.transport import SSHClientTransport
from twisted.conch.ssh.userauth import SSHUserAuthClient
from twisted.internet.defer import CancelledError, Deferred, succeed
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
from twisted.internet.error import ConnectionDone, ProcessTerminated
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Factory
from twisted.logger import Logger
from twisted.python.compat import nativeString, networkString
from twisted.python.failure import Failure
from twisted.python.filepath import FilePath
class AuthenticationFailed(Exception):
"""
An SSH session could not be established because authentication was not
successful.
"""
# This should be public. See #6541.
class _ISSHConnectionCreator(Interface):
"""
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
"""
def secureConnection():
"""
Return a new, connected, secured, but not yet authenticated instance of
L{twisted.conch.ssh.transport.SSHServerTransport} or
L{twisted.conch.ssh.transport.SSHClientTransport}.
"""
def cleanupConnection(connection, immediate):
"""
Perform cleanup necessary for a connection object previously returned
from this creator's C{secureConnection} method.
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
previous call to C{secureConnection}. It is no longer needed by
the caller of that method and may be closed or otherwise cleaned up
as necessary.
@param immediate: If C{True} don't wait for any network communication,
just close the connection immediately and as aggressively as
necessary.
"""
class SSHCommandAddress:
"""
An L{SSHCommandAddress} instance represents the address of an SSH server, a
username which was used to authenticate with that server, and a command
which was run there.
@ivar server: See L{__init__}
@ivar username: See L{__init__}
@ivar command: See L{__init__}
"""
def __init__(self, server, username, command):
"""
@param server: The address of the SSH server on which the command is
running.
@type server: L{IAddress} provider
@param username: An authentication username which was used to
authenticate against the server at the given address.
@type username: L{bytes}
@param command: A command which was run in a session channel on the
server at the given address.
@type command: L{bytes}
"""
self.server = server
self.username = username
self.command = command
class _CommandChannel(SSHChannel):
"""
A L{_CommandChannel} executes a command in a session channel and connects
its input and output to an L{IProtocol} provider.
@ivar _creator: See L{__init__}
@ivar _command: See L{__init__}
@ivar _protocolFactory: See L{__init__}
@ivar _commandConnected: See L{__init__}
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
which is hooked up to the running command's input and output streams.
"""
name = b"session"
_log = Logger()
def __init__(self, creator, command, protocolFactory, commandConnected):
"""
@param creator: The L{_ISSHConnectionCreator} provider which was used
to get the connection which this channel exists on.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command to be executed.
@type command: L{bytes}
@param protocolFactory: A client factory to use to build a L{IProtocol}
provider to use to associate with the running command.
@param commandConnected: A L{Deferred} to use to signal that execution
of the command has failed or that it has succeeded and the command
is now running.
@type commandConnected: L{Deferred}
"""
SSHChannel.__init__(self)
self._creator = creator
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
self._reason = None
def openFailed(self, reason):
"""
When the request to open a new channel to run this command in fails,
fire the C{commandConnected} deferred with a failure indicating that.
"""
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
"""
When the request to open a new channel to run this command in succeeds,
issue an C{"exec"} request to run the command.
"""
command = self.conn.sendRequest(
self, b"exec", NS(self._command), wantReply=True
)
command.addCallbacks(self._execSuccess, self._execFailure)
def _execFailure(self, reason):
"""
When the request to execute the command in this channel fails, fire the
C{commandConnected} deferred with a failure indicating this.
@param reason: The cause of the command execution failure.
@type reason: L{Failure}
"""
self._commandConnected.errback(reason)
def _execSuccess(self, ignored):
"""
When the request to execute the command in this channel succeeds, use
C{protocolFactory} to build a protocol to handle the command's input
and output and connect the protocol to a transport representing those
streams.
Also fire C{commandConnected} with the created protocol after it is
connected to its transport.
@param ignored: The (ignored) result of the execute request
"""
self._protocol = self._protocolFactory.buildProtocol(
SSHCommandAddress(
self.conn.transport.transport.getPeer(),
self.conn.transport.creator.username,
self.conn.transport.creator.command,
)
)
self._protocol.makeConnection(self)
self._commandConnected.callback(self._protocol)
def dataReceived(self, data):
"""
When the command's stdout data arrives over the channel, deliver it to
the protocol instance.
@param data: The bytes from the command's stdout.
@type data: L{bytes}
"""
self._protocol.dataReceived(data)
def request_exit_status(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
status of the command.
@type data: L{bytes}
"""
(status,) = unpack(">L", data)
if status != 0:
self._reason = ProcessTerminated(status, None, None)
def request_exit_signal(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
signal of the command.
@type data: L{bytes}
"""
shortSignalName, data = getNS(data)
coreDumped, data = bool(ord(data[0:1])), data[1:]
errorMessage, data = getNS(data)
languageTag, data = getNS(data)
signalName = f"SIG{nativeString(shortSignalName)}"
signalID = getattr(signal, signalName, -1)
self._log.info(
"Process exited with signal {shortSignalName!r};"
" core dumped: {coreDumped};"
" error message: {errorMessage};"
" language: {languageTag!r}",
shortSignalName=shortSignalName,
coreDumped=coreDumped,
errorMessage=errorMessage.decode("utf-8"),
languageTag=languageTag,
)
self._reason = ProcessTerminated(None, signalID, None)
def closed(self):
"""
When the channel closes, deliver disconnection notification to the
protocol.
"""
self._creator.cleanupConnection(self.conn, False)
if self._reason is None:
reason = ConnectionDone("ssh channel closed")
else:
reason = self._reason
self._protocol.connectionLost(Failure(reason))
class _ConnectionReady(SSHConnection):
"""
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
propagates the I{serviceStarted} event to a L{Deferred} to be handled
elsewhere.
"""
def __init__(self, ready):
"""
@param ready: A L{Deferred} which should be fired when
I{serviceStarted} happens.
"""
SSHConnection.__init__(self)
self._ready = ready
def serviceStarted(self):
"""
When the SSH I{connection} I{service} this object represents is ready
to be used, fire the C{connectionReady} L{Deferred} to publish that
event to some other interested party.
"""
self._ready.callback(self)
del self._ready
class _UserAuth(SSHUserAuthClient):
"""
L{_UserAuth} implements the client part of SSH user authentication in the
convenient way a user might expect if they are familiar with the
interactive I{ssh} command line client.
L{_UserAuth} supports key-based authentication, password-based
authentication, and delegating authentication to an agent.
"""
password = None
keys = None
agent = None
def getPublicKey(self):
"""
Retrieve the next public key object to offer to the server, possibly
delegating to an authentication agent if there is one.
@return: The public part of a key pair that could be used to
authenticate with the server, or L{None} if there are no more
public keys to try.
@rtype: L{twisted.conch.ssh.keys.Key} or L{None}
"""
if self.agent is not None:
return self.agent.getPublicKey()
if self.keys:
self.key = self.keys.pop(0)
else:
self.key = None
return self.key.public()
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{str}
"""
if self.agent is not None:
return self.agent.signData(publicKey.blob(), signData)
else:
return SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Get the private part of a key pair to use for authentication. The key
corresponds to the public part most recently returned from
C{getPublicKey}.
@return: A L{Deferred} which fires with the private key.
@rtype: L{Deferred}
"""
return succeed(self.key)
def getPassword(self):
"""
Get the password to use for authentication.
@return: A L{Deferred} which fires with the password, or L{None} if the
password was not specified.
"""
if self.password is None:
return
return succeed(self.password)
def ssh_USERAUTH_SUCCESS(self, packet):
"""
Handle user authentication success in the normal way, but also make a
note of the state change on the L{_CommandTransport}.
"""
self.transport._state = b"CHANNELLING"
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
def connectToAgent(self, endpoint):
"""
Set up a connection to the authentication agent and trigger its
initialization.
@param endpoint: An endpoint which can be used to connect to the
authentication agent.
@type endpoint: L{IStreamClientEndpoint} provider
@return: A L{Deferred} which fires when the agent connection is ready
for use.
"""
factory = Factory()
factory.protocol = SSHAgentClient
d = endpoint.connect(factory)
def connected(agent):
self.agent = agent
return agent.getPublicKeys()
d.addCallback(connected)
return d
def loseAgentConnection(self):
"""
Disconnect the agent.
"""
if self.agent is None:
return
self.agent.transport.loseConnection()
class _CommandTransport(SSHClientTransport):
"""
L{_CommandTransport} is an SSH client I{transport} which includes a host
key verification step before it will proceed to secure the connection.
L{_CommandTransport} also knows how to set up a connection to an
authentication agent if it is told where it can connect to one.
@ivar _userauth: The L{_UserAuth} instance which is in charge of the
overall authentication process or L{None} if the SSH connection has not
reach yet the C{user-auth} service.
@type _userauth: L{_UserAuth}
"""
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
_state = b"STARTING"
_hostKeyFailure = None
_userauth = None
def __init__(self, creator):
"""
@param creator: The L{_NewConnectionHelper} that created this
connection.
@type creator: L{_NewConnectionHelper}.
"""
self.connectionReady = Deferred(lambda d: self.transport.abortConnection())
# Clear the reference to that deferred to help the garbage collector
# and to signal to other parts of this implementation (in particular
# connectionLost) that it has already been fired and does not need to
# be fired again.
def readyFired(result):
self.connectionReady = None
return result
self.connectionReady.addBoth(readyFired)
self.creator = creator
def verifyHostKey(self, hostKey, fingerprint):
"""
Ask the L{KnownHostsFile} provider available on the factory which
created this protocol this protocol to verify the given host key.
@return: A L{Deferred} which fires with the result of
L{KnownHostsFile.verifyHostKey}.
"""
hostname = self.creator.hostname
ip = networkString(self.transport.getPeer().host)
self._state = b"SECURING"
d = self.creator.knownHosts.verifyHostKey(
self.creator.ui, hostname, ip, Key.fromString(hostKey)
)
d.addErrback(self._saveHostKeyFailure)
return d
def _saveHostKeyFailure(self, reason):
"""
When host key verification fails, record the reason for the failure in
order to fire a L{Deferred} with it later.
@param reason: The cause of the host key verification failure.
@type reason: L{Failure}
@return: C{reason}
@rtype: L{Failure}
"""
self._hostKeyFailure = reason
return reason
def connectionSecure(self):
"""
When the connection is secure, start the authentication process.
"""
self._state = b"AUTHENTICATING"
command = _ConnectionReady(self.connectionReady)
self._userauth = _UserAuth(self.creator.username, command)
self._userauth.password = self.creator.password
if self.creator.keys:
self._userauth.keys = list(self.creator.keys)
if self.creator.agentEndpoint is not None:
d = self._userauth.connectToAgent(self.creator.agentEndpoint)
else:
d = succeed(None)
def maybeGotAgent(ignored):
self.requestService(self._userauth)
d.addBoth(maybeGotAgent)
def connectionLost(self, reason):
"""
When the underlying connection to the SSH server is lost, if there were
any connection setup errors, propagate them. Also, clean up the
connection to the ssh agent if one was created.
"""
if self._userauth:
self._userauth.loseAgentConnection()
if self._state == b"RUNNING" or self.connectionReady is None:
return
if self._state == b"SECURING" and self._hostKeyFailure is not None:
reason = self._hostKeyFailure
elif self._state == b"AUTHENTICATING":
reason = Failure(
AuthenticationFailed("Connection lost while authenticating")
)
self.connectionReady.errback(reason)
@implementer(IStreamClientEndpoint)
class SSHCommandClientEndpoint:
"""
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
SSH servers.
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
it in any one of a number of different ways (keys, passwords, agents),
launch a command over that connection and then associate its input and
output with a protocol.
It can also re-use an existing, already-authenticated SSH connection
(perhaps one which already has some SSH channels being used for other
purposes). In this case it creates a new SSH channel to use to execute the
command. Notably this means it supports multiplexing several different
command invocations over a single SSH connection.
"""
def __init__(self, creator, command):
"""
@param creator: An L{_ISSHConnectionCreator} provider which will be
used to set up the SSH connection which will be used to run a
command.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command line to execute on the SSH server. This
byte string is interpreted by a shell on the SSH server, so it may
have a value like C{"ls /"}. Take care when trying to run a
command like C{"/Volumes/My Stuff/a-program"} - spaces (and other
special bytes) may require escaping.
@type command: L{bytes}
"""
self._creator = creator
self._command = command
@classmethod
def newConnection(
cls,
reactor,
command,
username,
hostname,
port=None,
keys=None,
password=None,
agentEndpoint=None,
knownHosts=None,
ui=None,
):
"""
Create and return a new endpoint which will try to create a new
connection to an SSH server and run a command over it. It will also
close the connection if there are problems leading up to the command
being executed, after the command finishes, or if the connection
L{Deferred} is cancelled.
@param reactor: The reactor to use to establish the connection.
@type reactor: L{IReactorTCP} provider
@param command: See L{__init__}'s C{command} argument.
@param username: The username with which to authenticate to the SSH
server.
@type username: L{bytes}
@param hostname: The hostname of the SSH server.
@type hostname: L{bytes}
@param port: The port number of the SSH server. By default, the
standard SSH port number is used.
@type port: L{int}
@param keys: Private keys with which to authenticate to the SSH server,
if key authentication is to be attempted (otherwise L{None}).
@type keys: L{list} of L{Key}
@param password: The password with which to authenticate to the SSH
server, if password authentication is to be attempted (otherwise
L{None}).
@type password: L{bytes} or L{None}
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
used to connect to an SSH agent, if one is to be used to help with
authentication.
@type agentEndpoint: L{IStreamClientEndpoint} provider
@param knownHosts: The currently known host keys, used to check the
host key presented by the server we actually connect to.
@type knownHosts: L{KnownHostsFile}
@param ui: An object for interacting with users to make decisions about
whether to accept the server host keys. If L{None}, a L{ConsoleUI}
connected to /dev/tty will be used; if /dev/tty is unavailable, an
object which answers C{b"no"} to all prompts will be used.
@type ui: L{None} or L{ConsoleUI}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _NewConnectionHelper(
reactor,
hostname,
port,
command,
username,
keys,
password,
agentEndpoint,
knownHosts,
ui,
)
return cls(helper, command)
@classmethod
def existingConnection(cls, connection, command):
"""
Create and return a new endpoint which will try to open a new channel
on an existing SSH connection and run a command over it. It will
B{not} close the connection if there is a problem executing the command
or after the command finishes.
@param connection: An existing connection to an SSH server.
@type connection: L{SSHConnection}
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
C{command} parameter.
@type command: L{bytes}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _ExistingConnectionHelper(connection)
return cls(helper, command)
def connect(self, protocolFactory):
"""
Set up an SSH connection, use a channel from that connection to launch
a command, and hook the stdin and stdout of that command up as a
transport for a protocol created by the given factory.
@param protocolFactory: A L{Factory} to use to create the protocol
which will be connected to the stdin and stdout of the command on
the SSH server.
@return: A L{Deferred} which will fire with an error if the connection
cannot be set up for any reason or with the protocol instance
created by C{protocolFactory} once it has been connected to the
command.
"""
d = self._creator.secureConnection()
d.addCallback(self._executeCommand, protocolFactory)
return d
def _executeCommand(self, connection, protocolFactory):
"""
Given a secured SSH connection, try to execute a command in a new
channel created on it and associate the result with a protocol from the
given factory.
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
C{protocolFactory} parameter.
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
"""
commandConnected = Deferred()
def disconnectOnFailure(passthrough):
# Close the connection immediately in case of cancellation, since
# that implies user wants it gone immediately (e.g. a timeout):
immediate = passthrough.check(CancelledError)
self._creator.cleanupConnection(connection, immediate)
return passthrough
commandConnected.addErrback(disconnectOnFailure)
channel = _CommandChannel(
self._creator, self._command, protocolFactory, commandConnected
)
connection.openChannel(channel)
return commandConnected
@implementer(_ISSHConnectionCreator)
class _NewConnectionHelper:
"""
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
establishing a brand new SSH connection, securing it, and authenticating.
"""
_KNOWN_HOSTS = _KNOWN_HOSTS
port = 22
def __init__(
self,
reactor: Any,
hostname: str,
port: int,
command: str,
username: str,
keys: str,
password: str,
agentEndpoint: str,
knownHosts: str | None,
ui: ConsoleUI | None,
tty: FilePath[bytes] | FilePath[str] = FilePath(b"/dev/tty"),
):
"""
@param tty: The path of the tty device to use in case C{ui} is L{None}.
@type tty: L{FilePath}
@see: L{SSHCommandClientEndpoint.newConnection}
"""
self.reactor = reactor
self.hostname = hostname
if port is not None:
self.port = port
self.command = command
self.username = username
self.keys = keys
self.password = password
self.agentEndpoint = agentEndpoint
if knownHosts is None:
knownHosts = self._knownHosts()
self.knownHosts = knownHosts
if ui is None:
ui = ConsoleUI(self._opener)
self.ui = ui
self.tty: FilePath[bytes] | FilePath[str] = tty
def _opener(self) -> IO[bytes]:
"""
Open the tty if possible, otherwise give back a file-like object from
which C{b"no"} can be read.
For use as the opener argument to L{ConsoleUI}.
"""
try:
return self.tty.open("r+")
except BaseException:
# Give back a file-like object from which can be read a byte string
# that KnownHostsFile recognizes as rejecting some option (b"no").
return BytesIO(b"no")
@classmethod
def _knownHosts(cls):
"""
@return: A L{KnownHostsFile} instance pointed at the user's personal
I{known hosts} file.
@rtype: L{KnownHostsFile}
"""
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
def secureConnection(self):
"""
Create and return a new SSH connection which has been secured and on
which authentication has already happened.
@return: A L{Deferred} which fires with the ready-to-use connection or
with a failure if something prevents the connection from being
setup, secured, or authenticated.
"""
protocol = _CommandTransport(self)
ready = protocol.connectionReady
sshClient = TCP4ClientEndpoint(
self.reactor, nativeString(self.hostname), self.port
)
d = connectProtocol(sshClient, protocol)
d.addCallback(lambda ignored: ready)
return d
def cleanupConnection(self, connection, immediate):
"""
Clean up the connection by closing it. The command running on the
endpoint has ended so the connection is no longer needed.
@param connection: The L{SSHConnection} to close.
@type connection: L{SSHConnection}
@param immediate: Whether to close connection immediately.
@type immediate: L{bool}.
"""
if immediate:
# We're assuming the underlying connection is an ITCPTransport,
# which is what the current implementation is restricted to:
connection.transport.transport.abortConnection()
else:
connection.transport.loseConnection()
@implementer(_ISSHConnectionCreator)
class _ExistingConnectionHelper:
"""
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
handing out an existing SSH connection which is supplied to its
initializer.
"""
def __init__(self, connection):
"""
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
"""
self.connection = connection
def secureConnection(self):
"""
@return: A L{Deferred} that fires synchronously with the
already-established connection object.
"""
return succeed(self.connection)
def cleanupConnection(self, connection, immediate):
"""
Do not do any cleanup on the connection. Leave that responsibility to
whatever code created it in the first place.
@param connection: The L{SSHConnection} which will not be modified in
any way.
@type connection: L{SSHConnection}
@param immediate: An argument which will be ignored.
@type immediate: L{bool}.
"""

View File

@@ -0,0 +1,96 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An error to represent bad things happening in Conch.
Maintainer: Paul Swartz
"""
from twisted.cred.error import UnauthorizedLogin
class ConchError(Exception):
def __init__(self, value, data=None):
Exception.__init__(self, value, data)
self.value = value
self.data = data
class NotEnoughAuthentication(Exception):
"""
This is thrown if the authentication is valid, but is not enough to
successfully verify the user. i.e. don't retry this type of
authentication, try another one.
"""
class ValidPublicKey(UnauthorizedLogin):
"""
Raised by public key checkers when they receive public key credentials
that don't contain a signature at all, but are valid in every other way.
(e.g. the public key matches one in the user's authorized_keys file).
Protocol code (eg
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
attempts to log in using
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
should be prepared to handle a failure of this type by telling the user to
re-authenticate using the same key and to include a signature with the new
attempt.
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
"""
class IgnoreAuthentication(Exception):
"""
This is thrown to let the UserAuthServer know it doesn't need to handle the
authentication anymore.
"""
class MissingKeyStoreError(Exception):
"""
Raised if an SSHAgentServer starts receiving data without its factory
providing a keys dict on which to read/write key data.
"""
class UserRejectedKey(Exception):
"""
The user interactively rejected a key.
"""
class InvalidEntry(Exception):
"""
An entry in a known_hosts file could not be interpreted as a valid entry.
"""
class HostKeyChanged(Exception):
"""
The host key of a remote host has changed.
@ivar offendingEntry: The entry which contains the persistent host key that
disagrees with the given host key.
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
@ivar path: a reference to the known_hosts file that the offending entry
was loaded from
@type path: L{twisted.python.filepath.FilePath}
@ivar lineno: The line number of the offending entry in the given path.
@type lineno: L{int}
"""
def __init__(self, offendingEntry, path, lineno):
Exception.__init__(self)
self.offendingEntry = offendingEntry
self.path = path
self.lineno = lineno

View File

@@ -0,0 +1,4 @@
"""
Insults: a replacement for Curses/S-Lang.
Very basic at the moment."""

View File

@@ -0,0 +1,556 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Partial in-memory terminal emulator
@author: Jp Calderone
"""
import re
import string
from zope.interface import implementer
from incremental import Version
from twisted.conch.insults import insults
from twisted.internet import defer, protocol, reactor
from twisted.logger import Logger
from twisted.python import _textattributes
from twisted.python.compat import iterbytes
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
FOREGROUND = 30
BACKGROUND = 40
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
class _FormattingState(_textattributes._FormattingStateMixin):
"""
Represents the formatting state/attributes of a single character.
Character set, intensity, underlinedness, blinkitude, video
reversal, as well as foreground and background colors made up a
character's attributes.
"""
compareAttributes = (
"charset",
"bold",
"underline",
"blink",
"reverseVideo",
"foreground",
"background",
"_subtracting",
)
def __init__(
self,
charset=insults.G0,
bold=False,
underline=False,
blink=False,
reverseVideo=False,
foreground=WHITE,
background=BLACK,
_subtracting=False,
):
self.charset = charset
self.bold = bold
self.underline = underline
self.blink = blink
self.reverseVideo = reverseVideo
self.foreground = foreground
self.background = background
self._subtracting = _subtracting
@deprecated(Version("Twisted", 13, 1, 0))
def wantOne(self, **kw):
"""
Add a character attribute to a copy of this formatting state.
@param kw: An optional attribute name and value can be provided with
a keyword argument.
@return: A formatting state instance with the new attribute.
@see: L{DefaultFormattingState._withAttribute}.
"""
k, v = kw.popitem()
return self._withAttribute(k, v)
def toVT102(self):
# Spit out a vt102 control sequence that will set up
# all the attributes set here. Except charset.
attrs = []
if self._subtracting:
attrs.append(0)
if self.bold:
attrs.append(insults.BOLD)
if self.underline:
attrs.append(insults.UNDERLINE)
if self.blink:
attrs.append(insults.BLINK)
if self.reverseVideo:
attrs.append(insults.REVERSE_VIDEO)
if self.foreground != WHITE:
attrs.append(FOREGROUND + self.foreground)
if self.background != BLACK:
attrs.append(BACKGROUND + self.background)
if attrs:
return "\x1b[" + ";".join(map(str, attrs)) + "m"
return ""
CharacterAttribute = _FormattingState
deprecatedModuleAttribute(
Version("Twisted", 13, 1, 0),
"Use twisted.conch.insults.text.assembleFormattedText instead.",
"twisted.conch.insults.helper",
"CharacterAttribute",
)
# XXX - need to support scroll regions and scroll history
@implementer(insults.ITerminalTransport)
class TerminalBuffer(protocol.Protocol):
"""
An in-memory terminal emulator.
"""
for keyID in (
b"UP_ARROW",
b"DOWN_ARROW",
b"RIGHT_ARROW",
b"LEFT_ARROW",
b"HOME",
b"INSERT",
b"DELETE",
b"END",
b"PGUP",
b"PGDN",
b"F1",
b"F2",
b"F3",
b"F4",
b"F5",
b"F6",
b"F7",
b"F8",
b"F9",
b"F10",
b"F11",
b"F12",
):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b"\t"
BACKSPACE = b"\x7f"
width = 80
height = 24
fill = b" "
void = object()
_log = Logger()
def getCharacter(self, x, y):
return self.lines[y][x]
def connectionMade(self):
self.reset()
def write(self, data):
"""
Add the given printable bytes to the terminal.
Line feeds in L{bytes} will be replaced with carriage return / line
feed pairs.
"""
for b in iterbytes(data.replace(b"\n", b"\r\n")):
self.insertAtCursor(b)
def _currentFormattingState(self):
return _FormattingState(self.activeCharset, **self.graphicRendition)
def insertAtCursor(self, b):
"""
Add one byte to the terminal at the cursor and make consequent state
updates.
If b is a carriage return, move the cursor to the beginning of the
current row.
If b is a line feed, move the cursor to the next row or scroll down if
the cursor is already in the last row.
Otherwise, if b is printable, put it at the cursor position (inserting
or overwriting as dictated by the current mode) and move the cursor.
"""
if b == b"\r":
self.x = 0
elif b == b"\n":
self._scrollDown()
elif b in string.printable.encode("ascii"):
if self.x >= self.width:
self.nextLine()
ch = (b, self._currentFormattingState())
if self.modes.get(insults.modes.IRM):
self.lines[self.y][self.x : self.x] = [ch]
self.lines[self.y].pop()
else:
self.lines[self.y][self.x] = ch
self.x += 1
def _emptyLine(self, width):
return [(self.void, self._currentFormattingState()) for i in range(width)]
def _scrollDown(self):
self.y += 1
if self.y >= self.height:
self.y -= 1
del self.lines[0]
self.lines.append(self._emptyLine(self.width))
def _scrollUp(self):
self.y -= 1
if self.y < 0:
self.y = 0
del self.lines[-1]
self.lines.insert(0, self._emptyLine(self.width))
def cursorUp(self, n=1):
self.y = max(0, self.y - n)
def cursorDown(self, n=1):
self.y = min(self.height - 1, self.y + n)
def cursorBackward(self, n=1):
self.x = max(0, self.x - n)
def cursorForward(self, n=1):
self.x = min(self.width, self.x + n)
def cursorPosition(self, column, line):
self.x = column
self.y = line
def cursorHome(self):
self.x = self.home.x
self.y = self.home.y
def index(self):
self._scrollDown()
def reverseIndex(self):
self._scrollUp()
def nextLine(self):
"""
Update the cursor position attributes and scroll down if appropriate.
"""
self.x = 0
self._scrollDown()
def saveCursor(self):
self._savedCursor = (self.x, self.y)
def restoreCursor(self):
self.x, self.y = self._savedCursor
del self._savedCursor
def setModes(self, modes):
for m in modes:
self.modes[m] = True
def resetModes(self, modes):
for m in modes:
try:
del self.modes[m]
except KeyError:
pass
def setPrivateModes(self, modes):
"""
Enable the given modes.
Track which modes have been enabled so that the implementations of
other L{insults.ITerminalTransport} methods can be properly implemented
to respect these settings.
@see: L{resetPrivateModes}
@see: L{insults.ITerminalTransport.setPrivateModes}
"""
for m in modes:
self.privateModes[m] = True
def resetPrivateModes(self, modes):
"""
Disable the given modes.
@see: L{setPrivateModes}
@see: L{insults.ITerminalTransport.resetPrivateModes}
"""
for m in modes:
try:
del self.privateModes[m]
except KeyError:
pass
def applicationKeypadMode(self):
self.keypadMode = "app"
def numericKeypadMode(self):
self.keypadMode = "num"
def selectCharacterSet(self, charSet, which):
self.charsets[which] = charSet
def shiftIn(self):
self.activeCharset = insults.G0
def shiftOut(self):
self.activeCharset = insults.G1
def singleShift2(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G2
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def singleShift3(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G3
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def selectGraphicRendition(self, *attributes):
for a in attributes:
if a == insults.NORMAL:
self.graphicRendition = {
"bold": False,
"underline": False,
"blink": False,
"reverseVideo": False,
"foreground": WHITE,
"background": BLACK,
}
elif a == insults.BOLD:
self.graphicRendition["bold"] = True
elif a == insults.UNDERLINE:
self.graphicRendition["underline"] = True
elif a == insults.BLINK:
self.graphicRendition["blink"] = True
elif a == insults.REVERSE_VIDEO:
self.graphicRendition["reverseVideo"] = True
else:
try:
v = int(a)
except ValueError:
self._log.error(
"Unknown graphic rendition attribute: {attr!r}", attr=a
)
else:
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
self.graphicRendition["foreground"] = v - FOREGROUND
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
self.graphicRendition["background"] = v - BACKGROUND
else:
self._log.error(
"Unknown graphic rendition attribute: {attr!r}", attr=a
)
def eraseLine(self):
self.lines[self.y] = self._emptyLine(self.width)
def eraseToLineEnd(self):
width = self.width - self.x
self.lines[self.y][self.x :] = self._emptyLine(width)
def eraseToLineBeginning(self):
self.lines[self.y][: self.x + 1] = self._emptyLine(self.x + 1)
def eraseDisplay(self):
self.lines = [self._emptyLine(self.width) for i in range(self.height)]
def eraseToDisplayEnd(self):
self.eraseToLineEnd()
height = self.height - self.y - 1
self.lines[self.y + 1 :] = [self._emptyLine(self.width) for i in range(height)]
def eraseToDisplayBeginning(self):
self.eraseToLineBeginning()
self.lines[: self.y] = [self._emptyLine(self.width) for i in range(self.y)]
def deleteCharacter(self, n=1):
del self.lines[self.y][self.x : self.x + n]
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
def insertLine(self, n=1):
self.lines[self.y : self.y] = [self._emptyLine(self.width) for i in range(n)]
del self.lines[self.height :]
def deleteLine(self, n=1):
del self.lines[self.y : self.y + n]
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
def reportCursorPosition(self):
return (self.x, self.y)
def reset(self):
self.home = insults.Vector(0, 0)
self.x = self.y = 0
self.modes = {}
self.privateModes = {}
self.setPrivateModes(
[insults.privateModes.AUTO_WRAP, insults.privateModes.CURSOR_MODE]
)
self.numericKeypad = "app"
self.activeCharset = insults.G0
self.graphicRendition = {
"bold": False,
"underline": False,
"blink": False,
"reverseVideo": False,
"foreground": WHITE,
"background": BLACK,
}
self.charsets = {
insults.G0: insults.CS_US,
insults.G1: insults.CS_US,
insults.G2: insults.CS_ALTERNATE,
insults.G3: insults.CS_ALTERNATE_SPECIAL,
}
self.eraseDisplay()
def unhandledControlSequence(self, buf):
print("Could not handle", repr(buf))
def __bytes__(self):
lines = []
for L in self.lines:
buf = []
length = 0
for ch, attr in L:
if ch is not self.void:
buf.append(ch)
length = len(buf)
else:
buf.append(self.fill)
lines.append(b"".join(buf[:length]))
return b"\n".join(lines)
def getHost(self):
# ITransport.getHost
raise NotImplementedError("Unimplemented: TerminalBuffer.getHost")
def getPeer(self):
# ITransport.getPeer
raise NotImplementedError("Unimplemented: TerminalBuffer.getPeer")
def loseConnection(self):
# ITransport.loseConnection
raise NotImplementedError("Unimplemented: TerminalBuffer.loseConnection")
def writeSequence(self, data):
# ITransport.writeSequence
raise NotImplementedError("Unimplemented: TerminalBuffer.writeSequence")
def horizontalTabulationSet(self):
# ITerminalTransport.horizontalTabulationSet
raise NotImplementedError(
"Unimplemented: TerminalBuffer.horizontalTabulationSet"
)
def tabulationClear(self):
# TerminalTransport.tabulationClear
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClear")
def tabulationClearAll(self):
# TerminalTransport.tabulationClearAll
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClearAll")
def doubleHeightLine(self, top=True):
# ITerminalTransport.doubleHeightLine
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleHeightLine")
def singleWidthLine(self):
# ITerminalTransport.singleWidthLine
raise NotImplementedError("Unimplemented: TerminalBuffer.singleWidthLine")
def doubleWidthLine(self):
# ITerminalTransport.doubleWidthLine
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleWidthLine")
class ExpectationTimeout(Exception):
pass
class ExpectableBuffer(TerminalBuffer):
_mark = 0
def connectionMade(self):
TerminalBuffer.connectionMade(self)
self._expecting = []
def write(self, data):
TerminalBuffer.write(self, data)
self._checkExpected()
def cursorHome(self):
TerminalBuffer.cursorHome(self)
self._mark = 0
def _timeoutExpected(self, d):
d.errback(ExpectationTimeout())
self._checkExpected()
def _checkExpected(self):
s = self.__bytes__()[self._mark :]
while self._expecting:
expr, timer, deferred = self._expecting[0]
if timer and not timer.active():
del self._expecting[0]
continue
for match in expr.finditer(s):
if timer:
timer.cancel()
del self._expecting[0]
self._mark += match.end()
s = s[match.end() :]
deferred.callback(match)
break
else:
return
def expect(self, expression, timeout=None, scheduler=reactor):
d = defer.Deferred()
timer = None
if timeout:
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
self._expecting.append((re.compile(expression), timer, d))
self._checkExpected()
return d
__all__ = ["CharacterAttribute", "TerminalBuffer", "ExpectableBuffer"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,176 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Character attribute manipulation API.
This module provides a domain-specific language (using Python syntax)
for the creation of text with additional display attributes associated
with it. It is intended as an alternative to manually building up
strings containing ECMA 48 character attribute control codes. It
currently supports foreground and background colors (black, red,
green, yellow, blue, magenta, cyan, and white), intensity selection,
underlining, blinking and reverse video. Character set selection
support is planned.
Character attributes are specified by using two Python operations:
attribute lookup and indexing. For example, the string \"Hello
world\" with red foreground and all other attributes set to their
defaults, assuming the name twisted.conch.insults.text.attributes has
been imported and bound to the name \"A\" (with the statement C{from
twisted.conch.insults.text import attributes as A}, for example) one
uses this expression::
A.fg.red[\"Hello world\"]
Other foreground colors are set by substituting their name for
\"red\". To set both a foreground and a background color, this
expression is used::
A.fg.red[A.bg.green[\"Hello world\"]]
Note that either A.bg.green can be nested within A.fg.red or vice
versa. Also note that multiple items can be nested within a single
index operation by separating them with commas::
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
Other character attributes are set in a similar fashion. To specify a
blinking version of the previous expression::
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
A third operation is actually supported: unary negation. This turns
off an attribute when an enclosing expression would otherwise have
caused it to be on. For example::
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
A formatting structure can then be serialized into a string containing the
necessary VT102 control codes with L{assembleFormattedText}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@author: Jp Calderone
"""
from incremental import Version
from twisted.conch.insults import helper, insults
from twisted.python import _textattributes
from twisted.python.deprecate import deprecatedModuleAttribute
flatten = _textattributes.flatten
deprecatedModuleAttribute(
Version("Twisted", 13, 1, 0),
"Use twisted.conch.insults.text.assembleFormattedText instead.",
"twisted.conch.insults.text",
"flatten",
)
_TEXT_COLORS = {
"black": helper.BLACK,
"red": helper.RED,
"green": helper.GREEN,
"yellow": helper.YELLOW,
"blue": helper.BLUE,
"magenta": helper.MAGENTA,
"cyan": helper.CYAN,
"white": helper.WHITE,
}
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
"""
Factory for character attributes, including foreground and background color
and non-color attributes such as bold, reverse video and underline.
Character attributes are applied to actual text by using object
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
example::
attributes.bold['Some text']
These can be nested to mix attributes::
attributes.bold[attributes.underline['Some text']]
And multiple values can be passed::
attributes.normal[attributes.bold['Some'], ' text']
Non-color attributes can be accessed by attribute name, available
attributes are:
- bold
- blink
- reverseVideo
- underline
Available colors are:
0. black
1. red
2. green
3. yellow
4. blue
5. magenta
6. cyan
7. white
@ivar fg: Foreground colors accessed by attribute name, see above
for possible names.
@ivar bg: Background colors accessed by attribute name, see above
for possible names.
"""
fg = _textattributes._ColorAttribute(
_textattributes._ForegroundColorAttr, _TEXT_COLORS
)
bg = _textattributes._ColorAttribute(
_textattributes._BackgroundColorAttr, _TEXT_COLORS
)
attrs = {
"bold": insults.BOLD,
"blink": insults.BLINK,
"underline": insults.UNDERLINE,
"reverseVideo": insults.REVERSE_VIDEO,
}
def assembleFormattedText(formatted):
"""
Assemble formatted text from structured information.
Currently handled formatting includes: bold, blink, reverse, underline and
color codes.
For example::
from twisted.conch.insults.text import attributes as A
assembleFormattedText(
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
Would produce "Time: " in bold formatting, followed by "Now!" with a
foreground color of light red and without any additional formatting.
@param formatted: Structured text and attributes.
@rtype: L{str}
@return: String containing VT102 control sequences that mimic those
specified by C{formatted}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@since: 13.1
"""
return _textattributes.flatten(formatted, helper._FormattingState(), "toVT102")
attributes = _CharacterAttributes()
__all__ = ["attributes", "flatten"]

View File

@@ -0,0 +1,936 @@
# -*- test-case-name: twisted.conch.test.test_window -*-
"""
Simple insults-based widget library
@author: Jp Calderone
"""
from __future__ import annotations
import array
from twisted.conch.insults import helper, insults
from twisted.python import text as tptext
class YieldFocus(Exception):
"""
Input focus manipulation exception
"""
class BoundedTerminalWrapper:
def __init__(self, terminal, width, height, xoff, yoff):
self.width = width
self.height = height
self.xoff = xoff
self.yoff = yoff
self.terminal = terminal
self.cursorForward = terminal.cursorForward
self.selectCharacterSet = terminal.selectCharacterSet
self.selectGraphicRendition = terminal.selectGraphicRendition
self.saveCursor = terminal.saveCursor
self.restoreCursor = terminal.restoreCursor
def cursorPosition(self, x, y):
return self.terminal.cursorPosition(
self.xoff + min(self.width, x), self.yoff + min(self.height, y)
)
def cursorHome(self):
return self.terminal.cursorPosition(self.xoff, self.yoff)
def write(self, data):
return self.terminal.write(data)
class Widget:
focused = False
parent = None
dirty = False
width: int | None = None
height: int | None = None
def repaint(self):
if not self.dirty:
self.dirty = True
if self.parent is not None and not self.parent.dirty:
self.parent.repaint()
def filthy(self):
self.dirty = True
def redraw(self, width, height, terminal):
self.filthy()
self.draw(width, height, terminal)
def draw(self, width, height, terminal):
if width != self.width or height != self.height or self.dirty:
self.width = width
self.height = height
self.dirty = False
self.render(width, height, terminal)
def render(self, width, height, terminal):
pass
def sizeHint(self):
return None
def keystrokeReceived(self, keyID, modifier):
if keyID == b"\t":
self.tabReceived(modifier)
elif keyID == b"\x7f":
self.backspaceReceived()
elif keyID in insults.FUNCTION_KEYS:
self.functionKeyReceived(keyID, modifier)
else:
self.characterReceived(keyID, modifier)
def tabReceived(self, modifier):
# XXX TODO - Handle shift+tab
raise YieldFocus()
def focusReceived(self):
"""
Called when focus is being given to this widget.
May raise YieldFocus is this widget does not want focus.
"""
self.focused = True
self.repaint()
def focusLost(self):
self.focused = False
self.repaint()
def backspaceReceived(self):
pass
def functionKeyReceived(self, keyID, modifier):
name = keyID
if not isinstance(keyID, str):
name = name.decode("utf-8")
# Peel off the square brackets added by the computed definition of
# twisted.conch.insults.insults.FUNCTION_KEYS.
methodName = "func_" + name[1:-1]
func = getattr(self, methodName, None)
if func is not None:
func(modifier)
def characterReceived(self, keyID, modifier):
pass
class ContainerWidget(Widget):
"""
@ivar focusedChild: The contained widget which currently has
focus, or None.
"""
focusedChild = None
focused = False
def __init__(self):
Widget.__init__(self)
self.children = []
def addChild(self, child):
assert child.parent is None
child.parent = self
self.children.append(child)
if self.focusedChild is None and self.focused:
try:
child.focusReceived()
except YieldFocus:
pass
else:
self.focusedChild = child
self.repaint()
def remChild(self, child):
assert child.parent is self
child.parent = None
self.children.remove(child)
self.repaint()
def filthy(self):
for ch in self.children:
ch.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
for ch in self.children:
ch.draw(width, height, terminal)
def changeFocus(self):
self.repaint()
if self.focusedChild is not None:
self.focusedChild.focusLost()
focusedChild = self.focusedChild
self.focusedChild = None
try:
curFocus = self.children.index(focusedChild) + 1
except ValueError:
raise YieldFocus()
else:
curFocus = 0
while curFocus < len(self.children):
try:
self.children[curFocus].focusReceived()
except YieldFocus:
curFocus += 1
else:
self.focusedChild = self.children[curFocus]
return
# None of our children wanted focus
raise YieldFocus()
def focusReceived(self):
self.changeFocus()
self.focused = True
def keystrokeReceived(self, keyID, modifier):
if self.focusedChild is not None:
try:
self.focusedChild.keystrokeReceived(keyID, modifier)
except YieldFocus:
self.changeFocus()
self.repaint()
else:
Widget.keystrokeReceived(self, keyID, modifier)
class TopWindow(ContainerWidget):
"""
A top-level container object which provides focus wrap-around and paint
scheduling.
@ivar painter: A no-argument callable which will be invoked when this
widget needs to be redrawn.
@ivar scheduler: A one-argument callable which will be invoked with a
no-argument callable and should arrange for it to invoked at some point in
the near future. The no-argument callable will cause this widget and all
its children to be redrawn. It is typically beneficial for the no-argument
callable to be invoked at the end of handling for whatever event is
currently active; for example, it might make sense to call it at the end of
L{twisted.conch.insults.insults.ITerminalProtocol.keystrokeReceived}.
Note, however, that since calls to this may also be made in response to no
apparent event, arrangements should be made for the function to be called
even if an event handler such as C{keystrokeReceived} is not on the call
stack (eg, using
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
with a short timeout).
"""
focused = True
def __init__(self, painter, scheduler):
ContainerWidget.__init__(self)
self.painter = painter
self.scheduler = scheduler
_paintCall = None
def repaint(self):
if self._paintCall is None:
self._paintCall = object()
self.scheduler(self._paint)
ContainerWidget.repaint(self)
def _paint(self):
self._paintCall = None
self.painter()
def changeFocus(self):
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
pass
def keystrokeReceived(self, keyID, modifier):
try:
ContainerWidget.keystrokeReceived(self, keyID, modifier)
except YieldFocus:
self.changeFocus()
class AbsoluteBox(ContainerWidget):
def moveChild(self, child, x, y):
for n in range(len(self.children)):
if self.children[n][0] is child:
self.children[n] = (child, x, y)
break
else:
raise ValueError("No such child", child)
def render(self, width, height, terminal):
for ch, x, y in self.children:
wrap = BoundedTerminalWrapper(terminal, width - x, height - y, x, y)
ch.draw(width, height, wrap)
class _Box(ContainerWidget):
TOP, CENTER, BOTTOM = range(3)
def __init__(self, gravity=CENTER):
ContainerWidget.__init__(self)
self.gravity = gravity
def sizeHint(self):
height = 0
width = 0
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if self.variableDimension == 0:
if hint[0] is None:
width = None
elif width is not None:
width += hint[0]
if hint[1] is None:
height = None
elif height is not None:
height = max(height, hint[1])
else:
if hint[0] is None:
width = None
elif width is not None:
width = max(width, hint[0])
if hint[1] is None:
height = None
elif height is not None:
height += hint[1]
return width, height
def render(self, width, height, terminal):
if not self.children:
return
greedy = 0
wants = []
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if hint[self.variableDimension] is None:
greedy += 1
wants.append(hint[self.variableDimension])
length = (width, height)[self.variableDimension]
totalWant = sum(w for w in wants if w is not None)
if greedy:
leftForGreedy = int((length - totalWant) / greedy)
widthOffset = heightOffset = 0
for want, ch in zip(wants, self.children):
if want is None:
want = leftForGreedy
subWidth, subHeight = width, height
if self.variableDimension == 0:
subWidth = want
else:
subHeight = want
wrap = BoundedTerminalWrapper(
terminal,
subWidth,
subHeight,
widthOffset,
heightOffset,
)
ch.draw(subWidth, subHeight, wrap)
if self.variableDimension == 0:
widthOffset += want
else:
heightOffset += want
class HBox(_Box):
variableDimension = 0
class VBox(_Box):
variableDimension = 1
class Packer(ContainerWidget):
def render(self, width, height, terminal):
if not self.children:
return
root = int(len(self.children) ** 0.5 + 0.5)
boxes = [VBox() for n in range(root)]
for n, ch in enumerate(self.children):
boxes[n % len(boxes)].addChild(ch)
h = HBox()
map(h.addChild, boxes)
h.render(width, height, terminal)
class Canvas(Widget):
focused = False
contents = None
def __init__(self):
Widget.__init__(self)
self.resize(1, 1)
def resize(self, width, height):
contents = array.array("B", b" " * width * height)
if self.contents is not None:
for x in range(min(width, self._width)):
for y in range(min(height, self._height)):
contents[width * y + x] = self[x, y]
self.contents = contents
self._width = width
self._height = height
if self.x >= width:
self.x = width - 1
if self.y >= height:
self.y = height - 1
def __getitem__(self, index):
(x, y) = index
return self.contents[(self._width * y) + x]
def __setitem__(self, index, value):
(x, y) = index
self.contents[(self._width * y) + x] = value
def clear(self):
self.contents = array.array("B", b" " * len(self.contents))
def render(self, width, height, terminal):
if not width or not height:
return
if width != self._width or height != self._height:
self.resize(width, height)
for i in range(height):
terminal.cursorPosition(0, i)
text = self.contents[
self._width * i : self._width * i + self._width
].tobytes()
text = text[:width]
terminal.write(text)
def horizontalLine(terminal, y, left, right):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(left, y)
terminal.write(b"\161" * (right - left))
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def verticalLine(terminal, x, top, bottom):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
for n in range(top, bottom):
terminal.cursorPosition(x, n)
terminal.write(b"\170")
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def rectangle(terminal, position, dimension):
"""
Draw a rectangle
@type position: L{tuple}
@param position: A tuple of the (top, left) coordinates of the rectangle.
@type dimension: L{tuple}
@param dimension: A tuple of the (width, height) size of the rectangle.
"""
(top, left) = position
(width, height) = dimension
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(top, left)
terminal.write(b"\154")
terminal.write(b"\161" * (width - 2))
terminal.write(b"\153")
for n in range(height - 2):
terminal.cursorPosition(left, top + n + 1)
terminal.write(b"\170")
terminal.cursorForward(width - 2)
terminal.write(b"\170")
terminal.cursorPosition(0, top + height - 1)
terminal.write(b"\155")
terminal.write(b"\161" * (width - 2))
terminal.write(b"\152")
terminal.selectCharacterSet(insults.CS_US, insults.G0)
class Border(Widget):
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
def focusReceived(self):
return self.containee.focusReceived()
def focusLost(self):
return self.containee.focusLost()
def keystrokeReceived(self, keyID, modifier):
return self.containee.keystrokeReceived(keyID, modifier)
def sizeHint(self):
hint = self.containee.sizeHint()
if hint is None:
hint = (None, None)
if hint[0] is None:
x = None
else:
x = hint[0] + 2
if hint[1] is None:
y = None
else:
y = hint[1] + 2
return x, y
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
if self.containee.focused:
terminal.write(b"\x1b[31m")
rectangle(terminal, (0, 0), (width, height))
terminal.write(b"\x1b[0m")
wrap = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self.containee.draw(width - 2, height - 2, wrap)
class Button(Widget):
def __init__(self, label, onPress):
Widget.__init__(self)
self.label = label
self.onPress = onPress
def sizeHint(self):
return len(self.label), 1
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onPress()
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write(b"\x1b[1m" + self.label + b"\x1b[0m")
else:
terminal.write(self.label)
class TextInput(Widget):
def __init__(self, maxwidth, onSubmit):
Widget.__init__(self)
self.onSubmit = onSubmit
self.maxwidth = maxwidth
self.buffer = b""
self.cursor = 0
def setText(self, text):
self.buffer = text[: self.maxwidth]
self.cursor = len(self.buffer)
self.repaint()
def func_LEFT_ARROW(self, modifier):
if self.cursor > 0:
self.cursor -= 1
self.repaint()
def func_RIGHT_ARROW(self, modifier):
if self.cursor < len(self.buffer):
self.cursor += 1
self.repaint()
def backspaceReceived(self):
if self.cursor > 0:
self.buffer = self.buffer[: self.cursor - 1] + self.buffer[self.cursor :]
self.cursor -= 1
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onSubmit(self.buffer)
else:
if len(self.buffer) < self.maxwidth:
self.buffer = (
self.buffer[: self.cursor] + keyID + self.buffer[self.cursor :]
)
self.cursor += 1
self.repaint()
def sizeHint(self):
return self.maxwidth + 1, 1
def render(self, width, height, terminal):
currentText = self._renderText()
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write(currentText[: self.cursor])
cursor(terminal, currentText[self.cursor : self.cursor + 1] or b" ")
terminal.write(currentText[self.cursor + 1 :])
terminal.write(b" " * (self.maxwidth - len(currentText) + 1))
else:
more = self.maxwidth - len(currentText)
terminal.write(currentText + b"_" * more)
def _renderText(self):
return self.buffer
class PasswordInput(TextInput):
def _renderText(self):
return "*" * len(self.buffer)
class TextOutput(Widget):
text = b""
def __init__(self, size=None):
Widget.__init__(self)
self.size = size
def sizeHint(self):
return self.size
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
text = self.text[:width]
terminal.write(text + b" " * (width - len(text)))
def setText(self, text):
self.text = text
self.repaint()
def focusReceived(self):
raise YieldFocus()
class TextOutputArea(TextOutput):
WRAP, TRUNCATE = range(2)
def __init__(self, size=None, longLines=WRAP):
TextOutput.__init__(self, size)
self.longLines = longLines
def render(self, width, height, terminal):
n = 0
inputLines = self.text.splitlines()
outputLines = []
while inputLines:
if self.longLines == self.WRAP:
line = inputLines.pop(0)
if not isinstance(line, str):
line = line.decode("utf-8")
wrappedLines = []
for wrappedLine in tptext.greedyWrap(line, width):
if not isinstance(wrappedLine, bytes):
wrappedLine = wrappedLine.encode("utf-8")
wrappedLines.append(wrappedLine)
outputLines.extend(wrappedLines or [b""])
else:
outputLines.append(inputLines.pop(0)[:width])
if len(outputLines) >= height:
break
for n, L in enumerate(outputLines[:height]):
terminal.cursorPosition(0, n)
terminal.write(L)
class Viewport(Widget):
_xOffset = 0
_yOffset = 0
@property
def xOffset(self):
return self._xOffset
@xOffset.setter
def xOffset(self, value):
if self._xOffset != value:
self._xOffset = value
self.repaint()
@property
def yOffset(self):
return self._yOffset
@yOffset.setter
def yOffset(self, value):
if self._yOffset != value:
self._yOffset = value
self.repaint()
_width = 160
_height = 24
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
self._buf = helper.TerminalBuffer()
self._buf.width = self._width
self._buf.height = self._height
self._buf.connectionMade()
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
self.containee.draw(self._width, self._height, self._buf)
# XXX /Lame/
for y, line in enumerate(
self._buf.lines[self._yOffset : self._yOffset + height]
):
terminal.cursorPosition(0, y)
n = 0
for n, (ch, attr) in enumerate(line[self._xOffset : self._xOffset + width]):
if ch is self._buf.void:
ch = b" "
terminal.write(ch)
if n < width:
terminal.write(b" " * (width - n - 1))
class _Scrollbar(Widget):
def __init__(self, onScroll):
Widget.__init__(self)
self.onScroll = onScroll
self.percent = 0.0
def smaller(self):
self.percent = min(1.0, max(0.0, self.onScroll(-1)))
self.repaint()
def bigger(self):
self.percent = min(1.0, max(0.0, self.onScroll(+1)))
self.repaint()
class HorizontalScrollbar(_Scrollbar):
def sizeHint(self):
return (None, 1)
def func_LEFT_ARROW(self, modifier):
self.smaller()
def func_RIGHT_ARROW(self, modifier):
self.bigger()
_left = "\N{BLACK LEFT-POINTING TRIANGLE}"
_right = "\N{BLACK RIGHT-POINTING TRIANGLE}"
_bar = "\N{LIGHT SHADE}"
_slider = "\N{DARK SHADE}"
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
n = width - 3
before = int(n * self.percent)
after = n - before
me = (
self._left
+ (self._bar * before)
+ self._slider
+ (self._bar * after)
+ self._right
)
terminal.write(me.encode("utf-8"))
class VerticalScrollbar(_Scrollbar):
def sizeHint(self):
return (1, None)
def func_UP_ARROW(self, modifier):
self.smaller()
def func_DOWN_ARROW(self, modifier):
self.bigger()
_up = "\N{BLACK UP-POINTING TRIANGLE}"
_down = "\N{BLACK DOWN-POINTING TRIANGLE}"
_bar = "\N{LIGHT SHADE}"
_slider = "\N{DARK SHADE}"
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
knob = int(self.percent * (height - 2))
terminal.write(self._up.encode("utf-8"))
for i in range(1, height - 1):
terminal.cursorPosition(0, i)
if i != (knob + 1):
terminal.write(self._bar.encode("utf-8"))
else:
terminal.write(self._slider.encode("utf-8"))
terminal.cursorPosition(0, height - 1)
terminal.write(self._down.encode("utf-8"))
class ScrolledArea(Widget):
"""
A L{ScrolledArea} contains another widget wrapped in a viewport and
vertical and horizontal scrollbars for moving the viewport around.
"""
def __init__(self, containee):
Widget.__init__(self)
self._viewport = Viewport(containee)
self._horiz = HorizontalScrollbar(self._horizScroll)
self._vert = VerticalScrollbar(self._vertScroll)
for w in self._viewport, self._horiz, self._vert:
w.parent = self
def _horizScroll(self, n):
self._viewport.xOffset += n
self._viewport.xOffset = max(0, self._viewport.xOffset)
return self._viewport.xOffset / 25.0
def _vertScroll(self, n):
self._viewport.yOffset += n
self._viewport.yOffset = max(0, self._viewport.yOffset)
return self._viewport.yOffset / 25.0
def func_UP_ARROW(self, modifier):
self._vert.smaller()
def func_DOWN_ARROW(self, modifier):
self._vert.bigger()
def func_LEFT_ARROW(self, modifier):
self._horiz.smaller()
def func_RIGHT_ARROW(self, modifier):
self._horiz.bigger()
def filthy(self):
self._viewport.filthy()
self._horiz.filthy()
self._vert.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
wrapper = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self._viewport.draw(width - 2, height - 2, wrapper)
if self.focused:
terminal.write(b"\x1b[31m")
horizontalLine(terminal, 0, 1, width - 1)
verticalLine(terminal, 0, 1, height - 1)
self._vert.draw(
1, height - 1, BoundedTerminalWrapper(terminal, 1, height - 1, width - 1, 0)
)
self._horiz.draw(
width, 1, BoundedTerminalWrapper(terminal, width, 1, 0, height - 1)
)
terminal.write(b"\x1b[0m")
def cursor(terminal, ch):
terminal.saveCursor()
terminal.selectGraphicRendition(str(insults.REVERSE_VIDEO))
terminal.write(ch)
terminal.restoreCursor()
terminal.cursorForward()
class Selection(Widget):
# Index into the sequence
focusedIndex = 0
# Offset into the displayed subset of the sequence
renderOffset = 0
def __init__(self, sequence, onSelect, minVisible=None):
Widget.__init__(self)
self.sequence = sequence
self.onSelect = onSelect
self.minVisible = minVisible
if minVisible is not None:
self._width = max(map(len, self.sequence))
def sizeHint(self):
if self.minVisible is not None:
return self._width, self.minVisible
def func_UP_ARROW(self, modifier):
if self.focusedIndex > 0:
self.focusedIndex -= 1
if self.renderOffset > 0:
self.renderOffset -= 1
self.repaint()
def func_PGUP(self, modifier):
if self.renderOffset != 0:
self.focusedIndex -= self.renderOffset
self.renderOffset = 0
else:
self.focusedIndex = max(0, self.focusedIndex - self.height)
self.repaint()
def func_DOWN_ARROW(self, modifier):
if self.focusedIndex < len(self.sequence) - 1:
self.focusedIndex += 1
if self.renderOffset < self.height - 1:
self.renderOffset += 1
self.repaint()
def func_PGDN(self, modifier):
if self.renderOffset != self.height - 1:
change = self.height - self.renderOffset - 1
if change + self.focusedIndex >= len(self.sequence):
change = len(self.sequence) - self.focusedIndex - 1
self.focusedIndex += change
self.renderOffset = self.height - 1
else:
self.focusedIndex = min(
len(self.sequence) - 1, self.focusedIndex + self.height
)
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == b"\r":
self.onSelect(self.sequence[self.focusedIndex])
def render(self, width, height, terminal):
self.height = height
start = self.focusedIndex - self.renderOffset
if start > len(self.sequence) - height:
start = max(0, len(self.sequence) - height)
elements = self.sequence[start : start + height]
for n, ele in enumerate(elements):
terminal.cursorPosition(0, n)
if n == self.renderOffset:
terminal.saveCursor()
if self.focused:
modes = str(insults.REVERSE_VIDEO), str(insults.BOLD)
else:
modes = (str(insults.REVERSE_VIDEO),)
terminal.selectGraphicRendition(*modes)
text = ele[:width]
terminal.write(text + (b" " * (width - len(text))))
if n == self.renderOffset:
terminal.restoreCursor()

View File

@@ -0,0 +1,456 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains interfaces defined for the L{twisted.conch} package.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from zope.interface import Attribute, Interface
if TYPE_CHECKING:
from twisted.conch.ssh.keys import Key
class IConchUser(Interface):
"""
A user who has been authenticated to Cred through Conch. This is
the interface between the SSH connection and the user.
"""
conn = Attribute("The SSHConnection object for this user.")
def lookupChannel(channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
C{channelType} is the type of channel being requested,
as an ssh connection protocol channel type.
C{data} is any other packet data (often nothing).
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
the channel type is unknown, we return C{None}.
For other failures, we raise an exception. If a
L{ConchError<error.ConchError>} is raised, the C{.value} will
be the message, and the C{.data} will be the error code.
@param channelType: The requested channel type
@type channelType: L{bytes}
@param windowSize: The initial size of the remote window
@type windowSize: L{int}
@param maxPacket: The largest packet we should send
@type maxPacket: L{int}
@param data: Additional request data
@type data: L{bytes}
@rtype: a subclass of L{SSHChannel} or L{None}
"""
def lookupSubsystem(subsystem, data):
"""
The other side requested a subsystem.
We return a L{Protocol} implementing the requested subsystem.
If the subsystem is not available, we return C{None}.
@param subsystem: The name of the subsystem being requested
@type subsystem: L{bytes}
@param data: Additional request data (often nothing)
@type data: L{bytes}
@rtype: L{Protocol} or L{None}
"""
def gotGlobalRequest(requestType, data):
"""
A global request was sent from the other side.
We return a true value on success or a false value on failure.
If we indicate success by returning a tuple, its second item
will be sent to the other side as additional response data.
@param requestType: The type of the request
@type requestType: L{bytes}
@param data: Additional request data
@type data: L{bytes}
@rtype: boolean or L{tuple}
"""
class ISession(Interface):
def getPty(term, windowSize, modes):
"""
Get a pseudo-terminal for use by a shell or command.
If a pseudo-terminal is not available, or the request otherwise
fails, raise an exception.
"""
def openShell(proto):
"""
Open a shell and connect it to proto.
@param proto: a L{ProcessProtocol} instance.
"""
def execCommand(proto, command):
"""
Execute a command.
@param proto: a L{ProcessProtocol} instance.
"""
def windowChanged(newWindowSize):
"""
Called when the size of the remote screen has changed.
"""
def eofReceived():
"""
Called when the other side has indicated no more data will be sent.
"""
def closed():
"""
Called when the session is closed.
"""
class EnvironmentVariableNotPermitted(ValueError):
"""Setting this environment variable in this session is not permitted."""
class ISessionSetEnv(Interface):
"""A session that can set environment variables."""
def setEnv(name, value):
"""
Set an environment variable for the shell or command to be started.
From U{RFC 4254, section 6.4
<https://tools.ietf.org/html/rfc4254#section-6.4>}: "Uncontrolled
setting of environment variables in a privileged process can be a
security hazard. It is recommended that implementations either
maintain a list of allowable variable names or only set environment
variables after the server process has dropped sufficient
privileges."
(OpenSSH refuses all environment variables by default, but has an
C{AcceptEnv} configuration option to select specific variables to
accept.)
@param name: The name of the environment variable to set.
@type name: L{bytes}
@param value: The value of the environment variable to set.
@type value: L{bytes}
@raise EnvironmentVariableNotPermitted: if setting this environment
variable is not permitted.
"""
class ISFTPServer(Interface):
"""
SFTP subsystem for server-side communication.
Each method should check to verify that the user has permission for
their actions.
"""
avatar = Attribute(
"""
The avatar returned by the Realm that we are authenticated with,
and represents the logged-in user.
"""
)
def gotVersion(otherVersion, extData):
"""
Called when the client sends their version info.
otherVersion is an integer representing the version of the SFTP
protocol they are claiming.
extData is a dictionary of extended_name : extended_data items.
These items are sent by the client to indicate additional features.
This method should return a dictionary of extended_name : extended_data
items. These items are the additional features (if any) supported
by the server.
"""
return {}
def openFile(filename, flags, attrs):
"""
Called when the clients asks to open a file.
@param filename: a string representing the file to open.
@param flags: an integer of the flags to open the file with, ORed
together. The flags and their values are listed at the bottom of
L{twisted.conch.ssh.filetransfer} as FXF_*.
@param attrs: a list of attributes to open the file with. It is a
dictionary, consisting of 0 or more keys. The possible keys are::
size: the size of the file in bytes
uid: the user ID of the file as an integer
gid: the group ID of the file as an integer
permissions: the permissions of the file with as an integer.
the bit representation of this field is defined by POSIX.
atime: the access time of the file as seconds since the epoch.
mtime: the modification time of the file as seconds since the epoch.
ext_*: extended attributes. The server is not required to
understand this, but it may.
NOTE: there is no way to indicate text or binary files. it is up
to the SFTP client to deal with this.
This method returns an object that meets the ISFTPFile interface.
Alternatively, it can return a L{Deferred} that will be called back
with the object.
"""
def removeFile(filename):
"""
Remove the given file.
This method returns when the remove succeeds, or a Deferred that is
called back when it succeeds.
@param filename: the name of the file as a string.
"""
def renameFile(oldpath, newpath):
"""
Rename the given file.
This method returns when the rename succeeds, or a L{Deferred} that is
called back when it succeeds. If the rename fails, C{renameFile} will
raise an implementation-dependent exception.
@param oldpath: the current location of the file.
@param newpath: the new file name.
"""
def makeDirectory(path, attrs):
"""
Make a directory.
This method returns when the directory is created, or a Deferred that
is called back when it is created.
@param path: the name of the directory to create as a string.
@param attrs: a dictionary of attributes to create the directory with.
Its meaning is the same as the attrs in the L{openFile} method.
"""
def removeDirectory(path):
"""
Remove a directory (non-recursively)
It is an error to remove a directory that has files or directories in
it.
This method returns when the directory is removed, or a Deferred that
is called back when it is removed.
@param path: the directory to remove.
"""
def openDirectory(path):
"""
Open a directory for scanning.
This method returns an iterable object that has a close() method,
or a Deferred that is called back with same.
The close() method is called when the client is finished reading
from the directory. At this point, the iterable will no longer
be used.
The iterable should return triples of the form (filename,
longname, attrs) or Deferreds that return the same. The
sequence must support __getitem__, but otherwise may be any
'sequence-like' object.
filename is the name of the file relative to the directory.
logname is an expanded format of the filename. The recommended format
is:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
The first line is sample output, the second is the length of the field.
The fields are: permissions, link count, user owner, group owner,
size in bytes, modification time.
attrs is a dictionary in the format of the attrs argument to openFile.
@param path: the directory to open.
"""
def getAttrs(path, followLinks):
"""
Return the attributes for the given path.
This method returns a dictionary in the same format as the attrs
argument to openFile or a Deferred that is called back with same.
@param path: the path to return attributes for as a string.
@param followLinks: a boolean. If it is True, follow symbolic links
and return attributes for the real path at the base. If it is False,
return attributes for the specified path.
"""
def setAttrs(path, attrs):
"""
Set the attributes for the path.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param path: the path to set attributes for as a string.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""
def readLink(path):
"""
Find the root of a set of symbolic links.
This method returns the target of the link, or a Deferred that
returns the same.
@param path: the path of the symlink to read.
"""
def makeLink(linkPath, targetPath):
"""
Create a symbolic link.
This method returns when the link is made, or a Deferred that
returns the same.
@param linkPath: the pathname of the symlink as a string.
@param targetPath: the path of the target of the link as a string.
"""
def realPath(path):
"""
Convert any path to an absolute path.
This method returns the absolute path as a string, or a Deferred
that returns the same.
@param path: the path to convert as a string.
"""
def extendedRequest(extendedName, extendedData):
"""
This is the extension mechanism for SFTP. The other side can send us
arbitrary requests.
If we don't implement the request given by extendedName, raise
NotImplementedError.
The return value is a string, or a Deferred that will be called
back with a string.
@param extendedName: the name of the request as a string.
@param extendedData: the data the other side sent with the request,
as a string.
"""
class IKnownHostEntry(Interface):
"""
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
file.
@since: 8.2
"""
def matchesKey(key: Key) -> bool:
"""
Return True if this entry matches the given Key object, False
otherwise.
@param key: The key object to match against.
"""
def matchesHost(hostname: bytes) -> bool:
"""
Return True if this entry matches the given hostname, False otherwise.
Note that this does no name resolution; if you want to match an IP
address, you have to resolve it yourself, and pass it in as a dotted
quad string.
@param hostname: The hostname to match against.
"""
def toString() -> bytes:
"""
@return: a serialized string representation of this entry, suitable for
inclusion in a known_hosts file. (Newline not included.)
"""
class ISFTPFile(Interface):
"""
This represents an open file on the server. An object adhering to this
interface should be returned from L{openFile}().
"""
def close():
"""
Close the file.
This method returns nothing if the close succeeds immediately, or a
Deferred that is called back when the close succeeds.
"""
def readChunk(offset, length):
"""
Read from the file.
If EOF is reached before any data is read, raise EOFError.
This method returns the data as a string, or a Deferred that is
called back with same.
@param offset: an integer that is the index to start from in the file.
@param length: the maximum length of data to return. The actual amount
returned may less than this. For normal disk files, however,
this should read the requested number (up to the end of the file).
"""
def writeChunk(offset, data):
"""
Write to the file.
This method returns when the write completes, or a Deferred that is
called when it completes.
@param offset: an integer that is the index to start from in the file.
@param data: a string that is the data to write.
"""
def getAttrs():
"""
Return the attributes for the file.
This method returns a dictionary in the same format as the attrs
argument to L{openFile} or a L{Deferred} that is called back with same.
"""
def setAttrs(attrs):
"""
Set the attributes for the file.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""

View File

@@ -0,0 +1,104 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import array
import stat
from time import localtime, strftime, time
# Locale-independent month names to use instead of strftime's
_MONTH_NAMES = dict(
list(
zip(
list(range(1, 13)),
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(),
)
)
)
def lsLine(name, s):
"""
Build an 'ls' line for a file ('file' in its generic sense, it
can be of any type).
"""
mode = s.st_mode
perms = array.array("B", b"-" * 10)
ft = stat.S_IFMT(mode)
if stat.S_ISDIR(ft):
perms[0] = ord("d")
elif stat.S_ISCHR(ft):
perms[0] = ord("c")
elif stat.S_ISBLK(ft):
perms[0] = ord("b")
elif stat.S_ISREG(ft):
perms[0] = ord("-")
elif stat.S_ISFIFO(ft):
perms[0] = ord("f")
elif stat.S_ISLNK(ft):
perms[0] = ord("l")
elif stat.S_ISSOCK(ft):
perms[0] = ord("s")
else:
perms[0] = ord("!")
# User
if mode & stat.S_IRUSR:
perms[1] = ord("r")
if mode & stat.S_IWUSR:
perms[2] = ord("w")
if mode & stat.S_IXUSR:
perms[3] = ord("x")
# Group
if mode & stat.S_IRGRP:
perms[4] = ord("r")
if mode & stat.S_IWGRP:
perms[5] = ord("w")
if mode & stat.S_IXGRP:
perms[6] = ord("x")
# Other
if mode & stat.S_IROTH:
perms[7] = ord("r")
if mode & stat.S_IWOTH:
perms[8] = ord("w")
if mode & stat.S_IXOTH:
perms[9] = ord("x")
# Suid/sgid
if mode & stat.S_ISUID:
if perms[3] == ord("x"):
perms[3] = ord("s")
else:
perms[3] = ord("S")
if mode & stat.S_ISGID:
if perms[6] == ord("x"):
perms[6] = ord("s")
else:
perms[6] = ord("S")
if isinstance(name, bytes):
name = name.decode("utf-8")
lsPerms = perms.tobytes()
lsPerms = lsPerms.decode("utf-8")
lsresult = [
lsPerms,
str(s.st_nlink).rjust(5),
" ",
str(s.st_uid).ljust(9),
str(s.st_gid).ljust(9),
str(s.st_size).rjust(8),
" ",
]
# Need to specify the month manually, as strftime depends on locale
ttup = localtime(s.st_mtime)
sixmonths = 60 * 60 * 24 * 7 * 26
if s.st_mtime + sixmonths < time(): # Last edited more than 6mo ago
strtime = strftime("%%s %d %Y ", ttup)
else:
strtime = strftime("%%s %d %H:%M ", ttup)
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
lsresult.append(name)
return "".join(lsresult)
__all__ = ["lsLine"]

View File

@@ -0,0 +1,392 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Line-input oriented interactive interpreter loop.
Provides classes for handling Python source input and arbitrary output
interactively from a Twisted application. Also included is syntax coloring
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
and reasonable handling of Deferreds.
@author: Jp Calderone
"""
import code
import sys
import tokenize
from io import BytesIO
from traceback import format_exception
from types import TracebackType
from typing import Type
from twisted.conch import recvline
from twisted.internet import defer
from twisted.python.htmlizer import TokenPrinter
from twisted.python.monkey import MonkeyPatcher
class FileWrapper:
"""
Minimal write-file-like object.
Writes are translated into addOutput calls on an object passed to
__init__. Newlines are also converted from network to local style.
"""
softspace = 0
state = "normal"
def __init__(self, o):
self.o = o
def flush(self):
pass
def write(self, data):
self.o.addOutput(data.replace("\r\n", "\n"))
def writelines(self, lines):
self.write("".join(lines))
class ManholeInterpreter(code.InteractiveInterpreter):
"""
Interactive Interpreter with special output and Deferred support.
Aside from the features provided by L{code.InteractiveInterpreter}, this
class captures sys.stdout output and redirects it to the appropriate
location (the Manhole protocol instance). It also treats Deferreds
which reach the top-level specially: each is formatted to the user with
a unique identifier and a new callback and errback added to it, each of
which will format the unique identifier and the result with which the
Deferred fires and then pass it on to the next participant in the
callback chain.
"""
numDeferreds = 0
def __init__(self, handler, locals=None, filename="<console>"):
code.InteractiveInterpreter.__init__(self, locals)
self._pendingDeferreds = {}
self.handler = handler
self.filename = filename
self.resetBuffer()
self.monkeyPatcher = MonkeyPatcher()
self.monkeyPatcher.addPatch(sys, "displayhook", self.displayhook)
self.monkeyPatcher.addPatch(sys, "excepthook", self.excepthook)
self.monkeyPatcher.addPatch(sys, "stdout", FileWrapper(self.handler))
def resetBuffer(self):
"""
Reset the input buffer.
"""
self.buffer = []
def push(self, line):
"""
Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
@param line: line of text
@type line: L{bytes}
@return: L{bool} from L{code.InteractiveInterpreter.runsource}
"""
self.buffer.append(line)
source = b"\n".join(self.buffer)
source = source.decode("utf-8")
more = self.runsource(source, self.filename)
if not more:
self.resetBuffer()
return more
def runcode(self, *a, **kw):
with self.monkeyPatcher:
code.InteractiveInterpreter.runcode(self, *a, **kw)
def excepthook(
self,
excType: Type[BaseException],
excValue: BaseException,
excTraceback: TracebackType,
) -> None:
"""
Format exception tracebacks and write them to the output handler.
"""
code_obj = excTraceback.tb_frame.f_code
if code_obj.co_filename == code.__file__ and code_obj.co_name == "runcode":
traceback = excTraceback.tb_next
else:
# Workaround for https://github.com/python/cpython/issues/122478,
# present e.g. in Python 3.12.6:
traceback = excTraceback
lines = format_exception(excType, excValue, traceback)
self.write("".join(lines))
def displayhook(self, obj):
self.locals["_"] = obj
if isinstance(obj, defer.Deferred):
# XXX Ick, where is my "hasFired()" interface?
if hasattr(obj, "result"):
self.write(repr(obj))
elif id(obj) in self._pendingDeferreds:
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
else:
d = self._pendingDeferreds
k = self.numDeferreds
d[id(obj)] = (k, obj)
self.numDeferreds += 1
obj.addCallbacks(
self._cbDisplayDeferred,
self._ebDisplayDeferred,
callbackArgs=(k, obj),
errbackArgs=(k, obj),
)
self.write("<Deferred #%d>" % (k,))
elif obj is not None:
self.write(repr(obj))
def _cbDisplayDeferred(self, result, k, obj):
self.write("Deferred #%d called back: %r" % (k, result), True)
del self._pendingDeferreds[id(obj)]
return result
def _ebDisplayDeferred(self, failure, k, obj):
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
del self._pendingDeferreds[id(obj)]
return failure
def write(self, data, isAsync=None):
self.handler.addOutput(data, isAsync)
CTRL_C = b"\x03"
CTRL_D = b"\x04"
CTRL_BACKSLASH = b"\x1c"
CTRL_L = b"\x0c"
CTRL_A = b"\x01"
CTRL_E = b"\x05"
class Manhole(recvline.HistoricRecvLine):
r"""
Mediator between a fancy line source and an interactive interpreter.
This accepts lines from its transport and passes them on to a
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
with something approximating their normal terminal-mode behavior. It
can optionally be constructed with a dict which will be used as the
local namespace for any code executed.
"""
namespace = None
def __init__(self, namespace=None):
recvline.HistoricRecvLine.__init__(self)
if namespace is not None:
self.namespace = namespace.copy()
def connectionMade(self):
recvline.HistoricRecvLine.connectionMade(self)
self.interpreter = ManholeInterpreter(self, self.namespace)
self.keyHandlers[CTRL_C] = self.handle_INT
self.keyHandlers[CTRL_D] = self.handle_EOF
self.keyHandlers[CTRL_L] = self.handle_FF
self.keyHandlers[CTRL_A] = self.handle_HOME
self.keyHandlers[CTRL_E] = self.handle_END
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
def handle_INT(self):
"""
Handle ^C as an interrupt keystroke by resetting the current input
variables to their initial state.
"""
self.pn = 0
self.lineBuffer = []
self.lineBufferIndex = 0
self.interpreter.resetBuffer()
self.terminal.nextLine()
self.terminal.write(b"KeyboardInterrupt")
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
def handle_EOF(self):
if self.lineBuffer:
self.terminal.write(b"\a")
else:
self.handle_QUIT()
def handle_FF(self):
"""
Handle a 'form feed' byte - generally used to request a screen
refresh/redraw.
"""
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.drawInputLine()
def handle_QUIT(self):
self.terminal.loseConnection()
def _needsNewline(self):
w = self.terminal.lastWrite
return not w.endswith(b"\n") and not w.endswith(b"\x1bE")
def addOutput(self, data, isAsync=None):
if isAsync:
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]))
self.terminal.write(data)
if isAsync:
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
if self.lineBuffer:
oldBuffer = self.lineBuffer
self.lineBuffer = []
self.lineBufferIndex = 0
self._deliverBuffer(oldBuffer)
def lineReceived(self, line):
more = self.interpreter.push(line)
self.pn = bool(more)
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
class VT102Writer:
"""
Colorizer for Python tokens.
A series of tokens are written to instances of this object. Each is
colored in a particular way. The final line of the result of this is
generally added to the output.
"""
typeToColor = {
"identifier": b"\x1b[31m",
"keyword": b"\x1b[32m",
"parameter": b"\x1b[33m",
"variable": b"\x1b[1;33m",
"string": b"\x1b[35m",
"number": b"\x1b[36m",
"op": b"\x1b[37m",
}
normalColor = b"\x1b[0m"
def __init__(self):
self.written = []
def color(self, type):
r = self.typeToColor.get(type, b"")
return r
def write(self, token, type=None):
if token and token != b"\r":
c = self.color(type)
if c:
self.written.append(c)
self.written.append(token)
if c:
self.written.append(self.normalColor)
def __bytes__(self):
s = b"".join(self.written)
return s.strip(b"\n").splitlines()[-1]
def lastColorizedLine(source):
"""
Tokenize and colorize the given Python source.
Returns a VT102-format colorized version of the last line of C{source}.
@param source: Python source code
@type source: L{str} or L{bytes}
@return: L{bytes} of colorized source
"""
if not isinstance(source, bytes):
source = source.encode("utf-8")
w = VT102Writer()
p = TokenPrinter(w.write).printtoken
s = BytesIO(source)
for token in tokenize.tokenize(s.readline):
(tokenType, string, start, end, line) = token
p(tokenType, string, start, end, line)
return bytes(w)
class ColoredManhole(Manhole):
"""
A REPL which syntax colors input as users type it.
"""
def getSource(self):
"""
Return a string containing the currently entered source.
This is only the code which will be considered for execution
next.
"""
return b"\n".join(self.interpreter.buffer) + b"\n" + b"".join(self.lineBuffer)
def characterReceived(self, ch, moreCharactersComing):
if self.mode == "insert":
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
self.lineBufferIndex += 1
if moreCharactersComing:
# Skip it all, we'll get called with another character in
# like 2 femtoseconds.
return
if ch == b" ":
# Don't bother to try to color whitespace
self.terminal.write(ch)
return
source = self.getSource()
# Try to write some junk
try:
coloredLine = lastColorizedLine(source)
except tokenize.TokenError:
# We couldn't do it. Strange. Oh well, just add the character.
self.terminal.write(ch)
else:
# Success! Clear the source on this line.
self.terminal.eraseLine()
self.terminal.cursorBackward(
len(self.lineBuffer) + len(self.ps[self.pn]) - 1
)
# And write a new, colorized one.
self.terminal.write(self.ps[self.pn] + coloredLine)
# And move the cursor to where it belongs
n = len(self.lineBuffer) - self.lineBufferIndex
if n:
self.terminal.cursorBackward(n)

View File

@@ -0,0 +1,148 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
insults/SSH integration support.
@author: Jp Calderone
"""
from typing import Dict
from zope.interface import implementer
from twisted.conch import avatar, error as econch, interfaces as iconch
from twisted.conch.insults import insults
from twisted.conch.ssh import factory, session
from twisted.python import components
class _Glue:
"""
A feeble class for making one attribute look like another.
This should be replaced with a real class at some point, probably.
Try not to write new code that uses it.
"""
def __init__(self, **kw):
self.__dict__.update(kw)
def __getattr__(self, name):
raise AttributeError(self.name, "has no attribute", name)
class TerminalSessionTransport:
def __init__(self, proto, chainedProtocol, avatar, width, height):
self.proto = proto
self.avatar = avatar
self.chainedProtocol = chainedProtocol
protoSession = self.proto.session
self.proto.makeConnection(
_Glue(
write=self.chainedProtocol.dataReceived,
loseConnection=lambda: avatar.conn.sendClose(protoSession),
name="SSH Proto Transport",
)
)
def loseConnection():
self.proto.loseConnection()
self.chainedProtocol.makeConnection(
_Glue(
write=self.proto.write,
loseConnection=loseConnection,
name="Chained Proto Transport",
)
)
# XXX TODO
# chainedProtocol is supposed to be an ITerminalTransport,
# maybe. That means perhaps its terminalProtocol attribute is
# an ITerminalProtocol, it could be. So calling terminalSize
# on that should do the right thing But it'd be nice to clean
# this bit up.
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
@implementer(iconch.ISession)
class TerminalSession(components.Adapter):
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def getPty(self, term, windowSize, attrs):
self.height, self.width = windowSize[:2]
def openShell(self, proto):
self.transportFactory(
proto,
self.chainedProtocolFactory(),
iconch.IConchUser(self.original),
self.width,
self.height,
)
def execCommand(self, proto, cmd):
raise econch.ConchError("Cannot execute commands")
def windowChanged(self, newWindowSize):
# ISession.windowChanged
raise NotImplementedError("Unimplemented: TerminalSession.windowChanged")
def eofReceived(self):
# ISession.eofReceived
raise NotImplementedError("Unimplemented: TerminalSession.eofReceived")
def closed(self):
# ISession.closed
pass
class TerminalUser(avatar.ConchUser, components.Adapter):
def __init__(self, original, avatarId):
components.Adapter.__init__(self, original)
avatar.ConchUser.__init__(self)
self.channelLookup[b"session"] = session.SSHSession
class TerminalRealm:
userFactory = TerminalUser
sessionFactory = TerminalSession
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def _getAvatar(self, avatarId):
comp = components.Componentized()
user = self.userFactory(comp, avatarId)
sess = self.sessionFactory(comp)
sess.transportFactory = self.transportFactory
sess.chainedProtocolFactory = self.chainedProtocolFactory
comp.setComponent(iconch.IConchUser, user)
comp.setComponent(iconch.ISession, sess)
return user
def __init__(self, transportFactory=None):
if transportFactory is not None:
self.transportFactory = transportFactory
def requestAvatar(self, avatarId, mind, *interfaces):
for i in interfaces:
if i is iconch.IConchUser:
return (iconch.IConchUser, self._getAvatar(avatarId), lambda: None)
raise NotImplementedError()
class ConchFactory(factory.SSHFactory):
publicKeys: Dict[bytes, bytes] = {}
privateKeys: Dict[bytes, bytes] = {}
def __init__(self, portal):
self.portal = portal

View File

@@ -0,0 +1,180 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
TAP plugin for creating telnet- and ssh-accessible manhole servers.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.application import service, strports
from twisted.conch import manhole, manhole_ssh, telnet
from twisted.conch.insults import insults
from twisted.conch.ssh import keys
from twisted.cred import checkers, portal
from twisted.internet import protocol
from twisted.python import filepath, usage
class makeTelnetProtocol:
def __init__(self, portal):
self.portal = portal
def __call__(self):
auth = telnet.AuthenticatingTelnetProtocol
args = (self.portal,)
return telnet.TelnetTransport(auth, *args)
class chainedProtocolFactory:
def __init__(self, namespace):
self.namespace = namespace
def __call__(self):
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
@implementer(portal.IRealm)
class _StupidRealm:
def __init__(self, proto, *a, **kw):
self.protocolFactory = proto
self.protocolArgs = a
self.protocolKwArgs = kw
def requestAvatar(self, avatarId, *interfaces):
if telnet.ITelnetProtocol in interfaces:
return (
telnet.ITelnetProtocol,
self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs),
lambda: None,
)
raise NotImplementedError()
class Options(usage.Options):
optParameters = [
[
"telnetPort",
"t",
None,
(
"strports description of the address on which to listen for telnet "
"connections"
),
],
[
"sshPort",
"s",
None,
(
"strports description of the address on which to listen for ssh "
"connections"
),
],
[
"passwd",
"p",
"/etc/passwd",
"name of a passwd(5)-format username/password file",
],
[
"sshKeyDir",
None,
"<USER DATA DIR>",
"Directory where the autogenerated SSH key is kept.",
],
["sshKeyName", None, "server.key", "Filename of the autogenerated SSH key."],
["sshKeySize", None, 4096, "Size of the automatically generated SSH key."],
]
def __init__(self):
usage.Options.__init__(self)
self["namespace"] = None
def postOptions(self):
if self["telnetPort"] is None and self["sshPort"] is None:
raise usage.UsageError(
"At least one of --telnetPort and --sshPort must be specified"
)
def makeService(options):
"""
Create a manhole server service.
@type options: L{dict}
@param options: A mapping describing the configuration of
the desired service. Recognized key/value pairs are::
"telnetPort": strports description of the address on which
to listen for telnet connections. If None,
no telnet service will be started.
"sshPort": strports description of the address on which to
listen for ssh connections. If None, no ssh
service will be started.
"namespace": dictionary containing desired initial locals
for manhole connections. If None, an empty
dictionary will be used.
"passwd": Name of a passwd(5)-format username/password file.
"sshKeyDir": The folder that the SSH server key will be kept in.
"sshKeyName": The filename of the key.
"sshKeySize": The size of the key, in bits. Default is 4096.
@rtype: L{twisted.application.service.IService}
@return: A manhole service.
"""
svc = service.MultiService()
namespace = options["namespace"]
if namespace is None:
namespace = {}
checker = checkers.FilePasswordDB(options["passwd"])
if options["telnetPort"]:
telnetRealm = _StupidRealm(
telnet.TelnetBootstrapProtocol,
insults.ServerProtocol,
manhole.ColoredManhole,
namespace,
)
telnetPortal = portal.Portal(telnetRealm, [checker])
telnetFactory = protocol.ServerFactory()
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
telnetService = strports.service(options["telnetPort"], telnetFactory)
telnetService.setServiceParent(svc)
if options["sshPort"]:
sshRealm = manhole_ssh.TerminalRealm()
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
sshPortal = portal.Portal(sshRealm, [checker])
sshFactory = manhole_ssh.ConchFactory(sshPortal)
if options["sshKeyDir"] != "<USER DATA DIR>":
keyDir = options["sshKeyDir"]
else:
from twisted.python._appdirs import getDataDirectory
keyDir = getDataDirectory()
keyLocation = filepath.FilePath(keyDir).child(options["sshKeyName"])
sshKey = keys._getPersistentRSAKey(keyLocation, int(options["sshKeySize"]))
sshFactory.publicKeys[b"ssh-rsa"] = sshKey
sshFactory.privateKeys[b"ssh-rsa"] = sshKey
sshService = strports.service(options["sshPort"], sshFactory)
sshService.setServiceParent(svc)
return svc

View File

@@ -0,0 +1,54 @@
# -*- test-case-name: twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Experimental optimization
This module provides a single mixin class which allows protocols to
collapse numerous small writes into a single larger one.
@author: Jp Calderone
"""
from twisted.internet import reactor
class BufferingMixin:
"""
Mixin which adds write buffering.
"""
_delayedWriteCall = None
data = None
DELAY = 0.0
def schedule(self):
return reactor.callLater(self.DELAY, self.flush)
def reschedule(self, token):
token.reset(self.DELAY)
def write(self, data):
"""
Buffer some bytes to be written soon.
Every call to this function delays the real write by C{self.DELAY}
seconds. When the delay expires, all collected bytes are written
to the underlying transport using L{ITransport.writeSequence}.
"""
if self._delayedWriteCall is None:
self.data = []
self._delayedWriteCall = self.schedule()
else:
self.reschedule(self._delayedWriteCall)
self.data.append(data)
def flush(self):
"""
Flush the buffer immediately.
"""
self._delayedWriteCall = None
self.transport.writeSequence(self.data)
self.data = None

View File

@@ -0,0 +1 @@
!.gitignore

View File

@@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Support for OpenSSH configuration files.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,74 @@
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Factory for reading openssh configuration files: public keys, private keys, and
moduli file.
"""
import errno
import os
from typing import Dict, List, Optional, Tuple
from twisted.conch.openssh_compat import primes
from twisted.conch.ssh import common, factory, keys
from twisted.python.util import runAsEffectiveUser
class OpenSSHFactory(factory.SSHFactory):
dataRoot = "/usr/local/etc"
# For openbsd which puts moduli in a different directory from keys.
moduliRoot = "/usr/local/etc"
def getPublicKeys(self):
"""
Return the server public keys.
"""
ks = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == "ssh_host_" and filename[-8:] == "_key.pub":
try:
k = keys.Key.fromFile(os.path.join(self.dataRoot, filename))
t = common.getNS(k.blob())[0]
ks[t] = k
except Exception as e:
self._log.error(
"bad public key file {filename}: {error}",
filename=filename,
error=e,
)
return ks
def getPrivateKeys(self):
"""
Return the server private keys.
"""
privateKeys = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == "ssh_host_" and filename[-4:] == "_key":
fullPath = os.path.join(self.dataRoot, filename)
try:
key = keys.Key.fromFile(fullPath)
except OSError as e:
if e.errno == errno.EACCES:
# Not allowed, let's switch to root
key = runAsEffectiveUser(0, 0, keys.Key.fromFile, fullPath)
privateKeys[key.sshType()] = key
else:
raise
except Exception as e:
self._log.error(
"bad public key file {filename}: {error}",
filename=filename,
error=e,
)
else:
privateKeys[key.sshType()] = key
return privateKeys
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
try:
return primes.parseModuliFile(self.moduliRoot + "/moduli")
except OSError:
return None

View File

@@ -0,0 +1,31 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
Maintainer: Paul Swartz
"""
from typing import Dict, List, Tuple
def parseModuliFile(filename: str) -> Dict[int, List[Tuple[int, int]]]:
with open(filename) as f:
lines = f.readlines()
primes: Dict[int, List[Tuple[int, int]]] = {}
for l in lines:
l = l.strip()
if not l or l[0] == "#":
continue
tim, typ, tst, tri, sizestr, genstr, modstr = l.split()
size = int(sizestr) + 1
gen = int(genstr)
mod = int(modstr, 16)
if size not in primes:
primes[size] = []
primes[size].append((gen, mod))
return primes

View File

@@ -0,0 +1,569 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Basic line editing support.
@author: Jp Calderone
"""
import string
from typing import Dict
from zope.interface import implementer
from twisted.conch.insults import helper, insults
from twisted.logger import Logger
from twisted.python import reflect
from twisted.python.compat import iterbytes
_counters: Dict[str, int] = {}
class Logging:
"""
Wrapper which logs attribute lookups.
This was useful in debugging something, I guess. I forget what.
It can probably be deleted or moved somewhere more appropriate.
Nothing special going on here, really.
"""
def __init__(self, original):
self.original = original
key = reflect.qual(original.__class__)
count = _counters.get(key, 0)
_counters[key] = count + 1
self._logFile = open(key + "-" + str(count), "w")
def __str__(self) -> str:
return str(super().__getattribute__("original"))
def __repr__(self) -> str:
return repr(super().__getattribute__("original"))
def __getattribute__(self, name):
original = super().__getattribute__("original")
logFile = super().__getattribute__("_logFile")
logFile.write(name + "\n")
return getattr(original, name)
@implementer(insults.ITerminalTransport)
class TransportSequence:
"""
An L{ITerminalTransport} implementation which forwards calls to
one or more other L{ITerminalTransport}s.
This is a cheap way for servers to keep track of the state they
expect the client to see, since all terminal manipulations can be
send to the real client and to a terminal emulator that lives in
the server process.
"""
for keyID in (
b"UP_ARROW",
b"DOWN_ARROW",
b"RIGHT_ARROW",
b"LEFT_ARROW",
b"HOME",
b"INSERT",
b"DELETE",
b"END",
b"PGUP",
b"PGDN",
b"F1",
b"F2",
b"F3",
b"F4",
b"F5",
b"F6",
b"F7",
b"F8",
b"F9",
b"F10",
b"F11",
b"F12",
):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b"\t"
BACKSPACE = b"\x7f"
def __init__(self, *transports):
assert transports, "Cannot construct a TransportSequence with no transports"
self.transports = transports
for method in insults.ITerminalTransport:
exec(
"""\
def %s(self, *a, **kw):
for tpt in self.transports:
result = tpt.%s(*a, **kw)
return result
"""
% (method, method)
)
def getHost(self):
# ITransport.getHost
raise NotImplementedError("Unimplemented: TransportSequence.getHost")
def getPeer(self):
# ITransport.getPeer
raise NotImplementedError("Unimplemented: TransportSequence.getPeer")
def loseConnection(self):
# ITransport.loseConnection
raise NotImplementedError("Unimplemented: TransportSequence.loseConnection")
def write(self, data):
# ITransport.write
raise NotImplementedError("Unimplemented: TransportSequence.write")
def writeSequence(self, data):
# ITransport.writeSequence
raise NotImplementedError("Unimplemented: TransportSequence.writeSequence")
def cursorUp(self, n=1):
# ITerminalTransport.cursorUp
raise NotImplementedError("Unimplemented: TransportSequence.cursorUp")
def cursorDown(self, n=1):
# ITerminalTransport.cursorDown
raise NotImplementedError("Unimplemented: TransportSequence.cursorDown")
def cursorForward(self, n=1):
# ITerminalTransport.cursorForward
raise NotImplementedError("Unimplemented: TransportSequence.cursorForward")
def cursorBackward(self, n=1):
# ITerminalTransport.cursorBackward
raise NotImplementedError("Unimplemented: TransportSequence.cursorBackward")
def cursorPosition(self, column, line):
# ITerminalTransport.cursorPosition
raise NotImplementedError("Unimplemented: TransportSequence.cursorPosition")
def cursorHome(self):
# ITerminalTransport.cursorHome
raise NotImplementedError("Unimplemented: TransportSequence.cursorHome")
def index(self):
# ITerminalTransport.index
raise NotImplementedError("Unimplemented: TransportSequence.index")
def reverseIndex(self):
# ITerminalTransport.reverseIndex
raise NotImplementedError("Unimplemented: TransportSequence.reverseIndex")
def nextLine(self):
# ITerminalTransport.nextLine
raise NotImplementedError("Unimplemented: TransportSequence.nextLine")
def saveCursor(self):
# ITerminalTransport.saveCursor
raise NotImplementedError("Unimplemented: TransportSequence.saveCursor")
def restoreCursor(self):
# ITerminalTransport.restoreCursor
raise NotImplementedError("Unimplemented: TransportSequence.restoreCursor")
def setModes(self, modes):
# ITerminalTransport.setModes
raise NotImplementedError("Unimplemented: TransportSequence.setModes")
def resetModes(self, mode):
# ITerminalTransport.resetModes
raise NotImplementedError("Unimplemented: TransportSequence.resetModes")
def setPrivateModes(self, modes):
# ITerminalTransport.setPrivateModes
raise NotImplementedError("Unimplemented: TransportSequence.setPrivateModes")
def resetPrivateModes(self, modes):
# ITerminalTransport.resetPrivateModes
raise NotImplementedError("Unimplemented: TransportSequence.resetPrivateModes")
def applicationKeypadMode(self):
# ITerminalTransport.applicationKeypadMode
raise NotImplementedError(
"Unimplemented: TransportSequence.applicationKeypadMode"
)
def numericKeypadMode(self):
# ITerminalTransport.numericKeypadMode
raise NotImplementedError("Unimplemented: TransportSequence.numericKeypadMode")
def selectCharacterSet(self, charSet, which):
# ITerminalTransport.selectCharacterSet
raise NotImplementedError("Unimplemented: TransportSequence.selectCharacterSet")
def shiftIn(self):
# ITerminalTransport.shiftIn
raise NotImplementedError("Unimplemented: TransportSequence.shiftIn")
def shiftOut(self):
# ITerminalTransport.shiftOut
raise NotImplementedError("Unimplemented: TransportSequence.shiftOut")
def singleShift2(self):
# ITerminalTransport.singleShift2
raise NotImplementedError("Unimplemented: TransportSequence.singleShift2")
def singleShift3(self):
# ITerminalTransport.singleShift3
raise NotImplementedError("Unimplemented: TransportSequence.singleShift3")
def selectGraphicRendition(self, *attributes):
# ITerminalTransport.selectGraphicRendition
raise NotImplementedError(
"Unimplemented: TransportSequence.selectGraphicRendition"
)
def horizontalTabulationSet(self):
# ITerminalTransport.horizontalTabulationSet
raise NotImplementedError(
"Unimplemented: TransportSequence.horizontalTabulationSet"
)
def tabulationClear(self):
# ITerminalTransport.tabulationClear
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClear")
def tabulationClearAll(self):
# ITerminalTransport.tabulationClearAll
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClearAll")
def doubleHeightLine(self, top=True):
# ITerminalTransport.doubleHeightLine
raise NotImplementedError("Unimplemented: TransportSequence.doubleHeightLine")
def singleWidthLine(self):
# ITerminalTransport.singleWidthLine
raise NotImplementedError("Unimplemented: TransportSequence.singleWidthLine")
def doubleWidthLine(self):
# ITerminalTransport.doubleWidthLine
raise NotImplementedError("Unimplemented: TransportSequence.doubleWidthLine")
def eraseToLineEnd(self):
# ITerminalTransport.eraseToLineEnd
raise NotImplementedError("Unimplemented: TransportSequence.eraseToLineEnd")
def eraseToLineBeginning(self):
# ITerminalTransport.eraseToLineBeginning
raise NotImplementedError(
"Unimplemented: TransportSequence.eraseToLineBeginning"
)
def eraseLine(self):
# ITerminalTransport.eraseLine
raise NotImplementedError("Unimplemented: TransportSequence.eraseLine")
def eraseToDisplayEnd(self):
# ITerminalTransport.eraseToDisplayEnd
raise NotImplementedError("Unimplemented: TransportSequence.eraseToDisplayEnd")
def eraseToDisplayBeginning(self):
# ITerminalTransport.eraseToDisplayBeginning
raise NotImplementedError(
"Unimplemented: TransportSequence.eraseToDisplayBeginning"
)
def eraseDisplay(self):
# ITerminalTransport.eraseDisplay
raise NotImplementedError("Unimplemented: TransportSequence.eraseDisplay")
def deleteCharacter(self, n=1):
# ITerminalTransport.deleteCharacter
raise NotImplementedError("Unimplemented: TransportSequence.deleteCharacter")
def insertLine(self, n=1):
# ITerminalTransport.insertLine
raise NotImplementedError("Unimplemented: TransportSequence.insertLine")
def deleteLine(self, n=1):
# ITerminalTransport.deleteLine
raise NotImplementedError("Unimplemented: TransportSequence.deleteLine")
def reportCursorPosition(self):
# ITerminalTransport.reportCursorPosition
raise NotImplementedError(
"Unimplemented: TransportSequence.reportCursorPosition"
)
def reset(self):
# ITerminalTransport.reset
raise NotImplementedError("Unimplemented: TransportSequence.reset")
def unhandledControlSequence(self, seq):
# ITerminalTransport.unhandledControlSequence
raise NotImplementedError(
"Unimplemented: TransportSequence.unhandledControlSequence"
)
class LocalTerminalBufferMixin:
"""
A mixin for RecvLine subclasses which records the state of the terminal.
This is accomplished by performing all L{ITerminalTransport} operations on both
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
will be made to keep up to date with the actual terminal
associated with this protocol instance.
"""
def makeConnection(self, transport):
self.terminalCopy = helper.TerminalBuffer()
self.terminalCopy.connectionMade()
return super().makeConnection(TransportSequence(transport, self.terminalCopy))
def __str__(self) -> str:
return str(self.terminalCopy)
class RecvLine(insults.TerminalProtocol):
"""
L{TerminalProtocol} which adds line editing features.
Clients will be prompted for lines of input with all the usual
features: character echoing, left and right arrow support for
moving the cursor to different areas of the line buffer, backspace
and delete for removing characters, and insert for toggling
between typeover and insert mode. Tabs will be expanded to enough
spaces to move the cursor to the next tabstop (every four
characters by default). Enter causes the line buffer to be
cleared and the line to be passed to the lineReceived() method
which, by default, does nothing. Subclasses are responsible for
redrawing the input prompt (this will probably change).
"""
width = 80
height = 24
TABSTOP = 4
ps = (b">>> ", b"... ")
pn = 0
_printableChars = string.printable.encode("ascii")
_log = Logger()
def connectionMade(self):
# A list containing the characters making up the current line
self.lineBuffer = []
# A zero-based (wtf else?) index into self.lineBuffer.
# Indicates the current cursor position.
self.lineBufferIndex = 0
t = self.terminal
# A map of keyIDs to bound instance methods.
self.keyHandlers = {
t.LEFT_ARROW: self.handle_LEFT,
t.RIGHT_ARROW: self.handle_RIGHT,
t.TAB: self.handle_TAB,
# Both of these should not be necessary, but figuring out
# which is necessary is a huge hassle.
b"\r": self.handle_RETURN,
b"\n": self.handle_RETURN,
t.BACKSPACE: self.handle_BACKSPACE,
t.DELETE: self.handle_DELETE,
t.INSERT: self.handle_INSERT,
t.HOME: self.handle_HOME,
t.END: self.handle_END,
}
self.initializeScreen()
def initializeScreen(self):
# Hmm, state sucks. Oh well.
# For now we will just take over the whole terminal.
self.terminal.reset()
self.terminal.write(self.ps[self.pn])
# XXX Note: I would prefer to default to starting in insert
# mode, however this does not seem to actually work! I do not
# know why. This is probably of interest to implementors
# subclassing RecvLine.
# XXX XXX Note: But the unit tests all expect the initial mode
# to be insert right now. Fuck, there needs to be a way to
# query the current mode or something.
# self.setTypeoverMode()
self.setInsertMode()
def currentLineBuffer(self):
s = b"".join(self.lineBuffer)
return s[: self.lineBufferIndex], s[self.lineBufferIndex :]
def setInsertMode(self):
self.mode = "insert"
self.terminal.setModes([insults.modes.IRM])
def setTypeoverMode(self):
self.mode = "typeover"
self.terminal.resetModes([insults.modes.IRM])
def drawInputLine(self):
"""
Write a line containing the current input prompt and the current line
buffer at the current cursor position.
"""
self.terminal.write(self.ps[self.pn] + b"".join(self.lineBuffer))
def terminalSize(self, width, height):
# XXX - Clear the previous input line, redraw it at the new
# cursor position
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.width = width
self.height = height
self.drawInputLine()
def unhandledControlSequence(self, seq):
pass
def keystrokeReceived(self, keyID, modifier):
m = self.keyHandlers.get(keyID)
if m is not None:
m()
elif keyID in self._printableChars:
self.characterReceived(keyID, False)
else:
self._log.warn("Received unhandled keyID: {keyID!r}", keyID=keyID)
def characterReceived(self, ch, moreCharactersComing):
if self.mode == "insert":
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
self.lineBufferIndex += 1
self.terminal.write(ch)
def handle_TAB(self):
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
self.terminal.cursorForward(n)
self.lineBufferIndex += n
self.lineBuffer.extend(iterbytes(b" " * n))
def handle_LEFT(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
self.terminal.cursorBackward()
def handle_RIGHT(self):
if self.lineBufferIndex < len(self.lineBuffer):
self.lineBufferIndex += 1
self.terminal.cursorForward()
def handle_HOME(self):
if self.lineBufferIndex:
self.terminal.cursorBackward(self.lineBufferIndex)
self.lineBufferIndex = 0
def handle_END(self):
offset = len(self.lineBuffer) - self.lineBufferIndex
if offset:
self.terminal.cursorForward(offset)
self.lineBufferIndex = len(self.lineBuffer)
def handle_BACKSPACE(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
del self.lineBuffer[self.lineBufferIndex]
self.terminal.cursorBackward()
self.terminal.deleteCharacter()
def handle_DELETE(self):
if self.lineBufferIndex < len(self.lineBuffer):
del self.lineBuffer[self.lineBufferIndex]
self.terminal.deleteCharacter()
def handle_RETURN(self):
line = b"".join(self.lineBuffer)
self.lineBuffer = []
self.lineBufferIndex = 0
self.terminal.nextLine()
self.lineReceived(line)
def handle_INSERT(self):
assert self.mode in ("typeover", "insert")
if self.mode == "typeover":
self.setInsertMode()
else:
self.setTypeoverMode()
def lineReceived(self, line):
pass
class HistoricRecvLine(RecvLine):
"""
L{TerminalProtocol} which adds both basic line-editing features and input history.
Everything supported by L{RecvLine} is also supported by this class. In addition, the
up and down arrows traverse the input history. Each received line is automatically
added to the end of the input history.
"""
def connectionMade(self):
RecvLine.connectionMade(self)
self.historyLines = []
self.historyPosition = 0
t = self.terminal
self.keyHandlers.update(
{t.UP_ARROW: self.handle_UP, t.DOWN_ARROW: self.handle_DOWN}
)
def currentHistoryBuffer(self):
b = tuple(self.historyLines)
return b[: self.historyPosition], b[self.historyPosition :]
def _deliverBuffer(self, buf):
if buf:
for ch in iterbytes(buf[:-1]):
self.characterReceived(ch, True)
self.characterReceived(buf[-1:], False)
def handle_UP(self):
if self.lineBuffer and self.historyPosition == len(self.historyLines):
self.historyLines.append(b"".join(self.lineBuffer))
if self.historyPosition > 0:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition -= 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
def handle_DOWN(self):
if self.historyPosition < len(self.historyLines) - 1:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition += 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
else:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition = len(self.historyLines)
self.lineBuffer = []
self.lineBufferIndex = 0
def handle_RETURN(self):
if self.lineBuffer:
self.historyLines.append(b"".join(self.lineBuffer))
self.historyPosition = len(self.historyLines)
return RecvLine.handle_RETURN(self)

View File

@@ -0,0 +1 @@
"conch scripts"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,400 @@
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `ckeygen` command.
"""
from __future__ import annotations
import getpass
import os
import platform
import socket
import sys
from collections.abc import Callable
from functools import wraps
from importlib import reload
from typing import Any, Dict, Optional
from twisted.conch.ssh import keys
from twisted.python import failure, filepath, log, usage
if getpass.getpass == getpass.unix_getpass: # type: ignore[attr-defined]
try:
import termios # hack around broken termios
termios.tcgetattr, termios.tcsetattr
except (ImportError, AttributeError):
sys.modules["termios"] = None # type: ignore[assignment]
reload(getpass)
supportedKeyTypes = dict()
def _keyGenerator(keyType):
def assignkeygenerator(keygenerator):
@wraps(keygenerator)
def wrapper(*args, **kwargs):
return keygenerator(*args, **kwargs)
supportedKeyTypes[keyType] = wrapper
return wrapper
return assignkeygenerator
class GeneralOptions(usage.Options):
synopsis = """Usage: ckeygen [options]
"""
longdesc = "ckeygen manipulates public/private keys in various ways."
optParameters = [
["bits", "b", None, "Number of bits in the key to create."],
["filename", "f", None, "Filename of the key file."],
["type", "t", None, "Specify type of key to create."],
["comment", "C", None, "Provide new comment."],
["newpass", "N", None, "Provide new passphrase."],
["pass", "P", None, "Provide old passphrase."],
["format", "o", "sha256-base64", "Fingerprint format of key file."],
[
"private-key-subtype",
None,
None,
'OpenSSH private key subtype to write ("PEM" or "v1").',
],
]
optFlags = [
["fingerprint", "l", "Show fingerprint of key file."],
["changepass", "p", "Change passphrase of private key file."],
["quiet", "q", "Quiet."],
["no-passphrase", None, "Create the key with no passphrase."],
["showpub", "y", "Read private key file and print public key."],
]
compData = usage.Completions(
optActions={
"type": usage.CompleteList(list(supportedKeyTypes.keys())),
"private-key-subtype": usage.CompleteList(["PEM", "v1"]),
}
)
def run():
options = GeneralOptions()
try:
options.parseOptions(sys.argv[1:])
except usage.UsageError as u:
print("ERROR: %s" % u)
options.opt_help()
sys.exit(1)
log.discardLogs()
log.deferr = handleError # HACK
if options["type"]:
if options["type"].lower() in supportedKeyTypes:
print("Generating public/private %s key pair." % (options["type"]))
supportedKeyTypes[options["type"].lower()](options)
else:
sys.exit(
"Key type was %s, must be one of %s"
% (options["type"], ", ".join(supportedKeyTypes.keys()))
)
elif options["fingerprint"]:
printFingerprint(options)
elif options["changepass"]:
changePassPhrase(options)
elif options["showpub"]:
displayPublicKey(options)
else:
options.opt_help()
sys.exit(1)
def enumrepresentation(options):
if options["format"] == "md5-hex":
options["format"] = keys.FingerprintFormats.MD5_HEX
return options
elif options["format"] == "sha256-base64":
options["format"] = keys.FingerprintFormats.SHA256_BASE64
return options
else:
raise keys.BadFingerPrintFormat(
f"Unsupported fingerprint format: {options['format']}"
)
def handleError():
global exitStatus
exitStatus = 2
log.err(failure.Failure())
raise
@_keyGenerator("rsa")
def generateRSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
if not options["bits"]:
options["bits"] = 2048
keyPrimitive = rsa.generate_private_key(
key_size=int(options["bits"]),
public_exponent=65537,
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("dsa")
def generateDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dsa
if not options["bits"]:
options["bits"] = 1024
keyPrimitive = dsa.generate_private_key(
key_size=int(options["bits"]),
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("ecdsa")
def generateECDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
if not options["bits"]:
options["bits"] = 256
# OpenSSH supports only mandatory sections of RFC5656.
# See https://www.openssh.com/txt/release-5.7
curve = b"ecdsa-sha2-nistp" + str(options["bits"]).encode("ascii")
keyPrimitive = ec.generate_private_key(
curve=keys._curveTable[curve], backend=default_backend()
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator("ed25519")
def generateEd25519key(options):
keyPrimitive = keys.Ed25519PrivateKey.generate()
key = keys.Key(keyPrimitive)
_saveKey(key, options)
def _defaultPrivateKeySubtype(keyType):
"""
Return a reasonable default private key subtype for a given key type.
@type keyType: L{str}
@param keyType: A key type, as returned by
L{twisted.conch.ssh.keys.Key.type}.
@rtype: L{str}
@return: A private OpenSSH key subtype (C{'PEM'} or C{'v1'}).
"""
if keyType == "Ed25519":
# No PEM format is defined for Ed25519 keys.
return "v1"
else:
return "PEM"
def _getKeyOrDefault(
options: Dict[Any, Any],
inputCollector: Optional[Callable[[str], str]] = None,
keyTypeName: str = "rsa",
) -> str:
"""
If C{options["filename"]} is None, prompt the user to enter a path
or attempt to set it to .ssh/id_rsa
@param options: command line options
@param inputCollector: dependency injection for testing
@param keyTypeName: key type or "rsa"
"""
if inputCollector is None:
inputCollector = input
filename = options["filename"]
if not filename:
filename = os.path.expanduser(f"~/.ssh/id_{keyTypeName}")
if platform.system() == "Windows":
filename = os.path.expanduser(Rf"%HOMEPATH %\.ssh\id_{keyTypeName}")
filename = (
inputCollector("Enter file in which the key is (%s): " % filename)
or filename
)
return str(filename)
def printFingerprint(options: Dict[Any, Any]) -> None:
filename = _getKeyOrDefault(options)
if os.path.exists(filename + ".pub"):
filename += ".pub"
options = enumrepresentation(options)
try:
key = keys.Key.fromFile(filename)
print(
"%s %s %s"
% (
key.size(),
key.fingerprint(options["format"]),
os.path.basename(filename),
)
)
except keys.BadKeyError:
sys.exit("bad key")
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
def changePassPhrase(options):
filename = _getKeyOrDefault(options)
try:
key = keys.Key.fromFile(filename)
except keys.EncryptedKeyError:
# Raised if password not supplied for an encrypted key
if not options.get("pass"):
options["pass"] = getpass.getpass("Enter old passphrase: ")
try:
key = keys.Key.fromFile(filename, passphrase=options["pass"])
except keys.BadKeyError:
sys.exit("Could not change passphrase: old passphrase error")
except keys.EncryptedKeyError as e:
sys.exit(f"Could not change passphrase: {e}")
except keys.BadKeyError as e:
sys.exit(f"Could not change passphrase: {e}")
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
if not options.get("newpass"):
while 1:
p1 = getpass.getpass("Enter new passphrase (empty for no passphrase): ")
p2 = getpass.getpass("Enter same passphrase again: ")
if p1 == p2:
break
print("Passphrases do not match. Try again.")
options["newpass"] = p1
if options.get("private-key-subtype") is None:
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
try:
newkeydata = key.toString(
"openssh",
subtype=options["private-key-subtype"],
passphrase=options["newpass"],
)
except Exception as e:
sys.exit(f"Could not change passphrase: {e}")
try:
keys.Key.fromString(newkeydata, passphrase=options["newpass"])
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
sys.exit(f"Could not change passphrase: {e}")
with open(filename, "wb") as fd:
fd.write(newkeydata)
print("Your identification has been saved with the new passphrase.")
def displayPublicKey(options):
filename = _getKeyOrDefault(options)
try:
key = keys.Key.fromFile(filename)
except FileNotFoundError:
sys.exit(f"{filename} could not be opened, please specify a file.")
except keys.EncryptedKeyError:
if not options.get("pass"):
options["pass"] = getpass.getpass("Enter passphrase: ")
key = keys.Key.fromFile(filename, passphrase=options["pass"])
displayKey = key.public().toString("openssh").decode("ascii")
print(displayKey)
def _inputSaveFile(prompt: str) -> str:
"""
Ask the user where to save the key.
This needs to be a separate function so the unit test can patch it.
"""
return input(prompt)
def _saveKey(
key: keys.Key,
options: Dict[Any, Any],
inputCollector: Optional[Callable[[str], str]] = None,
) -> None:
"""
Persist a SSH key on local filesystem.
@param key: Key which is persisted on local filesystem.
@param options:
@param inputCollector: Dependency injection for testing.
"""
if inputCollector is None:
inputCollector = input
KeyTypeMapping = {"EC": "ecdsa", "Ed25519": "ed25519", "RSA": "rsa", "DSA": "dsa"}
keyTypeName = KeyTypeMapping[key.type()]
filename = options["filename"]
if not filename:
defaultPath = _getKeyOrDefault(options, inputCollector, keyTypeName)
newPath = _inputSaveFile(
f"Enter file in which to save the key ({defaultPath}): "
)
filename = newPath.strip() or defaultPath
if os.path.exists(filename):
print(f"{filename} already exists.")
yn = inputCollector("Overwrite (y/n)? ")
if yn[0].lower() != "y":
sys.exit()
if options.get("no-passphrase"):
options["pass"] = b""
elif not options["pass"]:
while 1:
p1 = getpass.getpass("Enter passphrase (empty for no passphrase): ")
p2 = getpass.getpass("Enter same passphrase again: ")
if p1 == p2:
break
print("Passphrases do not match. Try again.")
options["pass"] = p1
if options.get("private-key-subtype") is None:
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
comment = f"{getpass.getuser()}@{socket.gethostname()}"
fp = filepath.FilePath(filename)
fp.setContent(
key.toString(
"openssh",
subtype=options["private-key-subtype"],
passphrase=options["pass"],
)
)
fp.chmod(0o100600)
filepath.FilePath(filename + ".pub").setContent(
key.public().toString("openssh", comment=comment)
)
options = enumrepresentation(options)
print(f"Your identification has been saved in {filename}")
print(f"Your public key has been saved in {filename}.pub")
print(f"The key fingerprint in {options['format']} is:")
print(key.fingerprint(options["format"]))
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,578 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
# Implementation module for the `conch` command.
#
import fcntl
import getpass
import os
import signal
import struct
import sys
import tty
from typing import List, Tuple
from twisted.conch.client import connect, default
from twisted.conch.client.options import ConchOptions
from twisted.conch.error import ConchError
from twisted.conch.ssh import channel, common, connection, forwarding, session
from twisted.internet import reactor, stdio, task
from twisted.python import log, usage
from twisted.python.compat import ioType, networkString
class ClientOptions(ConchOptions):
synopsis = """Usage: conch [options] host [command]
"""
longdesc = (
"conch is a SSHv2 client that allows logging into a remote "
"machine and executing commands."
)
optParameters = [
["escape", "e", "~"],
[
"localforward",
"L",
None,
"listen-port:host:port Forward local port to remote address",
],
[
"remoteforward",
"R",
None,
"listen-port:host:port Forward remote port to local address",
],
]
optFlags = [
["null", "n", "Redirect input from /dev/null."],
["fork", "f", "Fork to background after authentication."],
["tty", "t", "Tty; allocate a tty even if command is given."],
["notty", "T", "Do not allocate a tty."],
["noshell", "N", "Do not execute a shell or command."],
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
]
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port"),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
localForwards: List[Tuple[int, Tuple[int, int]]] = []
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
def opt_escape(self, esc):
"""
Set escape character; ``none'' = disable
"""
if esc == "none":
self["escape"] = None
elif esc[0] == "^" and len(esc) == 2:
self["escape"] = chr(ord(esc[1]) - 64)
elif len(esc) == 1:
self["escape"] = esc
else:
sys.exit(f"Bad escape character '{esc}'.")
def opt_localforward(self, f):
"""
Forward local port to remote address (lport:host:port)
"""
localPort, remoteHost, remotePort = f.split(":") # Doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
"""
Forward remote port to local address (rport:host:port)
"""
remotePort, connHost, connPort = f.split(":") # Doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def parseArgs(self, host, *command):
self["host"] = host
self["command"] = " ".join(command)
# Rest of code in "run"
options = None
conn = None
exitStatus = 0
old = None
_inRawMode = 0
_savedRawMode = None
def run():
global options, old
args = sys.argv[1:]
if "-l" in args: # CVS is an idiot
i = args.index("-l")
args = args[i : i + 2] + args
del args[i + 2 : i + 4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == "-o" and args[i + 1][0] != "-":
args[i : i + 2] = [] # Suck on it scp
except ValueError:
pass
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print(f"ERROR: {u}")
options.opt_help()
sys.exit(1)
if options["log"]:
if options["logfile"]:
if options["logfile"] == "-":
f = sys.stdout
else:
f = open(options["logfile"], "a+")
else:
f = sys.stderr
realout = sys.stdout
log.startLogging(f)
sys.stdout = realout
else:
log.discardLogs()
doConnect()
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
except BaseException:
old = None
try:
oldUSR1 = signal.signal(
signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect)
)
except BaseException:
oldUSR1 = None
try:
reactor.run()
finally:
if old:
tty.tcsetattr(fd, tty.TCSANOW, old)
if oldUSR1:
signal.signal(signal.SIGUSR1, oldUSR1)
if (options["command"] and options["tty"]) or not options["notty"]:
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
if sys.stdout.isatty() and not options["command"]:
print("Connection to {} closed.".format(options["host"]))
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
reactor.callLater(0.01, _stopReactor)
log.err(failure.Failure())
raise
def _stopReactor():
try:
reactor.stop()
except BaseException:
pass
def doConnect():
if "@" in options["host"]:
options["user"], options["host"] = options["host"].split("@", 1)
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
host = options["host"]
if not options["user"]:
options["user"] = getpass.getuser()
if not options["port"]:
options["port"] = 22
else:
options["port"] = int(options["port"])
host = options["host"]
port = options["port"]
vhk = default.verifyHostKey
if not options["host-key-algorithms"]:
options["host-key-algorithms"] = default.getHostKeyAlgorithms(host, options)
uao = default.SSHUserAuthClient(options["user"], options, SSHConnection())
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
global exitStatus
exitStatus = f"conch: exiting with error {f}"
reactor.callLater(0.1, _stopReactor)
def onConnect():
# if keyAgent and options['agent']:
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
if hasattr(conn.transport, "sendIgnore"):
_KeepAlive(conn)
if options.localForwards:
for localPort, hostport in options.localForwards:
s = reactor.listenTCP(
localPort,
forwarding.SSHListenForwardingFactory(
conn, hostport, SSHListenClientForwardingChannel
),
)
conn.localForwards.append(s)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg(f"asking for remote forwarding for {remotePort}:{hostport}")
conn.requestRemoteForwarding(remotePort, hostport)
reactor.addSystemEventTrigger("before", "shutdown", beforeShutdown)
if not options["noshell"] or options["agent"]:
conn.openChannel(SSHSession())
if options["fork"]:
if os.fork():
os._exit(0)
os.setsid()
for i in range(3):
try:
os.close(i)
except OSError as e:
import errno
if e.errno != errno.EBADF:
raise
def reConnect():
beforeShutdown()
conn.transport.transport.loseConnection()
def beforeShutdown():
remoteForwards = options.remoteForwards
for remotePort, hostport in remoteForwards:
log.msg(f"cancelling {remotePort}:{hostport}")
conn.cancelRemoteForwarding(remotePort)
def stopConnection():
if not options["reconnect"]:
reactor.callLater(0.1, _stopReactor)
class _KeepAlive:
def __init__(self, conn):
self.conn = conn
self.globalTimeout = None
self.lc = task.LoopingCall(self.sendGlobal)
self.lc.start(300)
def sendGlobal(self):
d = self.conn.sendGlobalRequest(
b"conch-keep-alive@twistedmatrix.com", b"", wantReply=1
)
d.addBoth(self._cbGlobal)
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
def _cbGlobal(self, res):
if self.globalTimeout:
self.globalTimeout.cancel()
self.globalTimeout = None
def _ebGlobal(self):
if self.globalTimeout:
self.globalTimeout = None
self.conn.transport.loseConnection()
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
global conn
conn = self
self.localForwards = []
self.remoteForwards = {}
onConnect()
def serviceStopped(self):
lf = self.localForwards
self.localForwards = []
for s in lf:
s.loseConnection()
stopConnection()
def requestRemoteForwarding(self, remotePort, hostport):
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
d = self.sendGlobalRequest(b"tcpip-forward", data, wantReply=1)
log.msg(f"requesting remote forwarding {remotePort}:{hostport}")
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
def _cbRemoteForwarding(self, result, remotePort, hostport):
log.msg(f"accepted remote forwarding {remotePort}:{hostport}")
self.remoteForwards[remotePort] = hostport
log.msg(repr(self.remoteForwards))
def _ebRemoteForwarding(self, f, remotePort, hostport):
log.msg(f"remote forwarding {remotePort}:{hostport} failed")
log.msg(f)
def cancelRemoteForwarding(self, remotePort):
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
self.sendGlobalRequest(b"cancel-tcpip-forward", data)
log.msg(f"cancelling remote forwarding {remotePort}")
try:
del self.remoteForwards[remotePort]
except Exception:
pass
log.msg(repr(self.remoteForwards))
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
log.msg(f"FTCP {data!r}")
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
log.msg(self.remoteForwards)
log.msg(remoteHP)
if remoteHP[1] in self.remoteForwards:
connectHP = self.remoteForwards[remoteHP[1]]
log.msg(f"connect forwarding {connectHP}")
return SSHConnectForwardingChannel(
connectHP, remoteWindow=windowSize, remoteMaxPacket=maxPacket, conn=self
)
else:
raise ConchError(
connection.OPEN_CONNECT_FAILED, "don't know about that port"
)
def channelClosed(self, channel):
log.msg(f"connection closing {channel}")
log.msg(self.channels)
if len(self.channels) == 1: # Just us left
log.msg("stopping connection")
stopConnection()
else:
# Because of the unix thing
self.__class__.__bases__[0].channelClosed(self, channel)
class SSHSession(channel.SSHChannel):
name = b"session"
def channelOpen(self, foo):
log.msg(f"session {self.id} open")
if options["agent"]:
d = self.conn.sendRequest(
self, b"auth-agent-req@openssh.com", b"", wantReply=1
)
d.addBoth(lambda x: log.msg(x))
if options["noshell"]:
return
if (options["command"] and options["tty"]) or not options["notty"]:
_enterRawMode()
c = session.SSHSessionClient()
if options["escape"] and not options["notty"]:
self.escapeMode = 1
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = lambda x: self.sendEOF()
self.stdio = stdio.StandardIO(c)
fd = 0
if options["subsystem"]:
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
elif options["command"]:
if options["tty"]:
term = os.environ["TERM"]
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b"exec", common.NS(options["command"]))
else:
if not options["notty"]:
term = os.environ["TERM"]
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b"shell", b"")
# if hasattr(conn.transport, 'transport'):
# conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
if char in (b"\n", b"\r"):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options["escape"]:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # So we can chain escapes together
if char == b".": # Disconnect
log.msg("disconnecting from escape")
stopConnection()
return
elif char == b"\x1a": # ^Z, suspend
def _():
_leaveRawMode()
sys.stdout.flush()
sys.stdin.flush()
os.kill(os.getpid(), signal.SIGTSTP)
_enterRawMode()
reactor.callLater(0, _)
return
elif char == b"R": # Rekey connection
log.msg("rekeying connection")
self.conn.transport.sendKexInit()
return
elif char == b"#": # Display connections
self.stdio.write(b"\r\nThe following connections are open:\r\n")
channels = self.conn.channels.keys()
channels.sort()
for channelId in channels:
self.stdio.write(
networkString(
" #{} {}\r\n".format(
channelId, self.conn.channels[channelId]
)
)
)
return
self.write(b"~" + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
self.stdio.write(data)
def extReceived(self, t, data):
if t == connection.EXTENDED_DATA_STDERR:
log.msg(f"got {len(data)} stderr data")
if ioType(sys.stderr) == str:
sys.stderr.buffer.write(data)
else:
sys.stderr.write(data)
def eofReceived(self):
log.msg("got eof")
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg(f"remote side closed {self}")
self.conn.sendClose(self)
def closed(self):
global old
log.msg(f"closed {self}")
log.msg(repr(self.conn.channels))
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack(">L", data)[0])
log.msg(f"exit status: {exitStatus}")
def sendEOF(self):
self.conn.sendEOF(self)
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
def _windowResized(self, *args):
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, "12345678")
winSize = struct.unpack("4H", winsz)
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
self.conn.sendRequest(self, b"window-change", struct.pack("!4L", *newSize))
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel):
pass
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel):
pass
def _leaveRawMode():
global _inRawMode
if not _inRawMode:
return
fd = sys.stdin.fileno()
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
_inRawMode = 0
def _enterRawMode():
global _inRawMode, _savedRawMode
if _inRawMode:
return
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
new = old[:]
except BaseException:
log.msg("not a typewriter!")
else:
# iflage
new[0] = new[0] | tty.IGNPAR
new[0] = new[0] & ~(
tty.ISTRIP
| tty.INLCR
| tty.IGNCR
| tty.ICRNL
| tty.IXON
| tty.IXANY
| tty.IXOFF
)
if hasattr(tty, "IUCLC"):
new[0] = new[0] & ~tty.IUCLC
# lflag
new[3] = new[3] & ~(
tty.ISIG
| tty.ICANON
| tty.ECHO
| tty.ECHO
| tty.ECHOE
| tty.ECHOK
| tty.ECHONL
)
if hasattr(tty, "IEXTEN"):
new[3] = new[3] & ~tty.IEXTEN
# oflag
new[1] = new[1] & ~tty.OPOST
new[6][tty.VMIN] = 1
new[6][tty.VTIME] = 0
_savedRawMode = old
tty.tcsetattr(fd, tty.TCSANOW, new)
# tty.setraw(fd)
_inRawMode = 1
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,673 @@
# -*- test-case-name: twisted.conch.test.test_scripts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `tkconch` command.
"""
import base64
import getpass
import os
import signal
import struct
import sys
import tkinter as Tkinter
import tkinter.filedialog as tkFileDialog
import tkinter.messagebox as tkMessageBox
from typing import List, Tuple
from twisted.conch import error
from twisted.conch.client.default import isInKnownHosts
from twisted.conch.ssh import (
channel,
common,
connection,
forwarding,
keys,
session,
transport,
userauth,
)
from twisted.conch.ui import tkvt100
from twisted.internet import defer, protocol, reactor, tksupport
from twisted.python import log, usage
class TkConchMenu(Tkinter.Frame):
def __init__(self, *args, **params):
## Standard heading: initialization
Tkinter.Frame.__init__(self, *args, **params)
self.master.title("TkConch")
self.localRemoteVar = Tkinter.StringVar()
self.localRemoteVar.set("local")
Tkinter.Label(self, anchor="w", justify="left", text="Hostname").grid(
column=1, row=1, sticky="w"
)
self.host = Tkinter.Entry(self)
self.host.grid(column=2, columnspan=2, row=1, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Port").grid(
column=1, row=2, sticky="w"
)
self.port = Tkinter.Entry(self)
self.port.grid(column=2, columnspan=2, row=2, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Username").grid(
column=1, row=3, sticky="w"
)
self.user = Tkinter.Entry(self)
self.user.grid(column=2, columnspan=2, row=3, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Command").grid(
column=1, row=4, sticky="w"
)
self.command = Tkinter.Entry(self)
self.command.grid(column=2, columnspan=2, row=4, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Identity").grid(
column=1, row=5, sticky="w"
)
self.identity = Tkinter.Entry(self)
self.identity.grid(column=2, row=5, sticky="nesw")
Tkinter.Button(self, command=self.getIdentityFile, text="Browse").grid(
column=3, row=5, sticky="nesw"
)
Tkinter.Label(self, text="Port Forwarding").grid(column=1, row=6, sticky="w")
self.forwards = Tkinter.Listbox(self, height=0, width=0)
self.forwards.grid(column=2, columnspan=2, row=6, sticky="nesw")
Tkinter.Button(self, text="Add", command=self.addForward).grid(column=1, row=7)
Tkinter.Button(self, text="Remove", command=self.removeForward).grid(
column=1, row=8
)
self.forwardPort = Tkinter.Entry(self)
self.forwardPort.grid(column=2, row=7, sticky="nesw")
Tkinter.Label(self, text="Port").grid(column=3, row=7, sticky="nesw")
self.forwardHost = Tkinter.Entry(self)
self.forwardHost.grid(column=2, row=8, sticky="nesw")
Tkinter.Label(self, text="Host").grid(column=3, row=8, sticky="nesw")
self.localForward = Tkinter.Radiobutton(
self, text="Local", variable=self.localRemoteVar, value="local"
)
self.localForward.grid(column=2, row=9)
self.remoteForward = Tkinter.Radiobutton(
self, text="Remote", variable=self.localRemoteVar, value="remote"
)
self.remoteForward.grid(column=3, row=9)
Tkinter.Label(self, text="Advanced Options").grid(
column=1, columnspan=3, row=10, sticky="nesw"
)
Tkinter.Label(self, anchor="w", justify="left", text="Cipher").grid(
column=1, row=11, sticky="w"
)
self.cipher = Tkinter.Entry(self, name="cipher")
self.cipher.grid(column=2, columnspan=2, row=11, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="MAC").grid(
column=1, row=12, sticky="w"
)
self.mac = Tkinter.Entry(self, name="mac")
self.mac.grid(column=2, columnspan=2, row=12, sticky="nesw")
Tkinter.Label(self, anchor="w", justify="left", text="Escape Char").grid(
column=1, row=13, sticky="w"
)
self.escape = Tkinter.Entry(self, name="escape")
self.escape.grid(column=2, columnspan=2, row=13, sticky="nesw")
Tkinter.Button(self, text="Connect!", command=self.doConnect).grid(
column=1, columnspan=3, row=14, sticky="nesw"
)
# Resize behavior(s)
self.grid_rowconfigure(6, weight=1, minsize=64)
self.grid_columnconfigure(2, weight=1, minsize=2)
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
def getIdentityFile(self):
r = tkFileDialog.askopenfilename()
if r:
self.identity.delete(0, Tkinter.END)
self.identity.insert(Tkinter.END, r)
def addForward(self):
port = self.forwardPort.get()
self.forwardPort.delete(0, Tkinter.END)
host = self.forwardHost.get()
self.forwardHost.delete(0, Tkinter.END)
if self.localRemoteVar.get() == "local":
self.forwards.insert(Tkinter.END, f"L:{port}:{host}")
else:
self.forwards.insert(Tkinter.END, f"R:{port}:{host}")
def removeForward(self):
cur = self.forwards.curselection()
if cur:
self.forwards.remove(cur[0])
def doConnect(self):
finished = 1
options["host"] = self.host.get()
options["port"] = self.port.get()
options["user"] = self.user.get()
options["command"] = self.command.get()
cipher = self.cipher.get()
mac = self.mac.get()
escape = self.escape.get()
if cipher:
if cipher in SSHClientTransport.supportedCiphers:
SSHClientTransport.supportedCiphers = [cipher]
else:
tkMessageBox.showerror("TkConch", "Bad cipher.")
finished = 0
if mac:
if mac in SSHClientTransport.supportedMACs:
SSHClientTransport.supportedMACs = [mac]
elif finished:
tkMessageBox.showerror("TkConch", "Bad MAC.")
finished = 0
if escape:
if escape == "none":
options["escape"] = None
elif escape[0] == "^" and len(escape) == 2:
options["escape"] = chr(ord(escape[1]) - 64)
elif len(escape) == 1:
options["escape"] = escape
elif finished:
tkMessageBox.showerror("TkConch", "Bad escape character '%s'." % escape)
finished = 0
if self.identity.get():
options.identitys.append(self.identity.get())
for line in self.forwards.get(0, Tkinter.END):
if line[0] == "L":
options.opt_localforward(line[2:])
else:
options.opt_remoteforward(line[2:])
if "@" in options["host"]:
options["user"], options["host"] = options["host"].split("@", 1)
if (not options["host"] or not options["user"]) and finished:
tkMessageBox.showerror("TkConch", "Missing host or username.")
finished = 0
if finished:
self.master.quit()
self.master.destroy()
if options["log"]:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
log.deferr = handleError # HACK
if not options.identitys:
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
host = options["host"]
port = int(options["port"] or 22)
log.msg((host, port))
reactor.connectTCP(host, port, SSHClientFactory())
frame.master.deiconify()
frame.master.title(
"{}@{} - TkConch".format(options["user"], options["host"])
)
else:
self.focus()
class GeneralOptions(usage.Options):
synopsis = """Usage: tkconch [options] host [command]
"""
optParameters = [
["user", "l", None, "Log in using this user name."],
["identity", "i", "~/.ssh/identity", "Identity for public key authentication"],
["escape", "e", "~", "Set escape character; ``none'' = disable"],
["cipher", "c", None, "Select encryption algorithm."],
["macs", "m", None, "Specify MAC algorithms for protocol version 2."],
["port", "p", None, "Connect to this port. Server must be on the same port."],
[
"localforward",
"L",
None,
"listen-port:host:port Forward local port to remote address",
],
[
"remoteforward",
"R",
None,
"listen-port:host:port Forward remote port to local address",
],
]
optFlags = [
["tty", "t", "Tty; allocate a tty even if command is given."],
["notty", "T", "Do not allocate a tty."],
["version", "V", "Display version number only."],
["compress", "C", "Enable compression."],
["noshell", "N", "Do not execute a shell or command."],
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
["log", "v", "Log to stderr"],
["ansilog", "a", "Print the received data to stdout"],
]
_ciphers = transport.SSHClientTransport.supportedCiphers
_macs = transport.SSHClientTransport.supportedMACs
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"cipher": usage.CompleteList([v.decode() for v in _ciphers]),
"macs": usage.CompleteList([v.decode() for v in _macs]),
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port"),
},
extraActions=[
usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True),
],
)
identitys: List[str] = []
localForwards: List[Tuple[int, Tuple[int, int]]] = []
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
def opt_identity(self, i):
self.identitys.append(i)
def opt_localforward(self, f):
localPort, remoteHost, remotePort = f.split(":") # doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
remotePort, connHost, connPort = f.split(":") # doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def opt_compress(self):
SSHClientTransport.supportedCompressions[0:1] = ["zlib"]
def parseArgs(self, *args):
if args:
self["host"] = args[0]
self["command"] = " ".join(args[1:])
else:
self["host"] = ""
self["command"] = ""
# Rest of code in "run"
options = None
menu = None
exitStatus = 0
frame = None
def deferredAskFrame(question, echo):
if frame.callback:
raise ValueError("can't ask 2 questions at once!")
d = defer.Deferred()
resp = []
def gotChar(ch, resp=resp):
if not ch:
return
if ch == "\x03": # C-c
reactor.stop()
if ch == "\r":
frame.write("\r\n")
stresp = "".join(resp)
del resp
frame.callback = None
d.callback(stresp)
return
elif 32 <= ord(ch) < 127:
resp.append(ch)
if echo:
frame.write(ch)
elif ord(ch) == 8 and resp: # BS
if echo:
frame.write("\x08 \x08")
resp.pop()
frame.callback = gotChar
frame.write(question)
frame.canvas.focus_force()
return d
def run():
global menu, options, frame
args = sys.argv[1:]
if "-l" in args: # cvs is an idiot
i = args.index("-l")
args = args[i : i + 2] + args
del args[i + 2 : i + 4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == "-o" and args[i + 1][0] != "-":
args[i : i + 2] = [] # suck on it scp
except ValueError:
pass
root = Tkinter.Tk()
root.withdraw()
top = Tkinter.Toplevel()
menu = TkConchMenu(top)
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
options = GeneralOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print("ERROR: %s" % u)
options.opt_help()
sys.exit(1)
for k, v in options.items():
if v and hasattr(menu, k):
getattr(menu, k).insert(Tkinter.END, v)
for p, (rh, rp) in options.localForwards:
menu.forwards.insert(Tkinter.END, f"L:{p}:{rh}:{rp}")
options.localForwards = []
for p, (rh, rp) in options.remoteForwards:
menu.forwards.insert(Tkinter.END, f"R:{p}:{rh}:{rp}")
options.remoteForwards = []
frame = tkvt100.VT100Frame(root, callback=None)
root.geometry(
"%dx%d"
% (tkvt100.fontWidth * frame.width + 3, tkvt100.fontHeight * frame.height + 3)
)
frame.pack(side=Tkinter.TOP)
tksupport.install(root)
root.withdraw()
if (options["host"] and options["user"]) or "@" in options["host"]:
menu.doConnect()
else:
top.mainloop()
reactor.run()
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
log.err(failure.Failure())
reactor.stop()
raise
class SSHClientFactory(protocol.ClientFactory):
noisy = True
def stopFactory(self):
reactor.stop()
def buildProtocol(self, addr):
return SSHClientTransport()
def clientConnectionFailed(self, connector, reason):
tkMessageBox.showwarning(
"TkConch",
f"Connection Failed, Reason:\n {reason.type}: {reason.value}",
)
class SSHClientTransport(transport.SSHClientTransport):
def receiveError(self, code, desc):
global exitStatus
exitStatus = (
"conch:\tRemote side disconnected with error code %i\nconch:\treason: %s"
% (code, desc)
)
def sendDisconnect(self, code, reason):
global exitStatus
exitStatus = (
"conch:\tSending disconnect with error code %i\nconch:\treason: %s"
% (code, reason)
)
transport.SSHClientTransport.sendDisconnect(self, code, reason)
def receiveDebug(self, alwaysDisplay, message, lang):
global options
if alwaysDisplay or options["log"]:
log.msg("Received Debug Message: %s" % message)
def verifyHostKey(self, pubKey, fingerprint):
# d = defer.Deferred()
# d.addCallback(lambda x:defer.succeed(1))
# d.callback(2)
# return d
goodKey = isInKnownHosts(options["host"], pubKey, {"known-hosts": None})
if goodKey == 1: # good key
return defer.succeed(1)
elif goodKey == 2: # AAHHHHH changed
return defer.fail(error.ConchError("bad host key"))
else:
if options["host"] == self.transport.getPeer().host:
host = options["host"]
khHost = options["host"]
else:
host = "{} ({})".format(options["host"], self.transport.getPeer().host)
khHost = "{},{}".format(options["host"], self.transport.getPeer().host)
keyType = common.getNS(pubKey)[0]
ques = """The authenticity of host '{}' can't be established.\r
{} key fingerprint is {}.""".format(
host,
{b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType],
fingerprint,
)
ques += "\r\nAre you sure you want to continue connecting (yes/no)? "
return deferredAskFrame(ques, 1).addCallback(
self._cbVerifyHostKey, pubKey, khHost, keyType
)
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
if ans.lower() not in ("yes", "no"):
return deferredAskFrame("Please type 'yes' or 'no': ", 1).addCallback(
self._cbVerifyHostKey, pubKey, khHost, keyType
)
if ans.lower() == "no":
frame.write("Host key verification failed.\r\n")
raise error.ConchError("bad host key")
try:
frame.write(
"Warning: Permanently added '%s' (%s) to the list of "
"known hosts.\r\n"
% (khHost, {b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType])
)
with open(os.path.expanduser("~/.ssh/known_hosts"), "a") as known_hosts:
encodedKey = base64.b64encode(pubKey)
known_hosts.write(f"\n{khHost} {keyType} {encodedKey}")
except BaseException:
log.deferr()
raise error.ConchError
def connectionSecure(self):
if options["user"]:
user = options["user"]
else:
user = getpass.getuser()
self.requestService(SSHUserAuthClient(user, SSHConnection()))
class SSHUserAuthClient(userauth.SSHUserAuthClient):
usedFiles: List[str] = []
def getPassword(self, prompt=None):
if not prompt:
prompt = "{}@{}'s password: ".format(self.user, options["host"])
return deferredAskFrame(prompt, 0)
def getPublicKey(self):
files = [x for x in options.identitys if x not in self.usedFiles]
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += ".pub"
if not os.path.exists(file):
return
try:
return keys.Key.fromFile(file).blob()
except BaseException:
return self.getPublicKey() # try again
def getPrivateKey(self):
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file).keyObject)
except keys.BadKeyError as e:
if e.args[0] == "encrypted key with no password":
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
def _cbGetPrivateKey(self, ans, count):
file = os.path.expanduser(self.usedFiles[-1])
try:
return keys.Key.fromFile(file, password=ans).keyObject
except keys.BadKeyError:
if count == 2:
raise
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(
self._cbGetPrivateKey, count + 1
)
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
if not options["noshell"]:
self.openChannel(SSHSession())
if options.localForwards:
for localPort, hostport in options.localForwards:
reactor.listenTCP(
localPort,
forwarding.SSHListenForwardingFactory(
self, hostport, forwarding.SSHListenClientForwardingChannel
),
)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg(
"asking for remote forwarding for {}:{}".format(
remotePort, hostport
)
)
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
self.sendGlobalRequest("tcpip-forward", data)
self.remoteForwards[remotePort] = hostport
class SSHSession(channel.SSHChannel):
name = b"session"
def channelOpen(self, foo):
# global globalSession
# globalSession = self
# turn off local echo
self.escapeMode = 1
c = session.SSHSessionClient()
if options["escape"]:
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = self.sendEOF
frame.callback = c.dataReceived
frame.canvas.focus_force()
if options["subsystem"]:
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
elif options["command"]:
if options["tty"]:
term = os.environ.get("TERM", "xterm")
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
self.conn.sendRequest(self, "exec", common.NS(options["command"]))
else:
if not options["notty"]:
term = os.environ.get("TERM", "xterm")
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, "")
self.conn.sendRequest(self, b"pty-req", ptyReqData)
self.conn.sendRequest(self, b"shell", b"")
self.conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
# log.msg('handling %s' % repr(char))
if char in ("\n", "\r"):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options["escape"]:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # so we can chain escapes together
if char == ".": # disconnect
log.msg("disconnecting from escape")
reactor.stop()
return
elif char == "\x1a": # ^Z, suspend
# following line courtesy of Erwin@freenode
os.kill(os.getpid(), signal.SIGSTOP)
return
elif char == "R": # rekey connection
log.msg("rekeying connection")
self.conn.transport.sendKexInit()
return
self.write("~" + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
data = data.decode("utf-8")
if options["ansilog"]:
print(repr(data))
frame.write(data)
def extReceived(self, t, data):
if t == connection.EXTENDED_DATA_STDERR:
log.msg("got %s stderr data" % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg("got eof")
sys.stdin.close()
def closed(self):
log.msg("closed %s" % self)
if len(self.conn.channels) == 1: # just us left
reactor.stop()
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack(">L", data)[0])
log.msg("exit status: %s" % exitStatus)
def sendEOF(self):
self.conn.sendEOF(self)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
Maintainer: Paul Swartz
"""

View File

@@ -0,0 +1,293 @@
# -*- test-case-name: twisted.conch.test.test_transport -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
SSH key exchange handling.
"""
from hashlib import sha1, sha256, sha384, sha512
from zope.interface import Attribute, Interface, implementer
from twisted.conch import error
class _IKexAlgorithm(Interface):
"""
An L{_IKexAlgorithm} describes a key exchange algorithm.
"""
preference = Attribute(
"An L{int} giving the preference of the algorithm when negotiating "
"key exchange. Algorithms with lower precedence values are more "
"preferred."
)
hashProcessor = Attribute(
"A callable hash algorithm constructor (e.g. C{hashlib.sha256}) "
"suitable for use with this key exchange algorithm."
)
class _IFixedGroupKexAlgorithm(_IKexAlgorithm):
"""
An L{_IFixedGroupKexAlgorithm} describes a key exchange algorithm with a
fixed prime / generator group.
"""
prime = Attribute(
"An L{int} giving the prime number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable."
)
generator = Attribute(
"An L{int} giving the generator number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable. (This is not related to "
"Python generator functions.)"
)
class _IEllipticCurveExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IEllipticCurveExchangeKexAlgorithm} describes a key exchange algorithm
that uses an elliptic curve exchange between the client and server.
"""
class _IGroupExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IGroupExchangeKexAlgorithm} describes a key exchange algorithm
that uses group exchange between the client and server.
A prime / generator group should be chosen at run time based on the
requested size. See RFC 4419.
"""
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256:
"""
Elliptic Curve Key Exchange using Curve25519 and SHA256. Defined in
U{https://datatracker.ietf.org/doc/draft-ietf-curdle-ssh-curves/}.
"""
preference = 1
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256LibSSH:
"""
As L{_Curve25519SHA256}, but with a pre-standardized algorithm name.
"""
preference = 2
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH256:
"""
Elliptic Curve Key Exchange with SHA-256 as HASH. Defined in
RFC 5656.
Note that C{ecdh-sha2-nistp256} takes priority over nistp384 or nistp512.
This is the same priority from OpenSSH.
C{ecdh-sha2-nistp256} is considered preety good cryptography.
If you need something better consider using C{curve25519-sha256}.
"""
preference = 3
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH384:
"""
Elliptic Curve Key Exchange with SHA-384 as HASH. Defined in
RFC 5656.
"""
preference = 4
hashProcessor = sha384
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH512:
"""
Elliptic Curve Key Exchange with SHA-512 as HASH. Defined in
RFC 5656.
"""
preference = 5
hashProcessor = sha512
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA256:
"""
Diffie-Hellman Group and Key Exchange with SHA-256 as HASH. Defined in
RFC 4419, 4.2.
"""
preference = 6
hashProcessor = sha256
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA1:
"""
Diffie-Hellman Group and Key Exchange with SHA-1 as HASH. Defined in
RFC 4419, 4.1.
"""
preference = 7
hashProcessor = sha1
@implementer(_IFixedGroupKexAlgorithm)
class _DHGroup14SHA1:
"""
Diffie-Hellman key exchange with SHA-1 as HASH and Oakley Group 14
(2048-bit MODP Group). Defined in RFC 4253, 8.2.
"""
preference = 8
hashProcessor = sha1
# Diffie-Hellman primes from Oakley Group 14 (RFC 3526, 3).
prime = int(
"323170060713110073003389139264238282488179412411402391128420"
"097514007417066343542226196894173635693471179017379097041917"
"546058732091950288537589861856221532121754125149017745202702"
"357960782362488842461894775876411059286460994117232454266225"
"221932305409190376805242355191256797158701170010580558776510"
"388618472802579760549035697325615261670813393617995413364765"
"591603683178967290731783845896806396719009772021941686472258"
"710314113364293195361934716365332097170774482279885885653692"
"086452966360772502689555059283627511211740969729980684105543"
"595848665832916421362182310789909994486524682624169720359118"
"52507045361090559"
)
generator = 2
# Which ECDH hash function to use is dependent on the size.
_kexAlgorithms = {
b"curve25519-sha256": _Curve25519SHA256(),
b"curve25519-sha256@libssh.org": _Curve25519SHA256LibSSH(),
b"diffie-hellman-group-exchange-sha256": _DHGroupExchangeSHA256(),
b"diffie-hellman-group-exchange-sha1": _DHGroupExchangeSHA1(),
b"diffie-hellman-group14-sha1": _DHGroup14SHA1(),
b"ecdh-sha2-nistp256": _ECDH256(),
b"ecdh-sha2-nistp384": _ECDH384(),
b"ecdh-sha2-nistp521": _ECDH512(),
}
def getKex(kexAlgorithm):
"""
Get a description of a named key exchange algorithm.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A description of the key exchange algorithm named by
C{kexAlgorithm}.
@rtype: L{_IKexAlgorithm}
@raises ConchError: if the key exchange algorithm is not found.
"""
if kexAlgorithm not in _kexAlgorithms:
raise error.ConchError(f"Unsupported key exchange algorithm: {kexAlgorithm}")
return _kexAlgorithms[kexAlgorithm]
def isEllipticCurve(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} is an elliptic curve.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: C{str}
@return: C{True} if C{kexAlgorithm} is an elliptic curve,
otherwise C{False}.
@rtype: C{bool}
"""
return _IEllipticCurveExchangeKexAlgorithm.providedBy(getKex(kexAlgorithm))
def isFixedGroup(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} has a fixed prime / generator group.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: C{True} if C{kexAlgorithm} has a fixed prime / generator group,
otherwise C{False}.
@rtype: L{bool}
"""
return _IFixedGroupKexAlgorithm.providedBy(getKex(kexAlgorithm))
def getHashProcessor(kexAlgorithm):
"""
Get the hash algorithm callable to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A callable hash algorithm constructor (e.g. C{hashlib.sha256}).
@rtype: C{callable}
"""
kex = getKex(kexAlgorithm)
return kex.hashProcessor
def getDHGeneratorAndPrime(kexAlgorithm):
"""
Get the generator and the prime to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A L{tuple} containing L{int} generator and L{int} prime.
@rtype: L{tuple}
"""
kex = getKex(kexAlgorithm)
return kex.generator, kex.prime
def getSupportedKeyExchanges():
"""
Get a list of supported key exchange algorithm names in order of
preference.
@return: A C{list} of supported key exchange algorithm names.
@rtype: C{list} of L{bytes}
"""
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from twisted.conch.ssh.keys import _curveTable
backend = default_backend()
kexAlgorithms = _kexAlgorithms.copy()
for keyAlgorithm in list(kexAlgorithms):
if keyAlgorithm.startswith(b"ecdh"):
keyAlgorithmDsa = keyAlgorithm.replace(b"ecdh", b"ecdsa")
supported = backend.elliptic_curve_exchange_algorithm_supported(
ec.ECDH(), _curveTable[keyAlgorithmDsa]
)
elif keyAlgorithm.startswith(b"curve25519-sha256"):
supported = backend.x25519_supported()
else:
supported = True
if not supported:
kexAlgorithms.pop(keyAlgorithm)
return sorted(
kexAlgorithms, key=lambda kexAlgorithm: kexAlgorithms[kexAlgorithm].preference
)

View File

@@ -0,0 +1,43 @@
# -*- test-case-name: twisted.conch.test.test_address -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Address object for SSH network connections.
Maintainer: Paul Swartz
@since: 12.1
"""
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python import util
@implementer(IAddress)
class SSHTransportAddress(util.FancyEqMixin):
"""
Object representing an SSH Transport endpoint.
This is used to ensure that any code inspecting this address and
attempting to construct a similar connection based upon it is not
mislead into creating a transport which is not similar to the one it is
indicating.
@ivar address: An instance of an object which implements I{IAddress} to
which this transport address is connected.
"""
compareAttributes = ("address",)
def __init__(self, address):
self.address = address
def __repr__(self) -> str:
return f"SSHTransportAddress({self.address!r})"
def __hash__(self):
return hash(("SSH", self.address))

View File

@@ -0,0 +1,278 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implements the SSH v2 key agent protocol. This protocol is documented in the
SSH source code, in the file
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.ssh import keys
from twisted.conch.ssh.common import NS, getMP, getNS
from twisted.internet import defer, protocol
class SSHAgentClient(protocol.Protocol):
"""
The client side of the SSH agent protocol. This is equivalent to
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
protocol, also in this package.
"""
def __init__(self):
self.buf = b""
self.deferreds = []
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack("!L", self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
reqType = ord(packet[0:1])
d = self.deferreds.pop(0)
if reqType == AGENT_FAILURE:
d.errback(ConchError("agent failure"))
elif reqType == AGENT_SUCCESS:
d.callback(b"")
else:
d.callback(packet)
def sendRequest(self, reqType, data):
pack = struct.pack("!LB", len(data) + 1, reqType) + data
self.transport.write(pack)
d = defer.Deferred()
self.deferreds.append(d)
return d
def requestIdentities(self):
"""
@return: A L{Deferred} which will fire with a list of all keys found in
the SSH agent. The list of keys is comprised of (public key blob,
comment) tuples.
"""
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, b"")
d.addCallback(self._cbRequestIdentities)
return d
def _cbRequestIdentities(self, data):
"""
Unpack a collection of identities into a list of tuples comprised of
public key blobs and comments.
"""
if ord(data[0:1]) != AGENT_IDENTITIES_ANSWER:
raise ConchError("unexpected response: %i" % ord(data[0:1]))
numKeys = struct.unpack("!L", data[1:5])[0]
result = []
data = data[5:]
for i in range(numKeys):
blob, data = getNS(data)
comment, data = getNS(data)
result.append((blob, comment))
return result
def addIdentity(self, blob, comment=b""):
"""
Add a private key blob to the agent's collection of keys.
"""
req = blob
req += NS(comment)
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
def signData(self, blob, data):
"""
Request that the agent sign the given C{data} with the private key
which corresponds to the public key given by C{blob}. The private
key should have been added to the agent already.
@type blob: L{bytes}
@type data: L{bytes}
@return: A L{Deferred} which fires with a signature for given data
created with the given key.
"""
req = NS(blob)
req += NS(data)
req += b"\000\000\000\000" # flags
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
def _cbSignData(self, data):
if ord(data[0:1]) != AGENT_SIGN_RESPONSE:
raise ConchError("unexpected data: %i" % ord(data[0:1]))
signature = getNS(data[1:])[0]
return signature
def removeIdentity(self, blob):
"""
Remove the private key corresponding to the public key in blob from the
running agent.
"""
req = NS(blob)
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
def removeAllIdentities(self):
"""
Remove all keys from the running agent.
"""
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, b"")
class SSHAgentServer(protocol.Protocol):
"""
The server side of the SSH agent protocol. This is equivalent to
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
protocol, also in this package.
"""
def __init__(self):
self.buf = b""
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack("!L", self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
reqType = ord(packet[0:1])
reqName = messages.get(reqType, None)
if not reqName:
self.sendResponse(AGENT_FAILURE, b"")
else:
f = getattr(self, "agentc_%s" % reqName)
if getattr(self.factory, "keys", None) is None:
self.sendResponse(AGENT_FAILURE, b"")
raise MissingKeyStoreError()
f(packet[1:])
def sendResponse(self, reqType, data):
pack = struct.pack("!LB", len(data) + 1, reqType) + data
self.transport.write(pack)
def agentc_REQUEST_IDENTITIES(self, data):
"""
Return all of the identities that have been added to the server
"""
assert data == b""
numKeys = len(self.factory.keys)
resp = []
resp.append(struct.pack("!L", numKeys))
for key, comment in self.factory.keys.values():
resp.append(NS(key.blob())) # yes, wrapped in an NS
resp.append(NS(comment))
self.sendResponse(AGENT_IDENTITIES_ANSWER, b"".join(resp))
def agentc_SIGN_REQUEST(self, data):
"""
Data is a structure with a reference to an already added key object and
some data that the clients wants signed with that key. If the key
object wasn't loaded, return AGENT_FAILURE, else return the signature.
"""
blob, data = getNS(data)
if blob not in self.factory.keys:
return self.sendResponse(AGENT_FAILURE, b"")
signData, data = getNS(data)
assert data == b"\000\000\000\000"
self.sendResponse(
AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData))
)
def agentc_ADD_IDENTITY(self, data):
"""
Adds a private key to the agent's collection of identities. On
subsequent interactions, the private key can be accessed using only the
corresponding public key.
"""
# need to pre-read the key data so we can get past it to the comment string
keyType, rest = getNS(data)
if keyType == b"ssh-rsa":
nmp = 6
elif keyType == b"ssh-dss":
nmp = 5
else:
raise keys.BadKeyError("unknown blob type: %s" % keyType)
rest = getMP(rest, nmp)[
-1
] # ignore the key data for now, we just want the comment
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
k = keys.Key.fromString(data, type="private_blob") # not wrapped in NS here
self.factory.keys[k.blob()] = (k, comment)
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_IDENTITY(self, data):
"""
Remove a specific key from the agent's collection of identities.
"""
blob, _ = getNS(data)
k = keys.Key.fromString(blob, type="blob")
del self.factory.keys[k.blob()]
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_ALL_IDENTITIES(self, data):
"""
Remove all keys from the agent's collection of identities.
"""
assert data == b""
self.factory.keys = {}
self.sendResponse(AGENT_SUCCESS, b"")
# v1 messages that we ignore because we don't keep v1 keys
# open-ssh sends both v1 and v2 commands, so we have to
# do no-ops for v1 commands or we'll get "bad request" errors
def agentc_REQUEST_RSA_IDENTITIES(self, data):
"""
v1 message for listing RSA1 keys; superseded by
agentc_REQUEST_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack("!L", 0))
def agentc_REMOVE_RSA_IDENTITY(self, data):
"""
v1 message for removing RSA1 keys; superseded by
agentc_REMOVE_IDENTITY, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b"")
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
"""
v1 message for removing all RSA1 keys; superseded by
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b"")
AGENTC_REQUEST_RSA_IDENTITIES = 1
AGENT_RSA_IDENTITIES_ANSWER = 2
AGENT_FAILURE = 5
AGENT_SUCCESS = 6
AGENTC_REMOVE_RSA_IDENTITY = 8
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
AGENTC_REQUEST_IDENTITIES = 11
AGENT_IDENTITIES_ANSWER = 12
AGENTC_SIGN_REQUEST = 13
AGENT_SIGN_RESPONSE = 14
AGENTC_ADD_IDENTITY = 17
AGENTC_REMOVE_IDENTITY = 18
AGENTC_REMOVE_ALL_IDENTITIES = 19
messages = {}
for name, value in locals().copy().items():
if name[:7] == "AGENTC_":
messages[value] = name[7:] # doesn't handle doubles

View File

@@ -0,0 +1,312 @@
# -*- test-case-name: twisted.conch.test.test_channel -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH Channels. Currently implemented channels
are session, direct-tcp, and forwarded-tcp.
Maintainer: Paul Swartz
"""
from zope.interface import implementer
from twisted.internet import interfaces
from twisted.logger import Logger
from twisted.python import log
@implementer(interfaces.ITransport)
class SSHChannel(log.Logger):
"""
A class that represents a multiplexed channel over an SSH connection.
The channel has a local window which is the maximum amount of data it will
receive, and a remote which is the maximum amount of data the remote side
will accept. There is also a maximum packet size for any individual data
packet going each way.
@ivar name: the name of the channel.
@type name: L{bytes}
@ivar localWindowSize: the maximum size of the local window in bytes.
@type localWindowSize: L{int}
@ivar localWindowLeft: how many bytes are left in the local window.
@type localWindowLeft: L{int}
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
@type localMaxPacket: L{int}
@ivar remoteWindowLeft: how many bytes are left in the remote window.
@type remoteWindowLeft: L{int}
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
accept in bytes.
@type remoteMaxPacket: L{int}
@ivar conn: the connection this channel is multiplexed through.
@type conn: L{SSHConnection}
@ivar data: any data to send to the other side when the channel is
requested.
@type data: L{bytes}
@ivar avatar: an avatar for the logged-in user (if a server channel)
@ivar localClosed: True if we aren't accepting more data.
@type localClosed: L{bool}
@ivar remoteClosed: True if the other side isn't accepting more data.
@type remoteClosed: L{bool}
"""
_log = Logger()
name: bytes = None # type: ignore[assignment] # only needed for client channels
def __init__(
self,
localWindow=0,
localMaxPacket=0,
remoteWindow=0,
remoteMaxPacket=0,
conn=None,
data=None,
avatar=None,
):
self.localWindowSize = localWindow or 131072
self.localWindowLeft = self.localWindowSize
self.localMaxPacket = localMaxPacket or 32768
self.remoteWindowLeft = remoteWindow
self.remoteMaxPacket = remoteMaxPacket
self.areWriting = 1
self.conn = conn
self.data = data
self.avatar = avatar
self.specificData = b""
self.buf = b""
self.extBuf = []
self.closing = 0
self.localClosed = 0
self.remoteClosed = 0
self.id = None # gets set later by SSHConnection
def __str__(self) -> str:
return self.__bytes__().decode("ascii")
def __bytes__(self) -> bytes:
"""
Return a byte string representation of the channel
"""
name = self.name
if not name:
name = b"None"
return b"<SSHChannel %b (lw %d rw %d)>" % (
name,
self.localWindowLeft,
self.remoteWindowLeft,
)
def logPrefix(self):
id = (self.id is not None and str(self.id)) or "unknown"
if self.name:
name = self.name.decode("ascii")
else:
name = "None"
return f"SSHChannel {name} ({id}) on {self.conn.logPrefix()}"
def channelOpen(self, specificData):
"""
Called when the channel is opened. specificData is any data that the
other side sent us when opening the channel.
@type specificData: L{bytes}
"""
self._log.info("channel open")
def openFailed(self, reason):
"""
Called when the open failed for some reason.
reason.desc is a string descrption, reason.code the SSH error code.
@type reason: L{error.ConchError}
"""
self._log.error("other side refused open\nreason: {reason}", reason=reason)
def addWindowBytes(self, data):
"""
Called when bytes are added to the remote window. By default it clears
the data buffers.
@type data: L{bytes}
"""
self.remoteWindowLeft = self.remoteWindowLeft + data
if not self.areWriting and not self.closing:
self.areWriting = True
self.startWriting()
if self.buf:
b = self.buf
self.buf = b""
self.write(b)
if self.extBuf:
b = self.extBuf
self.extBuf = []
for type, data in b:
self.writeExtended(type, data)
def requestReceived(self, requestType, data):
"""
Called when a request is sent to this channel. By default it delegates
to self.request_<requestType>.
If this function returns true, the request succeeded, otherwise it
failed.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{bool}
"""
foo = requestType.replace(b"-", b"_").decode("ascii")
f = getattr(self, "request_" + foo, None)
if f:
return f(data)
self._log.info("unhandled request for {requestType}", requestType=requestType)
return 0
def dataReceived(self, data):
"""
Called when we receive data.
@type data: L{bytes}
"""
self._log.debug("got data {data}", data=data)
def extReceived(self, dataType, data):
"""
Called when we receive extended data (usually standard error).
@type dataType: L{int}
@type data: L{str}
"""
self._log.debug(
"got extended data {dataType} {data!r}", dataType=dataType, data=data
)
def eofReceived(self):
"""
Called when the other side will send no more data.
"""
self._log.info("remote eof")
def closeReceived(self):
"""
Called when the other side has closed the channel.
"""
self._log.info("remote close")
self.loseConnection()
def closed(self):
"""
Called when the channel is closed. This means that both our side and
the remote side have closed the channel.
"""
self._log.info("closed")
def write(self, data):
"""
Write some data to the channel. If there is not enough remote window
available, buffer until it is. Otherwise, split the data into
packets of length remoteMaxPacket and send them.
@type data: L{bytes}
"""
if self.buf:
self.buf += data
return
top = len(data)
if top > self.remoteWindowLeft:
data, self.buf = (
data[: self.remoteWindowLeft],
data[self.remoteWindowLeft :],
)
self.areWriting = 0
self.stopWriting()
top = self.remoteWindowLeft
rmp = self.remoteMaxPacket
write = self.conn.sendData
r = range(0, top, rmp)
for offset in r:
write(self, data[offset : offset + rmp])
self.remoteWindowLeft -= top
if self.closing and not self.buf:
self.loseConnection() # try again
def writeExtended(self, dataType, data):
"""
Send extended data to this channel. If there is not enough remote
window available, buffer until there is. Otherwise, split the data
into packets of length remoteMaxPacket and send them.
@type dataType: L{int}
@type data: L{bytes}
"""
if self.extBuf:
if self.extBuf[-1][0] == dataType:
self.extBuf[-1][1] += data
else:
self.extBuf.append([dataType, data])
return
if len(data) > self.remoteWindowLeft:
data, self.extBuf = (
data[: self.remoteWindowLeft],
[[dataType, data[self.remoteWindowLeft :]]],
)
self.areWriting = 0
self.stopWriting()
while len(data) > self.remoteMaxPacket:
self.conn.sendExtendedData(self, dataType, data[: self.remoteMaxPacket])
data = data[self.remoteMaxPacket :]
self.remoteWindowLeft -= self.remoteMaxPacket
if data:
self.conn.sendExtendedData(self, dataType, data)
self.remoteWindowLeft -= len(data)
if self.closing:
self.loseConnection() # try again
def writeSequence(self, data):
"""
Part of the Transport interface. Write a list of strings to the
channel.
@type data: C{list} of L{str}
"""
self.write(b"".join(data))
def loseConnection(self):
"""
Close the channel if there is no buferred data. Otherwise, note the
request and return.
"""
self.closing = 1
if not self.buf and not self.extBuf:
self.conn.sendClose(self)
def getPeer(self):
"""
See: L{ITransport.getPeer}
@return: The remote address of this connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getPeer()
def getHost(self):
"""
See: L{ITransport.getHost}
@return: An address describing this side of the connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getHost()
def stopWriting(self):
"""
Called when the remote buffer is full, as a hint to stop writing.
This can be ignored, but it can be helpful.
"""
def startWriting(self):
"""
Called when the remote buffer has more room, as a hint to continue
writing.
"""

View File

@@ -0,0 +1,85 @@
# -*- test-case-name: twisted.conch.test.test_ssh -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functions for the SSH classes.
Maintainer: Paul Swartz
"""
import struct
from cryptography.utils import int_to_bytes
from twisted.python.deprecate import deprecated
from twisted.python.versions import Version
__all__ = ["NS", "getNS", "MP", "getMP", "ffs"]
def NS(t):
"""
net string
"""
if isinstance(t, str):
t = t.encode("utf-8")
return struct.pack("!L", len(t)) + t
def getNS(s, count=1):
"""
get net string
"""
ns = []
c = 0
for i in range(count):
(l,) = struct.unpack("!L", s[c : c + 4])
ns.append(s[c + 4 : 4 + l + c])
c += 4 + l
return tuple(ns) + (s[c:],)
def MP(number):
if number == 0:
return b"\000" * 4
assert number > 0
bn = int_to_bytes(number)
if ord(bn[0:1]) & 128:
bn = b"\000" + bn
return struct.pack(">L", len(bn)) + bn
def getMP(data, count=1):
"""
Get multiple precision integer out of the string. A multiple precision
integer is stored as a 4-byte length followed by length bytes of the
integer. If count is specified, get count integers out of the string.
The return value is a tuple of count integers followed by the rest of
the data.
"""
mp = []
c = 0
for i in range(count):
(length,) = struct.unpack(">L", data[c : c + 4])
mp.append(int.from_bytes(data[c + 4 : c + 4 + length], "big"))
c += 4 + length
return tuple(mp) + (data[c:],)
def ffs(c, s):
"""
first from second
goes through the first list, looking for items in the second, returns the first one
"""
for i in c:
if i in s:
return i
@deprecated(Version("Twisted", 16, 5, 0))
def install():
# This used to install gmpy, but is technically public API, so just do
# nothing.
pass

View File

@@ -0,0 +1,679 @@
# -*- test-case-name: twisted.conch.test.test_connection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the ssh-connection service, which
allows access to the shell and port-forwarding.
Maintainer: Paul Swartz
"""
import string
import struct
import twisted.internet.error
from twisted.conch import error
from twisted.conch.ssh import common, service
from twisted.internet import defer
from twisted.logger import Logger
from twisted.python.compat import nativeString, networkString
class SSHConnection(service.SSHService):
"""
An implementation of the 'ssh-connection' service. It is used to
multiplex multiple channels over the single SSH connection.
@ivar localChannelID: the next number to use as a local channel ID.
@type localChannelID: L{int}
@ivar channels: a L{dict} mapping a local channel ID to C{SSHChannel}
subclasses.
@type channels: L{dict}
@ivar localToRemoteChannel: a L{dict} mapping a local channel ID to a
remote channel ID.
@type localToRemoteChannel: L{dict}
@ivar channelsToRemoteChannel: a L{dict} mapping a C{SSHChannel} subclass
to remote channel ID.
@type channelsToRemoteChannel: L{dict}
@ivar deferreds: a L{dict} mapping a local channel ID to a C{list} of
C{Deferreds} for outstanding channel requests. Also, the 'global'
key stores the C{list} of pending global request C{Deferred}s.
"""
name = b"ssh-connection"
_log = Logger()
def __init__(self):
self.localChannelID = 0 # this is the current # to use for channel ID
# local channel ID -> remote channel ID
self.localToRemoteChannel = {}
# local channel ID -> subclass of SSHChannel
self.channels = {}
# subclass of SSHChannel -> remote channel ID
self.channelsToRemoteChannel = {}
# local channel -> list of deferreds for pending requests
# or 'global' -> list of deferreds for global requests
self.deferreds = {"global": []}
self.transport = None # gets set later
def serviceStarted(self):
if hasattr(self.transport, "avatar"):
self.transport.avatar.conn = self
def serviceStopped(self):
"""
Called when the connection is stopped.
"""
# Close any fully open channels
for channel in list(self.channelsToRemoteChannel.keys()):
self.channelClosed(channel)
# Indicate failure to any channels that were in the process of
# opening but not yet open.
while self.channels:
(_, channel) = self.channels.popitem()
channel.openFailed(twisted.internet.error.ConnectionLost())
# Errback any unfinished global requests.
self._cleanupGlobalDeferreds()
def _cleanupGlobalDeferreds(self):
"""
All pending requests that have returned a deferred must be errbacked
when this service is stopped, otherwise they might be left uncalled and
uncallable.
"""
for d in self.deferreds["global"]:
d.errback(error.ConchError("Connection stopped."))
del self.deferreds["global"][:]
# packet methods
def ssh_GLOBAL_REQUEST(self, packet):
"""
The other side has made a global request. Payload::
string request type
bool want reply
<request specific data>
This dispatches to self.gotGlobalRequest.
"""
requestType, rest = common.getNS(packet)
wantReply, rest = ord(rest[0:1]), rest[1:]
ret = self.gotGlobalRequest(requestType, rest)
if wantReply:
reply = MSG_REQUEST_FAILURE
data = b""
if ret:
reply = MSG_REQUEST_SUCCESS
if isinstance(ret, (tuple, list)):
data = ret[1]
self.transport.sendPacket(reply, data)
def ssh_REQUEST_SUCCESS(self, packet):
"""
Our global request succeeded. Get the appropriate Deferred and call
it back with the packet we received.
"""
self._log.debug("global request success")
self.deferreds["global"].pop(0).callback(packet)
def ssh_REQUEST_FAILURE(self, packet):
"""
Our global request failed. Get the appropriate Deferred and errback
it with the packet we received.
"""
self._log.debug("global request failure")
self.deferreds["global"].pop(0).errback(
error.ConchError("global request failed", packet)
)
def ssh_CHANNEL_OPEN(self, packet):
"""
The other side wants to get a channel. Payload::
string channel name
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
We get a channel from self.getChannel(), give it a local channel number
and notify the other side. Then notify the channel by calling its
channelOpen method.
"""
channelType, rest = common.getNS(packet)
senderChannel, windowSize, maxPacket = struct.unpack(">3L", rest[:12])
packet = rest[12:]
try:
channel = self.getChannel(channelType, windowSize, maxPacket, packet)
localChannel = self.localChannelID
self.localChannelID += 1
channel.id = localChannel
self.channels[localChannel] = channel
self.channelsToRemoteChannel[channel] = senderChannel
self.localToRemoteChannel[localChannel] = senderChannel
openConfirmPacket = (
struct.pack(
">4L",
senderChannel,
localChannel,
channel.localWindowSize,
channel.localMaxPacket,
)
+ channel.specificData
)
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION, openConfirmPacket)
channel.channelOpen(packet)
except Exception as e:
self._log.failure("channel open failed")
if isinstance(e, error.ConchError):
textualInfo, reason = e.args
if isinstance(textualInfo, int):
# See #3657 and #3071
textualInfo, reason = reason, textualInfo
else:
reason = OPEN_CONNECT_FAILED
textualInfo = "unknown failure"
self.transport.sendPacket(
MSG_CHANNEL_OPEN_FAILURE,
struct.pack(">2L", senderChannel, reason)
+ common.NS(networkString(textualInfo))
+ common.NS(b""),
)
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
"""
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
Find the channel using the local channel number and notify its
channelOpen method.
"""
(localChannel, remoteChannel, windowSize, maxPacket) = struct.unpack(
">4L", packet[:16]
)
specificData = packet[16:]
channel = self.channels[localChannel]
channel.conn = self
self.localToRemoteChannel[localChannel] = remoteChannel
self.channelsToRemoteChannel[channel] = remoteChannel
channel.remoteWindowLeft = windowSize
channel.remoteMaxPacket = maxPacket
channel.channelOpen(specificData)
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
"""
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 reason code
string reason description
Find the channel using the local channel number and notify it by
calling its openFailed() method.
"""
localChannel, reasonCode = struct.unpack(">2L", packet[:8])
reasonDesc = common.getNS(packet[8:])[0]
channel = self.channels[localChannel]
del self.channels[localChannel]
channel.conn = self
reason = error.ConchError(reasonDesc, reasonCode)
channel.openFailed(reason)
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
"""
The other side is adding bytes to its window. Payload::
uint32 local channel number
uint32 bytes to add
Call the channel's addWindowBytes() method to add new bytes to the
remote window.
"""
localChannel, bytesToAdd = struct.unpack(">2L", packet[:8])
channel = self.channels[localChannel]
channel.addWindowBytes(bytesToAdd)
def ssh_CHANNEL_DATA(self, packet):
"""
The other side is sending us data. Payload::
uint32 local channel number
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or more than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data to the channel's dataReceived().
"""
localChannel, dataLength = struct.unpack(">2L", packet[:8])
channel = self.channels[localChannel]
# XXX should this move to dataReceived to put client in charge?
if (
dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket
): # more data than we want
self._log.error("too much data")
self.sendClose(channel)
return
# packet = packet[:channel.localWindowLeft+4]
data = common.getNS(packet[4:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(
channel, channel.localWindowSize - channel.localWindowLeft
)
channel.dataReceived(data)
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
"""
The other side is sending us exteneded data. Payload::
uint32 local channel number
uint32 type code
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data and type code to the channel's
extReceived().
"""
localChannel, typeCode, dataLength = struct.unpack(">3L", packet[:12])
channel = self.channels[localChannel]
if dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket:
self._log.error("too much extdata")
self.sendClose(channel)
return
data = common.getNS(packet[8:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(
channel, channel.localWindowSize - channel.localWindowLeft
)
channel.extReceived(typeCode, data)
def ssh_CHANNEL_EOF(self, packet):
"""
The other side is not sending any more data. Payload::
uint32 local channel number
Notify the channel by calling its eofReceived() method.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
channel = self.channels[localChannel]
channel.eofReceived()
def ssh_CHANNEL_CLOSE(self, packet):
"""
The other side is closing its end; it does not want to receive any
more data. Payload::
uint32 local channel number
Notify the channnel by calling its closeReceived() method. If
the channel has also sent a close message, call self.channelClosed().
"""
localChannel = struct.unpack(">L", packet[:4])[0]
channel = self.channels[localChannel]
channel.closeReceived()
channel.remoteClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
def ssh_CHANNEL_REQUEST(self, packet):
"""
The other side is sending a request to a channel. Payload::
uint32 local channel number
string request name
bool want reply
<request specific data>
Pass the message to the channel's requestReceived method. If the
other side wants a reply, add callbacks which will send the
reply.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
requestType, rest = common.getNS(packet[4:])
wantReply = ord(rest[0:1])
channel = self.channels[localChannel]
d = defer.maybeDeferred(channel.requestReceived, requestType, rest[1:])
if wantReply:
d.addCallback(self._cbChannelRequest, localChannel)
d.addErrback(self._ebChannelRequest, localChannel)
return d
def _cbChannelRequest(self, result, localChannel):
"""
Called back if the other side wanted a reply to a channel request. If
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
a C{error.ConchError}
@param result: the value returned from the channel's requestReceived()
method. If it's False, the request failed.
@type result: L{bool}
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
@raises ConchError: if the result is False.
"""
if not result:
raise error.ConchError("failed request")
self.transport.sendPacket(
MSG_CHANNEL_SUCCESS,
struct.pack(">L", self.localToRemoteChannel[localChannel]),
)
def _ebChannelRequest(self, result, localChannel):
"""
Called if the other wisde wanted a reply to the channel requeset and
the channel request failed.
@param result: a Failure, but it's not used.
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
"""
self.transport.sendPacket(
MSG_CHANNEL_FAILURE,
struct.pack(">L", self.localToRemoteChannel[localChannel]),
)
def ssh_CHANNEL_SUCCESS(self, packet):
"""
Our channel request to the other side succeeded. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and call it back.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
d.callback("")
def ssh_CHANNEL_FAILURE(self, packet):
"""
Our channel request to the other side failed. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and errback it with a
C{error.ConchError}.
"""
localChannel = struct.unpack(">L", packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
d.errback(error.ConchError("channel request failed"))
# methods for users of the connection to call
def sendGlobalRequest(self, request, data, wantReply=0):
"""
Send a global request for this connection. Current this is only used
for remote->local TCP forwarding.
@type request: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype: C{Deferred}/L{None}
"""
self.transport.sendPacket(
MSG_GLOBAL_REQUEST,
common.NS(request) + (wantReply and b"\xff" or b"\x00") + data,
)
if wantReply:
d = defer.Deferred()
self.deferreds["global"].append(d)
return d
def openChannel(self, channel, extra=b""):
"""
Open a new channel on this connection.
@type channel: subclass of C{SSHChannel}
@type extra: L{bytes}
"""
self._log.info(
"opening channel {id} with {localWindowSize} {localMaxPacket}",
id=self.localChannelID,
localWindowSize=channel.localWindowSize,
localMaxPacket=channel.localMaxPacket,
)
self.transport.sendPacket(
MSG_CHANNEL_OPEN,
common.NS(channel.name)
+ struct.pack(
">3L",
self.localChannelID,
channel.localWindowSize,
channel.localMaxPacket,
)
+ extra,
)
channel.id = self.localChannelID
self.channels[self.localChannelID] = channel
self.localChannelID += 1
def sendRequest(self, channel, requestType, data, wantReply=0):
"""
Send a request to a channel.
@type channel: subclass of C{SSHChannel}
@type requestType: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype: C{Deferred}/L{None}
"""
if channel.localClosed:
return
self._log.debug("sending request {requestType}", requestType=requestType)
self.transport.sendPacket(
MSG_CHANNEL_REQUEST,
struct.pack(">L", self.channelsToRemoteChannel[channel])
+ common.NS(requestType)
+ (b"\1" if wantReply else b"\0")
+ data,
)
if wantReply:
d = defer.Deferred()
self.deferreds.setdefault(channel.id, []).append(d)
return d
def adjustWindow(self, channel, bytesToAdd):
"""
Tell the other side that we will receive more data. This should not
normally need to be called as it is managed automatically.
@type channel: subclass of L{SSHChannel}
@type bytesToAdd: L{int}
"""
if channel.localClosed:
return # we're already closed
packet = struct.pack(">2L", self.channelsToRemoteChannel[channel], bytesToAdd)
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, packet)
self._log.debug(
"adding {bytesToAdd} to {localWindowLeft} in channel {id}",
bytesToAdd=bytesToAdd,
localWindowLeft=channel.localWindowLeft,
id=channel.id,
)
channel.localWindowLeft += bytesToAdd
def sendData(self, channel, data):
"""
Send data to a channel. This should not normally be used: instead use
channel.write(data) as it manages the window automatically.
@type channel: subclass of L{SSHChannel}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(
MSG_CHANNEL_DATA,
struct.pack(">L", self.channelsToRemoteChannel[channel]) + common.NS(data),
)
def sendExtendedData(self, channel, dataType, data):
"""
Send extended data to a channel. This should not normally be used:
instead use channel.writeExtendedData(data, dataType) as it manages
the window automatically.
@type channel: subclass of L{SSHChannel}
@type dataType: L{int}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(
MSG_CHANNEL_EXTENDED_DATA,
struct.pack(">2L", self.channelsToRemoteChannel[channel], dataType)
+ common.NS(data),
)
def sendEOF(self, channel):
"""
Send an EOF (End of File) for a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
self._log.debug("sending eof")
self.transport.sendPacket(
MSG_CHANNEL_EOF, struct.pack(">L", self.channelsToRemoteChannel[channel])
)
def sendClose(self, channel):
"""
Close a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
self._log.info("sending close {id}", id=channel.id)
self.transport.sendPacket(
MSG_CHANNEL_CLOSE, struct.pack(">L", self.channelsToRemoteChannel[channel])
)
channel.localClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
# methods to override
def getChannel(self, channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
channelType is the type of channel being requested,
windowSize is the initial size of the remote window,
maxPacket is the largest packet we should send,
data is any other packet data (often nothing).
We return a subclass of L{SSHChannel}.
By default, this dispatches to a method 'channel_channelType' with any
non-alphanumerics in the channelType replace with _'s. If it cannot
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
The method is called with arguments of windowSize, maxPacket, data.
@type channelType: L{bytes}
@type windowSize: L{int}
@type maxPacket: L{int}
@type data: L{bytes}
@rtype: subclass of L{SSHChannel}/L{tuple}
"""
self._log.debug("got channel {channelType!r} request", channelType=channelType)
if hasattr(self.transport, "avatar"): # this is a server!
chan = self.transport.avatar.lookupChannel(
channelType, windowSize, maxPacket, data
)
else:
channelType = channelType.translate(TRANSLATE_TABLE)
attr = "channel_%s" % nativeString(channelType)
f = getattr(self, attr, None)
if f is not None:
chan = f(windowSize, maxPacket, data)
else:
chan = None
if chan is None:
raise error.ConchError("unknown channel", OPEN_UNKNOWN_CHANNEL_TYPE)
else:
chan.conn = self
return chan
def gotGlobalRequest(self, requestType, data):
"""
We got a global request. pretty much, this is just used by the client
to request that we forward a port from the server to the client.
Returns either:
- 1: request accepted
- 1, <data>: request accepted with request specific data
- 0: request denied
By default, this dispatches to a method 'global_requestType' with
-'s in requestType replaced with _'s. The found method is passed data.
If this method cannot be found, this method returns 0. Otherwise, it
returns the return value of that method.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{int}/L{tuple}
"""
self._log.debug("got global {requestType} request", requestType=requestType)
if hasattr(self.transport, "avatar"): # this is a server!
return self.transport.avatar.gotGlobalRequest(requestType, data)
requestType = nativeString(requestType.replace(b"-", b"_"))
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)
def channelClosed(self, channel):
"""
Called when a channel is closed.
It clears the local state related to the channel, and calls
channel.closed().
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
If you don't, things will break mysteriously.
@type channel: L{SSHChannel}
"""
if channel in self.channelsToRemoteChannel: # actually open
channel.localClosed = channel.remoteClosed = True
del self.localToRemoteChannel[channel.id]
del self.channels[channel.id]
del self.channelsToRemoteChannel[channel]
for d in self.deferreds.pop(channel.id, []):
d.errback(error.ConchError("Channel closed."))
channel.closed()
MSG_GLOBAL_REQUEST = 80
MSG_REQUEST_SUCCESS = 81
MSG_REQUEST_FAILURE = 82
MSG_CHANNEL_OPEN = 90
MSG_CHANNEL_OPEN_CONFIRMATION = 91
MSG_CHANNEL_OPEN_FAILURE = 92
MSG_CHANNEL_WINDOW_ADJUST = 93
MSG_CHANNEL_DATA = 94
MSG_CHANNEL_EXTENDED_DATA = 95
MSG_CHANNEL_EOF = 96
MSG_CHANNEL_CLOSE = 97
MSG_CHANNEL_REQUEST = 98
MSG_CHANNEL_SUCCESS = 99
MSG_CHANNEL_FAILURE = 100
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
OPEN_CONNECT_FAILED = 2
OPEN_UNKNOWN_CHANNEL_TYPE = 3
OPEN_RESOURCE_SHORTAGE = 4
# From RFC 4254
EXTENDED_DATA_STDERR = 1
messages = {}
for name, value in locals().copy().items():
if name[:4] == "MSG_":
messages[value] = name # Doesn't handle doubles
alphanums = networkString(string.ascii_letters + string.digits)
TRANSLATE_TABLE = bytes(i if i in alphanums else ord("_") for i in range(256))
SSHConnection.protocolMessages = messages

View File

@@ -0,0 +1,129 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A Factory for SSH servers.
See also L{twisted.conch.openssh_compat.factory} for OpenSSH compatibility.
Maintainer: Paul Swartz
"""
import random
from itertools import chain
from typing import Dict, List, Optional, Tuple
from twisted.conch import error
from twisted.conch.ssh import _kex, connection, transport, userauth
from twisted.internet import protocol
from twisted.logger import Logger
class SSHFactory(protocol.Factory):
"""
A Factory for SSH servers.
"""
primes: Optional[Dict[int, List[Tuple[int, int]]]]
_log = Logger()
protocol = transport.SSHServerTransport
services = {
b"ssh-userauth": userauth.SSHUserAuthServer,
b"ssh-connection": connection.SSHConnection,
}
def startFactory(self) -> None:
"""
Check for public and private keys.
"""
if not hasattr(self, "publicKeys"):
self.publicKeys = self.getPublicKeys()
if not hasattr(self, "privateKeys"):
self.privateKeys = self.getPrivateKeys()
if not self.publicKeys or not self.privateKeys:
raise error.ConchError("no host keys, failing")
if not hasattr(self, "primes"):
self.primes = self.getPrimes()
def buildProtocol(self, addr):
"""
Create an instance of the server side of the SSH protocol.
@type addr: L{twisted.internet.interfaces.IAddress} provider
@param addr: The address at which the server will listen.
@rtype: L{twisted.conch.ssh.transport.SSHServerTransport}
@return: The built transport.
"""
t = protocol.Factory.buildProtocol(self, addr)
t.supportedPublicKeys = list(
chain.from_iterable(
key.supportedSignatureAlgorithms() for key in self.privateKeys.values()
)
)
if not self.primes:
self._log.info(
"disabling non-fixed-group key exchange algorithms "
"because we cannot find moduli file"
)
t.supportedKeyExchanges = [
kexAlgorithm
for kexAlgorithm in t.supportedKeyExchanges
if _kex.isFixedGroup(kexAlgorithm) or _kex.isEllipticCurve(kexAlgorithm)
]
return t
def getPublicKeys(self):
"""
Called when the factory is started to get the public portions of the
servers host keys. Returns a dictionary mapping SSH key types to
public key strings.
@rtype: L{dict}
"""
raise NotImplementedError("getPublicKeys unimplemented")
def getPrivateKeys(self):
"""
Called when the factory is started to get the private portions of the
servers host keys. Returns a dictionary mapping SSH key types to
L{twisted.conch.ssh.keys.Key} objects.
@rtype: L{dict}
"""
raise NotImplementedError("getPrivateKeys unimplemented")
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
"""
Called when the factory is started to get Diffie-Hellman generators and
primes to use. Returns a dictionary mapping number of bits to lists of
tuple of (generator, prime).
"""
def getDHPrime(self, bits: int) -> Tuple[int, int]:
"""
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
close to C{bits} bits as possible.
"""
def keyfunc(i: int) -> int:
return abs(i - bits)
assert self.primes is not None, "Factory should have been started by now."
primesKeys = sorted(self.primes.keys(), key=keyfunc)
realBits = primesKeys[0]
return random.choice(self.primes[realBits])
def getService(self, transport, service):
"""
Return a class to use as a service for the given transport.
@type transport: L{transport.SSHServerTransport}
@type service: L{bytes}
@rtype: subclass of L{service.SSHService}
"""
if service == b"ssh-userauth" or hasattr(transport, "avatar"):
return self.services[service]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the TCP forwarding, which allows
clients and servers to forward arbitrary TCP data across the connection.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.ssh import channel, common
from twisted.internet import protocol, reactor
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
class SSHListenForwardingFactory(protocol.Factory):
def __init__(self, connection, hostport, klass):
self.conn = connection
self.hostport = hostport # tuple
self.klass = klass
def buildProtocol(self, addr):
channel = self.klass(conn=self.conn)
client = SSHForwardingClient(channel)
channel.client = client
addrTuple = (addr.host, addr.port)
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
self.conn.openChannel(channel, channelOpenData)
return client
class SSHListenForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
self._log.info("opened forwarding channel {id}", id=self.id)
if len(self.client.buf) > 1:
b = self.client.buf[1:]
self.write(b)
self.client.buf = b""
def openFailed(self, reason):
self.closed()
def dataReceived(self, data):
self.client.transport.write(data)
def eofReceived(self):
self.client.transport.loseConnection()
def closed(self):
if hasattr(self, "client"):
self._log.info("closing local forwarding channel {id}", id=self.id)
self.client.transport.loseConnection()
del self.client
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
name = b"direct-tcpip"
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
name = b"forwarded-tcpip"
class SSHConnectForwardingChannel(channel.SSHChannel):
"""
Channel used for handling server side forwarding request.
It acts as a client for the remote forwarding destination.
@ivar hostport: C{(host, port)} requested by client as forwarding
destination.
@type hostport: L{tuple} or a C{sequence}
@ivar client: Protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
@ivar clientBuf: Data received while forwarding channel is not yet
connected.
@type clientBuf: L{bytes}
@var _reactor: Reactor used for TCP connections.
@type _reactor: A reactor.
@ivar _channelOpenDeferred: Deferred used in testing to check the
result of C{channelOpen}.
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
"""
_reactor = reactor
def __init__(self, hostport, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.hostport = hostport
self.client = None
self.clientBuf = b""
def channelOpen(self, specificData):
"""
See: L{channel.SSHChannel}
"""
self._log.info(
"connecting to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
)
ep = HostnameEndpoint(self._reactor, self.hostport[0], self.hostport[1])
d = connectProtocol(ep, SSHForwardingClient(self))
d.addCallbacks(self._setClient, self._close)
self._channelOpenDeferred = d
def _setClient(self, client):
"""
Called when the connection was established to the forwarding
destination.
@param client: Client protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
"""
self.client = client
self._log.info(
"connected to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
)
if self.clientBuf:
self.client.transport.write(self.clientBuf)
self.clientBuf = None
if self.client.buf[1:]:
self.write(self.client.buf[1:])
self.client.buf = b""
def _close(self, reason):
"""
Called when failed to connect to the forwarding destination.
@param reason: Reason why connection failed.
@type reason: L{twisted.python.failure.Failure}
"""
self._log.error(
"failed to connect to {host}:{port}: {reason}",
host=self.hostport[0],
port=self.hostport[1],
reason=reason,
)
self.loseConnection()
def dataReceived(self, data):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self.client.transport.write(data)
else:
self.clientBuf += data
def closed(self):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self._log.info("closed remote forwarding channel {id}", id=self.id)
if self.client.channel:
self.loseConnection()
self.client.transport.loseConnection()
del self.client
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
remoteHP, origHP = unpackOpen_direct_tcpip(data)
return SSHConnectForwardingChannel(
remoteHP,
remoteWindow=remoteWindow,
remoteMaxPacket=remoteMaxPacket,
avatar=avatar,
)
class SSHForwardingClient(protocol.Protocol):
def __init__(self, channel):
self.channel = channel
self.buf = b"\000"
def dataReceived(self, data):
if self.buf:
self.buf += data
else:
self.channel.write(data)
def connectionLost(self, reason):
if self.channel:
self.channel.loseConnection()
self.channel = None
def packOpen_direct_tcpip(destination, source):
"""
Pack the data suitable for sending in a CHANNEL_OPEN packet.
@type destination: L{tuple}
@param destination: A tuple of the (host, port) of the destination host.
@type source: L{tuple}
@param source: A tuple of the (host, port) of the source host.
"""
(connHost, connPort) = destination
(origHost, origPort) = source
if isinstance(connHost, str):
connHost = connHost.encode("utf-8")
if isinstance(origHost, str):
origHost = origHost.encode("utf-8")
conn = common.NS(connHost) + struct.pack(">L", connPort)
orig = common.NS(origHost) + struct.pack(">L", origPort)
return conn + orig
packOpen_forwarded_tcpip = packOpen_direct_tcpip
def unpackOpen_direct_tcpip(data):
"""Unpack the data to a usable format."""
connHost, rest = common.getNS(data)
if isinstance(connHost, bytes):
connHost = connHost.decode("utf-8")
connPort = int(struct.unpack(">L", rest[:4])[0])
origHost, rest = common.getNS(rest[4:])
if isinstance(origHost, bytes):
origHost = origHost.decode("utf-8")
origPort = int(struct.unpack(">L", rest[:4])[0])
return (connHost, connPort), (origHost, origPort)
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
def packGlobal_tcpip_forward(peer):
"""
Pack the data for tcpip forwarding.
@param peer: A tuple of the (host, port) .
@type peer: L{tuple}
"""
(host, port) = peer
return common.NS(host) + struct.pack(">L", port)
def unpackGlobal_tcpip_forward(data):
host, rest = common.getNS(data)
if isinstance(host, bytes):
host = host.decode("utf-8")
port = int(struct.unpack(">L", rest[:4])[0])
return host, port
"""This is how the data -> eof -> close stuff /should/ work.
debug3: channel 1: waiting for connection
debug1: channel 1: connected
debug1: channel 1: read<=0 rfd 7 len 0
debug1: channel 1: read failed
debug1: channel 1: close_read
debug1: channel 1: input open -> drain
debug1: channel 1: ibuf empty
debug1: channel 1: send eof
debug1: channel 1: input drain -> closed
debug1: channel 1: rcvd eof
debug1: channel 1: output open -> drain
debug1: channel 1: obuf empty
debug1: channel 1: close_write
debug1: channel 1: output drain -> closed
debug1: channel 1: rcvd close
debug3: channel 1: will not send data after close
debug1: channel 1: send close
debug1: channel 1: is dead
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH services. Currently implemented services
are ssh-userauth and ssh-connection.
Maintainer: Paul Swartz
"""
from typing import Dict
from twisted.logger import Logger
class SSHService:
# this is the ssh name for the service:
name: bytes = None # type:ignore[assignment]
protocolMessages: Dict[int, str] = {} # map #'s -> protocol names
transport = None # gets set later
_log = Logger()
def serviceStarted(self):
"""
called when the service is active on the transport.
"""
def serviceStopped(self):
"""
called when the service is stopped, either by the connection ending
or by another service being started
"""
def logPrefix(self):
return "SSHService {!r} on {}".format(
self.name, self.transport.transport.logPrefix()
)
def packetReceived(self, messageNum, packet):
"""
called when we receive a packet on the transport
"""
# print self.protocolMessages
if messageNum in self.protocolMessages:
messageType = self.protocolMessages[messageNum]
f = getattr(self, "ssh_%s" % messageType[4:], None)
if f is not None:
return f(packet)
self._log.info(
"couldn't handle {messageNum} {packet!r}",
messageNum=messageNum,
packet=packet,
)
self.transport.sendUnimplemented()

View File

@@ -0,0 +1,440 @@
# -*- test-case-name: twisted.conch.test.test_session -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of SSHSession, which (by default)
allows access to a shell and a python interpreter over SSH.
Maintainer: Paul Swartz
"""
import os
import signal
import struct
import sys
from zope.interface import implementer
from twisted.conch.interfaces import (
EnvironmentVariableNotPermitted,
ISession,
ISessionSetEnv,
)
from twisted.conch.ssh import channel, common, connection
from twisted.internet import interfaces, protocol
from twisted.logger import Logger
from twisted.python.compat import networkString
log = Logger()
class SSHSession(channel.SSHChannel):
"""
A generalized implementation of an SSH session.
See RFC 4254, section 6.
The precise implementation of the various operations that the remote end
can send is left up to the avatar, usually via an adapter to an
interface such as L{ISession}.
@ivar buf: a buffer for data received before making a connection to a
client.
@type buf: L{bytes}
@ivar client: a protocol for communication with a shell, an application
program, or a subsystem (see RFC 4254, section 6.5).
@type client: L{SSHSessionProcessProtocol}
@ivar session: an object providing concrete implementations of session
operations.
@type session: L{ISession}
"""
name = b"session"
def __init__(self, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.buf = b""
self.client = None
self.session = None
def request_subsystem(self, data):
subsystem, ignored = common.getNS(data)
log.info('Asking for subsystem "{subsystem}"', subsystem=subsystem)
client = self.avatar.lookupSubsystem(subsystem, data)
if client:
pp = SSHSessionProcessProtocol(self)
proto = wrapProcessProtocol(pp)
client.makeConnection(proto)
pp.makeConnection(wrapProtocol(client))
self.client = pp
return 1
else:
log.error("Failed to get subsystem")
return 0
def request_shell(self, data):
log.info("Getting shell")
if not self.session:
self.session = ISession(self.avatar)
try:
pp = SSHSessionProcessProtocol(self)
self.session.openShell(pp)
except Exception:
log.failure("Error getting shell")
return 0
else:
self.client = pp
return 1
def request_exec(self, data):
if not self.session:
self.session = ISession(self.avatar)
f, data = common.getNS(data)
log.info('Executing command "{f}"', f=f)
try:
pp = SSHSessionProcessProtocol(self)
self.session.execCommand(pp, f)
except Exception:
log.failure('Error executing command "{f}"', f=f)
return 0
else:
self.client = pp
return 1
def request_pty_req(self, data):
if not self.session:
self.session = ISession(self.avatar)
term, windowSize, modes = parseRequest_pty_req(data)
log.info(
"Handling pty request: {term!r} {windowSize!r}",
term=term,
windowSize=windowSize,
)
try:
self.session.getPty(term, windowSize, modes)
except Exception:
log.failure("Error handling pty request")
return 0
else:
return 1
def request_env(self, data):
"""
Process a request to pass an environment variable.
@param data: The environment variable name and value, each encoded
as an SSH protocol string and concatenated.
@type data: L{bytes}
@return: A true value if the request to pass this environment
variable was accepted, otherwise a false value.
"""
if not self.session:
self.session = ISession(self.avatar)
if not ISessionSetEnv.providedBy(self.session):
return 0
name, value, data = common.getNS(data, 2)
try:
self.session.setEnv(name, value)
except EnvironmentVariableNotPermitted:
return 0
except Exception:
log.failure("Error setting environment variable {name}", name=name)
return 0
else:
return 1
def request_window_change(self, data):
if not self.session:
self.session = ISession(self.avatar)
winSize = parseRequest_window_change(data)
try:
self.session.windowChanged(winSize)
except Exception:
log.failure("Error changing window size")
return 0
else:
return 1
def dataReceived(self, data):
if not self.client:
# self.conn.sendClose(self)
self.buf += data
return
self.client.transport.write(data)
def extReceived(self, dataType, data):
if dataType == connection.EXTENDED_DATA_STDERR:
if self.client and hasattr(self.client.transport, "writeErr"):
self.client.transport.writeErr(data)
else:
log.warn("Weird extended data: {dataType}", dataType=dataType)
def eofReceived(self):
# If we have a session, tell it that EOF has been received and
# expect it to send a close message (it may need to send other
# messages such as exit-status or exit-signal first). If we don't
# have a session, then just send a close message directly.
if self.session:
self.session.eofReceived()
elif self.client:
self.conn.sendClose(self)
def closed(self):
if self.client and self.client.transport:
self.client.transport.loseConnection()
if self.session:
self.session.closed()
# def closeReceived(self):
# self.loseConnection() # don't know what to do with this
def loseConnection(self):
if self.client:
self.client.transport.loseConnection()
channel.SSHChannel.loseConnection(self)
class _ProtocolWrapper(protocol.ProcessProtocol):
"""
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
"""
def __init__(self, proto):
self.proto = proto
def connectionMade(self):
self.proto.connectionMade()
def outReceived(self, data):
self.proto.dataReceived(data)
def processEnded(self, reason):
self.proto.connectionLost(reason)
class _DummyTransport:
def __init__(self, proto):
self.proto = proto
def dataReceived(self, data):
self.proto.transport.write(data)
def write(self, data):
self.proto.dataReceived(data)
def writeSequence(self, seq):
self.write(b"".join(seq))
def loseConnection(self):
self.proto.connectionLost(protocol.connectionDone)
def wrapProcessProtocol(inst):
if isinstance(inst, protocol.Protocol):
return _ProtocolWrapper(inst)
else:
return inst
def wrapProtocol(proto):
return _DummyTransport(proto)
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
# to accept. See RFC 4254
SUPPORTED_SIGNALS = [
"ABRT",
"ALRM",
"FPE",
"HUP",
"ILL",
"INT",
"KILL",
"PIPE",
"QUIT",
"SEGV",
"TERM",
"USR1",
"USR2",
]
@implementer(interfaces.ITransport)
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
"""I am both an L{IProcessProtocol} and an L{ITransport}.
I am a transport to the remote endpoint and a process protocol to the
local subsystem.
"""
# once initialized, a dictionary mapping signal values to strings
# that follow RFC 4254.
_signalValuesToNames = None
def __init__(self, session):
self.session = session
self.lostOutOrErrFlag = False
def connectionMade(self):
if self.session.buf:
self.transport.write(self.session.buf)
self.session.buf = None
def outReceived(self, data):
self.session.write(data)
def errReceived(self, err):
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
def outConnectionLost(self):
"""
EOF should only be sent when both STDOUT and STDERR have been closed.
"""
if self.lostOutOrErrFlag:
self.session.conn.sendEOF(self.session)
else:
self.lostOutOrErrFlag = True
def errConnectionLost(self):
"""
See outConnectionLost().
"""
self.outConnectionLost()
def connectionLost(self, reason=None):
self.session.loseConnection()
def _getSignalName(self, signum):
"""
Get a signal name given a signal number.
"""
if self._signalValuesToNames is None:
self._signalValuesToNames = {}
# make sure that the POSIX ones are the defaults
for signame in SUPPORTED_SIGNALS:
signame = "SIG" + signame
sigvalue = getattr(signal, signame, None)
if sigvalue is not None:
self._signalValuesToNames[sigvalue] = signame
for k, v in signal.__dict__.items():
# Check for platform specific signals, ignoring Python specific
# SIG_DFL and SIG_IGN
if k.startswith("SIG") and not k.startswith("SIG_"):
if v not in self._signalValuesToNames:
self._signalValuesToNames[v] = k + "@" + sys.platform
return self._signalValuesToNames[signum]
def processEnded(self, reason=None):
"""
When we are told the process ended, try to notify the other side about
how the process ended using the exit-signal or exit-status requests.
Also, close the channel.
"""
if reason is not None:
err = reason.value
if err.signal is not None:
signame = self._getSignalName(err.signal)
if getattr(os, "WCOREDUMP", None) is not None and os.WCOREDUMP(
err.status
):
log.info("exitSignal: {signame} (core dumped)", signame=signame)
coreDumped = True
else:
log.info("exitSignal: {}", signame=signame)
coreDumped = False
self.session.conn.sendRequest(
self.session,
b"exit-signal",
common.NS(networkString(signame[3:]))
+ (b"\1" if coreDumped else b"\0")
+ common.NS(b"")
+ common.NS(b""),
)
elif err.exitCode is not None:
log.info("exitCode: {exitCode!r}", exitCode=err.exitCode)
self.session.conn.sendRequest(
self.session, b"exit-status", struct.pack(">L", err.exitCode)
)
self.session.loseConnection()
def getHost(self):
"""
Return the host from my session's transport.
"""
return self.session.conn.transport.getHost()
def getPeer(self):
"""
Return the peer from my session's transport.
"""
return self.session.conn.transport.getPeer()
def write(self, data):
self.session.write(data)
def writeSequence(self, seq):
self.session.write(b"".join(seq))
def loseConnection(self):
self.session.loseConnection()
class SSHSessionClient(protocol.Protocol):
def dataReceived(self, data):
if self.transport:
self.transport.write(data)
# methods factored out to make live easier on server writers
def parseRequest_pty_req(data):
"""Parse the data from a pty-req request into usable data.
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
"""
term, rest = common.getNS(data)
cols, rows, xpixel, ypixel = struct.unpack(">4L", rest[:16])
modes, ignored = common.getNS(rest[16:])
winSize = (rows, cols, xpixel, ypixel)
modes = [
(ord(modes[i : i + 1]), struct.unpack(">L", modes[i + 1 : i + 5])[0])
for i in range(0, len(modes) - 1, 5)
]
return term, winSize, modes
def packRequest_pty_req(term, geometry, modes):
"""
Pack a pty-req request so that it is suitable for sending.
NOTE: modes must be packed before being sent here.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
termPacked = common.NS(term)
winSizePacked = struct.pack(">4L", cols, rows, xpixel, ypixel)
modesPacked = common.NS(modes) # depend on the client packing modes
return termPacked + winSizePacked + modesPacked
def parseRequest_window_change(data):
"""Parse the data from a window-change request into usuable data.
@returns: a tuple of (rows, cols, xpixel, ypixel)
"""
cols, rows, xpixel, ypixel = struct.unpack(">4L", data)
return rows, cols, xpixel, ypixel
def packRequest_window_change(geometry):
"""
Pack a window-change request so that it is suitable for sending.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
return struct.pack(">4L", cols, rows, xpixel, ypixel)

View File

@@ -0,0 +1,40 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
def parse(s):
s = s.strip()
expr = []
while s:
if s[0:1] == b"(":
newSexp = []
if expr:
expr[-1].append(newSexp)
expr.append(newSexp)
s = s[1:]
continue
if s[0:1] == b")":
aList = expr.pop()
s = s[1:]
if not expr:
assert not s
return aList
continue
i = 0
while s[i : i + 1].isdigit():
i += 1
assert i
length = int(s[:i])
data = s[i + 1 : i + 1 + length]
expr[-1].append(data)
s = s[i + 1 + length :]
assert False, "this should not happen"
def pack(sexp):
return b"".join(
b"(%b)" % (pack(o),)
if type(o) in (type(()), type([]))
else b"%d:%b" % (len(o), o)
for o in sexp
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,764 @@
# -*- test-case-name: twisted.conch.test.test_userauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the ssh-userauth service.
Currently implemented authentication types are public-key and password.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch import error, interfaces
from twisted.conch.ssh import keys, service, transport
from twisted.conch.ssh.common import NS, getNS
from twisted.cred import credentials
from twisted.cred.error import UnauthorizedLogin
from twisted.internet import defer, reactor
from twisted.logger import Logger
from twisted.python import failure
from twisted.python.compat import nativeString
class SSHUserAuthServer(service.SSHService):
"""
A service implementing the server side of the 'ssh-userauth' service. It
is used to authenticate the user on the other side as being able to access
this server.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{bytes}
@ivar authenticatedWith: a list of authentication methods that have
already been used.
@type authenticatedWith: L{list}
@ivar loginTimeout: the number of seconds we wait before disconnecting
the user for taking too long to authenticate
@type loginTimeout: L{int}
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
allow before disconnecting.
@type attemptsBeforeDisconnect: L{int}
@ivar loginAttempts: the number of login attempts that have been made
@type loginAttempts: L{int}
@ivar passwordDelay: the number of seconds to delay when the user gives
an incorrect password
@type passwordDelay: L{int}
@ivar interfaceToMethod: a L{dict} mapping credential interfaces to
authentication methods. The server checks to see which of the
cred interfaces have checkers and tells the client that those methods
are valid for authentication.
@type interfaceToMethod: L{dict}
@ivar supportedAuthentications: A list of the supported authentication
methods.
@type supportedAuthentications: L{list} of L{bytes}
@ivar user: the last username the client tried to authenticate with
@type user: L{bytes}
@ivar method: the current authentication method
@type method: L{bytes}
@ivar nextService: the service the user wants started after authentication
has been completed.
@type nextService: L{bytes}
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
authentication
@type portal: L{twisted.cred.portal.Portal}
@ivar clock: an object with a callLater method. Stubbed out for testing.
"""
name = b"ssh-userauth"
loginTimeout = 10 * 60 * 60
# 10 minutes before we disconnect them
attemptsBeforeDisconnect = 20
# 20 login attempts before a disconnect
passwordDelay = 1 # number of seconds to delay on a failed password
clock = reactor
interfaceToMethod = {
credentials.ISSHPrivateKey: b"publickey",
credentials.IUsernamePassword: b"password",
}
_log = Logger()
def serviceStarted(self):
"""
Called when the userauth service is started. Set up instance
variables, check if we should allow password authentication (only
allow if the outgoing connection is encrypted) and set up a login
timeout.
"""
self.authenticatedWith = []
self.loginAttempts = 0
self.user = None
self.nextService = None
self.portal = self.transport.factory.portal
self.supportedAuthentications = []
for i in self.portal.listCredentialsInterfaces():
if i in self.interfaceToMethod:
self.supportedAuthentications.append(self.interfaceToMethod[i])
if not self.transport.isEncrypted("in"):
# don't let us transport password in plaintext
if b"password" in self.supportedAuthentications:
self.supportedAuthentications.remove(b"password")
self._cancelLoginTimeout = self.clock.callLater(
self.loginTimeout, self.timeoutAuthentication
)
def serviceStopped(self):
"""
Called when the userauth service is stopped. Cancel the login timeout
if it's still going.
"""
if self._cancelLoginTimeout:
self._cancelLoginTimeout.cancel()
self._cancelLoginTimeout = None
def timeoutAuthentication(self):
"""
Called when the user has timed out on authentication. Disconnect
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
"""
self._cancelLoginTimeout = None
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, b"you took too long"
)
def tryAuth(self, kind, user, data):
"""
Try to authenticate the user with the given method. Dispatches to a
auth_* method.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param user: the username the client is authenticating with.
@type user: L{bytes}
@param data: authentication specific data sent by the client.
@type data: L{bytes}
@return: A Deferred called back if the method succeeded, or erred back
if it failed.
@rtype: C{defer.Deferred}
"""
self._log.debug("{user!r} trying auth {kind!r}", user=user, kind=kind)
if kind not in self.supportedAuthentications:
return defer.fail(error.ConchError("unsupported authentication, failing"))
kind = nativeString(kind.replace(b"-", b"_"))
f = getattr(self, f"auth_{kind}", None)
if f:
ret = f(data)
if not ret:
return defer.fail(
error.ConchError(f"{kind} return None instead of a Deferred")
)
else:
return ret
return defer.fail(error.ConchError(f"bad auth type: {kind}"))
def ssh_USERAUTH_REQUEST(self, packet):
"""
The client has requested authentication. Payload::
string user
string next service
string method
<authentication specific data>
@type packet: L{bytes}
"""
user, nextService, method, rest = getNS(packet, 3)
if user != self.user or nextService != self.nextService:
self.authenticatedWith = [] # clear auth state
self.user = user
self.nextService = nextService
self.method = method
d = self.tryAuth(method, user, rest)
if not d:
self._ebBadAuth(failure.Failure(error.ConchError("auth returned none")))
return
d.addCallback(self._cbFinishedAuth)
d.addErrback(self._ebMaybeBadAuth)
d.addErrback(self._ebBadAuth)
return d
def _cbFinishedAuth(self, result):
"""
The callback when user has successfully been authenticated. For a
description of the arguments, see L{twisted.cred.portal.Portal.login}.
We start the service requested by the user.
"""
(interface, avatar, logout) = result
self.transport.avatar = avatar
self.transport.logoutFunction = logout
service = self.transport.factory.getService(self.transport, self.nextService)
if not service:
raise error.ConchError(f"could not get next service: {self.nextService}")
self._log.debug(
"{user!r} authenticated with {method!r}", user=self.user, method=self.method
)
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, b"")
self.transport.setService(service())
def _ebMaybeBadAuth(self, reason):
"""
An intermediate errback. If the reason is
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
with the partial success indicator set.
@type reason: L{twisted.python.failure.Failure}
"""
reason.trap(error.NotEnoughAuthentication)
self.transport.sendPacket(
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\xff"
)
def _ebBadAuth(self, reason):
"""
The final errback in the authentication chain. If the reason is
error.IgnoreAuthentication, we simply return; the authentication
method has sent its own response. Otherwise, send a failure message
and (if the method is not 'none') increment the number of login
attempts.
@type reason: L{twisted.python.failure.Failure}
"""
if reason.check(error.IgnoreAuthentication):
return
if self.method != b"none":
self._log.debug(
"{user!r} failed auth {method!r}", user=self.user, method=self.method
)
if reason.check(UnauthorizedLogin):
self._log.debug(
"unauthorized login: {message}", message=reason.getErrorMessage()
)
elif reason.check(error.ConchError):
self._log.debug("reason: {reason}", reason=reason.getErrorMessage())
else:
self._log.failure(
"Error checking auth for user {user}",
failure=reason,
user=self.user,
)
self.loginAttempts += 1
if self.loginAttempts > self.attemptsBeforeDisconnect:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b"too many bad auths",
)
return
self.transport.sendPacket(
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\x00"
)
def auth_publickey(self, packet):
"""
Public key authentication. Payload::
byte has signature
string algorithm name
string key blob
[string signature] (if has signature is True)
Create a SSHPublicKey credential and verify it using our portal.
"""
hasSig = ord(packet[0:1])
algName, blob, rest = getNS(packet[1:], 2)
try:
keys.Key.fromString(blob)
except keys.BadKeyError:
error = "Unsupported key type {} or bad key".format(algName.decode("ascii"))
self._log.error(error)
return defer.fail(UnauthorizedLogin(error))
signature = hasSig and getNS(rest)[0] or None
if hasSig:
b = (
NS(self.transport.sessionID)
+ bytes((MSG_USERAUTH_REQUEST,))
+ NS(self.user)
+ NS(self.nextService)
+ NS(b"publickey")
+ bytes((hasSig,))
+ NS(algName)
+ NS(blob)
)
c = credentials.SSHPrivateKey(self.user, algName, blob, b, signature)
return self.portal.login(c, None, interfaces.IConchUser)
else:
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebCheckKey, packet[1:]
)
def _ebCheckKey(self, reason, packet):
"""
Called back if the user did not sent a signature. If reason is
error.ValidPublicKey then this key is valid for the user to
authenticate with. Send MSG_USERAUTH_PK_OK.
"""
reason.trap(error.ValidPublicKey)
# if we make it here, it means that the publickey is valid
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
return failure.Failure(error.IgnoreAuthentication())
def auth_password(self, packet):
"""
Password authentication. Payload::
string password
Make a UsernamePassword credential and verify it with our portal.
"""
password = getNS(packet[1:])[0]
c = credentials.UsernamePassword(self.user, password)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebPassword
)
def _ebPassword(self, f):
"""
If the password is invalid, wait before sending the failure in order
to delay brute-force password guessing.
"""
d = defer.Deferred()
self.clock.callLater(self.passwordDelay, d.callback, f)
return d
class SSHUserAuthClient(service.SSHService):
"""
A service implementing the client side of 'ssh-userauth'.
This service will try all authentication methods provided by the server,
making callbacks for more information when necessary.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{str}
@ivar preferredOrder: a list of authentication methods that should be used
first, in order of preference, if supported by the server
@type preferredOrder: L{list}
@ivar user: the name of the user to authenticate as
@type user: L{bytes}
@ivar instance: the service to start after authentication has finished
@type instance: L{service.SSHService}
@ivar authenticatedWith: a list of strings of authentication methods we've tried
@type authenticatedWith: L{list} of L{bytes}
@ivar triedPublicKeys: a list of public key objects that we've tried to
authenticate with
@type triedPublicKeys: L{list} of L{Key}
@ivar lastPublicKey: the last public key object we've tried to authenticate
with
@type lastPublicKey: L{Key}
"""
name = b"ssh-userauth"
preferredOrder = [b"publickey", b"password", b"keyboard-interactive"]
def __init__(self, user, instance):
self.user = user
self.instance = instance
def serviceStarted(self):
self.authenticatedWith = []
self.triedPublicKeys = []
self.lastPublicKey = None
self.askForAuth(b"none", b"")
def askForAuth(self, kind, extraData):
"""
Send a MSG_USERAUTH_REQUEST.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param extraData: method-specific data to go in the packet
@type extraData: L{bytes}
"""
self.lastAuth = kind
self.transport.sendPacket(
MSG_USERAUTH_REQUEST,
NS(self.user) + NS(self.instance.name) + NS(kind) + extraData,
)
def tryAuth(self, kind):
"""
Dispatch to an authentication method.
@param kind: the authentication method
@type kind: L{bytes}
"""
kind = nativeString(kind.replace(b"-", b"_"))
self._log.debug("trying to auth with {kind}", kind=kind)
f = getattr(self, "auth_" + kind, None)
if f:
return f()
def _ebAuth(self, ignored, *args):
"""
Generic callback for a failed authentication attempt. Respond by
asking for the list of accepted methods (the 'none' method)
"""
self.askForAuth(b"none", b"")
def ssh_USERAUTH_SUCCESS(self, packet):
"""
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
authentication, so start the next service.
"""
self.transport.setService(self.instance)
def ssh_USERAUTH_FAILURE(self, packet):
"""
We received a MSG_USERAUTH_FAILURE. Payload::
string methods
byte partial success
If partial success is C{True}, then the previous method succeeded but is
not sufficient for authentication. C{methods} is a comma-separated list
of accepted authentication methods.
We sort the list of methods by their position in C{self.preferredOrder},
removing methods that have already succeeded. We then call
C{self.tryAuth} with the most preferred method.
@param packet: the C{MSG_USERAUTH_FAILURE} payload.
@type packet: L{bytes}
@return: a L{defer.Deferred} that will be callbacked with L{None} as
soon as all authentication methods have been tried, or L{None} if no
more authentication methods are available.
@rtype: C{defer.Deferred} or L{None}
"""
canContinue, partial = getNS(packet)
partial = ord(partial)
if partial:
self.authenticatedWith.append(self.lastAuth)
def orderByPreference(meth):
"""
Invoked once per authentication method in order to extract a
comparison key which is then used for sorting.
@param meth: the authentication method.
@type meth: L{bytes}
@return: the comparison key for C{meth}.
@rtype: L{int}
"""
if meth in self.preferredOrder:
return self.preferredOrder.index(meth)
else:
# put the element at the end of the list.
return len(self.preferredOrder)
canContinue = sorted(
(
meth
for meth in canContinue.split(b",")
if meth not in self.authenticatedWith
),
key=orderByPreference,
)
self._log.debug("can continue with: {methods}", methods=canContinue)
return self._cbUserauthFailure(None, iter(canContinue))
def _cbUserauthFailure(self, result, iterator):
if result:
return
try:
method = next(iterator)
except StopIteration:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b"no more authentication methods available",
)
else:
d = defer.maybeDeferred(self.tryAuth, method)
d.addCallback(self._cbUserauthFailure, iterator)
return d
def ssh_USERAUTH_PK_OK(self, packet):
"""
This message (number 60) can mean several different messages depending
on the current authentication type. We dispatch to individual methods
in order to handle this request.
"""
func = getattr(
self,
"ssh_USERAUTH_PK_OK_%s" % nativeString(self.lastAuth.replace(b"-", b"_")),
None,
)
if func is not None:
return func(packet)
else:
self.askForAuth(b"none", b"")
def ssh_USERAUTH_PK_OK_publickey(self, packet):
"""
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
signature and try to authenticate with it.
"""
publicKey = self.lastPublicKey
b = (
NS(self.transport.sessionID)
+ bytes((MSG_USERAUTH_REQUEST,))
+ NS(self.user)
+ NS(self.instance.name)
+ NS(b"publickey")
+ b"\x01"
+ NS(publicKey.sshType())
+ NS(publicKey.blob())
)
d = self.signData(publicKey, b)
if not d:
self.askForAuth(b"none", b"")
# this will fail, we'll move on
return
d.addCallback(self._cbSignedData)
d.addErrback(self._ebAuth)
def ssh_USERAUTH_PK_OK_password(self, packet):
"""
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
We ask for an old password and a new password, then send both back to
the server.
"""
prompt, language, rest = getNS(packet, 2)
self._oldPass = self._newPass = None
d = self.getPassword(b"Old Password: ")
d = d.addCallbacks(self._setOldPass, self._ebAuth)
d.addCallback(lambda ignored: self.getPassword(prompt))
d.addCallbacks(self._setNewPass, self._ebAuth)
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
"""
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
questions it wants us to answer, so we ask the user and sent the
responses.
"""
name, instruction, lang, data = getNS(packet, 3)
numPrompts = struct.unpack("!L", data[:4])[0]
data = data[4:]
prompts = []
for i in range(numPrompts):
prompt, data = getNS(data)
echo = bool(ord(data[0:1]))
data = data[1:]
prompts.append((prompt, echo))
d = self.getGenericAnswers(name, instruction, prompts)
d.addCallback(self._cbGenericAnswers)
d.addErrback(self._ebAuth)
def _cbSignedData(self, signedData):
"""
Called back out of self.signData with the signed data. Send the
authentication request with the signature.
@param signedData: the data signed by the user's private key.
@type signedData: L{bytes}
"""
publicKey = self.lastPublicKey
self.askForAuth(
b"publickey",
b"\x01" + NS(publicKey.sshType()) + NS(publicKey.blob()) + NS(signedData),
)
def _setOldPass(self, op):
"""
Called back when we are choosing a new password. Simply store the old
password for now.
@param op: the old password as entered by the user
@type op: L{bytes}
"""
self._oldPass = op
def _setNewPass(self, np):
"""
Called back when we are choosing a new password. Get the old password
and send the authentication message with both.
@param np: the new password as entered by the user
@type np: L{bytes}
"""
op = self._oldPass
self._oldPass = None
self.askForAuth(b"password", b"\xff" + NS(op) + NS(np))
def _cbGenericAnswers(self, responses):
"""
Called back when we are finished answering keyboard-interactive
questions. Send the info back to the server in a
MSG_USERAUTH_INFO_RESPONSE.
@param responses: a list of L{bytes} responses
@type responses: L{list}
"""
data = struct.pack("!L", len(responses))
for r in responses:
data += NS(r.encode("UTF8"))
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
def auth_publickey(self):
"""
Try to authenticate with a public key. Ask the user for a public key;
if the user has one, send the request to the server and return True.
Otherwise, return False.
@rtype: L{bool}
"""
d = defer.maybeDeferred(self.getPublicKey)
d.addBoth(self._cbGetPublicKey)
return d
def _cbGetPublicKey(self, publicKey):
if not isinstance(publicKey, keys.Key): # failure or None
publicKey = None
if publicKey is not None:
self.lastPublicKey = publicKey
self.triedPublicKeys.append(publicKey)
self._log.debug("using key of type {keyType}", keyType=publicKey.type())
self.askForAuth(
b"publickey", b"\x00" + NS(publicKey.sshType()) + NS(publicKey.blob())
)
return True
else:
return False
def auth_password(self):
"""
Try to authenticate with a password. Ask the user for a password.
If the user will return a password, return True. Otherwise, return
False.
@rtype: L{bool}
"""
d = self.getPassword()
if d:
d.addCallbacks(self._cbPassword, self._ebAuth)
return True
else: # returned None, don't do password auth
return False
def auth_keyboard_interactive(self):
"""
Try to authenticate with keyboard-interactive authentication. Send
the request to the server and return True.
@rtype: L{bool}
"""
self._log.debug("authing with keyboard-interactive")
self.askForAuth(b"keyboard-interactive", NS(b"") + NS(b""))
return True
def _cbPassword(self, password):
"""
Called back when the user gives a password. Send the request to the
server.
@param password: the password the user entered
@type password: L{bytes}
"""
self.askForAuth(b"password", b"\x00" + NS(password))
def signData(self, publicKey, signData):
"""
Sign the given data with the given public key.
By default, this will call getPrivateKey to get the private key,
then sign the data using Key.sign().
This method is factored out so that it can be overridden to use
alternate methods, such as a key agent.
@param publicKey: The public key object returned from L{getPublicKey}
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: a Deferred that's called back with the signature
@rtype: L{defer.Deferred}
"""
key = self.getPrivateKey()
if not key:
return
return key.addCallback(self._cbSignData, signData)
def _cbSignData(self, privateKey, signData):
"""
Called back when the private key is returned. Sign the data and
return the signature.
@param privateKey: the private key object
@type privateKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: the signature
@rtype: L{bytes}
"""
return privateKey.sign(signData)
def getPublicKey(self):
"""
Return a public key for the user. If no more public keys are
available, return L{None}.
This implementation always returns L{None}. Override it in a
subclass to actually find and return a public key object.
@rtype: L{Key} or L{None}
"""
return None
def getPrivateKey(self):
"""
Return a L{Deferred} that will be called back with the private key
object corresponding to the last public key from getPublicKey().
If the private key is not available, errback on the Deferred.
@rtype: L{Deferred} called back with L{Key}
"""
return defer.fail(NotImplementedError())
def getPassword(self, prompt=None):
"""
Return a L{Deferred} that will be called back with a password.
prompt is a string to display for the password, or None for a generic
'user@hostname's password: '.
@type prompt: L{bytes}/L{None}
@rtype: L{defer.Deferred}
"""
return defer.fail(NotImplementedError())
def getGenericAnswers(self, name, instruction, prompts):
"""
Returns a L{Deferred} with the responses to the promopts.
@param name: The name of the authentication currently in progress.
@param instruction: Describes what the authentication wants.
@param prompts: A list of (prompt, echo) pairs, where prompt is a
string to display and echo is a boolean indicating whether the
user's response should be echoed as they type it.
"""
return defer.fail(NotImplementedError())
MSG_USERAUTH_REQUEST = 50
MSG_USERAUTH_FAILURE = 51
MSG_USERAUTH_SUCCESS = 52
MSG_USERAUTH_BANNER = 53
MSG_USERAUTH_INFO_RESPONSE = 61
MSG_USERAUTH_PK_OK = 60
messages = {}
for k, v in list(locals().items()):
if k[:4] == "MSG_":
messages[v] = k
SSHUserAuthServer.protocolMessages = messages
SSHUserAuthClient.protocolMessages = messages
del messages
del v
# Doubles, not included in the protocols' mappings
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
MSG_USERAUTH_INFO_REQUEST = 60

View File

@@ -0,0 +1,114 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Asynchronous local terminal input handling
@author: Jp Calderone
"""
import os
import sys
import termios
import tty
from twisted.conch.insults.insults import ServerProtocol
from twisted.conch.manhole import ColoredManhole
from twisted.internet import defer, protocol, reactor, stdio
from twisted.python import failure, log, reflect
class UnexpectedOutputError(Exception):
pass
class TerminalProcessProtocol(protocol.ProcessProtocol):
def __init__(self, proto):
self.proto = proto
self.onConnection = defer.Deferred()
def connectionMade(self):
self.proto.makeConnection(self)
self.onConnection.callback(None)
self.onConnection = None
def write(self, data):
"""
Write to the terminal.
@param data: Data to write.
@type data: L{bytes}
"""
self.transport.write(data)
def outReceived(self, data):
"""
Receive data from the terminal.
@param data: Data received.
@type data: L{bytes}
"""
self.proto.dataReceived(data)
def errReceived(self, data):
"""
Report an error.
@param data: Data to include in L{Failure}.
@type data: L{bytes}
"""
self.transport.loseConnection()
if self.proto is not None:
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(data)))
self.proto = None
def childConnectionLost(self, childFD):
if self.proto is not None:
self.proto.childConnectionLost(childFD)
def processEnded(self, reason):
if self.proto is not None:
self.proto.connectionLost(reason)
self.proto = None
class ConsoleManhole(ColoredManhole):
"""
A manhole protocol specifically for use with L{stdio.StandardIO}.
"""
def connectionLost(self, reason):
"""
When the connection is lost, there is nothing more to do. Stop the
reactor so that the process can exit.
"""
reactor.stop()
def runWithProtocol(klass):
fd = sys.__stdin__.fileno()
oldSettings = termios.tcgetattr(fd)
tty.setraw(fd)
try:
stdio.StandardIO(ServerProtocol(klass))
reactor.run()
finally:
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
os.write(fd, b"\r\x1bc\r")
def main(argv=None):
log.startLogging(open("child.log", "w"))
if argv is None:
argv = sys.argv[1:]
if argv:
klass = reflect.namedClass(argv[0])
else:
klass = ConsoleManhole
runWithProtocol(klass)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,91 @@
# -*- test-case-name: twisted.conch.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support module for making SSH servers with twistd.
"""
from twisted.application import strports
from twisted.conch import checkers as conch_checkers, unix
from twisted.conch.openssh_compat import factory
from twisted.cred import portal, strcred
from twisted.python import usage
class Options(usage.Options, strcred.AuthOptionMixin):
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
longdesc = (
"Makes a Conch SSH server. If no authentication methods are "
"specified, the default authentication methods are UNIX passwords "
"and SSH public keys. If --auth options are "
"passed, only the measures specified will be used."
)
optParameters = [
["interface", "i", "", "local interface to which we listen"],
["port", "p", "tcp:22", "Port on which to listen"],
["data", "d", "/etc", "directory to look for host keys in"],
[
"moduli",
"",
None,
"directory to look for moduli in " "(if different from --data)",
],
]
compData = usage.Completions(
optActions={
"data": usage.CompleteDirs(descr="data directory"),
"moduli": usage.CompleteDirs(descr="moduli directory"),
"interface": usage.CompleteNetInterfaces(),
}
)
def __init__(self, *a, **kw):
usage.Options.__init__(self, *a, **kw)
# Call the default addCheckers (for backwards compatibility) that will
# be used if no --auth option is provided - note that conch's
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
# checker
super().addChecker(conch_checkers.UNIXPasswordDatabase())
super().addChecker(
conch_checkers.SSHPublicKeyChecker(conch_checkers.UNIXAuthorizedKeysFiles())
)
self._usingDefaultAuth = True
def addChecker(self, checker):
"""
Add the checker specified. If any checkers are added, the default
checkers are automatically cleared and the only checkers will be the
specified one(s).
"""
if self._usingDefaultAuth:
self["credCheckers"] = []
self["credInterfaces"] = {}
self._usingDefaultAuth = False
super().addChecker(checker)
def makeService(config):
"""
Construct a service for operating a SSH server.
@param config: An L{Options} instance specifying server options, including
where server keys are stored and what authentication methods to use.
@return: A L{twisted.application.service.IService} provider which contains
the requested SSH server.
"""
t = factory.OpenSSHFactory()
r = unix.UnixSSHRealm()
t.portal = portal.Portal(r, config.get("credCheckers", []))
t.dataRoot = config["data"]
t.moduliRoot = config["moduli"] or config["data"]
port = config["port"]
if config["interface"]:
# Add warning here
port += ":interface=" + config["interface"]
return strports.service(port, t)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"conch tests"

View File

@@ -0,0 +1,671 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=I0011,C0103,W9401,W9402
"""
Data used by test_keys as well as others.
"""
from base64 import decodebytes
RSAData = {
"n": int(
"269413617238113438198661010376758399219880277968382122687862697"
"296942471209955603071120391975773283844560230371884389952067978"
"789684135947515341209478065209455427327369102356204259106807047"
"964139525310539133073743116175821417513079706301100600025815509"
"786721808719302671068052414466483676821987505720384645561708425"
"794379383191274856941628512616355437197560712892001107828247792"
"561858327085521991407807015047750218508971611590850575870321007"
"991909043252470730134547038841839367764074379439843108550888709"
"430958143271417044750314742880542002948053835745429446485015316"
"60749404403945254975473896534482849256068133525751"
),
"e": 65537,
"d": int(
"420335724286999695680502438485489819800002417295071059780489811"
"840828351636754206234982682752076205397047218449504537476523960"
"987613148307573487322720481066677105211155388802079519869249746"
"774085882219244493290663802569201213676433159425782937159766786"
"329742053214957933941260042101377175565683849732354700525628975"
"239000548651346620826136200952740446562751690924335365940810658"
"931238410612521441739702170503547025018016868116037053013935451"
"477930426013703886193016416453215950072147440344656137718959053"
"897268663969428680144841987624962928576808352739627262941675617"
"7724661940425316604626522633351193810751757014073"
),
"p": int(
"152689878451107675391723141129365667732639179427453246378763774"
"448531436802867910180261906924087589684175595016060014593521649"
"964959248408388984465569934780790357826811592229318702991401054"
"226302790395714901636384511513449977061729214247279176398290513"
"085108930550446985490864812445551198848562639933888780317"
),
"q": int(
"176444974592327996338888725079951900172097062203378367409936859"
"072670162290963119826394224277287608693818012745872307600855894"
"647300295516866118620024751601329775653542084052616260193174546"
"400544176890518564317596334518015173606460860373958663673307503"
"231977779632583864454001476729233959405710696795574874403"
),
"u": int(
"936018002388095842969518498561007090965136403384715613439364803"
"229386793506402222847415019772053080458257034241832795210460612"
"924445085372678524176842007912276654532773301546269997020970818"
"155956828553418266110329867222673040098885651348225673298948529"
"93885224775891490070400861134282266967852120152546563278"
),
}
DSAData = {
"g": int(
"10253261326864117157640690761723586967382334319435778695"
"29171533815411392477819921538350732400350395446211982054"
"96512489289702949127531056893725702005035043292195216541"
"11525058911428414042792836395195432445511200566318251789"
"10575695836669396181746841141924498545494149998282951407"
"18645344764026044855941864175"
),
"p": int(
"10292031726231756443208850082191198787792966516790381991"
"77502076899763751166291092085666022362525614129374702633"
"26262930887668422949051881895212412718444016917144560705"
"45675251775747156453237145919794089496168502517202869160"
"78674893099371444940800865897607102159386345313384716752"
"18590012064772045092956919481"
),
"q": 1393384845225358996250882900535419012502712821577,
"x": 1220877188542930584999385210465204342686893855021,
"y": int(
"14604423062661947579790240720337570315008549983452208015"
"39426429789435409684914513123700756086453120500041882809"
"10283610277194188071619191739512379408443695946763554493"
"86398594314468629823767964702559709430618263927529765769"
"10270265745700231533660131769648708944711006508965764877"
"684264272082256183140297951"
),
}
ECDatanistp256 = {
"x": int(
"762825130203920963171185031449647317742997734817505505433829043"
"45687059013883"
),
"y": int(
"815431978646028526322656647694416475343443758943143196810611371"
"59310646683104"
),
"privateValue": int(
"3463874347721034170096400845565569825355565567882605"
"9678074967909361042656500"
),
"curve": b"ecdsa-sha2-nistp256",
}
SKECDatanistp256 = {
"x": int(
"239399367768747020111880335553299826848360860410053166887934464"
"83115637049597"
),
"y": int(
"114119006635761413192818806701564910719235784173643448381780025"
"223832906554748"
),
"curve": b"sk-ecdsa-sha2-nistp256@openssh.com",
}
ECDatanistp384 = {
"privateValue": int(
"280814107134858470598753916394807521398239633534281"
"633982576099083357871098966021020900021966162732114"
"95718603965098"
),
"x": int(
"10036914308591746758780165503819213553101287571902957054148542"
"504671046744460374996612408381962208627004841444205030"
),
"y": int(
"17337335659928075994560513699823544906448896792102247714689323"
"575406618073069185107088229463828921069465902299522926"
),
"curve": b"ecdsa-sha2-nistp384",
}
ECDatanistp521 = {
"x": int(
"12944742826257420846659527752683763193401384271391513286022917"
"29910013082920512632908350502247952686156279140016049549948975"
"670668730618745449113644014505462"
),
"y": int(
"10784108810271976186737587749436295782985563640368689081052886"
"16296815984553198866894145509329328086635278430266482551941240"
"591605833440825557820439734509311"
),
"privateValue": int(
"662751235215460886290293902658128847495347691199214"
"706697089140769672273950767961331442265530524063943"
"548846724348048614239791498442599782310681891569896"
"0565"
),
"curve": b"ecdsa-sha2-nistp521",
}
Ed25519Data = {
"a": (
b"\xf1\x16\xd1\x15J\x1e\x15\x0e\x19^\x19F\xb5\xf2D\r\xb2R\xa0\xae*k"
b"#\x13sE\xfd@\xd9W{\x8b"
),
"k": (
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
b"\xbb\xc9DX"
),
}
SKEd25519Data = {
"a": (
b"\x08}'U\xd2i\x04\x11\xea\x01~+\x165iRM\xdd\xe6R\x7f\xd3\xaf\\\xa8p"
b"\xa0LL\xe5\x8a\xa0"
),
"k": (
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
b"\xbb\xc9DX"
),
}
privateECDSA_openssh521 = b"""-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIAjn0lSVF6QweS4bjOGP9RHwqxUiTastSE0MVuLtFvkxygZqQ712oZ
ewMvqKkxthMQgxzSpGtRBcmkL7RqZ94+18qgBwYFK4EEACOhgYkDgYYABAFpX/6B
mxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB
j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8X
f09ETdku/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh521_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS
1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBaV/+gZscYJcA/laRL8NIXMsVcy8T
ZzBs8WRhe8ZjY/J5RezPk0HLXZUdl7/yfgICHlRpgY9qfOG5FfYwc61DSRAArCY7bWcPDW
1K4cY5C0/96zDbbsIxOLy42tD+wzExNBFfF39PRE3ZLv8/9bTkkq0r0cJlHDPZ0FCbRxwG
Za+UNH8AAAEAeRISlnkSEpYAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ
AAAIUEAWlf/oGbHGCXAP5WkS/DSFzLFXMvE2cwbPFkYXvGY2PyeUXsz5NBy12VHZe/8n4C
Ah5UaYGPanzhuRX2MHOtQ0kQAKwmO21nDw1tSuHGOQtP/esw227CMTi8uNrQ/sMxMTQRXx
d/T0RN2S7/P/W05JKtK9HCZRwz2dBQm0ccBmWvlDR/AAAAQgCOfSVJUXpDB5LhuM4Y/1Ef
CrFSJNqy1ITQxW4u0W+THKBmpDvXahl7Ay+oqTG2ExCDHNKka1EFyaQvtGpn3j7XygAAAA
ABAg==
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh521 = (
b"ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACF"
b"BAFpX/6BmxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB"
b"j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8Xf09ETdku"
b"/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw== comment"
)
privateECDSA_openssh384 = b"""-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAtAi7I8j73WCX20qUM5hhHwHuFzYWYYILs2Sh8UZ+awNkARZ/Fu2LU
LLl5RtOQpbWgBwYFK4EEACKhZANiAATU17sA9P5FRwSknKcFsjjsk0+E3CeXPYX0
Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCW
G0RqpQ+np31aKmeJshkcYALEchnU+tQ=
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh384_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS
1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTU17sA9P5FRwSknKcFsjjsk0+E3CeX
PYX0Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCWG0
RqpQ+np31aKmeJshkcYALEchnU+tQAAADIiktpWIpLaVgAAAATZWNkc2Etc2hhMi1uaXN0
cDM4NAAAAAhuaXN0cDM4NAAAAGEE1Ne7APT+RUcEpJynBbI47JNPhNwnlz2F9E5PzNBytz
6VkFoKzvCXURz/XhRPTv9tz/AbpsBbexfwLz5Qs1Xd+VX62hU9KOHAlhtEaqUPp6d9Wipn
ibIZHGACxHIZ1PrUAAAAMC0CLsjyPvdYJfbSpQzmGEfAe4XNhZhgguzZKHxRn5rA2QBFn8
W7YtQsuXlG05CltQAAAAA=
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh384 = (
b"ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABh"
b"BNTXuwD0/kVHBKScpwWyOOyTT4TcJ5c9hfROT8zQcrc+lZBaCs7wl1Ec/14UT07/bc/wG6bA"
b"W3sX8C8+ULNV3flV+toVPSjhwJYbRGqlD6enfVoqZ4myGRxgAsRyGdT61A== comment"
)
publicECDSA_openssh = (
b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABB"
b"BKimX1DZ7+Qj0SpfePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBY"
b"gkN/34n42F4vpeA= comment"
)
publicSKECDSA_openssh = (
b"sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
b"LXRlc3Q= comment"
)
publicSKEd25519_openssh = (
b"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
)
publicSKECDSA_cert_openssh = (
b"sk-ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
b"LXRlc3Q= comment"
)
publicSKEd25519_cert_openssh = (
b"sk-ssh-ed25519-cert-v01@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
)
privateECDSA_openssh = b"""-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEyU1YOT2JxxofwbJXIjGftdNcJK55aQdNrhIt2xYQz0oAoGCCqGSM49
AwEHoUQDQgAEqKZfUNnv5CPRKl948xujWlvrIaQBvmXt24LWXznnIPu0R9B+qTtt
zu/jpZ7WEszLPo5tQFiCQ3/fifjYXi+l4A==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS
1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQSopl9Q2e/kI9EqX3jzG6NaW+shpAG+
Ze3bgtZfOecg+7RH0H6pO23O7+OlntYSzMs+jm1AWIJDf9+J+NheL6XgAAAAmCKU4hcilO
IXAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKimX1DZ7+Qj0Spf
ePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBYgkN/34n42F4vpe
AAAAAgTJTVg5PYnHGh/BslciMZ+101wkrnlpB02uEi3bFhDPQAAAAA
-----END OPENSSH PRIVATE KEY-----
"""
publicEd25519_openssh = (
b"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPEW0RVKHhUOGV4ZRrXyRA2yUqCuKmsjE3NF"
b"/UDZV3uL comment"
)
# OpenSSH has only ever supported the "new" (v1) format for Ed25519.
privateEd25519_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iwAAAJA61eMLOtXj
CwAAAAtzc2gtZWQyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iw
AAAEA3LyXajdSomnh8YfCYAcb0Xm1nBWkxN0xpDQVVu8lEWPEW0RVKHhUOGV4ZRrXyRA2y
UqCuKmsjE3NF/UDZV3uLAAAAB2NvbW1lbnQBAgMEBQY=
-----END OPENSSH PRIVATE KEY-----"""
publicRSA_openssh = (
b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bWG+wloVDEd2NQhEUBVUIUKirg"
b"0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgYwBGTJAkMgUyP"
b"95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm/9NNN9b0b/h9qp"
b"KSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6RKXCpCnd1bqcPUWz"
b"xiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1"
b"foNfICZgptyti8ZseZj3 comment"
)
privateRSA_openssh = b"""-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkG
XoRVdV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMW
aqQE6Ul3w+RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4fa
qSknqv+t9YXmPhq4eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp
3dW6nD1Fs8YsGGTVuj3fq3/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28
OyluAgtX33ToE7Q3NX6DXyAmYKbcrYvGbHmY9wIDAQABAoIBACFMCGaiKNW0+44P
chuFCQC58k438BxXS+NRf54jp+Q6mFUb6ot6mB682Lqx+YkSGGCs6MwLTglaQGq6
L5n4syRghLnOaZWa+eL8H1FNJxXbKyet77RprL59EOuGR3BztACHlRU7N/nnFOeA
u2geG+bdu3NjuWfmsid/z88wm8KY/dkYNi82LvE9gXqf4QMtR9s0UWI53U/prKiL
2dbzhMQXuXGdBghCeE27xSr0w1jNVSvtvjNfBOp75gQkY/It1z0bbNWcY0MvkoiN
Pm7aGDfYDyVniR25RjReyc7Ei+2SWjMHD9+GCPmS6dvrOAg2yc3NCgFIWzk+esrG
gKnc1DkCgYEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6D
MaIVokQ9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0CgYEA+0QX
i6Q2vh43Haf2YWwExKrdeD4HjB4zAq4DFIeDeuWefQhnqPKqvxJwz3Kpp8cLHYjV
IP2cY8pHMFVOi8TP9H8WpJISdKEJwsRunIwz76Xl9+ArrU9cEaoahDdb/Xrqw818
sMjkH1Rjtcev3/QJp/zHJfxc6ZHXksWYHlbTsSMCgYBRr+mSn5QLSoRlPpSzO5IQ
tXS4jMnvyQ4BMvovaBKhAyauz1FoFEwmmyikAjMIX+GncJgBNHleUo7Ezza8H0tV
rOvBU4TH4WGoStSi/0ANgB8SqVDAKhh1lAwGmxZQqEvsQc177/dLyXUCaMSYuIaI
GFpD5wIzlyJkk4MMRSp87QKBgGlmN8ZA3SHFBPOwuD5HlHx2/C3rPzk8lcNDAVHE
Qpfz6Bakxu7s1EkQUDgE7jvN19DMzDJpkAegG1qf/jHNHjp+cR4ZlBpOTwzfX1LV
0Rdu7NectlWd244hX7wkiLb8r6vw76QssNyfhrADEriL4t0PwO4jPUpQ/i+4KUZY
v7YnAoGAZhb5IDTQVCW8YTGsgvvvnDUefkpVAmiVDQqTvh6/4UD6kKdUcDHpePzg
Zrcid5rr3dXSMEbK4tdeQZvPtUg1Uaol3N7bNClIIdvWdPx+5S9T95wJcLnkoHam
rXp0IjScTxfLP+Cq5V6lJ94/pX8Ppoj1FdZfNxeS4NYFSRA7kvY=
-----END RSA PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateRSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkGXoRV
dV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMWaqQE6Ul3w+
RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4faqSknqv+t9YXmPhq4
eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp3dW6nD1Fs8YsGGTVuj3fq3
/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28OyluAgtX33ToE7Q3NX6DXyAmYKbc
rYvGbHmY9wAAA7gXkBoMF5AaDAAAAAdzc2gtcnNhAAABAQDVaqx4I9bWG+wloVDEd2NQhE
UBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgY
wBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm
/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6
RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvb
w7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAAAAwEAAQAAAQAhTAhmoijVtPuOD3Ib
hQkAufJON/AcV0vjUX+eI6fkOphVG+qLepgevNi6sfmJEhhgrOjMC04JWkBqui+Z+LMkYI
S5zmmVmvni/B9RTScV2ysnre+0aay+fRDrhkdwc7QAh5UVOzf55xTngLtoHhvm3btzY7ln
5rInf8/PMJvCmP3ZGDYvNi7xPYF6n+EDLUfbNFFiOd1P6ayoi9nW84TEF7lxnQYIQnhNu8
Uq9MNYzVUr7b4zXwTqe+YEJGPyLdc9G2zVnGNDL5KIjT5u2hg32A8lZ4kduUY0XsnOxIvt
klozBw/fhgj5kunb6zgINsnNzQoBSFs5PnrKxoCp3NQ5AAAAgQCFSxt6mxIQN54frV7a/s
aW/t81a7k04haXkiYJvb1wIAOnNb0tG6DSB0cr1N6oqAcHG7gEIKcnQTxsOTnpQc7nFx3R
TFy8PdImJv5q1v1Icq5G+nvD0xlgRB2lE6eA9WMp1HpdBgcWXfaLPctkOuKEWk2MBi0tnR
zrg0x4PXlUzgAAAIEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6DMaIVok
Q9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0AAACBAPtEF4ukNr4eNx2n
9mFsBMSq3Xg+B4weMwKuAxSHg3rlnn0IZ6jyqr8ScM9yqafHCx2I1SD9nGPKRzBVTovEz/
R/FqSSEnShCcLEbpyMM++l5ffgK61PXBGqGoQ3W/166sPNfLDI5B9UY7XHr9/0Caf8xyX8
XOmR15LFmB5W07EjAAAAAAEC
-----END OPENSSH PRIVATE KEY-----
"""
# Encrypted with the passphrase 'encrypted'
privateRSA_openssh_encrypted = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
p2A1YsHLXkpMVcsEqhh/nCYb5AqL0uMzfEIqc8hpZ/Ub8PtLsypilMkqzYTnZIGS
ouyPjU/WgtR4VaDnutPWdgYaKdixSEmGhKghCtXFySZqCTJ4O8NCczsktYjUK3D4
Jtl90zL6O81WBY6xP76PBQo9lrI/heAetATeyqutc18bwQIGU+gKk32qvfo15DfS
VYiY0Ds4D7F7fd9pz+f5+UbFUCgU+tfDvBrqodYrUgmH7jKoW/CRDCHHyeEIZDbF
mcMwdcKOyw1sRLaPdihRSVx3kOMvIotHKVTkIDMp+0RTNeXzQnp5U2qzsxzTcG/M
UyJN38XXkuvq5VMj2zmmjHzx34w3NK3ZxpZcoaFUqUBlNp2C8hkCLrAa/DWobKqN
5xA1ElrQvli9XXkT/RIuy4Gc10bbGEoJjuxNRibtSxxWd5Bd1E40ocOd4l1ebI8+
w69XvMTnsmHvkBEADGF2zfRszKnMelg+W5NER1UDuNT03i+1cuhp+2AZg8z7niTO
M17XP3ScGVxrQAEYgtxPrPeIpFJvOx2j5Yt78U9Y2WlaAG6DrubbYv2RsMIibhOG
yk139vMdD8FwCey6yMkkhFAJwnBtC22MAWgjmC5c6AF3SRQSjjQXepPsJcLgpOjy
YwjhnL8w56x9kVDUNPw9A9Cqgxo2sty34ATnKrh4h59PsP83LOL6OC5WjbASgZRd
OIBD8RloQPISo+RUF7X0i4kdaHVNPlR0KyapR+3M5BwhQuvEO99IArDV2LNKGzfc
W4ssugm8iyAJlmwmb2yRXIDHXabInWY7XCdGk8J2qPFbDTvnPbiagJBimjVjgpWw
tV3sVlJYqmOqmCDP78J6he04l0vaHtiOWTDEmNCrK7oFMXIIp3XWjOZGPSOJFdPs
6Go3YB+EGWfOQxqkFM28gcqmYfVPF2sa1FbZLz0ffO11Ma/rliZxZu7WdrAXe/tc
BgIQ8etp2PwAK4jCwwVwjIO8FzqQGpS23Y9NY3rfi97ckgYXKESFtXPsMMA+drZd
ThbXvccfh4EPmaqQXKf4WghHiVJ+/yuY1kUIDEl/O0jRZWT7STgBim/Aha1m6qRs
zl1H7hkDbU4solb1GM5oPzbgGTzyBc+z0XxM9iFRM+fMzPB8+yYHTr4kPbVmKBjy
SCovjQQVsHE4YeUGTq6k/NF5cVIRKTW/RlHvzxsky1Zj31MC736jrxGw4KG7VSLZ
fP6F5jj+mXwS7m0v5to42JBZmRJdKUD88QaGE3ncyQ4yleW5bn9Lf9SuzQg1Dhao
3rSA1RuexsHlIAHvGxx/17X+pyygl8DJbt6TBfbLQk9wc707DJTfh5M/bnk9wwIX
l/Hsa1WtylAMW/2MzgiVy83MbYz4+Ss6GQ5W66okWji+NxrnrYEy6q+WgVQanp7X
D+D7oKykqE1Cdvvulvtfl5fh8wlAs8mrUnKPBBUru348u++2lfacLkxRXyT1ooqY
uSNE5nlwFt08N2Ou/bl7yq6QNRMYrRkn+UEfHWCNYDoGMHln2/i6Z1RapQzNarik
tJf7radBz5nBwBjP08YAEACNSQvpsUgdqiuYjLwX7efFXQva2RzqaQ==
-----END RSA PRIVATE KEY-----"""
# Encrypted with the passphrase 'encrypted', and using the new format
# introduced in OpenSSH 6.5
privateRSA_openssh_encrypted_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABD0f9WAof
DTbmwztb8pdrSeAAAAEAAAAAEAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bW
G+wloVDEd2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n
3WmM06QHjVyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9T
MA2l5bs9auIJNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQN
UZdy03w17snaY6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASf
NaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAADwPQaac8s1xX3af
hQTQexj0vEAWDQsLYzDHN9G7W+UP5WHUu7igeu2GqAC/TOnjUXDP73I+EN3n7T3JFeDRfs
U1Z6Zqb0NKHSRVYwDIdIi8qVohFv85g6+xQ01OpaoOzz+vI34OUvCRHQGTgR6L9fQShZyC
McopYMYfbIse6KcqkfxX3KSdG1Pao6Njx/ShFRbgvmALpR/z0EaGCzHCDxpfUyAdnxm621
Jzaf+LverWdN7sfrfMptaS9//9iJb70sL67K+YIB64qhDnA/w9UOQfXGQFL+AEtdM0BPv8
thP1bs7T0yucBl+ZXdrDKVLZfaS3S/w85Jlgfu+a1DG73pOBOuag435iEJ9EnspjXiiydx
GrfSRk2C+/c4fBDZVGFscK5bfQuUUZyU1qOagekxX7WLHFKk9xajnud+nrAN070SeNwlX8
FZ2CI4KGlQfDvVUpKanYn8Kkj3fZ+YBGyx4M+19clF65FKSM0x1Rrh5tAmNT/SNDbSc28m
ASxrBhztzxUFTrIn3tp+uqkJniFLmFsUtiAUmj8fNyE9blykU7dqq+CqpLA872nQ9bOHHA
JsS1oBYmQ0n6AJz8WrYMdcepqWVld6Q8QSD1zdrY/sAWUovuBA1s4oIEXZhpXSS4ZJiMfh
PVktKBwj5bmoG/mmwYLbo0JHntK8N3TGTzTGLq5TpSBBdVvWSWo7tnfEkrFObmhi1uJSrQ
3zfPVP6BguboxBv+oxhaUBK8UOANe6ZwM4vfiu+QN+sZqWymHIfAktz7eWzwlToe4cKpdG
Uv+e3/7Lo2dyMl3nke5HsSUrlsMGPREuGkBih8+o85ii6D+cuCiVtus3f5c78Cir80zLIr
Z0wWvEAjciEvml00DWaA+JIaOrWwvXySaOzFGpCqC9SQjao379bvn9P3b7kVZsy6zBfHqm
bNEJUOuhBZaY8Okz36chh1xqh4sz7m3nsZ3GYGcvM+3mvRY72QnqsQEG0Sp1XYIn2bHa29
tqp7CG9X8J6dqMcPeoPRDWIX9gw7EPl/M0LP6xgewGJ9bgxwle6Mnr9kNITIswjAJqrLec
zx7dfixjAPc42ADqrw/tEdFQcSqxigcfJNKO1LbDBjh+Hk/cSBou2PoxbIcl0qfQfbGcqI
Dbpd695IEuiW9pYR22txNoIi+7cbMsuFHxQ/OqbrX/jCsprGNNJLAjgGsVEI1JnHWDH0db
3UbqbOHAeY3ufoYXNY1utVOIACpW3r9wBw3FjRi04d70VcKr16OXvOAHGN2G++Y+kMya84
Hl/Kt/gA==
-----END OPENSSH PRIVATE KEY-----
"""
# Encrypted with the passphrase 'testxp'. NB: this key was generated by
# OpenSSH, so it doesn't use the same key data as the other keys here.
privateRSA_openssh_encrypted_aes = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----"""
publicRSA_lsh = (
b"{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuMjU3OgDVaqx4I9bWG+wloVD"
b"Ed2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHj"
b"VyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auI"
b"JNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY"
b"6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw"
b"7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3KSgxOmUzOgEAASkpKQ==}"
)
privateRSA_lsh = (
b"(11:private-key(9:rsa-pkcs1(1:n257:\x00\xd5j\xacx#\xd6\xd6\x1b\xec%\xa1P"
b"\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ\xfa9\x06^\x84Uu_"
b"\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98\xcd:@x\xd5\xca"
b"\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j\xa4\x04\xe9Iw"
b"\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e\xefS0\r\xa5\xe5"
b"\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad\xf5\x85\xe6>"
b"\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?\x91I\x96@\xd5"
b"\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E\xb3\xc6,\x18d"
b"\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b\x0b\xa4*nC\xc3"
b"\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdft\xe8\x13\xb475"
b"~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7)(1:e3:\x01\x00\x01)(1:d256:!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49)(1:p129:\x00\xfbD\x17\x8b\xa46\xbe"
b"\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03\x14\x87"
b"\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b\x1d\x88"
b"\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1\t\xc2"
b"\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz\xea"
b"\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\\xe9"
b"\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#)(1:q129:\x00\xd9p\x06\xd8\xe2\xbc\xd4"
b"x\x91P\x94\xd4\xc1\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}"
b"\x1a\xb1e\xe7qu9\xe02\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca"
b"\xe6MC\xb3\x9c\xf4k}\xe6\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$"
b'n\x831\xa2\x15\xa2D="\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff'
b"\x19\x18\x8e\xd8\xab\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d)(1:a128:if7"
b"\xc6@\xdd!\xc5\x04\xf3\xb0\xb8>G\x94|v\xfc-\xeb?9<\x95\xc3C\x01Q\xc4B"
b"\x97\xf3\xe8\x16\xa4\xc6\xee\xec\xd4I\x10P8\x04\xee;\xcd\xd7\xd0\xcc\xcc"
b"2i\x90\x07\xa0\x1bZ\x9f\xfe1\xcd\x1e:~q\x1e\x19\x94\x1aNO\x0c\xdf_R\xd5"
b"\xd1\x17n\xec\xd7\x9c\xb6U\x9d\xdb\x8e!_\xbc$\x88\xb6\xfc\xaf\xab\xf0"
b"\xef\xa4,\xb0\xdc\x9f\x86\xb0\x03\x12\xb8\x8b\xe2\xdd\x0f\xc0\xee#=JP"
b"\xfe/\xb8)FX\xbf\xb6')(1:b128:Q\xaf\xe9\x92\x9f\x94\x0bJ\x84e>\x94\xb3;"
b"\x92\x10\xb5t\xb8\x8c\xc9\xef\xc9\x0e\x012\xfa/h\x12\xa1\x03&\xae\xcfQh"
b"\x14L&\x9b(\xa4\x023\x08_\xe1\xa7p\x98\x014y^R\x8e\xc4\xcf6\xbc\x1fKU"
b"\xac\xeb\xc1S\x84\xc7\xe1a\xa8J\xd4\xa2\xff@\r\x80\x1f\x12\xa9P\xc0*\x18"
b"u\x94\x0c\x06\x9b\x16P\xa8K\xecA\xcd{\xef\xf7K\xc9u\x02h\xc4\x98\xb8\x86"
b'\x88\x18ZC\xe7\x023\x97"d\x93\x83\x0cE*|\xed)(1:c128:f\x16\xf9 4\xd0T%'
b"\xbca1\xac\x82\xfb\xef\x9c5\x1e~JU\x02h\x95\r\n\x93\xbe\x1e\xbf\xe1@\xfa"
b'\x90\xa7Tp1\xe9x\xfc\xe0f\xb7"w\x9a\xeb\xdd\xd5\xd20F\xca\xe2\xd7^A\x9b'
b"\xcf\xb5H5Q\xaa%\xdc\xde\xdb4)H!\xdb\xd6t\xfc~\xe5/S\xf7\x9c\tp\xb9\xe4"
b"\xa0v\xa6\xadzt\"4\x9cO\x17\xcb?\xe0\xaa\xe5^\xa5'\xde?\xa5\x7f\x0f\xa6"
b"\x88\xf5\x15\xd6_7\x17\x92\xe0\xd6\x05I\x10;\x92\xf6)))"
)
privateRSA_agentv3 = (
b"\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x03\x01\x00\x01\x00\x00\x01\x00!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49\x00\x00\x01\x01\x00\xd5j\xacx#\xd6"
b"\xd6\x1b\xec%\xa1P\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ"
b"\xfa9\x06^\x84Uu_\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98"
b"\xcd:@x\xd5\xca\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j"
b"\xa4\x04\xe9Iw\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e"
b"\xefS0\r\xa5\xe5\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad"
b"\xf5\x85\xe6>\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?"
b"\x91I\x96@\xd5\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E"
b"\xb3\xc6,\x18d\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b"
b"\x0b\xa4*nC\xc3\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdf"
b"t\xe8\x13\xb475~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7\x00\x00\x00\x81"
b"\x00\x85K\x1bz\x9b\x12\x107\x9e\x1f\xad^\xda\xfe\xc6\x96\xfe\xdf5k\xb94"
b"\xe2\x16\x97\x92&\t\xbd\xbdp \x03\xa75\xbd-\x1b\xa0\xd2\x07G+\xd4\xde"
b"\xa8\xa8\x07\x07\x1b\xb8\x04 \xa7'A<l99\xe9A\xce\xe7\x17\x1d\xd1L\\\xbc="
b"\xd2&&\xfej\xd6\xfdHr\xaeF\xfa{\xc3\xd3\x19`D\x1d\xa5\x13\xa7\x80\xf5c)"
b"\xd4z]\x06\x07\x16]\xf6\x8b=\xcbd:\xe2\x84ZM\x8c\x06--\x9d\x1c\xeb\x83Lx"
b"=yT\xce\x00\x00\x00\x81\x00\xd9p\x06\xd8\xe2\xbc\xd4x\x91P\x94\xd4\xc1"
b"\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}\x1a\xb1e\xe7qu9\xe02"
b"\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca\xe6MC\xb3\x9c\xf4k}\xe6"
b'\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$n\x831\xa2\x15\xa2D="'
b'\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff\x19\x18\x8e\xd8\xab'
b"\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d\x00\x00\x00\x81\x00\xfbD\x17\x8b"
b"\xa46\xbe\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03"
b"\x14\x87\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b"
b"\x1d\x88\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1"
b"\t\xc2\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz"
b"\xea\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\"
b"\xe9\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#"
)
publicDSA_openssh = b"""\
ssh-dss AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9\
LvFYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G\
+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0\
EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB\
7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nD\
ioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQT\
NEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2G\
gdgMQWC7S6WFIXePGGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8= \
comment\
"""
privateDSA_openssh = b"""\
-----BEGIN DSA PRIVATE KEY-----
MIIBvAIBAAKBgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyD
JKsvnLLCDTP5Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VV
Ir6LPzJmFSeuqk/fYbXGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQIVAPQR
iZM1oUnwJLT68VIXidhzISdJAoGBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2M
zB3tZXTMRAe8zcOWjp8Y4aGC7Yh3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4
Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDKJ0fDdwBDub/UnwaktkPUejga+pX9OEb8
KR2dFBrvAoGAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNICl
GlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXeP
GGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8CFQDV2gbL
czUdxCus0pfEP1bddaXRLQ==
-----END DSA PRIVATE KEY-----\
"""
privateDSA_openssh_new = b"""\
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABsgAAAAdzc2gtZH
NzAAAAgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyDJKsvnLLCDTP5
Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VVIr6LPzJmFSeuqk/fYb
XGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQAAABUA9BGJkzWhSfAktPrxUheJ2HMh
J0kAAACBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2MzB3tZXTMRAe8zcOWjp8Y4aGC7Y
h3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDK
J0fDdwBDub/UnwaktkPUejga+pX9OEb8KR2dFBrvAAAAgAIUacRjCFhMmhIfGJ44ms0EzR
KZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPxh2pFuWFh
OHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQhJ
lz6CzfAAAB2MVcBjzFXAY8AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9Lv
FYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+
Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+S
x9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jA
GmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmiz
kvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ
0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi35
9efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDdWxlX8u
mhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEtAAAAAAE=
-----END OPENSSH PRIVATE KEY-----
"""
publicDSA_lsh = decodebytes(
b"""\
e0tERXdPbkIxWW14cFl5MXJaWGtvTXpwa2MyRW9NVHB3TVRJNU9nQ1NrRHJGUkVWUTBDS1FEUngv
aVFBTXBhUFM3eFdKaFJWVENMaHhScFdhU0MrN0lkeURKS3N2bkxMQ0RUUDVaeHc5MzVyQU1pNVZG
MmJiZWp3L1M0R1VXczdEem9LYmJoL2hydVBCdnNoYmhQSmRIMVZWSXI2TFB6Sm1GU2V1cWsvZlli
WEdTbXJhMDFtWjZWVnU0QlFBYUtzUG9YdGkyZElKSGxOUGtzZld1U2tvTVRweE1qRTZBUFFSaVpN
MW9VbndKTFQ2OFZJWGlkaHpJU2RKS1NneE9tY3hNams2QUpJQzQ4UnlRQk82aFZuVTZuK3hHWjRM
dTR3QnBuUWZaTjJNekIzdFpYVE1SQWU4emNPV2pwOFk0YUdDN1loM3FCSERwTmx6c3I1eFBsWDRT
RnNGZ0hUb2lrVXhHWlpvczVMNFJtNnRBd250aER6RjRxNDZsZ0FpT1p3NHFFRjlhQkRLSjBmRGR3
QkR1Yi9Vbndha3RrUFVlamdhK3BYOU9FYjhLUjJkRkJydktTZ3hPbmt4TWpnNkFoUnB4R01JV0V5
YUVoOFluamlhelFUTkVwa2xSWnFlQkdvMWdvdEpnZ05tVmFJUU5JQ2xHbEx5Q2kzNTllZkVVdVFj
WjlTWHhNNTlQK2hlY2MvR1UvR0hha1c1WVdFNGRQMkdnZGdNUVdDN1M2V0ZJWGVQR0dYcU5RRGRX
eGxYOHVtaGVudlFxYTFQbktyRlJoRHJKdzhaN0dqZEh4ZmxzeENFbVhQb0xOOHBLU2s9fQ==
"""
)
privateDSA_lsh = decodebytes(
b"""\
KDExOnByaXZhdGUta2V5KDM6ZHNhKDE6cDEyOToAkpA6xURFUNAikA0cf4kADKWj0u8ViYUVUwi4
cUaVmkgvuyHcgySrL5yywg0z+WccPd+awDIuVRdm23o8P0uBlFrOw86Cm24f4a7jwb7IW4TyXR9V
VSK+iz8yZhUnrqpP32G1xkpq2tNZmelVbuAUAGirD6F7YtnSCR5TT5LH1rkpKDE6cTIxOgD0EYmT
NaFJ8CS0+vFSF4nYcyEnSSkoMTpnMTI5OgCSAuPEckATuoVZ1Op/sRmeC7uMAaZ0H2TdjMwd7WV0
zEQHvM3Dlo6fGOGhgu2Id6gRw6TZc7K+cT5V+EhbBYB06IpFMRmWaLOS+EZurQMJ7YQ8xeKuOpYA
IjmcOKhBfWgQyidHw3cAQ7m/1J8GpLZD1Ho4GvqV/ThG/CkdnRQa7ykoMTp5MTI4OgIUacRjCFhM
mhIfGJ44ms0EzRKZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPx
h2pFuWFhOHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQ
hJlz6CzfKSgxOngyMToA1doGy3M1HcQrrNKXxD9W3XWl0S0pKSk=
"""
)
privateDSA_agentv3 = decodebytes(
b"""\
AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9LvFYmFFVMIuHFGlZpIL7sh3IMkqy+c
ssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99h
tcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAA
AIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOy
vnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2
Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQ
NIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDd
WxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEt
"""
)
__all__ = [
"DSAData",
"RSAData",
"privateDSA_agentv3",
"privateDSA_lsh",
"privateDSA_openssh",
"privateRSA_agentv3",
"privateRSA_lsh",
"privateRSA_openssh",
"publicDSA_lsh",
"publicDSA_openssh",
"publicRSA_lsh",
"publicRSA_openssh",
]

View File

@@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Loopback helper used in test_ssh and test_recvline
"""
from twisted.protocols import loopback
class LoopbackRelay(loopback.LoopbackRelay):
clearCall = None
def logPrefix(self):
return f"LoopbackRelay({self.target.__class__.__name__!r})"
def write(self, data):
loopback.LoopbackRelay.write(self, data)
if self.clearCall is not None:
self.clearCall.cancel()
from twisted.internet import reactor
self.clearCall = reactor.callLater(0, self._clearBuffer)
def _clearBuffer(self):
self.clearCall = None
loopback.LoopbackRelay.clearBuffer(self)

View File

@@ -0,0 +1,45 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{SSHTransportAddrress} in ssh/address.py
"""
from __future__ import annotations
from typing import Callable
from twisted.conch.ssh.address import SSHTransportAddress
from twisted.internet.address import IPv4Address
from twisted.internet.test.test_address import AddressTestCaseMixin
from twisted.trial import unittest
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
"""
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
use to represent the other side of the SSH connection. This tests the
basic functionality of that class (string representation, comparison, &c).
"""
def _stringRepresentation(self, stringFunction: Callable[[object], str]) -> None:
"""
The string representation of C{SSHTransportAddress} should be
"SSHTransportAddress(<stringFunction on address>)".
"""
addr = self.buildAddress()
stringValue = stringFunction(addr)
addressValue = stringFunction(addr.address)
self.assertEqual(stringValue, "SSHTransportAddress(%s)" % addressValue)
def buildAddress(self) -> SSHTransportAddress:
"""
Create an arbitrary new C{SSHTransportAddress}. A new instance is
created for each call, but always for the same address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
def buildDifferentAddress(self) -> SSHTransportAddress:
"""
Like C{buildAddress}, but with a different fixed address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))

View File

@@ -0,0 +1,398 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.agent}.
"""
import struct
from twisted.test import iosim
from twisted.trial import unittest
try:
import cryptography as _cryptography
except ImportError:
cryptography = None
else:
cryptography = _cryptography
try:
from twisted.conch.ssh import agent as _agent, keys as _keys
except ImportError:
keys = agent = None
else:
keys, agent = _keys, _agent
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.test import keydata
class StubFactory:
"""
Mock factory that provides the keys attribute required by the
SSHAgentServerProtocol
"""
def __init__(self):
self.keys = {}
class AgentTestBase(unittest.TestCase):
"""
Tests for SSHAgentServer/Client.
"""
if agent is None or keys is None:
skip = "Cannot run without cryptography"
def setUp(self):
# wire up our client <-> server
self.client, self.server, self.pump = iosim.connectedServerAndClient(
agent.SSHAgentServer, agent.SSHAgentClient
)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
# pub/priv keys of each kind
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
class ServerProtocolContractWithFactoryTests(AgentTestBase):
"""
The server protocol is stateful and so uses its factory to track state
across requests. This test asserts that the protocol raises if its factory
doesn't provide the necessary storage for that state.
"""
def test_factorySuppliesKeyStorageForServerProtocol(self):
# need a message to send into the server
msg = struct.pack("!LB", 1, agent.AGENTC_REQUEST_IDENTITIES)
del self.server.factory.__dict__["keys"]
self.assertRaises(MissingKeyStoreError, self.server.dataReceived, msg)
class UnimplementedVersionOneServerTests(AgentTestBase):
"""
Tests for methods with no-op implementations on the server. We need these
for clients, such as openssh, that try v1 methods before going to v2.
Because the client doesn't expose these operations with nice method names,
we invoke sendRequest directly with an op code.
"""
def test_agentc_REQUEST_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA identities request
"""
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, b"")
self.pump.flush()
def _cb(packet):
self.assertEqual(agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0:1]))
return d.addCallback(_cb)
def test_agentc_REMOVE_RSA_IDENTITY(self):
"""
assert that we get the correct op code for an RSA remove identity request
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, b"")
self.pump.flush()
return d.addCallback(self.assertEqual, b"")
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA remove all identities
request.
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, b"")
self.pump.flush()
return d.addCallback(self.assertEqual, b"")
if agent is not None:
class CorruptServer(agent.SSHAgentServer): # type: ignore[name-defined]
"""
A misbehaving server that returns bogus response op codes so that we can
verify that our callbacks that deal with these op codes handle such
miscreants.
"""
def agentc_REQUEST_IDENTITIES(self, data):
self.sendResponse(254, b"")
def agentc_SIGN_REQUEST(self, data):
self.sendResponse(254, b"")
class ClientWithBrokenServerTests(AgentTestBase):
"""
verify error handling code in the client using a misbehaving server
"""
def setUp(self):
AgentTestBase.setUp(self)
self.client, self.server, self.pump = iosim.connectedServerAndClient(
CorruptServer, agent.SSHAgentClient
)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
def test_signDataCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.signData} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for data signing requests.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentitiesCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for identity requests.
"""
d = self.client.requestIdentities()
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentKeyAdditionTests(AgentTestBase):
"""
Test adding different flavors of keys to an agent.
"""
def test_addRSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b"", serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b"", serverKey[1])
return d.addCallback(_check)
def test_addRSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.rsaPrivate.privateBlob(), comment=b"My special key"
)
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b"My special key", serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.dsaPrivate.privateBlob(), comment=b"My special key"
)
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b"My special key", serverKey[1])
return d.addCallback(_check)
class AgentClientFailureTests(AgentTestBase):
def test_agentFailure(self):
"""
verify that the client raises ConchError on AGENT_FAILURE
"""
d = self.client.sendRequest(254, b"")
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentIdentityRequestsTests(AgentTestBase):
"""
Test operations against a server with identities already loaded.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate,
b"a comment",
)
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate,
b"another comment",
)
def test_signDataRSA(self):
"""
Sign data with an RSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
signature = self.successResultOf(d)
expected = self.rsaPrivate.sign(b"John Hancock")
self.assertEqual(expected, signature)
self.assertTrue(self.rsaPublic.verify(signature, b"John Hancock"))
def test_signDataDSA(self):
"""
Sign data with a DSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.dsaPublic.blob(), b"John Hancock")
self.pump.flush()
def _check(sig):
# Cannot do this b/c DSA uses random numbers when signing
# expected = self.dsaPrivate.sign("John Hancock")
# self.assertEqual(expected, sig)
self.assertTrue(self.dsaPublic.verify(sig, b"John Hancock"))
return d.addCallback(_check)
def test_signDataRSAErrbackOnUnknownBlob(self):
"""
Assert that we get an errback if we try to sign data using a key that
wasn't added.
"""
del self.server.factory.keys[self.rsaPublic.blob()]
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentities(self):
"""
Assert that we get all of the keys/comments that we add when we issue a
request for all identities.
"""
d = self.client.requestIdentities()
self.pump.flush()
def _check(keyt):
expected = {}
expected[self.dsaPublic.blob()] = b"a comment"
expected[self.rsaPublic.blob()] = b"another comment"
received = {}
for k in keyt:
received[keys.Key.fromString(k[0], type="blob").blob()] = k[1]
self.assertEqual(expected, received)
return d.addCallback(_check)
class AgentKeyRemovalTests(AgentTestBase):
"""
Test support for removing keys in a remote server.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate,
b"a comment",
)
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate,
b"another comment",
)
def test_removeRSAIdentity(self):
"""
Assert that we can remove an RSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.rsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeDSAIdentity(self):
"""
Assert that we can remove a DSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.dsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeAllIdentities(self):
"""
Assert that we can remove all identities.
"""
d = self.client.removeAllIdentities()
self.pump.flush()
def _check(ignored):
self.assertEqual(0, len(self.server.factory.keys))
return d.addCallback(_check)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
# Copyright Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test ssh/channel.py.
"""
from __future__ import annotations
from unittest import skipIf
from zope.interface.verify import verifyObject
try:
from twisted.conch.ssh import channel
from twisted.conch.ssh.address import SSHTransportAddress
from twisted.conch.ssh.service import SSHService
from twisted.conch.ssh.transport import SSHServerTransport
from twisted.internet import interfaces
from twisted.internet.address import IPv4Address
from twisted.internet.testing import StringTransport
skipTest = ""
except ImportError:
skipTest = "Conch SSH not supported."
SSHService = object # type: ignore[assignment,misc]
from twisted.trial.unittest import TestCase
class MockConnection(SSHService):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a L{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a L{dict} mapping channel id #s to True if that channel sent
a close message.
"""
def __init__(self) -> None:
self.data: dict[channel.SSHChannel, list[bytes]] = {}
self.extData: dict[channel.SSHChannel, list[tuple[int, bytes]]] = {}
self.closes: dict[channel.SSHChannel, bool] = {}
def logPrefix(self) -> str:
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel: channel.SSHChannel, data: bytes) -> None:
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(
self, channel: channel.SSHChannel, type: int, data: bytes
) -> None:
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel: channel.SSHChannel) -> None:
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
def connectSSHTransport(
service: SSHService,
hostAddress: interfaces.IAddress | None = None,
peerAddress: interfaces.IAddress | None = None,
) -> None:
"""
Connect a SSHTransport which is already connected to a remote peer to
the channel under test.
@param service: Service used over the connected transport.
@type service: L{SSHService}
@param hostAddress: Local address of the connected transport.
@type hostAddress: L{interfaces.IAddress}
@param peerAddress: Remote address of the connected transport.
@type peerAddress: L{interfaces.IAddress}
"""
transport = SSHServerTransport()
transport.makeConnection(
StringTransport(hostAddress=hostAddress, peerAddress=peerAddress)
)
transport.setService(service)
@skipIf(skipTest, skipTest)
class ChannelTests(TestCase):
"""
Tests for L{SSHChannel}.
"""
def setUp(self) -> None:
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn, remoteMaxPacket=10)
self.channel.name = b"channel"
def test_interface(self) -> None:
"""
L{SSHChannel} instances provide L{interfaces.ITransport}.
"""
self.assertTrue(verifyObject(interfaces.ITransport, self.channel))
def test_init(self) -> None:
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertIsNone(c.data)
self.assertIsNone(c.avatar)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self) -> None:
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(str(self.channel), "<SSHChannel channel (lw 131072 rw 0)>")
self.assertEqual(
str(channel.SSHChannel(localWindow=1)), "<SSHChannel None (lw 1 rw 0)>"
)
def test_bytes(self) -> None:
"""
Test that bytes(SSHChannel) works, gives the channel name and
local and remote windows at a glance..
"""
self.assertEqual(
self.channel.__bytes__(), b"<SSHChannel channel (lw 131072 rw 0)>"
)
self.assertEqual(
channel.SSHChannel(localWindow=1).__bytes__(),
b"<SSHChannel None (lw 1 rw 0)>",
)
def test_logPrefix(self) -> None:
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(
self.channel.logPrefix(), "SSHChannel channel (unknown) on MockConnection"
)
def test_addWindowBytes(self) -> None:
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting() -> None:
cb[0] = True
self.channel.startWriting = stubStartWriting # type: ignore[method-assign]
self.channel.write(b"test")
self.channel.writeExtended(1, b"test")
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, b"")
self.assertEqual(self.conn.data[self.channel], [b"test"])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, b"test")])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write(b"a" * 80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self) -> None:
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == b"" # type: ignore[attr-defined]
self.assertTrue(self.channel.requestReceived(b"test-method", b""))
self.assertFalse(self.channel.requestReceived(b"test-method", b"a"))
self.assertFalse(self.channel.requestReceived(b"bad-method", b""))
def test_closeReceieved(self) -> None:
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self) -> None:
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting() -> None:
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
self.channel.write(b"d")
self.channel.write(b"a")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write(b"ta")
data = self.conn.data[self.channel]
self.assertEqual(data, [b"da", b"ta"])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write(b"12345678901")
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1"])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write(b"123456")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1", b"12345"])
self.assertEqual(self.channel.buf, b"6")
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self) -> None:
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting() -> None:
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
self.channel.writeExtended(1, b"d")
self.channel.writeExtended(1, b"a")
self.channel.writeExtended(2, b"t")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, b"a")
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, b"da"), (2, b"t"), (2, b"a")])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, b"12345678901")
self.assertEqual(
data, [(1, b"da"), (2, b"t"), (2, b"a"), (3, b"1234567890"), (3, b"1")]
)
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, b"123456")
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(
data,
[
(1, b"da"),
(2, b"t"),
(2, b"a"),
(3, b"1234567890"),
(3, b"1"),
(4, b"12345"),
],
)
self.assertEqual(self.channel.extBuf, [[4, b"6"]])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self) -> None:
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(b"%d" % (i,) for i in range(10))
self.assertEqual(self.conn.data[self.channel], [b"0123456789"])
def test_loseConnection(self) -> None:
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write(b"data")
self.channel.writeExtended(1, b"datadata")
self.channel.loseConnection()
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(4) # send regular data
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self) -> None:
"""
L{SSHChannel.getPeer} returns the same object as the underlying
transport's C{getPeer} method returns.
"""
peer = IPv4Address("TCP", "192.168.0.1", 54321)
connectSSHTransport(service=self.channel.conn, peerAddress=peer)
self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())
def test_getHost(self) -> None:
"""
L{SSHChannel.getHost} returns the same object as the underlying
transport's C{getHost} method returns.
"""
host = IPv4Address("TCP", "127.0.0.1", 12345)
connectSSHTransport(service=self.channel.conn, hostAddress=host)
self.assertEqual(SSHTransportAddress(host), self.channel.getHost())

Some files were not shown because too many files have changed in this diff Show More