mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-22 14:51:09 -05:00
okay fine
This commit is contained in:
12
.venv/lib/python3.12/site-packages/twisted/__init__.py
Normal file
12
.venv/lib/python3.12/site-packages/twisted/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- test-case-name: twisted -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted: The Framework Of Your Internet.
|
||||
"""
|
||||
|
||||
from twisted._version import __version__ as version
|
||||
|
||||
__version__ = version.short()
|
||||
14
.venv/lib/python3.12/site-packages/twisted/__main__.py
Normal file
14
.venv/lib/python3.12/site-packages/twisted/__main__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# Make the twisted module executable with the default behaviour of
|
||||
# running twist.
|
||||
# This is not a docstring to avoid changing the string output of twist.
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
from twisted.application.twist._twist import Twist
|
||||
|
||||
sys.exit(Twist.main())
|
||||
@@ -0,0 +1,24 @@
|
||||
# -*- test-case-name: twisted.test.test_paths -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted integration with operating system threads.
|
||||
"""
|
||||
|
||||
|
||||
from ._ithreads import AlreadyQuit, IWorker
|
||||
from ._memory import createMemoryWorker
|
||||
from ._pool import pool
|
||||
from ._team import Team
|
||||
from ._threadworker import LockWorker, ThreadWorker
|
||||
|
||||
__all__ = [
|
||||
"ThreadWorker",
|
||||
"LockWorker",
|
||||
"IWorker",
|
||||
"AlreadyQuit",
|
||||
"Team",
|
||||
"createMemoryWorker",
|
||||
"pool",
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
# -*- test-case-name: twisted._threads.test.test_convenience -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Common functionality used within the implementation of various workers.
|
||||
"""
|
||||
|
||||
|
||||
from ._ithreads import AlreadyQuit
|
||||
|
||||
|
||||
class Quit:
|
||||
"""
|
||||
A flag representing whether a worker has been quit.
|
||||
|
||||
@ivar isSet: Whether this flag is set.
|
||||
@type isSet: L{bool}
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Create a L{Quit} un-set.
|
||||
"""
|
||||
self.isSet = False
|
||||
|
||||
def set(self) -> None:
|
||||
"""
|
||||
Set the flag if it has not been set.
|
||||
|
||||
@raise AlreadyQuit: If it has been set.
|
||||
"""
|
||||
self.check()
|
||||
self.isSet = True
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check if the flag has been set.
|
||||
|
||||
@raise AlreadyQuit: If it has been set.
|
||||
"""
|
||||
if self.isSet:
|
||||
raise AlreadyQuit()
|
||||
@@ -0,0 +1,61 @@
|
||||
# -*- test-case-name: twisted._threads.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Interfaces related to threads.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from zope.interface import Interface
|
||||
|
||||
|
||||
class AlreadyQuit(Exception):
|
||||
"""
|
||||
This worker worker is dead and cannot execute more instructions.
|
||||
"""
|
||||
|
||||
|
||||
class IWorker(Interface):
|
||||
"""
|
||||
A worker that can perform some work concurrently.
|
||||
|
||||
All methods on this interface must be thread-safe.
|
||||
"""
|
||||
|
||||
def do(task: Callable[[], None]) -> None:
|
||||
"""
|
||||
Perform the given task.
|
||||
|
||||
As an interface, this method makes no specific claims about concurrent
|
||||
execution. An L{IWorker}'s C{do} implementation may defer execution
|
||||
for later on the same thread, immediately on a different thread, or
|
||||
some combination of the two. It is valid for a C{do} method to
|
||||
schedule C{task} in such a way that it may never be executed.
|
||||
|
||||
It is important for some implementations to provide specific properties
|
||||
with respect to where C{task} is executed, of course, and client code
|
||||
may rely on a more specific implementation of C{do} than L{IWorker}.
|
||||
|
||||
@param task: a task to call in a thread or other concurrent context.
|
||||
@type task: 0-argument callable
|
||||
|
||||
@raise AlreadyQuit: if C{quit} has been called.
|
||||
"""
|
||||
|
||||
def quit() -> None:
|
||||
"""
|
||||
Free any resources associated with this L{IWorker} and cause it to
|
||||
reject all future work.
|
||||
|
||||
@raise AlreadyQuit: if this method has already been called.
|
||||
"""
|
||||
|
||||
|
||||
class IExclusiveWorker(IWorker):
|
||||
"""
|
||||
Like L{IWorker}, but with the additional guarantee that the callables
|
||||
passed to C{do} will not be called exclusively with each other.
|
||||
"""
|
||||
@@ -0,0 +1,83 @@
|
||||
# -*- test-case-name: twisted._threads.test.test_memory -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of an in-memory worker that defers execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Literal
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from ._convenience import Quit
|
||||
from ._ithreads import IExclusiveWorker
|
||||
|
||||
|
||||
class NoMore(Enum):
|
||||
Work = auto()
|
||||
|
||||
|
||||
NoMoreWork = NoMore.Work
|
||||
|
||||
|
||||
@implementer(IExclusiveWorker)
|
||||
class MemoryWorker:
|
||||
"""
|
||||
An L{IWorker} that queues work for later performance.
|
||||
|
||||
@ivar _quit: a flag indicating
|
||||
@type _quit: L{Quit}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pending: Callable[[], list[Callable[[], object] | Literal[NoMore.Work]]] = list,
|
||||
) -> None:
|
||||
"""
|
||||
Create a L{MemoryWorker}.
|
||||
"""
|
||||
self._quit = Quit()
|
||||
self._pending = pending()
|
||||
|
||||
def do(self, work: Callable[[], object]) -> None:
|
||||
"""
|
||||
Queue some work for to perform later; see L{createMemoryWorker}.
|
||||
|
||||
@param work: The work to perform.
|
||||
"""
|
||||
self._quit.check()
|
||||
self._pending.append(work)
|
||||
|
||||
def quit(self) -> None:
|
||||
"""
|
||||
Quit this worker.
|
||||
"""
|
||||
self._quit.set()
|
||||
self._pending.append(NoMoreWork)
|
||||
|
||||
|
||||
def createMemoryWorker() -> tuple[MemoryWorker, Callable[[], bool]]:
|
||||
"""
|
||||
Create an L{IWorker} that does nothing but defer work, to be performed
|
||||
later.
|
||||
|
||||
@return: a worker that will enqueue work to perform later, and a callable
|
||||
that will perform one element of that work.
|
||||
"""
|
||||
|
||||
def perform() -> bool:
|
||||
if not worker._pending:
|
||||
return False
|
||||
peek = worker._pending[0]
|
||||
if peek is NoMoreWork:
|
||||
return False
|
||||
worker._pending.pop(0)
|
||||
peek()
|
||||
return True
|
||||
|
||||
worker = MemoryWorker()
|
||||
return (worker, perform)
|
||||
73
.venv/lib/python3.12/site-packages/twisted/_threads/_pool.py
Normal file
73
.venv/lib/python3.12/site-packages/twisted/_threads/_pool.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# -*- test-case-name: twisted._threads.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Top level thread pool interface, used to implement
|
||||
L{twisted.python.threadpool}.
|
||||
"""
|
||||
|
||||
|
||||
from queue import Queue
|
||||
from threading import Lock, Thread, local as LocalStorage
|
||||
from typing import Callable, Optional
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.python.log import err
|
||||
from ._ithreads import IWorker
|
||||
from ._team import Team
|
||||
from ._threadworker import LockWorker, ThreadWorker
|
||||
|
||||
|
||||
class _ThreadFactory(Protocol):
|
||||
def __call__(self, *, target: Callable[..., object]) -> Thread:
|
||||
...
|
||||
|
||||
|
||||
def pool(
|
||||
currentLimit: Callable[[], int], threadFactory: _ThreadFactory = Thread
|
||||
) -> Team:
|
||||
"""
|
||||
Construct a L{Team} that spawns threads as a thread pool, with the given
|
||||
limiting function.
|
||||
|
||||
@note: Future maintainers: while the public API for the eventual move to
|
||||
twisted.threads should look I{something} like this, and while this
|
||||
function is necessary to implement the API described by
|
||||
L{twisted.python.threadpool}, I am starting to think the idea of a hard
|
||||
upper limit on threadpool size is just bad (turning memory performance
|
||||
issues into correctness issues well before we run into memory
|
||||
pressure), and instead we should build something with reactor
|
||||
integration for slowly releasing idle threads when they're not needed
|
||||
and I{rate} limiting the creation of new threads rather than just
|
||||
hard-capping it.
|
||||
|
||||
@param currentLimit: a callable that returns the current limit on the
|
||||
number of workers that the returned L{Team} should create; if it
|
||||
already has more workers than that value, no new workers will be
|
||||
created.
|
||||
@type currentLimit: 0-argument callable returning L{int}
|
||||
|
||||
@param threadFactory: Factory that, when given a C{target} keyword argument,
|
||||
returns a L{threading.Thread} that will run that target.
|
||||
@type threadFactory: callable returning a L{threading.Thread}
|
||||
|
||||
@return: a new L{Team}.
|
||||
"""
|
||||
|
||||
def startThread(target: Callable[..., object]) -> None:
|
||||
return threadFactory(target=target).start()
|
||||
|
||||
def limitedWorkerCreator() -> Optional[IWorker]:
|
||||
stats = team.statistics()
|
||||
if stats.busyWorkerCount + stats.idleWorkerCount >= currentLimit():
|
||||
return None
|
||||
return ThreadWorker(startThread, Queue())
|
||||
|
||||
team = Team(
|
||||
coordinator=LockWorker(Lock(), LocalStorage()),
|
||||
createWorker=limitedWorkerCreator,
|
||||
logException=err,
|
||||
)
|
||||
return team
|
||||
232
.venv/lib/python3.12/site-packages/twisted/_threads/_team.py
Normal file
232
.venv/lib/python3.12/site-packages/twisted/_threads/_team.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# -*- test-case-name: twisted._threads.test.test_team -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of a L{Team} of workers; a thread-pool that can allocate work to
|
||||
workers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from . import IWorker
|
||||
from ._convenience import Quit
|
||||
from ._ithreads import IExclusiveWorker
|
||||
|
||||
|
||||
class Statistics:
|
||||
"""
|
||||
Statistics about a L{Team}'s current activity.
|
||||
|
||||
@ivar idleWorkerCount: The number of idle workers.
|
||||
@type idleWorkerCount: L{int}
|
||||
|
||||
@ivar busyWorkerCount: The number of busy workers.
|
||||
@type busyWorkerCount: L{int}
|
||||
|
||||
@ivar backloggedWorkCount: The number of work items passed to L{Team.do}
|
||||
which have not yet been sent to a worker to be performed because not
|
||||
enough workers are available.
|
||||
@type backloggedWorkCount: L{int}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, idleWorkerCount: int, busyWorkerCount: int, backloggedWorkCount: int
|
||||
) -> None:
|
||||
self.idleWorkerCount = idleWorkerCount
|
||||
self.busyWorkerCount = busyWorkerCount
|
||||
self.backloggedWorkCount = backloggedWorkCount
|
||||
|
||||
|
||||
@implementer(IWorker)
|
||||
class Team:
|
||||
"""
|
||||
A composite L{IWorker} implementation.
|
||||
|
||||
@ivar _quit: A L{Quit} flag indicating whether this L{Team} has been quit
|
||||
yet. This may be set by an arbitrary thread since L{Team.quit} may be
|
||||
called from anywhere.
|
||||
|
||||
@ivar _coordinator: the L{IExclusiveWorker} coordinating access to this
|
||||
L{Team}'s internal resources.
|
||||
|
||||
@ivar _createWorker: a callable that will create new workers.
|
||||
|
||||
@ivar _logException: a 0-argument callable called in an exception context
|
||||
when there is an unhandled error from a task passed to L{Team.do}
|
||||
|
||||
@ivar _idle: a L{set} of idle workers.
|
||||
|
||||
@ivar _busyCount: the number of workers currently busy.
|
||||
|
||||
@ivar _pending: a C{deque} of tasks - that is, 0-argument callables passed
|
||||
to L{Team.do} - that are outstanding.
|
||||
|
||||
@ivar _shouldQuitCoordinator: A flag indicating that the coordinator should
|
||||
be quit at the next available opportunity. Unlike L{Team._quit}, this
|
||||
flag is only set by the coordinator.
|
||||
|
||||
@ivar _toShrink: the number of workers to shrink this L{Team} by at the
|
||||
next available opportunity; set in the coordinator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinator: IExclusiveWorker,
|
||||
createWorker: Callable[[], Optional[IWorker]],
|
||||
logException: Callable[[], None],
|
||||
):
|
||||
"""
|
||||
@param coordinator: an L{IExclusiveWorker} which will coordinate access
|
||||
to resources on this L{Team}; that is to say, an
|
||||
L{IExclusiveWorker} whose C{do} method ensures that its given work
|
||||
will be executed in a mutually exclusive context, not in parallel
|
||||
with other work enqueued by C{do} (although possibly in parallel
|
||||
with the caller).
|
||||
|
||||
@param createWorker: A 0-argument callable that will create an
|
||||
L{IWorker} to perform work.
|
||||
|
||||
@param logException: A 0-argument callable called in an exception
|
||||
context when the work passed to C{do} raises an exception.
|
||||
"""
|
||||
self._quit = Quit()
|
||||
self._coordinator = coordinator
|
||||
self._createWorker = createWorker
|
||||
self._logException = logException
|
||||
|
||||
# Don't touch these except from the coordinator.
|
||||
self._idle: Set[IWorker] = set()
|
||||
self._busyCount = 0
|
||||
self._pending: "deque[Callable[..., object]]" = deque()
|
||||
self._shouldQuitCoordinator = False
|
||||
self._toShrink = 0
|
||||
|
||||
def statistics(self) -> Statistics:
|
||||
"""
|
||||
Gather information on the current status of this L{Team}.
|
||||
|
||||
@return: a L{Statistics} describing the current state of this L{Team}.
|
||||
"""
|
||||
return Statistics(len(self._idle), self._busyCount, len(self._pending))
|
||||
|
||||
def grow(self, n: int) -> None:
|
||||
"""
|
||||
Increase the the number of idle workers by C{n}.
|
||||
|
||||
@param n: The number of new idle workers to create.
|
||||
@type n: L{int}
|
||||
"""
|
||||
self._quit.check()
|
||||
|
||||
@self._coordinator.do
|
||||
def createOneWorker() -> None:
|
||||
for x in range(n):
|
||||
worker = self._createWorker()
|
||||
if worker is None:
|
||||
return
|
||||
self._recycleWorker(worker)
|
||||
|
||||
def shrink(self, n: Optional[int] = None) -> None:
|
||||
"""
|
||||
Decrease the number of idle workers by C{n}.
|
||||
|
||||
@param n: The number of idle workers to shut down, or L{None} (or
|
||||
unspecified) to shut down all workers.
|
||||
@type n: L{int} or L{None}
|
||||
"""
|
||||
self._quit.check()
|
||||
self._coordinator.do(lambda: self._quitIdlers(n))
|
||||
|
||||
def _quitIdlers(self, n: Optional[int] = None) -> None:
|
||||
"""
|
||||
The implmentation of C{shrink}, performed by the coordinator worker.
|
||||
|
||||
@param n: see L{Team.shrink}
|
||||
"""
|
||||
if n is None:
|
||||
n = len(self._idle) + self._busyCount
|
||||
for x in range(n):
|
||||
if self._idle:
|
||||
self._idle.pop().quit()
|
||||
else:
|
||||
self._toShrink += 1
|
||||
if self._shouldQuitCoordinator and self._busyCount == 0:
|
||||
self._coordinator.quit()
|
||||
|
||||
def do(self, task: Callable[[], object]) -> None:
|
||||
"""
|
||||
Perform some work in a worker created by C{createWorker}.
|
||||
|
||||
@param task: the callable to run
|
||||
"""
|
||||
self._quit.check()
|
||||
self._coordinator.do(lambda: self._coordinateThisTask(task))
|
||||
|
||||
def _coordinateThisTask(self, task: Callable[..., object]) -> None:
|
||||
"""
|
||||
Select a worker to dispatch to, either an idle one or a new one, and
|
||||
perform it.
|
||||
|
||||
This method should run on the coordinator worker.
|
||||
|
||||
@param task: the task to dispatch
|
||||
@type task: 0-argument callable
|
||||
"""
|
||||
worker = self._idle.pop() if self._idle else self._createWorker()
|
||||
if worker is None:
|
||||
# The createWorker method may return None if we're out of resources
|
||||
# to create workers.
|
||||
self._pending.append(task)
|
||||
return
|
||||
not_none_worker = worker
|
||||
self._busyCount += 1
|
||||
|
||||
@worker.do
|
||||
def doWork() -> None:
|
||||
try:
|
||||
task()
|
||||
except BaseException:
|
||||
self._logException()
|
||||
|
||||
@self._coordinator.do
|
||||
def idleAndPending() -> None:
|
||||
self._busyCount -= 1
|
||||
self._recycleWorker(not_none_worker)
|
||||
|
||||
def _recycleWorker(self, worker: IWorker) -> None:
|
||||
"""
|
||||
Called only from coordinator.
|
||||
|
||||
Recycle the given worker into the idle pool.
|
||||
|
||||
@param worker: a worker created by C{createWorker} and now idle.
|
||||
@type worker: L{IWorker}
|
||||
"""
|
||||
self._idle.add(worker)
|
||||
if self._pending:
|
||||
# Re-try the first enqueued thing.
|
||||
# (Explicitly do _not_ honor _quit.)
|
||||
self._coordinateThisTask(self._pending.popleft())
|
||||
elif self._shouldQuitCoordinator:
|
||||
self._quitIdlers()
|
||||
elif self._toShrink > 0:
|
||||
self._toShrink -= 1
|
||||
self._idle.remove(worker)
|
||||
worker.quit()
|
||||
|
||||
def quit(self) -> None:
|
||||
"""
|
||||
Stop doing work and shut down all idle workers.
|
||||
"""
|
||||
self._quit.set()
|
||||
# In case all the workers are idle when we do this.
|
||||
|
||||
@self._coordinator.do
|
||||
def startFinishing() -> None:
|
||||
self._shouldQuitCoordinator = True
|
||||
self._quitIdlers()
|
||||
@@ -0,0 +1,156 @@
|
||||
# -*- test-case-name: twisted._threads.test.test_threadworker -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of an L{IWorker} based on native threads and queues.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Literal, Protocol, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import threading
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from ._convenience import Quit
|
||||
from ._ithreads import IExclusiveWorker
|
||||
|
||||
|
||||
class Stop(Enum):
|
||||
Thread = auto()
|
||||
|
||||
|
||||
StopThread = Stop.Thread
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class SimpleQueue(Protocol[T]):
|
||||
def put(self, item: T) -> None:
|
||||
...
|
||||
|
||||
def get(self) -> T:
|
||||
...
|
||||
|
||||
|
||||
# when the sentinel value is a literal in a union, this is how iter works
|
||||
smartiter: Callable[[Callable[[], T | U], U], Iterator[T]]
|
||||
smartiter = iter # type:ignore[assignment]
|
||||
|
||||
|
||||
@implementer(IExclusiveWorker)
|
||||
class ThreadWorker:
|
||||
"""
|
||||
An L{IExclusiveWorker} implemented based on a single thread and a queue.
|
||||
|
||||
This worker ensures exclusivity (i.e. it is an L{IExclusiveWorker} and not
|
||||
an L{IWorker}) by performing all of the work passed to C{do} on the I{same}
|
||||
thread.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
startThread: Callable[[Callable[[], object]], object],
|
||||
queue: SimpleQueue[Callable[[], object] | Literal[Stop.Thread]],
|
||||
):
|
||||
"""
|
||||
Create a L{ThreadWorker} with a function to start a thread and a queue
|
||||
to use to communicate with that thread.
|
||||
|
||||
@param startThread: a callable that takes a callable to run in another
|
||||
thread.
|
||||
|
||||
@param queue: A L{Queue} to use to give tasks to the thread created by
|
||||
C{startThread}.
|
||||
"""
|
||||
self._q = queue
|
||||
self._hasQuit = Quit()
|
||||
|
||||
def work() -> None:
|
||||
for task in smartiter(queue.get, StopThread):
|
||||
task()
|
||||
|
||||
startThread(work)
|
||||
|
||||
def do(self, task: Callable[[], None]) -> None:
|
||||
"""
|
||||
Perform the given task on the thread owned by this L{ThreadWorker}.
|
||||
|
||||
@param task: the function to call on a thread.
|
||||
"""
|
||||
self._hasQuit.check()
|
||||
self._q.put(task)
|
||||
|
||||
def quit(self) -> None:
|
||||
"""
|
||||
Reject all future work and stop the thread started by C{__init__}.
|
||||
"""
|
||||
# Reject all future work. Set this _before_ enqueueing _stop, so
|
||||
# that no work is ever enqueued _after_ _stop.
|
||||
self._hasQuit.set()
|
||||
self._q.put(StopThread)
|
||||
|
||||
|
||||
class SimpleLock(Protocol):
|
||||
def acquire(self) -> bool:
|
||||
...
|
||||
|
||||
def release(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@implementer(IExclusiveWorker)
|
||||
class LockWorker:
|
||||
"""
|
||||
An L{IWorker} implemented based on a mutual-exclusion lock.
|
||||
"""
|
||||
|
||||
def __init__(self, lock: SimpleLock, local: threading.local):
|
||||
"""
|
||||
@param lock: A mutual-exclusion lock, with C{acquire} and C{release}
|
||||
methods.
|
||||
@type lock: L{threading.Lock}
|
||||
|
||||
@param local: Local storage.
|
||||
@type local: L{threading.local}
|
||||
"""
|
||||
self._quit = Quit()
|
||||
self._lock: SimpleLock | None = lock
|
||||
self._local = local
|
||||
|
||||
def do(self, work: Callable[[], None]) -> None:
|
||||
"""
|
||||
Do the given work on this thread, with the mutex acquired. If this is
|
||||
called re-entrantly, return and wait for the outer invocation to do the
|
||||
work.
|
||||
|
||||
@param work: the work to do with the lock held.
|
||||
"""
|
||||
lock = self._lock
|
||||
local = self._local
|
||||
self._quit.check()
|
||||
working = getattr(local, "working", None)
|
||||
if working is None:
|
||||
assert lock is not None, "LockWorker used after quit()"
|
||||
working = local.working = []
|
||||
working.append(work)
|
||||
lock.acquire()
|
||||
try:
|
||||
while working:
|
||||
working.pop(0)()
|
||||
finally:
|
||||
lock.release()
|
||||
local.working = None
|
||||
else:
|
||||
working.append(work)
|
||||
|
||||
def quit(self) -> None:
|
||||
"""
|
||||
Quit this L{LockWorker}.
|
||||
"""
|
||||
self._quit.set()
|
||||
self._lock = None
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted._threads.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted._threads}.
|
||||
"""
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for convenience functionality in L{twisted._threads._convenience}.
|
||||
"""
|
||||
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
from .._convenience import Quit
|
||||
from .._ithreads import AlreadyQuit
|
||||
|
||||
|
||||
class QuitTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{Quit}
|
||||
"""
|
||||
|
||||
def test_isInitiallySet(self) -> None:
|
||||
"""
|
||||
L{Quit.isSet} starts as L{False}.
|
||||
"""
|
||||
quit = Quit()
|
||||
self.assertEqual(quit.isSet, False)
|
||||
|
||||
def test_setSetsSet(self) -> None:
|
||||
"""
|
||||
L{Quit.set} sets L{Quit.isSet} to L{True}.
|
||||
"""
|
||||
quit = Quit()
|
||||
quit.set()
|
||||
self.assertEqual(quit.isSet, True)
|
||||
|
||||
def test_checkDoesNothing(self) -> None:
|
||||
"""
|
||||
L{Quit.check} initially does nothing and returns L{None}.
|
||||
"""
|
||||
quit = Quit()
|
||||
checked = quit.check() # type:ignore[func-returns-value]
|
||||
self.assertIs(checked, None)
|
||||
|
||||
def test_checkAfterSetRaises(self) -> None:
|
||||
"""
|
||||
L{Quit.check} raises L{AlreadyQuit} if L{Quit.set} has been called.
|
||||
"""
|
||||
quit = Quit()
|
||||
quit.set()
|
||||
self.assertRaises(AlreadyQuit, quit.check)
|
||||
|
||||
def test_setTwiceRaises(self) -> None:
|
||||
"""
|
||||
L{Quit.set} raises L{AlreadyQuit} if it has been called previously.
|
||||
"""
|
||||
quit = Quit()
|
||||
quit.set()
|
||||
self.assertRaises(AlreadyQuit, quit.set)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted._threads._memory}.
|
||||
"""
|
||||
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
from .. import AlreadyQuit, IWorker, createMemoryWorker
|
||||
|
||||
|
||||
class MemoryWorkerTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{MemoryWorker}.
|
||||
"""
|
||||
|
||||
def test_createWorkerAndPerform(self) -> None:
|
||||
"""
|
||||
L{createMemoryWorker} creates an L{IWorker} and a callable that can
|
||||
perform work on it. The performer returns C{True} if it accomplished
|
||||
useful work.
|
||||
"""
|
||||
worker, performer = createMemoryWorker()
|
||||
verifyObject(IWorker, worker)
|
||||
done = []
|
||||
worker.do(lambda: done.append(3))
|
||||
worker.do(lambda: done.append(4))
|
||||
self.assertEqual(done, [])
|
||||
self.assertEqual(performer(), True)
|
||||
self.assertEqual(done, [3])
|
||||
self.assertEqual(performer(), True)
|
||||
self.assertEqual(done, [3, 4])
|
||||
|
||||
def test_quitQuits(self) -> None:
|
||||
"""
|
||||
Calling C{quit} on the worker returned by L{createMemoryWorker} causes
|
||||
its C{do} and C{quit} methods to raise L{AlreadyQuit}; its C{perform}
|
||||
callable will start raising L{AlreadyQuit} when the work already
|
||||
provided to C{do} has been exhausted.
|
||||
"""
|
||||
worker, performer = createMemoryWorker()
|
||||
done = []
|
||||
|
||||
def moreWork() -> None:
|
||||
done.append(7)
|
||||
|
||||
worker.do(moreWork)
|
||||
worker.quit()
|
||||
self.assertRaises(AlreadyQuit, worker.do, moreWork)
|
||||
self.assertRaises(AlreadyQuit, worker.quit)
|
||||
performer()
|
||||
self.assertEqual(done, [7])
|
||||
self.assertEqual(performer(), False)
|
||||
|
||||
def test_performWhenNothingToDoYet(self) -> None:
|
||||
"""
|
||||
The C{perform} callable returned by L{createMemoryWorker} will return
|
||||
no result when there's no work to do yet. Since there is no work to
|
||||
do, the performer returns C{False}.
|
||||
"""
|
||||
worker, performer = createMemoryWorker()
|
||||
self.assertEqual(performer(), False)
|
||||
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted._threads._team}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from twisted.python.components import proxyForInterface
|
||||
from twisted.python.context import call, get
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
from .. import AlreadyQuit, IWorker, Team, createMemoryWorker
|
||||
|
||||
|
||||
class ContextualWorker(proxyForInterface(IWorker, "_realWorker")): # type: ignore[misc]
|
||||
"""
|
||||
A worker implementation that supplies a context.
|
||||
"""
|
||||
|
||||
def __init__(self, realWorker: IWorker, **ctx: object) -> None:
|
||||
"""
|
||||
Create with a real worker and a context.
|
||||
"""
|
||||
self._realWorker = realWorker
|
||||
self._context = ctx
|
||||
|
||||
def do(self, work: Callable[[], object]) -> None:
|
||||
"""
|
||||
Perform the given work with the context given to __init__.
|
||||
|
||||
@param work: the work to pass on to the real worker.
|
||||
"""
|
||||
super().do(lambda: call(self._context, work))
|
||||
|
||||
|
||||
class TeamTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{Team}
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up a L{Team} with inspectable, synchronous workers that can be
|
||||
single-stepped.
|
||||
"""
|
||||
coordinator, self.coordinateOnce = createMemoryWorker()
|
||||
self.coordinator = ContextualWorker(coordinator, worker="coordinator")
|
||||
self.workerPerformers: list[Callable[[], object]] = []
|
||||
self.allWorkersEver: list[ContextualWorker] = []
|
||||
self.allUnquitWorkers: list[ContextualWorker] = []
|
||||
self.activePerformers: list[Callable[[], object]] = []
|
||||
self.noMoreWorkers = lambda: False
|
||||
|
||||
def createWorker() -> ContextualWorker | None:
|
||||
if self.noMoreWorkers():
|
||||
return None
|
||||
worker, performer = createMemoryWorker()
|
||||
self.workerPerformers.append(performer)
|
||||
self.activePerformers.append(performer)
|
||||
cw = ContextualWorker(worker, worker=len(self.workerPerformers))
|
||||
self.allWorkersEver.append(cw)
|
||||
self.allUnquitWorkers.append(cw)
|
||||
realQuit = cw.quit
|
||||
|
||||
def quitAndRemove() -> None:
|
||||
realQuit()
|
||||
self.allUnquitWorkers.remove(cw)
|
||||
self.activePerformers.remove(performer)
|
||||
|
||||
cw.quit = quitAndRemove
|
||||
return cw
|
||||
|
||||
self.failures: list[Failure] = []
|
||||
|
||||
def logException() -> None:
|
||||
self.failures.append(Failure())
|
||||
|
||||
self.team = Team(coordinator, createWorker, logException)
|
||||
|
||||
def coordinate(self) -> bool:
|
||||
"""
|
||||
Perform all work currently scheduled in the coordinator.
|
||||
|
||||
@return: whether any coordination work was performed; if the
|
||||
coordinator was idle when this was called, return L{False}
|
||||
(otherwise L{True}).
|
||||
"""
|
||||
did = False
|
||||
while self.coordinateOnce():
|
||||
did = True
|
||||
return did
|
||||
|
||||
def performAllOutstandingWork(self) -> None:
|
||||
"""
|
||||
Perform all work on the coordinator and worker performers that needs to
|
||||
be done.
|
||||
"""
|
||||
continuing = True
|
||||
while continuing:
|
||||
continuing = self.coordinate()
|
||||
for performer in self.workerPerformers:
|
||||
if performer in self.activePerformers:
|
||||
performer()
|
||||
continuing = continuing or self.coordinate()
|
||||
|
||||
def test_doDoesWorkInWorker(self) -> None:
|
||||
"""
|
||||
L{Team.do} does the work in a worker created by the createWorker
|
||||
callable.
|
||||
"""
|
||||
who = None
|
||||
|
||||
def something() -> None:
|
||||
nonlocal who
|
||||
who = get("worker")
|
||||
|
||||
self.team.do(something)
|
||||
self.coordinate()
|
||||
self.assertEqual(self.team.statistics().busyWorkerCount, 1)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(who, 1)
|
||||
self.assertEqual(self.team.statistics().busyWorkerCount, 0)
|
||||
|
||||
def test_initialStatistics(self) -> None:
|
||||
"""
|
||||
L{Team.statistics} returns an object with idleWorkerCount,
|
||||
busyWorkerCount, and backloggedWorkCount integer attributes.
|
||||
"""
|
||||
stats = self.team.statistics()
|
||||
self.assertEqual(stats.idleWorkerCount, 0)
|
||||
self.assertEqual(stats.busyWorkerCount, 0)
|
||||
self.assertEqual(stats.backloggedWorkCount, 0)
|
||||
|
||||
def test_growCreatesIdleWorkers(self) -> None:
|
||||
"""
|
||||
L{Team.grow} increases the number of available idle workers.
|
||||
"""
|
||||
self.team.grow(5)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.workerPerformers), 5)
|
||||
|
||||
def test_growCreateLimit(self) -> None:
|
||||
"""
|
||||
L{Team.grow} increases the number of available idle workers until the
|
||||
C{createWorker} callable starts returning None.
|
||||
"""
|
||||
self.noMoreWorkers = lambda: len(self.allWorkersEver) >= 3
|
||||
self.team.grow(5)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allWorkersEver), 3)
|
||||
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
|
||||
|
||||
def test_shrinkQuitsWorkers(self) -> None:
|
||||
"""
|
||||
L{Team.shrink} will quit the given number of workers.
|
||||
"""
|
||||
self.team.grow(5)
|
||||
self.performAllOutstandingWork()
|
||||
self.team.shrink(3)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 2)
|
||||
|
||||
def test_shrinkToZero(self) -> None:
|
||||
"""
|
||||
L{Team.shrink} with no arguments will stop all outstanding workers.
|
||||
"""
|
||||
self.team.grow(10)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 10)
|
||||
self.team.shrink()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 10)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 0)
|
||||
|
||||
def test_moreWorkWhenNoWorkersAvailable(self) -> None:
|
||||
"""
|
||||
When no additional workers are available, the given work is backlogged,
|
||||
and then performed later when the work was.
|
||||
"""
|
||||
self.team.grow(3)
|
||||
self.coordinate()
|
||||
times = 0
|
||||
|
||||
def something() -> None:
|
||||
nonlocal times
|
||||
times += 1
|
||||
|
||||
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
|
||||
for i in range(3):
|
||||
self.team.do(something)
|
||||
# Make progress on the coordinator but do _not_ actually complete the
|
||||
# work, yet.
|
||||
self.coordinate()
|
||||
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
|
||||
self.noMoreWorkers = lambda: True
|
||||
self.team.do(something)
|
||||
self.coordinate()
|
||||
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
|
||||
self.assertEqual(self.team.statistics().backloggedWorkCount, 1)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(self.team.statistics().backloggedWorkCount, 0)
|
||||
self.assertEqual(times, 4)
|
||||
|
||||
def test_exceptionInTask(self) -> None:
|
||||
"""
|
||||
When an exception is raised in a task passed to L{Team.do}, the
|
||||
C{logException} given to the L{Team} at construction is invoked in the
|
||||
exception context.
|
||||
"""
|
||||
self.team.do(lambda: 1 / 0)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.failures), 1)
|
||||
self.assertEqual(self.failures[0].type, ZeroDivisionError)
|
||||
|
||||
def test_quit(self) -> None:
|
||||
"""
|
||||
L{Team.quit} causes future invocations of L{Team.do} and L{Team.quit}
|
||||
to raise L{AlreadyQuit}.
|
||||
"""
|
||||
self.team.quit()
|
||||
self.assertRaises(AlreadyQuit, self.team.quit)
|
||||
self.assertRaises(AlreadyQuit, self.team.do, list)
|
||||
|
||||
def test_quitQuits(self) -> None:
|
||||
"""
|
||||
L{Team.quit} causes all idle workers, as well as the coordinator
|
||||
worker, to quit.
|
||||
"""
|
||||
for x in range(10):
|
||||
self.team.do(list)
|
||||
self.performAllOutstandingWork()
|
||||
self.team.quit()
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 0)
|
||||
self.assertRaises(AlreadyQuit, self.coordinator.quit)
|
||||
|
||||
def test_quitQuitsLaterWhenBusy(self) -> None:
|
||||
"""
|
||||
L{Team.quit} causes all busy workers to be quit once they've finished
|
||||
the work they've been given.
|
||||
"""
|
||||
self.team.grow(10)
|
||||
for x in range(5):
|
||||
self.team.do(list)
|
||||
self.coordinate()
|
||||
self.team.quit()
|
||||
self.coordinate()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 5)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 0)
|
||||
self.assertRaises(AlreadyQuit, self.coordinator.quit)
|
||||
|
||||
def test_quitConcurrentWithWorkHappening(self) -> None:
|
||||
"""
|
||||
If work happens after L{Team.quit} sets its C{Quit} flag, but before
|
||||
any other work takes place, the L{Team} should still exit gracefully.
|
||||
"""
|
||||
self.team.do(list)
|
||||
originalSet = self.team._quit.set
|
||||
|
||||
def performWorkConcurrently() -> None:
|
||||
originalSet()
|
||||
self.performAllOutstandingWork()
|
||||
|
||||
self.team._quit.set = performWorkConcurrently # type:ignore[method-assign]
|
||||
self.team.quit()
|
||||
self.assertRaises(AlreadyQuit, self.team.quit)
|
||||
self.assertRaises(AlreadyQuit, self.team.do, list)
|
||||
|
||||
def test_shrinkWhenBusy(self) -> None:
|
||||
"""
|
||||
L{Team.shrink} will wait for busy workers to finish being busy and then
|
||||
quit them.
|
||||
"""
|
||||
for x in range(10):
|
||||
self.team.do(list)
|
||||
self.coordinate()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 10)
|
||||
# There should be 10 busy workers at this point.
|
||||
self.team.shrink(7)
|
||||
self.performAllOutstandingWork()
|
||||
self.assertEqual(len(self.allUnquitWorkers), 3)
|
||||
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted._threads._threadworker}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import weakref
|
||||
from threading import ThreadError, local
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase
|
||||
from .. import AlreadyQuit, LockWorker, ThreadWorker
|
||||
from .._threadworker import SimpleLock
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class FakeQueueEmpty(Exception):
|
||||
"""
|
||||
L{FakeQueue}'s C{get} has exhausted the queue.
|
||||
"""
|
||||
|
||||
|
||||
class WouldDeadlock(Exception):
|
||||
"""
|
||||
If this were a real lock, you'd be deadlocked because the lock would be
|
||||
double-acquired.
|
||||
"""
|
||||
|
||||
|
||||
class FakeThread:
|
||||
"""
|
||||
A fake L{threading.Thread}.
|
||||
|
||||
@ivar target: A target function to run.
|
||||
|
||||
@ivar started: Has this thread been started?
|
||||
@type started: L{bool}
|
||||
"""
|
||||
|
||||
def __init__(self, target: Callable[[], object]) -> None:
|
||||
"""
|
||||
Create a L{FakeThread} with a target.
|
||||
"""
|
||||
self.target = target
|
||||
self.started = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Set the "started" flag.
|
||||
"""
|
||||
self.started = True
|
||||
|
||||
|
||||
class FakeQueue(Generic[T]):
|
||||
"""
|
||||
A fake L{Queue} implementing C{put} and C{get}.
|
||||
|
||||
@ivar items: A lit of items placed by C{put} but not yet retrieved by
|
||||
C{get}.
|
||||
@type items: L{list}
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Create a L{FakeQueue}.
|
||||
"""
|
||||
self.items: list[T] = []
|
||||
|
||||
def put(self, item: T) -> None:
|
||||
"""
|
||||
Put an item into the queue for later retrieval by L{FakeQueue.get}.
|
||||
|
||||
@param item: any object
|
||||
"""
|
||||
self.items.append(item)
|
||||
|
||||
def get(self) -> T:
|
||||
"""
|
||||
Get an item.
|
||||
|
||||
@return: an item previously put by C{put}.
|
||||
"""
|
||||
if not self.items:
|
||||
raise FakeQueueEmpty()
|
||||
return self.items.pop(0)
|
||||
|
||||
|
||||
class FakeLock:
|
||||
"""
|
||||
A stand-in for L{threading.Lock}.
|
||||
|
||||
@ivar acquired: Whether this lock is presently acquired.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Create a lock in the un-acquired state.
|
||||
"""
|
||||
self.acquired = False
|
||||
|
||||
def acquire(self) -> bool:
|
||||
"""
|
||||
Acquire the lock. Raise an exception if the lock is already acquired.
|
||||
"""
|
||||
if self.acquired:
|
||||
raise WouldDeadlock()
|
||||
self.acquired = True
|
||||
return True
|
||||
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Release the lock. Raise an exception if the lock is not presently
|
||||
acquired.
|
||||
"""
|
||||
if not self.acquired:
|
||||
raise ThreadError()
|
||||
self.acquired = False
|
||||
|
||||
|
||||
class ThreadWorkerTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{ThreadWorker}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Create a worker with fake threads.
|
||||
"""
|
||||
self.fakeThreads: list[FakeThread] = []
|
||||
|
||||
def startThread(target: Callable[[], object]) -> FakeThread:
|
||||
newThread = FakeThread(target=target)
|
||||
newThread.start()
|
||||
self.fakeThreads.append(newThread)
|
||||
return newThread
|
||||
|
||||
self.worker = ThreadWorker(startThread, FakeQueue())
|
||||
|
||||
def test_startsThreadAndPerformsWork(self) -> None:
|
||||
"""
|
||||
L{ThreadWorker} calls its C{createThread} callable to create a thread,
|
||||
its C{createQueue} callable to create a queue, and then the thread's
|
||||
target pulls work from that queue.
|
||||
"""
|
||||
self.assertEqual(len(self.fakeThreads), 1)
|
||||
self.assertEqual(self.fakeThreads[0].started, True)
|
||||
|
||||
done = False
|
||||
|
||||
def doIt() -> None:
|
||||
nonlocal done
|
||||
done = True
|
||||
|
||||
self.worker.do(doIt)
|
||||
self.assertEqual(done, False)
|
||||
self.assertRaises(FakeQueueEmpty, self.fakeThreads[0].target)
|
||||
self.assertEqual(done, True)
|
||||
|
||||
def test_quitPreventsFutureCalls(self) -> None:
|
||||
"""
|
||||
L{ThreadWorker.quit} causes future calls to L{ThreadWorker.do} and
|
||||
L{ThreadWorker.quit} to raise L{AlreadyQuit}.
|
||||
"""
|
||||
self.worker.quit()
|
||||
self.assertRaises(AlreadyQuit, self.worker.quit)
|
||||
self.assertRaises(AlreadyQuit, self.worker.do, list)
|
||||
|
||||
|
||||
class LockWorkerTests(SynchronousTestCase):
|
||||
"""
|
||||
Tests for L{LockWorker}.
|
||||
"""
|
||||
|
||||
def test_fakeDeadlock(self) -> None:
|
||||
"""
|
||||
The L{FakeLock} test fixture will alert us if there's a potential
|
||||
deadlock.
|
||||
"""
|
||||
lock = FakeLock()
|
||||
lock.acquire()
|
||||
self.assertRaises(WouldDeadlock, lock.acquire)
|
||||
|
||||
def test_fakeDoubleRelease(self) -> None:
|
||||
"""
|
||||
The L{FakeLock} test fixture will alert us if there's a potential
|
||||
double-release.
|
||||
"""
|
||||
lock = FakeLock()
|
||||
self.assertRaises(ThreadError, lock.release)
|
||||
lock.acquire()
|
||||
noResult = lock.release() # type:ignore[func-returns-value]
|
||||
self.assertIs(None, noResult)
|
||||
self.assertRaises(ThreadError, lock.release)
|
||||
|
||||
def test_doExecutesImmediatelyWithLock(self) -> None:
|
||||
"""
|
||||
L{LockWorker.do} immediately performs the work it's given, while the
|
||||
lock is acquired.
|
||||
"""
|
||||
storage = local()
|
||||
lock = FakeLock()
|
||||
worker = LockWorker(lock, storage)
|
||||
done = False
|
||||
acquired = False
|
||||
|
||||
def work() -> None:
|
||||
nonlocal done, acquired
|
||||
done = True
|
||||
acquired = lock.acquired
|
||||
|
||||
worker.do(work)
|
||||
self.assertEqual(done, True)
|
||||
self.assertEqual(acquired, True)
|
||||
self.assertEqual(lock.acquired, False)
|
||||
|
||||
def test_doUnwindsReentrancy(self) -> None:
|
||||
"""
|
||||
If L{LockWorker.do} is called recursively, it postpones the inner call
|
||||
until the outer one is complete.
|
||||
"""
|
||||
lock = FakeLock()
|
||||
worker = LockWorker(lock, local())
|
||||
levels = []
|
||||
acquired = []
|
||||
level = 0
|
||||
|
||||
def work() -> None:
|
||||
nonlocal level
|
||||
level += 1
|
||||
levels.append(level)
|
||||
acquired.append(lock.acquired)
|
||||
if len(levels) < 2:
|
||||
worker.do(work)
|
||||
level -= 1
|
||||
|
||||
worker.do(work)
|
||||
self.assertEqual(levels, [1, 1])
|
||||
self.assertEqual(acquired, [True, True])
|
||||
|
||||
def test_quit(self) -> None:
|
||||
"""
|
||||
L{LockWorker.quit} frees the resources associated with its lock and
|
||||
causes further calls to C{do} and C{quit} to fail.
|
||||
"""
|
||||
lock = FakeLock()
|
||||
ref = weakref.ref(lock)
|
||||
worker = LockWorker(lock, local())
|
||||
del lock
|
||||
self.assertIsNot(ref(), None)
|
||||
worker.quit()
|
||||
gc.collect()
|
||||
self.assertIs(ref(), None)
|
||||
self.assertRaises(AlreadyQuit, worker.quit)
|
||||
self.assertRaises(AlreadyQuit, worker.do, list)
|
||||
|
||||
def test_quitWhileWorking(self) -> None:
|
||||
"""
|
||||
If L{LockWorker.quit} is invoked during a call to L{LockWorker.do}, all
|
||||
recursive work scheduled with L{LockWorker.do} will be completed and
|
||||
the lock will be released.
|
||||
"""
|
||||
lock = FakeLock()
|
||||
ref = weakref.ref(lock)
|
||||
worker = LockWorker(lock, local())
|
||||
|
||||
phase1complete = False
|
||||
phase2complete = False
|
||||
phase2acquired = None
|
||||
|
||||
def phase1() -> None:
|
||||
nonlocal phase1complete
|
||||
worker.do(phase2)
|
||||
worker.quit()
|
||||
self.assertRaises(AlreadyQuit, worker.do, list)
|
||||
phase1complete = True
|
||||
|
||||
def phase2() -> None:
|
||||
nonlocal phase2complete, phase2acquired, lock
|
||||
phase2complete = True
|
||||
phase2acquired = lock.acquired
|
||||
|
||||
worker.do(phase1)
|
||||
self.assertEqual(phase1complete, True)
|
||||
self.assertEqual(phase2complete, True)
|
||||
self.assertEqual(lock.acquired, False)
|
||||
del lock
|
||||
gc.collect()
|
||||
self.assertIs(ref(), None)
|
||||
|
||||
def test_quitWhileGettingLock(self) -> None:
|
||||
"""
|
||||
If L{LockWorker.do} is called concurrently with L{LockWorker.quit}, and
|
||||
C{quit} wins the race before C{do} gets the lock attribute, then
|
||||
L{AlreadyQuit} will be raised.
|
||||
"""
|
||||
|
||||
class RacyLockWorker(LockWorker):
|
||||
@property
|
||||
def _lock(self) -> SimpleLock | None:
|
||||
self.quit()
|
||||
it: SimpleLock = self.__dict__["_lock"]
|
||||
return it
|
||||
|
||||
@_lock.setter
|
||||
def _lock(self, value: SimpleLock | None) -> None:
|
||||
self.__dict__["_lock"] = value
|
||||
|
||||
worker = RacyLockWorker(FakeLock(), local())
|
||||
self.assertRaises(AlreadyQuit, worker.do, list)
|
||||
11
.venv/lib/python3.12/site-packages/twisted/_version.py
Normal file
11
.venv/lib/python3.12/site-packages/twisted/_version.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Provides Twisted version information.
|
||||
"""
|
||||
|
||||
# This file is auto-generated! Do not edit!
|
||||
# Use `python -m incremental.update Twisted` to change this file.
|
||||
|
||||
from incremental import Version
|
||||
|
||||
__version__ = Version("Twisted", 24, 10, 0)
|
||||
__all__ = ["__version__"]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Configuration objects for Twisted Applications.
|
||||
"""
|
||||
@@ -0,0 +1,596 @@
|
||||
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
|
||||
|
||||
"""
|
||||
Implementation of L{twisted.application.internet.ClientService}, particularly
|
||||
its U{automat <https://automat.readthedocs.org/>} state machine.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from random import random as _goodEnoughRandom
|
||||
from typing import Callable, Optional, Protocol as TypingProtocol, TypeVar, Union
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from automat import TypeMachineBuilder, pep614
|
||||
|
||||
from twisted.application.service import Service
|
||||
from twisted.internet.defer import (
|
||||
CancelledError,
|
||||
Deferred,
|
||||
fail,
|
||||
maybeDeferred,
|
||||
succeed,
|
||||
)
|
||||
from twisted.internet.interfaces import (
|
||||
IAddress,
|
||||
IDelayedCall,
|
||||
IProtocol,
|
||||
IProtocolFactory,
|
||||
IReactorTime,
|
||||
IStreamClientEndpoint,
|
||||
ITransport,
|
||||
)
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _maybeGlobalReactor(maybeReactor: Optional[T]) -> T:
|
||||
"""
|
||||
@return: the argument, or the global reactor if the argument is L{None}.
|
||||
"""
|
||||
if maybeReactor is None:
|
||||
from twisted.internet import reactor
|
||||
|
||||
return reactor # type:ignore[return-value]
|
||||
else:
|
||||
return maybeReactor
|
||||
|
||||
|
||||
class _Client(TypingProtocol):
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start this L{ClientService}, initiating the connection retry loop.
|
||||
"""
|
||||
|
||||
def stop(self) -> Deferred[None]:
|
||||
"""
|
||||
Stop trying to connect and disconnect any current connection.
|
||||
|
||||
@return: a L{Deferred} that fires when all outstanding connections are
|
||||
closed and all in-progress connection attempts halted.
|
||||
"""
|
||||
|
||||
def _connectionMade(self, protocol: _ReconnectingProtocolProxy) -> None:
|
||||
"""
|
||||
A connection has been made.
|
||||
|
||||
@param protocol: The protocol of the connection.
|
||||
"""
|
||||
|
||||
def _connectionFailed(self, failure: Failure) -> None:
|
||||
"""
|
||||
Deliver connection failures to any L{ClientService.whenConnected}
|
||||
L{Deferred}s that have met their failAfterFailures threshold.
|
||||
|
||||
@param failure: the Failure to fire the L{Deferred}s with.
|
||||
"""
|
||||
|
||||
def _reconnect(self, failure: Optional[Failure] = None) -> None:
|
||||
"""
|
||||
The wait between connection attempts is done.
|
||||
"""
|
||||
|
||||
def _clientDisconnected(self, failure: Optional[Failure] = None) -> None:
|
||||
"""
|
||||
The current connection has been disconnected.
|
||||
"""
|
||||
|
||||
def whenConnected(
|
||||
self, /, failAfterFailures: Optional[int] = None
|
||||
) -> Deferred[IProtocol]:
|
||||
"""
|
||||
Retrieve the currently-connected L{Protocol}, or the next one to
|
||||
connect.
|
||||
|
||||
@param failAfterFailures: number of connection failures after which the
|
||||
Deferred will deliver a Failure (None means the Deferred will only
|
||||
fail if/when the service is stopped). Set this to 1 to make the
|
||||
very first connection failure signal an error. Use 2 to allow one
|
||||
failure but signal an error if the subsequent retry then fails.
|
||||
|
||||
@return: a Deferred that fires with a protocol produced by the factory
|
||||
passed to C{__init__}. It may:
|
||||
|
||||
- fire with L{IProtocol}
|
||||
|
||||
- fail with L{CancelledError} when the service is stopped
|
||||
|
||||
- fail with e.g.
|
||||
L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
|
||||
L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
|
||||
when the number of consecutive failed connection attempts
|
||||
equals the value of "failAfterFailures"
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IProtocol)
|
||||
class _ReconnectingProtocolProxy:
|
||||
"""
|
||||
A proxy for a Protocol to provide connectionLost notification to a client
|
||||
connection service, in support of reconnecting when connections are lost.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: IProtocol, lostNotification: Callable[[Failure], None]
|
||||
) -> None:
|
||||
"""
|
||||
Create a L{_ReconnectingProtocolProxy}.
|
||||
|
||||
@param protocol: the application-provided L{interfaces.IProtocol}
|
||||
provider.
|
||||
@type protocol: provider of L{interfaces.IProtocol} which may
|
||||
additionally provide L{interfaces.IHalfCloseableProtocol} and
|
||||
L{interfaces.IFileDescriptorReceiver}.
|
||||
|
||||
@param lostNotification: a 1-argument callable to invoke with the
|
||||
C{reason} when the connection is lost.
|
||||
"""
|
||||
self._protocol = protocol
|
||||
self._lostNotification = lostNotification
|
||||
|
||||
def makeConnection(self, transport: ITransport) -> None:
|
||||
self._transport = transport
|
||||
self._protocol.makeConnection(transport)
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
"""
|
||||
The connection was lost. Relay this information.
|
||||
|
||||
@param reason: The reason the connection was lost.
|
||||
|
||||
@return: the underlying protocol's result
|
||||
"""
|
||||
try:
|
||||
return self._protocol.connectionLost(reason)
|
||||
finally:
|
||||
self._lostNotification(reason)
|
||||
|
||||
def __getattr__(self, item: str) -> object:
|
||||
return getattr(self._protocol, item)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} wrapping {self._protocol!r}>"
|
||||
|
||||
|
||||
@implementer(IProtocolFactory)
|
||||
class _DisconnectFactory:
|
||||
"""
|
||||
A L{_DisconnectFactory} is a proxy for L{IProtocolFactory} that catches
|
||||
C{connectionLost} notifications and relays them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocolFactory: IProtocolFactory,
|
||||
protocolDisconnected: Callable[[Failure], None],
|
||||
) -> None:
|
||||
self._protocolFactory = protocolFactory
|
||||
self._protocolDisconnected = protocolDisconnected
|
||||
|
||||
def buildProtocol(self, addr: IAddress) -> Optional[IProtocol]:
|
||||
"""
|
||||
Create a L{_ReconnectingProtocolProxy} with the disconnect-notification
|
||||
callback we were called with.
|
||||
|
||||
@param addr: The address the connection is coming from.
|
||||
|
||||
@return: a L{_ReconnectingProtocolProxy} for a protocol produced by
|
||||
C{self._protocolFactory}
|
||||
"""
|
||||
built = self._protocolFactory.buildProtocol(addr)
|
||||
if built is None:
|
||||
return None
|
||||
return _ReconnectingProtocolProxy(built, self._protocolDisconnected)
|
||||
|
||||
def __getattr__(self, item: str) -> object:
|
||||
return getattr(self._protocolFactory, item)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<{} wrapping {!r}>".format(
|
||||
self.__class__.__name__, self._protocolFactory
|
||||
)
|
||||
|
||||
|
||||
def _deinterface(o: object) -> None:
|
||||
"""
|
||||
Remove the special runtime attributes set by L{implementer} so that a class
|
||||
can proxy through those attributes with C{__getattr__} and thereby forward
|
||||
optionally-provided interfaces by the delegated class.
|
||||
"""
|
||||
for zopeSpecial in ["__providedBy__", "__provides__", "__implemented__"]:
|
||||
delattr(o, zopeSpecial)
|
||||
|
||||
|
||||
_deinterface(_DisconnectFactory)
|
||||
_deinterface(_ReconnectingProtocolProxy)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Core:
|
||||
"""
|
||||
Shared core for ClientService state machine.
|
||||
"""
|
||||
|
||||
# required parameters
|
||||
endpoint: IStreamClientEndpoint
|
||||
factory: IProtocolFactory
|
||||
timeoutForAttempt: Callable[[int], float]
|
||||
clock: IReactorTime
|
||||
prepareConnection: Optional[Callable[[IProtocol], object]]
|
||||
|
||||
# internal state
|
||||
stopWaiters: list[Deferred[None]] = field(default_factory=list)
|
||||
awaitingConnected: list[tuple[Deferred[IProtocol], Optional[int]]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
failedAttempts: int = 0
|
||||
log: Logger = Logger()
|
||||
|
||||
def waitForStop(self) -> Deferred[None]:
|
||||
self.stopWaiters.append(Deferred())
|
||||
return self.stopWaiters[-1]
|
||||
|
||||
def unawait(self, value: Union[IProtocol, Failure]) -> None:
|
||||
self.awaitingConnected, waiting = [], self.awaitingConnected
|
||||
for w, remaining in waiting:
|
||||
w.callback(value)
|
||||
|
||||
def cancelConnectWaiters(self) -> None:
|
||||
self.unawait(Failure(CancelledError()))
|
||||
|
||||
def finishStopping(self) -> None:
|
||||
self.stopWaiters, waiting = [], self.stopWaiters
|
||||
for w in waiting:
|
||||
w.callback(None)
|
||||
|
||||
|
||||
def makeMachine() -> Callable[[_Core], _Client]:
|
||||
machine = TypeMachineBuilder(_Client, _Core)
|
||||
|
||||
def waitForRetry(
|
||||
c: _Client, s: _Core, failure: Optional[Failure] = None
|
||||
) -> IDelayedCall:
|
||||
s.failedAttempts += 1
|
||||
delay = s.timeoutForAttempt(s.failedAttempts)
|
||||
s.log.info(
|
||||
"Scheduling retry {attempt} to connect {endpoint} in {delay} seconds.",
|
||||
attempt=s.failedAttempts,
|
||||
endpoint=s.endpoint,
|
||||
delay=delay,
|
||||
)
|
||||
return s.clock.callLater(delay, c._reconnect)
|
||||
|
||||
def rememberConnection(
|
||||
c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
|
||||
) -> _ReconnectingProtocolProxy:
|
||||
s.failedAttempts = 0
|
||||
s.unawait(protocol._protocol)
|
||||
return protocol
|
||||
|
||||
def attemptConnection(
|
||||
c: _Client, s: _Core, failure: Optional[Failure] = None
|
||||
) -> Deferred[_ReconnectingProtocolProxy]:
|
||||
factoryProxy = _DisconnectFactory(s.factory, c._clientDisconnected)
|
||||
connecting: Deferred[IProtocol] = s.endpoint.connect(factoryProxy)
|
||||
|
||||
def prepare(
|
||||
protocol: _ReconnectingProtocolProxy,
|
||||
) -> Deferred[_ReconnectingProtocolProxy]:
|
||||
if s.prepareConnection is not None:
|
||||
return maybeDeferred(s.prepareConnection, protocol).addCallback(
|
||||
lambda _: protocol
|
||||
)
|
||||
return succeed(protocol)
|
||||
|
||||
# endpoint.connect() is actually generic on the type of the protocol,
|
||||
# but this is not expressible via zope.interface, so we have to cast
|
||||
# https://github.com/Shoobx/mypy-zope/issues/95
|
||||
connectingProxy: Deferred[_ReconnectingProtocolProxy]
|
||||
connectingProxy = connecting # type:ignore[assignment]
|
||||
(
|
||||
connectingProxy.addCallback(prepare)
|
||||
.addCallback(c._connectionMade)
|
||||
.addErrback(c._connectionFailed)
|
||||
)
|
||||
return connectingProxy
|
||||
|
||||
# States:
|
||||
Init = machine.state("Init")
|
||||
Connecting = machine.state("Connecting", attemptConnection)
|
||||
Stopped = machine.state("Stopped")
|
||||
Waiting = machine.state("Waiting", waitForRetry)
|
||||
Connected = machine.state("Connected", rememberConnection)
|
||||
Disconnecting = machine.state("Disconnecting")
|
||||
Restarting = machine.state("Restarting")
|
||||
Stopped = machine.state("Stopped")
|
||||
|
||||
# Behavior-less state transitions:
|
||||
Init.upon(_Client.start).to(Connecting).returns(None)
|
||||
Connecting.upon(_Client.start).loop().returns(None)
|
||||
Connecting.upon(_Client._connectionMade).to(Connected).returns(None)
|
||||
Waiting.upon(_Client.start).loop().returns(None)
|
||||
Waiting.upon(_Client._reconnect).to(Connecting).returns(None)
|
||||
Connected.upon(_Client._connectionFailed).to(Waiting).returns(None)
|
||||
Connected.upon(_Client.start).loop().returns(None)
|
||||
Connected.upon(_Client._clientDisconnected).to(Waiting).returns(None)
|
||||
Disconnecting.upon(_Client.start).to(Restarting).returns(None)
|
||||
Restarting.upon(_Client.start).to(Restarting).returns(None)
|
||||
Stopped.upon(_Client.start).to(Connecting).returns(None)
|
||||
|
||||
# Behavior-full state transitions:
|
||||
@pep614(Init.upon(_Client.stop).to(Stopped))
|
||||
@pep614(Stopped.upon(_Client.stop).to(Stopped))
|
||||
def immediateStop(c: _Client, s: _Core) -> Deferred[None]:
|
||||
return succeed(None)
|
||||
|
||||
@pep614(Connecting.upon(_Client.stop).to(Disconnecting))
|
||||
def connectingStop(
|
||||
c: _Client, s: _Core, attempt: Deferred[_ReconnectingProtocolProxy]
|
||||
) -> Deferred[None]:
|
||||
waited = s.waitForStop()
|
||||
attempt.cancel()
|
||||
return waited
|
||||
|
||||
@pep614(Connecting.upon(_Client._connectionFailed, nodata=True).to(Waiting))
|
||||
def failedWhenConnecting(c: _Client, s: _Core, failure: Failure) -> None:
|
||||
ready = []
|
||||
notReady: list[tuple[Deferred[IProtocol], Optional[int]]] = []
|
||||
for w, remaining in s.awaitingConnected:
|
||||
if remaining is None:
|
||||
notReady.append((w, remaining))
|
||||
elif remaining <= 1:
|
||||
ready.append(w)
|
||||
else:
|
||||
notReady.append((w, remaining - 1))
|
||||
s.awaitingConnected = notReady
|
||||
for w in ready:
|
||||
w.callback(failure)
|
||||
|
||||
@pep614(Waiting.upon(_Client.stop).to(Stopped))
|
||||
def stop(c: _Client, s: _Core, futureRetry: IDelayedCall) -> Deferred[None]:
|
||||
waited = s.waitForStop()
|
||||
s.cancelConnectWaiters()
|
||||
futureRetry.cancel()
|
||||
s.finishStopping()
|
||||
return waited
|
||||
|
||||
@pep614(Connected.upon(_Client.stop).to(Disconnecting))
|
||||
def stopWhileConnected(
|
||||
c: _Client, s: _Core, protocol: _ReconnectingProtocolProxy
|
||||
) -> Deferred[None]:
|
||||
waited = s.waitForStop()
|
||||
protocol._transport.loseConnection()
|
||||
return waited
|
||||
|
||||
@pep614(Connected.upon(_Client.whenConnected).loop())
|
||||
def whenConnectedWhenConnected(
|
||||
c: _Client,
|
||||
s: _Core,
|
||||
protocol: _ReconnectingProtocolProxy,
|
||||
failAfterFailures: Optional[int] = None,
|
||||
) -> Deferred[IProtocol]:
|
||||
return succeed(protocol._protocol)
|
||||
|
||||
@pep614(Disconnecting.upon(_Client.stop).loop())
|
||||
@pep614(Restarting.upon(_Client.stop).to(Disconnecting))
|
||||
def discoStop(c: _Client, s: _Core) -> Deferred[None]:
|
||||
return s.waitForStop()
|
||||
|
||||
@pep614(Disconnecting.upon(_Client._connectionFailed).to(Stopped))
|
||||
@pep614(Disconnecting.upon(_Client._clientDisconnected).to(Stopped))
|
||||
def disconnectingFinished(
|
||||
c: _Client, s: _Core, failure: Optional[Failure] = None
|
||||
) -> None:
|
||||
s.cancelConnectWaiters()
|
||||
s.finishStopping()
|
||||
|
||||
@pep614(Connecting.upon(_Client.whenConnected, nodata=True).loop())
|
||||
@pep614(Waiting.upon(_Client.whenConnected, nodata=True).loop())
|
||||
@pep614(Init.upon(_Client.whenConnected).to(Init))
|
||||
@pep614(Restarting.upon(_Client.whenConnected).to(Restarting))
|
||||
@pep614(Disconnecting.upon(_Client.whenConnected).to(Disconnecting))
|
||||
def awaitingConnection(
|
||||
c: _Client, s: _Core, failAfterFailures: Optional[int] = None
|
||||
) -> Deferred[IProtocol]:
|
||||
result: Deferred[IProtocol] = Deferred()
|
||||
s.awaitingConnected.append((result, failAfterFailures))
|
||||
return result
|
||||
|
||||
@pep614(Restarting.upon(_Client._clientDisconnected).to(Connecting))
|
||||
def restartDone(c: _Client, s: _Core, failure: Optional[Failure] = None) -> None:
|
||||
s.finishStopping()
|
||||
|
||||
@pep614(Stopped.upon(_Client.whenConnected).to(Stopped))
|
||||
def notGoingToConnect(
|
||||
c: _Client, s: _Core, failAfterFailures: Optional[int] = None
|
||||
) -> Deferred[IProtocol]:
|
||||
return fail(CancelledError())
|
||||
|
||||
return machine.build()
|
||||
|
||||
|
||||
def backoffPolicy(
|
||||
initialDelay: float = 1.0,
|
||||
maxDelay: float = 60.0,
|
||||
factor: float = 1.5,
|
||||
jitter: Callable[[], float] = _goodEnoughRandom,
|
||||
) -> Callable[[int], float]:
|
||||
"""
|
||||
A timeout policy for L{ClientService} which computes an exponential backoff
|
||||
interval with configurable parameters.
|
||||
|
||||
@since: 16.1.0
|
||||
|
||||
@param initialDelay: Delay for the first reconnection attempt (default
|
||||
1.0s).
|
||||
@type initialDelay: L{float}
|
||||
|
||||
@param maxDelay: Maximum number of seconds between connection attempts
|
||||
(default 60 seconds, or one minute). Note that this value is before
|
||||
jitter is applied, so the actual maximum possible delay is this value
|
||||
plus the maximum possible result of C{jitter()}.
|
||||
@type maxDelay: L{float}
|
||||
|
||||
@param factor: A multiplicative factor by which the delay grows on each
|
||||
failed reattempt. Default: 1.5.
|
||||
@type factor: L{float}
|
||||
|
||||
@param jitter: A 0-argument callable that introduces noise into the delay.
|
||||
By default, C{random.random}, i.e. a pseudorandom floating-point value
|
||||
between zero and one.
|
||||
@type jitter: 0-argument callable returning L{float}
|
||||
|
||||
@return: a 1-argument callable that, given an attempt count, returns a
|
||||
floating point number; the number of seconds to delay.
|
||||
@rtype: see L{ClientService.__init__}'s C{retryPolicy} argument.
|
||||
"""
|
||||
|
||||
def policy(attempt: int) -> float:
|
||||
try:
|
||||
delay = min(initialDelay * (factor ** min(100, attempt)), maxDelay)
|
||||
except OverflowError:
|
||||
delay = maxDelay
|
||||
return delay + jitter()
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
_defaultPolicy = backoffPolicy()
|
||||
ClientMachine = makeMachine()
|
||||
|
||||
|
||||
class ClientService(Service):
|
||||
"""
|
||||
A L{ClientService} maintains a single outgoing connection to a client
|
||||
endpoint, reconnecting after a configurable timeout when a connection
|
||||
fails, either before or after connecting.
|
||||
|
||||
@since: 16.1.0
|
||||
"""
|
||||
|
||||
_log = Logger()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: IStreamClientEndpoint,
|
||||
factory: IProtocolFactory,
|
||||
retryPolicy: Optional[Callable[[int], float]] = None,
|
||||
clock: Optional[IReactorTime] = None,
|
||||
prepareConnection: Optional[Callable[[IProtocol], object]] = None,
|
||||
):
|
||||
"""
|
||||
@param endpoint: A L{stream client endpoint
|
||||
<interfaces.IStreamClientEndpoint>} provider which will be used to
|
||||
connect when the service starts.
|
||||
|
||||
@param factory: A L{protocol factory <interfaces.IProtocolFactory>}
|
||||
which will be used to create clients for the endpoint.
|
||||
|
||||
@param retryPolicy: A policy configuring how long L{ClientService} will
|
||||
wait between attempts to connect to C{endpoint}; a callable taking
|
||||
(the number of failed connection attempts made in a row (L{int}))
|
||||
and returning the number of seconds to wait before making another
|
||||
attempt.
|
||||
|
||||
@param clock: The clock used to schedule reconnection. It's mainly
|
||||
useful to be parametrized in tests. If the factory is serialized,
|
||||
this attribute will not be serialized, and the default value (the
|
||||
reactor) will be restored when deserialized.
|
||||
|
||||
@param prepareConnection: A single argument L{callable} that may return
|
||||
a L{Deferred}. It will be called once with the L{protocol
|
||||
<interfaces.IProtocol>} each time a new connection is made. It may
|
||||
call methods on the protocol to prepare it for use (e.g.
|
||||
authenticate) or validate it (check its health).
|
||||
|
||||
The C{prepareConnection} callable may raise an exception or return
|
||||
a L{Deferred} which fails to reject the connection. A rejected
|
||||
connection is not used to fire an L{Deferred} returned by
|
||||
L{whenConnected}. Instead, L{ClientService} handles the failure
|
||||
and continues as if the connection attempt were a failure
|
||||
(incrementing the counter passed to C{retryPolicy}).
|
||||
|
||||
L{Deferred}s returned by L{whenConnected} will not fire until any
|
||||
L{Deferred} returned by the C{prepareConnection} callable fire.
|
||||
Otherwise its successful return value is consumed, but ignored.
|
||||
|
||||
Present Since Twisted 18.7.0
|
||||
"""
|
||||
clock = _maybeGlobalReactor(clock)
|
||||
retryPolicy = _defaultPolicy if retryPolicy is None else retryPolicy
|
||||
|
||||
self._machine: _Client = ClientMachine(
|
||||
_Core(
|
||||
endpoint,
|
||||
factory,
|
||||
retryPolicy,
|
||||
clock,
|
||||
prepareConnection=prepareConnection,
|
||||
log=self._log,
|
||||
)
|
||||
)
|
||||
|
||||
def whenConnected(
|
||||
self, failAfterFailures: Optional[int] = None
|
||||
) -> Deferred[IProtocol]:
|
||||
"""
|
||||
Retrieve the currently-connected L{Protocol}, or the next one to
|
||||
connect.
|
||||
|
||||
@param failAfterFailures: number of connection failures after which
|
||||
the Deferred will deliver a Failure (None means the Deferred will
|
||||
only fail if/when the service is stopped). Set this to 1 to make
|
||||
the very first connection failure signal an error. Use 2 to
|
||||
allow one failure but signal an error if the subsequent retry
|
||||
then fails.
|
||||
@type failAfterFailures: L{int} or None
|
||||
|
||||
@return: a Deferred that fires with a protocol produced by the
|
||||
factory passed to C{__init__}
|
||||
@rtype: L{Deferred} that may:
|
||||
|
||||
- fire with L{IProtocol}
|
||||
|
||||
- fail with L{CancelledError} when the service is stopped
|
||||
|
||||
- fail with e.g.
|
||||
L{DNSLookupError<twisted.internet.error.DNSLookupError>} or
|
||||
L{ConnectionRefusedError<twisted.internet.error.ConnectionRefusedError>}
|
||||
when the number of consecutive failed connection attempts
|
||||
equals the value of "failAfterFailures"
|
||||
"""
|
||||
return self._machine.whenConnected(failAfterFailures)
|
||||
|
||||
def startService(self) -> None:
|
||||
"""
|
||||
Start this L{ClientService}, initiating the connection retry loop.
|
||||
"""
|
||||
if self.running:
|
||||
self._log.warn("Duplicate ClientService.startService {log_source}")
|
||||
return
|
||||
super().startService()
|
||||
self._machine.start()
|
||||
|
||||
def stopService(self) -> Deferred[None]:
|
||||
"""
|
||||
Stop attempting to reconnect and close any existing connections.
|
||||
|
||||
@return: a L{Deferred} that fires when all outstanding connections are
|
||||
closed and all in-progress connection attempts halted.
|
||||
"""
|
||||
super().stopService()
|
||||
return self._machine.stop()
|
||||
706
.venv/lib/python3.12/site-packages/twisted/application/app.py
Normal file
706
.venv/lib/python3.12/site-packages/twisted/application/app.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# -*- test-case-name: twisted.test.test_application,twisted.test.test_twistd -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import pdb
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from operator import attrgetter
|
||||
|
||||
from twisted import copyright, logger, plugin
|
||||
from twisted.application import reactors, service
|
||||
|
||||
# Expose the new implementation of installReactor at the old location.
|
||||
from twisted.application.reactors import NoSuchReactor, installReactor
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.interfaces import _ISupportsExitSignalCapturing
|
||||
from twisted.persisted import sob
|
||||
from twisted.python import failure, log, logfile, runtime, usage, util
|
||||
from twisted.python.reflect import namedAny, namedModule, qual
|
||||
|
||||
|
||||
class _BasicProfiler:
|
||||
"""
|
||||
@ivar saveStats: if C{True}, save the stats information instead of the
|
||||
human readable format
|
||||
@type saveStats: C{bool}
|
||||
|
||||
@ivar profileOutput: the name of the file use to print profile data.
|
||||
@type profileOutput: C{str}
|
||||
"""
|
||||
|
||||
def __init__(self, profileOutput, saveStats):
|
||||
self.profileOutput = profileOutput
|
||||
self.saveStats = saveStats
|
||||
|
||||
def _reportImportError(self, module, e):
|
||||
"""
|
||||
Helper method to report an import error with a profile module. This
|
||||
has to be explicit because some of these modules are removed by
|
||||
distributions due to them being non-free.
|
||||
"""
|
||||
s = f"Failed to import module {module}: {e}"
|
||||
s += """
|
||||
This is most likely caused by your operating system not including
|
||||
the module due to it being non-free. Either do not use the option
|
||||
--profile, or install the module; your operating system vendor
|
||||
may provide it in a separate package.
|
||||
"""
|
||||
raise SystemExit(s)
|
||||
|
||||
|
||||
class ProfileRunner(_BasicProfiler):
|
||||
"""
|
||||
Runner for the standard profile module.
|
||||
"""
|
||||
|
||||
def run(self, reactor):
|
||||
"""
|
||||
Run reactor under the standard profiler.
|
||||
"""
|
||||
try:
|
||||
import profile
|
||||
except ImportError as e:
|
||||
self._reportImportError("profile", e)
|
||||
|
||||
p = profile.Profile()
|
||||
p.runcall(reactor.run)
|
||||
if self.saveStats:
|
||||
p.dump_stats(self.profileOutput)
|
||||
else:
|
||||
tmp, sys.stdout = sys.stdout, open(self.profileOutput, "a")
|
||||
try:
|
||||
p.print_stats()
|
||||
finally:
|
||||
sys.stdout, tmp = tmp, sys.stdout
|
||||
tmp.close()
|
||||
|
||||
|
||||
class CProfileRunner(_BasicProfiler):
|
||||
"""
|
||||
Runner for the cProfile module.
|
||||
"""
|
||||
|
||||
def run(self, reactor):
|
||||
"""
|
||||
Run reactor under the cProfile profiler.
|
||||
"""
|
||||
try:
|
||||
import cProfile
|
||||
import pstats
|
||||
except ImportError as e:
|
||||
self._reportImportError("cProfile", e)
|
||||
|
||||
p = cProfile.Profile()
|
||||
p.runcall(reactor.run)
|
||||
if self.saveStats:
|
||||
p.dump_stats(self.profileOutput)
|
||||
else:
|
||||
with open(self.profileOutput, "w") as stream:
|
||||
s = pstats.Stats(p, stream=stream)
|
||||
s.strip_dirs()
|
||||
s.sort_stats(-1)
|
||||
s.print_stats()
|
||||
|
||||
|
||||
class AppProfiler:
|
||||
"""
|
||||
Class which selects a specific profile runner based on configuration
|
||||
options.
|
||||
|
||||
@ivar profiler: the name of the selected profiler.
|
||||
@type profiler: C{str}
|
||||
"""
|
||||
|
||||
profilers = {"profile": ProfileRunner, "cprofile": CProfileRunner}
|
||||
|
||||
def __init__(self, options):
|
||||
saveStats = options.get("savestats", False)
|
||||
profileOutput = options.get("profile", None)
|
||||
self.profiler = options.get("profiler", "cprofile").lower()
|
||||
if self.profiler in self.profilers:
|
||||
profiler = self.profilers[self.profiler](profileOutput, saveStats)
|
||||
self.run = profiler.run
|
||||
else:
|
||||
raise SystemExit(f"Unsupported profiler name: {self.profiler}")
|
||||
|
||||
|
||||
class AppLogger:
|
||||
"""
|
||||
An L{AppLogger} attaches the configured log observer specified on the
|
||||
commandline to a L{ServerOptions} object, a custom L{logger.ILogObserver},
|
||||
or a legacy custom {log.ILogObserver}.
|
||||
|
||||
@ivar _logfilename: The name of the file to which to log, if other than the
|
||||
default.
|
||||
@type _logfilename: C{str}
|
||||
|
||||
@ivar _observerFactory: Callable object that will create a log observer, or
|
||||
None.
|
||||
|
||||
@ivar _observer: log observer added at C{start} and removed at C{stop}.
|
||||
@type _observer: a callable that implements L{logger.ILogObserver} or
|
||||
L{log.ILogObserver}.
|
||||
"""
|
||||
|
||||
_observer = None
|
||||
|
||||
def __init__(self, options):
|
||||
"""
|
||||
Initialize an L{AppLogger} with a L{ServerOptions}.
|
||||
"""
|
||||
self._logfilename = options.get("logfile", "")
|
||||
self._observerFactory = options.get("logger") or None
|
||||
|
||||
def start(self, application):
|
||||
"""
|
||||
Initialize the global logging system for the given application.
|
||||
|
||||
If a custom logger was specified on the command line it will be used.
|
||||
If not, and an L{logger.ILogObserver} or legacy L{log.ILogObserver}
|
||||
component has been set on C{application}, then it will be used as the
|
||||
log observer. Otherwise a log observer will be created based on the
|
||||
command line options for built-in loggers (e.g. C{--logfile}).
|
||||
|
||||
@param application: The application on which to check for an
|
||||
L{logger.ILogObserver} or legacy L{log.ILogObserver}.
|
||||
@type application: L{twisted.python.components.Componentized}
|
||||
"""
|
||||
if self._observerFactory is not None:
|
||||
observer = self._observerFactory()
|
||||
else:
|
||||
observer = application.getComponent(logger.ILogObserver, None)
|
||||
if observer is None:
|
||||
# If there's no new ILogObserver, try the legacy one
|
||||
observer = application.getComponent(log.ILogObserver, None)
|
||||
|
||||
if observer is None:
|
||||
observer = self._getLogObserver()
|
||||
self._observer = observer
|
||||
|
||||
if logger.ILogObserver.providedBy(self._observer):
|
||||
observers = [self._observer]
|
||||
elif log.ILogObserver.providedBy(self._observer):
|
||||
observers = [logger.LegacyLogObserverWrapper(self._observer)]
|
||||
else:
|
||||
warnings.warn(
|
||||
(
|
||||
"Passing a logger factory which makes log observers which do "
|
||||
"not implement twisted.logger.ILogObserver or "
|
||||
"twisted.python.log.ILogObserver to "
|
||||
"twisted.application.app.AppLogger was deprecated in "
|
||||
"Twisted 16.2. Please use a factory that produces "
|
||||
"twisted.logger.ILogObserver (or the legacy "
|
||||
"twisted.python.log.ILogObserver) implementing objects "
|
||||
"instead."
|
||||
),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
observers = [logger.LegacyLogObserverWrapper(self._observer)]
|
||||
|
||||
logger.globalLogBeginner.beginLoggingTo(observers)
|
||||
self._initialLog()
|
||||
|
||||
def _initialLog(self):
|
||||
"""
|
||||
Print twistd start log message.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
logger._loggerFor(self).info(
|
||||
"twistd {version} ({exe} {pyVersion}) starting up.",
|
||||
version=copyright.version,
|
||||
exe=sys.executable,
|
||||
pyVersion=runtime.shortPythonVersion(),
|
||||
)
|
||||
logger._loggerFor(self).info(
|
||||
"reactor class: {reactor}.", reactor=qual(reactor.__class__)
|
||||
)
|
||||
|
||||
def _getLogObserver(self):
|
||||
"""
|
||||
Create a log observer to be added to the logging system before running
|
||||
this application.
|
||||
"""
|
||||
if self._logfilename == "-" or not self._logfilename:
|
||||
logFile = sys.stdout
|
||||
else:
|
||||
logFile = logfile.LogFile.fromFullPath(self._logfilename)
|
||||
return logger.textFileLogObserver(logFile)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Remove all log observers previously set up by L{AppLogger.start}.
|
||||
"""
|
||||
logger._loggerFor(self).info("Server Shut Down.")
|
||||
if self._observer is not None:
|
||||
logger.globalLogPublisher.removeObserver(self._observer)
|
||||
self._observer = None
|
||||
|
||||
|
||||
def fixPdb():
|
||||
def do_stop(self, arg):
|
||||
self.clear_all_breaks()
|
||||
self.set_continue()
|
||||
from twisted.internet import reactor
|
||||
|
||||
reactor.callLater(0, reactor.stop)
|
||||
return 1
|
||||
|
||||
def help_stop(self):
|
||||
print(
|
||||
"stop - Continue execution, then cleanly shutdown the twisted " "reactor."
|
||||
)
|
||||
|
||||
def set_quit(self):
|
||||
os._exit(0)
|
||||
|
||||
pdb.Pdb.set_quit = set_quit
|
||||
pdb.Pdb.do_stop = do_stop
|
||||
pdb.Pdb.help_stop = help_stop
|
||||
|
||||
|
||||
def runReactorWithLogging(config, oldstdout, oldstderr, profiler=None, reactor=None):
|
||||
"""
|
||||
Start the reactor, using profiling if specified by the configuration, and
|
||||
log any error happening in the process.
|
||||
|
||||
@param config: configuration of the twistd application.
|
||||
@type config: L{ServerOptions}
|
||||
|
||||
@param oldstdout: initial value of C{sys.stdout}.
|
||||
@type oldstdout: C{file}
|
||||
|
||||
@param oldstderr: initial value of C{sys.stderr}.
|
||||
@type oldstderr: C{file}
|
||||
|
||||
@param profiler: object used to run the reactor with profiling.
|
||||
@type profiler: L{AppProfiler}
|
||||
|
||||
@param reactor: The reactor to use. If L{None}, the global reactor will
|
||||
be used.
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
try:
|
||||
if config["profile"]:
|
||||
if profiler is not None:
|
||||
profiler.run(reactor)
|
||||
elif config["debug"]:
|
||||
sys.stdout = oldstdout
|
||||
sys.stderr = oldstderr
|
||||
if runtime.platformType == "posix":
|
||||
signal.signal(signal.SIGUSR2, lambda *args: pdb.set_trace())
|
||||
signal.signal(signal.SIGINT, lambda *args: pdb.set_trace())
|
||||
fixPdb()
|
||||
pdb.runcall(reactor.run)
|
||||
else:
|
||||
reactor.run()
|
||||
except BaseException:
|
||||
close = False
|
||||
if config["nodaemon"]:
|
||||
file = oldstdout
|
||||
else:
|
||||
file = open("TWISTD-CRASH.log", "a")
|
||||
close = True
|
||||
try:
|
||||
traceback.print_exc(file=file)
|
||||
file.flush()
|
||||
finally:
|
||||
if close:
|
||||
file.close()
|
||||
|
||||
|
||||
def getPassphrase(needed):
|
||||
if needed:
|
||||
return getpass.getpass("Passphrase: ")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def getSavePassphrase(needed):
|
||||
if needed:
|
||||
return util.getPassword("Encryption passphrase: ")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ApplicationRunner:
|
||||
"""
|
||||
An object which helps running an application based on a config object.
|
||||
|
||||
Subclass me and implement preApplication and postApplication
|
||||
methods. postApplication generally will want to run the reactor
|
||||
after starting the application.
|
||||
|
||||
@ivar config: The config object, which provides a dict-like interface.
|
||||
|
||||
@ivar application: Available in postApplication, but not
|
||||
preApplication. This is the application object.
|
||||
|
||||
@ivar profilerFactory: Factory for creating a profiler object, able to
|
||||
profile the application if options are set accordingly.
|
||||
|
||||
@ivar profiler: Instance provided by C{profilerFactory}.
|
||||
|
||||
@ivar loggerFactory: Factory for creating object responsible for logging.
|
||||
|
||||
@ivar logger: Instance provided by C{loggerFactory}.
|
||||
"""
|
||||
|
||||
profilerFactory = AppProfiler
|
||||
loggerFactory = AppLogger
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.profiler = self.profilerFactory(config)
|
||||
self.logger = self.loggerFactory(config)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the application.
|
||||
"""
|
||||
self.preApplication()
|
||||
self.application = self.createOrGetApplication()
|
||||
|
||||
self.logger.start(self.application)
|
||||
|
||||
self.postApplication()
|
||||
self.logger.stop()
|
||||
|
||||
def startReactor(self, reactor, oldstdout, oldstderr):
|
||||
"""
|
||||
Run the reactor with the given configuration. Subclasses should
|
||||
probably call this from C{postApplication}.
|
||||
|
||||
@see: L{runReactorWithLogging}
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
runReactorWithLogging(self.config, oldstdout, oldstderr, self.profiler, reactor)
|
||||
|
||||
if _ISupportsExitSignalCapturing.providedBy(reactor):
|
||||
self._exitSignal = reactor._exitSignal
|
||||
else:
|
||||
self._exitSignal = None
|
||||
|
||||
def preApplication(self):
|
||||
"""
|
||||
Override in subclass.
|
||||
|
||||
This should set up any state necessary before loading and
|
||||
running the Application.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def postApplication(self):
|
||||
"""
|
||||
Override in subclass.
|
||||
|
||||
This will be called after the application has been loaded (so
|
||||
the C{application} attribute will be set). Generally this
|
||||
should start the application and run the reactor.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def createOrGetApplication(self):
|
||||
"""
|
||||
Create or load an Application based on the parameters found in the
|
||||
given L{ServerOptions} instance.
|
||||
|
||||
If a subcommand was used, the L{service.IServiceMaker} that it
|
||||
represents will be used to construct a service to be added to
|
||||
a newly-created Application.
|
||||
|
||||
Otherwise, an application will be loaded based on parameters in
|
||||
the config.
|
||||
"""
|
||||
if self.config.subCommand:
|
||||
# If a subcommand was given, it's our responsibility to create
|
||||
# the application, instead of load it from a file.
|
||||
|
||||
# loadedPlugins is set up by the ServerOptions.subCommands
|
||||
# property, which is iterated somewhere in the bowels of
|
||||
# usage.Options.
|
||||
plg = self.config.loadedPlugins[self.config.subCommand]
|
||||
ser = plg.makeService(self.config.subOptions)
|
||||
application = service.Application(plg.tapname)
|
||||
ser.setServiceParent(application)
|
||||
else:
|
||||
passphrase = getPassphrase(self.config["encrypted"])
|
||||
application = getApplication(self.config, passphrase)
|
||||
return application
|
||||
|
||||
|
||||
def getApplication(config, passphrase):
|
||||
s = [(config[t], t) for t in ["python", "source", "file"] if config[t]][0]
|
||||
filename, style = s[0], {"file": "pickle"}.get(s[1], s[1])
|
||||
try:
|
||||
log.msg("Loading %s..." % filename)
|
||||
application = service.loadApplication(filename, style, passphrase)
|
||||
log.msg("Loaded.")
|
||||
except Exception as e:
|
||||
s = "Failed to load application: %s" % e
|
||||
if isinstance(e, KeyError) and e.args[0] == "application":
|
||||
s += """
|
||||
Could not find 'application' in the file. To use 'twistd -y', your .tac
|
||||
file must create a suitable object (e.g., by calling service.Application())
|
||||
and store it in a variable named 'application'. twistd loads your .tac file
|
||||
and scans the global variables for one of this name.
|
||||
|
||||
Please read the 'Using Application' HOWTO for details.
|
||||
"""
|
||||
traceback.print_exc(file=log.logfile)
|
||||
log.msg(s)
|
||||
log.deferr()
|
||||
sys.exit("\n" + s + "\n")
|
||||
return application
|
||||
|
||||
|
||||
def _reactorAction():
|
||||
return usage.CompleteList([r.shortName for r in reactors.getReactorTypes()])
|
||||
|
||||
|
||||
class ReactorSelectionMixin:
|
||||
"""
|
||||
Provides options for selecting a reactor to install.
|
||||
|
||||
If a reactor is installed, the short name which was used to locate it is
|
||||
saved as the value for the C{"reactor"} key.
|
||||
"""
|
||||
|
||||
compData = usage.Completions(optActions={"reactor": _reactorAction})
|
||||
|
||||
messageOutput = sys.stdout
|
||||
_getReactorTypes = staticmethod(reactors.getReactorTypes)
|
||||
|
||||
def opt_help_reactors(self):
|
||||
"""
|
||||
Display a list of possibly available reactor names.
|
||||
"""
|
||||
rcts = sorted(self._getReactorTypes(), key=attrgetter("shortName"))
|
||||
notWorkingReactors = ""
|
||||
for r in rcts:
|
||||
try:
|
||||
namedModule(r.moduleName)
|
||||
self.messageOutput.write(f" {r.shortName:<4}\t{r.description}\n")
|
||||
except ImportError as e:
|
||||
notWorkingReactors += " !{:<4}\t{} ({})\n".format(
|
||||
r.shortName,
|
||||
r.description,
|
||||
e.args[0],
|
||||
)
|
||||
|
||||
if notWorkingReactors:
|
||||
self.messageOutput.write("\n")
|
||||
self.messageOutput.write(
|
||||
" reactors not available " "on this platform:\n\n"
|
||||
)
|
||||
self.messageOutput.write(notWorkingReactors)
|
||||
raise SystemExit(0)
|
||||
|
||||
def opt_reactor(self, shortName):
|
||||
"""
|
||||
Which reactor to use (see --help-reactors for a list of possibilities)
|
||||
"""
|
||||
# Actually actually actually install the reactor right at this very
|
||||
# moment, before any other code (for example, a sub-command plugin)
|
||||
# runs and accidentally imports and installs the default reactor.
|
||||
#
|
||||
# This could probably be improved somehow.
|
||||
try:
|
||||
installReactor(shortName)
|
||||
except NoSuchReactor:
|
||||
msg = (
|
||||
"The specified reactor does not exist: '%s'.\n"
|
||||
"See the list of available reactors with "
|
||||
"--help-reactors" % (shortName,)
|
||||
)
|
||||
raise usage.UsageError(msg)
|
||||
except Exception as e:
|
||||
msg = (
|
||||
"The specified reactor cannot be used, failed with error: "
|
||||
"%s.\nSee the list of available reactors with "
|
||||
"--help-reactors" % (e,)
|
||||
)
|
||||
raise usage.UsageError(msg)
|
||||
else:
|
||||
self["reactor"] = shortName
|
||||
|
||||
opt_r = opt_reactor
|
||||
|
||||
|
||||
class ServerOptions(usage.Options, ReactorSelectionMixin):
|
||||
longdesc = (
|
||||
"twistd reads a twisted.application.service.Application out "
|
||||
"of a file and runs it."
|
||||
)
|
||||
|
||||
optFlags = [
|
||||
[
|
||||
"savestats",
|
||||
None,
|
||||
"save the Stats object rather than the text output of " "the profiler.",
|
||||
],
|
||||
["no_save", "o", "do not save state on shutdown"],
|
||||
["encrypted", "e", "The specified tap/aos file is encrypted."],
|
||||
]
|
||||
|
||||
optParameters = [
|
||||
["logfile", "l", None, "log to a specified file, - for stdout"],
|
||||
[
|
||||
"logger",
|
||||
None,
|
||||
None,
|
||||
"A fully-qualified name to a log observer factory to "
|
||||
"use for the initial log observer. Takes precedence "
|
||||
"over --logfile and --syslog (when available).",
|
||||
],
|
||||
[
|
||||
"profile",
|
||||
"p",
|
||||
None,
|
||||
"Run in profile mode, dumping results to specified " "file.",
|
||||
],
|
||||
[
|
||||
"profiler",
|
||||
None,
|
||||
"cprofile",
|
||||
"Name of the profiler to use (%s)." % ", ".join(AppProfiler.profilers),
|
||||
],
|
||||
["file", "f", "twistd.tap", "read the given .tap file"],
|
||||
[
|
||||
"python",
|
||||
"y",
|
||||
None,
|
||||
"read an application from within a Python file " "(implies -o)",
|
||||
],
|
||||
["source", "s", None, "Read an application from a .tas file (AOT format)."],
|
||||
["rundir", "d", ".", "Change to a supplied directory before running"],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("file", "python", "source")],
|
||||
optActions={
|
||||
"file": usage.CompleteFiles("*.tap"),
|
||||
"python": usage.CompleteFiles("*.(tac|py)"),
|
||||
"source": usage.CompleteFiles("*.tas"),
|
||||
"rundir": usage.CompleteDirs(),
|
||||
},
|
||||
)
|
||||
|
||||
_getPlugins = staticmethod(plugin.getPlugins)
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
self["debug"] = False
|
||||
if "stdout" in kw:
|
||||
self.stdout = kw["stdout"]
|
||||
else:
|
||||
self.stdout = sys.stdout
|
||||
usage.Options.__init__(self)
|
||||
|
||||
def opt_debug(self):
|
||||
"""
|
||||
Run the application in the Python Debugger (implies nodaemon),
|
||||
sending SIGUSR2 will drop into debugger
|
||||
"""
|
||||
defer.setDebugging(True)
|
||||
failure.startDebugMode()
|
||||
self["debug"] = True
|
||||
|
||||
opt_b = opt_debug
|
||||
|
||||
def opt_spew(self):
|
||||
"""
|
||||
Print an insanely verbose log of everything that happens.
|
||||
Useful when debugging freezes or locks in complex code.
|
||||
"""
|
||||
sys.settrace(util.spewer)
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
return
|
||||
threading.settrace(util.spewer)
|
||||
|
||||
def parseOptions(self, options=None):
|
||||
if options is None:
|
||||
options = sys.argv[1:] or ["--help"]
|
||||
usage.Options.parseOptions(self, options)
|
||||
|
||||
def postOptions(self):
|
||||
if self.subCommand or self["python"]:
|
||||
self["no_save"] = True
|
||||
if self["logger"] is not None:
|
||||
try:
|
||||
self["logger"] = namedAny(self["logger"])
|
||||
except Exception as e:
|
||||
raise usage.UsageError(
|
||||
"Logger '{}' could not be imported: {}".format(self["logger"], e)
|
||||
)
|
||||
|
||||
@property
|
||||
def subCommands(self):
|
||||
plugins = self._getPlugins(service.IServiceMaker)
|
||||
self.loadedPlugins = {}
|
||||
for plug in sorted(plugins, key=attrgetter("tapname")):
|
||||
self.loadedPlugins[plug.tapname] = plug
|
||||
yield (
|
||||
plug.tapname,
|
||||
None,
|
||||
# Avoid resolving the options attribute right away, in case
|
||||
# it's a property with a non-trivial getter (eg, one which
|
||||
# imports modules).
|
||||
lambda plug=plug: plug.options(),
|
||||
plug.description,
|
||||
)
|
||||
|
||||
|
||||
def run(runApp, ServerOptions):
|
||||
config = ServerOptions()
|
||||
try:
|
||||
config.parseOptions()
|
||||
except usage.error as ue:
|
||||
commstr = " ".join(sys.argv[0:2])
|
||||
print(config)
|
||||
print(f"{commstr}: {ue}")
|
||||
else:
|
||||
runApp(config)
|
||||
|
||||
|
||||
def convertStyle(filein, typein, passphrase, fileout, typeout, encrypt):
|
||||
application = service.loadApplication(filein, typein, passphrase)
|
||||
sob.IPersistable(application).setStyle(typeout)
|
||||
passphrase = getSavePassphrase(encrypt)
|
||||
if passphrase:
|
||||
fileout = None
|
||||
sob.IPersistable(application).save(filename=fileout, passphrase=passphrase)
|
||||
|
||||
|
||||
def startApplication(application, save):
|
||||
from twisted.internet import reactor
|
||||
|
||||
service.IService(application).startService()
|
||||
if save:
|
||||
p = sob.IPersistable(application)
|
||||
reactor.addSystemEventTrigger("after", "shutdown", p.save, "shutdown")
|
||||
reactor.addSystemEventTrigger(
|
||||
"before", "shutdown", service.IService(application).stopService
|
||||
)
|
||||
|
||||
|
||||
def _exitWithSignal(sig):
|
||||
"""
|
||||
Force the application to terminate with the specified signal by replacing
|
||||
the signal handler with the default and sending the signal to ourselves.
|
||||
|
||||
@param sig: Signal to use to terminate the process with C{os.kill}.
|
||||
@type sig: C{int}
|
||||
"""
|
||||
signal.signal(sig, signal.SIG_DFL)
|
||||
os.kill(os.getpid(), sig)
|
||||
@@ -0,0 +1,427 @@
|
||||
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Reactor-based Services
|
||||
|
||||
Here are services to run clients, servers and periodic services using
|
||||
the reactor.
|
||||
|
||||
If you want to run a server service, L{StreamServerEndpointService} defines a
|
||||
service that can wrap an arbitrary L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>}
|
||||
as an L{IService}. See also L{twisted.application.strports.service} for
|
||||
constructing one of these directly from a descriptive string.
|
||||
|
||||
Additionally, this module (dynamically) defines various Service subclasses that
|
||||
let you represent clients and servers in a Service hierarchy. Endpoints APIs
|
||||
should be preferred for stream server services, but since those APIs do not yet
|
||||
exist for clients or datagram services, many of these are still useful.
|
||||
|
||||
They are as follows::
|
||||
|
||||
TCPServer, TCPClient,
|
||||
UNIXServer, UNIXClient,
|
||||
SSLServer, SSLClient,
|
||||
UDPServer,
|
||||
UNIXDatagramServer, UNIXDatagramClient,
|
||||
MulticastServer
|
||||
|
||||
These classes take arbitrary arguments in their constructors and pass
|
||||
them straight on to their respective reactor.listenXXX or
|
||||
reactor.connectXXX calls.
|
||||
|
||||
For example, the following service starts a web server on port 8080:
|
||||
C{TCPServer(8080, server.Site(r))}. See the documentation for the
|
||||
reactor.listen/connect* methods for more information.
|
||||
"""
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from twisted.application import service
|
||||
from twisted.internet import task
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.python import log
|
||||
from ._client_service import ClientService, _maybeGlobalReactor, backoffPolicy
|
||||
|
||||
|
||||
class _VolatileDataService(service.Service):
|
||||
volatile: List[str] = []
|
||||
|
||||
def __getstate__(self):
|
||||
d = service.Service.__getstate__(self)
|
||||
for attr in self.volatile:
|
||||
if attr in d:
|
||||
del d[attr]
|
||||
return d
|
||||
|
||||
|
||||
class _AbstractServer(_VolatileDataService):
|
||||
"""
|
||||
@cvar volatile: list of attribute to remove from pickling.
|
||||
@type volatile: C{list}
|
||||
|
||||
@ivar method: the type of method to call on the reactor, one of B{TCP},
|
||||
B{UDP}, B{SSL} or B{UNIX}.
|
||||
@type method: C{str}
|
||||
|
||||
@ivar reactor: the current running reactor.
|
||||
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
|
||||
C{IReactorSSL} or C{IReactorUnix}.
|
||||
|
||||
@ivar _port: instance of port set when the service is started.
|
||||
@type _port: a provider of L{twisted.internet.interfaces.IListeningPort}.
|
||||
"""
|
||||
|
||||
volatile = ["_port"]
|
||||
method: str = ""
|
||||
reactor = None
|
||||
|
||||
_port = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
if "reactor" in kwargs:
|
||||
self.reactor = kwargs.pop("reactor")
|
||||
self.kwargs = kwargs
|
||||
|
||||
def privilegedStartService(self):
|
||||
service.Service.privilegedStartService(self)
|
||||
self._port = self._getPort()
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
if self._port is None:
|
||||
self._port = self._getPort()
|
||||
|
||||
def stopService(self):
|
||||
service.Service.stopService(self)
|
||||
# TODO: if startup failed, should shutdown skip stopListening?
|
||||
# _port won't exist
|
||||
if self._port is not None:
|
||||
d = self._port.stopListening()
|
||||
del self._port
|
||||
return d
|
||||
|
||||
def _getPort(self):
|
||||
"""
|
||||
Wrapper around the appropriate listen method of the reactor.
|
||||
|
||||
@return: the port object returned by the listen method.
|
||||
@rtype: an object providing
|
||||
L{twisted.internet.interfaces.IListeningPort}.
|
||||
"""
|
||||
return getattr(
|
||||
_maybeGlobalReactor(self.reactor),
|
||||
"listen{}".format(
|
||||
self.method,
|
||||
),
|
||||
)(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
class _AbstractClient(_VolatileDataService):
|
||||
"""
|
||||
@cvar volatile: list of attribute to remove from pickling.
|
||||
@type volatile: C{list}
|
||||
|
||||
@ivar method: the type of method to call on the reactor, one of B{TCP},
|
||||
B{UDP}, B{SSL} or B{UNIX}.
|
||||
@type method: C{str}
|
||||
|
||||
@ivar reactor: the current running reactor.
|
||||
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
|
||||
C{IReactorSSL} or C{IReactorUnix}.
|
||||
|
||||
@ivar _connection: instance of connection set when the service is started.
|
||||
@type _connection: a provider of L{twisted.internet.interfaces.IConnector}.
|
||||
"""
|
||||
|
||||
volatile = ["_connection"]
|
||||
method: str = ""
|
||||
reactor = None
|
||||
|
||||
_connection = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
if "reactor" in kwargs:
|
||||
self.reactor = kwargs.pop("reactor")
|
||||
self.kwargs = kwargs
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
self._connection = self._getConnection()
|
||||
|
||||
def stopService(self):
|
||||
service.Service.stopService(self)
|
||||
if self._connection is not None:
|
||||
self._connection.disconnect()
|
||||
del self._connection
|
||||
|
||||
def _getConnection(self):
|
||||
"""
|
||||
Wrapper around the appropriate connect method of the reactor.
|
||||
|
||||
@return: the port object returned by the connect method.
|
||||
@rtype: an object providing L{twisted.internet.interfaces.IConnector}.
|
||||
"""
|
||||
return getattr(_maybeGlobalReactor(self.reactor), f"connect{self.method}")(
|
||||
*self.args, **self.kwargs
|
||||
)
|
||||
|
||||
|
||||
_clientDoc = """Connect to {tran}
|
||||
|
||||
Call reactor.connect{tran} when the service starts, with the
|
||||
arguments given to the constructor.
|
||||
"""
|
||||
|
||||
_serverDoc = """Serve {tran} clients
|
||||
|
||||
Call reactor.listen{tran} when the service starts, with the
|
||||
arguments given to the constructor. When the service stops,
|
||||
stop listening. See twisted.internet.interfaces for documentation
|
||||
on arguments to the reactor method.
|
||||
"""
|
||||
|
||||
|
||||
class TCPServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="TCP")
|
||||
method = "TCP"
|
||||
|
||||
|
||||
class TCPClient(_AbstractClient):
|
||||
__doc__ = _clientDoc.format(tran="TCP")
|
||||
method = "TCP"
|
||||
|
||||
|
||||
class UNIXServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="UNIX")
|
||||
method = "UNIX"
|
||||
|
||||
|
||||
class UNIXClient(_AbstractClient):
|
||||
__doc__ = _clientDoc.format(tran="UNIX")
|
||||
method = "UNIX"
|
||||
|
||||
|
||||
class SSLServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="SSL")
|
||||
method = "SSL"
|
||||
|
||||
|
||||
class SSLClient(_AbstractClient):
|
||||
__doc__ = _clientDoc.format(tran="SSL")
|
||||
method = "SSL"
|
||||
|
||||
|
||||
class UDPServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="UDP")
|
||||
method = "UDP"
|
||||
|
||||
|
||||
class UNIXDatagramServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="UNIXDatagram")
|
||||
method = "UNIXDatagram"
|
||||
|
||||
|
||||
class UNIXDatagramClient(_AbstractClient):
|
||||
__doc__ = _clientDoc.format(tran="UNIXDatagram")
|
||||
method = "UNIXDatagram"
|
||||
|
||||
|
||||
class MulticastServer(_AbstractServer):
|
||||
__doc__ = _serverDoc.format(tran="Multicast")
|
||||
method = "Multicast"
|
||||
|
||||
|
||||
class TimerService(_VolatileDataService):
|
||||
"""
|
||||
Service to periodically call a function
|
||||
|
||||
Every C{step} seconds call the given function with the given arguments.
|
||||
The service starts the calls when it starts, and cancels them
|
||||
when it stops.
|
||||
|
||||
@ivar clock: Source of time. This defaults to L{None} which is
|
||||
causes L{twisted.internet.reactor} to be used.
|
||||
Feel free to set this to something else, but it probably ought to be
|
||||
set *before* calling L{startService}.
|
||||
@type clock: L{IReactorTime<twisted.internet.interfaces.IReactorTime>}
|
||||
|
||||
@ivar call: Function and arguments to call periodically.
|
||||
@type call: L{tuple} of C{(callable, args, kwargs)}
|
||||
"""
|
||||
|
||||
volatile = ["_loop", "_loopFinished"]
|
||||
|
||||
def __init__(self, step, callable, *args, **kwargs):
|
||||
"""
|
||||
@param step: The number of seconds between calls.
|
||||
@type step: L{float}
|
||||
|
||||
@param callable: Function to call
|
||||
@type callable: L{callable}
|
||||
|
||||
@param args: Positional arguments to pass to function
|
||||
@param kwargs: Keyword arguments to pass to function
|
||||
"""
|
||||
self.step = step
|
||||
self.call = (callable, args, kwargs)
|
||||
self.clock = None
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
callable, args, kwargs = self.call
|
||||
# we have to make a new LoopingCall each time we're started, because
|
||||
# an active LoopingCall remains active when serialized. If
|
||||
# LoopingCall were a _VolatileDataService, we wouldn't need to do
|
||||
# this.
|
||||
self._loop = task.LoopingCall(callable, *args, **kwargs)
|
||||
self._loop.clock = _maybeGlobalReactor(self.clock)
|
||||
self._loopFinished = self._loop.start(self.step, now=True)
|
||||
self._loopFinished.addErrback(self._failed)
|
||||
|
||||
def _failed(self, why):
|
||||
# make a note that the LoopingCall is no longer looping, so we don't
|
||||
# try to shut it down a second time in stopService. I think this
|
||||
# should be in LoopingCall. -warner
|
||||
self._loop.running = False
|
||||
log.err(why)
|
||||
|
||||
def stopService(self):
|
||||
"""
|
||||
Stop the service.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is fired when the
|
||||
currently running call (if any) is finished.
|
||||
"""
|
||||
if self._loop.running:
|
||||
self._loop.stop()
|
||||
self._loopFinished.addCallback(lambda _: service.Service.stopService(self))
|
||||
return self._loopFinished
|
||||
|
||||
|
||||
class CooperatorService(service.Service):
|
||||
"""
|
||||
Simple L{service.IService} which starts and stops a L{twisted.internet.task.Cooperator}.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.coop = task.Cooperator(started=False)
|
||||
|
||||
def coiterate(self, iterator):
|
||||
return self.coop.coiterate(iterator)
|
||||
|
||||
def startService(self):
|
||||
self.coop.start()
|
||||
|
||||
def stopService(self):
|
||||
self.coop.stop()
|
||||
|
||||
|
||||
class StreamServerEndpointService(service.Service):
|
||||
"""
|
||||
A L{StreamServerEndpointService} is an L{IService} which runs a server on a
|
||||
listening port described by an L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>}.
|
||||
|
||||
@ivar factory: A server factory which will be used to listen on the
|
||||
endpoint.
|
||||
|
||||
@ivar endpoint: An L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>} provider
|
||||
which will be used to listen when the service starts.
|
||||
|
||||
@ivar _waitingForPort: a Deferred, if C{listen} has yet been invoked on the
|
||||
endpoint, otherwise None.
|
||||
|
||||
@ivar _raiseSynchronously: Defines error-handling behavior for the case
|
||||
where C{listen(...)} raises an exception before C{startService} or
|
||||
C{privilegedStartService} have completed.
|
||||
|
||||
@type _raiseSynchronously: C{bool}
|
||||
|
||||
@since: 10.2
|
||||
"""
|
||||
|
||||
_raiseSynchronously = False
|
||||
|
||||
def __init__(self, endpoint, factory):
|
||||
self.endpoint = endpoint
|
||||
self.factory = factory
|
||||
self._waitingForPort = None
|
||||
|
||||
def privilegedStartService(self):
|
||||
"""
|
||||
Start listening on the endpoint.
|
||||
"""
|
||||
service.Service.privilegedStartService(self)
|
||||
self._waitingForPort = self.endpoint.listen(self.factory)
|
||||
raisedNow = []
|
||||
|
||||
def handleIt(err):
|
||||
if self._raiseSynchronously:
|
||||
raisedNow.append(err)
|
||||
elif not err.check(CancelledError):
|
||||
log.err(err)
|
||||
|
||||
self._waitingForPort.addErrback(handleIt)
|
||||
if raisedNow:
|
||||
raisedNow[0].raiseException()
|
||||
self._raiseSynchronously = False
|
||||
|
||||
def startService(self):
|
||||
"""
|
||||
Start listening on the endpoint, unless L{privilegedStartService} got
|
||||
around to it already.
|
||||
"""
|
||||
service.Service.startService(self)
|
||||
if self._waitingForPort is None:
|
||||
self.privilegedStartService()
|
||||
|
||||
def stopService(self):
|
||||
"""
|
||||
Stop listening on the port if it is already listening, otherwise,
|
||||
cancel the attempt to listen.
|
||||
|
||||
@return: a L{Deferred<twisted.internet.defer.Deferred>} which fires
|
||||
with L{None} when the port has stopped listening.
|
||||
"""
|
||||
self._waitingForPort.cancel()
|
||||
|
||||
def stopIt(port):
|
||||
if port is not None:
|
||||
return port.stopListening()
|
||||
|
||||
d = self._waitingForPort.addCallback(stopIt)
|
||||
|
||||
def stop(passthrough):
|
||||
self.running = False
|
||||
return passthrough
|
||||
|
||||
d.addBoth(stop)
|
||||
return d
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TimerService",
|
||||
"CooperatorService",
|
||||
"MulticastServer",
|
||||
"StreamServerEndpointService",
|
||||
"UDPServer",
|
||||
"ClientService",
|
||||
"TCPServer",
|
||||
"TCPClient",
|
||||
"UNIXServer",
|
||||
"UNIXClient",
|
||||
"SSLServer",
|
||||
"SSLClient",
|
||||
"UNIXDatagramServer",
|
||||
"UNIXDatagramClient",
|
||||
"ClientService",
|
||||
"backoffPolicy",
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# -*- test-case-name: twisted.test.test_application -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Plugin-based system for enumerating available reactors and installing one of
|
||||
them.
|
||||
"""
|
||||
from typing import Iterable, cast
|
||||
|
||||
from zope.interface import Attribute, Interface, implementer
|
||||
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
from twisted.plugin import IPlugin, getPlugins
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
|
||||
class IReactorInstaller(Interface):
|
||||
"""
|
||||
Definition of a reactor which can probably be installed.
|
||||
"""
|
||||
|
||||
shortName = Attribute(
|
||||
"""
|
||||
A brief string giving the user-facing name of this reactor.
|
||||
"""
|
||||
)
|
||||
|
||||
description = Attribute(
|
||||
"""
|
||||
A longer string giving a user-facing description of this reactor.
|
||||
"""
|
||||
)
|
||||
|
||||
def install() -> None:
|
||||
"""
|
||||
Install this reactor.
|
||||
"""
|
||||
|
||||
# TODO - A method which provides a best-guess as to whether this reactor
|
||||
# can actually be used in the execution environment.
|
||||
|
||||
|
||||
class NoSuchReactor(KeyError):
|
||||
"""
|
||||
Raised when an attempt is made to install a reactor which cannot be found.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IPlugin, IReactorInstaller)
|
||||
class Reactor:
|
||||
"""
|
||||
@ivar moduleName: The fully-qualified Python name of the module of which
|
||||
the install callable is an attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, shortName: str, moduleName: str, description: str):
|
||||
self.shortName = shortName
|
||||
self.moduleName = moduleName
|
||||
self.description = description
|
||||
|
||||
def install(self) -> None:
|
||||
namedAny(self.moduleName).install()
|
||||
|
||||
|
||||
def getReactorTypes() -> Iterable[IReactorInstaller]:
|
||||
"""
|
||||
Return an iterator of L{IReactorInstaller} plugins.
|
||||
"""
|
||||
return getPlugins(IReactorInstaller)
|
||||
|
||||
|
||||
def installReactor(shortName: str) -> IReactorCore:
|
||||
"""
|
||||
Install the reactor with the given C{shortName} attribute.
|
||||
|
||||
@raise NoSuchReactor: If no reactor is found with a matching C{shortName}.
|
||||
|
||||
@raise Exception: Anything that the specified reactor can raise when installed.
|
||||
"""
|
||||
for installer in getReactorTypes():
|
||||
if installer.shortName == shortName:
|
||||
installer.install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
return cast(IReactorCore, reactor)
|
||||
raise NoSuchReactor(shortName)
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.application.runner.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Facilities for running a Twisted application.
|
||||
"""
|
||||
@@ -0,0 +1,99 @@
|
||||
# -*- test-case-name: twisted.application.runner.test.test_exit -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
System exit support.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from enum import IntEnum
|
||||
from sys import exit as sysexit, stderr, stdout
|
||||
from typing import Union
|
||||
|
||||
try:
|
||||
import posix as Status
|
||||
except ImportError:
|
||||
|
||||
class Status: # type: ignore[no-redef]
|
||||
"""
|
||||
Object to hang C{EX_*} values off of as a substitute for L{posix}.
|
||||
"""
|
||||
|
||||
EX__BASE = 64
|
||||
|
||||
EX_OK = 0
|
||||
EX_USAGE = EX__BASE
|
||||
EX_DATAERR = EX__BASE + 1
|
||||
EX_NOINPUT = EX__BASE + 2
|
||||
EX_NOUSER = EX__BASE + 3
|
||||
EX_NOHOST = EX__BASE + 4
|
||||
EX_UNAVAILABLE = EX__BASE + 5
|
||||
EX_SOFTWARE = EX__BASE + 6
|
||||
EX_OSERR = EX__BASE + 7
|
||||
EX_OSFILE = EX__BASE + 8
|
||||
EX_CANTCREAT = EX__BASE + 9
|
||||
EX_IOERR = EX__BASE + 10
|
||||
EX_TEMPFAIL = EX__BASE + 11
|
||||
EX_PROTOCOL = EX__BASE + 12
|
||||
EX_NOPERM = EX__BASE + 13
|
||||
EX_CONFIG = EX__BASE + 14
|
||||
|
||||
|
||||
class ExitStatus(IntEnum):
|
||||
"""
|
||||
Standard exit status codes for system programs.
|
||||
|
||||
@cvar EX_OK: Successful termination.
|
||||
@cvar EX_USAGE: Command line usage error.
|
||||
@cvar EX_DATAERR: Data format error.
|
||||
@cvar EX_NOINPUT: Cannot open input.
|
||||
@cvar EX_NOUSER: Addressee unknown.
|
||||
@cvar EX_NOHOST: Host name unknown.
|
||||
@cvar EX_UNAVAILABLE: Service unavailable.
|
||||
@cvar EX_SOFTWARE: Internal software error.
|
||||
@cvar EX_OSERR: System error (e.g., can't fork).
|
||||
@cvar EX_OSFILE: Critical OS file missing.
|
||||
@cvar EX_CANTCREAT: Can't create (user) output file.
|
||||
@cvar EX_IOERR: Input/output error.
|
||||
@cvar EX_TEMPFAIL: Temporary failure; the user is invited to retry.
|
||||
@cvar EX_PROTOCOL: Remote error in protocol.
|
||||
@cvar EX_NOPERM: Permission denied.
|
||||
@cvar EX_CONFIG: Configuration error.
|
||||
"""
|
||||
|
||||
EX_OK = Status.EX_OK
|
||||
EX_USAGE = Status.EX_USAGE
|
||||
EX_DATAERR = Status.EX_DATAERR
|
||||
EX_NOINPUT = Status.EX_NOINPUT
|
||||
EX_NOUSER = Status.EX_NOUSER
|
||||
EX_NOHOST = Status.EX_NOHOST
|
||||
EX_UNAVAILABLE = Status.EX_UNAVAILABLE
|
||||
EX_SOFTWARE = Status.EX_SOFTWARE
|
||||
EX_OSERR = Status.EX_OSERR
|
||||
EX_OSFILE = Status.EX_OSFILE
|
||||
EX_CANTCREAT = Status.EX_CANTCREAT
|
||||
EX_IOERR = Status.EX_IOERR
|
||||
EX_TEMPFAIL = Status.EX_TEMPFAIL
|
||||
EX_PROTOCOL = Status.EX_PROTOCOL
|
||||
EX_NOPERM = Status.EX_NOPERM
|
||||
EX_CONFIG = Status.EX_CONFIG
|
||||
|
||||
|
||||
def exit(status: Union[int, ExitStatus], message: str = "") -> "typing.NoReturn":
|
||||
"""
|
||||
Exit the python interpreter with the given status and an optional message.
|
||||
|
||||
@param status: An exit status. An appropriate value from L{ExitStatus} is
|
||||
recommended.
|
||||
@param message: An optional message to print.
|
||||
"""
|
||||
if message:
|
||||
if status == ExitStatus.EX_OK:
|
||||
out = stdout
|
||||
else:
|
||||
out = stderr
|
||||
out.write(message)
|
||||
out.write("\n")
|
||||
|
||||
sysexit(status)
|
||||
@@ -0,0 +1,282 @@
|
||||
# -*- test-case-name: twisted.application.runner.test.test_pidfile -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
PID file.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
from os import getpid, kill, name as SYSTEM_NAME
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
|
||||
class IPIDFile(Interface):
|
||||
"""
|
||||
Manages a file that remembers a process ID.
|
||||
"""
|
||||
|
||||
def read() -> int:
|
||||
"""
|
||||
Read the process ID stored in this PID file.
|
||||
|
||||
@return: The contained process ID.
|
||||
|
||||
@raise NoPIDFound: If this PID file does not exist.
|
||||
@raise EnvironmentError: If this PID file cannot be read.
|
||||
@raise ValueError: If this PID file's content is invalid.
|
||||
"""
|
||||
|
||||
def writeRunningPID() -> None:
|
||||
"""
|
||||
Store the PID of the current process in this PID file.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be written.
|
||||
"""
|
||||
|
||||
def remove() -> None:
|
||||
"""
|
||||
Remove this PID file.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be removed.
|
||||
"""
|
||||
|
||||
def isRunning() -> bool:
|
||||
"""
|
||||
Determine whether there is a running process corresponding to the PID
|
||||
in this PID file.
|
||||
|
||||
@return: True if this PID file contains a PID and a process with that
|
||||
PID is currently running; false otherwise.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be read.
|
||||
@raise InvalidPIDFileError: If this PID file's content is invalid.
|
||||
@raise StalePIDFileError: If this PID file's content refers to a PID
|
||||
for which there is no corresponding running process.
|
||||
"""
|
||||
|
||||
def __enter__() -> "IPIDFile":
|
||||
"""
|
||||
Enter a context using this PIDFile.
|
||||
|
||||
Writes the PID file with the PID of the running process.
|
||||
|
||||
@raise AlreadyRunningError: A process corresponding to the PID in this
|
||||
PID file is already running.
|
||||
"""
|
||||
|
||||
def __exit__(
|
||||
excType: Optional[Type[BaseException]],
|
||||
excValue: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Exit a context using this PIDFile.
|
||||
|
||||
Removes the PID file.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IPIDFile)
|
||||
class PIDFile:
|
||||
"""
|
||||
Concrete implementation of L{IPIDFile}.
|
||||
|
||||
This implementation is presently not supported on non-POSIX platforms.
|
||||
Specifically, calling L{PIDFile.isRunning} will raise
|
||||
L{NotImplementedError}.
|
||||
"""
|
||||
|
||||
_log = Logger()
|
||||
|
||||
@staticmethod
|
||||
def _format(pid: int) -> bytes:
|
||||
"""
|
||||
Format a PID file's content.
|
||||
|
||||
@param pid: A process ID.
|
||||
|
||||
@return: Formatted PID file contents.
|
||||
"""
|
||||
return f"{int(pid)}\n".encode()
|
||||
|
||||
def __init__(self, filePath: FilePath[Any]) -> None:
|
||||
"""
|
||||
@param filePath: The path to the PID file on disk.
|
||||
"""
|
||||
self.filePath = filePath
|
||||
|
||||
def read(self) -> int:
|
||||
pidString = b""
|
||||
try:
|
||||
with self.filePath.open() as fh:
|
||||
for pidString in fh:
|
||||
break
|
||||
except OSError as e:
|
||||
if e.errno == errno.ENOENT: # No such file
|
||||
raise NoPIDFound("PID file does not exist")
|
||||
raise
|
||||
|
||||
try:
|
||||
return int(pidString)
|
||||
except ValueError:
|
||||
raise InvalidPIDFileError(
|
||||
f"non-integer PID value in PID file: {pidString!r}"
|
||||
)
|
||||
|
||||
def _write(self, pid: int) -> None:
|
||||
"""
|
||||
Store a PID in this PID file.
|
||||
|
||||
@param pid: A PID to store.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be written.
|
||||
"""
|
||||
self.filePath.setContent(self._format(pid=pid))
|
||||
|
||||
def writeRunningPID(self) -> None:
|
||||
self._write(getpid())
|
||||
|
||||
def remove(self) -> None:
|
||||
self.filePath.remove()
|
||||
|
||||
def isRunning(self) -> bool:
|
||||
try:
|
||||
pid = self.read()
|
||||
except NoPIDFound:
|
||||
return False
|
||||
|
||||
if SYSTEM_NAME == "posix":
|
||||
return self._pidIsRunningPOSIX(pid)
|
||||
else:
|
||||
raise NotImplementedError(f"isRunning is not implemented on {SYSTEM_NAME}")
|
||||
|
||||
@staticmethod
|
||||
def _pidIsRunningPOSIX(pid: int) -> bool:
|
||||
"""
|
||||
POSIX implementation for running process check.
|
||||
|
||||
Determine whether there is a running process corresponding to the given
|
||||
PID.
|
||||
|
||||
@param pid: The PID to check.
|
||||
|
||||
@return: True if the given PID is currently running; false otherwise.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be read.
|
||||
@raise InvalidPIDFileError: If this PID file's content is invalid.
|
||||
@raise StalePIDFileError: If this PID file's content refers to a PID
|
||||
for which there is no corresponding running process.
|
||||
"""
|
||||
try:
|
||||
kill(pid, 0)
|
||||
except OSError as e:
|
||||
if e.errno == errno.ESRCH: # No such process
|
||||
raise StalePIDFileError("PID file refers to non-existing process")
|
||||
elif e.errno == errno.EPERM: # Not permitted to kill
|
||||
return True
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return True
|
||||
|
||||
def __enter__(self) -> "PIDFile":
|
||||
try:
|
||||
if self.isRunning():
|
||||
raise AlreadyRunningError()
|
||||
except StalePIDFileError:
|
||||
self._log.info("Replacing stale PID file: {log_source}")
|
||||
self.writeRunningPID()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
excType: Optional[Type[BaseException]],
|
||||
excValue: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.remove()
|
||||
return None
|
||||
|
||||
|
||||
@implementer(IPIDFile)
|
||||
class NonePIDFile:
|
||||
"""
|
||||
PID file implementation that does nothing.
|
||||
|
||||
This is meant to be used as a "active None" object in place of a PID file
|
||||
when no PID file is desired.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def read(self) -> int:
|
||||
raise NoPIDFound("PID file does not exist")
|
||||
|
||||
def _write(self, pid: int) -> None:
|
||||
"""
|
||||
Store a PID in this PID file.
|
||||
|
||||
@param pid: A PID to store.
|
||||
|
||||
@raise EnvironmentError: If this PID file cannot be written.
|
||||
|
||||
@note: This implementation always raises an L{EnvironmentError}.
|
||||
"""
|
||||
raise OSError(errno.EPERM, "Operation not permitted")
|
||||
|
||||
def writeRunningPID(self) -> None:
|
||||
self._write(0)
|
||||
|
||||
def remove(self) -> None:
|
||||
raise OSError(errno.ENOENT, "No such file or directory")
|
||||
|
||||
def isRunning(self) -> bool:
|
||||
return False
|
||||
|
||||
def __enter__(self) -> "NonePIDFile":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
excType: Optional[Type[BaseException]],
|
||||
excValue: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
nonePIDFile: IPIDFile = NonePIDFile()
|
||||
|
||||
|
||||
class AlreadyRunningError(Exception):
|
||||
"""
|
||||
Process is already running.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidPIDFileError(Exception):
|
||||
"""
|
||||
PID file contents are invalid.
|
||||
"""
|
||||
|
||||
|
||||
class StalePIDFileError(Exception):
|
||||
"""
|
||||
PID file contents are valid, but there is no process with the referenced
|
||||
PID.
|
||||
"""
|
||||
|
||||
|
||||
class NoPIDFound(Exception):
|
||||
"""
|
||||
No PID found in PID file.
|
||||
"""
|
||||
@@ -0,0 +1,166 @@
|
||||
# -*- test-case-name: twisted.application.runner.test.test_runner -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted application runner.
|
||||
"""
|
||||
|
||||
from os import kill
|
||||
from signal import SIGTERM
|
||||
from sys import stderr
|
||||
from typing import Any, Callable, Mapping, TextIO
|
||||
|
||||
from attr import Factory, attrib, attrs
|
||||
from constantly import NamedConstant
|
||||
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
from twisted.logger import (
|
||||
FileLogObserver,
|
||||
FilteringLogObserver,
|
||||
Logger,
|
||||
LogLevel,
|
||||
LogLevelFilterPredicate,
|
||||
globalLogBeginner,
|
||||
textFileLogObserver,
|
||||
)
|
||||
from ._exit import ExitStatus, exit
|
||||
from ._pidfile import AlreadyRunningError, InvalidPIDFileError, IPIDFile, nonePIDFile
|
||||
|
||||
|
||||
@attrs(frozen=True)
|
||||
class Runner:
|
||||
"""
|
||||
Twisted application runner.
|
||||
|
||||
@cvar _log: The logger attached to this class.
|
||||
|
||||
@ivar _reactor: The reactor to start and run the application in.
|
||||
@ivar _pidFile: The file to store the running process ID in.
|
||||
@ivar _kill: Whether this runner should kill an existing running
|
||||
instance of the application.
|
||||
@ivar _defaultLogLevel: The default log level to start the logging
|
||||
system with.
|
||||
@ivar _logFile: A file stream to write logging output to.
|
||||
@ivar _fileLogObserverFactory: A factory for the file log observer to
|
||||
use when starting the logging system.
|
||||
@ivar _whenRunning: Hook to call after the reactor is running;
|
||||
this is where the application code that relies on the reactor gets
|
||||
called.
|
||||
@ivar _whenRunningArguments: Keyword arguments to pass to
|
||||
C{whenRunning} when it is called.
|
||||
@ivar _reactorExited: Hook to call after the reactor exits.
|
||||
@ivar _reactorExitedArguments: Keyword arguments to pass to
|
||||
C{reactorExited} when it is called.
|
||||
"""
|
||||
|
||||
_log = Logger()
|
||||
|
||||
_reactor = attrib(type=IReactorCore)
|
||||
_pidFile = attrib(type=IPIDFile, default=nonePIDFile)
|
||||
_kill = attrib(type=bool, default=False)
|
||||
_defaultLogLevel = attrib(type=NamedConstant, default=LogLevel.info)
|
||||
_logFile = attrib(type=TextIO, default=stderr)
|
||||
_fileLogObserverFactory = attrib(
|
||||
type=Callable[[TextIO], FileLogObserver], default=textFileLogObserver
|
||||
)
|
||||
_whenRunning = attrib(type=Callable[..., None], default=lambda **_: None)
|
||||
_whenRunningArguments = attrib(type=Mapping[str, Any], default=Factory(dict))
|
||||
_reactorExited = attrib(type=Callable[..., None], default=lambda **_: None)
|
||||
_reactorExitedArguments = attrib(type=Mapping[str, Any], default=Factory(dict))
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run this command.
|
||||
"""
|
||||
pidFile = self._pidFile
|
||||
|
||||
self.killIfRequested()
|
||||
|
||||
try:
|
||||
with pidFile:
|
||||
self.startLogging()
|
||||
self.startReactor()
|
||||
self.reactorExited()
|
||||
|
||||
except AlreadyRunningError:
|
||||
exit(ExitStatus.EX_CONFIG, "Already running.")
|
||||
# When testing, patched exit doesn't exit
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
def killIfRequested(self) -> None:
|
||||
"""
|
||||
If C{self._kill} is true, attempt to kill a running instance of the
|
||||
application.
|
||||
"""
|
||||
pidFile = self._pidFile
|
||||
|
||||
if self._kill:
|
||||
if pidFile is nonePIDFile:
|
||||
exit(ExitStatus.EX_USAGE, "No PID file specified.")
|
||||
# When testing, patched exit doesn't exit
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
try:
|
||||
pid = pidFile.read()
|
||||
except OSError:
|
||||
exit(ExitStatus.EX_IOERR, "Unable to read PID file.")
|
||||
# When testing, patched exit doesn't exit
|
||||
return # type: ignore[unreachable]
|
||||
except InvalidPIDFileError:
|
||||
exit(ExitStatus.EX_DATAERR, "Invalid PID file.")
|
||||
# When testing, patched exit doesn't exit
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
self.startLogging()
|
||||
self._log.info("Terminating process: {pid}", pid=pid)
|
||||
|
||||
kill(pid, SIGTERM)
|
||||
|
||||
exit(ExitStatus.EX_OK)
|
||||
# When testing, patched exit doesn't exit
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
def startLogging(self) -> None:
|
||||
"""
|
||||
Start the L{twisted.logger} logging system.
|
||||
"""
|
||||
logFile = self._logFile
|
||||
|
||||
fileLogObserverFactory = self._fileLogObserverFactory
|
||||
|
||||
fileLogObserver = fileLogObserverFactory(logFile)
|
||||
|
||||
logLevelPredicate = LogLevelFilterPredicate(
|
||||
defaultLogLevel=self._defaultLogLevel
|
||||
)
|
||||
|
||||
filteringObserver = FilteringLogObserver(fileLogObserver, [logLevelPredicate])
|
||||
|
||||
globalLogBeginner.beginLoggingTo([filteringObserver])
|
||||
|
||||
def startReactor(self) -> None:
|
||||
"""
|
||||
Register C{self._whenRunning} with the reactor so that it is called
|
||||
once the reactor is running, then start the reactor.
|
||||
"""
|
||||
self._reactor.callWhenRunning(self.whenRunning)
|
||||
|
||||
self._log.info("Starting reactor...")
|
||||
self._reactor.run()
|
||||
|
||||
def whenRunning(self) -> None:
|
||||
"""
|
||||
Call C{self._whenRunning} with C{self._whenRunningArguments}.
|
||||
|
||||
@note: This method is called after the reactor starts running.
|
||||
"""
|
||||
self._whenRunning(**self._whenRunningArguments)
|
||||
|
||||
def reactorExited(self) -> None:
|
||||
"""
|
||||
Call C{self._reactorExited} with C{self._reactorExitedArguments}.
|
||||
|
||||
@note: This method is called after the reactor exits.
|
||||
"""
|
||||
self._reactorExited(**self._reactorExitedArguments)
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.application.runner.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.runner}.
|
||||
"""
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.runner._exit}.
|
||||
"""
|
||||
|
||||
from io import StringIO
|
||||
from typing import Optional, Union
|
||||
|
||||
import twisted.trial.unittest
|
||||
from ...runner import _exit
|
||||
from .._exit import ExitStatus, exit
|
||||
|
||||
|
||||
class ExitTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{exit}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.exit = DummyExit()
|
||||
self.patch(_exit, "sysexit", self.exit)
|
||||
|
||||
def test_exitStatusInt(self) -> None:
|
||||
"""
|
||||
L{exit} given an L{int} status code will pass it to L{sys.exit}.
|
||||
"""
|
||||
status = 1234
|
||||
exit(status)
|
||||
self.assertEqual(self.exit.arg, status) # type: ignore[unreachable]
|
||||
|
||||
def test_exitConstant(self) -> None:
|
||||
"""
|
||||
L{exit} given a L{ValueConstant} status code passes the corresponding
|
||||
value to L{sys.exit}.
|
||||
"""
|
||||
status = ExitStatus.EX_CONFIG
|
||||
exit(status)
|
||||
self.assertEqual(self.exit.arg, status.value) # type: ignore[unreachable]
|
||||
|
||||
def test_exitMessageZero(self) -> None:
|
||||
"""
|
||||
L{exit} given a status code of zero (C{0}) writes the given message to
|
||||
standard output.
|
||||
"""
|
||||
out = StringIO()
|
||||
self.patch(_exit, "stdout", out)
|
||||
|
||||
message = "Hello, world."
|
||||
exit(0, message)
|
||||
|
||||
self.assertEqual(out.getvalue(), message + "\n") # type: ignore[unreachable]
|
||||
|
||||
def test_exitMessageNonZero(self) -> None:
|
||||
"""
|
||||
L{exit} given a non-zero status code writes the given message to
|
||||
standard error.
|
||||
"""
|
||||
out = StringIO()
|
||||
self.patch(_exit, "stderr", out)
|
||||
|
||||
message = "Hello, world."
|
||||
exit(64, message)
|
||||
|
||||
self.assertEqual(out.getvalue(), message + "\n") # type: ignore[unreachable]
|
||||
|
||||
|
||||
class DummyExit:
|
||||
"""
|
||||
Stub for L{sys.exit} that remembers whether it's been called and, if it
|
||||
has, what argument it was given.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.exited = False
|
||||
|
||||
def __call__(self, arg: Optional[Union[int, str]] = None) -> None:
|
||||
assert not self.exited
|
||||
|
||||
self.arg = arg
|
||||
self.exited = True
|
||||
@@ -0,0 +1,419 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.runner._pidfile}.
|
||||
"""
|
||||
|
||||
import errno
|
||||
from functools import wraps
|
||||
from os import getpid, name as SYSTEM_NAME
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
import twisted.trial.unittest
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.runtime import platform
|
||||
from twisted.trial.unittest import SkipTest
|
||||
from ...runner import _pidfile
|
||||
from .._pidfile import (
|
||||
AlreadyRunningError,
|
||||
InvalidPIDFileError,
|
||||
IPIDFile,
|
||||
NonePIDFile,
|
||||
NoPIDFound,
|
||||
PIDFile,
|
||||
StalePIDFileError,
|
||||
)
|
||||
|
||||
|
||||
def ifPlatformSupported(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Decorator for tests that are not expected to work on all platforms.
|
||||
|
||||
Calling L{PIDFile.isRunning} currently raises L{NotImplementedError} on
|
||||
non-POSIX platforms.
|
||||
|
||||
On an unsupported platform, we expect to see any test that calls
|
||||
L{PIDFile.isRunning} to raise either L{NotImplementedError}, L{SkipTest},
|
||||
or C{self.failureException}.
|
||||
(C{self.failureException} may occur in a test that checks for a specific
|
||||
exception but it gets NotImplementedError instead.)
|
||||
|
||||
@param f: The test method to decorate.
|
||||
|
||||
@return: The wrapped callable.
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
supported = platform.getType() == "posix"
|
||||
|
||||
if supported:
|
||||
return f(self, *args, **kwargs)
|
||||
else:
|
||||
e = self.assertRaises(
|
||||
(NotImplementedError, SkipTest, self.failureException),
|
||||
f,
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(e, NotImplementedError):
|
||||
self.assertTrue(str(e).startswith("isRunning is not implemented on "))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PIDFileTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{PIDFile}.
|
||||
"""
|
||||
|
||||
def filePath(self, content: Optional[bytes] = None) -> FilePath[str]:
|
||||
filePath = FilePath(self.mktemp())
|
||||
if content is not None:
|
||||
filePath.setContent(content)
|
||||
return filePath
|
||||
|
||||
def test_interface(self) -> None:
|
||||
"""
|
||||
L{PIDFile} conforms to L{IPIDFile}.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
verifyObject(IPIDFile, pidFile)
|
||||
|
||||
def test_formatWithPID(self) -> None:
|
||||
"""
|
||||
L{PIDFile._format} returns the expected format when given a PID.
|
||||
"""
|
||||
self.assertEqual(PIDFile._format(pid=1337), b"1337\n")
|
||||
|
||||
def test_readWithPID(self) -> None:
|
||||
"""
|
||||
L{PIDFile.read} returns the PID from the given file path.
|
||||
"""
|
||||
pid = 1337
|
||||
|
||||
pidFile = PIDFile(self.filePath(PIDFile._format(pid=pid)))
|
||||
|
||||
self.assertEqual(pid, pidFile.read())
|
||||
|
||||
def test_readEmptyPID(self) -> None:
|
||||
"""
|
||||
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
|
||||
path.
|
||||
"""
|
||||
pidValue = b""
|
||||
pidFile = PIDFile(self.filePath(b""))
|
||||
|
||||
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
|
||||
self.assertEqual(str(e), f"non-integer PID value in PID file: {pidValue!r}")
|
||||
|
||||
def test_readWithBogusPID(self) -> None:
|
||||
"""
|
||||
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
|
||||
path.
|
||||
"""
|
||||
pidValue = b"$foo!"
|
||||
pidFile = PIDFile(self.filePath(pidValue))
|
||||
|
||||
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
|
||||
self.assertEqual(str(e), f"non-integer PID value in PID file: {pidValue!r}")
|
||||
|
||||
def test_readDoesntExist(self) -> None:
|
||||
"""
|
||||
L{PIDFile.read} raises L{NoPIDFound} when given a non-existing file
|
||||
path.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
|
||||
e = self.assertRaises(NoPIDFound, pidFile.read)
|
||||
self.assertEqual(str(e), "PID file does not exist")
|
||||
|
||||
def test_readOpenRaisesOSErrorNotENOENT(self) -> None:
|
||||
"""
|
||||
L{PIDFile.read} re-raises L{OSError} if the associated C{errno} is
|
||||
anything other than L{errno.ENOENT}.
|
||||
"""
|
||||
|
||||
def oops(mode: str = "r") -> NoReturn:
|
||||
raise OSError(errno.EIO, "I/O error")
|
||||
|
||||
self.patch(FilePath, "open", oops)
|
||||
|
||||
pidFile = PIDFile(self.filePath())
|
||||
|
||||
error = self.assertRaises(OSError, pidFile.read)
|
||||
self.assertEqual(error.errno, errno.EIO)
|
||||
|
||||
def test_writePID(self) -> None:
|
||||
"""
|
||||
L{PIDFile._write} stores the given PID.
|
||||
"""
|
||||
pid = 1995
|
||||
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(pid)
|
||||
|
||||
self.assertEqual(pidFile.read(), pid)
|
||||
|
||||
def test_writePIDInvalid(self) -> None:
|
||||
"""
|
||||
L{PIDFile._write} raises L{ValueError} when given an invalid PID.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
|
||||
self.assertRaises(ValueError, pidFile._write, "burp")
|
||||
|
||||
def test_writeRunningPID(self) -> None:
|
||||
"""
|
||||
L{PIDFile.writeRunningPID} stores the PID for the current process.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile.writeRunningPID()
|
||||
|
||||
self.assertEqual(pidFile.read(), getpid())
|
||||
|
||||
def test_remove(self) -> None:
|
||||
"""
|
||||
L{PIDFile.remove} removes the PID file.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(b""))
|
||||
self.assertTrue(pidFile.filePath.exists())
|
||||
|
||||
pidFile.remove()
|
||||
self.assertFalse(pidFile.filePath.exists())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningDoesExist(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} returns true for a process that does exist.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1337)
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
return # Don't actually kill anything
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
self.assertTrue(pidFile.isRunning())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningThis(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} returns true for this process (which is running).
|
||||
|
||||
@note: This differs from L{PIDFileTests.test_isRunningDoesExist} in
|
||||
that it actually invokes the C{kill} system call, which is useful for
|
||||
testing of our chosen method for probing the existence of a process.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile.writeRunningPID()
|
||||
|
||||
self.assertTrue(pidFile.isRunning())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningDoesNotExist(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} raises L{StalePIDFileError} for a process that
|
||||
does not exist (errno=ESRCH).
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1337)
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
raise OSError(errno.ESRCH, "No such process")
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
self.assertRaises(StalePIDFileError, pidFile.isRunning)
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningNotAllowed(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} returns true for a process that we are not allowed
|
||||
to kill (errno=EPERM).
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1337)
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
raise OSError(errno.EPERM, "Operation not permitted")
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
self.assertTrue(pidFile.isRunning())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningInit(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} returns true for a process that we are not allowed
|
||||
to kill (errno=EPERM).
|
||||
|
||||
@note: This differs from L{PIDFileTests.test_isRunningNotAllowed} in
|
||||
that it actually invokes the C{kill} system call, which is useful for
|
||||
testing of our chosen method for probing the existence of a process
|
||||
that we are not allowed to kill.
|
||||
|
||||
@note: In this case, we try killing C{init}, which is process #1 on
|
||||
POSIX systems, so this test is not portable. C{init} should always be
|
||||
running and should not be killable by non-root users.
|
||||
"""
|
||||
if SYSTEM_NAME != "posix":
|
||||
raise SkipTest("This test assumes POSIX")
|
||||
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1) # PID 1 is init on POSIX systems
|
||||
|
||||
self.assertTrue(pidFile.isRunning())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_isRunningUnknownErrno(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} re-raises L{OSError} if the attached C{errno}
|
||||
value from L{os.kill} is not an expected one.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile.writeRunningPID()
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
raise OSError(errno.EEXIST, "File exists")
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
self.assertRaises(OSError, pidFile.isRunning)
|
||||
|
||||
def test_isRunningNoPIDFile(self) -> None:
|
||||
"""
|
||||
L{PIDFile.isRunning} returns false if the PID file doesn't exist.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
|
||||
self.assertFalse(pidFile.isRunning())
|
||||
|
||||
def test_contextManager(self) -> None:
|
||||
"""
|
||||
When used as a context manager, a L{PIDFile} will store the current pid
|
||||
on entry, then removes the PID file on exit.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
self.assertFalse(pidFile.filePath.exists())
|
||||
|
||||
with pidFile:
|
||||
self.assertTrue(pidFile.filePath.exists())
|
||||
self.assertEqual(pidFile.read(), getpid())
|
||||
|
||||
self.assertFalse(pidFile.filePath.exists())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_contextManagerDoesntExist(self) -> None:
|
||||
"""
|
||||
When used as a context manager, a L{PIDFile} will replace the
|
||||
underlying PIDFile rather than raising L{AlreadyRunningError} if the
|
||||
contained PID file exists but refers to a non-running PID.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1337)
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
raise OSError(errno.ESRCH, "No such process")
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
e = self.assertRaises(StalePIDFileError, pidFile.isRunning)
|
||||
self.assertEqual(str(e), "PID file refers to non-existing process")
|
||||
|
||||
with pidFile:
|
||||
self.assertEqual(pidFile.read(), getpid())
|
||||
|
||||
@ifPlatformSupported
|
||||
def test_contextManagerAlreadyRunning(self) -> None:
|
||||
"""
|
||||
When used as a context manager, a L{PIDFile} will raise
|
||||
L{AlreadyRunningError} if the there is already a running process with
|
||||
the contained PID.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath())
|
||||
pidFile._write(1337)
|
||||
|
||||
def kill(pid: int, signal: int) -> None:
|
||||
return # Don't actually kill anything
|
||||
|
||||
self.patch(_pidfile, "kill", kill)
|
||||
|
||||
self.assertTrue(pidFile.isRunning())
|
||||
|
||||
self.assertRaises(AlreadyRunningError, pidFile.__enter__)
|
||||
|
||||
|
||||
class NonePIDFileTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{NonePIDFile}.
|
||||
"""
|
||||
|
||||
def test_interface(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile} conforms to L{IPIDFile}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
verifyObject(IPIDFile, pidFile)
|
||||
|
||||
def test_read(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile.read} raises L{NoPIDFound}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
e = self.assertRaises(NoPIDFound, pidFile.read)
|
||||
self.assertEqual(str(e), "PID file does not exist")
|
||||
|
||||
def test_write(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile._write} raises L{OSError} with an errno of L{errno.EPERM}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
error = self.assertRaises(OSError, pidFile._write, 0)
|
||||
self.assertEqual(error.errno, errno.EPERM)
|
||||
|
||||
def test_writeRunningPID(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile.writeRunningPID} raises L{OSError} with an errno of
|
||||
L{errno.EPERM}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
error = self.assertRaises(OSError, pidFile.writeRunningPID)
|
||||
self.assertEqual(error.errno, errno.EPERM)
|
||||
|
||||
def test_remove(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile.remove} raises L{OSError} with an errno of L{errno.EPERM}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
error = self.assertRaises(OSError, pidFile.remove)
|
||||
self.assertEqual(error.errno, errno.ENOENT)
|
||||
|
||||
def test_isRunning(self) -> None:
|
||||
"""
|
||||
L{NonePIDFile.isRunning} returns L{False}.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
self.assertEqual(pidFile.isRunning(), False)
|
||||
|
||||
def test_contextManager(self) -> None:
|
||||
"""
|
||||
When used as a context manager, a L{NonePIDFile} doesn't raise, despite
|
||||
not existing.
|
||||
"""
|
||||
pidFile = NonePIDFile()
|
||||
|
||||
with pidFile:
|
||||
pass
|
||||
@@ -0,0 +1,454 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.runner._runner}.
|
||||
"""
|
||||
|
||||
import errno
|
||||
from io import StringIO
|
||||
from signal import SIGTERM
|
||||
from types import TracebackType
|
||||
from typing import Any, Iterable, List, Optional, TextIO, Tuple, Type, Union, cast
|
||||
|
||||
from attr import Factory, attrib, attrs
|
||||
|
||||
import twisted.trial.unittest
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
from twisted.logger import (
|
||||
FileLogObserver,
|
||||
FilteringLogObserver,
|
||||
ILogObserver,
|
||||
LogBeginner,
|
||||
LogLevel,
|
||||
LogLevelFilterPredicate,
|
||||
LogPublisher,
|
||||
)
|
||||
from twisted.python.filepath import FilePath
|
||||
from ...runner import _runner
|
||||
from .._exit import ExitStatus
|
||||
from .._pidfile import NonePIDFile, PIDFile
|
||||
from .._runner import Runner
|
||||
|
||||
|
||||
class RunnerTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{Runner}.
|
||||
"""
|
||||
|
||||
def filePath(self, content: Optional[bytes] = None) -> FilePath[str]:
|
||||
filePath = FilePath(self.mktemp())
|
||||
if content is not None:
|
||||
filePath.setContent(content)
|
||||
return filePath
|
||||
|
||||
def setUp(self) -> None:
|
||||
# Patch exit and kill so we can capture usage and prevent actual exits
|
||||
# and kills.
|
||||
|
||||
self.exit = DummyExit()
|
||||
self.kill = DummyKill()
|
||||
|
||||
self.patch(_runner, "exit", self.exit)
|
||||
self.patch(_runner, "kill", self.kill)
|
||||
|
||||
# Patch getpid so we get a known result
|
||||
|
||||
self.pid = 1337
|
||||
self.pidFileContent = f"{self.pid}\n".encode()
|
||||
|
||||
# Patch globalLogBeginner so that we aren't trying to install multiple
|
||||
# global log observers.
|
||||
|
||||
self.stdout = StringIO()
|
||||
self.stderr = StringIO()
|
||||
self.stdio = DummyStandardIO(self.stdout, self.stderr)
|
||||
self.warnings = DummyWarningsModule()
|
||||
|
||||
self.globalLogPublisher = LogPublisher()
|
||||
self.globalLogBeginner = LogBeginner(
|
||||
self.globalLogPublisher,
|
||||
self.stdio.stderr,
|
||||
self.stdio,
|
||||
self.warnings,
|
||||
)
|
||||
|
||||
self.patch(_runner, "stderr", self.stderr)
|
||||
self.patch(_runner, "globalLogBeginner", self.globalLogBeginner)
|
||||
|
||||
def test_runInOrder(self) -> None:
|
||||
"""
|
||||
L{Runner.run} calls the expected methods in order.
|
||||
"""
|
||||
runner = DummyRunner(reactor=MemoryReactor())
|
||||
runner.run()
|
||||
|
||||
self.assertEqual(
|
||||
runner.calledMethods,
|
||||
[
|
||||
"killIfRequested",
|
||||
"startLogging",
|
||||
"startReactor",
|
||||
"reactorExited",
|
||||
],
|
||||
)
|
||||
|
||||
def test_runUsesPIDFile(self) -> None:
|
||||
"""
|
||||
L{Runner.run} uses the provided PID file.
|
||||
"""
|
||||
pidFile = DummyPIDFile()
|
||||
|
||||
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
|
||||
|
||||
self.assertFalse(pidFile.entered)
|
||||
self.assertFalse(pidFile.exited)
|
||||
|
||||
runner.run()
|
||||
|
||||
self.assertTrue(pidFile.entered)
|
||||
self.assertTrue(pidFile.exited)
|
||||
|
||||
def test_runAlreadyRunning(self) -> None:
|
||||
"""
|
||||
L{Runner.run} exits with L{ExitStatus.EX_USAGE} and the expected
|
||||
message if a process is already running that corresponds to the given
|
||||
PID file.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(self.pidFileContent))
|
||||
pidFile.isRunning = lambda: True # type: ignore[method-assign]
|
||||
|
||||
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
|
||||
runner.run()
|
||||
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_CONFIG)
|
||||
self.assertEqual(self.exit.message, "Already running.")
|
||||
|
||||
def test_killNotRequested(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is false doesn't exit and
|
||||
doesn't indiscriminately murder anyone.
|
||||
"""
|
||||
runner = Runner(reactor=MemoryReactor())
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.kill.calls, [])
|
||||
self.assertFalse(self.exit.exited)
|
||||
|
||||
def test_killRequestedWithoutPIDFile(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is true but C{pidFile} is
|
||||
L{nonePIDFile} exits with L{ExitStatus.EX_USAGE} and the expected
|
||||
message; and also doesn't indiscriminately murder anyone.
|
||||
"""
|
||||
runner = Runner(reactor=MemoryReactor(), kill=True)
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.kill.calls, [])
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_USAGE)
|
||||
self.assertEqual(self.exit.message, "No PID file specified.")
|
||||
|
||||
def test_killRequestedWithPIDFile(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
|
||||
performs a targeted killing of the appropriate process.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(self.pidFileContent))
|
||||
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.kill.calls, [(self.pid, SIGTERM)])
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_OK)
|
||||
self.assertIdentical(self.exit.message, None)
|
||||
|
||||
def test_killRequestedWithPIDFileCantRead(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
|
||||
that it can't read exits with L{ExitStatus.EX_IOERR}.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(None))
|
||||
|
||||
def read() -> int:
|
||||
raise OSError(errno.EACCES, "Permission denied")
|
||||
|
||||
pidFile.read = read # type: ignore[method-assign]
|
||||
|
||||
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_IOERR)
|
||||
self.assertEqual(self.exit.message, "Unable to read PID file.")
|
||||
|
||||
def test_killRequestedWithPIDFileEmpty(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
|
||||
containing no value exits with L{ExitStatus.EX_DATAERR}.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(b""))
|
||||
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
|
||||
self.assertEqual(self.exit.message, "Invalid PID file.")
|
||||
|
||||
def test_killRequestedWithPIDFileNotAnInt(self) -> None:
|
||||
"""
|
||||
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
|
||||
containing a non-integer value exits with L{ExitStatus.EX_DATAERR}.
|
||||
"""
|
||||
pidFile = PIDFile(self.filePath(b"** totally not a number, dude **"))
|
||||
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
|
||||
runner.killIfRequested()
|
||||
|
||||
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
|
||||
self.assertEqual(self.exit.message, "Invalid PID file.")
|
||||
|
||||
def test_startLogging(self) -> None:
|
||||
"""
|
||||
L{Runner.startLogging} sets up a filtering observer with a log level
|
||||
predicate set to the given log level that contains a file observer of
|
||||
the given type which writes to the given file.
|
||||
"""
|
||||
logFile = StringIO()
|
||||
|
||||
# Patch the log beginner so that we don't try to start the already
|
||||
# running (started by trial) logging system.
|
||||
|
||||
class LogBeginner:
|
||||
observers: List[ILogObserver] = []
|
||||
|
||||
def beginLoggingTo(self, observers: Iterable[ILogObserver]) -> None:
|
||||
LogBeginner.observers = list(observers)
|
||||
|
||||
self.patch(_runner, "globalLogBeginner", LogBeginner())
|
||||
|
||||
# Patch FilteringLogObserver so we can capture its arguments
|
||||
|
||||
class MockFilteringLogObserver(FilteringLogObserver):
|
||||
observer: Optional[ILogObserver] = None
|
||||
predicates: List[LogLevelFilterPredicate] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observer: ILogObserver,
|
||||
predicates: Iterable[LogLevelFilterPredicate],
|
||||
negativeObserver: ILogObserver = cast(ILogObserver, lambda event: None),
|
||||
) -> None:
|
||||
MockFilteringLogObserver.observer = observer
|
||||
MockFilteringLogObserver.predicates = list(predicates)
|
||||
FilteringLogObserver.__init__(
|
||||
self, observer, predicates, negativeObserver
|
||||
)
|
||||
|
||||
self.patch(_runner, "FilteringLogObserver", MockFilteringLogObserver)
|
||||
|
||||
# Patch FileLogObserver so we can capture its arguments
|
||||
|
||||
class MockFileLogObserver(FileLogObserver):
|
||||
outFile: Optional[TextIO] = None
|
||||
|
||||
def __init__(self, outFile: TextIO) -> None:
|
||||
MockFileLogObserver.outFile = outFile
|
||||
FileLogObserver.__init__(self, outFile, str)
|
||||
|
||||
# Start logging
|
||||
runner = Runner(
|
||||
reactor=MemoryReactor(),
|
||||
defaultLogLevel=LogLevel.critical,
|
||||
logFile=logFile,
|
||||
fileLogObserverFactory=MockFileLogObserver,
|
||||
)
|
||||
runner.startLogging()
|
||||
|
||||
# Check for a filtering observer
|
||||
self.assertEqual(len(LogBeginner.observers), 1)
|
||||
self.assertIsInstance(LogBeginner.observers[0], FilteringLogObserver)
|
||||
|
||||
# Check log level predicate with the correct default log level
|
||||
self.assertEqual(len(MockFilteringLogObserver.predicates), 1)
|
||||
self.assertIsInstance(
|
||||
MockFilteringLogObserver.predicates[0], LogLevelFilterPredicate
|
||||
)
|
||||
self.assertIdentical(
|
||||
MockFilteringLogObserver.predicates[0].defaultLogLevel, LogLevel.critical
|
||||
)
|
||||
|
||||
# Check for a file observer attached to the filtering observer
|
||||
observer = cast(MockFileLogObserver, MockFilteringLogObserver.observer)
|
||||
self.assertIsInstance(observer, MockFileLogObserver)
|
||||
|
||||
# Check for the file we gave it
|
||||
self.assertIdentical(observer.outFile, logFile)
|
||||
|
||||
def test_startReactorWithReactor(self) -> None:
|
||||
"""
|
||||
L{Runner.startReactor} with the C{reactor} argument runs the given
|
||||
reactor.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
runner = Runner(reactor=reactor)
|
||||
runner.startReactor()
|
||||
|
||||
self.assertTrue(reactor.hasRun)
|
||||
|
||||
def test_startReactorWhenRunning(self) -> None:
|
||||
"""
|
||||
L{Runner.startReactor} ensures that C{whenRunning} is called with
|
||||
C{whenRunningArguments} when the reactor is running.
|
||||
"""
|
||||
self._testHook("whenRunning", "startReactor")
|
||||
|
||||
def test_whenRunningWithArguments(self) -> None:
|
||||
"""
|
||||
L{Runner.whenRunning} calls C{whenRunning} with
|
||||
C{whenRunningArguments}.
|
||||
"""
|
||||
self._testHook("whenRunning")
|
||||
|
||||
def test_reactorExitedWithArguments(self) -> None:
|
||||
"""
|
||||
L{Runner.whenRunning} calls C{reactorExited} with
|
||||
C{reactorExitedArguments}.
|
||||
"""
|
||||
self._testHook("reactorExited")
|
||||
|
||||
def _testHook(self, methodName: str, callerName: Optional[str] = None) -> None:
|
||||
"""
|
||||
Verify that the named hook is run with the expected arguments as
|
||||
specified by the arguments used to create the L{Runner}, when the
|
||||
specified caller is invoked.
|
||||
|
||||
@param methodName: The name of the hook to verify.
|
||||
|
||||
@param callerName: The name of the method that is expected to cause the
|
||||
hook to be called.
|
||||
If C{None}, use the L{Runner} method with the same name as the
|
||||
hook.
|
||||
"""
|
||||
if callerName is None:
|
||||
callerName = methodName
|
||||
|
||||
arguments = dict(a=object(), b=object(), c=object())
|
||||
argumentsSeen = []
|
||||
|
||||
def hook(**arguments: object) -> None:
|
||||
argumentsSeen.append(arguments)
|
||||
|
||||
runnerArguments = {
|
||||
methodName: hook,
|
||||
f"{methodName}Arguments": arguments.copy(),
|
||||
}
|
||||
runner = Runner(
|
||||
reactor=MemoryReactor(), **runnerArguments # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
hookCaller = getattr(runner, callerName)
|
||||
hookCaller()
|
||||
|
||||
self.assertEqual(len(argumentsSeen), 1)
|
||||
self.assertEqual(argumentsSeen[0], arguments)
|
||||
|
||||
|
||||
@attrs(frozen=True)
|
||||
class DummyRunner(Runner):
|
||||
"""
|
||||
Stub for L{Runner}.
|
||||
|
||||
Keep track of calls to some methods without actually doing anything.
|
||||
"""
|
||||
|
||||
calledMethods = attrib(type=List[str], default=Factory(list))
|
||||
|
||||
def killIfRequested(self) -> None:
|
||||
self.calledMethods.append("killIfRequested")
|
||||
|
||||
def startLogging(self) -> None:
|
||||
self.calledMethods.append("startLogging")
|
||||
|
||||
def startReactor(self) -> None:
|
||||
self.calledMethods.append("startReactor")
|
||||
|
||||
def reactorExited(self) -> None:
|
||||
self.calledMethods.append("reactorExited")
|
||||
|
||||
|
||||
class DummyPIDFile(NonePIDFile):
|
||||
"""
|
||||
Stub for L{PIDFile}.
|
||||
|
||||
Tracks context manager entry/exit without doing anything.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
NonePIDFile.__init__(self)
|
||||
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
|
||||
def __enter__(self) -> "DummyPIDFile":
|
||||
self.entered = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
excType: Optional[Type[BaseException]],
|
||||
excValue: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.exited = True
|
||||
|
||||
|
||||
class DummyExit:
|
||||
"""
|
||||
Stub for L{_exit.exit} that remembers whether it's been called and, if it has,
|
||||
what arguments it was given.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.exited = False
|
||||
|
||||
def __call__(
|
||||
self, status: Union[int, ExitStatus], message: Optional[str] = None
|
||||
) -> None:
|
||||
assert not self.exited
|
||||
|
||||
self.status = status
|
||||
self.message = message
|
||||
self.exited = True
|
||||
|
||||
|
||||
class DummyKill:
|
||||
"""
|
||||
Stub for L{os.kill} that remembers whether it's been called and, if it has,
|
||||
what arguments it was given.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[int, int]] = []
|
||||
|
||||
def __call__(self, pid: int, sig: int) -> None:
|
||||
self.calls.append((pid, sig))
|
||||
|
||||
|
||||
class DummyStandardIO:
|
||||
"""
|
||||
Stub for L{sys} which provides L{StringIO} streams as stdout and stderr.
|
||||
"""
|
||||
|
||||
def __init__(self, stdout: TextIO, stderr: TextIO) -> None:
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
class DummyWarningsModule:
|
||||
"""
|
||||
Stub for L{warnings} which provides a C{showwarning} method that is a no-op.
|
||||
"""
|
||||
|
||||
def showwarning(*args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Do nothing.
|
||||
|
||||
@param args: ignored.
|
||||
@param kwargs: ignored.
|
||||
"""
|
||||
@@ -0,0 +1,420 @@
|
||||
# -*- test-case-name: twisted.application.test.test_service -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Service architecture for Twisted.
|
||||
|
||||
Services are arranged in a hierarchy. At the leafs of the hierarchy,
|
||||
the services which actually interact with the outside world are started.
|
||||
Services can be named or anonymous -- usually, they will be named if
|
||||
there is need to access them through the hierarchy (from a parent or
|
||||
a sibling).
|
||||
|
||||
Maintainer: Moshe Zadka
|
||||
"""
|
||||
|
||||
|
||||
from zope.interface import Attribute, Interface, implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.persisted import sob
|
||||
from twisted.plugin import IPlugin
|
||||
from twisted.python import components
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
|
||||
class IServiceMaker(Interface):
|
||||
"""
|
||||
An object which can be used to construct services in a flexible
|
||||
way.
|
||||
|
||||
This interface should most often be implemented along with
|
||||
L{twisted.plugin.IPlugin}, and will most often be used by the
|
||||
'twistd' command.
|
||||
"""
|
||||
|
||||
tapname = Attribute(
|
||||
"A short string naming this Twisted plugin, for example 'web' or "
|
||||
"'pencil'. This name will be used as the subcommand of 'twistd'."
|
||||
)
|
||||
|
||||
description = Attribute(
|
||||
"A brief summary of the features provided by this "
|
||||
"Twisted application plugin."
|
||||
)
|
||||
|
||||
options = Attribute(
|
||||
"A C{twisted.python.usage.Options} subclass defining the "
|
||||
"configuration options for this application."
|
||||
)
|
||||
|
||||
def makeService(options):
|
||||
"""
|
||||
Create and return an object providing
|
||||
L{twisted.application.service.IService}.
|
||||
|
||||
@param options: A mapping (typically a C{dict} or
|
||||
L{twisted.python.usage.Options} instance) of configuration
|
||||
options to desired configuration values.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IPlugin, IServiceMaker)
|
||||
class ServiceMaker:
|
||||
"""
|
||||
Utility class to simplify the definition of L{IServiceMaker} plugins.
|
||||
"""
|
||||
|
||||
def __init__(self, name, module, description, tapname):
|
||||
self.name = name
|
||||
self.module = module
|
||||
self.description = description
|
||||
self.tapname = tapname
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return namedAny(self.module).Options
|
||||
|
||||
@property
|
||||
def makeService(self):
|
||||
return namedAny(self.module).makeService
|
||||
|
||||
|
||||
class IService(Interface):
|
||||
"""
|
||||
A service.
|
||||
|
||||
Run start-up and shut-down code at the appropriate times.
|
||||
"""
|
||||
|
||||
name = Attribute("A C{str} which is the name of the service or C{None}.")
|
||||
|
||||
running = Attribute("A C{boolean} which indicates whether the service is running.")
|
||||
|
||||
parent = Attribute("An C{IServiceCollection} which is the parent or C{None}.")
|
||||
|
||||
def setName(name):
|
||||
"""
|
||||
Set the name of the service.
|
||||
|
||||
@type name: C{str}
|
||||
@raise RuntimeError: Raised if the service already has a parent.
|
||||
"""
|
||||
|
||||
def setServiceParent(parent):
|
||||
"""
|
||||
Set the parent of the service. This method is responsible for setting
|
||||
the C{parent} attribute on this service (the child service).
|
||||
|
||||
@type parent: L{IServiceCollection}
|
||||
@raise RuntimeError: Raised if the service already has a parent
|
||||
or if the service has a name and the parent already has a child
|
||||
by that name.
|
||||
"""
|
||||
|
||||
def disownServiceParent():
|
||||
"""
|
||||
Use this API to remove an L{IService} from an L{IServiceCollection}.
|
||||
|
||||
This method is used symmetrically with L{setServiceParent} in that it
|
||||
sets the C{parent} attribute on the child.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, L{None}).
|
||||
"""
|
||||
|
||||
def startService():
|
||||
"""
|
||||
Start the service.
|
||||
"""
|
||||
|
||||
def stopService():
|
||||
"""
|
||||
Stop the service.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, L{None}).
|
||||
"""
|
||||
|
||||
def privilegedStartService():
|
||||
"""
|
||||
Do preparation work for starting the service.
|
||||
|
||||
Here things which should be done before changing directory,
|
||||
root or shedding privileges are done.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IService)
|
||||
class Service:
|
||||
"""
|
||||
Base class for services.
|
||||
|
||||
Most services should inherit from this class. It handles the
|
||||
book-keeping responsibilities of starting and stopping, as well
|
||||
as not serializing this book-keeping information.
|
||||
"""
|
||||
|
||||
running = 0
|
||||
name = None
|
||||
parent = None
|
||||
|
||||
def __getstate__(self):
|
||||
dict = self.__dict__.copy()
|
||||
if "running" in dict:
|
||||
del dict["running"]
|
||||
return dict
|
||||
|
||||
def setName(self, name):
|
||||
if self.parent is not None:
|
||||
raise RuntimeError("cannot change name when parent exists")
|
||||
self.name = name
|
||||
|
||||
def setServiceParent(self, parent):
|
||||
if self.parent is not None:
|
||||
self.disownServiceParent()
|
||||
parent = IServiceCollection(parent, parent)
|
||||
self.parent = parent
|
||||
self.parent.addService(self)
|
||||
|
||||
def disownServiceParent(self):
|
||||
d = self.parent.removeService(self)
|
||||
self.parent = None
|
||||
return d
|
||||
|
||||
def privilegedStartService(self):
|
||||
pass
|
||||
|
||||
def startService(self):
|
||||
self.running = 1
|
||||
|
||||
def stopService(self):
|
||||
self.running = 0
|
||||
|
||||
|
||||
class IServiceCollection(Interface):
|
||||
"""
|
||||
Collection of services.
|
||||
|
||||
Contain several services, and manage their start-up/shut-down.
|
||||
Services can be accessed by name if they have a name, and it
|
||||
is always possible to iterate over them.
|
||||
"""
|
||||
|
||||
def getServiceNamed(name):
|
||||
"""
|
||||
Get the child service with a given name.
|
||||
|
||||
@type name: C{str}
|
||||
@rtype: L{IService}
|
||||
@raise KeyError: Raised if the service has no child with the
|
||||
given name.
|
||||
"""
|
||||
|
||||
def __iter__():
|
||||
"""
|
||||
Get an iterator over all child services.
|
||||
"""
|
||||
|
||||
def addService(service):
|
||||
"""
|
||||
Add a child service.
|
||||
|
||||
Only implementations of L{IService.setServiceParent} should use this
|
||||
method.
|
||||
|
||||
@type service: L{IService}
|
||||
@raise RuntimeError: Raised if the service has a child with
|
||||
the given name.
|
||||
"""
|
||||
|
||||
def removeService(service):
|
||||
"""
|
||||
Remove a child service.
|
||||
|
||||
Only implementations of L{IService.disownServiceParent} should
|
||||
use this method.
|
||||
|
||||
@type service: L{IService}
|
||||
@raise ValueError: Raised if the given service is not a child.
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, L{None}).
|
||||
"""
|
||||
|
||||
|
||||
@implementer(IServiceCollection)
|
||||
class MultiService(Service):
|
||||
"""
|
||||
Straightforward Service Container.
|
||||
|
||||
Hold a collection of services, and manage them in a simplistic
|
||||
way. No service will wait for another, but this object itself
|
||||
will not finish shutting down until all of its child services
|
||||
will finish.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
self.namedServices = {}
|
||||
self.parent = None
|
||||
|
||||
def privilegedStartService(self):
|
||||
Service.privilegedStartService(self)
|
||||
for service in self:
|
||||
service.privilegedStartService()
|
||||
|
||||
def startService(self):
|
||||
Service.startService(self)
|
||||
for service in self:
|
||||
service.startService()
|
||||
|
||||
def stopService(self):
|
||||
Service.stopService(self)
|
||||
l = []
|
||||
services = list(self)
|
||||
services.reverse()
|
||||
for service in services:
|
||||
l.append(defer.maybeDeferred(service.stopService))
|
||||
return defer.DeferredList(l)
|
||||
|
||||
def getServiceNamed(self, name):
|
||||
return self.namedServices[name]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.services)
|
||||
|
||||
def addService(self, service):
|
||||
if service.name is not None:
|
||||
if service.name in self.namedServices:
|
||||
raise RuntimeError(
|
||||
"cannot have two services with same name" " '%s'" % service.name
|
||||
)
|
||||
self.namedServices[service.name] = service
|
||||
self.services.append(service)
|
||||
if self.running:
|
||||
# It may be too late for that, but we will do our best
|
||||
service.privilegedStartService()
|
||||
service.startService()
|
||||
|
||||
def removeService(self, service):
|
||||
if service.name:
|
||||
del self.namedServices[service.name]
|
||||
self.services.remove(service)
|
||||
if self.running:
|
||||
# Returning this so as not to lose information from the
|
||||
# MultiService.stopService deferred.
|
||||
return service.stopService()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class IProcess(Interface):
|
||||
"""
|
||||
Process running parameters.
|
||||
|
||||
Represents parameters for how processes should be run.
|
||||
"""
|
||||
|
||||
processName = Attribute(
|
||||
"""
|
||||
A C{str} giving the name the process should have in ps (or L{None}
|
||||
to leave the name alone).
|
||||
"""
|
||||
)
|
||||
|
||||
uid = Attribute(
|
||||
"""
|
||||
An C{int} giving the user id as which the process should run (or
|
||||
L{None} to leave the UID alone).
|
||||
"""
|
||||
)
|
||||
|
||||
gid = Attribute(
|
||||
"""
|
||||
An C{int} giving the group id as which the process should run (or
|
||||
L{None} to leave the GID alone).
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@implementer(IProcess)
|
||||
class Process:
|
||||
"""
|
||||
Process running parameters.
|
||||
|
||||
Sets up uid/gid in the constructor, and has a default
|
||||
of L{None} as C{processName}.
|
||||
"""
|
||||
|
||||
processName = None
|
||||
|
||||
def __init__(self, uid=None, gid=None):
|
||||
"""
|
||||
Set uid and gid.
|
||||
|
||||
@param uid: The user ID as whom to execute the process. If
|
||||
this is L{None}, no attempt will be made to change the UID.
|
||||
|
||||
@param gid: The group ID as whom to execute the process. If
|
||||
this is L{None}, no attempt will be made to change the GID.
|
||||
"""
|
||||
self.uid = uid
|
||||
self.gid = gid
|
||||
|
||||
|
||||
def Application(name, uid=None, gid=None):
|
||||
"""
|
||||
Return a compound class.
|
||||
|
||||
Return an object supporting the L{IService}, L{IServiceCollection},
|
||||
L{IProcess} and L{sob.IPersistable} interfaces, with the given
|
||||
parameters. Always access the return value by explicit casting to
|
||||
one of the interfaces.
|
||||
"""
|
||||
ret = components.Componentized()
|
||||
availableComponents = [MultiService(), Process(uid, gid), sob.Persistent(ret, name)]
|
||||
|
||||
for comp in availableComponents:
|
||||
ret.addComponent(comp, ignoreClass=1)
|
||||
IService(ret).setName(name)
|
||||
return ret
|
||||
|
||||
|
||||
def loadApplication(filename, kind, passphrase=None):
|
||||
"""
|
||||
Load Application from a given file.
|
||||
|
||||
The serialization format it was saved in should be given as
|
||||
C{kind}, and is one of C{pickle}, C{source}, C{xml} or C{python}. If
|
||||
C{passphrase} is given, the application was encrypted with the
|
||||
given passphrase.
|
||||
|
||||
@type filename: C{str}
|
||||
@type kind: C{str}
|
||||
@type passphrase: C{str}
|
||||
"""
|
||||
if kind == "python":
|
||||
application = sob.loadValueFromFile(filename, "application")
|
||||
else:
|
||||
application = sob.load(filename, kind)
|
||||
return application
|
||||
|
||||
|
||||
__all__ = [
|
||||
"IServiceMaker",
|
||||
"IService",
|
||||
"Service",
|
||||
"IServiceCollection",
|
||||
"MultiService",
|
||||
"IProcess",
|
||||
"Process",
|
||||
"Application",
|
||||
"loadApplication",
|
||||
]
|
||||
@@ -0,0 +1,83 @@
|
||||
# -*- test-case-name: twisted.test.test_strports -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Construct listening port services from a simple string description.
|
||||
|
||||
@see: L{twisted.internet.endpoints.serverFromString}
|
||||
@see: L{twisted.internet.endpoints.clientFromString}
|
||||
"""
|
||||
from typing import Optional, cast
|
||||
|
||||
from twisted.application.internet import StreamServerEndpointService
|
||||
from twisted.internet import endpoints, interfaces
|
||||
|
||||
|
||||
def _getReactor() -> interfaces.IReactorCore:
|
||||
from twisted.internet import reactor
|
||||
|
||||
return cast(interfaces.IReactorCore, reactor)
|
||||
|
||||
|
||||
def service(
|
||||
description: str,
|
||||
factory: interfaces.IProtocolFactory,
|
||||
reactor: Optional[interfaces.IReactorCore] = None,
|
||||
) -> StreamServerEndpointService:
|
||||
"""
|
||||
Return the service corresponding to a description.
|
||||
|
||||
@param description: The description of the listening port, in the syntax
|
||||
described by L{twisted.internet.endpoints.serverFromString}.
|
||||
@type description: C{str}
|
||||
|
||||
@param factory: The protocol factory which will build protocols for
|
||||
connections to this service.
|
||||
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
|
||||
@rtype: C{twisted.application.service.IService}
|
||||
@return: the service corresponding to a description of a reliable stream
|
||||
server.
|
||||
|
||||
@see: L{twisted.internet.endpoints.serverFromString}
|
||||
"""
|
||||
if reactor is None:
|
||||
reactor = _getReactor()
|
||||
|
||||
svc = StreamServerEndpointService(
|
||||
endpoints.serverFromString(reactor, description), factory
|
||||
)
|
||||
svc._raiseSynchronously = True
|
||||
return svc
|
||||
|
||||
|
||||
def listen(
|
||||
description: str, factory: interfaces.IProtocolFactory
|
||||
) -> interfaces.IListeningPort:
|
||||
"""
|
||||
Listen on a port corresponding to a description.
|
||||
|
||||
@param description: The description of the connecting port, in the syntax
|
||||
described by L{twisted.internet.endpoints.serverFromString}.
|
||||
@type description: L{str}
|
||||
|
||||
@param factory: The protocol factory which will build protocols on
|
||||
connection.
|
||||
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
|
||||
@rtype: L{twisted.internet.interfaces.IListeningPort}
|
||||
@return: the port corresponding to a description of a reliable virtual
|
||||
circuit server.
|
||||
|
||||
@see: L{twisted.internet.endpoints.serverFromString}
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
|
||||
name, args, kw = endpoints._parseServer(description, factory)
|
||||
return cast(
|
||||
interfaces.IListeningPort, getattr(reactor, "listen" + name)(*args, **kw)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["service", "listen"]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.application}.
|
||||
"""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.service}.
|
||||
"""
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
from zope.interface.exceptions import BrokenImplementation
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
from twisted.application.service import (
|
||||
Application,
|
||||
IProcess,
|
||||
IService,
|
||||
IServiceCollection,
|
||||
Service,
|
||||
)
|
||||
from twisted.persisted.sob import IPersistable
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
@implementer(IService)
|
||||
class AlmostService:
|
||||
"""
|
||||
Implement IService in a way that can fail.
|
||||
|
||||
In general, classes should maintain invariants that adhere
|
||||
to the interfaces that they claim to implement --
|
||||
otherwise, it is a bug.
|
||||
|
||||
This is a buggy class -- the IService implementation is fragile,
|
||||
and several methods will break it. These bugs are intentional,
|
||||
as the tests trigger them -- and then check that the class,
|
||||
indeed, no longer complies with the interface (IService)
|
||||
that it claims to comply with.
|
||||
|
||||
Since the verification will, by definition, only fail on buggy classes --
|
||||
in other words, those which do not actually support the interface they
|
||||
claim to support, we have to write a buggy class to properly verify
|
||||
the interface.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, parent: IServiceCollection, running: bool) -> None:
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.running = running
|
||||
|
||||
def makeInvalidByDeletingName(self) -> None:
|
||||
"""
|
||||
Probably not a wise method to call.
|
||||
|
||||
This method removes the :code:`name` attribute,
|
||||
which has to exist in IService classes.
|
||||
"""
|
||||
del self.name
|
||||
|
||||
def makeInvalidByDeletingParent(self) -> None:
|
||||
"""
|
||||
Probably not a wise method to call.
|
||||
|
||||
This method removes the :code:`parent` attribute,
|
||||
which has to exist in IService classes.
|
||||
"""
|
||||
del self.parent
|
||||
|
||||
def makeInvalidByDeletingRunning(self) -> None:
|
||||
"""
|
||||
Probably not a wise method to call.
|
||||
|
||||
This method removes the :code:`running` attribute,
|
||||
which has to exist in IService classes.
|
||||
"""
|
||||
del self.running
|
||||
|
||||
def setName(self, name: object) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
|
||||
@param name: ignored
|
||||
"""
|
||||
|
||||
def setServiceParent(self, parent: object) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
|
||||
@param parent: ignored
|
||||
"""
|
||||
|
||||
def disownServiceParent(self) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
"""
|
||||
|
||||
def privilegedStartService(self) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
"""
|
||||
|
||||
def startService(self) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
"""
|
||||
|
||||
def stopService(self) -> None:
|
||||
"""
|
||||
See L{twisted.application.service.IService}.
|
||||
"""
|
||||
|
||||
|
||||
class ServiceInterfaceTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.application.service.IService} implementation.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Build something that implements IService.
|
||||
"""
|
||||
self.almostService = AlmostService(parent=None, running=False, name=None) # type: ignore[arg-type]
|
||||
|
||||
def test_realService(self) -> None:
|
||||
"""
|
||||
Service implements IService.
|
||||
"""
|
||||
myService = Service()
|
||||
verifyObject(IService, myService)
|
||||
|
||||
def test_hasAll(self) -> None:
|
||||
"""
|
||||
AlmostService implements IService.
|
||||
"""
|
||||
verifyObject(IService, self.almostService)
|
||||
|
||||
def test_noName(self) -> None:
|
||||
"""
|
||||
AlmostService with no name does not implement IService.
|
||||
"""
|
||||
self.almostService.makeInvalidByDeletingName()
|
||||
with self.assertRaises(BrokenImplementation):
|
||||
verifyObject(IService, self.almostService)
|
||||
|
||||
def test_noParent(self) -> None:
|
||||
"""
|
||||
AlmostService with no parent does not implement IService.
|
||||
"""
|
||||
self.almostService.makeInvalidByDeletingParent()
|
||||
with self.assertRaises(BrokenImplementation):
|
||||
verifyObject(IService, self.almostService)
|
||||
|
||||
def test_noRunning(self) -> None:
|
||||
"""
|
||||
AlmostService with no running does not implement IService.
|
||||
"""
|
||||
self.almostService.makeInvalidByDeletingRunning()
|
||||
with self.assertRaises(BrokenImplementation):
|
||||
verifyObject(IService, self.almostService)
|
||||
|
||||
|
||||
class ApplicationTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.application.service.Application}.
|
||||
"""
|
||||
|
||||
def test_applicationComponents(self) -> None:
|
||||
"""
|
||||
Check L{twisted.application.service.Application} instantiation.
|
||||
"""
|
||||
app = Application("app-name")
|
||||
|
||||
self.assertTrue(verifyObject(IService, IService(app)))
|
||||
self.assertTrue(verifyObject(IServiceCollection, IServiceCollection(app)))
|
||||
self.assertTrue(verifyObject(IProcess, IProcess(app)))
|
||||
self.assertTrue(verifyObject(IPersistable, IPersistable(app)))
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.application.twist.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
C{twist} command line tool.
|
||||
"""
|
||||
@@ -0,0 +1,207 @@
|
||||
# -*- test-case-name: twisted.application.twist.test.test_options -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Command line options for C{twist}.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from sys import stderr, stdout
|
||||
from textwrap import dedent
|
||||
from typing import Callable, Iterable, Mapping, Optional, Sequence, Tuple, cast
|
||||
|
||||
from twisted.copyright import version
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
from twisted.logger import (
|
||||
InvalidLogLevelError,
|
||||
LogLevel,
|
||||
jsonFileLogObserver,
|
||||
textFileLogObserver,
|
||||
)
|
||||
from twisted.plugin import getPlugins
|
||||
from twisted.python.usage import Options, UsageError
|
||||
from ..reactors import NoSuchReactor, getReactorTypes, installReactor
|
||||
from ..runner._exit import ExitStatus, exit
|
||||
from ..service import IServiceMaker
|
||||
|
||||
openFile = open
|
||||
|
||||
|
||||
def _update_doc(opt: Callable[["TwistOptions", str], None], **kwargs: str) -> None:
|
||||
"""
|
||||
Update the docstring of a method that implements an option.
|
||||
The string is dedented and the given keyword arguments are substituted.
|
||||
"""
|
||||
opt.__doc__ = dedent(opt.__doc__ or "").format(**kwargs)
|
||||
|
||||
|
||||
class TwistOptions(Options):
|
||||
"""
|
||||
Command line options for C{twist}.
|
||||
"""
|
||||
|
||||
defaultReactorName = "default"
|
||||
defaultLogLevel = LogLevel.info
|
||||
|
||||
def __init__(self) -> None:
|
||||
Options.__init__(self)
|
||||
|
||||
self["reactorName"] = self.defaultReactorName
|
||||
self["logLevel"] = self.defaultLogLevel
|
||||
self["logFile"] = stdout
|
||||
# An empty long description is explicitly set here as otherwise
|
||||
# when executing from distributed trial twisted.python.usage will
|
||||
# pull the description from `__main__` which is another entry point.
|
||||
self.longdesc = ""
|
||||
|
||||
def getSynopsis(self) -> str:
|
||||
return f"{Options.getSynopsis(self)} plugin [plugin_options]"
|
||||
|
||||
def opt_version(self) -> "typing.NoReturn":
|
||||
"""
|
||||
Print version and exit.
|
||||
"""
|
||||
exit(ExitStatus.EX_OK, f"{version}")
|
||||
|
||||
def opt_reactor(self, name: str) -> None:
|
||||
"""
|
||||
The name of the reactor to use.
|
||||
(options: {options})
|
||||
"""
|
||||
# Actually actually actually install the reactor right at this very
|
||||
# moment, before any other code (for example, a sub-command plugin)
|
||||
# runs and accidentally imports and installs the default reactor.
|
||||
try:
|
||||
self["reactor"] = self.installReactor(name)
|
||||
except NoSuchReactor:
|
||||
raise UsageError(f"Unknown reactor: {name}")
|
||||
else:
|
||||
self["reactorName"] = name
|
||||
|
||||
_update_doc(
|
||||
opt_reactor,
|
||||
options=", ".join(f'"{rt.shortName}"' for rt in getReactorTypes()),
|
||||
)
|
||||
|
||||
def installReactor(self, name: str) -> IReactorCore:
|
||||
"""
|
||||
Install the reactor.
|
||||
"""
|
||||
if name == self.defaultReactorName:
|
||||
from twisted.internet import reactor
|
||||
|
||||
return cast(IReactorCore, reactor)
|
||||
else:
|
||||
return installReactor(name)
|
||||
|
||||
def opt_log_level(self, levelName: str) -> None:
|
||||
"""
|
||||
Set default log level.
|
||||
(options: {options}; default: "{default}")
|
||||
"""
|
||||
try:
|
||||
self["logLevel"] = LogLevel.levelWithName(levelName)
|
||||
except InvalidLogLevelError:
|
||||
raise UsageError(f"Invalid log level: {levelName}")
|
||||
|
||||
_update_doc(
|
||||
opt_log_level,
|
||||
options=", ".join(
|
||||
f'"{constant.name}"' for constant in LogLevel.iterconstants()
|
||||
),
|
||||
default=defaultLogLevel.name,
|
||||
)
|
||||
|
||||
def opt_log_file(self, fileName: str) -> None:
|
||||
"""
|
||||
Log to file. ("-" for stdout, "+" for stderr; default: "-")
|
||||
"""
|
||||
if fileName == "-":
|
||||
self["logFile"] = stdout
|
||||
return
|
||||
|
||||
if fileName == "+":
|
||||
self["logFile"] = stderr
|
||||
return
|
||||
|
||||
try:
|
||||
self["logFile"] = openFile(fileName, "a")
|
||||
except OSError as e:
|
||||
exit(
|
||||
ExitStatus.EX_IOERR,
|
||||
f"Unable to open log file {fileName!r}: {e}",
|
||||
)
|
||||
|
||||
def opt_log_format(self, format: str) -> None:
|
||||
"""
|
||||
Log file format.
|
||||
(options: "text", "json"; default: "text" if the log file is a tty,
|
||||
otherwise "json")
|
||||
"""
|
||||
format = format.lower()
|
||||
|
||||
if format == "text":
|
||||
self["fileLogObserverFactory"] = textFileLogObserver
|
||||
elif format == "json":
|
||||
self["fileLogObserverFactory"] = jsonFileLogObserver
|
||||
else:
|
||||
raise UsageError(f"Invalid log format: {format}")
|
||||
self["logFormat"] = format
|
||||
|
||||
_update_doc(opt_log_format)
|
||||
|
||||
def selectDefaultLogObserver(self) -> None:
|
||||
"""
|
||||
Set C{fileLogObserverFactory} to the default appropriate for the
|
||||
chosen C{logFile}.
|
||||
"""
|
||||
if "fileLogObserverFactory" not in self:
|
||||
logFile = self["logFile"]
|
||||
|
||||
if hasattr(logFile, "isatty") and logFile.isatty():
|
||||
self["fileLogObserverFactory"] = textFileLogObserver
|
||||
self["logFormat"] = "text"
|
||||
else:
|
||||
self["fileLogObserverFactory"] = jsonFileLogObserver
|
||||
self["logFormat"] = "json"
|
||||
|
||||
def parseOptions(self, options: Optional[Sequence[str]] = None) -> None:
|
||||
self.selectDefaultLogObserver()
|
||||
|
||||
Options.parseOptions(self, options=options)
|
||||
|
||||
if "reactor" not in self:
|
||||
self["reactor"] = self.installReactor(self["reactorName"])
|
||||
|
||||
@property
|
||||
def plugins(self) -> Mapping[str, IServiceMaker]:
|
||||
if "plugins" not in self:
|
||||
plugins = {}
|
||||
for plugin in getPlugins(IServiceMaker):
|
||||
plugins[plugin.tapname] = plugin
|
||||
self["plugins"] = plugins
|
||||
|
||||
return cast(Mapping[str, IServiceMaker], self["plugins"])
|
||||
|
||||
@property
|
||||
def subCommands(
|
||||
self,
|
||||
) -> Iterable[Tuple[str, None, Callable[[IServiceMaker], Options], str]]:
|
||||
plugins = self.plugins
|
||||
for name in sorted(plugins):
|
||||
plugin = plugins[name]
|
||||
|
||||
# Don't pass plugin.options along in order to avoid resolving the
|
||||
# options attribute right away, in case it's a property with a
|
||||
# non-trivial getter (eg, one which imports modules).
|
||||
def options(plugin: IServiceMaker = plugin) -> Options:
|
||||
return cast(Options, plugin.options())
|
||||
|
||||
yield (plugin.tapname, None, options, plugin.description)
|
||||
|
||||
def postOptions(self) -> None:
|
||||
Options.postOptions(self)
|
||||
|
||||
if self.subCommand is None:
|
||||
raise UsageError("No plugin specified.")
|
||||
@@ -0,0 +1,114 @@
|
||||
# -*- test-case-name: twisted.application.twist.test.test_twist -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Run a Twisted application.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
from twisted.application.app import _exitWithSignal
|
||||
from twisted.internet.interfaces import IReactorCore, _ISupportsExitSignalCapturing
|
||||
from twisted.python.usage import Options, UsageError
|
||||
from ..runner._exit import ExitStatus, exit
|
||||
from ..runner._runner import Runner
|
||||
from ..service import Application, IService, IServiceMaker
|
||||
from ._options import TwistOptions
|
||||
|
||||
|
||||
class Twist:
|
||||
"""
|
||||
Run a Twisted application.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def options(argv: Sequence[str]) -> TwistOptions:
|
||||
"""
|
||||
Parse command line options.
|
||||
|
||||
@param argv: Command line arguments.
|
||||
@return: The parsed options.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
try:
|
||||
options.parseOptions(argv[1:])
|
||||
except UsageError as e:
|
||||
exit(ExitStatus.EX_USAGE, f"Error: {e}\n\n{options}")
|
||||
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def service(plugin: IServiceMaker, options: Options) -> IService:
|
||||
"""
|
||||
Create the application service.
|
||||
|
||||
@param plugin: The name of the plugin that implements the service
|
||||
application to run.
|
||||
@param options: Options to pass to the application.
|
||||
@return: The created application service.
|
||||
"""
|
||||
service = plugin.makeService(options)
|
||||
application = Application(plugin.tapname)
|
||||
service.setServiceParent(application)
|
||||
|
||||
return IService(application)
|
||||
|
||||
@staticmethod
|
||||
def startService(reactor: IReactorCore, service: IService) -> None:
|
||||
"""
|
||||
Start the application service.
|
||||
|
||||
@param reactor: The reactor to run the service with.
|
||||
@param service: The application service to run.
|
||||
"""
|
||||
service.startService()
|
||||
|
||||
# Ask the reactor to stop the service before shutting down
|
||||
reactor.addSystemEventTrigger("before", "shutdown", service.stopService)
|
||||
|
||||
@staticmethod
|
||||
def run(twistOptions: TwistOptions) -> None:
|
||||
"""
|
||||
Run the application service.
|
||||
|
||||
@param twistOptions: Command line options to convert to runner
|
||||
arguments.
|
||||
"""
|
||||
runner = Runner(
|
||||
reactor=twistOptions["reactor"],
|
||||
defaultLogLevel=twistOptions["logLevel"],
|
||||
logFile=twistOptions["logFile"],
|
||||
fileLogObserverFactory=twistOptions["fileLogObserverFactory"],
|
||||
)
|
||||
runner.run()
|
||||
reactor = twistOptions["reactor"]
|
||||
if _ISupportsExitSignalCapturing.providedBy(reactor):
|
||||
if reactor._exitSignal is not None:
|
||||
_exitWithSignal(reactor._exitSignal)
|
||||
|
||||
@classmethod
|
||||
def main(cls, argv: Sequence[str] = sys.argv) -> None:
|
||||
"""
|
||||
Executable entry point for L{Twist}.
|
||||
Processes options and run a twisted reactor with a service.
|
||||
|
||||
@param argv: Command line arguments.
|
||||
@type argv: L{list}
|
||||
"""
|
||||
options = cls.options(argv)
|
||||
|
||||
reactor = options["reactor"]
|
||||
# If subCommand is None, TwistOptions.parseOptions() raises UsageError
|
||||
# and Twist.options() will exit the runner, so we'll never get here.
|
||||
subCommand = options.subCommand
|
||||
assert subCommand is not None
|
||||
service = cls.service(
|
||||
plugin=options.plugins[subCommand],
|
||||
options=options.subOptions,
|
||||
)
|
||||
|
||||
cls.startService(reactor, service)
|
||||
cls.run(options)
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.application.twist.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.twist}.
|
||||
"""
|
||||
@@ -0,0 +1,355 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.twist._options}.
|
||||
"""
|
||||
|
||||
from sys import stderr, stdout
|
||||
from typing import Callable, Dict, List, Optional, TextIO, Tuple
|
||||
|
||||
import twisted.trial.unittest
|
||||
from twisted.copyright import version
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
from twisted.logger import (
|
||||
FileLogObserver,
|
||||
LogLevel,
|
||||
jsonFileLogObserver,
|
||||
textFileLogObserver,
|
||||
)
|
||||
from twisted.python.usage import UsageError
|
||||
from ...reactors import NoSuchReactor
|
||||
from ...runner._exit import ExitStatus
|
||||
from ...runner.test.test_runner import DummyExit
|
||||
from ...service import ServiceMaker
|
||||
from ...twist import _options
|
||||
from .._options import TwistOptions
|
||||
|
||||
|
||||
class OptionsTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{TwistOptions}.
|
||||
"""
|
||||
|
||||
def patchExit(self) -> None:
|
||||
"""
|
||||
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
|
||||
"""
|
||||
self.exit = DummyExit()
|
||||
self.patch(_options, "exit", self.exit)
|
||||
|
||||
def patchOpen(self) -> None:
|
||||
"""
|
||||
Patch L{_options.open} so we can capture usage and prevent actual opens.
|
||||
"""
|
||||
self.opened: List[Tuple[str, Optional[str]]] = []
|
||||
|
||||
def fakeOpen(name: str, mode: Optional[str] = None) -> TextIO:
|
||||
if name == "nocanopen":
|
||||
raise OSError(None, None, name)
|
||||
|
||||
self.opened.append((name, mode))
|
||||
return NotImplemented
|
||||
|
||||
self.patch(_options, "openFile", fakeOpen)
|
||||
|
||||
def patchInstallReactor(self) -> None:
|
||||
"""
|
||||
Patch C{_options.installReactor} so we can capture usage and prevent
|
||||
actual installs.
|
||||
"""
|
||||
self.installedReactors: Dict[str, IReactorCore] = {}
|
||||
|
||||
def installReactor(name: str) -> IReactorCore:
|
||||
if name != "fusion":
|
||||
raise NoSuchReactor()
|
||||
|
||||
reactor = MemoryReactor()
|
||||
self.installedReactors[name] = reactor
|
||||
return reactor
|
||||
|
||||
self.patch(_options, "installReactor", installReactor)
|
||||
|
||||
def test_synopsis(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.getSynopsis} appends arguments.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
self.assertTrue(options.getSynopsis().endswith(" plugin [plugin_options]"))
|
||||
|
||||
def test_version(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_version} exits with L{ExitStatus.EX_OK} and prints
|
||||
the version.
|
||||
"""
|
||||
self.patchExit()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_version()
|
||||
|
||||
self.assertEquals(self.exit.status, ExitStatus.EX_OK) # type: ignore[unreachable]
|
||||
self.assertEquals(self.exit.message, version)
|
||||
|
||||
def test_reactor(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.installReactor} installs the chosen reactor and sets
|
||||
the reactor name.
|
||||
"""
|
||||
self.patchInstallReactor()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_reactor("fusion")
|
||||
|
||||
self.assertEqual(set(self.installedReactors), {"fusion"})
|
||||
self.assertEquals(options["reactorName"], "fusion")
|
||||
|
||||
def test_installCorrectReactor(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.installReactor} installs the chosen reactor after the
|
||||
command line options have been parsed.
|
||||
"""
|
||||
self.patchInstallReactor()
|
||||
|
||||
options = TwistOptions()
|
||||
options.subCommand = "test-subcommand"
|
||||
options.parseOptions(["--reactor=fusion"])
|
||||
|
||||
self.assertEqual(set(self.installedReactors), {"fusion"})
|
||||
|
||||
def test_installReactorBogus(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.installReactor} raises UsageError if an unknown reactor
|
||||
is specified.
|
||||
"""
|
||||
self.patchInstallReactor()
|
||||
|
||||
options = TwistOptions()
|
||||
self.assertRaises(UsageError, options.opt_reactor, "coal")
|
||||
|
||||
def test_installReactorDefault(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.installReactor} returns the currently installed reactor
|
||||
when the default reactor name is specified.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
self.assertIdentical(reactor, options.installReactor("default"))
|
||||
|
||||
def test_logLevelValid(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_level} sets the corresponding log level.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
options.opt_log_level("warn")
|
||||
|
||||
self.assertIdentical(options["logLevel"], LogLevel.warn)
|
||||
|
||||
def test_logLevelInvalid(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_level} with an invalid log level name raises
|
||||
UsageError.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
self.assertRaises(UsageError, options.opt_log_level, "cheese")
|
||||
|
||||
def _testLogFile(self, name: str, expectedStream: TextIO) -> None:
|
||||
"""
|
||||
Set log file name and check the selected output stream.
|
||||
|
||||
@param name: The name of the file.
|
||||
@param expectedStream: The expected stream.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
options.opt_log_file(name)
|
||||
|
||||
self.assertIdentical(options["logFile"], expectedStream)
|
||||
|
||||
def test_logFileStdout(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_file} given C{"-"} as a file name uses stdout.
|
||||
"""
|
||||
self._testLogFile("-", stdout)
|
||||
|
||||
def test_logFileStderr(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_file} given C{"+"} as a file name uses stderr.
|
||||
"""
|
||||
self._testLogFile("+", stderr)
|
||||
|
||||
def test_logFileNamed(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_file} opens the given file name in append mode.
|
||||
"""
|
||||
self.patchOpen()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_log_file("mylog")
|
||||
|
||||
self.assertEqual([("mylog", "a")], self.opened)
|
||||
|
||||
def test_logFileCantOpen(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_file} exits with L{ExitStatus.EX_IOERR} if
|
||||
unable to open the log file due to an L{EnvironmentError}.
|
||||
"""
|
||||
self.patchExit()
|
||||
self.patchOpen()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_log_file("nocanopen")
|
||||
|
||||
self.assertEquals(self.exit.status, ExitStatus.EX_IOERR)
|
||||
self.assertIsNotNone(self.exit.message)
|
||||
self.assertTrue(
|
||||
self.exit.message.startswith( # type: ignore[union-attr]
|
||||
"Unable to open log file 'nocanopen': "
|
||||
)
|
||||
)
|
||||
|
||||
def _testLogFormat(
|
||||
self, format: str, expectedObserverFactory: Callable[[TextIO], FileLogObserver]
|
||||
) -> None:
|
||||
"""
|
||||
Set log file format and check the selected observer factory.
|
||||
|
||||
@param format: The format of the file.
|
||||
@param expectedObserverFactory: The expected observer factory.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
options.opt_log_format(format)
|
||||
|
||||
self.assertIdentical(options["fileLogObserverFactory"], expectedObserverFactory)
|
||||
self.assertEqual(options["logFormat"], format)
|
||||
|
||||
def test_logFormatText(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_format} given C{"text"} uses a
|
||||
L{textFileLogObserver}.
|
||||
"""
|
||||
self._testLogFormat("text", textFileLogObserver)
|
||||
|
||||
def test_logFormatJSON(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_format} given C{"text"} uses a
|
||||
L{textFileLogObserver}.
|
||||
"""
|
||||
self._testLogFormat("json", jsonFileLogObserver)
|
||||
|
||||
def test_logFormatInvalid(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.opt_log_format} given an invalid format name raises
|
||||
L{UsageError}.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
self.assertRaises(UsageError, options.opt_log_format, "frommage")
|
||||
|
||||
def test_selectDefaultLogObserverNoOverride(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.selectDefaultLogObserver} will not override an already
|
||||
selected observer.
|
||||
"""
|
||||
self.patchOpen()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_log_format("text") # Ask for text
|
||||
options.opt_log_file("queso") # File, not a tty
|
||||
options.selectDefaultLogObserver()
|
||||
|
||||
# Because we didn't select a file that is a tty, the default is JSON,
|
||||
# but since we asked for text, we should get text.
|
||||
self.assertIdentical(options["fileLogObserverFactory"], textFileLogObserver)
|
||||
self.assertEqual(options["logFormat"], "text")
|
||||
|
||||
def test_selectDefaultLogObserverDefaultWithTTY(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.selectDefaultLogObserver} will not override an already
|
||||
selected observer.
|
||||
"""
|
||||
|
||||
class TTYFile:
|
||||
def isatty(self) -> bool:
|
||||
return True
|
||||
|
||||
# stdout may not be a tty, so let's make sure it thinks it is
|
||||
self.patch(_options, "stdout", TTYFile())
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_log_file("-") # stdout, a tty
|
||||
options.selectDefaultLogObserver()
|
||||
|
||||
self.assertIdentical(options["fileLogObserverFactory"], textFileLogObserver)
|
||||
self.assertEqual(options["logFormat"], "text")
|
||||
|
||||
def test_selectDefaultLogObserverDefaultWithoutTTY(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.selectDefaultLogObserver} will not override an already
|
||||
selected observer.
|
||||
"""
|
||||
self.patchOpen()
|
||||
|
||||
options = TwistOptions()
|
||||
options.opt_log_file("queso") # File, not a tty
|
||||
options.selectDefaultLogObserver()
|
||||
|
||||
self.assertIdentical(options["fileLogObserverFactory"], jsonFileLogObserver)
|
||||
self.assertEqual(options["logFormat"], "json")
|
||||
|
||||
def test_pluginsType(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.plugins} is a mapping of available plug-ins.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
plugins = options.plugins
|
||||
|
||||
for name in plugins:
|
||||
self.assertIsInstance(name, str)
|
||||
self.assertIsInstance(plugins[name], ServiceMaker)
|
||||
|
||||
def test_pluginsIncludeWeb(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.plugins} includes a C{"web"} plug-in.
|
||||
This is an attempt to verify that something we expect to be in the list
|
||||
is in there without enumerating all of the built-in plug-ins.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
self.assertIn("web", options.plugins)
|
||||
|
||||
def test_subCommandsType(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.subCommands} is an iterable of tuples as expected by
|
||||
L{twisted.python.usage.Options}.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
for name, shortcut, parser, doc in options.subCommands:
|
||||
self.assertIsInstance(name, str)
|
||||
self.assertIdentical(shortcut, None)
|
||||
self.assertTrue(callable(parser))
|
||||
self.assertIsInstance(doc, str)
|
||||
|
||||
def test_subCommandsIncludeWeb(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.subCommands} includes a sub-command for every plug-in.
|
||||
"""
|
||||
options = TwistOptions()
|
||||
|
||||
plugins = set(options.plugins)
|
||||
subCommands = {name for name, shortcut, parser, doc in options.subCommands}
|
||||
|
||||
self.assertEqual(subCommands, plugins)
|
||||
|
||||
def test_postOptionsNoSubCommand(self) -> None:
|
||||
"""
|
||||
L{TwistOptions.postOptions} raises L{UsageError} is it has no
|
||||
sub-command.
|
||||
"""
|
||||
self.patchInstallReactor()
|
||||
|
||||
options = TwistOptions()
|
||||
|
||||
self.assertRaises(UsageError, options.postOptions)
|
||||
@@ -0,0 +1,256 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.application.twist._twist}.
|
||||
"""
|
||||
|
||||
from sys import stdout
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import twisted.trial.unittest
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
from twisted.logger import LogLevel, jsonFileLogObserver
|
||||
from twisted.test.test_twistd import SignalCapturingMemoryReactor
|
||||
from ...runner._exit import ExitStatus
|
||||
from ...runner._runner import Runner
|
||||
from ...runner.test.test_runner import DummyExit
|
||||
from ...service import IService, MultiService
|
||||
from ...twist import _twist
|
||||
from .._options import TwistOptions
|
||||
from .._twist import Twist
|
||||
|
||||
|
||||
class TwistTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests for L{Twist}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.patchInstallReactor()
|
||||
|
||||
def patchExit(self) -> None:
|
||||
"""
|
||||
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
|
||||
"""
|
||||
self.exit = DummyExit()
|
||||
self.patch(_twist, "exit", self.exit)
|
||||
|
||||
def patchInstallReactor(self) -> None:
|
||||
"""
|
||||
Patch C{_options.installReactor} so we can capture usage and prevent
|
||||
actual installs.
|
||||
"""
|
||||
self.installedReactors: Dict[str, IReactorCore] = {}
|
||||
|
||||
def installReactor(_: TwistOptions, name: str) -> IReactorCore:
|
||||
reactor = MemoryReactor()
|
||||
self.installedReactors[name] = reactor
|
||||
return reactor
|
||||
|
||||
self.patch(TwistOptions, "installReactor", installReactor)
|
||||
|
||||
def patchStartService(self) -> None:
|
||||
"""
|
||||
Patch L{MultiService.startService} so we can capture usage and prevent
|
||||
actual starts.
|
||||
"""
|
||||
self.serviceStarts: List[IService] = []
|
||||
|
||||
def startService(service: IService) -> None:
|
||||
self.serviceStarts.append(service)
|
||||
|
||||
self.patch(MultiService, "startService", startService)
|
||||
|
||||
def test_optionsValidArguments(self) -> None:
|
||||
"""
|
||||
L{Twist.options} given valid arguments returns options.
|
||||
"""
|
||||
options = Twist.options(["twist", "web"])
|
||||
|
||||
self.assertIsInstance(options, TwistOptions)
|
||||
|
||||
def test_optionsInvalidArguments(self) -> None:
|
||||
"""
|
||||
L{Twist.options} given invalid arguments exits with
|
||||
L{ExitStatus.EX_USAGE} and an error/usage message.
|
||||
"""
|
||||
self.patchExit()
|
||||
|
||||
Twist.options(["twist", "--bogus-bagels"])
|
||||
|
||||
self.assertIdentical(self.exit.status, ExitStatus.EX_USAGE)
|
||||
self.assertIsNotNone(self.exit.message)
|
||||
self.assertTrue(
|
||||
self.exit.message.startswith("Error: ") # type: ignore[union-attr]
|
||||
)
|
||||
self.assertTrue(
|
||||
self.exit.message.endswith( # type: ignore[union-attr]
|
||||
f"\n\n{TwistOptions()}"
|
||||
)
|
||||
)
|
||||
|
||||
def test_service(self) -> None:
|
||||
"""
|
||||
L{Twist.service} returns an L{IService}.
|
||||
"""
|
||||
options = Twist.options(["twist", "web"]) # web should exist
|
||||
service = Twist.service(options.plugins["web"], options.subOptions)
|
||||
self.assertTrue(IService.providedBy(service))
|
||||
|
||||
def test_startService(self) -> None:
|
||||
"""
|
||||
L{Twist.startService} starts the service and registers a trigger to
|
||||
stop the service when the reactor shuts down.
|
||||
"""
|
||||
options = Twist.options(["twist", "web"])
|
||||
|
||||
reactor = options["reactor"]
|
||||
subCommand = options.subCommand
|
||||
assert subCommand is not None
|
||||
service = Twist.service(
|
||||
plugin=options.plugins[subCommand],
|
||||
options=options.subOptions,
|
||||
)
|
||||
|
||||
self.patchStartService()
|
||||
|
||||
Twist.startService(reactor, service)
|
||||
|
||||
self.assertEqual(self.serviceStarts, [service])
|
||||
self.assertEqual(
|
||||
reactor.triggers["before"]["shutdown"], [(service.stopService, (), {})]
|
||||
)
|
||||
|
||||
def test_run(self) -> None:
|
||||
"""
|
||||
L{Twist.run} runs the runner with arguments corresponding to the given
|
||||
options.
|
||||
"""
|
||||
argsSeen = []
|
||||
|
||||
self.patch(Runner, "__init__", lambda self, **args: argsSeen.append(args))
|
||||
self.patch(Runner, "run", lambda self: None)
|
||||
|
||||
twistOptions = Twist.options(
|
||||
["twist", "--reactor=default", "--log-format=json", "web"]
|
||||
)
|
||||
Twist.run(twistOptions)
|
||||
|
||||
self.assertEqual(len(argsSeen), 1)
|
||||
self.assertEqual(
|
||||
argsSeen[0],
|
||||
dict(
|
||||
reactor=self.installedReactors["default"],
|
||||
defaultLogLevel=LogLevel.info,
|
||||
logFile=stdout,
|
||||
fileLogObserverFactory=jsonFileLogObserver,
|
||||
),
|
||||
)
|
||||
|
||||
def test_main(self) -> None:
|
||||
"""
|
||||
L{Twist.main} runs the runner with arguments corresponding to the given
|
||||
command line arguments.
|
||||
"""
|
||||
self.patchStartService()
|
||||
|
||||
runners = []
|
||||
|
||||
class Runner:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.args = kwargs
|
||||
self.runs = 0
|
||||
runners.append(self)
|
||||
|
||||
def run(self) -> None:
|
||||
self.runs += 1
|
||||
|
||||
self.patch(_twist, "Runner", Runner)
|
||||
|
||||
Twist.main(["twist", "--reactor=default", "--log-format=json", "web"])
|
||||
|
||||
self.assertEqual(len(self.serviceStarts), 1)
|
||||
self.assertEqual(len(runners), 1)
|
||||
self.assertEqual(
|
||||
runners[0].args,
|
||||
dict(
|
||||
reactor=self.installedReactors["default"],
|
||||
defaultLogLevel=LogLevel.info,
|
||||
logFile=stdout,
|
||||
fileLogObserverFactory=jsonFileLogObserver,
|
||||
),
|
||||
)
|
||||
self.assertEqual(runners[0].runs, 1)
|
||||
|
||||
|
||||
class TwistExitTests(twisted.trial.unittest.TestCase):
|
||||
"""
|
||||
Tests to verify that the Twist script takes the expected actions related
|
||||
to signals and the reactor.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.exitWithSignalCalled = False
|
||||
|
||||
def fakeExitWithSignal(sig: int) -> None:
|
||||
"""
|
||||
Fake to capture whether L{twisted.application._exitWithSignal
|
||||
was called.
|
||||
|
||||
@param sig: Signal value
|
||||
@type sig: C{int}
|
||||
"""
|
||||
self.exitWithSignalCalled = True
|
||||
|
||||
self.patch(_twist, "_exitWithSignal", fakeExitWithSignal)
|
||||
|
||||
def startLogging(_: Runner) -> None:
|
||||
"""
|
||||
Prevent Runner from adding new log observers or other
|
||||
tests outside this module will fail.
|
||||
|
||||
@param _: Unused self param
|
||||
"""
|
||||
|
||||
self.patch(Runner, "startLogging", startLogging)
|
||||
|
||||
def test_twistReactorDoesntExitWithSignal(self) -> None:
|
||||
"""
|
||||
_exitWithSignal is not called if the reactor's _exitSignal attribute
|
||||
is zero.
|
||||
"""
|
||||
reactor = SignalCapturingMemoryReactor()
|
||||
reactor._exitSignal = None
|
||||
options = TwistOptions()
|
||||
options["reactor"] = reactor
|
||||
options["fileLogObserverFactory"] = jsonFileLogObserver
|
||||
|
||||
Twist.run(options)
|
||||
self.assertFalse(self.exitWithSignalCalled)
|
||||
|
||||
def test_twistReactorHasNoExitSignalAttr(self) -> None:
|
||||
"""
|
||||
_exitWithSignal is not called if the runner's reactor does not
|
||||
implement L{twisted.internet.interfaces._ISupportsExitSignalCapturing}
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
options = TwistOptions()
|
||||
options["reactor"] = reactor
|
||||
options["fileLogObserverFactory"] = jsonFileLogObserver
|
||||
Twist.run(options)
|
||||
self.assertFalse(self.exitWithSignalCalled)
|
||||
|
||||
def test_twistReactorExitsWithSignal(self) -> None:
|
||||
"""
|
||||
_exitWithSignal is called if the runner's reactor exits due
|
||||
to a signal.
|
||||
"""
|
||||
reactor = SignalCapturingMemoryReactor()
|
||||
reactor._exitSignal = 2
|
||||
options = TwistOptions()
|
||||
options["reactor"] = reactor
|
||||
options["fileLogObserverFactory"] = jsonFileLogObserver
|
||||
Twist.run(options)
|
||||
self.assertTrue(self.exitWithSignalCalled)
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- test-case-name: twisted.conch.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
|
||||
"""
|
||||
56
.venv/lib/python3.12/site-packages/twisted/conch/avatar.py
Normal file
56
.venv/lib/python3.12/site-packages/twisted/conch/avatar.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.interfaces import IConchUser
|
||||
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.compat import nativeString
|
||||
|
||||
|
||||
@implementer(IConchUser)
|
||||
class ConchUser:
|
||||
_log = Logger()
|
||||
|
||||
def __init__(self):
|
||||
self.channelLookup = {}
|
||||
self.subsystemLookup = {}
|
||||
|
||||
@property
|
||||
def conn(self):
|
||||
return self._conn
|
||||
|
||||
@conn.setter
|
||||
def conn(self, value):
|
||||
self._conn = value
|
||||
|
||||
def lookupChannel(self, channelType, windowSize, maxPacket, data):
|
||||
klass = self.channelLookup.get(channelType, None)
|
||||
if not klass:
|
||||
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
|
||||
else:
|
||||
return klass(
|
||||
remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket,
|
||||
data=data,
|
||||
avatar=self,
|
||||
)
|
||||
|
||||
def lookupSubsystem(self, subsystem, data):
|
||||
self._log.debug(
|
||||
"Subsystem lookup: {subsystem!r}", subsystem=self.subsystemLookup
|
||||
)
|
||||
klass = self.subsystemLookup.get(subsystem, None)
|
||||
if not klass:
|
||||
return False
|
||||
return klass(data, avatar=self)
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
# XXX should this use method dispatch?
|
||||
requestType = nativeString(requestType.replace(b"-", b"_"))
|
||||
f = getattr(self, "global_%s" % requestType, None)
|
||||
if not f:
|
||||
return 0
|
||||
return f(data)
|
||||
640
.venv/lib/python3.12/site-packages/twisted/conch/checkers.py
Normal file
640
.venv/lib/python3.12/site-packages/twisted/conch/checkers.py
Normal file
@@ -0,0 +1,640 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_checkers -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
|
||||
"""
|
||||
|
||||
|
||||
import binascii
|
||||
import errno
|
||||
import sys
|
||||
from base64 import decodebytes
|
||||
from typing import IO, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, cast
|
||||
|
||||
from zope.interface import Interface, implementer, providedBy
|
||||
|
||||
from incremental import Version
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.cred.checkers import ICredentialsChecker
|
||||
from twisted.cred.credentials import ISSHPrivateKey, IUsernamePassword
|
||||
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
|
||||
from twisted.internet import defer
|
||||
from twisted.logger import Logger
|
||||
from twisted.plugins.cred_unix import verifyCryptedPassword
|
||||
from twisted.python import failure, reflect
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.util import runAsEffectiveUser
|
||||
|
||||
_log = Logger()
|
||||
|
||||
|
||||
class UserRecord(Tuple[str, str, int, int, str, str, str]):
|
||||
"""
|
||||
A record in a UNIX-style password database. See L{pwd} for field details.
|
||||
|
||||
This corresponds to the undocumented type L{pwd.struct_passwd}, but lacks named
|
||||
field accessors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def pw_dir(self) -> str: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
|
||||
class UserDB(Protocol):
|
||||
"""
|
||||
A database of users by name, like the stdlib L{pwd} module.
|
||||
|
||||
See L{twisted.python.fakepwd} for an in-memory implementation.
|
||||
"""
|
||||
|
||||
def getpwnam(self, username: str) -> UserRecord:
|
||||
"""
|
||||
Lookup a user record by name.
|
||||
|
||||
@raises KeyError: when no such user exists
|
||||
"""
|
||||
|
||||
|
||||
pwd: Optional[UserDB]
|
||||
try:
|
||||
import pwd as _pwd
|
||||
except ImportError:
|
||||
pwd = None
|
||||
else:
|
||||
pwd = cast(UserDB, _pwd)
|
||||
|
||||
|
||||
try:
|
||||
import spwd as _spwd
|
||||
except ImportError:
|
||||
spwd = None
|
||||
else:
|
||||
spwd = _spwd
|
||||
|
||||
|
||||
class CryptedPasswordRecord(Protocol):
|
||||
"""
|
||||
A sequence where the item at index 1 may be a crypted password.
|
||||
|
||||
Both L{pwd.struct_passwd} and L{spwd.struct_spwd} conform to this protocol.
|
||||
"""
|
||||
|
||||
def __getitem__(self, index: Literal[1]) -> str:
|
||||
"""
|
||||
Get the crypted password.
|
||||
"""
|
||||
|
||||
|
||||
def _lookupUser(userdb: UserDB, username: bytes) -> UserRecord:
|
||||
"""
|
||||
Lookup a user by name in a L{pwd}-style database.
|
||||
|
||||
@param userdb: The user database.
|
||||
|
||||
@param username: Identifying name in bytes. This will be decoded according
|
||||
to the filesystem encoding, as the L{pwd} module does internally.
|
||||
|
||||
@raises KeyError: when the user doesn't exist
|
||||
"""
|
||||
return userdb.getpwnam(username.decode(sys.getfilesystemencoding()))
|
||||
|
||||
|
||||
def _pwdGetByName(username: str) -> Optional[CryptedPasswordRecord]:
|
||||
"""
|
||||
Look up a user in the /etc/passwd database using the pwd module. If the
|
||||
pwd module is not available, return None.
|
||||
|
||||
@param username: the username of the user to return the passwd database
|
||||
information for.
|
||||
|
||||
@returns: A L{pwd.struct_passwd}, where field 1 may contain a crypted
|
||||
password, or L{None} when the L{pwd} database is unavailable.
|
||||
|
||||
@raises KeyError: when no such user exists
|
||||
"""
|
||||
if pwd is None:
|
||||
return None
|
||||
return cast(CryptedPasswordRecord, pwd.getpwnam(username))
|
||||
|
||||
|
||||
def _shadowGetByName(username: str) -> Optional[CryptedPasswordRecord]:
|
||||
"""
|
||||
Look up a user in the /etc/shadow database using the spwd module. If it is
|
||||
not available, return L{None}.
|
||||
|
||||
@param username: the username of the user to return the shadow database
|
||||
information for.
|
||||
@type username: L{str}
|
||||
|
||||
@returns: A L{spwd.struct_spwd}, where field 1 may contain a crypted
|
||||
password, or L{None} when the L{spwd} database is unavailable.
|
||||
|
||||
@raises KeyError: when no such user exists
|
||||
"""
|
||||
if spwd is not None:
|
||||
f = spwd.getspnam
|
||||
else:
|
||||
return None
|
||||
return cast(CryptedPasswordRecord, runAsEffectiveUser(0, 0, f, username))
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class UNIXPasswordDatabase:
|
||||
"""
|
||||
A checker which validates users out of the UNIX password databases, or
|
||||
databases of a compatible format.
|
||||
|
||||
@ivar _getByNameFunctions: a C{list} of functions which are called in order
|
||||
to validate a user. The default value is such that the C{/etc/passwd}
|
||||
database will be tried first, followed by the C{/etc/shadow} database.
|
||||
"""
|
||||
|
||||
credentialInterfaces = (IUsernamePassword,)
|
||||
|
||||
def __init__(self, getByNameFunctions=None):
|
||||
if getByNameFunctions is None:
|
||||
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
|
||||
self._getByNameFunctions = getByNameFunctions
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
# We get bytes, but the Py3 pwd module uses str. So attempt to decode
|
||||
# it using the same method that CPython does for the file on disk.
|
||||
username = credentials.username.decode(sys.getfilesystemencoding())
|
||||
password = credentials.password.decode(sys.getfilesystemencoding())
|
||||
|
||||
for func in self._getByNameFunctions:
|
||||
try:
|
||||
pwnam = func(username)
|
||||
except KeyError:
|
||||
return defer.fail(UnauthorizedLogin("invalid username"))
|
||||
else:
|
||||
if pwnam is not None:
|
||||
crypted = pwnam[1]
|
||||
if crypted == "":
|
||||
continue
|
||||
|
||||
if verifyCryptedPassword(crypted, password):
|
||||
return defer.succeed(credentials.username)
|
||||
# fallback
|
||||
return defer.fail(UnauthorizedLogin("unable to verify password"))
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHPublicKeyDatabase:
|
||||
"""
|
||||
Checker that authenticates SSH public keys, based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in user .ssh/ directories.
|
||||
"""
|
||||
|
||||
credentialInterfaces = (ISSHPrivateKey,)
|
||||
|
||||
_userdb: UserDB = cast(UserDB, pwd)
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
d = defer.maybeDeferred(self.checkKey, credentials)
|
||||
d.addCallback(self._cbRequestAvatarId, credentials)
|
||||
d.addErrback(self._ebRequestAvatarId)
|
||||
return d
|
||||
|
||||
def _cbRequestAvatarId(self, validKey, credentials):
|
||||
"""
|
||||
Check whether the credentials themselves are valid, now that we know
|
||||
if the key matches the user.
|
||||
|
||||
@param validKey: A boolean indicating whether or not the public key
|
||||
matches a key in the user's authorized_keys file.
|
||||
|
||||
@param credentials: The credentials offered by the user.
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: (as a failure) if the key does not match the
|
||||
user in C{credentials}. Also raised if the user provides an invalid
|
||||
signature.
|
||||
|
||||
@raise ValidPublicKey: (as a failure) if the key matches the user but
|
||||
the credentials do not include a signature. See
|
||||
L{error.ValidPublicKey} for more information.
|
||||
|
||||
@return: The user's username, if authentication was successful.
|
||||
"""
|
||||
if not validKey:
|
||||
return failure.Failure(UnauthorizedLogin("invalid key"))
|
||||
if not credentials.signature:
|
||||
return failure.Failure(error.ValidPublicKey())
|
||||
else:
|
||||
try:
|
||||
pubKey = keys.Key.fromString(credentials.blob)
|
||||
if pubKey.verify(credentials.signature, credentials.sigData):
|
||||
return credentials.username
|
||||
except Exception: # any error should be treated as a failed login
|
||||
_log.failure("Error while verifying key")
|
||||
return failure.Failure(UnauthorizedLogin("error while verifying key"))
|
||||
return failure.Failure(UnauthorizedLogin("unable to verify key"))
|
||||
|
||||
def getAuthorizedKeysFiles(self, credentials):
|
||||
"""
|
||||
Return a list of L{FilePath} instances for I{authorized_keys} files
|
||||
which might contain information about authorized keys for the given
|
||||
credentials.
|
||||
|
||||
On OpenSSH servers, the default location of the file containing the
|
||||
list of authorized public keys is
|
||||
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
|
||||
|
||||
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
|
||||
U{deprecated by OpenSSH since
|
||||
2001<http://marc.info/?m=100508718416162>}.
|
||||
|
||||
@return: A list of L{FilePath} instances to files with the authorized keys.
|
||||
"""
|
||||
pwent = _lookupUser(self._userdb, credentials.username)
|
||||
root = FilePath(pwent.pw_dir).child(".ssh")
|
||||
files = ["authorized_keys", "authorized_keys2"]
|
||||
return [root.child(f) for f in files]
|
||||
|
||||
def checkKey(self, credentials):
|
||||
"""
|
||||
Retrieve files containing authorized keys and check against user
|
||||
credentials.
|
||||
"""
|
||||
ouid, ogid = _lookupUser(self._userdb, credentials.username)[2:4]
|
||||
for filepath in self.getAuthorizedKeysFiles(credentials):
|
||||
if not filepath.exists():
|
||||
continue
|
||||
try:
|
||||
lines = filepath.open()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EACCES:
|
||||
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
|
||||
else:
|
||||
raise
|
||||
with lines:
|
||||
for l in lines:
|
||||
l2 = l.split()
|
||||
if len(l2) < 2:
|
||||
continue
|
||||
try:
|
||||
if decodebytes(l2[1]) == credentials.blob:
|
||||
return True
|
||||
except binascii.Error:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _ebRequestAvatarId(self, f):
|
||||
if not f.check(UnauthorizedLogin):
|
||||
_log.error(
|
||||
"Unauthorized login due to internal error: {error}", error=f.value
|
||||
)
|
||||
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
|
||||
return f
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHProtocolChecker:
|
||||
"""
|
||||
SSHProtocolChecker is a checker that requires multiple authentications
|
||||
to succeed. To add a checker, call my registerChecker method with
|
||||
the checker and the interface.
|
||||
|
||||
After each successful authenticate, I call my areDone method with the
|
||||
avatar id. To get a list of the successful credentials for an avatar id,
|
||||
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
|
||||
returns True, the authentication has succeeded.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.checkers = {}
|
||||
self.successfulCredentials = {}
|
||||
|
||||
@property
|
||||
def credentialInterfaces(self):
|
||||
return list(self.checkers.keys())
|
||||
|
||||
def registerChecker(self, checker, *credentialInterfaces):
|
||||
if not credentialInterfaces:
|
||||
credentialInterfaces = checker.credentialInterfaces
|
||||
for credentialInterface in credentialInterfaces:
|
||||
self.checkers[credentialInterface] = checker
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
"""
|
||||
Part of the L{ICredentialsChecker} interface. Called by a portal with
|
||||
some credentials to check if they'll authenticate a user. We check the
|
||||
interfaces that the credentials provide against our list of acceptable
|
||||
checkers. If one of them matches, we ask that checker to verify the
|
||||
credentials. If they're valid, we call our L{_cbGoodAuthentication}
|
||||
method to continue.
|
||||
|
||||
@param credentials: the credentials the L{Portal} wants us to verify
|
||||
"""
|
||||
ifac = providedBy(credentials)
|
||||
for i in ifac:
|
||||
c = self.checkers.get(i)
|
||||
if c is not None:
|
||||
d = defer.maybeDeferred(c.requestAvatarId, credentials)
|
||||
return d.addCallback(self._cbGoodAuthentication, credentials)
|
||||
return defer.fail(
|
||||
UnhandledCredentials(
|
||||
"No checker for %s" % ", ".join(map(reflect.qual, ifac))
|
||||
)
|
||||
)
|
||||
|
||||
def _cbGoodAuthentication(self, avatarId, credentials):
|
||||
"""
|
||||
Called if a checker has verified the credentials. We call our
|
||||
L{areDone} method to see if the whole of the successful authentications
|
||||
are enough. If they are, we return the avatar ID returned by the first
|
||||
checker.
|
||||
"""
|
||||
if avatarId not in self.successfulCredentials:
|
||||
self.successfulCredentials[avatarId] = []
|
||||
self.successfulCredentials[avatarId].append(credentials)
|
||||
if self.areDone(avatarId):
|
||||
del self.successfulCredentials[avatarId]
|
||||
return avatarId
|
||||
else:
|
||||
raise error.NotEnoughAuthentication()
|
||||
|
||||
def areDone(self, avatarId):
|
||||
"""
|
||||
Override to determine if the authentication is finished for a given
|
||||
avatarId.
|
||||
|
||||
@param avatarId: the avatar returned by the first checker. For
|
||||
this checker to function correctly, all the checkers must
|
||||
return the same avatar ID.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 15, 0, 0),
|
||||
(
|
||||
"Please use twisted.conch.checkers.SSHPublicKeyChecker, "
|
||||
"initialized with an instance of "
|
||||
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."
|
||||
),
|
||||
__name__,
|
||||
"SSHPublicKeyDatabase",
|
||||
)
|
||||
|
||||
|
||||
class IAuthorizedKeysDB(Interface):
|
||||
"""
|
||||
An object that provides valid authorized ssh keys mapped to usernames.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
|
||||
def getAuthorizedKeys(avatarId):
|
||||
"""
|
||||
Gets an iterable of authorized keys that are valid for the given
|
||||
C{avatarId}.
|
||||
|
||||
@param avatarId: the ID of the avatar
|
||||
@type avatarId: valid return value of
|
||||
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
|
||||
|
||||
@return: an iterable of L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
|
||||
|
||||
def readAuthorizedKeyFile(
|
||||
fileobj: IO[bytes], parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString
|
||||
) -> Iterator[keys.Key]:
|
||||
"""
|
||||
Reads keys from an authorized keys file. Any non-comment line that cannot
|
||||
be parsed as a key will be ignored, although that particular line will
|
||||
be logged.
|
||||
|
||||
@param fileobj: something from which to read lines which can be parsed
|
||||
as keys
|
||||
@param parseKey: a callable that takes bytes and returns a
|
||||
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
|
||||
default is L{twisted.conch.ssh.keys.Key.fromString}.
|
||||
@return: an iterable of L{twisted.conch.ssh.keys.Key}
|
||||
@since: 15.0
|
||||
"""
|
||||
for line in fileobj:
|
||||
line = line.strip()
|
||||
if line and not line.startswith(b"#"): # for comments
|
||||
try:
|
||||
yield parseKey(line)
|
||||
except keys.BadKeyError as e:
|
||||
_log.error(
|
||||
"Unable to parse line {line!r} as a key: {error!s}",
|
||||
line=line,
|
||||
error=e,
|
||||
)
|
||||
|
||||
|
||||
def _keysFromFilepaths(
|
||||
filepaths: Iterable[FilePath[Any]], parseKey: Callable[[bytes], keys.Key]
|
||||
) -> Iterable[keys.Key]:
|
||||
"""
|
||||
Helper function that turns an iterable of filepaths into a generator of
|
||||
keys. If any file cannot be read, a message is logged but it is
|
||||
otherwise ignored.
|
||||
|
||||
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
|
||||
@type filepaths: iterable
|
||||
|
||||
@param parseKey: a callable that takes a string and returns a
|
||||
L{twisted.conch.ssh.keys.Key}
|
||||
@type parseKey: L{callable}
|
||||
|
||||
@return: generator of L{twisted.conch.ssh.keys.Key}
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
for fp in filepaths:
|
||||
if fp.exists():
|
||||
try:
|
||||
with fp.open() as f:
|
||||
yield from readAuthorizedKeyFile(f, parseKey)
|
||||
except OSError as e:
|
||||
_log.error("Unable to read {path!r}: {error!s}", path=fp.path, error=e)
|
||||
|
||||
|
||||
@implementer(IAuthorizedKeysDB)
|
||||
class InMemorySSHKeyDB:
|
||||
"""
|
||||
Object that provides SSH public keys based on a dictionary of usernames
|
||||
mapped to L{twisted.conch.ssh.keys.Key}s.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
|
||||
def __init__(self, mapping: Mapping[bytes, Iterable[keys.Key]]) -> None:
|
||||
"""
|
||||
Initializes a new L{InMemorySSHKeyDB}.
|
||||
|
||||
@param mapping: mapping of usernames to iterables of
|
||||
L{twisted.conch.ssh.keys.Key}s
|
||||
|
||||
"""
|
||||
self._mapping = mapping
|
||||
|
||||
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
|
||||
"""
|
||||
Look up the authorized keys for a user.
|
||||
|
||||
@param username: Name of the user
|
||||
"""
|
||||
return self._mapping.get(username, [])
|
||||
|
||||
|
||||
@implementer(IAuthorizedKeysDB)
|
||||
class UNIXAuthorizedKeysFiles:
|
||||
"""
|
||||
Object that provides SSH public keys based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
|
||||
If any of the files cannot be read, a message is logged but that file is
|
||||
otherwise ignored.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
|
||||
_userdb: UserDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
userdb: Optional[UserDB] = None,
|
||||
parseKey: Callable[[bytes], keys.Key] = keys.Key.fromString,
|
||||
):
|
||||
"""
|
||||
Initializes a new L{UNIXAuthorizedKeysFiles}.
|
||||
|
||||
@param userdb: access to the Unix user account and password database
|
||||
(default is the Python module L{pwd}, if available)
|
||||
|
||||
@param parseKey: a callable that takes a string and returns a
|
||||
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
|
||||
default is L{twisted.conch.ssh.keys.Key.fromString}.
|
||||
"""
|
||||
if userdb is not None:
|
||||
self._userdb = userdb
|
||||
elif pwd is not None:
|
||||
self._userdb = pwd
|
||||
else:
|
||||
raise ValueError("No pwd module found, and no userdb argument passed.")
|
||||
self._parseKey = parseKey
|
||||
|
||||
def getAuthorizedKeys(self, username: bytes) -> Iterable[keys.Key]:
|
||||
try:
|
||||
passwd = _lookupUser(self._userdb, username)
|
||||
except KeyError:
|
||||
return ()
|
||||
|
||||
root = FilePath(passwd.pw_dir).child(".ssh")
|
||||
files = ["authorized_keys", "authorized_keys2"]
|
||||
return _keysFromFilepaths((root.child(f) for f in files), self._parseKey)
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHPublicKeyChecker:
|
||||
"""
|
||||
Checker that authenticates SSH public keys, based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in user .ssh/ directories.
|
||||
|
||||
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
|
||||
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
|
||||
credentialInterfaces = (ISSHPrivateKey,)
|
||||
|
||||
def __init__(self, keydb: IAuthorizedKeysDB) -> None:
|
||||
"""
|
||||
Initializes a L{SSHPublicKeyChecker}.
|
||||
|
||||
@param keydb: a provider of L{IAuthorizedKeysDB}
|
||||
"""
|
||||
self._keydb = keydb
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
d = defer.execute(self._sanityCheckKey, credentials)
|
||||
d.addCallback(self._checkKey, credentials)
|
||||
d.addCallback(self._verifyKey, credentials)
|
||||
return d
|
||||
|
||||
def _sanityCheckKey(self, credentials):
|
||||
"""
|
||||
Checks whether the provided credentials are a valid SSH key with a
|
||||
signature (does not actually verify the signature).
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise ValidPublicKey: the credentials do not include a signature. See
|
||||
L{error.ValidPublicKey} for more information.
|
||||
|
||||
@raise BadKeyError: The key included with the credentials is not
|
||||
recognized as a key.
|
||||
|
||||
@return: the key in the credentials
|
||||
@rtype: L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
if not credentials.signature:
|
||||
raise error.ValidPublicKey()
|
||||
|
||||
return keys.Key.fromString(credentials.blob)
|
||||
|
||||
def _checkKey(self, pubKey, credentials):
|
||||
"""
|
||||
Checks the public key against all authorized keys (if any) for the
|
||||
user.
|
||||
|
||||
@param pubKey: the key in the credentials (just to prevent it from
|
||||
having to be calculated again)
|
||||
@type pubKey:
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: If the key is not authorized, or if there
|
||||
was any error obtaining a list of authorized keys for the user.
|
||||
|
||||
@return: C{pubKey} if the key is authorized
|
||||
@rtype: L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
if any(
|
||||
key == pubKey for key in self._keydb.getAuthorizedKeys(credentials.username)
|
||||
):
|
||||
return pubKey
|
||||
|
||||
raise UnauthorizedLogin("Key not authorized")
|
||||
|
||||
def _verifyKey(self, pubKey, credentials):
|
||||
"""
|
||||
Checks whether the credentials themselves are valid, now that we know
|
||||
if the key matches the user.
|
||||
|
||||
@param pubKey: the key in the credentials (just to prevent it from
|
||||
having to be calculated again)
|
||||
@type pubKey: L{twisted.conch.ssh.keys.Key}
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: If the key signature is invalid or there
|
||||
was any error verifying the signature.
|
||||
|
||||
@return: The user's username, if authentication was successful
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
try:
|
||||
if pubKey.verify(credentials.signature, credentials.sigData):
|
||||
return credentials.username
|
||||
except Exception as e: # Any error should be treated as a failed login
|
||||
raise UnauthorizedLogin("Error while verifying key") from e
|
||||
|
||||
raise UnauthorizedLogin("Key signature invalid.")
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
"""
|
||||
Client support code for Conch.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
@@ -0,0 +1,65 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_default -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Accesses the key agent for user authentication.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from twisted.conch.ssh import agent, channel, keys
|
||||
from twisted.internet import protocol, reactor
|
||||
from twisted.logger import Logger
|
||||
|
||||
|
||||
class SSHAgentClient(agent.SSHAgentClient):
|
||||
_log = Logger()
|
||||
|
||||
def __init__(self):
|
||||
agent.SSHAgentClient.__init__(self)
|
||||
self.blobs = []
|
||||
|
||||
def getPublicKeys(self):
|
||||
return self.requestIdentities().addCallback(self._cbPublicKeys)
|
||||
|
||||
def _cbPublicKeys(self, blobcomm):
|
||||
self._log.debug("got {num_keys} public keys", num_keys=len(blobcomm))
|
||||
self.blobs = [x[0] for x in blobcomm]
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Return a L{Key} from the first blob in C{self.blobs}, if any, or
|
||||
return L{None}.
|
||||
"""
|
||||
if self.blobs:
|
||||
return keys.Key.fromString(self.blobs.pop(0))
|
||||
return None
|
||||
|
||||
|
||||
class SSHAgentForwardingChannel(channel.SSHChannel):
|
||||
def channelOpen(self, specificData):
|
||||
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
|
||||
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
|
||||
d.addCallback(self._cbGotLocal)
|
||||
d.addErrback(lambda x: self.loseConnection())
|
||||
self.buf = ""
|
||||
|
||||
def _cbGotLocal(self, local):
|
||||
self.local = local
|
||||
self.dataReceived = self.local.transport.write
|
||||
self.local.dataReceived = self.write
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
|
||||
def closed(self):
|
||||
if self.local:
|
||||
self.local.loseConnection()
|
||||
self.local = None
|
||||
|
||||
|
||||
class SSHAgentForwardingLocal(protocol.Protocol):
|
||||
pass
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
from twisted.conch.client import direct
|
||||
|
||||
connectTypes = {"direct": direct.connect}
|
||||
|
||||
|
||||
def connect(host, port, options, verifyHostKey, userAuthObject):
|
||||
useConnects = ["direct"]
|
||||
return _ebConnect(
|
||||
None, useConnects, host, port, options, verifyHostKey, userAuthObject
|
||||
)
|
||||
|
||||
|
||||
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
|
||||
if not useConnects:
|
||||
return f
|
||||
connectType = useConnects.pop(0)
|
||||
f = connectTypes[connectType]
|
||||
d = f(host, port, options, vhk, uao)
|
||||
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
|
||||
return d
|
||||
@@ -0,0 +1,331 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Various classes and functions for implementing user-interaction in the
|
||||
command-line conch client.
|
||||
|
||||
You probably shouldn't use anything in this module directly, since it assumes
|
||||
you are sitting at an interactive terminal. For example, to programmatically
|
||||
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import getpass
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from base64 import decodebytes
|
||||
|
||||
from twisted.conch.client import agent
|
||||
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.ssh import common, keys, userauth
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.python.compat import nativeString
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
# The default location of the known hosts file (probably should be parsed out
|
||||
# of an ssh config file someday).
|
||||
_KNOWN_HOSTS = "~/.ssh/known_hosts"
|
||||
|
||||
|
||||
# This name is bound so that the unit tests can use 'patch' to override it.
|
||||
_open = open
|
||||
_input = input
|
||||
|
||||
|
||||
def verifyHostKey(transport, host, pubKey, fingerprint):
|
||||
"""
|
||||
Verify a host's key.
|
||||
|
||||
This function is a gross vestige of some bad factoring in the client
|
||||
internals. The actual implementation, and a better signature of this logic
|
||||
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
|
||||
because the callers have not yet been rehabilitated, but they should
|
||||
eventually be changed to call that method instead.
|
||||
|
||||
However, this function does perform two functions not implemented by
|
||||
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
|
||||
known_hosts file based on the options (which should really be the options
|
||||
object's job), and it provides an opener to L{ConsoleUI} which opens
|
||||
'/dev/tty' so that the user will be prompted on the tty of the process even
|
||||
if the input and output of the process has been redirected. This latter
|
||||
part is, somewhat obviously, not portable, but I don't know of a portable
|
||||
equivalent that could be used.
|
||||
|
||||
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
|
||||
always the dotted-quad IP address of the host being connected to.
|
||||
@type host: L{str}
|
||||
|
||||
@param transport: the client transport which is attempting to connect to
|
||||
the given host.
|
||||
@type transport: L{SSHClientTransport}
|
||||
|
||||
@param fingerprint: the fingerprint of the given public key, in
|
||||
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
|
||||
from the key itself.
|
||||
@type fingerprint: L{str}
|
||||
|
||||
@param pubKey: The public key of the server being connected to.
|
||||
@type pubKey: L{str}
|
||||
|
||||
@return: a L{Deferred} which fires with C{1} if the key was successfully
|
||||
verified, or fails if the key could not be successfully verified. Failure
|
||||
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
|
||||
L{KeyboardInterrupt}.
|
||||
"""
|
||||
actualHost = transport.factory.options["host"]
|
||||
actualKey = keys.Key.fromString(pubKey)
|
||||
kh = KnownHostsFile.fromPath(
|
||||
FilePath(
|
||||
transport.factory.options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS)
|
||||
)
|
||||
)
|
||||
ui = ConsoleUI(lambda: _open("/dev/tty", "r+b", buffering=0))
|
||||
return kh.verifyHostKey(ui, actualHost, host, actualKey)
|
||||
|
||||
|
||||
def isInKnownHosts(host, pubKey, options):
|
||||
"""
|
||||
Checks to see if host is in the known_hosts file for the user.
|
||||
|
||||
@return: 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
|
||||
@rtype: L{int}
|
||||
"""
|
||||
keyType = common.getNS(pubKey)[0]
|
||||
retVal = 0
|
||||
|
||||
if not options["known-hosts"] and not os.path.exists(os.path.expanduser("~/.ssh/")):
|
||||
print("Creating ~/.ssh directory...")
|
||||
os.mkdir(os.path.expanduser("~/.ssh"))
|
||||
kh_file = options["known-hosts"] or _KNOWN_HOSTS
|
||||
try:
|
||||
known_hosts = open(os.path.expanduser(kh_file), "rb")
|
||||
except OSError:
|
||||
return 0
|
||||
with known_hosts:
|
||||
for line in known_hosts.readlines():
|
||||
split = line.split()
|
||||
if len(split) < 3:
|
||||
continue
|
||||
hosts, hostKeyType, encodedKey = split[:3]
|
||||
if host not in hosts.split(b","): # incorrect host
|
||||
continue
|
||||
if hostKeyType != keyType: # incorrect type of key
|
||||
continue
|
||||
try:
|
||||
decodedKey = decodebytes(encodedKey)
|
||||
except BaseException:
|
||||
continue
|
||||
if decodedKey == pubKey:
|
||||
return 1
|
||||
else:
|
||||
retVal = 2
|
||||
return retVal
|
||||
|
||||
|
||||
def getHostKeyAlgorithms(host, options):
|
||||
"""
|
||||
Look in known_hosts for a key corresponding to C{host}.
|
||||
This can be used to change the order of supported key types
|
||||
in the KEXINIT packet.
|
||||
|
||||
@type host: L{str}
|
||||
@param host: the host to check in known_hosts
|
||||
@type options: L{twisted.conch.client.options.ConchOptions}
|
||||
@param options: options passed to client
|
||||
@return: L{list} of L{str} representing key types or L{None}.
|
||||
"""
|
||||
knownHosts = KnownHostsFile.fromPath(
|
||||
FilePath(options["known-hosts"] or os.path.expanduser(_KNOWN_HOSTS))
|
||||
)
|
||||
keyTypes = []
|
||||
for entry in knownHosts.iterentries():
|
||||
if entry.matchesHost(host):
|
||||
if entry.keyType not in keyTypes:
|
||||
keyTypes.append(entry.keyType)
|
||||
return keyTypes or None
|
||||
|
||||
|
||||
class SSHUserAuthClient(userauth.SSHUserAuthClient):
|
||||
def __init__(self, user, options, *args):
|
||||
userauth.SSHUserAuthClient.__init__(self, user, *args)
|
||||
self.keyAgent = None
|
||||
self.options = options
|
||||
self.usedFiles = []
|
||||
if not options.identitys:
|
||||
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
|
||||
|
||||
def serviceStarted(self):
|
||||
if "SSH_AUTH_SOCK" in os.environ and not self.options["noagent"]:
|
||||
self._log.debug(
|
||||
"using SSH agent {authSock!r}", authSock=os.environ["SSH_AUTH_SOCK"]
|
||||
)
|
||||
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
|
||||
d = cc.connectUNIX(os.environ["SSH_AUTH_SOCK"])
|
||||
d.addCallback(self._setAgent)
|
||||
d.addErrback(self._ebSetAgent)
|
||||
else:
|
||||
userauth.SSHUserAuthClient.serviceStarted(self)
|
||||
|
||||
def serviceStopped(self):
|
||||
if self.keyAgent:
|
||||
self.keyAgent.transport.loseConnection()
|
||||
self.keyAgent = None
|
||||
|
||||
def _setAgent(self, a):
|
||||
self.keyAgent = a
|
||||
d = self.keyAgent.getPublicKeys()
|
||||
d.addBoth(self._ebSetAgent)
|
||||
return d
|
||||
|
||||
def _ebSetAgent(self, f):
|
||||
userauth.SSHUserAuthClient.serviceStarted(self)
|
||||
|
||||
def _getPassword(self, prompt):
|
||||
"""
|
||||
Prompt for a password using L{getpass.getpass}.
|
||||
|
||||
@param prompt: Written on tty to ask for the input.
|
||||
@type prompt: L{str}
|
||||
@return: The input.
|
||||
@rtype: L{str}
|
||||
"""
|
||||
with self._replaceStdoutStdin():
|
||||
try:
|
||||
p = getpass.getpass(prompt)
|
||||
return p
|
||||
except (KeyboardInterrupt, OSError):
|
||||
print()
|
||||
raise ConchError("PEBKAC")
|
||||
|
||||
def getPassword(self, prompt=None):
|
||||
if prompt:
|
||||
prompt = nativeString(prompt)
|
||||
else:
|
||||
prompt = "{}@{}'s password: ".format(
|
||||
nativeString(self.user),
|
||||
self.transport.transport.getPeer().host,
|
||||
)
|
||||
try:
|
||||
# We don't know the encoding the other side is using,
|
||||
# signaling that is not part of the SSH protocol. But
|
||||
# using our defaultencoding is better than just going for
|
||||
# ASCII.
|
||||
p = self._getPassword(prompt).encode(sys.getdefaultencoding())
|
||||
return defer.succeed(p)
|
||||
except ConchError:
|
||||
return defer.fail()
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Get a public key from the key agent if possible, otherwise look in
|
||||
the next configured identity file for one.
|
||||
"""
|
||||
if self.keyAgent:
|
||||
key = self.keyAgent.getPublicKey()
|
||||
if key is not None:
|
||||
return key
|
||||
files = [x for x in self.options.identitys if x not in self.usedFiles]
|
||||
self._log.debug(
|
||||
"public key identities: {identities}\n{files}",
|
||||
identities=self.options.identitys,
|
||||
files=files,
|
||||
)
|
||||
if not files:
|
||||
return None
|
||||
file = files[0]
|
||||
self.usedFiles.append(file)
|
||||
file = os.path.expanduser(file)
|
||||
file += ".pub"
|
||||
if not os.path.exists(file):
|
||||
return self.getPublicKey() # try again
|
||||
try:
|
||||
return keys.Key.fromFile(file)
|
||||
except keys.BadKeyError:
|
||||
return self.getPublicKey() # try again
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Extend the base signing behavior by using an SSH agent to sign the
|
||||
data, if one is available.
|
||||
|
||||
@type publicKey: L{Key}
|
||||
@type signData: L{bytes}
|
||||
"""
|
||||
if not self.usedFiles: # agent key
|
||||
return self.keyAgent.signData(publicKey.blob(), signData)
|
||||
else:
|
||||
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Try to load the private key from the last used file identified by
|
||||
C{getPublicKey}, potentially asking for the passphrase if the key is
|
||||
encrypted.
|
||||
"""
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
if not os.path.exists(file):
|
||||
return None
|
||||
try:
|
||||
return defer.succeed(keys.Key.fromFile(file))
|
||||
except keys.EncryptedKeyError:
|
||||
for i in range(3):
|
||||
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
|
||||
try:
|
||||
p = self._getPassword(prompt).encode(sys.getfilesystemencoding())
|
||||
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
|
||||
except (keys.BadKeyError, ConchError):
|
||||
pass
|
||||
return defer.fail(ConchError("bad password"))
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
reactor.stop()
|
||||
|
||||
def getGenericAnswers(self, name, instruction, prompts):
|
||||
responses = []
|
||||
with self._replaceStdoutStdin():
|
||||
if name:
|
||||
print(name.decode("utf-8"))
|
||||
if instruction:
|
||||
print(instruction.decode("utf-8"))
|
||||
for prompt, echo in prompts:
|
||||
prompt = prompt.decode("utf-8")
|
||||
if echo:
|
||||
responses.append(_input(prompt))
|
||||
else:
|
||||
responses.append(getpass.getpass(prompt))
|
||||
return defer.succeed(responses)
|
||||
|
||||
@classmethod
|
||||
def _openTty(cls):
|
||||
"""
|
||||
Open /dev/tty as two streams one in read, one in write mode,
|
||||
and return them.
|
||||
|
||||
@return: File objects for reading and writing to /dev/tty,
|
||||
corresponding to standard input and standard output.
|
||||
@rtype: A L{tuple} of L{io.TextIOWrapper} on Python 3.
|
||||
"""
|
||||
stdin = io.TextIOWrapper(open("/dev/tty", "rb"))
|
||||
stdout = io.TextIOWrapper(open("/dev/tty", "wb"))
|
||||
return stdin, stdout
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def _replaceStdoutStdin(cls):
|
||||
"""
|
||||
Contextmanager that replaces stdout and stdin with /dev/tty
|
||||
and resets them when it is done.
|
||||
"""
|
||||
oldout, oldin = sys.stdout, sys.stdin
|
||||
sys.stdin, sys.stdout = cls._openTty()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout.close()
|
||||
sys.stdin.close()
|
||||
sys.stdout, sys.stdin = oldout, oldin
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import transport
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
|
||||
|
||||
class SSHClientFactory(protocol.ClientFactory):
|
||||
def __init__(self, d, options, verifyHostKey, userAuthObject):
|
||||
self.d = d
|
||||
self.options = options
|
||||
self.verifyHostKey = verifyHostKey
|
||||
self.userAuthObject = userAuthObject
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
if self.options["reconnect"]:
|
||||
connector.connect()
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
if self.d is None:
|
||||
return
|
||||
d, self.d = self.d, None
|
||||
d.errback(reason)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
trans = SSHClientTransport(self)
|
||||
if self.options["ciphers"]:
|
||||
trans.supportedCiphers = self.options["ciphers"]
|
||||
if self.options["macs"]:
|
||||
trans.supportedMACs = self.options["macs"]
|
||||
if self.options["compress"]:
|
||||
trans.supportedCompressions[0:1] = ["zlib"]
|
||||
if self.options["host-key-algorithms"]:
|
||||
trans.supportedPublicKeys = self.options["host-key-algorithms"]
|
||||
return trans
|
||||
|
||||
|
||||
class SSHClientTransport(transport.SSHClientTransport):
|
||||
def __init__(self, factory):
|
||||
self.factory = factory
|
||||
self.unixServer = None
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.unixServer:
|
||||
d = self.unixServer.stopListening()
|
||||
self.unixServer = None
|
||||
else:
|
||||
d = defer.succeed(None)
|
||||
d.addCallback(
|
||||
lambda x: transport.SSHClientTransport.connectionLost(self, reason)
|
||||
)
|
||||
|
||||
def receiveError(self, code, desc):
|
||||
if self.factory.d is None:
|
||||
return
|
||||
d, self.factory.d = self.factory.d, None
|
||||
d.errback(error.ConchError(desc, code))
|
||||
|
||||
def sendDisconnect(self, code, reason):
|
||||
if self.factory.d is None:
|
||||
return
|
||||
d, self.factory.d = self.factory.d, None
|
||||
transport.SSHClientTransport.sendDisconnect(self, code, reason)
|
||||
d.errback(error.ConchError(reason, code))
|
||||
|
||||
def receiveDebug(self, alwaysDisplay, message, lang):
|
||||
self._log.debug(
|
||||
"Received Debug Message: {message}",
|
||||
message=message,
|
||||
alwaysDisplay=alwaysDisplay,
|
||||
lang=lang,
|
||||
)
|
||||
if alwaysDisplay: # XXX what should happen here?
|
||||
print(message)
|
||||
|
||||
def verifyHostKey(self, pubKey, fingerprint):
|
||||
return self.factory.verifyHostKey(
|
||||
self, self.transport.getPeer().host, pubKey, fingerprint
|
||||
)
|
||||
|
||||
def setService(self, service):
|
||||
self._log.info("setting client server to {service}", service=service)
|
||||
transport.SSHClientTransport.setService(self, service)
|
||||
if service.name != "ssh-userauth" and self.factory.d is not None:
|
||||
d, self.factory.d = self.factory.d, None
|
||||
d.callback(None)
|
||||
|
||||
def connectionSecure(self):
|
||||
self.requestService(self.factory.userAuthObject)
|
||||
|
||||
|
||||
def connect(host, port, options, verifyHostKey, userAuthObject):
|
||||
d = defer.Deferred()
|
||||
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
|
||||
reactor.connectTCP(host, port, factory)
|
||||
return d
|
||||
@@ -0,0 +1,622 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
An implementation of the OpenSSH known_hosts database.
|
||||
|
||||
@since: 8.2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import sys
|
||||
from binascii import Error as DecodeError, a2b_base64, b2a_base64
|
||||
from contextlib import closing
|
||||
from hashlib import sha1
|
||||
from typing import IO, Callable, Literal
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.error import HostKeyChanged, InvalidEntry, UserRejectedKey
|
||||
from twisted.conch.interfaces import IKnownHostEntry
|
||||
from twisted.conch.ssh.keys import BadKeyError, FingerprintFormats, Key
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.compat import nativeString
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.randbytes import secureRandom
|
||||
from twisted.python.util import FancyEqMixin
|
||||
|
||||
log = Logger()
|
||||
|
||||
|
||||
def _b64encode(s):
|
||||
"""
|
||||
Encode a binary string as base64 with no trailing newline.
|
||||
|
||||
@param s: The string to encode.
|
||||
@type s: L{bytes}
|
||||
|
||||
@return: The base64-encoded string.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return b2a_base64(s).strip()
|
||||
|
||||
|
||||
def _extractCommon(string):
|
||||
"""
|
||||
Extract common elements of base64 keys from an entry in a hosts file.
|
||||
|
||||
@param string: A known hosts file entry (a single line).
|
||||
@type string: L{bytes}
|
||||
|
||||
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
|
||||
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
|
||||
simply the beginning of the line up to the first occurrence of
|
||||
whitespace.
|
||||
@rtype: L{tuple}
|
||||
"""
|
||||
elements = string.split(None, 2)
|
||||
if len(elements) != 3:
|
||||
raise InvalidEntry()
|
||||
hostnames, keyType, keyAndComment = elements
|
||||
splitkey = keyAndComment.split(None, 1)
|
||||
if len(splitkey) == 2:
|
||||
keyString, comment = splitkey
|
||||
comment = comment.rstrip(b"\n")
|
||||
else:
|
||||
keyString = splitkey[0]
|
||||
comment = None
|
||||
key = Key.fromString(a2b_base64(keyString))
|
||||
return hostnames, keyType, key, comment
|
||||
|
||||
|
||||
class _BaseEntry:
|
||||
"""
|
||||
Abstract base of both hashed and non-hashed entry objects, since they
|
||||
represent keys and key types the same way.
|
||||
|
||||
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
|
||||
@type keyType: L{bytes}
|
||||
|
||||
@ivar publicKey: The server public key indicated by this line.
|
||||
@type publicKey: L{twisted.conch.ssh.keys.Key}
|
||||
|
||||
@ivar comment: Trailing garbage after the key line.
|
||||
@type comment: L{bytes}
|
||||
"""
|
||||
|
||||
def __init__(self, keyType, publicKey, comment):
|
||||
self.keyType = keyType
|
||||
self.publicKey = publicKey
|
||||
self.comment = comment
|
||||
|
||||
def matchesKey(self, keyObject):
|
||||
"""
|
||||
Check to see if this entry matches a given key object.
|
||||
|
||||
@param keyObject: A public key object to check.
|
||||
@type keyObject: L{Key}
|
||||
|
||||
@return: C{True} if this entry's key matches C{keyObject}, C{False}
|
||||
otherwise.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return self.publicKey == keyObject
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class PlainEntry(_BaseEntry):
|
||||
"""
|
||||
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
|
||||
file.
|
||||
|
||||
@ivar _hostnames: the list of all host-names associated with this entry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hostnames: list[bytes], keyType: bytes, publicKey: Key, comment: bytes
|
||||
):
|
||||
self._hostnames: list[bytes] = hostnames
|
||||
super().__init__(keyType, publicKey, comment)
|
||||
|
||||
@classmethod
|
||||
def fromString(cls, string: bytes) -> PlainEntry:
|
||||
"""
|
||||
Parse a plain-text entry in a known_hosts file, and return a
|
||||
corresponding L{PlainEntry}.
|
||||
|
||||
@param string: a space-separated string formatted like "hostname
|
||||
key-type base64-key-data comment".
|
||||
|
||||
@raise DecodeError: if the key is not valid encoded as valid base64.
|
||||
|
||||
@raise InvalidEntry: if the entry does not have the right number of
|
||||
elements and is therefore invalid.
|
||||
|
||||
@raise BadKeyError: if the key, once decoded from base64, is not
|
||||
actually an SSH key.
|
||||
|
||||
@return: an IKnownHostEntry representing the hostname and key in the
|
||||
input line.
|
||||
|
||||
@rtype: L{PlainEntry}
|
||||
"""
|
||||
hostnames, keyType, key, comment = _extractCommon(string)
|
||||
self = cls(hostnames.split(b","), keyType, key, comment)
|
||||
return self
|
||||
|
||||
def matchesHost(self, hostname: bytes | str) -> bool:
|
||||
"""
|
||||
Check to see if this entry matches a given hostname.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check against this
|
||||
entry.
|
||||
|
||||
@return: C{True} if this entry is for the given hostname or IP address,
|
||||
C{False} otherwise.
|
||||
"""
|
||||
if isinstance(hostname, str):
|
||||
hostname = hostname.encode("utf-8")
|
||||
return hostname in self._hostnames
|
||||
|
||||
def toString(self) -> bytes:
|
||||
"""
|
||||
Implement L{IKnownHostEntry.toString} by recording the comma-separated
|
||||
hostnames, key type, and base-64 encoded key.
|
||||
|
||||
@return: The string representation of this entry, with unhashed hostname
|
||||
information.
|
||||
"""
|
||||
fields = [
|
||||
b",".join(self._hostnames),
|
||||
self.keyType,
|
||||
_b64encode(self.publicKey.blob()),
|
||||
]
|
||||
if self.comment is not None:
|
||||
fields.append(self.comment)
|
||||
return b" ".join(fields)
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class UnparsedEntry:
|
||||
"""
|
||||
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
|
||||
parsed; therefore it matches no keys and no hosts.
|
||||
"""
|
||||
|
||||
def __init__(self, string):
|
||||
"""
|
||||
Create an unparsed entry from a line in a known_hosts file which cannot
|
||||
otherwise be parsed.
|
||||
"""
|
||||
self._string = string
|
||||
|
||||
def matchesHost(self, hostname):
|
||||
"""
|
||||
Always returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def matchesKey(self, key):
|
||||
"""
|
||||
Always returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def toString(self):
|
||||
"""
|
||||
Returns the input line, without its newline if one was given.
|
||||
|
||||
@return: The string representation of this entry, almost exactly as was
|
||||
used to initialize this entry but without a trailing newline.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return self._string.rstrip(b"\n")
|
||||
|
||||
|
||||
def _hmacedString(key, string):
|
||||
"""
|
||||
Return the SHA-1 HMAC hash of the given key and string.
|
||||
|
||||
@param key: The HMAC key.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param string: The string to be hashed.
|
||||
@type string: L{bytes}
|
||||
|
||||
@return: The keyed hash value.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
hash = hmac.HMAC(key, digestmod=sha1)
|
||||
if isinstance(string, str):
|
||||
string = string.encode("utf-8")
|
||||
hash.update(string)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class HashedEntry(_BaseEntry, FancyEqMixin):
|
||||
"""
|
||||
A L{HashedEntry} is a representation of an entry in a known_hosts file
|
||||
where the hostname has been hashed and salted.
|
||||
|
||||
@ivar _hostSalt: the salt to combine with a hostname for hashing.
|
||||
|
||||
@ivar _hostHash: the hashed representation of the hostname.
|
||||
|
||||
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
|
||||
known_hosts file as opposed to a plaintext one.
|
||||
"""
|
||||
|
||||
MAGIC = b"|1|"
|
||||
|
||||
compareAttributes = ("_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostSalt: bytes,
|
||||
hostHash: bytes,
|
||||
keyType: bytes,
|
||||
publicKey: Key,
|
||||
comment: bytes | None,
|
||||
) -> None:
|
||||
self._hostSalt = hostSalt
|
||||
self._hostHash = hostHash
|
||||
super().__init__(keyType, publicKey, comment)
|
||||
|
||||
@classmethod
|
||||
def fromString(cls, string: bytes) -> HashedEntry:
|
||||
"""
|
||||
Load a hashed entry from a string representing a line in a known_hosts
|
||||
file.
|
||||
|
||||
@param string: A complete single line from a I{known_hosts} file,
|
||||
formatted as defined by OpenSSH.
|
||||
|
||||
@raise DecodeError: if the key, the hostname, or the is not valid
|
||||
encoded as valid base64
|
||||
|
||||
@raise InvalidEntry: if the entry does not have the right number of
|
||||
elements and is therefore invalid, or the host/hash portion
|
||||
contains more items than just the host and hash.
|
||||
|
||||
@raise BadKeyError: if the key, once decoded from base64, is not
|
||||
actually an SSH key.
|
||||
|
||||
@return: The newly created L{HashedEntry} instance, initialized with
|
||||
the information from C{string}.
|
||||
"""
|
||||
stuff, keyType, key, comment = _extractCommon(string)
|
||||
saltAndHash = stuff[len(cls.MAGIC) :].split(b"|")
|
||||
if len(saltAndHash) != 2:
|
||||
raise InvalidEntry()
|
||||
hostSalt, hostHash = saltAndHash
|
||||
self = cls(a2b_base64(hostSalt), a2b_base64(hostHash), keyType, key, comment)
|
||||
return self
|
||||
|
||||
def matchesHost(self, hostname):
|
||||
"""
|
||||
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
|
||||
input to the stored hash.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check against this
|
||||
entry.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@return: C{True} if this entry is for the given hostname or IP address,
|
||||
C{False} otherwise.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return hmac.compare_digest(
|
||||
_hmacedString(self._hostSalt, hostname), self._hostHash
|
||||
)
|
||||
|
||||
def toString(self):
|
||||
"""
|
||||
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
|
||||
hash, and key.
|
||||
|
||||
@return: The string representation of this entry, with the hostname part
|
||||
hashed.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
fields = [
|
||||
self.MAGIC
|
||||
+ b"|".join([_b64encode(self._hostSalt), _b64encode(self._hostHash)]),
|
||||
self.keyType,
|
||||
_b64encode(self.publicKey.blob()),
|
||||
]
|
||||
if self.comment is not None:
|
||||
fields.append(self.comment)
|
||||
return b" ".join(fields)
|
||||
|
||||
|
||||
class KnownHostsFile:
|
||||
"""
|
||||
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
|
||||
|
||||
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
|
||||
to this instance in memory but not yet saved.
|
||||
|
||||
@ivar _clobber: A flag indicating whether the current contents of the save
|
||||
path will be disregarded and potentially overwritten or not. If
|
||||
C{True}, this will be done. If C{False}, entries in the save path will
|
||||
be read and new entries will be saved by appending rather than
|
||||
overwriting.
|
||||
@type _clobber: L{bool}
|
||||
|
||||
@ivar _savePath: See C{savePath} parameter of L{__init__}.
|
||||
"""
|
||||
|
||||
def __init__(self, savePath: FilePath[str]) -> None:
|
||||
"""
|
||||
Create a new, empty KnownHostsFile.
|
||||
|
||||
Unless you want to erase the current contents of C{savePath}, you want
|
||||
to use L{KnownHostsFile.fromPath} instead.
|
||||
|
||||
@param savePath: The L{FilePath} to which to save new entries.
|
||||
@type savePath: L{FilePath}
|
||||
"""
|
||||
self._added: list[IKnownHostEntry] = []
|
||||
self._savePath = savePath
|
||||
self._clobber = True
|
||||
|
||||
@property
|
||||
def savePath(self) -> FilePath[str]:
|
||||
"""
|
||||
@see: C{savePath} parameter of L{__init__}
|
||||
"""
|
||||
return self._savePath
|
||||
|
||||
def iterentries(self):
|
||||
"""
|
||||
Iterate over the host entries in this file.
|
||||
|
||||
@return: An iterable the elements of which provide L{IKnownHostEntry}.
|
||||
There is an element for each entry in the file as well as an element
|
||||
for each added but not yet saved entry.
|
||||
@rtype: iterable of L{IKnownHostEntry} providers
|
||||
"""
|
||||
for entry in self._added:
|
||||
yield entry
|
||||
|
||||
if self._clobber:
|
||||
return
|
||||
|
||||
try:
|
||||
fp = self._savePath.open()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
with fp:
|
||||
for line in fp:
|
||||
try:
|
||||
if line.startswith(HashedEntry.MAGIC):
|
||||
entry = HashedEntry.fromString(line)
|
||||
else:
|
||||
entry = PlainEntry.fromString(line)
|
||||
except (DecodeError, InvalidEntry, BadKeyError):
|
||||
entry = UnparsedEntry(line)
|
||||
yield entry
|
||||
|
||||
def hasHostKey(self, hostname, key):
|
||||
"""
|
||||
Check for an entry with matching hostname and key.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check for.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param key: The public key to check for.
|
||||
@type key: L{Key}
|
||||
|
||||
@return: C{True} if the given hostname and key are present in this file,
|
||||
C{False} if they are not.
|
||||
@rtype: L{bool}
|
||||
|
||||
@raise HostKeyChanged: if the host key found for the given hostname
|
||||
does not match the given key.
|
||||
"""
|
||||
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
|
||||
if entry.matchesHost(hostname) and entry.keyType == key.sshType():
|
||||
if entry.matchesKey(key):
|
||||
return True
|
||||
else:
|
||||
# Notice that lineidx is 0-based but HostKeyChanged.lineno
|
||||
# is 1-based.
|
||||
if lineidx < 0:
|
||||
line = None
|
||||
path = None
|
||||
else:
|
||||
line = lineidx + 1
|
||||
path = self._savePath
|
||||
raise HostKeyChanged(entry, path, line)
|
||||
return False
|
||||
|
||||
def verifyHostKey(
|
||||
self, ui: ConsoleUI, hostname: bytes, ip: bytes, key: Key
|
||||
) -> Deferred[bool]:
|
||||
"""
|
||||
Verify the given host key for the given IP and host, asking for
|
||||
confirmation from, and notifying, the given UI about changes to this
|
||||
file.
|
||||
|
||||
@param ui: The user interface to request an IP address from.
|
||||
|
||||
@param hostname: The hostname that the user requested to connect to.
|
||||
|
||||
@param ip: The string representation of the IP address that is actually
|
||||
being connected to.
|
||||
|
||||
@param key: The public key of the server.
|
||||
|
||||
@return: a L{Deferred} that fires with True when the key has been
|
||||
verified, or fires with an errback when the key either cannot be
|
||||
verified or has changed.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
hhk = defer.execute(self.hasHostKey, hostname, key)
|
||||
|
||||
def gotHasKey(result: bool) -> bool | Deferred[bool]:
|
||||
if result:
|
||||
if not self.hasHostKey(ip, key):
|
||||
addMessage = (
|
||||
f"Warning: Permanently added the {key.type()} host key"
|
||||
f" for IP address '{ip.decode()}' to the list of known"
|
||||
" hosts.\n"
|
||||
)
|
||||
ui.warn(addMessage.encode("utf-8"))
|
||||
self.addHostKey(ip, key)
|
||||
self.save()
|
||||
return result
|
||||
else:
|
||||
|
||||
def promptResponse(response: bool) -> bool:
|
||||
if response:
|
||||
self.addHostKey(hostname, key)
|
||||
self.addHostKey(ip, key)
|
||||
self.save()
|
||||
return response
|
||||
else:
|
||||
raise UserRejectedKey()
|
||||
|
||||
keytype: str = key.type()
|
||||
|
||||
if keytype == "EC":
|
||||
keytype = "ECDSA"
|
||||
|
||||
prompt = (
|
||||
"The authenticity of host '%s (%s)' "
|
||||
"can't be established.\n"
|
||||
"%s key fingerprint is SHA256:%s.\n"
|
||||
"Are you sure you want to continue connecting (yes/no)? "
|
||||
% (
|
||||
nativeString(hostname),
|
||||
nativeString(ip),
|
||||
keytype,
|
||||
key.fingerprint(format=FingerprintFormats.SHA256_BASE64),
|
||||
)
|
||||
)
|
||||
proceed = ui.prompt(prompt.encode(sys.getdefaultencoding()))
|
||||
return proceed.addCallback(promptResponse)
|
||||
|
||||
return hhk.addCallback(gotHasKey)
|
||||
|
||||
def addHostKey(self, hostname: bytes, key: Key) -> HashedEntry:
|
||||
"""
|
||||
Add a new L{HashedEntry} to the key database.
|
||||
|
||||
Note that you still need to call L{KnownHostsFile.save} if you wish
|
||||
these changes to be persisted.
|
||||
|
||||
@param hostname: A hostname or IP address literal to associate with the
|
||||
new entry.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param key: The public key to associate with the new entry.
|
||||
@type key: L{Key}
|
||||
|
||||
@return: The L{HashedEntry} that was added.
|
||||
@rtype: L{HashedEntry}
|
||||
"""
|
||||
salt = secureRandom(20)
|
||||
keyType = key.sshType()
|
||||
entry = HashedEntry(salt, _hmacedString(salt, hostname), keyType, key, None)
|
||||
self._added.append(entry)
|
||||
return entry
|
||||
|
||||
def save(self) -> None:
|
||||
"""
|
||||
Save this L{KnownHostsFile} to the path it was loaded from.
|
||||
"""
|
||||
p = self._savePath.parent()
|
||||
if not p.isdir():
|
||||
p.makedirs()
|
||||
|
||||
mode: Literal["a", "w"] = "w" if self._clobber else "a"
|
||||
with self._savePath.open(mode) as hostsFileObj:
|
||||
if self._added:
|
||||
hostsFileObj.write(
|
||||
b"\n".join([entry.toString() for entry in self._added]) + b"\n"
|
||||
)
|
||||
self._added = []
|
||||
self._clobber = False
|
||||
|
||||
@classmethod
|
||||
def fromPath(cls, path: FilePath[str]) -> KnownHostsFile:
|
||||
"""
|
||||
Create a new L{KnownHostsFile}, potentially reading existing known
|
||||
hosts information from the given file.
|
||||
|
||||
@param path: A path object to use for both reading contents from and
|
||||
later saving to. If no file exists at this path, it is not an
|
||||
error; a L{KnownHostsFile} with no entries is returned.
|
||||
|
||||
@return: A L{KnownHostsFile} initialized with entries from C{path}.
|
||||
"""
|
||||
knownHosts = cls(path)
|
||||
knownHosts._clobber = False
|
||||
return knownHosts
|
||||
|
||||
|
||||
class ConsoleUI:
|
||||
"""
|
||||
A UI object that can ask true/false questions and post notifications on the
|
||||
console, to be used during key verification.
|
||||
"""
|
||||
|
||||
def __init__(self, opener: Callable[[], IO[bytes]]):
|
||||
"""
|
||||
@param opener: A no-argument callable which should open a console
|
||||
binary-mode file-like object to be used for reading and writing.
|
||||
This initializes the C{opener} attribute.
|
||||
@type opener: callable taking no arguments and returning a read/write
|
||||
file-like object
|
||||
"""
|
||||
self.opener = opener
|
||||
|
||||
def prompt(self, text: bytes) -> Deferred[bool]:
|
||||
"""
|
||||
Write the given text as a prompt to the console output, then read a
|
||||
result from the console input.
|
||||
|
||||
@param text: Something to present to a user to solicit a yes or no
|
||||
response.
|
||||
@type text: L{bytes}
|
||||
|
||||
@return: a L{Deferred} which fires with L{True} when the user answers
|
||||
'yes' and L{False} when the user answers 'no'. It may errback if
|
||||
there were any I/O errors.
|
||||
"""
|
||||
d = defer.succeed(None)
|
||||
|
||||
def body(ignored):
|
||||
with closing(self.opener()) as f:
|
||||
f.write(text)
|
||||
while True:
|
||||
answer = f.readline().strip().lower()
|
||||
if answer == b"yes":
|
||||
return True
|
||||
elif answer in {b"no", b""}:
|
||||
return False
|
||||
else:
|
||||
f.write(b"Please type 'yes' or 'no': ")
|
||||
|
||||
return d.addCallback(body)
|
||||
|
||||
def warn(self, text: bytes) -> None:
|
||||
"""
|
||||
Notify the user (non-interactively) of the provided text, by writing it
|
||||
to the console.
|
||||
|
||||
@param text: Some information the user is to be made aware of.
|
||||
"""
|
||||
try:
|
||||
with closing(self.opener()) as f:
|
||||
f.write(text)
|
||||
except Exception:
|
||||
log.failure("Failed to write to console")
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys
|
||||
from typing import List, Optional, Union
|
||||
|
||||
#
|
||||
from twisted.conch.ssh.transport import SSHCiphers, SSHClientTransport
|
||||
from twisted.python import usage
|
||||
|
||||
|
||||
class ConchOptions(usage.Options):
|
||||
optParameters: List[List[Optional[Union[str, int]]]] = [
|
||||
["user", "l", None, "Log in using this user name."],
|
||||
["identity", "i", None],
|
||||
["ciphers", "c", None],
|
||||
["macs", "m", None],
|
||||
["port", "p", None, "Connect to this port. Server must be on the same port."],
|
||||
["option", "o", None, "Ignored OpenSSH options"],
|
||||
["host-key-algorithms", "", None],
|
||||
["known-hosts", "", None, "File to check for host keys"],
|
||||
["user-authentications", "", None, "Types of user authentications to use."],
|
||||
["logfile", "", None, "File to log to, or - for stdout"],
|
||||
]
|
||||
|
||||
optFlags = [
|
||||
["version", "V", "Display version number only."],
|
||||
["compress", "C", "Enable compression."],
|
||||
["log", "v", "Enable logging (defaults to stderr)"],
|
||||
["nox11", "x", "Disable X11 connection forwarding (default)"],
|
||||
["agent", "A", "Enable authentication agent forwarding"],
|
||||
["noagent", "a", "Disable authentication agent forwarding (default)"],
|
||||
["reconnect", "r", "Reconnect to the server if the connection is lost."],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("agent", "noagent")],
|
||||
optActions={
|
||||
"user": usage.CompleteUsernames(),
|
||||
"ciphers": usage.CompleteMultiList(
|
||||
[v.decode() for v in SSHCiphers.cipherMap.keys()],
|
||||
descr="ciphers to choose from",
|
||||
),
|
||||
"macs": usage.CompleteMultiList(
|
||||
[v.decode() for v in SSHCiphers.macMap.keys()],
|
||||
descr="macs to choose from",
|
||||
),
|
||||
"host-key-algorithms": usage.CompleteMultiList(
|
||||
[v.decode() for v in SSHClientTransport.supportedPublicKeys],
|
||||
descr="host key algorithms to choose from",
|
||||
),
|
||||
# "user-authentications": usage.CompleteMultiList(?
|
||||
# descr='user authentication types' ),
|
||||
},
|
||||
extraActions=[
|
||||
usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr="argument", repeat=True),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
usage.Options.__init__(self, *args, **kw)
|
||||
self.identitys = []
|
||||
self.conns = None
|
||||
|
||||
def opt_identity(self, i):
|
||||
"""Identity for public-key authentication"""
|
||||
self.identitys.append(i)
|
||||
|
||||
def opt_ciphers(self, ciphers):
|
||||
"Select encryption algorithms"
|
||||
ciphers = ciphers.split(",")
|
||||
for cipher in ciphers:
|
||||
if cipher not in SSHCiphers.cipherMap:
|
||||
sys.exit("Unknown cipher type '%s'" % cipher)
|
||||
self["ciphers"] = ciphers
|
||||
|
||||
def opt_macs(self, macs):
|
||||
"Specify MAC algorithms"
|
||||
if isinstance(macs, str):
|
||||
macs = macs.encode("utf-8")
|
||||
macs = macs.split(b",")
|
||||
for mac in macs:
|
||||
if mac not in SSHCiphers.macMap:
|
||||
sys.exit("Unknown mac type '%r'" % mac)
|
||||
self["macs"] = macs
|
||||
|
||||
def opt_host_key_algorithms(self, hkas):
|
||||
"Select host key algorithms"
|
||||
if isinstance(hkas, str):
|
||||
hkas = hkas.encode("utf-8")
|
||||
hkas = hkas.split(b",")
|
||||
for hka in hkas:
|
||||
if hka not in SSHClientTransport.supportedPublicKeys:
|
||||
sys.exit("Unknown host key type '%r'" % hka)
|
||||
self["host-key-algorithms"] = hkas
|
||||
|
||||
def opt_user_authentications(self, uas):
|
||||
"Choose how to authenticate to the remote server"
|
||||
if isinstance(uas, str):
|
||||
uas = uas.encode("utf-8")
|
||||
self["user-authentications"] = uas.split(b",")
|
||||
|
||||
|
||||
# def opt_compress(self):
|
||||
# "Enable compression"
|
||||
# self.enableCompression = 1
|
||||
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']
|
||||
845
.venv/lib/python3.12/site-packages/twisted/conch/endpoints.py
Normal file
845
.venv/lib/python3.12/site-packages/twisted/conch/endpoints.py
Normal file
@@ -0,0 +1,845 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Endpoint implementations of various SSH interactions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationFailed",
|
||||
"SSHCommandAddress",
|
||||
"SSHCommandClientEndpoint",
|
||||
]
|
||||
|
||||
import signal
|
||||
from io import BytesIO
|
||||
from os.path import expanduser
|
||||
from struct import unpack
|
||||
from typing import IO, Any
|
||||
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.conch.client.agent import SSHAgentClient
|
||||
from twisted.conch.client.default import _KNOWN_HOSTS
|
||||
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
|
||||
from twisted.conch.ssh.channel import SSHChannel
|
||||
from twisted.conch.ssh.common import NS, getNS
|
||||
from twisted.conch.ssh.connection import SSHConnection
|
||||
from twisted.conch.ssh.keys import Key
|
||||
from twisted.conch.ssh.transport import SSHClientTransport
|
||||
from twisted.conch.ssh.userauth import SSHUserAuthClient
|
||||
from twisted.internet.defer import CancelledError, Deferred, succeed
|
||||
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
|
||||
from twisted.internet.error import ConnectionDone, ProcessTerminated
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.compat import nativeString, networkString
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
|
||||
class AuthenticationFailed(Exception):
|
||||
"""
|
||||
An SSH session could not be established because authentication was not
|
||||
successful.
|
||||
"""
|
||||
|
||||
|
||||
# This should be public. See #6541.
|
||||
class _ISSHConnectionCreator(Interface):
|
||||
"""
|
||||
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
|
||||
"""
|
||||
|
||||
def secureConnection():
|
||||
"""
|
||||
Return a new, connected, secured, but not yet authenticated instance of
|
||||
L{twisted.conch.ssh.transport.SSHServerTransport} or
|
||||
L{twisted.conch.ssh.transport.SSHClientTransport}.
|
||||
"""
|
||||
|
||||
def cleanupConnection(connection, immediate):
|
||||
"""
|
||||
Perform cleanup necessary for a connection object previously returned
|
||||
from this creator's C{secureConnection} method.
|
||||
|
||||
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
|
||||
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
|
||||
previous call to C{secureConnection}. It is no longer needed by
|
||||
the caller of that method and may be closed or otherwise cleaned up
|
||||
as necessary.
|
||||
|
||||
@param immediate: If C{True} don't wait for any network communication,
|
||||
just close the connection immediately and as aggressively as
|
||||
necessary.
|
||||
"""
|
||||
|
||||
|
||||
class SSHCommandAddress:
|
||||
"""
|
||||
An L{SSHCommandAddress} instance represents the address of an SSH server, a
|
||||
username which was used to authenticate with that server, and a command
|
||||
which was run there.
|
||||
|
||||
@ivar server: See L{__init__}
|
||||
@ivar username: See L{__init__}
|
||||
@ivar command: See L{__init__}
|
||||
"""
|
||||
|
||||
def __init__(self, server, username, command):
|
||||
"""
|
||||
@param server: The address of the SSH server on which the command is
|
||||
running.
|
||||
@type server: L{IAddress} provider
|
||||
|
||||
@param username: An authentication username which was used to
|
||||
authenticate against the server at the given address.
|
||||
@type username: L{bytes}
|
||||
|
||||
@param command: A command which was run in a session channel on the
|
||||
server at the given address.
|
||||
@type command: L{bytes}
|
||||
"""
|
||||
self.server = server
|
||||
self.username = username
|
||||
self.command = command
|
||||
|
||||
|
||||
class _CommandChannel(SSHChannel):
|
||||
"""
|
||||
A L{_CommandChannel} executes a command in a session channel and connects
|
||||
its input and output to an L{IProtocol} provider.
|
||||
|
||||
@ivar _creator: See L{__init__}
|
||||
@ivar _command: See L{__init__}
|
||||
@ivar _protocolFactory: See L{__init__}
|
||||
@ivar _commandConnected: See L{__init__}
|
||||
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
|
||||
which is hooked up to the running command's input and output streams.
|
||||
"""
|
||||
|
||||
name = b"session"
|
||||
_log = Logger()
|
||||
|
||||
def __init__(self, creator, command, protocolFactory, commandConnected):
|
||||
"""
|
||||
@param creator: The L{_ISSHConnectionCreator} provider which was used
|
||||
to get the connection which this channel exists on.
|
||||
@type creator: L{_ISSHConnectionCreator} provider
|
||||
|
||||
@param command: The command to be executed.
|
||||
@type command: L{bytes}
|
||||
|
||||
@param protocolFactory: A client factory to use to build a L{IProtocol}
|
||||
provider to use to associate with the running command.
|
||||
|
||||
@param commandConnected: A L{Deferred} to use to signal that execution
|
||||
of the command has failed or that it has succeeded and the command
|
||||
is now running.
|
||||
@type commandConnected: L{Deferred}
|
||||
"""
|
||||
SSHChannel.__init__(self)
|
||||
self._creator = creator
|
||||
self._command = command
|
||||
self._protocolFactory = protocolFactory
|
||||
self._commandConnected = commandConnected
|
||||
self._reason = None
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
When the request to open a new channel to run this command in fails,
|
||||
fire the C{commandConnected} deferred with a failure indicating that.
|
||||
"""
|
||||
self._commandConnected.errback(reason)
|
||||
|
||||
def channelOpen(self, ignored):
|
||||
"""
|
||||
When the request to open a new channel to run this command in succeeds,
|
||||
issue an C{"exec"} request to run the command.
|
||||
"""
|
||||
command = self.conn.sendRequest(
|
||||
self, b"exec", NS(self._command), wantReply=True
|
||||
)
|
||||
command.addCallbacks(self._execSuccess, self._execFailure)
|
||||
|
||||
def _execFailure(self, reason):
|
||||
"""
|
||||
When the request to execute the command in this channel fails, fire the
|
||||
C{commandConnected} deferred with a failure indicating this.
|
||||
|
||||
@param reason: The cause of the command execution failure.
|
||||
@type reason: L{Failure}
|
||||
"""
|
||||
self._commandConnected.errback(reason)
|
||||
|
||||
def _execSuccess(self, ignored):
|
||||
"""
|
||||
When the request to execute the command in this channel succeeds, use
|
||||
C{protocolFactory} to build a protocol to handle the command's input
|
||||
and output and connect the protocol to a transport representing those
|
||||
streams.
|
||||
|
||||
Also fire C{commandConnected} with the created protocol after it is
|
||||
connected to its transport.
|
||||
|
||||
@param ignored: The (ignored) result of the execute request
|
||||
"""
|
||||
self._protocol = self._protocolFactory.buildProtocol(
|
||||
SSHCommandAddress(
|
||||
self.conn.transport.transport.getPeer(),
|
||||
self.conn.transport.creator.username,
|
||||
self.conn.transport.creator.command,
|
||||
)
|
||||
)
|
||||
self._protocol.makeConnection(self)
|
||||
self._commandConnected.callback(self._protocol)
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
When the command's stdout data arrives over the channel, deliver it to
|
||||
the protocol instance.
|
||||
|
||||
@param data: The bytes from the command's stdout.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self._protocol.dataReceived(data)
|
||||
|
||||
def request_exit_status(self, data):
|
||||
"""
|
||||
When the server sends the command's exit status, record it for later
|
||||
delivery to the protocol.
|
||||
|
||||
@param data: The network-order four byte representation of the exit
|
||||
status of the command.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
(status,) = unpack(">L", data)
|
||||
if status != 0:
|
||||
self._reason = ProcessTerminated(status, None, None)
|
||||
|
||||
def request_exit_signal(self, data):
|
||||
"""
|
||||
When the server sends the command's exit status, record it for later
|
||||
delivery to the protocol.
|
||||
|
||||
@param data: The network-order four byte representation of the exit
|
||||
signal of the command.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
shortSignalName, data = getNS(data)
|
||||
coreDumped, data = bool(ord(data[0:1])), data[1:]
|
||||
errorMessage, data = getNS(data)
|
||||
languageTag, data = getNS(data)
|
||||
signalName = f"SIG{nativeString(shortSignalName)}"
|
||||
signalID = getattr(signal, signalName, -1)
|
||||
self._log.info(
|
||||
"Process exited with signal {shortSignalName!r};"
|
||||
" core dumped: {coreDumped};"
|
||||
" error message: {errorMessage};"
|
||||
" language: {languageTag!r}",
|
||||
shortSignalName=shortSignalName,
|
||||
coreDumped=coreDumped,
|
||||
errorMessage=errorMessage.decode("utf-8"),
|
||||
languageTag=languageTag,
|
||||
)
|
||||
self._reason = ProcessTerminated(None, signalID, None)
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
When the channel closes, deliver disconnection notification to the
|
||||
protocol.
|
||||
"""
|
||||
self._creator.cleanupConnection(self.conn, False)
|
||||
if self._reason is None:
|
||||
reason = ConnectionDone("ssh channel closed")
|
||||
else:
|
||||
reason = self._reason
|
||||
self._protocol.connectionLost(Failure(reason))
|
||||
|
||||
|
||||
class _ConnectionReady(SSHConnection):
|
||||
"""
|
||||
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
|
||||
propagates the I{serviceStarted} event to a L{Deferred} to be handled
|
||||
elsewhere.
|
||||
"""
|
||||
|
||||
def __init__(self, ready):
|
||||
"""
|
||||
@param ready: A L{Deferred} which should be fired when
|
||||
I{serviceStarted} happens.
|
||||
"""
|
||||
SSHConnection.__init__(self)
|
||||
self._ready = ready
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
When the SSH I{connection} I{service} this object represents is ready
|
||||
to be used, fire the C{connectionReady} L{Deferred} to publish that
|
||||
event to some other interested party.
|
||||
|
||||
"""
|
||||
self._ready.callback(self)
|
||||
del self._ready
|
||||
|
||||
|
||||
class _UserAuth(SSHUserAuthClient):
|
||||
"""
|
||||
L{_UserAuth} implements the client part of SSH user authentication in the
|
||||
convenient way a user might expect if they are familiar with the
|
||||
interactive I{ssh} command line client.
|
||||
|
||||
L{_UserAuth} supports key-based authentication, password-based
|
||||
authentication, and delegating authentication to an agent.
|
||||
"""
|
||||
|
||||
password = None
|
||||
keys = None
|
||||
agent = None
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Retrieve the next public key object to offer to the server, possibly
|
||||
delegating to an authentication agent if there is one.
|
||||
|
||||
@return: The public part of a key pair that could be used to
|
||||
authenticate with the server, or L{None} if there are no more
|
||||
public keys to try.
|
||||
@rtype: L{twisted.conch.ssh.keys.Key} or L{None}
|
||||
"""
|
||||
if self.agent is not None:
|
||||
return self.agent.getPublicKey()
|
||||
|
||||
if self.keys:
|
||||
self.key = self.keys.pop(0)
|
||||
else:
|
||||
self.key = None
|
||||
return self.key.public()
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Extend the base signing behavior by using an SSH agent to sign the
|
||||
data, if one is available.
|
||||
|
||||
@type publicKey: L{Key}
|
||||
@type signData: L{str}
|
||||
"""
|
||||
if self.agent is not None:
|
||||
return self.agent.signData(publicKey.blob(), signData)
|
||||
else:
|
||||
return SSHUserAuthClient.signData(self, publicKey, signData)
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Get the private part of a key pair to use for authentication. The key
|
||||
corresponds to the public part most recently returned from
|
||||
C{getPublicKey}.
|
||||
|
||||
@return: A L{Deferred} which fires with the private key.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return succeed(self.key)
|
||||
|
||||
def getPassword(self):
|
||||
"""
|
||||
Get the password to use for authentication.
|
||||
|
||||
@return: A L{Deferred} which fires with the password, or L{None} if the
|
||||
password was not specified.
|
||||
"""
|
||||
if self.password is None:
|
||||
return
|
||||
return succeed(self.password)
|
||||
|
||||
def ssh_USERAUTH_SUCCESS(self, packet):
|
||||
"""
|
||||
Handle user authentication success in the normal way, but also make a
|
||||
note of the state change on the L{_CommandTransport}.
|
||||
"""
|
||||
self.transport._state = b"CHANNELLING"
|
||||
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
|
||||
|
||||
def connectToAgent(self, endpoint):
|
||||
"""
|
||||
Set up a connection to the authentication agent and trigger its
|
||||
initialization.
|
||||
|
||||
@param endpoint: An endpoint which can be used to connect to the
|
||||
authentication agent.
|
||||
@type endpoint: L{IStreamClientEndpoint} provider
|
||||
|
||||
@return: A L{Deferred} which fires when the agent connection is ready
|
||||
for use.
|
||||
"""
|
||||
factory = Factory()
|
||||
factory.protocol = SSHAgentClient
|
||||
d = endpoint.connect(factory)
|
||||
|
||||
def connected(agent):
|
||||
self.agent = agent
|
||||
return agent.getPublicKeys()
|
||||
|
||||
d.addCallback(connected)
|
||||
return d
|
||||
|
||||
def loseAgentConnection(self):
|
||||
"""
|
||||
Disconnect the agent.
|
||||
"""
|
||||
if self.agent is None:
|
||||
return
|
||||
self.agent.transport.loseConnection()
|
||||
|
||||
|
||||
class _CommandTransport(SSHClientTransport):
|
||||
"""
|
||||
L{_CommandTransport} is an SSH client I{transport} which includes a host
|
||||
key verification step before it will proceed to secure the connection.
|
||||
|
||||
L{_CommandTransport} also knows how to set up a connection to an
|
||||
authentication agent if it is told where it can connect to one.
|
||||
|
||||
@ivar _userauth: The L{_UserAuth} instance which is in charge of the
|
||||
overall authentication process or L{None} if the SSH connection has not
|
||||
reach yet the C{user-auth} service.
|
||||
@type _userauth: L{_UserAuth}
|
||||
"""
|
||||
|
||||
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
|
||||
_state = b"STARTING"
|
||||
|
||||
_hostKeyFailure = None
|
||||
|
||||
_userauth = None
|
||||
|
||||
def __init__(self, creator):
|
||||
"""
|
||||
@param creator: The L{_NewConnectionHelper} that created this
|
||||
connection.
|
||||
|
||||
@type creator: L{_NewConnectionHelper}.
|
||||
"""
|
||||
self.connectionReady = Deferred(lambda d: self.transport.abortConnection())
|
||||
# Clear the reference to that deferred to help the garbage collector
|
||||
# and to signal to other parts of this implementation (in particular
|
||||
# connectionLost) that it has already been fired and does not need to
|
||||
# be fired again.
|
||||
|
||||
def readyFired(result):
|
||||
self.connectionReady = None
|
||||
return result
|
||||
|
||||
self.connectionReady.addBoth(readyFired)
|
||||
self.creator = creator
|
||||
|
||||
def verifyHostKey(self, hostKey, fingerprint):
|
||||
"""
|
||||
Ask the L{KnownHostsFile} provider available on the factory which
|
||||
created this protocol this protocol to verify the given host key.
|
||||
|
||||
@return: A L{Deferred} which fires with the result of
|
||||
L{KnownHostsFile.verifyHostKey}.
|
||||
"""
|
||||
hostname = self.creator.hostname
|
||||
ip = networkString(self.transport.getPeer().host)
|
||||
|
||||
self._state = b"SECURING"
|
||||
d = self.creator.knownHosts.verifyHostKey(
|
||||
self.creator.ui, hostname, ip, Key.fromString(hostKey)
|
||||
)
|
||||
d.addErrback(self._saveHostKeyFailure)
|
||||
return d
|
||||
|
||||
def _saveHostKeyFailure(self, reason):
|
||||
"""
|
||||
When host key verification fails, record the reason for the failure in
|
||||
order to fire a L{Deferred} with it later.
|
||||
|
||||
@param reason: The cause of the host key verification failure.
|
||||
@type reason: L{Failure}
|
||||
|
||||
@return: C{reason}
|
||||
@rtype: L{Failure}
|
||||
"""
|
||||
self._hostKeyFailure = reason
|
||||
return reason
|
||||
|
||||
def connectionSecure(self):
|
||||
"""
|
||||
When the connection is secure, start the authentication process.
|
||||
"""
|
||||
self._state = b"AUTHENTICATING"
|
||||
|
||||
command = _ConnectionReady(self.connectionReady)
|
||||
|
||||
self._userauth = _UserAuth(self.creator.username, command)
|
||||
self._userauth.password = self.creator.password
|
||||
if self.creator.keys:
|
||||
self._userauth.keys = list(self.creator.keys)
|
||||
|
||||
if self.creator.agentEndpoint is not None:
|
||||
d = self._userauth.connectToAgent(self.creator.agentEndpoint)
|
||||
else:
|
||||
d = succeed(None)
|
||||
|
||||
def maybeGotAgent(ignored):
|
||||
self.requestService(self._userauth)
|
||||
|
||||
d.addBoth(maybeGotAgent)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
When the underlying connection to the SSH server is lost, if there were
|
||||
any connection setup errors, propagate them. Also, clean up the
|
||||
connection to the ssh agent if one was created.
|
||||
"""
|
||||
if self._userauth:
|
||||
self._userauth.loseAgentConnection()
|
||||
|
||||
if self._state == b"RUNNING" or self.connectionReady is None:
|
||||
return
|
||||
if self._state == b"SECURING" and self._hostKeyFailure is not None:
|
||||
reason = self._hostKeyFailure
|
||||
elif self._state == b"AUTHENTICATING":
|
||||
reason = Failure(
|
||||
AuthenticationFailed("Connection lost while authenticating")
|
||||
)
|
||||
self.connectionReady.errback(reason)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class SSHCommandClientEndpoint:
|
||||
"""
|
||||
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
|
||||
SSH servers.
|
||||
|
||||
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
|
||||
it in any one of a number of different ways (keys, passwords, agents),
|
||||
launch a command over that connection and then associate its input and
|
||||
output with a protocol.
|
||||
|
||||
It can also re-use an existing, already-authenticated SSH connection
|
||||
(perhaps one which already has some SSH channels being used for other
|
||||
purposes). In this case it creates a new SSH channel to use to execute the
|
||||
command. Notably this means it supports multiplexing several different
|
||||
command invocations over a single SSH connection.
|
||||
"""
|
||||
|
||||
def __init__(self, creator, command):
|
||||
"""
|
||||
@param creator: An L{_ISSHConnectionCreator} provider which will be
|
||||
used to set up the SSH connection which will be used to run a
|
||||
command.
|
||||
@type creator: L{_ISSHConnectionCreator} provider
|
||||
|
||||
@param command: The command line to execute on the SSH server. This
|
||||
byte string is interpreted by a shell on the SSH server, so it may
|
||||
have a value like C{"ls /"}. Take care when trying to run a
|
||||
command like C{"/Volumes/My Stuff/a-program"} - spaces (and other
|
||||
special bytes) may require escaping.
|
||||
@type command: L{bytes}
|
||||
|
||||
"""
|
||||
self._creator = creator
|
||||
self._command = command
|
||||
|
||||
@classmethod
|
||||
def newConnection(
|
||||
cls,
|
||||
reactor,
|
||||
command,
|
||||
username,
|
||||
hostname,
|
||||
port=None,
|
||||
keys=None,
|
||||
password=None,
|
||||
agentEndpoint=None,
|
||||
knownHosts=None,
|
||||
ui=None,
|
||||
):
|
||||
"""
|
||||
Create and return a new endpoint which will try to create a new
|
||||
connection to an SSH server and run a command over it. It will also
|
||||
close the connection if there are problems leading up to the command
|
||||
being executed, after the command finishes, or if the connection
|
||||
L{Deferred} is cancelled.
|
||||
|
||||
@param reactor: The reactor to use to establish the connection.
|
||||
@type reactor: L{IReactorTCP} provider
|
||||
|
||||
@param command: See L{__init__}'s C{command} argument.
|
||||
|
||||
@param username: The username with which to authenticate to the SSH
|
||||
server.
|
||||
@type username: L{bytes}
|
||||
|
||||
@param hostname: The hostname of the SSH server.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param port: The port number of the SSH server. By default, the
|
||||
standard SSH port number is used.
|
||||
@type port: L{int}
|
||||
|
||||
@param keys: Private keys with which to authenticate to the SSH server,
|
||||
if key authentication is to be attempted (otherwise L{None}).
|
||||
@type keys: L{list} of L{Key}
|
||||
|
||||
@param password: The password with which to authenticate to the SSH
|
||||
server, if password authentication is to be attempted (otherwise
|
||||
L{None}).
|
||||
@type password: L{bytes} or L{None}
|
||||
|
||||
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
|
||||
used to connect to an SSH agent, if one is to be used to help with
|
||||
authentication.
|
||||
@type agentEndpoint: L{IStreamClientEndpoint} provider
|
||||
|
||||
@param knownHosts: The currently known host keys, used to check the
|
||||
host key presented by the server we actually connect to.
|
||||
@type knownHosts: L{KnownHostsFile}
|
||||
|
||||
@param ui: An object for interacting with users to make decisions about
|
||||
whether to accept the server host keys. If L{None}, a L{ConsoleUI}
|
||||
connected to /dev/tty will be used; if /dev/tty is unavailable, an
|
||||
object which answers C{b"no"} to all prompts will be used.
|
||||
@type ui: L{None} or L{ConsoleUI}
|
||||
|
||||
@return: A new instance of C{cls} (probably
|
||||
L{SSHCommandClientEndpoint}).
|
||||
"""
|
||||
helper = _NewConnectionHelper(
|
||||
reactor,
|
||||
hostname,
|
||||
port,
|
||||
command,
|
||||
username,
|
||||
keys,
|
||||
password,
|
||||
agentEndpoint,
|
||||
knownHosts,
|
||||
ui,
|
||||
)
|
||||
return cls(helper, command)
|
||||
|
||||
@classmethod
|
||||
def existingConnection(cls, connection, command):
|
||||
"""
|
||||
Create and return a new endpoint which will try to open a new channel
|
||||
on an existing SSH connection and run a command over it. It will
|
||||
B{not} close the connection if there is a problem executing the command
|
||||
or after the command finishes.
|
||||
|
||||
@param connection: An existing connection to an SSH server.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
|
||||
C{command} parameter.
|
||||
@type command: L{bytes}
|
||||
|
||||
@return: A new instance of C{cls} (probably
|
||||
L{SSHCommandClientEndpoint}).
|
||||
"""
|
||||
helper = _ExistingConnectionHelper(connection)
|
||||
return cls(helper, command)
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
"""
|
||||
Set up an SSH connection, use a channel from that connection to launch
|
||||
a command, and hook the stdin and stdout of that command up as a
|
||||
transport for a protocol created by the given factory.
|
||||
|
||||
@param protocolFactory: A L{Factory} to use to create the protocol
|
||||
which will be connected to the stdin and stdout of the command on
|
||||
the SSH server.
|
||||
|
||||
@return: A L{Deferred} which will fire with an error if the connection
|
||||
cannot be set up for any reason or with the protocol instance
|
||||
created by C{protocolFactory} once it has been connected to the
|
||||
command.
|
||||
"""
|
||||
d = self._creator.secureConnection()
|
||||
d.addCallback(self._executeCommand, protocolFactory)
|
||||
return d
|
||||
|
||||
def _executeCommand(self, connection, protocolFactory):
|
||||
"""
|
||||
Given a secured SSH connection, try to execute a command in a new
|
||||
channel created on it and associate the result with a protocol from the
|
||||
given factory.
|
||||
|
||||
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
|
||||
C{connection} parameter.
|
||||
|
||||
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
|
||||
C{protocolFactory} parameter.
|
||||
|
||||
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
|
||||
"""
|
||||
commandConnected = Deferred()
|
||||
|
||||
def disconnectOnFailure(passthrough):
|
||||
# Close the connection immediately in case of cancellation, since
|
||||
# that implies user wants it gone immediately (e.g. a timeout):
|
||||
immediate = passthrough.check(CancelledError)
|
||||
self._creator.cleanupConnection(connection, immediate)
|
||||
return passthrough
|
||||
|
||||
commandConnected.addErrback(disconnectOnFailure)
|
||||
|
||||
channel = _CommandChannel(
|
||||
self._creator, self._command, protocolFactory, commandConnected
|
||||
)
|
||||
connection.openChannel(channel)
|
||||
return commandConnected
|
||||
|
||||
|
||||
@implementer(_ISSHConnectionCreator)
|
||||
class _NewConnectionHelper:
|
||||
"""
|
||||
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
|
||||
establishing a brand new SSH connection, securing it, and authenticating.
|
||||
"""
|
||||
|
||||
_KNOWN_HOSTS = _KNOWN_HOSTS
|
||||
port = 22
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reactor: Any,
|
||||
hostname: str,
|
||||
port: int,
|
||||
command: str,
|
||||
username: str,
|
||||
keys: str,
|
||||
password: str,
|
||||
agentEndpoint: str,
|
||||
knownHosts: str | None,
|
||||
ui: ConsoleUI | None,
|
||||
tty: FilePath[bytes] | FilePath[str] = FilePath(b"/dev/tty"),
|
||||
):
|
||||
"""
|
||||
@param tty: The path of the tty device to use in case C{ui} is L{None}.
|
||||
@type tty: L{FilePath}
|
||||
|
||||
@see: L{SSHCommandClientEndpoint.newConnection}
|
||||
"""
|
||||
self.reactor = reactor
|
||||
self.hostname = hostname
|
||||
if port is not None:
|
||||
self.port = port
|
||||
self.command = command
|
||||
self.username = username
|
||||
self.keys = keys
|
||||
self.password = password
|
||||
self.agentEndpoint = agentEndpoint
|
||||
if knownHosts is None:
|
||||
knownHosts = self._knownHosts()
|
||||
self.knownHosts = knownHosts
|
||||
|
||||
if ui is None:
|
||||
ui = ConsoleUI(self._opener)
|
||||
self.ui = ui
|
||||
self.tty: FilePath[bytes] | FilePath[str] = tty
|
||||
|
||||
def _opener(self) -> IO[bytes]:
|
||||
"""
|
||||
Open the tty if possible, otherwise give back a file-like object from
|
||||
which C{b"no"} can be read.
|
||||
|
||||
For use as the opener argument to L{ConsoleUI}.
|
||||
"""
|
||||
try:
|
||||
return self.tty.open("r+")
|
||||
except BaseException:
|
||||
# Give back a file-like object from which can be read a byte string
|
||||
# that KnownHostsFile recognizes as rejecting some option (b"no").
|
||||
return BytesIO(b"no")
|
||||
|
||||
@classmethod
|
||||
def _knownHosts(cls):
|
||||
"""
|
||||
|
||||
@return: A L{KnownHostsFile} instance pointed at the user's personal
|
||||
I{known hosts} file.
|
||||
@rtype: L{KnownHostsFile}
|
||||
"""
|
||||
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
|
||||
|
||||
def secureConnection(self):
|
||||
"""
|
||||
Create and return a new SSH connection which has been secured and on
|
||||
which authentication has already happened.
|
||||
|
||||
@return: A L{Deferred} which fires with the ready-to-use connection or
|
||||
with a failure if something prevents the connection from being
|
||||
setup, secured, or authenticated.
|
||||
"""
|
||||
protocol = _CommandTransport(self)
|
||||
ready = protocol.connectionReady
|
||||
|
||||
sshClient = TCP4ClientEndpoint(
|
||||
self.reactor, nativeString(self.hostname), self.port
|
||||
)
|
||||
|
||||
d = connectProtocol(sshClient, protocol)
|
||||
d.addCallback(lambda ignored: ready)
|
||||
return d
|
||||
|
||||
def cleanupConnection(self, connection, immediate):
|
||||
"""
|
||||
Clean up the connection by closing it. The command running on the
|
||||
endpoint has ended so the connection is no longer needed.
|
||||
|
||||
@param connection: The L{SSHConnection} to close.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param immediate: Whether to close connection immediately.
|
||||
@type immediate: L{bool}.
|
||||
"""
|
||||
if immediate:
|
||||
# We're assuming the underlying connection is an ITCPTransport,
|
||||
# which is what the current implementation is restricted to:
|
||||
connection.transport.transport.abortConnection()
|
||||
else:
|
||||
connection.transport.loseConnection()
|
||||
|
||||
|
||||
@implementer(_ISSHConnectionCreator)
|
||||
class _ExistingConnectionHelper:
|
||||
"""
|
||||
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
|
||||
handing out an existing SSH connection which is supplied to its
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
"""
|
||||
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
|
||||
C{connection} parameter.
|
||||
"""
|
||||
self.connection = connection
|
||||
|
||||
def secureConnection(self):
|
||||
"""
|
||||
|
||||
@return: A L{Deferred} that fires synchronously with the
|
||||
already-established connection object.
|
||||
"""
|
||||
return succeed(self.connection)
|
||||
|
||||
def cleanupConnection(self, connection, immediate):
|
||||
"""
|
||||
Do not do any cleanup on the connection. Leave that responsibility to
|
||||
whatever code created it in the first place.
|
||||
|
||||
@param connection: The L{SSHConnection} which will not be modified in
|
||||
any way.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param immediate: An argument which will be ignored.
|
||||
@type immediate: L{bool}.
|
||||
"""
|
||||
96
.venv/lib/python3.12/site-packages/twisted/conch/error.py
Normal file
96
.venv/lib/python3.12/site-packages/twisted/conch/error.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
An error to represent bad things happening in Conch.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
from twisted.cred.error import UnauthorizedLogin
|
||||
|
||||
|
||||
class ConchError(Exception):
|
||||
def __init__(self, value, data=None):
|
||||
Exception.__init__(self, value, data)
|
||||
self.value = value
|
||||
self.data = data
|
||||
|
||||
|
||||
class NotEnoughAuthentication(Exception):
|
||||
"""
|
||||
This is thrown if the authentication is valid, but is not enough to
|
||||
successfully verify the user. i.e. don't retry this type of
|
||||
authentication, try another one.
|
||||
"""
|
||||
|
||||
|
||||
class ValidPublicKey(UnauthorizedLogin):
|
||||
"""
|
||||
Raised by public key checkers when they receive public key credentials
|
||||
that don't contain a signature at all, but are valid in every other way.
|
||||
(e.g. the public key matches one in the user's authorized_keys file).
|
||||
|
||||
Protocol code (eg
|
||||
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
|
||||
attempts to log in using
|
||||
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
|
||||
should be prepared to handle a failure of this type by telling the user to
|
||||
re-authenticate using the same key and to include a signature with the new
|
||||
attempt.
|
||||
|
||||
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
|
||||
"""
|
||||
|
||||
|
||||
class IgnoreAuthentication(Exception):
|
||||
"""
|
||||
This is thrown to let the UserAuthServer know it doesn't need to handle the
|
||||
authentication anymore.
|
||||
"""
|
||||
|
||||
|
||||
class MissingKeyStoreError(Exception):
|
||||
"""
|
||||
Raised if an SSHAgentServer starts receiving data without its factory
|
||||
providing a keys dict on which to read/write key data.
|
||||
"""
|
||||
|
||||
|
||||
class UserRejectedKey(Exception):
|
||||
"""
|
||||
The user interactively rejected a key.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidEntry(Exception):
|
||||
"""
|
||||
An entry in a known_hosts file could not be interpreted as a valid entry.
|
||||
"""
|
||||
|
||||
|
||||
class HostKeyChanged(Exception):
|
||||
"""
|
||||
The host key of a remote host has changed.
|
||||
|
||||
@ivar offendingEntry: The entry which contains the persistent host key that
|
||||
disagrees with the given host key.
|
||||
|
||||
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
|
||||
|
||||
@ivar path: a reference to the known_hosts file that the offending entry
|
||||
was loaded from
|
||||
|
||||
@type path: L{twisted.python.filepath.FilePath}
|
||||
|
||||
@ivar lineno: The line number of the offending entry in the given path.
|
||||
|
||||
@type lineno: L{int}
|
||||
"""
|
||||
|
||||
def __init__(self, offendingEntry, path, lineno):
|
||||
Exception.__init__(self)
|
||||
self.offendingEntry = offendingEntry
|
||||
self.path = path
|
||||
self.lineno = lineno
|
||||
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Insults: a replacement for Curses/S-Lang.
|
||||
|
||||
Very basic at the moment."""
|
||||
@@ -0,0 +1,556 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_helper -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Partial in-memory terminal emulator
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from incremental import Version
|
||||
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.logger import Logger
|
||||
from twisted.python import _textattributes
|
||||
from twisted.python.compat import iterbytes
|
||||
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
|
||||
|
||||
FOREGROUND = 30
|
||||
BACKGROUND = 40
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
|
||||
|
||||
|
||||
class _FormattingState(_textattributes._FormattingStateMixin):
|
||||
"""
|
||||
Represents the formatting state/attributes of a single character.
|
||||
|
||||
Character set, intensity, underlinedness, blinkitude, video
|
||||
reversal, as well as foreground and background colors made up a
|
||||
character's attributes.
|
||||
"""
|
||||
|
||||
compareAttributes = (
|
||||
"charset",
|
||||
"bold",
|
||||
"underline",
|
||||
"blink",
|
||||
"reverseVideo",
|
||||
"foreground",
|
||||
"background",
|
||||
"_subtracting",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
charset=insults.G0,
|
||||
bold=False,
|
||||
underline=False,
|
||||
blink=False,
|
||||
reverseVideo=False,
|
||||
foreground=WHITE,
|
||||
background=BLACK,
|
||||
_subtracting=False,
|
||||
):
|
||||
self.charset = charset
|
||||
self.bold = bold
|
||||
self.underline = underline
|
||||
self.blink = blink
|
||||
self.reverseVideo = reverseVideo
|
||||
self.foreground = foreground
|
||||
self.background = background
|
||||
self._subtracting = _subtracting
|
||||
|
||||
@deprecated(Version("Twisted", 13, 1, 0))
|
||||
def wantOne(self, **kw):
|
||||
"""
|
||||
Add a character attribute to a copy of this formatting state.
|
||||
|
||||
@param kw: An optional attribute name and value can be provided with
|
||||
a keyword argument.
|
||||
|
||||
@return: A formatting state instance with the new attribute.
|
||||
|
||||
@see: L{DefaultFormattingState._withAttribute}.
|
||||
"""
|
||||
k, v = kw.popitem()
|
||||
return self._withAttribute(k, v)
|
||||
|
||||
def toVT102(self):
|
||||
# Spit out a vt102 control sequence that will set up
|
||||
# all the attributes set here. Except charset.
|
||||
attrs = []
|
||||
if self._subtracting:
|
||||
attrs.append(0)
|
||||
if self.bold:
|
||||
attrs.append(insults.BOLD)
|
||||
if self.underline:
|
||||
attrs.append(insults.UNDERLINE)
|
||||
if self.blink:
|
||||
attrs.append(insults.BLINK)
|
||||
if self.reverseVideo:
|
||||
attrs.append(insults.REVERSE_VIDEO)
|
||||
if self.foreground != WHITE:
|
||||
attrs.append(FOREGROUND + self.foreground)
|
||||
if self.background != BLACK:
|
||||
attrs.append(BACKGROUND + self.background)
|
||||
if attrs:
|
||||
return "\x1b[" + ";".join(map(str, attrs)) + "m"
|
||||
return ""
|
||||
|
||||
|
||||
CharacterAttribute = _FormattingState
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 13, 1, 0),
|
||||
"Use twisted.conch.insults.text.assembleFormattedText instead.",
|
||||
"twisted.conch.insults.helper",
|
||||
"CharacterAttribute",
|
||||
)
|
||||
|
||||
|
||||
# XXX - need to support scroll regions and scroll history
|
||||
@implementer(insults.ITerminalTransport)
|
||||
class TerminalBuffer(protocol.Protocol):
|
||||
"""
|
||||
An in-memory terminal emulator.
|
||||
"""
|
||||
|
||||
for keyID in (
|
||||
b"UP_ARROW",
|
||||
b"DOWN_ARROW",
|
||||
b"RIGHT_ARROW",
|
||||
b"LEFT_ARROW",
|
||||
b"HOME",
|
||||
b"INSERT",
|
||||
b"DELETE",
|
||||
b"END",
|
||||
b"PGUP",
|
||||
b"PGDN",
|
||||
b"F1",
|
||||
b"F2",
|
||||
b"F3",
|
||||
b"F4",
|
||||
b"F5",
|
||||
b"F6",
|
||||
b"F7",
|
||||
b"F8",
|
||||
b"F9",
|
||||
b"F10",
|
||||
b"F11",
|
||||
b"F12",
|
||||
):
|
||||
execBytes = keyID + b" = object()"
|
||||
execStr = execBytes.decode("ascii")
|
||||
exec(execStr)
|
||||
|
||||
TAB = b"\t"
|
||||
BACKSPACE = b"\x7f"
|
||||
|
||||
width = 80
|
||||
height = 24
|
||||
|
||||
fill = b" "
|
||||
void = object()
|
||||
_log = Logger()
|
||||
|
||||
def getCharacter(self, x, y):
|
||||
return self.lines[y][x]
|
||||
|
||||
def connectionMade(self):
|
||||
self.reset()
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
Add the given printable bytes to the terminal.
|
||||
|
||||
Line feeds in L{bytes} will be replaced with carriage return / line
|
||||
feed pairs.
|
||||
"""
|
||||
for b in iterbytes(data.replace(b"\n", b"\r\n")):
|
||||
self.insertAtCursor(b)
|
||||
|
||||
def _currentFormattingState(self):
|
||||
return _FormattingState(self.activeCharset, **self.graphicRendition)
|
||||
|
||||
def insertAtCursor(self, b):
|
||||
"""
|
||||
Add one byte to the terminal at the cursor and make consequent state
|
||||
updates.
|
||||
|
||||
If b is a carriage return, move the cursor to the beginning of the
|
||||
current row.
|
||||
|
||||
If b is a line feed, move the cursor to the next row or scroll down if
|
||||
the cursor is already in the last row.
|
||||
|
||||
Otherwise, if b is printable, put it at the cursor position (inserting
|
||||
or overwriting as dictated by the current mode) and move the cursor.
|
||||
"""
|
||||
if b == b"\r":
|
||||
self.x = 0
|
||||
elif b == b"\n":
|
||||
self._scrollDown()
|
||||
elif b in string.printable.encode("ascii"):
|
||||
if self.x >= self.width:
|
||||
self.nextLine()
|
||||
ch = (b, self._currentFormattingState())
|
||||
if self.modes.get(insults.modes.IRM):
|
||||
self.lines[self.y][self.x : self.x] = [ch]
|
||||
self.lines[self.y].pop()
|
||||
else:
|
||||
self.lines[self.y][self.x] = ch
|
||||
self.x += 1
|
||||
|
||||
def _emptyLine(self, width):
|
||||
return [(self.void, self._currentFormattingState()) for i in range(width)]
|
||||
|
||||
def _scrollDown(self):
|
||||
self.y += 1
|
||||
if self.y >= self.height:
|
||||
self.y -= 1
|
||||
del self.lines[0]
|
||||
self.lines.append(self._emptyLine(self.width))
|
||||
|
||||
def _scrollUp(self):
|
||||
self.y -= 1
|
||||
if self.y < 0:
|
||||
self.y = 0
|
||||
del self.lines[-1]
|
||||
self.lines.insert(0, self._emptyLine(self.width))
|
||||
|
||||
def cursorUp(self, n=1):
|
||||
self.y = max(0, self.y - n)
|
||||
|
||||
def cursorDown(self, n=1):
|
||||
self.y = min(self.height - 1, self.y + n)
|
||||
|
||||
def cursorBackward(self, n=1):
|
||||
self.x = max(0, self.x - n)
|
||||
|
||||
def cursorForward(self, n=1):
|
||||
self.x = min(self.width, self.x + n)
|
||||
|
||||
def cursorPosition(self, column, line):
|
||||
self.x = column
|
||||
self.y = line
|
||||
|
||||
def cursorHome(self):
|
||||
self.x = self.home.x
|
||||
self.y = self.home.y
|
||||
|
||||
def index(self):
|
||||
self._scrollDown()
|
||||
|
||||
def reverseIndex(self):
|
||||
self._scrollUp()
|
||||
|
||||
def nextLine(self):
|
||||
"""
|
||||
Update the cursor position attributes and scroll down if appropriate.
|
||||
"""
|
||||
self.x = 0
|
||||
self._scrollDown()
|
||||
|
||||
def saveCursor(self):
|
||||
self._savedCursor = (self.x, self.y)
|
||||
|
||||
def restoreCursor(self):
|
||||
self.x, self.y = self._savedCursor
|
||||
del self._savedCursor
|
||||
|
||||
def setModes(self, modes):
|
||||
for m in modes:
|
||||
self.modes[m] = True
|
||||
|
||||
def resetModes(self, modes):
|
||||
for m in modes:
|
||||
try:
|
||||
del self.modes[m]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def setPrivateModes(self, modes):
|
||||
"""
|
||||
Enable the given modes.
|
||||
|
||||
Track which modes have been enabled so that the implementations of
|
||||
other L{insults.ITerminalTransport} methods can be properly implemented
|
||||
to respect these settings.
|
||||
|
||||
@see: L{resetPrivateModes}
|
||||
@see: L{insults.ITerminalTransport.setPrivateModes}
|
||||
"""
|
||||
for m in modes:
|
||||
self.privateModes[m] = True
|
||||
|
||||
def resetPrivateModes(self, modes):
|
||||
"""
|
||||
Disable the given modes.
|
||||
|
||||
@see: L{setPrivateModes}
|
||||
@see: L{insults.ITerminalTransport.resetPrivateModes}
|
||||
"""
|
||||
for m in modes:
|
||||
try:
|
||||
del self.privateModes[m]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def applicationKeypadMode(self):
|
||||
self.keypadMode = "app"
|
||||
|
||||
def numericKeypadMode(self):
|
||||
self.keypadMode = "num"
|
||||
|
||||
def selectCharacterSet(self, charSet, which):
|
||||
self.charsets[which] = charSet
|
||||
|
||||
def shiftIn(self):
|
||||
self.activeCharset = insults.G0
|
||||
|
||||
def shiftOut(self):
|
||||
self.activeCharset = insults.G1
|
||||
|
||||
def singleShift2(self):
|
||||
oldActiveCharset = self.activeCharset
|
||||
self.activeCharset = insults.G2
|
||||
f = self.insertAtCursor
|
||||
|
||||
def insertAtCursor(b):
|
||||
f(b)
|
||||
del self.insertAtCursor
|
||||
self.activeCharset = oldActiveCharset
|
||||
|
||||
self.insertAtCursor = insertAtCursor
|
||||
|
||||
def singleShift3(self):
|
||||
oldActiveCharset = self.activeCharset
|
||||
self.activeCharset = insults.G3
|
||||
f = self.insertAtCursor
|
||||
|
||||
def insertAtCursor(b):
|
||||
f(b)
|
||||
del self.insertAtCursor
|
||||
self.activeCharset = oldActiveCharset
|
||||
|
||||
self.insertAtCursor = insertAtCursor
|
||||
|
||||
def selectGraphicRendition(self, *attributes):
|
||||
for a in attributes:
|
||||
if a == insults.NORMAL:
|
||||
self.graphicRendition = {
|
||||
"bold": False,
|
||||
"underline": False,
|
||||
"blink": False,
|
||||
"reverseVideo": False,
|
||||
"foreground": WHITE,
|
||||
"background": BLACK,
|
||||
}
|
||||
elif a == insults.BOLD:
|
||||
self.graphicRendition["bold"] = True
|
||||
elif a == insults.UNDERLINE:
|
||||
self.graphicRendition["underline"] = True
|
||||
elif a == insults.BLINK:
|
||||
self.graphicRendition["blink"] = True
|
||||
elif a == insults.REVERSE_VIDEO:
|
||||
self.graphicRendition["reverseVideo"] = True
|
||||
else:
|
||||
try:
|
||||
v = int(a)
|
||||
except ValueError:
|
||||
self._log.error(
|
||||
"Unknown graphic rendition attribute: {attr!r}", attr=a
|
||||
)
|
||||
else:
|
||||
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
|
||||
self.graphicRendition["foreground"] = v - FOREGROUND
|
||||
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
|
||||
self.graphicRendition["background"] = v - BACKGROUND
|
||||
else:
|
||||
self._log.error(
|
||||
"Unknown graphic rendition attribute: {attr!r}", attr=a
|
||||
)
|
||||
|
||||
def eraseLine(self):
|
||||
self.lines[self.y] = self._emptyLine(self.width)
|
||||
|
||||
def eraseToLineEnd(self):
|
||||
width = self.width - self.x
|
||||
self.lines[self.y][self.x :] = self._emptyLine(width)
|
||||
|
||||
def eraseToLineBeginning(self):
|
||||
self.lines[self.y][: self.x + 1] = self._emptyLine(self.x + 1)
|
||||
|
||||
def eraseDisplay(self):
|
||||
self.lines = [self._emptyLine(self.width) for i in range(self.height)]
|
||||
|
||||
def eraseToDisplayEnd(self):
|
||||
self.eraseToLineEnd()
|
||||
height = self.height - self.y - 1
|
||||
self.lines[self.y + 1 :] = [self._emptyLine(self.width) for i in range(height)]
|
||||
|
||||
def eraseToDisplayBeginning(self):
|
||||
self.eraseToLineBeginning()
|
||||
self.lines[: self.y] = [self._emptyLine(self.width) for i in range(self.y)]
|
||||
|
||||
def deleteCharacter(self, n=1):
|
||||
del self.lines[self.y][self.x : self.x + n]
|
||||
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
|
||||
|
||||
def insertLine(self, n=1):
|
||||
self.lines[self.y : self.y] = [self._emptyLine(self.width) for i in range(n)]
|
||||
del self.lines[self.height :]
|
||||
|
||||
def deleteLine(self, n=1):
|
||||
del self.lines[self.y : self.y + n]
|
||||
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
|
||||
|
||||
def reportCursorPosition(self):
|
||||
return (self.x, self.y)
|
||||
|
||||
def reset(self):
|
||||
self.home = insults.Vector(0, 0)
|
||||
self.x = self.y = 0
|
||||
self.modes = {}
|
||||
self.privateModes = {}
|
||||
self.setPrivateModes(
|
||||
[insults.privateModes.AUTO_WRAP, insults.privateModes.CURSOR_MODE]
|
||||
)
|
||||
self.numericKeypad = "app"
|
||||
self.activeCharset = insults.G0
|
||||
self.graphicRendition = {
|
||||
"bold": False,
|
||||
"underline": False,
|
||||
"blink": False,
|
||||
"reverseVideo": False,
|
||||
"foreground": WHITE,
|
||||
"background": BLACK,
|
||||
}
|
||||
self.charsets = {
|
||||
insults.G0: insults.CS_US,
|
||||
insults.G1: insults.CS_US,
|
||||
insults.G2: insults.CS_ALTERNATE,
|
||||
insults.G3: insults.CS_ALTERNATE_SPECIAL,
|
||||
}
|
||||
self.eraseDisplay()
|
||||
|
||||
def unhandledControlSequence(self, buf):
|
||||
print("Could not handle", repr(buf))
|
||||
|
||||
def __bytes__(self):
|
||||
lines = []
|
||||
for L in self.lines:
|
||||
buf = []
|
||||
length = 0
|
||||
for ch, attr in L:
|
||||
if ch is not self.void:
|
||||
buf.append(ch)
|
||||
length = len(buf)
|
||||
else:
|
||||
buf.append(self.fill)
|
||||
lines.append(b"".join(buf[:length]))
|
||||
return b"\n".join(lines)
|
||||
|
||||
def getHost(self):
|
||||
# ITransport.getHost
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.getHost")
|
||||
|
||||
def getPeer(self):
|
||||
# ITransport.getPeer
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.getPeer")
|
||||
|
||||
def loseConnection(self):
|
||||
# ITransport.loseConnection
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.loseConnection")
|
||||
|
||||
def writeSequence(self, data):
|
||||
# ITransport.writeSequence
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.writeSequence")
|
||||
|
||||
def horizontalTabulationSet(self):
|
||||
# ITerminalTransport.horizontalTabulationSet
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TerminalBuffer.horizontalTabulationSet"
|
||||
)
|
||||
|
||||
def tabulationClear(self):
|
||||
# TerminalTransport.tabulationClear
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClear")
|
||||
|
||||
def tabulationClearAll(self):
|
||||
# TerminalTransport.tabulationClearAll
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.tabulationClearAll")
|
||||
|
||||
def doubleHeightLine(self, top=True):
|
||||
# ITerminalTransport.doubleHeightLine
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleHeightLine")
|
||||
|
||||
def singleWidthLine(self):
|
||||
# ITerminalTransport.singleWidthLine
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.singleWidthLine")
|
||||
|
||||
def doubleWidthLine(self):
|
||||
# ITerminalTransport.doubleWidthLine
|
||||
raise NotImplementedError("Unimplemented: TerminalBuffer.doubleWidthLine")
|
||||
|
||||
|
||||
class ExpectationTimeout(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ExpectableBuffer(TerminalBuffer):
|
||||
_mark = 0
|
||||
|
||||
def connectionMade(self):
|
||||
TerminalBuffer.connectionMade(self)
|
||||
self._expecting = []
|
||||
|
||||
def write(self, data):
|
||||
TerminalBuffer.write(self, data)
|
||||
self._checkExpected()
|
||||
|
||||
def cursorHome(self):
|
||||
TerminalBuffer.cursorHome(self)
|
||||
self._mark = 0
|
||||
|
||||
def _timeoutExpected(self, d):
|
||||
d.errback(ExpectationTimeout())
|
||||
self._checkExpected()
|
||||
|
||||
def _checkExpected(self):
|
||||
s = self.__bytes__()[self._mark :]
|
||||
while self._expecting:
|
||||
expr, timer, deferred = self._expecting[0]
|
||||
if timer and not timer.active():
|
||||
del self._expecting[0]
|
||||
continue
|
||||
for match in expr.finditer(s):
|
||||
if timer:
|
||||
timer.cancel()
|
||||
del self._expecting[0]
|
||||
self._mark += match.end()
|
||||
s = s[match.end() :]
|
||||
deferred.callback(match)
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
def expect(self, expression, timeout=None, scheduler=reactor):
|
||||
d = defer.Deferred()
|
||||
timer = None
|
||||
if timeout:
|
||||
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
|
||||
self._expecting.append((re.compile(expression), timer, d))
|
||||
self._checkExpected()
|
||||
return d
|
||||
|
||||
|
||||
__all__ = ["CharacterAttribute", "TerminalBuffer", "ExpectableBuffer"]
|
||||
1207
.venv/lib/python3.12/site-packages/twisted/conch/insults/insults.py
Normal file
1207
.venv/lib/python3.12/site-packages/twisted/conch/insults/insults.py
Normal file
File diff suppressed because it is too large
Load Diff
176
.venv/lib/python3.12/site-packages/twisted/conch/insults/text.py
Normal file
176
.venv/lib/python3.12/site-packages/twisted/conch/insults/text.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_text -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Character attribute manipulation API.
|
||||
|
||||
This module provides a domain-specific language (using Python syntax)
|
||||
for the creation of text with additional display attributes associated
|
||||
with it. It is intended as an alternative to manually building up
|
||||
strings containing ECMA 48 character attribute control codes. It
|
||||
currently supports foreground and background colors (black, red,
|
||||
green, yellow, blue, magenta, cyan, and white), intensity selection,
|
||||
underlining, blinking and reverse video. Character set selection
|
||||
support is planned.
|
||||
|
||||
Character attributes are specified by using two Python operations:
|
||||
attribute lookup and indexing. For example, the string \"Hello
|
||||
world\" with red foreground and all other attributes set to their
|
||||
defaults, assuming the name twisted.conch.insults.text.attributes has
|
||||
been imported and bound to the name \"A\" (with the statement C{from
|
||||
twisted.conch.insults.text import attributes as A}, for example) one
|
||||
uses this expression::
|
||||
|
||||
A.fg.red[\"Hello world\"]
|
||||
|
||||
Other foreground colors are set by substituting their name for
|
||||
\"red\". To set both a foreground and a background color, this
|
||||
expression is used::
|
||||
|
||||
A.fg.red[A.bg.green[\"Hello world\"]]
|
||||
|
||||
Note that either A.bg.green can be nested within A.fg.red or vice
|
||||
versa. Also note that multiple items can be nested within a single
|
||||
index operation by separating them with commas::
|
||||
|
||||
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
|
||||
|
||||
Other character attributes are set in a similar fashion. To specify a
|
||||
blinking version of the previous expression::
|
||||
|
||||
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
|
||||
|
||||
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
|
||||
|
||||
A third operation is actually supported: unary negation. This turns
|
||||
off an attribute when an enclosing expression would otherwise have
|
||||
caused it to be on. For example::
|
||||
|
||||
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
|
||||
|
||||
A formatting structure can then be serialized into a string containing the
|
||||
necessary VT102 control codes with L{assembleFormattedText}.
|
||||
|
||||
@see: L{twisted.conch.insults.text._CharacterAttributes}
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from incremental import Version
|
||||
|
||||
from twisted.conch.insults import helper, insults
|
||||
from twisted.python import _textattributes
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
|
||||
flatten = _textattributes.flatten
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 13, 1, 0),
|
||||
"Use twisted.conch.insults.text.assembleFormattedText instead.",
|
||||
"twisted.conch.insults.text",
|
||||
"flatten",
|
||||
)
|
||||
|
||||
_TEXT_COLORS = {
|
||||
"black": helper.BLACK,
|
||||
"red": helper.RED,
|
||||
"green": helper.GREEN,
|
||||
"yellow": helper.YELLOW,
|
||||
"blue": helper.BLUE,
|
||||
"magenta": helper.MAGENTA,
|
||||
"cyan": helper.CYAN,
|
||||
"white": helper.WHITE,
|
||||
}
|
||||
|
||||
|
||||
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
|
||||
"""
|
||||
Factory for character attributes, including foreground and background color
|
||||
and non-color attributes such as bold, reverse video and underline.
|
||||
|
||||
Character attributes are applied to actual text by using object
|
||||
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
|
||||
example::
|
||||
|
||||
attributes.bold['Some text']
|
||||
|
||||
These can be nested to mix attributes::
|
||||
|
||||
attributes.bold[attributes.underline['Some text']]
|
||||
|
||||
And multiple values can be passed::
|
||||
|
||||
attributes.normal[attributes.bold['Some'], ' text']
|
||||
|
||||
Non-color attributes can be accessed by attribute name, available
|
||||
attributes are:
|
||||
|
||||
- bold
|
||||
- blink
|
||||
- reverseVideo
|
||||
- underline
|
||||
|
||||
Available colors are:
|
||||
|
||||
0. black
|
||||
1. red
|
||||
2. green
|
||||
3. yellow
|
||||
4. blue
|
||||
5. magenta
|
||||
6. cyan
|
||||
7. white
|
||||
|
||||
@ivar fg: Foreground colors accessed by attribute name, see above
|
||||
for possible names.
|
||||
|
||||
@ivar bg: Background colors accessed by attribute name, see above
|
||||
for possible names.
|
||||
"""
|
||||
|
||||
fg = _textattributes._ColorAttribute(
|
||||
_textattributes._ForegroundColorAttr, _TEXT_COLORS
|
||||
)
|
||||
bg = _textattributes._ColorAttribute(
|
||||
_textattributes._BackgroundColorAttr, _TEXT_COLORS
|
||||
)
|
||||
|
||||
attrs = {
|
||||
"bold": insults.BOLD,
|
||||
"blink": insults.BLINK,
|
||||
"underline": insults.UNDERLINE,
|
||||
"reverseVideo": insults.REVERSE_VIDEO,
|
||||
}
|
||||
|
||||
|
||||
def assembleFormattedText(formatted):
|
||||
"""
|
||||
Assemble formatted text from structured information.
|
||||
|
||||
Currently handled formatting includes: bold, blink, reverse, underline and
|
||||
color codes.
|
||||
|
||||
For example::
|
||||
|
||||
from twisted.conch.insults.text import attributes as A
|
||||
assembleFormattedText(
|
||||
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
|
||||
|
||||
Would produce "Time: " in bold formatting, followed by "Now!" with a
|
||||
foreground color of light red and without any additional formatting.
|
||||
|
||||
@param formatted: Structured text and attributes.
|
||||
|
||||
@rtype: L{str}
|
||||
@return: String containing VT102 control sequences that mimic those
|
||||
specified by C{formatted}.
|
||||
|
||||
@see: L{twisted.conch.insults.text._CharacterAttributes}
|
||||
@since: 13.1
|
||||
"""
|
||||
return _textattributes.flatten(formatted, helper._FormattingState(), "toVT102")
|
||||
|
||||
|
||||
attributes = _CharacterAttributes()
|
||||
|
||||
__all__ = ["attributes", "flatten"]
|
||||
@@ -0,0 +1,936 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_window -*-
|
||||
|
||||
"""
|
||||
Simple insults-based widget library
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
|
||||
from twisted.conch.insults import helper, insults
|
||||
from twisted.python import text as tptext
|
||||
|
||||
|
||||
class YieldFocus(Exception):
|
||||
"""
|
||||
Input focus manipulation exception
|
||||
"""
|
||||
|
||||
|
||||
class BoundedTerminalWrapper:
|
||||
def __init__(self, terminal, width, height, xoff, yoff):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.xoff = xoff
|
||||
self.yoff = yoff
|
||||
self.terminal = terminal
|
||||
self.cursorForward = terminal.cursorForward
|
||||
self.selectCharacterSet = terminal.selectCharacterSet
|
||||
self.selectGraphicRendition = terminal.selectGraphicRendition
|
||||
self.saveCursor = terminal.saveCursor
|
||||
self.restoreCursor = terminal.restoreCursor
|
||||
|
||||
def cursorPosition(self, x, y):
|
||||
return self.terminal.cursorPosition(
|
||||
self.xoff + min(self.width, x), self.yoff + min(self.height, y)
|
||||
)
|
||||
|
||||
def cursorHome(self):
|
||||
return self.terminal.cursorPosition(self.xoff, self.yoff)
|
||||
|
||||
def write(self, data):
|
||||
return self.terminal.write(data)
|
||||
|
||||
|
||||
class Widget:
|
||||
focused = False
|
||||
parent = None
|
||||
dirty = False
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
def repaint(self):
|
||||
if not self.dirty:
|
||||
self.dirty = True
|
||||
if self.parent is not None and not self.parent.dirty:
|
||||
self.parent.repaint()
|
||||
|
||||
def filthy(self):
|
||||
self.dirty = True
|
||||
|
||||
def redraw(self, width, height, terminal):
|
||||
self.filthy()
|
||||
self.draw(width, height, terminal)
|
||||
|
||||
def draw(self, width, height, terminal):
|
||||
if width != self.width or height != self.height or self.dirty:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.dirty = False
|
||||
self.render(width, height, terminal)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
pass
|
||||
|
||||
def sizeHint(self):
|
||||
return None
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
if keyID == b"\t":
|
||||
self.tabReceived(modifier)
|
||||
elif keyID == b"\x7f":
|
||||
self.backspaceReceived()
|
||||
elif keyID in insults.FUNCTION_KEYS:
|
||||
self.functionKeyReceived(keyID, modifier)
|
||||
else:
|
||||
self.characterReceived(keyID, modifier)
|
||||
|
||||
def tabReceived(self, modifier):
|
||||
# XXX TODO - Handle shift+tab
|
||||
raise YieldFocus()
|
||||
|
||||
def focusReceived(self):
|
||||
"""
|
||||
Called when focus is being given to this widget.
|
||||
|
||||
May raise YieldFocus is this widget does not want focus.
|
||||
"""
|
||||
self.focused = True
|
||||
self.repaint()
|
||||
|
||||
def focusLost(self):
|
||||
self.focused = False
|
||||
self.repaint()
|
||||
|
||||
def backspaceReceived(self):
|
||||
pass
|
||||
|
||||
def functionKeyReceived(self, keyID, modifier):
|
||||
name = keyID
|
||||
if not isinstance(keyID, str):
|
||||
name = name.decode("utf-8")
|
||||
|
||||
# Peel off the square brackets added by the computed definition of
|
||||
# twisted.conch.insults.insults.FUNCTION_KEYS.
|
||||
methodName = "func_" + name[1:-1]
|
||||
|
||||
func = getattr(self, methodName, None)
|
||||
if func is not None:
|
||||
func(modifier)
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
pass
|
||||
|
||||
|
||||
class ContainerWidget(Widget):
|
||||
"""
|
||||
@ivar focusedChild: The contained widget which currently has
|
||||
focus, or None.
|
||||
"""
|
||||
|
||||
focusedChild = None
|
||||
focused = False
|
||||
|
||||
def __init__(self):
|
||||
Widget.__init__(self)
|
||||
self.children = []
|
||||
|
||||
def addChild(self, child):
|
||||
assert child.parent is None
|
||||
child.parent = self
|
||||
self.children.append(child)
|
||||
if self.focusedChild is None and self.focused:
|
||||
try:
|
||||
child.focusReceived()
|
||||
except YieldFocus:
|
||||
pass
|
||||
else:
|
||||
self.focusedChild = child
|
||||
self.repaint()
|
||||
|
||||
def remChild(self, child):
|
||||
assert child.parent is self
|
||||
child.parent = None
|
||||
self.children.remove(child)
|
||||
self.repaint()
|
||||
|
||||
def filthy(self):
|
||||
for ch in self.children:
|
||||
ch.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
for ch in self.children:
|
||||
ch.draw(width, height, terminal)
|
||||
|
||||
def changeFocus(self):
|
||||
self.repaint()
|
||||
|
||||
if self.focusedChild is not None:
|
||||
self.focusedChild.focusLost()
|
||||
focusedChild = self.focusedChild
|
||||
self.focusedChild = None
|
||||
try:
|
||||
curFocus = self.children.index(focusedChild) + 1
|
||||
except ValueError:
|
||||
raise YieldFocus()
|
||||
else:
|
||||
curFocus = 0
|
||||
while curFocus < len(self.children):
|
||||
try:
|
||||
self.children[curFocus].focusReceived()
|
||||
except YieldFocus:
|
||||
curFocus += 1
|
||||
else:
|
||||
self.focusedChild = self.children[curFocus]
|
||||
return
|
||||
# None of our children wanted focus
|
||||
raise YieldFocus()
|
||||
|
||||
def focusReceived(self):
|
||||
self.changeFocus()
|
||||
self.focused = True
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
if self.focusedChild is not None:
|
||||
try:
|
||||
self.focusedChild.keystrokeReceived(keyID, modifier)
|
||||
except YieldFocus:
|
||||
self.changeFocus()
|
||||
self.repaint()
|
||||
else:
|
||||
Widget.keystrokeReceived(self, keyID, modifier)
|
||||
|
||||
|
||||
class TopWindow(ContainerWidget):
|
||||
"""
|
||||
A top-level container object which provides focus wrap-around and paint
|
||||
scheduling.
|
||||
|
||||
@ivar painter: A no-argument callable which will be invoked when this
|
||||
widget needs to be redrawn.
|
||||
|
||||
@ivar scheduler: A one-argument callable which will be invoked with a
|
||||
no-argument callable and should arrange for it to invoked at some point in
|
||||
the near future. The no-argument callable will cause this widget and all
|
||||
its children to be redrawn. It is typically beneficial for the no-argument
|
||||
callable to be invoked at the end of handling for whatever event is
|
||||
currently active; for example, it might make sense to call it at the end of
|
||||
L{twisted.conch.insults.insults.ITerminalProtocol.keystrokeReceived}.
|
||||
Note, however, that since calls to this may also be made in response to no
|
||||
apparent event, arrangements should be made for the function to be called
|
||||
even if an event handler such as C{keystrokeReceived} is not on the call
|
||||
stack (eg, using
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
with a short timeout).
|
||||
"""
|
||||
|
||||
focused = True
|
||||
|
||||
def __init__(self, painter, scheduler):
|
||||
ContainerWidget.__init__(self)
|
||||
self.painter = painter
|
||||
self.scheduler = scheduler
|
||||
|
||||
_paintCall = None
|
||||
|
||||
def repaint(self):
|
||||
if self._paintCall is None:
|
||||
self._paintCall = object()
|
||||
self.scheduler(self._paint)
|
||||
ContainerWidget.repaint(self)
|
||||
|
||||
def _paint(self):
|
||||
self._paintCall = None
|
||||
self.painter()
|
||||
|
||||
def changeFocus(self):
|
||||
try:
|
||||
ContainerWidget.changeFocus(self)
|
||||
except YieldFocus:
|
||||
try:
|
||||
ContainerWidget.changeFocus(self)
|
||||
except YieldFocus:
|
||||
pass
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
try:
|
||||
ContainerWidget.keystrokeReceived(self, keyID, modifier)
|
||||
except YieldFocus:
|
||||
self.changeFocus()
|
||||
|
||||
|
||||
class AbsoluteBox(ContainerWidget):
|
||||
def moveChild(self, child, x, y):
|
||||
for n in range(len(self.children)):
|
||||
if self.children[n][0] is child:
|
||||
self.children[n] = (child, x, y)
|
||||
break
|
||||
else:
|
||||
raise ValueError("No such child", child)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
for ch, x, y in self.children:
|
||||
wrap = BoundedTerminalWrapper(terminal, width - x, height - y, x, y)
|
||||
ch.draw(width, height, wrap)
|
||||
|
||||
|
||||
class _Box(ContainerWidget):
|
||||
TOP, CENTER, BOTTOM = range(3)
|
||||
|
||||
def __init__(self, gravity=CENTER):
|
||||
ContainerWidget.__init__(self)
|
||||
self.gravity = gravity
|
||||
|
||||
def sizeHint(self):
|
||||
height = 0
|
||||
width = 0
|
||||
for ch in self.children:
|
||||
hint = ch.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
|
||||
if self.variableDimension == 0:
|
||||
if hint[0] is None:
|
||||
width = None
|
||||
elif width is not None:
|
||||
width += hint[0]
|
||||
if hint[1] is None:
|
||||
height = None
|
||||
elif height is not None:
|
||||
height = max(height, hint[1])
|
||||
else:
|
||||
if hint[0] is None:
|
||||
width = None
|
||||
elif width is not None:
|
||||
width = max(width, hint[0])
|
||||
if hint[1] is None:
|
||||
height = None
|
||||
elif height is not None:
|
||||
height += hint[1]
|
||||
|
||||
return width, height
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if not self.children:
|
||||
return
|
||||
|
||||
greedy = 0
|
||||
wants = []
|
||||
for ch in self.children:
|
||||
hint = ch.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
if hint[self.variableDimension] is None:
|
||||
greedy += 1
|
||||
wants.append(hint[self.variableDimension])
|
||||
|
||||
length = (width, height)[self.variableDimension]
|
||||
totalWant = sum(w for w in wants if w is not None)
|
||||
if greedy:
|
||||
leftForGreedy = int((length - totalWant) / greedy)
|
||||
|
||||
widthOffset = heightOffset = 0
|
||||
|
||||
for want, ch in zip(wants, self.children):
|
||||
if want is None:
|
||||
want = leftForGreedy
|
||||
|
||||
subWidth, subHeight = width, height
|
||||
if self.variableDimension == 0:
|
||||
subWidth = want
|
||||
else:
|
||||
subHeight = want
|
||||
|
||||
wrap = BoundedTerminalWrapper(
|
||||
terminal,
|
||||
subWidth,
|
||||
subHeight,
|
||||
widthOffset,
|
||||
heightOffset,
|
||||
)
|
||||
ch.draw(subWidth, subHeight, wrap)
|
||||
if self.variableDimension == 0:
|
||||
widthOffset += want
|
||||
else:
|
||||
heightOffset += want
|
||||
|
||||
|
||||
class HBox(_Box):
|
||||
variableDimension = 0
|
||||
|
||||
|
||||
class VBox(_Box):
|
||||
variableDimension = 1
|
||||
|
||||
|
||||
class Packer(ContainerWidget):
|
||||
def render(self, width, height, terminal):
|
||||
if not self.children:
|
||||
return
|
||||
|
||||
root = int(len(self.children) ** 0.5 + 0.5)
|
||||
boxes = [VBox() for n in range(root)]
|
||||
for n, ch in enumerate(self.children):
|
||||
boxes[n % len(boxes)].addChild(ch)
|
||||
h = HBox()
|
||||
map(h.addChild, boxes)
|
||||
h.render(width, height, terminal)
|
||||
|
||||
|
||||
class Canvas(Widget):
|
||||
focused = False
|
||||
|
||||
contents = None
|
||||
|
||||
def __init__(self):
|
||||
Widget.__init__(self)
|
||||
self.resize(1, 1)
|
||||
|
||||
def resize(self, width, height):
|
||||
contents = array.array("B", b" " * width * height)
|
||||
if self.contents is not None:
|
||||
for x in range(min(width, self._width)):
|
||||
for y in range(min(height, self._height)):
|
||||
contents[width * y + x] = self[x, y]
|
||||
self.contents = contents
|
||||
self._width = width
|
||||
self._height = height
|
||||
if self.x >= width:
|
||||
self.x = width - 1
|
||||
if self.y >= height:
|
||||
self.y = height - 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
(x, y) = index
|
||||
return self.contents[(self._width * y) + x]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
(x, y) = index
|
||||
self.contents[(self._width * y) + x] = value
|
||||
|
||||
def clear(self):
|
||||
self.contents = array.array("B", b" " * len(self.contents))
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if not width or not height:
|
||||
return
|
||||
|
||||
if width != self._width or height != self._height:
|
||||
self.resize(width, height)
|
||||
for i in range(height):
|
||||
terminal.cursorPosition(0, i)
|
||||
text = self.contents[
|
||||
self._width * i : self._width * i + self._width
|
||||
].tobytes()
|
||||
text = text[:width]
|
||||
terminal.write(text)
|
||||
|
||||
|
||||
def horizontalLine(terminal, y, left, right):
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
terminal.cursorPosition(left, y)
|
||||
terminal.write(b"\161" * (right - left))
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
|
||||
def verticalLine(terminal, x, top, bottom):
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
for n in range(top, bottom):
|
||||
terminal.cursorPosition(x, n)
|
||||
terminal.write(b"\170")
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
|
||||
def rectangle(terminal, position, dimension):
|
||||
"""
|
||||
Draw a rectangle
|
||||
|
||||
@type position: L{tuple}
|
||||
@param position: A tuple of the (top, left) coordinates of the rectangle.
|
||||
@type dimension: L{tuple}
|
||||
@param dimension: A tuple of the (width, height) size of the rectangle.
|
||||
"""
|
||||
(top, left) = position
|
||||
(width, height) = dimension
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
|
||||
terminal.cursorPosition(top, left)
|
||||
terminal.write(b"\154")
|
||||
terminal.write(b"\161" * (width - 2))
|
||||
terminal.write(b"\153")
|
||||
for n in range(height - 2):
|
||||
terminal.cursorPosition(left, top + n + 1)
|
||||
terminal.write(b"\170")
|
||||
terminal.cursorForward(width - 2)
|
||||
terminal.write(b"\170")
|
||||
terminal.cursorPosition(0, top + height - 1)
|
||||
terminal.write(b"\155")
|
||||
terminal.write(b"\161" * (width - 2))
|
||||
terminal.write(b"\152")
|
||||
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
|
||||
class Border(Widget):
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self.containee = containee
|
||||
self.containee.parent = self
|
||||
|
||||
def focusReceived(self):
|
||||
return self.containee.focusReceived()
|
||||
|
||||
def focusLost(self):
|
||||
return self.containee.focusLost()
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
return self.containee.keystrokeReceived(keyID, modifier)
|
||||
|
||||
def sizeHint(self):
|
||||
hint = self.containee.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
if hint[0] is None:
|
||||
x = None
|
||||
else:
|
||||
x = hint[0] + 2
|
||||
if hint[1] is None:
|
||||
y = None
|
||||
else:
|
||||
y = hint[1] + 2
|
||||
return x, y
|
||||
|
||||
def filthy(self):
|
||||
self.containee.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if self.containee.focused:
|
||||
terminal.write(b"\x1b[31m")
|
||||
rectangle(terminal, (0, 0), (width, height))
|
||||
terminal.write(b"\x1b[0m")
|
||||
wrap = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
|
||||
self.containee.draw(width - 2, height - 2, wrap)
|
||||
|
||||
|
||||
class Button(Widget):
|
||||
def __init__(self, label, onPress):
|
||||
Widget.__init__(self)
|
||||
self.label = label
|
||||
self.onPress = onPress
|
||||
|
||||
def sizeHint(self):
|
||||
return len(self.label), 1
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == b"\r":
|
||||
self.onPress()
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
if self.focused:
|
||||
terminal.write(b"\x1b[1m" + self.label + b"\x1b[0m")
|
||||
else:
|
||||
terminal.write(self.label)
|
||||
|
||||
|
||||
class TextInput(Widget):
|
||||
def __init__(self, maxwidth, onSubmit):
|
||||
Widget.__init__(self)
|
||||
self.onSubmit = onSubmit
|
||||
self.maxwidth = maxwidth
|
||||
self.buffer = b""
|
||||
self.cursor = 0
|
||||
|
||||
def setText(self, text):
|
||||
self.buffer = text[: self.maxwidth]
|
||||
self.cursor = len(self.buffer)
|
||||
self.repaint()
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
if self.cursor > 0:
|
||||
self.cursor -= 1
|
||||
self.repaint()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
if self.cursor < len(self.buffer):
|
||||
self.cursor += 1
|
||||
self.repaint()
|
||||
|
||||
def backspaceReceived(self):
|
||||
if self.cursor > 0:
|
||||
self.buffer = self.buffer[: self.cursor - 1] + self.buffer[self.cursor :]
|
||||
self.cursor -= 1
|
||||
self.repaint()
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == b"\r":
|
||||
self.onSubmit(self.buffer)
|
||||
else:
|
||||
if len(self.buffer) < self.maxwidth:
|
||||
self.buffer = (
|
||||
self.buffer[: self.cursor] + keyID + self.buffer[self.cursor :]
|
||||
)
|
||||
self.cursor += 1
|
||||
self.repaint()
|
||||
|
||||
def sizeHint(self):
|
||||
return self.maxwidth + 1, 1
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
currentText = self._renderText()
|
||||
terminal.cursorPosition(0, 0)
|
||||
if self.focused:
|
||||
terminal.write(currentText[: self.cursor])
|
||||
cursor(terminal, currentText[self.cursor : self.cursor + 1] or b" ")
|
||||
terminal.write(currentText[self.cursor + 1 :])
|
||||
terminal.write(b" " * (self.maxwidth - len(currentText) + 1))
|
||||
else:
|
||||
more = self.maxwidth - len(currentText)
|
||||
terminal.write(currentText + b"_" * more)
|
||||
|
||||
def _renderText(self):
|
||||
return self.buffer
|
||||
|
||||
|
||||
class PasswordInput(TextInput):
|
||||
def _renderText(self):
|
||||
return "*" * len(self.buffer)
|
||||
|
||||
|
||||
class TextOutput(Widget):
|
||||
text = b""
|
||||
|
||||
def __init__(self, size=None):
|
||||
Widget.__init__(self)
|
||||
self.size = size
|
||||
|
||||
def sizeHint(self):
|
||||
return self.size
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
text = self.text[:width]
|
||||
terminal.write(text + b" " * (width - len(text)))
|
||||
|
||||
def setText(self, text):
|
||||
self.text = text
|
||||
self.repaint()
|
||||
|
||||
def focusReceived(self):
|
||||
raise YieldFocus()
|
||||
|
||||
|
||||
class TextOutputArea(TextOutput):
|
||||
WRAP, TRUNCATE = range(2)
|
||||
|
||||
def __init__(self, size=None, longLines=WRAP):
|
||||
TextOutput.__init__(self, size)
|
||||
self.longLines = longLines
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
n = 0
|
||||
inputLines = self.text.splitlines()
|
||||
outputLines = []
|
||||
while inputLines:
|
||||
if self.longLines == self.WRAP:
|
||||
line = inputLines.pop(0)
|
||||
if not isinstance(line, str):
|
||||
line = line.decode("utf-8")
|
||||
wrappedLines = []
|
||||
for wrappedLine in tptext.greedyWrap(line, width):
|
||||
if not isinstance(wrappedLine, bytes):
|
||||
wrappedLine = wrappedLine.encode("utf-8")
|
||||
wrappedLines.append(wrappedLine)
|
||||
outputLines.extend(wrappedLines or [b""])
|
||||
else:
|
||||
outputLines.append(inputLines.pop(0)[:width])
|
||||
if len(outputLines) >= height:
|
||||
break
|
||||
for n, L in enumerate(outputLines[:height]):
|
||||
terminal.cursorPosition(0, n)
|
||||
terminal.write(L)
|
||||
|
||||
|
||||
class Viewport(Widget):
|
||||
_xOffset = 0
|
||||
_yOffset = 0
|
||||
|
||||
@property
|
||||
def xOffset(self):
|
||||
return self._xOffset
|
||||
|
||||
@xOffset.setter
|
||||
def xOffset(self, value):
|
||||
if self._xOffset != value:
|
||||
self._xOffset = value
|
||||
self.repaint()
|
||||
|
||||
@property
|
||||
def yOffset(self):
|
||||
return self._yOffset
|
||||
|
||||
@yOffset.setter
|
||||
def yOffset(self, value):
|
||||
if self._yOffset != value:
|
||||
self._yOffset = value
|
||||
self.repaint()
|
||||
|
||||
_width = 160
|
||||
_height = 24
|
||||
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self.containee = containee
|
||||
self.containee.parent = self
|
||||
|
||||
self._buf = helper.TerminalBuffer()
|
||||
self._buf.width = self._width
|
||||
self._buf.height = self._height
|
||||
self._buf.connectionMade()
|
||||
|
||||
def filthy(self):
|
||||
self.containee.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
self.containee.draw(self._width, self._height, self._buf)
|
||||
|
||||
# XXX /Lame/
|
||||
for y, line in enumerate(
|
||||
self._buf.lines[self._yOffset : self._yOffset + height]
|
||||
):
|
||||
terminal.cursorPosition(0, y)
|
||||
n = 0
|
||||
for n, (ch, attr) in enumerate(line[self._xOffset : self._xOffset + width]):
|
||||
if ch is self._buf.void:
|
||||
ch = b" "
|
||||
terminal.write(ch)
|
||||
if n < width:
|
||||
terminal.write(b" " * (width - n - 1))
|
||||
|
||||
|
||||
class _Scrollbar(Widget):
|
||||
def __init__(self, onScroll):
|
||||
Widget.__init__(self)
|
||||
self.onScroll = onScroll
|
||||
self.percent = 0.0
|
||||
|
||||
def smaller(self):
|
||||
self.percent = min(1.0, max(0.0, self.onScroll(-1)))
|
||||
self.repaint()
|
||||
|
||||
def bigger(self):
|
||||
self.percent = min(1.0, max(0.0, self.onScroll(+1)))
|
||||
self.repaint()
|
||||
|
||||
|
||||
class HorizontalScrollbar(_Scrollbar):
|
||||
def sizeHint(self):
|
||||
return (None, 1)
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
self.smaller()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
self.bigger()
|
||||
|
||||
_left = "\N{BLACK LEFT-POINTING TRIANGLE}"
|
||||
_right = "\N{BLACK RIGHT-POINTING TRIANGLE}"
|
||||
_bar = "\N{LIGHT SHADE}"
|
||||
_slider = "\N{DARK SHADE}"
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
n = width - 3
|
||||
before = int(n * self.percent)
|
||||
after = n - before
|
||||
me = (
|
||||
self._left
|
||||
+ (self._bar * before)
|
||||
+ self._slider
|
||||
+ (self._bar * after)
|
||||
+ self._right
|
||||
)
|
||||
terminal.write(me.encode("utf-8"))
|
||||
|
||||
|
||||
class VerticalScrollbar(_Scrollbar):
|
||||
def sizeHint(self):
|
||||
return (1, None)
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
self.smaller()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
self.bigger()
|
||||
|
||||
_up = "\N{BLACK UP-POINTING TRIANGLE}"
|
||||
_down = "\N{BLACK DOWN-POINTING TRIANGLE}"
|
||||
_bar = "\N{LIGHT SHADE}"
|
||||
_slider = "\N{DARK SHADE}"
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
knob = int(self.percent * (height - 2))
|
||||
terminal.write(self._up.encode("utf-8"))
|
||||
for i in range(1, height - 1):
|
||||
terminal.cursorPosition(0, i)
|
||||
if i != (knob + 1):
|
||||
terminal.write(self._bar.encode("utf-8"))
|
||||
else:
|
||||
terminal.write(self._slider.encode("utf-8"))
|
||||
terminal.cursorPosition(0, height - 1)
|
||||
terminal.write(self._down.encode("utf-8"))
|
||||
|
||||
|
||||
class ScrolledArea(Widget):
|
||||
"""
|
||||
A L{ScrolledArea} contains another widget wrapped in a viewport and
|
||||
vertical and horizontal scrollbars for moving the viewport around.
|
||||
"""
|
||||
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self._viewport = Viewport(containee)
|
||||
self._horiz = HorizontalScrollbar(self._horizScroll)
|
||||
self._vert = VerticalScrollbar(self._vertScroll)
|
||||
|
||||
for w in self._viewport, self._horiz, self._vert:
|
||||
w.parent = self
|
||||
|
||||
def _horizScroll(self, n):
|
||||
self._viewport.xOffset += n
|
||||
self._viewport.xOffset = max(0, self._viewport.xOffset)
|
||||
return self._viewport.xOffset / 25.0
|
||||
|
||||
def _vertScroll(self, n):
|
||||
self._viewport.yOffset += n
|
||||
self._viewport.yOffset = max(0, self._viewport.yOffset)
|
||||
return self._viewport.yOffset / 25.0
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
self._vert.smaller()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
self._vert.bigger()
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
self._horiz.smaller()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
self._horiz.bigger()
|
||||
|
||||
def filthy(self):
|
||||
self._viewport.filthy()
|
||||
self._horiz.filthy()
|
||||
self._vert.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
wrapper = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
|
||||
self._viewport.draw(width - 2, height - 2, wrapper)
|
||||
if self.focused:
|
||||
terminal.write(b"\x1b[31m")
|
||||
horizontalLine(terminal, 0, 1, width - 1)
|
||||
verticalLine(terminal, 0, 1, height - 1)
|
||||
self._vert.draw(
|
||||
1, height - 1, BoundedTerminalWrapper(terminal, 1, height - 1, width - 1, 0)
|
||||
)
|
||||
self._horiz.draw(
|
||||
width, 1, BoundedTerminalWrapper(terminal, width, 1, 0, height - 1)
|
||||
)
|
||||
terminal.write(b"\x1b[0m")
|
||||
|
||||
|
||||
def cursor(terminal, ch):
|
||||
terminal.saveCursor()
|
||||
terminal.selectGraphicRendition(str(insults.REVERSE_VIDEO))
|
||||
terminal.write(ch)
|
||||
terminal.restoreCursor()
|
||||
terminal.cursorForward()
|
||||
|
||||
|
||||
class Selection(Widget):
|
||||
# Index into the sequence
|
||||
focusedIndex = 0
|
||||
|
||||
# Offset into the displayed subset of the sequence
|
||||
renderOffset = 0
|
||||
|
||||
def __init__(self, sequence, onSelect, minVisible=None):
|
||||
Widget.__init__(self)
|
||||
self.sequence = sequence
|
||||
self.onSelect = onSelect
|
||||
self.minVisible = minVisible
|
||||
if minVisible is not None:
|
||||
self._width = max(map(len, self.sequence))
|
||||
|
||||
def sizeHint(self):
|
||||
if self.minVisible is not None:
|
||||
return self._width, self.minVisible
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
if self.focusedIndex > 0:
|
||||
self.focusedIndex -= 1
|
||||
if self.renderOffset > 0:
|
||||
self.renderOffset -= 1
|
||||
self.repaint()
|
||||
|
||||
def func_PGUP(self, modifier):
|
||||
if self.renderOffset != 0:
|
||||
self.focusedIndex -= self.renderOffset
|
||||
self.renderOffset = 0
|
||||
else:
|
||||
self.focusedIndex = max(0, self.focusedIndex - self.height)
|
||||
self.repaint()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
if self.focusedIndex < len(self.sequence) - 1:
|
||||
self.focusedIndex += 1
|
||||
if self.renderOffset < self.height - 1:
|
||||
self.renderOffset += 1
|
||||
self.repaint()
|
||||
|
||||
def func_PGDN(self, modifier):
|
||||
if self.renderOffset != self.height - 1:
|
||||
change = self.height - self.renderOffset - 1
|
||||
if change + self.focusedIndex >= len(self.sequence):
|
||||
change = len(self.sequence) - self.focusedIndex - 1
|
||||
self.focusedIndex += change
|
||||
self.renderOffset = self.height - 1
|
||||
else:
|
||||
self.focusedIndex = min(
|
||||
len(self.sequence) - 1, self.focusedIndex + self.height
|
||||
)
|
||||
self.repaint()
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == b"\r":
|
||||
self.onSelect(self.sequence[self.focusedIndex])
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
self.height = height
|
||||
start = self.focusedIndex - self.renderOffset
|
||||
if start > len(self.sequence) - height:
|
||||
start = max(0, len(self.sequence) - height)
|
||||
|
||||
elements = self.sequence[start : start + height]
|
||||
|
||||
for n, ele in enumerate(elements):
|
||||
terminal.cursorPosition(0, n)
|
||||
if n == self.renderOffset:
|
||||
terminal.saveCursor()
|
||||
if self.focused:
|
||||
modes = str(insults.REVERSE_VIDEO), str(insults.BOLD)
|
||||
else:
|
||||
modes = (str(insults.REVERSE_VIDEO),)
|
||||
terminal.selectGraphicRendition(*modes)
|
||||
text = ele[:width]
|
||||
terminal.write(text + (b" " * (width - len(text))))
|
||||
if n == self.renderOffset:
|
||||
terminal.restoreCursor()
|
||||
456
.venv/lib/python3.12/site-packages/twisted/conch/interfaces.py
Normal file
456
.venv/lib/python3.12/site-packages/twisted/conch/interfaces.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains interfaces defined for the L{twisted.conch} package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from zope.interface import Attribute, Interface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from twisted.conch.ssh.keys import Key
|
||||
|
||||
|
||||
class IConchUser(Interface):
|
||||
"""
|
||||
A user who has been authenticated to Cred through Conch. This is
|
||||
the interface between the SSH connection and the user.
|
||||
"""
|
||||
|
||||
conn = Attribute("The SSHConnection object for this user.")
|
||||
|
||||
def lookupChannel(channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side requested a channel of some sort.
|
||||
|
||||
C{channelType} is the type of channel being requested,
|
||||
as an ssh connection protocol channel type.
|
||||
C{data} is any other packet data (often nothing).
|
||||
|
||||
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
|
||||
the channel type is unknown, we return C{None}.
|
||||
|
||||
For other failures, we raise an exception. If a
|
||||
L{ConchError<error.ConchError>} is raised, the C{.value} will
|
||||
be the message, and the C{.data} will be the error code.
|
||||
|
||||
@param channelType: The requested channel type
|
||||
@type channelType: L{bytes}
|
||||
@param windowSize: The initial size of the remote window
|
||||
@type windowSize: L{int}
|
||||
@param maxPacket: The largest packet we should send
|
||||
@type maxPacket: L{int}
|
||||
@param data: Additional request data
|
||||
@type data: L{bytes}
|
||||
@rtype: a subclass of L{SSHChannel} or L{None}
|
||||
"""
|
||||
|
||||
def lookupSubsystem(subsystem, data):
|
||||
"""
|
||||
The other side requested a subsystem.
|
||||
|
||||
We return a L{Protocol} implementing the requested subsystem.
|
||||
If the subsystem is not available, we return C{None}.
|
||||
|
||||
@param subsystem: The name of the subsystem being requested
|
||||
@type subsystem: L{bytes}
|
||||
@param data: Additional request data (often nothing)
|
||||
@type data: L{bytes}
|
||||
@rtype: L{Protocol} or L{None}
|
||||
"""
|
||||
|
||||
def gotGlobalRequest(requestType, data):
|
||||
"""
|
||||
A global request was sent from the other side.
|
||||
|
||||
We return a true value on success or a false value on failure.
|
||||
If we indicate success by returning a tuple, its second item
|
||||
will be sent to the other side as additional response data.
|
||||
|
||||
@param requestType: The type of the request
|
||||
@type requestType: L{bytes}
|
||||
@param data: Additional request data
|
||||
@type data: L{bytes}
|
||||
@rtype: boolean or L{tuple}
|
||||
"""
|
||||
|
||||
|
||||
class ISession(Interface):
|
||||
def getPty(term, windowSize, modes):
|
||||
"""
|
||||
Get a pseudo-terminal for use by a shell or command.
|
||||
|
||||
If a pseudo-terminal is not available, or the request otherwise
|
||||
fails, raise an exception.
|
||||
"""
|
||||
|
||||
def openShell(proto):
|
||||
"""
|
||||
Open a shell and connect it to proto.
|
||||
|
||||
@param proto: a L{ProcessProtocol} instance.
|
||||
"""
|
||||
|
||||
def execCommand(proto, command):
|
||||
"""
|
||||
Execute a command.
|
||||
|
||||
@param proto: a L{ProcessProtocol} instance.
|
||||
"""
|
||||
|
||||
def windowChanged(newWindowSize):
|
||||
"""
|
||||
Called when the size of the remote screen has changed.
|
||||
"""
|
||||
|
||||
def eofReceived():
|
||||
"""
|
||||
Called when the other side has indicated no more data will be sent.
|
||||
"""
|
||||
|
||||
def closed():
|
||||
"""
|
||||
Called when the session is closed.
|
||||
"""
|
||||
|
||||
|
||||
class EnvironmentVariableNotPermitted(ValueError):
|
||||
"""Setting this environment variable in this session is not permitted."""
|
||||
|
||||
|
||||
class ISessionSetEnv(Interface):
|
||||
"""A session that can set environment variables."""
|
||||
|
||||
def setEnv(name, value):
|
||||
"""
|
||||
Set an environment variable for the shell or command to be started.
|
||||
|
||||
From U{RFC 4254, section 6.4
|
||||
<https://tools.ietf.org/html/rfc4254#section-6.4>}: "Uncontrolled
|
||||
setting of environment variables in a privileged process can be a
|
||||
security hazard. It is recommended that implementations either
|
||||
maintain a list of allowable variable names or only set environment
|
||||
variables after the server process has dropped sufficient
|
||||
privileges."
|
||||
|
||||
(OpenSSH refuses all environment variables by default, but has an
|
||||
C{AcceptEnv} configuration option to select specific variables to
|
||||
accept.)
|
||||
|
||||
@param name: The name of the environment variable to set.
|
||||
@type name: L{bytes}
|
||||
@param value: The value of the environment variable to set.
|
||||
@type value: L{bytes}
|
||||
@raise EnvironmentVariableNotPermitted: if setting this environment
|
||||
variable is not permitted.
|
||||
"""
|
||||
|
||||
|
||||
class ISFTPServer(Interface):
|
||||
"""
|
||||
SFTP subsystem for server-side communication.
|
||||
|
||||
Each method should check to verify that the user has permission for
|
||||
their actions.
|
||||
"""
|
||||
|
||||
avatar = Attribute(
|
||||
"""
|
||||
The avatar returned by the Realm that we are authenticated with,
|
||||
and represents the logged-in user.
|
||||
"""
|
||||
)
|
||||
|
||||
def gotVersion(otherVersion, extData):
|
||||
"""
|
||||
Called when the client sends their version info.
|
||||
|
||||
otherVersion is an integer representing the version of the SFTP
|
||||
protocol they are claiming.
|
||||
extData is a dictionary of extended_name : extended_data items.
|
||||
These items are sent by the client to indicate additional features.
|
||||
|
||||
This method should return a dictionary of extended_name : extended_data
|
||||
items. These items are the additional features (if any) supported
|
||||
by the server.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def openFile(filename, flags, attrs):
|
||||
"""
|
||||
Called when the clients asks to open a file.
|
||||
|
||||
@param filename: a string representing the file to open.
|
||||
|
||||
@param flags: an integer of the flags to open the file with, ORed
|
||||
together. The flags and their values are listed at the bottom of
|
||||
L{twisted.conch.ssh.filetransfer} as FXF_*.
|
||||
|
||||
@param attrs: a list of attributes to open the file with. It is a
|
||||
dictionary, consisting of 0 or more keys. The possible keys are::
|
||||
|
||||
size: the size of the file in bytes
|
||||
uid: the user ID of the file as an integer
|
||||
gid: the group ID of the file as an integer
|
||||
permissions: the permissions of the file with as an integer.
|
||||
the bit representation of this field is defined by POSIX.
|
||||
atime: the access time of the file as seconds since the epoch.
|
||||
mtime: the modification time of the file as seconds since the epoch.
|
||||
ext_*: extended attributes. The server is not required to
|
||||
understand this, but it may.
|
||||
|
||||
NOTE: there is no way to indicate text or binary files. it is up
|
||||
to the SFTP client to deal with this.
|
||||
|
||||
This method returns an object that meets the ISFTPFile interface.
|
||||
Alternatively, it can return a L{Deferred} that will be called back
|
||||
with the object.
|
||||
"""
|
||||
|
||||
def removeFile(filename):
|
||||
"""
|
||||
Remove the given file.
|
||||
|
||||
This method returns when the remove succeeds, or a Deferred that is
|
||||
called back when it succeeds.
|
||||
|
||||
@param filename: the name of the file as a string.
|
||||
"""
|
||||
|
||||
def renameFile(oldpath, newpath):
|
||||
"""
|
||||
Rename the given file.
|
||||
|
||||
This method returns when the rename succeeds, or a L{Deferred} that is
|
||||
called back when it succeeds. If the rename fails, C{renameFile} will
|
||||
raise an implementation-dependent exception.
|
||||
|
||||
@param oldpath: the current location of the file.
|
||||
@param newpath: the new file name.
|
||||
"""
|
||||
|
||||
def makeDirectory(path, attrs):
|
||||
"""
|
||||
Make a directory.
|
||||
|
||||
This method returns when the directory is created, or a Deferred that
|
||||
is called back when it is created.
|
||||
|
||||
@param path: the name of the directory to create as a string.
|
||||
@param attrs: a dictionary of attributes to create the directory with.
|
||||
Its meaning is the same as the attrs in the L{openFile} method.
|
||||
"""
|
||||
|
||||
def removeDirectory(path):
|
||||
"""
|
||||
Remove a directory (non-recursively)
|
||||
|
||||
It is an error to remove a directory that has files or directories in
|
||||
it.
|
||||
|
||||
This method returns when the directory is removed, or a Deferred that
|
||||
is called back when it is removed.
|
||||
|
||||
@param path: the directory to remove.
|
||||
"""
|
||||
|
||||
def openDirectory(path):
|
||||
"""
|
||||
Open a directory for scanning.
|
||||
|
||||
This method returns an iterable object that has a close() method,
|
||||
or a Deferred that is called back with same.
|
||||
|
||||
The close() method is called when the client is finished reading
|
||||
from the directory. At this point, the iterable will no longer
|
||||
be used.
|
||||
|
||||
The iterable should return triples of the form (filename,
|
||||
longname, attrs) or Deferreds that return the same. The
|
||||
sequence must support __getitem__, but otherwise may be any
|
||||
'sequence-like' object.
|
||||
|
||||
filename is the name of the file relative to the directory.
|
||||
logname is an expanded format of the filename. The recommended format
|
||||
is:
|
||||
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
|
||||
1234567890 123 12345678 12345678 12345678 123456789012
|
||||
|
||||
The first line is sample output, the second is the length of the field.
|
||||
The fields are: permissions, link count, user owner, group owner,
|
||||
size in bytes, modification time.
|
||||
|
||||
attrs is a dictionary in the format of the attrs argument to openFile.
|
||||
|
||||
@param path: the directory to open.
|
||||
"""
|
||||
|
||||
def getAttrs(path, followLinks):
|
||||
"""
|
||||
Return the attributes for the given path.
|
||||
|
||||
This method returns a dictionary in the same format as the attrs
|
||||
argument to openFile or a Deferred that is called back with same.
|
||||
|
||||
@param path: the path to return attributes for as a string.
|
||||
@param followLinks: a boolean. If it is True, follow symbolic links
|
||||
and return attributes for the real path at the base. If it is False,
|
||||
return attributes for the specified path.
|
||||
"""
|
||||
|
||||
def setAttrs(path, attrs):
|
||||
"""
|
||||
Set the attributes for the path.
|
||||
|
||||
This method returns when the attributes are set or a Deferred that is
|
||||
called back when they are.
|
||||
|
||||
@param path: the path to set attributes for as a string.
|
||||
@param attrs: a dictionary in the same format as the attrs argument to
|
||||
L{openFile}.
|
||||
"""
|
||||
|
||||
def readLink(path):
|
||||
"""
|
||||
Find the root of a set of symbolic links.
|
||||
|
||||
This method returns the target of the link, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param path: the path of the symlink to read.
|
||||
"""
|
||||
|
||||
def makeLink(linkPath, targetPath):
|
||||
"""
|
||||
Create a symbolic link.
|
||||
|
||||
This method returns when the link is made, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param linkPath: the pathname of the symlink as a string.
|
||||
@param targetPath: the path of the target of the link as a string.
|
||||
"""
|
||||
|
||||
def realPath(path):
|
||||
"""
|
||||
Convert any path to an absolute path.
|
||||
|
||||
This method returns the absolute path as a string, or a Deferred
|
||||
that returns the same.
|
||||
|
||||
@param path: the path to convert as a string.
|
||||
"""
|
||||
|
||||
def extendedRequest(extendedName, extendedData):
|
||||
"""
|
||||
This is the extension mechanism for SFTP. The other side can send us
|
||||
arbitrary requests.
|
||||
|
||||
If we don't implement the request given by extendedName, raise
|
||||
NotImplementedError.
|
||||
|
||||
The return value is a string, or a Deferred that will be called
|
||||
back with a string.
|
||||
|
||||
@param extendedName: the name of the request as a string.
|
||||
@param extendedData: the data the other side sent with the request,
|
||||
as a string.
|
||||
"""
|
||||
|
||||
|
||||
class IKnownHostEntry(Interface):
|
||||
"""
|
||||
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
|
||||
file.
|
||||
|
||||
@since: 8.2
|
||||
"""
|
||||
|
||||
def matchesKey(key: Key) -> bool:
|
||||
"""
|
||||
Return True if this entry matches the given Key object, False
|
||||
otherwise.
|
||||
|
||||
@param key: The key object to match against.
|
||||
"""
|
||||
|
||||
def matchesHost(hostname: bytes) -> bool:
|
||||
"""
|
||||
Return True if this entry matches the given hostname, False otherwise.
|
||||
|
||||
Note that this does no name resolution; if you want to match an IP
|
||||
address, you have to resolve it yourself, and pass it in as a dotted
|
||||
quad string.
|
||||
|
||||
@param hostname: The hostname to match against.
|
||||
"""
|
||||
|
||||
def toString() -> bytes:
|
||||
"""
|
||||
@return: a serialized string representation of this entry, suitable for
|
||||
inclusion in a known_hosts file. (Newline not included.)
|
||||
"""
|
||||
|
||||
|
||||
class ISFTPFile(Interface):
|
||||
"""
|
||||
This represents an open file on the server. An object adhering to this
|
||||
interface should be returned from L{openFile}().
|
||||
"""
|
||||
|
||||
def close():
|
||||
"""
|
||||
Close the file.
|
||||
|
||||
This method returns nothing if the close succeeds immediately, or a
|
||||
Deferred that is called back when the close succeeds.
|
||||
"""
|
||||
|
||||
def readChunk(offset, length):
|
||||
"""
|
||||
Read from the file.
|
||||
|
||||
If EOF is reached before any data is read, raise EOFError.
|
||||
|
||||
This method returns the data as a string, or a Deferred that is
|
||||
called back with same.
|
||||
|
||||
@param offset: an integer that is the index to start from in the file.
|
||||
@param length: the maximum length of data to return. The actual amount
|
||||
returned may less than this. For normal disk files, however,
|
||||
this should read the requested number (up to the end of the file).
|
||||
"""
|
||||
|
||||
def writeChunk(offset, data):
|
||||
"""
|
||||
Write to the file.
|
||||
|
||||
This method returns when the write completes, or a Deferred that is
|
||||
called when it completes.
|
||||
|
||||
@param offset: an integer that is the index to start from in the file.
|
||||
@param data: a string that is the data to write.
|
||||
"""
|
||||
|
||||
def getAttrs():
|
||||
"""
|
||||
Return the attributes for the file.
|
||||
|
||||
This method returns a dictionary in the same format as the attrs
|
||||
argument to L{openFile} or a L{Deferred} that is called back with same.
|
||||
"""
|
||||
|
||||
def setAttrs(attrs):
|
||||
"""
|
||||
Set the attributes for the file.
|
||||
|
||||
This method returns when the attributes are set or a Deferred that is
|
||||
called back when they are.
|
||||
|
||||
@param attrs: a dictionary in the same format as the attrs argument to
|
||||
L{openFile}.
|
||||
"""
|
||||
104
.venv/lib/python3.12/site-packages/twisted/conch/ls.py
Normal file
104
.venv/lib/python3.12/site-packages/twisted/conch/ls.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_cftp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import array
|
||||
import stat
|
||||
from time import localtime, strftime, time
|
||||
|
||||
# Locale-independent month names to use instead of strftime's
|
||||
_MONTH_NAMES = dict(
|
||||
list(
|
||||
zip(
|
||||
list(range(1, 13)),
|
||||
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def lsLine(name, s):
|
||||
"""
|
||||
Build an 'ls' line for a file ('file' in its generic sense, it
|
||||
can be of any type).
|
||||
"""
|
||||
mode = s.st_mode
|
||||
perms = array.array("B", b"-" * 10)
|
||||
ft = stat.S_IFMT(mode)
|
||||
if stat.S_ISDIR(ft):
|
||||
perms[0] = ord("d")
|
||||
elif stat.S_ISCHR(ft):
|
||||
perms[0] = ord("c")
|
||||
elif stat.S_ISBLK(ft):
|
||||
perms[0] = ord("b")
|
||||
elif stat.S_ISREG(ft):
|
||||
perms[0] = ord("-")
|
||||
elif stat.S_ISFIFO(ft):
|
||||
perms[0] = ord("f")
|
||||
elif stat.S_ISLNK(ft):
|
||||
perms[0] = ord("l")
|
||||
elif stat.S_ISSOCK(ft):
|
||||
perms[0] = ord("s")
|
||||
else:
|
||||
perms[0] = ord("!")
|
||||
# User
|
||||
if mode & stat.S_IRUSR:
|
||||
perms[1] = ord("r")
|
||||
if mode & stat.S_IWUSR:
|
||||
perms[2] = ord("w")
|
||||
if mode & stat.S_IXUSR:
|
||||
perms[3] = ord("x")
|
||||
# Group
|
||||
if mode & stat.S_IRGRP:
|
||||
perms[4] = ord("r")
|
||||
if mode & stat.S_IWGRP:
|
||||
perms[5] = ord("w")
|
||||
if mode & stat.S_IXGRP:
|
||||
perms[6] = ord("x")
|
||||
# Other
|
||||
if mode & stat.S_IROTH:
|
||||
perms[7] = ord("r")
|
||||
if mode & stat.S_IWOTH:
|
||||
perms[8] = ord("w")
|
||||
if mode & stat.S_IXOTH:
|
||||
perms[9] = ord("x")
|
||||
# Suid/sgid
|
||||
if mode & stat.S_ISUID:
|
||||
if perms[3] == ord("x"):
|
||||
perms[3] = ord("s")
|
||||
else:
|
||||
perms[3] = ord("S")
|
||||
if mode & stat.S_ISGID:
|
||||
if perms[6] == ord("x"):
|
||||
perms[6] = ord("s")
|
||||
else:
|
||||
perms[6] = ord("S")
|
||||
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("utf-8")
|
||||
lsPerms = perms.tobytes()
|
||||
lsPerms = lsPerms.decode("utf-8")
|
||||
|
||||
lsresult = [
|
||||
lsPerms,
|
||||
str(s.st_nlink).rjust(5),
|
||||
" ",
|
||||
str(s.st_uid).ljust(9),
|
||||
str(s.st_gid).ljust(9),
|
||||
str(s.st_size).rjust(8),
|
||||
" ",
|
||||
]
|
||||
# Need to specify the month manually, as strftime depends on locale
|
||||
ttup = localtime(s.st_mtime)
|
||||
sixmonths = 60 * 60 * 24 * 7 * 26
|
||||
if s.st_mtime + sixmonths < time(): # Last edited more than 6mo ago
|
||||
strtime = strftime("%%s %d %Y ", ttup)
|
||||
else:
|
||||
strtime = strftime("%%s %d %H:%M ", ttup)
|
||||
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
|
||||
|
||||
lsresult.append(name)
|
||||
return "".join(lsresult)
|
||||
|
||||
|
||||
__all__ = ["lsLine"]
|
||||
392
.venv/lib/python3.12/site-packages/twisted/conch/manhole.py
Normal file
392
.venv/lib/python3.12/site-packages/twisted/conch/manhole.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Line-input oriented interactive interpreter loop.
|
||||
|
||||
Provides classes for handling Python source input and arbitrary output
|
||||
interactively from a Twisted application. Also included is syntax coloring
|
||||
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
|
||||
and reasonable handling of Deferreds.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import code
|
||||
import sys
|
||||
import tokenize
|
||||
from io import BytesIO
|
||||
from traceback import format_exception
|
||||
from types import TracebackType
|
||||
from typing import Type
|
||||
|
||||
from twisted.conch import recvline
|
||||
from twisted.internet import defer
|
||||
from twisted.python.htmlizer import TokenPrinter
|
||||
from twisted.python.monkey import MonkeyPatcher
|
||||
|
||||
|
||||
class FileWrapper:
|
||||
"""
|
||||
Minimal write-file-like object.
|
||||
|
||||
Writes are translated into addOutput calls on an object passed to
|
||||
__init__. Newlines are also converted from network to local style.
|
||||
"""
|
||||
|
||||
softspace = 0
|
||||
state = "normal"
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def write(self, data):
|
||||
self.o.addOutput(data.replace("\r\n", "\n"))
|
||||
|
||||
def writelines(self, lines):
|
||||
self.write("".join(lines))
|
||||
|
||||
|
||||
class ManholeInterpreter(code.InteractiveInterpreter):
|
||||
"""
|
||||
Interactive Interpreter with special output and Deferred support.
|
||||
|
||||
Aside from the features provided by L{code.InteractiveInterpreter}, this
|
||||
class captures sys.stdout output and redirects it to the appropriate
|
||||
location (the Manhole protocol instance). It also treats Deferreds
|
||||
which reach the top-level specially: each is formatted to the user with
|
||||
a unique identifier and a new callback and errback added to it, each of
|
||||
which will format the unique identifier and the result with which the
|
||||
Deferred fires and then pass it on to the next participant in the
|
||||
callback chain.
|
||||
"""
|
||||
|
||||
numDeferreds = 0
|
||||
|
||||
def __init__(self, handler, locals=None, filename="<console>"):
|
||||
code.InteractiveInterpreter.__init__(self, locals)
|
||||
self._pendingDeferreds = {}
|
||||
self.handler = handler
|
||||
self.filename = filename
|
||||
self.resetBuffer()
|
||||
|
||||
self.monkeyPatcher = MonkeyPatcher()
|
||||
self.monkeyPatcher.addPatch(sys, "displayhook", self.displayhook)
|
||||
self.monkeyPatcher.addPatch(sys, "excepthook", self.excepthook)
|
||||
self.monkeyPatcher.addPatch(sys, "stdout", FileWrapper(self.handler))
|
||||
|
||||
def resetBuffer(self):
|
||||
"""
|
||||
Reset the input buffer.
|
||||
"""
|
||||
self.buffer = []
|
||||
|
||||
def push(self, line):
|
||||
"""
|
||||
Push a line to the interpreter.
|
||||
|
||||
The line should not have a trailing newline; it may have
|
||||
internal newlines. The line is appended to a buffer and the
|
||||
interpreter's runsource() method is called with the
|
||||
concatenated contents of the buffer as source. If this
|
||||
indicates that the command was executed or invalid, the buffer
|
||||
is reset; otherwise, the command is incomplete, and the buffer
|
||||
is left as it was after the line was appended. The return
|
||||
value is 1 if more input is required, 0 if the line was dealt
|
||||
with in some way (this is the same as runsource()).
|
||||
|
||||
@param line: line of text
|
||||
@type line: L{bytes}
|
||||
@return: L{bool} from L{code.InteractiveInterpreter.runsource}
|
||||
"""
|
||||
self.buffer.append(line)
|
||||
source = b"\n".join(self.buffer)
|
||||
source = source.decode("utf-8")
|
||||
more = self.runsource(source, self.filename)
|
||||
if not more:
|
||||
self.resetBuffer()
|
||||
return more
|
||||
|
||||
def runcode(self, *a, **kw):
|
||||
with self.monkeyPatcher:
|
||||
code.InteractiveInterpreter.runcode(self, *a, **kw)
|
||||
|
||||
def excepthook(
|
||||
self,
|
||||
excType: Type[BaseException],
|
||||
excValue: BaseException,
|
||||
excTraceback: TracebackType,
|
||||
) -> None:
|
||||
"""
|
||||
Format exception tracebacks and write them to the output handler.
|
||||
"""
|
||||
code_obj = excTraceback.tb_frame.f_code
|
||||
if code_obj.co_filename == code.__file__ and code_obj.co_name == "runcode":
|
||||
traceback = excTraceback.tb_next
|
||||
else:
|
||||
# Workaround for https://github.com/python/cpython/issues/122478,
|
||||
# present e.g. in Python 3.12.6:
|
||||
traceback = excTraceback
|
||||
lines = format_exception(excType, excValue, traceback)
|
||||
self.write("".join(lines))
|
||||
|
||||
def displayhook(self, obj):
|
||||
self.locals["_"] = obj
|
||||
if isinstance(obj, defer.Deferred):
|
||||
# XXX Ick, where is my "hasFired()" interface?
|
||||
if hasattr(obj, "result"):
|
||||
self.write(repr(obj))
|
||||
elif id(obj) in self._pendingDeferreds:
|
||||
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
|
||||
else:
|
||||
d = self._pendingDeferreds
|
||||
k = self.numDeferreds
|
||||
d[id(obj)] = (k, obj)
|
||||
self.numDeferreds += 1
|
||||
obj.addCallbacks(
|
||||
self._cbDisplayDeferred,
|
||||
self._ebDisplayDeferred,
|
||||
callbackArgs=(k, obj),
|
||||
errbackArgs=(k, obj),
|
||||
)
|
||||
self.write("<Deferred #%d>" % (k,))
|
||||
elif obj is not None:
|
||||
self.write(repr(obj))
|
||||
|
||||
def _cbDisplayDeferred(self, result, k, obj):
|
||||
self.write("Deferred #%d called back: %r" % (k, result), True)
|
||||
del self._pendingDeferreds[id(obj)]
|
||||
return result
|
||||
|
||||
def _ebDisplayDeferred(self, failure, k, obj):
|
||||
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
|
||||
del self._pendingDeferreds[id(obj)]
|
||||
return failure
|
||||
|
||||
def write(self, data, isAsync=None):
|
||||
self.handler.addOutput(data, isAsync)
|
||||
|
||||
|
||||
CTRL_C = b"\x03"
|
||||
CTRL_D = b"\x04"
|
||||
CTRL_BACKSLASH = b"\x1c"
|
||||
CTRL_L = b"\x0c"
|
||||
CTRL_A = b"\x01"
|
||||
CTRL_E = b"\x05"
|
||||
|
||||
|
||||
class Manhole(recvline.HistoricRecvLine):
|
||||
r"""
|
||||
Mediator between a fancy line source and an interactive interpreter.
|
||||
|
||||
This accepts lines from its transport and passes them on to a
|
||||
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
|
||||
with something approximating their normal terminal-mode behavior. It
|
||||
can optionally be constructed with a dict which will be used as the
|
||||
local namespace for any code executed.
|
||||
"""
|
||||
|
||||
namespace = None
|
||||
|
||||
def __init__(self, namespace=None):
|
||||
recvline.HistoricRecvLine.__init__(self)
|
||||
if namespace is not None:
|
||||
self.namespace = namespace.copy()
|
||||
|
||||
def connectionMade(self):
|
||||
recvline.HistoricRecvLine.connectionMade(self)
|
||||
self.interpreter = ManholeInterpreter(self, self.namespace)
|
||||
self.keyHandlers[CTRL_C] = self.handle_INT
|
||||
self.keyHandlers[CTRL_D] = self.handle_EOF
|
||||
self.keyHandlers[CTRL_L] = self.handle_FF
|
||||
self.keyHandlers[CTRL_A] = self.handle_HOME
|
||||
self.keyHandlers[CTRL_E] = self.handle_END
|
||||
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
|
||||
|
||||
def handle_INT(self):
|
||||
"""
|
||||
Handle ^C as an interrupt keystroke by resetting the current input
|
||||
variables to their initial state.
|
||||
"""
|
||||
self.pn = 0
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
self.interpreter.resetBuffer()
|
||||
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write(b"KeyboardInterrupt")
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
def handle_EOF(self):
|
||||
if self.lineBuffer:
|
||||
self.terminal.write(b"\a")
|
||||
else:
|
||||
self.handle_QUIT()
|
||||
|
||||
def handle_FF(self):
|
||||
"""
|
||||
Handle a 'form feed' byte - generally used to request a screen
|
||||
refresh/redraw.
|
||||
"""
|
||||
self.terminal.eraseDisplay()
|
||||
self.terminal.cursorHome()
|
||||
self.drawInputLine()
|
||||
|
||||
def handle_QUIT(self):
|
||||
self.terminal.loseConnection()
|
||||
|
||||
def _needsNewline(self):
|
||||
w = self.terminal.lastWrite
|
||||
return not w.endswith(b"\n") and not w.endswith(b"\x1bE")
|
||||
|
||||
def addOutput(self, data, isAsync=None):
|
||||
if isAsync:
|
||||
self.terminal.eraseLine()
|
||||
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]))
|
||||
|
||||
self.terminal.write(data)
|
||||
|
||||
if isAsync:
|
||||
if self._needsNewline():
|
||||
self.terminal.nextLine()
|
||||
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
if self.lineBuffer:
|
||||
oldBuffer = self.lineBuffer
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
self._deliverBuffer(oldBuffer)
|
||||
|
||||
def lineReceived(self, line):
|
||||
more = self.interpreter.push(line)
|
||||
self.pn = bool(more)
|
||||
if self._needsNewline():
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
|
||||
class VT102Writer:
|
||||
"""
|
||||
Colorizer for Python tokens.
|
||||
|
||||
A series of tokens are written to instances of this object. Each is
|
||||
colored in a particular way. The final line of the result of this is
|
||||
generally added to the output.
|
||||
"""
|
||||
|
||||
typeToColor = {
|
||||
"identifier": b"\x1b[31m",
|
||||
"keyword": b"\x1b[32m",
|
||||
"parameter": b"\x1b[33m",
|
||||
"variable": b"\x1b[1;33m",
|
||||
"string": b"\x1b[35m",
|
||||
"number": b"\x1b[36m",
|
||||
"op": b"\x1b[37m",
|
||||
}
|
||||
|
||||
normalColor = b"\x1b[0m"
|
||||
|
||||
def __init__(self):
|
||||
self.written = []
|
||||
|
||||
def color(self, type):
|
||||
r = self.typeToColor.get(type, b"")
|
||||
return r
|
||||
|
||||
def write(self, token, type=None):
|
||||
if token and token != b"\r":
|
||||
c = self.color(type)
|
||||
if c:
|
||||
self.written.append(c)
|
||||
self.written.append(token)
|
||||
if c:
|
||||
self.written.append(self.normalColor)
|
||||
|
||||
def __bytes__(self):
|
||||
s = b"".join(self.written)
|
||||
return s.strip(b"\n").splitlines()[-1]
|
||||
|
||||
|
||||
def lastColorizedLine(source):
|
||||
"""
|
||||
Tokenize and colorize the given Python source.
|
||||
|
||||
Returns a VT102-format colorized version of the last line of C{source}.
|
||||
|
||||
@param source: Python source code
|
||||
@type source: L{str} or L{bytes}
|
||||
@return: L{bytes} of colorized source
|
||||
"""
|
||||
if not isinstance(source, bytes):
|
||||
source = source.encode("utf-8")
|
||||
w = VT102Writer()
|
||||
p = TokenPrinter(w.write).printtoken
|
||||
s = BytesIO(source)
|
||||
|
||||
for token in tokenize.tokenize(s.readline):
|
||||
(tokenType, string, start, end, line) = token
|
||||
p(tokenType, string, start, end, line)
|
||||
|
||||
return bytes(w)
|
||||
|
||||
|
||||
class ColoredManhole(Manhole):
|
||||
"""
|
||||
A REPL which syntax colors input as users type it.
|
||||
"""
|
||||
|
||||
def getSource(self):
|
||||
"""
|
||||
Return a string containing the currently entered source.
|
||||
|
||||
This is only the code which will be considered for execution
|
||||
next.
|
||||
"""
|
||||
return b"\n".join(self.interpreter.buffer) + b"\n" + b"".join(self.lineBuffer)
|
||||
|
||||
def characterReceived(self, ch, moreCharactersComing):
|
||||
if self.mode == "insert":
|
||||
self.lineBuffer.insert(self.lineBufferIndex, ch)
|
||||
else:
|
||||
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
|
||||
self.lineBufferIndex += 1
|
||||
|
||||
if moreCharactersComing:
|
||||
# Skip it all, we'll get called with another character in
|
||||
# like 2 femtoseconds.
|
||||
return
|
||||
|
||||
if ch == b" ":
|
||||
# Don't bother to try to color whitespace
|
||||
self.terminal.write(ch)
|
||||
return
|
||||
|
||||
source = self.getSource()
|
||||
|
||||
# Try to write some junk
|
||||
try:
|
||||
coloredLine = lastColorizedLine(source)
|
||||
except tokenize.TokenError:
|
||||
# We couldn't do it. Strange. Oh well, just add the character.
|
||||
self.terminal.write(ch)
|
||||
else:
|
||||
# Success! Clear the source on this line.
|
||||
self.terminal.eraseLine()
|
||||
self.terminal.cursorBackward(
|
||||
len(self.lineBuffer) + len(self.ps[self.pn]) - 1
|
||||
)
|
||||
|
||||
# And write a new, colorized one.
|
||||
self.terminal.write(self.ps[self.pn] + coloredLine)
|
||||
|
||||
# And move the cursor to where it belongs
|
||||
n = len(self.lineBuffer) - self.lineBufferIndex
|
||||
if n:
|
||||
self.terminal.cursorBackward(n)
|
||||
148
.venv/lib/python3.12/site-packages/twisted/conch/manhole_ssh.py
Normal file
148
.venv/lib/python3.12/site-packages/twisted/conch/manhole_ssh.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
insults/SSH integration support.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch import avatar, error as econch, interfaces as iconch
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.conch.ssh import factory, session
|
||||
from twisted.python import components
|
||||
|
||||
|
||||
class _Glue:
|
||||
"""
|
||||
A feeble class for making one attribute look like another.
|
||||
|
||||
This should be replaced with a real class at some point, probably.
|
||||
Try not to write new code that uses it.
|
||||
"""
|
||||
|
||||
def __init__(self, **kw):
|
||||
self.__dict__.update(kw)
|
||||
|
||||
def __getattr__(self, name):
|
||||
raise AttributeError(self.name, "has no attribute", name)
|
||||
|
||||
|
||||
class TerminalSessionTransport:
|
||||
def __init__(self, proto, chainedProtocol, avatar, width, height):
|
||||
self.proto = proto
|
||||
self.avatar = avatar
|
||||
self.chainedProtocol = chainedProtocol
|
||||
|
||||
protoSession = self.proto.session
|
||||
|
||||
self.proto.makeConnection(
|
||||
_Glue(
|
||||
write=self.chainedProtocol.dataReceived,
|
||||
loseConnection=lambda: avatar.conn.sendClose(protoSession),
|
||||
name="SSH Proto Transport",
|
||||
)
|
||||
)
|
||||
|
||||
def loseConnection():
|
||||
self.proto.loseConnection()
|
||||
|
||||
self.chainedProtocol.makeConnection(
|
||||
_Glue(
|
||||
write=self.proto.write,
|
||||
loseConnection=loseConnection,
|
||||
name="Chained Proto Transport",
|
||||
)
|
||||
)
|
||||
|
||||
# XXX TODO
|
||||
# chainedProtocol is supposed to be an ITerminalTransport,
|
||||
# maybe. That means perhaps its terminalProtocol attribute is
|
||||
# an ITerminalProtocol, it could be. So calling terminalSize
|
||||
# on that should do the right thing But it'd be nice to clean
|
||||
# this bit up.
|
||||
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
|
||||
|
||||
|
||||
@implementer(iconch.ISession)
|
||||
class TerminalSession(components.Adapter):
|
||||
transportFactory = TerminalSessionTransport
|
||||
chainedProtocolFactory = insults.ServerProtocol
|
||||
|
||||
def getPty(self, term, windowSize, attrs):
|
||||
self.height, self.width = windowSize[:2]
|
||||
|
||||
def openShell(self, proto):
|
||||
self.transportFactory(
|
||||
proto,
|
||||
self.chainedProtocolFactory(),
|
||||
iconch.IConchUser(self.original),
|
||||
self.width,
|
||||
self.height,
|
||||
)
|
||||
|
||||
def execCommand(self, proto, cmd):
|
||||
raise econch.ConchError("Cannot execute commands")
|
||||
|
||||
def windowChanged(self, newWindowSize):
|
||||
# ISession.windowChanged
|
||||
raise NotImplementedError("Unimplemented: TerminalSession.windowChanged")
|
||||
|
||||
def eofReceived(self):
|
||||
# ISession.eofReceived
|
||||
raise NotImplementedError("Unimplemented: TerminalSession.eofReceived")
|
||||
|
||||
def closed(self):
|
||||
# ISession.closed
|
||||
pass
|
||||
|
||||
|
||||
class TerminalUser(avatar.ConchUser, components.Adapter):
|
||||
def __init__(self, original, avatarId):
|
||||
components.Adapter.__init__(self, original)
|
||||
avatar.ConchUser.__init__(self)
|
||||
self.channelLookup[b"session"] = session.SSHSession
|
||||
|
||||
|
||||
class TerminalRealm:
|
||||
userFactory = TerminalUser
|
||||
sessionFactory = TerminalSession
|
||||
|
||||
transportFactory = TerminalSessionTransport
|
||||
chainedProtocolFactory = insults.ServerProtocol
|
||||
|
||||
def _getAvatar(self, avatarId):
|
||||
comp = components.Componentized()
|
||||
user = self.userFactory(comp, avatarId)
|
||||
sess = self.sessionFactory(comp)
|
||||
|
||||
sess.transportFactory = self.transportFactory
|
||||
sess.chainedProtocolFactory = self.chainedProtocolFactory
|
||||
|
||||
comp.setComponent(iconch.IConchUser, user)
|
||||
comp.setComponent(iconch.ISession, sess)
|
||||
|
||||
return user
|
||||
|
||||
def __init__(self, transportFactory=None):
|
||||
if transportFactory is not None:
|
||||
self.transportFactory = transportFactory
|
||||
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
for i in interfaces:
|
||||
if i is iconch.IConchUser:
|
||||
return (iconch.IConchUser, self._getAvatar(avatarId), lambda: None)
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConchFactory(factory.SSHFactory):
|
||||
publicKeys: Dict[bytes, bytes] = {}
|
||||
privateKeys: Dict[bytes, bytes] = {}
|
||||
|
||||
def __init__(self, portal):
|
||||
self.portal = portal
|
||||
180
.venv/lib/python3.12/site-packages/twisted/conch/manhole_tap.py
Normal file
180
.venv/lib/python3.12/site-packages/twisted/conch/manhole_tap.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
TAP plugin for creating telnet- and ssh-accessible manhole servers.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.application import service, strports
|
||||
from twisted.conch import manhole, manhole_ssh, telnet
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.cred import checkers, portal
|
||||
from twisted.internet import protocol
|
||||
from twisted.python import filepath, usage
|
||||
|
||||
|
||||
class makeTelnetProtocol:
|
||||
def __init__(self, portal):
|
||||
self.portal = portal
|
||||
|
||||
def __call__(self):
|
||||
auth = telnet.AuthenticatingTelnetProtocol
|
||||
args = (self.portal,)
|
||||
return telnet.TelnetTransport(auth, *args)
|
||||
|
||||
|
||||
class chainedProtocolFactory:
|
||||
def __init__(self, namespace):
|
||||
self.namespace = namespace
|
||||
|
||||
def __call__(self):
|
||||
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
|
||||
|
||||
|
||||
@implementer(portal.IRealm)
|
||||
class _StupidRealm:
|
||||
def __init__(self, proto, *a, **kw):
|
||||
self.protocolFactory = proto
|
||||
self.protocolArgs = a
|
||||
self.protocolKwArgs = kw
|
||||
|
||||
def requestAvatar(self, avatarId, *interfaces):
|
||||
if telnet.ITelnetProtocol in interfaces:
|
||||
return (
|
||||
telnet.ITelnetProtocol,
|
||||
self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs),
|
||||
lambda: None,
|
||||
)
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Options(usage.Options):
|
||||
optParameters = [
|
||||
[
|
||||
"telnetPort",
|
||||
"t",
|
||||
None,
|
||||
(
|
||||
"strports description of the address on which to listen for telnet "
|
||||
"connections"
|
||||
),
|
||||
],
|
||||
[
|
||||
"sshPort",
|
||||
"s",
|
||||
None,
|
||||
(
|
||||
"strports description of the address on which to listen for ssh "
|
||||
"connections"
|
||||
),
|
||||
],
|
||||
[
|
||||
"passwd",
|
||||
"p",
|
||||
"/etc/passwd",
|
||||
"name of a passwd(5)-format username/password file",
|
||||
],
|
||||
[
|
||||
"sshKeyDir",
|
||||
None,
|
||||
"<USER DATA DIR>",
|
||||
"Directory where the autogenerated SSH key is kept.",
|
||||
],
|
||||
["sshKeyName", None, "server.key", "Filename of the autogenerated SSH key."],
|
||||
["sshKeySize", None, 4096, "Size of the automatically generated SSH key."],
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
usage.Options.__init__(self)
|
||||
self["namespace"] = None
|
||||
|
||||
def postOptions(self):
|
||||
if self["telnetPort"] is None and self["sshPort"] is None:
|
||||
raise usage.UsageError(
|
||||
"At least one of --telnetPort and --sshPort must be specified"
|
||||
)
|
||||
|
||||
|
||||
def makeService(options):
|
||||
"""
|
||||
Create a manhole server service.
|
||||
|
||||
@type options: L{dict}
|
||||
@param options: A mapping describing the configuration of
|
||||
the desired service. Recognized key/value pairs are::
|
||||
|
||||
"telnetPort": strports description of the address on which
|
||||
to listen for telnet connections. If None,
|
||||
no telnet service will be started.
|
||||
|
||||
"sshPort": strports description of the address on which to
|
||||
listen for ssh connections. If None, no ssh
|
||||
service will be started.
|
||||
|
||||
"namespace": dictionary containing desired initial locals
|
||||
for manhole connections. If None, an empty
|
||||
dictionary will be used.
|
||||
|
||||
"passwd": Name of a passwd(5)-format username/password file.
|
||||
|
||||
"sshKeyDir": The folder that the SSH server key will be kept in.
|
||||
|
||||
"sshKeyName": The filename of the key.
|
||||
|
||||
"sshKeySize": The size of the key, in bits. Default is 4096.
|
||||
|
||||
@rtype: L{twisted.application.service.IService}
|
||||
@return: A manhole service.
|
||||
"""
|
||||
svc = service.MultiService()
|
||||
|
||||
namespace = options["namespace"]
|
||||
if namespace is None:
|
||||
namespace = {}
|
||||
|
||||
checker = checkers.FilePasswordDB(options["passwd"])
|
||||
|
||||
if options["telnetPort"]:
|
||||
telnetRealm = _StupidRealm(
|
||||
telnet.TelnetBootstrapProtocol,
|
||||
insults.ServerProtocol,
|
||||
manhole.ColoredManhole,
|
||||
namespace,
|
||||
)
|
||||
|
||||
telnetPortal = portal.Portal(telnetRealm, [checker])
|
||||
|
||||
telnetFactory = protocol.ServerFactory()
|
||||
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
|
||||
telnetService = strports.service(options["telnetPort"], telnetFactory)
|
||||
telnetService.setServiceParent(svc)
|
||||
|
||||
if options["sshPort"]:
|
||||
sshRealm = manhole_ssh.TerminalRealm()
|
||||
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
|
||||
|
||||
sshPortal = portal.Portal(sshRealm, [checker])
|
||||
sshFactory = manhole_ssh.ConchFactory(sshPortal)
|
||||
|
||||
if options["sshKeyDir"] != "<USER DATA DIR>":
|
||||
keyDir = options["sshKeyDir"]
|
||||
else:
|
||||
from twisted.python._appdirs import getDataDirectory
|
||||
|
||||
keyDir = getDataDirectory()
|
||||
|
||||
keyLocation = filepath.FilePath(keyDir).child(options["sshKeyName"])
|
||||
|
||||
sshKey = keys._getPersistentRSAKey(keyLocation, int(options["sshKeySize"]))
|
||||
sshFactory.publicKeys[b"ssh-rsa"] = sshKey
|
||||
sshFactory.privateKeys[b"ssh-rsa"] = sshKey
|
||||
|
||||
sshService = strports.service(options["sshPort"], sshFactory)
|
||||
sshService.setServiceParent(svc)
|
||||
|
||||
return svc
|
||||
54
.venv/lib/python3.12/site-packages/twisted/conch/mixin.py
Normal file
54
.venv/lib/python3.12/site-packages/twisted/conch/mixin.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_mixin -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Experimental optimization
|
||||
|
||||
This module provides a single mixin class which allows protocols to
|
||||
collapse numerous small writes into a single larger one.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
|
||||
class BufferingMixin:
|
||||
"""
|
||||
Mixin which adds write buffering.
|
||||
"""
|
||||
|
||||
_delayedWriteCall = None
|
||||
data = None
|
||||
|
||||
DELAY = 0.0
|
||||
|
||||
def schedule(self):
|
||||
return reactor.callLater(self.DELAY, self.flush)
|
||||
|
||||
def reschedule(self, token):
|
||||
token.reset(self.DELAY)
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
Buffer some bytes to be written soon.
|
||||
|
||||
Every call to this function delays the real write by C{self.DELAY}
|
||||
seconds. When the delay expires, all collected bytes are written
|
||||
to the underlying transport using L{ITransport.writeSequence}.
|
||||
"""
|
||||
if self._delayedWriteCall is None:
|
||||
self.data = []
|
||||
self._delayedWriteCall = self.schedule()
|
||||
else:
|
||||
self.reschedule(self._delayedWriteCall)
|
||||
self.data.append(data)
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush the buffer immediately.
|
||||
"""
|
||||
self._delayedWriteCall = None
|
||||
self.transport.writeSequence(self.data)
|
||||
self.data = None
|
||||
1
.venv/lib/python3.12/site-packages/twisted/conch/newsfragments/.gitignore
vendored
Normal file
1
.venv/lib/python3.12/site-packages/twisted/conch/newsfragments/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
!.gitignore
|
||||
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
Support for OpenSSH configuration files.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
@@ -0,0 +1,74 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Factory for reading openssh configuration files: public keys, private keys, and
|
||||
moduli file.
|
||||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.conch.openssh_compat import primes
|
||||
from twisted.conch.ssh import common, factory, keys
|
||||
from twisted.python.util import runAsEffectiveUser
|
||||
|
||||
|
||||
class OpenSSHFactory(factory.SSHFactory):
|
||||
dataRoot = "/usr/local/etc"
|
||||
# For openbsd which puts moduli in a different directory from keys.
|
||||
moduliRoot = "/usr/local/etc"
|
||||
|
||||
def getPublicKeys(self):
|
||||
"""
|
||||
Return the server public keys.
|
||||
"""
|
||||
ks = {}
|
||||
for filename in os.listdir(self.dataRoot):
|
||||
if filename[:9] == "ssh_host_" and filename[-8:] == "_key.pub":
|
||||
try:
|
||||
k = keys.Key.fromFile(os.path.join(self.dataRoot, filename))
|
||||
t = common.getNS(k.blob())[0]
|
||||
ks[t] = k
|
||||
except Exception as e:
|
||||
self._log.error(
|
||||
"bad public key file {filename}: {error}",
|
||||
filename=filename,
|
||||
error=e,
|
||||
)
|
||||
return ks
|
||||
|
||||
def getPrivateKeys(self):
|
||||
"""
|
||||
Return the server private keys.
|
||||
"""
|
||||
privateKeys = {}
|
||||
for filename in os.listdir(self.dataRoot):
|
||||
if filename[:9] == "ssh_host_" and filename[-4:] == "_key":
|
||||
fullPath = os.path.join(self.dataRoot, filename)
|
||||
try:
|
||||
key = keys.Key.fromFile(fullPath)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EACCES:
|
||||
# Not allowed, let's switch to root
|
||||
key = runAsEffectiveUser(0, 0, keys.Key.fromFile, fullPath)
|
||||
privateKeys[key.sshType()] = key
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
self._log.error(
|
||||
"bad public key file {filename}: {error}",
|
||||
filename=filename,
|
||||
error=e,
|
||||
)
|
||||
else:
|
||||
privateKeys[key.sshType()] = key
|
||||
return privateKeys
|
||||
|
||||
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
|
||||
try:
|
||||
return primes.parseModuliFile(self.moduliRoot + "/moduli")
|
||||
except OSError:
|
||||
return None
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def parseModuliFile(filename: str) -> Dict[int, List[Tuple[int, int]]]:
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
primes: Dict[int, List[Tuple[int, int]]] = {}
|
||||
for l in lines:
|
||||
l = l.strip()
|
||||
if not l or l[0] == "#":
|
||||
continue
|
||||
tim, typ, tst, tri, sizestr, genstr, modstr = l.split()
|
||||
size = int(sizestr) + 1
|
||||
gen = int(genstr)
|
||||
mod = int(modstr, 16)
|
||||
if size not in primes:
|
||||
primes[size] = []
|
||||
primes[size].append((gen, mod))
|
||||
return primes
|
||||
569
.venv/lib/python3.12/site-packages/twisted/conch/recvline.py
Normal file
569
.venv/lib/python3.12/site-packages/twisted/conch/recvline.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_recvline -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Basic line editing support.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import string
|
||||
from typing import Dict
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.insults import helper, insults
|
||||
from twisted.logger import Logger
|
||||
from twisted.python import reflect
|
||||
from twisted.python.compat import iterbytes
|
||||
|
||||
_counters: Dict[str, int] = {}
|
||||
|
||||
|
||||
class Logging:
|
||||
"""
|
||||
Wrapper which logs attribute lookups.
|
||||
|
||||
This was useful in debugging something, I guess. I forget what.
|
||||
It can probably be deleted or moved somewhere more appropriate.
|
||||
Nothing special going on here, really.
|
||||
"""
|
||||
|
||||
def __init__(self, original):
|
||||
self.original = original
|
||||
key = reflect.qual(original.__class__)
|
||||
count = _counters.get(key, 0)
|
||||
_counters[key] = count + 1
|
||||
self._logFile = open(key + "-" + str(count), "w")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(super().__getattribute__("original"))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(super().__getattribute__("original"))
|
||||
|
||||
def __getattribute__(self, name):
|
||||
original = super().__getattribute__("original")
|
||||
logFile = super().__getattribute__("_logFile")
|
||||
logFile.write(name + "\n")
|
||||
return getattr(original, name)
|
||||
|
||||
|
||||
@implementer(insults.ITerminalTransport)
|
||||
class TransportSequence:
|
||||
"""
|
||||
An L{ITerminalTransport} implementation which forwards calls to
|
||||
one or more other L{ITerminalTransport}s.
|
||||
|
||||
This is a cheap way for servers to keep track of the state they
|
||||
expect the client to see, since all terminal manipulations can be
|
||||
send to the real client and to a terminal emulator that lives in
|
||||
the server process.
|
||||
"""
|
||||
|
||||
for keyID in (
|
||||
b"UP_ARROW",
|
||||
b"DOWN_ARROW",
|
||||
b"RIGHT_ARROW",
|
||||
b"LEFT_ARROW",
|
||||
b"HOME",
|
||||
b"INSERT",
|
||||
b"DELETE",
|
||||
b"END",
|
||||
b"PGUP",
|
||||
b"PGDN",
|
||||
b"F1",
|
||||
b"F2",
|
||||
b"F3",
|
||||
b"F4",
|
||||
b"F5",
|
||||
b"F6",
|
||||
b"F7",
|
||||
b"F8",
|
||||
b"F9",
|
||||
b"F10",
|
||||
b"F11",
|
||||
b"F12",
|
||||
):
|
||||
execBytes = keyID + b" = object()"
|
||||
execStr = execBytes.decode("ascii")
|
||||
exec(execStr)
|
||||
|
||||
TAB = b"\t"
|
||||
BACKSPACE = b"\x7f"
|
||||
|
||||
def __init__(self, *transports):
|
||||
assert transports, "Cannot construct a TransportSequence with no transports"
|
||||
self.transports = transports
|
||||
|
||||
for method in insults.ITerminalTransport:
|
||||
exec(
|
||||
"""\
|
||||
def %s(self, *a, **kw):
|
||||
for tpt in self.transports:
|
||||
result = tpt.%s(*a, **kw)
|
||||
return result
|
||||
"""
|
||||
% (method, method)
|
||||
)
|
||||
|
||||
def getHost(self):
|
||||
# ITransport.getHost
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.getHost")
|
||||
|
||||
def getPeer(self):
|
||||
# ITransport.getPeer
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.getPeer")
|
||||
|
||||
def loseConnection(self):
|
||||
# ITransport.loseConnection
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.loseConnection")
|
||||
|
||||
def write(self, data):
|
||||
# ITransport.write
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.write")
|
||||
|
||||
def writeSequence(self, data):
|
||||
# ITransport.writeSequence
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.writeSequence")
|
||||
|
||||
def cursorUp(self, n=1):
|
||||
# ITerminalTransport.cursorUp
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorUp")
|
||||
|
||||
def cursorDown(self, n=1):
|
||||
# ITerminalTransport.cursorDown
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorDown")
|
||||
|
||||
def cursorForward(self, n=1):
|
||||
# ITerminalTransport.cursorForward
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorForward")
|
||||
|
||||
def cursorBackward(self, n=1):
|
||||
# ITerminalTransport.cursorBackward
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorBackward")
|
||||
|
||||
def cursorPosition(self, column, line):
|
||||
# ITerminalTransport.cursorPosition
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorPosition")
|
||||
|
||||
def cursorHome(self):
|
||||
# ITerminalTransport.cursorHome
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.cursorHome")
|
||||
|
||||
def index(self):
|
||||
# ITerminalTransport.index
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.index")
|
||||
|
||||
def reverseIndex(self):
|
||||
# ITerminalTransport.reverseIndex
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.reverseIndex")
|
||||
|
||||
def nextLine(self):
|
||||
# ITerminalTransport.nextLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.nextLine")
|
||||
|
||||
def saveCursor(self):
|
||||
# ITerminalTransport.saveCursor
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.saveCursor")
|
||||
|
||||
def restoreCursor(self):
|
||||
# ITerminalTransport.restoreCursor
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.restoreCursor")
|
||||
|
||||
def setModes(self, modes):
|
||||
# ITerminalTransport.setModes
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.setModes")
|
||||
|
||||
def resetModes(self, mode):
|
||||
# ITerminalTransport.resetModes
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.resetModes")
|
||||
|
||||
def setPrivateModes(self, modes):
|
||||
# ITerminalTransport.setPrivateModes
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.setPrivateModes")
|
||||
|
||||
def resetPrivateModes(self, modes):
|
||||
# ITerminalTransport.resetPrivateModes
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.resetPrivateModes")
|
||||
|
||||
def applicationKeypadMode(self):
|
||||
# ITerminalTransport.applicationKeypadMode
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.applicationKeypadMode"
|
||||
)
|
||||
|
||||
def numericKeypadMode(self):
|
||||
# ITerminalTransport.numericKeypadMode
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.numericKeypadMode")
|
||||
|
||||
def selectCharacterSet(self, charSet, which):
|
||||
# ITerminalTransport.selectCharacterSet
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.selectCharacterSet")
|
||||
|
||||
def shiftIn(self):
|
||||
# ITerminalTransport.shiftIn
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.shiftIn")
|
||||
|
||||
def shiftOut(self):
|
||||
# ITerminalTransport.shiftOut
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.shiftOut")
|
||||
|
||||
def singleShift2(self):
|
||||
# ITerminalTransport.singleShift2
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.singleShift2")
|
||||
|
||||
def singleShift3(self):
|
||||
# ITerminalTransport.singleShift3
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.singleShift3")
|
||||
|
||||
def selectGraphicRendition(self, *attributes):
|
||||
# ITerminalTransport.selectGraphicRendition
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.selectGraphicRendition"
|
||||
)
|
||||
|
||||
def horizontalTabulationSet(self):
|
||||
# ITerminalTransport.horizontalTabulationSet
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.horizontalTabulationSet"
|
||||
)
|
||||
|
||||
def tabulationClear(self):
|
||||
# ITerminalTransport.tabulationClear
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClear")
|
||||
|
||||
def tabulationClearAll(self):
|
||||
# ITerminalTransport.tabulationClearAll
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.tabulationClearAll")
|
||||
|
||||
def doubleHeightLine(self, top=True):
|
||||
# ITerminalTransport.doubleHeightLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.doubleHeightLine")
|
||||
|
||||
def singleWidthLine(self):
|
||||
# ITerminalTransport.singleWidthLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.singleWidthLine")
|
||||
|
||||
def doubleWidthLine(self):
|
||||
# ITerminalTransport.doubleWidthLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.doubleWidthLine")
|
||||
|
||||
def eraseToLineEnd(self):
|
||||
# ITerminalTransport.eraseToLineEnd
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.eraseToLineEnd")
|
||||
|
||||
def eraseToLineBeginning(self):
|
||||
# ITerminalTransport.eraseToLineBeginning
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.eraseToLineBeginning"
|
||||
)
|
||||
|
||||
def eraseLine(self):
|
||||
# ITerminalTransport.eraseLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.eraseLine")
|
||||
|
||||
def eraseToDisplayEnd(self):
|
||||
# ITerminalTransport.eraseToDisplayEnd
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.eraseToDisplayEnd")
|
||||
|
||||
def eraseToDisplayBeginning(self):
|
||||
# ITerminalTransport.eraseToDisplayBeginning
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.eraseToDisplayBeginning"
|
||||
)
|
||||
|
||||
def eraseDisplay(self):
|
||||
# ITerminalTransport.eraseDisplay
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.eraseDisplay")
|
||||
|
||||
def deleteCharacter(self, n=1):
|
||||
# ITerminalTransport.deleteCharacter
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.deleteCharacter")
|
||||
|
||||
def insertLine(self, n=1):
|
||||
# ITerminalTransport.insertLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.insertLine")
|
||||
|
||||
def deleteLine(self, n=1):
|
||||
# ITerminalTransport.deleteLine
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.deleteLine")
|
||||
|
||||
def reportCursorPosition(self):
|
||||
# ITerminalTransport.reportCursorPosition
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.reportCursorPosition"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# ITerminalTransport.reset
|
||||
raise NotImplementedError("Unimplemented: TransportSequence.reset")
|
||||
|
||||
def unhandledControlSequence(self, seq):
|
||||
# ITerminalTransport.unhandledControlSequence
|
||||
raise NotImplementedError(
|
||||
"Unimplemented: TransportSequence.unhandledControlSequence"
|
||||
)
|
||||
|
||||
|
||||
class LocalTerminalBufferMixin:
|
||||
"""
|
||||
A mixin for RecvLine subclasses which records the state of the terminal.
|
||||
|
||||
This is accomplished by performing all L{ITerminalTransport} operations on both
|
||||
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
|
||||
|
||||
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
|
||||
will be made to keep up to date with the actual terminal
|
||||
associated with this protocol instance.
|
||||
"""
|
||||
|
||||
def makeConnection(self, transport):
|
||||
self.terminalCopy = helper.TerminalBuffer()
|
||||
self.terminalCopy.connectionMade()
|
||||
return super().makeConnection(TransportSequence(transport, self.terminalCopy))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.terminalCopy)
|
||||
|
||||
|
||||
class RecvLine(insults.TerminalProtocol):
|
||||
"""
|
||||
L{TerminalProtocol} which adds line editing features.
|
||||
|
||||
Clients will be prompted for lines of input with all the usual
|
||||
features: character echoing, left and right arrow support for
|
||||
moving the cursor to different areas of the line buffer, backspace
|
||||
and delete for removing characters, and insert for toggling
|
||||
between typeover and insert mode. Tabs will be expanded to enough
|
||||
spaces to move the cursor to the next tabstop (every four
|
||||
characters by default). Enter causes the line buffer to be
|
||||
cleared and the line to be passed to the lineReceived() method
|
||||
which, by default, does nothing. Subclasses are responsible for
|
||||
redrawing the input prompt (this will probably change).
|
||||
"""
|
||||
|
||||
width = 80
|
||||
height = 24
|
||||
|
||||
TABSTOP = 4
|
||||
|
||||
ps = (b">>> ", b"... ")
|
||||
pn = 0
|
||||
_printableChars = string.printable.encode("ascii")
|
||||
|
||||
_log = Logger()
|
||||
|
||||
def connectionMade(self):
|
||||
# A list containing the characters making up the current line
|
||||
self.lineBuffer = []
|
||||
|
||||
# A zero-based (wtf else?) index into self.lineBuffer.
|
||||
# Indicates the current cursor position.
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
t = self.terminal
|
||||
# A map of keyIDs to bound instance methods.
|
||||
self.keyHandlers = {
|
||||
t.LEFT_ARROW: self.handle_LEFT,
|
||||
t.RIGHT_ARROW: self.handle_RIGHT,
|
||||
t.TAB: self.handle_TAB,
|
||||
# Both of these should not be necessary, but figuring out
|
||||
# which is necessary is a huge hassle.
|
||||
b"\r": self.handle_RETURN,
|
||||
b"\n": self.handle_RETURN,
|
||||
t.BACKSPACE: self.handle_BACKSPACE,
|
||||
t.DELETE: self.handle_DELETE,
|
||||
t.INSERT: self.handle_INSERT,
|
||||
t.HOME: self.handle_HOME,
|
||||
t.END: self.handle_END,
|
||||
}
|
||||
|
||||
self.initializeScreen()
|
||||
|
||||
def initializeScreen(self):
|
||||
# Hmm, state sucks. Oh well.
|
||||
# For now we will just take over the whole terminal.
|
||||
self.terminal.reset()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
# XXX Note: I would prefer to default to starting in insert
|
||||
# mode, however this does not seem to actually work! I do not
|
||||
# know why. This is probably of interest to implementors
|
||||
# subclassing RecvLine.
|
||||
|
||||
# XXX XXX Note: But the unit tests all expect the initial mode
|
||||
# to be insert right now. Fuck, there needs to be a way to
|
||||
# query the current mode or something.
|
||||
# self.setTypeoverMode()
|
||||
self.setInsertMode()
|
||||
|
||||
def currentLineBuffer(self):
|
||||
s = b"".join(self.lineBuffer)
|
||||
return s[: self.lineBufferIndex], s[self.lineBufferIndex :]
|
||||
|
||||
def setInsertMode(self):
|
||||
self.mode = "insert"
|
||||
self.terminal.setModes([insults.modes.IRM])
|
||||
|
||||
def setTypeoverMode(self):
|
||||
self.mode = "typeover"
|
||||
self.terminal.resetModes([insults.modes.IRM])
|
||||
|
||||
def drawInputLine(self):
|
||||
"""
|
||||
Write a line containing the current input prompt and the current line
|
||||
buffer at the current cursor position.
|
||||
"""
|
||||
self.terminal.write(self.ps[self.pn] + b"".join(self.lineBuffer))
|
||||
|
||||
def terminalSize(self, width, height):
|
||||
# XXX - Clear the previous input line, redraw it at the new
|
||||
# cursor position
|
||||
self.terminal.eraseDisplay()
|
||||
self.terminal.cursorHome()
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.drawInputLine()
|
||||
|
||||
def unhandledControlSequence(self, seq):
|
||||
pass
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
m = self.keyHandlers.get(keyID)
|
||||
if m is not None:
|
||||
m()
|
||||
elif keyID in self._printableChars:
|
||||
self.characterReceived(keyID, False)
|
||||
else:
|
||||
self._log.warn("Received unhandled keyID: {keyID!r}", keyID=keyID)
|
||||
|
||||
def characterReceived(self, ch, moreCharactersComing):
|
||||
if self.mode == "insert":
|
||||
self.lineBuffer.insert(self.lineBufferIndex, ch)
|
||||
else:
|
||||
self.lineBuffer[self.lineBufferIndex : self.lineBufferIndex + 1] = [ch]
|
||||
self.lineBufferIndex += 1
|
||||
self.terminal.write(ch)
|
||||
|
||||
def handle_TAB(self):
|
||||
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
|
||||
self.terminal.cursorForward(n)
|
||||
self.lineBufferIndex += n
|
||||
self.lineBuffer.extend(iterbytes(b" " * n))
|
||||
|
||||
def handle_LEFT(self):
|
||||
if self.lineBufferIndex > 0:
|
||||
self.lineBufferIndex -= 1
|
||||
self.terminal.cursorBackward()
|
||||
|
||||
def handle_RIGHT(self):
|
||||
if self.lineBufferIndex < len(self.lineBuffer):
|
||||
self.lineBufferIndex += 1
|
||||
self.terminal.cursorForward()
|
||||
|
||||
def handle_HOME(self):
|
||||
if self.lineBufferIndex:
|
||||
self.terminal.cursorBackward(self.lineBufferIndex)
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
def handle_END(self):
|
||||
offset = len(self.lineBuffer) - self.lineBufferIndex
|
||||
if offset:
|
||||
self.terminal.cursorForward(offset)
|
||||
self.lineBufferIndex = len(self.lineBuffer)
|
||||
|
||||
def handle_BACKSPACE(self):
|
||||
if self.lineBufferIndex > 0:
|
||||
self.lineBufferIndex -= 1
|
||||
del self.lineBuffer[self.lineBufferIndex]
|
||||
self.terminal.cursorBackward()
|
||||
self.terminal.deleteCharacter()
|
||||
|
||||
def handle_DELETE(self):
|
||||
if self.lineBufferIndex < len(self.lineBuffer):
|
||||
del self.lineBuffer[self.lineBufferIndex]
|
||||
self.terminal.deleteCharacter()
|
||||
|
||||
def handle_RETURN(self):
|
||||
line = b"".join(self.lineBuffer)
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
self.terminal.nextLine()
|
||||
self.lineReceived(line)
|
||||
|
||||
def handle_INSERT(self):
|
||||
assert self.mode in ("typeover", "insert")
|
||||
if self.mode == "typeover":
|
||||
self.setInsertMode()
|
||||
else:
|
||||
self.setTypeoverMode()
|
||||
|
||||
def lineReceived(self, line):
|
||||
pass
|
||||
|
||||
|
||||
class HistoricRecvLine(RecvLine):
|
||||
"""
|
||||
L{TerminalProtocol} which adds both basic line-editing features and input history.
|
||||
|
||||
Everything supported by L{RecvLine} is also supported by this class. In addition, the
|
||||
up and down arrows traverse the input history. Each received line is automatically
|
||||
added to the end of the input history.
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
RecvLine.connectionMade(self)
|
||||
|
||||
self.historyLines = []
|
||||
self.historyPosition = 0
|
||||
|
||||
t = self.terminal
|
||||
self.keyHandlers.update(
|
||||
{t.UP_ARROW: self.handle_UP, t.DOWN_ARROW: self.handle_DOWN}
|
||||
)
|
||||
|
||||
def currentHistoryBuffer(self):
|
||||
b = tuple(self.historyLines)
|
||||
return b[: self.historyPosition], b[self.historyPosition :]
|
||||
|
||||
def _deliverBuffer(self, buf):
|
||||
if buf:
|
||||
for ch in iterbytes(buf[:-1]):
|
||||
self.characterReceived(ch, True)
|
||||
self.characterReceived(buf[-1:], False)
|
||||
|
||||
def handle_UP(self):
|
||||
if self.lineBuffer and self.historyPosition == len(self.historyLines):
|
||||
self.historyLines.append(b"".join(self.lineBuffer))
|
||||
if self.historyPosition > 0:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition -= 1
|
||||
self.lineBuffer = []
|
||||
|
||||
self._deliverBuffer(self.historyLines[self.historyPosition])
|
||||
|
||||
def handle_DOWN(self):
|
||||
if self.historyPosition < len(self.historyLines) - 1:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition += 1
|
||||
self.lineBuffer = []
|
||||
|
||||
self._deliverBuffer(self.historyLines[self.historyPosition])
|
||||
else:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition = len(self.historyLines)
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
def handle_RETURN(self):
|
||||
if self.lineBuffer:
|
||||
self.historyLines.append(b"".join(self.lineBuffer))
|
||||
self.historyPosition = len(self.historyLines)
|
||||
return RecvLine.handle_RETURN(self)
|
||||
@@ -0,0 +1 @@
|
||||
"conch scripts"
|
||||
1002
.venv/lib/python3.12/site-packages/twisted/conch/scripts/cftp.py
Normal file
1002
.venv/lib/python3.12/site-packages/twisted/conch/scripts/cftp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,400 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation module for the `ckeygen` command.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from importlib import reload
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.python import failure, filepath, log, usage
|
||||
|
||||
if getpass.getpass == getpass.unix_getpass: # type: ignore[attr-defined]
|
||||
try:
|
||||
import termios # hack around broken termios
|
||||
|
||||
termios.tcgetattr, termios.tcsetattr
|
||||
except (ImportError, AttributeError):
|
||||
sys.modules["termios"] = None # type: ignore[assignment]
|
||||
reload(getpass)
|
||||
|
||||
supportedKeyTypes = dict()
|
||||
|
||||
|
||||
def _keyGenerator(keyType):
|
||||
def assignkeygenerator(keygenerator):
|
||||
@wraps(keygenerator)
|
||||
def wrapper(*args, **kwargs):
|
||||
return keygenerator(*args, **kwargs)
|
||||
|
||||
supportedKeyTypes[keyType] = wrapper
|
||||
return wrapper
|
||||
|
||||
return assignkeygenerator
|
||||
|
||||
|
||||
class GeneralOptions(usage.Options):
|
||||
synopsis = """Usage: ckeygen [options]
|
||||
"""
|
||||
|
||||
longdesc = "ckeygen manipulates public/private keys in various ways."
|
||||
|
||||
optParameters = [
|
||||
["bits", "b", None, "Number of bits in the key to create."],
|
||||
["filename", "f", None, "Filename of the key file."],
|
||||
["type", "t", None, "Specify type of key to create."],
|
||||
["comment", "C", None, "Provide new comment."],
|
||||
["newpass", "N", None, "Provide new passphrase."],
|
||||
["pass", "P", None, "Provide old passphrase."],
|
||||
["format", "o", "sha256-base64", "Fingerprint format of key file."],
|
||||
[
|
||||
"private-key-subtype",
|
||||
None,
|
||||
None,
|
||||
'OpenSSH private key subtype to write ("PEM" or "v1").',
|
||||
],
|
||||
]
|
||||
|
||||
optFlags = [
|
||||
["fingerprint", "l", "Show fingerprint of key file."],
|
||||
["changepass", "p", "Change passphrase of private key file."],
|
||||
["quiet", "q", "Quiet."],
|
||||
["no-passphrase", None, "Create the key with no passphrase."],
|
||||
["showpub", "y", "Read private key file and print public key."],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
optActions={
|
||||
"type": usage.CompleteList(list(supportedKeyTypes.keys())),
|
||||
"private-key-subtype": usage.CompleteList(["PEM", "v1"]),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
options = GeneralOptions()
|
||||
try:
|
||||
options.parseOptions(sys.argv[1:])
|
||||
except usage.UsageError as u:
|
||||
print("ERROR: %s" % u)
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
log.discardLogs()
|
||||
log.deferr = handleError # HACK
|
||||
if options["type"]:
|
||||
if options["type"].lower() in supportedKeyTypes:
|
||||
print("Generating public/private %s key pair." % (options["type"]))
|
||||
supportedKeyTypes[options["type"].lower()](options)
|
||||
else:
|
||||
sys.exit(
|
||||
"Key type was %s, must be one of %s"
|
||||
% (options["type"], ", ".join(supportedKeyTypes.keys()))
|
||||
)
|
||||
elif options["fingerprint"]:
|
||||
printFingerprint(options)
|
||||
elif options["changepass"]:
|
||||
changePassPhrase(options)
|
||||
elif options["showpub"]:
|
||||
displayPublicKey(options)
|
||||
else:
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def enumrepresentation(options):
|
||||
if options["format"] == "md5-hex":
|
||||
options["format"] = keys.FingerprintFormats.MD5_HEX
|
||||
return options
|
||||
elif options["format"] == "sha256-base64":
|
||||
options["format"] = keys.FingerprintFormats.SHA256_BASE64
|
||||
return options
|
||||
else:
|
||||
raise keys.BadFingerPrintFormat(
|
||||
f"Unsupported fingerprint format: {options['format']}"
|
||||
)
|
||||
|
||||
|
||||
def handleError():
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
log.err(failure.Failure())
|
||||
raise
|
||||
|
||||
|
||||
@_keyGenerator("rsa")
|
||||
def generateRSAkey(options):
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
if not options["bits"]:
|
||||
options["bits"] = 2048
|
||||
keyPrimitive = rsa.generate_private_key(
|
||||
key_size=int(options["bits"]),
|
||||
public_exponent=65537,
|
||||
backend=default_backend(),
|
||||
)
|
||||
key = keys.Key(keyPrimitive)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
@_keyGenerator("dsa")
|
||||
def generateDSAkey(options):
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa
|
||||
|
||||
if not options["bits"]:
|
||||
options["bits"] = 1024
|
||||
keyPrimitive = dsa.generate_private_key(
|
||||
key_size=int(options["bits"]),
|
||||
backend=default_backend(),
|
||||
)
|
||||
key = keys.Key(keyPrimitive)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
@_keyGenerator("ecdsa")
|
||||
def generateECDSAkey(options):
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
|
||||
if not options["bits"]:
|
||||
options["bits"] = 256
|
||||
# OpenSSH supports only mandatory sections of RFC5656.
|
||||
# See https://www.openssh.com/txt/release-5.7
|
||||
curve = b"ecdsa-sha2-nistp" + str(options["bits"]).encode("ascii")
|
||||
keyPrimitive = ec.generate_private_key(
|
||||
curve=keys._curveTable[curve], backend=default_backend()
|
||||
)
|
||||
key = keys.Key(keyPrimitive)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
@_keyGenerator("ed25519")
|
||||
def generateEd25519key(options):
|
||||
keyPrimitive = keys.Ed25519PrivateKey.generate()
|
||||
key = keys.Key(keyPrimitive)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
def _defaultPrivateKeySubtype(keyType):
|
||||
"""
|
||||
Return a reasonable default private key subtype for a given key type.
|
||||
|
||||
@type keyType: L{str}
|
||||
@param keyType: A key type, as returned by
|
||||
L{twisted.conch.ssh.keys.Key.type}.
|
||||
|
||||
@rtype: L{str}
|
||||
@return: A private OpenSSH key subtype (C{'PEM'} or C{'v1'}).
|
||||
"""
|
||||
if keyType == "Ed25519":
|
||||
# No PEM format is defined for Ed25519 keys.
|
||||
return "v1"
|
||||
else:
|
||||
return "PEM"
|
||||
|
||||
|
||||
def _getKeyOrDefault(
|
||||
options: Dict[Any, Any],
|
||||
inputCollector: Optional[Callable[[str], str]] = None,
|
||||
keyTypeName: str = "rsa",
|
||||
) -> str:
|
||||
"""
|
||||
If C{options["filename"]} is None, prompt the user to enter a path
|
||||
or attempt to set it to .ssh/id_rsa
|
||||
@param options: command line options
|
||||
@param inputCollector: dependency injection for testing
|
||||
@param keyTypeName: key type or "rsa"
|
||||
"""
|
||||
if inputCollector is None:
|
||||
inputCollector = input
|
||||
filename = options["filename"]
|
||||
if not filename:
|
||||
filename = os.path.expanduser(f"~/.ssh/id_{keyTypeName}")
|
||||
if platform.system() == "Windows":
|
||||
filename = os.path.expanduser(Rf"%HOMEPATH %\.ssh\id_{keyTypeName}")
|
||||
filename = (
|
||||
inputCollector("Enter file in which the key is (%s): " % filename)
|
||||
or filename
|
||||
)
|
||||
return str(filename)
|
||||
|
||||
|
||||
def printFingerprint(options: Dict[Any, Any]) -> None:
|
||||
filename = _getKeyOrDefault(options)
|
||||
if os.path.exists(filename + ".pub"):
|
||||
filename += ".pub"
|
||||
options = enumrepresentation(options)
|
||||
try:
|
||||
key = keys.Key.fromFile(filename)
|
||||
print(
|
||||
"%s %s %s"
|
||||
% (
|
||||
key.size(),
|
||||
key.fingerprint(options["format"]),
|
||||
os.path.basename(filename),
|
||||
)
|
||||
)
|
||||
except keys.BadKeyError:
|
||||
sys.exit("bad key")
|
||||
except FileNotFoundError:
|
||||
sys.exit(f"{filename} could not be opened, please specify a file.")
|
||||
|
||||
|
||||
def changePassPhrase(options):
|
||||
filename = _getKeyOrDefault(options)
|
||||
try:
|
||||
key = keys.Key.fromFile(filename)
|
||||
except keys.EncryptedKeyError:
|
||||
# Raised if password not supplied for an encrypted key
|
||||
if not options.get("pass"):
|
||||
options["pass"] = getpass.getpass("Enter old passphrase: ")
|
||||
try:
|
||||
key = keys.Key.fromFile(filename, passphrase=options["pass"])
|
||||
except keys.BadKeyError:
|
||||
sys.exit("Could not change passphrase: old passphrase error")
|
||||
except keys.EncryptedKeyError as e:
|
||||
sys.exit(f"Could not change passphrase: {e}")
|
||||
except keys.BadKeyError as e:
|
||||
sys.exit(f"Could not change passphrase: {e}")
|
||||
except FileNotFoundError:
|
||||
sys.exit(f"{filename} could not be opened, please specify a file.")
|
||||
|
||||
if not options.get("newpass"):
|
||||
while 1:
|
||||
p1 = getpass.getpass("Enter new passphrase (empty for no passphrase): ")
|
||||
p2 = getpass.getpass("Enter same passphrase again: ")
|
||||
if p1 == p2:
|
||||
break
|
||||
print("Passphrases do not match. Try again.")
|
||||
options["newpass"] = p1
|
||||
|
||||
if options.get("private-key-subtype") is None:
|
||||
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
|
||||
|
||||
try:
|
||||
newkeydata = key.toString(
|
||||
"openssh",
|
||||
subtype=options["private-key-subtype"],
|
||||
passphrase=options["newpass"],
|
||||
)
|
||||
except Exception as e:
|
||||
sys.exit(f"Could not change passphrase: {e}")
|
||||
|
||||
try:
|
||||
keys.Key.fromString(newkeydata, passphrase=options["newpass"])
|
||||
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
|
||||
sys.exit(f"Could not change passphrase: {e}")
|
||||
|
||||
with open(filename, "wb") as fd:
|
||||
fd.write(newkeydata)
|
||||
|
||||
print("Your identification has been saved with the new passphrase.")
|
||||
|
||||
|
||||
def displayPublicKey(options):
|
||||
filename = _getKeyOrDefault(options)
|
||||
try:
|
||||
key = keys.Key.fromFile(filename)
|
||||
except FileNotFoundError:
|
||||
sys.exit(f"{filename} could not be opened, please specify a file.")
|
||||
except keys.EncryptedKeyError:
|
||||
if not options.get("pass"):
|
||||
options["pass"] = getpass.getpass("Enter passphrase: ")
|
||||
key = keys.Key.fromFile(filename, passphrase=options["pass"])
|
||||
displayKey = key.public().toString("openssh").decode("ascii")
|
||||
print(displayKey)
|
||||
|
||||
|
||||
def _inputSaveFile(prompt: str) -> str:
|
||||
"""
|
||||
Ask the user where to save the key.
|
||||
|
||||
This needs to be a separate function so the unit test can patch it.
|
||||
"""
|
||||
return input(prompt)
|
||||
|
||||
|
||||
def _saveKey(
|
||||
key: keys.Key,
|
||||
options: Dict[Any, Any],
|
||||
inputCollector: Optional[Callable[[str], str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Persist a SSH key on local filesystem.
|
||||
|
||||
@param key: Key which is persisted on local filesystem.
|
||||
|
||||
@param options:
|
||||
|
||||
@param inputCollector: Dependency injection for testing.
|
||||
"""
|
||||
if inputCollector is None:
|
||||
inputCollector = input
|
||||
KeyTypeMapping = {"EC": "ecdsa", "Ed25519": "ed25519", "RSA": "rsa", "DSA": "dsa"}
|
||||
keyTypeName = KeyTypeMapping[key.type()]
|
||||
filename = options["filename"]
|
||||
if not filename:
|
||||
defaultPath = _getKeyOrDefault(options, inputCollector, keyTypeName)
|
||||
newPath = _inputSaveFile(
|
||||
f"Enter file in which to save the key ({defaultPath}): "
|
||||
)
|
||||
|
||||
filename = newPath.strip() or defaultPath
|
||||
|
||||
if os.path.exists(filename):
|
||||
print(f"{filename} already exists.")
|
||||
yn = inputCollector("Overwrite (y/n)? ")
|
||||
if yn[0].lower() != "y":
|
||||
sys.exit()
|
||||
|
||||
if options.get("no-passphrase"):
|
||||
options["pass"] = b""
|
||||
elif not options["pass"]:
|
||||
while 1:
|
||||
p1 = getpass.getpass("Enter passphrase (empty for no passphrase): ")
|
||||
p2 = getpass.getpass("Enter same passphrase again: ")
|
||||
if p1 == p2:
|
||||
break
|
||||
print("Passphrases do not match. Try again.")
|
||||
options["pass"] = p1
|
||||
|
||||
if options.get("private-key-subtype") is None:
|
||||
options["private-key-subtype"] = _defaultPrivateKeySubtype(key.type())
|
||||
|
||||
comment = f"{getpass.getuser()}@{socket.gethostname()}"
|
||||
|
||||
fp = filepath.FilePath(filename)
|
||||
fp.setContent(
|
||||
key.toString(
|
||||
"openssh",
|
||||
subtype=options["private-key-subtype"],
|
||||
passphrase=options["pass"],
|
||||
)
|
||||
)
|
||||
fp.chmod(0o100600)
|
||||
|
||||
filepath.FilePath(filename + ".pub").setContent(
|
||||
key.public().toString("openssh", comment=comment)
|
||||
)
|
||||
options = enumrepresentation(options)
|
||||
|
||||
print(f"Your identification has been saved in {filename}")
|
||||
print(f"Your public key has been saved in {filename}.pub")
|
||||
print(f"The key fingerprint in {options['format']} is:")
|
||||
print(key.fingerprint(options["format"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -0,0 +1,578 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
#
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
|
||||
|
||||
# Implementation module for the `conch` command.
|
||||
#
|
||||
|
||||
import fcntl
|
||||
import getpass
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import tty
|
||||
from typing import List, Tuple
|
||||
|
||||
from twisted.conch.client import connect, default
|
||||
from twisted.conch.client.options import ConchOptions
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.ssh import channel, common, connection, forwarding, session
|
||||
from twisted.internet import reactor, stdio, task
|
||||
from twisted.python import log, usage
|
||||
from twisted.python.compat import ioType, networkString
|
||||
|
||||
|
||||
class ClientOptions(ConchOptions):
|
||||
synopsis = """Usage: conch [options] host [command]
|
||||
"""
|
||||
longdesc = (
|
||||
"conch is a SSHv2 client that allows logging into a remote "
|
||||
"machine and executing commands."
|
||||
)
|
||||
|
||||
optParameters = [
|
||||
["escape", "e", "~"],
|
||||
[
|
||||
"localforward",
|
||||
"L",
|
||||
None,
|
||||
"listen-port:host:port Forward local port to remote address",
|
||||
],
|
||||
[
|
||||
"remoteforward",
|
||||
"R",
|
||||
None,
|
||||
"listen-port:host:port Forward remote port to local address",
|
||||
],
|
||||
]
|
||||
|
||||
optFlags = [
|
||||
["null", "n", "Redirect input from /dev/null."],
|
||||
["fork", "f", "Fork to background after authentication."],
|
||||
["tty", "t", "Tty; allocate a tty even if command is given."],
|
||||
["notty", "T", "Do not allocate a tty."],
|
||||
["noshell", "N", "Do not execute a shell or command."],
|
||||
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("tty", "notty")],
|
||||
optActions={
|
||||
"localforward": usage.Completer(descr="listen-port:host:port"),
|
||||
"remoteforward": usage.Completer(descr="listen-port:host:port"),
|
||||
},
|
||||
extraActions=[
|
||||
usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr="argument", repeat=True),
|
||||
],
|
||||
)
|
||||
|
||||
localForwards: List[Tuple[int, Tuple[int, int]]] = []
|
||||
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
|
||||
|
||||
def opt_escape(self, esc):
|
||||
"""
|
||||
Set escape character; ``none'' = disable
|
||||
"""
|
||||
if esc == "none":
|
||||
self["escape"] = None
|
||||
elif esc[0] == "^" and len(esc) == 2:
|
||||
self["escape"] = chr(ord(esc[1]) - 64)
|
||||
elif len(esc) == 1:
|
||||
self["escape"] = esc
|
||||
else:
|
||||
sys.exit(f"Bad escape character '{esc}'.")
|
||||
|
||||
def opt_localforward(self, f):
|
||||
"""
|
||||
Forward local port to remote address (lport:host:port)
|
||||
"""
|
||||
localPort, remoteHost, remotePort = f.split(":") # Doesn't do v6 yet
|
||||
localPort = int(localPort)
|
||||
remotePort = int(remotePort)
|
||||
self.localForwards.append((localPort, (remoteHost, remotePort)))
|
||||
|
||||
def opt_remoteforward(self, f):
|
||||
"""
|
||||
Forward remote port to local address (rport:host:port)
|
||||
"""
|
||||
remotePort, connHost, connPort = f.split(":") # Doesn't do v6 yet
|
||||
remotePort = int(remotePort)
|
||||
connPort = int(connPort)
|
||||
self.remoteForwards.append((remotePort, (connHost, connPort)))
|
||||
|
||||
def parseArgs(self, host, *command):
|
||||
self["host"] = host
|
||||
self["command"] = " ".join(command)
|
||||
|
||||
|
||||
# Rest of code in "run"
|
||||
options = None
|
||||
conn = None
|
||||
exitStatus = 0
|
||||
old = None
|
||||
_inRawMode = 0
|
||||
_savedRawMode = None
|
||||
|
||||
|
||||
def run():
|
||||
global options, old
|
||||
args = sys.argv[1:]
|
||||
if "-l" in args: # CVS is an idiot
|
||||
i = args.index("-l")
|
||||
args = args[i : i + 2] + args
|
||||
del args[i + 2 : i + 4]
|
||||
for arg in args[:]:
|
||||
try:
|
||||
i = args.index(arg)
|
||||
if arg[:2] == "-o" and args[i + 1][0] != "-":
|
||||
args[i : i + 2] = [] # Suck on it scp
|
||||
except ValueError:
|
||||
pass
|
||||
options = ClientOptions()
|
||||
try:
|
||||
options.parseOptions(args)
|
||||
except usage.UsageError as u:
|
||||
print(f"ERROR: {u}")
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
if options["log"]:
|
||||
if options["logfile"]:
|
||||
if options["logfile"] == "-":
|
||||
f = sys.stdout
|
||||
else:
|
||||
f = open(options["logfile"], "a+")
|
||||
else:
|
||||
f = sys.stderr
|
||||
realout = sys.stdout
|
||||
log.startLogging(f)
|
||||
sys.stdout = realout
|
||||
else:
|
||||
log.discardLogs()
|
||||
doConnect()
|
||||
fd = sys.stdin.fileno()
|
||||
try:
|
||||
old = tty.tcgetattr(fd)
|
||||
except BaseException:
|
||||
old = None
|
||||
try:
|
||||
oldUSR1 = signal.signal(
|
||||
signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect)
|
||||
)
|
||||
except BaseException:
|
||||
oldUSR1 = None
|
||||
try:
|
||||
reactor.run()
|
||||
finally:
|
||||
if old:
|
||||
tty.tcsetattr(fd, tty.TCSANOW, old)
|
||||
if oldUSR1:
|
||||
signal.signal(signal.SIGUSR1, oldUSR1)
|
||||
if (options["command"] and options["tty"]) or not options["notty"]:
|
||||
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
|
||||
if sys.stdout.isatty() and not options["command"]:
|
||||
print("Connection to {} closed.".format(options["host"]))
|
||||
sys.exit(exitStatus)
|
||||
|
||||
|
||||
def handleError():
|
||||
from twisted.python import failure
|
||||
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
reactor.callLater(0.01, _stopReactor)
|
||||
log.err(failure.Failure())
|
||||
raise
|
||||
|
||||
|
||||
def _stopReactor():
|
||||
try:
|
||||
reactor.stop()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
|
||||
def doConnect():
|
||||
if "@" in options["host"]:
|
||||
options["user"], options["host"] = options["host"].split("@", 1)
|
||||
if not options.identitys:
|
||||
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
|
||||
host = options["host"]
|
||||
if not options["user"]:
|
||||
options["user"] = getpass.getuser()
|
||||
if not options["port"]:
|
||||
options["port"] = 22
|
||||
else:
|
||||
options["port"] = int(options["port"])
|
||||
host = options["host"]
|
||||
port = options["port"]
|
||||
vhk = default.verifyHostKey
|
||||
if not options["host-key-algorithms"]:
|
||||
options["host-key-algorithms"] = default.getHostKeyAlgorithms(host, options)
|
||||
uao = default.SSHUserAuthClient(options["user"], options, SSHConnection())
|
||||
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
|
||||
|
||||
|
||||
def _ebExit(f):
|
||||
global exitStatus
|
||||
exitStatus = f"conch: exiting with error {f}"
|
||||
reactor.callLater(0.1, _stopReactor)
|
||||
|
||||
|
||||
def onConnect():
|
||||
# if keyAgent and options['agent']:
|
||||
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
|
||||
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
|
||||
if hasattr(conn.transport, "sendIgnore"):
|
||||
_KeepAlive(conn)
|
||||
if options.localForwards:
|
||||
for localPort, hostport in options.localForwards:
|
||||
s = reactor.listenTCP(
|
||||
localPort,
|
||||
forwarding.SSHListenForwardingFactory(
|
||||
conn, hostport, SSHListenClientForwardingChannel
|
||||
),
|
||||
)
|
||||
conn.localForwards.append(s)
|
||||
if options.remoteForwards:
|
||||
for remotePort, hostport in options.remoteForwards:
|
||||
log.msg(f"asking for remote forwarding for {remotePort}:{hostport}")
|
||||
conn.requestRemoteForwarding(remotePort, hostport)
|
||||
reactor.addSystemEventTrigger("before", "shutdown", beforeShutdown)
|
||||
if not options["noshell"] or options["agent"]:
|
||||
conn.openChannel(SSHSession())
|
||||
if options["fork"]:
|
||||
if os.fork():
|
||||
os._exit(0)
|
||||
os.setsid()
|
||||
for i in range(3):
|
||||
try:
|
||||
os.close(i)
|
||||
except OSError as e:
|
||||
import errno
|
||||
|
||||
if e.errno != errno.EBADF:
|
||||
raise
|
||||
|
||||
|
||||
def reConnect():
|
||||
beforeShutdown()
|
||||
conn.transport.transport.loseConnection()
|
||||
|
||||
|
||||
def beforeShutdown():
|
||||
remoteForwards = options.remoteForwards
|
||||
for remotePort, hostport in remoteForwards:
|
||||
log.msg(f"cancelling {remotePort}:{hostport}")
|
||||
conn.cancelRemoteForwarding(remotePort)
|
||||
|
||||
|
||||
def stopConnection():
|
||||
if not options["reconnect"]:
|
||||
reactor.callLater(0.1, _stopReactor)
|
||||
|
||||
|
||||
class _KeepAlive:
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.globalTimeout = None
|
||||
self.lc = task.LoopingCall(self.sendGlobal)
|
||||
self.lc.start(300)
|
||||
|
||||
def sendGlobal(self):
|
||||
d = self.conn.sendGlobalRequest(
|
||||
b"conch-keep-alive@twistedmatrix.com", b"", wantReply=1
|
||||
)
|
||||
d.addBoth(self._cbGlobal)
|
||||
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
|
||||
|
||||
def _cbGlobal(self, res):
|
||||
if self.globalTimeout:
|
||||
self.globalTimeout.cancel()
|
||||
self.globalTimeout = None
|
||||
|
||||
def _ebGlobal(self):
|
||||
if self.globalTimeout:
|
||||
self.globalTimeout = None
|
||||
self.conn.transport.loseConnection()
|
||||
|
||||
|
||||
class SSHConnection(connection.SSHConnection):
|
||||
def serviceStarted(self):
|
||||
global conn
|
||||
conn = self
|
||||
self.localForwards = []
|
||||
self.remoteForwards = {}
|
||||
onConnect()
|
||||
|
||||
def serviceStopped(self):
|
||||
lf = self.localForwards
|
||||
self.localForwards = []
|
||||
for s in lf:
|
||||
s.loseConnection()
|
||||
stopConnection()
|
||||
|
||||
def requestRemoteForwarding(self, remotePort, hostport):
|
||||
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
|
||||
d = self.sendGlobalRequest(b"tcpip-forward", data, wantReply=1)
|
||||
log.msg(f"requesting remote forwarding {remotePort}:{hostport}")
|
||||
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
|
||||
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
|
||||
|
||||
def _cbRemoteForwarding(self, result, remotePort, hostport):
|
||||
log.msg(f"accepted remote forwarding {remotePort}:{hostport}")
|
||||
self.remoteForwards[remotePort] = hostport
|
||||
log.msg(repr(self.remoteForwards))
|
||||
|
||||
def _ebRemoteForwarding(self, f, remotePort, hostport):
|
||||
log.msg(f"remote forwarding {remotePort}:{hostport} failed")
|
||||
log.msg(f)
|
||||
|
||||
def cancelRemoteForwarding(self, remotePort):
|
||||
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
|
||||
self.sendGlobalRequest(b"cancel-tcpip-forward", data)
|
||||
log.msg(f"cancelling remote forwarding {remotePort}")
|
||||
try:
|
||||
del self.remoteForwards[remotePort]
|
||||
except Exception:
|
||||
pass
|
||||
log.msg(repr(self.remoteForwards))
|
||||
|
||||
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
|
||||
log.msg(f"FTCP {data!r}")
|
||||
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
|
||||
log.msg(self.remoteForwards)
|
||||
log.msg(remoteHP)
|
||||
if remoteHP[1] in self.remoteForwards:
|
||||
connectHP = self.remoteForwards[remoteHP[1]]
|
||||
log.msg(f"connect forwarding {connectHP}")
|
||||
return SSHConnectForwardingChannel(
|
||||
connectHP, remoteWindow=windowSize, remoteMaxPacket=maxPacket, conn=self
|
||||
)
|
||||
else:
|
||||
raise ConchError(
|
||||
connection.OPEN_CONNECT_FAILED, "don't know about that port"
|
||||
)
|
||||
|
||||
def channelClosed(self, channel):
|
||||
log.msg(f"connection closing {channel}")
|
||||
log.msg(self.channels)
|
||||
if len(self.channels) == 1: # Just us left
|
||||
log.msg("stopping connection")
|
||||
stopConnection()
|
||||
else:
|
||||
# Because of the unix thing
|
||||
self.__class__.__bases__[0].channelClosed(self, channel)
|
||||
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
name = b"session"
|
||||
|
||||
def channelOpen(self, foo):
|
||||
log.msg(f"session {self.id} open")
|
||||
if options["agent"]:
|
||||
d = self.conn.sendRequest(
|
||||
self, b"auth-agent-req@openssh.com", b"", wantReply=1
|
||||
)
|
||||
d.addBoth(lambda x: log.msg(x))
|
||||
if options["noshell"]:
|
||||
return
|
||||
if (options["command"] and options["tty"]) or not options["notty"]:
|
||||
_enterRawMode()
|
||||
c = session.SSHSessionClient()
|
||||
if options["escape"] and not options["notty"]:
|
||||
self.escapeMode = 1
|
||||
c.dataReceived = self.handleInput
|
||||
else:
|
||||
c.dataReceived = self.write
|
||||
c.connectionLost = lambda x: self.sendEOF()
|
||||
self.stdio = stdio.StandardIO(c)
|
||||
fd = 0
|
||||
if options["subsystem"]:
|
||||
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
|
||||
elif options["command"]:
|
||||
if options["tty"]:
|
||||
term = os.environ["TERM"]
|
||||
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
|
||||
winSize = struct.unpack("4H", winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, "")
|
||||
self.conn.sendRequest(self, b"pty-req", ptyReqData)
|
||||
signal.signal(signal.SIGWINCH, self._windowResized)
|
||||
self.conn.sendRequest(self, b"exec", common.NS(options["command"]))
|
||||
else:
|
||||
if not options["notty"]:
|
||||
term = os.environ["TERM"]
|
||||
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, "12345678")
|
||||
winSize = struct.unpack("4H", winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, "")
|
||||
self.conn.sendRequest(self, b"pty-req", ptyReqData)
|
||||
signal.signal(signal.SIGWINCH, self._windowResized)
|
||||
self.conn.sendRequest(self, b"shell", b"")
|
||||
# if hasattr(conn.transport, 'transport'):
|
||||
# conn.transport.transport.setTcpNoDelay(1)
|
||||
|
||||
def handleInput(self, char):
|
||||
if char in (b"\n", b"\r"):
|
||||
self.escapeMode = 1
|
||||
self.write(char)
|
||||
elif self.escapeMode == 1 and char == options["escape"]:
|
||||
self.escapeMode = 2
|
||||
elif self.escapeMode == 2:
|
||||
self.escapeMode = 1 # So we can chain escapes together
|
||||
if char == b".": # Disconnect
|
||||
log.msg("disconnecting from escape")
|
||||
stopConnection()
|
||||
return
|
||||
elif char == b"\x1a": # ^Z, suspend
|
||||
|
||||
def _():
|
||||
_leaveRawMode()
|
||||
sys.stdout.flush()
|
||||
sys.stdin.flush()
|
||||
os.kill(os.getpid(), signal.SIGTSTP)
|
||||
_enterRawMode()
|
||||
|
||||
reactor.callLater(0, _)
|
||||
return
|
||||
elif char == b"R": # Rekey connection
|
||||
log.msg("rekeying connection")
|
||||
self.conn.transport.sendKexInit()
|
||||
return
|
||||
elif char == b"#": # Display connections
|
||||
self.stdio.write(b"\r\nThe following connections are open:\r\n")
|
||||
channels = self.conn.channels.keys()
|
||||
channels.sort()
|
||||
for channelId in channels:
|
||||
self.stdio.write(
|
||||
networkString(
|
||||
" #{} {}\r\n".format(
|
||||
channelId, self.conn.channels[channelId]
|
||||
)
|
||||
)
|
||||
)
|
||||
return
|
||||
self.write(b"~" + char)
|
||||
else:
|
||||
self.escapeMode = 0
|
||||
self.write(char)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.stdio.write(data)
|
||||
|
||||
def extReceived(self, t, data):
|
||||
if t == connection.EXTENDED_DATA_STDERR:
|
||||
log.msg(f"got {len(data)} stderr data")
|
||||
if ioType(sys.stderr) == str:
|
||||
sys.stderr.buffer.write(data)
|
||||
else:
|
||||
sys.stderr.write(data)
|
||||
|
||||
def eofReceived(self):
|
||||
log.msg("got eof")
|
||||
self.stdio.loseWriteConnection()
|
||||
|
||||
def closeReceived(self):
|
||||
log.msg(f"remote side closed {self}")
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def closed(self):
|
||||
global old
|
||||
log.msg(f"closed {self}")
|
||||
log.msg(repr(self.conn.channels))
|
||||
|
||||
def request_exit_status(self, data):
|
||||
global exitStatus
|
||||
exitStatus = int(struct.unpack(">L", data)[0])
|
||||
log.msg(f"exit status: {exitStatus}")
|
||||
|
||||
def sendEOF(self):
|
||||
self.conn.sendEOF(self)
|
||||
|
||||
def stopWriting(self):
|
||||
self.stdio.pauseProducing()
|
||||
|
||||
def startWriting(self):
|
||||
self.stdio.resumeProducing()
|
||||
|
||||
def _windowResized(self, *args):
|
||||
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, "12345678")
|
||||
winSize = struct.unpack("4H", winsz)
|
||||
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
|
||||
self.conn.sendRequest(self, b"window-change", struct.pack("!4L", *newSize))
|
||||
|
||||
|
||||
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel):
|
||||
pass
|
||||
|
||||
|
||||
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel):
|
||||
pass
|
||||
|
||||
|
||||
def _leaveRawMode():
|
||||
global _inRawMode
|
||||
if not _inRawMode:
|
||||
return
|
||||
fd = sys.stdin.fileno()
|
||||
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
|
||||
_inRawMode = 0
|
||||
|
||||
|
||||
def _enterRawMode():
|
||||
global _inRawMode, _savedRawMode
|
||||
if _inRawMode:
|
||||
return
|
||||
fd = sys.stdin.fileno()
|
||||
try:
|
||||
old = tty.tcgetattr(fd)
|
||||
new = old[:]
|
||||
except BaseException:
|
||||
log.msg("not a typewriter!")
|
||||
else:
|
||||
# iflage
|
||||
new[0] = new[0] | tty.IGNPAR
|
||||
new[0] = new[0] & ~(
|
||||
tty.ISTRIP
|
||||
| tty.INLCR
|
||||
| tty.IGNCR
|
||||
| tty.ICRNL
|
||||
| tty.IXON
|
||||
| tty.IXANY
|
||||
| tty.IXOFF
|
||||
)
|
||||
if hasattr(tty, "IUCLC"):
|
||||
new[0] = new[0] & ~tty.IUCLC
|
||||
|
||||
# lflag
|
||||
new[3] = new[3] & ~(
|
||||
tty.ISIG
|
||||
| tty.ICANON
|
||||
| tty.ECHO
|
||||
| tty.ECHO
|
||||
| tty.ECHOE
|
||||
| tty.ECHOK
|
||||
| tty.ECHONL
|
||||
)
|
||||
if hasattr(tty, "IEXTEN"):
|
||||
new[3] = new[3] & ~tty.IEXTEN
|
||||
|
||||
# oflag
|
||||
new[1] = new[1] & ~tty.OPOST
|
||||
|
||||
new[6][tty.VMIN] = 1
|
||||
new[6][tty.VTIME] = 0
|
||||
|
||||
_savedRawMode = old
|
||||
tty.tcsetattr(fd, tty.TCSANOW, new)
|
||||
# tty.setraw(fd)
|
||||
_inRawMode = 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -0,0 +1,673 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_scripts -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation module for the `tkconch` command.
|
||||
"""
|
||||
|
||||
|
||||
import base64
|
||||
import getpass
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import tkinter as Tkinter
|
||||
import tkinter.filedialog as tkFileDialog
|
||||
import tkinter.messagebox as tkMessageBox
|
||||
from typing import List, Tuple
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.client.default import isInKnownHosts
|
||||
from twisted.conch.ssh import (
|
||||
channel,
|
||||
common,
|
||||
connection,
|
||||
forwarding,
|
||||
keys,
|
||||
session,
|
||||
transport,
|
||||
userauth,
|
||||
)
|
||||
from twisted.conch.ui import tkvt100
|
||||
from twisted.internet import defer, protocol, reactor, tksupport
|
||||
from twisted.python import log, usage
|
||||
|
||||
|
||||
class TkConchMenu(Tkinter.Frame):
|
||||
def __init__(self, *args, **params):
|
||||
## Standard heading: initialization
|
||||
Tkinter.Frame.__init__(self, *args, **params)
|
||||
|
||||
self.master.title("TkConch")
|
||||
self.localRemoteVar = Tkinter.StringVar()
|
||||
self.localRemoteVar.set("local")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Hostname").grid(
|
||||
column=1, row=1, sticky="w"
|
||||
)
|
||||
self.host = Tkinter.Entry(self)
|
||||
self.host.grid(column=2, columnspan=2, row=1, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Port").grid(
|
||||
column=1, row=2, sticky="w"
|
||||
)
|
||||
self.port = Tkinter.Entry(self)
|
||||
self.port.grid(column=2, columnspan=2, row=2, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Username").grid(
|
||||
column=1, row=3, sticky="w"
|
||||
)
|
||||
self.user = Tkinter.Entry(self)
|
||||
self.user.grid(column=2, columnspan=2, row=3, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Command").grid(
|
||||
column=1, row=4, sticky="w"
|
||||
)
|
||||
self.command = Tkinter.Entry(self)
|
||||
self.command.grid(column=2, columnspan=2, row=4, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Identity").grid(
|
||||
column=1, row=5, sticky="w"
|
||||
)
|
||||
self.identity = Tkinter.Entry(self)
|
||||
self.identity.grid(column=2, row=5, sticky="nesw")
|
||||
Tkinter.Button(self, command=self.getIdentityFile, text="Browse").grid(
|
||||
column=3, row=5, sticky="nesw"
|
||||
)
|
||||
|
||||
Tkinter.Label(self, text="Port Forwarding").grid(column=1, row=6, sticky="w")
|
||||
self.forwards = Tkinter.Listbox(self, height=0, width=0)
|
||||
self.forwards.grid(column=2, columnspan=2, row=6, sticky="nesw")
|
||||
Tkinter.Button(self, text="Add", command=self.addForward).grid(column=1, row=7)
|
||||
Tkinter.Button(self, text="Remove", command=self.removeForward).grid(
|
||||
column=1, row=8
|
||||
)
|
||||
self.forwardPort = Tkinter.Entry(self)
|
||||
self.forwardPort.grid(column=2, row=7, sticky="nesw")
|
||||
Tkinter.Label(self, text="Port").grid(column=3, row=7, sticky="nesw")
|
||||
self.forwardHost = Tkinter.Entry(self)
|
||||
self.forwardHost.grid(column=2, row=8, sticky="nesw")
|
||||
Tkinter.Label(self, text="Host").grid(column=3, row=8, sticky="nesw")
|
||||
self.localForward = Tkinter.Radiobutton(
|
||||
self, text="Local", variable=self.localRemoteVar, value="local"
|
||||
)
|
||||
self.localForward.grid(column=2, row=9)
|
||||
self.remoteForward = Tkinter.Radiobutton(
|
||||
self, text="Remote", variable=self.localRemoteVar, value="remote"
|
||||
)
|
||||
self.remoteForward.grid(column=3, row=9)
|
||||
|
||||
Tkinter.Label(self, text="Advanced Options").grid(
|
||||
column=1, columnspan=3, row=10, sticky="nesw"
|
||||
)
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Cipher").grid(
|
||||
column=1, row=11, sticky="w"
|
||||
)
|
||||
self.cipher = Tkinter.Entry(self, name="cipher")
|
||||
self.cipher.grid(column=2, columnspan=2, row=11, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="MAC").grid(
|
||||
column=1, row=12, sticky="w"
|
||||
)
|
||||
self.mac = Tkinter.Entry(self, name="mac")
|
||||
self.mac.grid(column=2, columnspan=2, row=12, sticky="nesw")
|
||||
|
||||
Tkinter.Label(self, anchor="w", justify="left", text="Escape Char").grid(
|
||||
column=1, row=13, sticky="w"
|
||||
)
|
||||
self.escape = Tkinter.Entry(self, name="escape")
|
||||
self.escape.grid(column=2, columnspan=2, row=13, sticky="nesw")
|
||||
Tkinter.Button(self, text="Connect!", command=self.doConnect).grid(
|
||||
column=1, columnspan=3, row=14, sticky="nesw"
|
||||
)
|
||||
|
||||
# Resize behavior(s)
|
||||
self.grid_rowconfigure(6, weight=1, minsize=64)
|
||||
self.grid_columnconfigure(2, weight=1, minsize=2)
|
||||
|
||||
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
|
||||
|
||||
def getIdentityFile(self):
|
||||
r = tkFileDialog.askopenfilename()
|
||||
if r:
|
||||
self.identity.delete(0, Tkinter.END)
|
||||
self.identity.insert(Tkinter.END, r)
|
||||
|
||||
def addForward(self):
|
||||
port = self.forwardPort.get()
|
||||
self.forwardPort.delete(0, Tkinter.END)
|
||||
host = self.forwardHost.get()
|
||||
self.forwardHost.delete(0, Tkinter.END)
|
||||
if self.localRemoteVar.get() == "local":
|
||||
self.forwards.insert(Tkinter.END, f"L:{port}:{host}")
|
||||
else:
|
||||
self.forwards.insert(Tkinter.END, f"R:{port}:{host}")
|
||||
|
||||
def removeForward(self):
|
||||
cur = self.forwards.curselection()
|
||||
if cur:
|
||||
self.forwards.remove(cur[0])
|
||||
|
||||
def doConnect(self):
|
||||
finished = 1
|
||||
options["host"] = self.host.get()
|
||||
options["port"] = self.port.get()
|
||||
options["user"] = self.user.get()
|
||||
options["command"] = self.command.get()
|
||||
cipher = self.cipher.get()
|
||||
mac = self.mac.get()
|
||||
escape = self.escape.get()
|
||||
if cipher:
|
||||
if cipher in SSHClientTransport.supportedCiphers:
|
||||
SSHClientTransport.supportedCiphers = [cipher]
|
||||
else:
|
||||
tkMessageBox.showerror("TkConch", "Bad cipher.")
|
||||
finished = 0
|
||||
|
||||
if mac:
|
||||
if mac in SSHClientTransport.supportedMACs:
|
||||
SSHClientTransport.supportedMACs = [mac]
|
||||
elif finished:
|
||||
tkMessageBox.showerror("TkConch", "Bad MAC.")
|
||||
finished = 0
|
||||
|
||||
if escape:
|
||||
if escape == "none":
|
||||
options["escape"] = None
|
||||
elif escape[0] == "^" and len(escape) == 2:
|
||||
options["escape"] = chr(ord(escape[1]) - 64)
|
||||
elif len(escape) == 1:
|
||||
options["escape"] = escape
|
||||
elif finished:
|
||||
tkMessageBox.showerror("TkConch", "Bad escape character '%s'." % escape)
|
||||
finished = 0
|
||||
|
||||
if self.identity.get():
|
||||
options.identitys.append(self.identity.get())
|
||||
|
||||
for line in self.forwards.get(0, Tkinter.END):
|
||||
if line[0] == "L":
|
||||
options.opt_localforward(line[2:])
|
||||
else:
|
||||
options.opt_remoteforward(line[2:])
|
||||
|
||||
if "@" in options["host"]:
|
||||
options["user"], options["host"] = options["host"].split("@", 1)
|
||||
|
||||
if (not options["host"] or not options["user"]) and finished:
|
||||
tkMessageBox.showerror("TkConch", "Missing host or username.")
|
||||
finished = 0
|
||||
if finished:
|
||||
self.master.quit()
|
||||
self.master.destroy()
|
||||
if options["log"]:
|
||||
realout = sys.stdout
|
||||
log.startLogging(sys.stderr)
|
||||
sys.stdout = realout
|
||||
else:
|
||||
log.discardLogs()
|
||||
log.deferr = handleError # HACK
|
||||
if not options.identitys:
|
||||
options.identitys = ["~/.ssh/id_rsa", "~/.ssh/id_dsa"]
|
||||
host = options["host"]
|
||||
port = int(options["port"] or 22)
|
||||
log.msg((host, port))
|
||||
reactor.connectTCP(host, port, SSHClientFactory())
|
||||
frame.master.deiconify()
|
||||
frame.master.title(
|
||||
"{}@{} - TkConch".format(options["user"], options["host"])
|
||||
)
|
||||
else:
|
||||
self.focus()
|
||||
|
||||
|
||||
class GeneralOptions(usage.Options):
|
||||
synopsis = """Usage: tkconch [options] host [command]
|
||||
"""
|
||||
|
||||
optParameters = [
|
||||
["user", "l", None, "Log in using this user name."],
|
||||
["identity", "i", "~/.ssh/identity", "Identity for public key authentication"],
|
||||
["escape", "e", "~", "Set escape character; ``none'' = disable"],
|
||||
["cipher", "c", None, "Select encryption algorithm."],
|
||||
["macs", "m", None, "Specify MAC algorithms for protocol version 2."],
|
||||
["port", "p", None, "Connect to this port. Server must be on the same port."],
|
||||
[
|
||||
"localforward",
|
||||
"L",
|
||||
None,
|
||||
"listen-port:host:port Forward local port to remote address",
|
||||
],
|
||||
[
|
||||
"remoteforward",
|
||||
"R",
|
||||
None,
|
||||
"listen-port:host:port Forward remote port to local address",
|
||||
],
|
||||
]
|
||||
|
||||
optFlags = [
|
||||
["tty", "t", "Tty; allocate a tty even if command is given."],
|
||||
["notty", "T", "Do not allocate a tty."],
|
||||
["version", "V", "Display version number only."],
|
||||
["compress", "C", "Enable compression."],
|
||||
["noshell", "N", "Do not execute a shell or command."],
|
||||
["subsystem", "s", "Invoke command (mandatory) as SSH2 subsystem."],
|
||||
["log", "v", "Log to stderr"],
|
||||
["ansilog", "a", "Print the received data to stdout"],
|
||||
]
|
||||
|
||||
_ciphers = transport.SSHClientTransport.supportedCiphers
|
||||
_macs = transport.SSHClientTransport.supportedMACs
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("tty", "notty")],
|
||||
optActions={
|
||||
"cipher": usage.CompleteList([v.decode() for v in _ciphers]),
|
||||
"macs": usage.CompleteList([v.decode() for v in _macs]),
|
||||
"localforward": usage.Completer(descr="listen-port:host:port"),
|
||||
"remoteforward": usage.Completer(descr="listen-port:host:port"),
|
||||
},
|
||||
extraActions=[
|
||||
usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr="argument", repeat=True),
|
||||
],
|
||||
)
|
||||
|
||||
identitys: List[str] = []
|
||||
localForwards: List[Tuple[int, Tuple[int, int]]] = []
|
||||
remoteForwards: List[Tuple[int, Tuple[int, int]]] = []
|
||||
|
||||
def opt_identity(self, i):
|
||||
self.identitys.append(i)
|
||||
|
||||
def opt_localforward(self, f):
|
||||
localPort, remoteHost, remotePort = f.split(":") # doesn't do v6 yet
|
||||
localPort = int(localPort)
|
||||
remotePort = int(remotePort)
|
||||
self.localForwards.append((localPort, (remoteHost, remotePort)))
|
||||
|
||||
def opt_remoteforward(self, f):
|
||||
remotePort, connHost, connPort = f.split(":") # doesn't do v6 yet
|
||||
remotePort = int(remotePort)
|
||||
connPort = int(connPort)
|
||||
self.remoteForwards.append((remotePort, (connHost, connPort)))
|
||||
|
||||
def opt_compress(self):
|
||||
SSHClientTransport.supportedCompressions[0:1] = ["zlib"]
|
||||
|
||||
def parseArgs(self, *args):
|
||||
if args:
|
||||
self["host"] = args[0]
|
||||
self["command"] = " ".join(args[1:])
|
||||
else:
|
||||
self["host"] = ""
|
||||
self["command"] = ""
|
||||
|
||||
|
||||
# Rest of code in "run"
|
||||
options = None
|
||||
menu = None
|
||||
exitStatus = 0
|
||||
frame = None
|
||||
|
||||
|
||||
def deferredAskFrame(question, echo):
|
||||
if frame.callback:
|
||||
raise ValueError("can't ask 2 questions at once!")
|
||||
d = defer.Deferred()
|
||||
resp = []
|
||||
|
||||
def gotChar(ch, resp=resp):
|
||||
if not ch:
|
||||
return
|
||||
if ch == "\x03": # C-c
|
||||
reactor.stop()
|
||||
if ch == "\r":
|
||||
frame.write("\r\n")
|
||||
stresp = "".join(resp)
|
||||
del resp
|
||||
frame.callback = None
|
||||
d.callback(stresp)
|
||||
return
|
||||
elif 32 <= ord(ch) < 127:
|
||||
resp.append(ch)
|
||||
if echo:
|
||||
frame.write(ch)
|
||||
elif ord(ch) == 8 and resp: # BS
|
||||
if echo:
|
||||
frame.write("\x08 \x08")
|
||||
resp.pop()
|
||||
|
||||
frame.callback = gotChar
|
||||
frame.write(question)
|
||||
frame.canvas.focus_force()
|
||||
return d
|
||||
|
||||
|
||||
def run():
|
||||
global menu, options, frame
|
||||
args = sys.argv[1:]
|
||||
if "-l" in args: # cvs is an idiot
|
||||
i = args.index("-l")
|
||||
args = args[i : i + 2] + args
|
||||
del args[i + 2 : i + 4]
|
||||
for arg in args[:]:
|
||||
try:
|
||||
i = args.index(arg)
|
||||
if arg[:2] == "-o" and args[i + 1][0] != "-":
|
||||
args[i : i + 2] = [] # suck on it scp
|
||||
except ValueError:
|
||||
pass
|
||||
root = Tkinter.Tk()
|
||||
root.withdraw()
|
||||
top = Tkinter.Toplevel()
|
||||
menu = TkConchMenu(top)
|
||||
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
|
||||
options = GeneralOptions()
|
||||
try:
|
||||
options.parseOptions(args)
|
||||
except usage.UsageError as u:
|
||||
print("ERROR: %s" % u)
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
for k, v in options.items():
|
||||
if v and hasattr(menu, k):
|
||||
getattr(menu, k).insert(Tkinter.END, v)
|
||||
for p, (rh, rp) in options.localForwards:
|
||||
menu.forwards.insert(Tkinter.END, f"L:{p}:{rh}:{rp}")
|
||||
options.localForwards = []
|
||||
for p, (rh, rp) in options.remoteForwards:
|
||||
menu.forwards.insert(Tkinter.END, f"R:{p}:{rh}:{rp}")
|
||||
options.remoteForwards = []
|
||||
frame = tkvt100.VT100Frame(root, callback=None)
|
||||
root.geometry(
|
||||
"%dx%d"
|
||||
% (tkvt100.fontWidth * frame.width + 3, tkvt100.fontHeight * frame.height + 3)
|
||||
)
|
||||
frame.pack(side=Tkinter.TOP)
|
||||
tksupport.install(root)
|
||||
root.withdraw()
|
||||
if (options["host"] and options["user"]) or "@" in options["host"]:
|
||||
menu.doConnect()
|
||||
else:
|
||||
top.mainloop()
|
||||
reactor.run()
|
||||
sys.exit(exitStatus)
|
||||
|
||||
|
||||
def handleError():
|
||||
from twisted.python import failure
|
||||
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
log.err(failure.Failure())
|
||||
reactor.stop()
|
||||
raise
|
||||
|
||||
|
||||
class SSHClientFactory(protocol.ClientFactory):
|
||||
noisy = True
|
||||
|
||||
def stopFactory(self):
|
||||
reactor.stop()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return SSHClientTransport()
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
tkMessageBox.showwarning(
|
||||
"TkConch",
|
||||
f"Connection Failed, Reason:\n {reason.type}: {reason.value}",
|
||||
)
|
||||
|
||||
|
||||
class SSHClientTransport(transport.SSHClientTransport):
|
||||
def receiveError(self, code, desc):
|
||||
global exitStatus
|
||||
exitStatus = (
|
||||
"conch:\tRemote side disconnected with error code %i\nconch:\treason: %s"
|
||||
% (code, desc)
|
||||
)
|
||||
|
||||
def sendDisconnect(self, code, reason):
|
||||
global exitStatus
|
||||
exitStatus = (
|
||||
"conch:\tSending disconnect with error code %i\nconch:\treason: %s"
|
||||
% (code, reason)
|
||||
)
|
||||
transport.SSHClientTransport.sendDisconnect(self, code, reason)
|
||||
|
||||
def receiveDebug(self, alwaysDisplay, message, lang):
|
||||
global options
|
||||
if alwaysDisplay or options["log"]:
|
||||
log.msg("Received Debug Message: %s" % message)
|
||||
|
||||
def verifyHostKey(self, pubKey, fingerprint):
|
||||
# d = defer.Deferred()
|
||||
# d.addCallback(lambda x:defer.succeed(1))
|
||||
# d.callback(2)
|
||||
# return d
|
||||
goodKey = isInKnownHosts(options["host"], pubKey, {"known-hosts": None})
|
||||
if goodKey == 1: # good key
|
||||
return defer.succeed(1)
|
||||
elif goodKey == 2: # AAHHHHH changed
|
||||
return defer.fail(error.ConchError("bad host key"))
|
||||
else:
|
||||
if options["host"] == self.transport.getPeer().host:
|
||||
host = options["host"]
|
||||
khHost = options["host"]
|
||||
else:
|
||||
host = "{} ({})".format(options["host"], self.transport.getPeer().host)
|
||||
khHost = "{},{}".format(options["host"], self.transport.getPeer().host)
|
||||
keyType = common.getNS(pubKey)[0]
|
||||
ques = """The authenticity of host '{}' can't be established.\r
|
||||
{} key fingerprint is {}.""".format(
|
||||
host,
|
||||
{b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType],
|
||||
fingerprint,
|
||||
)
|
||||
ques += "\r\nAre you sure you want to continue connecting (yes/no)? "
|
||||
return deferredAskFrame(ques, 1).addCallback(
|
||||
self._cbVerifyHostKey, pubKey, khHost, keyType
|
||||
)
|
||||
|
||||
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
|
||||
if ans.lower() not in ("yes", "no"):
|
||||
return deferredAskFrame("Please type 'yes' or 'no': ", 1).addCallback(
|
||||
self._cbVerifyHostKey, pubKey, khHost, keyType
|
||||
)
|
||||
if ans.lower() == "no":
|
||||
frame.write("Host key verification failed.\r\n")
|
||||
raise error.ConchError("bad host key")
|
||||
try:
|
||||
frame.write(
|
||||
"Warning: Permanently added '%s' (%s) to the list of "
|
||||
"known hosts.\r\n"
|
||||
% (khHost, {b"ssh-dss": "DSA", b"ssh-rsa": "RSA"}[keyType])
|
||||
)
|
||||
with open(os.path.expanduser("~/.ssh/known_hosts"), "a") as known_hosts:
|
||||
encodedKey = base64.b64encode(pubKey)
|
||||
known_hosts.write(f"\n{khHost} {keyType} {encodedKey}")
|
||||
except BaseException:
|
||||
log.deferr()
|
||||
raise error.ConchError
|
||||
|
||||
def connectionSecure(self):
|
||||
if options["user"]:
|
||||
user = options["user"]
|
||||
else:
|
||||
user = getpass.getuser()
|
||||
self.requestService(SSHUserAuthClient(user, SSHConnection()))
|
||||
|
||||
|
||||
class SSHUserAuthClient(userauth.SSHUserAuthClient):
|
||||
usedFiles: List[str] = []
|
||||
|
||||
def getPassword(self, prompt=None):
|
||||
if not prompt:
|
||||
prompt = "{}@{}'s password: ".format(self.user, options["host"])
|
||||
return deferredAskFrame(prompt, 0)
|
||||
|
||||
def getPublicKey(self):
|
||||
files = [x for x in options.identitys if x not in self.usedFiles]
|
||||
if not files:
|
||||
return None
|
||||
file = files[0]
|
||||
log.msg(file)
|
||||
self.usedFiles.append(file)
|
||||
file = os.path.expanduser(file)
|
||||
file += ".pub"
|
||||
if not os.path.exists(file):
|
||||
return
|
||||
try:
|
||||
return keys.Key.fromFile(file).blob()
|
||||
except BaseException:
|
||||
return self.getPublicKey() # try again
|
||||
|
||||
def getPrivateKey(self):
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
if not os.path.exists(file):
|
||||
return None
|
||||
try:
|
||||
return defer.succeed(keys.Key.fromFile(file).keyObject)
|
||||
except keys.BadKeyError as e:
|
||||
if e.args[0] == "encrypted key with no password":
|
||||
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
|
||||
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
|
||||
|
||||
def _cbGetPrivateKey(self, ans, count):
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
try:
|
||||
return keys.Key.fromFile(file, password=ans).keyObject
|
||||
except keys.BadKeyError:
|
||||
if count == 2:
|
||||
raise
|
||||
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
|
||||
return deferredAskFrame(prompt, 0).addCallback(
|
||||
self._cbGetPrivateKey, count + 1
|
||||
)
|
||||
|
||||
|
||||
class SSHConnection(connection.SSHConnection):
|
||||
def serviceStarted(self):
|
||||
if not options["noshell"]:
|
||||
self.openChannel(SSHSession())
|
||||
if options.localForwards:
|
||||
for localPort, hostport in options.localForwards:
|
||||
reactor.listenTCP(
|
||||
localPort,
|
||||
forwarding.SSHListenForwardingFactory(
|
||||
self, hostport, forwarding.SSHListenClientForwardingChannel
|
||||
),
|
||||
)
|
||||
if options.remoteForwards:
|
||||
for remotePort, hostport in options.remoteForwards:
|
||||
log.msg(
|
||||
"asking for remote forwarding for {}:{}".format(
|
||||
remotePort, hostport
|
||||
)
|
||||
)
|
||||
data = forwarding.packGlobal_tcpip_forward(("0.0.0.0", remotePort))
|
||||
self.sendGlobalRequest("tcpip-forward", data)
|
||||
self.remoteForwards[remotePort] = hostport
|
||||
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
name = b"session"
|
||||
|
||||
def channelOpen(self, foo):
|
||||
# global globalSession
|
||||
# globalSession = self
|
||||
# turn off local echo
|
||||
self.escapeMode = 1
|
||||
c = session.SSHSessionClient()
|
||||
if options["escape"]:
|
||||
c.dataReceived = self.handleInput
|
||||
else:
|
||||
c.dataReceived = self.write
|
||||
c.connectionLost = self.sendEOF
|
||||
frame.callback = c.dataReceived
|
||||
frame.canvas.focus_force()
|
||||
if options["subsystem"]:
|
||||
self.conn.sendRequest(self, b"subsystem", common.NS(options["command"]))
|
||||
elif options["command"]:
|
||||
if options["tty"]:
|
||||
term = os.environ.get("TERM", "xterm")
|
||||
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, "")
|
||||
self.conn.sendRequest(self, b"pty-req", ptyReqData)
|
||||
self.conn.sendRequest(self, "exec", common.NS(options["command"]))
|
||||
else:
|
||||
if not options["notty"]:
|
||||
term = os.environ.get("TERM", "xterm")
|
||||
# winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = (25, 80, 0, 0) # struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, "")
|
||||
self.conn.sendRequest(self, b"pty-req", ptyReqData)
|
||||
self.conn.sendRequest(self, b"shell", b"")
|
||||
self.conn.transport.transport.setTcpNoDelay(1)
|
||||
|
||||
def handleInput(self, char):
|
||||
# log.msg('handling %s' % repr(char))
|
||||
if char in ("\n", "\r"):
|
||||
self.escapeMode = 1
|
||||
self.write(char)
|
||||
elif self.escapeMode == 1 and char == options["escape"]:
|
||||
self.escapeMode = 2
|
||||
elif self.escapeMode == 2:
|
||||
self.escapeMode = 1 # so we can chain escapes together
|
||||
if char == ".": # disconnect
|
||||
log.msg("disconnecting from escape")
|
||||
reactor.stop()
|
||||
return
|
||||
elif char == "\x1a": # ^Z, suspend
|
||||
# following line courtesy of Erwin@freenode
|
||||
os.kill(os.getpid(), signal.SIGSTOP)
|
||||
return
|
||||
elif char == "R": # rekey connection
|
||||
log.msg("rekeying connection")
|
||||
self.conn.transport.sendKexInit()
|
||||
return
|
||||
self.write("~" + char)
|
||||
else:
|
||||
self.escapeMode = 0
|
||||
self.write(char)
|
||||
|
||||
def dataReceived(self, data):
|
||||
data = data.decode("utf-8")
|
||||
if options["ansilog"]:
|
||||
print(repr(data))
|
||||
frame.write(data)
|
||||
|
||||
def extReceived(self, t, data):
|
||||
if t == connection.EXTENDED_DATA_STDERR:
|
||||
log.msg("got %s stderr data" % len(data))
|
||||
sys.stderr.write(data)
|
||||
sys.stderr.flush()
|
||||
|
||||
def eofReceived(self):
|
||||
log.msg("got eof")
|
||||
sys.stdin.close()
|
||||
|
||||
def closed(self):
|
||||
log.msg("closed %s" % self)
|
||||
if len(self.conn.channels) == 1: # just us left
|
||||
reactor.stop()
|
||||
|
||||
def request_exit_status(self, data):
|
||||
global exitStatus
|
||||
exitStatus = int(struct.unpack(">L", data)[0])
|
||||
log.msg("exit status: %s" % exitStatus)
|
||||
|
||||
def sendEOF(self):
|
||||
self.conn.sendEOF(self)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
293
.venv/lib/python3.12/site-packages/twisted/conch/ssh/_kex.py
Normal file
293
.venv/lib/python3.12/site-packages/twisted/conch/ssh/_kex.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_transport -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
SSH key exchange handling.
|
||||
"""
|
||||
|
||||
|
||||
from hashlib import sha1, sha256, sha384, sha512
|
||||
|
||||
from zope.interface import Attribute, Interface, implementer
|
||||
|
||||
from twisted.conch import error
|
||||
|
||||
|
||||
class _IKexAlgorithm(Interface):
|
||||
"""
|
||||
An L{_IKexAlgorithm} describes a key exchange algorithm.
|
||||
"""
|
||||
|
||||
preference = Attribute(
|
||||
"An L{int} giving the preference of the algorithm when negotiating "
|
||||
"key exchange. Algorithms with lower precedence values are more "
|
||||
"preferred."
|
||||
)
|
||||
|
||||
hashProcessor = Attribute(
|
||||
"A callable hash algorithm constructor (e.g. C{hashlib.sha256}) "
|
||||
"suitable for use with this key exchange algorithm."
|
||||
)
|
||||
|
||||
|
||||
class _IFixedGroupKexAlgorithm(_IKexAlgorithm):
|
||||
"""
|
||||
An L{_IFixedGroupKexAlgorithm} describes a key exchange algorithm with a
|
||||
fixed prime / generator group.
|
||||
"""
|
||||
|
||||
prime = Attribute(
|
||||
"An L{int} giving the prime number used in Diffie-Hellman key "
|
||||
"exchange, or L{None} if not applicable."
|
||||
)
|
||||
|
||||
generator = Attribute(
|
||||
"An L{int} giving the generator number used in Diffie-Hellman key "
|
||||
"exchange, or L{None} if not applicable. (This is not related to "
|
||||
"Python generator functions.)"
|
||||
)
|
||||
|
||||
|
||||
class _IEllipticCurveExchangeKexAlgorithm(_IKexAlgorithm):
|
||||
"""
|
||||
An L{_IEllipticCurveExchangeKexAlgorithm} describes a key exchange algorithm
|
||||
that uses an elliptic curve exchange between the client and server.
|
||||
"""
|
||||
|
||||
|
||||
class _IGroupExchangeKexAlgorithm(_IKexAlgorithm):
|
||||
"""
|
||||
An L{_IGroupExchangeKexAlgorithm} describes a key exchange algorithm
|
||||
that uses group exchange between the client and server.
|
||||
|
||||
A prime / generator group should be chosen at run time based on the
|
||||
requested size. See RFC 4419.
|
||||
"""
|
||||
|
||||
|
||||
@implementer(_IEllipticCurveExchangeKexAlgorithm)
|
||||
class _Curve25519SHA256:
|
||||
"""
|
||||
Elliptic Curve Key Exchange using Curve25519 and SHA256. Defined in
|
||||
U{https://datatracker.ietf.org/doc/draft-ietf-curdle-ssh-curves/}.
|
||||
"""
|
||||
|
||||
preference = 1
|
||||
hashProcessor = sha256
|
||||
|
||||
|
||||
@implementer(_IEllipticCurveExchangeKexAlgorithm)
|
||||
class _Curve25519SHA256LibSSH:
|
||||
"""
|
||||
As L{_Curve25519SHA256}, but with a pre-standardized algorithm name.
|
||||
"""
|
||||
|
||||
preference = 2
|
||||
hashProcessor = sha256
|
||||
|
||||
|
||||
@implementer(_IEllipticCurveExchangeKexAlgorithm)
|
||||
class _ECDH256:
|
||||
"""
|
||||
Elliptic Curve Key Exchange with SHA-256 as HASH. Defined in
|
||||
RFC 5656.
|
||||
|
||||
Note that C{ecdh-sha2-nistp256} takes priority over nistp384 or nistp512.
|
||||
This is the same priority from OpenSSH.
|
||||
|
||||
C{ecdh-sha2-nistp256} is considered preety good cryptography.
|
||||
If you need something better consider using C{curve25519-sha256}.
|
||||
"""
|
||||
|
||||
preference = 3
|
||||
hashProcessor = sha256
|
||||
|
||||
|
||||
@implementer(_IEllipticCurveExchangeKexAlgorithm)
|
||||
class _ECDH384:
|
||||
"""
|
||||
Elliptic Curve Key Exchange with SHA-384 as HASH. Defined in
|
||||
RFC 5656.
|
||||
"""
|
||||
|
||||
preference = 4
|
||||
hashProcessor = sha384
|
||||
|
||||
|
||||
@implementer(_IEllipticCurveExchangeKexAlgorithm)
|
||||
class _ECDH512:
|
||||
"""
|
||||
Elliptic Curve Key Exchange with SHA-512 as HASH. Defined in
|
||||
RFC 5656.
|
||||
"""
|
||||
|
||||
preference = 5
|
||||
hashProcessor = sha512
|
||||
|
||||
|
||||
@implementer(_IGroupExchangeKexAlgorithm)
|
||||
class _DHGroupExchangeSHA256:
|
||||
"""
|
||||
Diffie-Hellman Group and Key Exchange with SHA-256 as HASH. Defined in
|
||||
RFC 4419, 4.2.
|
||||
"""
|
||||
|
||||
preference = 6
|
||||
hashProcessor = sha256
|
||||
|
||||
|
||||
@implementer(_IGroupExchangeKexAlgorithm)
|
||||
class _DHGroupExchangeSHA1:
|
||||
"""
|
||||
Diffie-Hellman Group and Key Exchange with SHA-1 as HASH. Defined in
|
||||
RFC 4419, 4.1.
|
||||
"""
|
||||
|
||||
preference = 7
|
||||
hashProcessor = sha1
|
||||
|
||||
|
||||
@implementer(_IFixedGroupKexAlgorithm)
|
||||
class _DHGroup14SHA1:
|
||||
"""
|
||||
Diffie-Hellman key exchange with SHA-1 as HASH and Oakley Group 14
|
||||
(2048-bit MODP Group). Defined in RFC 4253, 8.2.
|
||||
"""
|
||||
|
||||
preference = 8
|
||||
hashProcessor = sha1
|
||||
# Diffie-Hellman primes from Oakley Group 14 (RFC 3526, 3).
|
||||
prime = int(
|
||||
"323170060713110073003389139264238282488179412411402391128420"
|
||||
"097514007417066343542226196894173635693471179017379097041917"
|
||||
"546058732091950288537589861856221532121754125149017745202702"
|
||||
"357960782362488842461894775876411059286460994117232454266225"
|
||||
"221932305409190376805242355191256797158701170010580558776510"
|
||||
"388618472802579760549035697325615261670813393617995413364765"
|
||||
"591603683178967290731783845896806396719009772021941686472258"
|
||||
"710314113364293195361934716365332097170774482279885885653692"
|
||||
"086452966360772502689555059283627511211740969729980684105543"
|
||||
"595848665832916421362182310789909994486524682624169720359118"
|
||||
"52507045361090559"
|
||||
)
|
||||
generator = 2
|
||||
|
||||
|
||||
# Which ECDH hash function to use is dependent on the size.
|
||||
_kexAlgorithms = {
|
||||
b"curve25519-sha256": _Curve25519SHA256(),
|
||||
b"curve25519-sha256@libssh.org": _Curve25519SHA256LibSSH(),
|
||||
b"diffie-hellman-group-exchange-sha256": _DHGroupExchangeSHA256(),
|
||||
b"diffie-hellman-group-exchange-sha1": _DHGroupExchangeSHA1(),
|
||||
b"diffie-hellman-group14-sha1": _DHGroup14SHA1(),
|
||||
b"ecdh-sha2-nistp256": _ECDH256(),
|
||||
b"ecdh-sha2-nistp384": _ECDH384(),
|
||||
b"ecdh-sha2-nistp521": _ECDH512(),
|
||||
}
|
||||
|
||||
|
||||
def getKex(kexAlgorithm):
|
||||
"""
|
||||
Get a description of a named key exchange algorithm.
|
||||
|
||||
@param kexAlgorithm: The key exchange algorithm name.
|
||||
@type kexAlgorithm: L{bytes}
|
||||
|
||||
@return: A description of the key exchange algorithm named by
|
||||
C{kexAlgorithm}.
|
||||
@rtype: L{_IKexAlgorithm}
|
||||
|
||||
@raises ConchError: if the key exchange algorithm is not found.
|
||||
"""
|
||||
if kexAlgorithm not in _kexAlgorithms:
|
||||
raise error.ConchError(f"Unsupported key exchange algorithm: {kexAlgorithm}")
|
||||
return _kexAlgorithms[kexAlgorithm]
|
||||
|
||||
|
||||
def isEllipticCurve(kexAlgorithm):
|
||||
"""
|
||||
Returns C{True} if C{kexAlgorithm} is an elliptic curve.
|
||||
|
||||
@param kexAlgorithm: The key exchange algorithm name.
|
||||
@type kexAlgorithm: C{str}
|
||||
|
||||
@return: C{True} if C{kexAlgorithm} is an elliptic curve,
|
||||
otherwise C{False}.
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
return _IEllipticCurveExchangeKexAlgorithm.providedBy(getKex(kexAlgorithm))
|
||||
|
||||
|
||||
def isFixedGroup(kexAlgorithm):
|
||||
"""
|
||||
Returns C{True} if C{kexAlgorithm} has a fixed prime / generator group.
|
||||
|
||||
@param kexAlgorithm: The key exchange algorithm name.
|
||||
@type kexAlgorithm: L{bytes}
|
||||
|
||||
@return: C{True} if C{kexAlgorithm} has a fixed prime / generator group,
|
||||
otherwise C{False}.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return _IFixedGroupKexAlgorithm.providedBy(getKex(kexAlgorithm))
|
||||
|
||||
|
||||
def getHashProcessor(kexAlgorithm):
|
||||
"""
|
||||
Get the hash algorithm callable to use in key exchange.
|
||||
|
||||
@param kexAlgorithm: The key exchange algorithm name.
|
||||
@type kexAlgorithm: L{bytes}
|
||||
|
||||
@return: A callable hash algorithm constructor (e.g. C{hashlib.sha256}).
|
||||
@rtype: C{callable}
|
||||
"""
|
||||
kex = getKex(kexAlgorithm)
|
||||
return kex.hashProcessor
|
||||
|
||||
|
||||
def getDHGeneratorAndPrime(kexAlgorithm):
|
||||
"""
|
||||
Get the generator and the prime to use in key exchange.
|
||||
|
||||
@param kexAlgorithm: The key exchange algorithm name.
|
||||
@type kexAlgorithm: L{bytes}
|
||||
|
||||
@return: A L{tuple} containing L{int} generator and L{int} prime.
|
||||
@rtype: L{tuple}
|
||||
"""
|
||||
kex = getKex(kexAlgorithm)
|
||||
return kex.generator, kex.prime
|
||||
|
||||
|
||||
def getSupportedKeyExchanges():
|
||||
"""
|
||||
Get a list of supported key exchange algorithm names in order of
|
||||
preference.
|
||||
|
||||
@return: A C{list} of supported key exchange algorithm names.
|
||||
@rtype: C{list} of L{bytes}
|
||||
"""
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
|
||||
from twisted.conch.ssh.keys import _curveTable
|
||||
|
||||
backend = default_backend()
|
||||
kexAlgorithms = _kexAlgorithms.copy()
|
||||
for keyAlgorithm in list(kexAlgorithms):
|
||||
if keyAlgorithm.startswith(b"ecdh"):
|
||||
keyAlgorithmDsa = keyAlgorithm.replace(b"ecdh", b"ecdsa")
|
||||
supported = backend.elliptic_curve_exchange_algorithm_supported(
|
||||
ec.ECDH(), _curveTable[keyAlgorithmDsa]
|
||||
)
|
||||
elif keyAlgorithm.startswith(b"curve25519-sha256"):
|
||||
supported = backend.x25519_supported()
|
||||
else:
|
||||
supported = True
|
||||
if not supported:
|
||||
kexAlgorithms.pop(keyAlgorithm)
|
||||
return sorted(
|
||||
kexAlgorithms, key=lambda kexAlgorithm: kexAlgorithms[kexAlgorithm].preference
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_address -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Address object for SSH network connections.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
|
||||
@since: 12.1
|
||||
"""
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.python import util
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class SSHTransportAddress(util.FancyEqMixin):
|
||||
"""
|
||||
Object representing an SSH Transport endpoint.
|
||||
|
||||
This is used to ensure that any code inspecting this address and
|
||||
attempting to construct a similar connection based upon it is not
|
||||
mislead into creating a transport which is not similar to the one it is
|
||||
indicating.
|
||||
|
||||
@ivar address: An instance of an object which implements I{IAddress} to
|
||||
which this transport address is connected.
|
||||
"""
|
||||
|
||||
compareAttributes = ("address",)
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SSHTransportAddress({self.address!r})"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(("SSH", self.address))
|
||||
278
.venv/lib/python3.12/site-packages/twisted/conch/ssh/agent.py
Normal file
278
.venv/lib/python3.12/site-packages/twisted/conch/ssh/agent.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implements the SSH v2 key agent protocol. This protocol is documented in the
|
||||
SSH source code, in the file
|
||||
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch.error import ConchError, MissingKeyStoreError
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.conch.ssh.common import NS, getMP, getNS
|
||||
from twisted.internet import defer, protocol
|
||||
|
||||
|
||||
class SSHAgentClient(protocol.Protocol):
|
||||
"""
|
||||
The client side of the SSH agent protocol. This is equivalent to
|
||||
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
|
||||
protocol, also in this package.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buf = b""
|
||||
self.deferreds = []
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
while 1:
|
||||
if len(self.buf) <= 4:
|
||||
return
|
||||
packLen = struct.unpack("!L", self.buf[:4])[0]
|
||||
if len(self.buf) < 4 + packLen:
|
||||
return
|
||||
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
|
||||
reqType = ord(packet[0:1])
|
||||
d = self.deferreds.pop(0)
|
||||
if reqType == AGENT_FAILURE:
|
||||
d.errback(ConchError("agent failure"))
|
||||
elif reqType == AGENT_SUCCESS:
|
||||
d.callback(b"")
|
||||
else:
|
||||
d.callback(packet)
|
||||
|
||||
def sendRequest(self, reqType, data):
|
||||
pack = struct.pack("!LB", len(data) + 1, reqType) + data
|
||||
self.transport.write(pack)
|
||||
d = defer.Deferred()
|
||||
self.deferreds.append(d)
|
||||
return d
|
||||
|
||||
def requestIdentities(self):
|
||||
"""
|
||||
@return: A L{Deferred} which will fire with a list of all keys found in
|
||||
the SSH agent. The list of keys is comprised of (public key blob,
|
||||
comment) tuples.
|
||||
"""
|
||||
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, b"")
|
||||
d.addCallback(self._cbRequestIdentities)
|
||||
return d
|
||||
|
||||
def _cbRequestIdentities(self, data):
|
||||
"""
|
||||
Unpack a collection of identities into a list of tuples comprised of
|
||||
public key blobs and comments.
|
||||
"""
|
||||
if ord(data[0:1]) != AGENT_IDENTITIES_ANSWER:
|
||||
raise ConchError("unexpected response: %i" % ord(data[0:1]))
|
||||
numKeys = struct.unpack("!L", data[1:5])[0]
|
||||
result = []
|
||||
data = data[5:]
|
||||
for i in range(numKeys):
|
||||
blob, data = getNS(data)
|
||||
comment, data = getNS(data)
|
||||
result.append((blob, comment))
|
||||
return result
|
||||
|
||||
def addIdentity(self, blob, comment=b""):
|
||||
"""
|
||||
Add a private key blob to the agent's collection of keys.
|
||||
"""
|
||||
req = blob
|
||||
req += NS(comment)
|
||||
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
|
||||
|
||||
def signData(self, blob, data):
|
||||
"""
|
||||
Request that the agent sign the given C{data} with the private key
|
||||
which corresponds to the public key given by C{blob}. The private
|
||||
key should have been added to the agent already.
|
||||
|
||||
@type blob: L{bytes}
|
||||
@type data: L{bytes}
|
||||
@return: A L{Deferred} which fires with a signature for given data
|
||||
created with the given key.
|
||||
"""
|
||||
req = NS(blob)
|
||||
req += NS(data)
|
||||
req += b"\000\000\000\000" # flags
|
||||
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
|
||||
|
||||
def _cbSignData(self, data):
|
||||
if ord(data[0:1]) != AGENT_SIGN_RESPONSE:
|
||||
raise ConchError("unexpected data: %i" % ord(data[0:1]))
|
||||
signature = getNS(data[1:])[0]
|
||||
return signature
|
||||
|
||||
def removeIdentity(self, blob):
|
||||
"""
|
||||
Remove the private key corresponding to the public key in blob from the
|
||||
running agent.
|
||||
"""
|
||||
req = NS(blob)
|
||||
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
|
||||
|
||||
def removeAllIdentities(self):
|
||||
"""
|
||||
Remove all keys from the running agent.
|
||||
"""
|
||||
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, b"")
|
||||
|
||||
|
||||
class SSHAgentServer(protocol.Protocol):
|
||||
"""
|
||||
The server side of the SSH agent protocol. This is equivalent to
|
||||
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
|
||||
protocol, also in this package.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buf = b""
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
while 1:
|
||||
if len(self.buf) <= 4:
|
||||
return
|
||||
packLen = struct.unpack("!L", self.buf[:4])[0]
|
||||
if len(self.buf) < 4 + packLen:
|
||||
return
|
||||
packet, self.buf = self.buf[4 : 4 + packLen], self.buf[4 + packLen :]
|
||||
reqType = ord(packet[0:1])
|
||||
reqName = messages.get(reqType, None)
|
||||
if not reqName:
|
||||
self.sendResponse(AGENT_FAILURE, b"")
|
||||
else:
|
||||
f = getattr(self, "agentc_%s" % reqName)
|
||||
if getattr(self.factory, "keys", None) is None:
|
||||
self.sendResponse(AGENT_FAILURE, b"")
|
||||
raise MissingKeyStoreError()
|
||||
f(packet[1:])
|
||||
|
||||
def sendResponse(self, reqType, data):
|
||||
pack = struct.pack("!LB", len(data) + 1, reqType) + data
|
||||
self.transport.write(pack)
|
||||
|
||||
def agentc_REQUEST_IDENTITIES(self, data):
|
||||
"""
|
||||
Return all of the identities that have been added to the server
|
||||
"""
|
||||
assert data == b""
|
||||
numKeys = len(self.factory.keys)
|
||||
resp = []
|
||||
|
||||
resp.append(struct.pack("!L", numKeys))
|
||||
for key, comment in self.factory.keys.values():
|
||||
resp.append(NS(key.blob())) # yes, wrapped in an NS
|
||||
resp.append(NS(comment))
|
||||
self.sendResponse(AGENT_IDENTITIES_ANSWER, b"".join(resp))
|
||||
|
||||
def agentc_SIGN_REQUEST(self, data):
|
||||
"""
|
||||
Data is a structure with a reference to an already added key object and
|
||||
some data that the clients wants signed with that key. If the key
|
||||
object wasn't loaded, return AGENT_FAILURE, else return the signature.
|
||||
"""
|
||||
blob, data = getNS(data)
|
||||
if blob not in self.factory.keys:
|
||||
return self.sendResponse(AGENT_FAILURE, b"")
|
||||
signData, data = getNS(data)
|
||||
assert data == b"\000\000\000\000"
|
||||
self.sendResponse(
|
||||
AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData))
|
||||
)
|
||||
|
||||
def agentc_ADD_IDENTITY(self, data):
|
||||
"""
|
||||
Adds a private key to the agent's collection of identities. On
|
||||
subsequent interactions, the private key can be accessed using only the
|
||||
corresponding public key.
|
||||
"""
|
||||
|
||||
# need to pre-read the key data so we can get past it to the comment string
|
||||
keyType, rest = getNS(data)
|
||||
if keyType == b"ssh-rsa":
|
||||
nmp = 6
|
||||
elif keyType == b"ssh-dss":
|
||||
nmp = 5
|
||||
else:
|
||||
raise keys.BadKeyError("unknown blob type: %s" % keyType)
|
||||
|
||||
rest = getMP(rest, nmp)[
|
||||
-1
|
||||
] # ignore the key data for now, we just want the comment
|
||||
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
|
||||
|
||||
k = keys.Key.fromString(data, type="private_blob") # not wrapped in NS here
|
||||
self.factory.keys[k.blob()] = (k, comment)
|
||||
self.sendResponse(AGENT_SUCCESS, b"")
|
||||
|
||||
def agentc_REMOVE_IDENTITY(self, data):
|
||||
"""
|
||||
Remove a specific key from the agent's collection of identities.
|
||||
"""
|
||||
blob, _ = getNS(data)
|
||||
k = keys.Key.fromString(blob, type="blob")
|
||||
del self.factory.keys[k.blob()]
|
||||
self.sendResponse(AGENT_SUCCESS, b"")
|
||||
|
||||
def agentc_REMOVE_ALL_IDENTITIES(self, data):
|
||||
"""
|
||||
Remove all keys from the agent's collection of identities.
|
||||
"""
|
||||
assert data == b""
|
||||
self.factory.keys = {}
|
||||
self.sendResponse(AGENT_SUCCESS, b"")
|
||||
|
||||
# v1 messages that we ignore because we don't keep v1 keys
|
||||
# open-ssh sends both v1 and v2 commands, so we have to
|
||||
# do no-ops for v1 commands or we'll get "bad request" errors
|
||||
|
||||
def agentc_REQUEST_RSA_IDENTITIES(self, data):
|
||||
"""
|
||||
v1 message for listing RSA1 keys; superseded by
|
||||
agentc_REQUEST_IDENTITIES, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack("!L", 0))
|
||||
|
||||
def agentc_REMOVE_RSA_IDENTITY(self, data):
|
||||
"""
|
||||
v1 message for removing RSA1 keys; superseded by
|
||||
agentc_REMOVE_IDENTITY, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_SUCCESS, b"")
|
||||
|
||||
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
|
||||
"""
|
||||
v1 message for removing all RSA1 keys; superseded by
|
||||
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_SUCCESS, b"")
|
||||
|
||||
|
||||
AGENTC_REQUEST_RSA_IDENTITIES = 1
|
||||
AGENT_RSA_IDENTITIES_ANSWER = 2
|
||||
AGENT_FAILURE = 5
|
||||
AGENT_SUCCESS = 6
|
||||
|
||||
AGENTC_REMOVE_RSA_IDENTITY = 8
|
||||
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
|
||||
|
||||
AGENTC_REQUEST_IDENTITIES = 11
|
||||
AGENT_IDENTITIES_ANSWER = 12
|
||||
AGENTC_SIGN_REQUEST = 13
|
||||
AGENT_SIGN_RESPONSE = 14
|
||||
AGENTC_ADD_IDENTITY = 17
|
||||
AGENTC_REMOVE_IDENTITY = 18
|
||||
AGENTC_REMOVE_ALL_IDENTITIES = 19
|
||||
|
||||
messages = {}
|
||||
for name, value in locals().copy().items():
|
||||
if name[:7] == "AGENTC_":
|
||||
messages[value] = name[7:] # doesn't handle doubles
|
||||
312
.venv/lib/python3.12/site-packages/twisted/conch/ssh/channel.py
Normal file
312
.venv/lib/python3.12/site-packages/twisted/conch/ssh/channel.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_channel -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
The parent class for all the SSH Channels. Currently implemented channels
|
||||
are session, direct-tcp, and forwarded-tcp.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces
|
||||
from twisted.logger import Logger
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport)
|
||||
class SSHChannel(log.Logger):
|
||||
"""
|
||||
A class that represents a multiplexed channel over an SSH connection.
|
||||
The channel has a local window which is the maximum amount of data it will
|
||||
receive, and a remote which is the maximum amount of data the remote side
|
||||
will accept. There is also a maximum packet size for any individual data
|
||||
packet going each way.
|
||||
|
||||
@ivar name: the name of the channel.
|
||||
@type name: L{bytes}
|
||||
@ivar localWindowSize: the maximum size of the local window in bytes.
|
||||
@type localWindowSize: L{int}
|
||||
@ivar localWindowLeft: how many bytes are left in the local window.
|
||||
@type localWindowLeft: L{int}
|
||||
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
|
||||
@type localMaxPacket: L{int}
|
||||
@ivar remoteWindowLeft: how many bytes are left in the remote window.
|
||||
@type remoteWindowLeft: L{int}
|
||||
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
|
||||
accept in bytes.
|
||||
@type remoteMaxPacket: L{int}
|
||||
@ivar conn: the connection this channel is multiplexed through.
|
||||
@type conn: L{SSHConnection}
|
||||
@ivar data: any data to send to the other side when the channel is
|
||||
requested.
|
||||
@type data: L{bytes}
|
||||
@ivar avatar: an avatar for the logged-in user (if a server channel)
|
||||
@ivar localClosed: True if we aren't accepting more data.
|
||||
@type localClosed: L{bool}
|
||||
@ivar remoteClosed: True if the other side isn't accepting more data.
|
||||
@type remoteClosed: L{bool}
|
||||
"""
|
||||
|
||||
_log = Logger()
|
||||
name: bytes = None # type: ignore[assignment] # only needed for client channels
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
localWindow=0,
|
||||
localMaxPacket=0,
|
||||
remoteWindow=0,
|
||||
remoteMaxPacket=0,
|
||||
conn=None,
|
||||
data=None,
|
||||
avatar=None,
|
||||
):
|
||||
self.localWindowSize = localWindow or 131072
|
||||
self.localWindowLeft = self.localWindowSize
|
||||
self.localMaxPacket = localMaxPacket or 32768
|
||||
self.remoteWindowLeft = remoteWindow
|
||||
self.remoteMaxPacket = remoteMaxPacket
|
||||
self.areWriting = 1
|
||||
self.conn = conn
|
||||
self.data = data
|
||||
self.avatar = avatar
|
||||
self.specificData = b""
|
||||
self.buf = b""
|
||||
self.extBuf = []
|
||||
self.closing = 0
|
||||
self.localClosed = 0
|
||||
self.remoteClosed = 0
|
||||
self.id = None # gets set later by SSHConnection
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__bytes__().decode("ascii")
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Return a byte string representation of the channel
|
||||
"""
|
||||
name = self.name
|
||||
if not name:
|
||||
name = b"None"
|
||||
|
||||
return b"<SSHChannel %b (lw %d rw %d)>" % (
|
||||
name,
|
||||
self.localWindowLeft,
|
||||
self.remoteWindowLeft,
|
||||
)
|
||||
|
||||
def logPrefix(self):
|
||||
id = (self.id is not None and str(self.id)) or "unknown"
|
||||
if self.name:
|
||||
name = self.name.decode("ascii")
|
||||
else:
|
||||
name = "None"
|
||||
return f"SSHChannel {name} ({id}) on {self.conn.logPrefix()}"
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
Called when the channel is opened. specificData is any data that the
|
||||
other side sent us when opening the channel.
|
||||
|
||||
@type specificData: L{bytes}
|
||||
"""
|
||||
self._log.info("channel open")
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
Called when the open failed for some reason.
|
||||
reason.desc is a string descrption, reason.code the SSH error code.
|
||||
|
||||
@type reason: L{error.ConchError}
|
||||
"""
|
||||
self._log.error("other side refused open\nreason: {reason}", reason=reason)
|
||||
|
||||
def addWindowBytes(self, data):
|
||||
"""
|
||||
Called when bytes are added to the remote window. By default it clears
|
||||
the data buffers.
|
||||
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self.remoteWindowLeft = self.remoteWindowLeft + data
|
||||
if not self.areWriting and not self.closing:
|
||||
self.areWriting = True
|
||||
self.startWriting()
|
||||
if self.buf:
|
||||
b = self.buf
|
||||
self.buf = b""
|
||||
self.write(b)
|
||||
if self.extBuf:
|
||||
b = self.extBuf
|
||||
self.extBuf = []
|
||||
for type, data in b:
|
||||
self.writeExtended(type, data)
|
||||
|
||||
def requestReceived(self, requestType, data):
|
||||
"""
|
||||
Called when a request is sent to this channel. By default it delegates
|
||||
to self.request_<requestType>.
|
||||
If this function returns true, the request succeeded, otherwise it
|
||||
failed.
|
||||
|
||||
@type requestType: L{bytes}
|
||||
@type data: L{bytes}
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
foo = requestType.replace(b"-", b"_").decode("ascii")
|
||||
f = getattr(self, "request_" + foo, None)
|
||||
if f:
|
||||
return f(data)
|
||||
self._log.info("unhandled request for {requestType}", requestType=requestType)
|
||||
return 0
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Called when we receive data.
|
||||
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self._log.debug("got data {data}", data=data)
|
||||
|
||||
def extReceived(self, dataType, data):
|
||||
"""
|
||||
Called when we receive extended data (usually standard error).
|
||||
|
||||
@type dataType: L{int}
|
||||
@type data: L{str}
|
||||
"""
|
||||
self._log.debug(
|
||||
"got extended data {dataType} {data!r}", dataType=dataType, data=data
|
||||
)
|
||||
|
||||
def eofReceived(self):
|
||||
"""
|
||||
Called when the other side will send no more data.
|
||||
"""
|
||||
self._log.info("remote eof")
|
||||
|
||||
def closeReceived(self):
|
||||
"""
|
||||
Called when the other side has closed the channel.
|
||||
"""
|
||||
self._log.info("remote close")
|
||||
self.loseConnection()
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
Called when the channel is closed. This means that both our side and
|
||||
the remote side have closed the channel.
|
||||
"""
|
||||
self._log.info("closed")
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
Write some data to the channel. If there is not enough remote window
|
||||
available, buffer until it is. Otherwise, split the data into
|
||||
packets of length remoteMaxPacket and send them.
|
||||
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
if self.buf:
|
||||
self.buf += data
|
||||
return
|
||||
top = len(data)
|
||||
if top > self.remoteWindowLeft:
|
||||
data, self.buf = (
|
||||
data[: self.remoteWindowLeft],
|
||||
data[self.remoteWindowLeft :],
|
||||
)
|
||||
self.areWriting = 0
|
||||
self.stopWriting()
|
||||
top = self.remoteWindowLeft
|
||||
rmp = self.remoteMaxPacket
|
||||
write = self.conn.sendData
|
||||
r = range(0, top, rmp)
|
||||
for offset in r:
|
||||
write(self, data[offset : offset + rmp])
|
||||
self.remoteWindowLeft -= top
|
||||
if self.closing and not self.buf:
|
||||
self.loseConnection() # try again
|
||||
|
||||
def writeExtended(self, dataType, data):
|
||||
"""
|
||||
Send extended data to this channel. If there is not enough remote
|
||||
window available, buffer until there is. Otherwise, split the data
|
||||
into packets of length remoteMaxPacket and send them.
|
||||
|
||||
@type dataType: L{int}
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
if self.extBuf:
|
||||
if self.extBuf[-1][0] == dataType:
|
||||
self.extBuf[-1][1] += data
|
||||
else:
|
||||
self.extBuf.append([dataType, data])
|
||||
return
|
||||
if len(data) > self.remoteWindowLeft:
|
||||
data, self.extBuf = (
|
||||
data[: self.remoteWindowLeft],
|
||||
[[dataType, data[self.remoteWindowLeft :]]],
|
||||
)
|
||||
self.areWriting = 0
|
||||
self.stopWriting()
|
||||
while len(data) > self.remoteMaxPacket:
|
||||
self.conn.sendExtendedData(self, dataType, data[: self.remoteMaxPacket])
|
||||
data = data[self.remoteMaxPacket :]
|
||||
self.remoteWindowLeft -= self.remoteMaxPacket
|
||||
if data:
|
||||
self.conn.sendExtendedData(self, dataType, data)
|
||||
self.remoteWindowLeft -= len(data)
|
||||
if self.closing:
|
||||
self.loseConnection() # try again
|
||||
|
||||
def writeSequence(self, data):
|
||||
"""
|
||||
Part of the Transport interface. Write a list of strings to the
|
||||
channel.
|
||||
|
||||
@type data: C{list} of L{str}
|
||||
"""
|
||||
self.write(b"".join(data))
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Close the channel if there is no buferred data. Otherwise, note the
|
||||
request and return.
|
||||
"""
|
||||
self.closing = 1
|
||||
if not self.buf and not self.extBuf:
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def getPeer(self):
|
||||
"""
|
||||
See: L{ITransport.getPeer}
|
||||
|
||||
@return: The remote address of this connection.
|
||||
@rtype: L{SSHTransportAddress}.
|
||||
"""
|
||||
return self.conn.transport.getPeer()
|
||||
|
||||
def getHost(self):
|
||||
"""
|
||||
See: L{ITransport.getHost}
|
||||
|
||||
@return: An address describing this side of the connection.
|
||||
@rtype: L{SSHTransportAddress}.
|
||||
"""
|
||||
return self.conn.transport.getHost()
|
||||
|
||||
def stopWriting(self):
|
||||
"""
|
||||
Called when the remote buffer is full, as a hint to stop writing.
|
||||
This can be ignored, but it can be helpful.
|
||||
"""
|
||||
|
||||
def startWriting(self):
|
||||
"""
|
||||
Called when the remote buffer has more room, as a hint to continue
|
||||
writing.
|
||||
"""
|
||||
@@ -0,0 +1,85 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_ssh -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Common functions for the SSH classes.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
from cryptography.utils import int_to_bytes
|
||||
|
||||
from twisted.python.deprecate import deprecated
|
||||
from twisted.python.versions import Version
|
||||
|
||||
__all__ = ["NS", "getNS", "MP", "getMP", "ffs"]
|
||||
|
||||
|
||||
def NS(t):
|
||||
"""
|
||||
net string
|
||||
"""
|
||||
if isinstance(t, str):
|
||||
t = t.encode("utf-8")
|
||||
return struct.pack("!L", len(t)) + t
|
||||
|
||||
|
||||
def getNS(s, count=1):
|
||||
"""
|
||||
get net string
|
||||
"""
|
||||
ns = []
|
||||
c = 0
|
||||
for i in range(count):
|
||||
(l,) = struct.unpack("!L", s[c : c + 4])
|
||||
ns.append(s[c + 4 : 4 + l + c])
|
||||
c += 4 + l
|
||||
return tuple(ns) + (s[c:],)
|
||||
|
||||
|
||||
def MP(number):
|
||||
if number == 0:
|
||||
return b"\000" * 4
|
||||
assert number > 0
|
||||
bn = int_to_bytes(number)
|
||||
if ord(bn[0:1]) & 128:
|
||||
bn = b"\000" + bn
|
||||
return struct.pack(">L", len(bn)) + bn
|
||||
|
||||
|
||||
def getMP(data, count=1):
|
||||
"""
|
||||
Get multiple precision integer out of the string. A multiple precision
|
||||
integer is stored as a 4-byte length followed by length bytes of the
|
||||
integer. If count is specified, get count integers out of the string.
|
||||
The return value is a tuple of count integers followed by the rest of
|
||||
the data.
|
||||
"""
|
||||
mp = []
|
||||
c = 0
|
||||
for i in range(count):
|
||||
(length,) = struct.unpack(">L", data[c : c + 4])
|
||||
mp.append(int.from_bytes(data[c + 4 : c + 4 + length], "big"))
|
||||
c += 4 + length
|
||||
return tuple(mp) + (data[c:],)
|
||||
|
||||
|
||||
def ffs(c, s):
|
||||
"""
|
||||
first from second
|
||||
goes through the first list, looking for items in the second, returns the first one
|
||||
"""
|
||||
for i in c:
|
||||
if i in s:
|
||||
return i
|
||||
|
||||
|
||||
@deprecated(Version("Twisted", 16, 5, 0))
|
||||
def install():
|
||||
# This used to install gmpy, but is technically public API, so just do
|
||||
# nothing.
|
||||
pass
|
||||
@@ -0,0 +1,679 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_connection -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of the ssh-connection service, which
|
||||
allows access to the shell and port-forwarding.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import string
|
||||
import struct
|
||||
|
||||
import twisted.internet.error
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import common, service
|
||||
from twisted.internet import defer
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.compat import nativeString, networkString
|
||||
|
||||
|
||||
class SSHConnection(service.SSHService):
|
||||
"""
|
||||
An implementation of the 'ssh-connection' service. It is used to
|
||||
multiplex multiple channels over the single SSH connection.
|
||||
|
||||
@ivar localChannelID: the next number to use as a local channel ID.
|
||||
@type localChannelID: L{int}
|
||||
@ivar channels: a L{dict} mapping a local channel ID to C{SSHChannel}
|
||||
subclasses.
|
||||
@type channels: L{dict}
|
||||
@ivar localToRemoteChannel: a L{dict} mapping a local channel ID to a
|
||||
remote channel ID.
|
||||
@type localToRemoteChannel: L{dict}
|
||||
@ivar channelsToRemoteChannel: a L{dict} mapping a C{SSHChannel} subclass
|
||||
to remote channel ID.
|
||||
@type channelsToRemoteChannel: L{dict}
|
||||
@ivar deferreds: a L{dict} mapping a local channel ID to a C{list} of
|
||||
C{Deferreds} for outstanding channel requests. Also, the 'global'
|
||||
key stores the C{list} of pending global request C{Deferred}s.
|
||||
"""
|
||||
|
||||
name = b"ssh-connection"
|
||||
_log = Logger()
|
||||
|
||||
def __init__(self):
|
||||
self.localChannelID = 0 # this is the current # to use for channel ID
|
||||
# local channel ID -> remote channel ID
|
||||
self.localToRemoteChannel = {}
|
||||
# local channel ID -> subclass of SSHChannel
|
||||
self.channels = {}
|
||||
# subclass of SSHChannel -> remote channel ID
|
||||
self.channelsToRemoteChannel = {}
|
||||
# local channel -> list of deferreds for pending requests
|
||||
# or 'global' -> list of deferreds for global requests
|
||||
self.deferreds = {"global": []}
|
||||
|
||||
self.transport = None # gets set later
|
||||
|
||||
def serviceStarted(self):
|
||||
if hasattr(self.transport, "avatar"):
|
||||
self.transport.avatar.conn = self
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
Called when the connection is stopped.
|
||||
"""
|
||||
# Close any fully open channels
|
||||
for channel in list(self.channelsToRemoteChannel.keys()):
|
||||
self.channelClosed(channel)
|
||||
# Indicate failure to any channels that were in the process of
|
||||
# opening but not yet open.
|
||||
while self.channels:
|
||||
(_, channel) = self.channels.popitem()
|
||||
channel.openFailed(twisted.internet.error.ConnectionLost())
|
||||
# Errback any unfinished global requests.
|
||||
self._cleanupGlobalDeferreds()
|
||||
|
||||
def _cleanupGlobalDeferreds(self):
|
||||
"""
|
||||
All pending requests that have returned a deferred must be errbacked
|
||||
when this service is stopped, otherwise they might be left uncalled and
|
||||
uncallable.
|
||||
"""
|
||||
for d in self.deferreds["global"]:
|
||||
d.errback(error.ConchError("Connection stopped."))
|
||||
del self.deferreds["global"][:]
|
||||
|
||||
# packet methods
|
||||
def ssh_GLOBAL_REQUEST(self, packet):
|
||||
"""
|
||||
The other side has made a global request. Payload::
|
||||
string request type
|
||||
bool want reply
|
||||
<request specific data>
|
||||
|
||||
This dispatches to self.gotGlobalRequest.
|
||||
"""
|
||||
requestType, rest = common.getNS(packet)
|
||||
wantReply, rest = ord(rest[0:1]), rest[1:]
|
||||
ret = self.gotGlobalRequest(requestType, rest)
|
||||
if wantReply:
|
||||
reply = MSG_REQUEST_FAILURE
|
||||
data = b""
|
||||
if ret:
|
||||
reply = MSG_REQUEST_SUCCESS
|
||||
if isinstance(ret, (tuple, list)):
|
||||
data = ret[1]
|
||||
self.transport.sendPacket(reply, data)
|
||||
|
||||
def ssh_REQUEST_SUCCESS(self, packet):
|
||||
"""
|
||||
Our global request succeeded. Get the appropriate Deferred and call
|
||||
it back with the packet we received.
|
||||
"""
|
||||
self._log.debug("global request success")
|
||||
self.deferreds["global"].pop(0).callback(packet)
|
||||
|
||||
def ssh_REQUEST_FAILURE(self, packet):
|
||||
"""
|
||||
Our global request failed. Get the appropriate Deferred and errback
|
||||
it with the packet we received.
|
||||
"""
|
||||
self._log.debug("global request failure")
|
||||
self.deferreds["global"].pop(0).errback(
|
||||
error.ConchError("global request failed", packet)
|
||||
)
|
||||
|
||||
def ssh_CHANNEL_OPEN(self, packet):
|
||||
"""
|
||||
The other side wants to get a channel. Payload::
|
||||
string channel name
|
||||
uint32 remote channel number
|
||||
uint32 remote window size
|
||||
uint32 remote maximum packet size
|
||||
<channel specific data>
|
||||
|
||||
We get a channel from self.getChannel(), give it a local channel number
|
||||
and notify the other side. Then notify the channel by calling its
|
||||
channelOpen method.
|
||||
"""
|
||||
channelType, rest = common.getNS(packet)
|
||||
senderChannel, windowSize, maxPacket = struct.unpack(">3L", rest[:12])
|
||||
packet = rest[12:]
|
||||
try:
|
||||
channel = self.getChannel(channelType, windowSize, maxPacket, packet)
|
||||
localChannel = self.localChannelID
|
||||
self.localChannelID += 1
|
||||
channel.id = localChannel
|
||||
self.channels[localChannel] = channel
|
||||
self.channelsToRemoteChannel[channel] = senderChannel
|
||||
self.localToRemoteChannel[localChannel] = senderChannel
|
||||
openConfirmPacket = (
|
||||
struct.pack(
|
||||
">4L",
|
||||
senderChannel,
|
||||
localChannel,
|
||||
channel.localWindowSize,
|
||||
channel.localMaxPacket,
|
||||
)
|
||||
+ channel.specificData
|
||||
)
|
||||
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION, openConfirmPacket)
|
||||
channel.channelOpen(packet)
|
||||
except Exception as e:
|
||||
self._log.failure("channel open failed")
|
||||
if isinstance(e, error.ConchError):
|
||||
textualInfo, reason = e.args
|
||||
if isinstance(textualInfo, int):
|
||||
# See #3657 and #3071
|
||||
textualInfo, reason = reason, textualInfo
|
||||
else:
|
||||
reason = OPEN_CONNECT_FAILED
|
||||
textualInfo = "unknown failure"
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_OPEN_FAILURE,
|
||||
struct.pack(">2L", senderChannel, reason)
|
||||
+ common.NS(networkString(textualInfo))
|
||||
+ common.NS(b""),
|
||||
)
|
||||
|
||||
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
|
||||
"""
|
||||
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
|
||||
uint32 local channel number
|
||||
uint32 remote channel number
|
||||
uint32 remote window size
|
||||
uint32 remote maximum packet size
|
||||
<channel specific data>
|
||||
|
||||
Find the channel using the local channel number and notify its
|
||||
channelOpen method.
|
||||
"""
|
||||
(localChannel, remoteChannel, windowSize, maxPacket) = struct.unpack(
|
||||
">4L", packet[:16]
|
||||
)
|
||||
specificData = packet[16:]
|
||||
channel = self.channels[localChannel]
|
||||
channel.conn = self
|
||||
self.localToRemoteChannel[localChannel] = remoteChannel
|
||||
self.channelsToRemoteChannel[channel] = remoteChannel
|
||||
channel.remoteWindowLeft = windowSize
|
||||
channel.remoteMaxPacket = maxPacket
|
||||
channel.channelOpen(specificData)
|
||||
|
||||
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
|
||||
"""
|
||||
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
|
||||
uint32 local channel number
|
||||
uint32 reason code
|
||||
string reason description
|
||||
|
||||
Find the channel using the local channel number and notify it by
|
||||
calling its openFailed() method.
|
||||
"""
|
||||
localChannel, reasonCode = struct.unpack(">2L", packet[:8])
|
||||
reasonDesc = common.getNS(packet[8:])[0]
|
||||
channel = self.channels[localChannel]
|
||||
del self.channels[localChannel]
|
||||
channel.conn = self
|
||||
reason = error.ConchError(reasonDesc, reasonCode)
|
||||
channel.openFailed(reason)
|
||||
|
||||
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
|
||||
"""
|
||||
The other side is adding bytes to its window. Payload::
|
||||
uint32 local channel number
|
||||
uint32 bytes to add
|
||||
|
||||
Call the channel's addWindowBytes() method to add new bytes to the
|
||||
remote window.
|
||||
"""
|
||||
localChannel, bytesToAdd = struct.unpack(">2L", packet[:8])
|
||||
channel = self.channels[localChannel]
|
||||
channel.addWindowBytes(bytesToAdd)
|
||||
|
||||
def ssh_CHANNEL_DATA(self, packet):
|
||||
"""
|
||||
The other side is sending us data. Payload::
|
||||
uint32 local channel number
|
||||
string data
|
||||
|
||||
Check to make sure the other side hasn't sent too much data (more
|
||||
than what's in the window, or more than the maximum packet size). If
|
||||
they have, close the channel. Otherwise, decrease the available
|
||||
window and pass the data to the channel's dataReceived().
|
||||
"""
|
||||
localChannel, dataLength = struct.unpack(">2L", packet[:8])
|
||||
channel = self.channels[localChannel]
|
||||
# XXX should this move to dataReceived to put client in charge?
|
||||
if (
|
||||
dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket
|
||||
): # more data than we want
|
||||
self._log.error("too much data")
|
||||
self.sendClose(channel)
|
||||
return
|
||||
# packet = packet[:channel.localWindowLeft+4]
|
||||
data = common.getNS(packet[4:])[0]
|
||||
channel.localWindowLeft -= dataLength
|
||||
if channel.localWindowLeft < channel.localWindowSize // 2:
|
||||
self.adjustWindow(
|
||||
channel, channel.localWindowSize - channel.localWindowLeft
|
||||
)
|
||||
channel.dataReceived(data)
|
||||
|
||||
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
|
||||
"""
|
||||
The other side is sending us exteneded data. Payload::
|
||||
uint32 local channel number
|
||||
uint32 type code
|
||||
string data
|
||||
|
||||
Check to make sure the other side hasn't sent too much data (more
|
||||
than what's in the window, or than the maximum packet size). If
|
||||
they have, close the channel. Otherwise, decrease the available
|
||||
window and pass the data and type code to the channel's
|
||||
extReceived().
|
||||
"""
|
||||
localChannel, typeCode, dataLength = struct.unpack(">3L", packet[:12])
|
||||
channel = self.channels[localChannel]
|
||||
if dataLength > channel.localWindowLeft or dataLength > channel.localMaxPacket:
|
||||
self._log.error("too much extdata")
|
||||
self.sendClose(channel)
|
||||
return
|
||||
data = common.getNS(packet[8:])[0]
|
||||
channel.localWindowLeft -= dataLength
|
||||
if channel.localWindowLeft < channel.localWindowSize // 2:
|
||||
self.adjustWindow(
|
||||
channel, channel.localWindowSize - channel.localWindowLeft
|
||||
)
|
||||
channel.extReceived(typeCode, data)
|
||||
|
||||
def ssh_CHANNEL_EOF(self, packet):
|
||||
"""
|
||||
The other side is not sending any more data. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Notify the channel by calling its eofReceived() method.
|
||||
"""
|
||||
localChannel = struct.unpack(">L", packet[:4])[0]
|
||||
channel = self.channels[localChannel]
|
||||
channel.eofReceived()
|
||||
|
||||
def ssh_CHANNEL_CLOSE(self, packet):
|
||||
"""
|
||||
The other side is closing its end; it does not want to receive any
|
||||
more data. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Notify the channnel by calling its closeReceived() method. If
|
||||
the channel has also sent a close message, call self.channelClosed().
|
||||
"""
|
||||
localChannel = struct.unpack(">L", packet[:4])[0]
|
||||
channel = self.channels[localChannel]
|
||||
channel.closeReceived()
|
||||
channel.remoteClosed = True
|
||||
if channel.localClosed and channel.remoteClosed:
|
||||
self.channelClosed(channel)
|
||||
|
||||
def ssh_CHANNEL_REQUEST(self, packet):
|
||||
"""
|
||||
The other side is sending a request to a channel. Payload::
|
||||
uint32 local channel number
|
||||
string request name
|
||||
bool want reply
|
||||
<request specific data>
|
||||
|
||||
Pass the message to the channel's requestReceived method. If the
|
||||
other side wants a reply, add callbacks which will send the
|
||||
reply.
|
||||
"""
|
||||
localChannel = struct.unpack(">L", packet[:4])[0]
|
||||
requestType, rest = common.getNS(packet[4:])
|
||||
wantReply = ord(rest[0:1])
|
||||
channel = self.channels[localChannel]
|
||||
d = defer.maybeDeferred(channel.requestReceived, requestType, rest[1:])
|
||||
if wantReply:
|
||||
d.addCallback(self._cbChannelRequest, localChannel)
|
||||
d.addErrback(self._ebChannelRequest, localChannel)
|
||||
return d
|
||||
|
||||
def _cbChannelRequest(self, result, localChannel):
|
||||
"""
|
||||
Called back if the other side wanted a reply to a channel request. If
|
||||
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
|
||||
a C{error.ConchError}
|
||||
|
||||
@param result: the value returned from the channel's requestReceived()
|
||||
method. If it's False, the request failed.
|
||||
@type result: L{bool}
|
||||
@param localChannel: the local channel ID of the channel to which the
|
||||
request was made.
|
||||
@type localChannel: L{int}
|
||||
@raises ConchError: if the result is False.
|
||||
"""
|
||||
if not result:
|
||||
raise error.ConchError("failed request")
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_SUCCESS,
|
||||
struct.pack(">L", self.localToRemoteChannel[localChannel]),
|
||||
)
|
||||
|
||||
def _ebChannelRequest(self, result, localChannel):
|
||||
"""
|
||||
Called if the other wisde wanted a reply to the channel requeset and
|
||||
the channel request failed.
|
||||
|
||||
@param result: a Failure, but it's not used.
|
||||
@param localChannel: the local channel ID of the channel to which the
|
||||
request was made.
|
||||
@type localChannel: L{int}
|
||||
"""
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_FAILURE,
|
||||
struct.pack(">L", self.localToRemoteChannel[localChannel]),
|
||||
)
|
||||
|
||||
def ssh_CHANNEL_SUCCESS(self, packet):
|
||||
"""
|
||||
Our channel request to the other side succeeded. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Get the C{Deferred} out of self.deferreds and call it back.
|
||||
"""
|
||||
localChannel = struct.unpack(">L", packet[:4])[0]
|
||||
if self.deferreds.get(localChannel):
|
||||
d = self.deferreds[localChannel].pop(0)
|
||||
d.callback("")
|
||||
|
||||
def ssh_CHANNEL_FAILURE(self, packet):
|
||||
"""
|
||||
Our channel request to the other side failed. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Get the C{Deferred} out of self.deferreds and errback it with a
|
||||
C{error.ConchError}.
|
||||
"""
|
||||
localChannel = struct.unpack(">L", packet[:4])[0]
|
||||
if self.deferreds.get(localChannel):
|
||||
d = self.deferreds[localChannel].pop(0)
|
||||
d.errback(error.ConchError("channel request failed"))
|
||||
|
||||
# methods for users of the connection to call
|
||||
|
||||
def sendGlobalRequest(self, request, data, wantReply=0):
|
||||
"""
|
||||
Send a global request for this connection. Current this is only used
|
||||
for remote->local TCP forwarding.
|
||||
|
||||
@type request: L{bytes}
|
||||
@type data: L{bytes}
|
||||
@type wantReply: L{bool}
|
||||
@rtype: C{Deferred}/L{None}
|
||||
"""
|
||||
self.transport.sendPacket(
|
||||
MSG_GLOBAL_REQUEST,
|
||||
common.NS(request) + (wantReply and b"\xff" or b"\x00") + data,
|
||||
)
|
||||
if wantReply:
|
||||
d = defer.Deferred()
|
||||
self.deferreds["global"].append(d)
|
||||
return d
|
||||
|
||||
def openChannel(self, channel, extra=b""):
|
||||
"""
|
||||
Open a new channel on this connection.
|
||||
|
||||
@type channel: subclass of C{SSHChannel}
|
||||
@type extra: L{bytes}
|
||||
"""
|
||||
self._log.info(
|
||||
"opening channel {id} with {localWindowSize} {localMaxPacket}",
|
||||
id=self.localChannelID,
|
||||
localWindowSize=channel.localWindowSize,
|
||||
localMaxPacket=channel.localMaxPacket,
|
||||
)
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_OPEN,
|
||||
common.NS(channel.name)
|
||||
+ struct.pack(
|
||||
">3L",
|
||||
self.localChannelID,
|
||||
channel.localWindowSize,
|
||||
channel.localMaxPacket,
|
||||
)
|
||||
+ extra,
|
||||
)
|
||||
channel.id = self.localChannelID
|
||||
self.channels[self.localChannelID] = channel
|
||||
self.localChannelID += 1
|
||||
|
||||
def sendRequest(self, channel, requestType, data, wantReply=0):
|
||||
"""
|
||||
Send a request to a channel.
|
||||
|
||||
@type channel: subclass of C{SSHChannel}
|
||||
@type requestType: L{bytes}
|
||||
@type data: L{bytes}
|
||||
@type wantReply: L{bool}
|
||||
@rtype: C{Deferred}/L{None}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return
|
||||
self._log.debug("sending request {requestType}", requestType=requestType)
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_REQUEST,
|
||||
struct.pack(">L", self.channelsToRemoteChannel[channel])
|
||||
+ common.NS(requestType)
|
||||
+ (b"\1" if wantReply else b"\0")
|
||||
+ data,
|
||||
)
|
||||
if wantReply:
|
||||
d = defer.Deferred()
|
||||
self.deferreds.setdefault(channel.id, []).append(d)
|
||||
return d
|
||||
|
||||
def adjustWindow(self, channel, bytesToAdd):
|
||||
"""
|
||||
Tell the other side that we will receive more data. This should not
|
||||
normally need to be called as it is managed automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type bytesToAdd: L{int}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
packet = struct.pack(">2L", self.channelsToRemoteChannel[channel], bytesToAdd)
|
||||
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, packet)
|
||||
self._log.debug(
|
||||
"adding {bytesToAdd} to {localWindowLeft} in channel {id}",
|
||||
bytesToAdd=bytesToAdd,
|
||||
localWindowLeft=channel.localWindowLeft,
|
||||
id=channel.id,
|
||||
)
|
||||
channel.localWindowLeft += bytesToAdd
|
||||
|
||||
def sendData(self, channel, data):
|
||||
"""
|
||||
Send data to a channel. This should not normally be used: instead use
|
||||
channel.write(data) as it manages the window automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_DATA,
|
||||
struct.pack(">L", self.channelsToRemoteChannel[channel]) + common.NS(data),
|
||||
)
|
||||
|
||||
def sendExtendedData(self, channel, dataType, data):
|
||||
"""
|
||||
Send extended data to a channel. This should not normally be used:
|
||||
instead use channel.writeExtendedData(data, dataType) as it manages
|
||||
the window automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type dataType: L{int}
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_EXTENDED_DATA,
|
||||
struct.pack(">2L", self.channelsToRemoteChannel[channel], dataType)
|
||||
+ common.NS(data),
|
||||
)
|
||||
|
||||
def sendEOF(self, channel):
|
||||
"""
|
||||
Send an EOF (End of File) for a channel.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self._log.debug("sending eof")
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_EOF, struct.pack(">L", self.channelsToRemoteChannel[channel])
|
||||
)
|
||||
|
||||
def sendClose(self, channel):
|
||||
"""
|
||||
Close a channel.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self._log.info("sending close {id}", id=channel.id)
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_CLOSE, struct.pack(">L", self.channelsToRemoteChannel[channel])
|
||||
)
|
||||
channel.localClosed = True
|
||||
if channel.localClosed and channel.remoteClosed:
|
||||
self.channelClosed(channel)
|
||||
|
||||
# methods to override
|
||||
def getChannel(self, channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side requested a channel of some sort.
|
||||
channelType is the type of channel being requested,
|
||||
windowSize is the initial size of the remote window,
|
||||
maxPacket is the largest packet we should send,
|
||||
data is any other packet data (often nothing).
|
||||
|
||||
We return a subclass of L{SSHChannel}.
|
||||
|
||||
By default, this dispatches to a method 'channel_channelType' with any
|
||||
non-alphanumerics in the channelType replace with _'s. If it cannot
|
||||
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
|
||||
The method is called with arguments of windowSize, maxPacket, data.
|
||||
|
||||
@type channelType: L{bytes}
|
||||
@type windowSize: L{int}
|
||||
@type maxPacket: L{int}
|
||||
@type data: L{bytes}
|
||||
@rtype: subclass of L{SSHChannel}/L{tuple}
|
||||
"""
|
||||
self._log.debug("got channel {channelType!r} request", channelType=channelType)
|
||||
if hasattr(self.transport, "avatar"): # this is a server!
|
||||
chan = self.transport.avatar.lookupChannel(
|
||||
channelType, windowSize, maxPacket, data
|
||||
)
|
||||
else:
|
||||
channelType = channelType.translate(TRANSLATE_TABLE)
|
||||
attr = "channel_%s" % nativeString(channelType)
|
||||
f = getattr(self, attr, None)
|
||||
if f is not None:
|
||||
chan = f(windowSize, maxPacket, data)
|
||||
else:
|
||||
chan = None
|
||||
if chan is None:
|
||||
raise error.ConchError("unknown channel", OPEN_UNKNOWN_CHANNEL_TYPE)
|
||||
else:
|
||||
chan.conn = self
|
||||
return chan
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
"""
|
||||
We got a global request. pretty much, this is just used by the client
|
||||
to request that we forward a port from the server to the client.
|
||||
Returns either:
|
||||
- 1: request accepted
|
||||
- 1, <data>: request accepted with request specific data
|
||||
- 0: request denied
|
||||
|
||||
By default, this dispatches to a method 'global_requestType' with
|
||||
-'s in requestType replaced with _'s. The found method is passed data.
|
||||
If this method cannot be found, this method returns 0. Otherwise, it
|
||||
returns the return value of that method.
|
||||
|
||||
@type requestType: L{bytes}
|
||||
@type data: L{bytes}
|
||||
@rtype: L{int}/L{tuple}
|
||||
"""
|
||||
self._log.debug("got global {requestType} request", requestType=requestType)
|
||||
if hasattr(self.transport, "avatar"): # this is a server!
|
||||
return self.transport.avatar.gotGlobalRequest(requestType, data)
|
||||
|
||||
requestType = nativeString(requestType.replace(b"-", b"_"))
|
||||
f = getattr(self, "global_%s" % requestType, None)
|
||||
if not f:
|
||||
return 0
|
||||
return f(data)
|
||||
|
||||
def channelClosed(self, channel):
|
||||
"""
|
||||
Called when a channel is closed.
|
||||
It clears the local state related to the channel, and calls
|
||||
channel.closed().
|
||||
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
|
||||
If you don't, things will break mysteriously.
|
||||
|
||||
@type channel: L{SSHChannel}
|
||||
"""
|
||||
if channel in self.channelsToRemoteChannel: # actually open
|
||||
channel.localClosed = channel.remoteClosed = True
|
||||
del self.localToRemoteChannel[channel.id]
|
||||
del self.channels[channel.id]
|
||||
del self.channelsToRemoteChannel[channel]
|
||||
for d in self.deferreds.pop(channel.id, []):
|
||||
d.errback(error.ConchError("Channel closed."))
|
||||
channel.closed()
|
||||
|
||||
|
||||
MSG_GLOBAL_REQUEST = 80
|
||||
MSG_REQUEST_SUCCESS = 81
|
||||
MSG_REQUEST_FAILURE = 82
|
||||
MSG_CHANNEL_OPEN = 90
|
||||
MSG_CHANNEL_OPEN_CONFIRMATION = 91
|
||||
MSG_CHANNEL_OPEN_FAILURE = 92
|
||||
MSG_CHANNEL_WINDOW_ADJUST = 93
|
||||
MSG_CHANNEL_DATA = 94
|
||||
MSG_CHANNEL_EXTENDED_DATA = 95
|
||||
MSG_CHANNEL_EOF = 96
|
||||
MSG_CHANNEL_CLOSE = 97
|
||||
MSG_CHANNEL_REQUEST = 98
|
||||
MSG_CHANNEL_SUCCESS = 99
|
||||
MSG_CHANNEL_FAILURE = 100
|
||||
|
||||
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
|
||||
OPEN_CONNECT_FAILED = 2
|
||||
OPEN_UNKNOWN_CHANNEL_TYPE = 3
|
||||
OPEN_RESOURCE_SHORTAGE = 4
|
||||
|
||||
# From RFC 4254
|
||||
EXTENDED_DATA_STDERR = 1
|
||||
|
||||
messages = {}
|
||||
for name, value in locals().copy().items():
|
||||
if name[:4] == "MSG_":
|
||||
messages[value] = name # Doesn't handle doubles
|
||||
|
||||
alphanums = networkString(string.ascii_letters + string.digits)
|
||||
TRANSLATE_TABLE = bytes(i if i in alphanums else ord("_") for i in range(256))
|
||||
SSHConnection.protocolMessages = messages
|
||||
129
.venv/lib/python3.12/site-packages/twisted/conch/ssh/factory.py
Normal file
129
.venv/lib/python3.12/site-packages/twisted/conch/ssh/factory.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
A Factory for SSH servers.
|
||||
|
||||
See also L{twisted.conch.openssh_compat.factory} for OpenSSH compatibility.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import random
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import _kex, connection, transport, userauth
|
||||
from twisted.internet import protocol
|
||||
from twisted.logger import Logger
|
||||
|
||||
|
||||
class SSHFactory(protocol.Factory):
|
||||
"""
|
||||
A Factory for SSH servers.
|
||||
"""
|
||||
|
||||
primes: Optional[Dict[int, List[Tuple[int, int]]]]
|
||||
|
||||
_log = Logger()
|
||||
protocol = transport.SSHServerTransport
|
||||
|
||||
services = {
|
||||
b"ssh-userauth": userauth.SSHUserAuthServer,
|
||||
b"ssh-connection": connection.SSHConnection,
|
||||
}
|
||||
|
||||
def startFactory(self) -> None:
|
||||
"""
|
||||
Check for public and private keys.
|
||||
"""
|
||||
if not hasattr(self, "publicKeys"):
|
||||
self.publicKeys = self.getPublicKeys()
|
||||
if not hasattr(self, "privateKeys"):
|
||||
self.privateKeys = self.getPrivateKeys()
|
||||
if not self.publicKeys or not self.privateKeys:
|
||||
raise error.ConchError("no host keys, failing")
|
||||
if not hasattr(self, "primes"):
|
||||
self.primes = self.getPrimes()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create an instance of the server side of the SSH protocol.
|
||||
|
||||
@type addr: L{twisted.internet.interfaces.IAddress} provider
|
||||
@param addr: The address at which the server will listen.
|
||||
|
||||
@rtype: L{twisted.conch.ssh.transport.SSHServerTransport}
|
||||
@return: The built transport.
|
||||
"""
|
||||
t = protocol.Factory.buildProtocol(self, addr)
|
||||
t.supportedPublicKeys = list(
|
||||
chain.from_iterable(
|
||||
key.supportedSignatureAlgorithms() for key in self.privateKeys.values()
|
||||
)
|
||||
)
|
||||
if not self.primes:
|
||||
self._log.info(
|
||||
"disabling non-fixed-group key exchange algorithms "
|
||||
"because we cannot find moduli file"
|
||||
)
|
||||
t.supportedKeyExchanges = [
|
||||
kexAlgorithm
|
||||
for kexAlgorithm in t.supportedKeyExchanges
|
||||
if _kex.isFixedGroup(kexAlgorithm) or _kex.isEllipticCurve(kexAlgorithm)
|
||||
]
|
||||
return t
|
||||
|
||||
def getPublicKeys(self):
|
||||
"""
|
||||
Called when the factory is started to get the public portions of the
|
||||
servers host keys. Returns a dictionary mapping SSH key types to
|
||||
public key strings.
|
||||
|
||||
@rtype: L{dict}
|
||||
"""
|
||||
raise NotImplementedError("getPublicKeys unimplemented")
|
||||
|
||||
def getPrivateKeys(self):
|
||||
"""
|
||||
Called when the factory is started to get the private portions of the
|
||||
servers host keys. Returns a dictionary mapping SSH key types to
|
||||
L{twisted.conch.ssh.keys.Key} objects.
|
||||
|
||||
@rtype: L{dict}
|
||||
"""
|
||||
raise NotImplementedError("getPrivateKeys unimplemented")
|
||||
|
||||
def getPrimes(self) -> Optional[Dict[int, List[Tuple[int, int]]]]:
|
||||
"""
|
||||
Called when the factory is started to get Diffie-Hellman generators and
|
||||
primes to use. Returns a dictionary mapping number of bits to lists of
|
||||
tuple of (generator, prime).
|
||||
"""
|
||||
|
||||
def getDHPrime(self, bits: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
|
||||
close to C{bits} bits as possible.
|
||||
"""
|
||||
|
||||
def keyfunc(i: int) -> int:
|
||||
return abs(i - bits)
|
||||
|
||||
assert self.primes is not None, "Factory should have been started by now."
|
||||
primesKeys = sorted(self.primes.keys(), key=keyfunc)
|
||||
realBits = primesKeys[0]
|
||||
return random.choice(self.primes[realBits])
|
||||
|
||||
def getService(self, transport, service):
|
||||
"""
|
||||
Return a class to use as a service for the given transport.
|
||||
|
||||
@type transport: L{transport.SSHServerTransport}
|
||||
@type service: L{bytes}
|
||||
@rtype: subclass of L{service.SSHService}
|
||||
"""
|
||||
if service == b"ssh-userauth" or hasattr(transport, "avatar"):
|
||||
return self.services[service]
|
||||
1069
.venv/lib/python3.12/site-packages/twisted/conch/ssh/filetransfer.py
Normal file
1069
.venv/lib/python3.12/site-packages/twisted/conch/ssh/filetransfer.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of the TCP forwarding, which allows
|
||||
clients and servers to forward arbitrary TCP data across the connection.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch.ssh import channel, common
|
||||
from twisted.internet import protocol, reactor
|
||||
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
|
||||
|
||||
|
||||
class SSHListenForwardingFactory(protocol.Factory):
|
||||
def __init__(self, connection, hostport, klass):
|
||||
self.conn = connection
|
||||
self.hostport = hostport # tuple
|
||||
self.klass = klass
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
channel = self.klass(conn=self.conn)
|
||||
client = SSHForwardingClient(channel)
|
||||
channel.client = client
|
||||
addrTuple = (addr.host, addr.port)
|
||||
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
|
||||
self.conn.openChannel(channel, channelOpenData)
|
||||
return client
|
||||
|
||||
|
||||
class SSHListenForwardingChannel(channel.SSHChannel):
|
||||
def channelOpen(self, specificData):
|
||||
self._log.info("opened forwarding channel {id}", id=self.id)
|
||||
if len(self.client.buf) > 1:
|
||||
b = self.client.buf[1:]
|
||||
self.write(b)
|
||||
self.client.buf = b""
|
||||
|
||||
def openFailed(self, reason):
|
||||
self.closed()
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.client.transport.write(data)
|
||||
|
||||
def eofReceived(self):
|
||||
self.client.transport.loseConnection()
|
||||
|
||||
def closed(self):
|
||||
if hasattr(self, "client"):
|
||||
self._log.info("closing local forwarding channel {id}", id=self.id)
|
||||
self.client.transport.loseConnection()
|
||||
del self.client
|
||||
|
||||
|
||||
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
|
||||
name = b"direct-tcpip"
|
||||
|
||||
|
||||
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
|
||||
name = b"forwarded-tcpip"
|
||||
|
||||
|
||||
class SSHConnectForwardingChannel(channel.SSHChannel):
|
||||
"""
|
||||
Channel used for handling server side forwarding request.
|
||||
It acts as a client for the remote forwarding destination.
|
||||
|
||||
@ivar hostport: C{(host, port)} requested by client as forwarding
|
||||
destination.
|
||||
@type hostport: L{tuple} or a C{sequence}
|
||||
|
||||
@ivar client: Protocol connected to the forwarding destination.
|
||||
@type client: L{protocol.Protocol}
|
||||
|
||||
@ivar clientBuf: Data received while forwarding channel is not yet
|
||||
connected.
|
||||
@type clientBuf: L{bytes}
|
||||
|
||||
@var _reactor: Reactor used for TCP connections.
|
||||
@type _reactor: A reactor.
|
||||
|
||||
@ivar _channelOpenDeferred: Deferred used in testing to check the
|
||||
result of C{channelOpen}.
|
||||
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
|
||||
"""
|
||||
|
||||
_reactor = reactor
|
||||
|
||||
def __init__(self, hostport, *args, **kw):
|
||||
channel.SSHChannel.__init__(self, *args, **kw)
|
||||
self.hostport = hostport
|
||||
self.client = None
|
||||
self.clientBuf = b""
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
self._log.info(
|
||||
"connecting to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
|
||||
)
|
||||
ep = HostnameEndpoint(self._reactor, self.hostport[0], self.hostport[1])
|
||||
d = connectProtocol(ep, SSHForwardingClient(self))
|
||||
d.addCallbacks(self._setClient, self._close)
|
||||
self._channelOpenDeferred = d
|
||||
|
||||
def _setClient(self, client):
|
||||
"""
|
||||
Called when the connection was established to the forwarding
|
||||
destination.
|
||||
|
||||
@param client: Client protocol connected to the forwarding destination.
|
||||
@type client: L{protocol.Protocol}
|
||||
"""
|
||||
self.client = client
|
||||
self._log.info(
|
||||
"connected to {host}:{port}", host=self.hostport[0], port=self.hostport[1]
|
||||
)
|
||||
if self.clientBuf:
|
||||
self.client.transport.write(self.clientBuf)
|
||||
self.clientBuf = None
|
||||
if self.client.buf[1:]:
|
||||
self.write(self.client.buf[1:])
|
||||
self.client.buf = b""
|
||||
|
||||
def _close(self, reason):
|
||||
"""
|
||||
Called when failed to connect to the forwarding destination.
|
||||
|
||||
@param reason: Reason why connection failed.
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
self._log.error(
|
||||
"failed to connect to {host}:{port}: {reason}",
|
||||
host=self.hostport[0],
|
||||
port=self.hostport[1],
|
||||
reason=reason,
|
||||
)
|
||||
self.loseConnection()
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
if self.client:
|
||||
self.client.transport.write(data)
|
||||
else:
|
||||
self.clientBuf += data
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
if self.client:
|
||||
self._log.info("closed remote forwarding channel {id}", id=self.id)
|
||||
if self.client.channel:
|
||||
self.loseConnection()
|
||||
self.client.transport.loseConnection()
|
||||
del self.client
|
||||
|
||||
|
||||
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
|
||||
remoteHP, origHP = unpackOpen_direct_tcpip(data)
|
||||
return SSHConnectForwardingChannel(
|
||||
remoteHP,
|
||||
remoteWindow=remoteWindow,
|
||||
remoteMaxPacket=remoteMaxPacket,
|
||||
avatar=avatar,
|
||||
)
|
||||
|
||||
|
||||
class SSHForwardingClient(protocol.Protocol):
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
self.buf = b"\000"
|
||||
|
||||
def dataReceived(self, data):
|
||||
if self.buf:
|
||||
self.buf += data
|
||||
else:
|
||||
self.channel.write(data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.channel:
|
||||
self.channel.loseConnection()
|
||||
self.channel = None
|
||||
|
||||
|
||||
def packOpen_direct_tcpip(destination, source):
|
||||
"""
|
||||
Pack the data suitable for sending in a CHANNEL_OPEN packet.
|
||||
|
||||
@type destination: L{tuple}
|
||||
@param destination: A tuple of the (host, port) of the destination host.
|
||||
|
||||
@type source: L{tuple}
|
||||
@param source: A tuple of the (host, port) of the source host.
|
||||
"""
|
||||
(connHost, connPort) = destination
|
||||
(origHost, origPort) = source
|
||||
if isinstance(connHost, str):
|
||||
connHost = connHost.encode("utf-8")
|
||||
if isinstance(origHost, str):
|
||||
origHost = origHost.encode("utf-8")
|
||||
conn = common.NS(connHost) + struct.pack(">L", connPort)
|
||||
orig = common.NS(origHost) + struct.pack(">L", origPort)
|
||||
return conn + orig
|
||||
|
||||
|
||||
packOpen_forwarded_tcpip = packOpen_direct_tcpip
|
||||
|
||||
|
||||
def unpackOpen_direct_tcpip(data):
|
||||
"""Unpack the data to a usable format."""
|
||||
connHost, rest = common.getNS(data)
|
||||
if isinstance(connHost, bytes):
|
||||
connHost = connHost.decode("utf-8")
|
||||
connPort = int(struct.unpack(">L", rest[:4])[0])
|
||||
origHost, rest = common.getNS(rest[4:])
|
||||
if isinstance(origHost, bytes):
|
||||
origHost = origHost.decode("utf-8")
|
||||
origPort = int(struct.unpack(">L", rest[:4])[0])
|
||||
return (connHost, connPort), (origHost, origPort)
|
||||
|
||||
|
||||
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
|
||||
|
||||
|
||||
def packGlobal_tcpip_forward(peer):
|
||||
"""
|
||||
Pack the data for tcpip forwarding.
|
||||
|
||||
@param peer: A tuple of the (host, port) .
|
||||
@type peer: L{tuple}
|
||||
"""
|
||||
(host, port) = peer
|
||||
return common.NS(host) + struct.pack(">L", port)
|
||||
|
||||
|
||||
def unpackGlobal_tcpip_forward(data):
|
||||
host, rest = common.getNS(data)
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("utf-8")
|
||||
port = int(struct.unpack(">L", rest[:4])[0])
|
||||
return host, port
|
||||
|
||||
|
||||
"""This is how the data -> eof -> close stuff /should/ work.
|
||||
|
||||
debug3: channel 1: waiting for connection
|
||||
debug1: channel 1: connected
|
||||
debug1: channel 1: read<=0 rfd 7 len 0
|
||||
debug1: channel 1: read failed
|
||||
debug1: channel 1: close_read
|
||||
debug1: channel 1: input open -> drain
|
||||
debug1: channel 1: ibuf empty
|
||||
debug1: channel 1: send eof
|
||||
debug1: channel 1: input drain -> closed
|
||||
debug1: channel 1: rcvd eof
|
||||
debug1: channel 1: output open -> drain
|
||||
debug1: channel 1: obuf empty
|
||||
debug1: channel 1: close_write
|
||||
debug1: channel 1: output drain -> closed
|
||||
debug1: channel 1: rcvd close
|
||||
debug3: channel 1: will not send data after close
|
||||
debug1: channel 1: send close
|
||||
debug1: channel 1: is dead
|
||||
"""
|
||||
1865
.venv/lib/python3.12/site-packages/twisted/conch/ssh/keys.py
Normal file
1865
.venv/lib/python3.12/site-packages/twisted/conch/ssh/keys.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
The parent class for all the SSH services. Currently implemented services
|
||||
are ssh-userauth and ssh-connection.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from twisted.logger import Logger
|
||||
|
||||
|
||||
class SSHService:
|
||||
# this is the ssh name for the service:
|
||||
name: bytes = None # type:ignore[assignment]
|
||||
|
||||
protocolMessages: Dict[int, str] = {} # map #'s -> protocol names
|
||||
transport = None # gets set later
|
||||
|
||||
_log = Logger()
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
called when the service is active on the transport.
|
||||
"""
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
called when the service is stopped, either by the connection ending
|
||||
or by another service being started
|
||||
"""
|
||||
|
||||
def logPrefix(self):
|
||||
return "SSHService {!r} on {}".format(
|
||||
self.name, self.transport.transport.logPrefix()
|
||||
)
|
||||
|
||||
def packetReceived(self, messageNum, packet):
|
||||
"""
|
||||
called when we receive a packet on the transport
|
||||
"""
|
||||
# print self.protocolMessages
|
||||
if messageNum in self.protocolMessages:
|
||||
messageType = self.protocolMessages[messageNum]
|
||||
f = getattr(self, "ssh_%s" % messageType[4:], None)
|
||||
if f is not None:
|
||||
return f(packet)
|
||||
self._log.info(
|
||||
"couldn't handle {messageNum} {packet!r}",
|
||||
messageNum=messageNum,
|
||||
packet=packet,
|
||||
)
|
||||
self.transport.sendUnimplemented()
|
||||
440
.venv/lib/python3.12/site-packages/twisted/conch/ssh/session.py
Normal file
440
.venv/lib/python3.12/site-packages/twisted/conch/ssh/session.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_session -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of SSHSession, which (by default)
|
||||
allows access to a shell and a python interpreter over SSH.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import sys
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.interfaces import (
|
||||
EnvironmentVariableNotPermitted,
|
||||
ISession,
|
||||
ISessionSetEnv,
|
||||
)
|
||||
from twisted.conch.ssh import channel, common, connection
|
||||
from twisted.internet import interfaces, protocol
|
||||
from twisted.logger import Logger
|
||||
from twisted.python.compat import networkString
|
||||
|
||||
log = Logger()
|
||||
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
"""
|
||||
A generalized implementation of an SSH session.
|
||||
|
||||
See RFC 4254, section 6.
|
||||
|
||||
The precise implementation of the various operations that the remote end
|
||||
can send is left up to the avatar, usually via an adapter to an
|
||||
interface such as L{ISession}.
|
||||
|
||||
@ivar buf: a buffer for data received before making a connection to a
|
||||
client.
|
||||
@type buf: L{bytes}
|
||||
@ivar client: a protocol for communication with a shell, an application
|
||||
program, or a subsystem (see RFC 4254, section 6.5).
|
||||
@type client: L{SSHSessionProcessProtocol}
|
||||
@ivar session: an object providing concrete implementations of session
|
||||
operations.
|
||||
@type session: L{ISession}
|
||||
"""
|
||||
|
||||
name = b"session"
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
channel.SSHChannel.__init__(self, *args, **kw)
|
||||
self.buf = b""
|
||||
self.client = None
|
||||
self.session = None
|
||||
|
||||
def request_subsystem(self, data):
|
||||
subsystem, ignored = common.getNS(data)
|
||||
log.info('Asking for subsystem "{subsystem}"', subsystem=subsystem)
|
||||
client = self.avatar.lookupSubsystem(subsystem, data)
|
||||
if client:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
proto = wrapProcessProtocol(pp)
|
||||
client.makeConnection(proto)
|
||||
pp.makeConnection(wrapProtocol(client))
|
||||
self.client = pp
|
||||
return 1
|
||||
else:
|
||||
log.error("Failed to get subsystem")
|
||||
return 0
|
||||
|
||||
def request_shell(self, data):
|
||||
log.info("Getting shell")
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
try:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
self.session.openShell(pp)
|
||||
except Exception:
|
||||
log.failure("Error getting shell")
|
||||
return 0
|
||||
else:
|
||||
self.client = pp
|
||||
return 1
|
||||
|
||||
def request_exec(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
f, data = common.getNS(data)
|
||||
log.info('Executing command "{f}"', f=f)
|
||||
try:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
self.session.execCommand(pp, f)
|
||||
except Exception:
|
||||
log.failure('Error executing command "{f}"', f=f)
|
||||
return 0
|
||||
else:
|
||||
self.client = pp
|
||||
return 1
|
||||
|
||||
def request_pty_req(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
term, windowSize, modes = parseRequest_pty_req(data)
|
||||
log.info(
|
||||
"Handling pty request: {term!r} {windowSize!r}",
|
||||
term=term,
|
||||
windowSize=windowSize,
|
||||
)
|
||||
try:
|
||||
self.session.getPty(term, windowSize, modes)
|
||||
except Exception:
|
||||
log.failure("Error handling pty request")
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def request_env(self, data):
|
||||
"""
|
||||
Process a request to pass an environment variable.
|
||||
|
||||
@param data: The environment variable name and value, each encoded
|
||||
as an SSH protocol string and concatenated.
|
||||
@type data: L{bytes}
|
||||
@return: A true value if the request to pass this environment
|
||||
variable was accepted, otherwise a false value.
|
||||
"""
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
if not ISessionSetEnv.providedBy(self.session):
|
||||
return 0
|
||||
name, value, data = common.getNS(data, 2)
|
||||
try:
|
||||
self.session.setEnv(name, value)
|
||||
except EnvironmentVariableNotPermitted:
|
||||
return 0
|
||||
except Exception:
|
||||
log.failure("Error setting environment variable {name}", name=name)
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def request_window_change(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
winSize = parseRequest_window_change(data)
|
||||
try:
|
||||
self.session.windowChanged(winSize)
|
||||
except Exception:
|
||||
log.failure("Error changing window size")
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def dataReceived(self, data):
|
||||
if not self.client:
|
||||
# self.conn.sendClose(self)
|
||||
self.buf += data
|
||||
return
|
||||
self.client.transport.write(data)
|
||||
|
||||
def extReceived(self, dataType, data):
|
||||
if dataType == connection.EXTENDED_DATA_STDERR:
|
||||
if self.client and hasattr(self.client.transport, "writeErr"):
|
||||
self.client.transport.writeErr(data)
|
||||
else:
|
||||
log.warn("Weird extended data: {dataType}", dataType=dataType)
|
||||
|
||||
def eofReceived(self):
|
||||
# If we have a session, tell it that EOF has been received and
|
||||
# expect it to send a close message (it may need to send other
|
||||
# messages such as exit-status or exit-signal first). If we don't
|
||||
# have a session, then just send a close message directly.
|
||||
if self.session:
|
||||
self.session.eofReceived()
|
||||
elif self.client:
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def closed(self):
|
||||
if self.client and self.client.transport:
|
||||
self.client.transport.loseConnection()
|
||||
if self.session:
|
||||
self.session.closed()
|
||||
|
||||
# def closeReceived(self):
|
||||
# self.loseConnection() # don't know what to do with this
|
||||
|
||||
def loseConnection(self):
|
||||
if self.client:
|
||||
self.client.transport.loseConnection()
|
||||
channel.SSHChannel.loseConnection(self)
|
||||
|
||||
|
||||
class _ProtocolWrapper(protocol.ProcessProtocol):
|
||||
"""
|
||||
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
|
||||
"""
|
||||
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
|
||||
def connectionMade(self):
|
||||
self.proto.connectionMade()
|
||||
|
||||
def outReceived(self, data):
|
||||
self.proto.dataReceived(data)
|
||||
|
||||
def processEnded(self, reason):
|
||||
self.proto.connectionLost(reason)
|
||||
|
||||
|
||||
class _DummyTransport:
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.proto.transport.write(data)
|
||||
|
||||
def write(self, data):
|
||||
self.proto.dataReceived(data)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.write(b"".join(seq))
|
||||
|
||||
def loseConnection(self):
|
||||
self.proto.connectionLost(protocol.connectionDone)
|
||||
|
||||
|
||||
def wrapProcessProtocol(inst):
|
||||
if isinstance(inst, protocol.Protocol):
|
||||
return _ProtocolWrapper(inst)
|
||||
else:
|
||||
return inst
|
||||
|
||||
|
||||
def wrapProtocol(proto):
|
||||
return _DummyTransport(proto)
|
||||
|
||||
|
||||
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
|
||||
# to accept. See RFC 4254
|
||||
SUPPORTED_SIGNALS = [
|
||||
"ABRT",
|
||||
"ALRM",
|
||||
"FPE",
|
||||
"HUP",
|
||||
"ILL",
|
||||
"INT",
|
||||
"KILL",
|
||||
"PIPE",
|
||||
"QUIT",
|
||||
"SEGV",
|
||||
"TERM",
|
||||
"USR1",
|
||||
"USR2",
|
||||
]
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport)
|
||||
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
|
||||
"""I am both an L{IProcessProtocol} and an L{ITransport}.
|
||||
|
||||
I am a transport to the remote endpoint and a process protocol to the
|
||||
local subsystem.
|
||||
"""
|
||||
|
||||
# once initialized, a dictionary mapping signal values to strings
|
||||
# that follow RFC 4254.
|
||||
_signalValuesToNames = None
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.lostOutOrErrFlag = False
|
||||
|
||||
def connectionMade(self):
|
||||
if self.session.buf:
|
||||
self.transport.write(self.session.buf)
|
||||
self.session.buf = None
|
||||
|
||||
def outReceived(self, data):
|
||||
self.session.write(data)
|
||||
|
||||
def errReceived(self, err):
|
||||
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
|
||||
|
||||
def outConnectionLost(self):
|
||||
"""
|
||||
EOF should only be sent when both STDOUT and STDERR have been closed.
|
||||
"""
|
||||
if self.lostOutOrErrFlag:
|
||||
self.session.conn.sendEOF(self.session)
|
||||
else:
|
||||
self.lostOutOrErrFlag = True
|
||||
|
||||
def errConnectionLost(self):
|
||||
"""
|
||||
See outConnectionLost().
|
||||
"""
|
||||
self.outConnectionLost()
|
||||
|
||||
def connectionLost(self, reason=None):
|
||||
self.session.loseConnection()
|
||||
|
||||
def _getSignalName(self, signum):
|
||||
"""
|
||||
Get a signal name given a signal number.
|
||||
"""
|
||||
if self._signalValuesToNames is None:
|
||||
self._signalValuesToNames = {}
|
||||
# make sure that the POSIX ones are the defaults
|
||||
for signame in SUPPORTED_SIGNALS:
|
||||
signame = "SIG" + signame
|
||||
sigvalue = getattr(signal, signame, None)
|
||||
if sigvalue is not None:
|
||||
self._signalValuesToNames[sigvalue] = signame
|
||||
for k, v in signal.__dict__.items():
|
||||
# Check for platform specific signals, ignoring Python specific
|
||||
# SIG_DFL and SIG_IGN
|
||||
if k.startswith("SIG") and not k.startswith("SIG_"):
|
||||
if v not in self._signalValuesToNames:
|
||||
self._signalValuesToNames[v] = k + "@" + sys.platform
|
||||
return self._signalValuesToNames[signum]
|
||||
|
||||
def processEnded(self, reason=None):
|
||||
"""
|
||||
When we are told the process ended, try to notify the other side about
|
||||
how the process ended using the exit-signal or exit-status requests.
|
||||
Also, close the channel.
|
||||
"""
|
||||
if reason is not None:
|
||||
err = reason.value
|
||||
if err.signal is not None:
|
||||
signame = self._getSignalName(err.signal)
|
||||
if getattr(os, "WCOREDUMP", None) is not None and os.WCOREDUMP(
|
||||
err.status
|
||||
):
|
||||
log.info("exitSignal: {signame} (core dumped)", signame=signame)
|
||||
coreDumped = True
|
||||
else:
|
||||
log.info("exitSignal: {}", signame=signame)
|
||||
coreDumped = False
|
||||
self.session.conn.sendRequest(
|
||||
self.session,
|
||||
b"exit-signal",
|
||||
common.NS(networkString(signame[3:]))
|
||||
+ (b"\1" if coreDumped else b"\0")
|
||||
+ common.NS(b"")
|
||||
+ common.NS(b""),
|
||||
)
|
||||
elif err.exitCode is not None:
|
||||
log.info("exitCode: {exitCode!r}", exitCode=err.exitCode)
|
||||
self.session.conn.sendRequest(
|
||||
self.session, b"exit-status", struct.pack(">L", err.exitCode)
|
||||
)
|
||||
self.session.loseConnection()
|
||||
|
||||
def getHost(self):
|
||||
"""
|
||||
Return the host from my session's transport.
|
||||
"""
|
||||
return self.session.conn.transport.getHost()
|
||||
|
||||
def getPeer(self):
|
||||
"""
|
||||
Return the peer from my session's transport.
|
||||
"""
|
||||
return self.session.conn.transport.getPeer()
|
||||
|
||||
def write(self, data):
|
||||
self.session.write(data)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.session.write(b"".join(seq))
|
||||
|
||||
def loseConnection(self):
|
||||
self.session.loseConnection()
|
||||
|
||||
|
||||
class SSHSessionClient(protocol.Protocol):
|
||||
def dataReceived(self, data):
|
||||
if self.transport:
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
# methods factored out to make live easier on server writers
|
||||
def parseRequest_pty_req(data):
|
||||
"""Parse the data from a pty-req request into usable data.
|
||||
|
||||
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
|
||||
"""
|
||||
term, rest = common.getNS(data)
|
||||
cols, rows, xpixel, ypixel = struct.unpack(">4L", rest[:16])
|
||||
modes, ignored = common.getNS(rest[16:])
|
||||
winSize = (rows, cols, xpixel, ypixel)
|
||||
modes = [
|
||||
(ord(modes[i : i + 1]), struct.unpack(">L", modes[i + 1 : i + 5])[0])
|
||||
for i in range(0, len(modes) - 1, 5)
|
||||
]
|
||||
return term, winSize, modes
|
||||
|
||||
|
||||
def packRequest_pty_req(term, geometry, modes):
|
||||
"""
|
||||
Pack a pty-req request so that it is suitable for sending.
|
||||
|
||||
NOTE: modes must be packed before being sent here.
|
||||
|
||||
@type geometry: L{tuple}
|
||||
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
|
||||
"""
|
||||
(rows, cols, xpixel, ypixel) = geometry
|
||||
termPacked = common.NS(term)
|
||||
winSizePacked = struct.pack(">4L", cols, rows, xpixel, ypixel)
|
||||
modesPacked = common.NS(modes) # depend on the client packing modes
|
||||
return termPacked + winSizePacked + modesPacked
|
||||
|
||||
|
||||
def parseRequest_window_change(data):
|
||||
"""Parse the data from a window-change request into usuable data.
|
||||
|
||||
@returns: a tuple of (rows, cols, xpixel, ypixel)
|
||||
"""
|
||||
cols, rows, xpixel, ypixel = struct.unpack(">4L", data)
|
||||
return rows, cols, xpixel, ypixel
|
||||
|
||||
|
||||
def packRequest_window_change(geometry):
|
||||
"""
|
||||
Pack a window-change request so that it is suitable for sending.
|
||||
|
||||
@type geometry: L{tuple}
|
||||
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
|
||||
"""
|
||||
(rows, cols, xpixel, ypixel) = geometry
|
||||
return struct.pack(">4L", cols, rows, xpixel, ypixel)
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
def parse(s):
|
||||
s = s.strip()
|
||||
expr = []
|
||||
while s:
|
||||
if s[0:1] == b"(":
|
||||
newSexp = []
|
||||
if expr:
|
||||
expr[-1].append(newSexp)
|
||||
expr.append(newSexp)
|
||||
s = s[1:]
|
||||
continue
|
||||
if s[0:1] == b")":
|
||||
aList = expr.pop()
|
||||
s = s[1:]
|
||||
if not expr:
|
||||
assert not s
|
||||
return aList
|
||||
continue
|
||||
i = 0
|
||||
while s[i : i + 1].isdigit():
|
||||
i += 1
|
||||
assert i
|
||||
length = int(s[:i])
|
||||
data = s[i + 1 : i + 1 + length]
|
||||
expr[-1].append(data)
|
||||
s = s[i + 1 + length :]
|
||||
assert False, "this should not happen"
|
||||
|
||||
|
||||
def pack(sexp):
|
||||
return b"".join(
|
||||
b"(%b)" % (pack(o),)
|
||||
if type(o) in (type(()), type([]))
|
||||
else b"%d:%b" % (len(o), o)
|
||||
for o in sexp
|
||||
)
|
||||
2258
.venv/lib/python3.12/site-packages/twisted/conch/ssh/transport.py
Normal file
2258
.venv/lib/python3.12/site-packages/twisted/conch/ssh/transport.py
Normal file
File diff suppressed because it is too large
Load Diff
764
.venv/lib/python3.12/site-packages/twisted/conch/ssh/userauth.py
Normal file
764
.venv/lib/python3.12/site-packages/twisted/conch/ssh/userauth.py
Normal file
@@ -0,0 +1,764 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_userauth -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of the ssh-userauth service.
|
||||
Currently implemented authentication types are public-key and password.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch import error, interfaces
|
||||
from twisted.conch.ssh import keys, service, transport
|
||||
from twisted.conch.ssh.common import NS, getNS
|
||||
from twisted.cred import credentials
|
||||
from twisted.cred.error import UnauthorizedLogin
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.logger import Logger
|
||||
from twisted.python import failure
|
||||
from twisted.python.compat import nativeString
|
||||
|
||||
|
||||
class SSHUserAuthServer(service.SSHService):
|
||||
"""
|
||||
A service implementing the server side of the 'ssh-userauth' service. It
|
||||
is used to authenticate the user on the other side as being able to access
|
||||
this server.
|
||||
|
||||
@ivar name: the name of this service: 'ssh-userauth'
|
||||
@type name: L{bytes}
|
||||
@ivar authenticatedWith: a list of authentication methods that have
|
||||
already been used.
|
||||
@type authenticatedWith: L{list}
|
||||
@ivar loginTimeout: the number of seconds we wait before disconnecting
|
||||
the user for taking too long to authenticate
|
||||
@type loginTimeout: L{int}
|
||||
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
|
||||
allow before disconnecting.
|
||||
@type attemptsBeforeDisconnect: L{int}
|
||||
@ivar loginAttempts: the number of login attempts that have been made
|
||||
@type loginAttempts: L{int}
|
||||
@ivar passwordDelay: the number of seconds to delay when the user gives
|
||||
an incorrect password
|
||||
@type passwordDelay: L{int}
|
||||
@ivar interfaceToMethod: a L{dict} mapping credential interfaces to
|
||||
authentication methods. The server checks to see which of the
|
||||
cred interfaces have checkers and tells the client that those methods
|
||||
are valid for authentication.
|
||||
@type interfaceToMethod: L{dict}
|
||||
@ivar supportedAuthentications: A list of the supported authentication
|
||||
methods.
|
||||
@type supportedAuthentications: L{list} of L{bytes}
|
||||
@ivar user: the last username the client tried to authenticate with
|
||||
@type user: L{bytes}
|
||||
@ivar method: the current authentication method
|
||||
@type method: L{bytes}
|
||||
@ivar nextService: the service the user wants started after authentication
|
||||
has been completed.
|
||||
@type nextService: L{bytes}
|
||||
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
|
||||
authentication
|
||||
@type portal: L{twisted.cred.portal.Portal}
|
||||
@ivar clock: an object with a callLater method. Stubbed out for testing.
|
||||
"""
|
||||
|
||||
name = b"ssh-userauth"
|
||||
loginTimeout = 10 * 60 * 60
|
||||
# 10 minutes before we disconnect them
|
||||
attemptsBeforeDisconnect = 20
|
||||
# 20 login attempts before a disconnect
|
||||
passwordDelay = 1 # number of seconds to delay on a failed password
|
||||
clock = reactor
|
||||
interfaceToMethod = {
|
||||
credentials.ISSHPrivateKey: b"publickey",
|
||||
credentials.IUsernamePassword: b"password",
|
||||
}
|
||||
_log = Logger()
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
Called when the userauth service is started. Set up instance
|
||||
variables, check if we should allow password authentication (only
|
||||
allow if the outgoing connection is encrypted) and set up a login
|
||||
timeout.
|
||||
"""
|
||||
self.authenticatedWith = []
|
||||
self.loginAttempts = 0
|
||||
self.user = None
|
||||
self.nextService = None
|
||||
self.portal = self.transport.factory.portal
|
||||
|
||||
self.supportedAuthentications = []
|
||||
for i in self.portal.listCredentialsInterfaces():
|
||||
if i in self.interfaceToMethod:
|
||||
self.supportedAuthentications.append(self.interfaceToMethod[i])
|
||||
|
||||
if not self.transport.isEncrypted("in"):
|
||||
# don't let us transport password in plaintext
|
||||
if b"password" in self.supportedAuthentications:
|
||||
self.supportedAuthentications.remove(b"password")
|
||||
self._cancelLoginTimeout = self.clock.callLater(
|
||||
self.loginTimeout, self.timeoutAuthentication
|
||||
)
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
Called when the userauth service is stopped. Cancel the login timeout
|
||||
if it's still going.
|
||||
"""
|
||||
if self._cancelLoginTimeout:
|
||||
self._cancelLoginTimeout.cancel()
|
||||
self._cancelLoginTimeout = None
|
||||
|
||||
def timeoutAuthentication(self):
|
||||
"""
|
||||
Called when the user has timed out on authentication. Disconnect
|
||||
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
|
||||
"""
|
||||
self._cancelLoginTimeout = None
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, b"you took too long"
|
||||
)
|
||||
|
||||
def tryAuth(self, kind, user, data):
|
||||
"""
|
||||
Try to authenticate the user with the given method. Dispatches to a
|
||||
auth_* method.
|
||||
|
||||
@param kind: the authentication method to try.
|
||||
@type kind: L{bytes}
|
||||
@param user: the username the client is authenticating with.
|
||||
@type user: L{bytes}
|
||||
@param data: authentication specific data sent by the client.
|
||||
@type data: L{bytes}
|
||||
@return: A Deferred called back if the method succeeded, or erred back
|
||||
if it failed.
|
||||
@rtype: C{defer.Deferred}
|
||||
"""
|
||||
self._log.debug("{user!r} trying auth {kind!r}", user=user, kind=kind)
|
||||
if kind not in self.supportedAuthentications:
|
||||
return defer.fail(error.ConchError("unsupported authentication, failing"))
|
||||
kind = nativeString(kind.replace(b"-", b"_"))
|
||||
f = getattr(self, f"auth_{kind}", None)
|
||||
if f:
|
||||
ret = f(data)
|
||||
if not ret:
|
||||
return defer.fail(
|
||||
error.ConchError(f"{kind} return None instead of a Deferred")
|
||||
)
|
||||
else:
|
||||
return ret
|
||||
return defer.fail(error.ConchError(f"bad auth type: {kind}"))
|
||||
|
||||
def ssh_USERAUTH_REQUEST(self, packet):
|
||||
"""
|
||||
The client has requested authentication. Payload::
|
||||
string user
|
||||
string next service
|
||||
string method
|
||||
<authentication specific data>
|
||||
|
||||
@type packet: L{bytes}
|
||||
"""
|
||||
user, nextService, method, rest = getNS(packet, 3)
|
||||
if user != self.user or nextService != self.nextService:
|
||||
self.authenticatedWith = [] # clear auth state
|
||||
self.user = user
|
||||
self.nextService = nextService
|
||||
self.method = method
|
||||
d = self.tryAuth(method, user, rest)
|
||||
if not d:
|
||||
self._ebBadAuth(failure.Failure(error.ConchError("auth returned none")))
|
||||
return
|
||||
d.addCallback(self._cbFinishedAuth)
|
||||
d.addErrback(self._ebMaybeBadAuth)
|
||||
d.addErrback(self._ebBadAuth)
|
||||
return d
|
||||
|
||||
def _cbFinishedAuth(self, result):
|
||||
"""
|
||||
The callback when user has successfully been authenticated. For a
|
||||
description of the arguments, see L{twisted.cred.portal.Portal.login}.
|
||||
We start the service requested by the user.
|
||||
"""
|
||||
(interface, avatar, logout) = result
|
||||
self.transport.avatar = avatar
|
||||
self.transport.logoutFunction = logout
|
||||
service = self.transport.factory.getService(self.transport, self.nextService)
|
||||
if not service:
|
||||
raise error.ConchError(f"could not get next service: {self.nextService}")
|
||||
self._log.debug(
|
||||
"{user!r} authenticated with {method!r}", user=self.user, method=self.method
|
||||
)
|
||||
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, b"")
|
||||
self.transport.setService(service())
|
||||
|
||||
def _ebMaybeBadAuth(self, reason):
|
||||
"""
|
||||
An intermediate errback. If the reason is
|
||||
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
|
||||
with the partial success indicator set.
|
||||
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
reason.trap(error.NotEnoughAuthentication)
|
||||
self.transport.sendPacket(
|
||||
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\xff"
|
||||
)
|
||||
|
||||
def _ebBadAuth(self, reason):
|
||||
"""
|
||||
The final errback in the authentication chain. If the reason is
|
||||
error.IgnoreAuthentication, we simply return; the authentication
|
||||
method has sent its own response. Otherwise, send a failure message
|
||||
and (if the method is not 'none') increment the number of login
|
||||
attempts.
|
||||
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
if reason.check(error.IgnoreAuthentication):
|
||||
return
|
||||
if self.method != b"none":
|
||||
self._log.debug(
|
||||
"{user!r} failed auth {method!r}", user=self.user, method=self.method
|
||||
)
|
||||
if reason.check(UnauthorizedLogin):
|
||||
self._log.debug(
|
||||
"unauthorized login: {message}", message=reason.getErrorMessage()
|
||||
)
|
||||
elif reason.check(error.ConchError):
|
||||
self._log.debug("reason: {reason}", reason=reason.getErrorMessage())
|
||||
else:
|
||||
self._log.failure(
|
||||
"Error checking auth for user {user}",
|
||||
failure=reason,
|
||||
user=self.user,
|
||||
)
|
||||
self.loginAttempts += 1
|
||||
if self.loginAttempts > self.attemptsBeforeDisconnect:
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
|
||||
b"too many bad auths",
|
||||
)
|
||||
return
|
||||
self.transport.sendPacket(
|
||||
MSG_USERAUTH_FAILURE, NS(b",".join(self.supportedAuthentications)) + b"\x00"
|
||||
)
|
||||
|
||||
def auth_publickey(self, packet):
|
||||
"""
|
||||
Public key authentication. Payload::
|
||||
byte has signature
|
||||
string algorithm name
|
||||
string key blob
|
||||
[string signature] (if has signature is True)
|
||||
|
||||
Create a SSHPublicKey credential and verify it using our portal.
|
||||
"""
|
||||
hasSig = ord(packet[0:1])
|
||||
algName, blob, rest = getNS(packet[1:], 2)
|
||||
|
||||
try:
|
||||
keys.Key.fromString(blob)
|
||||
except keys.BadKeyError:
|
||||
error = "Unsupported key type {} or bad key".format(algName.decode("ascii"))
|
||||
self._log.error(error)
|
||||
return defer.fail(UnauthorizedLogin(error))
|
||||
|
||||
signature = hasSig and getNS(rest)[0] or None
|
||||
if hasSig:
|
||||
b = (
|
||||
NS(self.transport.sessionID)
|
||||
+ bytes((MSG_USERAUTH_REQUEST,))
|
||||
+ NS(self.user)
|
||||
+ NS(self.nextService)
|
||||
+ NS(b"publickey")
|
||||
+ bytes((hasSig,))
|
||||
+ NS(algName)
|
||||
+ NS(blob)
|
||||
)
|
||||
c = credentials.SSHPrivateKey(self.user, algName, blob, b, signature)
|
||||
return self.portal.login(c, None, interfaces.IConchUser)
|
||||
else:
|
||||
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
|
||||
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
|
||||
self._ebCheckKey, packet[1:]
|
||||
)
|
||||
|
||||
def _ebCheckKey(self, reason, packet):
|
||||
"""
|
||||
Called back if the user did not sent a signature. If reason is
|
||||
error.ValidPublicKey then this key is valid for the user to
|
||||
authenticate with. Send MSG_USERAUTH_PK_OK.
|
||||
"""
|
||||
reason.trap(error.ValidPublicKey)
|
||||
# if we make it here, it means that the publickey is valid
|
||||
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
|
||||
return failure.Failure(error.IgnoreAuthentication())
|
||||
|
||||
def auth_password(self, packet):
|
||||
"""
|
||||
Password authentication. Payload::
|
||||
string password
|
||||
|
||||
Make a UsernamePassword credential and verify it with our portal.
|
||||
"""
|
||||
password = getNS(packet[1:])[0]
|
||||
c = credentials.UsernamePassword(self.user, password)
|
||||
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
|
||||
self._ebPassword
|
||||
)
|
||||
|
||||
def _ebPassword(self, f):
|
||||
"""
|
||||
If the password is invalid, wait before sending the failure in order
|
||||
to delay brute-force password guessing.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.clock.callLater(self.passwordDelay, d.callback, f)
|
||||
return d
|
||||
|
||||
|
||||
class SSHUserAuthClient(service.SSHService):
|
||||
"""
|
||||
A service implementing the client side of 'ssh-userauth'.
|
||||
|
||||
This service will try all authentication methods provided by the server,
|
||||
making callbacks for more information when necessary.
|
||||
|
||||
@ivar name: the name of this service: 'ssh-userauth'
|
||||
@type name: L{str}
|
||||
@ivar preferredOrder: a list of authentication methods that should be used
|
||||
first, in order of preference, if supported by the server
|
||||
@type preferredOrder: L{list}
|
||||
@ivar user: the name of the user to authenticate as
|
||||
@type user: L{bytes}
|
||||
@ivar instance: the service to start after authentication has finished
|
||||
@type instance: L{service.SSHService}
|
||||
@ivar authenticatedWith: a list of strings of authentication methods we've tried
|
||||
@type authenticatedWith: L{list} of L{bytes}
|
||||
@ivar triedPublicKeys: a list of public key objects that we've tried to
|
||||
authenticate with
|
||||
@type triedPublicKeys: L{list} of L{Key}
|
||||
@ivar lastPublicKey: the last public key object we've tried to authenticate
|
||||
with
|
||||
@type lastPublicKey: L{Key}
|
||||
"""
|
||||
|
||||
name = b"ssh-userauth"
|
||||
preferredOrder = [b"publickey", b"password", b"keyboard-interactive"]
|
||||
|
||||
def __init__(self, user, instance):
|
||||
self.user = user
|
||||
self.instance = instance
|
||||
|
||||
def serviceStarted(self):
|
||||
self.authenticatedWith = []
|
||||
self.triedPublicKeys = []
|
||||
self.lastPublicKey = None
|
||||
self.askForAuth(b"none", b"")
|
||||
|
||||
def askForAuth(self, kind, extraData):
|
||||
"""
|
||||
Send a MSG_USERAUTH_REQUEST.
|
||||
|
||||
@param kind: the authentication method to try.
|
||||
@type kind: L{bytes}
|
||||
@param extraData: method-specific data to go in the packet
|
||||
@type extraData: L{bytes}
|
||||
"""
|
||||
self.lastAuth = kind
|
||||
self.transport.sendPacket(
|
||||
MSG_USERAUTH_REQUEST,
|
||||
NS(self.user) + NS(self.instance.name) + NS(kind) + extraData,
|
||||
)
|
||||
|
||||
def tryAuth(self, kind):
|
||||
"""
|
||||
Dispatch to an authentication method.
|
||||
|
||||
@param kind: the authentication method
|
||||
@type kind: L{bytes}
|
||||
"""
|
||||
kind = nativeString(kind.replace(b"-", b"_"))
|
||||
self._log.debug("trying to auth with {kind}", kind=kind)
|
||||
f = getattr(self, "auth_" + kind, None)
|
||||
if f:
|
||||
return f()
|
||||
|
||||
def _ebAuth(self, ignored, *args):
|
||||
"""
|
||||
Generic callback for a failed authentication attempt. Respond by
|
||||
asking for the list of accepted methods (the 'none' method)
|
||||
"""
|
||||
self.askForAuth(b"none", b"")
|
||||
|
||||
def ssh_USERAUTH_SUCCESS(self, packet):
|
||||
"""
|
||||
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
|
||||
authentication, so start the next service.
|
||||
"""
|
||||
self.transport.setService(self.instance)
|
||||
|
||||
def ssh_USERAUTH_FAILURE(self, packet):
|
||||
"""
|
||||
We received a MSG_USERAUTH_FAILURE. Payload::
|
||||
string methods
|
||||
byte partial success
|
||||
|
||||
If partial success is C{True}, then the previous method succeeded but is
|
||||
not sufficient for authentication. C{methods} is a comma-separated list
|
||||
of accepted authentication methods.
|
||||
|
||||
We sort the list of methods by their position in C{self.preferredOrder},
|
||||
removing methods that have already succeeded. We then call
|
||||
C{self.tryAuth} with the most preferred method.
|
||||
|
||||
@param packet: the C{MSG_USERAUTH_FAILURE} payload.
|
||||
@type packet: L{bytes}
|
||||
|
||||
@return: a L{defer.Deferred} that will be callbacked with L{None} as
|
||||
soon as all authentication methods have been tried, or L{None} if no
|
||||
more authentication methods are available.
|
||||
@rtype: C{defer.Deferred} or L{None}
|
||||
"""
|
||||
canContinue, partial = getNS(packet)
|
||||
partial = ord(partial)
|
||||
if partial:
|
||||
self.authenticatedWith.append(self.lastAuth)
|
||||
|
||||
def orderByPreference(meth):
|
||||
"""
|
||||
Invoked once per authentication method in order to extract a
|
||||
comparison key which is then used for sorting.
|
||||
|
||||
@param meth: the authentication method.
|
||||
@type meth: L{bytes}
|
||||
|
||||
@return: the comparison key for C{meth}.
|
||||
@rtype: L{int}
|
||||
"""
|
||||
if meth in self.preferredOrder:
|
||||
return self.preferredOrder.index(meth)
|
||||
else:
|
||||
# put the element at the end of the list.
|
||||
return len(self.preferredOrder)
|
||||
|
||||
canContinue = sorted(
|
||||
(
|
||||
meth
|
||||
for meth in canContinue.split(b",")
|
||||
if meth not in self.authenticatedWith
|
||||
),
|
||||
key=orderByPreference,
|
||||
)
|
||||
|
||||
self._log.debug("can continue with: {methods}", methods=canContinue)
|
||||
return self._cbUserauthFailure(None, iter(canContinue))
|
||||
|
||||
def _cbUserauthFailure(self, result, iterator):
|
||||
if result:
|
||||
return
|
||||
try:
|
||||
method = next(iterator)
|
||||
except StopIteration:
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
|
||||
b"no more authentication methods available",
|
||||
)
|
||||
else:
|
||||
d = defer.maybeDeferred(self.tryAuth, method)
|
||||
d.addCallback(self._cbUserauthFailure, iterator)
|
||||
return d
|
||||
|
||||
def ssh_USERAUTH_PK_OK(self, packet):
|
||||
"""
|
||||
This message (number 60) can mean several different messages depending
|
||||
on the current authentication type. We dispatch to individual methods
|
||||
in order to handle this request.
|
||||
"""
|
||||
func = getattr(
|
||||
self,
|
||||
"ssh_USERAUTH_PK_OK_%s" % nativeString(self.lastAuth.replace(b"-", b"_")),
|
||||
None,
|
||||
)
|
||||
if func is not None:
|
||||
return func(packet)
|
||||
else:
|
||||
self.askForAuth(b"none", b"")
|
||||
|
||||
def ssh_USERAUTH_PK_OK_publickey(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
|
||||
signature and try to authenticate with it.
|
||||
"""
|
||||
publicKey = self.lastPublicKey
|
||||
b = (
|
||||
NS(self.transport.sessionID)
|
||||
+ bytes((MSG_USERAUTH_REQUEST,))
|
||||
+ NS(self.user)
|
||||
+ NS(self.instance.name)
|
||||
+ NS(b"publickey")
|
||||
+ b"\x01"
|
||||
+ NS(publicKey.sshType())
|
||||
+ NS(publicKey.blob())
|
||||
)
|
||||
d = self.signData(publicKey, b)
|
||||
if not d:
|
||||
self.askForAuth(b"none", b"")
|
||||
# this will fail, we'll move on
|
||||
return
|
||||
d.addCallback(self._cbSignedData)
|
||||
d.addErrback(self._ebAuth)
|
||||
|
||||
def ssh_USERAUTH_PK_OK_password(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
|
||||
We ask for an old password and a new password, then send both back to
|
||||
the server.
|
||||
"""
|
||||
prompt, language, rest = getNS(packet, 2)
|
||||
self._oldPass = self._newPass = None
|
||||
d = self.getPassword(b"Old Password: ")
|
||||
d = d.addCallbacks(self._setOldPass, self._ebAuth)
|
||||
d.addCallback(lambda ignored: self.getPassword(prompt))
|
||||
d.addCallbacks(self._setNewPass, self._ebAuth)
|
||||
|
||||
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
|
||||
questions it wants us to answer, so we ask the user and sent the
|
||||
responses.
|
||||
"""
|
||||
name, instruction, lang, data = getNS(packet, 3)
|
||||
numPrompts = struct.unpack("!L", data[:4])[0]
|
||||
data = data[4:]
|
||||
prompts = []
|
||||
for i in range(numPrompts):
|
||||
prompt, data = getNS(data)
|
||||
echo = bool(ord(data[0:1]))
|
||||
data = data[1:]
|
||||
prompts.append((prompt, echo))
|
||||
d = self.getGenericAnswers(name, instruction, prompts)
|
||||
d.addCallback(self._cbGenericAnswers)
|
||||
d.addErrback(self._ebAuth)
|
||||
|
||||
def _cbSignedData(self, signedData):
|
||||
"""
|
||||
Called back out of self.signData with the signed data. Send the
|
||||
authentication request with the signature.
|
||||
|
||||
@param signedData: the data signed by the user's private key.
|
||||
@type signedData: L{bytes}
|
||||
"""
|
||||
publicKey = self.lastPublicKey
|
||||
self.askForAuth(
|
||||
b"publickey",
|
||||
b"\x01" + NS(publicKey.sshType()) + NS(publicKey.blob()) + NS(signedData),
|
||||
)
|
||||
|
||||
def _setOldPass(self, op):
|
||||
"""
|
||||
Called back when we are choosing a new password. Simply store the old
|
||||
password for now.
|
||||
|
||||
@param op: the old password as entered by the user
|
||||
@type op: L{bytes}
|
||||
"""
|
||||
self._oldPass = op
|
||||
|
||||
def _setNewPass(self, np):
|
||||
"""
|
||||
Called back when we are choosing a new password. Get the old password
|
||||
and send the authentication message with both.
|
||||
|
||||
@param np: the new password as entered by the user
|
||||
@type np: L{bytes}
|
||||
"""
|
||||
op = self._oldPass
|
||||
self._oldPass = None
|
||||
self.askForAuth(b"password", b"\xff" + NS(op) + NS(np))
|
||||
|
||||
def _cbGenericAnswers(self, responses):
|
||||
"""
|
||||
Called back when we are finished answering keyboard-interactive
|
||||
questions. Send the info back to the server in a
|
||||
MSG_USERAUTH_INFO_RESPONSE.
|
||||
|
||||
@param responses: a list of L{bytes} responses
|
||||
@type responses: L{list}
|
||||
"""
|
||||
data = struct.pack("!L", len(responses))
|
||||
for r in responses:
|
||||
data += NS(r.encode("UTF8"))
|
||||
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
|
||||
|
||||
def auth_publickey(self):
|
||||
"""
|
||||
Try to authenticate with a public key. Ask the user for a public key;
|
||||
if the user has one, send the request to the server and return True.
|
||||
Otherwise, return False.
|
||||
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
d = defer.maybeDeferred(self.getPublicKey)
|
||||
d.addBoth(self._cbGetPublicKey)
|
||||
return d
|
||||
|
||||
def _cbGetPublicKey(self, publicKey):
|
||||
if not isinstance(publicKey, keys.Key): # failure or None
|
||||
publicKey = None
|
||||
if publicKey is not None:
|
||||
self.lastPublicKey = publicKey
|
||||
self.triedPublicKeys.append(publicKey)
|
||||
self._log.debug("using key of type {keyType}", keyType=publicKey.type())
|
||||
self.askForAuth(
|
||||
b"publickey", b"\x00" + NS(publicKey.sshType()) + NS(publicKey.blob())
|
||||
)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def auth_password(self):
|
||||
"""
|
||||
Try to authenticate with a password. Ask the user for a password.
|
||||
If the user will return a password, return True. Otherwise, return
|
||||
False.
|
||||
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
d = self.getPassword()
|
||||
if d:
|
||||
d.addCallbacks(self._cbPassword, self._ebAuth)
|
||||
return True
|
||||
else: # returned None, don't do password auth
|
||||
return False
|
||||
|
||||
def auth_keyboard_interactive(self):
|
||||
"""
|
||||
Try to authenticate with keyboard-interactive authentication. Send
|
||||
the request to the server and return True.
|
||||
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
self._log.debug("authing with keyboard-interactive")
|
||||
self.askForAuth(b"keyboard-interactive", NS(b"") + NS(b""))
|
||||
return True
|
||||
|
||||
def _cbPassword(self, password):
|
||||
"""
|
||||
Called back when the user gives a password. Send the request to the
|
||||
server.
|
||||
|
||||
@param password: the password the user entered
|
||||
@type password: L{bytes}
|
||||
"""
|
||||
self.askForAuth(b"password", b"\x00" + NS(password))
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Sign the given data with the given public key.
|
||||
|
||||
By default, this will call getPrivateKey to get the private key,
|
||||
then sign the data using Key.sign().
|
||||
|
||||
This method is factored out so that it can be overridden to use
|
||||
alternate methods, such as a key agent.
|
||||
|
||||
@param publicKey: The public key object returned from L{getPublicKey}
|
||||
@type publicKey: L{keys.Key}
|
||||
|
||||
@param signData: the data to be signed by the private key.
|
||||
@type signData: L{bytes}
|
||||
@return: a Deferred that's called back with the signature
|
||||
@rtype: L{defer.Deferred}
|
||||
"""
|
||||
key = self.getPrivateKey()
|
||||
if not key:
|
||||
return
|
||||
return key.addCallback(self._cbSignData, signData)
|
||||
|
||||
def _cbSignData(self, privateKey, signData):
|
||||
"""
|
||||
Called back when the private key is returned. Sign the data and
|
||||
return the signature.
|
||||
|
||||
@param privateKey: the private key object
|
||||
@type privateKey: L{keys.Key}
|
||||
@param signData: the data to be signed by the private key.
|
||||
@type signData: L{bytes}
|
||||
@return: the signature
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return privateKey.sign(signData)
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Return a public key for the user. If no more public keys are
|
||||
available, return L{None}.
|
||||
|
||||
This implementation always returns L{None}. Override it in a
|
||||
subclass to actually find and return a public key object.
|
||||
|
||||
@rtype: L{Key} or L{None}
|
||||
"""
|
||||
return None
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Return a L{Deferred} that will be called back with the private key
|
||||
object corresponding to the last public key from getPublicKey().
|
||||
If the private key is not available, errback on the Deferred.
|
||||
|
||||
@rtype: L{Deferred} called back with L{Key}
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
def getPassword(self, prompt=None):
|
||||
"""
|
||||
Return a L{Deferred} that will be called back with a password.
|
||||
prompt is a string to display for the password, or None for a generic
|
||||
'user@hostname's password: '.
|
||||
|
||||
@type prompt: L{bytes}/L{None}
|
||||
@rtype: L{defer.Deferred}
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
def getGenericAnswers(self, name, instruction, prompts):
|
||||
"""
|
||||
Returns a L{Deferred} with the responses to the promopts.
|
||||
|
||||
@param name: The name of the authentication currently in progress.
|
||||
@param instruction: Describes what the authentication wants.
|
||||
@param prompts: A list of (prompt, echo) pairs, where prompt is a
|
||||
string to display and echo is a boolean indicating whether the
|
||||
user's response should be echoed as they type it.
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
|
||||
MSG_USERAUTH_REQUEST = 50
|
||||
MSG_USERAUTH_FAILURE = 51
|
||||
MSG_USERAUTH_SUCCESS = 52
|
||||
MSG_USERAUTH_BANNER = 53
|
||||
MSG_USERAUTH_INFO_RESPONSE = 61
|
||||
MSG_USERAUTH_PK_OK = 60
|
||||
|
||||
messages = {}
|
||||
for k, v in list(locals().items()):
|
||||
if k[:4] == "MSG_":
|
||||
messages[v] = k
|
||||
|
||||
SSHUserAuthServer.protocolMessages = messages
|
||||
SSHUserAuthClient.protocolMessages = messages
|
||||
del messages
|
||||
del v
|
||||
|
||||
# Doubles, not included in the protocols' mappings
|
||||
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
|
||||
MSG_USERAUTH_INFO_REQUEST = 60
|
||||
114
.venv/lib/python3.12/site-packages/twisted/conch/stdio.py
Normal file
114
.venv/lib/python3.12/site-packages/twisted/conch/stdio.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Asynchronous local terminal input handling
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import termios
|
||||
import tty
|
||||
|
||||
from twisted.conch.insults.insults import ServerProtocol
|
||||
from twisted.conch.manhole import ColoredManhole
|
||||
from twisted.internet import defer, protocol, reactor, stdio
|
||||
from twisted.python import failure, log, reflect
|
||||
|
||||
|
||||
class UnexpectedOutputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TerminalProcessProtocol(protocol.ProcessProtocol):
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
self.onConnection = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.proto.makeConnection(self)
|
||||
self.onConnection.callback(None)
|
||||
self.onConnection = None
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
Write to the terminal.
|
||||
|
||||
@param data: Data to write.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self.transport.write(data)
|
||||
|
||||
def outReceived(self, data):
|
||||
"""
|
||||
Receive data from the terminal.
|
||||
|
||||
@param data: Data received.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self.proto.dataReceived(data)
|
||||
|
||||
def errReceived(self, data):
|
||||
"""
|
||||
Report an error.
|
||||
|
||||
@param data: Data to include in L{Failure}.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
if self.proto is not None:
|
||||
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(data)))
|
||||
self.proto = None
|
||||
|
||||
def childConnectionLost(self, childFD):
|
||||
if self.proto is not None:
|
||||
self.proto.childConnectionLost(childFD)
|
||||
|
||||
def processEnded(self, reason):
|
||||
if self.proto is not None:
|
||||
self.proto.connectionLost(reason)
|
||||
self.proto = None
|
||||
|
||||
|
||||
class ConsoleManhole(ColoredManhole):
|
||||
"""
|
||||
A manhole protocol specifically for use with L{stdio.StandardIO}.
|
||||
"""
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
When the connection is lost, there is nothing more to do. Stop the
|
||||
reactor so that the process can exit.
|
||||
"""
|
||||
reactor.stop()
|
||||
|
||||
|
||||
def runWithProtocol(klass):
|
||||
fd = sys.__stdin__.fileno()
|
||||
oldSettings = termios.tcgetattr(fd)
|
||||
tty.setraw(fd)
|
||||
try:
|
||||
stdio.StandardIO(ServerProtocol(klass))
|
||||
reactor.run()
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
|
||||
os.write(fd, b"\r\x1bc\r")
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
log.startLogging(open("child.log", "w"))
|
||||
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
if argv:
|
||||
klass = reflect.namedClass(argv[0])
|
||||
else:
|
||||
klass = ConsoleManhole
|
||||
runWithProtocol(klass)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
91
.venv/lib/python3.12/site-packages/twisted/conch/tap.py
Normal file
91
.venv/lib/python3.12/site-packages/twisted/conch/tap.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_tap -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Support module for making SSH servers with twistd.
|
||||
"""
|
||||
|
||||
from twisted.application import strports
|
||||
from twisted.conch import checkers as conch_checkers, unix
|
||||
from twisted.conch.openssh_compat import factory
|
||||
from twisted.cred import portal, strcred
|
||||
from twisted.python import usage
|
||||
|
||||
|
||||
class Options(usage.Options, strcred.AuthOptionMixin):
|
||||
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
|
||||
longdesc = (
|
||||
"Makes a Conch SSH server. If no authentication methods are "
|
||||
"specified, the default authentication methods are UNIX passwords "
|
||||
"and SSH public keys. If --auth options are "
|
||||
"passed, only the measures specified will be used."
|
||||
)
|
||||
optParameters = [
|
||||
["interface", "i", "", "local interface to which we listen"],
|
||||
["port", "p", "tcp:22", "Port on which to listen"],
|
||||
["data", "d", "/etc", "directory to look for host keys in"],
|
||||
[
|
||||
"moduli",
|
||||
"",
|
||||
None,
|
||||
"directory to look for moduli in " "(if different from --data)",
|
||||
],
|
||||
]
|
||||
compData = usage.Completions(
|
||||
optActions={
|
||||
"data": usage.CompleteDirs(descr="data directory"),
|
||||
"moduli": usage.CompleteDirs(descr="moduli directory"),
|
||||
"interface": usage.CompleteNetInterfaces(),
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
usage.Options.__init__(self, *a, **kw)
|
||||
|
||||
# Call the default addCheckers (for backwards compatibility) that will
|
||||
# be used if no --auth option is provided - note that conch's
|
||||
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
|
||||
# checker
|
||||
super().addChecker(conch_checkers.UNIXPasswordDatabase())
|
||||
super().addChecker(
|
||||
conch_checkers.SSHPublicKeyChecker(conch_checkers.UNIXAuthorizedKeysFiles())
|
||||
)
|
||||
self._usingDefaultAuth = True
|
||||
|
||||
def addChecker(self, checker):
|
||||
"""
|
||||
Add the checker specified. If any checkers are added, the default
|
||||
checkers are automatically cleared and the only checkers will be the
|
||||
specified one(s).
|
||||
"""
|
||||
if self._usingDefaultAuth:
|
||||
self["credCheckers"] = []
|
||||
self["credInterfaces"] = {}
|
||||
self._usingDefaultAuth = False
|
||||
super().addChecker(checker)
|
||||
|
||||
|
||||
def makeService(config):
|
||||
"""
|
||||
Construct a service for operating a SSH server.
|
||||
|
||||
@param config: An L{Options} instance specifying server options, including
|
||||
where server keys are stored and what authentication methods to use.
|
||||
|
||||
@return: A L{twisted.application.service.IService} provider which contains
|
||||
the requested SSH server.
|
||||
"""
|
||||
|
||||
t = factory.OpenSSHFactory()
|
||||
|
||||
r = unix.UnixSSHRealm()
|
||||
t.portal = portal.Portal(r, config.get("credCheckers", []))
|
||||
t.dataRoot = config["data"]
|
||||
t.moduliRoot = config["moduli"] or config["data"]
|
||||
|
||||
port = config["port"]
|
||||
if config["interface"]:
|
||||
# Add warning here
|
||||
port += ":interface=" + config["interface"]
|
||||
return strports.service(port, t)
|
||||
1144
.venv/lib/python3.12/site-packages/twisted/conch/telnet.py
Normal file
1144
.venv/lib/python3.12/site-packages/twisted/conch/telnet.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
"conch tests"
|
||||
671
.venv/lib/python3.12/site-packages/twisted/conch/test/keydata.py
Normal file
671
.venv/lib/python3.12/site-packages/twisted/conch/test/keydata.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_keys -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# pylint: disable=I0011,C0103,W9401,W9402
|
||||
|
||||
"""
|
||||
Data used by test_keys as well as others.
|
||||
"""
|
||||
|
||||
|
||||
from base64 import decodebytes
|
||||
|
||||
RSAData = {
|
||||
"n": int(
|
||||
"269413617238113438198661010376758399219880277968382122687862697"
|
||||
"296942471209955603071120391975773283844560230371884389952067978"
|
||||
"789684135947515341209478065209455427327369102356204259106807047"
|
||||
"964139525310539133073743116175821417513079706301100600025815509"
|
||||
"786721808719302671068052414466483676821987505720384645561708425"
|
||||
"794379383191274856941628512616355437197560712892001107828247792"
|
||||
"561858327085521991407807015047750218508971611590850575870321007"
|
||||
"991909043252470730134547038841839367764074379439843108550888709"
|
||||
"430958143271417044750314742880542002948053835745429446485015316"
|
||||
"60749404403945254975473896534482849256068133525751"
|
||||
),
|
||||
"e": 65537,
|
||||
"d": int(
|
||||
"420335724286999695680502438485489819800002417295071059780489811"
|
||||
"840828351636754206234982682752076205397047218449504537476523960"
|
||||
"987613148307573487322720481066677105211155388802079519869249746"
|
||||
"774085882219244493290663802569201213676433159425782937159766786"
|
||||
"329742053214957933941260042101377175565683849732354700525628975"
|
||||
"239000548651346620826136200952740446562751690924335365940810658"
|
||||
"931238410612521441739702170503547025018016868116037053013935451"
|
||||
"477930426013703886193016416453215950072147440344656137718959053"
|
||||
"897268663969428680144841987624962928576808352739627262941675617"
|
||||
"7724661940425316604626522633351193810751757014073"
|
||||
),
|
||||
"p": int(
|
||||
"152689878451107675391723141129365667732639179427453246378763774"
|
||||
"448531436802867910180261906924087589684175595016060014593521649"
|
||||
"964959248408388984465569934780790357826811592229318702991401054"
|
||||
"226302790395714901636384511513449977061729214247279176398290513"
|
||||
"085108930550446985490864812445551198848562639933888780317"
|
||||
),
|
||||
"q": int(
|
||||
"176444974592327996338888725079951900172097062203378367409936859"
|
||||
"072670162290963119826394224277287608693818012745872307600855894"
|
||||
"647300295516866118620024751601329775653542084052616260193174546"
|
||||
"400544176890518564317596334518015173606460860373958663673307503"
|
||||
"231977779632583864454001476729233959405710696795574874403"
|
||||
),
|
||||
"u": int(
|
||||
"936018002388095842969518498561007090965136403384715613439364803"
|
||||
"229386793506402222847415019772053080458257034241832795210460612"
|
||||
"924445085372678524176842007912276654532773301546269997020970818"
|
||||
"155956828553418266110329867222673040098885651348225673298948529"
|
||||
"93885224775891490070400861134282266967852120152546563278"
|
||||
),
|
||||
}
|
||||
|
||||
DSAData = {
|
||||
"g": int(
|
||||
"10253261326864117157640690761723586967382334319435778695"
|
||||
"29171533815411392477819921538350732400350395446211982054"
|
||||
"96512489289702949127531056893725702005035043292195216541"
|
||||
"11525058911428414042792836395195432445511200566318251789"
|
||||
"10575695836669396181746841141924498545494149998282951407"
|
||||
"18645344764026044855941864175"
|
||||
),
|
||||
"p": int(
|
||||
"10292031726231756443208850082191198787792966516790381991"
|
||||
"77502076899763751166291092085666022362525614129374702633"
|
||||
"26262930887668422949051881895212412718444016917144560705"
|
||||
"45675251775747156453237145919794089496168502517202869160"
|
||||
"78674893099371444940800865897607102159386345313384716752"
|
||||
"18590012064772045092956919481"
|
||||
),
|
||||
"q": 1393384845225358996250882900535419012502712821577,
|
||||
"x": 1220877188542930584999385210465204342686893855021,
|
||||
"y": int(
|
||||
"14604423062661947579790240720337570315008549983452208015"
|
||||
"39426429789435409684914513123700756086453120500041882809"
|
||||
"10283610277194188071619191739512379408443695946763554493"
|
||||
"86398594314468629823767964702559709430618263927529765769"
|
||||
"10270265745700231533660131769648708944711006508965764877"
|
||||
"684264272082256183140297951"
|
||||
),
|
||||
}
|
||||
|
||||
ECDatanistp256 = {
|
||||
"x": int(
|
||||
"762825130203920963171185031449647317742997734817505505433829043"
|
||||
"45687059013883"
|
||||
),
|
||||
"y": int(
|
||||
"815431978646028526322656647694416475343443758943143196810611371"
|
||||
"59310646683104"
|
||||
),
|
||||
"privateValue": int(
|
||||
"3463874347721034170096400845565569825355565567882605"
|
||||
"9678074967909361042656500"
|
||||
),
|
||||
"curve": b"ecdsa-sha2-nistp256",
|
||||
}
|
||||
|
||||
SKECDatanistp256 = {
|
||||
"x": int(
|
||||
"239399367768747020111880335553299826848360860410053166887934464"
|
||||
"83115637049597"
|
||||
),
|
||||
"y": int(
|
||||
"114119006635761413192818806701564910719235784173643448381780025"
|
||||
"223832906554748"
|
||||
),
|
||||
"curve": b"sk-ecdsa-sha2-nistp256@openssh.com",
|
||||
}
|
||||
|
||||
ECDatanistp384 = {
|
||||
"privateValue": int(
|
||||
"280814107134858470598753916394807521398239633534281"
|
||||
"633982576099083357871098966021020900021966162732114"
|
||||
"95718603965098"
|
||||
),
|
||||
"x": int(
|
||||
"10036914308591746758780165503819213553101287571902957054148542"
|
||||
"504671046744460374996612408381962208627004841444205030"
|
||||
),
|
||||
"y": int(
|
||||
"17337335659928075994560513699823544906448896792102247714689323"
|
||||
"575406618073069185107088229463828921069465902299522926"
|
||||
),
|
||||
"curve": b"ecdsa-sha2-nistp384",
|
||||
}
|
||||
|
||||
ECDatanistp521 = {
|
||||
"x": int(
|
||||
"12944742826257420846659527752683763193401384271391513286022917"
|
||||
"29910013082920512632908350502247952686156279140016049549948975"
|
||||
"670668730618745449113644014505462"
|
||||
),
|
||||
"y": int(
|
||||
"10784108810271976186737587749436295782985563640368689081052886"
|
||||
"16296815984553198866894145509329328086635278430266482551941240"
|
||||
"591605833440825557820439734509311"
|
||||
),
|
||||
"privateValue": int(
|
||||
"662751235215460886290293902658128847495347691199214"
|
||||
"706697089140769672273950767961331442265530524063943"
|
||||
"548846724348048614239791498442599782310681891569896"
|
||||
"0565"
|
||||
),
|
||||
"curve": b"ecdsa-sha2-nistp521",
|
||||
}
|
||||
|
||||
Ed25519Data = {
|
||||
"a": (
|
||||
b"\xf1\x16\xd1\x15J\x1e\x15\x0e\x19^\x19F\xb5\xf2D\r\xb2R\xa0\xae*k"
|
||||
b"#\x13sE\xfd@\xd9W{\x8b"
|
||||
),
|
||||
"k": (
|
||||
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
|
||||
b"\xbb\xc9DX"
|
||||
),
|
||||
}
|
||||
|
||||
SKEd25519Data = {
|
||||
"a": (
|
||||
b"\x08}'U\xd2i\x04\x11\xea\x01~+\x165iRM\xdd\xe6R\x7f\xd3\xaf\\\xa8p"
|
||||
b"\xa0LL\xe5\x8a\xa0"
|
||||
),
|
||||
"k": (
|
||||
b"7/%\xda\x8d\xd4\xa8\x9ax|a\xf0\x98\x01\xc6\xf4^mg\x05i17Li\r\x05U"
|
||||
b"\xbb\xc9DX"
|
||||
),
|
||||
}
|
||||
|
||||
privateECDSA_openssh521 = b"""-----BEGIN EC PRIVATE KEY-----
|
||||
MIHcAgEBBEIAjn0lSVF6QweS4bjOGP9RHwqxUiTastSE0MVuLtFvkxygZqQ712oZ
|
||||
ewMvqKkxthMQgxzSpGtRBcmkL7RqZ94+18qgBwYFK4EEACOhgYkDgYYABAFpX/6B
|
||||
mxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB
|
||||
j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8X
|
||||
f09ETdku/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw==
|
||||
-----END EC PRIVATE KEY-----"""
|
||||
|
||||
# New format introduced in OpenSSH 6.5
|
||||
privateECDSA_openssh521_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS
|
||||
1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBaV/+gZscYJcA/laRL8NIXMsVcy8T
|
||||
ZzBs8WRhe8ZjY/J5RezPk0HLXZUdl7/yfgICHlRpgY9qfOG5FfYwc61DSRAArCY7bWcPDW
|
||||
1K4cY5C0/96zDbbsIxOLy42tD+wzExNBFfF39PRE3ZLv8/9bTkkq0r0cJlHDPZ0FCbRxwG
|
||||
Za+UNH8AAAEAeRISlnkSEpYAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ
|
||||
AAAIUEAWlf/oGbHGCXAP5WkS/DSFzLFXMvE2cwbPFkYXvGY2PyeUXsz5NBy12VHZe/8n4C
|
||||
Ah5UaYGPanzhuRX2MHOtQ0kQAKwmO21nDw1tSuHGOQtP/esw227CMTi8uNrQ/sMxMTQRXx
|
||||
d/T0RN2S7/P/W05JKtK9HCZRwz2dBQm0ccBmWvlDR/AAAAQgCOfSVJUXpDB5LhuM4Y/1Ef
|
||||
CrFSJNqy1ITQxW4u0W+THKBmpDvXahl7Ay+oqTG2ExCDHNKka1EFyaQvtGpn3j7XygAAAA
|
||||
ABAg==
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
publicECDSA_openssh521 = (
|
||||
b"ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACF"
|
||||
b"BAFpX/6BmxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB"
|
||||
b"j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8Xf09ETdku"
|
||||
b"/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw== comment"
|
||||
)
|
||||
|
||||
privateECDSA_openssh384 = b"""-----BEGIN EC PRIVATE KEY-----
|
||||
MIGkAgEBBDAtAi7I8j73WCX20qUM5hhHwHuFzYWYYILs2Sh8UZ+awNkARZ/Fu2LU
|
||||
LLl5RtOQpbWgBwYFK4EEACKhZANiAATU17sA9P5FRwSknKcFsjjsk0+E3CeXPYX0
|
||||
Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCW
|
||||
G0RqpQ+np31aKmeJshkcYALEchnU+tQ=
|
||||
-----END EC PRIVATE KEY-----"""
|
||||
|
||||
# New format introduced in OpenSSH 6.5
|
||||
privateECDSA_openssh384_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS
|
||||
1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTU17sA9P5FRwSknKcFsjjsk0+E3CeX
|
||||
PYX0Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCWG0
|
||||
RqpQ+np31aKmeJshkcYALEchnU+tQAAADIiktpWIpLaVgAAAATZWNkc2Etc2hhMi1uaXN0
|
||||
cDM4NAAAAAhuaXN0cDM4NAAAAGEE1Ne7APT+RUcEpJynBbI47JNPhNwnlz2F9E5PzNBytz
|
||||
6VkFoKzvCXURz/XhRPTv9tz/AbpsBbexfwLz5Qs1Xd+VX62hU9KOHAlhtEaqUPp6d9Wipn
|
||||
ibIZHGACxHIZ1PrUAAAAMC0CLsjyPvdYJfbSpQzmGEfAe4XNhZhgguzZKHxRn5rA2QBFn8
|
||||
W7YtQsuXlG05CltQAAAAA=
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
publicECDSA_openssh384 = (
|
||||
b"ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABh"
|
||||
b"BNTXuwD0/kVHBKScpwWyOOyTT4TcJ5c9hfROT8zQcrc+lZBaCs7wl1Ec/14UT07/bc/wG6bA"
|
||||
b"W3sX8C8+ULNV3flV+toVPSjhwJYbRGqlD6enfVoqZ4myGRxgAsRyGdT61A== comment"
|
||||
)
|
||||
|
||||
publicECDSA_openssh = (
|
||||
b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABB"
|
||||
b"BKimX1DZ7+Qj0SpfePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBY"
|
||||
b"gkN/34n42F4vpeA= comment"
|
||||
)
|
||||
|
||||
publicSKECDSA_openssh = (
|
||||
b"sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
|
||||
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
|
||||
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
|
||||
b"LXRlc3Q= comment"
|
||||
)
|
||||
|
||||
publicSKEd25519_openssh = (
|
||||
b"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
|
||||
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
|
||||
)
|
||||
|
||||
publicSKECDSA_cert_openssh = (
|
||||
b"sk-ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3"
|
||||
b"BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBDTthidmBSzlQiO8aZPfLmUDOS2TSRevW8IrHPK"
|
||||
b"IhYj9/E0RnTyvPIB1eWQx4rQl5iO1mihuBz+u4LkjwVEU3XwAAAAUc3NoOmVjZHNhLWZpZG8y"
|
||||
b"LXRlc3Q= comment"
|
||||
)
|
||||
|
||||
publicSKEd25519_cert_openssh = (
|
||||
b"sk-ssh-ed25519-cert-v01@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIA"
|
||||
b"h9J1XSaQQR6gF+KxY1aVJN3eZSf9OvXKhwoExM5YqgAAAABHNzaDo= comment"
|
||||
)
|
||||
|
||||
privateECDSA_openssh = b"""-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIEyU1YOT2JxxofwbJXIjGftdNcJK55aQdNrhIt2xYQz0oAoGCCqGSM49
|
||||
AwEHoUQDQgAEqKZfUNnv5CPRKl948xujWlvrIaQBvmXt24LWXznnIPu0R9B+qTtt
|
||||
zu/jpZ7WEszLPo5tQFiCQ3/fifjYXi+l4A==
|
||||
-----END EC PRIVATE KEY-----"""
|
||||
|
||||
# New format introduced in OpenSSH 6.5
|
||||
privateECDSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS
|
||||
1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQSopl9Q2e/kI9EqX3jzG6NaW+shpAG+
|
||||
Ze3bgtZfOecg+7RH0H6pO23O7+OlntYSzMs+jm1AWIJDf9+J+NheL6XgAAAAmCKU4hcilO
|
||||
IXAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKimX1DZ7+Qj0Spf
|
||||
ePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBYgkN/34n42F4vpe
|
||||
AAAAAgTJTVg5PYnHGh/BslciMZ+101wkrnlpB02uEi3bFhDPQAAAAA
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
publicEd25519_openssh = (
|
||||
b"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPEW0RVKHhUOGV4ZRrXyRA2yUqCuKmsjE3NF"
|
||||
b"/UDZV3uL comment"
|
||||
)
|
||||
|
||||
# OpenSSH has only ever supported the "new" (v1) format for Ed25519.
|
||||
privateEd25519_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
|
||||
QyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iwAAAJA61eMLOtXj
|
||||
CwAAAAtzc2gtZWQyNTUxOQAAACDxFtEVSh4VDhleGUa18kQNslKgriprIxNzRf1A2Vd7iw
|
||||
AAAEA3LyXajdSomnh8YfCYAcb0Xm1nBWkxN0xpDQVVu8lEWPEW0RVKHhUOGV4ZRrXyRA2y
|
||||
UqCuKmsjE3NF/UDZV3uLAAAAB2NvbW1lbnQBAgMEBQY=
|
||||
-----END OPENSSH PRIVATE KEY-----"""
|
||||
|
||||
publicRSA_openssh = (
|
||||
b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bWG+wloVDEd2NQhEUBVUIUKirg"
|
||||
b"0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgYwBGTJAkMgUyP"
|
||||
b"95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm/9NNN9b0b/h9qp"
|
||||
b"KSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6RKXCpCnd1bqcPUWz"
|
||||
b"xiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1"
|
||||
b"foNfICZgptyti8ZseZj3 comment"
|
||||
)
|
||||
|
||||
privateRSA_openssh = b"""-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkG
|
||||
XoRVdV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMW
|
||||
aqQE6Ul3w+RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4fa
|
||||
qSknqv+t9YXmPhq4eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp
|
||||
3dW6nD1Fs8YsGGTVuj3fq3/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28
|
||||
OyluAgtX33ToE7Q3NX6DXyAmYKbcrYvGbHmY9wIDAQABAoIBACFMCGaiKNW0+44P
|
||||
chuFCQC58k438BxXS+NRf54jp+Q6mFUb6ot6mB682Lqx+YkSGGCs6MwLTglaQGq6
|
||||
L5n4syRghLnOaZWa+eL8H1FNJxXbKyet77RprL59EOuGR3BztACHlRU7N/nnFOeA
|
||||
u2geG+bdu3NjuWfmsid/z88wm8KY/dkYNi82LvE9gXqf4QMtR9s0UWI53U/prKiL
|
||||
2dbzhMQXuXGdBghCeE27xSr0w1jNVSvtvjNfBOp75gQkY/It1z0bbNWcY0MvkoiN
|
||||
Pm7aGDfYDyVniR25RjReyc7Ei+2SWjMHD9+GCPmS6dvrOAg2yc3NCgFIWzk+esrG
|
||||
gKnc1DkCgYEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
|
||||
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6D
|
||||
MaIVokQ9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0CgYEA+0QX
|
||||
i6Q2vh43Haf2YWwExKrdeD4HjB4zAq4DFIeDeuWefQhnqPKqvxJwz3Kpp8cLHYjV
|
||||
IP2cY8pHMFVOi8TP9H8WpJISdKEJwsRunIwz76Xl9+ArrU9cEaoahDdb/Xrqw818
|
||||
sMjkH1Rjtcev3/QJp/zHJfxc6ZHXksWYHlbTsSMCgYBRr+mSn5QLSoRlPpSzO5IQ
|
||||
tXS4jMnvyQ4BMvovaBKhAyauz1FoFEwmmyikAjMIX+GncJgBNHleUo7Ezza8H0tV
|
||||
rOvBU4TH4WGoStSi/0ANgB8SqVDAKhh1lAwGmxZQqEvsQc177/dLyXUCaMSYuIaI
|
||||
GFpD5wIzlyJkk4MMRSp87QKBgGlmN8ZA3SHFBPOwuD5HlHx2/C3rPzk8lcNDAVHE
|
||||
Qpfz6Bakxu7s1EkQUDgE7jvN19DMzDJpkAegG1qf/jHNHjp+cR4ZlBpOTwzfX1LV
|
||||
0Rdu7NectlWd244hX7wkiLb8r6vw76QssNyfhrADEriL4t0PwO4jPUpQ/i+4KUZY
|
||||
v7YnAoGAZhb5IDTQVCW8YTGsgvvvnDUefkpVAmiVDQqTvh6/4UD6kKdUcDHpePzg
|
||||
Zrcid5rr3dXSMEbK4tdeQZvPtUg1Uaol3N7bNClIIdvWdPx+5S9T95wJcLnkoHam
|
||||
rXp0IjScTxfLP+Cq5V6lJ94/pX8Ppoj1FdZfNxeS4NYFSRA7kvY=
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# New format introduced in OpenSSH 6.5
|
||||
privateRSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
|
||||
NhAAAAAwEAAQAAAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkGXoRV
|
||||
dV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMWaqQE6Ul3w+
|
||||
RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4faqSknqv+t9YXmPhq4
|
||||
eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp3dW6nD1Fs8YsGGTVuj3fq3
|
||||
/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28OyluAgtX33ToE7Q3NX6DXyAmYKbc
|
||||
rYvGbHmY9wAAA7gXkBoMF5AaDAAAAAdzc2gtcnNhAAABAQDVaqx4I9bWG+wloVDEd2NQhE
|
||||
UBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgY
|
||||
wBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm
|
||||
/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6
|
||||
RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvb
|
||||
w7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAAAAwEAAQAAAQAhTAhmoijVtPuOD3Ib
|
||||
hQkAufJON/AcV0vjUX+eI6fkOphVG+qLepgevNi6sfmJEhhgrOjMC04JWkBqui+Z+LMkYI
|
||||
S5zmmVmvni/B9RTScV2ysnre+0aay+fRDrhkdwc7QAh5UVOzf55xTngLtoHhvm3btzY7ln
|
||||
5rInf8/PMJvCmP3ZGDYvNi7xPYF6n+EDLUfbNFFiOd1P6ayoi9nW84TEF7lxnQYIQnhNu8
|
||||
Uq9MNYzVUr7b4zXwTqe+YEJGPyLdc9G2zVnGNDL5KIjT5u2hg32A8lZ4kduUY0XsnOxIvt
|
||||
klozBw/fhgj5kunb6zgINsnNzQoBSFs5PnrKxoCp3NQ5AAAAgQCFSxt6mxIQN54frV7a/s
|
||||
aW/t81a7k04haXkiYJvb1wIAOnNb0tG6DSB0cr1N6oqAcHG7gEIKcnQTxsOTnpQc7nFx3R
|
||||
TFy8PdImJv5q1v1Icq5G+nvD0xlgRB2lE6eA9WMp1HpdBgcWXfaLPctkOuKEWk2MBi0tnR
|
||||
zrg0x4PXlUzgAAAIEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
|
||||
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6DMaIVok
|
||||
Q9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0AAACBAPtEF4ukNr4eNx2n
|
||||
9mFsBMSq3Xg+B4weMwKuAxSHg3rlnn0IZ6jyqr8ScM9yqafHCx2I1SD9nGPKRzBVTovEz/
|
||||
R/FqSSEnShCcLEbpyMM++l5ffgK61PXBGqGoQ3W/166sPNfLDI5B9UY7XHr9/0Caf8xyX8
|
||||
XOmR15LFmB5W07EjAAAAAAEC
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
# Encrypted with the passphrase 'encrypted'
|
||||
privateRSA_openssh_encrypted = b"""-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
|
||||
|
||||
p2A1YsHLXkpMVcsEqhh/nCYb5AqL0uMzfEIqc8hpZ/Ub8PtLsypilMkqzYTnZIGS
|
||||
ouyPjU/WgtR4VaDnutPWdgYaKdixSEmGhKghCtXFySZqCTJ4O8NCczsktYjUK3D4
|
||||
Jtl90zL6O81WBY6xP76PBQo9lrI/heAetATeyqutc18bwQIGU+gKk32qvfo15DfS
|
||||
VYiY0Ds4D7F7fd9pz+f5+UbFUCgU+tfDvBrqodYrUgmH7jKoW/CRDCHHyeEIZDbF
|
||||
mcMwdcKOyw1sRLaPdihRSVx3kOMvIotHKVTkIDMp+0RTNeXzQnp5U2qzsxzTcG/M
|
||||
UyJN38XXkuvq5VMj2zmmjHzx34w3NK3ZxpZcoaFUqUBlNp2C8hkCLrAa/DWobKqN
|
||||
5xA1ElrQvli9XXkT/RIuy4Gc10bbGEoJjuxNRibtSxxWd5Bd1E40ocOd4l1ebI8+
|
||||
w69XvMTnsmHvkBEADGF2zfRszKnMelg+W5NER1UDuNT03i+1cuhp+2AZg8z7niTO
|
||||
M17XP3ScGVxrQAEYgtxPrPeIpFJvOx2j5Yt78U9Y2WlaAG6DrubbYv2RsMIibhOG
|
||||
yk139vMdD8FwCey6yMkkhFAJwnBtC22MAWgjmC5c6AF3SRQSjjQXepPsJcLgpOjy
|
||||
YwjhnL8w56x9kVDUNPw9A9Cqgxo2sty34ATnKrh4h59PsP83LOL6OC5WjbASgZRd
|
||||
OIBD8RloQPISo+RUF7X0i4kdaHVNPlR0KyapR+3M5BwhQuvEO99IArDV2LNKGzfc
|
||||
W4ssugm8iyAJlmwmb2yRXIDHXabInWY7XCdGk8J2qPFbDTvnPbiagJBimjVjgpWw
|
||||
tV3sVlJYqmOqmCDP78J6he04l0vaHtiOWTDEmNCrK7oFMXIIp3XWjOZGPSOJFdPs
|
||||
6Go3YB+EGWfOQxqkFM28gcqmYfVPF2sa1FbZLz0ffO11Ma/rliZxZu7WdrAXe/tc
|
||||
BgIQ8etp2PwAK4jCwwVwjIO8FzqQGpS23Y9NY3rfi97ckgYXKESFtXPsMMA+drZd
|
||||
ThbXvccfh4EPmaqQXKf4WghHiVJ+/yuY1kUIDEl/O0jRZWT7STgBim/Aha1m6qRs
|
||||
zl1H7hkDbU4solb1GM5oPzbgGTzyBc+z0XxM9iFRM+fMzPB8+yYHTr4kPbVmKBjy
|
||||
SCovjQQVsHE4YeUGTq6k/NF5cVIRKTW/RlHvzxsky1Zj31MC736jrxGw4KG7VSLZ
|
||||
fP6F5jj+mXwS7m0v5to42JBZmRJdKUD88QaGE3ncyQ4yleW5bn9Lf9SuzQg1Dhao
|
||||
3rSA1RuexsHlIAHvGxx/17X+pyygl8DJbt6TBfbLQk9wc707DJTfh5M/bnk9wwIX
|
||||
l/Hsa1WtylAMW/2MzgiVy83MbYz4+Ss6GQ5W66okWji+NxrnrYEy6q+WgVQanp7X
|
||||
D+D7oKykqE1Cdvvulvtfl5fh8wlAs8mrUnKPBBUru348u++2lfacLkxRXyT1ooqY
|
||||
uSNE5nlwFt08N2Ou/bl7yq6QNRMYrRkn+UEfHWCNYDoGMHln2/i6Z1RapQzNarik
|
||||
tJf7radBz5nBwBjP08YAEACNSQvpsUgdqiuYjLwX7efFXQva2RzqaQ==
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# Encrypted with the passphrase 'encrypted', and using the new format
|
||||
# introduced in OpenSSH 6.5
|
||||
privateRSA_openssh_encrypted_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABD0f9WAof
|
||||
DTbmwztb8pdrSeAAAAEAAAAAEAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bW
|
||||
G+wloVDEd2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n
|
||||
3WmM06QHjVyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9T
|
||||
MA2l5bs9auIJNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQN
|
||||
UZdy03w17snaY6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASf
|
||||
NaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAADwPQaac8s1xX3af
|
||||
hQTQexj0vEAWDQsLYzDHN9G7W+UP5WHUu7igeu2GqAC/TOnjUXDP73I+EN3n7T3JFeDRfs
|
||||
U1Z6Zqb0NKHSRVYwDIdIi8qVohFv85g6+xQ01OpaoOzz+vI34OUvCRHQGTgR6L9fQShZyC
|
||||
McopYMYfbIse6KcqkfxX3KSdG1Pao6Njx/ShFRbgvmALpR/z0EaGCzHCDxpfUyAdnxm621
|
||||
Jzaf+LverWdN7sfrfMptaS9//9iJb70sL67K+YIB64qhDnA/w9UOQfXGQFL+AEtdM0BPv8
|
||||
thP1bs7T0yucBl+ZXdrDKVLZfaS3S/w85Jlgfu+a1DG73pOBOuag435iEJ9EnspjXiiydx
|
||||
GrfSRk2C+/c4fBDZVGFscK5bfQuUUZyU1qOagekxX7WLHFKk9xajnud+nrAN070SeNwlX8
|
||||
FZ2CI4KGlQfDvVUpKanYn8Kkj3fZ+YBGyx4M+19clF65FKSM0x1Rrh5tAmNT/SNDbSc28m
|
||||
ASxrBhztzxUFTrIn3tp+uqkJniFLmFsUtiAUmj8fNyE9blykU7dqq+CqpLA872nQ9bOHHA
|
||||
JsS1oBYmQ0n6AJz8WrYMdcepqWVld6Q8QSD1zdrY/sAWUovuBA1s4oIEXZhpXSS4ZJiMfh
|
||||
PVktKBwj5bmoG/mmwYLbo0JHntK8N3TGTzTGLq5TpSBBdVvWSWo7tnfEkrFObmhi1uJSrQ
|
||||
3zfPVP6BguboxBv+oxhaUBK8UOANe6ZwM4vfiu+QN+sZqWymHIfAktz7eWzwlToe4cKpdG
|
||||
Uv+e3/7Lo2dyMl3nke5HsSUrlsMGPREuGkBih8+o85ii6D+cuCiVtus3f5c78Cir80zLIr
|
||||
Z0wWvEAjciEvml00DWaA+JIaOrWwvXySaOzFGpCqC9SQjao379bvn9P3b7kVZsy6zBfHqm
|
||||
bNEJUOuhBZaY8Okz36chh1xqh4sz7m3nsZ3GYGcvM+3mvRY72QnqsQEG0Sp1XYIn2bHa29
|
||||
tqp7CG9X8J6dqMcPeoPRDWIX9gw7EPl/M0LP6xgewGJ9bgxwle6Mnr9kNITIswjAJqrLec
|
||||
zx7dfixjAPc42ADqrw/tEdFQcSqxigcfJNKO1LbDBjh+Hk/cSBou2PoxbIcl0qfQfbGcqI
|
||||
Dbpd695IEuiW9pYR22txNoIi+7cbMsuFHxQ/OqbrX/jCsprGNNJLAjgGsVEI1JnHWDH0db
|
||||
3UbqbOHAeY3ufoYXNY1utVOIACpW3r9wBw3FjRi04d70VcKr16OXvOAHGN2G++Y+kMya84
|
||||
Hl/Kt/gA==
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
# Encrypted with the passphrase 'testxp'. NB: this key was generated by
|
||||
# OpenSSH, so it doesn't use the same key data as the other keys here.
|
||||
privateRSA_openssh_encrypted_aes = b"""-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
publicRSA_lsh = (
|
||||
b"{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuMjU3OgDVaqx4I9bWG+wloVD"
|
||||
b"Ed2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHj"
|
||||
b"VyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auI"
|
||||
b"JNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY"
|
||||
b"6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw"
|
||||
b"7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3KSgxOmUzOgEAASkpKQ==}"
|
||||
)
|
||||
|
||||
privateRSA_lsh = (
|
||||
b"(11:private-key(9:rsa-pkcs1(1:n257:\x00\xd5j\xacx#\xd6\xd6\x1b\xec%\xa1P"
|
||||
b"\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ\xfa9\x06^\x84Uu_"
|
||||
b"\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98\xcd:@x\xd5\xca"
|
||||
b"\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j\xa4\x04\xe9Iw"
|
||||
b"\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e\xefS0\r\xa5\xe5"
|
||||
b"\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad\xf5\x85\xe6>"
|
||||
b"\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?\x91I\x96@\xd5"
|
||||
b"\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E\xb3\xc6,\x18d"
|
||||
b"\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b\x0b\xa4*nC\xc3"
|
||||
b"\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdft\xe8\x13\xb475"
|
||||
b"~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7)(1:e3:\x01\x00\x01)(1:d256:!L"
|
||||
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
|
||||
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
|
||||
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
|
||||
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
|
||||
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
|
||||
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
|
||||
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
|
||||
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
|
||||
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
|
||||
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
|
||||
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49)(1:p129:\x00\xfbD\x17\x8b\xa46\xbe"
|
||||
b"\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03\x14\x87"
|
||||
b"\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b\x1d\x88"
|
||||
b"\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1\t\xc2"
|
||||
b"\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz\xea"
|
||||
b"\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\\xe9"
|
||||
b"\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#)(1:q129:\x00\xd9p\x06\xd8\xe2\xbc\xd4"
|
||||
b"x\x91P\x94\xd4\xc1\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}"
|
||||
b"\x1a\xb1e\xe7qu9\xe02\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca"
|
||||
b"\xe6MC\xb3\x9c\xf4k}\xe6\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$"
|
||||
b'n\x831\xa2\x15\xa2D="\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff'
|
||||
b"\x19\x18\x8e\xd8\xab\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d)(1:a128:if7"
|
||||
b"\xc6@\xdd!\xc5\x04\xf3\xb0\xb8>G\x94|v\xfc-\xeb?9<\x95\xc3C\x01Q\xc4B"
|
||||
b"\x97\xf3\xe8\x16\xa4\xc6\xee\xec\xd4I\x10P8\x04\xee;\xcd\xd7\xd0\xcc\xcc"
|
||||
b"2i\x90\x07\xa0\x1bZ\x9f\xfe1\xcd\x1e:~q\x1e\x19\x94\x1aNO\x0c\xdf_R\xd5"
|
||||
b"\xd1\x17n\xec\xd7\x9c\xb6U\x9d\xdb\x8e!_\xbc$\x88\xb6\xfc\xaf\xab\xf0"
|
||||
b"\xef\xa4,\xb0\xdc\x9f\x86\xb0\x03\x12\xb8\x8b\xe2\xdd\x0f\xc0\xee#=JP"
|
||||
b"\xfe/\xb8)FX\xbf\xb6')(1:b128:Q\xaf\xe9\x92\x9f\x94\x0bJ\x84e>\x94\xb3;"
|
||||
b"\x92\x10\xb5t\xb8\x8c\xc9\xef\xc9\x0e\x012\xfa/h\x12\xa1\x03&\xae\xcfQh"
|
||||
b"\x14L&\x9b(\xa4\x023\x08_\xe1\xa7p\x98\x014y^R\x8e\xc4\xcf6\xbc\x1fKU"
|
||||
b"\xac\xeb\xc1S\x84\xc7\xe1a\xa8J\xd4\xa2\xff@\r\x80\x1f\x12\xa9P\xc0*\x18"
|
||||
b"u\x94\x0c\x06\x9b\x16P\xa8K\xecA\xcd{\xef\xf7K\xc9u\x02h\xc4\x98\xb8\x86"
|
||||
b'\x88\x18ZC\xe7\x023\x97"d\x93\x83\x0cE*|\xed)(1:c128:f\x16\xf9 4\xd0T%'
|
||||
b"\xbca1\xac\x82\xfb\xef\x9c5\x1e~JU\x02h\x95\r\n\x93\xbe\x1e\xbf\xe1@\xfa"
|
||||
b'\x90\xa7Tp1\xe9x\xfc\xe0f\xb7"w\x9a\xeb\xdd\xd5\xd20F\xca\xe2\xd7^A\x9b'
|
||||
b"\xcf\xb5H5Q\xaa%\xdc\xde\xdb4)H!\xdb\xd6t\xfc~\xe5/S\xf7\x9c\tp\xb9\xe4"
|
||||
b"\xa0v\xa6\xadzt\"4\x9cO\x17\xcb?\xe0\xaa\xe5^\xa5'\xde?\xa5\x7f\x0f\xa6"
|
||||
b"\x88\xf5\x15\xd6_7\x17\x92\xe0\xd6\x05I\x10;\x92\xf6)))"
|
||||
)
|
||||
|
||||
privateRSA_agentv3 = (
|
||||
b"\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x03\x01\x00\x01\x00\x00\x01\x00!L"
|
||||
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
|
||||
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
|
||||
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
|
||||
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
|
||||
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
|
||||
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
|
||||
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
|
||||
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
|
||||
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
|
||||
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
|
||||
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49\x00\x00\x01\x01\x00\xd5j\xacx#\xd6"
|
||||
b"\xd6\x1b\xec%\xa1P\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ"
|
||||
b"\xfa9\x06^\x84Uu_\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98"
|
||||
b"\xcd:@x\xd5\xca\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j"
|
||||
b"\xa4\x04\xe9Iw\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e"
|
||||
b"\xefS0\r\xa5\xe5\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad"
|
||||
b"\xf5\x85\xe6>\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?"
|
||||
b"\x91I\x96@\xd5\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E"
|
||||
b"\xb3\xc6,\x18d\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b"
|
||||
b"\x0b\xa4*nC\xc3\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdf"
|
||||
b"t\xe8\x13\xb475~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7\x00\x00\x00\x81"
|
||||
b"\x00\x85K\x1bz\x9b\x12\x107\x9e\x1f\xad^\xda\xfe\xc6\x96\xfe\xdf5k\xb94"
|
||||
b"\xe2\x16\x97\x92&\t\xbd\xbdp \x03\xa75\xbd-\x1b\xa0\xd2\x07G+\xd4\xde"
|
||||
b"\xa8\xa8\x07\x07\x1b\xb8\x04 \xa7'A<l99\xe9A\xce\xe7\x17\x1d\xd1L\\\xbc="
|
||||
b"\xd2&&\xfej\xd6\xfdHr\xaeF\xfa{\xc3\xd3\x19`D\x1d\xa5\x13\xa7\x80\xf5c)"
|
||||
b"\xd4z]\x06\x07\x16]\xf6\x8b=\xcbd:\xe2\x84ZM\x8c\x06--\x9d\x1c\xeb\x83Lx"
|
||||
b"=yT\xce\x00\x00\x00\x81\x00\xd9p\x06\xd8\xe2\xbc\xd4x\x91P\x94\xd4\xc1"
|
||||
b"\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}\x1a\xb1e\xe7qu9\xe02"
|
||||
b"\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca\xe6MC\xb3\x9c\xf4k}\xe6"
|
||||
b'\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$n\x831\xa2\x15\xa2D="'
|
||||
b'\xa9b&"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff\x19\x18\x8e\xd8\xab'
|
||||
b"\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d\x00\x00\x00\x81\x00\xfbD\x17\x8b"
|
||||
b"\xa46\xbe\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03"
|
||||
b"\x14\x87\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b"
|
||||
b"\x1d\x88\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1"
|
||||
b"\t\xc2\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz"
|
||||
b"\xea\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\"
|
||||
b"\xe9\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#"
|
||||
)
|
||||
|
||||
publicDSA_openssh = b"""\
|
||||
ssh-dss AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9\
|
||||
LvFYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G\
|
||||
+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0\
|
||||
EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB\
|
||||
7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nD\
|
||||
ioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQT\
|
||||
NEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2G\
|
||||
gdgMQWC7S6WFIXePGGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8= \
|
||||
comment\
|
||||
"""
|
||||
|
||||
privateDSA_openssh = b"""\
|
||||
-----BEGIN DSA PRIVATE KEY-----
|
||||
MIIBvAIBAAKBgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyD
|
||||
JKsvnLLCDTP5Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VV
|
||||
Ir6LPzJmFSeuqk/fYbXGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQIVAPQR
|
||||
iZM1oUnwJLT68VIXidhzISdJAoGBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2M
|
||||
zB3tZXTMRAe8zcOWjp8Y4aGC7Yh3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4
|
||||
Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDKJ0fDdwBDub/UnwaktkPUejga+pX9OEb8
|
||||
KR2dFBrvAoGAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNICl
|
||||
GlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXeP
|
||||
GGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8CFQDV2gbL
|
||||
czUdxCus0pfEP1bddaXRLQ==
|
||||
-----END DSA PRIVATE KEY-----\
|
||||
"""
|
||||
|
||||
privateDSA_openssh_new = b"""\
|
||||
-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABsgAAAAdzc2gtZH
|
||||
NzAAAAgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyDJKsvnLLCDTP5
|
||||
Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VVIr6LPzJmFSeuqk/fYb
|
||||
XGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQAAABUA9BGJkzWhSfAktPrxUheJ2HMh
|
||||
J0kAAACBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2MzB3tZXTMRAe8zcOWjp8Y4aGC7Y
|
||||
h3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDK
|
||||
J0fDdwBDub/UnwaktkPUejga+pX9OEb8KR2dFBrvAAAAgAIUacRjCFhMmhIfGJ44ms0EzR
|
||||
KZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPxh2pFuWFh
|
||||
OHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQhJ
|
||||
lz6CzfAAAB2MVcBjzFXAY8AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9Lv
|
||||
FYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+
|
||||
Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+S
|
||||
x9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jA
|
||||
GmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmiz
|
||||
kvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ
|
||||
0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi35
|
||||
9efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDdWxlX8u
|
||||
mhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
|
||||
pdEtAAAAAAE=
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
publicDSA_lsh = decodebytes(
|
||||
b"""\
|
||||
e0tERXdPbkIxWW14cFl5MXJaWGtvTXpwa2MyRW9NVHB3TVRJNU9nQ1NrRHJGUkVWUTBDS1FEUngv
|
||||
aVFBTXBhUFM3eFdKaFJWVENMaHhScFdhU0MrN0lkeURKS3N2bkxMQ0RUUDVaeHc5MzVyQU1pNVZG
|
||||
MmJiZWp3L1M0R1VXczdEem9LYmJoL2hydVBCdnNoYmhQSmRIMVZWSXI2TFB6Sm1GU2V1cWsvZlli
|
||||
WEdTbXJhMDFtWjZWVnU0QlFBYUtzUG9YdGkyZElKSGxOUGtzZld1U2tvTVRweE1qRTZBUFFSaVpN
|
||||
MW9VbndKTFQ2OFZJWGlkaHpJU2RKS1NneE9tY3hNams2QUpJQzQ4UnlRQk82aFZuVTZuK3hHWjRM
|
||||
dTR3QnBuUWZaTjJNekIzdFpYVE1SQWU4emNPV2pwOFk0YUdDN1loM3FCSERwTmx6c3I1eFBsWDRT
|
||||
RnNGZ0hUb2lrVXhHWlpvczVMNFJtNnRBd250aER6RjRxNDZsZ0FpT1p3NHFFRjlhQkRLSjBmRGR3
|
||||
QkR1Yi9Vbndha3RrUFVlamdhK3BYOU9FYjhLUjJkRkJydktTZ3hPbmt4TWpnNkFoUnB4R01JV0V5
|
||||
YUVoOFluamlhelFUTkVwa2xSWnFlQkdvMWdvdEpnZ05tVmFJUU5JQ2xHbEx5Q2kzNTllZkVVdVFj
|
||||
WjlTWHhNNTlQK2hlY2MvR1UvR0hha1c1WVdFNGRQMkdnZGdNUVdDN1M2V0ZJWGVQR0dYcU5RRGRX
|
||||
eGxYOHVtaGVudlFxYTFQbktyRlJoRHJKdzhaN0dqZEh4ZmxzeENFbVhQb0xOOHBLU2s9fQ==
|
||||
"""
|
||||
)
|
||||
|
||||
privateDSA_lsh = decodebytes(
|
||||
b"""\
|
||||
KDExOnByaXZhdGUta2V5KDM6ZHNhKDE6cDEyOToAkpA6xURFUNAikA0cf4kADKWj0u8ViYUVUwi4
|
||||
cUaVmkgvuyHcgySrL5yywg0z+WccPd+awDIuVRdm23o8P0uBlFrOw86Cm24f4a7jwb7IW4TyXR9V
|
||||
VSK+iz8yZhUnrqpP32G1xkpq2tNZmelVbuAUAGirD6F7YtnSCR5TT5LH1rkpKDE6cTIxOgD0EYmT
|
||||
NaFJ8CS0+vFSF4nYcyEnSSkoMTpnMTI5OgCSAuPEckATuoVZ1Op/sRmeC7uMAaZ0H2TdjMwd7WV0
|
||||
zEQHvM3Dlo6fGOGhgu2Id6gRw6TZc7K+cT5V+EhbBYB06IpFMRmWaLOS+EZurQMJ7YQ8xeKuOpYA
|
||||
IjmcOKhBfWgQyidHw3cAQ7m/1J8GpLZD1Ho4GvqV/ThG/CkdnRQa7ykoMTp5MTI4OgIUacRjCFhM
|
||||
mhIfGJ44ms0EzRKZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPx
|
||||
h2pFuWFhOHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQ
|
||||
hJlz6CzfKSgxOngyMToA1doGy3M1HcQrrNKXxD9W3XWl0S0pKSk=
|
||||
"""
|
||||
)
|
||||
|
||||
privateDSA_agentv3 = decodebytes(
|
||||
b"""\
|
||||
AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9LvFYmFFVMIuHFGlZpIL7sh3IMkqy+c
|
||||
ssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99h
|
||||
tcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAA
|
||||
AIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOy
|
||||
vnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2
|
||||
Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQ
|
||||
NIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDd
|
||||
WxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
|
||||
pdEt
|
||||
"""
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DSAData",
|
||||
"RSAData",
|
||||
"privateDSA_agentv3",
|
||||
"privateDSA_lsh",
|
||||
"privateDSA_openssh",
|
||||
"privateRSA_agentv3",
|
||||
"privateRSA_lsh",
|
||||
"privateRSA_openssh",
|
||||
"publicDSA_lsh",
|
||||
"publicDSA_openssh",
|
||||
"publicRSA_lsh",
|
||||
"publicRSA_openssh",
|
||||
]
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
"""
|
||||
Loopback helper used in test_ssh and test_recvline
|
||||
"""
|
||||
|
||||
|
||||
from twisted.protocols import loopback
|
||||
|
||||
|
||||
class LoopbackRelay(loopback.LoopbackRelay):
|
||||
clearCall = None
|
||||
|
||||
def logPrefix(self):
|
||||
return f"LoopbackRelay({self.target.__class__.__name__!r})"
|
||||
|
||||
def write(self, data):
|
||||
loopback.LoopbackRelay.write(self, data)
|
||||
if self.clearCall is not None:
|
||||
self.clearCall.cancel()
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
self.clearCall = reactor.callLater(0, self._clearBuffer)
|
||||
|
||||
def _clearBuffer(self):
|
||||
self.clearCall = None
|
||||
loopback.LoopbackRelay.clearBuffer(self)
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{SSHTransportAddrress} in ssh/address.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from twisted.conch.ssh.address import SSHTransportAddress
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.test.test_address import AddressTestCaseMixin
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
|
||||
"""
|
||||
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
|
||||
use to represent the other side of the SSH connection. This tests the
|
||||
basic functionality of that class (string representation, comparison, &c).
|
||||
"""
|
||||
|
||||
def _stringRepresentation(self, stringFunction: Callable[[object], str]) -> None:
|
||||
"""
|
||||
The string representation of C{SSHTransportAddress} should be
|
||||
"SSHTransportAddress(<stringFunction on address>)".
|
||||
"""
|
||||
addr = self.buildAddress()
|
||||
stringValue = stringFunction(addr)
|
||||
addressValue = stringFunction(addr.address)
|
||||
self.assertEqual(stringValue, "SSHTransportAddress(%s)" % addressValue)
|
||||
|
||||
def buildAddress(self) -> SSHTransportAddress:
|
||||
"""
|
||||
Create an arbitrary new C{SSHTransportAddress}. A new instance is
|
||||
created for each call, but always for the same address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
|
||||
|
||||
def buildDifferentAddress(self) -> SSHTransportAddress:
|
||||
"""
|
||||
Like C{buildAddress}, but with a different fixed address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))
|
||||
@@ -0,0 +1,398 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.agent}.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.test import iosim
|
||||
from twisted.trial import unittest
|
||||
|
||||
try:
|
||||
import cryptography as _cryptography
|
||||
except ImportError:
|
||||
cryptography = None
|
||||
else:
|
||||
cryptography = _cryptography
|
||||
|
||||
try:
|
||||
from twisted.conch.ssh import agent as _agent, keys as _keys
|
||||
except ImportError:
|
||||
keys = agent = None
|
||||
else:
|
||||
keys, agent = _keys, _agent
|
||||
|
||||
from twisted.conch.error import ConchError, MissingKeyStoreError
|
||||
from twisted.conch.test import keydata
|
||||
|
||||
|
||||
class StubFactory:
|
||||
"""
|
||||
Mock factory that provides the keys attribute required by the
|
||||
SSHAgentServerProtocol
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.keys = {}
|
||||
|
||||
|
||||
class AgentTestBase(unittest.TestCase):
|
||||
"""
|
||||
Tests for SSHAgentServer/Client.
|
||||
"""
|
||||
|
||||
if agent is None or keys is None:
|
||||
skip = "Cannot run without cryptography"
|
||||
|
||||
def setUp(self):
|
||||
# wire up our client <-> server
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
agent.SSHAgentServer, agent.SSHAgentClient
|
||||
)
|
||||
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
# pub/priv keys of each kind
|
||||
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
|
||||
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
|
||||
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
|
||||
|
||||
|
||||
class ServerProtocolContractWithFactoryTests(AgentTestBase):
|
||||
"""
|
||||
The server protocol is stateful and so uses its factory to track state
|
||||
across requests. This test asserts that the protocol raises if its factory
|
||||
doesn't provide the necessary storage for that state.
|
||||
"""
|
||||
|
||||
def test_factorySuppliesKeyStorageForServerProtocol(self):
|
||||
# need a message to send into the server
|
||||
msg = struct.pack("!LB", 1, agent.AGENTC_REQUEST_IDENTITIES)
|
||||
del self.server.factory.__dict__["keys"]
|
||||
self.assertRaises(MissingKeyStoreError, self.server.dataReceived, msg)
|
||||
|
||||
|
||||
class UnimplementedVersionOneServerTests(AgentTestBase):
|
||||
"""
|
||||
Tests for methods with no-op implementations on the server. We need these
|
||||
for clients, such as openssh, that try v1 methods before going to v2.
|
||||
|
||||
Because the client doesn't expose these operations with nice method names,
|
||||
we invoke sendRequest directly with an op code.
|
||||
"""
|
||||
|
||||
def test_agentc_REQUEST_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA identities request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, b"")
|
||||
self.pump.flush()
|
||||
|
||||
def _cb(packet):
|
||||
self.assertEqual(agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0:1]))
|
||||
|
||||
return d.addCallback(_cb)
|
||||
|
||||
def test_agentc_REMOVE_RSA_IDENTITY(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove identity request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, b"")
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, b"")
|
||||
|
||||
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove all identities
|
||||
request.
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, b"")
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, b"")
|
||||
|
||||
|
||||
if agent is not None:
|
||||
|
||||
class CorruptServer(agent.SSHAgentServer): # type: ignore[name-defined]
|
||||
"""
|
||||
A misbehaving server that returns bogus response op codes so that we can
|
||||
verify that our callbacks that deal with these op codes handle such
|
||||
miscreants.
|
||||
"""
|
||||
|
||||
def agentc_REQUEST_IDENTITIES(self, data):
|
||||
self.sendResponse(254, b"")
|
||||
|
||||
def agentc_SIGN_REQUEST(self, data):
|
||||
self.sendResponse(254, b"")
|
||||
|
||||
|
||||
class ClientWithBrokenServerTests(AgentTestBase):
|
||||
"""
|
||||
verify error handling code in the client using a misbehaving server
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
CorruptServer, agent.SSHAgentClient
|
||||
)
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
def test_signDataCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.signData} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for data signing requests.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
def test_requestIdentitiesCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for identity requests.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
class AgentKeyAdditionTests(AgentTestBase):
|
||||
"""
|
||||
Test adding different flavors of keys to an agent.
|
||||
"""
|
||||
|
||||
def test_addRSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that omitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual(b"", serverKey[1])
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_addDSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that omitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual(b"", serverKey[1])
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_addRSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.rsaPrivate.privateBlob(), comment=b"My special key"
|
||||
)
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual(b"My special key", serverKey[1])
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_addDSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.dsaPrivate.privateBlob(), comment=b"My special key"
|
||||
)
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual(b"My special key", serverKey[1])
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
class AgentClientFailureTests(AgentTestBase):
|
||||
def test_agentFailure(self):
|
||||
"""
|
||||
verify that the client raises ConchError on AGENT_FAILURE
|
||||
"""
|
||||
d = self.client.sendRequest(254, b"")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
class AgentIdentityRequestsTests(AgentTestBase):
|
||||
"""
|
||||
Test operations against a server with identities already loaded.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate,
|
||||
b"a comment",
|
||||
)
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate,
|
||||
b"another comment",
|
||||
)
|
||||
|
||||
def test_signDataRSA(self):
|
||||
"""
|
||||
Sign data with an RSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
|
||||
self.pump.flush()
|
||||
signature = self.successResultOf(d)
|
||||
|
||||
expected = self.rsaPrivate.sign(b"John Hancock")
|
||||
self.assertEqual(expected, signature)
|
||||
self.assertTrue(self.rsaPublic.verify(signature, b"John Hancock"))
|
||||
|
||||
def test_signDataDSA(self):
|
||||
"""
|
||||
Sign data with a DSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.dsaPublic.blob(), b"John Hancock")
|
||||
self.pump.flush()
|
||||
|
||||
def _check(sig):
|
||||
# Cannot do this b/c DSA uses random numbers when signing
|
||||
# expected = self.dsaPrivate.sign("John Hancock")
|
||||
# self.assertEqual(expected, sig)
|
||||
self.assertTrue(self.dsaPublic.verify(sig, b"John Hancock"))
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_signDataRSAErrbackOnUnknownBlob(self):
|
||||
"""
|
||||
Assert that we get an errback if we try to sign data using a key that
|
||||
wasn't added.
|
||||
"""
|
||||
del self.server.factory.keys[self.rsaPublic.blob()]
|
||||
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
def test_requestIdentities(self):
|
||||
"""
|
||||
Assert that we get all of the keys/comments that we add when we issue a
|
||||
request for all identities.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
|
||||
def _check(keyt):
|
||||
expected = {}
|
||||
expected[self.dsaPublic.blob()] = b"a comment"
|
||||
expected[self.rsaPublic.blob()] = b"another comment"
|
||||
|
||||
received = {}
|
||||
for k in keyt:
|
||||
received[keys.Key.fromString(k[0], type="blob").blob()] = k[1]
|
||||
self.assertEqual(expected, received)
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
class AgentKeyRemovalTests(AgentTestBase):
|
||||
"""
|
||||
Test support for removing keys in a remote server.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate,
|
||||
b"a comment",
|
||||
)
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate,
|
||||
b"another comment",
|
||||
)
|
||||
|
||||
def test_removeRSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove an RSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.rsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
|
||||
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_removeDSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove a DSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.dsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
|
||||
return d.addCallback(_check)
|
||||
|
||||
def test_removeAllIdentities(self):
|
||||
"""
|
||||
Assert that we can remove all identities.
|
||||
"""
|
||||
d = self.client.removeAllIdentities()
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(0, len(self.server.factory.keys))
|
||||
|
||||
return d.addCallback(_check)
|
||||
1510
.venv/lib/python3.12/site-packages/twisted/conch/test/test_cftp.py
Normal file
1510
.venv/lib/python3.12/site-packages/twisted/conch/test/test_cftp.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,358 @@
|
||||
# Copyright Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test ssh/channel.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import skipIf
|
||||
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
try:
|
||||
from twisted.conch.ssh import channel
|
||||
from twisted.conch.ssh.address import SSHTransportAddress
|
||||
from twisted.conch.ssh.service import SSHService
|
||||
from twisted.conch.ssh.transport import SSHServerTransport
|
||||
from twisted.internet import interfaces
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.testing import StringTransport
|
||||
|
||||
skipTest = ""
|
||||
except ImportError:
|
||||
skipTest = "Conch SSH not supported."
|
||||
SSHService = object # type: ignore[assignment,misc]
|
||||
from twisted.trial.unittest import TestCase
|
||||
|
||||
|
||||
class MockConnection(SSHService):
|
||||
"""
|
||||
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
|
||||
that channels send, and when they try to close the connection.
|
||||
|
||||
@ivar data: a L{dict} mapping channel id #s to lists of data sent by that
|
||||
channel.
|
||||
@ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
|
||||
(extended data type, data) sent by that channel.
|
||||
@ivar closes: a L{dict} mapping channel id #s to True if that channel sent
|
||||
a close message.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data: dict[channel.SSHChannel, list[bytes]] = {}
|
||||
self.extData: dict[channel.SSHChannel, list[tuple[int, bytes]]] = {}
|
||||
self.closes: dict[channel.SSHChannel, bool] = {}
|
||||
|
||||
def logPrefix(self) -> str:
|
||||
"""
|
||||
Return our logging prefix.
|
||||
"""
|
||||
return "MockConnection"
|
||||
|
||||
def sendData(self, channel: channel.SSHChannel, data: bytes) -> None:
|
||||
"""
|
||||
Record the sent data.
|
||||
"""
|
||||
self.data.setdefault(channel, []).append(data)
|
||||
|
||||
def sendExtendedData(
|
||||
self, channel: channel.SSHChannel, type: int, data: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Record the sent extended data.
|
||||
"""
|
||||
self.extData.setdefault(channel, []).append((type, data))
|
||||
|
||||
def sendClose(self, channel: channel.SSHChannel) -> None:
|
||||
"""
|
||||
Record that the channel sent a close message.
|
||||
"""
|
||||
self.closes[channel] = True
|
||||
|
||||
|
||||
def connectSSHTransport(
|
||||
service: SSHService,
|
||||
hostAddress: interfaces.IAddress | None = None,
|
||||
peerAddress: interfaces.IAddress | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Connect a SSHTransport which is already connected to a remote peer to
|
||||
the channel under test.
|
||||
|
||||
@param service: Service used over the connected transport.
|
||||
@type service: L{SSHService}
|
||||
|
||||
@param hostAddress: Local address of the connected transport.
|
||||
@type hostAddress: L{interfaces.IAddress}
|
||||
|
||||
@param peerAddress: Remote address of the connected transport.
|
||||
@type peerAddress: L{interfaces.IAddress}
|
||||
"""
|
||||
transport = SSHServerTransport()
|
||||
transport.makeConnection(
|
||||
StringTransport(hostAddress=hostAddress, peerAddress=peerAddress)
|
||||
)
|
||||
transport.setService(service)
|
||||
|
||||
|
||||
@skipIf(skipTest, skipTest)
|
||||
class ChannelTests(TestCase):
|
||||
"""
|
||||
Tests for L{SSHChannel}.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Initialize the channel. remoteMaxPacket is 10 so that data is able
|
||||
to be sent (the default of 0 means no data is sent because no packets
|
||||
are made).
|
||||
"""
|
||||
self.conn = MockConnection()
|
||||
self.channel = channel.SSHChannel(conn=self.conn, remoteMaxPacket=10)
|
||||
self.channel.name = b"channel"
|
||||
|
||||
def test_interface(self) -> None:
|
||||
"""
|
||||
L{SSHChannel} instances provide L{interfaces.ITransport}.
|
||||
"""
|
||||
self.assertTrue(verifyObject(interfaces.ITransport, self.channel))
|
||||
|
||||
def test_init(self) -> None:
|
||||
"""
|
||||
Test that SSHChannel initializes correctly. localWindowSize defaults
|
||||
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
|
||||
defaults (what OpenSSH uses for those variables).
|
||||
|
||||
The values in the second set of assertions are meaningless; they serve
|
||||
only to verify that the instance variables are assigned in the correct
|
||||
order.
|
||||
"""
|
||||
c = channel.SSHChannel(conn=self.conn)
|
||||
self.assertEqual(c.localWindowSize, 131072)
|
||||
self.assertEqual(c.localWindowLeft, 131072)
|
||||
self.assertEqual(c.localMaxPacket, 32768)
|
||||
self.assertEqual(c.remoteWindowLeft, 0)
|
||||
self.assertEqual(c.remoteMaxPacket, 0)
|
||||
self.assertEqual(c.conn, self.conn)
|
||||
self.assertIsNone(c.data)
|
||||
self.assertIsNone(c.avatar)
|
||||
|
||||
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
|
||||
self.assertEqual(c2.localWindowSize, 1)
|
||||
self.assertEqual(c2.localWindowLeft, 1)
|
||||
self.assertEqual(c2.localMaxPacket, 2)
|
||||
self.assertEqual(c2.remoteWindowLeft, 3)
|
||||
self.assertEqual(c2.remoteMaxPacket, 4)
|
||||
self.assertEqual(c2.conn, 5)
|
||||
self.assertEqual(c2.data, 6)
|
||||
self.assertEqual(c2.avatar, 7)
|
||||
|
||||
def test_str(self) -> None:
|
||||
"""
|
||||
Test that str(SSHChannel) works gives the channel name and local and
|
||||
remote windows at a glance..
|
||||
"""
|
||||
self.assertEqual(str(self.channel), "<SSHChannel channel (lw 131072 rw 0)>")
|
||||
self.assertEqual(
|
||||
str(channel.SSHChannel(localWindow=1)), "<SSHChannel None (lw 1 rw 0)>"
|
||||
)
|
||||
|
||||
def test_bytes(self) -> None:
|
||||
"""
|
||||
Test that bytes(SSHChannel) works, gives the channel name and
|
||||
local and remote windows at a glance..
|
||||
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.channel.__bytes__(), b"<SSHChannel channel (lw 131072 rw 0)>"
|
||||
)
|
||||
self.assertEqual(
|
||||
channel.SSHChannel(localWindow=1).__bytes__(),
|
||||
b"<SSHChannel None (lw 1 rw 0)>",
|
||||
)
|
||||
|
||||
def test_logPrefix(self) -> None:
|
||||
"""
|
||||
Test that SSHChannel.logPrefix gives the name of the channel, the
|
||||
local channel ID and the underlying connection.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.channel.logPrefix(), "SSHChannel channel (unknown) on MockConnection"
|
||||
)
|
||||
|
||||
def test_addWindowBytes(self) -> None:
|
||||
"""
|
||||
Test that addWindowBytes adds bytes to the window and resumes writing
|
||||
if it was paused.
|
||||
"""
|
||||
cb = [False]
|
||||
|
||||
def stubStartWriting() -> None:
|
||||
cb[0] = True
|
||||
|
||||
self.channel.startWriting = stubStartWriting # type: ignore[method-assign]
|
||||
self.channel.write(b"test")
|
||||
self.channel.writeExtended(1, b"test")
|
||||
self.channel.addWindowBytes(50)
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
|
||||
self.assertTrue(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(self.channel.buf, b"")
|
||||
self.assertEqual(self.conn.data[self.channel], [b"test"])
|
||||
self.assertEqual(self.channel.extBuf, [])
|
||||
self.assertEqual(self.conn.extData[self.channel], [(1, b"test")])
|
||||
|
||||
cb[0] = False
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
self.channel.write(b"a" * 80)
|
||||
self.channel.loseConnection()
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
def test_requestReceived(self) -> None:
|
||||
"""
|
||||
Test that requestReceived handles requests by dispatching them to
|
||||
request_* methods.
|
||||
"""
|
||||
self.channel.request_test_method = lambda data: data == b"" # type: ignore[attr-defined]
|
||||
self.assertTrue(self.channel.requestReceived(b"test-method", b""))
|
||||
self.assertFalse(self.channel.requestReceived(b"test-method", b"a"))
|
||||
self.assertFalse(self.channel.requestReceived(b"bad-method", b""))
|
||||
|
||||
def test_closeReceieved(self) -> None:
|
||||
"""
|
||||
Test that the default closeReceieved closes the connection.
|
||||
"""
|
||||
self.assertFalse(self.channel.closing)
|
||||
self.channel.closeReceived()
|
||||
self.assertTrue(self.channel.closing)
|
||||
|
||||
def test_write(self) -> None:
|
||||
"""
|
||||
Test that write handles data correctly. Send data up to the size
|
||||
of the remote window, splitting the data into packets of length
|
||||
remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
|
||||
def stubStopWriting() -> None:
|
||||
cb[0] = True
|
||||
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
|
||||
self.channel.write(b"d")
|
||||
self.channel.write(b"a")
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.write(b"ta")
|
||||
data = self.conn.data[self.channel]
|
||||
self.assertEqual(data, [b"da", b"ta"])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.write(b"12345678901")
|
||||
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1"])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.write(b"123456")
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1", b"12345"])
|
||||
self.assertEqual(self.channel.buf, b"6")
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeExtended(self) -> None:
|
||||
"""
|
||||
Test that writeExtended handles data correctly. Send extended data
|
||||
up to the size of the window, splitting the extended data into packets
|
||||
of length remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
|
||||
def stubStopWriting() -> None:
|
||||
cb[0] = True
|
||||
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting # type: ignore[method-assign]
|
||||
self.channel.writeExtended(1, b"d")
|
||||
self.channel.writeExtended(1, b"a")
|
||||
self.channel.writeExtended(2, b"t")
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeExtended(2, b"a")
|
||||
data = self.conn.extData[self.channel]
|
||||
self.assertEqual(data, [(1, b"da"), (2, b"t"), (2, b"a")])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.writeExtended(3, b"12345678901")
|
||||
self.assertEqual(
|
||||
data, [(1, b"da"), (2, b"t"), (2, b"a"), (3, b"1234567890"), (3, b"1")]
|
||||
)
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.writeExtended(4, b"123456")
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(
|
||||
data,
|
||||
[
|
||||
(1, b"da"),
|
||||
(2, b"t"),
|
||||
(2, b"a"),
|
||||
(3, b"1234567890"),
|
||||
(3, b"1"),
|
||||
(4, b"12345"),
|
||||
],
|
||||
)
|
||||
self.assertEqual(self.channel.extBuf, [[4, b"6"]])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeSequence(self) -> None:
|
||||
"""
|
||||
Test that writeSequence is equivalent to write(''.join(sequece)).
|
||||
"""
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeSequence(b"%d" % (i,) for i in range(10))
|
||||
self.assertEqual(self.conn.data[self.channel], [b"0123456789"])
|
||||
|
||||
def test_loseConnection(self) -> None:
|
||||
"""
|
||||
Tesyt that loseConnection() doesn't close the channel until all
|
||||
the data is sent.
|
||||
"""
|
||||
self.channel.write(b"data")
|
||||
self.channel.writeExtended(1, b"datadata")
|
||||
self.channel.loseConnection()
|
||||
self.assertIsNone(self.conn.closes.get(self.channel))
|
||||
self.channel.addWindowBytes(4) # send regular data
|
||||
self.assertIsNone(self.conn.closes.get(self.channel))
|
||||
self.channel.addWindowBytes(8) # send extended data
|
||||
self.assertTrue(self.conn.closes.get(self.channel))
|
||||
|
||||
def test_getPeer(self) -> None:
|
||||
"""
|
||||
L{SSHChannel.getPeer} returns the same object as the underlying
|
||||
transport's C{getPeer} method returns.
|
||||
"""
|
||||
peer = IPv4Address("TCP", "192.168.0.1", 54321)
|
||||
connectSSHTransport(service=self.channel.conn, peerAddress=peer)
|
||||
|
||||
self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())
|
||||
|
||||
def test_getHost(self) -> None:
|
||||
"""
|
||||
L{SSHChannel.getHost} returns the same object as the underlying
|
||||
transport's C{getHost} method returns.
|
||||
"""
|
||||
host = IPv4Address("TCP", "127.0.0.1", 12345)
|
||||
connectSSHTransport(service=self.channel.conn, hostAddress=host)
|
||||
|
||||
self.assertEqual(SSHTransportAddress(host), self.channel.getHost())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user