okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,12 @@
# -*- test-case-name: twisted.web.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Web: HTTP clients and servers, plus tools for implementing them.
Contains a L{web server<twisted.web.server>} (including an
L{HTTP implementation<twisted.web.http>}, a
L{resource model<twisted.web.resource>}), and
a L{web client<twisted.web.client>}.
"""

View File

@@ -0,0 +1,68 @@
# -*- test-case-name: twisted.web.test.test_abnf -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tools for pedantically processing the HTTP protocol.
"""
def _istoken(b: bytes) -> bool:
"""
Is the string a token per RFC 9110 section 5.6.2?
"""
for c in b:
if c not in (
b"[AWS-SECRET-REMOVED]opqrstuvwxyz" # ALPHA
b"0123456789" # DIGIT
b"!#$%&'*+-.^_`|~"
):
return False
return b != b""
def _decint(data: bytes) -> int:
"""
Parse a decimal integer of the form C{1*DIGIT}, i.e. consisting only of
decimal digits. The integer may be embedded in whitespace (space and
horizontal tab). This differs from the built-in L{int()} function by
disallowing a leading C{+} character and various forms of whitespace
(note that we sanitize linear whitespace in header values in
L{twisted.web.http_headers.Headers}).
@param data: Value to parse.
@returns: A non-negative integer.
@raises ValueError: When I{value} contains non-decimal characters.
"""
data = data.strip(b" \t")
if not data.isdigit():
raise ValueError(f"Value contains non-decimal digits: {data!r}")
return int(data)
def _ishexdigits(b: bytes) -> bool:
"""
Is the string case-insensitively hexidecimal?
It must be composed of one or more characters in the ranges a-f, A-F
and 0-9.
"""
for c in b:
if c not in b"0123456789abcdefABCDEF":
return False
return b != b""
def _hexint(b: bytes) -> int:
"""
Decode a hexadecimal integer.
Unlike L{int(b, 16)}, this raises L{ValueError} when the integer has
a prefix like C{b'0x'}, C{b'+'}, or C{b'-'}, which is desirable when
parsing network protocols.
"""
if not _ishexdigits(b):
raise ValueError(b)
return int(b, 16)

View File

@@ -0,0 +1,7 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP header-based authentication migrated from web2
"""

View File

@@ -0,0 +1,58 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP BASIC authentication.
@see: U{http://tools.ietf.org/html/rfc1945}
@see: U{http://tools.ietf.org/html/rfc2616}
@see: U{http://tools.ietf.org/html/rfc2617}
"""
import binascii
from zope.interface import implementer
from twisted.cred import credentials, error
from twisted.web.iweb import ICredentialFactory
@implementer(ICredentialFactory)
class BasicCredentialFactory:
"""
Credential Factory for HTTP Basic Authentication
@type authenticationRealm: L{bytes}
@ivar authenticationRealm: The HTTP authentication realm which will be issued in
challenges.
"""
scheme = b"basic"
def __init__(self, authenticationRealm):
self.authenticationRealm = authenticationRealm
def getChallenge(self, request):
"""
Return a challenge including the HTTP authentication realm with which
this factory was created.
"""
return {"realm": self.authenticationRealm}
def decode(self, response, request):
"""
Parse the base64-encoded, colon-separated username and password into a
L{credentials.UsernamePassword} instance.
"""
try:
creds = binascii.a2b_base64(response + b"===")
except binascii.Error:
raise error.LoginFailed("Invalid credentials")
creds = creds.split(b":", 1)
if len(creds) == 2:
return credentials.UsernamePassword(*creds)
else:
raise error.LoginFailed("Invalid credentials")

View File

@@ -0,0 +1,56 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of RFC2617: HTTP Digest Authentication
@see: U{http://www.faqs.org/rfcs/rfc2617.html}
"""
from zope.interface import implementer
from twisted.cred import credentials
from twisted.web.iweb import ICredentialFactory
@implementer(ICredentialFactory)
class DigestCredentialFactory:
"""
Wrapper for L{digest.DigestCredentialFactory} that implements the
L{ICredentialFactory} interface.
"""
scheme = b"digest"
def __init__(self, algorithm, authenticationRealm):
"""
Create the digest credential factory that this object wraps.
"""
self.digest = credentials.DigestCredentialFactory(
algorithm, authenticationRealm
)
def getChallenge(self, request):
"""
Generate the challenge for use in the WWW-Authenticate header
@param request: The L{IRequest} to with access was denied and for the
response to which this challenge is being generated.
@return: The L{dict} that can be used to generate a WWW-Authenticate
header.
"""
return self.digest.getChallenge(request.getClientAddress().host)
def decode(self, response, request):
"""
Create a L{twisted.cred.credentials.DigestedCredentials} object
from the given response and request.
@see: L{ICredentialFactory.decode}
"""
return self.digest.decode(
response, request.method, request.getClientAddress().host
)

View File

@@ -0,0 +1,236 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A guard implementation which supports HTTP header-based authentication
schemes.
If no I{Authorization} header is supplied, an anonymous login will be
attempted by using a L{Anonymous} credentials object. If such a header is
supplied and does not contain allowed credentials, or if anonymous login is
denied, a 401 will be sent in the response along with I{WWW-Authenticate}
headers for each of the allowed authentication schemes.
"""
from zope.interface import implementer
from twisted.cred import error
from twisted.cred.credentials import Anonymous
from twisted.logger import Logger
from twisted.python.components import proxyForInterface
from twisted.web import util
from twisted.web.resource import IResource, _UnsafeErrorPage
@implementer(IResource)
class UnauthorizedResource:
"""
Simple IResource to escape Resource dispatch
"""
isLeaf = True
def __init__(self, factories):
self._credentialFactories = factories
def render(self, request):
"""
Send www-authenticate headers to the client
"""
def ensureBytes(s):
return s.encode("ascii") if isinstance(s, str) else s
def generateWWWAuthenticate(scheme, challenge):
lst = []
for k, v in challenge.items():
k = ensureBytes(k)
v = ensureBytes(v)
lst.append(k + b"=" + quoteString(v))
return b" ".join([scheme, b", ".join(lst)])
def quoteString(s):
return b'"' + s.replace(b"\\", rb"\\").replace(b'"', rb"\"") + b'"'
request.setResponseCode(401)
for fact in self._credentialFactories:
challenge = fact.getChallenge(request)
request.responseHeaders.addRawHeader(
b"www-authenticate", generateWWWAuthenticate(fact.scheme, challenge)
)
if request.method == b"HEAD":
return b""
return b"Unauthorized"
def getChildWithDefault(self, path, request):
"""
Disable resource dispatch
"""
return self
def putChild(self, path, child):
# IResource.putChild
raise NotImplementedError()
@implementer(IResource)
class HTTPAuthSessionWrapper:
"""
Wrap a portal, enforcing supported header-based authentication schemes.
@ivar _portal: The L{Portal} which will be used to retrieve L{IResource}
avatars.
@ivar _credentialFactories: A list of L{ICredentialFactory} providers which
will be used to decode I{Authorization} headers into L{ICredentials}
providers.
"""
isLeaf = False
_log = Logger()
def __init__(self, portal, credentialFactories):
"""
Initialize a session wrapper
@type portal: C{Portal}
@param portal: The portal that will authenticate the remote client
@type credentialFactories: C{Iterable}
@param credentialFactories: The portal that will authenticate the
remote client based on one submitted C{ICredentialFactory}
"""
self._portal = portal
self._credentialFactories = credentialFactories
def _authorizedResource(self, request):
"""
Get the L{IResource} which the given request is authorized to receive.
If the proper authorization headers are present, the resource will be
requested from the portal. If not, an anonymous login attempt will be
made.
"""
authheader = request.getHeader(b"authorization")
if not authheader:
return util.DeferredResource(self._login(Anonymous()))
factory, respString = self._selectParseHeader(authheader)
if factory is None:
return UnauthorizedResource(self._credentialFactories)
try:
credentials = factory.decode(respString, request)
except error.LoginFailed:
return UnauthorizedResource(self._credentialFactories)
except BaseException:
self._log.failure("Unexpected failure from credentials factory")
return _UnsafeErrorPage(500, "Internal Error", "")
else:
return util.DeferredResource(self._login(credentials))
def render(self, request):
"""
Find the L{IResource} avatar suitable for the given request, if
possible, and render it. Otherwise, perhaps render an error page
requiring authorization or describing an internal server failure.
"""
return self._authorizedResource(request).render(request)
def getChildWithDefault(self, path, request):
"""
Inspect the Authorization HTTP header, and return a deferred which,
when fired after successful authentication, will return an authorized
C{Avatar}. On authentication failure, an C{UnauthorizedResource} will
be returned, essentially halting further dispatch on the wrapped
resource and all children
"""
# Don't consume any segments of the request - this class should be
# transparent!
request.postpath.insert(0, request.prepath.pop())
return self._authorizedResource(request)
def _login(self, credentials):
"""
Get the L{IResource} avatar for the given credentials.
@return: A L{Deferred} which will be called back with an L{IResource}
avatar or which will errback if authentication fails.
"""
d = self._portal.login(credentials, None, IResource)
d.addCallbacks(self._loginSucceeded, self._loginFailed)
return d
def _loginSucceeded(self, args):
"""
Handle login success by wrapping the resulting L{IResource} avatar
so that the C{logout} callback will be invoked when rendering is
complete.
"""
interface, avatar, logout = args
class ResourceWrapper(proxyForInterface(IResource, "resource")):
"""
Wrap an L{IResource} so that whenever it or a child of it
completes rendering, the cred logout hook will be invoked.
An assumption is made here that exactly one L{IResource} from
among C{avatar} and all of its children will be rendered. If
more than one is rendered, C{logout} will be invoked multiple
times and probably earlier than desired.
"""
def getChildWithDefault(self, name, request):
"""
Pass through the lookup to the wrapped resource, wrapping
the result in L{ResourceWrapper} to ensure C{logout} is
called when rendering of the child is complete.
"""
return ResourceWrapper(self.resource.getChildWithDefault(name, request))
def render(self, request):
"""
Hook into response generation so that when rendering has
finished completely (with or without error), C{logout} is
called.
"""
request.notifyFinish().addBoth(lambda ign: logout())
return super().render(request)
return ResourceWrapper(avatar)
def _loginFailed(self, result):
"""
Handle login failure by presenting either another challenge (for
expected authentication/authorization-related failures) or a server
error page (for anything else).
"""
if result.check(error.Unauthorized, error.LoginFailed):
return UnauthorizedResource(self._credentialFactories)
else:
self._log.failure(
"HTTPAuthSessionWrapper.getChildWithDefault encountered "
"unexpected error",
failure=result,
)
return _UnsafeErrorPage(500, "Internal Error", "")
def _selectParseHeader(self, header):
"""
Choose an C{ICredentialFactory} from C{_credentialFactories}
suitable to use to decode the given I{Authenticate} header.
@return: A two-tuple of a factory and the remaining portion of the
header value to be decoded or a two-tuple of L{None} if no
factory can decode the header value.
"""
elements = header.split(b" ")
scheme = elements[0].lower()
for fact in self._credentialFactories:
if fact.scheme == scheme:
return (fact, b" ".join(elements[1:]))
return (None, None)
def putChild(self, path, child):
# IResource.putChild
raise NotImplementedError()

View File

@@ -0,0 +1,200 @@
# -*- test-case-name: twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import itertools
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
TypeVar,
Union,
overload,
)
from zope.interface import implementer
from twisted.web.error import (
MissingRenderMethod,
MissingTemplateLoader,
UnexposedMethodError,
)
from twisted.web.iweb import IRenderable, IRequest, ITemplateLoader
if TYPE_CHECKING:
from twisted.web.template import Flattenable, Tag
T = TypeVar("T")
_Tc = TypeVar("_Tc", bound=Callable[..., object])
class Expose:
"""
Helper for exposing methods for various uses using a simple decorator-style
callable.
Instances of this class can be called with one or more functions as
positional arguments. The names of these functions will be added to a list
on the class object of which they are methods.
"""
def __call__(self, f: _Tc, /, *funcObjs: Callable[..., object]) -> _Tc:
"""
Add one or more functions to the set of exposed functions.
This is a way to declare something about a class definition, similar to
L{zope.interface.implementer}. Use it like this::
magic = Expose('perform extra magic')
class Foo(Bar):
def twiddle(self, x, y):
...
def frob(self, a, b):
...
magic(twiddle, frob)
Later you can query the object::
aFoo = Foo()
magic.get(aFoo, 'twiddle')(x=1, y=2)
The call to C{get} will fail if the name it is given has not been
exposed using C{magic}.
@param funcObjs: One or more function objects which will be exposed to
the client.
@return: The first of C{funcObjs}.
"""
for fObj in itertools.chain([f], funcObjs):
exposedThrough: List[Expose] = getattr(fObj, "exposedThrough", [])
exposedThrough.append(self)
setattr(fObj, "exposedThrough", exposedThrough)
return f
_nodefault = object()
@overload
def get(self, instance: object, methodName: str) -> Callable[..., Any]:
...
@overload
def get(
self, instance: object, methodName: str, default: T
) -> Union[Callable[..., Any], T]:
...
def get(
self, instance: object, methodName: str, default: object = _nodefault
) -> object:
"""
Retrieve an exposed method with the given name from the given instance.
@raise UnexposedMethodError: Raised if C{default} is not specified and
there is no exposed method with the given name.
@return: A callable object for the named method assigned to the given
instance.
"""
method = getattr(instance, methodName, None)
exposedThrough = getattr(method, "exposedThrough", [])
if self not in exposedThrough:
if default is self._nodefault:
raise UnexposedMethodError(self, methodName)
return default
return method
def exposer(thunk: Callable[..., object]) -> Expose:
expose = Expose()
expose.__doc__ = thunk.__doc__
return expose
@exposer
def renderer() -> None:
"""
Decorate with L{renderer} to use methods as template render directives.
For example::
class Foo(Element):
@renderer
def twiddle(self, request, tag):
return tag('Hello, world.')
<div xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">
<span t:render="twiddle" />
</div>
Will result in this final output::
<div>
<span>Hello, world.</span>
</div>
"""
@implementer(IRenderable)
class Element:
"""
Base for classes which can render part of a page.
An Element is a renderer that can be embedded in a stan document and can
hook its template (from the loader) up to render methods.
An Element might be used to encapsulate the rendering of a complex piece of
data which is to be displayed in multiple different contexts. The Element
allows the rendering logic to be easily re-used in different ways.
Element returns render methods which are registered using
L{twisted.web._element.renderer}. For example::
class Menu(Element):
@renderer
def items(self, request, tag):
....
Render methods are invoked with two arguments: first, the
L{twisted.web.http.Request} being served and second, the tag object which
"invoked" the render method.
@ivar loader: The factory which will be used to load documents to
return from C{render}.
"""
loader: Optional[ITemplateLoader] = None
def __init__(self, loader: Optional[ITemplateLoader] = None):
if loader is not None:
self.loader = loader
def lookupRenderMethod(
self, name: str
) -> Callable[[Optional[IRequest], "Tag"], "Flattenable"]:
"""
Look up and return the named render method.
"""
method = renderer.get(self, name, None)
if method is None:
raise MissingRenderMethod(self, name)
return method
def render(self, request: Optional[IRequest]) -> "Flattenable":
"""
Implement L{IRenderable} to allow one L{Element} to be embedded in
another's template or rendering output.
(This will simply load the template from the C{loader}; when used in a
template, the flattening engine will keep track of this object
separately as the object to lookup renderers on and call
L{Element.renderer} to look them up. The resulting object from this
method is not directly associated with this L{Element}.)
"""
loader = self.loader
if loader is None:
raise MissingTemplateLoader(self)
return loader.load()

View File

@@ -0,0 +1,486 @@
# -*- test-case-name: twisted.web.test.test_flatten,twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Context-free flattener/serializer for rendering Python objects, possibly
complex or arbitrarily nested, as strings.
"""
from __future__ import annotations
from inspect import iscoroutine
from io import BytesIO
from sys import exc_info
from traceback import extract_tb
from types import GeneratorType
from typing import (
Any,
Callable,
Coroutine,
Generator,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.compat import nativeString
from twisted.python.failure import Failure
from twisted.web._stan import CDATA, CharRef, Comment, Tag, slot, voidElements
from twisted.web.error import FlattenerError, UnfilledSlot, UnsupportedType
from twisted.web.iweb import IRenderable, IRequest
T = TypeVar("T")
FlattenableRecursive = Any
"""
For documentation purposes, read C{FlattenableRecursive} as L{Flattenable}.
However, since mypy doesn't support recursive type definitions (yet?),
we'll put Any in the actual definition.
"""
Flattenable = Union[
bytes,
str,
slot,
CDATA,
Comment,
Tag,
Tuple[FlattenableRecursive, ...],
List[FlattenableRecursive],
Generator[FlattenableRecursive, None, None],
CharRef,
Deferred[FlattenableRecursive],
Coroutine[Deferred[FlattenableRecursive], object, FlattenableRecursive],
IRenderable,
]
"""
Type alias containing all types that can be flattened by L{flatten()}.
"""
# The maximum number of bytes to synchronously accumulate in the flattener
# buffer before delivering them onwards.
BUFFER_SIZE = 2**16
def escapeForContent(data: Union[bytes, str]) -> bytes:
"""
Escape some character or UTF-8 byte data for inclusion in an HTML or XML
document, by replacing metacharacters (C{&<>}) with their entity
equivalents (C{&amp;&lt;&gt;}).
This is used as an input to L{_flattenElement}'s C{dataEscaper} parameter.
@param data: The string to escape.
@return: The quoted form of C{data}. If C{data} is L{str}, return a utf-8
encoded string.
"""
if isinstance(data, str):
data = data.encode("utf-8")
data = data.replace(b"&", b"&amp;").replace(b"<", b"&lt;").replace(b">", b"&gt;")
return data
def attributeEscapingDoneOutside(data: Union[bytes, str]) -> bytes:
"""
Escape some character or UTF-8 byte data for inclusion in the top level of
an attribute. L{attributeEscapingDoneOutside} actually passes the data
through unchanged, because L{writeWithAttributeEscaping} handles the
quoting of the text within attributes outside the generator returned by
L{_flattenElement}; this is used as the C{dataEscaper} argument to that
L{_flattenElement} call so that that generator does not redundantly escape
its text output.
@param data: The string to escape.
@return: The string, unchanged, except for encoding.
"""
if isinstance(data, str):
return data.encode("utf-8")
return data
def writeWithAttributeEscaping(
write: Callable[[bytes], object]
) -> Callable[[bytes], None]:
"""
Decorate a C{write} callable so that all output written is properly quoted
for inclusion within an XML attribute value.
If a L{Tag <twisted.web.template.Tag>} C{x} is flattened within the context
of the contents of another L{Tag <twisted.web.template.Tag>} C{y}, the
metacharacters (C{<>&"}) delimiting C{x} should be passed through
unchanged, but the textual content of C{x} should still be quoted, as
usual. For example: C{<y><x>&amp;</x></y>}. That is the default behavior
of L{_flattenElement} when L{escapeForContent} is passed as the
C{dataEscaper}.
However, when a L{Tag <twisted.web.template.Tag>} C{x} is flattened within
the context of an I{attribute} of another L{Tag <twisted.web.template.Tag>}
C{y}, then the metacharacters delimiting C{x} should be quoted so that it
can be parsed from the attribute's value. In the DOM itself, this is not a
valid thing to do, but given that renderers and slots may be freely moved
around in a L{twisted.web.template} template, it is a condition which may
arise in a document and must be handled in a way which produces valid
output. So, for example, you should be able to get C{<y attr="&lt;x /&gt;"
/>}. This should also be true for other XML/HTML meta-constructs such as
comments and CDATA, so if you were to serialize a L{comment
<twisted.web.template.Comment>} in an attribute you should get C{<y
attr="&lt;-- comment --&gt;" />}. Therefore in order to capture these
meta-characters, flattening is done with C{write} callable that is wrapped
with L{writeWithAttributeEscaping}.
The final case, and hopefully the much more common one as compared to
serializing L{Tag <twisted.web.template.Tag>} and arbitrary L{IRenderable}
objects within an attribute, is to serialize a simple string, and those
should be passed through for L{writeWithAttributeEscaping} to quote
without applying a second, redundant level of quoting.
@param write: A callable which will be invoked with the escaped L{bytes}.
@return: A callable that writes data with escaping.
"""
def _write(data: bytes) -> None:
write(escapeForContent(data).replace(b'"', b"&quot;"))
return _write
def escapedCDATA(data: Union[bytes, str]) -> bytes:
"""
Escape CDATA for inclusion in a document.
@param data: The string to escape.
@return: The quoted form of C{data}. If C{data} is unicode, return a utf-8
encoded string.
"""
if isinstance(data, str):
data = data.encode("utf-8")
return data.replace(b"]]>", b"]]]]><![CDATA[>")
def escapedComment(data: Union[bytes, str]) -> bytes:
"""
Within comments the sequence C{-->} can be mistaken as the end of the comment.
To ensure consistent parsing and valid output the sequence is replaced with C{--&gt;}.
Furthermore, whitespace is added when a comment ends in a dash. This is done to break
the connection of the ending C{-} with the closing C{-->}.
@param data: The string to escape.
@return: The quoted form of C{data}. If C{data} is unicode, return a utf-8
encoded string.
"""
if isinstance(data, str):
data = data.encode("utf-8")
data = data.replace(b"-->", b"--&gt;")
if data and data[-1:] == b"-":
data += b" "
return data
def _getSlotValue(
name: str,
slotData: Sequence[Optional[Mapping[str, Flattenable]]],
default: Optional[Flattenable] = None,
) -> Flattenable:
"""
Find the value of the named slot in the given stack of slot data.
"""
for slotFrame in reversed(slotData):
if slotFrame is not None and name in slotFrame:
return slotFrame[name]
else:
if default is not None:
return default
raise UnfilledSlot(name)
def _fork(d: Deferred[T]) -> Deferred[T]:
"""
Create a new L{Deferred} based on C{d} that will fire and fail with C{d}'s
result or error, but will not modify C{d}'s callback type.
"""
d2: Deferred[T] = Deferred(lambda _: d.cancel())
def callback(result: T) -> T:
d2.callback(result)
return result
def errback(failure: Failure) -> Failure:
d2.errback(failure)
return failure
d.addCallbacks(callback, errback)
return d2
def _flattenElement(
request: Optional[IRequest],
root: Flattenable,
write: Callable[[bytes], object],
slotData: List[Optional[Mapping[str, Flattenable]]],
renderFactory: Optional[IRenderable],
dataEscaper: Callable[[Union[bytes, str]], bytes],
# This is annotated as Generator[T, None, None] instead of Iterator[T]
# because mypy does not consider an Iterator to be an instance of
# GeneratorType.
) -> Generator[Union[Generator[Any, Any, Any], Deferred[Flattenable]], None, None]:
"""
Make C{root} slightly more flat by yielding all its immediate contents as
strings, deferreds or generators that are recursive calls to itself.
@param request: A request object which will be passed to
L{IRenderable.render}.
@param root: An object to be made flatter. This may be of type C{unicode},
L{str}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple}, L{list},
L{types.GeneratorType}, L{Deferred}, or an object that implements
L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@param slotData: A L{list} of L{dict} mapping L{str} slot names to data
with which those slots will be replaced.
@param renderFactory: If not L{None}, an object that provides
L{IRenderable}.
@param dataEscaper: A 1-argument callable which takes L{bytes} or
L{unicode} and returns L{bytes}, quoted as appropriate for the
rendering context. This is really only one of two values:
L{attributeEscapingDoneOutside} or L{escapeForContent}, depending on
whether the rendering context is within an attribute or not. See the
explanation in L{writeWithAttributeEscaping}.
@return: An iterator that eventually writes L{bytes} to C{write}.
It can yield other iterators or L{Deferred}s; if it yields another
iterator, the caller will iterate it; if it yields a L{Deferred},
the result of that L{Deferred} will be another generator, in which
case it is iterated. See L{_flattenTree} for the trampoline that
consumes said values.
"""
def keepGoing(
newRoot: Flattenable,
dataEscaper: Callable[[Union[bytes, str]], bytes] = dataEscaper,
renderFactory: Optional[IRenderable] = renderFactory,
write: Callable[[bytes], object] = write,
) -> Generator[Union[Flattenable, Deferred[Flattenable]], None, None]:
return _flattenElement(
request, newRoot, write, slotData, renderFactory, dataEscaper
)
def keepGoingAsync(result: Deferred[Flattenable]) -> Deferred[Flattenable]:
return result.addCallback(keepGoing)
if isinstance(root, (bytes, str)):
write(dataEscaper(root))
elif isinstance(root, slot):
slotValue = _getSlotValue(root.name, slotData, root.default)
yield keepGoing(slotValue)
elif isinstance(root, CDATA):
write(b"<![CDATA[")
write(escapedCDATA(root.data))
write(b"]]>")
elif isinstance(root, Comment):
write(b"<!--")
write(escapedComment(root.data))
write(b"-->")
elif isinstance(root, Tag):
slotData.append(root.slotData)
rendererName = root.render
if rendererName is not None:
if renderFactory is None:
raise ValueError(
f'Tag wants to be rendered by method "{rendererName}" '
f"but is not contained in any IRenderable"
)
rootClone = root.clone(False)
rootClone.render = None
renderMethod = renderFactory.lookupRenderMethod(rendererName)
result = renderMethod(request, rootClone)
yield keepGoing(result)
slotData.pop()
return
if not root.tagName:
yield keepGoing(root.children)
return
write(b"<")
if isinstance(root.tagName, str):
tagName = root.tagName.encode("ascii")
else:
tagName = root.tagName
write(tagName)
for k, v in root.attributes.items():
if isinstance(k, str):
k = k.encode("ascii")
write(b" " + k + b'="')
# Serialize the contents of the attribute, wrapping the results of
# that serialization so that _everything_ is quoted.
yield keepGoing(
v, attributeEscapingDoneOutside, write=writeWithAttributeEscaping(write)
)
write(b'"')
if root.children or nativeString(tagName) not in voidElements:
write(b">")
# Regardless of whether we're in an attribute or not, switch back
# to the escapeForContent dataEscaper. The contents of a tag must
# be quoted no matter what; in the top-level document, just so
# they're valid, and if they're within an attribute, they have to
# be quoted so that after applying the *un*-quoting required to re-
# parse the tag within the attribute, all the quoting is still
# correct.
yield keepGoing(root.children, escapeForContent)
write(b"</" + tagName + b">")
else:
write(b" />")
elif isinstance(root, (tuple, list, GeneratorType)):
for element in root:
yield keepGoing(element)
elif isinstance(root, CharRef):
escaped = "&#%d;" % (root.ordinal,)
write(escaped.encode("ascii"))
elif isinstance(root, Deferred):
yield keepGoingAsync(_fork(root))
elif iscoroutine(root):
yield keepGoingAsync(
Deferred.fromCoroutine(
cast(Coroutine[Deferred[Flattenable], object, Flattenable], root)
)
)
elif IRenderable.providedBy(root):
result = root.render(request)
yield keepGoing(result, renderFactory=root)
else:
raise UnsupportedType(root)
async def _flattenTree(
request: Optional[IRequest], root: Flattenable, write: Callable[[bytes], object]
) -> None:
"""
Make C{root} into an iterable of L{bytes} and L{Deferred} by doing a depth
first traversal of the tree.
@param request: A request object which will be passed to
L{IRenderable.render}.
@param root: An object to be made flatter. This may be of type C{unicode},
L{bytes}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple},
L{list}, L{types.GeneratorType}, L{Deferred}, or something providing
L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: A C{Deferred}-returning coroutine that resolves to C{None}.
"""
buf = []
bufSize = 0
# Accumulate some bytes up to the buffer size so that we don't annoy the
# upstream writer with a million tiny string.
def bufferedWrite(bs: bytes) -> None:
nonlocal bufSize
buf.append(bs)
bufSize += len(bs)
if bufSize >= BUFFER_SIZE:
flushBuffer()
# Deliver the buffered content to the upstream writer as a single string.
# This is how a "big enough" buffer gets delivered, how a buffer of any
# size is delivered before execution is suspended to wait for an
# asynchronous value, and how anything left in the buffer when we're
# finished is delivered.
def flushBuffer() -> None:
nonlocal bufSize
if bufSize > 0:
write(b"".join(buf))
del buf[:]
bufSize = 0
stack: List[Generator[Any, Any, Any]] = [
_flattenElement(request, root, bufferedWrite, [], None, escapeForContent)
]
while stack:
try:
element = next(stack[-1])
if isinstance(element, Deferred):
# Before suspending flattening for an unknown amount of time,
# flush whatever data we have collected so far.
flushBuffer()
element = await element
except StopIteration:
stack.pop()
except Exception as e:
roots = []
for generator in stack:
if generator.gi_frame is not None:
roots.append(generator.gi_frame.f_locals["root"])
stack.pop()
raise FlattenerError(e, roots, extract_tb(exc_info()[2]))
else:
stack.append(element)
# Flush any data that remains in the buffer before finishing.
flushBuffer()
def flatten(
request: Optional[IRequest], root: Flattenable, write: Callable[[bytes], object]
) -> Deferred[None]:
"""
Incrementally write out a string representation of C{root} using C{write}.
In order to create a string representation, C{root} will be decomposed into
simpler objects which will themselves be decomposed and so on until strings
or objects which can easily be converted to strings are encountered.
@param request: A request object which will be passed to the C{render}
method of any L{IRenderable} provider which is encountered.
@param root: An object to be made flatter. This may be of type L{str},
L{bytes}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple},
L{list}, L{types.GeneratorType}, L{Deferred}, or something that
provides L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: A L{Deferred} which will be called back with C{None} when C{root}
has been completely flattened into C{write} or which will be errbacked
if an unexpected exception occurs.
"""
return ensureDeferred(_flattenTree(request, root, write))
def flattenString(request: Optional[IRequest], root: Flattenable) -> Deferred[bytes]:
"""
Collate a string representation of C{root} into a single string.
This is basically gluing L{flatten} to an L{io.BytesIO} and returning
the results. See L{flatten} for the exact meanings of C{request} and
C{root}.
@return: A L{Deferred} which will be called back with a single UTF-8 encoded
string as its result when C{root} has been completely flattened or which
will be errbacked if an unexpected exception occurs.
"""
io = BytesIO()
d = flatten(request, root, io.write)
d.addCallback(lambda _: io.getvalue())
return cast(Deferred[bytes], d)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,112 @@
# -*- test-case-name: twisted.web.test.test_http -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP response code definitions.
"""
_CONTINUE = 100
SWITCHING = 101
OK = 200
CREATED = 201
ACCEPTED = 202
NON_AUTHORITATIVE_INFORMATION = 203
NO_CONTENT = 204
RESET_CONTENT = 205
PARTIAL_CONTENT = 206
MULTI_STATUS = 207
MULTIPLE_CHOICE = 300
MOVED_PERMANENTLY = 301
FOUND = 302
SEE_OTHER = 303
NOT_MODIFIED = 304
USE_PROXY = 305
TEMPORARY_REDIRECT = 307
PERMANENT_REDIRECT = 308
BAD_REQUEST = 400
UNAUTHORIZED = 401
PAYMENT_REQUIRED = 402
FORBIDDEN = 403
NOT_FOUND = 404
NOT_ALLOWED = 405
NOT_ACCEPTABLE = 406
PROXY_AUTH_REQUIRED = 407
REQUEST_TIMEOUT = 408
CONFLICT = 409
GONE = 410
LENGTH_REQUIRED = 411
PRECONDITION_FAILED = 412
REQUEST_ENTITY_TOO_LARGE = 413
REQUEST_URI_TOO_LONG = 414
UNSUPPORTED_MEDIA_TYPE = 415
REQUESTED_RANGE_NOT_SATISFIABLE = 416
EXPECTATION_FAILED = 417
IM_A_TEAPOT = 418
INTERNAL_SERVER_ERROR = 500
NOT_IMPLEMENTED = 501
BAD_GATEWAY = 502
SERVICE_UNAVAILABLE = 503
GATEWAY_TIMEOUT = 504
HTTP_VERSION_NOT_SUPPORTED = 505
INSUFFICIENT_STORAGE_SPACE = 507
NOT_EXTENDED = 510
RESPONSES = {
# 100
_CONTINUE: b"Continue",
SWITCHING: b"Switching Protocols",
# 200
OK: b"OK",
CREATED: b"Created",
ACCEPTED: b"Accepted",
NON_AUTHORITATIVE_INFORMATION: b"Non-Authoritative Information",
NO_CONTENT: b"No Content",
RESET_CONTENT: b"Reset Content.",
PARTIAL_CONTENT: b"Partial Content",
MULTI_STATUS: b"Multi-Status",
# 300
MULTIPLE_CHOICE: b"Multiple Choices",
MOVED_PERMANENTLY: b"Moved Permanently",
FOUND: b"Found",
SEE_OTHER: b"See Other",
NOT_MODIFIED: b"Not Modified",
USE_PROXY: b"Use Proxy",
# 306 not defined??
TEMPORARY_REDIRECT: b"Temporary Redirect",
PERMANENT_REDIRECT: b"Permanent Redirect",
# 400
BAD_REQUEST: b"Bad Request",
UNAUTHORIZED: b"Unauthorized",
PAYMENT_REQUIRED: b"Payment Required",
FORBIDDEN: b"Forbidden",
NOT_FOUND: b"Not Found",
NOT_ALLOWED: b"Method Not Allowed",
NOT_ACCEPTABLE: b"Not Acceptable",
PROXY_AUTH_REQUIRED: b"Proxy Authentication Required",
REQUEST_TIMEOUT: b"Request Time-out",
CONFLICT: b"Conflict",
GONE: b"Gone",
LENGTH_REQUIRED: b"Length Required",
PRECONDITION_FAILED: b"Precondition Failed",
REQUEST_ENTITY_TOO_LARGE: b"Request Entity Too Large",
REQUEST_URI_TOO_LONG: b"Request-URI Too Long",
UNSUPPORTED_MEDIA_TYPE: b"Unsupported Media Type",
REQUESTED_RANGE_NOT_SATISFIABLE: b"Requested Range not satisfiable",
EXPECTATION_FAILED: b"Expectation Failed",
IM_A_TEAPOT: b"I'm a teapot",
# 500
INTERNAL_SERVER_ERROR: b"Internal Server Error",
NOT_IMPLEMENTED: b"Not Implemented",
BAD_GATEWAY: b"Bad Gateway",
SERVICE_UNAVAILABLE: b"Service Unavailable",
GATEWAY_TIMEOUT: b"Gateway Time-out",
HTTP_VERSION_NOT_SUPPORTED: b"HTTP Version not supported",
INSUFFICIENT_STORAGE_SPACE: b"Insufficient Storage Space",
NOT_EXTENDED: b"Not Extended",
}

View File

@@ -0,0 +1,360 @@
# -*- test-case-name: twisted.web.test.test_stan -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An s-expression-like syntax for expressing xml in pure python.
Stan tags allow you to build XML documents using Python.
Stan is a DOM, or Document Object Model, implemented using basic Python types
and functions called "flatteners". A flattener is a function that knows how to
turn an object of a specific type into something that is closer to an HTML
string. Stan differs from the W3C DOM by not being as cumbersome and heavy
weight. Since the object model is built using simple python types such as lists,
strings, and dictionaries, the API is simpler and constructing a DOM less
cumbersome.
@var voidElements: the names of HTML 'U{void
elements<http://www.whatwg.org/specs/web-apps/current-work/multipage/syntax.html#void-elements>}';
those which can't have contents and can therefore be self-closing in the
output.
"""
from inspect import iscoroutine, isgenerator
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from warnings import warn
import attr
if TYPE_CHECKING:
from twisted.web.template import Flattenable
@attr.s(unsafe_hash=False, eq=False, auto_attribs=True)
class slot:
"""
Marker for markup insertion in a template.
"""
name: str
"""
The name of this slot.
The key which must be used in L{Tag.fillSlots} to fill it.
"""
children: List["Tag"] = attr.ib(init=False, factory=list)
"""
The L{Tag} objects included in this L{slot}'s template.
"""
default: Optional["Flattenable"] = None
"""
The default contents of this slot, if it is left unfilled.
If this is L{None}, an L{UnfilledSlot} will be raised, rather than
L{None} actually being used.
"""
filename: Optional[str] = None
"""
The name of the XML file from which this tag was parsed.
If it was not parsed from an XML file, L{None}.
"""
lineNumber: Optional[int] = None
"""
The line number on which this tag was encountered in the XML file
from which it was parsed.
If it was not parsed from an XML file, L{None}.
"""
columnNumber: Optional[int] = None
"""
The column number at which this tag was encountered in the XML file
from which it was parsed.
If it was not parsed from an XML file, L{None}.
"""
@attr.s(unsafe_hash=False, eq=False, repr=False, auto_attribs=True)
class Tag:
"""
A L{Tag} represents an XML tags with a tag name, attributes, and children.
A L{Tag} can be constructed using the special L{twisted.web.template.tags}
object, or it may be constructed directly with a tag name. L{Tag}s have a
special method, C{__call__}, which makes representing trees of XML natural
using pure python syntax.
"""
tagName: Union[bytes, str]
"""
The name of the represented element.
For a tag like C{<div></div>}, this would be C{"div"}.
"""
attributes: Dict[Union[bytes, str], "Flattenable"] = attr.ib(factory=dict)
"""The attributes of the element."""
children: List["Flattenable"] = attr.ib(factory=list)
"""The contents of this C{Tag}."""
render: Optional[str] = None
"""
The name of the render method to use for this L{Tag}.
This name will be looked up at render time by the
L{twisted.web.template.Element} doing the rendering,
via L{twisted.web.template.Element.lookupRenderMethod},
to determine which method to call.
"""
filename: Optional[str] = None
"""
The name of the XML file from which this tag was parsed.
If it was not parsed from an XML file, L{None}.
"""
lineNumber: Optional[int] = None
"""
The line number on which this tag was encountered in the XML file
from which it was parsed.
If it was not parsed from an XML file, L{None}.
"""
columnNumber: Optional[int] = None
"""
The column number at which this tag was encountered in the XML file
from which it was parsed.
If it was not parsed from an XML file, L{None}.
"""
slotData: Optional[Dict[str, "Flattenable"]] = attr.ib(init=False, default=None)
"""
The data which can fill slots.
If present, a dictionary mapping slot names to renderable values.
The values in this dict might be anything that can be present as
the child of a L{Tag}: strings, lists, L{Tag}s, generators, etc.
"""
def fillSlots(self, **slots: "Flattenable") -> "Tag":
"""
Remember the slots provided at this position in the DOM.
During the rendering of children of this node, slots with names in
C{slots} will be rendered as their corresponding values.
@return: C{self}. This enables the idiom C{return tag.fillSlots(...)} in
renderers.
"""
if self.slotData is None:
self.slotData = {}
self.slotData.update(slots)
return self
def __call__(self, *children: "Flattenable", **kw: "Flattenable") -> "Tag":
"""
Add children and change attributes on this tag.
This is implemented using __call__ because it then allows the natural
syntax::
table(tr1, tr2, width="100%", height="50%", border="1")
Children may be other tag instances, strings, functions, or any other
object which has a registered flatten.
Attributes may be 'transparent' tag instances (so that
C{a(href=transparent(data="foo", render=myhrefrenderer))} works),
strings, functions, or any other object which has a registered
flattener.
If the attribute is a python keyword, such as 'class', you can add an
underscore to the name, like 'class_'.
There is one special keyword argument, 'render', which will be used as
the name of the renderer and saved as the 'render' attribute of this
instance, rather than the DOM 'render' attribute in the attributes
dictionary.
"""
self.children.extend(children)
for k, v in kw.items():
if k[-1] == "_":
k = k[:-1]
if k == "render":
if not isinstance(v, str):
raise TypeError(
f'Value for "render" attribute must be str, got {v!r}'
)
self.render = v
else:
self.attributes[k] = v
return self
def _clone(self, obj: "Flattenable", deep: bool) -> "Flattenable":
"""
Clone a C{Flattenable} object; used by L{Tag.clone}.
Note that both lists and tuples are cloned into lists.
@param obj: an object with a clone method, a list or tuple, or something
which should be immutable.
@param deep: whether to continue cloning child objects; i.e. the
contents of lists, the sub-tags within a tag.
@return: a clone of C{obj}.
"""
if hasattr(obj, "clone"):
return obj.clone(deep)
elif isinstance(obj, (list, tuple)):
return [self._clone(x, deep) for x in obj]
elif isgenerator(obj):
warn(
"Cloning a Tag which contains a generator is unsafe, "
"since the generator can be consumed only once; "
"this is deprecated since Twisted 21.7.0 and will raise "
"an exception in the future",
DeprecationWarning,
)
return obj
elif iscoroutine(obj):
warn(
"Cloning a Tag which contains a coroutine is unsafe, "
"since the coroutine can run only once; "
"this is deprecated since Twisted 21.7.0 and will raise "
"an exception in the future",
DeprecationWarning,
)
return obj
else:
return obj
def clone(self, deep: bool = True) -> "Tag":
"""
Return a clone of this tag. If deep is True, clone all of this tag's
children. Otherwise, just shallow copy the children list without copying
the children themselves.
"""
if deep:
newchildren = [self._clone(x, True) for x in self.children]
else:
newchildren = self.children[:]
newattrs = self.attributes.copy()
for key in newattrs.keys():
newattrs[key] = self._clone(newattrs[key], True)
newslotdata = None
if self.slotData:
newslotdata = self.slotData.copy()
for key in newslotdata:
newslotdata[key] = self._clone(newslotdata[key], True)
newtag = Tag(
self.tagName,
attributes=newattrs,
children=newchildren,
render=self.render,
filename=self.filename,
lineNumber=self.lineNumber,
columnNumber=self.columnNumber,
)
newtag.slotData = newslotdata
return newtag
def clear(self) -> "Tag":
"""
Clear any existing children from this tag.
"""
self.children = []
return self
def __repr__(self) -> str:
rstr = ""
if self.attributes:
rstr += ", attributes=%r" % self.attributes
if self.children:
rstr += ", children=%r" % self.children
return f"Tag({self.tagName!r}{rstr})"
voidElements = (
"img",
"br",
"hr",
"base",
"meta",
"link",
"param",
"area",
"input",
"col",
"basefont",
"isindex",
"frame",
"command",
"embed",
"keygen",
"source",
"track",
"wbs",
)
@attr.s(unsafe_hash=False, eq=False, repr=False, auto_attribs=True)
class CDATA:
"""
A C{<![CDATA[]]>} block from a template. Given a separate representation in
the DOM so that they may be round-tripped through rendering without losing
information.
"""
data: str
"""The data between "C{<![CDATA[}" and "C{]]>}"."""
def __repr__(self) -> str:
return f"CDATA({self.data!r})"
@attr.s(unsafe_hash=False, eq=False, repr=False, auto_attribs=True)
class Comment:
"""
A C{<!-- -->} comment from a template. Given a separate representation in
the DOM so that they may be round-tripped through rendering without losing
information.
"""
data: str
"""The data between "C{<!--}" and "C{-->}"."""
def __repr__(self) -> str:
return f"Comment({self.data!r})"
@attr.s(unsafe_hash=False, eq=False, repr=False, auto_attribs=True)
class CharRef:
"""
A numeric character reference. Given a separate representation in the DOM
so that non-ASCII characters may be output as pure ASCII.
@since: 12.0
"""
ordinal: int
"""The ordinal value of the unicode character to which this object refers."""
def __repr__(self) -> str:
return "CharRef(%d)" % (self.ordinal,)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I am a simple test resource.
"""
from twisted.web import static
class Test(static.Data):
isLeaf = True
def __init__(self):
static.Data.__init__(
self,
b"""
<html>
<head><title>Twisted Web Demo</title><head>
<body>
Hello! This is a Twisted Web test page.
</body>
</html>
""",
"text/html",
)

View File

@@ -0,0 +1,390 @@
# -*- test-case-name: twisted.web.test.test_distrib -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Distributed web servers.
This is going to have to be refactored so that argument parsing is done
by each subprocess and not by the main web server (i.e. GET, POST etc.).
"""
import copy
import os
import sys
try:
import pwd
except ImportError:
pwd = None # type: ignore[assignment]
from io import BytesIO
from xml.dom.minidom import getDOMImplementation
from twisted.internet import address, reactor
from twisted.logger import Logger
from twisted.persisted import styles
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.web import http, resource, server, static, util
from twisted.web.http_headers import Headers
class _ReferenceableProducerWrapper(pb.Referenceable):
def __init__(self, producer):
self.producer = producer
def remote_resumeProducing(self):
self.producer.resumeProducing()
def remote_pauseProducing(self):
self.producer.pauseProducing()
def remote_stopProducing(self):
self.producer.stopProducing()
class Request(pb.RemoteCopy, server.Request):
"""
A request which was received by a L{ResourceSubscription} and sent via
PB to a distributed node.
"""
def setCopyableState(self, state):
"""
Initialize this L{twisted.web.distrib.Request} based on the copied
state so that it closely resembles a L{twisted.web.server.Request}.
"""
for k in "host", "client":
tup = state[k]
addrdesc = {"INET": "TCP", "UNIX": "UNIX"}[tup[0]]
addr = {
"TCP": lambda: address.IPv4Address(addrdesc, tup[1], tup[2]),
"UNIX": lambda: address.UNIXAddress(tup[1]),
}[addrdesc]()
state[k] = addr
state["requestHeaders"] = Headers(dict(state["requestHeaders"]))
pb.RemoteCopy.setCopyableState(self, state)
# Emulate the local request interface --
self.content = BytesIO(self.content_data)
self.finish = self.remote.remoteMethod("finish")
self.setHeader = self.remote.remoteMethod("setHeader")
self.addCookie = self.remote.remoteMethod("addCookie")
self.setETag = self.remote.remoteMethod("setETag")
self.setResponseCode = self.remote.remoteMethod("setResponseCode")
self.setLastModified = self.remote.remoteMethod("setLastModified")
# To avoid failing if a resource tries to write a very long string
# all at once, this one will be handled slightly differently.
self._write = self.remote.remoteMethod("write")
def write(self, bytes):
"""
Write the given bytes to the response body.
@param bytes: The bytes to write. If this is longer than 640k, it
will be split up into smaller pieces.
"""
start = 0
end = SIZE_LIMIT
while True:
self._write(bytes[start:end])
start += SIZE_LIMIT
end += SIZE_LIMIT
if start >= len(bytes):
break
def registerProducer(self, producer, streaming):
self.remote.callRemote(
"registerProducer", _ReferenceableProducerWrapper(producer), streaming
).addErrback(self.fail)
def unregisterProducer(self):
self.remote.callRemote("unregisterProducer").addErrback(self.fail)
def fail(self, failure):
self._log.failure("", failure=failure)
pb.setUnjellyableForClass(server.Request, Request)
class Issue:
_log = Logger()
def __init__(self, request):
self.request = request
def finished(self, result):
if result is not server.NOT_DONE_YET:
assert isinstance(result, str), "return value not a string"
self.request.write(result)
self.request.finish()
def failed(self, failure):
# XXX: Argh. FIXME.
failure = str(failure)
self.request.write(
resource._UnsafeErrorPage(
http.INTERNAL_SERVER_ERROR,
"Server Connection Lost",
# GHSA-vg46-2rrj-3647 note: _PRE does HTML-escape the input.
"Connection to distributed server lost:" + util._PRE(failure),
).render(self.request)
)
self.request.finish()
self._log.info(failure)
class ResourceSubscription(resource.Resource):
isLeaf = 1
waiting = 0
_log = Logger()
def __init__(self, host, port):
resource.Resource.__init__(self)
self.host = host
self.port = port
self.pending = []
self.publisher = None
def __getstate__(self):
"""Get persistent state for this ResourceSubscription."""
# When I unserialize,
state = copy.copy(self.__dict__)
# Publisher won't be connected...
state["publisher"] = None
# I won't be making a connection
state["waiting"] = 0
# There will be no pending requests.
state["pending"] = []
return state
def connected(self, publisher):
"""I've connected to a publisher; I'll now send all my requests."""
self._log.info("connected to publisher")
publisher.broker.notifyOnDisconnect(self.booted)
self.publisher = publisher
self.waiting = 0
for request in self.pending:
self.render(request)
self.pending = []
def notConnected(self, msg):
"""I can't connect to a publisher; I'll now reply to all pending
requests.
"""
self._log.info("could not connect to distributed web service: {msg}", msg=msg)
self.waiting = 0
self.publisher = None
for request in self.pending:
request.write("Unable to connect to distributed server.")
request.finish()
self.pending = []
def booted(self):
self.notConnected("connection dropped")
def render(self, request):
"""Render this request, from my server.
This will always be asynchronous, and therefore return NOT_DONE_YET.
It spins off a request to the pb client, and either adds it to the list
of pending issues or requests it immediately, depending on if the
client is already connected.
"""
if not self.publisher:
self.pending.append(request)
if not self.waiting:
self.waiting = 1
bf = pb.PBClientFactory()
timeout = 10
if self.host == "unix":
reactor.connectUNIX(self.port, bf, timeout)
else:
reactor.connectTCP(self.host, self.port, bf, timeout)
d = bf.getRootObject()
d.addCallbacks(self.connected, self.notConnected)
else:
i = Issue(request)
self.publisher.callRemote("request", request).addCallbacks(
i.finished, i.failed
)
return server.NOT_DONE_YET
class ResourcePublisher(pb.Root, styles.Versioned):
"""
L{ResourcePublisher} exposes a remote API which can be used to respond
to request.
@ivar site: The site which will be used for resource lookup.
@type site: L{twisted.web.server.Site}
"""
_log = Logger()
def __init__(self, site):
self.site = site
persistenceVersion = 2
def upgradeToVersion2(self):
self.application.authorizer.removeIdentity("web")
del self.application.services[self.serviceName]
del self.serviceName
del self.application
del self.perspectiveName
def getPerspectiveNamed(self, name):
return self
def remote_request(self, request):
"""
Look up the resource for the given request and render it.
"""
res = self.site.getResourceFor(request)
self._log.info(request)
result = res.render(request)
if result is not server.NOT_DONE_YET:
request.write(result)
request.finish()
return server.NOT_DONE_YET
class UserDirectory(resource.Resource):
"""
A resource which lists available user resources and serves them as
children.
@ivar _pwd: An object like L{pwd} which is used to enumerate users and
their home directories.
"""
userDirName = "public_html"
userSocketName = ".twistd-web-pb"
template = """
<html>
<head>
<title>twisted.web.distrib.UserDirectory</title>
<style>
a
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
color: #369;
text-decoration: none;
}
th
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
font-weight: bold;
text-decoration: none;
text-align: left;
}
pre, code
{
font-family: "Courier New", Courier, monospace;
}
p, body, td, ol, ul, menu, blockquote, div
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
color: #000;
}
</style>
</head>
<body>
<h1>twisted.web.distrib.UserDirectory</h1>
%(users)s
</body>
</html>
"""
def __init__(self, userDatabase=None):
resource.Resource.__init__(self)
if userDatabase is None:
userDatabase = pwd
self._pwd = userDatabase
def _users(self):
"""
Return a list of two-tuples giving links to user resources and text to
associate with those links.
"""
users = []
for user in self._pwd.getpwall():
name, passwd, uid, gid, gecos, dir, shell = user
realname = gecos.split(",")[0]
if not realname:
realname = name
if os.path.exists(os.path.join(dir, self.userDirName)):
users.append((name, realname + " (file)"))
twistdsock = os.path.join(dir, self.userSocketName)
if os.path.exists(twistdsock):
linkName = name + ".twistd"
users.append((linkName, realname + " (twistd)"))
return users
def render_GET(self, request):
"""
Render as HTML a listing of all known users with links to their
personal resources.
"""
domImpl = getDOMImplementation()
newDoc = domImpl.createDocument(None, "ul", None)
listing = newDoc.documentElement
for link, text in self._users():
linkElement = newDoc.createElement("a")
linkElement.setAttribute("href", link + "/")
textNode = newDoc.createTextNode(text)
linkElement.appendChild(textNode)
item = newDoc.createElement("li")
item.appendChild(linkElement)
listing.appendChild(item)
htmlDoc = self.template % ({"users": listing.toxml()})
return htmlDoc.encode("utf-8")
def getChild(self, name, request):
if name == b"":
return self
td = b".twistd"
if name.endswith(td):
username = name[: -len(td)]
sub = 1
else:
username = name
sub = 0
try:
# Decode using the filesystem encoding to reverse a transformation
# done in the pwd module.
(
pw_name,
pw_passwd,
pw_uid,
pw_gid,
pw_gecos,
pw_dir,
pw_shell,
) = self._pwd.getpwnam(username.decode(sys.getfilesystemencoding()))
except KeyError:
return resource._UnsafeNoResource()
if sub:
twistdsock = os.path.join(pw_dir, self.userSocketName)
rs = ResourceSubscription("unix", twistdsock)
self.putChild(name, rs)
return rs
else:
path = os.path.join(pw_dir, self.userDirName)
if not os.path.exists(path):
return resource._UnsafeNoResource()
return static.File(path)

View File

@@ -0,0 +1,313 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A library for performing interesting tasks with DOM objects.
This module is now deprecated.
"""
import warnings
from io import StringIO
from incremental import Version, getVersionString
from twisted.web import microdom
from twisted.web.microdom import escape, getElementsByTagName, unescape
warningString = "twisted.web.domhelpers was deprecated at {}".format(
getVersionString(Version("Twisted", 23, 10, 0))
)
warnings.warn(warningString, DeprecationWarning, stacklevel=3)
# These modules are imported here as a shortcut.
escape
getElementsByTagName
class NodeLookupError(Exception):
pass
def substitute(request, node, subs):
"""
Look through the given node's children for strings, and
attempt to do string substitution with the given parameter.
"""
for child in node.childNodes:
if hasattr(child, "nodeValue") and child.nodeValue:
child.replaceData(0, len(child.nodeValue), child.nodeValue % subs)
substitute(request, child, subs)
def _get(node, nodeId, nodeAttrs=("id", "class", "model", "pattern")):
"""
(internal) Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes.
"""
if hasattr(node, "hasAttributes") and node.hasAttributes():
for nodeAttr in nodeAttrs:
if str(node.getAttribute(nodeAttr)) == nodeId:
return node
if node.hasChildNodes():
if hasattr(node.childNodes, "length"):
length = node.childNodes.length
else:
length = len(node.childNodes)
for childNum in range(length):
result = _get(node.childNodes[childNum], nodeId)
if result:
return result
def get(node, nodeId):
"""
Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, raise
L{NodeLookupError}.
"""
result = _get(node, nodeId)
if result:
return result
raise NodeLookupError(nodeId)
def getIfExists(node, nodeId):
"""
Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, return
L{None}.
"""
return _get(node, nodeId)
def getAndClear(node, nodeId):
"""Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, raise
L{NodeLookupError}. Remove all child nodes before returning.
"""
result = get(node, nodeId)
if result:
clearNode(result)
return result
def clearNode(node):
"""
Remove all children from the given node.
"""
node.childNodes[:] = []
def locateNodes(nodeList, key, value, noNesting=1):
"""
Find subnodes in the given node where the given attribute
has the given value.
"""
returnList = []
if not isinstance(nodeList, type([])):
return locateNodes(nodeList.childNodes, key, value, noNesting)
for childNode in nodeList:
if not hasattr(childNode, "getAttribute"):
continue
if str(childNode.getAttribute(key)) == value:
returnList.append(childNode)
if noNesting:
continue
returnList.extend(locateNodes(childNode, key, value, noNesting))
return returnList
def superSetAttribute(node, key, value):
if not hasattr(node, "setAttribute"):
return
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superSetAttribute(child, key, value)
def superPrependAttribute(node, key, value):
if not hasattr(node, "setAttribute"):
return
old = node.getAttribute(key)
if old:
node.setAttribute(key, value + "/" + old)
else:
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superPrependAttribute(child, key, value)
def superAppendAttribute(node, key, value):
if not hasattr(node, "setAttribute"):
return
old = node.getAttribute(key)
if old:
node.setAttribute(key, old + "/" + value)
else:
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superAppendAttribute(child, key, value)
def gatherTextNodes(iNode, dounescape=0, joinWith=""):
"""Visit each child node and collect its text data, if any, into a string.
For example::
>>> doc=microdom.parseString('<a>1<b>2<c>3</c>4</b></a>')
>>> gatherTextNodes(doc.documentElement)
'1234'
With dounescape=1, also convert entities back into normal characters.
@return: the gathered nodes as a single string
@rtype: str"""
gathered = []
gathered_append = gathered.append
slice = [iNode]
while len(slice) > 0:
c = slice.pop(0)
if hasattr(c, "nodeValue") and c.nodeValue is not None:
if dounescape:
val = unescape(c.nodeValue)
else:
val = c.nodeValue
gathered_append(val)
slice[:0] = c.childNodes
return joinWith.join(gathered)
class RawText(microdom.Text):
"""This is an evil and horrible speed hack. Basically, if you have a big
chunk of XML that you want to insert into the DOM, but you don't want to
incur the cost of parsing it, you can construct one of these and insert it
into the DOM. This will most certainly only work with microdom as the API
for converting nodes to xml is different in every DOM implementation.
This could be improved by making this class a Lazy parser, so if you
inserted this into the DOM and then later actually tried to mutate this
node, it would be parsed then.
"""
def writexml(
self,
writer,
indent="",
addindent="",
newl="",
strip=0,
nsprefixes=None,
namespace=None,
):
writer.write(f"{indent}{self.data}{newl}")
def findNodes(parent, matcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
# print child, child.nodeType, child.nodeName
if matcher(child):
accum.append(child)
findNodes(child, matcher, accum)
return accum
def findNodesShallowOnMatch(parent, matcher, recurseMatcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
# print child, child.nodeType, child.nodeName
if matcher(child):
accum.append(child)
if recurseMatcher(child):
findNodesShallowOnMatch(child, matcher, recurseMatcher, accum)
return accum
def findNodesShallow(parent, matcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
if matcher(child):
accum.append(child)
else:
findNodes(child, matcher, accum)
return accum
def findElementsWithAttributeShallow(parent, attribute):
"""
Return an iterable of the elements which are direct children of C{parent}
and which have the C{attribute} attribute.
"""
return findNodesShallow(
parent,
lambda n: getattr(n, "tagName", None) is not None and n.hasAttribute(attribute),
)
def findElements(parent, matcher):
"""
Return an iterable of the elements which are children of C{parent} for
which the predicate C{matcher} returns true.
"""
return findNodes(
parent,
lambda n, matcher=matcher: getattr(n, "tagName", None) is not None
and matcher(n),
)
def findElementsWithAttribute(parent, attribute, value=None):
if value:
return findElements(
parent,
lambda n, attribute=attribute, value=value: n.hasAttribute(attribute)
and n.getAttribute(attribute) == value,
)
else:
return findElements(
parent, lambda n, attribute=attribute: n.hasAttribute(attribute)
)
def findNodesNamed(parent, name):
return findNodes(parent, lambda n, name=name: n.nodeName == name)
def writeNodeData(node, oldio):
for subnode in node.childNodes:
if hasattr(subnode, "data"):
oldio.write("" + subnode.data)
else:
writeNodeData(subnode, oldio)
def getNodeText(node):
oldio = StringIO()
writeNodeData(node, oldio)
return oldio.getvalue()
def getParents(node):
l = []
while node:
l.append(node)
node = node.parentNode
return l
def namedChildren(parent, nodeName):
"""namedChildren(parent, nodeName) -> children (not descendants) of parent
that have tagName == nodeName
"""
return [n for n in parent.childNodes if getattr(n, "tagName", "") == nodeName]

View File

@@ -0,0 +1,442 @@
# -*- test-case-name: twisted.web.test.test_error -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Exception definitions for L{twisted.web}.
"""
__all__ = [
"Error",
"PageRedirect",
"InfiniteRedirection",
"RenderError",
"MissingRenderMethod",
"MissingTemplateLoader",
"UnexposedMethodError",
"UnfilledSlot",
"UnsupportedType",
"FlattenerError",
"RedirectWithNoLocation",
]
from collections.abc import Sequence
from typing import Optional, Union, cast
from twisted.python.compat import nativeString
from twisted.web._responses import RESPONSES
def _codeToMessage(code: Union[int, bytes]) -> Optional[bytes]:
"""
Returns the response message corresponding to an HTTP code, or None
if the code is unknown or unrecognized.
@param code: HTTP status code, for example C{http.NOT_FOUND}.
@return: A string message or none
"""
try:
return RESPONSES.get(int(code))
except (ValueError, AttributeError):
return None
class Error(Exception):
"""
A basic HTTP error.
@ivar status: Refers to an HTTP status code, for example C{http.NOT_FOUND}.
@param message: A short error message, for example "NOT FOUND".
@ivar response: A complete HTML document for an error page.
"""
status: bytes
message: Optional[bytes]
response: Optional[bytes]
def __init__(
self,
code: Union[int, bytes],
message: Optional[bytes] = None,
response: Optional[bytes] = None,
) -> None:
"""
Initializes a basic exception.
@type code: L{bytes} or L{int}
@param code: Refers to an HTTP status code (for example, 200) either as
an integer or a bytestring representing such. If no C{message} is
given, C{code} is mapped to a descriptive bytestring that is used
instead.
@type message: L{bytes}
@param message: A short error message, for example C{b"NOT FOUND"}.
@type response: L{bytes}
@param response: A complete HTML document for an error page.
"""
message = message or _codeToMessage(code)
Exception.__init__(self, code, message, response)
if isinstance(code, int):
# If we're given an int, convert it to a bytestring
# downloadPage gives a bytes, Agent gives an int, and it worked by
# accident previously, so just make it keep working.
code = b"%d" % (code,)
elif len(code) != 3 or not code.isdigit():
# Status codes must be 3 digits. See
# https://httpwg.org/specs/rfc9110.html#status.code.extensibility
raise ValueError(f"Not a valid HTTP status code: {code!r}")
self.status = code
self.message = message
self.response = response
def __str__(self) -> str:
s = self.status
if self.message:
s += b" " + self.message
return nativeString(s)
class PageRedirect(Error):
"""
A request resulted in an HTTP redirect.
@ivar location: The location of the redirect which was not followed.
"""
location: Optional[bytes]
def __init__(
self,
code: Union[int, bytes],
message: Optional[bytes] = None,
response: Optional[bytes] = None,
location: Optional[bytes] = None,
) -> None:
"""
Initializes a page redirect exception.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to a
descriptive string that is used instead.
@type message: L{bytes}
@param message: A short error message, for example C{b"NOT FOUND"}.
@type response: L{bytes}
@param response: A complete HTML document for an error page.
@type location: L{bytes}
@param location: The location response-header field value. It is an
absolute URI used to redirect the receiver to a location other than
the Request-URI so the request can be completed.
"""
Error.__init__(self, code, message, response)
if self.message and location:
self.message = self.message + b" to " + location
self.location = location
class InfiniteRedirection(Error):
"""
HTTP redirection is occurring endlessly.
@ivar location: The first URL in the series of redirections which was
not followed.
"""
location: Optional[bytes]
def __init__(
self,
code: Union[int, bytes],
message: Optional[bytes] = None,
response: Optional[bytes] = None,
location: Optional[bytes] = None,
) -> None:
"""
Initializes an infinite redirection exception.
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to a
descriptive string that is used instead.
@param message: A short error message, for example C{b"NOT FOUND"}.
@param response: A complete HTML document for an error page.
@param location: The location response-header field value. It is an
absolute URI used to redirect the receiver to a location other than
the Request-URI so the request can be completed.
"""
Error.__init__(self, code, message, response)
if self.message and location:
self.message = self.message + b" to " + location
self.location = location
class RedirectWithNoLocation(Error):
"""
Exception passed to L{ResponseFailed} if we got a redirect without a
C{Location} header field.
@type uri: L{bytes}
@ivar uri: The URI which failed to give a proper location header
field.
@since: 11.1
"""
message: bytes
uri: bytes
def __init__(self, code: Union[bytes, int], message: bytes, uri: bytes) -> None:
"""
Initializes a page redirect exception when no location is given.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to
a descriptive string that is used instead.
@type message: L{bytes}
@param message: A short error message.
@type uri: L{bytes}
@param uri: The URI which failed to give a proper location header
field.
"""
Error.__init__(self, code, message)
self.message = self.message + b" to " + uri
self.uri = uri
class UnsupportedMethod(Exception):
"""
Raised by a resource when faced with a strange request method.
RFC 2616 (HTTP 1.1) gives us two choices when faced with this situation:
If the type of request is known to us, but not allowed for the requested
resource, respond with NOT_ALLOWED. Otherwise, if the request is something
we don't know how to deal with in any case, respond with NOT_IMPLEMENTED.
When this exception is raised by a Resource's render method, the server
will make the appropriate response.
This exception's first argument MUST be a sequence of the methods the
resource *does* support.
"""
allowedMethods = ()
def __init__(self, allowedMethods, *args):
Exception.__init__(self, allowedMethods, *args)
self.allowedMethods = allowedMethods
if not isinstance(allowedMethods, Sequence):
raise TypeError(
"First argument must be a sequence of supported methods, "
"but my first argument is not a sequence."
)
def __str__(self) -> str:
return f"Expected one of {self.allowedMethods!r}"
class SchemeNotSupported(Exception):
"""
The scheme of a URI was not one of the supported values.
"""
class RenderError(Exception):
"""
Base exception class for all errors which can occur during template
rendering.
"""
class MissingRenderMethod(RenderError):
"""
Tried to use a render method which does not exist.
@ivar element: The element which did not have the render method.
@ivar renderName: The name of the renderer which could not be found.
"""
def __init__(self, element, renderName):
RenderError.__init__(self, element, renderName)
self.element = element
self.renderName = renderName
def __repr__(self) -> str:
return "{!r}: {!r} had no render method named {!r}".format(
self.__class__.__name__,
self.element,
self.renderName,
)
class MissingTemplateLoader(RenderError):
"""
L{MissingTemplateLoader} is raised when trying to render an Element without
a template loader, i.e. a C{loader} attribute.
@ivar element: The Element which did not have a document factory.
"""
def __init__(self, element):
RenderError.__init__(self, element)
self.element = element
def __repr__(self) -> str:
return f"{self.__class__.__name__!r}: {self.element!r} had no loader"
class UnexposedMethodError(Exception):
"""
Raised on any attempt to get a method which has not been exposed.
"""
class UnfilledSlot(Exception):
"""
During flattening, a slot with no associated data was encountered.
"""
class UnsupportedType(Exception):
"""
During flattening, an object of a type which cannot be flattened was
encountered.
"""
class ExcessiveBufferingError(Exception):
"""
The HTTP/2 protocol has been forced to buffer an excessive amount of
outbound data, and has therefore closed the connection and dropped all
outbound data.
"""
class FlattenerError(Exception):
"""
An error occurred while flattening an object.
@ivar _roots: A list of the objects on the flattener's stack at the time
the unflattenable object was encountered. The first element is least
deeply nested object and the last element is the most deeply nested.
"""
def __init__(self, exception, roots, traceback):
self._exception = exception
self._roots = roots
self._traceback = traceback
Exception.__init__(self, exception, roots, traceback)
def _formatRoot(self, obj):
"""
Convert an object from C{self._roots} to a string suitable for
inclusion in a render-traceback (like a normal Python traceback, but
can include "frame" source locations which are not in Python source
files).
@param obj: Any object which can be a render step I{root}.
Typically, L{Tag}s, strings, and other simple Python types.
@return: A string representation of C{obj}.
@rtype: L{str}
"""
# There's a circular dependency between this class and 'Tag', although
# only for an isinstance() check.
from twisted.web.template import Tag
if isinstance(obj, (bytes, str)):
# It's somewhat unlikely that there will ever be a str in the roots
# list. However, something like a MemoryError during a str.replace
# call (eg, replacing " with &quot;) could possibly cause this.
# Likewise, UTF-8 encoding a unicode string to a byte string might
# fail like this.
if len(obj) > 40:
if isinstance(obj, str):
ellipsis = "<...>"
else:
ellipsis = b"<...>"
return ascii(obj[:20] + ellipsis + obj[-20:])
else:
return ascii(obj)
elif isinstance(obj, Tag):
if obj.filename is None:
return "Tag <" + obj.tagName + ">"
else:
return 'File "%s", line %d, column %d, in "%s"' % (
obj.filename,
obj.lineNumber,
obj.columnNumber,
obj.tagName,
)
else:
return ascii(obj)
def __repr__(self) -> str:
"""
Present a string representation which includes a template traceback, so
we can tell where this error occurred in the template, as well as in
Python.
"""
# Avoid importing things unnecessarily until we actually need them;
# since this is an 'error' module we should be extra paranoid about
# that.
from traceback import format_list
if self._roots:
roots = (
" " + "\n ".join([self._formatRoot(r) for r in self._roots]) + "\n"
)
else:
roots = ""
if self._traceback:
traceback = (
"\n".join(
[
line
for entry in format_list(self._traceback)
for line in entry.splitlines()
]
)
+ "\n"
)
else:
traceback = ""
return cast(
str,
(
"Exception while flattening:\n"
+ roots
+ traceback
+ self._exception.__class__.__name__
+ ": "
+ str(self._exception)
+ "\n"
),
)
def __str__(self) -> str:
return repr(self)
class UnsupportedSpecialHeader(Exception):
"""
A HTTP/2 request was received that contained a HTTP/2 pseudo-header field
that is not recognised by Twisted.
"""

View File

@@ -0,0 +1,21 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Resource traversal integration with L{twisted.cred} to allow for
authentication and authorization of HTTP requests.
"""
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web._auth.digest import DigestCredentialFactory
# Expose HTTP authentication classes here.
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper
__all__ = [
"HTTPAuthSessionWrapper",
"BasicCredentialFactory",
"DigestCredentialFactory",
]

View File

@@ -0,0 +1,56 @@
# -*- test-case-name: twisted.web.test.test_html -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""I hold HTML generation helpers.
"""
from html import escape
from io import StringIO
from incremental import Version
from twisted.python import log
from twisted.python.deprecate import deprecated
@deprecated(Version("Twisted", 15, 3, 0), replacement="twisted.web.template")
def PRE(text):
"Wrap <pre> tags around some text and HTML-escape it."
return "<pre>" + escape(text) + "</pre>"
@deprecated(Version("Twisted", 15, 3, 0), replacement="twisted.web.template")
def UL(lst):
io = StringIO()
io.write("<ul>\n")
for el in lst:
io.write("<li> %s</li>\n" % el)
io.write("</ul>")
return io.getvalue()
@deprecated(Version("Twisted", 15, 3, 0), replacement="twisted.web.template")
def linkList(lst):
io = StringIO()
io.write("<ul>\n")
for hr, el in lst:
io.write(f'<li> <a href="{hr}">{el}</a></li>\n')
io.write("</ul>")
return io.getvalue()
@deprecated(Version("Twisted", 15, 3, 0), replacement="twisted.web.template")
def output(func, *args, **kw):
"""output(func, *args, **kw) -> html string
Either return the result of a function (which presumably returns an
HTML-legal string) or a sparse HTMLized error message and a message
in the server log.
"""
try:
return func(*args, **kw)
except BaseException:
log.msg(f"Error calling {func!r}:")
log.err()
return PRE("An error occurred.")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
# -*- test-case-name: twisted.web.test.test_http_headers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An API for storing HTTP header names and values.
"""
from typing import (
AnyStr,
ClassVar,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)
from twisted.python.compat import cmp, comparable
from twisted.web._abnf import _istoken
class InvalidHeaderName(ValueError):
"""
HTTP header names must be tokens, per RFC 9110 section 5.1.
"""
_T = TypeVar("_T")
def _sanitizeLinearWhitespace(headerComponent: bytes) -> bytes:
r"""
Replace linear whitespace (C{\n}, C{\r\n}, C{\r}) in a header
value with a single space.
@param headerComponent: The header value to sanitize.
@return: The sanitized header value.
"""
return b" ".join(headerComponent.splitlines())
@comparable
class Headers:
"""
Stores HTTP headers in a key and multiple value format.
When passed L{str}, header names (e.g. 'Content-Type')
are encoded using ISO-8859-1 and header values (e.g.
'text/html;charset=utf-8') are encoded using UTF-8. Some methods that return
values will return them in the same type as the name given.
If the header keys or values cannot be encoded or decoded using the rules
above, using just L{bytes} arguments to the methods of this class will
ensure no decoding or encoding is done, and L{Headers} will treat the keys
and values as opaque byte strings.
@ivar _rawHeaders: A L{dict} mapping header names as L{bytes} to L{list}s of
header values as L{bytes}.
"""
__slots__ = ["_rawHeaders"]
def __init__(
self,
rawHeaders: Optional[Mapping[AnyStr, Sequence[AnyStr]]] = None,
) -> None:
self._rawHeaders: Dict[bytes, List[bytes]] = {}
if rawHeaders is not None:
for name, values in rawHeaders.items():
self.setRawHeaders(name, values)
def __repr__(self) -> str:
"""
Return a string fully describing the headers set on this object.
"""
return "{}({!r})".format(
self.__class__.__name__,
self._rawHeaders,
)
def __cmp__(self, other):
"""
Define L{Headers} instances as being equal to each other if they have
the same raw headers.
"""
if isinstance(other, Headers):
return cmp(
sorted(self._rawHeaders.items()), sorted(other._rawHeaders.items())
)
return NotImplemented
def copy(self):
"""
Return a copy of itself with the same headers set.
@return: A new L{Headers}
"""
return self.__class__(self._rawHeaders)
def hasHeader(self, name: AnyStr) -> bool:
"""
Check for the existence of a given header.
@param name: The name of the HTTP header to check for.
@return: C{True} if the header exists, otherwise C{False}.
"""
return _nameEncoder.encode(name) in self._rawHeaders
def removeHeader(self, name: AnyStr) -> None:
"""
Remove the named header from this header object.
@param name: The name of the HTTP header to remove.
@return: L{None}
"""
self._rawHeaders.pop(_nameEncoder.encode(name), None)
def setRawHeaders(
self, name: Union[str, bytes], values: Sequence[Union[str, bytes]]
) -> None:
"""
Sets the raw representation of the given header.
@param name: The name of the HTTP header to set the values for.
@param values: A list of strings each one being a header value of
the given name.
@raise TypeError: Raised if C{values} is not a sequence of L{bytes}
or L{str}, or if C{name} is not L{bytes} or L{str}.
@return: L{None}
"""
_name = _nameEncoder.encode(name)
encodedValues: List[bytes] = []
for v in values:
if isinstance(v, str):
_v = v.encode("utf8")
else:
_v = v
encodedValues.append(_sanitizeLinearWhitespace(_v))
self._rawHeaders[_name] = encodedValues
def addRawHeader(self, name: Union[str, bytes], value: Union[str, bytes]) -> None:
"""
Add a new raw value for the given header.
@param name: The name of the header for which to set the value.
@param value: The value to set for the named header.
"""
self._rawHeaders.setdefault(_nameEncoder.encode(name), []).append(
_sanitizeLinearWhitespace(
value.encode("utf8") if isinstance(value, str) else value
)
)
@overload
def getRawHeaders(self, name: AnyStr) -> Optional[Sequence[AnyStr]]:
...
@overload
def getRawHeaders(self, name: AnyStr, default: _T) -> Union[Sequence[AnyStr], _T]:
...
def getRawHeaders(
self, name: AnyStr, default: Optional[_T] = None
) -> Union[Sequence[AnyStr], Optional[_T]]:
"""
Returns a sequence of headers matching the given name as the raw string
given.
@param name: The name of the HTTP header to get the values of.
@param default: The value to return if no header with the given C{name}
exists.
@return: If the named header is present, a sequence of its
values. Otherwise, C{default}.
"""
encodedName = _nameEncoder.encode(name)
values = self._rawHeaders.get(encodedName, [])
if not values:
return default
if isinstance(name, str):
return [v.decode("utf8") for v in values]
return values
def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]:
"""
Return an iterator of key, value pairs of all headers contained in this
object, as L{bytes}. The keys are capitalized in canonical
capitalization.
"""
return iter(self._rawHeaders.items())
class _NameEncoder:
"""
C{_NameEncoder} converts HTTP header names to L{bytes} and canonicalizies
their capitalization.
@cvar _caseMappings: A L{dict} that maps conventionally-capitalized
header names to their canonicalized representation, for headers with
unconventional capitalization.
@cvar _canonicalHeaderCache: A L{dict} that maps header names to their
canonicalized representation.
"""
__slots__ = ("_canonicalHeaderCache",)
_canonicalHeaderCache: Dict[Union[bytes, str], bytes]
_caseMappings: ClassVar[Dict[bytes, bytes]] = {
b"Content-Md5": b"Content-MD5",
b"Dnt": b"DNT",
b"Etag": b"ETag",
b"P3p": b"P3P",
b"Te": b"TE",
b"Www-Authenticate": b"WWW-Authenticate",
b"X-Xss-Protection": b"X-XSS-Protection",
}
_MAX_CACHED_HEADERS: ClassVar[int] = 10_000
def __init__(self):
self._canonicalHeaderCache = {}
def encode(self, name: Union[str, bytes]) -> bytes:
"""
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1
bytestring if required. It will be canonicalized to Http-Header-Case.
@raises InvalidHeaderName:
If the header name contains invalid characters like whitespace
or NUL.
@param name: An HTTP header name
@return: C{name}, encoded if required, in Header-Case
"""
if canonicalName := self._canonicalHeaderCache.get(name):
return canonicalName
bytes_name = name.encode("iso-8859-1") if isinstance(name, str) else name
if not _istoken(bytes_name):
raise InvalidHeaderName(bytes_name)
result = b"-".join([word.capitalize() for word in bytes_name.split(b"-")])
# Some headers have special capitalization:
if result in self._caseMappings:
result = self._caseMappings[result]
# In general, we should only see a very small number of header
# variations in the real world, so caching them is fine. However, an
# attacker could generate infinite header variations to fill up RAM, so
# we cap how many we cache. The performance degradation from lack of
# caching won't be that bad, and legit traffic won't hit it.
if len(self._canonicalHeaderCache) < self._MAX_CACHED_HEADERS:
self._canonicalHeaderCache[name] = result
return result
_nameEncoder = _NameEncoder()
"""
The global name encoder.
"""
__all__ = ["Headers"]

View File

@@ -0,0 +1,831 @@
# -*- test-case-name: twisted.web.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Interface definitions for L{twisted.web}.
@var UNKNOWN_LENGTH: An opaque object which may be used as the value of
L{IBodyProducer.length} to indicate that the length of the entity
body is not known in advance.
"""
from typing import TYPE_CHECKING, Callable, List, Optional
from zope.interface import Attribute, Interface
from twisted.cred.credentials import IUsernameDigestHash
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IPushProducer
from twisted.web.http_headers import Headers
if TYPE_CHECKING:
from twisted.web.template import Flattenable, Tag
class IRequest(Interface):
"""
An HTTP request.
@since: 9.0
"""
method = Attribute("A L{bytes} giving the HTTP method that was used.")
uri = Attribute(
"A L{bytes} giving the full encoded URI which was requested (including"
" query arguments)."
)
path = Attribute(
"A L{bytes} giving the encoded query path of the request URI (not "
"including query arguments)."
)
args = Attribute(
"A mapping of decoded query argument names as L{bytes} to "
"corresponding query argument values as L{list}s of L{bytes}. "
"For example, for a URI with C{foo=bar&foo=baz&quux=spam} "
"for its query part, C{args} will be C{{b'foo': [b'bar', b'baz'], "
"b'quux': [b'spam']}}."
)
prepath = Attribute(
"The URL path segments which have been processed during resource "
"traversal, as a list of L{bytes}."
)
postpath = Attribute(
"The URL path segments which have not (yet) been processed "
"during resource traversal, as a list of L{bytes}."
)
requestHeaders = Attribute(
"A L{http_headers.Headers} instance giving all received HTTP request "
"headers."
)
content = Attribute(
"A file-like object giving the request body. This may be a file on "
"disk, an L{io.BytesIO}, or some other type. The implementation is "
"free to decide on a per-request basis."
)
responseHeaders = Attribute(
"A L{http_headers.Headers} instance holding all HTTP response "
"headers to be sent."
)
def getHeader(key):
"""
Get an HTTP request header.
@type key: L{bytes} or L{str}
@param key: The name of the header to get the value of.
@rtype: L{bytes} or L{str} or L{None}
@return: The value of the specified header, or L{None} if that header
was not present in the request. The string type of the result
matches the type of C{key}.
"""
def getCookie(key):
"""
Get a cookie that was sent from the network.
@type key: L{bytes}
@param key: The name of the cookie to get.
@rtype: L{bytes} or L{None}
@returns: The value of the specified cookie, or L{None} if that cookie
was not present in the request.
"""
def getAllHeaders():
"""
Return dictionary mapping the names of all received headers to the last
value received for each.
Since this method does not return all header information,
C{requestHeaders.getAllRawHeaders()} may be preferred.
"""
def getRequestHostname():
"""
Get the hostname that the HTTP client passed in to the request.
This will either use the C{Host:} header (if it is available; which,
for a spec-compliant request, it will be) or the IP address of the host
we are listening on if the header is unavailable.
@note: This is the I{host portion} of the requested resource, which
means that:
1. it might be an IPv4 or IPv6 address, not just a DNS host
name,
2. there's no guarantee it's even a I{valid} host name or IP
address, since the C{Host:} header may be malformed,
3. it does not include the port number.
@returns: the requested hostname
@rtype: L{bytes}
"""
def getHost():
"""
Get my originally requesting transport's host.
@return: An L{IAddress<twisted.internet.interfaces.IAddress>}.
"""
def getClientAddress():
"""
Return the address of the client who submitted this request.
The address may not be a network address. Callers must check
its type before using it.
@since: 18.4
@return: the client's address.
@rtype: an L{IAddress} provider.
"""
def getClientIP():
"""
Return the IP address of the client who submitted this request.
This method is B{deprecated}. See L{getClientAddress} instead.
@returns: the client IP address or L{None} if the request was submitted
over a transport where IP addresses do not make sense.
@rtype: L{str} or L{None}
"""
def getUser():
"""
Return the HTTP user sent with this request, if any.
If no user was supplied, return the empty string.
@returns: the HTTP user, if any
@rtype: L{str}
"""
def getPassword():
"""
Return the HTTP password sent with this request, if any.
If no password was supplied, return the empty string.
@returns: the HTTP password, if any
@rtype: L{str}
"""
def isSecure():
"""
Return True if this request is using a secure transport.
Normally this method returns True if this request's HTTPChannel
instance is using a transport that implements ISSLTransport.
This will also return True if setHost() has been called
with ssl=True.
@returns: True if this request is secure
@rtype: C{bool}
"""
def getSession(sessionInterface=None):
"""
Look up the session associated with this request or create a new one if
there is not one.
@return: The L{Session} instance identified by the session cookie in
the request, or the C{sessionInterface} component of that session
if C{sessionInterface} is specified.
"""
def URLPath():
"""
@return: A L{URLPath<twisted.python.urlpath.URLPath>} instance
which identifies the URL for which this request is.
"""
def prePathURL():
"""
At any time during resource traversal or resource rendering,
returns an absolute URL to the most nested resource which has
yet been reached.
@see: {twisted.web.server.Request.prepath}
@return: An absolute URL.
@rtype: L{bytes}
"""
def rememberRootURL():
"""
Remember the currently-processed part of the URL for later
recalling.
"""
def getRootURL():
"""
Get a previously-remembered URL.
@return: An absolute URL.
@rtype: L{bytes}
"""
# Methods for outgoing response
def finish():
"""
Indicate that the response to this request is complete.
"""
def write(data):
"""
Write some data to the body of the response to this request. Response
headers are written the first time this method is called, after which
new response headers may not be added.
@param data: Bytes of the response body.
@type data: L{bytes}
"""
def addCookie(
k,
v,
expires=None,
domain=None,
path=None,
max_age=None,
comment=None,
secure=None,
):
"""
Set an outgoing HTTP cookie.
In general, you should consider using sessions instead of cookies, see
L{twisted.web.server.Request.getSession} and the
L{twisted.web.server.Session} class for details.
"""
def setResponseCode(code, message=None):
"""
Set the HTTP response code.
@type code: L{int}
@type message: L{bytes}
"""
def setHeader(k, v):
"""
Set an HTTP response header. Overrides any previously set values for
this header.
@type k: L{bytes} or L{str}
@param k: The name of the header for which to set the value.
@type v: L{bytes} or L{str}
@param v: The value to set for the named header. A L{str} will be
UTF-8 encoded, which may not interoperable with other
implementations. Avoid passing non-ASCII characters if possible.
"""
def redirect(url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
def setLastModified(when):
"""
Set the C{Last-Modified} time for the response to this request.
If I am called more than once, I ignore attempts to set Last-Modified
earlier, only replacing the Last-Modified time if it is to a later
value.
If I am a conditional request, I may modify my response code to
L{NOT_MODIFIED<http.NOT_MODIFIED>} if appropriate for the time given.
@param when: The last time the resource being returned was modified, in
seconds since the epoch.
@type when: L{int} or L{float}
@return: If I am a C{If-Modified-Since} conditional request and the time
given is not newer than the condition, I return
L{CACHED<http.CACHED>} to indicate that you should write no body.
Otherwise, I return a false value.
"""
def setETag(etag):
"""
Set an C{entity tag} for the outgoing response.
That's "entity tag" as in the HTTP/1.1 I{ETag} header, "used for
comparing two or more entities from the same requested resource."
If I am a conditional request, I may modify my response code to
L{NOT_MODIFIED<http.NOT_MODIFIED>} or
L{PRECONDITION_FAILED<http.PRECONDITION_FAILED>}, if appropriate for the
tag given.
@param etag: The entity tag for the resource being returned.
@type etag: L{str}
@return: If I am a C{If-None-Match} conditional request and the tag
matches one in the request, I return L{CACHED<http.CACHED>} to
indicate that you should write no body. Otherwise, I return a
false value.
"""
def setHost(host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
This method is useful for working with reverse HTTP proxies (e.g. both
Squid and Apache's mod_proxy can do this), when the address the HTTP
client is using is different than the one we're listening on.
For example, Apache may be listening on https://www.example.com, and
then forwarding requests to http://localhost:8080, but we don't want
HTML produced by Twisted to say 'http://localhost:8080', they should
say 'https://www.example.com', so we do::
request.setHost('www.example.com', 443, ssl=1)
"""
class INonQueuedRequestFactory(Interface):
"""
A factory of L{IRequest} objects that does not take a ``queued`` parameter.
"""
def __call__(channel):
"""
Create an L{IRequest} that is operating on the given channel. There
must only be one L{IRequest} object processing at any given time on a
channel.
@param channel: A L{twisted.web.http.HTTPChannel} object.
@type channel: L{twisted.web.http.HTTPChannel}
@return: A request object.
@rtype: L{IRequest}
"""
class IAccessLogFormatter(Interface):
"""
An object which can represent an HTTP request as a line of text for
inclusion in an access log file.
"""
def __call__(timestamp, request):
"""
Generate a line for the access log.
@param timestamp: The time at which the request was completed in the
standard format for access logs.
@type timestamp: L{unicode}
@param request: The request object about which to log.
@type request: L{twisted.web.server.Request}
@return: One line describing the request without a trailing newline.
@rtype: L{unicode}
"""
class ICredentialFactory(Interface):
"""
A credential factory defines a way to generate a particular kind of
authentication challenge and a way to interpret the responses to these
challenges. It creates
L{ICredentials<twisted.cred.credentials.ICredentials>} providers from
responses. These objects will be used with L{twisted.cred} to authenticate
an authorize requests.
"""
scheme = Attribute(
"A L{str} giving the name of the authentication scheme with which "
"this factory is associated. For example, C{'basic'} or C{'digest'}."
)
def getChallenge(request):
"""
Generate a new challenge to be sent to a client.
@type request: L{twisted.web.http.Request}
@param request: The request the response to which this challenge will
be included.
@rtype: L{dict}
@return: A mapping from L{str} challenge fields to associated L{str}
values.
"""
def decode(response, request):
"""
Create a credentials object from the given response.
@type response: L{str}
@param response: scheme specific response string
@type request: L{twisted.web.http.Request}
@param request: The request being processed (from which the response
was taken).
@raise twisted.cred.error.LoginFailed: If the response is invalid.
@rtype: L{twisted.cred.credentials.ICredentials} provider
@return: The credentials represented by the given response.
"""
class IBodyProducer(IPushProducer):
"""
Objects which provide L{IBodyProducer} write bytes to an object which
provides L{IConsumer<twisted.internet.interfaces.IConsumer>} by calling its
C{write} method repeatedly.
L{IBodyProducer} providers may start producing as soon as they have an
L{IConsumer<twisted.internet.interfaces.IConsumer>} provider. That is, they
should not wait for a C{resumeProducing} call to begin writing data.
L{IConsumer.unregisterProducer<twisted.internet.interfaces.IConsumer.unregisterProducer>}
must not be called. Instead, the
L{Deferred<twisted.internet.defer.Deferred>} returned from C{startProducing}
must be fired when all bytes have been written.
L{IConsumer.write<twisted.internet.interfaces.IConsumer.write>} may
synchronously invoke any of C{pauseProducing}, C{resumeProducing}, or
C{stopProducing}. These methods must be implemented with this in mind.
@since: 9.0
"""
# Despite the restrictions above and the additional requirements of
# stopProducing documented below, this interface still needs to be an
# IPushProducer subclass. Providers of it will be passed to IConsumer
# providers which only know about IPushProducer and IPullProducer, not
# about this interface. This interface needs to remain close enough to one
# of those interfaces for consumers to work with it.
length = Attribute(
"""
C{length} is a L{int} indicating how many bytes in total this
L{IBodyProducer} will write to the consumer or L{UNKNOWN_LENGTH}
if this is not known in advance.
"""
)
def startProducing(consumer):
"""
Start producing to the given
L{IConsumer<twisted.internet.interfaces.IConsumer>} provider.
@return: A L{Deferred<twisted.internet.defer.Deferred>} which stops
production of data when L{Deferred.cancel} is called, and which
fires with L{None} when all bytes have been produced or with a
L{Failure<twisted.python.failure.Failure>} if there is any problem
before all bytes have been produced.
"""
def stopProducing():
"""
In addition to the standard behavior of
L{IProducer.stopProducing<twisted.internet.interfaces.IProducer.stopProducing>}
(stop producing data), make sure the
L{Deferred<twisted.internet.defer.Deferred>} returned by
C{startProducing} is never fired.
"""
class IRenderable(Interface):
"""
An L{IRenderable} is an object that may be rendered by the
L{twisted.web.template} templating system.
"""
def lookupRenderMethod(
name: str,
) -> Callable[[Optional[IRequest], "Tag"], "Flattenable"]:
"""
Look up and return the render method associated with the given name.
@param name: The value of a render directive encountered in the
document returned by a call to L{IRenderable.render}.
@return: A two-argument callable which will be invoked with the request
being responded to and the tag object on which the render directive
was encountered.
"""
def render(request: Optional[IRequest]) -> "Flattenable":
"""
Get the document for this L{IRenderable}.
@param request: The request in response to which this method is being
invoked.
@return: An object which can be flattened.
"""
class ITemplateLoader(Interface):
"""
A loader for templates; something usable as a value for
L{twisted.web.template.Element}'s C{loader} attribute.
"""
def load() -> List["Flattenable"]:
"""
Load a template suitable for rendering.
@return: a L{list} of flattenable objects, such as byte and unicode
strings, L{twisted.web.template.Element}s and L{IRenderable} providers.
"""
class IResponse(Interface):
"""
An object representing an HTTP response received from an HTTP server.
@since: 11.1
"""
version = Attribute(
"A three-tuple describing the protocol and protocol version "
"of the response. The first element is of type L{str}, the second "
"and third are of type L{int}. For example, C{(b'HTTP', 1, 1)}."
)
code = Attribute("The HTTP status code of this response, as a L{int}.")
phrase = Attribute("The HTTP reason phrase of this response, as a L{str}.")
headers = Attribute("The HTTP response L{Headers} of this response.")
length = Attribute(
"The L{int} number of bytes expected to be in the body of this "
"response or L{UNKNOWN_LENGTH} if the server did not indicate how "
"many bytes to expect. For I{HEAD} responses, this will be 0; if "
"the response includes a I{Content-Length} header, it will be "
"available in C{headers}."
)
request = Attribute("The L{IClientRequest} that resulted in this response.")
previousResponse = Attribute(
"The previous L{IResponse} from a redirect, or L{None} if there was no "
"previous response. This can be used to walk the response or request "
"history for redirections."
)
def deliverBody(protocol):
"""
Register an L{IProtocol<twisted.internet.interfaces.IProtocol>} provider
to receive the response body.
The protocol will be connected to a transport which provides
L{IPushProducer}. The protocol's C{connectionLost} method will be
called with:
- L{ResponseDone}, which indicates that all bytes from the response
have been successfully delivered.
- L{PotentialDataLoss}, which indicates that it cannot be determined
if the entire response body has been delivered. This only occurs
when making requests to HTTP servers which do not set
I{Content-Length} or a I{Transfer-Encoding} in the response.
- L{ResponseFailed}, which indicates that some bytes from the response
were lost. The C{reasons} attribute of the exception may provide
more specific indications as to why.
"""
def setPreviousResponse(response):
"""
Set the reference to the previous L{IResponse}.
The value of the previous response can be read via
L{IResponse.previousResponse}.
"""
class _IRequestEncoder(Interface):
"""
An object encoding data passed to L{IRequest.write}, for example for
compression purpose.
@since: 12.3
"""
def encode(data):
"""
Encode the data given and return the result.
@param data: The content to encode.
@type data: L{str}
@return: The encoded data.
@rtype: L{str}
"""
def finish():
"""
Callback called when the request is closing.
@return: If necessary, the pending data accumulated from previous
C{encode} calls.
@rtype: L{str}
"""
class _IRequestEncoderFactory(Interface):
"""
A factory for returing L{_IRequestEncoder} instances.
@since: 12.3
"""
def encoderForRequest(request):
"""
If applicable, returns a L{_IRequestEncoder} instance which will encode
the request.
"""
class IClientRequest(Interface):
"""
An object representing an HTTP request to make to an HTTP server.
@since: 13.1
"""
method = Attribute(
"The HTTP method for this request, as L{bytes}. For example: "
"C{b'GET'}, C{b'HEAD'}, C{b'POST'}, etc."
)
absoluteURI = Attribute(
"The absolute URI of the requested resource, as L{bytes}; or L{None} "
"if the absolute URI cannot be determined."
)
headers = Attribute(
"Headers to be sent to the server, as "
"a L{twisted.web.http_headers.Headers} instance."
)
class IAgent(Interface):
"""
An agent makes HTTP requests.
The way in which requests are issued is left up to each implementation.
Some may issue them directly to the server indicated by the net location
portion of the request URL. Others may use a proxy specified by system
configuration.
Processing of responses is also left very widely specified. An
implementation may perform no special handling of responses, or it may
implement redirect following or content negotiation, it may implement a
cookie store or automatically respond to authentication challenges. It may
implement many other unforeseen behaviors as well.
It is also intended that L{IAgent} implementations be composable. An
implementation which provides cookie handling features should re-use an
implementation that provides connection pooling and this combination could
be used by an implementation which adds content negotiation functionality.
Some implementations will be completely self-contained, such as those which
actually perform the network operations to send and receive requests, but
most or all other implementations should implement a small number of new
features (perhaps one new feature) and delegate the rest of the
request/response machinery to another implementation.
This allows for great flexibility in the behavior an L{IAgent} will
provide. For example, an L{IAgent} with web browser-like behavior could be
obtained by combining a number of (hypothetical) implementations::
baseAgent = Agent(reactor)
decode = ContentDecoderAgent(baseAgent, [(b"gzip", GzipDecoder())])
cookie = CookieAgent(decode, diskStore.cookie)
authenticate = AuthenticateAgent(
cookie, [diskStore.credentials, GtkAuthInterface()])
cache = CacheAgent(authenticate, diskStore.cache)
redirect = BrowserLikeRedirectAgent(cache, limit=10)
doSomeRequests(cache)
"""
def request(
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> Deferred[IResponse]:
"""
Request the resource at the given location.
@param method: The request method to use, such as C{b"GET"}, C{b"HEAD"},
C{b"PUT"}, C{b"POST"}, etc.
@param uri: The location of the resource to request. This should be an
absolute URI but some implementations may support relative URIs
(with absolute or relative paths). I{HTTP} and I{HTTPS} are the
schemes most likely to be supported but others may be as well.
@param headers: The headers to send with the request (or L{None} to
send no extra headers). An implementation may add its own headers
to this (for example for client identification or content
negotiation).
@param bodyProducer: An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or, L{None} if the request is to have
no body.
@return: A L{Deferred} that fires with an L{IResponse} provider when
the header of the response has been received (regardless of the
response status code) or with a L{Failure} if there is any problem
which prevents that response from being received (including
problems that prevent the request from being sent).
"""
class IPolicyForHTTPS(Interface):
"""
An L{IPolicyForHTTPS} provides a policy for verifying the certificates of
HTTPS connections, in the form of a L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>} per network
location.
@since: 14.0
"""
def creatorForNetloc(hostname, port):
"""
Create a L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>}
appropriate for the given URL "netloc"; i.e. hostname and port number
pair.
@param hostname: The name of the requested remote host.
@type hostname: L{bytes}
@param port: The number of the requested remote port.
@type port: L{int}
@return: A client connection creator expressing the security
requirements for the given remote host.
@rtype: L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>}
"""
class IAgentEndpointFactory(Interface):
"""
An L{IAgentEndpointFactory} provides a way of constructing an endpoint
used for outgoing Agent requests. This is useful in the case of needing to
proxy outgoing connections, or to otherwise vary the transport used.
@since: 15.0
"""
def endpointForURI(uri):
"""
Construct and return an L{IStreamClientEndpoint} for the outgoing
request's connection.
@param uri: The URI of the request.
@type uri: L{twisted.web.client.URI}
@return: An endpoint which will have its C{connect} method called to
issue the request.
@rtype: an L{IStreamClientEndpoint} provider
@raises twisted.internet.error.SchemeNotSupported: If the given
URI's scheme cannot be handled by this factory.
"""
UNKNOWN_LENGTH = "twisted.web.iweb.UNKNOWN_LENGTH"
__all__ = [
"IUsernameDigestHash",
"ICredentialFactory",
"IRequest",
"IBodyProducer",
"IRenderable",
"IResponse",
"_IRequestEncoder",
"_IRequestEncoderFactory",
"IClientRequest",
"UNKNOWN_LENGTH",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
!.gitignore

View File

@@ -0,0 +1,134 @@
# -*- test-case-name: twisted.web.test.test_pages -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Utility implementations of L{IResource}.
"""
__all__ = (
"errorPage",
"notFound",
"forbidden",
)
from typing import cast
from twisted.web import http
from twisted.web.iweb import IRenderable, IRequest
from twisted.web.resource import IResource, Resource
from twisted.web.template import renderElement, tags
class _ErrorPage(Resource):
"""
L{_ErrorPage} is a resource that responds to all requests with a particular
(parameterized) HTTP status code and an HTML body containing some
descriptive text. This is useful for rendering simple error pages.
@see: L{twisted.web.pages.errorPage}
@ivar _code: An integer HTTP status code which will be used for the
response.
@ivar _brief: A short string which will be included in the response body as
the page title.
@ivar _detail: A longer string which will be included in the response body.
"""
def __init__(self, code: int, brief: str, detail: str) -> None:
super().__init__()
self._code: int = code
self._brief: str = brief
self._detail: str = detail
def render(self, request: IRequest) -> object:
"""
Respond to all requests with the given HTTP status code and an HTML
document containing the explanatory strings.
"""
request.setResponseCode(self._code)
request.setHeader(b"content-type", b"text/html; charset=utf-8")
return renderElement(
request,
# cast because the type annotations here seem off; Tag isn't an
# IRenderable but also probably should be? See
# https://github.com/twisted/twisted/issues/4982
cast(
IRenderable,
tags.html(
tags.head(tags.title(f"{self._code} - {self._brief}")),
tags.body(tags.h1(self._brief), tags.p(self._detail)),
),
),
)
def getChild(self, path: bytes, request: IRequest) -> Resource:
"""
Handle all requests for which L{_ErrorPage} lacks a child by returning
this error page.
@param path: A path segment.
@param request: HTTP request
"""
return self
def errorPage(code: int, brief: str, detail: str) -> _ErrorPage:
"""
Build a resource that responds to all requests with a particular HTTP
status code and an HTML body containing some descriptive text. This is
useful for rendering simple error pages.
The resource dynamically handles all paths below it. Use
L{IResource.putChild()} to override a specific path.
@param code: An integer HTTP status code which will be used for the
response.
@param brief: A short string which will be included in the response
body as the page title.
@param detail: A longer string which will be included in the
response body.
@returns: An L{IResource}
"""
return _ErrorPage(code, brief, detail)
def notFound(
brief: str = "No Such Resource",
message: str = "Sorry. No luck finding that resource.",
) -> IResource:
"""
Generate an L{IResource} with a 404 Not Found status code.
@see: L{twisted.web.pages.errorPage}
@param brief: A short string displayed as the page title.
@param brief: A longer string displayed in the page body.
@returns: An L{IResource}
"""
return _ErrorPage(http.NOT_FOUND, brief, message)
def forbidden(
brief: str = "Forbidden Resource", message: str = "Sorry, resource is forbidden."
) -> IResource:
"""
Generate an L{IResource} with a 403 Forbidden status code.
@see: L{twisted.web.pages.errorPage}
@param brief: A short string displayed as the page title.
@param brief: A longer string displayed in the page body.
@returns: An L{IResource}
"""
return _ErrorPage(http.FORBIDDEN, brief, message)

View File

@@ -0,0 +1,296 @@
# -*- test-case-name: twisted.web.test.test_proxy -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Simplistic HTTP proxy support.
This comes in two main variants - the Proxy and the ReverseProxy.
When a Proxy is in use, a browser trying to connect to a server (say,
www.yahoo.com) will be intercepted by the Proxy, and the proxy will covertly
connect to the server, and return the result.
When a ReverseProxy is in use, the client connects directly to the ReverseProxy
(say, www.yahoo.com) which farms off the request to one of a pool of servers,
and returns the result.
Normally, a Proxy is used on the client end of an Internet connection, while a
ReverseProxy is used on the server end.
"""
from urllib.parse import quote as urlquote, urlparse, urlunparse
from twisted.internet import reactor
from twisted.internet.protocol import ClientFactory
from twisted.web.http import _QUEUED_SENTINEL, HTTPChannel, HTTPClient, Request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
class ProxyClient(HTTPClient):
"""
Used by ProxyClientFactory to implement a simple web proxy.
@ivar _finished: A flag which indicates whether or not the original request
has been finished yet.
"""
_finished = False
def __init__(self, command, rest, version, headers, data, father):
self.father = father
self.command = command
self.rest = rest
if b"proxy-connection" in headers:
del headers[b"proxy-connection"]
headers[b"connection"] = b"close"
headers.pop(b"keep-alive", None)
self.headers = headers
self.data = data
def connectionMade(self):
self.sendCommand(self.command, self.rest)
for header, value in self.headers.items():
self.sendHeader(header, value)
self.endHeaders()
self.transport.write(self.data)
def handleStatus(self, version, code, message):
self.father.setResponseCode(int(code), message)
def handleHeader(self, key, value):
# t.web.server.Request sets default values for these headers in its
# 'process' method. When these headers are received from the remote
# server, they ought to override the defaults, rather than append to
# them.
if key.lower() in [b"server", b"date", b"content-type"]:
self.father.responseHeaders.setRawHeaders(key, [value])
else:
self.father.responseHeaders.addRawHeader(key, value)
def handleResponsePart(self, buffer):
self.father.write(buffer)
def handleResponseEnd(self):
"""
Finish the original request, indicating that the response has been
completely written to it, and disconnect the outgoing transport.
"""
if not self._finished:
self._finished = True
self.father.finish()
self.transport.loseConnection()
class ProxyClientFactory(ClientFactory):
"""
Used by ProxyRequest to implement a simple web proxy.
"""
# Type is wrong. See: https://twistedmatrix.com/trac/ticket/10006
protocol = ProxyClient # type: ignore[assignment]
def __init__(self, command, rest, version, headers, data, father):
self.father = father
self.command = command
self.rest = rest
self.headers = headers
self.data = data
self.version = version
def buildProtocol(self, addr):
return self.protocol(
self.command, self.rest, self.version, self.headers, self.data, self.father
)
def clientConnectionFailed(self, connector, reason):
"""
Report a connection failure in a response to the incoming request as
an error.
"""
self.father.setResponseCode(501, b"Gateway error")
self.father.responseHeaders.addRawHeader(b"Content-Type", b"text/html")
self.father.write(b"<H1>Could not connect</H1>")
self.father.finish()
class ProxyRequest(Request):
"""
Used by Proxy to implement a simple web proxy.
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
protocols = {b"http": ProxyClientFactory}
ports = {b"http": 80}
def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor):
Request.__init__(self, channel, queued)
self.reactor = reactor
def process(self):
parsed = urlparse(self.uri)
protocol = parsed[0]
host = parsed[1].decode("ascii")
port = self.ports[protocol]
if ":" in host:
host, port = host.split(":")
port = int(port)
rest = urlunparse((b"", b"") + parsed[2:])
if not rest:
rest = rest + b"/"
class_ = self.protocols[protocol]
headers = self.getAllHeaders().copy()
if b"host" not in headers:
headers[b"host"] = host.encode("ascii")
self.content.seek(0, 0)
s = self.content.read()
clientFactory = class_(self.method, rest, self.clientproto, headers, s, self)
self.reactor.connectTCP(host, port, clientFactory)
class Proxy(HTTPChannel):
"""
This class implements a simple web proxy.
Since it inherits from L{twisted.web.http.HTTPChannel}, to use it you
should do something like this::
from twisted.web import http
f = http.HTTPFactory()
f.protocol = Proxy
Make the HTTPFactory a listener on a port as per usual, and you have
a fully-functioning web proxy!
"""
requestFactory = ProxyRequest
class ReverseProxyRequest(Request):
"""
Used by ReverseProxy to implement a simple reverse proxy.
@ivar proxyClientFactoryClass: a proxy client factory class, used to create
new connections.
@type proxyClientFactoryClass: L{ClientFactory}
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
proxyClientFactoryClass = ProxyClientFactory
def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor):
Request.__init__(self, channel, queued)
self.reactor = reactor
def process(self):
"""
Handle this request by connecting to the proxied server and forwarding
it there, then forwarding the response back as the response to this
request.
"""
self.requestHeaders.setRawHeaders(b"host", [self.factory.host.encode("ascii")])
clientFactory = self.proxyClientFactoryClass(
self.method,
self.uri,
self.clientproto,
self.getAllHeaders(),
self.content.read(),
self,
)
self.reactor.connectTCP(self.factory.host, self.factory.port, clientFactory)
class ReverseProxy(HTTPChannel):
"""
Implements a simple reverse proxy.
For details of usage, see the file examples/reverse-proxy.py.
"""
requestFactory = ReverseProxyRequest
class ReverseProxyResource(Resource):
"""
Resource that renders the results gotten from another server
Put this resource in the tree to cause everything below it to be relayed
to a different server.
@ivar proxyClientFactoryClass: a proxy client factory class, used to create
new connections.
@type proxyClientFactoryClass: L{ClientFactory}
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
proxyClientFactoryClass = ProxyClientFactory
def __init__(self, host, port, path, reactor=reactor):
"""
@param host: the host of the web server to proxy.
@type host: C{str}
@param port: the port of the web server to proxy.
@type port: C{port}
@param path: the base path to fetch data from. Note that you shouldn't
put any trailing slashes in it, it will be added automatically in
request. For example, if you put B{/foo}, a request on B{/bar} will
be proxied to B{/foo/bar}. Any required encoding of special
characters (such as " " or "/") should have been done already.
@type path: C{bytes}
"""
Resource.__init__(self)
self.host = host
self.port = port
self.path = path
self.reactor = reactor
def getChild(self, path, request):
"""
Create and return a proxy resource with the same proxy configuration
as this one, except that its path also contains the segment given by
C{path} at the end.
"""
return ReverseProxyResource(
self.host,
self.port,
self.path + b"/" + urlquote(path, safe=b"").encode("utf-8"),
self.reactor,
)
def render(self, request):
"""
Render a request by forwarding it to the proxied server.
"""
# RFC 2616 tells us that we can omit the port if it's the default port,
# but we have to provide it otherwise
if self.port == 80:
host = self.host
else:
host = "%s:%d" % (self.host, self.port)
request.requestHeaders.setRawHeaders(b"host", [host.encode("ascii")])
request.content.seek(0, 0)
qs = urlparse(request.uri)[4]
if qs:
rest = self.path + b"?" + qs
else:
rest = self.path
clientFactory = self.proxyClientFactoryClass(
request.method,
rest,
request.clientproto,
request.getAllHeaders(),
request.content.read(),
request,
)
self.reactor.connectTCP(self.host, self.port, clientFactory)
return NOT_DONE_YET

View File

@@ -0,0 +1,460 @@
# -*- test-case-name: twisted.web.test.test_web, twisted.web.test.test_resource -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the lowest-level Resource class.
See L{twisted.web.pages} for some utility implementations.
"""
from __future__ import annotations
__all__ = [
"IResource",
"getChildForRequest",
"Resource",
"ErrorPage",
"NoResource",
"ForbiddenResource",
"EncodingResourceWrapper",
]
import warnings
from typing import Sequence
from zope.interface import Attribute, Interface, implementer
from incremental import Version
from twisted.python.compat import nativeString
from twisted.python.components import proxyForInterface
from twisted.python.deprecate import deprecated
from twisted.python.reflect import prefixedMethodNames
from twisted.web._responses import FORBIDDEN, NOT_FOUND
from twisted.web.error import UnsupportedMethod
class IResource(Interface):
"""
A web resource.
"""
isLeaf = Attribute(
"""
Signal if this IResource implementor is a "leaf node" or not. If True,
getChildWithDefault will not be called on this Resource.
"""
)
def getChildWithDefault(name, request):
"""
Return a child with the given name for the given request.
This is the external interface used by the Resource publishing
machinery. If implementing IResource without subclassing
Resource, it must be provided. However, if subclassing Resource,
getChild overridden instead.
@param name: A single path component from a requested URL. For example,
a request for I{http://example.com/foo/bar} will result in calls to
this method with C{b"foo"} and C{b"bar"} as values for this
argument.
@type name: C{bytes}
@param request: A representation of all of the information about the
request that is being made for this child.
@type request: L{twisted.web.server.Request}
"""
def putChild(path: bytes, child: "IResource") -> None:
"""
Put a child L{IResource} implementor at the given path.
@param path: A single path component, to be interpreted relative to the
path this resource is found at, at which to put the given child.
For example, if resource A can be found at I{http://example.com/foo}
then a call like C{A.putChild(b"bar", B)} will make resource B
available at I{http://example.com/foo/bar}.
The path component is I{not} URL-encoded -- pass C{b'foo bar'}
rather than C{b'foo%20bar'}.
"""
def render(request):
"""
Render a request. This is called on the leaf resource for a request.
@return: Either C{server.NOT_DONE_YET} to indicate an asynchronous or a
C{bytes} instance to write as the response to the request. If
C{NOT_DONE_YET} is returned, at some point later (for example, in a
Deferred callback) call C{request.write(b"<html>")} to write data to
the request, and C{request.finish()} to send the data to the
browser.
@raise twisted.web.error.UnsupportedMethod: If the HTTP verb
requested is not supported by this resource.
"""
def getChildForRequest(resource, request):
"""
Traverse resource tree to find who will handle the request.
"""
while request.postpath and not resource.isLeaf:
pathElement = request.postpath.pop(0)
request.prepath.append(pathElement)
resource = resource.getChildWithDefault(pathElement, request)
return resource
@implementer(IResource)
class Resource:
"""
Define a web-accessible resource.
This serves two main purposes: one is to provide a standard representation
for what HTTP specification calls an 'entity', and the other is to provide
an abstract directory structure for URL retrieval.
"""
entityType = IResource
allowedMethods: Sequence[bytes]
server = None
def __init__(self):
"""
Initialize.
"""
self.children = {}
isLeaf = 0
### Abstract Collection Interface
def listStaticNames(self):
return list(self.children.keys())
def listStaticEntities(self):
return list(self.children.items())
def listNames(self):
return list(self.listStaticNames()) + self.listDynamicNames()
def listEntities(self):
return list(self.listStaticEntities()) + self.listDynamicEntities()
def listDynamicNames(self):
return []
def listDynamicEntities(self, request=None):
return []
def getStaticEntity(self, name):
return self.children.get(name)
def getDynamicEntity(self, name, request):
if name not in self.children:
return self.getChild(name, request)
else:
return None
def delEntity(self, name):
del self.children[name]
def reallyPutEntity(self, name, entity):
self.children[name] = entity
# Concrete HTTP interface
def getChild(self, path, request):
"""
Retrieve a 'child' resource from me.
Implement this to create dynamic resource generation -- resources which
are always available may be registered with self.putChild().
This will not be called if the class-level variable 'isLeaf' is set in
your subclass; instead, the 'postpath' attribute of the request will be
left as a list of the remaining path elements.
For example, the URL /foo/bar/baz will normally be::
| site.resource.getChild('foo').getChild('bar').getChild('baz').
However, if the resource returned by 'bar' has isLeaf set to true, then
the getChild call will never be made on it.
Parameters and return value have the same meaning and requirements as
those defined by L{IResource.getChildWithDefault}.
"""
return _UnsafeNoResource()
def getChildWithDefault(self, path, request):
"""
Retrieve a static or dynamically generated child resource from me.
First checks if a resource was added manually by putChild, and then
call getChild to check for dynamic resources. Only override if you want
to affect behaviour of all child lookups, rather than just dynamic
ones.
This will check to see if I have a pre-registered child resource of the
given name, and call getChild if I do not.
@see: L{IResource.getChildWithDefault}
"""
if path in self.children:
return self.children[path]
return self.getChild(path, request)
def getChildForRequest(self, request):
"""
Deprecated in favor of L{getChildForRequest}.
@see: L{twisted.web.resource.getChildForRequest}.
"""
warnings.warn(
"Please use module level getChildForRequest.", DeprecationWarning, 2
)
return getChildForRequest(self, request)
def putChild(self, path: bytes, child: IResource) -> None:
"""
Register a static child.
You almost certainly don't want '/' in your path. If you
intended to have the root of a folder, e.g. /foo/, you want
path to be ''.
@param path: A single path component.
@param child: The child resource to register.
@see: L{IResource.putChild}
"""
if not isinstance(path, bytes):
raise TypeError(f"Path segment must be bytes, but {path!r} is {type(path)}")
self.children[path] = child
# IResource is incomplete and doesn't mention this server attribute, see
# https://github.com/twisted/twisted/issues/11717
child.server = self.server # type: ignore[attr-defined]
def render(self, request):
"""
Render a given resource. See L{IResource}'s render method.
I delegate to methods of self with the form 'render_METHOD'
where METHOD is the HTTP that was used to make the
request. Examples: render_GET, render_HEAD, render_POST, and
so on. Generally you should implement those methods instead of
overriding this one.
render_METHOD methods are expected to return a byte string which will be
the rendered page, unless the return value is C{server.NOT_DONE_YET}, in
which case it is this class's responsibility to write the results using
C{request.write(data)} and then call C{request.finish()}.
Old code that overrides render() directly is likewise expected
to return a byte string or NOT_DONE_YET.
@see: L{IResource.render}
"""
m = getattr(self, "render_" + nativeString(request.method), None)
if not m:
try:
allowedMethods = self.allowedMethods
except AttributeError:
allowedMethods = _computeAllowedMethods(self)
raise UnsupportedMethod(allowedMethods)
return m(request)
def render_HEAD(self, request):
"""
Default handling of HEAD method.
I just return self.render_GET(request). When method is HEAD,
the framework will handle this correctly.
"""
return self.render_GET(request)
def _computeAllowedMethods(resource):
"""
Compute the allowed methods on a C{Resource} based on defined render_FOO
methods. Used when raising C{UnsupportedMethod} but C{Resource} does
not define C{allowedMethods} attribute.
"""
allowedMethods = []
for name in prefixedMethodNames(resource.__class__, "render_"):
# Potentially there should be an API for encode('ascii') in this
# situation - an API for taking a Python native string (bytes on Python
# 2, text on Python 3) and returning a socket-compatible string type.
allowedMethods.append(name.encode("ascii"))
return allowedMethods
class _UnsafeErrorPageBase(Resource):
"""
Base class for deprecated error page resources.
@ivar template: A native string which will have a dictionary interpolated
into it to generate the response body. The dictionary has the following
keys:
- C{"code"}: The status code passed to L{_UnsafeErrorPage.__init__}.
- C{"brief"}: The brief description passed to
L{_UnsafeErrorPage.__init__}.
- C{"detail"}: The detailed description passed to
L{_UnsafeErrorPage.__init__}.
@ivar code: An integer status code which will be used for the response.
@type code: C{int}
@ivar brief: A short string which will be included in the response body as
the page title.
@type brief: C{str}
@ivar detail: A longer string which will be included in the response body.
@type detail: C{str}
"""
template = """
<html>
<head><title>%(code)s - %(brief)s</title></head>
<body>
<h1>%(brief)s</h1>
<p>%(detail)s</p>
</body>
</html>
"""
def __init__(self, status, brief, detail):
Resource.__init__(self)
self.code = status
self.brief = brief
self.detail = detail
def render(self, request):
request.setResponseCode(self.code)
request.setHeader(b"content-type", b"text/html; charset=utf-8")
interpolated = self.template % dict(
code=self.code, brief=self.brief, detail=self.detail
)
if isinstance(interpolated, str):
return interpolated.encode("utf-8")
return interpolated
def getChild(self, chnam, request):
return self
class _UnsafeErrorPage(_UnsafeErrorPageBase):
"""
L{_UnsafeErrorPage}, publicly available via the deprecated alias
C{ErrorPage}, is a resource which responds with a particular
(parameterized) status and a body consisting of HTML containing some
descriptive text. This is useful for rendering simple error pages.
Deprecated in Twisted 22.10.0 because it permits HTML injection; use
L{twisted.web.pages.errorPage} instead.
"""
@deprecated(
Version("Twisted", 22, 10, 0),
"Use twisted.web.pages.errorPage instead, which properly escapes HTML.",
)
def __init__(self, status, brief, detail):
_UnsafeErrorPageBase.__init__(self, status, brief, detail)
class _UnsafeNoResource(_UnsafeErrorPageBase):
"""
L{_UnsafeNoResource}, publicly available via the deprecated alias
C{NoResource}, is a specialization of L{_UnsafeErrorPage} which
returns the HTTP response code I{NOT FOUND}.
Deprecated in Twisted 22.10.0 because it permits HTML injection; use
L{twisted.web.pages.notFound} instead.
"""
@deprecated(
Version("Twisted", 22, 10, 0),
"Use twisted.web.pages.notFound instead, which properly escapes HTML.",
)
def __init__(self, message="Sorry. No luck finding that resource."):
_UnsafeErrorPageBase.__init__(self, NOT_FOUND, "No Such Resource", message)
class _UnsafeForbiddenResource(_UnsafeErrorPageBase):
"""
L{_UnsafeForbiddenResource}, publicly available via the deprecated alias
C{ForbiddenResource} is a specialization of L{_UnsafeErrorPage} which
returns the I{FORBIDDEN} HTTP response code.
Deprecated in Twisted 22.10.0 because it permits HTML injection; use
L{twisted.web.pages.forbidden} instead.
"""
@deprecated(
Version("Twisted", 22, 10, 0),
"Use twisted.web.pages.forbidden instead, which properly escapes HTML.",
)
def __init__(self, message="Sorry, resource is forbidden."):
_UnsafeErrorPageBase.__init__(self, FORBIDDEN, "Forbidden Resource", message)
# Deliberately undocumented public aliases. See GHSA-vg46-2rrj-3647.
ErrorPage = _UnsafeErrorPage
NoResource = _UnsafeNoResource
ForbiddenResource = _UnsafeForbiddenResource
class _IEncodingResource(Interface):
"""
A resource which knows about L{_IRequestEncoderFactory}.
@since: 12.3
"""
def getEncoder(request):
"""
Parse the request and return an encoder if applicable, using
L{_IRequestEncoderFactory.encoderForRequest}.
@return: A L{_IRequestEncoder}, or L{None}.
"""
@implementer(_IEncodingResource)
class EncodingResourceWrapper(proxyForInterface(IResource)): # type: ignore[misc]
"""
Wrap a L{IResource}, potentially applying an encoding to the response body
generated.
Note that the returned children resources won't be wrapped, so you have to
explicitly wrap them if you want the encoding to be applied.
@ivar encoders: A list of
L{_IRequestEncoderFactory<twisted.web.iweb._IRequestEncoderFactory>}
returning L{_IRequestEncoder<twisted.web.iweb._IRequestEncoder>} that
may transform the data passed to C{Request.write}. The list must be
sorted in order of priority: the first encoder factory handling the
request will prevent the others from doing the same.
@type encoders: C{list}.
@since: 12.3
"""
def __init__(self, original, encoders):
super().__init__(original)
self._encoders = encoders
def getEncoder(self, request):
"""
Browser the list of encoders looking for one applicable encoder.
"""
for encoderFactory in self._encoders:
encoder = encoderFactory.encoderForRequest(request)
if encoder is not None:
return encoder

View File

@@ -0,0 +1,55 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.web import resource
class RewriterResource(resource.Resource):
def __init__(self, orig, *rewriteRules):
resource.Resource.__init__(self)
self.resource = orig
self.rewriteRules = list(rewriteRules)
def _rewrite(self, request):
for rewriteRule in self.rewriteRules:
rewriteRule(request)
def getChild(self, path, request):
request.postpath.insert(0, path)
request.prepath.pop()
self._rewrite(request)
path = request.postpath.pop(0)
request.prepath.append(path)
return self.resource.getChildWithDefault(path, request)
def render(self, request):
self._rewrite(request)
return self.resource.render(request)
def tildeToUsers(request):
if request.postpath and request.postpath[0][:1] == "~":
request.postpath[:1] = ["users", request.postpath[0][1:]]
request.path = "/" + "/".join(request.prepath + request.postpath)
def alias(aliasPath, sourcePath):
"""
I am not a very good aliaser. But I'm the best I can be. If I'm
aliasing to a Resource that generates links, and it uses any parts
of request.prepath to do so, the links will not be relative to the
aliased path, but rather to the aliased-to path. That I can't
alias static.File directory listings that nicely. However, I can
still be useful, as many resources will play nice.
"""
sourcePath = sourcePath.split("/")
aliasPath = aliasPath.split("/")
def rewriter(request):
if request.postpath[: len(aliasPath)] == aliasPath:
after = request.postpath[len(aliasPath) :]
request.postpath = sourcePath + after
request.path = "/" + "/".join(request.prepath + request.postpath)
return rewriter

View File

@@ -0,0 +1,193 @@
# -*- test-case-name: twisted.web.test.test_script -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I contain PythonScript, which is a very simple python script resource.
"""
import os
import traceback
from io import StringIO
from twisted import copyright
from twisted.python.compat import execfile, networkString
from twisted.python.filepath import _coerceToFilesystemEncoding
from twisted.web import http, resource, server, static, util
rpyNoResource = """<p>You forgot to assign to the variable "resource" in your script. For example:</p>
<pre>
# MyCoolWebApp.rpy
import mygreatresource
resource = mygreatresource.MyGreatResource()
</pre>
"""
class AlreadyCached(Exception):
"""
This exception is raised when a path has already been cached.
"""
class CacheScanner:
def __init__(self, path, registry):
self.path = path
self.registry = registry
self.doCache = 0
def cache(self):
c = self.registry.getCachedPath(self.path)
if c is not None:
raise AlreadyCached(c)
self.recache()
def recache(self):
self.doCache = 1
noRsrc = resource._UnsafeErrorPage(500, "Whoops! Internal Error", rpyNoResource)
def ResourceScript(path, registry):
"""
I am a normal py file which must define a 'resource' global, which should
be an instance of (a subclass of) web.resource.Resource; it will be
renderred.
"""
cs = CacheScanner(path, registry)
glob = {
"__file__": _coerceToFilesystemEncoding("", path),
"resource": noRsrc,
"registry": registry,
"cache": cs.cache,
"recache": cs.recache,
}
try:
execfile(path, glob, glob)
except AlreadyCached as ac:
return ac.args[0]
rsrc = glob["resource"]
if cs.doCache and rsrc is not noRsrc:
registry.cachePath(path, rsrc)
return rsrc
def ResourceTemplate(path, registry):
from quixote import ptl_compile
glob = {
"__file__": _coerceToFilesystemEncoding("", path),
"resource": resource._UnsafeErrorPage(
500, "Whoops! Internal Error", rpyNoResource
),
"registry": registry,
}
with open(path) as f: # Not closed by quixote as of 2.9.1
e = ptl_compile.compile_template(f, path)
code = compile(e, "<source>", "exec")
eval(code, glob, glob)
return glob["resource"]
class ResourceScriptWrapper(resource.Resource):
def __init__(self, path, registry=None):
resource.Resource.__init__(self)
self.path = path
self.registry = registry or static.Registry()
def render(self, request):
res = ResourceScript(self.path, self.registry)
return res.render(request)
def getChildWithDefault(self, path, request):
res = ResourceScript(self.path, self.registry)
return res.getChildWithDefault(path, request)
class ResourceScriptDirectory(resource.Resource):
"""
L{ResourceScriptDirectory} is a resource which serves scripts from a
filesystem directory. File children of a L{ResourceScriptDirectory} will
be served using L{ResourceScript}. Directory children will be served using
another L{ResourceScriptDirectory}.
@ivar path: A C{str} giving the filesystem path in which children will be
looked up.
@ivar registry: A L{static.Registry} instance which will be used to decide
how to interpret scripts found as children of this resource.
"""
def __init__(self, pathname, registry=None):
resource.Resource.__init__(self)
self.path = pathname
self.registry = registry or static.Registry()
def getChild(self, path, request):
fn = os.path.join(self.path, path)
if os.path.isdir(fn):
return ResourceScriptDirectory(fn, self.registry)
if os.path.exists(fn):
return ResourceScript(fn, self.registry)
return resource._UnsafeNoResource()
def render(self, request):
return resource._UnsafeNoResource().render(request)
class PythonScript(resource.Resource):
"""
I am an extremely simple dynamic resource; an embedded python script.
This will execute a file (usually of the extension '.epy') as Python code,
internal to the webserver.
"""
isLeaf = True
def __init__(self, filename, registry):
"""
Initialize me with a script name.
"""
self.filename = filename
self.registry = registry
def render(self, request):
"""
Render me to a web client.
Load my file, execute it in a special namespace (with 'request' and
'__file__' global vars) and finish the request. Output to the web-page
will NOT be handled with print - standard output goes to the log - but
with request.write.
"""
request.setHeader(
b"x-powered-by", networkString("Twisted/%s" % copyright.version)
)
namespace = {
"request": request,
"__file__": _coerceToFilesystemEncoding("", self.filename),
"registry": self.registry,
}
try:
execfile(self.filename, namespace, namespace)
except OSError as e:
if e.errno == 2: # file not found
request.setResponseCode(http.NOT_FOUND)
request.write(
resource._UnsafeNoResource("File not found.").render(request)
)
except BaseException:
io = StringIO()
traceback.print_exc(file=io)
output = util._PRE(io.getvalue())
output = output.encode("utf8")
request.write(output)
request.finish()
return server.NOT_DONE_YET

View File

@@ -0,0 +1,891 @@
# -*- test-case-name: twisted.web.test.test_web -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This is a web server which integrates with the twisted.internet infrastructure.
@var NOT_DONE_YET: A token value which L{twisted.web.resource.IResource.render}
implementations can return to indicate that the application will later call
C{.write} and C{.finish} to complete the request, and that the HTTP
connection should be left open.
@type NOT_DONE_YET: Opaque; do not depend on any particular type for this
value.
"""
import copy
import os
import re
import zlib
from binascii import hexlify
from html import escape
from typing import List, Optional
from urllib.parse import quote as _quote
from zope.interface import implementer
from twisted import copyright
from twisted.internet import address, interfaces
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.logger import Logger
from twisted.python import components, failure, reflect
from twisted.python.compat import nativeString, networkString
from twisted.spread.pb import Copyable, ViewPoint
from twisted.web import http, iweb, resource, util
from twisted.web.error import UnsupportedMethod
from twisted.web.http import (
NO_CONTENT,
NOT_MODIFIED,
HTTPFactory,
Request as _HTTPRequest,
datetimeToString,
unquote,
)
NOT_DONE_YET = 1
__all__ = [
"supportedMethods",
"Request",
"Session",
"Site",
"version",
"NOT_DONE_YET",
"GzipEncoderFactory",
]
# Support for other methods may be implemented on a per-resource basis.
supportedMethods = (b"GET", b"HEAD", b"POST")
def quote(string, *args, **kwargs):
return _quote(string.decode("charmap"), *args, **kwargs).encode("charmap")
def _addressToTuple(addr):
if isinstance(addr, address.IPv4Address):
return ("INET", addr.host, addr.port)
elif isinstance(addr, address.UNIXAddress):
return ("UNIX", addr.name)
else:
return tuple(addr)
@implementer(iweb.IRequest)
class Request(Copyable, http.Request, components.Componentized):
"""
An HTTP request.
@ivar defaultContentType: A L{bytes} giving the default I{Content-Type}
value to send in responses if no other value is set. L{None} disables
the default.
@ivar _insecureSession: The L{Session} object representing state that will
be transmitted over plain-text HTTP.
@ivar _secureSession: The L{Session} object representing the state that
will be transmitted only over HTTPS.
"""
defaultContentType: Optional[bytes] = b"text/html"
site = None
appRootURL = None
prepath: Optional[List[bytes]] = None
postpath: Optional[List[bytes]] = None
__pychecker__ = "unusednames=issuer"
_inFakeHead = False
_encoder = None
_log = Logger()
def __init__(self, *args, **kw):
_HTTPRequest.__init__(self, *args, **kw)
components.Componentized.__init__(self)
def getStateToCopyFor(self, issuer):
x = self.__dict__.copy()
del x["transport"]
# XXX refactor this attribute out; it's from protocol
# del x['server']
del x["channel"]
del x["content"]
del x["site"]
self.content.seek(0, 0)
x["content_data"] = self.content.read()
x["remote"] = ViewPoint(issuer, self)
# Address objects aren't jellyable
x["host"] = _addressToTuple(x["host"])
x["client"] = _addressToTuple(x["client"])
# Header objects also aren't jellyable.
x["requestHeaders"] = list(x["requestHeaders"].getAllRawHeaders())
return x
# HTML generation helpers
def sibLink(self, name):
"""
Return the text that links to a sibling of the requested resource.
@param name: The sibling resource
@type name: C{bytes}
@return: A relative URL.
@rtype: C{bytes}
"""
if self.postpath:
return (len(self.postpath) * b"../") + name
else:
return name
def childLink(self, name):
"""
Return the text that links to a child of the requested resource.
@param name: The child resource
@type name: C{bytes}
@return: A relative URL.
@rtype: C{bytes}
"""
lpp = len(self.postpath)
if lpp > 1:
return ((lpp - 1) * b"../") + name
elif lpp == 1:
return name
else: # lpp == 0
if len(self.prepath) and self.prepath[-1]:
return self.prepath[-1] + b"/" + name
else:
return name
def gotLength(self, length):
"""
Called when HTTP channel got length of content in this request.
This method is not intended for users.
@param length: The length of the request body, as indicated by the
request headers. L{None} if the request headers do not indicate a
length.
"""
try:
getContentFile = self.channel.site.getContentFile
except AttributeError:
_HTTPRequest.gotLength(self, length)
else:
self.content = getContentFile(length)
def process(self):
"""
Process a request.
Find the addressed resource in this request's L{Site},
and call L{self.render()<Request.render()>} with it.
@see: L{Site.getResourceFor()}
"""
# get site from channel
self.site = self.channel.site
# set various default headers
self.setHeader(b"Server", version)
self.setHeader(b"Date", datetimeToString())
# Resource Identification
self.prepath = []
self.postpath = list(map(unquote, self.path[1:].split(b"/")))
# Short-circuit for requests whose path is '*'.
if self.path == b"*":
self._handleStar()
return
try:
resrc = self.site.getResourceFor(self)
if resource._IEncodingResource.providedBy(resrc):
encoder = resrc.getEncoder(self)
if encoder is not None:
self._encoder = encoder
self.render(resrc)
except BaseException:
self.processingFailed(failure.Failure())
def write(self, data):
"""
Write data to the transport (if not responding to a HEAD request).
@param data: A string to write to the response.
@type data: L{bytes}
"""
if not self.startedWriting:
# Before doing the first write, check to see if a default
# Content-Type header should be supplied. We omit it on
# NOT_MODIFIED and NO_CONTENT responses. We also omit it if there
# is a Content-Length header set to 0, as empty bodies don't need
# a content-type.
needsCT = self.code not in (NOT_MODIFIED, NO_CONTENT)
contentType = self.responseHeaders.getRawHeaders(b"Content-Type")
contentLength = self.responseHeaders.getRawHeaders(b"Content-Length")
contentLengthZero = contentLength and (contentLength[0] == b"0")
if (
needsCT
and contentType is None
and self.defaultContentType is not None
and not contentLengthZero
):
self.responseHeaders.setRawHeaders(
b"Content-Type", [self.defaultContentType]
)
# Only let the write happen if we're not generating a HEAD response by
# faking out the request method. Note, if we are doing that,
# startedWriting will never be true, and the above logic may run
# multiple times. It will only actually change the responseHeaders
# once though, so it's still okay.
if not self._inFakeHead:
if self._encoder:
data = self._encoder.encode(data)
_HTTPRequest.write(self, data)
def finish(self):
"""
Override L{twisted.web.http.Request.finish} for possible encoding.
"""
if self._encoder:
data = self._encoder.finish()
if data:
_HTTPRequest.write(self, data)
return _HTTPRequest.finish(self)
def render(self, resrc):
"""
Ask a resource to render itself.
If the resource does not support the requested method,
generate a C{NOT IMPLEMENTED} or C{NOT ALLOWED} response.
@param resrc: The resource to render.
@type resrc: L{twisted.web.resource.IResource}
@see: L{IResource.render()<twisted.web.resource.IResource.render()>}
"""
try:
body = resrc.render(self)
except UnsupportedMethod as e:
allowedMethods = e.allowedMethods
if (self.method == b"HEAD") and (b"GET" in allowedMethods):
# We must support HEAD (RFC 2616, 5.1.1). If the
# resource doesn't, fake it by giving the resource
# a 'GET' request and then return only the headers,
# not the body.
self._log.info(
"Using GET to fake a HEAD request for {resrc}", resrc=resrc
)
self.method = b"GET"
self._inFakeHead = True
body = resrc.render(self)
if body is NOT_DONE_YET:
self._log.info(
"Tried to fake a HEAD request for {resrc}, but "
"it got away from me.",
resrc=resrc,
)
# Oh well, I guess we won't include the content length.
else:
self.setHeader(b"Content-Length", b"%d" % (len(body),))
self._inFakeHead = False
self.method = b"HEAD"
self.write(b"")
self.finish()
return
if self.method in (supportedMethods):
# We MUST include an Allow header
# (RFC 2616, 10.4.6 and 14.7)
self.setHeader(b"Allow", b", ".join(allowedMethods))
s = (
"""Your browser approached me (at %(URI)s) with"""
""" the method "%(method)s". I only allow"""
""" the method%(plural)s %(allowed)s here."""
% {
"URI": escape(nativeString(self.uri)),
"method": nativeString(self.method),
"plural": ((len(allowedMethods) > 1) and "s") or "",
"allowed": ", ".join([nativeString(x) for x in allowedMethods]),
}
)
epage = resource._UnsafeErrorPage(
http.NOT_ALLOWED, "Method Not Allowed", s
)
body = epage.render(self)
else:
epage = resource._UnsafeErrorPage(
http.NOT_IMPLEMENTED,
"Huh?",
"I don't know how to treat a %s request."
% (escape(self.method.decode("charmap")),),
)
body = epage.render(self)
# end except UnsupportedMethod
if body is NOT_DONE_YET:
return
if not isinstance(body, bytes):
body = resource._UnsafeErrorPage(
http.INTERNAL_SERVER_ERROR,
"Request did not return bytes",
"Request: "
# GHSA-vg46-2rrj-3647 note: _PRE does HTML-escape the input.
+ util._PRE(reflect.safe_repr(self))
+ "<br />"
+ "Resource: "
+ util._PRE(reflect.safe_repr(resrc))
+ "<br />"
+ "Value: "
+ util._PRE(reflect.safe_repr(body)),
).render(self)
if self.method == b"HEAD":
if len(body) > 0:
# This is a Bad Thing (RFC 2616, 9.4)
self._log.info(
"Warning: HEAD request {slf} for resource {resrc} is"
" returning a message body. I think I'll eat it.",
slf=self,
resrc=resrc,
)
self.setHeader(b"Content-Length", b"%d" % (len(body),))
self.write(b"")
else:
self.setHeader(b"Content-Length", b"%d" % (len(body),))
self.write(body)
self.finish()
def processingFailed(self, reason):
"""
Finish this request with an indication that processing failed and
possibly display a traceback.
@param reason: Reason this request has failed.
@type reason: L{twisted.python.failure.Failure}
@return: The reason passed to this method.
@rtype: L{twisted.python.failure.Failure}
"""
self._log.failure("", failure=reason)
if self.site.displayTracebacks:
body = (
b"<html><head><title>web.Server Traceback"
b" (most recent call last)</title></head>"
b"<body><b>web.Server Traceback"
b" (most recent call last):</b>\n\n"
+ util.formatFailure(reason)
+ b"\n\n</body></html>\n"
)
else:
body = (
b"<html><head><title>Processing Failed"
b"</title></head><body>"
b"<b>Processing Failed</b></body></html>"
)
self.setResponseCode(http.INTERNAL_SERVER_ERROR)
self.setHeader(b"Content-Type", b"text/html")
self.setHeader(b"Content-Length", b"%d" % (len(body),))
self.write(body)
self.finish()
return reason
def view_write(self, issuer, data):
"""Remote version of write; same interface."""
self.write(data)
def view_finish(self, issuer):
"""Remote version of finish; same interface."""
self.finish()
def view_addCookie(self, issuer, k, v, **kwargs):
"""Remote version of addCookie; same interface."""
self.addCookie(k, v, **kwargs)
def view_setHeader(self, issuer, k, v):
"""Remote version of setHeader; same interface."""
self.setHeader(k, v)
def view_setLastModified(self, issuer, when):
"""Remote version of setLastModified; same interface."""
self.setLastModified(when)
def view_setETag(self, issuer, tag):
"""Remote version of setETag; same interface."""
self.setETag(tag)
def view_setResponseCode(self, issuer, code, message=None):
"""
Remote version of setResponseCode; same interface.
"""
self.setResponseCode(code, message)
def view_registerProducer(self, issuer, producer, streaming):
"""Remote version of registerProducer; same interface.
(requires a remote producer.)
"""
self.registerProducer(_RemoteProducerWrapper(producer), streaming)
def view_unregisterProducer(self, issuer):
self.unregisterProducer()
### these calls remain local
_secureSession = None
_insecureSession = None
@property
def session(self):
"""
If a session has already been created or looked up with
L{Request.getSession}, this will return that object. (This will always
be the session that matches the security of the request; so if
C{forceNotSecure} is used on a secure request, this will not return
that session.)
@return: the session attribute
@rtype: L{Session} or L{None}
"""
if self.isSecure():
return self._secureSession
else:
return self._insecureSession
def getSession(self, sessionInterface=None, forceNotSecure=False):
"""
Check if there is a session cookie, and if not, create it.
By default, the cookie with be secure for HTTPS requests and not secure
for HTTP requests. If for some reason you need access to the insecure
cookie from a secure request you can set C{forceNotSecure = True}.
@param forceNotSecure: Should we retrieve a session that will be
transmitted over HTTP, even if this L{Request} was delivered over
HTTPS?
@type forceNotSecure: L{bool}
"""
# Make sure we aren't creating a secure session on a non-secure page
secure = self.isSecure() and not forceNotSecure
if not secure:
cookieString = b"TWISTED_SESSION"
sessionAttribute = "_insecureSession"
else:
cookieString = b"TWISTED_SECURE_SESSION"
sessionAttribute = "_secureSession"
session = getattr(self, sessionAttribute)
if session is not None:
# We have a previously created session.
try:
# Refresh the session, to keep it alive.
session.touch()
except (AlreadyCalled, AlreadyCancelled):
# Session has already expired.
session = None
if session is None:
# No session was created yet for this request.
cookiename = b"_".join([cookieString] + self.sitepath)
sessionCookie = self.getCookie(cookiename)
if sessionCookie:
try:
session = self.site.getSession(sessionCookie)
except KeyError:
pass
# if it still hasn't been set, fix it up.
if not session:
session = self.site.makeSession()
self.addCookie(cookiename, session.uid, path=b"/", secure=secure)
setattr(self, sessionAttribute, session)
if sessionInterface:
return session.getComponent(sessionInterface)
return session
def _prePathURL(self, prepath):
port = self.getHost().port
if self.isSecure():
default = 443
else:
default = 80
if port == default:
hostport = ""
else:
hostport = ":%d" % port
prefix = networkString(
"http%s://%s%s/"
% (
self.isSecure() and "s" or "",
nativeString(self.getRequestHostname()),
hostport,
)
)
path = b"/".join([quote(segment, safe=b"") for segment in prepath])
return prefix + path
def prePathURL(self):
return self._prePathURL(self.prepath)
def URLPath(self):
from twisted.python import urlpath
return urlpath.URLPath.fromRequest(self)
def rememberRootURL(self):
"""
Remember the currently-processed part of the URL for later
recalling.
"""
url = self._prePathURL(self.prepath[:-1])
self.appRootURL = url
def getRootURL(self):
"""
Get a previously-remembered URL.
@return: An absolute URL.
@rtype: L{bytes}
"""
return self.appRootURL
def _handleStar(self):
"""
Handle receiving a request whose path is '*'.
RFC 7231 defines an OPTIONS * request as being something that a client
can send as a low-effort way to probe server capabilities or readiness.
Rather than bother the user with this, we simply fast-path it back to
an empty 200 OK. Any non-OPTIONS verb gets a 405 Method Not Allowed
telling the client they can only use OPTIONS.
"""
if self.method == b"OPTIONS":
self.setResponseCode(http.OK)
else:
self.setResponseCode(http.NOT_ALLOWED)
self.setHeader(b"Allow", b"OPTIONS")
# RFC 7231 says we MUST set content-length 0 when responding to this
# with no body.
self.setHeader(b"Content-Length", b"0")
self.finish()
@implementer(iweb._IRequestEncoderFactory)
class GzipEncoderFactory:
"""
@cvar compressLevel: The compression level used by the compressor, default
to 9 (highest).
@since: 12.3
"""
_gzipCheckRegex = re.compile(rb"(:?^|[\s,])gzip(:?$|[\s,])")
compressLevel = 9
def encoderForRequest(self, request):
"""
Check the headers if the client accepts gzip encoding, and encodes the
request if so.
"""
acceptHeaders = b",".join(
request.requestHeaders.getRawHeaders(b"Accept-Encoding", [])
)
if self._gzipCheckRegex.search(acceptHeaders):
encoding = request.responseHeaders.getRawHeaders(b"Content-Encoding")
if encoding:
encoding = b",".join(encoding + [b"gzip"])
else:
encoding = b"gzip"
request.responseHeaders.setRawHeaders(b"Content-Encoding", [encoding])
return _GzipEncoder(self.compressLevel, request)
@implementer(iweb._IRequestEncoder)
class _GzipEncoder:
"""
An encoder which supports gzip.
@ivar _zlibCompressor: The zlib compressor instance used to compress the
stream.
@ivar _request: A reference to the originating request.
@since: 12.3
"""
_zlibCompressor = None
def __init__(self, compressLevel, request):
self._zlibCompressor = zlib.compressobj(
compressLevel, zlib.DEFLATED, 16 + zlib.MAX_WBITS
)
self._request = request
def encode(self, data):
"""
Write to the request, automatically compressing data on the fly.
"""
if not self._request.startedWriting:
# Remove the content-length header, we can't honor it
# because we compress on the fly.
self._request.responseHeaders.removeHeader(b"Content-Length")
return self._zlibCompressor.compress(data)
def finish(self):
"""
Finish handling the request request, flushing any data from the zlib
buffer.
"""
remain = self._zlibCompressor.flush()
self._zlibCompressor = None
return remain
class _RemoteProducerWrapper:
def __init__(self, remote):
self.resumeProducing = remote.remoteMethod("resumeProducing")
self.pauseProducing = remote.remoteMethod("pauseProducing")
self.stopProducing = remote.remoteMethod("stopProducing")
class Session(components.Componentized):
"""
A user's session with a system.
This utility class contains no functionality, but is used to
represent a session.
@ivar site: The L{Site} that generated the session.
@type site: L{Site}
@ivar uid: A unique identifier for the session.
@type uid: L{bytes}
@ivar _reactor: An object providing L{IReactorTime} to use for scheduling
expiration.
@ivar sessionTimeout: Time after last modification the session will expire,
in seconds.
@type sessionTimeout: L{float}
@ivar lastModified: Time the C{touch()} method was last called (or time the
session was created). A UNIX timestamp as returned by
L{IReactorTime.seconds()}.
@type lastModified: L{float}
"""
sessionTimeout = 900
_expireCall = None
def __init__(self, site, uid, reactor=None):
"""
Initialize a session with a unique ID for that session.
@param reactor: L{IReactorTime} used to schedule expiration of the
session. If C{None}, the reactor associated with I{site} is used.
"""
super().__init__()
if reactor is None:
reactor = site.reactor
self._reactor = reactor
self.site = site
self.uid = uid
self.expireCallbacks = []
self.touch()
self.sessionNamespaces = {}
def startCheckingExpiration(self):
"""
Start expiration tracking.
@return: L{None}
"""
self._expireCall = self._reactor.callLater(self.sessionTimeout, self.expire)
def notifyOnExpire(self, callback):
"""
Call this callback when the session expires or logs out.
"""
self.expireCallbacks.append(callback)
def expire(self):
"""
Expire/logout of the session.
"""
del self.site.sessions[self.uid]
for c in self.expireCallbacks:
c()
self.expireCallbacks = []
if self._expireCall and self._expireCall.active():
self._expireCall.cancel()
# Break reference cycle.
self._expireCall = None
def touch(self):
"""
Mark the session as modified, which resets expiration timer.
"""
self.lastModified = self._reactor.seconds()
if self._expireCall is not None:
self._expireCall.reset(self.sessionTimeout)
version = networkString(f"TwistedWeb/{copyright.version}")
@implementer(interfaces.IProtocolNegotiationFactory)
class Site(HTTPFactory):
"""
A web site: manage log, sessions, and resources.
@ivar requestFactory: A factory which is called with (channel)
and creates L{Request} instances. Default to L{Request}.
@ivar displayTracebacks: If set, unhandled exceptions raised during
rendering are returned to the client as HTML. Default to C{False}.
@ivar sessionFactory: factory for sessions objects. Default to L{Session}.
@ivar sessions: Mapping of session IDs to objects returned by
C{sessionFactory}.
@type sessions: L{dict} mapping L{bytes} to L{Session} given the default
C{sessionFactory}
@ivar counter: The number of sessions that have been generated.
@type counter: L{int}
@ivar sessionCheckTime: Deprecated and unused. See
L{Session.sessionTimeout} instead.
"""
counter = 0
requestFactory = Request
displayTracebacks = False
sessionFactory = Session
sessionCheckTime = 1800
_entropy = os.urandom
def __init__(self, resource, requestFactory=None, *args, **kwargs):
"""
@param resource: The root of the resource hierarchy. All request
traversal for requests received by this factory will begin at this
resource.
@type resource: L{IResource} provider
@param requestFactory: Overwrite for default requestFactory.
@type requestFactory: C{callable} or C{class}.
@see: L{twisted.web.http.HTTPFactory.__init__}
"""
super().__init__(*args, **kwargs)
self.sessions = {}
self.resource = resource
if requestFactory is not None:
self.requestFactory = requestFactory
def _openLogFile(self, path):
from twisted.python import logfile
return logfile.LogFile(os.path.basename(path), os.path.dirname(path))
def __getstate__(self):
d = self.__dict__.copy()
d["sessions"] = {}
return d
def _mkuid(self):
"""
(internal) Generate an opaque, unique ID for a user's session.
"""
self.counter = self.counter + 1
return hexlify(self._entropy(32))
def makeSession(self):
"""
Generate a new Session instance, and store it for future reference.
"""
uid = self._mkuid()
session = self.sessions[uid] = self.sessionFactory(self, uid)
session.startCheckingExpiration()
return session
def getSession(self, uid):
"""
Get a previously generated session.
@param uid: Unique ID of the session.
@type uid: L{bytes}.
@raise KeyError: If the session is not found.
"""
return self.sessions[uid]
def buildProtocol(self, addr):
"""
Generate a channel attached to this site.
"""
channel = super().buildProtocol(addr)
channel.requestFactory = self.requestFactory
channel.site = self
return channel
isLeaf = 0
def render(self, request):
"""
Redirect because a Site is always a directory.
"""
request.redirect(request.prePathURL() + b"/")
request.finish()
def getChildWithDefault(self, pathEl, request):
"""
Emulate a resource's getChild method.
"""
request.site = self
return self.resource.getChildWithDefault(pathEl, request)
def getResourceFor(self, request):
"""
Get a resource for a request.
This iterates through the resource hierarchy, calling
getChildWithDefault on each resource it finds for a path element,
stopping when it hits an element where isLeaf is true.
"""
request.site = self
# Sitepath is used to determine cookie names between distributed
# servers and disconnected sites.
request.sitepath = copy.copy(request.prepath)
return resource.getChildForRequest(self.resource, request)
# IProtocolNegotiationFactory
def acceptableProtocols(self):
"""
Protocols this server can speak.
"""
baseProtocols = [b"http/1.1"]
if http.H2_ENABLED:
baseProtocols.insert(0, b"h2")
return baseProtocols

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,644 @@
# -*- test-case-name: twisted.web.test.test_xml -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
*S*mall, *U*ncomplicated *X*ML.
This is a very simple implementation of XML/HTML as a network
protocol. It is not at all clever. Its main features are that it
does not:
- support namespaces
- mung mnemonic entity references
- validate
- perform *any* external actions (such as fetching URLs or writing files)
under *any* circumstances
- has lots and lots of horrible hacks for supporting broken HTML (as an
option, they're not on by default).
"""
from twisted.internet.protocol import Protocol
from twisted.python.reflect import prefixedMethodNames
# Elements of the three-tuples in the state table.
BEGIN_HANDLER = 0
DO_HANDLER = 1
END_HANDLER = 2
identChars = ".-_:"
lenientIdentChars = identChars + ";+#/%~"
def nop(*args, **kw):
"Do nothing."
def unionlist(*args):
l = []
for x in args:
l.extend(x)
d = {x: 1 for x in l}
return d.keys()
def zipfndict(*args, **kw):
default = kw.get("default", nop)
d = {}
for key in unionlist(*(fndict.keys() for fndict in args)):
d[key] = tuple(x.get(key, default) for x in args)
return d
def prefixedMethodClassDict(clazz, prefix):
return {
name: getattr(clazz, prefix + name)
for name in prefixedMethodNames(clazz, prefix)
}
def prefixedMethodObjDict(obj, prefix):
return {
name: getattr(obj, prefix + name)
for name in prefixedMethodNames(obj.__class__, prefix)
}
class ParseError(Exception):
def __init__(self, filename, line, col, message):
self.filename = filename
self.line = line
self.col = col
self.message = message
def __str__(self) -> str:
return f"{self.filename}:{self.line}:{self.col}: {self.message}"
class XMLParser(Protocol):
state = None
encodings = None
filename = "<xml />"
beExtremelyLenient = 0
_prepend = None
# _leadingBodyData will sometimes be set before switching to the
# 'bodydata' state, when we "accidentally" read a byte of bodydata
# in a different state.
_leadingBodyData = None
def connectionMade(self):
self.lineno = 1
self.colno = 0
self.encodings = []
def saveMark(self):
"""Get the line number and column of the last character parsed"""
# This gets replaced during dataReceived, restored afterwards
return (self.lineno, self.colno)
def _parseError(self, message):
raise ParseError(*((self.filename,) + self.saveMark() + (message,)))
def _buildStateTable(self):
"""Return a dictionary of begin, do, end state function tuples"""
# _buildStateTable leaves something to be desired but it does what it
# does.. probably slowly, so I'm doing some evil caching so it doesn't
# get called more than once per class.
stateTable = getattr(self.__class__, "__stateTable", None)
if stateTable is None:
stateTable = self.__class__.__stateTable = zipfndict(
*(
prefixedMethodObjDict(self, prefix)
for prefix in ("begin_", "do_", "end_")
)
)
return stateTable
def _decode(self, data):
if "UTF-16" in self.encodings or "UCS-2" in self.encodings:
assert not len(data) & 1, "UTF-16 must come in pairs for now"
if self._prepend:
data = self._prepend + data
for encoding in self.encodings:
data = str(data, encoding)
return data
def maybeBodyData(self):
if self.endtag:
return "bodydata"
# Get ready for fun! We're going to allow
# <script>if (foo < bar)</script> to work!
# We do this by making everything between <script> and
# </script> a Text
# BUT <script src="foo"> will be special-cased to do regular,
# lenient behavior, because those may not have </script>
# -radix
if self.tagName == "script" and "src" not in self.tagAttributes:
# we do this ourselves rather than having begin_waitforendscript
# because that can get called multiple times and we don't want
# bodydata to get reset other than the first time.
self.begin_bodydata(None)
return "waitforendscript"
return "bodydata"
def dataReceived(self, data):
stateTable = self._buildStateTable()
if not self.state:
# all UTF-16 starts with this string
if data.startswith((b"\xff\xfe", b"\xfe\xff")):
self._prepend = data[0:2]
self.encodings.append("UTF-16")
data = data[2:]
self.state = "begin"
if self.encodings:
data = self._decode(data)
else:
data = data.decode("utf-8")
# bring state, lineno, colno into local scope
lineno, colno = self.lineno, self.colno
curState = self.state
# replace saveMark with a nested scope function
_saveMark = self.saveMark
def saveMark():
return (lineno, colno)
self.saveMark = saveMark
# fetch functions from the stateTable
beginFn, doFn, endFn = stateTable[curState]
try:
for byte in data:
# do newline stuff
if byte == "\n":
lineno += 1
colno = 0
else:
colno += 1
newState = doFn(byte)
if newState is not None and newState != curState:
# this is the endFn from the previous state
endFn()
curState = newState
beginFn, doFn, endFn = stateTable[curState]
beginFn(byte)
finally:
self.saveMark = _saveMark
self.lineno, self.colno = lineno, colno
# state doesn't make sense if there's an exception..
self.state = curState
def connectionLost(self, reason):
"""
End the last state we were in.
"""
stateTable = self._buildStateTable()
stateTable[self.state][END_HANDLER]()
# state methods
def do_begin(self, byte):
if byte.isspace():
return
if byte != "<":
if self.beExtremelyLenient:
self._leadingBodyData = byte
return "bodydata"
self._parseError(f"First char of document [{byte!r}] wasn't <")
return "tagstart"
def begin_comment(self, byte):
self.commentbuf = ""
def do_comment(self, byte):
self.commentbuf += byte
if self.commentbuf.endswith("-->"):
self.gotComment(self.commentbuf[:-3])
return "bodydata"
def begin_tagstart(self, byte):
self.tagName = "" # name of the tag
self.tagAttributes = {} # attributes of the tag
self.termtag = 0 # is the tag self-terminating
self.endtag = 0
def do_tagstart(self, byte):
if byte.isalnum() or byte in identChars:
self.tagName += byte
if self.tagName == "!--":
return "comment"
elif byte.isspace():
if self.tagName:
if self.endtag:
# properly strict thing to do here is probably to only
# accept whitespace
return "waitforgt"
return "attrs"
else:
self._parseError("Whitespace before tag-name")
elif byte == ">":
if self.endtag:
self.gotTagEnd(self.tagName)
return "bodydata"
else:
self.gotTagStart(self.tagName, {})
return (
(not self.beExtremelyLenient) and "bodydata" or self.maybeBodyData()
)
elif byte == "/":
if self.tagName:
return "afterslash"
else:
self.endtag = 1
elif byte in "!?":
if self.tagName:
if not self.beExtremelyLenient:
self._parseError("Invalid character in tag-name")
else:
self.tagName += byte
self.termtag = 1
elif byte == "[":
if self.tagName == "!":
return "expectcdata"
else:
self._parseError("Invalid '[' in tag-name")
else:
if self.beExtremelyLenient:
self.bodydata = "<"
return "unentity"
self._parseError("Invalid tag character: %r" % byte)
def begin_unentity(self, byte):
self.bodydata += byte
def do_unentity(self, byte):
self.bodydata += byte
return "bodydata"
def end_unentity(self):
self.gotText(self.bodydata)
def begin_expectcdata(self, byte):
self.cdatabuf = byte
def do_expectcdata(self, byte):
self.cdatabuf += byte
cdb = self.cdatabuf
cd = "[CDATA["
if len(cd) > len(cdb):
if cd.startswith(cdb):
return
elif self.beExtremelyLenient:
## WHAT THE CRAP!? MSWord9 generates HTML that includes these
## bizarre <![if !foo]> <![endif]> chunks, so I've gotta ignore
## 'em as best I can. this should really be a separate parse
## state but I don't even have any idea what these _are_.
return "waitforgt"
else:
self._parseError("Mal-formed CDATA header")
if cd == cdb:
self.cdatabuf = ""
return "cdata"
self._parseError("Mal-formed CDATA header")
def do_cdata(self, byte):
self.cdatabuf += byte
if self.cdatabuf.endswith("]]>"):
self.cdatabuf = self.cdatabuf[:-3]
return "bodydata"
def end_cdata(self):
self.gotCData(self.cdatabuf)
self.cdatabuf = ""
def do_attrs(self, byte):
if byte.isalnum() or byte in identChars:
# XXX FIXME really handle !DOCTYPE at some point
if self.tagName == "!DOCTYPE":
return "doctype"
if self.tagName[0] in "!?":
return "waitforgt"
return "attrname"
elif byte.isspace():
return
elif byte == ">":
self.gotTagStart(self.tagName, self.tagAttributes)
return (not self.beExtremelyLenient) and "bodydata" or self.maybeBodyData()
elif byte == "/":
return "afterslash"
elif self.beExtremelyLenient:
# discard and move on? Only case I've seen of this so far was:
# <foo bar="baz"">
return
self._parseError("Unexpected character: %r" % byte)
def begin_doctype(self, byte):
self.doctype = byte
def do_doctype(self, byte):
if byte == ">":
return "bodydata"
self.doctype += byte
def end_doctype(self):
self.gotDoctype(self.doctype)
self.doctype = None
def do_waitforgt(self, byte):
if byte == ">":
if self.endtag or not self.beExtremelyLenient:
return "bodydata"
return self.maybeBodyData()
def begin_attrname(self, byte):
self.attrname = byte
self._attrname_termtag = 0
def do_attrname(self, byte):
if byte.isalnum() or byte in identChars:
self.attrname += byte
return
elif byte == "=":
return "beforeattrval"
elif byte.isspace():
return "beforeeq"
elif self.beExtremelyLenient:
if byte in "\"'":
return "attrval"
if byte in lenientIdentChars or byte.isalnum():
self.attrname += byte
return
if byte == "/":
self._attrname_termtag = 1
return
if byte == ">":
self.attrval = "True"
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if self._attrname_termtag:
self.gotTagEnd(self.tagName)
return "bodydata"
return self.maybeBodyData()
# something is really broken. let's leave this attribute where it
# is and move on to the next thing
return
self._parseError(f"Invalid attribute name: {self.attrname!r} {byte!r}")
def do_beforeattrval(self, byte):
if byte in "\"'":
return "attrval"
elif byte.isspace():
return
elif self.beExtremelyLenient:
if byte in lenientIdentChars or byte.isalnum():
return "messyattr"
if byte == ">":
self.attrval = "True"
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
return self.maybeBodyData()
if byte == "\\":
# I saw this in actual HTML once:
# <font size=\"3\"><sup>SM</sup></font>
return
self._parseError(
"Invalid initial attribute value: %r; Attribute values must be quoted."
% byte
)
attrname = ""
attrval = ""
def begin_beforeeq(self, byte):
self._beforeeq_termtag = 0
def do_beforeeq(self, byte):
if byte == "=":
return "beforeattrval"
elif byte.isspace():
return
elif self.beExtremelyLenient:
if byte.isalnum() or byte in identChars:
self.attrval = "True"
self.tagAttributes[self.attrname] = self.attrval
return "attrname"
elif byte == ">":
self.attrval = "True"
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if self._beforeeq_termtag:
self.gotTagEnd(self.tagName)
return "bodydata"
return self.maybeBodyData()
elif byte == "/":
self._beforeeq_termtag = 1
return
self._parseError("Invalid attribute")
def begin_attrval(self, byte):
self.quotetype = byte
self.attrval = ""
def do_attrval(self, byte):
if byte == self.quotetype:
return "attrs"
self.attrval += byte
def end_attrval(self):
self.tagAttributes[self.attrname] = self.attrval
self.attrname = self.attrval = ""
def begin_messyattr(self, byte):
self.attrval = byte
def do_messyattr(self, byte):
if byte.isspace():
return "attrs"
elif byte == ">":
endTag = 0
if self.attrval.endswith("/"):
endTag = 1
self.attrval = self.attrval[:-1]
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if endTag:
self.gotTagEnd(self.tagName)
return "bodydata"
return self.maybeBodyData()
else:
self.attrval += byte
def end_messyattr(self):
if self.attrval:
self.tagAttributes[self.attrname] = self.attrval
def begin_afterslash(self, byte):
self._after_slash_closed = 0
def do_afterslash(self, byte):
# this state is only after a self-terminating slash, e.g. <foo/>
if self._after_slash_closed:
self._parseError("Mal-formed") # XXX When does this happen??
if byte != ">":
if self.beExtremelyLenient:
return
else:
self._parseError("No data allowed after '/'")
self._after_slash_closed = 1
self.gotTagStart(self.tagName, self.tagAttributes)
self.gotTagEnd(self.tagName)
# don't need maybeBodyData here because there better not be
# any javascript code after a <script/>... we'll see :(
return "bodydata"
def begin_bodydata(self, byte):
if self._leadingBodyData:
self.bodydata = self._leadingBodyData
del self._leadingBodyData
else:
self.bodydata = ""
def do_bodydata(self, byte):
if byte == "<":
return "tagstart"
if byte == "&":
return "entityref"
self.bodydata += byte
def end_bodydata(self):
self.gotText(self.bodydata)
self.bodydata = ""
def do_waitforendscript(self, byte):
if byte == "<":
return "waitscriptendtag"
self.bodydata += byte
def begin_waitscriptendtag(self, byte):
self.temptagdata = ""
self.tagName = ""
self.endtag = 0
def do_waitscriptendtag(self, byte):
# 1 enforce / as first byte read
# 2 enforce following bytes to be subset of "script" until
# tagName == "script"
# 2a when that happens, gotText(self.bodydata) and gotTagEnd(self.tagName)
# 3 spaces can happen anywhere, they're ignored
# e.g. < / script >
# 4 anything else causes all data I've read to be moved to the
# bodydata, and switch back to waitforendscript state
# If it turns out this _isn't_ a </script>, we need to
# remember all the data we've been through so we can append it
# to bodydata
self.temptagdata += byte
# 1
if byte == "/":
self.endtag = True
elif not self.endtag:
self.bodydata += "<" + self.temptagdata
return "waitforendscript"
# 2
elif byte.isalnum() or byte in identChars:
self.tagName += byte
if not "script".startswith(self.tagName):
self.bodydata += "<" + self.temptagdata
return "waitforendscript"
elif self.tagName == "script":
self.gotText(self.bodydata)
self.gotTagEnd(self.tagName)
return "waitforgt"
# 3
elif byte.isspace():
return "waitscriptendtag"
# 4
else:
self.bodydata += "<" + self.temptagdata
return "waitforendscript"
def begin_entityref(self, byte):
self.erefbuf = ""
self.erefextra = "" # extra bit for lenient mode
def do_entityref(self, byte):
if byte.isspace() or byte == "<":
if self.beExtremelyLenient:
# '&foo' probably was '&amp;foo'
if self.erefbuf and self.erefbuf != "amp":
self.erefextra = self.erefbuf
self.erefbuf = "amp"
if byte == "<":
return "tagstart"
else:
self.erefextra += byte
return "spacebodydata"
self._parseError("Bad entity reference")
elif byte != ";":
self.erefbuf += byte
else:
return "bodydata"
def end_entityref(self):
self.gotEntityReference(self.erefbuf)
# hacky support for space after & in entityref in beExtremelyLenient
# state should only happen in that case
def begin_spacebodydata(self, byte):
self.bodydata = self.erefextra
self.erefextra = None
do_spacebodydata = do_bodydata
end_spacebodydata = end_bodydata
# Sorta SAX-ish API
def gotTagStart(self, name, attributes):
"""Encountered an opening tag.
Default behaviour is to print."""
print("begin", name, attributes)
def gotText(self, data):
"""Encountered text
Default behaviour is to print."""
print("text:", repr(data))
def gotEntityReference(self, entityRef):
"""Encountered mnemonic entity reference
Default behaviour is to print."""
print("entityRef: &%s;" % entityRef)
def gotComment(self, comment):
"""Encountered comment.
Default behaviour is to ignore."""
pass
def gotCData(self, cdata):
"""Encountered CDATA
Default behaviour is to call the gotText method"""
self.gotText(cdata)
def gotDoctype(self, doctype):
"""Encountered DOCTYPE
This is really grotty: it basically just gives you everything between
'<!DOCTYPE' and '>' as an argument.
"""
print("!DOCTYPE", repr(doctype))
def gotTagEnd(self, name):
"""Encountered closing tag
Default behaviour is to print."""
print("end", name)

View File

@@ -0,0 +1,322 @@
# -*- test-case-name: twisted.web.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support for creating a service which runs a web server.
"""
import os
import warnings
import incremental
from twisted.application import service, strports
from twisted.internet import interfaces, reactor
from twisted.python import deprecate, reflect, threadpool, usage
from twisted.spread import pb
from twisted.web import demo, distrib, resource, script, server, static, twcgi, wsgi
class Options(usage.Options):
"""
Define the options accepted by the I{twistd web} plugin.
"""
synopsis = "[web options]"
optParameters = [
["logfile", "l", None, "Path to web CLF (Combined Log Format) log file."],
[
"certificate",
"c",
"server.pem",
"(DEPRECATED: use --listen) " "SSL certificate to use for HTTPS. ",
],
[
"privkey",
"k",
"server.pem",
"(DEPRECATED: use --listen) " "SSL certificate to use for HTTPS.",
],
]
optFlags = [
[
"notracebacks",
"n",
(
"(DEPRECATED: Tracebacks are disabled by default. "
"See --enable-tracebacks to turn them on."
),
],
[
"display-tracebacks",
"",
(
"Show uncaught exceptions during rendering tracebacks to "
"the client. WARNING: This may be a security risk and "
"expose private data!"
),
],
]
optFlags.append(
[
"personal",
"",
"Instead of generating a webserver, generate a "
"ResourcePublisher which listens on the port given by "
"--listen, or ~/%s " % (distrib.UserDirectory.userSocketName,)
+ "if --listen is not specified.",
]
)
compData = usage.Completions(
optActions={
"logfile": usage.CompleteFiles("*.log"),
"certificate": usage.CompleteFiles("***REMOVED***"),
"privkey": usage.CompleteFiles("***REMOVED***"),
}
)
longdesc = """\
This starts a webserver. If you specify no arguments, it will be a
demo webserver that has the Test class from twisted.web.demo in it."""
def __init__(self):
usage.Options.__init__(self)
self["indexes"] = []
self["root"] = None
self["extraHeaders"] = []
self["ports"] = []
self["port"] = self["https"] = None
def opt_port(self, port):
"""
(DEPRECATED: use --listen)
Strports description of port to start the server on
"""
msg = deprecate.getDeprecationWarningString(
self.opt_port, incremental.Version("Twisted", 18, 4, 0)
)
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
self["port"] = port
opt_p = opt_port
def opt_https(self, port):
"""
(DEPRECATED: use --listen)
Port to listen on for Secure HTTP.
"""
msg = deprecate.getDeprecationWarningString(
self.opt_https, incremental.Version("Twisted", 18, 4, 0)
)
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
self["https"] = port
def opt_listen(self, port):
"""
Add an strports description of port to start the server on.
[default: tcp:8080]
"""
self["ports"].append(port)
def opt_index(self, indexName):
"""
Add the name of a file used to check for directory indexes.
[default: index, index.html]
"""
self["indexes"].append(indexName)
opt_i = opt_index
def opt_user(self):
"""
Makes a server with ~/public_html and ~/.twistd-web-pb support for
users.
"""
self["root"] = distrib.UserDirectory()
opt_u = opt_user
def opt_path(self, path):
"""
<path> is either a specific file or a directory to be set as the root
of the web server. Use this if you have a directory full of HTML, cgi,
epy, or rpy files or any other files that you want to be served up raw.
"""
self["root"] = static.File(os.path.abspath(path))
self["root"].processors = {
".epy": script.PythonScript,
".rpy": script.ResourceScript,
}
self["root"].processors[".cgi"] = twcgi.CGIScript
def opt_processor(self, proc):
"""
`ext=class' where `class' is added as a Processor for files ending
with `ext'.
"""
if not isinstance(self["root"], static.File):
raise usage.UsageError("You can only use --processor after --path.")
ext, klass = proc.split("=", 1)
self["root"].processors[ext] = reflect.namedClass(klass)
def opt_class(self, className):
"""
Create a Resource subclass with a zero-argument constructor.
"""
classObj = reflect.namedClass(className)
self["root"] = classObj()
def opt_resource_script(self, name):
"""
An .rpy file to be used as the root resource of the webserver.
"""
self["root"] = script.ResourceScriptWrapper(name)
def opt_wsgi(self, name):
"""
The FQPN of a WSGI application object to serve as the root resource of
the webserver.
"""
try:
application = reflect.namedAny(name)
except (AttributeError, ValueError):
raise usage.UsageError(f"No such WSGI application: {name!r}")
pool = threadpool.ThreadPool()
reactor.callWhenRunning(pool.start)
reactor.addSystemEventTrigger("after", "shutdown", pool.stop)
self["root"] = wsgi.WSGIResource(reactor, pool, application)
def opt_mime_type(self, defaultType):
"""
Specify the default mime-type for static files.
"""
if not isinstance(self["root"], static.File):
raise usage.UsageError("You can only use --mime_type after --path.")
self["root"].defaultType = defaultType
opt_m = opt_mime_type
def opt_allow_ignore_ext(self):
"""
Specify whether or not a request for 'foo' should return 'foo.ext'
"""
if not isinstance(self["root"], static.File):
raise usage.UsageError(
"You can only use --allow_ignore_ext " "after --path."
)
self["root"].ignoreExt("*")
def opt_ignore_ext(self, ext):
"""
Specify an extension to ignore. These will be processed in order.
"""
if not isinstance(self["root"], static.File):
raise usage.UsageError("You can only use --ignore_ext " "after --path.")
self["root"].ignoreExt(ext)
def opt_add_header(self, header):
"""
Specify an additional header to be included in all responses. Specified
as "HeaderName: HeaderValue".
"""
name, value = header.split(":", 1)
self["extraHeaders"].append((name.strip(), value.strip()))
def postOptions(self):
"""
Set up conditional defaults and check for dependencies.
If SSL is not available but an HTTPS server was configured, raise a
L{UsageError} indicating that this is not possible.
If no server port was supplied, select a default appropriate for the
other options supplied.
"""
if self["port"] is not None:
self["ports"].append(self["port"])
if self["https"] is not None:
try:
reflect.namedModule("OpenSSL.SSL")
except ImportError:
raise usage.UsageError("SSL support not installed")
sslStrport = "ssl:port={}:privateKey={}:certKey={}".format(
self["https"],
self["privkey"],
self["certificate"],
)
self["ports"].append(sslStrport)
if len(self["ports"]) == 0:
if self["personal"]:
path = os.path.expanduser(
os.path.join("~", distrib.UserDirectory.userSocketName)
)
self["ports"].append("unix:" + path)
else:
self["ports"].append("tcp:8080")
def makePersonalServerFactory(site):
"""
Create and return a factory which will respond to I{distrib} requests
against the given site.
@type site: L{twisted.web.server.Site}
@rtype: L{twisted.internet.protocol.Factory}
"""
return pb.PBServerFactory(distrib.ResourcePublisher(site))
class _AddHeadersResource(resource.Resource):
def __init__(self, originalResource, headers):
self._originalResource = originalResource
self._headers = headers
def getChildWithDefault(self, name, request):
for k, v in self._headers:
request.responseHeaders.addRawHeader(k, v)
return self._originalResource.getChildWithDefault(name, request)
def makeService(config):
s = service.MultiService()
if config["root"]:
root = config["root"]
if config["indexes"]:
config["root"].indexNames = config["indexes"]
else:
# This really ought to be web.Admin or something
root = demo.Test()
if isinstance(root, static.File):
root.registry.setComponent(interfaces.IServiceCollection, s)
if config["extraHeaders"]:
root = _AddHeadersResource(root, config["extraHeaders"])
if config["logfile"]:
site = server.Site(root, logPath=config["logfile"])
else:
site = server.Site(root)
if config["display-tracebacks"]:
site.displayTracebacks = True
# Deprecate --notracebacks/-n
if config["notracebacks"]:
msg = deprecate._getDeprecationWarningString(
"--notracebacks", incremental.Version("Twisted", 19, 7, 0)
)
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
if config["personal"]:
site = makePersonalServerFactory(site)
for port in config["ports"]:
svc = strports.service(port, site)
svc.setServiceParent(s)
return s

View File

@@ -0,0 +1,60 @@
# -*- test-case-name: twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTML rendering for twisted.web.
@var VALID_HTML_TAG_NAMES: A list of recognized HTML tag names, used by the
L{tag} object.
@var TEMPLATE_NAMESPACE: The XML namespace used to identify attributes and
elements used by the templating system, which should be removed from the
final output document.
@var tags: A convenience object which can produce L{Tag} objects on demand via
attribute access. For example: C{tags.div} is equivalent to C{Tag("div")}.
Tags not specified in L{VALID_HTML_TAG_NAMES} will result in an
L{AttributeError}.
"""
__all__ = [
"TEMPLATE_NAMESPACE",
"VALID_HTML_TAG_NAMES",
"Element",
"Flattenable",
"TagLoader",
"XMLString",
"XMLFile",
"renderer",
"flatten",
"flattenString",
"tags",
"Comment",
"CDATA",
"Tag",
"slot",
"CharRef",
"renderElement",
]
from ._stan import CharRef
from ._template_util import (
CDATA,
TEMPLATE_NAMESPACE,
VALID_HTML_TAG_NAMES,
Comment,
Element,
Flattenable,
Tag,
TagLoader,
XMLFile,
XMLString,
flatten,
flattenString,
renderElement,
renderer,
slot,
tags,
)

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web}.
"""

View File

@@ -0,0 +1,95 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
General helpers for L{twisted.web} unit tests.
"""
from __future__ import annotations
from typing import Type
from twisted.internet.defer import Deferred, succeed
from twisted.trial.unittest import SynchronousTestCase
from twisted.web import server
from twisted.web._flatten import flattenString
from twisted.web.error import FlattenerError
from twisted.web.http import Request
from twisted.web.resource import IResource
from twisted.web.template import Flattenable
from .requesthelper import DummyRequest
def _render(resource: IResource, request: Request | DummyRequest) -> Deferred[None]:
result = resource.render(request)
if isinstance(result, bytes):
request.write(result)
request.finish()
return succeed(None)
elif result is server.NOT_DONE_YET:
if request.finished:
return succeed(None)
else:
return request.notifyFinish()
else:
raise ValueError(f"Unexpected return value: {result!r}")
class FlattenTestCase(SynchronousTestCase):
"""
A test case that assists with testing L{twisted.web._flatten}.
"""
def assertFlattensTo(self, root: Flattenable, target: bytes) -> Deferred[bytes]:
"""
Assert that a root element, when flattened, is equal to a string.
"""
def check(result: bytes) -> bytes:
self.assertEqual(result, target)
return result
d: Deferred[bytes] = flattenString(None, root)
d.addCallback(check)
return d
def assertFlattensImmediately(self, root: Flattenable, target: bytes) -> bytes:
"""
Assert that a root element, when flattened, is equal to a string, and
performs no asynchronus Deferred anything.
This version is more convenient in tests which wish to make multiple
assertions about flattening, since it can be called multiple times
without having to add multiple callbacks.
@return: the result of rendering L{root}, which should be equivalent to
L{target}.
@rtype: L{bytes}
"""
return self.successResultOf(self.assertFlattensTo(root, target))
def assertFlatteningRaises(self, root: Flattenable, exn: Type[Exception]) -> None:
"""
Assert flattening a root element raises a particular exception.
"""
failure = self.failureResultOf(self.assertFlattensTo(root, b""), FlattenerError)
self.assertIsInstance(failure.value._exception, exn)
def assertIsFilesystemTemporary(case, fileObj):
"""
Assert that C{fileObj} is a temporary file on the filesystem.
@param case: A C{TestCase} instance to use to make the assertion.
@raise: C{case.failureException} if C{fileObj} is not a temporary file on
the filesystem.
"""
# The tempfile API used to create content returns an instance of a
# different type depending on what platform we're running on. The point
# here is to verify that the request body is in a file that's on the
# filesystem. Having a fileno method that returns an int is a somewhat
# close approximation of this. -exarkun
case.assertIsInstance(fileObj.fileno(), int)
__all__ = ["_render", "FlattenTestCase", "assertIsFilesystemTemporary"]

View File

@@ -0,0 +1,155 @@
"""
Helpers for URI and method injection tests.
@see: U{CVE-2019-12387}
"""
import string
UNPRINTABLE_ASCII = frozenset(range(0, 128)) - frozenset(
bytearray(string.printable, "ascii")
)
NONASCII = frozenset(range(128, 256))
class MethodInjectionTestsMixin:
"""
A mixin that runs HTTP method injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousMethod} in
a L{twisted.trial.unittest.SynchronousTestCase} subclass to test
how HTTP client code behaves when presented with malicious HTTP
methods.
@see: U{CVE-2019-12387}
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt to send a request with the given method. This should
synchronously raise a L{ValueError} if either is invalid.
@param method: the method (e.g. C{GET\x00})
@param uri: the URI
@type method:
"""
raise NotImplementedError()
def test_methodWithCLRFRejected(self):
"""
Issuing a request with a method that contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
method = b"GET\r\nX-Injected-Header: value"
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithUnprintableASCIIRejected(self):
"""
Issuing a request with a method that contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithNonASCIIRejected(self):
"""
Issuing a request with a method that contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
class URIInjectionTestsMixin:
"""
A mixin that runs HTTP URI injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousURI} in a
L{twisted.trial.unittest.SynchronousTestCase} subclass to test how
HTTP client code behaves when presented with malicious HTTP
URIs.
"""
def attemptRequestWithMaliciousURI(self, method):
"""
Attempt to send a request with the given URI. This should
synchronously raise a L{ValueError} if either is invalid.
@param uri: the URI.
@type method:
"""
raise NotImplementedError()
def test_hostWithCRLFRejected(self):
"""
Issuing a request with a URI whose host contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted\r\n.invalid/path"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose host contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose host contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithCRLFRejected(self):
"""
Issuing a request with a URI whose path contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted.invalid/\r\npath"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose path contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose path contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")

View File

@@ -0,0 +1,516 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helpers related to HTTP requests, used by tests.
"""
from __future__ import annotations
__all__ = ["DummyChannel", "DummyRequest"]
from io import BytesIO
from typing import Dict, List, Optional
from zope.interface import implementer, verify
from incremental import Version
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, ISSLTransport
from twisted.internet.task import Clock
from twisted.python.deprecate import deprecated
from twisted.trial import unittest
from twisted.web._responses import FOUND
from twisted.web.http_headers import Headers
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET, Session, Site
textLinearWhitespaceComponents = [f"Foo{lw}bar" for lw in ["\r", "\n", "\r\n"]]
sanitizedText = "Foo bar"
bytesLinearWhitespaceComponents = [
component.encode("ascii") for component in textLinearWhitespaceComponents
]
sanitizedBytes = sanitizedText.encode("ascii")
@implementer(IAddress)
class NullAddress:
"""
A null implementation of L{IAddress}.
"""
class DummyChannel:
class TCP:
port = 80
disconnected = False
def __init__(self, peer=None):
if peer is None:
peer = IPv4Address("TCP", "192.168.1.1", 12344)
self._peer = peer
self.written = BytesIO()
self.producers = []
def getPeer(self):
return self._peer
def write(self, data):
if not isinstance(data, bytes):
raise TypeError(f"Can only write bytes to a transport, not {data!r}")
self.written.write(data)
def writeSequence(self, iovec):
for data in iovec:
self.write(data)
def getHost(self):
return IPv4Address("TCP", "10.0.0.1", self.port)
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def unregisterProducer(self):
pass
def loseConnection(self):
self.disconnected = True
@implementer(ISSLTransport)
class SSL(TCP):
def abortConnection(self):
# ITCPTransport.abortConnection
pass
def getTcpKeepAlive(self):
# ITCPTransport.getTcpKeepAlive
pass
def getTcpNoDelay(self):
# ITCPTransport.getTcpNoDelay
pass
def loseWriteConnection(self):
# ITCPTransport.loseWriteConnection
pass
def setTcpKeepAlive(self, enabled):
# ITCPTransport.setTcpKeepAlive
pass
def setTcpNoDelay(self, enabled):
# ITCPTransport.setTcpNoDelay
pass
def getPeerCertificate(self):
# ISSLTransport.getPeerCertificate
pass
site = Site(Resource())
def __init__(self, peer=None):
self.transport = self.TCP(peer)
def requestDone(self, request):
pass
def writeHeaders(self, version, code, reason, headers):
if isinstance(headers, Headers):
headers = [
(k, v) for (k, values) in headers.getAllRawHeaders() for v in values
]
response_line = version + b" " + code + b" " + reason + b"\r\n"
headerSequence = [response_line]
headerSequence.extend(name + b": " + value + b"\r\n" for name, value in headers)
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
self.transport.write(data)
def writeSequence(self, iovec):
self.transport.writeSequence(iovec)
def loseConnection(self):
self.transport.loseConnection()
def endRequest(self):
pass
def isSecure(self):
return isinstance(self.transport, self.SSL)
def abortConnection(self):
# ITCPTransport.abortConnection
pass
def getTcpKeepAlive(self):
# ITCPTransport.getTcpKeepAlive
pass
def getTcpNoDelay(self):
# ITCPTransport.getTcpNoDelay
pass
def loseWriteConnection(self):
# ITCPTransport.loseWriteConnection
pass
def setTcpKeepAlive(self):
# ITCPTransport.setTcpKeepAlive
pass
def setTcpNoDelay(self):
# ITCPTransport.setTcpNoDelay
pass
def getPeerCertificate(self):
# ISSLTransport.getPeerCertificate
pass
class DummyRequest:
"""
Represents a dummy or fake request. See L{twisted.web.server.Request}.
@ivar _finishedDeferreds: L{None} or a C{list} of L{Deferreds} which will
be called back with L{None} when C{finish} is called or which will be
errbacked if C{processingFailed} is called.
@type requestheaders: C{Headers}
@ivar requestheaders: A Headers instance that stores values for all request
headers.
@type responseHeaders: C{Headers}
@ivar responseHeaders: A Headers instance that stores values for all
response headers.
@type responseCode: C{int}
@ivar responseCode: The response code which was passed to
C{setResponseCode}.
@type written: C{list} of C{bytes}
@ivar written: The bytes which have been written to the request.
"""
uri = b"http://dummy/"
method = b"GET"
client: Optional[IAddress] = None
sitepath: List[bytes]
written: List[bytes]
prepath: List[bytes]
args: Dict[bytes, List[bytes]]
_finishedDeferreds: List[Deferred[None]]
def registerProducer(self, prod, s):
"""
Call an L{IPullProducer}'s C{resumeProducing} method in a
loop until it unregisters itself.
@param prod: The producer.
@type prod: L{IPullProducer}
@param s: Whether or not the producer is streaming.
"""
# XXX: Handle IPushProducers
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
self.go = 0
def __init__(
self,
postpath: list[bytes],
session: Optional[Session] = None,
client: Optional[IAddress] = None,
) -> None:
self.sitepath = []
self.written = []
self.finished = 0
self.postpath = postpath
self.prepath = []
self.session = None
self.protoSession = session or Session(site=None, uid=b"0", reactor=Clock())
self.args = {}
self.requestHeaders = Headers()
self.responseHeaders = Headers()
self.responseCode = None
self._finishedDeferreds = []
self._serverName = b"dummy"
self.clientproto = b"HTTP/1.0"
def getAllHeaders(self):
"""
Return dictionary mapping the names of all received headers to the last
value received for each.
Since this method does not return all header information,
C{self.requestHeaders.getAllRawHeaders()} may be preferred.
NOTE: This function is a direct copy of
C{twisted.web.http.Request.getAllRawHeaders}.
"""
headers = {}
for k, v in self.requestHeaders.getAllRawHeaders():
headers[k.lower()] = v[-1]
return headers
def getHeader(self, name):
"""
Retrieve the value of a request header.
@type name: C{bytes}
@param name: The name of the request header for which to retrieve the
value. Header names are compared case-insensitively.
@rtype: C{bytes} or L{None}
@return: The value of the specified request header.
"""
return self.requestHeaders.getRawHeaders(name.lower(), [None])[0]
def setHeader(self, name, value):
"""TODO: make this assert on write() if the header is content-length"""
self.responseHeaders.addRawHeader(name, value)
def getSession(self, sessionInterface=None):
if self.session:
return self.session
assert (
not self.written
), "Session cannot be requested after data has been written."
self.session = self.protoSession
return self.session
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is NOT_DONE_YET:
return
self.write(result)
self.finish()
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("write() only accepts bytes")
self.written.append(data)
def notifyFinish(self) -> Deferred[None]:
"""
Return a L{Deferred} which is called back with L{None} when the request
is finished. This will probably only work if you haven't called
C{finish} yet.
"""
finished: Deferred[None] = Deferred()
self._finishedDeferreds.append(finished)
return finished
def finish(self):
"""
Record that the request is finished and callback and L{Deferred}s
waiting for notification of this.
"""
self.finished = self.finished + 1
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.callback(None)
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.errback(reason)
def addArg(self, name, value):
self.args[name] = [value]
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
assert (
not self.written
), "Response code cannot be set after data has" "been written: {}.".format(
"@@@@".join(self.written)
)
self.responseCode = code
self.responseMessage = message
def setLastModified(self, when):
assert (
not self.written
), "Last-Modified cannot be set after data has " "been written: {}.".format(
"@@@@".join(self.written)
)
def setETag(self, tag):
assert (
not self.written
), "ETag cannot be set after data has been " "written: {}.".format(
"@@@@".join(self.written)
)
@deprecated(Version("Twisted", 18, 4, 0), replacement="getClientAddress")
def getClientIP(self):
"""
Return the IPv4 address of the client which made this request, if there
is one, otherwise L{None}.
"""
if isinstance(self.client, (IPv4Address, IPv6Address)):
return self.client.host
return None
def getClientAddress(self):
"""
Return the L{IAddress} of the client that made this request.
@return: an address.
@rtype: an L{IAddress} provider.
"""
if self.client is None:
return NullAddress()
return self.client
def getRequestHostname(self):
"""
Get a dummy hostname associated to the HTTP request.
@rtype: C{bytes}
@returns: a dummy hostname
"""
return self._serverName
def getHost(self):
"""
Get a dummy transport's host.
@rtype: C{IPv4Address}
@returns: a dummy transport's host
"""
return IPv4Address("TCP", "127.0.0.1", 80)
def setHost(self, host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
@type host: C{bytes}
@param host: The value to which to change the host header.
@type ssl: C{bool}
@param ssl: A flag which, if C{True}, indicates that the request is
considered secure (if C{True}, L{isSecure} will return C{True}).
"""
self._forceSSL = ssl # set first so isSecure will work
if self.isSecure():
default = 443
else:
default = 80
if port == default:
hostHeader = host
else:
hostHeader = b"%b:%d" % (host, port)
self.requestHeaders.addRawHeader(b"host", hostHeader)
def redirect(self, url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(FOUND)
self.setHeader(b"location", url)
class DummyRequestTests(unittest.SynchronousTestCase):
"""
Tests for L{DummyRequest}.
"""
def test_getClientIPDeprecated(self):
"""
L{DummyRequest.getClientIP} is deprecated in favor of
L{DummyRequest.getClientAddress}
"""
request = DummyRequest([])
request.getClientIP()
warnings = self.flushWarnings(
offendingFunctions=[self.test_getClientIPDeprecated]
)
self.assertEqual(1, len(warnings))
[warning] = warnings
self.assertEqual(warning.get("category"), DeprecationWarning)
self.assertEqual(
warning.get("message"),
(
"twisted.web.test.requesthelper.DummyRequest.getClientIP "
"was deprecated in Twisted 18.4.0; "
"please use getClientAddress instead"
),
)
def test_getClientIPSupportsIPv6(self):
"""
L{DummyRequest.getClientIP} supports IPv6 addresses, just like
L{twisted.web.http.Request.getClientIP}.
"""
request = DummyRequest([])
client = IPv6Address("TCP", "::1", 12345)
request.client = client
self.assertEqual("::1", request.getClientIP())
def test_getClientAddressWithoutClient(self):
"""
L{DummyRequest.getClientAddress} returns an L{IAddress}
provider no C{client} has been set.
"""
request = DummyRequest([])
null = request.getClientAddress()
verify.verifyObject(IAddress, null)
def test_getClientAddress(self):
"""
L{DummyRequest.getClientAddress} returns the C{client}.
"""
request = DummyRequest([])
client = IPv4Address("TCP", "127.0.0.1", 12345)
request.client = client
address = request.getClientAddress()
self.assertIs(address, client)

View File

@@ -0,0 +1,114 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._abnf}.
"""
from twisted.trial import unittest
from twisted.web._abnf import _decint, _hexint, _ishexdigits, _istoken
class IsTokenTests(unittest.SynchronousTestCase):
"""
Test the L{twisted.web._abnf._istoken} function.
"""
def test_ok(self) -> None:
for b in (
b"GET",
b"Cache-Control",
b"&",
):
self.assertTrue(_istoken(b))
def test_bad(self) -> None:
for b in (
b"",
b" ",
b"a b",
):
self.assertFalse(_istoken(b))
class DecintTests(unittest.SynchronousTestCase):
"""
Test the L{twisted.web._abnf._decint} function.
"""
def test_valid(self) -> None:
"""
Given a decimal digits, L{_decint} return an L{int}.
"""
self.assertEqual(1, _decint(b"1"))
self.assertEqual(10, _decint(b"10"))
self.assertEqual(9000, _decint(b"9000"))
self.assertEqual(9000, _decint(b"0009000"))
def test_validWhitespace(self) -> None:
"""
L{_decint} decodes integers embedded in linear whitespace.
"""
self.assertEqual(123, _decint(b" 123"))
self.assertEqual(123, _decint(b"123\t\t"))
self.assertEqual(123, _decint(b" \t 123 \t "))
def test_invalidPlus(self) -> None:
"""
L{_decint} rejects a number with a leading C{+} character.
"""
self.assertRaises(ValueError, _decint, b"+1")
def test_invalidMinus(self) -> None:
"""
L{_decint} rejects a number with a leading C{-} character.
"""
self.assertRaises(ValueError, _decint, b"-1")
def test_invalidWhitespace(self) -> None:
"""
L{_decint} rejects a number embedded in non-linear whitespace.
"""
self.assertRaises(ValueError, _decint, b"\v1")
self.assertRaises(ValueError, _decint, b"\x1c1")
self.assertRaises(ValueError, _decint, b"1\x1e")
class HexHelperTests(unittest.SynchronousTestCase):
"""
Test the L{twisted.web._abnf._hexint} and L{_ishexdigits} helper functions.
"""
badStrings = (b"", b"0x1234", b"feds", b"-123" b"+123")
def test_isHex(self) -> None:
"""
L{_ishexdigits()} returns L{True} for nonempy bytestrings containing
hexadecimal digits.
"""
for s in (b"10", b"abcdef", b"AB1234", b"fed", b"123467890"):
self.assertIs(True, _ishexdigits(s))
def test_decodes(self) -> None:
"""
L{_hexint()} returns the integer equivalent of the input.
"""
self.assertEqual(10, _hexint(b"a"))
self.assertEqual(0x10, _hexint(b"10"))
self.assertEqual(0xABCD123, _hexint(b"abCD123"))
def test_isNotHex(self) -> None:
"""
L{_ishexdigits()} returns L{False} for bytestrings that don't contain
hexadecimal digits, including the empty string.
"""
for s in self.badStrings:
self.assertIs(False, _ishexdigits(s))
def test_decodeNotHex(self) -> None:
"""
L{_hexint()} raises L{ValueError} for bytestrings that can't
be decoded.
"""
for s in self.badStrings:
self.assertRaises(ValueError, _hexint, s)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,499 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.twcgi}.
"""
import json
import os
import sys
from io import BytesIO
from twisted.internet import address, error, interfaces, reactor
from twisted.internet.error import ConnectionLost
from twisted.python import failure, log, util
from twisted.trial import unittest
from twisted.web import client, http, http_headers, resource, server, twcgi
from twisted.web.http import INTERNAL_SERVER_ERROR, NOT_FOUND
from twisted.web.test._util import _render
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
DUMMY_CGI = """\
print("Header: OK")
print("")
print("cgi output")
"""
DUAL_HEADER_CGI = """\
print("Header: spam")
print("Header: eggs")
print("")
print("cgi output")
"""
BROKEN_HEADER_CGI = """\
print("XYZ")
print("")
print("cgi output")
"""
SPECIAL_HEADER_CGI = """\
print("Server: monkeys")
print("Date: last year")
print("")
print("cgi output")
"""
READINPUT_CGI = """\
# This is an example of a correctly-written CGI script which reads a body
# from stdin, which only reads env['CONTENT_LENGTH'] bytes.
import os, sys
body_length = int(os***REMOVED***iron.get('CONTENT_LENGTH',0))
indata = sys.stdin.read(body_length)
print("Header: OK")
print("")
print("readinput ok")
"""
READALLINPUT_CGI = """\
# This is an example of the typical (incorrect) CGI script which expects
# the server to close stdin when the body of the request is complete.
# A correct CGI should only read env['CONTENT_LENGTH'] bytes.
import sys
indata = sys.stdin.read()
print("Header: OK")
print("")
print("readallinput ok")
"""
NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI = """\
print("content-type: text/cgi-duplicate-test")
print("")
print("cgi output")
"""
HEADER_OUTPUT_CGI = """\
import json
import os
print("")
print("")
vals = {x:y for x,y in os***REMOVED***iron.items() if x.startswith("HTTP_")}
print(json.dumps(vals))
"""
URL_PARAMETER_CGI = """\
import os
param = str(os***REMOVED***iron['QUERY_STRING'])
print("Header: OK")
print("")
print(param)
"""
class PythonScript(twcgi.FilteredScript):
filter = sys.executable
class _StartServerAndTearDownMixin:
def startServer(self, cgi):
root = resource.Resource()
cgipath = util.sibpath(__file__, cgi)
root.putChild(b"cgi", PythonScript(cgipath))
site = server.Site(root)
self.p = reactor.listenTCP(0, site)
return self.p.getHost().port
def tearDown(self):
if getattr(self, "p", None):
return self.p.stopListening()
def writeCGI(self, source):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, "wt") as cgiFile:
cgiFile.write(source)
return cgiFilename
class CGITests(_StartServerAndTearDownMixin, unittest.TestCase):
"""
Tests for L{twcgi.FilteredScript}.
"""
if not interfaces.IReactorProcess.providedBy(reactor):
skip = "CGI tests require a functional reactor.spawnProcess()"
def test_CGI(self):
cgiFilename = self.writeCGI(DUMMY_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._testCGI_1)
return d
def _testCGI_1(self, res):
self.assertEqual(res, b"cgi output" + os.linesep.encode("ascii"))
def test_protectedServerAndDate(self):
"""
If the CGI script emits a I{Server} or I{Date} header, these are
ignored.
"""
cgiFilename = self.writeCGI(SPECIAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertNotIn("monkeys", response.headers.getRawHeaders("server"))
self.assertNotIn("last year", response.headers.getRawHeaders("date"))
d.addCallback(checkResponse)
return d
def test_noDuplicateContentTypeHeaders(self):
"""
If the CGI script emits a I{content-type} header, make sure that the
server doesn't add an additional (duplicate) one, as per ticket 4786.
"""
cgiFilename = self.writeCGI(NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(
response.headers.getRawHeaders("content-type"),
["text/cgi-duplicate-test"],
)
return response
d.addCallback(checkResponse)
return d
def test_noProxyPassthrough(self):
"""
The CGI script is never called with the Proxy header passed through.
"""
cgiFilename = self.writeCGI(HEADER_OUTPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
headers = http_headers.Headers(
{b"Proxy": [b"foo"], b"X-Innocent-Header": [b"bar"]}
)
d = agent.request(b"GET", url, headers=headers)
def checkResponse(response):
headers = json.loads(response.decode("ascii"))
self.assertEqual(
set(headers.keys()),
{"HTTP_HOST", "HTTP_CONNECTION", "HTTP_X_INNOCENT_HEADER"},
)
d.addCallback(client.readBody)
d.addCallback(checkResponse)
return d
def test_duplicateHeaderCGI(self):
"""
If a CGI script emits two instances of the same header, both are sent
in the response.
"""
cgiFilename = self.writeCGI(DUAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(response.headers.getRawHeaders("header"), ["spam", "eggs"])
d.addCallback(checkResponse)
return d
def test_malformedHeaderCGI(self):
"""
Check for the error message in the duplicated header
"""
cgiFilename = self.writeCGI(BROKEN_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
loggedMessages = []
def addMessage(eventDict):
loggedMessages.append(log.textFromEventDict(eventDict))
log.addObserver(addMessage)
self.addCleanup(log.removeObserver, addMessage)
def checkResponse(ignored):
self.assertIn(
"ignoring malformed CGI header: " + repr(b"XYZ"), loggedMessages
)
d.addCallback(checkResponse)
return d
def test_ReadEmptyInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, "wt") as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadEmptyInput_1)
return d
test_ReadEmptyInput.timeout = 5 # type: ignore[attr-defined]
def _test_ReadEmptyInput_1(self, res):
expected = f"readinput ok{os.linesep}"
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, "wt") as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadInput_1)
return d
test_ReadInput.timeout = 5 # type: ignore[attr-defined]
def _test_ReadInput_1(self, res):
expected = f"readinput ok{os.linesep}"
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadAllInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, "wt") as cgiFile:
cgiFile.write(READALLINPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadAllInput_1)
return d
test_ReadAllInput.timeout = 5 # type: ignore[attr-defined]
def _test_ReadAllInput_1(self, res):
expected = f"readallinput ok{os.linesep}"
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_useReactorArgument(self):
"""
L{twcgi.FilteredScript.runProcess} uses the reactor passed as an
argument to the constructor.
"""
class FakeReactor:
"""
A fake reactor recording whether spawnProcess is called.
"""
called = False
def spawnProcess(self, *args, **kwargs):
"""
Set the C{called} flag to C{True} if C{spawnProcess} is called.
@param args: Positional arguments.
@param kwargs: Keyword arguments.
"""
self.called = True
fakeReactor = FakeReactor()
request = DummyRequest(["a", "b"])
request.client = address.IPv4Address("TCP", "127.0.0.1", 12345)
resource = twcgi.FilteredScript("dummy-file", reactor=fakeReactor)
_render(resource, request)
self.assertTrue(fakeReactor.called)
class CGIScriptTests(_StartServerAndTearDownMixin, unittest.TestCase):
"""
Tests for L{twcgi.CGIScript}.
"""
def test_urlParameters(self):
"""
If the CGI script is passed URL parameters, do not fall over,
as per ticket 9887.
"""
cgiFilename = self.writeCGI(URL_PARAMETER_CGI)
portnum = self.startServer(cgiFilename)
url = b"http://localhost:%d/cgi?param=1234" % (portnum,)
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._test_urlParameters_1)
return d
def _test_urlParameters_1(self, res):
expected = f"param=1234{os.linesep}"
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_pathInfo(self):
"""
L{twcgi.CGIScript.render} sets the process environment
I{PATH_INFO} from the request path.
"""
class FakeReactor:
"""
A fake reactor recording the environment passed to spawnProcess.
"""
def spawnProcess(self, process, filename, args, env, wdir):
"""
Store the C{env} L{dict} to an instance attribute.
@param process: Ignored
@param filename: Ignored
@param args: Ignored
@param env: The environment L{dict} which will be stored
@param wdir: Ignored
"""
self.process_env = env
_reactor = FakeReactor()
resource = twcgi.CGIScript(self.mktemp(), reactor=_reactor)
request = DummyRequest(["a", "b"])
request.client = address.IPv4Address("TCP", "127.0.0.1", 12345)
_render(resource, request)
self.assertEqual(_reactor.process_env["PATH_INFO"], "/a/b")
class CGIDirectoryTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIDirectory}.
"""
def test_render(self):
"""
L{twcgi.CGIDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = twcgi.CGIDirectory(self.mktemp())
request = DummyRequest([""])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{twcgi.CGIDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{twcgi.CGIDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = twcgi.CGIDirectory(path)
request = DummyRequest(["foo"])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
class CGIProcessProtocolTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIProcessProtocol}.
"""
def test_prematureEndOfHeaders(self):
"""
If the process communicating with L{CGIProcessProtocol} ends before
finishing writing out headers, the response has I{INTERNAL SERVER
ERROR} as its status code.
"""
request = DummyRequest([""])
protocol = twcgi.CGIProcessProtocol(request)
protocol.processEnded(failure.Failure(error.ProcessTerminated()))
self.assertEqual(request.responseCode, INTERNAL_SERVER_ERROR)
def test_connectionLost(self):
"""
Ensure that the CGI process ends cleanly when the request connection
is lost.
"""
d = DummyChannel()
request = http.Request(d, True)
protocol = twcgi.CGIProcessProtocol(request)
request.connectionLost(failure.Failure(ConnectionLost("Connection done")))
protocol.processEnded(failure.Failure(error.ProcessTerminated()))
def discardBody(response):
"""
Discard the body of a HTTP response.
@param response: The response.
@return: The response.
"""
return client.readBody(response).addCallback(lambda _: response)

View File

@@ -0,0 +1,50 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for various parts of L{twisted.web}.
"""
from __future__ import annotations
from zope.interface import implementer, verify
from twisted.internet import defer, interfaces
from twisted.trial import unittest
from twisted.web import client
@implementer(interfaces.IStreamClientEndpoint)
class DummyEndPoint:
"""An endpoint that does not connect anywhere"""
def __init__(self, someString: str) -> None:
self.someString = someString
def __repr__(self) -> str:
return f"DummyEndPoint({self.someString})"
def connect( # type: ignore[override]
self, factory: interfaces.IProtocolFactory
) -> defer.Deferred[dict[str, interfaces.IProtocolFactory]]:
return defer.succeed(dict(factory=factory))
class HTTPConnectionPoolTests(unittest.TestCase):
"""
Unit tests for L{client.HTTPConnectionPoolTest}.
"""
def test_implements(self) -> None:
"""L{DummyEndPoint}s implements L{interfaces.IStreamClientEndpoint}"""
ep = DummyEndPoint("something")
verify.verifyObject(interfaces.IStreamClientEndpoint, ep)
def test_repr(self) -> None:
"""connection L{repr()} includes endpoint's L{repr()}"""
pool = client.HTTPConnectionPool(reactor=None)
ep = DummyEndPoint("this_is_probably_unique")
d = pool.getConnection("someplace", ep)
result = self.successResultOf(d)
representation = repr(result)
self.assertIn(repr(ep), representation)

View File

@@ -0,0 +1,502 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.distrib}.
"""
from os.path import abspath
from xml.dom.minidom import parseString
try:
import pwd as _pwd
except ImportError:
pwd = None
else:
pwd = _pwd
from unittest import skipIf
from zope.interface.verify import verifyObject
from twisted.internet import defer, reactor
from twisted.logger import globalLogPublisher
from twisted.python import failure, filepath
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.test import proto_helpers
from twisted.trial.unittest import TestCase
from twisted.web import client, distrib, resource, server, static
from twisted.web.http_headers import Headers
from twisted.web.test._util import _render
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
class MySite(server.Site):
pass
class PBServerFactory(pb.PBServerFactory):
"""
A PB server factory which keeps track of the most recent protocol it
created.
@ivar proto: L{None} or the L{Broker} instance most recently returned
from C{buildProtocol}.
"""
proto = None
def buildProtocol(self, addr):
self.proto = pb.PBServerFactory.buildProtocol(self, addr)
return self.proto
class ArbitraryError(Exception):
"""
An exception for this test.
"""
class DistribTests(TestCase):
port1 = None
port2 = None
sub = None
f1 = None
def tearDown(self):
"""
Clean up all the event sources left behind by either directly by
test methods or indirectly via some distrib API.
"""
dl = [defer.Deferred(), defer.Deferred()]
if self.f1 is not None and self.f1.proto is not None:
self.f1.proto.notifyOnDisconnect(lambda: dl[0].callback(None))
else:
dl[0].callback(None)
if self.sub is not None and self.sub.publisher is not None:
self.sub.publisher.broker.notifyOnDisconnect(lambda: dl[1].callback(None))
self.sub.publisher.broker.transport.loseConnection()
else:
dl[1].callback(None)
if self.port1 is not None:
dl.append(self.port1.stopListening())
if self.port2 is not None:
dl.append(self.port2.stopListening())
return defer.gatherResults(dl)
def testDistrib(self):
# site1 is the publisher
r1 = resource.Resource()
r1.putChild(b"there", static.Data(b"root", "text/plain"))
site1 = server.Site(r1)
self.f1 = PBServerFactory(distrib.ResourcePublisher(site1))
self.port1 = reactor.listenTCP(0, self.f1)
self.sub = distrib.ResourceSubscription("127.0.0.1", self.port1.getHost().port)
r2 = resource.Resource()
r2.putChild(b"here", self.sub)
f2 = MySite(r2)
self.port2 = reactor.listenTCP(0, f2)
agent = client.Agent(reactor)
url = f"http://127.0.0.1:{self.port2.getHost().port}/here/there"
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self.assertEqual, b"root")
return d
def _setupDistribServer(self, child):
"""
Set up a resource on a distrib site using L{ResourcePublisher}.
@param child: The resource to publish using distrib.
@return: A tuple consisting of the host and port on which to contact
the created site.
"""
distribRoot = resource.Resource()
distribRoot.putChild(b"child", child)
distribSite = server.Site(distribRoot)
self.f1 = distribFactory = PBServerFactory(
distrib.ResourcePublisher(distribSite)
)
distribPort = reactor.listenTCP(0, distribFactory, interface="127.0.0.1")
self.addCleanup(distribPort.stopListening)
addr = distribPort.getHost()
self.sub = mainRoot = distrib.ResourceSubscription(addr.host, addr.port)
mainSite = server.Site(mainRoot)
mainPort = reactor.listenTCP(0, mainSite, interface="127.0.0.1")
self.addCleanup(mainPort.stopListening)
mainAddr = mainPort.getHost()
return mainPort, mainAddr
def _requestTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with the result of the request.
"""
mainPort, mainAddr = self._setupDistribServer(child)
agent = client.Agent(reactor)
url = f"http://{mainAddr.host}:{mainAddr.port}/child"
url = url.encode("ascii")
d = agent.request(b"GET", url, **kwargs)
d.addCallback(client.readBody)
return d
def _requestAgentTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with a tuple consisting of a
L{twisted.test.proto_helpers.AccumulatingProtocol} containing the
body of the response and an L{IResponse} with the response itself.
"""
mainPort, mainAddr = self._setupDistribServer(child)
url = f"http://{mainAddr.host}:{mainAddr.port}/child"
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url, **kwargs)
def cbCollectBody(response):
protocol = proto_helpers.AccumulatingProtocol()
response.deliverBody(protocol)
d = protocol.closedDeferred = defer.Deferred()
d.addCallback(lambda _: (protocol, response))
return d
d.addCallback(cbCollectBody)
return d
def test_requestHeaders(self):
"""
The request headers are available on the request object passed to a
distributed resource's C{render} method.
"""
requestHeaders = {}
logObserver = proto_helpers.EventLoggingObserver()
globalLogPublisher.addObserver(logObserver)
req = [None]
class ReportRequestHeaders(resource.Resource):
def render(self, request):
req[0] = request
requestHeaders.update(dict(request.requestHeaders.getAllRawHeaders()))
return b""
def check_logs():
msgs = [e["log_format"] for e in logObserver]
self.assertIn("connected to publisher", msgs)
self.assertIn("could not connect to distributed web service: {msg}", msgs)
self.assertIn(req[0], msgs)
globalLogPublisher.removeObserver(logObserver)
request = self._requestTest(
ReportRequestHeaders(), headers=Headers({"foo": ["bar"]})
)
def cbRequested(result):
self.f1.proto.notifyOnDisconnect(check_logs)
self.assertEqual(requestHeaders[b"Foo"], [b"bar"])
request.addCallback(cbRequested)
return request
def test_requestResponseCode(self):
"""
The response code can be set by the request object passed to a
distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200)
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"OK")
request.addCallback(cbRequested)
return request
def test_requestResponseCodeMessage(self):
"""
The response code and message can be set by the request object passed to
a distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200, b"some-message")
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"some-message")
request.addCallback(cbRequested)
return request
def test_largeWrite(self):
"""
If a string longer than the Banana size limit is passed to the
L{distrib.Request} passed to the remote resource, it is broken into
smaller strings to be transported over the PB connection.
"""
class LargeWrite(resource.Resource):
def render(self, request):
request.write(b"x" * SIZE_LIMIT + b"y")
request.finish()
return server.NOT_DONE_YET
request = self._requestTest(LargeWrite())
request.addCallback(self.assertEqual, b"x" * SIZE_LIMIT + b"y")
return request
def test_largeReturn(self):
"""
Like L{test_largeWrite}, but for the case where C{render} returns a
long string rather than explicitly passing it to L{Request.write}.
"""
class LargeReturn(resource.Resource):
def render(self, request):
return b"x" * SIZE_LIMIT + b"y"
request = self._requestTest(LargeReturn())
request.addCallback(self.assertEqual, b"x" * SIZE_LIMIT + b"y")
return request
def test_connectionLost(self):
"""
If there is an error issuing the request to the remote publisher, an
error response is returned.
"""
# Using pb.Root as a publisher will cause request calls to fail with an
# error every time. Just what we want to test.
self.f1 = serverFactory = PBServerFactory(pb.Root())
self.port1 = serverPort = reactor.listenTCP(0, serverFactory)
self.sub = subscription = distrib.ResourceSubscription(
"127.0.0.1", serverPort.getHost().port
)
request = DummyRequest([b""])
d = _render(subscription, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 500)
# This is the error we caused the request to fail with. It should
# have been logged.
errors = self.flushLoggedErrors(pb.NoSuchMethod)
self.assertEqual(len(errors), 1)
# The error page is rendered as HTML.
expected = [
b"",
b"<html>",
b" <head><title>500 - Server Connection Lost</title></head>",
b" <body>",
b" <h1>Server Connection Lost</h1>",
b" <p>Connection to distributed server lost:"
b"<pre>"
b"[Failure instance: Traceback from remote host -- "
b"twisted.spread.flavors.NoSuchMethod: "
b"No such method: remote_request",
b"]</pre></p>",
b" </body>",
b"</html>",
b"",
]
self.assertEqual([b"\n".join(expected)], request.written)
d.addCallback(cbRendered)
return d
def test_logFailed(self):
"""
When a request fails, the string form of the failure is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self, globalLogPublisher
)
f = failure.Failure(ArbitraryError())
request = DummyRequest([b""])
issue = distrib.Issue(request)
issue.failed(f)
self.assertEquals(1, len(logObserver))
self.assertIn("Failure instance", logObserver[0]["log_format"])
def test_requestFail(self):
"""
When L{twisted.web.distrib.Request}'s fail is called, the failure
is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self, globalLogPublisher
)
err = ArbitraryError()
f = failure.Failure(err)
req = distrib.Request(DummyChannel())
req.fail(f)
self.flushLoggedErrors(ArbitraryError)
self.assertEquals(1, len(logObserver))
self.assertIs(logObserver[0]["log_failure"], f)
class _PasswordDatabase:
def __init__(self, users):
self._users = users
def getpwall(self):
return iter(self._users)
def getpwnam(self, username):
for user in self._users:
if user[0] == username:
return user
raise KeyError()
class UserDirectoryTests(TestCase):
"""
Tests for L{UserDirectory}, a resource for listing all user resources
available on a system.
"""
def setUp(self):
self.alice = ("alice", "x", 123, 456, "Alice,,,", self.mktemp(), "/bin/sh")
self.bob = ("bob", "x", 234, 567, "Bob,,,", self.mktemp(), "/bin/sh")
self.database = _PasswordDatabase([self.alice, self.bob])
self.directory = distrib.UserDirectory(self.database)
def test_interface(self):
"""
L{UserDirectory} instances provide L{resource.IResource}.
"""
self.assertTrue(verifyObject(resource.IResource, self.directory))
async def _404Test(self, name: bytes) -> None:
"""
Verify that requesting the C{name} child of C{self.directory} results
in a 404 response.
"""
request = DummyRequest([name])
result = self.directory.getChild(name, request)
d = _render(result, request)
await d
self.assertEqual(request.responseCode, 404)
async def test_getInvalidUser(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which does not correspond to any known
user.
"""
await self._404Test(b"carol")
async def test_getUserWithoutResource(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which corresponds to a known user who has
neither a user directory nor a user distrib socket.
"""
await self._404Test(b"alice")
def test_getPublicHTMLChild(self):
"""
L{UserDirectory.getChild} returns a L{static.File} instance when passed
the name of a user with a home directory containing a I{public_html}
directory.
"""
home = filepath.FilePath(self.bob[-2])
public_html = home.child("public_html")
public_html.makedirs()
request = DummyRequest(["bob"])
result = self.directory.getChild(b"bob", request)
self.assertIsInstance(result, static.File)
self.assertEqual(result.path, public_html.path)
def test_getDistribChild(self):
"""
L{UserDirectory.getChild} returns a L{ResourceSubscription} instance
when passed the name of a user suffixed with C{".twistd"} who has a
home directory containing a I{.twistd-web-pb} socket.
"""
home = filepath.FilePath(self.bob[-2])
home.makedirs()
web = home.child(".twistd-web-pb")
request = DummyRequest(["bob"])
result = self.directory.getChild(b"bob.twistd", request)
self.assertIsInstance(result, distrib.ResourceSubscription)
self.assertEqual(result.host, "unix")
self.assertEqual(abspath(result.port), web.path)
def test_invalidMethod(self):
"""
L{UserDirectory.render} raises L{UnsupportedMethod} in response to a
non-I{GET} request.
"""
request = DummyRequest([""])
request.method = "POST"
self.assertRaises(server.UnsupportedMethod, self.directory.render, request)
def test_render(self):
"""
L{UserDirectory} renders a list of links to available user content
in response to a I{GET} request.
"""
public_html = filepath.FilePath(self.alice[-2]).child("public_html")
public_html.makedirs()
web = filepath.FilePath(self.bob[-2])
web.makedirs()
# This really only works if it's a unix socket, but the implementation
# doesn't currently check for that. It probably should someday, and
# then skip users with non-sockets.
web.child(".twistd-web-pb").setContent(b"")
request = DummyRequest([""])
result = _render(self.directory, request)
def cbRendered(ignored):
document = parseString(b"".join(request.written))
# Each user should have an li with a link to their page.
[alice, bob] = document.getElementsByTagName("li")
self.assertEqual(alice.firstChild.tagName, "a")
self.assertEqual(alice.firstChild.getAttribute("href"), "alice/")
self.assertEqual(alice.firstChild.firstChild.data, "Alice (file)")
self.assertEqual(bob.firstChild.tagName, "a")
self.assertEqual(bob.firstChild.getAttribute("href"), "bob.twistd/")
self.assertEqual(bob.firstChild.firstChild.data, "Bob (twistd)")
result.addCallback(cbRendered)
return result
@skipIf(not pwd, "pwd module required")
def test_passwordDatabase(self):
"""
If L{UserDirectory} is instantiated with no arguments, it uses the
L{pwd} module as its password database.
"""
directory = distrib.UserDirectory()
self.assertIdentical(directory._pwd, pwd)

View File

@@ -0,0 +1,305 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Specific tests for (some of) the methods in L{twisted.web.domhelpers}.
"""
from importlib import reload
from typing import Any, Optional
from xml.dom import minidom
from twisted.trial.unittest import TestCase
from twisted.web import domhelpers, microdom
class DOMHelpersTestsMixin:
"""
A mixin for L{TestCase} subclasses which defines test methods for
domhelpers functionality based on a DOM creation function provided by a
subclass.
"""
dom: Optional[Any] = None
def test_getElementsByTagName(self):
doc1 = self.dom.parseString("<foo/>")
actual = domhelpers.getElementsByTagName(doc1, "foo")[0].nodeName
expected = "foo"
self.assertEqual(actual, expected)
el1 = doc1.documentElement
actual = domhelpers.getElementsByTagName(el1, "foo")[0].nodeName
self.assertEqual(actual, expected)
doc2_xml = '<a><foo in="a"/><b><foo in="b"/></b><c><foo in="c"/></c><foo in="d"/><foo in="ef"/><g><foo in="g"/><h><foo in="h"/></h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
tag_list = domhelpers.getElementsByTagName(doc2, "foo")
actual = "".join([node.getAttribute("in") for node in tag_list])
expected = "abcdefgh"
self.assertEqual(actual, expected)
el2 = doc2.documentElement
tag_list = domhelpers.getElementsByTagName(el2, "foo")
actual = "".join([node.getAttribute("in") for node in tag_list])
self.assertEqual(actual, expected)
doc3_xml = """
<a><foo in="a"/>
<b><foo in="b"/>
<d><foo in="d"/>
<g><foo in="g"/></g>
<h><foo in="h"/></h>
</d>
<e><foo in="e"/>
<i><foo in="i"/></i>
</e>
</b>
<c><foo in="c"/>
<f><foo in="f"/>
<j><foo in="j"/></j>
</f>
</c>
</a>"""
doc3 = self.dom.parseString(doc3_xml)
tag_list = domhelpers.getElementsByTagName(doc3, "foo")
actual = "".join([node.getAttribute("in") for node in tag_list])
expected = "abdgheicfj"
self.assertEqual(actual, expected)
el3 = doc3.documentElement
tag_list = domhelpers.getElementsByTagName(el3, "foo")
actual = "".join([node.getAttribute("in") for node in tag_list])
self.assertEqual(actual, expected)
doc4_xml = "<foo><bar></bar><baz><foo/></baz></foo>"
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.getElementsByTagName(doc4, "foo")
root = doc4.documentElement
expected = [root, root.childNodes[-1].childNodes[0]]
self.assertEqual(actual, expected)
actual = domhelpers.getElementsByTagName(root, "foo")
self.assertEqual(actual, expected)
def test_gatherTextNodes(self):
doc1 = self.dom.parseString("<a>foo</a>")
actual = domhelpers.gatherTextNodes(doc1)
expected = "foo"
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc1.documentElement)
self.assertEqual(actual, expected)
doc2_xml = "<a>a<b>b</b><c>c</c>def<g>g<h>h</h></g></a>"
doc2 = self.dom.parseString(doc2_xml)
actual = domhelpers.gatherTextNodes(doc2)
expected = "abcdefgh"
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc2.documentElement)
self.assertEqual(actual, expected)
doc3_xml = (
"<a>a<b>b<d>d<g>g</g><h>h</h></d><e>e<i>i</i></e></b>"
+ "<c>c<f>f<j>j</j></f></c></a>"
)
doc3 = self.dom.parseString(doc3_xml)
actual = domhelpers.gatherTextNodes(doc3)
expected = "abdgheicfj"
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc3.documentElement)
self.assertEqual(actual, expected)
def test_clearNode(self):
doc1 = self.dom.parseString("<a><b><c><d/></c></b></a>")
a_node = doc1.documentElement
domhelpers.clearNode(a_node)
self.assertEqual(a_node.toxml(), self.dom.Element("a").toxml())
doc2 = self.dom.parseString("<a><b><c><d/></c></b></a>")
b_node = doc2.documentElement.childNodes[0]
domhelpers.clearNode(b_node)
actual = doc2.documentElement.toxml()
expected = self.dom.Element("a")
expected.appendChild(self.dom.Element("b"))
self.assertEqual(actual, expected.toxml())
def test_get(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.get(doc1, "foo")
actual = node.toxml()
expected = doc.createElement("c")
expected.setAttribute("class", "foo")
self.assertEqual(actual, expected.toxml())
node = domhelpers.get(doc1, "bar")
actual = node.toxml()
expected = doc.createElement("b")
expected.setAttribute("id", "bar")
self.assertEqual(actual, expected.toxml())
self.assertRaises(domhelpers.NodeLookupError, domhelpers.get, doc1, "pzork")
def test_getIfExists(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.getIfExists(doc1, "foo")
actual = node.toxml()
expected = doc.createElement("c")
expected.setAttribute("class", "foo")
self.assertEqual(actual, expected.toxml())
node = domhelpers.getIfExists(doc1, "pzork")
self.assertIdentical(node, None)
def test_getAndClear(self):
doc1 = self.dom.parseString('<a><b id="foo"><c></c></b></a>')
doc = self.dom.Document()
node = domhelpers.getAndClear(doc1, "foo")
actual = node.toxml()
expected = doc.createElement("b")
expected.setAttribute("id", "foo")
self.assertEqual(actual, expected.toxml())
def test_locateNodes(self):
doc1 = self.dom.parseString(
'<a><b foo="olive"><c foo="olive"/></b><d foo="poopy"/></a>'
)
doc = self.dom.Document()
node_list = domhelpers.locateNodes(doc1.childNodes, "foo", "olive", noNesting=1)
actual = "".join([node.toxml() for node in node_list])
expected = doc.createElement("b")
expected.setAttribute("foo", "olive")
c = doc.createElement("c")
c.setAttribute("foo", "olive")
expected.appendChild(c)
self.assertEqual(actual, expected.toxml())
node_list = domhelpers.locateNodes(doc1.childNodes, "foo", "olive", noNesting=0)
actual = "".join([node.toxml() for node in node_list])
self.assertEqual(actual, expected.toxml() + c.toxml())
def test_getParents(self):
doc1 = self.dom.parseString("<a><b><c><d/></c><e/></b><f/></a>")
node_list = domhelpers.getParents(
doc1.childNodes[0].childNodes[0].childNodes[0]
)
actual = "".join(
[node.tagName for node in node_list if hasattr(node, "tagName")]
)
self.assertEqual(actual, "cba")
def test_findElementsWithAttribute(self):
doc1 = self.dom.parseString('<a foo="1"><b foo="2"/><c foo="1"/><d/></a>')
node_list = domhelpers.findElementsWithAttribute(doc1, "foo")
actual = "".join([node.tagName for node in node_list])
self.assertEqual(actual, "abc")
node_list = domhelpers.findElementsWithAttribute(doc1, "foo", "1")
actual = "".join([node.tagName for node in node_list])
self.assertEqual(actual, "ac")
def test_findNodesNamed(self):
doc1 = self.dom.parseString("<doc><foo/><bar/><foo>a</foo></doc>")
node_list = domhelpers.findNodesNamed(doc1, "foo")
actual = len(node_list)
self.assertEqual(actual, 2)
def test_escape(self):
j = "this string \" contains many & characters> xml< won't like"
expected = (
"this string &quot; contains many &amp; characters&gt; xml&lt; won't like"
)
self.assertEqual(domhelpers.escape(j), expected)
def test_unescape(self):
j = "this string &quot; has &&amp; entities &gt; &lt; and some characters xml won't like<"
expected = (
"this string \" has && entities > < and some characters xml won't like<"
)
self.assertEqual(domhelpers.unescape(j), expected)
def test_getNodeText(self):
"""
L{getNodeText} returns the concatenation of all the text data at or
beneath the node passed to it.
"""
node = self.dom.parseString("<foo><bar>baz</bar><bar>quux</bar></foo>")
self.assertEqual(domhelpers.getNodeText(node), "bazquux")
class MicroDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = microdom
def test_gatherTextNodesDropsWhitespace(self):
"""
Microdom discards whitespace-only text nodes, so L{gatherTextNodes}
returns only the text from nodes which had non-whitespace characters.
"""
doc4_xml = """<html>
<head>
</head>
<body>
stuff
</body>
</html>
"""
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.gatherTextNodes(doc4)
expected = "\n stuff\n "
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc4.documentElement)
self.assertEqual(actual, expected)
def test_textEntitiesNotDecoded(self):
"""
Microdom does not decode entities in text nodes.
"""
doc5_xml = "<x>Souffl&amp;</x>"
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = "Souffl&amp;"
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
def test_deprecation(self):
"""
An import will raise the deprecation warning.
"""
reload(domhelpers)
warnings = self.flushWarnings([self.test_deprecation])
self.assertEqual(1, len(warnings))
self.assertEqual(
"twisted.web.domhelpers was deprecated at Twisted 23.10.0",
warnings[0]["message"],
)
class MiniDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = minidom
def test_textEntitiesDecoded(self):
"""
Minidom does decode entities in text nodes.
"""
doc5_xml = "<x>Souffl&amp;</x>"
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = "Souffl&"
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
def test_getNodeUnicodeText(self):
"""
L{domhelpers.getNodeText} returns a C{unicode} string when text
nodes are represented in the DOM with unicode, whether or not there
are non-ASCII characters present.
"""
node = self.dom.parseString("<foo>bar</foo>")
text = domhelpers.getNodeText(node)
self.assertEqual(text, "bar")
self.assertIsInstance(text, str)
node = self.dom.parseString("<foo>\N{SNOWMAN}</foo>".encode())
text = domhelpers.getNodeText(node)
self.assertEqual(text, "\N{SNOWMAN}")
self.assertIsInstance(text, str)

View File

@@ -0,0 +1,477 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP errors.
"""
from __future__ import annotations
import re
import sys
import traceback
from twisted.python.compat import nativeString
from twisted.trial import unittest
from twisted.web import error
from twisted.web.template import Tag
class CodeToMessageTests(unittest.TestCase):
"""
L{_codeToMessages} inverts L{_responses.RESPONSES}
"""
def test_validCode(self) -> None:
m = error._codeToMessage(b"302")
self.assertEqual(m, b"Found")
def test_invalidCode(self) -> None:
m = error._codeToMessage(b"987")
self.assertEqual(m, None)
def test_nonintegerCode(self) -> None:
m = error._codeToMessage(b"InvalidCode")
self.assertEqual(m, None)
class ErrorTests(unittest.TestCase):
"""
Tests for how L{Error} attributes are initialized.
"""
def test_noMessageValidStatus(self) -> None:
"""
If no C{message} argument is passed to the L{Error} constructor and the
C{code} argument is a valid HTTP status code, C{message} is set to the
HTTP reason phrase for C{code}.
"""
e = error.Error(b"200")
self.assertEqual(e.message, b"OK")
self.assertEqual(str(e), "200 OK")
def test_noMessageForStatus(self) -> None:
"""
If no C{message} argument is passed to the L{Error} constructor and
C{code} isn't a known HTTP status code, C{message} stays L{None}.
"""
e = error.Error(b"999")
self.assertEqual(e.message, None)
self.assertEqual(str(e), "999")
def test_invalidStatus(self) -> None:
"""
If C{code} isn't plausibly an HTTP status code (i.e., composed of
digits) it is rejected with L{ValueError}.
"""
with self.assertRaises(ValueError):
error.Error(b"InvalidStatus")
def test_messageExists(self) -> None:
"""
If a C{message} argument is passed to the L{Error} constructor, the
C{message} isn't affected by the value of C{status}.
"""
e = error.Error(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
self.assertEqual(str(e), "200 My own message")
def test_str(self) -> None:
"""
C{str()} on an L{Error} returns the code and message it was
instantiated with.
"""
# Bytestring status
e = error.Error(b"200", b"OK")
self.assertEqual(str(e), "200 OK")
# int status
e = error.Error(200, b"OK")
self.assertEqual(str(e), "200 OK")
class PageRedirectTests(unittest.TestCase):
"""
Tests for how L{PageRedirect} attributes are initialized.
"""
def test_noMessageValidStatus(self) -> None:
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and the C{code} argument is a valid HTTP status code, C{code} is mapped
to a descriptive string to which C{message} is assigned.
"""
e = error.PageRedirect(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self) -> None:
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{location} is also empty and the C{code} argument is a valid HTTP
status code, C{code} is mapped to a descriptive string to which
C{message} is assigned without trying to include an empty location.
"""
e = error.PageRedirect(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self) -> None:
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{code} isn't a valid HTTP status code, C{message} stays L{None}.
"""
e = error.PageRedirect(b"999", location=b"/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self) -> None:
"""
If a C{message} argument is passed to the L{PageRedirect} constructor,
the C{message} isn't affected by the value of C{status}.
"""
e = error.PageRedirect(b"200", b"My own message", location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self) -> None:
"""
If a C{message} argument is passed to the L{PageRedirect} constructor
and no location is provided, C{message} doesn't try to include the
empty location.
"""
e = error.PageRedirect(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class InfiniteRedirectionTests(unittest.TestCase):
"""
Tests for how L{InfiniteRedirection} attributes are initialized.
"""
def test_noMessageValidStatus(self) -> None:
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and the C{code} argument is a valid HTTP status code,
C{code} is mapped to a descriptive string to which C{message} is
assigned.
"""
e = error.InfiniteRedirection(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self) -> None:
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{location} is also empty and the C{code} argument is a
valid HTTP status code, C{code} is mapped to a descriptive string to
which C{message} is assigned without trying to include an empty
location.
"""
e = error.InfiniteRedirection(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self) -> None:
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{code} isn't a valid HTTP status code, C{message} stays
L{None}.
"""
e = error.InfiniteRedirection(b"999", location=b"/foo")
self.assertEqual(e.message, None)
self.assertEqual(str(e), "999")
def test_messageExistsLocationExists(self) -> None:
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor, the C{message} isn't affected by the value of C{status}.
"""
e = error.InfiniteRedirection(b"200", b"My own message", location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self) -> None:
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor and no location is provided, C{message} doesn't try to
include the empty location.
"""
e = error.InfiniteRedirection(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class RedirectWithNoLocationTests(unittest.TestCase):
"""
L{RedirectWithNoLocation} is a subclass of L{Error} which sets
a custom message in the constructor.
"""
def test_validMessage(self) -> None:
"""
When C{code}, C{message}, and C{uri} are passed to the
L{RedirectWithNoLocation} constructor, the C{message} and C{uri}
attributes are set, respectively.
"""
e = error.RedirectWithNoLocation(b"302", b"REDIRECT", b"https://example.com")
self.assertEqual(e.message, b"REDIRECT to https://example.com")
self.assertEqual(e.uri, b"https://example.com")
class MissingRenderMethodTests(unittest.TestCase):
"""
Tests for how L{MissingRenderMethod} exceptions are initialized and
displayed.
"""
def test_constructor(self) -> None:
"""
Given C{element} and C{renderName} arguments, the
L{MissingRenderMethod} constructor assigns the values to the
corresponding attributes.
"""
elt = object()
e = error.MissingRenderMethod(elt, "renderThing")
self.assertIs(e.element, elt)
self.assertIs(e.renderName, "renderThing")
def test_repr(self) -> None:
"""
A L{MissingRenderMethod} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingRenderMethod(elt, "renderThing")
self.assertEqual(
repr(e),
("'MissingRenderMethod': " "%r had no render method named 'renderThing'")
% elt,
)
class MissingTemplateLoaderTests(unittest.TestCase):
"""
Tests for how L{MissingTemplateLoader} exceptions are initialized and
displayed.
"""
def test_constructor(self) -> None:
"""
Given an C{element} argument, the L{MissingTemplateLoader} constructor
assigns the value to the corresponding attribute.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertIs(e.element, elt)
def test_repr(self) -> None:
"""
A L{MissingTemplateLoader} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertEqual(repr(e), "'MissingTemplateLoader': %r had no loader" % elt)
class FlattenerErrorTests(unittest.TestCase):
"""
Tests for L{FlattenerError}.
"""
def makeFlattenerError(self, roots: list[object] = []) -> error.FlattenerError:
try:
raise RuntimeError("oh noes")
except Exception as e:
tb = traceback.extract_tb(sys.exc_info()[2])
return error.FlattenerError(e, roots, tb)
def fakeFormatRoot(self, obj: object) -> str:
return "R(%s)" % obj
def test_constructor(self) -> None:
"""
Given C{exception}, C{roots}, and C{traceback} arguments, the
L{FlattenerError} constructor assigns the roots to the C{_roots}
attribute.
"""
e = self.makeFlattenerError(roots=["a", "b"])
self.assertEqual(e._roots, ["a", "b"])
def test_str(self) -> None:
"""
The string form of a L{FlattenerError} is identical to its
representation.
"""
e = self.makeFlattenerError()
self.assertEqual(str(e), repr(e))
def test_reprWithRootsAndWithTraceback(self) -> None:
"""
The representation of a L{FlattenerError} initialized with roots and a
traceback contains a formatted representation of those roots (using
C{_formatRoot}) and a formatted traceback.
"""
e = self.makeFlattenerError(["a", "b"])
e._formatRoot = self.fakeFormatRoot # type: ignore[method-assign]
self.assertTrue(
re.match(
"Exception while flattening:\n"
" R\\(a\\)\n"
" R\\(b\\)\n"
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\\("oh noes"\\)\n'
"RuntimeError: oh noes\n$",
repr(e),
re.M | re.S,
),
repr(e),
)
def test_reprWithoutRootsAndWithTraceback(self) -> None:
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = self.makeFlattenerError([])
self.assertTrue(
re.match(
"Exception while flattening:\n"
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\\("oh noes"\\)\n'
"RuntimeError: oh noes\n$",
repr(e),
re.M | re.S,
),
repr(e),
)
def test_reprWithoutRootsAndWithoutTraceback(self) -> None:
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = error.FlattenerError(RuntimeError("oh noes"), [], None)
self.assertTrue(
re.match(
"Exception while flattening:\n" "RuntimeError: oh noes\n$",
repr(e),
re.M | re.S,
),
repr(e),
)
def test_formatRootShortUnicodeString(self) -> None:
"""
The C{_formatRoot} method formats a short unicode string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(nativeString("abcd")), repr("abcd"))
def test_formatRootLongUnicodeString(self) -> None:
"""
The C{_formatRoot} method formats a long unicode string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = nativeString("abcde-" * 20)
self.assertEqual(
e._formatRoot(longString),
repr("abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-"),
)
def test_formatRootShortByteString(self) -> None:
"""
The C{_formatRoot} method formats a short byte string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(b"abcd"), repr(b"abcd"))
def test_formatRootLongByteString(self) -> None:
"""
The C{_formatRoot} method formats a long byte string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = b"abcde-" * 20
self.assertEqual(
e._formatRoot(longString),
repr(b"abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-"),
)
def test_formatRootTagNoFilename(self) -> None:
"""
The C{_formatRoot} method formats a C{Tag} with no filename information
as 'Tag <tagName>'.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(Tag("a-tag")), "Tag <a-tag>")
def test_formatRootTagWithFilename(self) -> None:
"""
The C{_formatRoot} method formats a C{Tag} with filename information
using the filename, line, column, and tag information
"""
e = self.makeFlattenerError()
t = Tag("a-tag", filename="tpl.py", lineNumber=10, columnNumber=20)
self.assertEqual(
e._formatRoot(t), 'File "tpl.py", line 10, column 20, in "a-tag"'
)
def test_string(self) -> None:
"""
If a L{FlattenerError} is created with a string root, up to around 40
bytes from that string are included in the string representation of the
exception.
"""
self.assertEqual(
str(error.FlattenerError(RuntimeError("reason"), ["abc123xyz"], [])),
"Exception while flattening:\n" " 'abc123xyz'\n" "RuntimeError: reason\n",
)
self.assertEqual(
str(error.FlattenerError(RuntimeError("reason"), ["0123456789" * 10], [])),
"Exception while flattening:\n"
" '01234567890123456789"
"<...>01234567890123456789'\n" # TODO: re-add 0
"RuntimeError: reason\n",
)
def test_unicode(self) -> None:
"""
If a L{FlattenerError} is created with a unicode root, up to around 40
characters from that string are included in the string representation
of the exception.
"""
self.assertEqual(
str(
error.FlattenerError(RuntimeError("reason"), ["abc\N{SNOWMAN}xyz"], [])
),
"Exception while flattening:\n"
" 'abc\\u2603xyz'\n" # Codepoint for SNOWMAN
"RuntimeError: reason\n",
)
self.assertEqual(
str(
error.FlattenerError(
RuntimeError("reason"), ["01234567\N{SNOWMAN}9" * 10], []
)
),
"Exception while flattening:\n"
" '01234567\\u2603901234567\\u26039"
"<...>01234567\\u2603901234567"
"\\u26039'\n"
"RuntimeError: reason\n",
)
class UnsupportedMethodTests(unittest.SynchronousTestCase):
"""
Tests for L{UnsupportedMethod}.
"""
def test_str(self) -> None:
"""
The C{__str__} for L{UnsupportedMethod} makes it clear that what it
shows is a list of the supported methods, not the method that was
unsupported.
"""
e = error.UnsupportedMethod([b"HEAD", b"PATCH"])
self.assertEqual(
str(e),
"Expected one of [b'HEAD', b'PATCH']",
)

View File

@@ -0,0 +1,767 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the flattening portion of L{twisted.web.template}, implemented in
L{twisted.web._flatten}.
"""
import re
import sys
import traceback
from collections import OrderedDict
from textwrap import dedent
from types import FunctionType
from typing import Callable, Dict, List, NoReturn, Optional, Tuple, cast
from xml.etree.ElementTree import XML
from zope.interface import implementer
from hamcrest import assert_that, equal_to
from twisted.internet.defer import (
CancelledError,
Deferred,
gatherResults,
passthru,
succeed,
)
from twisted.python.failure import Failure
from twisted.test.testutils import XMLAssertionMixin
from twisted.trial.unittest import SynchronousTestCase
from twisted.web._flatten import BUFFER_SIZE
from twisted.web.error import FlattenerError, UnfilledSlot, UnsupportedType
from twisted.web.iweb import IRenderable, IRequest, ITemplateLoader
from twisted.web.template import (
CDATA,
CharRef,
Comment,
Element,
Flattenable,
Tag,
TagLoader,
flatten,
flattenString,
renderer,
slot,
tags,
)
from twisted.web.test._util import FlattenTestCase
IS_PYTHON_313 = sys.version_info[:2] >= (3, 13)
class SerializationTests(FlattenTestCase, XMLAssertionMixin):
"""
Tests for flattening various things.
"""
def test_nestedTags(self) -> None:
"""
Test that nested tags flatten correctly.
"""
self.assertFlattensImmediately(
tags.html(tags.body("42"), hi="there"),
b'<html hi="there"><body>42</body></html>',
)
def test_serializeString(self) -> None:
"""
Test that strings will be flattened and escaped correctly.
"""
self.assertFlattensImmediately("one", b"one"),
self.assertFlattensImmediately("<abc&&>123", b"&lt;abc&amp;&amp;&gt;123"),
def test_serializeSelfClosingTags(self) -> None:
"""
The serialized form of a self-closing tag is C{'<tagName />'}.
"""
self.assertFlattensImmediately(tags.img(), b"<img />")
def test_serializeAttribute(self) -> None:
"""
The serialized form of attribute I{a} with value I{b} is C{'a="b"'}.
"""
self.assertFlattensImmediately(tags.img(src="foo"), b'<img src="foo" />')
def test_serializedMultipleAttributes(self) -> None:
"""
Multiple attributes are separated by a single space in their serialized
form.
"""
tag = tags.img()
tag.attributes = OrderedDict([("src", "foo"), ("name", "bar")])
self.assertFlattensImmediately(tag, b'<img src="foo" name="bar" />')
def checkAttributeSanitization(
self,
wrapData: Callable[[str], Flattenable],
wrapTag: Callable[[Tag], Flattenable],
) -> None:
"""
Common implementation of L{test_serializedAttributeWithSanitization}
and L{test_[AWS-SECRET-REMOVED]ion},
L{test_serializedAttributeWithTransparentTag}.
@param wrapData: A 1-argument callable that wraps around the
attribute's value so other tests can customize it.
@param wrapTag: A 1-argument callable that wraps around the outer tag
so other tests can customize it.
"""
self.assertFlattensImmediately(
wrapTag(tags.img(src=wrapData('<>&"'))),
b'<img src="&lt;&gt;&amp;&quot;" />',
)
def test_serializedAttributeWithSanitization(self) -> None:
"""
Attribute values containing C{"<"}, C{">"}, C{"&"}, or C{'"'} have
C{"&lt;"}, C{"&gt;"}, C{"&amp;"}, or C{"&quot;"} substituted for those
bytes in the serialized output.
"""
self.checkAttributeSanitization(passthru, passthru)
def test_[AWS-SECRET-REMOVED]ion(self) -> None:
"""
Like L{test_serializedAttributeWithSanitization}, but when the contents
of the attribute are in a L{Deferred
<twisted.internet.defer.Deferred>}.
"""
self.checkAttributeSanitization(succeed, passthru)
def test_[AWS-SECRET-REMOVED]ion(self) -> None:
"""
Like L{test_serializedAttributeWithSanitization} but with a slot.
"""
toss = []
def insertSlot(value: str) -> Flattenable:
toss.append(value)
return slot("stuff")
def fillSlot(tag: Tag) -> Tag:
return tag.fillSlots(stuff=toss.pop())
self.checkAttributeSanitization(insertSlot, fillSlot)
def test_serializedAttributeWithTransparentTag(self) -> None:
"""
Attribute values which are supplied via the value of a C{t:transparent}
tag have the same substitution rules to them as values supplied
directly.
"""
self.checkAttributeSanitization(tags.transparent, passthru)
def test_[AWS-SECRET-REMOVED]hRenderer(self) -> None:
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is rendered by a renderer on an element.
"""
class WithRenderer(Element):
def __init__(self, value: str, loader: Optional[ITemplateLoader]) -> None:
self.value = value
super().__init__(loader)
@renderer
def stuff(self, request: Optional[IRequest], tag: Tag) -> Flattenable:
return self.value
toss = []
def insertRenderer(value: str) -> Flattenable:
toss.append(value)
return tags.transparent(render="stuff")
def render(tag: Tag) -> Flattenable:
return WithRenderer(toss.pop(), TagLoader(tag))
self.checkAttributeSanitization(insertRenderer, render)
def test_serializedAttributeWithRenderable(self) -> None:
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is a provider of L{IRenderable} rather than a transparent
tag.
"""
@implementer(IRenderable)
class Arbitrary:
def __init__(self, value: Flattenable) -> None:
self.value = value
def render(self, request: Optional[IRequest]) -> Flattenable:
return self.value
def lookupRenderMethod(
self, name: str
) -> Callable[[Optional[IRequest], Tag], Flattenable]:
raise NotImplementedError("Unexpected call")
self.checkAttributeSanitization(Arbitrary, passthru)
def checkTagAttributeSerialization(
self, wrapTag: Callable[[Tag], Flattenable]
) -> None:
"""
Common implementation of L{test_serializedAttributeWithTag} and
L{test_serializedAttributeWithDeferredTag}.
@param wrapTag: A 1-argument callable that wraps around the attribute's
value so other tests can customize it.
@type wrapTag: callable taking L{Tag} and returning something
flattenable
"""
innerTag = tags.a('<>&"')
outerTag = tags.img(src=wrapTag(innerTag))
outer = self.assertFlattensImmediately(
outerTag,
b'<img src="&lt;a&gt;&amp;lt;&amp;gt;&amp;amp;&quot;&lt;/a&gt;" />',
)
inner = self.assertFlattensImmediately(innerTag, b'<a>&lt;&gt;&amp;"</a>')
# Since the above quoting is somewhat tricky, validate it by making sure
# that the main use-case for tag-within-attribute is supported here: if
# we serialize a tag, it is quoted *such that it can be parsed out again
# as a tag*.
self.assertXMLEqual(XML(outer).attrib["src"], inner)
def test_serializedAttributeWithTag(self) -> None:
"""
L{Tag} objects which are serialized within the context of an attribute
are serialized such that the text content of the attribute may be
parsed to retrieve the tag.
"""
self.checkTagAttributeSerialization(passthru)
def test_serializedAttributeWithDeferredTag(self) -> None:
"""
Like L{test_serializedAttributeWithTag}, but when the L{Tag} is in a
L{Deferred <twisted.internet.defer.Deferred>}.
"""
self.checkTagAttributeSerialization(succeed)
def test_serializedAttributeWithTagWithAttribute(self) -> None:
"""
Similar to L{test_serializedAttributeWithTag}, but for the additional
complexity where the tag which is the attribute value itself has an
attribute value which contains bytes which require substitution.
"""
flattened = self.assertFlattensImmediately(
tags.img(src=tags.a(href='<>&"')),
b'<img src="&lt;a href='
b"&quot;&amp;lt;&amp;gt;&amp;amp;&amp;quot;&quot;&gt;"
b'&lt;/a&gt;" />',
)
# As in checkTagAttributeSerialization, belt-and-suspenders:
self.assertXMLEqual(
XML(flattened).attrib["src"], b'<a href="&lt;&gt;&amp;&quot;"></a>'
)
def test_serializeComment(self) -> None:
"""
Test that comments are correctly flattened and escaped.
"""
self.assertFlattensImmediately(Comment("foo bar"), b"<!--foo bar-->")
def test_commentEscaping(self) -> Deferred[List[bytes]]:
"""
The data in a L{Comment} is escaped and mangled in the flattened output
so that the result can be safely included in an HTML document.
Test that C{>} is escaped when the sequence C{-->} is encountered
within a comment, and that comments do not end with C{-}.
"""
def verifyComment(c: bytes) -> None:
self.assertTrue(
c.startswith(b"<!--"),
f"{c!r} does not start with the comment prefix",
)
self.assertTrue(
c.endswith(b"-->"),
f"{c!r} does not end with the comment suffix",
)
# If it is shorter than 7, then the prefix and suffix overlap
# illegally.
self.assertTrue(len(c) >= 7, f"{c!r} is too short to be a legal comment")
content = c[4:-3]
if b"foo" in content:
self.assertIn(b">", content)
else:
self.assertNotIn(b">", content)
if content:
self.assertNotEqual(content[-1], b"-")
results = []
for c in [
"",
"foo > bar",
"abracadabra-",
"not-->magic",
]:
d = flattenString(None, Comment(c))
d.addCallback(verifyComment)
results.append(d)
return gatherResults(results)
def test_serializeCDATA(self) -> None:
"""
Test that CDATA is correctly flattened and escaped.
"""
self.assertFlattensImmediately(CDATA("foo bar"), b"<![CDATA[foo bar]]>"),
self.assertFlattensImmediately(
CDATA("foo ]]> bar"), b"<![CDATA[foo ]]]]><![CDATA[> bar]]>"
)
def test_serializeUnicode(self) -> None:
"""
Test that unicode is encoded correctly in the appropriate places, and
raises an error when it occurs in inappropriate place.
"""
snowman = "\N{SNOWMAN}"
self.assertFlattensImmediately(snowman, b"\xe2\x98\x83")
self.assertFlattensImmediately(tags.p(snowman), b"<p>\xe2\x98\x83</p>")
self.assertFlattensImmediately(Comment(snowman), b"<!--\xe2\x98\x83-->")
self.assertFlattensImmediately(CDATA(snowman), b"<![CDATA[\xe2\x98\x83]]>")
self.assertFlatteningRaises(Tag(snowman), UnicodeEncodeError)
self.assertFlatteningRaises(
Tag("p", attributes={snowman: ""}), UnicodeEncodeError
)
def test_serializeCharRef(self) -> None:
"""
A character reference is flattened to a string using the I{&#NNNN;}
syntax.
"""
ref = CharRef(ord("\N{SNOWMAN}"))
self.assertFlattensImmediately(ref, b"&#9731;")
def test_serializeDeferred(self) -> None:
"""
Test that a deferred is substituted with the current value in the
callback chain when flattened.
"""
self.assertFlattensImmediately(succeed("two"), b"two")
def test_serializeSameDeferredTwice(self) -> None:
"""
Test that the same deferred can be flattened twice.
"""
d = succeed("three")
self.assertFlattensImmediately(d, b"three")
self.assertFlattensImmediately(d, b"three")
def test_serializeCoroutine(self) -> None:
"""
Test that a coroutine returning a value is substituted with the that
value when flattened.
"""
from textwrap import dedent
namespace: Dict[str, FunctionType] = {}
exec(
dedent(
"""
async def coro(x):
return x
"""
),
namespace,
)
coro = namespace["coro"]
self.assertFlattensImmediately(coro("four"), b"four")
def test_serializeCoroutineWithAwait(self) -> None:
"""
Test that a coroutine returning an awaited deferred value is
substituted with that value when flattened.
"""
from textwrap import dedent
namespace = dict(succeed=succeed)
exec(
dedent(
"""
async def coro(x):
return await succeed(x)
"""
),
namespace,
)
coro = namespace["coro"]
self.assertFlattensImmediately(coro("four"), b"four")
def test_serializeIRenderable(self) -> None:
"""
Test that flattening respects all of the IRenderable interface.
"""
@implementer(IRenderable)
class FakeElement:
def render(ign, ored: object) -> Tag:
return tags.p(
"hello, ",
tags.transparent(render="test"),
" - ",
tags.transparent(render="test"),
)
def lookupRenderMethod(
ign, name: str
) -> Callable[[Optional[IRequest], Tag], Flattenable]:
self.assertEqual(name, "test")
return lambda ign, node: node("world")
self.assertFlattensImmediately(FakeElement(), b"<p>hello, world - world</p>")
def test_serializeMissingRenderFactory(self) -> None:
"""
Test that flattening a tag with a C{render} attribute when no render
factory is available in the context raises an exception.
"""
self.assertFlatteningRaises(tags.transparent(render="test"), ValueError)
def test_serializeSlots(self) -> None:
"""
Test that flattening a slot will use the slot value from the tag.
"""
t1 = tags.p(slot("test"))
t2 = t1.clone()
t2.fillSlots(test="hello, world")
self.assertFlatteningRaises(t1, UnfilledSlot)
self.assertFlattensImmediately(t2, b"<p>hello, world</p>")
def test_serializeDeferredSlots(self) -> None:
"""
Test that a slot with a deferred as its value will be flattened using
the value from the deferred.
"""
t = tags.p(slot("test"))
t.fillSlots(test=succeed(tags.em("four>")))
self.assertFlattensImmediately(t, b"<p><em>four&gt;</em></p>")
def test_unknownTypeRaises(self) -> None:
"""
Test that flattening an unknown type of thing raises an exception.
"""
self.assertFlatteningRaises(None, UnsupportedType) # type: ignore[arg-type]
class FlattenChunkingTests(SynchronousTestCase):
"""
Tests for the way pieces of the result are chunked together in calls to
the write function.
"""
def test_oneSmallChunk(self) -> None:
"""
If the entire value to be flattened is available synchronously and fits
into the buffer it is all passed to a single call to the write
function.
"""
output: List[bytes] = []
self.successResultOf(flatten(None, ["1", "2", "3"], output.append))
assert_that(output, equal_to([b"123"]))
def test_someLargeChunks(self) -> None:
"""
If the entire value to be flattened is available synchronously but does
not fit into the buffer then it is chunked into buffer-sized pieces
and these are passed to the write function.
"""
some = ["x"] * BUFFER_SIZE
someMore = ["y"] * BUFFER_SIZE
evenMore = ["z"] * BUFFER_SIZE
output: List[bytes] = []
self.successResultOf(flatten(None, [some, someMore, evenMore], output.append))
assert_that(
output,
equal_to([b"x" * BUFFER_SIZE, b"y" * BUFFER_SIZE, b"z" * BUFFER_SIZE]),
)
def _chunksSeparatedByAsyncTest(
self,
start: Callable[
[Flattenable], Tuple[Deferred[Flattenable], Callable[[], object]]
],
) -> None:
"""
Assert that flattening with a L{Deferred} returned by C{start} results
in the expected buffering behavior.
The L{Deferred} need not have a result by it is returned by C{start}
but must have a result after the callable returned along with it is
called.
The expected buffering behavior is that flattened values up to the
L{Deferred} are written together and then the result of the
L{Deferred} is written together with values following it up to the
next L{Deferred}.
"""
first_wait, first_finish = start("first-")
second_wait, second_finish = start("second-")
value = [
"already-available",
"-chunks",
first_wait,
"chunks-already-",
"computed",
second_wait,
"more-chunks-",
"already-available",
]
output: List[bytes] = []
d = flatten(None, value, output.append)
first_finish()
second_finish()
self.successResultOf(d)
assert_that(
output,
equal_to(
[
b"already-available-chunks",
b"first-chunks-already-computed",
b"second-more-chunks-already-available",
]
),
)
def test_chunksSeparatedByFiredDeferred(self) -> None:
"""
When a fired L{Deferred} is encountered any buffered data is
passed to the write function. Then the L{Deferred}'s result is passed
to another write along with following synchronous values.
This exact buffering behavior should be considered an implementation
detail and can be replaced by some other better behavior in the future
if someone wants.
"""
def sync_start(
v: Flattenable,
) -> Tuple[Deferred[Flattenable], Callable[[], None]]:
return (succeed(v), lambda: None)
self._chunksSeparatedByAsyncTest(sync_start)
def test_chunksSeparatedByUnfiredDeferred(self) -> None:
"""
When an unfired L{Deferred} is encountered any buffered data is
passed to the write function. After the result of the L{Deferred} is
available it is passed to another write along with following
synchronous values.
"""
def async_start(
v: Flattenable,
) -> Tuple[Deferred[Flattenable], Callable[[], None]]:
d: Deferred[Flattenable] = Deferred()
return (d, lambda: d.callback(v))
self._chunksSeparatedByAsyncTest(async_start)
# Use the co_filename mechanism (instead of the __file__ mechanism) because
# it is the mechanism traceback formatting uses. The two do not necessarily
# agree with each other. This requires a code object compiled in this file.
# The easiest way to get a code object is with a new function. I'll use a
# lambda to avoid adding anything else to this namespace. The result will
# be a string which agrees with the one the traceback module will put into a
# traceback for frames associated with functions defined in this file.
HERE = (lambda: None).__code__.co_filename
class FlattenerErrorTests(SynchronousTestCase):
"""
Tests for L{FlattenerError}.
"""
def test_renderable(self) -> None:
"""
If a L{FlattenerError} is created with an L{IRenderable} provider root,
the repr of that object is included in the string representation of the
exception.
"""
@implementer(IRenderable)
class Renderable: # type: ignore[misc]
def __repr__(self) -> str:
return "renderable repr"
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [Renderable()], [])),
"Exception while flattening:\n"
" renderable repr\n"
"RuntimeError: reason\n",
)
def test_tag(self) -> None:
"""
If a L{FlattenerError} is created with a L{Tag} instance with source
location information, the source location is included in the string
representation of the exception.
"""
tag = Tag("div", filename="/foo/filename.xhtml", lineNumber=17, columnNumber=12)
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [tag], [])),
"Exception while flattening:\n"
' File "/foo/filename.xhtml", line 17, column 12, in "div"\n'
"RuntimeError: reason\n",
)
def test_tagWithoutLocation(self) -> None:
"""
If a L{FlattenerError} is created with a L{Tag} instance without source
location information, only the tagName is included in the string
representation of the exception.
"""
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [Tag("span")], [])),
"Exception while flattening:\n" " Tag <span>\n" "RuntimeError: reason\n",
)
def test_traceback(self) -> None:
"""
If a L{FlattenerError} is created with traceback frames, they are
included in the string representation of the exception.
"""
# Try to be realistic in creating the data passed in for the traceback
# frames.
def f() -> None:
g()
def g() -> NoReturn:
raise RuntimeError("reason")
try:
f()
except RuntimeError as e:
# Get the traceback, minus the info for *this* frame
tbinfo = traceback.extract_tb(sys.exc_info()[2])[1:]
exc = e
else:
self.fail("f() must raise RuntimeError")
if IS_PYTHON_313:
column_marker = " ~^^\n"
else:
column_marker = ""
self.assertEqual(
str(FlattenerError(exc, [], tbinfo)),
"Exception while flattening:\n"
' File "%s", line %d, in f\n'
" g()\n"
"%s"
' File "%s", line %d, in g\n'
' raise RuntimeError("reason")\n'
"RuntimeError: reason\n"
% (
HERE,
f.__code__.co_firstlineno + 1,
column_marker,
HERE,
g.__code__.co_firstlineno + 1,
),
)
def test_asynchronousFlattenError(self) -> None:
"""
When flattening a renderer which raises an exception asynchronously,
the error is reported when it occurs.
"""
failing: Deferred[object] = Deferred()
@implementer(IRenderable)
class NotActuallyRenderable:
"No methods provided; this will fail"
def __repr__(self) -> str:
return "<unrenderable>"
def lookupRenderMethod( # type: ignore[empty-body]
self, name: str
) -> Callable[[Optional[IRequest], Tag], Flattenable]:
...
def render(self, request: Optional[IRequest]) -> Flattenable:
return failing
flattening = flattenString(None, [NotActuallyRenderable()])
self.assertNoResult(flattening)
exc = RuntimeError("example")
failing.errback(exc)
failure = self.failureResultOf(flattening, FlattenerError)
if IS_PYTHON_313:
column_marker = ".*\n.*\n.*\nRuntimeError: example\n"
else:
column_marker = ""
self.assertRegex(
str(failure.value),
re.compile(
dedent(
"""\
Exception while flattening:
\\[<unrenderable>\\]
<unrenderable>
<Deferred at .* current result: <twisted.python.failure.Failure builtins.RuntimeError: example>>
File ".*", line \\d*, in _flattenTree
element = await element.*
"""
)
+ column_marker,
flags=re.MULTILINE,
),
)
self.assertIn("RuntimeError: example", str(failure.value))
# The original exception is unmodified and will be logged separately if
# unhandled.
self.failureResultOf(failing, RuntimeError)
def test_cancel(self) -> None:
"""
The flattening of a Deferred can be cancelled.
"""
cancelCount = 0
cancelArg = None
def checkCancel(cancelled: Deferred[object]) -> None:
nonlocal cancelArg, cancelCount
cancelArg = cancelled
cancelCount += 1
err = None
def saveErr(failure: Failure) -> None:
nonlocal err
err = failure
d: Deferred[object] = Deferred(checkCancel)
flattening = flattenString(None, d)
self.assertNoResult(flattening)
d.addErrback(saveErr)
flattening.cancel()
# Check whether we got an orderly cancellation.
# Do this first to get more meaningful reporting if something crashed.
failure = self.failureResultOf(flattening, FlattenerError)
self.assertEqual(cancelCount, 1)
self.assertIs(cancelArg, d)
self.assertIsInstance(err, Failure)
self.assertIsInstance(cast(Failure, err).value, CancelledError)
exc = failure.value.args[0]
self.assertIsInstance(exc, CancelledError)

View File

@@ -0,0 +1,41 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.web import html
class WebHtmlTests(unittest.TestCase):
"""
Unit tests for L{twisted.web.html}.
"""
def test_deprecation(self) -> None:
"""
Calls to L{twisted.web.html} members emit a deprecation warning.
"""
def assertDeprecationWarningOf(method: str) -> None:
"""
Check that a deprecation warning is present.
"""
warningsShown = self.flushWarnings([self.test_deprecation])
self.assertEqual(len(warningsShown), 1)
self.assertIdentical(warningsShown[0]["category"], DeprecationWarning)
self.assertEqual(
warningsShown[0]["message"],
"twisted.web.html.%s was deprecated in Twisted 15.3.0; "
"please use twisted.web.template instead" % (method,),
)
html.PRE("")
assertDeprecationWarningOf("PRE")
html.UL([])
assertDeprecationWarningOf("UL")
html.linkList([])
assertDeprecationWarningOf("linkList")
html.output(lambda: None)
assertDeprecationWarningOf("output")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,694 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.http_headers}.
"""
from __future__ import annotations
from typing import Sequence
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.web.http_headers import Headers, InvalidHeaderName, _nameEncoder
from twisted.web.test.requesthelper import (
bytesLinearWhitespaceComponents,
sanitizedBytes,
textLinearWhitespaceComponents,
)
class NameEncoderTests(SynchronousTestCase):
"""
Test L{twisted.web.http_headers._NameEncoder}
"""
def test_encodeName(self) -> None:
"""
L{_NameEncoder.encode} returns the canonical capitalization for
the given header.
"""
self.assertEqual(_nameEncoder.encode(b"test"), b"Test")
self.assertEqual(_nameEncoder.encode(b"test-stuff"), b"Test-Stuff")
self.assertEqual(_nameEncoder.encode(b"content-md5"), b"Content-MD5")
self.assertEqual(_nameEncoder.encode(b"dnt"), b"DNT")
self.assertEqual(_nameEncoder.encode(b"etag"), b"ETag")
self.assertEqual(_nameEncoder.encode(b"p3p"), b"P3P")
self.assertEqual(_nameEncoder.encode(b"te"), b"TE")
self.assertEqual(_nameEncoder.encode(b"www-authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode(b"WWW-authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode(b"Www-Authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode(b"x-xss-protection"), b"X-XSS-Protection")
def test_encodeNameStr(self) -> None:
"""
L{_NameEncoder.encode} returns the canonical capitalization for
a header name given as a L{str}.
"""
self.assertEqual(_nameEncoder.encode("test"), b"Test")
self.assertEqual(_nameEncoder.encode("test-stuff"), b"Test-Stuff")
self.assertEqual(_nameEncoder.encode("content-md5"), b"Content-MD5")
self.assertEqual(_nameEncoder.encode("dnt"), b"DNT")
self.assertEqual(_nameEncoder.encode("etag"), b"ETag")
self.assertEqual(_nameEncoder.encode("p3p"), b"P3P")
self.assertEqual(_nameEncoder.encode("te"), b"TE")
self.assertEqual(_nameEncoder.encode("www-authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode("WWW-authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode("Www-Authenticate"), b"WWW-Authenticate")
self.assertEqual(_nameEncoder.encode("x-xss-protection"), b"X-XSS-Protection")
def test_maxCachedHeaders(self) -> None:
"""
Only a limited number of HTTP header names get cached.
"""
headers = Headers()
for i in range(_nameEncoder._MAX_CACHED_HEADERS + 200):
headers.addRawHeader(f"hello-{i}", "value")
self.assertEqual(
len(_nameEncoder._canonicalHeaderCache), _nameEncoder._MAX_CACHED_HEADERS
)
def assertSanitized(
testCase: TestCase, components: Sequence[bytes] | Sequence[str], expected: bytes
) -> None:
"""
Assert that the components are sanitized to the expected value as
both a header value, across all of L{Header}'s setters and getters.
@param testCase: A test case.
@param components: A sequence of values that contain linear
whitespace to use as header values; see
C{textLinearWhitespaceComponents} and
C{bytesLinearWhitespaceComponents}
@param expected: The expected sanitized form of the component as
a header value.
"""
name = b"Name"
for component in components:
headers = []
headers.append(Headers({name: [component]})) # type: ignore[misc]
added = Headers()
added.addRawHeader(name, component)
headers.append(added)
setHeader = Headers()
setHeader.setRawHeaders(name, [component])
headers.append(setHeader)
for header in headers:
testCase.assertEqual(list(header.getAllRawHeaders()), [(name, [expected])])
testCase.assertEqual(header.getRawHeaders(name), [expected])
class BytesHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{bytes} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self) -> None:
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, bytesLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self) -> None:
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}.
"""
h = Headers({b"Foo": [b"bar"]})
self.assertEqual(h.getRawHeaders(b"foo"), [b"bar"])
def test_setRawHeaders(self) -> None:
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of byte string values.
"""
rawValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders(b"test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertEqual(h.getRawHeaders(b"test"), rawValue)
def test_addRawHeader(self) -> None:
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(b"test", b"lemur")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
h.addRawHeader(b"test", b"panda")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self) -> None:
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders(b"test"))
def test_getRawHeadersDefaultValue(self) -> None:
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(b"test", default), default)
def test_getRawHeadersWithDefaultMatchingValue(self) -> None:
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains encoded values.
"""
h = Headers()
default = ["value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders(b"key", default)[0], bytes)
self.assertEqual(h.getRawHeaders(b"key", default), [b"value"])
def test_getRawHeaders(self) -> None:
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test"), [b"lemur"])
def test_hasHeaderTrue(self) -> None:
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
def test_hasHeaderFalse(self) -> None:
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(b"test"))
def test_removeHeader(self) -> None:
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(b"foo", [b"lemur"])
self.assertTrue(h.hasHeader(b"foo"))
h.removeHeader(b"foo")
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(b"bar", [b"panda"])
self.assertTrue(h.hasHeader(b"bar"))
h.removeHeader(b"Bar")
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self) -> None:
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(b"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_getAllRawHeaders(self) -> None:
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemurs"])
h.setRawHeaders(b"www-authenticate", [b"basic aksljdlk="])
allHeaders = {(k, tuple(v)) for k, v in h.getAllRawHeaders()}
self.assertEqual(
allHeaders,
{(b"WWW-Authenticate", (b"basic aksljdlk=",)), (b"Test", (b"lemurs",))},
)
def test_headersComparison(self) -> None:
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(b"foo", [b"panda"])
second = Headers()
second.setRawHeaders(b"foo", [b"panda"])
third = Headers()
third.setRawHeaders(b"foo", [b"lemur", b"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
def test_otherComparison(self) -> None:
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, b"foo")
def test_repr(self) -> None:
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
f"Headers({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)
def test_reprWithRawBytes(self) -> None:
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains, not attempting to decode any raw bytes.
"""
# There's no such thing as undecodable latin-1, you'll just get
# some mojibake
foo = b"foo"
# But this is invalid UTF-8! So, any accidental decoding/encoding will
# throw an exception.
bar = b"bar\xe1"
baz = b"baz\xe1"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
f"Headers({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)
def test_subclassRepr(self) -> None:
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
f"FunnyHeaders({{{foo.capitalize()!r}: [{bar!r}, {baz!r}]}})",
)
def test_copy(self) -> None:
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(b"test", [b"foo"])
i = h.copy()
self.assertEqual(i.getRawHeaders(b"test"), [b"foo"])
h.addRawHeader(b"test", b"bar")
self.assertEqual(i.getRawHeaders(b"test"), [b"foo"])
i.addRawHeader(b"test", b"baz")
self.assertEqual(h.getRawHeaders(b"test"), [b"foo", b"bar"])
class UnicodeHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{str} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self) -> None:
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, textLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self) -> None:
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}. If a L{bytes} argument is given, it returns
L{bytes} values, and if a L{str} argument is given, it returns
L{str} values. Both are the same header value, just encoded or
decoded.
"""
h = Headers({"Foo": ["bar"]})
self.assertEqual(h.getRawHeaders(b"foo"), [b"bar"])
self.assertEqual(h.getRawHeaders("foo"), ["bar"])
def test_setRawHeaders(self) -> None:
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of strings, encoded.
"""
rawValue = ["value1", "value2"]
rawEncodedValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders("test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertTrue(h.hasHeader("test"))
self.assertTrue(h.hasHeader("Test"))
self.assertEqual(h.getRawHeaders("test"), rawValue)
self.assertEqual(h.getRawHeaders(b"test"), rawEncodedValue)
def test_nameNotEncodable(self) -> None:
"""
Passing L{str} to any function that takes a header name will encode
said header name as ISO-8859-1, and if it cannot be encoded, it will
raise a L{UnicodeDecodeError}.
"""
h = Headers()
# Only these two functions take names
with self.assertRaises(UnicodeEncodeError):
h.setRawHeaders("\u2603", ["val"])
with self.assertRaises(UnicodeEncodeError):
h.hasHeader("\u2603")
def test_nameNotToken(self) -> None:
"""
HTTP header names must be tokens, so any names containing non-token
characters raises L{InvalidHeaderName}
"""
h = Headers()
# A non-token character within ISO-8851-1
self.assertRaises(InvalidHeaderName, h.setRawHeaders, b"\xe1", [b"val"])
self.assertRaises(InvalidHeaderName, h.setRawHeaders, "\u00e1", [b"val"])
# Whitespace
self.assertRaises(InvalidHeaderName, h.setRawHeaders, b"a b", [b"val"])
self.assertRaises(InvalidHeaderName, h.setRawHeaders, "c\nd", [b"val"])
self.assertRaises(InvalidHeaderName, h.setRawHeaders, "c\td", [b"val"])
def test_nameEncoding(self) -> None:
"""
Passing L{str} to any function that takes a header name will encode
said header name as ISO-8859-1.
"""
h = Headers()
# We set it using a Unicode string.
h.setRawHeaders("bar", [b"foo"])
# It's encoded to the ISO-8859-1 value, which we can use to access it
self.assertTrue(h.hasHeader(b"bar"))
self.assertEqual(h.getRawHeaders(b"bar"), [b"foo"])
# We can still access it using the Unicode string..
self.assertTrue(h.hasHeader("bar"))
def test_rawHeadersValueEncoding(self) -> None:
"""
Passing L{str} to L{Headers.setRawHeaders} will encode the name as
ISO-8859-1 and values as UTF-8.
"""
h = Headers()
h.setRawHeaders("x", ["\u2603", b"foo"])
self.assertTrue(h.hasHeader(b"x"))
self.assertEqual(h.getRawHeaders(b"x"), [b"\xe2\x98\x83", b"foo"])
def test_addRawHeader(self) -> None:
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader("test", "lemur")
self.assertEqual(h.getRawHeaders("test"), ["lemur"])
h.addRawHeader("test", "panda")
self.assertEqual(h.getRawHeaders("test"), ["lemur", "panda"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self) -> None:
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders("test"))
def test_getRawHeadersDefaultValue(self) -> None:
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders("test", default), default)
self.assertIdentical(h.getRawHeaders("test", None), None)
self.assertEqual(h.getRawHeaders("test", [None]), [None])
self.assertEqual(
h.getRawHeaders("test", ["\N{SNOWMAN}"]),
["\N{SNOWMAN}"],
)
def test_getRawHeadersWithDefaultMatchingValue(self) -> None:
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains decoded values.
"""
h = Headers()
default = [b"value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders("key", default)[0], str)
self.assertEqual(h.getRawHeaders("key", default), ["value"])
def test_getRawHeaders(self) -> None:
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders("test", ["lemur"])
self.assertEqual(h.getRawHeaders("test"), ["lemur"])
self.assertEqual(h.getRawHeaders("Test"), ["lemur"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test"), [b"lemur"])
def test_hasHeaderTrue(self) -> None:
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders("test", ["lemur"])
self.assertTrue(h.hasHeader("test"))
self.assertTrue(h.hasHeader("Test"))
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
def test_hasHeaderFalse(self) -> None:
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader("test"))
def test_removeHeader(self) -> None:
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders("foo", ["lemur"])
self.assertTrue(h.hasHeader("foo"))
h.removeHeader("foo")
self.assertFalse(h.hasHeader("foo"))
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders("bar", ["panda"])
self.assertTrue(h.hasHeader("bar"))
h.removeHeader("Bar")
self.assertFalse(h.hasHeader("bar"))
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self) -> None:
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader("test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_getAllRawHeaders(self) -> None:
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders("test", ["lemurs"])
h.setRawHeaders("www-authenticate", ["basic aksljdlk="])
h.setRawHeaders("content-md5", ["kjdfdfgdfgnsd"])
allHeaders = {(k, tuple(v)) for k, v in h.getAllRawHeaders()}
self.assertEqual(
allHeaders,
{
(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Content-MD5", (b"kjdfdfgdfgnsd",)),
(b"Test", (b"lemurs",)),
},
)
def test_headersComparison(self) -> None:
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders("foo", ["panda"])
second = Headers()
second.setRawHeaders("foo", ["panda"])
third = Headers()
third.setRawHeaders("foo", ["lemur", "panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
# Headers instantiated with bytes equivs are also the same
firstBytes = Headers()
firstBytes.setRawHeaders(b"foo", [b"panda"])
secondBytes = Headers()
secondBytes.setRawHeaders(b"foo", [b"panda"])
thirdBytes = Headers()
thirdBytes.setRawHeaders(b"foo", [b"lemur", "panda"])
self.assertEqual(first, firstBytes)
self.assertEqual(second, secondBytes)
self.assertEqual(third, thirdBytes)
def test_otherComparison(self) -> None:
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, "foo")
def test_repr(self) -> None:
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains. This shows only reprs of bytes values, as
undecodable headers may cause an exception.
"""
foo = "foo"
bar = "bar\u2603"
baz = "baz"
fooEncoded = "'Foo'"
barEncoded = "'bar\\xe2\\x98\\x83'"
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({{{}: [{}, {!r}]}})".format(
fooEncoded, barEncoded, baz.encode("utf8")
),
)
def test_subclassRepr(self) -> None:
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = "foo"
bar = "bar\u2603"
baz = "baz"
fooEncoded = "b'Foo'"
barEncoded = "b'bar\\xe2\\x98\\x83'"
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%s: [%s, %r]})"
% (fooEncoded, barEncoded, baz.encode("utf8")),
)
def test_copy(self) -> None:
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders("test", ["foo\u2603"])
i = h.copy()
# The copy contains the same value as the original
self.assertEqual(i.getRawHeaders("test"), ["foo\u2603"])
self.assertEqual(i.getRawHeaders(b"test"), [b"foo\xe2\x98\x83"])
# Add a header to the original
h.addRawHeader("test", "bar")
# Verify that the copy has not changed
self.assertEqual(i.getRawHeaders("test"), ["foo\u2603"])
self.assertEqual(i.getRawHeaders(b"test"), [b"foo\xe2\x98\x83"])
# Add a header to the copy
i.addRawHeader("Test", b"baz")
# Verify that the orignal does not have it
self.assertEqual(h.getRawHeaders("test"), ["foo\u2603", "bar"])
self.assertEqual(h.getRawHeaders(b"test"), [b"foo\xe2\x98\x83", b"bar"])
class MixedHeadersTests(TestCase):
"""
Tests for L{Headers}, mixing L{bytes} and L{str} arguments for methods
where that is permitted.
"""
def test_addRawHeader(self) -> None:
"""
L{Headers.addRawHeader} accepts mixed L{str} and L{bytes}.
"""
h = Headers()
h.addRawHeader(b"bytes", "str")
h.addRawHeader("str", b"bytes")
self.assertEqual(h.getRawHeaders(b"Bytes"), [b"str"])
self.assertEqual(h.getRawHeaders("Str"), ["bytes"])
def test_setRawHeaders(self) -> None:
"""
L{Headers.setRawHeaders} accepts mixed L{str} and L{bytes}.
"""
h = Headers()
h.setRawHeaders(b"bytes", [b"bytes"])
h.setRawHeaders("str", ["str"])
h.setRawHeaders("mixed-str", [b"bytes", "str"])
h.setRawHeaders(b"mixed-bytes", ["str", b"bytes"])
self.assertEqual(h.getRawHeaders(b"Bytes"), [b"bytes"])
self.assertEqual(h.getRawHeaders("Str"), ["str"])
self.assertEqual(h.getRawHeaders("Mixed-Str"), ["bytes", "str"])
self.assertEqual(h.getRawHeaders(b"Mixed-Bytes"), [b"str", b"bytes"])

View File

@@ -0,0 +1,644 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._auth}.
"""
import base64
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.cred import error, portal
from twisted.cred.checkers import (
ANONYMOUS,
AllowAnonymousAccess,
InMemoryUsernamePasswordDatabaseDontUse,
)
from twisted.cred.credentials import IUsernamePassword
from twisted.internet.address import IPv4Address
from twisted.internet.error import ConnectionDone
from twisted.internet.testing import EventLoggingObserver
from twisted.logger import globalLogPublisher
from twisted.python.failure import Failure
from twisted.trial import unittest
from twisted.web._auth import basic, digest
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper, UnauthorizedResource
from twisted.web.iweb import ICredentialFactory
from twisted.web.resource import IResource, Resource, getChildForRequest
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import Data
from twisted.web.test.test_web import DummyRequest
def b64encode(s):
return base64.b64encode(s).strip()
class BasicAuthTestsMixin:
"""
L{TestCase} mixin class which defines a number of tests for
L{basic.BasicCredentialFactory}. Because this mixin defines C{setUp}, it
must be inherited before L{TestCase}.
"""
def setUp(self):
self.request = self.makeRequest()
self.realm = b"foo"
self.username = b"dreid"
self.password = b"S3CuR1Ty"
self.credentialFactory = basic.BasicCredentialFactory(self.realm)
def makeRequest(self, method=b"GET", clientAddress=None):
"""
Create a request object to be passed to
L{basic.BasicCredentialFactory.decode} along with a response value.
Override this in a subclass.
"""
raise NotImplementedError(f"{self.__class__!r} did not implement makeRequest")
def test_interface(self):
"""
L{BasicCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(verifyObject(ICredentialFactory, self.credentialFactory))
def test_usernamePassword(self):
"""
L{basic.BasicCredentialFactory.decode} turns a base64-encoded response
into a L{UsernamePassword} object with a password which reflects the
one which was encoded in the response.
"""
response = b64encode(b"".join([self.username, b":", self.password]))
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(IUsernamePassword.providedBy(creds))
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + b"wrong"))
def test_incorrectPadding(self):
"""
L{basic.BasicCredentialFactory.decode} decodes a base64-encoded
response with incorrect padding.
"""
response = b64encode(b"".join([self.username, b":", self.password]))
response = response.strip(b"=")
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(verifyObject(IUsernamePassword, creds))
self.assertTrue(creds.checkPassword(self.password))
def test_invalidEncoding(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} if passed
a response which is not base64-encoded.
"""
response = b"x" # one byte cannot be valid base64 text
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode,
response,
self.makeRequest(),
)
def test_invalidCredentials(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} when
passed a response which is not valid base64-encoded text.
"""
response = b64encode(b"123abc+/")
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode,
response,
self.makeRequest(),
)
class RequestMixin:
def makeRequest(self, method=b"GET", clientAddress=None):
"""
Create a L{DummyRequest} (change me to create a
L{twisted.web.http.Request} instead).
"""
if clientAddress is None:
clientAddress = IPv4Address("TCP", "localhost", 1234)
request = DummyRequest(b"/")
request.method = method
request.client = clientAddress
return request
class BasicAuthTests(RequestMixin, BasicAuthTestsMixin, unittest.TestCase):
"""
Basic authentication tests which use L{twisted.web.http.Request}.
"""
class DigestAuthTests(RequestMixin, unittest.TestCase):
"""
Digest authentication tests which use L{twisted.web.http.Request}.
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.realm = b"test realm"
self.algorithm = b"md5"
self.credentialFactory = digest.DigestCredentialFactory(
self.algorithm, self.realm
)
self.request = self.makeRequest()
def test_decode(self):
"""
L{digest.DigestCredentialFactory.decode} calls the C{decode} method on
L{twisted.cred.digest.DigestCredentialFactory} with the HTTP method and
host of the request.
"""
host = b"169.254.0.1"
method = b"GET"
done = [False]
response = object()
def check(_response, _method, _host):
self.assertEqual(response, _response)
self.assertEqual(method, _method)
self.assertEqual(host, _host)
done[0] = True
self.patch(self.credentialFactory.digest, "decode", check)
req = self.makeRequest(method, IPv4Address("TCP", host, 81))
self.credentialFactory.decode(response, req)
self.assertTrue(done[0])
def test_interface(self):
"""
L{DigestCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(verifyObject(ICredentialFactory, self.credentialFactory))
def test_getChallenge(self):
"""
The challenge issued by L{DigestCredentialFactory.getChallenge} must
include C{'qop'}, C{'realm'}, C{'algorithm'}, C{'nonce'}, and
C{'opaque'} keys. The values for the C{'realm'} and C{'algorithm'}
keys must match the values supplied to the factory's initializer.
None of the values may have newlines in them.
"""
challenge = self.credentialFactory.getChallenge(self.request)
self.assertEqual(challenge["qop"], b"auth")
self.assertEqual(challenge["realm"], b"test realm")
self.assertEqual(challenge["algorithm"], b"md5")
self.assertIn("nonce", challenge)
self.assertIn("opaque", challenge)
for v in challenge.values():
self.assertNotIn(b"\n", v)
def test_getChallengeWithoutClientIP(self):
"""
L{DigestCredentialFactory.getChallenge} can issue a challenge even if
the L{Request} it is passed returns L{None} from C{getClientIP}.
"""
request = self.makeRequest(b"GET", None)
challenge = self.credentialFactory.getChallenge(request)
self.assertEqual(challenge["qop"], b"auth")
self.assertEqual(challenge["realm"], b"test realm")
self.assertEqual(challenge["algorithm"], b"md5")
self.assertIn("nonce", challenge)
self.assertIn("opaque", challenge)
class UnauthorizedResourceTests(RequestMixin, unittest.TestCase):
"""
Tests for L{UnauthorizedResource}.
"""
def test_getChildWithDefault(self):
"""
An L{UnauthorizedResource} is every child of itself.
"""
resource = UnauthorizedResource([])
self.assertIdentical(resource.getChildWithDefault("foo", None), resource)
self.assertIdentical(resource.getChildWithDefault("bar", None), resource)
def _unauthorizedRenderTest(self, request):
"""
Render L{UnauthorizedResource} for the given request object and verify
that the response code is I{Unauthorized} and that a I{WWW-Authenticate}
header is set in the response containing a challenge.
"""
resource = UnauthorizedResource([BasicCredentialFactory("example.com")])
request.render(resource)
self.assertEqual(request.responseCode, 401)
self.assertEqual(
request.responseHeaders.getRawHeaders(b"www-authenticate"),
[b'basic realm="example.com"'],
)
def test_render(self):
"""
L{UnauthorizedResource} renders with a 401 response code and a
I{WWW-Authenticate} header and puts a simple unauthorized message
into the response body.
"""
request = self.makeRequest()
self._unauthorizedRenderTest(request)
self.assertEqual(b"Unauthorized", b"".join(request.written))
def test_renderHEAD(self):
"""
The rendering behavior of L{UnauthorizedResource} for a I{HEAD} request
is like its handling of a I{GET} request, but no response body is
written.
"""
request = self.makeRequest(method=b"HEAD")
self._unauthorizedRenderTest(request)
self.assertEqual(b"", b"".join(request.written))
def test_renderQuotesRealm(self):
"""
The realm value included in the I{WWW-Authenticate} header set in
the response when L{UnauthorizedResounrce} is rendered has quotes
and backslashes escaped.
"""
resource = UnauthorizedResource([BasicCredentialFactory('example\\"foo')])
request = self.makeRequest()
request.render(resource)
self.assertEqual(
request.responseHeaders.getRawHeaders(b"www-authenticate"),
[b'basic realm="example\\\\\\"foo"'],
)
def test_renderQuotesDigest(self):
"""
The digest value included in the I{WWW-Authenticate} header
set in the response when L{UnauthorizedResource} is rendered
has quotes and backslashes escaped.
"""
resource = UnauthorizedResource(
[digest.DigestCredentialFactory(b"md5", b'example\\"foo')]
)
request = self.makeRequest()
request.render(resource)
authHeader = request.responseHeaders.getRawHeaders(b"www-authenticate")[0]
self.assertIn(b'realm="example\\\\\\"foo"', authHeader)
self.assertIn(b'hm="md5', authHeader)
implementer(portal.IRealm)
class Realm:
"""
A simple L{IRealm} implementation which gives out L{WebAvatar} for any
avatarId.
@type loggedIn: C{int}
@ivar loggedIn: The number of times C{requestAvatar} has been invoked for
L{IResource}.
@type loggedOut: C{int}
@ivar loggedOut: The number of times the logout callback has been invoked.
"""
def __init__(self, avatarFactory):
self.loggedOut = 0
self.loggedIn = 0
self.avatarFactory = avatarFactory
def requestAvatar(self, avatarId, mind, *interfaces):
if IResource in interfaces:
self.loggedIn += 1
return IResource, self.avatarFactory(avatarId), self.logout
raise NotImplementedError()
def logout(self):
self.loggedOut += 1
class HTTPAuthHeaderTests(unittest.TestCase):
"""
Tests for L{HTTPAuthSessionWrapper}.
"""
makeRequest = DummyRequest
def setUp(self):
"""
Create a realm, portal, and L{HTTPAuthSessionWrapper} to use in the tests.
"""
self.username = b"foo bar"
self.password = b"bar baz"
self.avatarContent = b"contents of the avatar resource itself"
self.childName = b"foo-child"
self.childContent = b"contents of the foo child of the avatar"
self.checker = InMemoryUsernamePasswordDatabaseDontUse()
self.checker.addUser(self.username, self.password)
self.avatar = Data(self.avatarContent, "text/plain")
self.avatar.putChild(self.childName, Data(self.childContent, "text/plain"))
self.avatars = {self.username: self.avatar}
self.realm = Realm(self.avatars.get)
self.portal = portal.Portal(self.realm, [self.checker])
self.credentialFactories = []
self.wrapper = HTTPAuthSessionWrapper(self.portal, self.credentialFactories)
def _authorizedBasicLogin(self, request):
"""
Add an I{basic authorization} header to the given request and then
dispatch it, starting from C{self.wrapper} and returning the resulting
L{IResource}.
"""
authorization = b64encode(self.username + b":" + self.password)
request.requestHeaders.addRawHeader(b"authorization", b"Basic " + authorization)
return getChildForRequest(self.wrapper, request)
def test_getChildWithDefault(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} instance when the request does
not have the required I{Authorization} headers.
"""
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def _invalidAuthorizationTest(self, response):
"""
Create a request with the given value as the value of an
I{Authorization} header and perform resource traversal with it,
starting at C{self.wrapper}. Assert that the result is a 401 response
code. Return a L{Deferred} which fires when this is all done.
"""
self.credentialFactories.append(BasicCredentialFactory("example.com"))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b"authorization", response)
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChildWithDefaultUnauthorizedUser(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which does not exist.
"""
return self._invalidAuthorizationTest(b"Basic " + b64encode(b"foo:bar"))
def test_getChildWithDefaultUnauthorizedPassword(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which exists and the wrong
password.
"""
return self._invalidAuthorizationTest(
b"Basic " + b64encode(self.username + b":bar")
)
def test_getChildWithDefaultUnrecognizedScheme(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with an unrecognized scheme.
"""
return self._invalidAuthorizationTest(b"Quux foo bar baz")
def test_getChildWithDefaultAuthorized(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{IResource} which renders the L{IResource} avatar
retrieved from the portal when the request has a valid I{Authorization}
header.
"""
self.credentialFactories.append(BasicCredentialFactory("example.com"))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.childContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_renderAuthorized(self):
"""
Resource traversal which terminates at an L{HTTPAuthSessionWrapper}
and includes correct authentication headers results in the
L{IResource} avatar (not one of its children) retrieved from the
portal being rendered.
"""
self.credentialFactories.append(BasicCredentialFactory("example.com"))
# Request it exactly, not any of its children.
request = self.makeRequest([])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.avatarContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChallengeCalledWithRequest(self):
"""
When L{HTTPAuthSessionWrapper} finds an L{ICredentialFactory} to issue
a challenge, it calls the C{getChallenge} method with the request as an
argument.
"""
@implementer(ICredentialFactory)
class DumbCredentialFactory:
scheme = b"dumb"
def __init__(self):
self.requests = []
def getChallenge(self, request):
self.requests.append(request)
return {}
factory = DumbCredentialFactory()
self.credentialFactories.append(factory)
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(factory.requests, [request])
d.addCallback(cbFinished)
request.render(child)
return d
def _logoutTest(self):
"""
Issue a request for an authentication-protected resource using valid
credentials and then return the C{DummyRequest} instance which was
used.
This is a helper for tests about the behavior of the logout
callback.
"""
self.credentialFactories.append(BasicCredentialFactory("example.com"))
class SlowerResource(Resource):
def render(self, request):
return NOT_DONE_YET
self.avatar.putChild(self.childName, SlowerResource())
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(self.realm.loggedOut, 0)
return request
def test_logout(self):
"""
The realm's logout callback is invoked after the resource is rendered.
"""
request = self._logoutTest()
request.finish()
self.assertEqual(self.realm.loggedOut, 1)
def test_logoutOnError(self):
"""
The realm's logout callback is also invoked if there is an error
generating the response (for example, if the client disconnects
early).
"""
request = self._logoutTest()
request.processingFailed(Failure(ConnectionDone("Simulated disconnect")))
self.assertEqual(self.realm.loggedOut, 1)
def test_decodeRaises(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has a I{Basic
Authorization} header which cannot be decoded using base64.
"""
self.credentialFactories.append(BasicCredentialFactory("example.com"))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(
b"authorization", b"Basic decode should fail"
)
child = getChildForRequest(self.wrapper, request)
self.assertIsInstance(child, UnauthorizedResource)
def test_selectParseResponse(self):
"""
L{HTTPAuthSessionWrapper._selectParseHeader} returns a two-tuple giving
the L{ICredentialFactory} to use to parse the header and a string
containing the portion of the header which remains to be parsed.
"""
basicAuthorization = b"Basic abcdef123456"
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization), (None, None)
)
factory = BasicCredentialFactory("example.com")
self.credentialFactories.append(factory)
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(factory, b"abcdef123456"),
)
def test_unexpectedDecodeError(self):
"""
Any unexpected exception raised by the credential factory's C{decode}
method results in a 500 response code and causes the exception to be
logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)
class UnexpectedException(Exception):
pass
class BadFactory:
scheme = b"bad"
def getChallenge(self, client):
return {}
def decode(self, response, request):
raise UnexpectedException()
self.credentialFactories.append(BadFactory())
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b"authorization", b"Bad abc")
child = getChildForRequest(self.wrapper, request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(logObserver[0]["log_failure"].value, UnexpectedException)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_unexpectedLoginError(self):
"""
Any unexpected failure from L{Portal.login} results in a 500 response
code and causes the failure to be logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)
class UnexpectedException(Exception):
pass
class BrokenChecker:
credentialInterfaces = (IUsernamePassword,)
def requestAvatarId(self, credentials):
raise UnexpectedException()
self.portal.registerChecker(BrokenChecker())
self.credentialFactories.append(BasicCredentialFactory("example.com"))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(logObserver[0]["log_failure"].value, UnexpectedException)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_anonymousAccess(self):
"""
Anonymous requests are allowed if a L{Portal} has an anonymous checker
registered.
"""
unprotectedContents = b"contents of the unprotected child resource"
self.avatars[ANONYMOUS] = Resource()
self.avatars[ANONYMOUS].putChild(
self.childName, Data(unprotectedContents, "text/plain")
)
self.portal.registerChecker(AllowAnonymousAccess())
self.credentialFactories.append(BasicCredentialFactory("example.com"))
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [unprotectedContents])
d.addCallback(cbFinished)
request.render(child)
return d

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,113 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test L{twisted.web.pages}
"""
from typing import cast
from twisted.trial.unittest import SynchronousTestCase
from twisted.web.http_headers import Headers
from twisted.web.iweb import IRequest
from twisted.web.pages import errorPage, forbidden, notFound
from twisted.web.resource import IResource
from twisted.web.test.requesthelper import DummyRequest
def _render(resource: IResource) -> DummyRequest:
"""
Render a response using the given resource.
@param resource: The resource to use to handle the request.
@returns: The request that the resource handled,
"""
request = DummyRequest([b""])
# The cast is necessary because DummyRequest isn't annotated
# as an IRequest, and this can't be trivially done. See
# https://github.com/twisted/twisted/issues/11719
resource.render(cast(IRequest, request))
return request
class ErrorPageTests(SynchronousTestCase):
"""
Test L{twisted.web.pages._ErrorPage} and its public aliases L{errorPage},
L{notFound} and L{forbidden}.
"""
maxDiff = None
def assertResponse(self, request: DummyRequest, code: int, body: bytes) -> None:
self.assertEqual(request.responseCode, code)
self.assertEqual(
request.responseHeaders,
Headers({b"content-type": [b"text/html; charset=utf-8"]}),
)
self.assertEqual(
# Decode to str because unittest somehow still doesn't diff bytes
# without truncating them in 2022.
b"".join(request.written).decode("latin-1"),
body.decode("latin-1"),
)
def test_escapesHTML(self) -> None:
"""
The I{brief} and I{detail} parameters are HTML-escaped on render.
"""
self.assertResponse(
_render(errorPage(400, "A & B", "<script>alert('oops!')")),
400,
(
b"<!DOCTYPE html>\n"
b"<html><head><title>400 - A &amp; B</title></head>"
b"<body><h1>A &amp; B</h1><p>&lt;script&gt;alert('oops!')"
b"</p></body></html>"
),
)
def test_getChild(self) -> None:
"""
The C{getChild} method of the resource returned by L{errorPage} returns
the L{_ErrorPage} it is called on.
"""
page = errorPage(404, "foo", "bar")
self.assertIs(
page.getChild(b"name", cast(IRequest, DummyRequest([b""]))),
page,
)
def test_notFoundDefaults(self) -> None:
"""
The default arguments to L{twisted.web.pages.notFound} produce
a reasonable error page.
"""
self.assertResponse(
_render(notFound()),
404,
(
b"<!DOCTYPE html>\n"
b"<html><head><title>404 - No Such Resource</title></head>"
b"<body><h1>No Such Resource</h1>"
b"<p>Sorry. No luck finding that resource.</p>"
b"</body></html>"
),
)
def test_forbiddenDefaults(self) -> None:
"""
The default arguments to L{twisted.web.pages.forbidden} produce
a reasonable error page.
"""
self.assertResponse(
_render(forbidden()),
403,
(
b"<!DOCTYPE html>\n"
b"<html><head><title>403 - Forbidden Resource</title></head>"
b"<body><h1>Forbidden Resource</h1>"
b"<p>Sorry, resource is forbidden.</p>"
b"</body></html>"
),
)

View File

@@ -0,0 +1,548 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test for L{twisted.web.proxy}.
"""
from twisted.internet.testing import MemoryReactor, StringTransportWithDisconnection
from twisted.trial.unittest import TestCase
from twisted.web.proxy import (
ProxyClient,
ProxyClientFactory,
ProxyRequest,
ReverseProxyRequest,
ReverseProxyResource,
)
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.test.test_web import DummyRequest
class ReverseProxyResourceTests(TestCase):
"""
Tests for L{ReverseProxyResource}.
"""
def _testRender(self, uri, expectedURI):
"""
Check that a request pointing at C{uri} produce a new proxy connection,
with the path of this request pointing at C{expectedURI}.
"""
root = Resource()
reactor = MemoryReactor()
resource = ReverseProxyResource("127.0.0.1", 1234, b"/path", reactor)
root.putChild(b"index", resource)
site = Site(root)
transport = StringTransportWithDisconnection()
channel = site.buildProtocol(None)
channel.makeConnection(transport)
# Clear the timeout if the tests failed
self.addCleanup(channel.connectionLost, None)
channel.dataReceived(b"GET " + uri + b" HTTP/1.1\r\nAccept: text/html\r\n\r\n")
[(host, port, factory, _timeout, _bind_addr)] = reactor.tcpClients
# Check that one connection has been created, to the good host/port
self.assertEqual(host, "127.0.0.1")
self.assertEqual(port, 1234)
# Check the factory passed to the connect, and its given path
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.headers[b"host"], b"127.0.0.1:1234")
def test_render(self):
"""
Test that L{ReverseProxyResource.render} initiates a connection to the
given server with a L{ProxyClientFactory} as parameter.
"""
return self._testRender(b"/index", b"/path")
def test_render_subpage(self):
"""
Test that L{ReverseProxyResource.render} will instantiate a child
resource that will initiate a connection to the given server
requesting the apropiate url subpath.
"""
return self._testRender(b"/index/page1", b"/path/page1")
def test_renderWithQuery(self):
"""
Test that L{ReverseProxyResource.render} passes query parameters to the
created factory.
"""
return self._testRender(b"/index?foo=bar", b"/path?foo=bar")
def test_getChild(self):
"""
The L{ReverseProxyResource.getChild} method should return a resource
instance with the same class as the originating resource, forward
port, host, and reactor values, and update the path value with the
value passed.
"""
reactor = MemoryReactor()
resource = ReverseProxyResource("127.0.0.1", 1234, b"/path", reactor)
child = resource.getChild(b"foo", None)
# The child should keep the same class
self.assertIsInstance(child, ReverseProxyResource)
self.assertEqual(child.path, b"/path/foo")
self.assertEqual(child.port, 1234)
self.assertEqual(child.host, "127.0.0.1")
self.assertIdentical(child.reactor, resource.reactor)
def test_getChildWithSpecial(self):
"""
The L{ReverseProxyResource} return by C{getChild} has a path which has
already been quoted.
"""
resource = ReverseProxyResource("127.0.0.1", 1234, b"/path")
child = resource.getChild(b" /%", None)
self.assertEqual(child.path, b"/path/%20%2F%25")
class DummyChannel:
"""
A dummy HTTP channel, that does nothing but holds a transport and saves
connection lost.
@ivar transport: the transport used by the client.
@ivar lostReason: the reason saved at connection lost.
"""
def __init__(self, transport):
"""
Hold a reference to the transport.
"""
self.transport = transport
self.lostReason = None
def connectionLost(self, reason):
"""
Keep track of the connection lost reason.
"""
self.lostReason = reason
def getPeer(self):
"""
Get peer information from the transport.
"""
return self.transport.getPeer()
def getHost(self):
"""
Get host information from the transport.
"""
return self.transport.getHost()
class ProxyClientTests(TestCase):
"""
Tests for L{ProxyClient}.
"""
def _parseOutHeaders(self, content):
"""
Parse the headers out of some web content.
@param content: Bytes received from a web server.
@return: A tuple of (requestLine, headers, body). C{headers} is a dict
of headers, C{requestLine} is the first line (e.g. "POST /foo ...")
and C{body} is whatever is left.
"""
headers, body = content.split(b"\r\n\r\n")
headers = headers.split(b"\r\n")
requestLine = headers.pop(0)
return (requestLine, dict(header.split(b": ") for header in headers), body)
def makeRequest(self, path):
"""
Make a dummy request object for the URL path.
@param path: A URL path, beginning with a slash.
@return: A L{DummyRequest}.
"""
return DummyRequest(path)
def makeProxyClient(self, request, method=b"GET", headers=None, requestBody=b""):
"""
Make a L{ProxyClient} object used for testing.
@param request: The request to use.
@param method: The HTTP method to use, GET by default.
@param headers: The HTTP headers to use expressed as a dict. If not
provided, defaults to {'accept': 'text/html'}.
@param requestBody: The body of the request. Defaults to the empty
string.
@return: A L{ProxyClient}
"""
if headers is None:
headers = {b"accept": b"text/html"}
path = b"/" + request.postpath
return ProxyClient(method, path, b"HTTP/1.0", headers, requestBody, request)
def connectProxy(self, proxyClient):
"""
Connect a proxy client to a L{StringTransportWithDisconnection}.
@param proxyClient: A L{ProxyClient}.
@return: The L{StringTransportWithDisconnection}.
"""
clientTransport = StringTransportWithDisconnection()
clientTransport.protocol = proxyClient
proxyClient.makeConnection(clientTransport)
return clientTransport
def assertForwardsHeaders(self, proxyClient, requestLine, headers):
"""
Assert that C{proxyClient} sends C{headers} when it connects.
@param proxyClient: A L{ProxyClient}.
@param requestLine: The request line we expect to be sent.
@param headers: A dict of headers we expect to be sent.
@return: If the assertion is successful, return the request body as
bytes.
"""
self.connectProxy(proxyClient)
requestContent = proxyClient.transport.value()
receivedLine, receivedHeaders, body = self._parseOutHeaders(requestContent)
self.assertEqual(receivedLine, requestLine)
self.assertEqual(receivedHeaders, headers)
return body
def makeResponseBytes(self, code, message, headers, body):
lines = [b"HTTP/1.0 " + str(code).encode("ascii") + b" " + message]
for header, values in headers:
for value in values:
lines.append(header + b": " + value)
lines.extend([b"", body])
return b"\r\n".join(lines)
def assertForwardsResponse(self, request, code, message, headers, body):
"""
Assert that C{request} has forwarded a response from the server.
@param request: A L{DummyRequest}.
@param code: The expected HTTP response code.
@param message: The expected HTTP message.
@param headers: The expected HTTP headers.
@param body: The expected response body.
"""
self.assertEqual(request.responseCode, code)
self.assertEqual(request.responseMessage, message)
receivedHeaders = list(request.responseHeaders.getAllRawHeaders())
receivedHeaders.sort()
expectedHeaders = headers[:]
expectedHeaders.sort()
self.assertEqual(receivedHeaders, expectedHeaders)
self.assertEqual(b"".join(request.written), body)
def _testDataForward(
self,
code,
message,
headers,
body,
method=b"GET",
requestBody=b"",
loseConnection=True,
):
"""
Build a fake proxy connection, and send C{data} over it, checking that
it's forwarded to the originating request.
"""
request = self.makeRequest(b"foo")
client = self.makeProxyClient(
request, method, {b"accept": b"text/html"}, requestBody
)
receivedBody = self.assertForwardsHeaders(
client,
method + b" /foo HTTP/1.0",
{b"connection": b"close", b"accept": b"text/html"},
)
self.assertEqual(receivedBody, requestBody)
# Fake an answer
client.dataReceived(self.makeResponseBytes(code, message, headers, body))
# Check that the response data has been forwarded back to the original
# requester.
self.assertForwardsResponse(request, code, message, headers, body)
# Check that when the response is done, the request is finished.
if loseConnection:
client.transport.loseConnection()
# Even if we didn't call loseConnection, the transport should be
# disconnected. This lets us not rely on the server to close our
# sockets for us.
self.assertFalse(client.transport.connected)
self.assertEqual(request.finished, 1)
def test_forward(self):
"""
When connected to the server, L{ProxyClient} should send the saved
request, with modifications of the headers, and then forward the result
to the parent request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar", b"baz"])], b"Some data\r\n"
)
def test_postData(self):
"""
Try to post content in the request, and check that the proxy client
forward the body of the request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar"])], b"Some data\r\n", b"POST", b"Some content"
)
def test_statusWithMessage(self):
"""
If the response contains a status with a message, it should be
forwarded to the parent request with all the information.
"""
return self._testDataForward(404, b"Not Found", [], b"")
def test_contentLength(self):
"""
If the response contains a I{Content-Length} header, the inbound
request object should still only have C{finish} called on it once.
"""
data = b"foo bar baz"
return self._testDataForward(
200, b"OK", [(b"Content-Length", [str(len(data)).encode("ascii")])], data
)
def test_losesConnection(self):
"""
If the response contains a I{Content-Length} header, the outgoing
connection is closed when all response body data has been received.
"""
data = b"foo bar baz"
return self._testDataForward(
200,
b"OK",
[(b"Content-Length", [str(len(data)).encode("ascii")])],
data,
loseConnection=False,
)
def test_headersCleanups(self):
"""
The headers given at initialization should be modified:
B{proxy-connection} should be removed if present, and B{connection}
should be added.
"""
client = ProxyClient(
b"GET",
b"/foo",
b"HTTP/1.0",
{b"accept": b"text/html", b"proxy-connection": b"foo"},
b"",
None,
)
self.assertEqual(
client.headers, {b"accept": b"text/html", b"connection": b"close"}
)
def test_keepaliveNotForwarded(self):
"""
The proxy doesn't really know what to do with keepalive things from
the remote server, so we stomp over any keepalive header we get from
the client.
"""
headers = {
b"accept": b"text/html",
b"keep-alive": b"300",
b"connection": b"keep-alive",
}
expectedHeaders = headers.copy()
expectedHeaders[b"connection"] = b"close"
del expectedHeaders[b"keep-alive"]
client = ProxyClient(b"GET", b"/foo", b"HTTP/1.0", headers, b"", None)
self.assertForwardsHeaders(client, b"GET /foo HTTP/1.0", expectedHeaders)
def test_defaultHeadersOverridden(self):
"""
L{server.Request} within the proxy sets certain response headers by
default. When we get these headers back from the remote server, the
defaults are overridden rather than simply appended.
"""
request = self.makeRequest(b"foo")
request.responseHeaders.setRawHeaders(b"server", [b"old-bar"])
request.responseHeaders.setRawHeaders(b"date", [b"old-baz"])
request.responseHeaders.setRawHeaders(b"content-type", [b"old/qux"])
client = self.makeProxyClient(request, headers={b"accept": b"text/html"})
self.connectProxy(client)
headers = {
b"Server": [b"bar"],
b"Date": [b"2010-01-01"],
b"Content-Type": [b"application/x-baz"],
}
client.dataReceived(self.makeResponseBytes(200, b"OK", headers.items(), b""))
self.assertForwardsResponse(request, 200, b"OK", list(headers.items()), b"")
class ProxyClientFactoryTests(TestCase):
"""
Tests for L{ProxyClientFactory}.
"""
def test_connectionFailed(self):
"""
Check that L{ProxyClientFactory.clientConnectionFailed} produces
a B{501} response to the parent request.
"""
request = DummyRequest([b"foo"])
factory = ProxyClientFactory(
b"GET", b"/foo", b"HTTP/1.0", {b"accept": b"text/html"}, "", request
)
factory.clientConnectionFailed(None, None)
self.assertEqual(request.responseCode, 501)
self.assertEqual(request.responseMessage, b"Gateway error")
self.assertEqual(
list(request.responseHeaders.getAllRawHeaders()),
[(b"Content-Type", [b"text/html"])],
)
self.assertEqual(b"".join(request.written), b"<H1>Could not connect</H1>")
self.assertEqual(request.finished, 1)
def test_buildProtocol(self):
"""
L{ProxyClientFactory.buildProtocol} should produce a L{ProxyClient}
with the same values of attributes (with updates on the headers).
"""
factory = ProxyClientFactory(
b"GET", b"/foo", b"HTTP/1.0", {b"accept": b"text/html"}, b"Some data", None
)
proto = factory.buildProtocol(None)
self.assertIsInstance(proto, ProxyClient)
self.assertEqual(proto.command, b"GET")
self.assertEqual(proto.rest, b"/foo")
self.assertEqual(proto.data, b"Some data")
self.assertEqual(
proto.headers, {b"accept": b"text/html", b"connection": b"close"}
)
class ProxyRequestTests(TestCase):
"""
Tests for L{ProxyRequest}.
"""
def _testProcess(self, uri, expectedURI, method=b"GET", data=b""):
"""
Build a request pointing at C{uri}, and check that a proxied request
is created, pointing a C{expectedURI}.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(len(data))
request.handleContentChunk(data)
request.requestReceived(method, b"http://example.com" + uri, b"HTTP/1.0")
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 80)
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.command, method)
self.assertEqual(factory.version, b"HTTP/1.0")
self.assertEqual(factory.headers, {b"host": b"example.com"})
self.assertEqual(factory.data, data)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.father, request)
def test_process(self):
"""
L{ProxyRequest.process} should create a connection to the given server,
with a L{ProxyClientFactory} as connection factory, with the correct
parameters:
- forward comment, version and data values
- update headers with the B{host} value
- remove the host from the URL
- pass the request as parent request
"""
return self._testProcess(b"/foo/bar", b"/foo/bar")
def test_processWithoutTrailingSlash(self):
"""
If the incoming request doesn't contain a slash,
L{ProxyRequest.process} should add one when instantiating
L{ProxyClientFactory}.
"""
return self._testProcess(b"", b"/")
def test_processWithData(self):
"""
L{ProxyRequest.process} should be able to retrieve request body and
to forward it.
"""
return self._testProcess(b"/foo/bar", b"/foo/bar", b"POST", b"Some content")
def test_processWithPort(self):
"""
Check that L{ProxyRequest.process} correctly parse port in the incoming
URL, and create an outgoing connection with this port.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(0)
request.requestReceived(b"GET", b"http://example.com:1234/foo/bar", b"HTTP/1.0")
# That should create one connection, with the port parsed from the URL
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
class DummyFactory:
"""
A simple holder for C{host} and C{port} information.
"""
def __init__(self, host, port):
self.host = host
self.port = port
class ReverseProxyRequestTests(TestCase):
"""
Tests for L{ReverseProxyRequest}.
"""
def test_process(self):
"""
L{ReverseProxyRequest.process} should create a connection to its
factory host/port, using a L{ProxyClientFactory} instantiated with the
correct parameters, and particularly set the B{host} header to the
factory host.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ReverseProxyRequest(channel, False, reactor)
request.factory = DummyFactory("example.com", 1234)
request.gotLength(0)
request.requestReceived(b"GET", b"/foo/bar", b"HTTP/1.0")
# Check that one connection has been created, to the good host/port
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], "example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
# Check the factory passed to the connect, and its headers
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.headers, {b"host": b"example.com"})

View File

@@ -0,0 +1,300 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.resource}.
"""
from twisted.trial.unittest import TestCase
from twisted.web.error import UnsupportedMethod
from twisted.web.http_headers import Headers
from twisted.web.iweb import IRequest
from twisted.web.resource import (
FORBIDDEN,
NOT_FOUND,
Resource,
_UnsafeErrorPage as ErrorPage,
_UnsafeForbiddenResource as ForbiddenResource,
_UnsafeNoResource as NoResource,
getChildForRequest,
)
from twisted.web.test.requesthelper import DummyRequest
class ErrorPageTests(TestCase):
"""
Tests for L{_UnafeErrorPage}, L{_UnsafeNoResource}, and
L{_UnsafeForbiddenResource}.
"""
def test_deprecatedErrorPage(self) -> None:
"""
The public C{twisted.web.resource.ErrorPage} alias for the
corresponding C{_Unsafe} class produces a deprecation warning when
called.
"""
_ = ErrorPage(123, "ono", "ono!")
[warning] = self.flushWarnings()
self.assertEqual(warning["category"], DeprecationWarning)
self.assertIn("twisted.web.pages.errorPage", warning["message"])
def test_deprecatedNoResource(self) -> None:
"""
The public C{twisted.web.resource.NoResource} alias for the
corresponding C{_Unsafe} class produces a deprecation warning when
called.
"""
_ = NoResource()
[warning] = self.flushWarnings()
self.assertEqual(warning["category"], DeprecationWarning)
self.assertIn("twisted.web.pages.notFound", warning["message"])
def test_deprecatedForbiddenResource(self) -> None:
"""
The public C{twisted.web.resource.ForbiddenResource} alias for the
corresponding C{_Unsafe} class produce a deprecation warning when
called.
"""
_ = ForbiddenResource()
[warning] = self.flushWarnings()
self.assertEqual(warning["category"], DeprecationWarning)
self.assertIn("twisted.web.pages.forbidden", warning["message"])
def test_getChild(self) -> None:
"""
The C{getChild} method of L{ErrorPage} returns the L{ErrorPage} it is
called on.
"""
page = ErrorPage(321, "foo", "bar")
self.assertIdentical(page.getChild(b"name", object()), page)
def _pageRenderingTest(
self, page: Resource, code: int, brief: str, detail: str
) -> None:
request = DummyRequest([b""])
template = (
"\n"
"<html>\n"
" <head><title>%s - %s</title></head>\n"
" <body>\n"
" <h1>%s</h1>\n"
" <p>%s</p>\n"
" </body>\n"
"</html>\n"
)
expected = template % (code, brief, brief, detail)
self.assertEqual(page.render(request), expected.encode("utf-8"))
self.assertEqual(request.responseCode, code)
self.assertEqual(
request.responseHeaders,
Headers({b"content-type": [b"text/html; charset=utf-8"]}),
)
def test_errorPageRendering(self) -> None:
"""
L{ErrorPage.render} returns a C{bytes} describing the error defined by
the response code and message passed to L{ErrorPage.__init__}. It also
uses that response code to set the response code on the L{Request}
passed in.
"""
code = 321
brief = "brief description text"
detail = "much longer text might go here"
page = ErrorPage(code, brief, detail)
self._pageRenderingTest(page, code, brief, detail)
def test_noResourceRendering(self) -> None:
"""
L{NoResource} sets the HTTP I{NOT FOUND} code.
"""
detail = "long message"
page = NoResource(detail)
self._pageRenderingTest(page, NOT_FOUND, "No Such Resource", detail)
def test_forbiddenResourceRendering(self) -> None:
"""
L{ForbiddenResource} sets the HTTP I{FORBIDDEN} code.
"""
detail = "longer message"
page = ForbiddenResource(detail)
self._pageRenderingTest(page, FORBIDDEN, "Forbidden Resource", detail)
class DynamicChild(Resource):
"""
A L{Resource} to be created on the fly by L{DynamicChildren}.
"""
def __init__(self, path: bytes, request: IRequest) -> None:
Resource.__init__(self)
self.path = path
self.request = request
class DynamicChildren(Resource):
"""
A L{Resource} with dynamic children.
"""
def getChild(self, path: bytes, request: IRequest) -> DynamicChild:
return DynamicChild(path, request)
class BytesReturnedRenderable(Resource):
"""
A L{Resource} with minimal capabilities to render a response.
"""
def __init__(self, response: bytes) -> None:
"""
@param response: A C{bytes} object giving the value to return from
C{render_GET}.
"""
Resource.__init__(self)
self._response = response
def render_GET(self, request: object) -> bytes:
"""
Render a response to a I{GET} request by returning a short byte string
to be written by the server.
"""
return self._response
class ImplicitAllowedMethods(Resource):
"""
A L{Resource} which implicitly defines its allowed methods by defining
renderers to handle them.
"""
def render_GET(self, request: object) -> None:
pass
def render_PUT(self, request: object) -> None:
pass
class ResourceTests(TestCase):
"""
Tests for L{Resource}.
"""
def test_staticChildren(self) -> None:
"""
L{Resource.putChild} adds a I{static} child to the resource. That child
is returned from any call to L{Resource.getChildWithDefault} for the
child's path.
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(b"foo", child)
resource.putChild(b"bar", sibling)
self.assertIdentical(
child, resource.getChildWithDefault(b"foo", DummyRequest([]))
)
def test_dynamicChildren(self) -> None:
"""
L{Resource.getChildWithDefault} delegates to L{Resource.getChild} when
the requested path is not associated with any static child.
"""
path = b"foo"
request = DummyRequest([])
resource = DynamicChildren()
child = resource.getChildWithDefault(path, request)
self.assertIsInstance(child, DynamicChild)
self.assertEqual(child.path, path)
self.assertIdentical(child.request, request)
def test_staticChildPathType(self) -> None:
"""
Test that passing the wrong type to putChild results in a warning,
and a failure in Python 3
"""
resource = Resource()
child = Resource()
sibling = Resource()
self.assertRaises(TypeError, resource.putChild, "foo", child)
self.assertRaises(TypeError, resource.putChild, None, sibling)
def test_defaultHEAD(self) -> None:
"""
When not otherwise overridden, L{Resource.render} treats a I{HEAD}
request as if it were a I{GET} request.
"""
expected = b"insert response here"
request = DummyRequest([])
request.method = b"HEAD"
resource = BytesReturnedRenderable(expected)
self.assertEqual(expected, resource.render(request))
def test_explicitAllowedMethods(self) -> None:
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to the value of the
C{allowedMethods} attribute of the L{Resource}, if it has one.
"""
expected = [b"GET", b"HEAD", b"PUT"]
resource = Resource()
resource.allowedMethods = expected
request = DummyRequest([])
request.method = b"FICTIONAL"
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(set(expected), set(exc.allowedMethods))
def test_implicitAllowedMethods(self) -> None:
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to a list of the
methods supported by the L{Resource}, as determined by the
I{render_}-prefixed methods which it defines, if C{allowedMethods} is
not explicitly defined by the L{Resource}.
"""
expected = {b"GET", b"HEAD", b"PUT"}
resource = ImplicitAllowedMethods()
request = DummyRequest([])
request.method = b"FICTIONAL"
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(expected, set(exc.allowedMethods))
class GetChildForRequestTests(TestCase):
"""
Tests for L{getChildForRequest}.
"""
def test_exhaustedPostPath(self) -> None:
"""
L{getChildForRequest} returns whatever resource has been reached by the
time the request's C{postpath} is empty.
"""
request = DummyRequest([])
resource = Resource()
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_leafResource(self) -> None:
"""
L{getChildForRequest} returns the first resource it encounters with a
C{isLeaf} attribute set to C{True}.
"""
request = DummyRequest([b"foo", b"bar"])
resource = Resource()
resource.isLeaf = True
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_postPathToPrePath(self) -> None:
"""
As path segments from the request are traversed, they are taken from
C{postpath} and put into C{prepath}.
"""
request = DummyRequest([b"foo", b"bar"])
root = Resource()
child = Resource()
child.isLeaf = True
root.putChild(b"foo", child)
self.assertIdentical(child, getChildForRequest(root, request))
self.assertEqual(request.prepath, [b"foo"])
self.assertEqual(request.postpath, [b"bar"])

View File

@@ -0,0 +1,121 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.script}.
"""
import os
from twisted.internet import defer
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.script import PythonScript, ResourceScriptDirectory
from twisted.web.test._util import _render
from twisted.web.test.requesthelper import DummyRequest
class ResourceScriptDirectoryTests(TestCase):
"""
Tests for L{ResourceScriptDirectory}.
"""
def test_renderNotFound(self) -> defer.Deferred[None]:
"""
L{ResourceScriptDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = ResourceScriptDirectory(self.mktemp())
request = DummyRequest([b""])
d = _render(resource, request)
def cbRendered(ignored: object) -> None:
self.assertEqual(request.responseCode, NOT_FOUND)
return d.addCallback(cbRendered)
def test_notFoundChild(self) -> defer.Deferred[None]:
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{ResourceScriptDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = ResourceScriptDirectory(path)
request = DummyRequest([b"foo"])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored: object) -> None:
self.assertEqual(request.responseCode, NOT_FOUND)
return d.addCallback(cbRendered)
def test_render(self) -> defer.Deferred[None]:
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
tmp.child("test.rpy").setContent(
b"""
from twisted.web.resource import Resource
class TestResource(Resource):
isLeaf = True
def render_GET(self, request):
return b'ok'
resource = TestResource()"""
)
resource = ResourceScriptDirectory(tmp._asBytesPath())
request = DummyRequest([b""])
child = resource.getChild(b"test.rpy", request)
d = _render(child, request)
def cbRendered(ignored: object) -> None:
self.assertEqual(b"".join(request.written), b"ok")
return d.addCallback(cbRendered)
class PythonScriptTests(TestCase):
"""
Tests for L{PythonScript}.
"""
def test_notFoundRender(self) -> defer.Deferred[None]:
"""
If the source file a L{PythonScript} is initialized with doesn't exist,
L{PythonScript.render} sets the HTTP response code to I{NOT FOUND}.
"""
resource = PythonScript(self.mktemp(), None)
request = DummyRequest([b""])
d = _render(resource, request)
def cbRendered(ignored: object) -> None:
self.assertEqual(request.responseCode, NOT_FOUND)
return d.addCallback(cbRendered)
def test_renderException(self) -> defer.Deferred[None]:
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
child = tmp.child("test.epy")
child.setContent(b'raise Exception("nooo")')
resource = PythonScript(child._asBytesPath(), None)
request = DummyRequest([b""])
d = _render(resource, request)
def cbRendered(ignored: object) -> None:
self.assertIn(b"nooo", b"".join(request.written))
return d.addCallback(cbRendered)

View File

@@ -0,0 +1,197 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._stan} portion of the L{twisted.web.template}
implementation.
"""
import sys
from typing import NoReturn
from twisted.trial.unittest import TestCase
from twisted.web.template import CDATA, CharRef, Comment, Flattenable, Tag
def proto(*a: Flattenable, **kw: Flattenable) -> Tag:
"""
Produce a new tag for testing.
"""
return Tag("hello")(*a, **kw)
class TagTests(TestCase):
"""
Tests for L{Tag}.
"""
def test_renderAttribute(self) -> None:
"""
Setting an attribute named C{render} will change the C{render} instance
variable instead of adding an attribute.
"""
tag = proto(render="myRenderer")
self.assertEqual(tag.render, "myRenderer")
self.assertEqual(tag.attributes, {})
def test_renderAttributeNonString(self) -> None:
"""
Attempting to set an attribute named C{render} to something other than
a string will raise L{TypeError}.
"""
with self.assertRaises(TypeError) as e:
proto(render=83) # type: ignore[arg-type]
self.assertEqual(
e.exception.args[0], 'Value for "render" attribute must be str, got 83'
)
def test_fillSlots(self) -> None:
"""
L{Tag.fillSlots} returns self.
"""
tag = proto()
self.assertIdentical(tag, tag.fillSlots(test="test"))
def test_cloneShallow(self) -> None:
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. If the shallow flag is C{False}, that's where it
stops.
"""
innerList = ["inner list"]
tag = proto("How are you", innerList, hello="world", render="aSampleMethod")
tag.fillSlots(foo="bar")
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone(deep=False)
self.assertEqual(clone.attributes["hello"], "world")
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertEqual(clone.children, ["How are you", innerList])
self.assertNotIdentical(clone.children, tag.children)
self.assertIdentical(clone.children[1], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_cloneDeep(self) -> None:
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. In its normal operating mode (where the deep flag is
C{True}, as is the default), it will clone all sub-lists and sub-tags.
"""
innerTag = proto("inner")
innerList = ["inner list"]
tag = proto(
"How are you", innerTag, innerList, hello="world", render="aSampleMethod"
)
tag.fillSlots(foo="bar")
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone()
self.assertEqual(clone.attributes["hello"], "world")
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertNotIdentical(clone.children, tag.children)
# sanity check
self.assertIdentical(tag.children[1], innerTag)
# clone should have sub-clone
self.assertNotIdentical(clone.children[1], innerTag)
# sanity check
self.assertIdentical(tag.children[2], innerList)
# clone should have sub-clone
self.assertNotIdentical(clone.children[2], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_cloneGeneratorDeprecation(self) -> None:
"""
Cloning a tag containing a generator is unsafe. To avoid breaking
programs that only flatten the clone or only flatten the original,
we deprecate old behavior rather than making it an error immediately.
"""
tag = proto(str(n) for n in range(10))
self.assertWarns(
DeprecationWarning,
"Cloning a Tag which contains a generator is unsafe, "
"since the generator can be consumed only once; "
"this is deprecated since Twisted 21.7.0 and will raise "
"an exception in the future",
sys.modules[Tag.__module__].__file__,
tag.clone,
)
def test_cloneCoroutineDeprecation(self) -> None:
"""
Cloning a tag containing a coroutine is unsafe. To avoid breaking
programs that only flatten the clone or only flatten the original,
we deprecate old behavior rather than making it an error immediately.
"""
async def asyncFunc() -> NoReturn:
raise NotImplementedError
coro = asyncFunc()
tag = proto("123", coro, "789")
try:
self.assertWarns(
DeprecationWarning,
"Cloning a Tag which contains a coroutine is unsafe, "
"since the coroutine can run only once; "
"this is deprecated since Twisted 21.7.0 and will raise "
"an exception in the future",
sys.modules[Tag.__module__].__file__,
tag.clone,
)
finally:
coro.close()
def test_clear(self) -> None:
"""
L{Tag.clear} removes all children from a tag, but leaves its attributes
in place.
"""
tag = proto("these are", "children", "cool", andSoIs="this-attribute")
tag.clear()
self.assertEqual(tag.children, [])
self.assertEqual(tag.attributes, {"andSoIs": "this-attribute"})
def test_suffix(self) -> None:
"""
L{Tag.__call__} accepts Python keywords with a suffixed underscore as
the DOM attribute of that literal suffix.
"""
proto = Tag("div")
tag = proto()
tag(class_="a")
self.assertEqual(tag.attributes, {"class": "a"})
def test_commentReprPy3(self) -> None:
"""
L{Comment.__repr__} returns a value which makes it easy to see what's
in the comment.
"""
self.assertEqual(repr(Comment("hello there")), "Comment('hello there')")
def test_cdataReprPy3(self) -> None:
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA("test data")), "CDATA('test data')")
def test_charrefRepr(self) -> None:
"""
L{CharRef.__repr__} returns a value which makes it easy to see what
character is referred to.
"""
snowman = ord("\N{SNOWMAN}")
self.assertEqual(repr(CharRef(snowman)), "CharRef(9731)")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,319 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.tap}.
"""
from __future__ import annotations
import os
import stat
from typing import cast
from unittest import skipIf
from twisted.internet import endpoints, reactor
from twisted.internet.interfaces import IReactorCore, IReactorUNIX
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.python.threadpool import ThreadPool
from twisted.python.usage import UsageError
from twisted.spread.pb import PBServerFactory
from twisted.trial.unittest import TestCase
from twisted.web import demo
from twisted.web.distrib import ResourcePublisher, UserDirectory
from twisted.web.script import PythonScript
from twisted.web.server import Site
from twisted.web.static import Data, File
from twisted.web.tap import (
Options,
_AddHeadersResource,
makePersonalServerFactory,
makeService,
)
from twisted.web.test.requesthelper import DummyRequest
from twisted.web.twcgi import CGIScript
from twisted.web.wsgi import WSGIResource
application = object()
class ServiceTests(TestCase):
"""
Tests for the service creation APIs in L{twisted.web.tap}.
"""
def _pathOption(self) -> tuple[FilePath[str], File]:
"""
Helper for the I{--path} tests which creates a directory and creates
an L{Options} object which uses that directory as its static
filesystem root.
@return: A two-tuple of a L{FilePath} referring to the directory and
the value associated with the C{'root'} key in the L{Options}
instance after parsing a I{--path} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
options = Options()
options.parseOptions(["--path", path.path])
root = options["root"]
return path, root
def test_path(self) -> None:
"""
The I{--path} option causes L{Options} to create a root resource
which serves responses from the specified path.
"""
path, root = self._pathOption()
self.assertIsInstance(root, File)
self.assertEqual(root.path, path.path)
@skipIf(
not IReactorUNIX.providedBy(reactor),
"The reactor does not support UNIX domain sockets",
)
def test_pathServer(self) -> None:
"""
The I{--path} option to L{makeService} causes it to return a service
which will listen on the server address given by the I{--port} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
port = self.mktemp()
options = Options()
options.parseOptions(["--port", "unix:" + port, "--path", path.path])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertIsInstance(service.services[0].factory.resource, File)
self.assertEqual(service.services[0].factory.resource.path, path.path)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
def test_cgiProcessor(self) -> None:
"""
The I{--path} option creates a root resource which serves a
L{CGIScript} instance for any child with the C{".cgi"} extension.
"""
path, root = self._pathOption()
path.child("foo.cgi").setContent(b"")
self.assertIsInstance(root.getChild("foo.cgi", None), CGIScript)
def test_epyProcessor(self) -> None:
"""
The I{--path} option creates a root resource which serves a
L{PythonScript} instance for any child with the C{".epy"} extension.
"""
path, root = self._pathOption()
path.child("foo.epy").setContent(b"")
self.assertIsInstance(root.getChild("foo.epy", None), PythonScript)
def test_rpyProcessor(self) -> None:
"""
The I{--path} option creates a root resource which serves the
C{resource} global defined by the Python source in any child with
the C{".rpy"} extension.
"""
path, root = self._pathOption()
path.child("foo.rpy").setContent(
b"from twisted.web.static import Data\n"
b"resource = Data('content', 'major/minor')\n"
)
child = root.getChild("foo.rpy", None)
self.assertIsInstance(child, Data)
self.assertEqual(child.data, "content")
self.assertEqual(child.type, "major/minor")
def test_makePersonalServerFactory(self) -> None:
"""
L{makePersonalServerFactory} returns a PB server factory which has
as its root object a L{ResourcePublisher}.
"""
# The fact that this pile of objects can actually be used somehow is
# verified by twisted.web.test.test_distrib.
site = Site(Data(b"foo bar", "text/plain"))
serverFactory = makePersonalServerFactory(site)
self.assertIsInstance(serverFactory, PBServerFactory)
self.assertIsInstance(serverFactory.root, ResourcePublisher)
self.assertIdentical(serverFactory.root.site, site)
@skipIf(
not IReactorUNIX.providedBy(reactor),
"The reactor does not support UNIX domain sockets",
)
def test_personalServer(self) -> None:
"""
The I{--personal} option to L{makeService} causes it to return a
service which will listen on the server address given by the I{--port}
option.
"""
port = self.mktemp()
options = Options()
options.parseOptions(["--port", "unix:" + port, "--personal"])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
@skipIf(
not IReactorUNIX.providedBy(reactor),
"The reactor does not support UNIX domain sockets",
)
def test_defaultPersonalPath(self) -> None:
"""
If the I{--port} option not specified but the I{--personal} option is,
L{Options} defaults the port to C{UserDirectory.userSocketName} in the
user's home directory.
"""
options = Options()
options.parseOptions(["--personal"])
path = os.path.expanduser(os.path.join("~", UserDirectory.userSocketName))
self.assertEqual(options["ports"][0], f"unix:{path}")
def test_defaultPort(self) -> None:
"""
If the I{--port} option is not specified, L{Options} defaults the port
to C{8080}.
"""
options = Options()
options.parseOptions([])
self.assertEqual(
endpoints._parseServer(options["ports"][0], None)[:2], ("TCP", (8080, None))
)
def test_twoPorts(self) -> None:
"""
If the I{--http} option is given twice, there are two listeners
"""
options = Options()
options.parseOptions(["--listen", "tcp:8001", "--listen", "tcp:8002"])
self.assertIn("8001", options["ports"][0])
self.assertIn("8002", options["ports"][1])
def test_wsgi(self) -> None:
"""
The I{--wsgi} option takes the fully-qualifed Python name of a WSGI
application object and creates a L{WSGIResource} at the root which
serves that application.
"""
options = Options()
options.parseOptions(["--wsgi", __name__ + ".application"])
root = options["root"]
self.assertTrue(root, WSGIResource)
self.assertIdentical(root._reactor, reactor)
self.assertTrue(isinstance(root._threadpool, ThreadPool))
self.assertIdentical(root._application, application)
# The threadpool should start and stop with the reactor.
self.assertFalse(root._threadpool.started)
cast(IReactorCore, reactor).fireSystemEvent("startup")
self.assertTrue(root._threadpool.started)
self.assertFalse(root._threadpool.joined)
cast(IReactorCore, reactor).fireSystemEvent("shutdown")
self.assertTrue(root._threadpool.joined)
def test_invalidApplication(self) -> None:
"""
If I{--wsgi} is given an invalid name, L{Options.parseOptions}
raises L{UsageError}.
"""
options = Options()
for name in [__name__ + ".nosuchthing", "foo."]:
exc = self.assertRaises(UsageError, options.parseOptions, ["--wsgi", name])
self.assertEqual(str(exc), f"No such WSGI application: {name!r}")
@skipIf(requireModule("OpenSSL.SSL") is not None, "SSL module is available.")
def test_HTTPSFailureOnMissingSSL(self) -> None:
"""
An L{UsageError} is raised when C{https} is requested but there is no
support for SSL.
"""
options = Options()
exception = self.assertRaises(UsageError, options.parseOptions, ["--https=443"])
self.assertEqual("SSL support not installed", exception.args[0])
@skipIf(requireModule("OpenSSL.SSL") is None, "SSL module is not available.")
def test_HTTPSAcceptedOnAvailableSSL(self) -> None:
"""
When SSL support is present, it accepts the --https option.
"""
options = Options()
options.parseOptions(["--https=443"])
self.assertIn("ssl", options["ports"][0])
self.assertIn("443", options["ports"][0])
def test_add_header_parsing(self) -> None:
"""
When --add-header is specific, the value is parsed.
"""
options = Options()
options.parseOptions(["--add-header", "K1: V1", "--add-header", "K2: V2"])
self.assertEqual(options["extraHeaders"], [("K1", "V1"), ("K2", "V2")])
def test_add_header_resource(self) -> None:
"""
When --add-header is specified, the resource is a composition that adds
headers.
"""
options = Options()
options.parseOptions(["--add-header", "K1: V1", "--add-header", "K2: V2"])
service = makeService(options)
resource = service.services[0].factory.resource
self.assertIsInstance(resource, _AddHeadersResource)
self.assertEqual(resource._headers, [("K1", "V1"), ("K2", "V2")])
self.assertIsInstance(resource._originalResource, demo.Test)
def test_noTracebacksDeprecation(self) -> None:
"""
Passing --notracebacks is deprecated.
"""
options = Options()
options.parseOptions(["--notracebacks"])
makeService(options)
warnings = self.flushWarnings([self.test_noTracebacksDeprecation])
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.assertEqual(
warnings[0]["message"], "--notracebacks was deprecated in Twisted 19.7.0"
)
self.assertEqual(len(warnings), 1)
def test_displayTracebacks(self) -> None:
"""
Passing --display-tracebacks will enable traceback rendering on the
generated Site.
"""
options = Options()
options.parseOptions(["--display-tracebacks"])
service = makeService(options)
self.assertTrue(service.services[0].factory.displayTracebacks)
def test_displayTracebacksNotGiven(self) -> None:
"""
Not passing --display-tracebacks will leave traceback rendering on the
generated Site off.
"""
options = Options()
options.parseOptions([])
service = makeService(options)
self.assertFalse(service.services[0].factory.displayTracebacks)
class AddHeadersResourceTests(TestCase):
def test_getChildWithDefault(self) -> None:
"""
When getChildWithDefault is invoked, it adds the headers to the
response.
"""
resource = _AddHeadersResource(
demo.Test(), [("K1", "V1"), ("K2", "V2"), ("K1", "V3")]
)
request = DummyRequest([])
resource.getChildWithDefault("", request)
self.assertEqual(request.responseHeaders.getRawHeaders("K1"), ["V1", "V3"])
self.assertEqual(request.responseHeaders.getRawHeaders("K2"), ["V2"])

View File

@@ -0,0 +1,915 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.template}
"""
import sys
from io import StringIO
from typing import List, Optional
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.internet.defer import Deferred, succeed
from twisted.internet.testing import EventLoggingObserver
from twisted.logger import globalLogPublisher
from twisted.python.failure import Failure
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.trial.util import suppress as SUPPRESS
from twisted.web._element import UnexposedMethodError
from twisted.web.error import FlattenerError, MissingRenderMethod, MissingTemplateLoader
from twisted.web.iweb import IRequest, ITemplateLoader
from twisted.web.server import NOT_DONE_YET
from twisted.web.template import (
Element,
Flattenable,
Tag,
TagLoader,
XMLFile,
XMLString,
renderElement,
renderer,
tags,
)
from twisted.web.test._util import FlattenTestCase
from twisted.web.test.test_web import DummyRequest
_xmlFileSuppress = SUPPRESS(
category=DeprecationWarning,
message="Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.",
)
class TagFactoryTests(TestCase):
"""
Tests for L{_TagFactory} through the publicly-exposed L{tags} object.
"""
def test_lookupTag(self) -> None:
"""
HTML tags can be retrieved through C{tags}.
"""
tag = tags.a
self.assertEqual(tag.tagName, "a")
def test_lookupHTML5Tag(self) -> None:
"""
Twisted supports the latest and greatest HTML tags from the HTML5
specification.
"""
tag = tags.video
self.assertEqual(tag.tagName, "video")
def test_lookupTransparentTag(self) -> None:
"""
To support transparent inclusion in templates, there is a special tag,
the transparent tag, which has no name of its own but is accessed
through the "transparent" attribute.
"""
tag = tags.transparent
self.assertEqual(tag.tagName, "")
def test_lookupInvalidTag(self) -> None:
"""
Invalid tags which are not part of HTML cause AttributeErrors when
accessed through C{tags}.
"""
self.assertRaises(AttributeError, getattr, tags, "invalid")
def test_lookupXMP(self) -> None:
"""
As a special case, the <xmp> tag is simply not available through
C{tags} or any other part of the templating machinery.
"""
self.assertRaises(AttributeError, getattr, tags, "xmp")
class ElementTests(TestCase):
"""
Tests for the awesome new L{Element} class.
"""
def test_missingTemplateLoader(self) -> None:
"""
L{Element.render} raises L{MissingTemplateLoader} if the C{loader}
attribute is L{None}.
"""
element = Element()
err = self.assertRaises(MissingTemplateLoader, element.render, None)
self.assertIdentical(err.element, element)
def test_missingTemplateLoaderRepr(self) -> None:
"""
A L{MissingTemplateLoader} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self) -> str:
return "Pretty Repr Element"
self.assertIn(
"Pretty Repr Element", repr(MissingTemplateLoader(PrettyReprElement()))
)
def test_missingRendererMethod(self) -> None:
"""
When called with the name which is not associated with a render method,
L{Element.lookupRenderMethod} raises L{MissingRenderMethod}.
"""
element = Element()
err = self.assertRaises(MissingRenderMethod, element.lookupRenderMethod, "foo")
self.assertIdentical(err.element, element)
self.assertEqual(err.renderName, "foo")
def test_missingRenderMethodRepr(self) -> None:
"""
A L{MissingRenderMethod} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self) -> str:
return "Pretty Repr Element"
s = repr(MissingRenderMethod(PrettyReprElement(), "expectedMethod"))
self.assertIn("Pretty Repr Element", s)
self.assertIn("expectedMethod", s)
def test_definedRenderer(self) -> None:
"""
When called with the name of a defined render method,
L{Element.lookupRenderMethod} returns that render method.
"""
class ElementWithRenderMethod(Element):
@renderer
def foo(self, request: Optional[IRequest], tag: Tag) -> Flattenable:
return "bar"
foo = ElementWithRenderMethod().lookupRenderMethod("foo")
self.assertEqual(foo(None, tags.br), "bar")
def test_render(self) -> None:
"""
L{Element.render} loads a document from the C{loader} attribute and
returns it.
"""
@implementer(ITemplateLoader)
class TemplateLoader:
def load(self) -> List[Flattenable]:
return ["result"]
class StubElement(Element):
loader = TemplateLoader()
element = StubElement()
self.assertEqual(element.render(None), ["result"])
def test_misuseRenderer(self) -> None:
"""
If the L{renderer} decorator is called without any arguments, it will
raise a comprehensible exception.
"""
te = self.assertRaises(TypeError, renderer)
if sys.version_info >= (3, 10):
self.assertEqual(
str(te), "Expose.__call__() missing 1 required positional argument: 'f'"
)
else:
self.assertEqual(
str(te), "__call__() missing 1 required positional argument: 'f'"
)
def test_renderGetDirectlyError(self) -> None:
"""
Called directly, without a default, L{renderer.get} raises
L{UnexposedMethodError} when it cannot find a renderer.
"""
self.assertRaises(UnexposedMethodError, renderer.get, None, "notARenderer")
class XMLFileReprTests(TestCase):
"""
Tests for L{twisted.web.template.XMLFile}'s C{__repr__}.
"""
def test_filePath(self) -> None:
"""
An L{XMLFile} with a L{FilePath} returns a useful repr().
"""
path = FilePath("/tmp/fake.xml")
self.assertEqual(f"<XMLFile of {path!r}>", repr(XMLFile(path)))
def test_filename(self) -> None:
"""
An L{XMLFile} with a filename returns a useful repr().
"""
fname = "/tmp/fake.xml" # deprecated
self.assertEqual(f"<XMLFile of {fname!r}>", repr(XMLFile(fname))) # type: ignore[arg-type]
test_filename.suppress = [_xmlFileSuppress] # type: ignore[attr-defined]
def test_file(self) -> None:
"""
An L{XMLFile} with a file object returns a useful repr().
"""
fobj = StringIO("not xml") # deprecated
self.assertEqual(f"<XMLFile of {fobj!r}>", repr(XMLFile(fobj))) # type: ignore[arg-type]
test_file.suppress = [_xmlFileSuppress] # type: ignore[attr-defined]
class XMLLoaderTestsMixin:
deprecatedUse: bool
"""
C{True} if this use of L{XMLFile} is deprecated and should emit
a C{DeprecationWarning}.
"""
templateString = "<p>Hello, world.</p>"
"""
Simple template to use to exercise the loaders.
"""
def loaderFactory(self) -> ITemplateLoader:
raise NotImplementedError
def test_load(self) -> None:
"""
Verify that the loader returns a tag with the correct children.
"""
assert isinstance(self, TestCase)
loader = self.loaderFactory()
(tag,) = loader.load()
assert isinstance(tag, Tag)
warnings = self.flushWarnings(offendingFunctions=[self.loaderFactory])
if self.deprecatedUse:
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]["category"], DeprecationWarning)
self.assertEqual(
warnings[0]["message"],
"Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.",
)
else:
self.assertEqual(len(warnings), 0)
self.assertEqual(tag.tagName, "p")
self.assertEqual(tag.children, ["Hello, world."])
def test_loadTwice(self) -> None:
"""
If {load()} can be called on a loader twice the result should be the
same.
"""
assert isinstance(self, TestCase)
loader = self.loaderFactory()
tags1 = loader.load()
tags2 = loader.load()
self.assertEqual(tags1, tags2)
test_loadTwice.suppress = [_xmlFileSuppress] # type: ignore[attr-defined]
class XMLStringLoaderTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLString}
"""
deprecatedUse = False
def loaderFactory(self) -> ITemplateLoader:
"""
@return: an L{XMLString} constructed with C{self.templateString}.
"""
return XMLString(self.templateString)
class XMLFileWithFilePathTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s L{FilePath} support.
"""
deprecatedUse = False
def loaderFactory(self) -> ITemplateLoader:
"""
@return: an L{XMLString} constructed with a L{FilePath} pointing to a
file that contains C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode("utf8"))
return XMLFile(fp)
class XMLFileWithFileTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated file object support.
"""
deprecatedUse = True
def loaderFactory(self) -> ITemplateLoader:
"""
@return: an L{XMLString} constructed with a file object that contains
C{self.templateString}.
"""
return XMLFile(StringIO(self.templateString)) # type: ignore[arg-type]
class XMLFileWithFilenameTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated filename support.
"""
deprecatedUse = True
def loaderFactory(self) -> ITemplateLoader:
"""
@return: an L{XMLString} constructed with a filename that points to a
file containing C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode("utf8"))
return XMLFile(fp.path)
class FlattenIntegrationTests(FlattenTestCase):
"""
Tests for integration between L{Element} and
L{twisted.web._flatten.flatten}.
"""
def test_roundTrip(self) -> None:
"""
Given a series of parsable XML strings, verify that
L{twisted.web._flatten.flatten} will flatten the L{Element} back to the
input when sent on a round trip.
"""
fragments = [
b"<p>Hello, world.</p>",
b"<p><!-- hello, world --></p>",
b"<p><![CDATA[Hello, world.]]></p>",
b'<test1 xmlns:test2="urn:test2">' b"<test2:test3></test2:test3></test1>",
b'<test1 xmlns="urn:test2"><test3></test3></test1>',
b"<p>\xe2\x98\x83</p>",
]
for xml in fragments:
self.assertFlattensImmediately(Element(loader=XMLString(xml)), xml)
def test_entityConversion(self) -> None:
"""
When flattening an HTML entity, it should flatten out to the utf-8
representation if possible.
"""
element = Element(loader=XMLString("<p>&#9731;</p>"))
self.assertFlattensImmediately(element, b"<p>\xe2\x98\x83</p>")
def test_missingTemplateLoader(self) -> None:
"""
Rendering an Element without a loader attribute raises the appropriate
exception.
"""
self.assertFlatteningRaises(Element(), MissingTemplateLoader)
def test_missingRenderMethod(self) -> None:
"""
Flattening an L{Element} with a C{loader} which has a tag with a render
directive fails with L{FlattenerError} if there is no available render
method to satisfy that directive.
"""
element = Element(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="unknownMethod" />
"""
)
)
self.assertFlatteningRaises(element, MissingRenderMethod)
def test_transparentRendering(self) -> None:
"""
A C{transparent} element should be eliminated from the DOM and rendered as
only its children.
"""
element = Element(
loader=XMLString(
"<t:transparent "
'xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
"Hello, world."
"</t:transparent>"
)
)
self.assertFlattensImmediately(element, b"Hello, world.")
def test_attrRendering(self) -> None:
"""
An Element with an attr tag renders the vaule of its attr tag as an
attribute of its containing tag.
"""
element = Element(
loader=XMLString(
'<a xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:attr name="href">http://example.com</t:attr>'
"Hello, world."
"</a>"
)
)
self.assertFlattensImmediately(
element, b'<a href="http://example.com">Hello, world.</a>'
)
def test_synchronousDeferredRecursion(self) -> None:
"""
When rendering a large number of already-fired Deferreds we should not
encounter any recursion errors or stack-depth issues.
"""
self.assertFlattensImmediately([succeed("x") for i in range(250)], b"x" * 250)
def test_errorToplevelAttr(self) -> None:
"""
A template with a toplevel C{attr} tag will not load; it will raise
L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
name='something'
>hello</t:attr>
""",
)
def test_errorUnnamedAttr(self) -> None:
"""
A template with an C{attr} tag with no C{name} attribute will not load;
it will raise L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<html><t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
>hello</t:attr></html>""",
)
def test_lenientPrefixBehavior(self) -> None:
"""
If the parser sees a prefix it doesn't recognize on an attribute, it
will pass it on through to serialization.
"""
theInput = (
'<hello:world hello:sample="testing" '
'xmlns:hello="http://made-up.example.com/ns/not-real">'
"This is a made-up tag.</hello:world>"
)
element = Element(loader=XMLString(theInput))
self.assertFlattensTo(element, theInput.encode("utf8"))
def test_deferredRendering(self) -> None:
"""
An Element with a render method which returns a Deferred will render
correctly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
return succeed("Hello, world.")
element = RenderfulElement(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""
)
)
self.assertFlattensImmediately(element, b"Hello, world.")
def test_loaderClassAttribute(self) -> None:
"""
If there is a non-None loader attribute on the class of an Element
instance but none on the instance itself, the class attribute is used.
"""
class SubElement(Element):
loader = XMLString("<p>Hello, world.</p>")
self.assertFlattensImmediately(SubElement(), b"<p>Hello, world.</p>")
def test_directiveRendering(self) -> None:
"""
An Element with a valid render directive has that directive invoked and
the result added to the output.
"""
renders = []
class RenderfulElement(Element):
@renderer
def renderMethod(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
renders.append((self, request))
return tag("Hello, world.")
element = RenderfulElement(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""
)
)
self.assertFlattensImmediately(element, b"<p>Hello, world.</p>")
def test_directiveRenderingOmittingTag(self) -> None:
"""
An Element with a render method which omits the containing tag
successfully removes that tag from the output.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
return "Hello, world."
element = RenderfulElement(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""
)
)
self.assertFlattensImmediately(element, b"Hello, world.")
def test_elementContainingStaticElement(self) -> None:
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
return tag(Element(loader=XMLString("<em>Hello, world.</em>")))
element = RenderfulElement(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""
)
)
self.assertFlattensImmediately(element, b"<p><em>Hello, world.</em></p>")
def test_elementUsingSlots(self) -> None:
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
return tag.fillSlots(test2="world.")
element = RenderfulElement(
loader=XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"'
' t:render="renderMethod">'
'<t:slot name="test1" default="Hello, " />'
'<t:slot name="test2" />'
"</p>"
)
)
self.assertFlattensImmediately(element, b"<p>Hello, world.</p>")
def test_elementContainingDynamicElement(self) -> None:
"""
Directives in the document factory of an Element returned from a render
method of another Element are satisfied from the correct object: the
"inner" Element.
"""
class OuterElement(Element):
@renderer
def outerMethod(self, request: Optional[IRequest], tag: Tag) -> Flattenable:
return tag(
InnerElement(
loader=XMLString(
"""
<t:ignored
xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="innerMethod" />
"""
)
)
)
class InnerElement(Element):
@renderer
def innerMethod(self, request: Optional[IRequest], tag: Tag) -> Flattenable:
return "Hello, world."
element = OuterElement(
loader=XMLString(
"""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="outerMethod" />
"""
)
)
self.assertFlattensImmediately(element, b"<p>Hello, world.</p>")
def test_sameLoaderTwice(self) -> None:
"""
Rendering the output of a loader, or even the same element, should
return different output each time.
"""
sharedLoader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:transparent t:render="classCounter" /> '
'<t:transparent t:render="instanceCounter" />'
"</p>"
)
class DestructiveElement(Element):
count = 0
instanceCount = 0
loader = sharedLoader
@renderer
def classCounter(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
DestructiveElement.count += 1
return tag(str(DestructiveElement.count))
@renderer
def instanceCounter(
self, request: Optional[IRequest], tag: Tag
) -> Flattenable:
self.instanceCount += 1
return tag(str(self.instanceCount))
e1 = DestructiveElement()
e2 = DestructiveElement()
self.assertFlattensImmediately(e1, b"<p>1 1</p>")
self.assertFlattensImmediately(e1, b"<p>2 2</p>")
self.assertFlattensImmediately(e2, b"<p>3 1</p>")
class TagLoaderTests(FlattenTestCase):
"""
Tests for L{TagLoader}.
"""
def setUp(self) -> None:
self.loader = TagLoader(tags.i("test"))
def test_interface(self) -> None:
"""
An instance of L{TagLoader} provides L{ITemplateLoader}.
"""
self.assertTrue(verifyObject(ITemplateLoader, self.loader))
def test_loadsList(self) -> None:
"""
L{TagLoader.load} returns a list, per L{ITemplateLoader}.
"""
self.assertIsInstance(self.loader.load(), list)
def test_flatten(self) -> None:
"""
L{TagLoader} can be used in an L{Element}, and flattens as the tag used
to construct the L{TagLoader} would flatten.
"""
e = Element(self.loader)
self.assertFlattensImmediately(e, b"<i>test</i>")
class TestElement(Element):
"""
An L{Element} that can be rendered successfully.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
"Hello, world."
"</p>"
)
class TestFailureElement(Element):
"""
An L{Element} that can be used in place of L{FailureElement} to verify
that L{renderElement} can render failures properly.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
"I failed."
"</p>"
)
def __init__(self, failure: Failure, loader: object = None) -> None:
self.failure = failure
class FailingElement(Element):
"""
An element that raises an exception when rendered.
"""
def render(self, request: Optional[IRequest]) -> "Flattenable":
a = 42
b = 0
return f"{a // b}"
class FakeSite:
"""
A minimal L{Site} object that we can use to test displayTracebacks
"""
displayTracebacks = False
@implementer(IRequest)
class DummyRenderRequest(DummyRequest): # type: ignore[misc]
"""
A dummy request object that has a C{site} attribute.
This does not implement the full IRequest interface, but enough of it
for this test suite.
"""
def __init__(self) -> None:
super().__init__([b""])
self.site = FakeSite()
class RenderElementTests(TestCase):
"""
Test L{renderElement}
"""
def setUp(self) -> None:
"""
Set up a common L{DummyRenderRequest}.
"""
self.request = DummyRenderRequest()
def test_simpleRender(self) -> Deferred[None]:
"""
L{renderElement} returns NOT_DONE_YET and eventually
writes the rendered L{Element} to the request before finishing the
request.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_: object) -> None:
self.assertEqual(
b"".join(self.request.written),
b"<!DOCTYPE html>\n" b"<p>Hello, world.</p>",
)
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailure(self) -> Deferred[None]:
"""
L{renderElement} handles failures by writing a minimal
error message to the request and finishing it.
"""
element = FailingElement()
d = self.request.notifyFinish()
def check(_: object) -> None:
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written),
(
b"<!DOCTYPE html>\n"
b'<div style="font-size:800%;'
b"background-color:#FFF;"
b"color:#F00"
b'">An error occurred while rendering the response.</div>'
),
)
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailureWithTraceback(self) -> Deferred[None]:
"""
L{renderElement} will render a traceback when rendering of
the element fails and our site is configured to display tracebacks.
"""
logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)
self.request.site.displayTracebacks = True
element = FailingElement()
d = self.request.notifyFinish()
def check(_: object) -> None:
self.assertEquals(1, len(logObserver))
f = logObserver[0]["log_failure"]
self.assertIsInstance(f.value, FlattenerError)
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written), b"<!DOCTYPE html>\n<p>I failed.</p>"
)
self.assertTrue(self.request.finished)
d.addCallback(check)
renderElement(self.request, element, _failElement=TestFailureElement)
return d
def test_nonDefaultDoctype(self) -> Deferred[None]:
"""
L{renderElement} will write the doctype string specified by the
doctype keyword argument.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_: object) -> None:
self.assertEqual(
b"".join(self.request.written),
(
b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n'
b"<p>Hello, world.</p>"
),
)
d.addCallback(check)
renderElement(
self.request,
element,
doctype=(
b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">'
),
)
return d
def test_noneDoctype(self) -> Deferred[None]:
"""
L{renderElement} will not write out a doctype if the doctype keyword
argument is L{None}.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_: object) -> None:
self.assertEqual(b"".join(self.request.written), b"<p>Hello, world.</p>")
d.addCallback(check)
renderElement(self.request, element, doctype=None)
return d

View File

@@ -0,0 +1,433 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.util}.
"""
import gc
from twisted.internet import defer
from twisted.python.compat import networkString
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.web import resource, util
from twisted.web.error import FlattenerError
from twisted.web.http import FOUND
from twisted.web.server import Request
from twisted.web.template import TagLoader, flattenString, tags
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
from twisted.web.util import (
DeferredResource,
FailureElement,
ParentRedirect,
_FrameElement,
_SourceFragmentElement,
_SourceLineElement,
_StackElement,
formatFailure,
redirectTo,
)
class RedirectToTests(TestCase):
"""
Tests for L{redirectTo}.
"""
def test_headersAndCode(self):
"""
L{redirectTo} will set the C{Location} and C{Content-Type} headers on
its request, and set the response code to C{FOUND}, so the browser will
be redirected.
"""
request = Request(DummyChannel(), True)
request.method = b"GET"
targetURL = b"http://target.example.com/4321"
redirectTo(targetURL, request)
self.assertEqual(request.code, FOUND)
self.assertEqual(
request.responseHeaders.getRawHeaders(b"location"), [targetURL]
)
self.assertEqual(
request.responseHeaders.getRawHeaders(b"content-type"),
[b"text/html; charset=utf-8"],
)
def test_redirectToUnicodeURL(self):
"""
L{redirectTo} will raise TypeError if unicode object is passed in URL
"""
request = Request(DummyChannel(), True)
request.method = b"GET"
targetURL = "http://target.example.com/4321"
self.assertRaises(TypeError, redirectTo, targetURL, request)
def test_legitimateRedirect(self):
"""
Legitimate URLs are fully interpolated in the `redirectTo` response body without transformation
"""
request = DummyRequest([b""])
html = redirectTo(b"https://twisted.org/", request)
expected = b"""
<html>
<head>
<meta http-equiv=\"refresh\" content=\"0;URL=https://twisted.org/\">
</head>
<body bgcolor=\"#FFFFFF\" text=\"#000000\">
<a href=\"https://twisted.org/\">click here</a>
</body>
</html>
"""
self.assertEqual(html, expected)
def test_maliciousRedirect(self):
"""
Malicious URLs are HTML-escaped before interpolating them in the `redirectTo` response body
"""
request = DummyRequest([b""])
html = redirectTo(
b'https://twisted.org/"><script>alert(document.location)</script>', request
)
expected = b"""
<html>
<head>
<meta http-equiv=\"refresh\" content=\"0;URL=https://twisted.org/&quot;&gt;&lt;script&gt;alert(document.location)&lt;/script&gt;\">
</head>
<body bgcolor=\"#FFFFFF\" text=\"#000000\">
<a href=\"https://twisted.org/&quot;&gt;&lt;script&gt;alert(document.location)&lt;/script&gt;\">click here</a>
</body>
</html>
"""
self.assertEqual(html, expected)
class ParentRedirectTests(SynchronousTestCase):
"""
Test L{ParentRedirect}.
"""
def doLocationTest(self, requestPath: bytes) -> bytes:
"""
Render a response to a request with path *requestPath*
@param requestPath: A slash-separated path like C{b'/foo/bar'}.
@returns: The value of the I{Location} header.
"""
request = Request(DummyChannel(), True)
request.method = b"GET"
request.prepath = requestPath.lstrip(b"/").split(b"/")
resource = ParentRedirect()
resource.render(request)
headers = request.responseHeaders.getRawHeaders(b"Location")
assert headers is not None
[location] = headers
return location
def test_locationRoot(self):
"""
At the URL root issue a redirect to the current URL, removing any query
string.
"""
self.assertEqual(b"http://10.0.0.1/", self.doLocationTest(b"/"))
self.assertEqual(b"http://10.0.0.1/", self.doLocationTest(b"/?biff=baff"))
def test_locationToRoot(self):
"""
A request for a resource one level down from the URL root produces
a redirect to the root.
"""
self.assertEqual(b"http://10.0.0.1/", self.doLocationTest(b"/foo"))
self.assertEqual(
b"http://10.0.0.1/", self.doLocationTest(b"/foo?bar=sproiiing")
)
def test_locationUpOne(self):
"""
Requests for resources directly under the path C{/foo/} produce
redirects to C{/foo/}.
"""
self.assertEqual(b"http://10.0.0.1/foo/", self.doLocationTest(b"/foo/"))
self.assertEqual(b"http://10.0.0.1/foo/", self.doLocationTest(b"/foo/bar"))
self.assertEqual(
b"http://10.0.0.1/foo/", self.doLocationTest(b"/foo/bar?biz=baz")
)
class FailureElementTests(TestCase):
"""
Tests for L{FailureElement} and related helpers which can render a
L{Failure} as an HTML string.
"""
def setUp(self):
"""
Create a L{Failure} which can be used by the rendering tests.
"""
def lineNumberProbeAlsoBroken():
message = "This is a problem"
raise Exception(message)
# Figure out the line number from which the exception will be raised.
self.base = lineNumberProbeAlsoBroken.__code__.co_firstlineno + 1
try:
lineNumberProbeAlsoBroken()
except BaseException:
self.failure = Failure(captureVars=True)
self.frame = self.failure.frames[-1]
def test_sourceLineElement(self):
"""
L{_SourceLineElement} renders a source line and line number.
"""
element = _SourceLineElement(
TagLoader(
tags.div(tags.span(render="lineNumber"), tags.span(render="sourceLine"))
),
50,
" print 'hello'",
)
d = flattenString(None, element)
expected = (
"<div><span>50</span><span>"
" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}print 'hello'</span></div>"
)
d.addCallback(self.assertEqual, expected.encode("utf-8"))
return d
def test_sourceFragmentElement(self):
"""
L{_SourceFragmentElement} renders source lines at and around the line
number indicated by a frame object.
"""
element = _SourceFragmentElement(
TagLoader(
tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"),
render="sourceLines",
)
),
self.frame,
)
source = [
" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}message = " '"This is a problem"',
" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}raise Exception(message)",
"",
]
d = flattenString(None, element)
stringToCheckFor = ""
for lineNumber, sourceLine in enumerate(source):
template = '<div class="snippet{}Line"><span>{}</span><span>{}</span></div>'
if lineNumber <= 1:
stringToCheckFor += template.format(
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(" \N{NO-BREAK SPACE}" * 4 + sourceLine),
)
else:
stringToCheckFor += template.format(
"", self.base + lineNumber, ("" + sourceLine)
)
bytesToCheckFor = stringToCheckFor.encode("utf8")
d.addCallback(self.assertEqual, bytesToCheckFor)
return d
def test_frameElementFilename(self):
"""
The I{filename} renderer of L{_FrameElement} renders the filename
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(TagLoader(tags.span(render="filename")), self.frame)
d = flattenString(None, element)
d.addCallback(
# __file__ differs depending on whether an up-to-date .pyc file
# already existed.
self.assertEqual,
b"<span>" + networkString(__file__.rstrip("c")) + b"</span>",
)
return d
def test_frameElementLineNumber(self):
"""
The I{lineNumber} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(TagLoader(tags.span(render="lineNumber")), self.frame)
d = flattenString(None, element)
d.addCallback(self.assertEqual, b"<span>%d</span>" % (self.base + 1,))
return d
def test_frameElementFunction(self):
"""
The I{function} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(TagLoader(tags.span(render="function")), self.frame)
d = flattenString(None, element)
d.addCallback(self.assertEqual, b"<span>lineNumberProbeAlsoBroken</span>")
return d
def test_frameElementSource(self):
"""
The I{source} renderer of L{_FrameElement} renders the source code near
the source filename/line number associated with the frame object used to
initialize the L{_FrameElement}.
"""
element = _FrameElement(None, self.frame)
renderer = element.lookupRenderMethod("source")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _SourceFragmentElement)
self.assertIdentical(result.frame, self.frame)
self.assertEqual([tag], result.loader.load())
def test_stackElement(self):
"""
The I{frames} renderer of L{_StackElement} renders each stack frame in
the list of frames used to initialize the L{_StackElement}.
"""
element = _StackElement(None, self.failure.frames[:2])
renderer = element.lookupRenderMethod("frames")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, list)
self.assertIsInstance(result[0], _FrameElement)
self.assertIdentical(result[0].frame, self.failure.frames[0])
self.assertIsInstance(result[1], _FrameElement)
self.assertIdentical(result[1].frame, self.failure.frames[1])
# They must not share the same tag object.
self.assertNotEqual(result[0].loader.load(), result[1].loader.load())
self.assertEqual(2, len(result))
def test_failureElementTraceback(self):
"""
The I{traceback} renderer of L{FailureElement} renders the failure's
stack frames using L{_StackElement}.
"""
element = FailureElement(self.failure)
renderer = element.lookupRenderMethod("traceback")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _StackElement)
self.assertIdentical(result.stackFrames, self.failure.frames)
self.assertEqual([tag], result.loader.load())
def test_failureElementType(self):
"""
The I{type} renderer of L{FailureElement} renders the failure's
exception type.
"""
element = FailureElement(self.failure, TagLoader(tags.span(render="type")))
d = flattenString(None, element)
exc = b"builtins.Exception"
d.addCallback(self.assertEqual, b"<span>" + exc + b"</span>")
return d
def test_failureElementValue(self):
"""
The I{value} renderer of L{FailureElement} renders the value's exception
value.
"""
element = FailureElement(self.failure, TagLoader(tags.span(render="value")))
d = flattenString(None, element)
d.addCallback(self.assertEqual, b"<span>This is a problem</span>")
return d
class FormatFailureTests(TestCase):
"""
Tests for L{twisted.web.util.formatFailure} which returns an HTML string
representing the L{Failure} instance passed to it.
"""
def test_flattenerError(self):
"""
If there is an error flattening the L{Failure} instance,
L{formatFailure} raises L{FlattenerError}.
"""
self.assertRaises(FlattenerError, formatFailure, object())
def test_returnsBytes(self):
"""
The return value of L{formatFailure} is a C{str} instance (not a
C{unicode} instance) with numeric character references for any non-ASCII
characters meant to appear in the output.
"""
try:
raise Exception("Fake bug")
except BaseException:
result = formatFailure(Failure())
self.assertIsInstance(result, bytes)
self.assertTrue(all(ch < 128 for ch in result))
# Indentation happens to rely on NO-BREAK SPACE
self.assertIn(b"&#160;", result)
class SDResource(resource.Resource):
def __init__(self, default):
self.default = default
def getChildWithDefault(self, name, request):
d = defer.succeed(self.default)
resource = util.DeferredResource(d)
return resource.getChildWithDefault(name, request)
class DeferredResourceTests(SynchronousTestCase):
"""
Tests for L{DeferredResource}.
"""
def testDeferredResource(self):
r = resource.Resource()
r.isLeaf = 1
s = SDResource(r)
d = DummyRequest(["foo", "bar", "baz"])
resource.getChildForRequest(s, d)
self.assertEqual(d.postpath, ["bar", "baz"])
def test_render(self):
"""
L{DeferredResource} uses the request object's C{render} method to
render the resource which is the result of the L{Deferred} being
handled.
"""
rendered = []
request = DummyRequest([])
request.render = rendered.append
result = resource.Resource()
deferredResource = DeferredResource(defer.succeed(result))
deferredResource.render(request)
self.assertEqual(rendered, [result])
def test_renderNoFailure(self):
"""
If the L{Deferred} fails, L{DeferredResource} reports the failure via
C{processingFailed}, and does not cause an unhandled error to be
logged.
"""
request = DummyRequest([])
d = request.notifyFinish()
failure = Failure(RuntimeError())
deferredResource = DeferredResource(defer.fail(failure))
deferredResource.render(request)
self.assertEqual(self.failureResultOf(d), failure)
del deferredResource
gc.collect()
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(errors, [])

View File

@@ -0,0 +1,213 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.vhost}.
"""
from twisted.internet.defer import gatherResults
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.resource import NoResource
from twisted.web.server import Site
from twisted.web.static import Data
from twisted.web.test._util import _render
from twisted.web.test.test_web import DummyRequest
from twisted.web.vhost import NameVirtualHost, VHostMonsterResource, _HostResource
class HostResourceTests(TestCase):
"""
Tests for L{_HostResource}.
"""
def test_getChild(self):
"""
L{_HostResource.getChild} returns the proper I{Resource} for the vhost
embedded in the URL. Verify that returning the proper I{Resource}
required changing the I{Host} in the header.
"""
bazroot = Data(b"root data", "")
bazuri = Data(b"uri data", "")
baztest = Data(b"test data", "")
bazuri.putChild(b"test", baztest)
bazroot.putChild(b"uri", bazuri)
hr = _HostResource()
root = NameVirtualHost()
root.default = Data(b"default data", "")
root.addHost(b"baz.com", bazroot)
request = DummyRequest([b"uri", b"test"])
request.prepath = [b"bar", b"http", b"baz.com"]
request.site = Site(root)
request.isSecure = lambda: False
request.host = b""
step = hr.getChild(b"baz.com", request) # Consumes rest of path
self.assertIsInstance(step, Data)
request = DummyRequest([b"uri", b"test"])
step = root.getChild(b"uri", request)
self.assertIsInstance(step, NoResource)
class NameVirtualHostTests(TestCase):
"""
Tests for L{NameVirtualHost}.
"""
def test_renderWithoutHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct result", "")
request = DummyRequest([b""])
self.assertEqual(virtualHostResource.render(request), b"correct result")
def test_renderWithoutHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([b""])
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderWithHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the resource
which is the value in the instance's C{host} dictionary corresponding
to the key indicated by the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.addHost(b"example.org", Data(b"winner", ""))
request = DummyRequest([b""])
request.requestHeaders.addRawHeader(b"host", b"example.org")
d = _render(virtualHostResource, request)
def cbRendered(ignored, request):
self.assertEqual(b"".join(request.written), b"winner")
d.addCallback(cbRendered, request)
# The port portion of the Host header should not be considered.
requestWithPort = DummyRequest([b""])
requestWithPort.requestHeaders.addRawHeader(b"host", b"example.org:8000")
dWithPort = _render(virtualHostResource, requestWithPort)
def cbRendered(ignored, requestWithPort):
self.assertEqual(b"".join(requestWithPort.written), b"winner")
dWithPort.addCallback(cbRendered, requestWithPort)
return gatherResults([d, dWithPort])
def test_renderWithUnknownHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct data", "")
request = DummyRequest([b""])
request.requestHeaders.addRawHeader(b"host", b"example.com")
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(b"".join(request.written), b"correct data")
d.addCallback(cbRendered)
return d
def test_renderWithUnknownHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([b""])
request.requestHeaders.addRawHeader(b"host", b"example.com")
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
async def test_renderWithHTMLHost(self):
"""
L{NameVirtualHost.render} doesn't echo unescaped HTML when present in
the I{Host} header.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([b""])
request.requestHeaders.addRawHeader(b"host", b"<b>example</b>.com")
await _render(virtualHostResource, request)
self.assertNotIn(b"<b>", b"".join(request.written))
def test_getChild(self):
"""
L{NameVirtualHost.getChild} returns correct I{Resource} based off
the header and modifies I{Request} to ensure proper prepath and
postpath are set.
"""
virtualHostResource = NameVirtualHost()
leafResource = Data(b"leaf data", "")
leafResource.isLeaf = True
normResource = Data(b"norm data", "")
virtualHostResource.addHost(b"leaf.example.org", leafResource)
virtualHostResource.addHost(b"norm.example.org", normResource)
request = DummyRequest([])
request.requestHeaders.addRawHeader(b"host", b"norm.example.org")
request.prepath = [b""]
self.assertIsInstance(virtualHostResource.getChild(b"", request), NoResource)
self.assertEqual(request.prepath, [b""])
self.assertEqual(request.postpath, [])
request = DummyRequest([])
request.requestHeaders.addRawHeader(b"host", b"leaf.example.org")
request.prepath = [b""]
self.assertIsInstance(virtualHostResource.getChild(b"", request), Data)
self.assertEqual(request.prepath, [])
self.assertEqual(request.postpath, [b""])
class VHostMonsterResourceTests(TestCase):
"""
Tests for L{VHostMonsterResource}.
"""
def test_getChild(self):
"""
L{VHostMonsterResource.getChild} returns I{_HostResource} and modifies
I{Request} with correct L{Request.isSecure}.
"""
vhm = VHostMonsterResource()
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b"http", request), _HostResource)
self.assertFalse(request.isSecure())
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b"https", request), _HostResource)
self.assertTrue(request.isSecure())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The L{_response} module contains constants for all standard HTTP codes, along
with a mapping to the corresponding phrases.
"""
import string
from twisted.trial import unittest
from twisted.web import _responses
class ResponseTests(unittest.TestCase):
def test_constants(self) -> None:
"""
All constants besides C{RESPONSES} defined in L{_response} are
integers and are keys in C{RESPONSES}.
"""
for sym in dir(_responses):
if sym == "RESPONSES":
continue
if all((c == "_" or c in string.ascii_uppercase) for c in sym):
val = getattr(_responses, sym)
self.assertIsInstance(val, int)
self.assertIn(val, _responses.RESPONSES)

View File

@@ -0,0 +1,367 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests L{twisted.web.client} helper APIs
"""
from urllib.parse import urlparse
from twisted.trial import unittest
from twisted.web import client
class URLJoinTests(unittest.TestCase):
"""
Tests for L{client._urljoin}.
"""
def test_noFragments(self):
"""
L{client._urljoin} does not include a fragment identifier in the
resulting URL if neither the base nor the new path include a fragment
identifier.
"""
self.assertEqual(
client._urljoin(b"http://foo.com/bar", b"/quux"), b"http://foo.com/quux"
)
self.assertEqual(
client._urljoin(b"http://foo.com/bar#", b"/quux"), b"http://foo.com/quux"
)
self.assertEqual(
client._urljoin(b"http://foo.com/bar", b"/quux#"), b"http://foo.com/quux"
)
def test_preserveFragments(self):
"""
L{client._urljoin} preserves the fragment identifier from either the
new path or the base URL respectively, as specified in the HTTP 1.1 bis
draft.
@see: U{https://tools.ietf.org/html/draft-ietf-httpbis-p2-semantics-22#section-7.1.2}
"""
self.assertEqual(
client._urljoin(b"http://foo.com/bar#frag", b"/quux"),
b"http://foo.com/quux#frag",
)
self.assertEqual(
client._urljoin(b"http://foo.com/bar", b"/quux#frag2"),
b"http://foo.com/quux#frag2",
)
self.assertEqual(
client._urljoin(b"http://foo.com/bar#frag", b"/quux#frag2"),
b"http://foo.com/quux#frag2",
)
class URITests:
"""
Abstract tests for L{twisted.web.client.URI}.
Subclass this and L{unittest.TestCase}. Then provide a value for
C{host} and C{uriHost}.
@ivar host: A host specification for use in tests, must be L{bytes}.
@ivar uriHost: The host specification in URI form, must be a L{bytes}. In
most cases this is identical with C{host}. IPv6 address literals are an
exception, according to RFC 3986 section 3.2.2, as they need to be
enclosed in brackets. In this case this variable is different.
"""
def makeURIString(self, template):
"""
Replace the string "HOST" in C{template} with this test's host.
Byte strings Python between (and including) versions 3.0 and 3.4
cannot be formatted using C{%} or C{format} so this does a simple
replace.
@type template: L{bytes}
@param template: A string containing "HOST".
@rtype: L{bytes}
@return: A string where "HOST" has been replaced by C{self.host}.
"""
self.assertIsInstance(self.host, bytes)
self.assertIsInstance(self.uriHost, bytes)
self.assertIsInstance(template, bytes)
self.assertIn(b"HOST", template)
return template.replace(b"HOST", self.uriHost)
def assertURIEquals(
self, uri, scheme, netloc, host, port, path, params=b"", query=b"", fragment=b""
):
"""
Assert that all of a L{client.URI}'s components match the expected
values.
@param uri: U{client.URI} instance whose attributes will be checked
for equality.
@type scheme: L{bytes}
@param scheme: URI scheme specifier.
@type netloc: L{bytes}
@param netloc: Network location component.
@type host: L{bytes}
@param host: Host name.
@type port: L{int}
@param port: Port number.
@type path: L{bytes}
@param path: Hierarchical path.
@type params: L{bytes}
@param params: Parameters for last path segment, defaults to C{b''}.
@type query: L{bytes}
@param query: Query string, defaults to C{b''}.
@type fragment: L{bytes}
@param fragment: Fragment identifier, defaults to C{b''}.
"""
self.assertEqual(
(scheme, netloc, host, port, path, params, query, fragment),
(
uri.scheme,
uri.netloc,
uri.host,
uri.port,
uri.path,
uri.params,
uri.query,
uri.fragment,
),
)
def test_parseDefaultPort(self):
"""
L{client.URI.fromBytes} by default assumes port 80 for the I{http}
scheme and 443 for the I{https} scheme.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST"))
self.assertEqual(80, uri.port)
# Weird (but commonly accepted) structure uses default port.
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST:"))
self.assertEqual(80, uri.port)
uri = client.URI.fromBytes(self.makeURIString(b"https://HOST"))
self.assertEqual(443, uri.port)
def test_parseCustomDefaultPort(self):
"""
L{client.URI.fromBytes} accepts a C{defaultPort} parameter that
overrides the normal default port logic.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST"), defaultPort=5144)
self.assertEqual(5144, uri.port)
uri = client.URI.fromBytes(
self.makeURIString(b"https://HOST"), defaultPort=5144
)
self.assertEqual(5144, uri.port)
def test_netlocHostPort(self):
"""
Parsing a I{URI} splits the network location component into I{host} and
I{port}.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST:5144"))
self.assertEqual(5144, uri.port)
self.assertEqual(self.host, uri.host)
self.assertEqual(self.uriHost + b":5144", uri.netloc)
# Spaces in the hostname are trimmed, the default path is /.
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST "))
self.assertEqual(self.uriHost, uri.netloc)
def test_path(self):
"""
Parse the path from a I{URI}.
"""
uri = self.makeURIString(b"http://HOST/foo/bar")
parsed = client.URI.fromBytes(uri)
self.assertURIEquals(
parsed,
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"/foo/bar",
)
self.assertEqual(uri, parsed.toBytes())
def test_noPath(self):
"""
The path of a I{URI} that has no path is the empty string.
"""
uri = self.makeURIString(b"http://HOST")
parsed = client.URI.fromBytes(uri)
self.assertURIEquals(
parsed,
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"",
)
self.assertEqual(uri, parsed.toBytes())
def test_emptyPath(self):
"""
The path of a I{URI} with an empty path is C{b'/'}.
"""
uri = self.makeURIString(b"http://HOST/")
self.assertURIEquals(
client.URI.fromBytes(uri),
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"/",
)
def test_param(self):
"""
Parse I{URI} parameters from a I{URI}.
"""
uri = self.makeURIString(b"http://HOST/foo/bar;param")
parsed = client.URI.fromBytes(uri)
self.assertURIEquals(
parsed,
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"/foo/bar",
params=b"param",
)
self.assertEqual(uri, parsed.toBytes())
def test_query(self):
"""
Parse the query string from a I{URI}.
"""
uri = self.makeURIString(b"http://HOST/foo/bar;param?a=1&b=2")
parsed = client.URI.fromBytes(uri)
self.assertURIEquals(
parsed,
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"/foo/bar",
params=b"param",
query=b"a=1&b=2",
)
self.assertEqual(uri, parsed.toBytes())
def test_fragment(self):
"""
Parse the fragment identifier from a I{URI}.
"""
uri = self.makeURIString(b"http://HOST/foo/bar;param?a=1&b=2#frag")
parsed = client.URI.fromBytes(uri)
self.assertURIEquals(
parsed,
scheme=b"http",
netloc=self.uriHost,
host=self.host,
port=80,
path=b"/foo/bar",
params=b"param",
query=b"a=1&b=2",
fragment=b"frag",
)
self.assertEqual(uri, parsed.toBytes())
def test_originForm(self):
"""
L{client.URI.originForm} produces an absolute I{URI} path including
the I{URI} path.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST/foo"))
self.assertEqual(b"/foo", uri.originForm)
def test_originFormComplex(self):
"""
L{client.URI.originForm} produces an absolute I{URI} path including
the I{URI} path, parameters and query string but excludes the fragment
identifier.
"""
uri = client.URI.fromBytes(
self.makeURIString(b"http://HOST/foo;param?a=1#frag")
)
self.assertEqual(b"/foo;param?a=1", uri.originForm)
def test_originFormNoPath(self):
"""
L{client.URI.originForm} produces a path of C{b'/'} when the I{URI}
specifies no path.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST"))
self.assertEqual(b"/", uri.originForm)
def test_originFormEmptyPath(self):
"""
L{client.URI.originForm} produces a path of C{b'/'} when the I{URI}
specifies an empty path.
"""
uri = client.URI.fromBytes(self.makeURIString(b"http://HOST/"))
self.assertEqual(b"/", uri.originForm)
def test_externalUnicodeInterference(self):
"""
L{client.URI.fromBytes} parses the scheme, host, and path elements
into L{bytes}, even when passed an URL which has previously been passed
to L{urlparse} as a L{unicode} string.
"""
goodInput = self.makeURIString(b"http://HOST/path")
badInput = goodInput.decode("ascii")
urlparse(badInput)
uri = client.URI.fromBytes(goodInput)
self.assertIsInstance(uri.scheme, bytes)
self.assertIsInstance(uri.host, bytes)
self.assertIsInstance(uri.path, bytes)
class URITestsForHostname(URITests, unittest.TestCase):
"""
Tests for L{twisted.web.client.URI} with host names.
"""
uriHost = host = b"example.com"
class URITestsForIPv4(URITests, unittest.TestCase):
"""
Tests for L{twisted.web.client.URI} with IPv4 host addresses.
"""
uriHost = host = b"192.168.1.67"
class URITestsForIPv6(URITests, unittest.TestCase):
"""
Tests for L{twisted.web.client.URI} with IPv6 host addresses.
IPv6 addresses must always be surrounded by square braces in URIs. No
attempt is made to test without.
"""
host = b"fe80::20c:29ff:fea4:c60"
uriHost = b"[fe80::20c:29ff:fea4:c60]"
def test_hostBracketIPv6AddressLiteral(self):
"""
Brackets around IPv6 addresses are stripped in the host field. The host
field is then exported with brackets in the output of
L{client.URI.toBytes}.
"""
uri = client.URI.fromBytes(b"http://[::1]:80/index.html")
self.assertEqual(uri.host, b"::1")
self.assertEqual(uri.netloc, b"[::1]:80")
self.assertEqual(uri.toBytes(), b"http://[::1]:80/index.html")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,949 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for XML-RPC support in L{twisted.web.xmlrpc}.
"""
import datetime
from io import BytesIO, StringIO
from unittest import skipIf
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectionDone
from twisted.internet.testing import EventLoggingObserver, MemoryReactor
from twisted.logger import (
FilteringLogObserver,
LogLevel,
LogLevelFilterPredicate,
globalLogPublisher,
)
from twisted.python import failure
from twisted.python.compat import nativeString, networkString
from twisted.python.reflect import namedModule
from twisted.trial import unittest
from twisted.web import client, http, server, static, xmlrpc
from twisted.web.test.test_web import DummyRequest
from twisted.web.xmlrpc import (
XMLRPC,
QueryFactory,
addIntrospection,
payloadTemplate,
withRequest,
xmlrpclib,
)
try:
namedModule("twisted.internet.ssl")
except ImportError:
sslSkip = True
else:
sslSkip = False
class AsyncXMLRPCTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of Deferreds.
"""
def setUp(self):
self.request = DummyRequest([""])
self.request.method = "POST"
self.request.content = StringIO(
payloadTemplate % ("async", xmlrpclib.dumps(()))
)
result = self.result = defer.Deferred()
class AsyncResource(XMLRPC):
def xmlrpc_async(self):
return result
self.resource = AsyncResource()
def test_deferredResponse(self):
"""
If an L{XMLRPC} C{xmlrpc_*} method returns a L{defer.Deferred}, the
response to the request is the result of that L{defer.Deferred}.
"""
self.resource.render(self.request)
self.assertEqual(self.request.written, [])
self.result.callback("result")
resp = xmlrpclib.loads(b"".join(self.request.written))
self.assertEqual(resp, (("result",), None))
self.assertEqual(self.request.finished, 1)
def test_interruptedDeferredResponse(self):
"""
While waiting for the L{Deferred} returned by an L{XMLRPC} C{xmlrpc_*}
method to fire, the connection the request was issued over may close.
If this happens, neither C{write} nor C{finish} is called on the
request.
"""
self.resource.render(self.request)
self.request.processingFailed(failure.Failure(ConnectionDone("Simulated")))
self.result.callback("result")
self.assertEqual(self.request.written, [])
self.assertEqual(self.request.finished, 0)
class TestRuntimeError(RuntimeError):
pass
class TestValueError(ValueError):
pass
class Test(XMLRPC):
# If you add xmlrpc_ methods to this class, go change test_listMethods
# below.
FAILURE = 666
NOT_FOUND = 23
SESSION_EXPIRED = 42
def xmlrpc_echo(self, arg):
return arg
# the doc string is part of the test
def xmlrpc_add(self, a, b):
"""
This function add two numbers.
"""
return a + b
xmlrpc_add.signature = [ # type: ignore[attr-defined]
["int", "int", "int"],
["double", "double", "double"],
]
# the doc string is part of the test
def xmlrpc_pair(self, string, num):
"""
This function puts the two arguments in an array.
"""
return [string, num]
xmlrpc_pair.signature = [["array", "string", "int"]] # type: ignore[attr-defined]
# the doc string is part of the test
def xmlrpc_defer(self, x):
"""Help for defer."""
return defer.succeed(x)
def xmlrpc_deferFail(self):
return defer.fail(TestValueError())
# don't add a doc string, it's part of the test
def xmlrpc_fail(self):
raise TestRuntimeError
def xmlrpc_fault(self):
return xmlrpc.Fault(12, "hello")
def xmlrpc_deferFault(self):
return defer.fail(xmlrpc.Fault(17, "hi"))
def xmlrpc_snowman(self, payload):
"""
Used to test that we can pass Unicode.
"""
snowman = "\u2603"
if snowman != payload:
return xmlrpc.Fault(13, "Payload not unicode snowman")
return snowman
def xmlrpc_complex(self):
return {"a": ["b", "c", 12, []], "D": "foo"}
def xmlrpc_dict(self, map, key):
return map[key]
xmlrpc_dict.help = "Help for dict." # type: ignore[attr-defined]
@withRequest
def xmlrpc_withRequest(self, request, other):
"""
A method decorated with L{withRequest} which can be called by
a test to verify that the request object really is passed as
an argument.
"""
return (
# as a proof that request is a request
request.method
+
# plus proof other arguments are still passed along
" "
+ other
)
def lookupProcedure(self, procedurePath):
try:
return XMLRPC.lookupProcedure(self, procedurePath)
except xmlrpc.NoSuchFunction:
if procedurePath.startswith("SESSION"):
raise xmlrpc.Fault(
self.SESSION_EXPIRED, "Session non-existent/expired."
)
else:
raise
class TestLookupProcedure(XMLRPC):
"""
This is a resource which customizes procedure lookup to be used by the tests
of support for this customization.
"""
def echo(self, x):
return x
def lookupProcedure(self, procedureName):
"""
Lookup a procedure from a fixed set of choices, either I{echo} or
I{system.listeMethods}.
"""
if procedureName == "echo":
return self.echo
raise xmlrpc.NoSuchFunction(
self.NOT_FOUND, f"procedure {procedureName} not found"
)
class TestListProcedures(XMLRPC):
"""
This is a resource which customizes procedure enumeration to be used by the
tests of support for this customization.
"""
def listProcedures(self):
"""
Return a list of a single method this resource will claim to support.
"""
return ["foo"]
class TestAuthHeader(Test):
"""
This is used to get the header info so that we can test
authentication.
"""
def __init__(self):
Test.__init__(self)
self.request = None
def render(self, request):
self.request = request
return Test.render(self, request)
def xmlrpc_authinfo(self):
return self.request.getUser(), self.request.getPassword()
class TestQueryProtocol(xmlrpc.QueryProtocol):
"""
QueryProtocol for tests that saves headers received and sent,
inside the factory.
"""
def connectionMade(self):
self.factory.transport = self.transport
xmlrpc.QueryProtocol.connectionMade(self)
def handleHeader(self, key, val):
self.factory.headers[key.lower()] = val
def sendHeader(self, key, val):
"""
Keep sent headers so we can inspect them later.
"""
self.factory.sent_headers[key.lower()] = val
xmlrpc.QueryProtocol.sendHeader(self, key, val)
class TestQueryFactory(xmlrpc.QueryFactory):
"""
QueryFactory using L{TestQueryProtocol} for saving headers.
"""
protocol = TestQueryProtocol
def __init__(self, *args, **kwargs):
self.headers = {}
self.sent_headers = {}
xmlrpc.QueryFactory.__init__(self, *args, **kwargs)
class TestQueryFactoryCancel(xmlrpc.QueryFactory):
"""
QueryFactory that saves a reference to the
L{twisted.internet.interfaces.IConnector} to test connection lost.
"""
def startedConnecting(self, connector):
self.connector = connector
class XMLRPCTests(unittest.TestCase):
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(Test()), interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def tearDown(self):
self.factories = []
return self.p.stopListening()
def queryFactory(self, *args, **kwargs):
"""
Specific queryFactory for proxy that uses our custom
L{TestQueryFactory}, and save factories.
"""
factory = TestQueryFactory(*args, **kwargs)
self.factories.append(factory)
return factory
def proxy(self, factory=None):
"""
Return a new xmlrpc.Proxy for the test site created in
setUp(), using the given factory as the queryFactory, or
self.queryFactory if no factory is provided.
"""
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
def test_results(self):
inputOutput = [
("add", (2, 3), 5),
("defer", ("a",), "a"),
("dict", ({"a": 1}, "a"), 1),
("pair", ("a", 1), ["a", 1]),
("snowman", ("\u2603"), "\u2603"),
("complex", (), {"a": ["b", "c", 12, []], "D": "foo"}),
]
dl = []
for meth, args, outp in inputOutput:
d = self.proxy().callRemote(meth, *args)
d.addCallback(self.assertEqual, outp)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_headers(self):
"""
Verify that headers sent from the client side and the ones we
get back from the server side are correct.
"""
d = self.proxy().callRemote("snowman", "\u2603")
def check_server_headers(ing):
self.assertEqual(
self.factories[0].headers[b"content-type"], b"text/xml; charset=utf-8"
)
self.assertEqual(self.factories[0].headers[b"content-length"], b"129")
def check_client_headers(ign):
self.assertEqual(
self.factories[0].sent_headers[b"user-agent"], b"Twisted/XMLRPClib"
)
self.assertEqual(
self.factories[0].sent_headers[b"content-type"],
b"text/xml; charset=utf-8",
)
self.assertEqual(self.factories[0].sent_headers[b"content-length"], b"155")
d.addCallback(check_server_headers)
d.addCallback(check_client_headers)
return d
def test_errors(self):
"""
Verify that for each way a method exposed via XML-RPC can fail, the
correct 'Content-type' header is set in the response and that the
client-side Deferred is errbacked with an appropriate C{Fault}
instance.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver, [LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
dl = []
for code, methodName in [
(666, "fail"),
(666, "deferFail"),
(12, "fault"),
(23, "noSuchMethod"),
(17, "deferFault"),
(42, "SESSION_TEST"),
]:
d = self.proxy().callRemote(methodName)
d = self.assertFailure(d, xmlrpc.Fault)
d.addCallback(lambda exc, code=code: self.assertEqual(exc.faultCode, code))
dl.append(d)
d = defer.DeferredList(dl, fireOnOneErrback=True)
def cb(ign):
for factory in self.factories:
self.assertEqual(
factory.headers[b"content-type"], b"text/xml; charset=utf-8"
)
self.assertEquals(2, len(logObserver))
f1 = logObserver[0]["log_failure"].value
f2 = logObserver[1]["log_failure"].value
if isinstance(f1, TestValueError):
self.assertIsInstance(f2, TestRuntimeError)
else:
self.assertIsInstance(f1, TestRuntimeError)
self.assertIsInstance(f2, TestValueError)
self.flushLoggedErrors(TestRuntimeError, TestValueError)
d.addCallback(cb)
return d
def test_cancel(self):
"""
A deferred from the Proxy can be cancelled, disconnecting
the L{twisted.internet.interfaces.IConnector}.
"""
def factory(*args, **kw):
factory.f = TestQueryFactoryCancel(*args, **kw)
return factory.f
d = self.proxy(factory).callRemote("add", 2, 3)
self.assertNotEqual(factory.f.connector.state, "disconnected")
d.cancel()
self.assertEqual(factory.f.connector.state, "disconnected")
d = self.assertFailure(d, defer.CancelledError)
return d
def test_errorGet(self):
"""
A classic GET on the xml server should return a NOT_ALLOWED.
"""
agent = client.Agent(reactor)
d = agent.request(b"GET", networkString("http://127.0.0.1:%d/" % (self.port,)))
def checkResponse(response):
self.assertEqual(response.code, http.NOT_ALLOWED)
d.addCallback(checkResponse)
return d
def test_errorXMLContent(self):
"""
Test that an invalid XML input returns an L{xmlrpc.Fault}.
"""
agent = client.Agent(reactor)
d = agent.request(
uri=networkString("http://127.0.0.1:%d/" % (self.port,)),
method=b"POST",
bodyProducer=client.FileBodyProducer(BytesIO(b"foo")),
)
d.addCallback(client.readBody)
def cb(result):
self.assertRaises(xmlrpc.Fault, xmlrpclib.loads, result)
d.addCallback(cb)
return d
def test_datetimeRoundtrip(self):
"""
If an L{xmlrpclib.DateTime} is passed as an argument to an XML-RPC
call and then returned by the server unmodified, the result should
be equal to the original object.
"""
when = xmlrpclib.DateTime()
d = self.proxy().callRemote("echo", when)
d.addCallback(self.assertEqual, when)
return d
def test_doubleEncodingError(self):
"""
If it is not possible to encode a response to the request (for example,
because L{xmlrpclib.dumps} raises an exception when encoding a
L{Fault}) the exception which prevents the response from being
generated is logged and the request object is finished anyway.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver, [LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
d = self.proxy().callRemote("echo", "")
# *Now* break xmlrpclib.dumps. Hopefully the client already used it.
def fakeDumps(*args, **kwargs):
raise RuntimeError("Cannot encode anything at all!")
self.patch(xmlrpclib, "dumps", fakeDumps)
# It doesn't matter how it fails, so long as it does. Also, it happens
# to fail with an implementation detail exception right now, not
# something suitable as part of a public interface.
d = self.assertFailure(d, Exception)
def cbFailed(ignored):
# The fakeDumps exception should have been logged.
self.assertEquals(1, len(logObserver))
self.assertIsInstance(logObserver[0]["log_failure"].value, RuntimeError)
self.assertEqual(len(self.flushLoggedErrors(RuntimeError)), 1)
d.addCallback(cbFailed)
return d
def test_closeConnectionAfterRequest(self):
"""
The connection to the web server is closed when the request is done.
"""
d = self.proxy().callRemote("echo", "")
def responseDone(ignored):
[factory] = self.factories
self.assertFalse(factory.transport.connected)
self.assertTrue(factory.transport.disconnected)
return d.addCallback(responseDone)
def test_tcpTimeout(self):
"""
For I{HTTP} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectTCP call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(
b"http://127.0.0.1:69", connectTimeout=2.0, reactor=reactor
)
proxy.callRemote("someMethod")
self.assertEqual(reactor.tcpClients[0][3], 2.0)
@skipIf(sslSkip, "OpenSSL not present")
def test_sslTimeout(self):
"""
For I{HTTPS} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectSSL call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(
b"https://127.0.0.1:69", connectTimeout=3.0, reactor=reactor
)
proxy.callRemote("someMethod")
self.assertEqual(reactor.sslClients[0][4], 3.0)
class XMLRPCProxyWithoutSlashTests(XMLRPCTests):
"""
Test with proxy that doesn't add a slash.
"""
def proxy(self, factory=None):
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
class XMLRPCPublicLookupProcedureTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of subclasses which override
C{lookupProcedure} and C{listProcedures}.
"""
def createServer(self, resource):
self.p = reactor.listenTCP(0, server.Site(resource), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(networkString("http://127.0.0.1:%d" % self.port))
def test_lookupProcedure(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to find
procedures that are not defined using a C{xmlrpc_}-prefixed method name.
"""
self.createServer(TestLookupProcedure())
what = "hello"
d = self.proxy.callRemote("echo", what)
d.addCallback(self.assertEqual, what)
return d
def test_errors(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to raise
L{NoSuchFunction} to indicate that a requested method is not available
to be called, signalling a fault to the XML-RPC client.
"""
self.createServer(TestLookupProcedure())
d = self.proxy.callRemote("xxxx", "hello")
d = self.assertFailure(d, xmlrpc.Fault)
return d
def test_listMethods(self):
"""
A subclass of L{XMLRPC} can override C{listProcedures} to define
Overriding listProcedures should prevent introspection from being
broken.
"""
resource = TestListProcedures()
addIntrospection(resource)
self.createServer(resource)
d = self.proxy.callRemote("system.listMethods")
def listed(procedures):
# The list will also include other introspection procedures added by
# addIntrospection. We just want to see "foo" from our customized
# listProcedures.
self.assertIn("foo", procedures)
d.addCallback(listed)
return d
class SerializationConfigMixin:
"""
Mixin which defines a couple tests which should pass when a particular flag
is passed to L{XMLRPC}.
These are not meant to be exhaustive serialization tests, since L{xmlrpclib}
does all of the actual serialization work. They are just meant to exercise
a few codepaths to make sure we are calling into xmlrpclib correctly.
@ivar flagName: A C{str} giving the name of the flag which must be passed to
L{XMLRPC} to allow the tests to pass. Subclasses should set this.
@ivar value: A value which the specified flag will allow the serialization
of. Subclasses should set this.
"""
def setUp(self):
"""
Create a new XML-RPC server with C{allowNone} set to C{True}.
"""
kwargs = {self.flagName: True}
self.p = reactor.listenTCP(
0, server.Site(Test(**kwargs)), interface="127.0.0.1"
)
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port,)), **kwargs
)
def test_roundtripValue(self):
"""
C{self.value} can be round-tripped over an XMLRPC method call/response.
"""
d = self.proxy.callRemote("defer", self.value)
d.addCallback(self.assertEqual, self.value)
return d
def test_roundtripNestedValue(self):
"""
A C{dict} which contains C{self.value} can be round-tripped over an
XMLRPC method call/response.
"""
d = self.proxy.callRemote("defer", {"a": self.value})
d.addCallback(self.assertEqual, {"a": self.value})
return d
class XMLRPCAllowNoneTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing L{None} when the C{allowNone} flag is set.
"""
flagName = "allowNone"
value = None
class XMLRPCUseDateTimeTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing a C{datetime.datetime} instance when the C{useDateTime}
flag is set.
"""
flagName = "useDateTime"
value = datetime.datetime(2000, 12, 28, 3, 45, 59)
class XMLRPCAuthenticatedTests(XMLRPCTests):
"""
Test with authenticated proxy. We run this with the same input/output as
above.
"""
user = b"username"
password = b"asecret"
def setUp(self):
self.p = reactor.listenTCP(
0, server.Site(TestAuthHeader()), interface="127.0.0.1"
)
self.port = self.p.getHost().port
self.factories = []
def test_authInfoInURL(self):
url = "http://%s:%s@127.0.0.1:%d/" % (
nativeString(self.user),
nativeString(self.password),
self.port,
)
p = xmlrpc.Proxy(networkString(url))
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_explicitAuthInfo(self):
p = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port,)),
self.user,
self.password,
)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_longPassword(self):
"""
C{QueryProtocol} uses the C{base64.b64encode} function to encode user
name and password in the I{Authorization} header, so that it doesn't
embed new lines when using long inputs.
"""
longPassword = self.password * 40
p = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port,)),
self.user,
longPassword,
)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, longPassword])
return d
def test_explicitAuthInfoOverride(self):
p = xmlrpc.Proxy(
networkString("http://wrong:info@127.0.0.1:%d/" % (self.port,)),
self.user,
self.password,
)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
class XMLRPCIntrospectionTests(XMLRPCTests):
def setUp(self):
xmlrpc = Test()
addIntrospection(xmlrpc)
self.p = reactor.listenTCP(0, server.Site(xmlrpc), interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_listMethods(self):
def cbMethods(meths):
meths.sort()
self.assertEqual(
meths,
[
"add",
"complex",
"defer",
"deferFail",
"deferFault",
"dict",
"echo",
"fail",
"fault",
"pair",
"snowman",
"system.listMethods",
"system.methodHelp",
"system.methodSignature",
"withRequest",
],
)
d = self.proxy().callRemote("system.listMethods")
d.addCallback(cbMethods)
return d
def test_methodHelp(self):
inputOutputs = [
("defer", "Help for defer."),
("fail", ""),
("dict", "Help for dict."),
]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodHelp", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_methodSignature(self):
inputOutputs = [
("defer", ""),
("add", [["int", "int", "int"], ["double", "double", "double"]]),
("pair", [["array", "string", "int"]]),
]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodSignature", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
class XMLRPCClientErrorHandlingTests(unittest.TestCase):
"""
Test error handling on the xmlrpc client.
"""
def setUp(self):
self.resource = static.Data(
b"This text is not a valid XML-RPC response.", b"text/plain"
)
self.resource.isLeaf = True
self.port = reactor.listenTCP(
0, server.Site(self.resource), interface="127.0.0.1"
)
def tearDown(self):
return self.port.stopListening()
def test_erroneousResponse(self):
"""
Test that calling the xmlrpc client on a static http server raises
an exception.
"""
proxy = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port.getHost().port,))
)
return self.assertFailure(proxy.callRemote("someMethod"), ValueError)
class QueryFactoryParseResponseTests(unittest.TestCase):
"""
Test the behaviour of L{QueryFactory.parseResponse}.
"""
def setUp(self):
# The QueryFactory that we are testing. We don't care about any
# of the constructor parameters.
self.queryFactory = QueryFactory(
path=None,
host=None,
method="POST",
user=None,
password=None,
allowNone=False,
args=(),
)
# An XML-RPC response that will parse without raising an error.
self.goodContents = xmlrpclib.dumps(("",))
# An 'XML-RPC response' that will raise a parsing error.
self.badContents = "invalid xml"
# A dummy 'reason' to pass to clientConnectionLost. We don't care
# what it is.
self.reason = failure.Failure(ConnectionDone())
def test_parseResponseCallbackSafety(self):
"""
We can safely call L{QueryFactory.clientConnectionLost} as a callback
of L{QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addCallback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.goodContents)
return d
def test_parseResponseErrbackSafety(self):
"""
We can safely call L{QueryFactory.clientConnectionLost} as an errback
of L{QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.badContents)
return d
def test_badStatusErrbackSafety(self):
"""
We can safely call L{QueryFactory.clientConnectionLost} as an errback
of L{QueryFactory.badStatus}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.badStatus("status", "message")
return d
def test_parseResponseWithoutData(self):
"""
Some server can send a response without any data:
L{QueryFactory.parseResponse} should catch the error and call the
result errback.
"""
content = """
<methodResponse>
<params>
<param>
</param>
</params>
</methodResponse>"""
d = self.queryFactory.deferred
self.queryFactory.parseResponse(content)
return self.assertFailure(d, IndexError)
class XMLRPCWithRequestTests(unittest.TestCase):
def setUp(self):
self.resource = Test()
def test_withRequest(self):
"""
When an XML-RPC method is called and the implementation is
decorated with L{withRequest}, the request object is passed as
the first argument.
"""
request = DummyRequest("/RPC2")
request.method = "POST"
request.content = StringIO(xmlrpclib.dumps(("foo",), "withRequest"))
def valid(n, request):
data = xmlrpclib.loads(request.written[0])
self.assertEqual(data, (("POST foo",), None))
d = request.notifyFinish().addCallback(valid, request)
self.resource.render_POST(request)
return d

View File

@@ -0,0 +1,343 @@
# -*- test-case-name: twisted.web.test.test_cgi -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I hold resource classes and helper classes that deal with CGI scripts.
"""
# System Imports
import os
import urllib
from typing import AnyStr
# Twisted Imports
from twisted.internet import protocol
from twisted.logger import Logger
from twisted.python import filepath
from twisted.spread import pb
from twisted.web import http, resource, server, static
class CGIDirectory(resource.Resource, filepath.FilePath[AnyStr]):
def __init__(self, pathname):
resource.Resource.__init__(self)
filepath.FilePath.__init__(self, pathname)
def getChild(self, path, request):
fnp = self.child(path)
if not fnp.exists():
return static.File.childNotFound
elif fnp.isdir():
return CGIDirectory(fnp.path)
else:
return CGIScript(fnp.path)
def render(self, request):
notFound = resource.NoResource(
"CGI directories do not support directory listing."
)
return notFound.render(request)
class CGIScript(resource.Resource):
"""
L{CGIScript} is a resource which runs child processes according to the CGI
specification.
The implementation is complex due to the fact that it requires asynchronous
IPC with an external process with an unpleasant protocol.
"""
isLeaf = 1
def __init__(self, filename, registry=None, reactor=None):
"""
Initialize, with the name of a CGI script file.
"""
self.filename = filename
if reactor is None:
# This installs a default reactor, if None was installed before.
# We do a late import here, so that importing the current module
# won't directly trigger installing a default reactor.
from twisted.internet import reactor
self._reactor = reactor
def render(self, request):
"""
Do various things to conform to the CGI specification.
I will set up the usual slew of environment variables, then spin off a
process.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
"""
scriptName = b"/" + b"/".join(request.prepath)
serverName = request.getRequestHostname().split(b":")[0]
env = {
"SERVER_SOFTWARE": server.version,
"SERVER_NAME": serverName,
"GATEWAY_INTERFACE": "CGI/1.1",
"SERVER_PROTOCOL": request.clientproto,
"SERVER_PORT": str(request.getHost().port),
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": scriptName,
"SCRIPT_FILENAME": self.filename,
"REQUEST_URI": request.uri,
}
ip = request.getClientAddress().host
if ip is not None:
env["REMOTE_ADDR"] = ip
pp = request.postpath
if pp:
env["PATH_INFO"] = "/" + "/".join(pp)
if hasattr(request, "content"):
# 'request.content' is either a StringIO or a TemporaryFile, and
# the file pointer is sitting at the beginning (seek(0,0))
request.content.seek(0, 2)
length = request.content.tell()
request.content.seek(0, 0)
env["CONTENT_LENGTH"] = str(length)
try:
qindex = request.uri.index(b"?")
except ValueError:
env["QUERY_STRING"] = ""
qargs = []
else:
qs = env["QUERY_STRING"] = request.uri[qindex + 1 :]
if b"=" in qs:
qargs = []
else:
qargs = [urllib.parse.unquote(x.decode()) for x in qs.split(b"+")]
# Propagate HTTP headers
for title, header in request.getAllHeaders().items():
envname = title.replace(b"-", b"_").upper()
if title not in (b"content-type", b"content-length", b"proxy"):
envname = b"HTTP_" + envname
env[envname] = header
# Propagate our environment
for key, value in os***REMOVED***iron.items():
if key not in env:
env[key] = value
# And they're off!
self.runProcess(env, request, qargs)
return server.NOT_DONE_YET
def runProcess(self, env, request, qargs=[]):
"""
Run the cgi script.
@type env: A L{dict} of L{str}, or L{None}
@param env: The environment variables to pass to the process that will
get spawned. See
L{twisted.internet.interfaces.IReactorProcess.spawnProcess} for
more information about environments and process creation.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
@type qargs: A L{list} of L{str}
@param qargs: The command line arguments to pass to the process that
will get spawned.
"""
p = CGIProcessProtocol(request)
self._reactor.spawnProcess(
p,
self.filename,
[self.filename] + qargs,
env,
os.path.dirname(self.filename),
)
class FilteredScript(CGIScript):
"""
I am a special version of a CGI script, that uses a specific executable.
This is useful for interfacing with other scripting languages that adhere
to the CGI standard. My C{filter} attribute specifies what executable to
run, and my C{filename} init parameter describes which script to pass to
the first argument of that script.
To customize me for a particular location of a CGI interpreter, override
C{filter}.
@type filter: L{str}
@ivar filter: The absolute path to the executable.
"""
filter = "/usr/bin/cat"
def runProcess(self, env, request, qargs=[]):
"""
Run a script through the C{filter} executable.
@type env: A L{dict} of L{str}, or L{None}
@param env: The environment variables to pass to the process that will
get spawned. See
L{twisted.internet.interfaces.IReactorProcess.spawnProcess}
for more information about environments and process creation.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
@type qargs: A L{list} of L{str}
@param qargs: The command line arguments to pass to the process that
will get spawned.
"""
p = CGIProcessProtocol(request)
self._reactor.spawnProcess(
p,
self.filter,
[self.filter, self.filename] + qargs,
env,
os.path.dirname(self.filename),
)
class CGIProcessProtocol(protocol.ProcessProtocol, pb.Viewable):
handling_headers = 1
headers_written = 0
headertext = b""
errortext = b""
_log = Logger()
_requestFinished = False
# Remotely relay producer interface.
def view_resumeProducing(self, issuer):
self.resumeProducing()
def view_pauseProducing(self, issuer):
self.pauseProducing()
def view_stopProducing(self, issuer):
self.stopProducing()
def resumeProducing(self):
self.transport.resumeProducing()
def pauseProducing(self):
self.transport.pauseProducing()
def stopProducing(self):
self.transport.loseConnection()
def __init__(self, request):
self.request = request
self.request.notifyFinish().addBoth(self._finished)
def connectionMade(self):
self.request.registerProducer(self, 1)
self.request.content.seek(0, 0)
content = self.request.content.read()
if content:
self.transport.write(content)
self.transport.closeStdin()
def errReceived(self, error):
self.errortext = self.errortext + error
def outReceived(self, output):
"""
Handle a chunk of input
"""
# First, make sure that the headers from the script are sorted
# out (we'll want to do some parsing on these later.)
if self.handling_headers:
text = self.headertext + output
headerEnds = []
for delimiter in b"\n\n", b"\r\n\r\n", b"\r\r", b"\n\r\n":
headerend = text.find(delimiter)
if headerend != -1:
headerEnds.append((headerend, delimiter))
if headerEnds:
# The script is entirely in control of response headers;
# disable the default Content-Type value normally provided by
# twisted.web.server.Request.
self.request.defaultContentType = None
headerEnds.sort()
headerend, delimiter = headerEnds[0]
self.headertext = text[:headerend]
# This is a final version of the header text.
linebreak = delimiter[: len(delimiter) // 2]
headers = self.headertext.split(linebreak)
for header in headers:
br = header.find(b": ")
if br == -1:
self._log.error(
"ignoring malformed CGI header: {header!r}", header=header
)
else:
headerName = header[:br].lower()
headerText = header[br + 2 :]
if headerName == b"location":
self.request.setResponseCode(http.FOUND)
if headerName == b"status":
try:
# "XXX <description>" sometimes happens.
statusNum = int(headerText[:3])
except BaseException:
self._log.error("malformed status header")
else:
self.request.setResponseCode(statusNum)
else:
# Don't allow the application to control
# these required headers.
if headerName.lower() not in (b"server", b"date"):
self.request.responseHeaders.addRawHeader(
headerName, headerText
)
output = text[headerend + len(delimiter) :]
self.handling_headers = 0
if self.handling_headers:
self.headertext = text
if not self.handling_headers:
self.request.write(output)
def processEnded(self, reason):
if reason.value.exitCode != 0:
self._log.error(
"CGI {uri} exited with exit code {exitCode}",
uri=self.request.uri,
exitCode=reason.value.exitCode,
)
if self.errortext:
self._log.error(
"Errors from CGI {uri}: {errorText}",
uri=self.request.uri,
errorText=self.errortext,
)
if self.handling_headers:
self._log.error(
"Premature end of headers in {uri}: {headerText}",
uri=self.request.uri,
headerText=self.headertext,
)
if not self._requestFinished:
self.request.write(
resource.ErrorPage(
http.INTERNAL_SERVER_ERROR,
"CGI Script Error",
"Premature end of script headers.",
).render(self.request)
)
if not self._requestFinished:
self.request.unregisterProducer()
self.request.finish()
def _finished(self, ignored):
"""
Record the end of the response generation for the request being
serviced.
"""
self._requestFinished = True

View File

@@ -0,0 +1,36 @@
# -*- test-case-name: twisted.web.test.test_util -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An assortment of web server-related utilities.
"""
__all__ = [
"redirectTo",
"Redirect",
"ParentRedirect",
"DeferredResource",
"FailureElement",
"formatFailure",
# publicized by unit tests:
"_FrameElement",
"_SourceFragmentElement",
"_SourceLineElement",
"_StackElement",
"_PRE",
]
from ._template_util import (
_PRE,
DeferredResource,
FailureElement,
ParentRedirect,
Redirect,
_FrameElement,
_SourceFragmentElement,
_SourceLineElement,
_StackElement,
formatFailure,
redirectTo,
)

View File

@@ -0,0 +1,137 @@
# -*- test-case-name: twisted.web.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I am a virtual hosts implementation.
"""
# Twisted Imports
from twisted.python import roots
from twisted.web import pages, resource
class VirtualHostCollection(roots.Homogenous):
"""Wrapper for virtual hosts collection.
This exists for configuration purposes.
"""
entityType = resource.Resource
def __init__(self, nvh):
self.nvh = nvh
def listStaticEntities(self):
return self.nvh.hosts.items()
def getStaticEntity(self, name):
return self.nvh.hosts.get(self)
def reallyPutEntity(self, name, entity):
self.nvh.addHost(name, entity)
def delEntity(self, name):
self.nvh.removeHost(name)
class NameVirtualHost(resource.Resource):
"""I am a resource which represents named virtual hosts."""
default = None
def __init__(self):
"""Initialize."""
resource.Resource.__init__(self)
self.hosts = {}
def listStaticEntities(self):
return resource.Resource.listStaticEntities(self) + [
("Virtual Hosts", VirtualHostCollection(self))
]
def getStaticEntity(self, name):
if name == "Virtual Hosts":
return VirtualHostCollection(self)
else:
return resource.Resource.getStaticEntity(self, name)
def addHost(self, name, resrc):
"""Add a host to this virtual host.
This will take a host named `name', and map it to a resource
`resrc'. For example, a setup for our virtual hosts would be::
nvh.addHost('divunal.com', divunalDirectory)
nvh.addHost('www.divunal.com', divunalDirectory)
nvh.addHost('twistedmatrix.com', twistedMatrixDirectory)
nvh.addHost('www.twistedmatrix.com', twistedMatrixDirectory)
"""
self.hosts[name] = resrc
def removeHost(self, name):
"""Remove a host."""
del self.hosts[name]
def _getResourceForRequest(self, request):
"""(Internal) Get the appropriate resource for the given host."""
hostHeader = request.getHeader(b"host")
if hostHeader is None:
return self.default or pages.notFound()
else:
host = hostHeader.lower().split(b":", 1)[0]
return self.hosts.get(host, self.default) or pages.notFound(
"Not Found",
f"host {host.decode('ascii', 'replace')!r} not in vhost map",
)
def render(self, request):
"""Implementation of resource.Resource's render method."""
resrc = self._getResourceForRequest(request)
return resrc.render(request)
def getChild(self, path, request):
"""Implementation of resource.Resource's getChild method."""
resrc = self._getResourceForRequest(request)
if resrc.isLeaf:
request.postpath.insert(0, request.prepath.pop(-1))
return resrc
else:
return resrc.getChildWithDefault(path, request)
class _HostResource(resource.Resource):
def getChild(self, path, request):
if b":" in path:
host, port = path.split(b":", 1)
port = int(port)
else:
host, port = path, 80
request.setHost(host, port)
prefixLen = 3 + request.isSecure() + 4 + len(path) + len(request.prepath[-3])
request.path = b"/" + b"/".join(request.postpath)
request.uri = request.uri[prefixLen:]
del request.prepath[:3]
return request.site.getResourceFor(request)
class VHostMonsterResource(resource.Resource):
"""
Use this to be able to record the hostname and method (http vs. https)
in the URL without disturbing your web site. If you put this resource
in a URL http://foo.com/bar then requests to
http://foo.com/bar/http/baz.com/something will be equivalent to
http://foo.com/something, except that the hostname the request will
appear to be accessing will be "baz.com". So if "baz.com" is redirecting
all requests for to foo.com, while foo.com is inaccessible from the outside,
then redirect and url generation will work correctly
"""
def getChild(self, path, request):
if path == b"http":
request.isSecure = lambda: 0
elif path == b"https":
request.isSecure = lambda: 1
return _HostResource()

View File

@@ -0,0 +1,555 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of
U{Python Web Server Gateway Interface v1.0.1<http://www.python.org/dev/peps/pep-3333/>}.
"""
from collections.abc import Sequence
from sys import exc_info
from typing import List, Union
from warnings import warn
from zope.interface import implementer
from twisted.internet.threads import blockingCallFromThread
from twisted.logger import Logger
from twisted.python.failure import Failure
from twisted.web.http import INTERNAL_SERVER_ERROR
from twisted.web.resource import IResource
from twisted.web.server import NOT_DONE_YET
# PEP-3333 -- which has superseded PEP-333 -- states that text strings MUST
# be represented using the platform's native string type, limited to
# characters defined in ISO-8859-1. Byte strings are used only for values
# read from wsgi.input, passed to write() or yielded by the application.
#
# Put another way:
#
# - All text strings are of type str, and all binary data are of
# type bytes. Text MUST always be limited to that which can be encoded as
# ISO-8859-1, U+0000 to U+00FF inclusive.
#
# The following pair of functions -- _wsgiString() and _wsgiStringToBytes() --
# are used to make Twisted's WSGI support compliant with the standard.
def _wsgiString(string: Union[str, bytes]) -> str:
"""
Convert C{string} to a WSGI "bytes-as-unicode" string.
If it's a byte string, decode as ISO-8859-1. If it's a Unicode string,
round-trip it to bytes and back using ISO-8859-1 as the encoding.
@type string: C{str} or C{bytes}
@rtype: C{str}
@raise UnicodeEncodeError: If C{string} contains non-ISO-8859-1 chars.
"""
if isinstance(string, str):
return string.encode("iso-8859-1").decode("iso-8859-1")
else:
return string.decode("iso-8859-1")
def _wsgiStringToBytes(string: str) -> bytes:
"""
Convert C{string} from a WSGI "bytes-as-unicode" string to an
ISO-8859-1 byte string.
@type string: C{str}
@rtype: C{bytes}
@raise UnicodeEncodeError: If C{string} contains non-ISO-8859-1 chars.
"""
return string.encode("iso-8859-1")
class _ErrorStream:
"""
File-like object instances of which are used as the value for the
C{'wsgi.errors'} key in the C{environ} dictionary passed to the application
object.
This simply passes writes on to L{logging<twisted.logger>} system as
error events from the C{'wsgi'} system. In the future, it may be desirable
to expose more information in the events it logs, such as the application
object which generated the message.
"""
_log = Logger()
def write(self, data: str) -> None:
"""
Generate an event for the logging system with the given bytes as the
message.
This is called in a WSGI application thread, not the I/O thread.
@type data: str
@raise TypeError: if C{data} is not a native string.
"""
if not isinstance(data, str):
raise TypeError(
"write() argument must be str, not %r (%s)"
% (data, type(data).__name__)
)
# Note that in old style, message was a tuple. logger._legacy
# will overwrite this value if it is not properly formatted here.
self._log.error(data, system="wsgi", isError=True, message=(data,))
def writelines(self, iovec: List[str]) -> None:
"""
Join the given lines and pass them to C{write} to be handled in the
usual way.
This is called in a WSGI application thread, not the I/O thread.
@param iovec: A C{list} of C{'\\n'}-terminated C{str} which will be
logged.
@raise TypeError: if C{iovec} contains any non-native strings.
"""
self.write("".join(iovec))
def flush(self):
"""
Nothing is buffered, so flushing does nothing. This method is required
to exist by PEP 333, though.
This is called in a WSGI application thread, not the I/O thread.
"""
class _InputStream:
"""
File-like object instances of which are used as the value for the
C{'wsgi.input'} key in the C{environ} dictionary passed to the application
object.
This only exists to make the handling of C{readline(-1)} consistent across
different possible underlying file-like object implementations. The other
supported methods pass through directly to the wrapped object.
"""
def __init__(self, input):
"""
Initialize the instance.
This is called in the I/O thread, not a WSGI application thread.
"""
self._wrapped = input
def read(self, size=None):
"""
Pass through to the underlying C{read}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Avoid passing None because cStringIO and file don't like it.
if size is None:
return self._wrapped.read()
return self._wrapped.read(size)
def readline(self, size=None):
"""
Pass through to the underlying C{readline}, with a size of C{-1} replaced
with a size of L{None}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Check for -1 because StringIO doesn't handle it correctly. Check for
# None because files and tempfiles don't accept that.
if size == -1 or size is None:
return self._wrapped.readline()
return self._wrapped.readline(size)
def readlines(self, size=None):
"""
Pass through to the underlying C{readlines}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Avoid passing None because cStringIO and file don't like it.
if size is None:
return self._wrapped.readlines()
return self._wrapped.readlines(size)
def __iter__(self):
"""
Pass through to the underlying C{__iter__}.
This is called in a WSGI application thread, not the I/O thread.
"""
return iter(self._wrapped)
class _WSGIResponse:
"""
Helper for L{WSGIResource} which drives the WSGI application using a
threadpool and hooks it up to the L{http.Request}.
@ivar started: A L{bool} indicating whether or not the response status and
headers have been written to the request yet. This may only be read or
written in the WSGI application thread.
@ivar reactor: An L{IReactorThreads} provider which is used to call methods
on the request in the I/O thread.
@ivar threadpool: A L{ThreadPool} which is used to call the WSGI
application object in a non-I/O thread.
@ivar application: The WSGI application object.
@ivar request: The L{http.Request} upon which the WSGI environment is
based and to which the application's output will be sent.
@ivar environ: The WSGI environment L{dict}.
@ivar status: The HTTP response status L{str} supplied to the WSGI
I{start_response} callable by the application.
@ivar headers: A list of HTTP response headers supplied to the WSGI
I{start_response} callable by the application.
@ivar _requestFinished: A flag which indicates whether it is possible to
generate more response data or not. This is L{False} until
L{http.Request.notifyFinish} tells us the request is done,
then L{True}.
"""
_requestFinished = False
_log = Logger()
def __init__(self, reactor, threadpool, application, request):
self.started = False
self.reactor = reactor
self.threadpool = threadpool
self.application = application
self.request = request
self.request.notifyFinish().addBoth(self._finished)
if request.prepath:
scriptName = b"/" + b"/".join(request.prepath)
else:
scriptName = b""
if request.postpath:
pathInfo = b"/" + b"/".join(request.postpath)
else:
pathInfo = b""
parts = request.uri.split(b"?", 1)
if len(parts) == 1:
queryString = b""
else:
queryString = parts[1]
# All keys and values need to be native strings, i.e. of type str in
# *both* Python 2 and Python 3, so says PEP-3333.
remotePeer = request.getClientAddress()
self***REMOVED***iron = {
"REQUEST_METHOD": _wsgiString(request.method),
"REMOTE_ADDR": _wsgiString(remotePeer.host),
"REMOTE_PORT": _wsgiString(str(remotePeer.port)),
"SCRIPT_NAME": _wsgiString(scriptName),
"PATH_INFO": _wsgiString(pathInfo),
"QUERY_STRING": _wsgiString(queryString),
"CONTENT_TYPE": _wsgiString(request.getHeader(b"content-type") or ""),
"CONTENT_LENGTH": _wsgiString(request.getHeader(b"content-length") or ""),
"SERVER_NAME": _wsgiString(request.getRequestHostname()),
"SERVER_PORT": _wsgiString(str(request.getHost().port)),
"SERVER_PROTOCOL": _wsgiString(request.clientproto),
}
# The application object is entirely in control of response headers;
# disable the default Content-Type value normally provided by
# twisted.web.server.Request.
self.request.defaultContentType = None
for name, values in request.requestHeaders.getAllRawHeaders():
name = "HTTP_" + _wsgiString(name).upper().replace("-", "_")
# It might be preferable for http.HTTPChannel to clear out
# newlines.
self***REMOVED***iron[name] = ",".join(_wsgiString(v) for v in values).replace(
"\n", " "
)
self***REMOVED***iron.update(
{
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.isSecure() and "https" or "http",
"wsgi.run_once": False,
"wsgi.multithread": True,
"wsgi.multiprocess": False,
"wsgi.errors": _ErrorStream(),
# Attend: request.content was owned by the I/O thread up until
# this point. By wrapping it and putting the result into the
# environment dictionary, it is effectively being given to
# another thread. This means that whatever it is, it has to be
# safe to access it from two different threads. The access
# *should* all be serialized (first the I/O thread writes to
# it, then the WSGI thread reads from it, then the I/O thread
# closes it). However, since the request is made available to
# arbitrary application code during resource traversal, it's
# possible that some other code might decide to use it in the
# I/O thread concurrently with its use in the WSGI thread.
# More likely than not, this will break. This seems like an
# unlikely possibility to me, but if it is to be allowed,
# something here needs to change. -exarkun
"wsgi.input": _InputStream(request.content),
}
)
def _finished(self, ignored):
"""
Record the end of the response generation for the request being
serviced.
"""
self._requestFinished = True
def startResponse(self, status, headers, excInfo=None):
"""
The WSGI I{start_response} callable. The given values are saved until
they are needed to generate the response.
This will be called in a non-I/O thread.
"""
if self.started and excInfo is not None:
raise excInfo[1].with_traceback(excInfo[2])
# PEP-3333 mandates that status should be a native string. In practice
# this is mandated by Twisted's HTTP implementation too.
if not isinstance(status, str):
raise TypeError(
"status must be str, not {!r} ({})".format(
status, type(status).__name__
)
)
# PEP-3333 mandates that headers should be a plain list, but in
# practice we work with any sequence type and only warn when it's not
# a plain list.
if isinstance(headers, list):
pass # This is okay.
elif isinstance(headers, Sequence):
warn(
"headers should be a list, not %r (%s)"
% (headers, type(headers).__name__),
category=RuntimeWarning,
)
else:
raise TypeError(
"headers must be a list, not %r (%s)"
% (headers, type(headers).__name__)
)
# PEP-3333 mandates that each header should be a (str, str) tuple, but
# in practice we work with any sequence type and only warn when it's
# not a plain list.
for header in headers:
if isinstance(header, tuple):
pass # This is okay.
elif isinstance(header, Sequence):
warn(
"header should be a (str, str) tuple, not %r (%s)"
% (header, type(header).__name__),
category=RuntimeWarning,
)
else:
raise TypeError(
"header must be a (str, str) tuple, not %r (%s)"
% (header, type(header).__name__)
)
# However, the sequence MUST contain only 2 elements.
if len(header) != 2:
raise TypeError(f"header must be a (str, str) tuple, not {header!r}")
# Both elements MUST be native strings. Non-native strings will be
# rejected by the underlying HTTP machinery in any case, but we
# reject them here in order to provide a more informative error.
for elem in header:
if not isinstance(elem, str):
raise TypeError(f"header must be (str, str) tuple, not {header!r}")
self.status = status
self.headers = headers
return self.write
def write(self, data):
"""
The WSGI I{write} callable returned by the I{start_response} callable.
The given bytes will be written to the response body, possibly flushing
the status and headers first.
This will be called in a non-I/O thread.
"""
# PEP-3333 states:
#
# The server or gateway must transmit the yielded bytestrings to the
# client in an unbuffered fashion, completing the transmission of
# each bytestring before requesting another one.
#
# This write() method is used for the imperative and (indirectly) for
# the more familiar iterable-of-bytestrings WSGI mechanism. It uses
# C{blockingCallFromThread} to schedule writes. This allows exceptions
# to propagate up from the underlying HTTP implementation. However,
# that underlying implementation does not, as yet, provide any way to
# know if the written data has been transmitted, so this method
# violates the above part of PEP-3333.
#
# PEP-3333 also says that a server may:
#
# Use a different thread to ensure that the block continues to be
# transmitted while the application produces the next block.
#
# Which suggests that this is actually compliant with PEP-3333,
# because writes are done in the reactor thread.
#
# However, providing some back-pressure may nevertheless be a Good
# Thing at some point in the future.
def wsgiWrite(started):
if not started:
self._sendResponseHeaders()
self.request.write(data)
try:
return blockingCallFromThread(self.reactor, wsgiWrite, self.started)
finally:
self.started = True
def _sendResponseHeaders(self):
"""
Set the response code and response headers on the request object, but
do not flush them. The caller is responsible for doing a write in
order for anything to actually be written out in response to the
request.
This must be called in the I/O thread.
"""
code, message = self.status.split(None, 1)
code = int(code)
self.request.setResponseCode(code, _wsgiStringToBytes(message))
for name, value in self.headers:
# Don't allow the application to control these required headers.
if name.lower() not in ("server", "date"):
self.request.responseHeaders.addRawHeader(
_wsgiStringToBytes(name), _wsgiStringToBytes(value)
)
def start(self):
"""
Start the WSGI application in the threadpool.
This must be called in the I/O thread.
"""
self.threadpool.callInThread(self.run)
def run(self):
"""
Call the WSGI application object, iterate it, and handle its output.
This must be called in a non-I/O thread (ie, a WSGI application
thread).
"""
try:
appIterator = self.application(self***REMOVED***iron, self.startResponse)
for elem in appIterator:
if elem:
self.write(elem)
if self._requestFinished:
break
close = getattr(appIterator, "close", None)
if close is not None:
close()
except BaseException:
def wsgiError(started, type, value, traceback):
self._log.failure(
"WSGI application error", failure=Failure(value, type, traceback)
)
if started:
self.request.loseConnection()
else:
self.request.setResponseCode(INTERNAL_SERVER_ERROR)
self.request.finish()
self.reactor.callFromThread(wsgiError, self.started, *exc_info())
else:
def wsgiFinish(started):
if not self._requestFinished:
if not started:
self._sendResponseHeaders()
self.request.finish()
self.reactor.callFromThread(wsgiFinish, self.started)
self.started = True
@implementer(IResource)
class WSGIResource:
"""
An L{IResource} implementation which delegates responsibility for all
resources hierarchically inferior to it to a WSGI application.
The C{environ} argument passed to the application, includes the
C{REMOTE_PORT} key to complement the C{REMOTE_ADDR} key.
@ivar _reactor: An L{IReactorThreads} provider which will be passed on to
L{_WSGIResponse} to schedule calls in the I/O thread.
@ivar _threadpool: A L{ThreadPool} which will be passed on to
L{_WSGIResponse} to run the WSGI application object.
@ivar _application: The WSGI application object.
"""
# Further resource segments are left up to the WSGI application object to
# handle.
isLeaf = True
def __init__(self, reactor, threadpool, application):
self._reactor = reactor
self._threadpool = threadpool
self._application = application
def render(self, request):
"""
Turn the request into the appropriate C{environ} C{dict} suitable to be
passed to the WSGI application object and then pass it on.
The WSGI application object is given almost complete control of the
rendering process. C{NOT_DONE_YET} will always be returned in order
and response completion will be dictated by the application object, as
will the status, headers, and the response body.
"""
response = _WSGIResponse(
self._reactor, self._threadpool, self._application, request
)
response.start()
return NOT_DONE_YET
def getChildWithDefault(self, name, request):
"""
Reject attempts to retrieve a child resource. All path segments beyond
the one which refers to this resource are handled by the WSGI
application object.
"""
raise RuntimeError("Cannot get IResource children from WSGIResource")
def putChild(self, path, child):
"""
Reject attempts to add a child resource to this resource. The WSGI
application object handles all path segments beneath this resource, so
L{IResource} children can never be found.
"""
raise RuntimeError("Cannot put IResource children under WSGIResource")
__all__ = ["WSGIResource"]

View File

@@ -0,0 +1,633 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A generic resource for publishing objects via XML-RPC.
Maintainer: Itamar Shtull-Trauring
@var Fault: See L{xmlrpclib.Fault}
@type Fault: L{xmlrpclib.Fault}
"""
# System Imports
import base64
import xmlrpc.client as xmlrpclib
from urllib.parse import urlparse
from xmlrpc.client import Binary, Boolean, DateTime, Fault
from twisted.internet import defer, error, protocol
from twisted.logger import Logger
from twisted.python import failure, reflect
from twisted.python.compat import nativeString
# Sibling Imports
from twisted.web import http, resource, server
# These are deprecated, use the class level definitions
NOT_FOUND = 8001
FAILURE = 8002
def withRequest(f):
"""
Decorator to cause the request to be passed as the first argument
to the method.
If an I{xmlrpc_} method is wrapped with C{withRequest}, the
request object is passed as the first argument to that method.
For example::
@withRequest
def xmlrpc_echo(self, request, s):
return s
@since: 10.2
"""
f.withRequest = True
return f
class NoSuchFunction(Fault):
"""
There is no function by the given name.
"""
class Handler:
"""
Handle a XML-RPC request and store the state for a request in progress.
Override the run() method and return result using self.result,
a Deferred.
We require this class since we're not using threads, so we can't
encapsulate state in a running function if we're going to have
to wait for results.
For example, lets say we want to authenticate against twisted.cred,
run a LDAP query and then pass its result to a database query, all
as a result of a single XML-RPC command. We'd use a Handler instance
to store the state of the running command.
"""
def __init__(self, resource, *args):
self.resource = resource # the XML-RPC resource we are connected to
self.result = defer.Deferred()
self.run(*args)
def run(self, *args):
# event driven equivalent of 'raise UnimplementedError'
self.result.errback(NotImplementedError("Implement run() in subclasses"))
class XMLRPC(resource.Resource):
"""
A resource that implements XML-RPC.
You probably want to connect this to '/RPC2'.
Methods published can return XML-RPC serializable results, Faults,
Binary, Boolean, DateTime, Deferreds, or Handler instances.
By default methods beginning with 'xmlrpc_' are published.
Sub-handlers for prefixed methods (e.g., system.listMethods)
can be added with putSubHandler. By default, prefixes are
separated with a '.'. Override self.separator to change this.
@ivar allowNone: Permit XML translating of Python constant None.
@type allowNone: C{bool}
@ivar useDateTime: Present C{datetime} values as C{datetime.datetime}
objects?
@type useDateTime: C{bool}
"""
# Error codes for Twisted, if they conflict with yours then
# modify them at runtime.
NOT_FOUND = 8001
FAILURE = 8002
isLeaf = 1
separator = "."
allowedMethods = (b"POST",)
_log = Logger()
def __init__(self, allowNone=False, useDateTime=False):
resource.Resource.__init__(self)
self.subHandlers = {}
self.allowNone = allowNone
self.useDateTime = useDateTime
def __setattr__(self, name, value):
self.__dict__[name] = value
def putSubHandler(self, prefix, handler):
self.subHandlers[prefix] = handler
def getSubHandler(self, prefix):
return self.subHandlers.get(prefix, None)
def getSubHandlerPrefixes(self):
return list(self.subHandlers.keys())
def render_POST(self, request):
request.content.seek(0, 0)
request.setHeader(b"content-type", b"text/xml; charset=utf-8")
try:
args, functionPath = xmlrpclib.loads(
request.content.read(), use_datetime=self.useDateTime
)
except Exception as e:
f = Fault(self.FAILURE, f"Can't deserialize input: {e}")
self._cbRender(f, request)
else:
try:
function = self.lookupProcedure(functionPath)
except Fault as f:
self._cbRender(f, request)
else:
# Use this list to track whether the response has failed or not.
# This will be used later on to decide if the result of the
# Deferred should be written out and Request.finish called.
responseFailed = []
request.notifyFinish().addErrback(responseFailed.append)
if getattr(function, "withRequest", False):
d = defer.maybeDeferred(function, request, *args)
else:
d = defer.maybeDeferred(function, *args)
d.addErrback(self._ebRender)
d.addCallback(self._cbRender, request, responseFailed)
return server.NOT_DONE_YET
def _cbRender(self, result, request, responseFailed=None):
if responseFailed:
return
if isinstance(result, Handler):
result = result.result
if not isinstance(result, Fault):
result = (result,)
try:
try:
content = xmlrpclib.dumps(
result, methodresponse=True, allow_none=self.allowNone
)
except Exception as e:
f = Fault(self.FAILURE, f"Can't serialize output: {e}")
content = xmlrpclib.dumps(
f, methodresponse=True, allow_none=self.allowNone
)
if isinstance(content, str):
content = content.encode("utf8")
request.setHeader(b"content-length", b"%d" % (len(content),))
request.write(content)
except Exception:
self._log.failure("")
request.finish()
def _ebRender(self, failure):
if isinstance(failure.value, Fault):
return failure.value
self._log.failure("", failure)
return Fault(self.FAILURE, "error")
def lookupProcedure(self, procedurePath):
"""
Given a string naming a procedure, return a callable object for that
procedure or raise NoSuchFunction.
The returned object will be called, and should return the result of the
procedure, a Deferred, or a Fault instance.
Override in subclasses if you want your own policy. The base
implementation that given C{'foo'}, C{self.xmlrpc_foo} will be returned.
If C{procedurePath} contains C{self.separator}, the sub-handler for the
initial prefix is used to search for the remaining path.
If you override C{lookupProcedure}, you may also want to override
C{listProcedures} to accurately report the procedures supported by your
resource, so that clients using the I{system.listMethods} procedure
receive accurate results.
@since: 11.1
"""
if procedurePath.find(self.separator) != -1:
prefix, procedurePath = procedurePath.split(self.separator, 1)
handler = self.getSubHandler(prefix)
if handler is None:
raise NoSuchFunction(self.NOT_FOUND, "no such subHandler %s" % prefix)
return handler.lookupProcedure(procedurePath)
f = getattr(self, "xmlrpc_%s" % procedurePath, None)
if not f:
raise NoSuchFunction(
self.NOT_FOUND, "procedure %s not found" % procedurePath
)
elif not callable(f):
raise NoSuchFunction(
self.NOT_FOUND, "procedure %s not callable" % procedurePath
)
else:
return f
def listProcedures(self):
"""
Return a list of the names of all xmlrpc procedures.
@since: 11.1
"""
return reflect.prefixedMethodNames(self.__class__, "xmlrpc_")
class XMLRPCIntrospection(XMLRPC):
"""
Implement the XML-RPC Introspection API.
By default, the methodHelp method returns the 'help' method attribute,
if it exists, otherwise the __doc__ method attribute, if it exists,
otherwise the empty string.
To enable the methodSignature method, add a 'signature' method attribute
containing a list of lists. See methodSignature's documentation for the
format. Note the type strings should be XML-RPC types, not Python types.
"""
def __init__(self, parent):
"""
Implement Introspection support for an XMLRPC server.
@param parent: the XMLRPC server to add Introspection support to.
@type parent: L{XMLRPC}
"""
XMLRPC.__init__(self)
self._xmlrpc_parent = parent
def xmlrpc_listMethods(self):
"""
Return a list of the method names implemented by this server.
"""
functions = []
todo = [(self._xmlrpc_parent, "")]
while todo:
obj, prefix = todo.pop(0)
functions.extend([prefix + name for name in obj.listProcedures()])
todo.extend(
[
(obj.getSubHandler(name), prefix + name + obj.separator)
for name in obj.getSubHandlerPrefixes()
]
)
return functions
xmlrpc_listMethods.signature = [["array"]] # type: ignore[attr-defined]
def xmlrpc_methodHelp(self, method):
"""
Return a documentation string describing the use of the given method.
"""
method = self._xmlrpc_parent.lookupProcedure(method)
return getattr(method, "help", None) or getattr(method, "__doc__", None) or ""
xmlrpc_methodHelp.signature = [["string", "string"]] # type: ignore[attr-defined]
def xmlrpc_methodSignature(self, method):
"""
Return a list of type signatures.
Each type signature is a list of the form [rtype, type1, type2, ...]
where rtype is the return type and typeN is the type of the Nth
argument. If no signature information is available, the empty
string is returned.
"""
method = self._xmlrpc_parent.lookupProcedure(method)
return getattr(method, "signature", None) or ""
xmlrpc_methodSignature.signature = [ # type: ignore[attr-defined]
["array", "string"],
["string", "string"],
]
def addIntrospection(xmlrpc):
"""
Add Introspection support to an XMLRPC server.
@param xmlrpc: the XMLRPC server to add Introspection support to.
@type xmlrpc: L{XMLRPC}
"""
xmlrpc.putSubHandler("system", XMLRPCIntrospection(xmlrpc))
class QueryProtocol(http.HTTPClient):
def connectionMade(self):
self._response = None
self.sendCommand(b"POST", self.factory.path)
self.sendHeader(b"User-Agent", b"Twisted/XMLRPClib")
self.sendHeader(b"Host", self.factory.host)
self.sendHeader(b"Content-type", b"text/xml; charset=utf-8")
payload = self.factory.payload
self.sendHeader(b"Content-length", b"%d" % (len(payload),))
if self.factory.user:
auth = b":".join([self.factory.user, self.factory.password])
authHeader = b"".join([b"Basic ", base64.b64encode(auth)])
self.sendHeader(b"Authorization", authHeader)
self.endHeaders()
self.transport.write(payload)
def handleStatus(self, version, status, message):
if status != b"200":
self.factory.badStatus(status, message)
def handleResponse(self, contents):
"""
Handle the XML-RPC response received from the server.
Specifically, disconnect from the server and store the XML-RPC
response so that it can be properly handled when the disconnect is
finished.
"""
self.transport.loseConnection()
self._response = contents
def connectionLost(self, reason):
"""
The connection to the server has been lost.
If we have a full response from the server, then parse it and fired a
Deferred with the return value or C{Fault} that the server gave us.
"""
if not reason.check(error.ConnectionDone, error.ConnectionLost):
# for example, ssl.SSL.Error
self.factory.clientConnectionLost(None, reason)
http.HTTPClient.connectionLost(self, reason)
if self._response is not None:
response, self._response = self._response, None
self.factory.parseResponse(response)
payloadTemplate = """<?xml version="1.0"?>
<methodCall>
<methodName>%s</methodName>
%s
</methodCall>
"""
class QueryFactory(protocol.ClientFactory):
"""
XML-RPC Client Factory
@ivar path: The path portion of the URL to which to post method calls.
@type path: L{bytes}
@ivar host: The value to use for the Host HTTP header.
@type host: L{bytes}
@ivar user: The username with which to authenticate with the server
when making calls.
@type user: L{bytes} or L{None}
@ivar password: The password with which to authenticate with the server
when making calls.
@type password: L{bytes} or L{None}
@ivar useDateTime: Accept datetime values as datetime.datetime objects.
also passed to the underlying xmlrpclib implementation. Defaults to
C{False}.
@type useDateTime: C{bool}
"""
deferred = None
protocol = QueryProtocol
def __init__(
self,
path,
host,
method,
user=None,
password=None,
allowNone=False,
args=(),
canceller=None,
useDateTime=False,
):
"""
@param method: The name of the method to call.
@type method: C{str}
@param allowNone: allow the use of None values in parameters. It's
passed to the underlying xmlrpclib implementation. Defaults to
C{False}.
@type allowNone: C{bool} or L{None}
@param args: the arguments to pass to the method.
@type args: C{tuple}
@param canceller: A 1-argument callable passed to the deferred as the
canceller callback.
@type canceller: callable or L{None}
"""
self.path, self.host = path, host
self.user, self.password = user, password
self.payload = payloadTemplate % (
method,
xmlrpclib.dumps(args, allow_none=allowNone),
)
if isinstance(self.payload, str):
self.payload = self.payload.encode("utf8")
self.deferred = defer.Deferred(canceller)
self.useDateTime = useDateTime
def parseResponse(self, contents):
if not self.deferred:
return
try:
response = xmlrpclib.loads(contents, use_datetime=self.useDateTime)[0][0]
except BaseException:
deferred, self.deferred = self.deferred, None
deferred.errback(failure.Failure())
else:
deferred, self.deferred = self.deferred, None
deferred.callback(response)
def clientConnectionLost(self, _, reason):
if self.deferred is not None:
deferred, self.deferred = self.deferred, None
deferred.errback(reason)
clientConnectionFailed = clientConnectionLost
def badStatus(self, status, message):
deferred, self.deferred = self.deferred, None
deferred.errback(ValueError(status, message))
class Proxy:
"""
A Proxy for making remote XML-RPC calls.
Pass the URL of the remote XML-RPC server to the constructor.
Use C{proxy.callRemote('foobar', *args)} to call remote method
'foobar' with *args.
@ivar user: The username with which to authenticate with the server
when making calls. If specified, overrides any username information
embedded in C{url}. If not specified, a value may be taken from
C{url} if present.
@type user: L{bytes} or L{None}
@ivar password: The password with which to authenticate with the server
when making calls. If specified, overrides any password information
embedded in C{url}. If not specified, a value may be taken from
C{url} if present.
@type password: L{bytes} or L{None}
@ivar allowNone: allow the use of None values in parameters. It's
passed to the underlying L{xmlrpclib} implementation. Defaults to
C{False}.
@type allowNone: C{bool} or L{None}
@ivar useDateTime: Accept datetime values as datetime.datetime objects.
also passed to the underlying L{xmlrpclib} implementation. Defaults to
C{False}.
@type useDateTime: C{bool}
@ivar connectTimeout: Number of seconds to wait before assuming the
connection has failed.
@type connectTimeout: C{float}
@ivar _reactor: The reactor used to create connections.
@type _reactor: Object providing L{twisted.internet.interfaces.IReactorTCP}
@ivar queryFactory: Object returning a factory for XML-RPC protocol. Use
this for testing, or to manipulate the XML-RPC parsing behavior. For
example, you may set this to a custom "debugging" factory object that
reimplements C{parseResponse} in order to log the raw XML-RPC contents
from the server before continuing on with parsing. Another possibility
is to implement your own XML-RPC marshaller here to handle non-standard
XML-RPC traffic.
@type queryFactory: L{twisted.web.xmlrpc.QueryFactory}
"""
queryFactory = QueryFactory
def __init__(
self,
url,
user=None,
password=None,
allowNone=False,
useDateTime=False,
connectTimeout=30.0,
reactor=None,
):
"""
@param url: The URL to which to post method calls. Calls will be made
over SSL if the scheme is HTTPS. If netloc contains username or
password information, these will be used to authenticate, as long as
the C{user} and C{password} arguments are not specified.
@type url: L{bytes}
"""
if reactor is None:
from twisted.internet import reactor
scheme, netloc, path, params, query, fragment = urlparse(url)
netlocParts = netloc.split(b"@")
if len(netlocParts) == 2:
userpass = netlocParts.pop(0).split(b":")
self.user = userpass.pop(0)
try:
self.[PASSWORD-REMOVED].pop(0)
except BaseException:
self.password = None
else:
self.user = self.password = None
hostport = netlocParts[0].split(b":")
self.host = hostport.pop(0)
try:
self.port = int(hostport.pop(0))
except BaseException:
self.port = None
self.path = path
if self.path in [b"", None]:
self.path = b"/"
self.secure = scheme == b"https"
if user is not None:
self.user = user
if password is not None:
self.[PASSWORD-REMOVED]
self.allowNone = allowNone
self.useDateTime = useDateTime
self.connectTimeout = connectTimeout
self._reactor = reactor
def callRemote(self, method, *args):
"""
Call remote XML-RPC C{method} with given arguments.
@return: a L{defer.Deferred} that will fire with the method response,
or a failure if the method failed. Generally, the failure type will
be L{Fault}, but you can also have an C{IndexError} on some buggy
servers giving empty responses.
If the deferred is cancelled before the request completes, the
connection is closed and the deferred will fire with a
L{defer.CancelledError}.
"""
def cancel(d):
factory.deferred = None
connector.disconnect()
factory = self.queryFactory(
self.path,
self.host,
method,
self.user,
self.password,
self.allowNone,
args,
cancel,
self.useDateTime,
)
if self.secure:
from twisted.internet import ssl
contextFactory = ssl.optionsForClientTLS(hostname=nativeString(self.host))
connector = self._reactor.connectSSL(
nativeString(self.host),
self.port or 443,
factory,
contextFactory,
timeout=self.connectTimeout,
)
else:
connector = self._reactor.connectTCP(
nativeString(self.host),
self.port or 80,
factory,
timeout=self.connectTimeout,
)
return factory.deferred
__all__ = [
"XMLRPC",
"Handler",
"NoSuchFunction",
"Proxy",
"Fault",
"Binary",
"Boolean",
"DateTime",
]