okay fine

This commit is contained in:
pacnpal
2024-11-03 17:47:26 +00:00
parent 01c6004a79
commit 27eb239e97
10020 changed files with 1935769 additions and 2364 deletions

View File

@@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Names: DNS server and client implementations.
"""

View File

@@ -0,0 +1,261 @@
# -*- test-case-name: twisted.names.test.test_rfc1982 -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Utilities for handling RFC1982 Serial Number Arithmetic.
@see: U{http://tools.ietf.org/html/rfc1982}
@var RFC4034_TIME_FORMAT: RRSIG Time field presentation format. The Signature
Expiration Time and Inception Time field values MUST be represented either
as an unsigned decimal integer indicating seconds since 1 January 1970
00:00:00 UTC, or in the form YYYYMMDDHHmmSS in UTC. See U{RRSIG Presentation
Format<https://tools.ietf.org/html/rfc4034#section-3.2>}
"""
import calendar
from datetime import datetime, timedelta
from twisted.python.compat import nativeString
from twisted.python.util import FancyStrMixin
RFC4034_TIME_FORMAT = "%Y%m%d%H%M%S"
class SerialNumber(FancyStrMixin):
"""
An RFC1982 Serial Number.
This class implements RFC1982 DNS Serial Number Arithmetic.
SNA is used in DNS and specifically in DNSSEC as defined in RFC4034 in the
DNSSEC Signature Expiration and Inception Fields.
@see: U{https://tools.ietf.org/html/rfc1982}
@see: U{https://tools.ietf.org/html/rfc4034}
@ivar _serialBits: See C{serialBits} of L{__init__}.
@ivar _number: See C{number} of L{__init__}.
@ivar _modulo: The value at which wrapping will occur.
@ivar _halfRing: Half C{_modulo}. If another L{SerialNumber} value is larger
than this, it would lead to a wrapped value which is larger than the
first and comparisons are therefore ambiguous.
@ivar _maxAdd: Half C{_modulo} plus 1. If another L{SerialNumber} value is
larger than this, it would lead to a wrapped value which is larger than
the first. Comparisons with the original value would therefore be
ambiguous.
"""
showAttributes = (
("_number", "number", "%d"),
("_serialBits", "serialBits", "%d"),
)
def __init__(self, number: int, serialBits: int = 32):
"""
Construct an L{SerialNumber} instance.
@param number: An L{int} which will be stored as the modulo
C{number % 2 ^ serialBits}
@type number: L{int}
@param serialBits: The size of the serial number space. The power of two
which results in one larger than the largest integer corresponding
to a serial number value.
@type serialBits: L{int}
"""
self._serialBits = serialBits
self._modulo = 2**serialBits
self._halfRing: int = 2 ** (serialBits - 1)
self._maxAdd = 2 ** (serialBits - 1) - 1
self._number: int = int(number) % self._modulo
def _convertOther(self, other: object) -> "SerialNumber":
"""
Check that a foreign object is suitable for use in the comparison or
arithmetic magic methods of this L{SerialNumber} instance. Raise
L{TypeError} if not.
@param other: The foreign L{object} to be checked.
@return: C{other} after compatibility checks and possible coercion.
@raise TypeError: If C{other} is not compatible.
"""
if not isinstance(other, SerialNumber):
raise TypeError(f"cannot compare or combine {self!r} and {other!r}")
if self._serialBits != other._serialBits:
raise TypeError(
"cannot compare or combine SerialNumber instances with "
"different serialBits. %r and %r" % (self, other)
)
return other
def __str__(self) -> str:
"""
Return a string representation of this L{SerialNumber} instance.
@rtype: L{nativeString}
"""
return nativeString("%d" % (self._number,))
def __int__(self):
"""
@return: The integer value of this L{SerialNumber} instance.
@rtype: L{int}
"""
return self._number
def __eq__(self, other: object) -> bool:
"""
Allow rich equality comparison with another L{SerialNumber} instance.
"""
try:
other = self._convertOther(other)
except TypeError:
return NotImplemented
return other._number == self._number
def __lt__(self, other: object) -> bool:
"""
Allow I{less than} comparison with another L{SerialNumber} instance.
"""
try:
other = self._convertOther(other)
except TypeError:
return NotImplemented
return (
self._number < other._number
and (other._number - self._number) < self._halfRing
) or (
self._number > other._number
and (self._number - other._number) > self._halfRing
)
def __gt__(self, other: object) -> bool:
"""
Allow I{greater than} comparison with another L{SerialNumber} instance.
"""
try:
other_sn = self._convertOther(other)
except TypeError:
return NotImplemented
return (
self._number < other_sn._number
and (other_sn._number - self._number) > self._halfRing
) or (
self._number > other_sn._number
and (self._number - other_sn._number) < self._halfRing
)
def __le__(self, other: object) -> bool:
"""
Allow I{less than or equal} comparison with another L{SerialNumber}
instance.
"""
try:
other = self._convertOther(other)
except TypeError:
return NotImplemented
return self == other or self < other
def __ge__(self, other: object) -> bool:
"""
Allow I{greater than or equal} comparison with another L{SerialNumber}
instance.
"""
try:
other = self._convertOther(other)
except TypeError:
return NotImplemented
return self == other or self > other
def __add__(self, other: object) -> "SerialNumber":
"""
Allow I{addition} with another L{SerialNumber} instance.
Serial numbers may be incremented by the addition of a positive
integer n, where n is taken from the range of integers
[0 .. (2^(SERIAL_BITS - 1) - 1)]. For a sequence number s, the
result of such an addition, s', is defined as
s' = (s + n) modulo (2 ^ SERIAL_BITS)
where the addition and modulus operations here act upon values that are
non-negative values of unbounded size in the usual ways of integer
arithmetic.
Addition of a value outside the range
[0 .. (2^(SERIAL_BITS - 1) - 1)] is undefined.
@see: U{http://tools.ietf.org/html/rfc1982#section-3.1}
@raise ArithmeticError: If C{other} is more than C{_maxAdd}
ie more than half the maximum value of this serial number.
"""
try:
other = self._convertOther(other)
except TypeError:
return NotImplemented
if other._number <= self._maxAdd:
return SerialNumber(
(self._number + other._number) % self._modulo,
serialBits=self._serialBits,
)
else:
raise ArithmeticError(
"value %r outside the range 0 .. %r"
% (
other._number,
self._maxAdd,
)
)
def __hash__(self):
"""
Allow L{SerialNumber} instances to be hashed for use as L{dict} keys.
@rtype: L{int}
"""
return hash(self._number)
@classmethod
def fromRFC4034DateString(cls, utcDateString):
"""
Create an L{SerialNumber} instance from a date string in format
'YYYYMMDDHHMMSS' described in U{RFC4034
3.2<https://tools.ietf.org/html/rfc4034#section-3.2>}.
The L{SerialNumber} instance stores the date as a 32bit UNIX timestamp.
@see: U{https://tools.ietf.org/html/rfc4034#section-3.1.5}
@param utcDateString: A UTC date/time string of format I{YYMMDDhhmmss}
which will be converted to seconds since the UNIX epoch.
@type utcDateString: L{unicode}
@return: An L{SerialNumber} instance containing the supplied date as a
32bit UNIX timestamp.
"""
parsedDate = datetime.strptime(utcDateString, RFC4034_TIME_FORMAT)
secondsSinceEpoch = calendar.timegm(parsedDate.utctimetuple())
return cls(secondsSinceEpoch, serialBits=32)
def toRFC4034DateString(self):
"""
Calculate a date by treating the current L{SerialNumber} value as a UNIX
timestamp and return a date string in the format described in
U{RFC4034 3.2<https://tools.ietf.org/html/rfc4034#section-3.2>}.
@return: The date string.
"""
# Can't use datetime.utcfromtimestamp, because it seems to overflow the
# signed 32bit int used in the underlying C library. SNA is unsigned
# and capable of handling all timestamps up to 2**32.
d = datetime(1970, 1, 1) + timedelta(seconds=self._number)
return nativeString(d.strftime(RFC4034_TIME_FORMAT))
__all__ = ["SerialNumber"]

View File

@@ -0,0 +1,503 @@
# -*- test-case-name: twisted.names.test.test_names -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Authoritative resolvers.
"""
import os
import time
from twisted.internet import defer
from twisted.names import common, dns, error
from twisted.python import failure
from twisted.python.compat import execfile, nativeString
from twisted.python.filepath import FilePath
def getSerial(filename="/tmp/twisted-names.serial"):
"""
Return a monotonically increasing (across program runs) integer.
State is stored in the given file. If it does not exist, it is
created with rw-/---/--- permissions.
This manipulates process-global state by calling C{os.umask()}, so it isn't
thread-safe.
@param filename: Path to a file that is used to store the state across
program runs.
@type filename: L{str}
@return: a monotonically increasing number
@rtype: L{str}
"""
serial = time.strftime("%Y%m%d")
o = os.umask(0o177)
try:
if not os.path.exists(filename):
with open(filename, "w") as f:
f.write(serial + " 0")
finally:
os.umask(o)
with open(filename) as serialFile:
lastSerial, zoneID = serialFile.readline().split()
zoneID = (lastSerial == serial) and (int(zoneID) + 1) or 0
with open(filename, "w") as serialFile:
serialFile.write("%s %d" % (serial, zoneID))
serial = serial + ("%02d" % (zoneID,))
return serial
class FileAuthority(common.ResolverBase):
"""
An Authority that is loaded from a file.
This is an abstract class that implements record search logic. To create
a functional resolver, subclass it and override the L{loadFile} method.
@ivar _ADDITIONAL_PROCESSING_TYPES: Record types for which additional
processing will be done.
@ivar _ADDRESS_TYPES: Record types which are useful for inclusion in the
additional section generated during additional processing.
@ivar soa: A 2-tuple containing the SOA domain name as a L{bytes} and a
L{dns.Record_SOA}.
@ivar records: A mapping of domains (as lowercased L{bytes}) to records.
@type records: L{dict} with L{bytes} keys
"""
# See https://twistedmatrix.com/trac/ticket/6650
_ADDITIONAL_PROCESSING_TYPES = (dns.CNAME, dns.MX, dns.NS)
_ADDRESS_TYPES = (dns.A, dns.AAAA)
soa = None
records = None
def __init__(self, filename):
common.ResolverBase.__init__(self)
self.loadFile(filename)
self._cache = {}
def __setstate__(self, state):
self.__dict__ = state
def loadFile(self, filename):
"""
Load DNS records from a file.
This method populates the I{soa} and I{records} attributes. It must be
overridden in a subclass. It is called once from the initializer.
@param filename: The I{filename} parameter that was passed to the
initilizer.
@returns: L{None} -- the return value is ignored
"""
def _additionalRecords(self, answer, authority, ttl):
"""
Find locally known information that could be useful to the consumer of
the response and construct appropriate records to include in the
I{additional} section of that response.
Essentially, implement RFC 1034 section 4.3.2 step 6.
@param answer: A L{list} of the records which will be included in the
I{answer} section of the response.
@param authority: A L{list} of the records which will be included in
the I{authority} section of the response.
@param ttl: The default TTL for records for which this is not otherwise
specified.
@return: A generator of L{dns.RRHeader} instances for inclusion in the
I{additional} section. These instances represent extra information
about the records in C{answer} and C{authority}.
"""
for record in answer + authority:
if record.type in self._ADDITIONAL_PROCESSING_TYPES:
name = record.payload.name.name
for rec in self.records.get(name.lower(), ()):
if rec.TYPE in self._ADDRESS_TYPES:
yield dns.RRHeader(
name, rec.TYPE, dns.IN, rec.ttl or ttl, rec, auth=True
)
def _lookup(self, name, cls, type, timeout=None):
"""
Determine a response to a particular DNS query.
@param name: The name which is being queried and for which to lookup a
response.
@type name: L{bytes}
@param cls: The class which is being queried. Only I{IN} is
implemented here and this value is presently disregarded.
@type cls: L{int}
@param type: The type of records being queried. See the types defined
in L{twisted.names.dns}.
@type type: L{int}
@param timeout: All processing is done locally and a result is
available immediately, so the timeout value is ignored.
@return: A L{Deferred} that fires with a L{tuple} of three sets of
response records (to comprise the I{answer}, I{authority}, and
I{additional} sections of a DNS response) or with a L{Failure} if
there is a problem processing the query.
"""
cnames = []
results = []
authority = []
additional = []
default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
domain_records = self.records.get(name.lower())
if domain_records:
for record in domain_records:
if record.ttl is not None:
ttl = record.ttl
else:
ttl = default_ttl
if record.TYPE == dns.NS and name.lower() != self.soa[0].lower():
# NS record belong to a child zone: this is a referral. As
# NS records are authoritative in the child zone, ours here
# are not. RFC 2181, section 6.1.
authority.append(
dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=False)
)
elif record.TYPE == type or type == dns.ALL_RECORDS:
results.append(
dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
)
if record.TYPE == dns.CNAME:
cnames.append(
dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
)
if not results:
results = cnames
# Sort of https://tools.ietf.org/html/rfc1034#section-4.3.2 .
# See https://twistedmatrix.com/trac/ticket/6732
additionalInformation = self._additionalRecords(
results, authority, default_ttl
)
if cnames:
results.extend(additionalInformation)
else:
additional.extend(additionalInformation)
if not results and not authority:
# Empty response. Include SOA record to allow clients to cache
# this response. RFC 1034, sections 3.7 and 4.3.4, and RFC 2181
# section 7.1.
authority.append(
dns.RRHeader(
self.soa[0], dns.SOA, dns.IN, ttl, self.soa[1], auth=True
)
)
return defer.succeed((results, authority, additional))
else:
if dns._isSubdomainOf(name, self.soa[0]):
# We may be the authority and we didn't find it.
# XXX: The QNAME may also be in a delegated child zone. See
# #6581 and #6580
return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
else:
# The QNAME is not a descendant of this zone. Fail with
# DomainError so that the next chained authority or
# resolver will be queried.
return defer.fail(failure.Failure(error.DomainError(name)))
def lookupZone(self, name, timeout=10):
name = dns.domainString(name)
if self.soa[0].lower() == name.lower():
# Wee hee hee hooo yea
default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
if self.soa[1].ttl is not None:
soa_ttl = self.soa[1].ttl
else:
soa_ttl = default_ttl
results = [
dns.RRHeader(
self.soa[0], dns.SOA, dns.IN, soa_ttl, self.soa[1], auth=True
)
]
for k, r in self.records.items():
for rec in r:
if rec.ttl is not None:
ttl = rec.ttl
else:
ttl = default_ttl
if rec.TYPE != dns.SOA:
results.append(
dns.RRHeader(k, rec.TYPE, dns.IN, ttl, rec, auth=True)
)
results.append(results[0])
return defer.succeed((results, (), ()))
return defer.fail(failure.Failure(dns.DomainError(name)))
def _cbAllRecords(self, results):
ans, auth, add = [], [], []
for res in results:
if res[0]:
ans.extend(res[1][0])
auth.extend(res[1][1])
add.extend(res[1][2])
return ans, auth, add
class PySourceAuthority(FileAuthority):
"""
A FileAuthority that is built up from Python source code.
"""
def loadFile(self, filename):
g, l = self.setupConfigNamespace(), {}
execfile(filename, g, l)
if "zone" not in l:
raise ValueError("No zone defined in " + filename)
self.records = {}
for rr in l["zone"]:
if isinstance(rr[1], dns.Record_SOA):
self.soa = rr
self.records.setdefault(rr[0].lower(), []).append(rr[1])
def wrapRecord(self, type):
def wrapRecordFunc(name, *arg, **kw):
return (dns.domainString(name), type(*arg, **kw))
return wrapRecordFunc
def setupConfigNamespace(self):
r = {}
items = dns.__dict__.keys()
for record in [x for x in items if x.startswith("Record_")]:
type = getattr(dns, record)
f = self.wrapRecord(type)
r[record[len("Record_") :]] = f
return r
class BindAuthority(FileAuthority):
"""
An Authority that loads U{BIND zone files
<https://en.wikipedia.org/wiki/Zone_file>}.
Supports only C{$ORIGIN} and C{$TTL} directives.
"""
def loadFile(self, filename):
"""
Load records from C{filename}.
@param filename: file to read from
@type filename: L{bytes}
"""
fp = FilePath(filename)
# Not the best way to set an origin. It can be set using $ORIGIN
# though.
self.origin = nativeString(fp.basename() + b".")
lines = fp.getContent().splitlines(True)
lines = self.stripComments(lines)
lines = self.collapseContinuations(lines)
self.parseLines(lines)
def stripComments(self, lines):
"""
Strip comments from C{lines}.
@param lines: lines to work on
@type lines: iterable of L{bytes}
@return: C{lines} sans comments.
"""
return (
a.find(b";") == -1 and a or a[: a.find(b";")]
for a in [b.strip() for b in lines]
)
def collapseContinuations(self, lines):
"""
Transform multiline statements into single lines.
@param lines: lines to work on
@type lines: iterable of L{bytes}
@return: iterable of continuous lines
"""
l = []
state = 0
for line in lines:
if state == 0:
if line.find(b"(") == -1:
l.append(line)
else:
l.append(line[: line.find(b"(")])
state = 1
else:
if line.find(b")") != -1:
l[-1] += b" " + line[: line.find(b")")]
state = 0
else:
l[-1] += b" " + line
return filter(None, (line.split() for line in l))
def parseLines(self, lines):
"""
Parse C{lines}.
@param lines: lines to work on
@type lines: iterable of L{bytes}
"""
ttl = 60 * 60 * 3
origin = self.origin
self.records = {}
for line in lines:
if line[0] == b"$TTL":
ttl = dns.str2time(line[1])
elif line[0] == b"$ORIGIN":
origin = line[1]
elif line[0] == b"$INCLUDE":
raise NotImplementedError("$INCLUDE directive not implemented")
elif line[0] == b"$GENERATE":
raise NotImplementedError("$GENERATE directive not implemented")
else:
self.parseRecordLine(origin, ttl, line)
# If the origin changed, reflect that within the instance.
self.origin = origin
def addRecord(self, owner, ttl, type, domain, cls, rdata):
"""
Add a record to our authority. Expand domain with origin if necessary.
@param owner: origin?
@type owner: L{bytes}
@param ttl: time to live for the record
@type ttl: L{int}
@param domain: the domain for which the record is to be added
@type domain: L{bytes}
@param type: record type
@type type: L{str}
@param cls: record class
@type cls: L{str}
@param rdata: record data
@type rdata: L{list} of L{bytes}
"""
if not domain.endswith(b"."):
domain = domain + b"." + owner[:-1]
else:
domain = domain[:-1]
f = getattr(self, f"class_{cls}", None)
if f:
f(ttl, type, domain, rdata)
else:
raise NotImplementedError(f"Record class {cls!r} not supported")
def class_IN(self, ttl, type, domain, rdata):
"""
Simulate a class IN and recurse into the actual class.
@param ttl: time to live for the record
@type ttl: L{int}
@param type: record type
@type type: str
@param domain: the domain
@type domain: bytes
@param rdata:
@type rdata: bytes
"""
record = getattr(dns, f"Record_{nativeString(type)}", None)
if record:
r = record(*rdata)
r.ttl = ttl
self.records.setdefault(domain.lower(), []).append(r)
if type == "SOA":
self.soa = (domain, r)
else:
raise NotImplementedError(
f"Record type {nativeString(type)!r} not supported"
)
def parseRecordLine(self, origin, ttl, line):
"""
Parse a C{line} from a zone file respecting C{origin} and C{ttl}.
Add resulting records to authority.
@param origin: starting point for the zone
@type origin: L{bytes}
@param ttl: time to live for the record
@type ttl: L{int}
@param line: zone file line to parse; split by word
@type line: L{list} of L{bytes}
"""
queryClasses = {qc.encode("ascii") for qc in dns.QUERY_CLASSES.values()}
queryTypes = {qt.encode("ascii") for qt in dns.QUERY_TYPES.values()}
markers = queryClasses | queryTypes
cls = b"IN"
owner = origin
if line[0] == b"@":
line = line[1:]
owner = origin
elif not line[0].isdigit() and line[0] not in markers:
owner = line[0]
line = line[1:]
if line[0].isdigit() or line[0] in markers:
domain = owner
owner = origin
else:
domain = line[0]
line = line[1:]
if line[0] in queryClasses:
cls = line[0]
line = line[1:]
if line[0].isdigit():
ttl = int(line[0])
line = line[1:]
elif line[0].isdigit():
ttl = int(line[0])
line = line[1:]
if line[0] in queryClasses:
cls = line[0]
line = line[1:]
type = line[0]
rdata = line[1:]
self.addRecord(owner, ttl, nativeString(type), domain, nativeString(cls), rdata)

View File

@@ -0,0 +1,131 @@
# -*- test-case-name: twisted.names.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An in-memory caching resolver.
"""
from twisted.internet import defer
from twisted.names import common, dns
from twisted.python import failure, log
class CacheResolver(common.ResolverBase):
"""
A resolver that serves records from a local, memory cache.
@ivar _reactor: A provider of L{interfaces.IReactorTime}.
"""
cache = None
def __init__(self, cache=None, verbose=0, reactor=None):
common.ResolverBase.__init__(self)
self.cache = {}
self.verbose = verbose
self.cancel = {}
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor
if cache:
for query, (seconds, payload) in cache.items():
self.cacheResult(query, payload, seconds)
def __setstate__(self, state):
self.__dict__ = state
now = self._reactor.seconds()
for k, (when, (ans, add, ns)) in self.cache.items():
diff = now - when
for rec in ans + add + ns:
if rec.ttl < diff:
del self.cache[k]
break
def __getstate__(self):
for c in self.cancel.values():
c.cancel()
self.cancel.clear()
return self.__dict__
def _lookup(self, name, cls, type, timeout):
now = self._reactor.seconds()
q = dns.Query(name, type, cls)
try:
when, (ans, auth, add) = self.cache[q]
except KeyError:
if self.verbose > 1:
log.msg("Cache miss for " + repr(name))
return defer.fail(failure.Failure(dns.DomainError(name)))
else:
if self.verbose:
log.msg("Cache hit for " + repr(name))
diff = now - when
try:
result = (
[
dns.RRHeader(
r.name.name, r.type, r.cls, r.ttl - diff, r.payload
)
for r in ans
],
[
dns.RRHeader(
r.name.name, r.type, r.cls, r.ttl - diff, r.payload
)
for r in auth
],
[
dns.RRHeader(
r.name.name, r.type, r.cls, r.ttl - diff, r.payload
)
for r in add
],
)
except ValueError:
return defer.fail(failure.Failure(dns.DomainError(name)))
else:
return defer.succeed(result)
def lookupAllRecords(self, name, timeout=None):
return defer.fail(failure.Failure(dns.DomainError(name)))
def cacheResult(self, query, payload, cacheTime=None):
"""
Cache a DNS entry.
@param query: a L{dns.Query} instance.
@param payload: a 3-tuple of lists of L{dns.RRHeader} records, the
matching result of the query (answers, authority and additional).
@param cacheTime: The time (seconds since epoch) at which the entry is
considered to have been added to the cache. If L{None} is given,
the current time is used.
"""
if self.verbose > 1:
log.msg("Adding %r to cache" % query)
self.cache[query] = (cacheTime or self._reactor.seconds(), payload)
if query in self.cancel:
self.cancel[query].cancel()
s = list(payload[0]) + list(payload[1]) + list(payload[2])
if s:
m = s[0].ttl
for r in s:
m = min(m, r.ttl)
else:
m = 0
self.cancel[query] = self._reactor.callLater(m, self.clearEntry, query)
def clearEntry(self, query):
del self.cache[query]
del self.cancel[query]

View File

@@ -0,0 +1,734 @@
# -*- test-case-name: twisted.names.test.test_names -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Asynchronous client DNS
The functions exposed in this module can be used for asynchronous name
resolution and dns queries.
If you need to create a resolver with specific requirements, such as needing to
do queries against a particular host, the L{createResolver} function will
return an C{IResolver}.
Future plans: Proper nameserver acquisition on Windows/MacOS,
better caching, respect timeouts
"""
import errno
import os
import warnings
from zope.interface import moduleProvides
from twisted.internet import defer, error, interfaces, protocol
from twisted.internet.abstract import isIPv6Address
from twisted.names import cache, common, dns, hosts as hostsModule, resolve, root
from twisted.python import failure, log
# Twisted imports
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform
moduleProvides(interfaces.IResolver)
class Resolver(common.ResolverBase):
"""
@ivar _waiting: A C{dict} mapping tuple keys of query name/type/class to
Deferreds which will be called back with the result of those queries.
This is used to avoid issuing the same query more than once in
parallel. This is more efficient on the network and helps avoid a
"birthday paradox" attack by keeping the number of outstanding requests
for a particular query fixed at one instead of allowing the attacker to
raise it to an arbitrary number.
@ivar _reactor: A provider of L{IReactorTCP}, L{IReactorUDP}, and
L{IReactorTime} which will be used to set up network resources and
track timeouts.
"""
index = 0
timeout = None
factory = None
servers = None
dynServers = ()
pending = None
connections = None
resolv = None
_lastResolvTime = None
_resolvReadInterval = 60
def __init__(self, resolv=None, servers=None, timeout=(1, 3, 11, 45), reactor=None):
"""
Construct a resolver which will query domain name servers listed in
the C{resolv.conf(5)}-format file given by C{resolv} as well as
those in the given C{servers} list. Servers are queried in a
round-robin fashion. If given, C{resolv} is periodically checked
for modification and re-parsed if it is noticed to have changed.
@type servers: C{list} of C{(str, int)} or L{None}
@param servers: If not None, interpreted as a list of (host, port)
pairs specifying addresses of domain name servers to attempt to use
for this lookup. Host addresses should be in IPv4 dotted-quad
form. If specified, overrides C{resolv}.
@type resolv: C{str}
@param resolv: Filename to read and parse as a resolver(5)
configuration file.
@type timeout: Sequence of C{int}
@param timeout: Default number of seconds after which to reissue the
query. When the last timeout expires, the query is considered
failed.
@param reactor: A provider of L{IReactorTime}, L{IReactorUDP}, and
L{IReactorTCP} which will be used to establish connections, listen
for DNS datagrams, and enforce timeouts. If not provided, the
global reactor will be used.
@raise ValueError: Raised if no nameserver addresses can be found.
"""
common.ResolverBase.__init__(self)
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor
self.timeout = timeout
if servers is None:
self.servers = []
else:
self.servers = servers
self.resolv = resolv
if not len(self.servers) and not resolv:
raise ValueError("No nameservers specified")
self.factory = DNSClientFactory(self, timeout)
self.factory.noisy = 0 # Be quiet by default
self.connections = []
self.pending = []
self._waiting = {}
self.maybeParseConfig()
def __getstate__(self):
d = self.__dict__.copy()
d["connections"] = []
d["_parseCall"] = None
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.maybeParseConfig()
def _openFile(self, path):
"""
Wrapper used for opening files in the class, exists primarily for unit
testing purposes.
"""
return FilePath(path).open()
def maybeParseConfig(self):
if self.resolv is None:
# Don't try to parse it, don't set up a call loop
return
try:
resolvConf = self._openFile(self.resolv)
except OSError as e:
if e.errno == errno.ENOENT:
# Missing resolv.conf is treated the same as an empty resolv.conf
self.parseConfig(())
else:
raise
else:
with resolvConf:
mtime = os.fstat(resolvConf.fileno()).st_mtime
if mtime != self._lastResolvTime:
log.msg(f"{self.resolv} changed, reparsing")
self._lastResolvTime = mtime
self.parseConfig(resolvConf)
# Check again in a little while
self._parseCall = self._reactor.callLater(
self._resolvReadInterval, self.maybeParseConfig
)
def parseConfig(self, resolvConf):
servers = []
for L in resolvConf:
L = L.strip()
if L.startswith(b"nameserver"):
resolver = (nativeString(L.split()[1]), dns.PORT)
servers.append(resolver)
log.msg(f"Resolver added {resolver!r} to server list")
elif L.startswith(b"domain"):
try:
self.domain = L.split()[1]
except IndexError:
self.domain = b""
self.search = None
elif L.startswith(b"search"):
self.search = L.split()[1:]
self.domain = None
if not servers:
servers.append(("127.0.0.1", dns.PORT))
self.dynServers = servers
def pickServer(self):
"""
Return the address of a nameserver.
TODO: Weight servers for response time so faster ones can be
preferred.
"""
if not self.servers and not self.dynServers:
return None
serverL = len(self.servers)
dynL = len(self.dynServers)
self.index += 1
self.index %= serverL + dynL
if self.index < serverL:
return self.servers[self.index]
else:
return self.dynServers[self.index - serverL]
def _connectedProtocol(self, interface=""):
"""
Return a new L{DNSDatagramProtocol} bound to a randomly selected port
number.
"""
failures = 0
proto = dns.DNSDatagramProtocol(self, reactor=self._reactor)
while True:
try:
self._reactor.listenUDP(dns.randomSource(), proto, interface=interface)
except error.CannotListenError as e:
failures += 1
if (
hasattr(e.socketError, "errno")
and e.socketError.errno == errno.EMFILE
):
# We've run out of file descriptors. Stop trying.
raise
if failures >= 1000:
# We've tried a thousand times and haven't found a port.
# This is almost impossible, and likely means something
# else weird is going on. Raise, as to not infinite loop.
raise
else:
return proto
def connectionMade(self, protocol):
"""
Called by associated L{dns.DNSProtocol} instances when they connect.
"""
self.connections.append(protocol)
for d, q, t in self.pending:
self.queryTCP(q, t).chainDeferred(d)
del self.pending[:]
def connectionLost(self, protocol):
"""
Called by associated L{dns.DNSProtocol} instances when they disconnect.
"""
if protocol in self.connections:
self.connections.remove(protocol)
def messageReceived(self, message, protocol, address=None):
log.msg("Unexpected message (%d) received from %r" % (message.id, address))
def _query(self, *args):
"""
Get a new L{DNSDatagramProtocol} instance from L{_connectedProtocol},
issue a query to it using C{*args}, and arrange for it to be
disconnected from its transport after the query completes.
@param args: Positional arguments to be passed to
L{DNSDatagramProtocol.query}.
@return: A L{Deferred} which will be called back with the result of the
query.
"""
if isIPv6Address(args[0][0]):
protocol = self._connectedProtocol(interface="::")
else:
protocol = self._connectedProtocol()
d = protocol.query(*args)
def cbQueried(result):
protocol.transport.stopListening()
return result
d.addBoth(cbQueried)
return d
def queryUDP(self, queries, timeout=None):
"""
Make a number of DNS queries via UDP.
@type queries: A C{list} of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: Sequence of C{int}
@param timeout: Number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@rtype: C{Deferred}
@raise C{twisted.internet.defer.TimeoutError}: When the query times
out.
"""
if timeout is None:
timeout = self.timeout
addresses = self.servers + list(self.dynServers)
if not addresses:
return defer.fail(IOError("No domain name servers available"))
# Make sure we go through servers in the list in the order they were
# specified.
addresses.reverse()
used = addresses.pop()
d = self._query(used, queries, timeout[0])
d.addErrback(self._reissue, addresses, [used], queries, timeout)
return d
def _reissue(self, reason, addressesLeft, addressesUsed, query, timeout):
reason.trap(dns.DNSQueryTimeoutError)
# If there are no servers left to be tried, adjust the timeout
# to the next longest timeout period and move all the
# "used" addresses back to the list of addresses to try.
if not addressesLeft:
addressesLeft = addressesUsed
addressesLeft.reverse()
addressesUsed = []
timeout = timeout[1:]
# If all timeout values have been used this query has failed. Tell the
# protocol we're giving up on it and return a terminal timeout failure
# to our caller.
if not timeout:
return failure.Failure(defer.TimeoutError(query))
# Get an address to try. Take it out of the list of addresses
# to try and put it ino the list of already tried addresses.
address = addressesLeft.pop()
addressesUsed.append(address)
# Issue a query to a server. Use the current timeout. Add this
# function as a timeout errback in case another retry is required.
d = self._query(address, query, timeout[0], reason.value.id)
d.addErrback(self._reissue, addressesLeft, addressesUsed, query, timeout)
return d
def queryTCP(self, queries, timeout=10):
"""
Make a number of DNS queries via TCP.
@type queries: Any non-zero number of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: C{int}
@param timeout: The number of seconds after which to fail.
@rtype: C{Deferred}
"""
if not len(self.connections):
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
host, port = address
self._reactor.connectTCP(host, port, self.factory)
self.pending.append((defer.Deferred(), queries, timeout))
return self.pending[-1][0]
else:
return self.connections[0].query(queries, timeout)
def filterAnswers(self, message):
"""
Extract results from the given message.
If the message was truncated, re-attempt the query over TCP and return
a Deferred which will fire with the results of that query.
If the message's result code is not C{twisted.names.dns.OK}, return a
Failure indicating the type of error which occurred.
Otherwise, return a three-tuple of lists containing the results from
the answers section, the authority section, and the additional section.
"""
if message.trunc:
return self.queryTCP(message.queries).addCallback(self.filterAnswers)
if message.rCode != dns.OK:
return failure.Failure(self.exceptionForCode(message.rCode)(message))
return (message.answers, message.authority, message.additional)
def _lookup(self, name, cls, type, timeout):
"""
Build a L{dns.Query} for the given parameters and dispatch it via UDP.
If this query is already outstanding, it will not be re-issued.
Instead, when the outstanding query receives a response, that response
will be re-used for this query as well.
@type name: C{str}
@type type: C{int}
@type cls: C{int}
@return: A L{Deferred} which fires with a three-tuple giving the
answer, authority, and additional sections of the response or with
a L{Failure} if the response code is anything other than C{dns.OK}.
"""
key = (name, type, cls)
waiting = self._waiting.get(key)
if waiting is None:
self._waiting[key] = []
d = self.queryUDP([dns.Query(name, type, cls)], timeout)
def cbResult(result):
for d in self._waiting.pop(key):
d.callback(result)
return result
d.addCallback(self.filterAnswers)
d.addBoth(cbResult)
else:
d = defer.Deferred()
waiting.append(d)
return d
# This one doesn't ever belong on UDP
def lookupZone(self, name, timeout=10):
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
host, port = address
d = defer.Deferred()
controller = AXFRController(name, d)
factory = DNSClientFactory(controller, timeout)
factory.noisy = False # stfu
connector = self._reactor.connectTCP(host, port, factory)
controller.timeoutCall = self._reactor.callLater(
timeout or 10, self._timeoutZone, d, controller, connector, timeout or 10
)
def eliminateTimeout(failure):
controller.timeoutCall.cancel()
controller.timeoutCall = None
return failure
return d.addCallbacks(
self._cbLookupZone, eliminateTimeout, callbackArgs=(connector,)
)
def _timeoutZone(self, d, controller, connector, seconds):
connector.disconnect()
controller.timeoutCall = None
controller.deferred = None
d.errback(
error.TimeoutError("Zone lookup timed out after %d seconds" % (seconds,))
)
def _cbLookupZone(self, result, connector):
connector.disconnect()
return (result, [], [])
class AXFRController:
timeoutCall = None
def __init__(self, name, deferred):
self.name = name
self.deferred = deferred
self.soa = None
self.records = []
self.pending = [(deferred,)]
def connectionMade(self, protocol):
# dig saids recursion-desired to 0, so I will too
message = dns.Message(protocol.pickID(), recDes=0)
message.queries = [dns.Query(self.name, dns.AXFR, dns.IN)]
protocol.writeMessage(message)
def connectionLost(self, protocol):
# XXX Do something here - see #3428
pass
def messageReceived(self, message, protocol):
# Caveat: We have to handle two cases: All records are in 1
# message, or all records are in N messages.
# According to http://cr.yp.to/djbdns/axfr-notes.html,
# 'authority' and 'additional' are always empty, and only
# 'answers' is present.
self.records.extend(message.answers)
if not self.records:
return
if not self.soa:
if self.records[0].type == dns.SOA:
# print "first SOA!"
self.soa = self.records[0]
if len(self.records) > 1 and self.records[-1].type == dns.SOA:
# print "It's the second SOA! We're done."
if self.timeoutCall is not None:
self.timeoutCall.cancel()
self.timeoutCall = None
if self.deferred is not None:
self.deferred.callback(self.records)
self.deferred = None
from twisted.internet.base import ThreadedResolver as _ThreadedResolverImpl
class ThreadedResolver(_ThreadedResolverImpl):
def __init__(self, reactor=None):
if reactor is None:
from twisted.internet import reactor
_ThreadedResolverImpl.__init__(self, reactor)
warnings.warn(
"twisted.names.client.ThreadedResolver is deprecated since "
"Twisted 9.0, use twisted.internet.base.ThreadedResolver "
"instead.",
category=DeprecationWarning,
stacklevel=2,
)
class DNSClientFactory(protocol.ClientFactory):
def __init__(self, controller, timeout=10):
self.controller = controller
self.timeout = timeout
def clientConnectionLost(self, connector, reason):
pass
def clientConnectionFailed(self, connector, reason):
"""
Fail all pending TCP DNS queries if the TCP connection attempt
fails.
@see: L{twisted.internet.protocol.ClientFactory}
@param connector: Not used.
@type connector: L{twisted.internet.interfaces.IConnector}
@param reason: A C{Failure} containing information about the
cause of the connection failure. This will be passed as the
argument to C{errback} on every pending TCP query
C{deferred}.
@type reason: L{twisted.python.failure.Failure}
"""
# Copy the current pending deferreds then reset the master
# pending list. This prevents triggering new deferreds which
# may be added by callback or errback functions on the current
# deferreds.
pending = self.controller.pending[:]
del self.controller.pending[:]
for pendingState in pending:
d = pendingState[0]
d.errback(reason)
def buildProtocol(self, addr):
p = dns.DNSProtocol(self.controller)
p.factory = self
return p
def createResolver(servers=None, resolvconf=None, hosts=None):
r"""
Create and return a Resolver.
@type servers: C{list} of C{(str, int)} or L{None}
@param servers: If not L{None}, interpreted as a list of domain name servers
to attempt to use. Each server is a tuple of address in C{str} dotted-quad
form and C{int} port number.
@type resolvconf: C{str} or L{None}
@param resolvconf: If not L{None}, on posix systems will be interpreted as
an alternate resolv.conf to use. Will do nothing on windows systems. If
L{None}, /etc/resolv.conf will be used.
@type hosts: C{str} or L{None}
@param hosts: If not L{None}, an alternate hosts file to use. If L{None}
on posix systems, /etc/hosts will be used. On windows, C:\windows\hosts
will be used.
@rtype: C{IResolver}
"""
if platform.getType() == "posix":
if resolvconf is None:
resolvconf = b"/etc/resolv.conf"
if hosts is None:
hosts = b"/etc/hosts"
theResolver = Resolver(resolvconf, servers)
hostResolver = hostsModule.Resolver(hosts)
else:
if hosts is None:
hosts = r"c:\windows\hosts"
from twisted.internet import reactor
bootstrap = _ThreadedResolverImpl(reactor)
hostResolver = hostsModule.Resolver(hosts)
theResolver = root.bootstrap(bootstrap, resolverFactory=Resolver)
L = [hostResolver, cache.CacheResolver(), theResolver]
return resolve.ResolverChain(L)
theResolver = None
def getResolver():
"""
Get a Resolver instance.
Create twisted.names.client.theResolver if it is L{None}, and then return
that value.
@rtype: C{IResolver}
"""
global theResolver
if theResolver is None:
try:
theResolver = createResolver()
except ValueError:
theResolver = createResolver(servers=[("127.0.0.1", 53)])
return theResolver
def getHostByName(name, timeout=None, effort=10):
"""
Resolve a name to a valid ipv4 or ipv6 address.
Will errback with C{DNSQueryTimeoutError} on a timeout, C{DomainError} or
C{AuthoritativeDomainError} (or subclasses) on other errors.
@type name: C{str}
@param name: DNS name to resolve.
@type timeout: Sequence of C{int}
@param timeout: Number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@type effort: C{int}
@param effort: How many times CNAME and NS records to follow while
resolving this name.
@rtype: C{Deferred}
"""
return getResolver().getHostByName(name, timeout, effort)
def query(query, timeout=None):
return getResolver().query(query, timeout)
def lookupAddress(name, timeout=None):
return getResolver().lookupAddress(name, timeout)
def lookupIPV6Address(name, timeout=None):
return getResolver().lookupIPV6Address(name, timeout)
def lookupAddress6(name, timeout=None):
return getResolver().lookupAddress6(name, timeout)
def lookupMailExchange(name, timeout=None):
return getResolver().lookupMailExchange(name, timeout)
def lookupNameservers(name, timeout=None):
return getResolver().lookupNameservers(name, timeout)
def lookupCanonicalName(name, timeout=None):
return getResolver().lookupCanonicalName(name, timeout)
def lookupMailBox(name, timeout=None):
return getResolver().lookupMailBox(name, timeout)
def lookupMailGroup(name, timeout=None):
return getResolver().lookupMailGroup(name, timeout)
def lookupMailRename(name, timeout=None):
return getResolver().lookupMailRename(name, timeout)
def lookupPointer(name, timeout=None):
return getResolver().lookupPointer(name, timeout)
def lookupAuthority(name, timeout=None):
return getResolver().lookupAuthority(name, timeout)
def lookupNull(name, timeout=None):
return getResolver().lookupNull(name, timeout)
def lookupWellKnownServices(name, timeout=None):
return getResolver().lookupWellKnownServices(name, timeout)
def lookupService(name, timeout=None):
return getResolver().lookupService(name, timeout)
def lookupHostInfo(name, timeout=None):
return getResolver().lookupHostInfo(name, timeout)
def lookupMailboxInfo(name, timeout=None):
return getResolver().lookupMailboxInfo(name, timeout)
def lookupText(name, timeout=None):
return getResolver().lookupText(name, timeout)
def lookupSenderPolicy(name, timeout=None):
return getResolver().lookupSenderPolicy(name, timeout)
def lookupResponsibility(name, timeout=None):
return getResolver().lookupResponsibility(name, timeout)
def lookupAFSDatabase(name, timeout=None):
return getResolver().lookupAFSDatabase(name, timeout)
def lookupZone(name, timeout=None):
return getResolver().lookupZone(name, timeout)
def lookupAllRecords(name, timeout=None):
return getResolver().lookupAllRecords(name, timeout)
def lookupNamingAuthorityPointer(name, timeout=None):
return getResolver().lookupNamingAuthorityPointer(name, timeout)

View File

@@ -0,0 +1,263 @@
# -*- test-case-name: twisted.names.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Base functionality useful to various parts of Twisted Names.
"""
import socket
from zope.interface import implementer
from twisted.internet import defer, error, interfaces
from twisted.logger import Logger
from twisted.names import dns
from twisted.names.error import (
DNSFormatError,
DNSNameError,
DNSNotImplementedError,
DNSQueryRefusedError,
DNSServerError,
DNSUnknownError,
)
# Helpers for indexing the three-tuples that get thrown around by this code a
# lot.
_ANS, _AUTH, _ADD = range(3)
EMPTY_RESULT = (), (), ()
@implementer(interfaces.IResolver)
class ResolverBase:
"""
L{ResolverBase} is a base class for implementations of
L{interfaces.IResolver} which deals with a lot
of the boilerplate of implementing all of the lookup methods.
@cvar _errormap: A C{dict} mapping DNS protocol failure response codes
to exception classes which will be used to represent those failures.
"""
_log = Logger()
_errormap = {
dns.EFORMAT: DNSFormatError,
dns.ESERVER: DNSServerError,
dns.ENAME: DNSNameError,
dns.ENOTIMP: DNSNotImplementedError,
dns.EREFUSED: DNSQueryRefusedError,
}
typeToMethod = None
def __init__(self):
self.typeToMethod = {}
for k, v in typeToMethod.items():
self.typeToMethod[k] = getattr(self, v)
def exceptionForCode(self, responseCode):
"""
Convert a response code (one of the possible values of
L{dns.Message.rCode} to an exception instance representing it.
@since: 10.0
"""
return self._errormap.get(responseCode, DNSUnknownError)
def query(self, query, timeout=None):
try:
method = self.typeToMethod[query.type]
except KeyError:
self._log.debug(
"Query of unknown type {query.type} for {query.name.name!r}",
query=query,
)
return defer.maybeDeferred(
self._lookup, query.name.name, dns.IN, query.type, timeout
)
else:
return defer.maybeDeferred(method, query.name.name, timeout)
def _lookup(self, name, cls, type, timeout):
return defer.fail(NotImplementedError("ResolverBase._lookup"))
def lookupAddress(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.A, timeout)
def lookupIPV6Address(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.AAAA, timeout)
def lookupAddress6(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.A6, timeout)
def lookupMailExchange(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.MX, timeout)
def lookupNameservers(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.NS, timeout)
def lookupCanonicalName(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.CNAME, timeout)
def lookupMailBox(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.MB, timeout)
def lookupMailGroup(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.MG, timeout)
def lookupMailRename(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.MR, timeout)
def lookupPointer(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.PTR, timeout)
def lookupAuthority(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.SOA, timeout)
def lookupNull(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.NULL, timeout)
def lookupWellKnownServices(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.WKS, timeout)
def lookupService(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.SRV, timeout)
def lookupHostInfo(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.HINFO, timeout)
def lookupMailboxInfo(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.MINFO, timeout)
def lookupText(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.TXT, timeout)
def lookupSenderPolicy(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.SPF, timeout)
def lookupResponsibility(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.RP, timeout)
def lookupAFSDatabase(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.AFSDB, timeout)
def lookupZone(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.AXFR, timeout)
def lookupNamingAuthorityPointer(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.NAPTR, timeout)
def lookupAllRecords(self, name, timeout=None):
return self._lookup(dns.domainString(name), dns.IN, dns.ALL_RECORDS, timeout)
# IResolverSimple
def getHostByName(self, name, timeout=None, effort=10):
name = dns.domainString(name)
# XXX - respect timeout
# XXX - this should do A and AAAA lookups, not ANY (see RFC 8482).
# https://twistedmatrix.com/trac/ticket/9691
d = self.lookupAllRecords(name, timeout)
d.addCallback(self._cbRecords, name, effort)
return d
def _cbRecords(self, records, name, effort):
(ans, auth, add) = records
result = extractRecord(self, dns.Name(name), ans + auth + add, effort)
if not result:
raise error.DNSLookupError(name)
return result
def extractRecord(resolver, name, answers, level=10):
"""
Resolve a name to an IP address, following I{CNAME} records and I{NS}
referrals recursively.
This is an implementation detail of L{ResolverBase.getHostByName}.
@param resolver: The resolver to use for the next query (unless handling
an I{NS} referral).
@type resolver: L{IResolver}
@param name: The name being looked up.
@type name: L{dns.Name}
@param answers: All of the records returned by the previous query (answers,
authority, and additional concatenated).
@type answers: L{list} of L{dns.RRHeader}
@param level: Remaining recursion budget. This is decremented at each
recursion. The query returns L{None} when it reaches 0.
@type level: L{int}
@returns: The first IPv4 or IPv6 address (as a dotted quad or colon
quibbles), or L{None} when no result is found.
@rtype: native L{str} or L{None}
"""
if not level:
return None
# FIXME: twisted.python.compat monkeypatches this if missing, so this
# condition is always true. https://twistedmatrix.com/trac/ticket/9753
if hasattr(socket, "inet_ntop"):
for r in answers:
if r.name == name and r.type == dns.A6:
return socket.inet_ntop(socket.AF_INET6, r.payload.address)
for r in answers:
if r.name == name and r.type == dns.AAAA:
return socket.inet_ntop(socket.AF_INET6, r.payload.address)
for r in answers:
if r.name == name and r.type == dns.A:
return socket.inet_ntop(socket.AF_INET, r.payload.address)
for r in answers:
if r.name == name and r.type == dns.CNAME:
result = extractRecord(resolver, r.payload.name, answers, level - 1)
if not result:
return resolver.getHostByName(r.payload.name.name, effort=level - 1)
return result
# No answers, but maybe there's a hint at who we should be asking about
# this
for r in answers:
if r.type != dns.NS:
continue
from twisted.names import client
nsResolver = client.Resolver(
servers=[
(r.payload.name.name.decode("ascii"), dns.PORT),
]
)
def queryAgain(records):
(ans, auth, add) = records
return extractRecord(nsResolver, name, ans + auth + add, level - 1)
return nsResolver.lookupAddress(name.name).addCallback(queryAgain)
typeToMethod = {
dns.A: "lookupAddress",
dns.AAAA: "lookupIPV6Address",
dns.A6: "lookupAddress6",
dns.NS: "lookupNameservers",
dns.CNAME: "lookupCanonicalName",
dns.SOA: "lookupAuthority",
dns.MB: "lookupMailBox",
dns.MG: "lookupMailGroup",
dns.MR: "lookupMailRename",
dns.NULL: "lookupNull",
dns.WKS: "lookupWellKnownServices",
dns.PTR: "lookupPointer",
dns.HINFO: "lookupHostInfo",
dns.MINFO: "lookupMailboxInfo",
dns.MX: "lookupMailExchange",
dns.TXT: "lookupText",
dns.SPF: "lookupSenderPolicy",
dns.RP: "lookupResponsibility",
dns.AFSDB: "lookupAFSDatabase",
dns.SRV: "lookupService",
dns.NAPTR: "lookupNamingAuthorityPointer",
dns.AXFR: "lookupZone",
dns.ALL_RECORDS: "lookupAllRecords",
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# -*- test-case-name: twisted.names.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Exception class definitions for Twisted Names.
"""
from twisted.internet.defer import TimeoutError
class DomainError(ValueError):
"""
Indicates a lookup failed because there were no records matching the given
C{name, class, type} triple.
"""
class AuthoritativeDomainError(ValueError):
"""
Indicates a lookup failed for a name for which this server is authoritative
because there were no records matching the given C{name, class, type}
triple.
"""
class DNSQueryTimeoutError(TimeoutError):
"""
Indicates a lookup failed due to a timeout.
@ivar id: The id of the message which timed out.
"""
def __init__(self, id):
TimeoutError.__init__(self)
self.id = id
class DNSFormatError(DomainError):
"""
Indicates a query failed with a result of C{twisted.names.dns.EFORMAT}.
"""
class DNSServerError(DomainError):
"""
Indicates a query failed with a result of C{twisted.names.dns.ESERVER}.
"""
class DNSNameError(DomainError):
"""
Indicates a query failed with a result of C{twisted.names.dns.ENAME}.
"""
class DNSNotImplementedError(DomainError):
"""
Indicates a query failed with a result of C{twisted.names.dns.ENOTIMP}.
"""
class DNSQueryRefusedError(DomainError):
"""
Indicates a query failed with a result of C{twisted.names.dns.EREFUSED}.
"""
class DNSUnknownError(DomainError):
"""
Indicates a query failed with an unknown result.
"""
class ResolverError(Exception):
"""
Indicates a query failed because of a decision made by the local
resolver object.
"""
__all__ = [
"DomainError",
"AuthoritativeDomainError",
"DNSQueryTimeoutError",
"DNSFormatError",
"DNSServerError",
"DNSNameError",
"DNSNotImplementedError",
"DNSQueryRefusedError",
"DNSUnknownError",
"ResolverError",
]

View File

@@ -0,0 +1,151 @@
# -*- test-case-name: twisted.names.test.test_hosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
hosts(5) support.
"""
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.names import common, dns
from twisted.python import failure
from twisted.python.compat import nativeString
from twisted.python.filepath import FilePath
def searchFileForAll(hostsFile, name):
"""
Search the given file, which is in hosts(5) standard format, for addresses
associated with a given name.
@param hostsFile: The name of the hosts(5)-format file to search.
@type hostsFile: L{FilePath}
@param name: The name to search for.
@type name: C{bytes}
@return: L{None} if the name is not found in the file, otherwise a
C{str} giving the address in the file associated with the name.
"""
results = []
try:
lines = hostsFile.getContent().splitlines()
except BaseException:
return results
name = name.lower()
for line in lines:
idx = line.find(b"#")
if idx != -1:
line = line[:idx]
if not line:
continue
parts = line.split()
if name.lower() in [s.lower() for s in parts[1:]]:
try:
maybeIP = nativeString(parts[0])
except ValueError: # Not ASCII.
continue
if isIPAddress(maybeIP) or isIPv6Address(maybeIP):
results.append(maybeIP)
return results
def searchFileFor(file, name):
"""
Grep given file, which is in hosts(5) standard format, for an address
entry with a given name.
@param file: The name of the hosts(5)-format file to search.
@type file: C{str} or C{bytes}
@param name: The name to search for.
@type name: C{bytes}
@return: L{None} if the name is not found in the file, otherwise a
C{str} giving the first address in the file associated with
the name.
"""
addresses = searchFileForAll(FilePath(file), name)
if addresses:
return addresses[0]
return None
class Resolver(common.ResolverBase):
"""
A resolver that services hosts(5) format files.
"""
def __init__(self, file=b"/etc/hosts", ttl=60 * 60):
common.ResolverBase.__init__(self)
self.file = file
self.ttl = ttl
def _aRecords(self, name):
"""
Return a tuple of L{dns.RRHeader} instances for all of the IPv4
addresses in the hosts file.
"""
return tuple(
dns.RRHeader(name, dns.A, dns.IN, self.ttl, dns.Record_A(addr, self.ttl))
for addr in searchFileForAll(FilePath(self.file), name)
if isIPAddress(addr)
)
def _aaaaRecords(self, name):
"""
Return a tuple of L{dns.RRHeader} instances for all of the IPv6
addresses in the hosts file.
"""
return tuple(
dns.RRHeader(
name, dns.AAAA, dns.IN, self.ttl, dns.Record_AAAA(addr, self.ttl)
)
for addr in searchFileForAll(FilePath(self.file), name)
if isIPv6Address(addr)
)
def _respond(self, name, records):
"""
Generate a response for the given name containing the given result
records, or a failure if there are no result records.
@param name: The DNS name the response is for.
@type name: C{str}
@param records: A tuple of L{dns.RRHeader} instances giving the results
that will go into the response.
@return: A L{Deferred} which will fire with a three-tuple of result
records, authority records, and additional records, or which will
fail with L{dns.DomainError} if there are no result records.
"""
if records:
return defer.succeed((records, (), ()))
return defer.fail(failure.Failure(dns.DomainError(name)))
def lookupAddress(self, name, timeout=None):
"""
Read any IPv4 addresses from C{self.file} and return them as
L{Record_A} instances.
"""
name = dns.domainString(name)
return self._respond(name, self._aRecords(name))
def lookupIPV6Address(self, name, timeout=None):
"""
Read any IPv6 addresses from C{self.file} and return them as
L{Record_AAAA} instances.
"""
name = dns.domainString(name)
return self._respond(name, self._aaaaRecords(name))
# Someday this should include IPv6 addresses too, but that will cause
# problems if users of the API (mainly via getHostByName) aren't updated to
# know about IPv6 first.
# FIXME - getHostByName knows about IPv6 now.
lookupAllRecords = lookupAddress

View File

@@ -0,0 +1 @@
!.gitignore

View File

@@ -0,0 +1,91 @@
# -*- test-case-name: twisted.names.test.test_resolve -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Lookup a name using multiple resolvers.
Future Plans: This needs someway to specify which resolver answered
the query, or someway to specify (authority|ttl|cache behavior|more?)
"""
from zope.interface import implementer
from twisted.internet import defer, interfaces
from twisted.names import common, dns, error
class FailureHandler:
def __init__(self, resolver, query, timeout):
self.resolver = resolver
self.query = query
self.timeout = timeout
def __call__(self, failure):
# AuthoritativeDomainErrors should halt resolution attempts
failure.trap(dns.DomainError, defer.TimeoutError, NotImplementedError)
return self.resolver(self.query, self.timeout)
@implementer(interfaces.IResolver)
class ResolverChain(common.ResolverBase):
"""
Lookup an address using multiple L{IResolver}s
"""
def __init__(self, resolvers):
"""
@type resolvers: L{list}
@param resolvers: A L{list} of L{IResolver} providers.
"""
common.ResolverBase.__init__(self)
self.resolvers = resolvers
def _lookup(self, name, cls, type, timeout):
"""
Build a L{dns.Query} for the given parameters and dispatch it
to each L{IResolver} in C{self.resolvers} until an answer or
L{error.AuthoritativeDomainError} is returned.
@type name: C{str}
@param name: DNS name to resolve.
@type type: C{int}
@param type: DNS record type.
@type cls: C{int}
@param cls: DNS record class.
@type timeout: Sequence of C{int}
@param timeout: Number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@rtype: L{Deferred}
@return: A L{Deferred} which fires with a three-tuple of lists of
L{twisted.names.dns.RRHeader} instances. The first element of the
tuple gives answers. The second element of the tuple gives
authorities. The third element of the tuple gives additional
information. The L{Deferred} may instead fail with one of the
exceptions defined in L{twisted.names.error} or with
C{NotImplementedError}.
"""
if not self.resolvers:
return defer.fail(error.DomainError())
q = dns.Query(name, type, cls)
d = self.resolvers[0].query(q, timeout)
for r in self.resolvers[1:]:
d = d.addErrback(FailureHandler(r.query, q, timeout))
return d
def lookupAllRecords(self, name, timeout=None):
# XXX: Why is this necessary? dns.ALL_RECORDS queries should
# be handled just the same as any other type by _lookup
# above. If I remove this method all names tests still
# pass. See #6604 -rwall
if not self.resolvers:
return defer.fail(error.DomainError())
d = self.resolvers[0].lookupAllRecords(name, timeout)
for r in self.resolvers[1:]:
d = d.addErrback(FailureHandler(r.lookupAllRecords, name, timeout))
return d

View File

@@ -0,0 +1,331 @@
# -*- test-case-name: twisted.names.test.test_rootresolve -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Resolver implementation for querying successive authoritative servers to
lookup a record, starting from the root nameservers.
@author: Jp Calderone
todo::
robustify it
documentation
"""
from twisted.internet import defer
from twisted.names import common, dns, error
from twisted.python.failure import Failure
class _DummyController:
"""
A do-nothing DNS controller. This is useful when all messages received
will be responses to previously issued queries. Anything else received
will be ignored.
"""
def messageReceived(self, *args):
pass
class Resolver(common.ResolverBase):
"""
L{Resolver} implements recursive lookup starting from a specified list of
root servers.
@ivar hints: See C{hints} parameter of L{__init__}
@ivar _maximumQueries: See C{maximumQueries} parameter of L{__init__}
@ivar _reactor: See C{reactor} parameter of L{__init__}
@ivar _resolverFactory: See C{resolverFactory} parameter of L{__init__}
"""
def __init__(self, hints, maximumQueries=10, reactor=None, resolverFactory=None):
"""
@param hints: A L{list} of L{str} giving the dotted quad
representation of IP addresses of root servers at which to
begin resolving names.
@type hints: L{list} of L{str}
@param maximumQueries: An optional L{int} giving the maximum
number of queries which will be attempted to resolve a
single name.
@type maximumQueries: L{int}
@param reactor: An optional L{IReactorTime} and L{IReactorUDP}
provider to use to bind UDP ports and manage timeouts.
@type reactor: L{IReactorTime} and L{IReactorUDP} provider
@param resolverFactory: An optional callable which accepts C{reactor}
and C{servers} arguments and returns an instance that provides a
C{queryUDP} method. Defaults to L{twisted.names.client.Resolver}.
@type resolverFactory: callable
"""
common.ResolverBase.__init__(self)
self.hints = hints
self._maximumQueries = maximumQueries
self._reactor = reactor
if resolverFactory is None:
from twisted.names.client import Resolver as resolverFactory
self._resolverFactory = resolverFactory
def _roots(self):
"""
Return a list of two-tuples representing the addresses of the root
servers, as defined by C{self.hints}.
"""
return [(ip, dns.PORT) for ip in self.hints]
def _query(self, query, servers, timeout, filter):
"""
Issue one query and return a L{Deferred} which fires with its response.
@param query: The query to issue.
@type query: L{dns.Query}
@param servers: The servers which might have an answer for this
query.
@type servers: L{list} of L{tuple} of L{str} and L{int}
@param timeout: A timeout on how long to wait for the response.
@type timeout: L{tuple} of L{int}
@param filter: A flag indicating whether to filter the results. If
C{True}, the returned L{Deferred} will fire with a three-tuple of
lists of L{twisted.names.dns.RRHeader} (like the return value of
the I{lookup*} methods of L{IResolver}. IF C{False}, the result
will be a L{Message} instance.
@type filter: L{bool}
@return: A L{Deferred} which fires with the response or a timeout
error.
@rtype: L{Deferred}
"""
r = self._resolverFactory(servers=servers, reactor=self._reactor)
d = r.queryUDP([query], timeout)
if filter:
d.addCallback(r.filterAnswers)
return d
def _lookup(self, name, cls, type, timeout):
"""
Implement name lookup by recursively discovering the authoritative
server for the name and then asking it, starting at one of the servers
in C{self.hints}.
"""
if timeout is None:
# A series of timeouts for semi-exponential backoff, summing to an
# arbitrary total of 60 seconds.
timeout = (1, 3, 11, 45)
return self._discoverAuthority(
dns.Query(name, type, cls), self._roots(), timeout, self._maximumQueries
)
def _discoverAuthority(self, query, servers, timeout, queriesLeft):
"""
Issue a query to a server and follow a delegation if necessary.
@param query: The query to issue.
@type query: L{dns.Query}
@param servers: The servers which might have an answer for this
query.
@type servers: L{list} of L{tuple} of L{str} and L{int}
@param timeout: A C{tuple} of C{int} giving the timeout to use for this
query.
@param queriesLeft: A C{int} giving the number of queries which may
yet be attempted to answer this query before the attempt will be
abandoned.
@return: A L{Deferred} which fires with a three-tuple of lists of
L{twisted.names.dns.RRHeader} giving the response, or with a
L{Failure} if there is a timeout or response error.
"""
# Stop now if we've hit the query limit.
if queriesLeft <= 0:
return Failure(error.ResolverError("Query limit reached without result"))
d = self._query(query, servers, timeout, False)
d.addCallback(self._discoveredAuthority, query, timeout, queriesLeft - 1)
return d
def _discoveredAuthority(self, response, query, timeout, queriesLeft):
"""
Interpret the response to a query, checking for error codes and
following delegations if necessary.
@param response: The L{Message} received in response to issuing C{query}.
@type response: L{Message}
@param query: The L{dns.Query} which was issued.
@type query: L{dns.Query}.
@param timeout: The timeout to use if another query is indicated by
this response.
@type timeout: L{tuple} of L{int}
@param queriesLeft: A C{int} giving the number of queries which may
yet be attempted to answer this query before the attempt will be
abandoned.
@return: A L{Failure} indicating a response error, a three-tuple of
lists of L{twisted.names.dns.RRHeader} giving the response to
C{query} or a L{Deferred} which will fire with one of those.
"""
if response.rCode != dns.OK:
return Failure(self.exceptionForCode(response.rCode)(response))
# Turn the answers into a structure that's a little easier to work with.
records = {}
for answer in response.answers:
records.setdefault(answer.name, []).append(answer)
def findAnswerOrCName(name, type, cls):
cname = None
for record in records.get(name, []):
if record.cls == cls:
if record.type == type:
return record
elif record.type == dns.CNAME:
cname = record
# If there were any CNAME records, return the last one. There's
# only supposed to be zero or one, though.
return cname
seen = set()
name = query.name
record = None
while True:
seen.add(name)
previous = record
record = findAnswerOrCName(name, query.type, query.cls)
if record is None:
if name == query.name:
# If there's no answer for the original name, then this may
# be a delegation. Code below handles it.
break
else:
# Try to resolve the CNAME with another query.
d = self._discoverAuthority(
dns.Query(str(name), query.type, query.cls),
self._roots(),
timeout,
queriesLeft,
)
# We also want to include the CNAME in the ultimate result,
# otherwise this will be pretty confusing.
def cbResolved(results):
answers, authority, additional = results
answers.insert(0, previous)
return (answers, authority, additional)
d.addCallback(cbResolved)
return d
elif record.type == query.type:
return (response.answers, response.authority, response.additional)
else:
# It's a CNAME record. Try to resolve it from the records
# in this response with another iteration around the loop.
if record.payload.name in seen:
raise error.ResolverError("Cycle in CNAME processing")
name = record.payload.name
# Build a map to use to convert NS names into IP addresses.
addresses = {}
for rr in response.additional:
if rr.type == dns.A:
addresses[rr.name.name] = rr.payload.dottedQuad()
hints = []
traps = []
for rr in response.authority:
if rr.type == dns.NS:
ns = rr.payload.name.name
if ns in addresses:
hints.append((addresses[ns], dns.PORT))
else:
traps.append(ns)
if hints:
return self._discoverAuthority(query, hints, timeout, queriesLeft)
elif traps:
d = self.lookupAddress(traps[0], timeout)
def getOneAddress(results):
answers, authority, additional = results
return answers[0].payload.dottedQuad()
d.addCallback(getOneAddress)
d.addCallback(
lambda hint: self._discoverAuthority(
query, [(hint, dns.PORT)], timeout, queriesLeft - 1
)
)
return d
else:
return Failure(
error.ResolverError("Stuck at response without answers or delegation")
)
def makePlaceholder(deferred, name):
def placeholder(*args, **kw):
deferred.addCallback(lambda r: getattr(r, name)(*args, **kw))
return deferred
return placeholder
class DeferredResolver:
def __init__(self, resolverDeferred):
self.waiting = []
resolverDeferred.addCallback(self.gotRealResolver)
def gotRealResolver(self, resolver):
w = self.waiting
self.__dict__ = resolver.__dict__
self.__class__ = resolver.__class__
for d in w:
d.callback(resolver)
def __getattr__(self, name):
if name.startswith("lookup") or name in ("getHostByName", "query"):
self.waiting.append(defer.Deferred())
return makePlaceholder(self.waiting[-1], name)
raise AttributeError(name)
def bootstrap(resolver, resolverFactory=None):
"""
Lookup the root nameserver addresses using the given resolver
Return a Resolver which will eventually become a C{root.Resolver}
instance that has references to all the root servers that we were able
to look up.
@param resolver: The resolver instance which will be used to
lookup the root nameserver addresses.
@type resolver: L{twisted.internet.interfaces.IResolverSimple}
@param resolverFactory: An optional callable which returns a
resolver instance. It will passed as the C{resolverFactory}
argument to L{Resolver.__init__}.
@type resolverFactory: callable
@return: A L{DeferredResolver} which will be dynamically replaced
with L{Resolver} when the root nameservers have been looked up.
"""
domains = [chr(ord("a") + i) for i in range(13)]
L = [resolver.getHostByName("%s.root-servers.net" % d) for d in domains]
d = defer.DeferredList(L, consumeErrors=True)
def buildResolver(res):
return Resolver(
hints=[e[1] for e in res if e[0]], resolverFactory=resolverFactory
)
d.addCallback(buildResolver)
return DeferredResolver(d)

View File

@@ -0,0 +1,216 @@
# -*- test-case-name: twisted.names.test.test_names -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
__all__ = ["SecondaryAuthority", "SecondaryAuthorityService"]
from twisted.application import service
from twisted.internet import defer, task
from twisted.names import client, common, dns, resolve
from twisted.names.authority import FileAuthority
from twisted.python import failure, log
from twisted.python.compat import nativeString
class SecondaryAuthorityService(service.Service):
"""
A service that keeps one or more authorities up to date by doing hourly
zone transfers from a master.
@ivar primary: IP address of the master.
@type primary: L{str}
@ivar domains: An authority for each domain mirrored from the master.
@type domains: L{list} of L{SecondaryAuthority}
"""
calls = None
_port = 53
def __init__(self, primary, domains):
"""
@param primary: The IP address of the server from which to perform
zone transfers.
@type primary: L{str}
@param domains: A sequence of domain names for which to perform
zone transfers.
@type domains: L{list} of L{bytes}
"""
self.primary = nativeString(primary)
self.domains = [SecondaryAuthority(primary, d) for d in domains]
@classmethod
def fromServerAddressAndDomains(cls, serverAddress, domains):
"""
Construct a new L{SecondaryAuthorityService} from a tuple giving a
server address and a C{str} giving the name of a domain for which this
is an authority.
@param serverAddress: A two-tuple, the first element of which is a
C{str} giving an IP address and the second element of which is a
C{int} giving a port number. Together, these define where zone
transfers will be attempted from.
@param domains: Domain names for which to perform zone transfers.
@type domains: sequence of L{bytes}
@return: A new instance of L{SecondaryAuthorityService}.
"""
primary, port = serverAddress
service = cls(primary, [])
service._port = port
service.domains = [
SecondaryAuthority.fromServerAddressAndDomain(serverAddress, d)
for d in domains
]
return service
def getAuthority(self):
"""
Get a resolver for the transferred domains.
@rtype: L{ResolverChain}
"""
return resolve.ResolverChain(self.domains)
def startService(self):
service.Service.startService(self)
self.calls = [task.LoopingCall(d.transfer) for d in self.domains]
i = 0
from twisted.internet import reactor
for c in self.calls:
# XXX Add errbacks, respect proper timeouts
reactor.callLater(i, c.start, 60 * 60)
i += 1
def stopService(self):
service.Service.stopService(self)
for c in self.calls:
c.stop()
class SecondaryAuthority(FileAuthority):
"""
An Authority that keeps itself updated by performing zone transfers.
@ivar primary: The IP address of the server from which zone transfers will
be attempted.
@type primary: L{str}
@ivar _port: The port number of the server from which zone transfers will
be attempted.
@type _port: L{int}
@ivar domain: The domain for which this is the secondary authority.
@type domain: L{bytes}
@ivar _reactor: The reactor to use to perform the zone transfers, or
L{None} to use the global reactor.
"""
transferring = False
soa = records = None
_port = 53
_reactor = None
def __init__(self, primaryIP, domain):
"""
@param domain: The domain for which this will be the secondary
authority.
@type domain: L{bytes} or L{str}
"""
# Yep. Skip over FileAuthority.__init__. This is a hack until we have
# a good composition-based API for the complicated DNS record lookup
# logic we want to share.
common.ResolverBase.__init__(self)
self.primary = nativeString(primaryIP)
self.domain = dns.domainString(domain)
@classmethod
def fromServerAddressAndDomain(cls, serverAddress, domain):
"""
Construct a new L{SecondaryAuthority} from a tuple giving a server
address and a C{bytes} giving the name of a domain for which this is an
authority.
@param serverAddress: A two-tuple, the first element of which is a
C{str} giving an IP address and the second element of which is a
C{int} giving a port number. Together, these define where zone
transfers will be attempted from.
@param domain: A C{bytes} giving the domain to transfer.
@type domain: L{bytes}
@return: A new instance of L{SecondaryAuthority}.
"""
primary, port = serverAddress
secondary = cls(primary, domain)
secondary._port = port
return secondary
def transfer(self):
"""
Attempt a zone transfer.
@returns: A L{Deferred} that fires with L{None} when attempted zone
transfer has completed.
"""
# FIXME: This logic doesn't avoid duplicate transfers
# https://twistedmatrix.com/trac/ticket/9754
if self.transferring: # <-- never true
return
self.transfering = True # <-- speling
reactor = self._reactor
if reactor is None:
from twisted.internet import reactor
resolver = client.Resolver(
servers=[(self.primary, self._port)], reactor=reactor
)
return (
resolver.lookupZone(self.domain)
.addCallback(self._cbZone)
.addErrback(self._ebZone)
)
def _lookup(self, name, cls, type, timeout=None):
if not self.soa or not self.records:
# No transfer has occurred yet. Fail non-authoritatively so that
# the caller can try elsewhere.
return defer.fail(failure.Failure(dns.DomainError(name)))
return FileAuthority._lookup(self, name, cls, type, timeout)
def _cbZone(self, zone):
ans, _, _ = zone
self.records = r = {}
for rec in ans:
if not self.soa and rec.type == dns.SOA:
self.soa = (rec.name.name.lower(), rec.payload)
else:
r.setdefault(rec.name.name.lower(), []).append(rec.payload)
def _ebZone(self, failure):
log.msg(
"Updating %s from %s failed during zone transfer"
% (self.domain, self.primary)
)
log.err(failure)
def update(self):
self.transfer().addCallbacks(self._cbTransferred, self._ebTransferred)
def _cbTransferred(self, result):
self.transferring = False
def _ebTransferred(self, failure):
self.transferred = False
log.msg(
"Transferring %s from %s failed after zone transfer"
% (self.domain, self.primary)
)
log.err(failure)

View File

@@ -0,0 +1,569 @@
# -*- test-case-name: twisted.names.test.test_names,twisted.names.test.test_server -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Async DNS server
Future plans:
- Better config file format maybe
- Make sure to differentiate between different classes
- notice truncation bit
Important: No additional processing is done on some of the record types.
This violates the most basic RFC and is just plain annoying
for resolvers to deal with. Fix it.
@author: Jp Calderone
"""
import time
from twisted.internet import protocol
from twisted.names import dns, resolve
from twisted.python import log
class DNSServerFactory(protocol.ServerFactory):
"""
Server factory and tracker for L{DNSProtocol} connections. This class also
provides records for responses to DNS queries.
@ivar cache: A L{Cache<twisted.names.cache.CacheResolver>} instance whose
C{cacheResult} method is called when a response is received from one of
C{clients}. Defaults to L{None} if no caches are specified. See
C{caches} of L{__init__} for more details.
@type cache: L{Cache<twisted.names.cache.CacheResolver>} or L{None}
@ivar canRecurse: A flag indicating whether this server is capable of
performing recursive DNS resolution.
@type canRecurse: L{bool}
@ivar resolver: A L{resolve.ResolverChain} containing an ordered list of
C{authorities}, C{caches} and C{clients} to which queries will be
dispatched.
@type resolver: L{resolve.ResolverChain}
@ivar verbose: See L{__init__}
@ivar connections: A list of all the connected L{DNSProtocol} instances
using this object as their controller.
@type connections: C{list} of L{DNSProtocol} instances
@ivar protocol: A callable used for building a DNS stream protocol. Called
by L{DNSServerFactory.buildProtocol} and passed the L{DNSServerFactory}
instance as the one and only positional argument. Defaults to
L{dns.DNSProtocol}.
@type protocol: L{IProtocolFactory} constructor
@ivar _messageFactory: A response message constructor with an initializer
signature matching L{dns.Message.__init__}.
@type _messageFactory: C{callable}
"""
# Type is wrong. See: https://twistedmatrix.com/trac/ticket/10004#ticket
protocol = dns.DNSProtocol # type: ignore[assignment]
cache = None
_messageFactory = dns.Message
def __init__(self, authorities=None, caches=None, clients=None, verbose=0):
"""
@param authorities: Resolvers which provide authoritative answers.
@type authorities: L{list} of L{IResolver} providers
@param caches: Resolvers which provide cached non-authoritative
answers. The first cache instance is assigned to
C{DNSServerFactory.cache} and its C{cacheResult} method will be
called when a response is received from one of C{clients}.
@type caches: L{list} of L{Cache<twisted.names.cache.CacheResolver>} instances
@param clients: Resolvers which are capable of performing recursive DNS
lookups.
@type clients: L{list} of L{IResolver} providers
@param verbose: An integer controlling the verbosity of logging of
queries and responses. Default is C{0} which means no logging. Set
to C{2} to enable logging of full query and response messages.
@type verbose: L{int}
"""
resolvers = []
if authorities is not None:
resolvers.extend(authorities)
if caches is not None:
resolvers.extend(caches)
if clients is not None:
resolvers.extend(clients)
self.canRecurse = not not clients
self.resolver = resolve.ResolverChain(resolvers)
self.verbose = verbose
if caches:
self.cache = caches[-1]
self.connections = []
def _verboseLog(self, *args, **kwargs):
"""
Log a message only if verbose logging is enabled.
@param args: Positional arguments which will be passed to C{log.msg}
@param kwargs: Keyword arguments which will be passed to C{log.msg}
"""
if self.verbose > 0:
log.msg(*args, **kwargs)
def buildProtocol(self, addr):
p = self.protocol(self)
p.factory = self
return p
def connectionMade(self, protocol):
"""
Track a newly connected L{DNSProtocol}.
@param protocol: The protocol instance to be tracked.
@type protocol: L{dns.DNSProtocol}
"""
self.connections.append(protocol)
def connectionLost(self, protocol):
"""
Stop tracking a no-longer connected L{DNSProtocol}.
@param protocol: The tracked protocol instance to be which has been
lost.
@type protocol: L{dns.DNSProtocol}
"""
self.connections.remove(protocol)
def sendReply(self, protocol, message, address):
"""
Send a response C{message} to a given C{address} via the supplied
C{protocol}.
Message payload will be logged if C{DNSServerFactory.verbose} is C{>1}.
@param protocol: The DNS protocol instance to which to send the message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The DNS message to be sent.
@type message: L{dns.Message}
@param address: The address to which the message will be sent or L{None}
if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
if self.verbose > 1:
s = " ".join([str(a.payload) for a in message.answers])
auth = " ".join([str(a.payload) for a in message.authority])
add = " ".join([str(a.payload) for a in message.additional])
if not s:
log.msg("Replying with no answers")
else:
log.msg("Answers are " + s)
log.msg("Authority is " + auth)
log.msg("Additional is " + add)
if address is None:
protocol.writeMessage(message)
else:
protocol.writeMessage(message, address)
self._verboseLog(
"Processed query in %0.3f seconds" % (time.time() - message.timeReceived)
)
def _responseFromMessage(
self, message, rCode=dns.OK, answers=None, authority=None, additional=None
):
"""
Generate a L{Message} instance suitable for use as the response to
C{message}.
C{queries} will be copied from the request to the response.
C{rCode}, C{answers}, C{authority} and C{additional} will be assigned to
the response, if supplied.
The C{recAv} flag will be set on the response if the C{canRecurse} flag
on this L{DNSServerFactory} is set to L{True}.
The C{auth} flag will be set on the response if *any* of the supplied
C{answers} have their C{auth} flag set to L{True}.
The response will have the same C{maxSize} as the request.
Additionally, the response will have a C{timeReceived} attribute whose
value is that of the original request and the
@see: L{dns._responseFromMessage}
@param message: The request message
@type message: L{Message}
@param rCode: The response code which will be assigned to the response.
@type message: L{int}
@param answers: An optional list of answer records which will be
assigned to the response.
@type answers: L{list} of L{dns.RRHeader}
@param authority: An optional list of authority records which will be
assigned to the response.
@type authority: L{list} of L{dns.RRHeader}
@param additional: An optional list of additional records which will be
assigned to the response.
@type additional: L{list} of L{dns.RRHeader}
@return: A response L{Message} instance.
@rtype: L{Message}
"""
if answers is None:
answers = []
if authority is None:
authority = []
if additional is None:
additional = []
authoritativeAnswer = False
for x in answers:
if x.isAuthoritative():
authoritativeAnswer = True
break
response = dns._responseFromMessage(
responseConstructor=self._messageFactory,
message=message,
recAv=self.canRecurse,
rCode=rCode,
auth=authoritativeAnswer,
)
# XXX: Timereceived is a hack which probably shouldn't be tacked onto
# the message. Use getattr here so that we don't have to set the
# timereceived on every message in the tests. See #6957.
response.timeReceived = getattr(message, "timeReceived", None)
# XXX: This is another hack. dns.Message.decode sets maxSize=0 which
# means that responses are never truncated. I'll maintain that behaviour
# here until #6949 is resolved.
response.maxSize = message.maxSize
response.answers = answers
response.authority = authority
response.additional = additional
return response
def gotResolverResponse(self, response, protocol, message, address):
"""
A callback used by L{DNSServerFactory.handleQuery} for handling the
deferred response from C{self.resolver.query}.
Constructs a response message by combining the original query message
with the resolved answer, authority and additional records.
Marks the response message as authoritative if any of the resolved
answers are found to be authoritative.
The resolved answers count will be logged if C{DNSServerFactory.verbose}
is C{>1}.
@param response: Answer records, authority records and additional records
@type response: L{tuple} of L{list} of L{dns.RRHeader} instances
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
ans, auth, add = response
response = self._responseFromMessage(
message=message, rCode=dns.OK, answers=ans, authority=auth, additional=add
)
self.sendReply(protocol, response, address)
l = len(ans) + len(auth) + len(add)
self._verboseLog("Lookup found %d record%s" % (l, l != 1 and "s" or ""))
if self.cache and l:
self.cache.cacheResult(message.queries[0], (ans, auth, add))
def gotResolverError(self, failure, protocol, message, address):
"""
A callback used by L{DNSServerFactory.handleQuery} for handling deferred
errors from C{self.resolver.query}.
Constructs a response message from the original query message by
assigning a suitable error code to C{rCode}.
An error message will be logged if C{DNSServerFactory.verbose} is C{>1}.
@param failure: The reason for the failed resolution (as reported by
C{self.resolver.query}).
@type failure: L{Failure<twisted.python.failure.Failure>}
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
if failure.check(dns.DomainError, dns.AuthoritativeDomainError):
rCode = dns.ENAME
else:
rCode = dns.ESERVER
log.err(failure)
response = self._responseFromMessage(message=message, rCode=rCode)
self.sendReply(protocol, response, address)
self._verboseLog("Lookup failed")
def handleQuery(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} when a query message is
received.
Takes the first query from the received message and dispatches it to
C{self.resolver.query}.
Adds callbacks L{DNSServerFactory.gotResolverResponse} and
L{DNSServerFactory.gotResolverError} to the resulting deferred.
Note: Multiple queries in a single message are not supported because
there is no standard way to respond with multiple rCodes, auth,
etc. This is consistent with other DNS server implementations. See
U{http://tools.ietf.org/html/draft-ietf-dnsext-edns1-03} for a proposed
solution.
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
@return: A C{deferred} which fires with the resolved result or error of
the first query in C{message}.
@rtype: L{Deferred<twisted.internet.defer.Deferred>}
"""
query = message.queries[0]
return (
self.resolver.query(query)
.addCallback(self.gotResolverResponse, protocol, message, address)
.addErrback(self.gotResolverError, protocol, message, address)
)
def handleInverseQuery(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} when an inverse query
message is received.
Replies with a I{Not Implemented} error by default.
An error message will be logged if C{DNSServerFactory.verbose} is C{>1}.
Override in a subclass.
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
message.rCode = dns.ENOTIMP
self.sendReply(protocol, message, address)
self._verboseLog(f"Inverse query from {address!r}")
def handleStatus(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} when a status message is
received.
Replies with a I{Not Implemented} error by default.
An error message will be logged if C{DNSServerFactory.verbose} is C{>1}.
Override in a subclass.
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
message.rCode = dns.ENOTIMP
self.sendReply(protocol, message, address)
self._verboseLog(f"Status request from {address!r}")
def handleNotify(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} when a notify message is
received.
Replies with a I{Not Implemented} error by default.
An error message will be logged if C{DNSServerFactory.verbose} is C{>1}.
Override in a subclass.
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
message.rCode = dns.ENOTIMP
self.sendReply(protocol, message, address)
self._verboseLog(f"Notify message from {address!r}")
def handleOther(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} when a message with
unrecognised I{OPCODE} is received.
Replies with a I{Not Implemented} error by default.
An error message will be logged if C{DNSServerFactory.verbose} is C{>1}.
Override in a subclass.
@param protocol: The DNS protocol instance to which to send a response
message.
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param message: The original DNS query message for which a response
message will be constructed.
@type message: L{dns.Message}
@param address: The address to which the response message will be sent
or L{None} if C{protocol} is a stream protocol.
@type address: L{tuple} or L{None}
"""
message.rCode = dns.ENOTIMP
self.sendReply(protocol, message, address)
self._verboseLog("Unknown op code (%d) from %r" % (message.opCode, address))
def messageReceived(self, message, proto, address=None):
"""
L{DNSServerFactory.messageReceived} is called by protocols which are
under the control of this L{DNSServerFactory} whenever they receive a
DNS query message or an unexpected / duplicate / late DNS response
message.
L{DNSServerFactory.allowQuery} is called with the received message,
protocol and origin address. If it returns L{False}, a C{dns.EREFUSED}
response is sent back to the client.
Otherwise the received message is dispatched to one of
L{DNSServerFactory.handleQuery}, L{DNSServerFactory.handleInverseQuery},
L{DNSServerFactory.handleStatus}, L{DNSServerFactory.handleNotify}, or
L{DNSServerFactory.handleOther} depending on the I{OPCODE} of the
received message.
If C{DNSServerFactory.verbose} is C{>0} all received messages will be
logged in more or less detail depending on the value of C{verbose}.
@param message: The DNS message that was received.
@type message: L{dns.Message}
@param proto: The DNS protocol instance which received the message
@type proto: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param address: The address from which the message was received. Only
provided for messages received by datagram protocols. The origin of
Messages received from stream protocols can be gleaned from the
protocol C{transport} attribute.
@type address: L{tuple} or L{None}
"""
message.timeReceived = time.time()
if self.verbose:
if self.verbose > 1:
s = " ".join([str(q) for q in message.queries])
else:
s = " ".join(
[dns.QUERY_TYPES.get(q.type, "UNKNOWN") for q in message.queries]
)
if not len(s):
log.msg(f"Empty query from {address or proto.transport.getPeer()!r}")
else:
log.msg(f"{s} query from {address or proto.transport.getPeer()!r}")
if not self.allowQuery(message, proto, address):
message.rCode = dns.EREFUSED
self.sendReply(proto, message, address)
elif message.opCode == dns.OP_QUERY:
self.handleQuery(message, proto, address)
elif message.opCode == dns.OP_INVERSE:
self.handleInverseQuery(message, proto, address)
elif message.opCode == dns.OP_STATUS:
self.handleStatus(message, proto, address)
elif message.opCode == dns.OP_NOTIFY:
self.handleNotify(message, proto, address)
else:
self.handleOther(message, proto, address)
def allowQuery(self, message, protocol, address):
"""
Called by L{DNSServerFactory.messageReceived} to decide whether to
process a received message or to reply with C{dns.EREFUSED}.
This default implementation permits anything but empty queries.
Override in a subclass to implement alternative policies.
@param message: The DNS message that was received.
@type message: L{dns.Message}
@param protocol: The DNS protocol instance which received the message
@type protocol: L{dns.DNSDatagramProtocol} or L{dns.DNSProtocol}
@param address: The address from which the message was received. Only
provided for messages received by datagram protocols. The origin of
Messages received from stream protocols can be gleaned from the
protocol C{transport} attribute.
@type address: L{tuple} or L{None}
@return: L{True} if the received message contained one or more queries,
else L{False}.
@rtype: L{bool}
"""
return len(message.queries)

View File

@@ -0,0 +1,271 @@
# -*- test-case-name: twisted.names.test.test_srvconnect -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import random
from zope.interface import implementer
from twisted.internet import error, interfaces
from twisted.names import client, dns
from twisted.names.error import DNSNameError
from twisted.python.compat import nativeString
class _SRVConnector_ClientFactoryWrapper:
def __init__(self, connector, wrappedFactory):
self.__connector = connector
self.__wrappedFactory = wrappedFactory
def startedConnecting(self, connector):
self.__wrappedFactory.startedConnecting(self.__connector)
def clientConnectionFailed(self, connector, reason):
self.__connector.connectionFailed(reason)
def clientConnectionLost(self, connector, reason):
self.__connector.connectionLost(reason)
def __getattr__(self, key):
return getattr(self.__wrappedFactory, key)
@implementer(interfaces.IConnector)
class SRVConnector:
"""
A connector that looks up DNS SRV records.
RFC 2782 details how SRV records should be interpreted and selected
for subsequent connection attempts. The algorithm for using the records'
priority and weight is implemented in L{pickServer}.
@ivar servers: List of candidate server records for future connection
attempts.
@type servers: L{list} of L{dns.Record_SRV}
@ivar orderedServers: List of server records that have already been tried
in this round of connection attempts.
@type orderedServers: L{list} of L{dns.Record_SRV}
"""
stopAfterDNS = 0
def __init__(
self,
reactor,
service,
domain,
factory,
protocol="tcp",
connectFuncName="connectTCP",
connectFuncArgs=(),
connectFuncKwArgs={},
defaultPort=None,
):
"""
@param domain: The domain to connect to. If passed as a text
string, it will be encoded using C{idna} encoding.
@type domain: L{bytes} or L{str}
@param defaultPort: Optional default port number to be used when SRV
lookup fails and the service name is unknown. This should be the
port number associated with the service name as defined by the IANA
registry.
@type defaultPort: L{int}
"""
self.reactor = reactor
self.service = service
self.domain = None if domain is None else dns.domainString(domain)
self.factory = factory
self.protocol = protocol
self.connectFuncName = connectFuncName
self.connectFuncArgs = connectFuncArgs
self.connectFuncKwArgs = connectFuncKwArgs
self._defaultPort = defaultPort
self.connector = None
self.servers = None
# list of servers already used in this round:
self.orderedServers = None
def connect(self):
"""Start connection to remote server."""
self.factory.doStart()
self.factory.startedConnecting(self)
if not self.servers:
if self.domain is None:
self.connectionFailed(
error.DNSLookupError("Domain is not defined."),
)
return
d = client.lookupService(
"_%s._%s.%s"
% (
nativeString(self.service),
nativeString(self.protocol),
nativeString(self.domain),
),
)
d.addCallbacks(self._cbGotServers, self._ebGotServers)
d.addCallback(lambda x, self=self: self._reallyConnect())
if self._defaultPort:
d.addErrback(self._ebServiceUnknown)
d.addErrback(self.connectionFailed)
elif self.connector is None:
self._reallyConnect()
else:
self.connector.connect()
def _ebGotServers(self, failure):
failure.trap(DNSNameError)
# Some DNS servers reply with NXDOMAIN when in fact there are
# just no SRV records for that domain. Act as if we just got an
# empty response and use fallback.
self.servers = []
self.orderedServers = []
def _cbGotServers(self, result):
answers, auth, add = result
if (
len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
# decidedly not available
raise error.DNSLookupError(
"Service %s not available for domain %s."
% (repr(self.service), repr(self.domain))
)
self.servers = []
self.orderedServers = []
for a in answers:
if a.type != dns.SRV or not a.payload:
continue
self.orderedServers.append(a.payload)
def _ebServiceUnknown(self, failure):
"""
Connect to the default port when the service name is unknown.
If no SRV records were found, the service name will be passed as the
port. If resolving the name fails with
L{error.ServiceNameUnknownError}, a final attempt is done using the
default port.
"""
failure.trap(error.ServiceNameUnknownError)
self.servers = [dns.Record_SRV(0, 0, self._defaultPort, self.domain)]
self.orderedServers = []
self.connect()
def pickServer(self):
"""
Pick the next server.
This selects the next server from the list of SRV records according
to their priority and weight values, as set out by the default
algorithm specified in RFC 2782.
At the beginning of a round, L{servers} is populated with
L{orderedServers}, and the latter is made empty. L{servers}
is the list of candidates, and L{orderedServers} is the list of servers
that have already been tried.
First, all records are ordered by priority and weight in ascending
order. Then for each priority level, a running sum is calculated
over the sorted list of records for that priority. Then a random value
between 0 and the final sum is compared to each record in order. The
first record that is greater than or equal to that random value is
chosen and removed from the list of candidates for this round.
@return: A tuple of target hostname and port from the chosen DNS SRV
record.
@rtype: L{tuple} of native L{str} and L{int}
"""
assert self.servers is not None
assert self.orderedServers is not None
if not self.servers and not self.orderedServers:
# no SRV record, fall back..
return nativeString(self.domain), self.service
if not self.servers and self.orderedServers:
# start new round
self.servers = self.orderedServers
self.orderedServers = []
assert self.servers
self.servers.sort(key=lambda record: (record.priority, record.weight))
minPriority = self.servers[0].priority
index = 0
weightSum = 0
weightIndex = []
for x in self.servers:
if x.priority == minPriority:
weightSum += x.weight
weightIndex.append((index, weightSum))
index += 1
rand = random.randint(0, weightSum)
for index, weight in weightIndex:
if weight >= rand:
chosen = self.servers[index]
del self.servers[index]
self.orderedServers.append(chosen)
return str(chosen.target), chosen.port
raise RuntimeError(f"Impossible {self.__class__.__name__} pickServer result.")
def _reallyConnect(self):
if self.stopAfterDNS:
self.stopAfterDNS = 0
return
self.host, self.port = self.pickServer()
assert self.host is not None, "Must have a host to connect to."
assert self.port is not None, "Must have a port to connect to."
connectFunc = getattr(self.reactor, self.connectFuncName)
self.connector = connectFunc(
self.host,
self.port,
_SRVConnector_ClientFactoryWrapper(self, self.factory),
*self.connectFuncArgs,
**self.connectFuncKwArgs,
)
def stopConnecting(self):
"""Stop attempting to connect."""
if self.connector:
self.connector.stopConnecting()
else:
self.stopAfterDNS = 1
def disconnect(self):
"""Disconnect whatever our are state is."""
if self.connector is not None:
self.connector.disconnect()
else:
self.stopConnecting()
def getDestination(self):
assert self.connector
return self.connector.getDestination()
def connectionFailed(self, reason):
self.factory.clientConnectionFailed(self, reason)
self.factory.doStop()
def connectionLost(self, reason):
self.factory.clientConnectionLost(self, reason)
self.factory.doStop()

View File

@@ -0,0 +1,149 @@
# -*- test-case-name: twisted.names.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Domain Name Server
"""
import os
import traceback
from twisted.application import internet, service
from twisted.names import authority, dns, secondary, server
from twisted.python import usage
class Options(usage.Options):
optParameters = [
["interface", "i", "", "The interface to which to bind"],
["port", "p", "53", "The port on which to listen"],
[
"resolv-conf",
None,
None,
"Override location of resolv.conf (implies --recursive)",
],
["hosts-file", None, None, "Perform lookups with a hosts file"],
]
optFlags = [
["cache", "c", "Enable record caching"],
["recursive", "r", "Perform recursive lookups"],
["verbose", "v", "Log verbosely"],
]
compData = usage.Completions(
optActions={"interface": usage.CompleteNetInterfaces()}
)
zones = None
zonefiles = None
def __init__(self):
usage.Options.__init__(self)
self["verbose"] = 0
self.bindfiles = []
self.zonefiles = []
self.secondaries = []
def opt_pyzone(self, filename):
"""Specify the filename of a Python syntax zone definition"""
if not os.path.exists(filename):
raise usage.UsageError(filename + ": No such file")
self.zonefiles.append(filename)
def opt_bindzone(self, filename):
"""Specify the filename of a BIND9 syntax zone definition"""
if not os.path.exists(filename):
raise usage.UsageError(filename + ": No such file")
self.bindfiles.append(filename)
def opt_secondary(self, ip_domain):
"""Act as secondary for the specified domain, performing
zone transfers from the specified IP (IP/domain)
"""
args = ip_domain.split("/", 1)
if len(args) != 2:
raise usage.UsageError("Argument must be of the form IP[:port]/domain")
address = args[0].split(":")
if len(address) == 1:
address = (address[0], dns.PORT)
else:
try:
port = int(address[1])
except ValueError:
raise usage.UsageError(
f"Specify an integer port number, not {address[1]!r}"
)
address = (address[0], port)
self.secondaries.append((address, [args[1]]))
def opt_verbose(self):
"""Increment verbosity level"""
self["verbose"] += 1
def postOptions(self):
if self["resolv-conf"]:
self["recursive"] = True
self.svcs = []
self.zones = []
for f in self.zonefiles:
try:
self.zones.append(authority.PySourceAuthority(f))
except Exception:
traceback.print_exc()
raise usage.UsageError("Invalid syntax in " + f)
for f in self.bindfiles:
try:
self.zones.append(authority.BindAuthority(f))
except Exception:
traceback.print_exc()
raise usage.UsageError("Invalid syntax in " + f)
for f in self.secondaries:
svc = secondary.SecondaryAuthorityService.fromServerAddressAndDomains(*f)
self.svcs.append(svc)
self.zones.append(self.svcs[-1].getAuthority())
try:
self["port"] = int(self["port"])
except ValueError:
raise usage.UsageError("Invalid port: {!r}".format(self["port"]))
def _buildResolvers(config):
"""
Build DNS resolver instances in an order which leaves recursive
resolving as a last resort.
@type config: L{Options} instance
@param config: Parsed command-line configuration
@return: Two-item tuple of a list of cache resovers and a list of client
resolvers
"""
from twisted.names import cache, client, hosts
ca, cl = [], []
if config["cache"]:
ca.append(cache.CacheResolver(verbose=config["verbose"]))
if config["hosts-file"]:
cl.append(hosts.Resolver(file=config["hosts-file"]))
if config["recursive"]:
cl.append(client.createResolver(resolvconf=config["resolv-conf"]))
return ca, cl
def makeService(config):
ca, cl = _buildResolvers(config)
f = server.DNSServerFactory(config.zones, ca, cl, config["verbose"])
p = dns.DNSDatagramProtocol(f)
f.noisy = 0
ret = service.MultiService()
for klass, arg in [(internet.TCPServer, f), (internet.UDPServer, p)]:
s = klass(config["port"], arg, interface=config["interface"])
s.setServiceParent(ret)
for svc in config.svcs:
svc.setServiceParent(ret)
return ret

View File

@@ -0,0 +1 @@
"Tests for twisted.names"

View File

@@ -0,0 +1,190 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.names.cache}.
"""
import time
from zope.interface.verify import verifyClass
from twisted.internet import interfaces, task
from twisted.names import cache, dns
from twisted.trial import unittest
class CachingTests(unittest.TestCase):
"""
Tests for L{cache.CacheResolver}.
"""
def test_interface(self):
"""
L{cache.CacheResolver} implements L{interfaces.IResolver}
"""
verifyClass(interfaces.IResolver, cache.CacheResolver)
def test_lookup(self):
c = cache.CacheResolver(
{
dns.Query(name=b"example.com", type=dns.MX, cls=dns.IN): (
time.time(),
([], [], []),
)
}
)
return c.lookupMailExchange(b"example.com").addCallback(
self.assertEqual, ([], [], [])
)
def test_constructorExpires(self):
"""
Cache entries passed into L{cache.CacheResolver.__init__} get
cancelled just like entries added with cacheResult
"""
r = (
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 60, dns.Record_A("127.0.0.1", 60)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 50, dns.Record_A("127.0.0.1", 50)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 40, dns.Record_A("127.0.0.1", 40)
)
],
)
clock = task.Clock()
query = dns.Query(name=b"example.com", type=dns.A, cls=dns.IN)
c = cache.CacheResolver({query: (clock.seconds(), r)}, reactor=clock)
# 40 seconds is enough to expire the entry because expiration is based
# on the minimum TTL.
clock.advance(40)
self.assertNotIn(query, c.cache)
return self.assertFailure(c.lookupAddress(b"example.com"), dns.DomainError)
def test_normalLookup(self):
"""
When a cache lookup finds a cached entry from 1 second ago, it is
returned with a TTL of original TTL minus the elapsed 1 second.
"""
r = (
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 60, dns.Record_A("127.0.0.1", 60)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 50, dns.Record_A("127.0.0.1", 50)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 40, dns.Record_A("127.0.0.1", 40)
)
],
)
clock = task.Clock()
c = cache.CacheResolver(reactor=clock)
c.cacheResult(dns.Query(name=b"example.com", type=dns.A, cls=dns.IN), r)
clock.advance(1)
def cbLookup(result):
self.assertEqual(result[0][0].ttl, 59)
self.assertEqual(result[1][0].ttl, 49)
self.assertEqual(result[2][0].ttl, 39)
self.assertEqual(result[0][0].name.name, b"example.com")
return c.lookupAddress(b"example.com").addCallback(cbLookup)
def test_cachedResultExpires(self):
"""
Once the TTL has been exceeded, the result is removed from the cache.
"""
r = (
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 60, dns.Record_A("127.0.0.1", 60)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 50, dns.Record_A("127.0.0.1", 50)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 40, dns.Record_A("127.0.0.1", 40)
)
],
)
clock = task.Clock()
c = cache.CacheResolver(reactor=clock)
query = dns.Query(name=b"example.com", type=dns.A, cls=dns.IN)
c.cacheResult(query, r)
clock.advance(40)
self.assertNotIn(query, c.cache)
return self.assertFailure(c.lookupAddress(b"example.com"), dns.DomainError)
def test_expiredTTLLookup(self):
"""
When the cache is queried exactly as the cached entry should expire but
before it has actually been cleared, the cache does not return the
expired entry.
"""
r = (
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 60, dns.Record_A("127.0.0.1", 60)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 50, dns.Record_A("127.0.0.1", 50)
)
],
[
dns.RRHeader(
b"example.com", dns.A, dns.IN, 40, dns.Record_A("127.0.0.1", 40)
)
],
)
clock = task.Clock()
# Make sure timeouts never happen, so entries won't get cleared:
clock.callLater = lambda *args, **kwargs: None
c = cache.CacheResolver(
{
dns.Query(name=b"example.com", type=dns.A, cls=dns.IN): (
clock.seconds(),
r,
)
},
reactor=clock,
)
clock.advance(60.1)
return self.assertFailure(c.lookupAddress(b"example.com"), dns.DomainError)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,129 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.names.common}.
"""
from zope.interface.verify import verifyClass
from twisted.internet.interfaces import IResolver
from twisted.names.common import ResolverBase
from twisted.names.dns import EFORMAT, ENAME, ENOTIMP, EREFUSED, ESERVER, Query
from twisted.names.error import (
DNSFormatError,
DNSNameError,
DNSNotImplementedError,
DNSQueryRefusedError,
DNSServerError,
DNSUnknownError,
)
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase
class ExceptionForCodeTests(SynchronousTestCase):
"""
Tests for L{ResolverBase.exceptionForCode}.
"""
def setUp(self):
self.exceptionForCode = ResolverBase().exceptionForCode
def test_eformat(self):
"""
L{ResolverBase.exceptionForCode} converts L{EFORMAT} to
L{DNSFormatError}.
"""
self.assertIs(self.exceptionForCode(EFORMAT), DNSFormatError)
def test_eserver(self):
"""
L{ResolverBase.exceptionForCode} converts L{ESERVER} to
L{DNSServerError}.
"""
self.assertIs(self.exceptionForCode(ESERVER), DNSServerError)
def test_ename(self):
"""
L{ResolverBase.exceptionForCode} converts L{ENAME} to L{DNSNameError}.
"""
self.assertIs(self.exceptionForCode(ENAME), DNSNameError)
def test_enotimp(self):
"""
L{ResolverBase.exceptionForCode} converts L{ENOTIMP} to
L{DNSNotImplementedError}.
"""
self.assertIs(self.exceptionForCode(ENOTIMP), DNSNotImplementedError)
def test_erefused(self):
"""
L{ResolverBase.exceptionForCode} converts L{EREFUSED} to
L{DNSQueryRefusedError}.
"""
self.assertIs(self.exceptionForCode(EREFUSED), DNSQueryRefusedError)
def test_other(self):
"""
L{ResolverBase.exceptionForCode} converts any other response code to
L{DNSUnknownError}.
"""
self.assertIs(self.exceptionForCode(object()), DNSUnknownError)
class QueryTests(SynchronousTestCase):
"""
Tests for L{ResolverBase.query}.
"""
def test_resolverBaseProvidesIResolver(self):
"""
L{ResolverBase} provides the L{IResolver} interface.
"""
verifyClass(IResolver, ResolverBase)
def test_typeToMethodDispatch(self):
"""
L{ResolverBase.query} looks up a method to invoke using the type of the
query passed to it and the C{typeToMethod} mapping on itself.
"""
results = []
resolver = ResolverBase()
resolver.typeToMethod = {
12345: lambda query, timeout: results.append((query, timeout))
}
query = Query(name=b"example.com", type=12345)
resolver.query(query, 123)
self.assertEqual([(b"example.com", 123)], results)
def test_typeToMethodResult(self):
"""
L{ResolverBase.query} returns a L{Deferred} which fires with the result
of the method found in the C{typeToMethod} mapping for the type of the
query passed to it.
"""
expected = object()
resolver = ResolverBase()
resolver.typeToMethod = {54321: lambda query, timeout: expected}
query = Query(name=b"example.com", type=54321)
queryDeferred = resolver.query(query, 123)
result = []
queryDeferred.addBoth(result.append)
self.assertEqual(expected, result[0])
def test_unknownQueryType(self):
"""
L{ResolverBase.query} returns a L{Deferred} which fails with
L{NotImplementedError} when called with a query of a type not present in
its C{typeToMethod} dictionary.
"""
resolver = ResolverBase()
resolver.typeToMethod = {}
query = Query(name=b"example.com", type=12345)
queryDeferred = resolver.query(query, 123)
result = []
queryDeferred.addBoth(result.append)
self.assertIsInstance(result[0], Failure)
result[0].trap(NotImplementedError)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,161 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.names} example scripts.
"""
import os
import sys
from io import StringIO
from twisted.python.filepath import FilePath
from twisted.trial.unittest import SkipTest, TestCase
class ExampleTestBase:
"""
This is a mixin which adds an example to the path, tests it, and then
removes it from the path and unimports the modules which the test loaded.
Test cases which test example code and documentation listings should use
this.
This is done this way so that examples can live in isolated path entries,
next to the documentation, replete with their own plugin packages and
whatever other metadata they need. Also, example code is a rare instance
of it being valid to have multiple versions of the same code in the
repository at once, rather than relying on version control, because
documentation will often show the progression of a single piece of code as
features are added to it, and we want to test each one.
"""
def setUp(self):
"""
Add our example directory to the path and record which modules are
currently loaded.
"""
self.originalPath = sys.path[:]
self.originalModules = sys.modules.copy()
# Python usually expects native strs to be written to sys.stdout/stderr
self.fakeErr = StringIO()
self.patch(sys, "stderr", self.fakeErr)
self.fakeOut = StringIO()
self.patch(sys, "stdout", self.fakeOut)
# Get documentation root
try:
here = FilePath(os***REMOVED***iron["TOX_INI_DIR"]).child("docs")
except KeyError:
raise SkipTest(
"Examples not found ($TOX_INI_DIR unset) - cannot test",
)
# Find the example script within this branch
for childName in self.exampleRelativePath.split("/"):
here = here.child(childName)
if not here.exists():
raise SkipTest(f"Examples ({here.path}) not found - cannot test")
self.examplePath = here
# Add the example parent folder to the Python path
sys.path.append(self.examplePath.parent().path)
# Import the example as a module
moduleName = self.examplePath.basename().split(".")[0]
self.example = __import__(moduleName)
def tearDown(self):
"""
Remove the example directory from the path and remove all
modules loaded by the test from sys.modules.
"""
sys.modules.clear()
sys.modules.update(self.originalModules)
sys.path[:] = self.originalPath
def test_shebang(self):
"""
The example scripts start with the standard shebang line.
"""
with self.examplePath.open() as f:
self.assertEqual(f.readline().rstrip(), b"#!/usr/bin/env python")
def test_usageConsistency(self):
"""
The example script prints a usage message to stdout if it is
passed a --help option and then exits.
The first line should contain a USAGE summary, explaining the
accepted command arguments.
"""
# Pass None as first parameter - the reactor - it shouldn't
# get as far as calling it.
self.assertRaises(SystemExit, self.example.main, None, "--help")
out = self.fakeOut.getvalue().splitlines()
self.assertTrue(
out[0].startswith("Usage:"),
'Usage message first line should start with "Usage:". '
"Actual: %r" % (out[0],),
)
def test_usageConsistencyOnError(self):
"""
The example script prints a usage message to stderr if it is
passed unrecognized command line arguments.
The first line should contain a USAGE summary, explaining the
accepted command arguments.
The last line should contain an ERROR summary, explaining that
incorrect arguments were supplied.
"""
# Pass None as first parameter - the reactor - it shouldn't
# get as far as calling it.
self.assertRaises(SystemExit, self.example.main, None, "--unexpected_argument")
err = self.fakeErr.getvalue().splitlines()
self.assertTrue(
err[0].startswith("Usage:"),
'Usage message first line should start with "Usage:". '
"Actual: %r" % (err[0],),
)
self.assertTrue(
err[-1].startswith("ERROR:"),
'Usage message last line should start with "ERROR:" '
"Actual: %r" % (err[-1],),
)
class TestDnsTests(ExampleTestBase, TestCase):
"""
Test the testdns.py example script.
"""
exampleRelativePath = "names/examples/testdns.py"
class GetHostByNameTests(ExampleTestBase, TestCase):
"""
Test the gethostbyname.py example script.
"""
exampleRelativePath = "names/examples/gethostbyname.py"
class DnsServiceTests(ExampleTestBase, TestCase):
"""
Test the dns-service.py example script.
"""
exampleRelativePath = "names/examples/dns-service.py"
class MultiReverseLookupTests(ExampleTestBase, TestCase):
"""
Test the multi_reverse_lookup.py example script.
"""
exampleRelativePath = "names/examples/multi_reverse_lookup.py"

View File

@@ -0,0 +1,305 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the I{hosts(5)}-based resolver, L{twisted.names.hosts}.
"""
from __future__ import annotations
from typing_extensions import Protocol
from twisted.internet.defer import gatherResults
from twisted.names.dns import (
AAAA,
IN,
A,
DomainError,
Query,
Record_A,
Record_AAAA,
RRHeader,
)
from twisted.names.hosts import Resolver, searchFileFor, searchFileForAll
from twisted.python.filepath import FilePath
from twisted.trial.unittest import SynchronousTestCase
class _SupportsMktemp(Protocol):
def mktemp(self) -> str:
...
class GoodTempPathMixin:
def path(self: _SupportsMktemp) -> FilePath[bytes]:
return FilePath(self.mktemp().encode("utf-8"))
class SearchHostsFileTests(SynchronousTestCase, GoodTempPathMixin):
"""
Tests for L{searchFileFor}, a helper which finds the first address for a
particular hostname in a I{hosts(5)}-style file.
"""
def test_findAddress(self) -> None:
"""
If there is an IPv4 address for the hostname passed to L{searchFileFor},
it is returned.
"""
hosts = self.path()
hosts.setContent(b"10.2.3.4 foo.example.com\n")
self.assertEqual("10.2.3.4", searchFileFor(hosts.path, b"foo.example.com"))
def test_notFoundAddress(self) -> None:
"""
If there is no address information for the hostname passed to
L{searchFileFor}, L{None} is returned.
"""
hosts = self.path()
hosts.setContent(b"10.2.3.4 foo.example.com\n")
self.assertIsNone(searchFileFor(hosts.path, b"bar.example.com"))
def test_firstAddress(self) -> None:
"""
The first address associated with the given hostname is returned.
"""
hosts = self.path()
hosts.setContent(
b"::1 foo.example.com\n"
b"10.1.2.3 foo.example.com\n"
b"fe80::21b:fcff:feee:5a1d foo.example.com\n"
)
self.assertEqual("::1", searchFileFor(hosts.path, b"foo.example.com"))
def test_searchFileForAliases(self) -> None:
"""
For a host with a canonical name and one or more aliases,
L{searchFileFor} can find an address given any of the names.
"""
hosts = self.path()
hosts.setContent(
b"127.0.1.1\thelmut.example.org\thelmut\n"
b"# a comment\n"
b"::1 localhost ip6-localhost ip6-loopback\n"
)
self.assertEqual(searchFileFor(hosts.path, b"helmut"), "127.0.1.1")
self.assertEqual(searchFileFor(hosts.path, b"helmut.example.org"), "127.0.1.1")
self.assertEqual(searchFileFor(hosts.path, b"ip6-localhost"), "::1")
self.assertEqual(searchFileFor(hosts.path, b"ip6-loopback"), "::1")
self.assertEqual(searchFileFor(hosts.path, b"localhost"), "::1")
class SearchHostsFileForAllTests(SynchronousTestCase, GoodTempPathMixin):
"""
Tests for L{searchFileForAll}, a helper which finds all addresses for a
particular hostname in a I{hosts(5)}-style file.
"""
def test_allAddresses(self) -> None:
"""
L{searchFileForAll} returns a list of all addresses associated with the
name passed to it.
"""
hosts = self.path()
hosts.setContent(
b"127.0.0.1 foobar.example.com\n"
b"127.0.0.2 foobar.example.com\n"
b"::1 foobar.example.com\n"
)
self.assertEqual(
["127.0.0.1", "127.0.0.2", "::1"],
searchFileForAll(hosts, b"foobar.example.com"),
)
def test_caseInsensitively(self) -> None:
"""
L{searchFileForAll} searches for names case-insensitively.
"""
hosts = self.path()
hosts.setContent(b"127.0.0.1 foobar.EXAMPLE.com\n")
self.assertEqual(["127.0.0.1"], searchFileForAll(hosts, b"FOOBAR.example.com"))
def test_readError(self) -> None:
"""
If there is an error reading the contents of the hosts file,
L{searchFileForAll} returns an empty list.
"""
self.assertEqual([], searchFileForAll(self.path(), b"example.com"))
def test_malformedIP(self) -> None:
"""
L{searchFileForAll} ignores any malformed IP addresses associated with
the name passed to it.
"""
hosts = self.path()
hosts.setContent(
b"127.0.0.1\tmiser.example.org\tmiser\n"
b"not-an-ip\tmiser\n"
b"\xffnot-ascii\t miser\n"
b"# miser\n"
b"miser\n"
b"::1 miser"
)
self.assertEqual(
["127.0.0.1", "::1"],
searchFileForAll(hosts, b"miser"),
)
class HostsTests(SynchronousTestCase, GoodTempPathMixin):
"""
Tests for the I{hosts(5)}-based L{twisted.names.hosts.Resolver}.
"""
def setUp(self) -> None:
f = self.path()
f.setContent(
b"""
1.1.1.1 EXAMPLE EXAMPLE.EXAMPLETHING
::2 mixed
1.1.1.2 MIXED
::1 ip6thingy
1.1.1.3 multiple
1.1.1.4 multiple
::3 ip6-multiple
::4 ip6-multiple
not-an-ip malformed
malformed
# malformed
1.1.1.5 malformed
::5 malformed
"""
)
self.ttl = 4200
self.resolver = Resolver(f.path, self.ttl)
def test_defaultPath(self) -> None:
"""
The default hosts file used by L{Resolver} is I{/etc/hosts} if no value
is given for the C{file} initializer parameter.
"""
resolver = Resolver()
self.assertEqual(b"/etc/hosts", resolver.file)
def test_getHostByName(self) -> None:
"""
L{hosts.Resolver.getHostByName} returns a L{Deferred} which fires with a
string giving the address of the queried name as found in the resolver's
hosts file.
"""
data = [
(b"EXAMPLE", "1.1.1.1"),
(b"EXAMPLE.EXAMPLETHING", "1.1.1.1"),
(b"MIXED", "1.1.1.2"),
]
ds = [
self.resolver.getHostByName(n).addCallback(self.assertEqual, ip)
for n, ip in data
]
self.successResultOf(gatherResults(ds))
def test_lookupAddress(self) -> None:
"""
L{hosts.Resolver.lookupAddress} returns a L{Deferred} which fires with A
records from the hosts file.
"""
d = self.resolver.lookupAddress(b"multiple")
answers, authority, additional = self.successResultOf(d)
self.assertEqual(
(
RRHeader(b"multiple", A, IN, self.ttl, Record_A("1.1.1.3", self.ttl)),
RRHeader(b"multiple", A, IN, self.ttl, Record_A("1.1.1.4", self.ttl)),
),
answers,
)
def test_lookupIPV6Address(self) -> None:
"""
L{hosts.Resolver.lookupIPV6Address} returns a L{Deferred} which fires
with AAAA records from the hosts file.
"""
d = self.resolver.lookupIPV6Address(b"ip6-multiple")
answers, authority, additional = self.successResultOf(d)
self.assertEqual(
(
RRHeader(
b"ip6-multiple", AAAA, IN, self.ttl, Record_AAAA("::3", self.ttl)
),
RRHeader(
b"ip6-multiple", AAAA, IN, self.ttl, Record_AAAA("::4", self.ttl)
),
),
answers,
)
def test_lookupAllRecords(self) -> None:
"""
L{hosts.Resolver.lookupAllRecords} returns a L{Deferred} which fires
with A records from the hosts file.
"""
d = self.resolver.lookupAllRecords(b"mixed")
answers, authority, additional = self.successResultOf(d)
self.assertEqual(
(RRHeader(b"mixed", A, IN, self.ttl, Record_A("1.1.1.2", self.ttl)),),
answers,
)
def test_notImplemented(self) -> None:
"""
L{hosts.Resolver} fails with L{NotImplementedError} for L{IResolver}
methods it doesn't implement.
"""
self.failureResultOf(
self.resolver.lookupMailExchange(b"EXAMPLE"), NotImplementedError
)
def test_query(self) -> None:
d = self.resolver.query(Query(b"EXAMPLE"))
[answer], authority, additional = self.successResultOf(d)
self.assertEqual(answer.payload.dottedQuad(), "1.1.1.1")
def test_lookupAddressNotFound(self) -> None:
"""
L{hosts.Resolver.lookupAddress} returns a L{Deferred} which fires with
L{dns.DomainError} if the name passed in has no addresses in the hosts
file.
"""
self.failureResultOf(self.resolver.lookupAddress(b"foueoa"), DomainError)
def test_lookupIPV6AddressNotFound(self) -> None:
"""
Like L{test_lookupAddressNotFound}, but for
L{hosts.Resolver.lookupIPV6Address}.
"""
self.failureResultOf(self.resolver.lookupIPV6Address(b"foueoa"), DomainError)
def test_lookupAllRecordsNotFound(self) -> None:
"""
Like L{test_lookupAddressNotFound}, but for
L{hosts.Resolver.lookupAllRecords}.
"""
self.failureResultOf(self.resolver.lookupAllRecords(b"foueoa"), DomainError)
def test_lookupMalformed(self) -> None:
"""
L{hosts.Resolver.lookupAddress} returns a L{Deferred} which fires with
the valid addresses from the hosts file, ignoring any entries that
aren't valid IP addresses.
"""
d = self.resolver.lookupAddress(b"malformed")
[answer], authority, additional = self.successResultOf(d)
self.assertEqual(
RRHeader(b"malformed", A, IN, self.ttl, Record_A("1.1.1.5", self.ttl)),
answer,
)
def test_lookupIPV6Malformed(self) -> None:
"""
Like L{test_lookupAddressMalformed}, but for
L{hosts.Resolver.lookupIPV6Address}.
"""
d = self.resolver.lookupIPV6Address(b"malformed")
[answer], authority, additional = self.successResultOf(d)
self.assertEqual(
RRHeader(b"malformed", AAAA, IN, self.ttl, Record_AAAA("::5", self.ttl)),
answer,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.names.resolve}.
"""
from twisted.names.error import DomainError
from twisted.names.resolve import ResolverChain
from twisted.trial.unittest import TestCase
class ResolverChainTests(TestCase):
"""
Tests for L{twisted.names.resolve.ResolverChain}
"""
def test_emptyResolversList(self) -> None:
"""
L{ResolverChain._lookup} returns a L{DomainError} failure if
its C{resolvers} list is empty.
"""
r = ResolverChain([])
d = r.lookupAddress("www.example.com")
f = self.failureResultOf(d)
self.assertIs(f.trap(DomainError), DomainError)
def test_emptyResolversListLookupAllRecords(self) -> None:
"""
L{ResolverChain.lookupAllRecords} returns a L{DomainError}
failure if its C{resolvers} list is empty.
"""
r = ResolverChain([])
d = r.lookupAllRecords("www.example.com")
f = self.failureResultOf(d)
self.assertIs(f.trap(DomainError), DomainError)

View File

@@ -0,0 +1,391 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.names.rfc1982}.
"""
import calendar
from datetime import datetime
from functools import partial
from twisted.names._rfc1982 import SerialNumber
from twisted.trial import unittest
class SerialNumberTests(unittest.TestCase):
"""
Tests for L{SerialNumber}.
"""
def test_serialBitsDefault(self):
"""
L{SerialNumber.serialBits} has default value 32.
"""
self.assertEqual(SerialNumber(1)._serialBits, 32)
def test_serialBitsOverride(self):
"""
L{SerialNumber.__init__} accepts a C{serialBits} argument whose value is
assigned to L{SerialNumber.serialBits}.
"""
self.assertEqual(SerialNumber(1, serialBits=8)._serialBits, 8)
def test_repr(self):
"""
L{SerialNumber.__repr__} returns a string containing number and
serialBits.
"""
self.assertEqual(
"<SerialNumber number=123 serialBits=32>",
repr(SerialNumber(123, serialBits=32)),
)
def test_str(self):
"""
L{SerialNumber.__str__} returns a string representation of the current
value.
"""
self.assertEqual(str(SerialNumber(123)), "123")
def test_int(self):
"""
L{SerialNumber.__int__} returns an integer representation of the current
value.
"""
self.assertEqual(int(SerialNumber(123)), 123)
def test_hash(self):
"""
L{SerialNumber.__hash__} allows L{SerialNumber} instances to be hashed
for use as dictionary keys.
"""
self.assertEqual(hash(SerialNumber(1)), hash(SerialNumber(1)))
self.assertNotEqual(hash(SerialNumber(1)), hash(SerialNumber(2)))
def test_convertOtherSerialBitsMismatch(self):
"""
L{SerialNumber._convertOther} raises L{TypeError} if the other
SerialNumber instance has a different C{serialBits} value.
"""
s1 = SerialNumber(0, serialBits=8)
s2 = SerialNumber(0, serialBits=16)
self.assertRaises(TypeError, s1._convertOther, s2)
def test_eq(self):
"""
L{SerialNumber.__eq__} provides rich equality comparison.
"""
self.assertEqual(SerialNumber(1), SerialNumber(1))
def test_eqForeignType(self):
"""
== comparison of L{SerialNumber} with a non-L{SerialNumber} instance
returns L{NotImplemented}.
"""
self.assertFalse(SerialNumber(1) == object())
self.assertIs(SerialNumber(1).__eq__(object()), NotImplemented)
def test_ne(self):
"""
L{SerialNumber.__ne__} provides rich equality comparison.
"""
self.assertFalse(SerialNumber(1) != SerialNumber(1))
self.assertNotEqual(SerialNumber(1), SerialNumber(2))
def test_neForeignType(self):
"""
!= comparison of L{SerialNumber} with a non-L{SerialNumber} instance
returns L{NotImplemented}.
"""
self.assertTrue(SerialNumber(1) != object())
self.assertIs(SerialNumber(1).__ne__(object()), NotImplemented)
def test_le(self):
"""
L{SerialNumber.__le__} provides rich <= comparison.
"""
self.assertTrue(SerialNumber(1) <= SerialNumber(1))
self.assertTrue(SerialNumber(1) <= SerialNumber(2))
def test_leForeignType(self):
"""
<= comparison of L{SerialNumber} with a non-L{SerialNumber} instance
raises L{TypeError}.
"""
self.assertRaises(TypeError, lambda: SerialNumber(1) <= object())
def test_ge(self):
"""
L{SerialNumber.__ge__} provides rich >= comparison.
"""
self.assertTrue(SerialNumber(1) >= SerialNumber(1))
self.assertTrue(SerialNumber(2) >= SerialNumber(1))
def test_geForeignType(self):
"""
>= comparison of L{SerialNumber} with a non-L{SerialNumber} instance
raises L{TypeError}.
"""
self.assertRaises(TypeError, lambda: SerialNumber(1) >= object())
def test_lt(self):
"""
L{SerialNumber.__lt__} provides rich < comparison.
"""
self.assertTrue(SerialNumber(1) < SerialNumber(2))
def test_ltForeignType(self):
"""
< comparison of L{SerialNumber} with a non-L{SerialNumber} instance
raises L{TypeError}.
"""
self.assertRaises(TypeError, lambda: SerialNumber(1) < object())
def test_gt(self):
"""
L{SerialNumber.__gt__} provides rich > comparison.
"""
self.assertTrue(SerialNumber(2) > SerialNumber(1))
def test_gtForeignType(self):
"""
> comparison of L{SerialNumber} with a non-L{SerialNumber} instance
raises L{TypeError}.
"""
self.assertRaises(TypeError, lambda: SerialNumber(2) > object())
def test_add(self):
"""
L{SerialNumber.__add__} allows L{SerialNumber} instances to be summed.
"""
self.assertEqual(SerialNumber(1) + SerialNumber(1), SerialNumber(2))
def test_addForeignType(self):
"""
Addition of L{SerialNumber} with a non-L{SerialNumber} instance raises
L{TypeError}.
"""
self.assertRaises(TypeError, lambda: SerialNumber(1) + object())
def test_addOutOfRangeHigh(self):
"""
L{SerialNumber} cannot be added with other SerialNumber values larger
than C{_maxAdd}.
"""
maxAdd = SerialNumber(1)._maxAdd
self.assertRaises(
ArithmeticError, lambda: SerialNumber(1) + SerialNumber(maxAdd + 1)
)
def test_maxVal(self):
"""
L{SerialNumber.__add__} returns a wrapped value when s1 plus the s2
would result in a value greater than the C{maxVal}.
"""
s = SerialNumber(1)
maxVal = s._halfRing + s._halfRing - 1
maxValPlus1 = maxVal + 1
self.assertTrue(SerialNumber(maxValPlus1) > SerialNumber(maxVal))
self.assertEqual(SerialNumber(maxValPlus1), SerialNumber(0))
def test_fromRFC4034DateString(self):
"""
L{SerialNumber.fromRFC4034DateString} accepts a datetime string argument
of the form 'YYYYMMDDhhmmss' and returns an L{SerialNumber} instance
whose value is the unix timestamp corresponding to that UTC date.
"""
self.assertEqual(
SerialNumber(1325376000),
SerialNumber.fromRFC4034DateString("20120101000000"),
)
def test_toRFC4034DateString(self):
"""
L{DateSerialNumber.toRFC4034DateString} interprets the current value as
a unix timestamp and returns a date string representation of that date.
"""
self.assertEqual(
"20120101000000", SerialNumber(1325376000).toRFC4034DateString()
)
def test_unixEpoch(self):
"""
L{SerialNumber.toRFC4034DateString} stores 32bit timestamps relative to
the UNIX epoch.
"""
self.assertEqual(SerialNumber(0).toRFC4034DateString(), "19700101000000")
def test_Y2106Problem(self):
"""
L{SerialNumber} wraps unix timestamps in the year 2106.
"""
self.assertEqual(SerialNumber(-1).toRFC4034DateString(), "21060207062815")
def test_Y2038Problem(self):
"""
L{SerialNumber} raises ArithmeticError when used to add dates more than
68 years in the future.
"""
maxAddTime = calendar.timegm(datetime(2038, 1, 19, 3, 14, 7).utctimetuple())
self.assertEqual(
maxAddTime,
SerialNumber(0)._maxAdd,
)
self.assertRaises(
ArithmeticError, lambda: SerialNumber(0) + SerialNumber(maxAddTime + 1)
)
def assertUndefinedComparison(testCase, s1, s2):
"""
A custom assertion for L{SerialNumber} values that cannot be meaningfully
compared.
"Note that there are some pairs of values s1 and s2 for which s1 is not
equal to s2, but for which s1 is neither greater than, nor less than, s2.
An attempt to use these ordering operators on such pairs of values produces
an undefined result."
@see: U{https://tools.ietf.org/html/rfc1982#section-3.2}
@param testCase: The L{unittest.TestCase} on which to call assertion
methods.
@type testCase: L{unittest.TestCase}
@param s1: The first value to compare.
@type s1: L{SerialNumber}
@param s2: The second value to compare.
@type s2: L{SerialNumber}
"""
testCase.assertFalse(s1 == s2)
testCase.assertFalse(s1 <= s2)
testCase.assertFalse(s1 < s2)
testCase.assertFalse(s1 > s2)
testCase.assertFalse(s1 >= s2)
serialNumber2 = partial(SerialNumber, serialBits=2)
class SerialNumber2BitTests(unittest.TestCase):
"""
Tests for correct answers to example calculations in RFC1982 5.1.
The simplest meaningful serial number space has SERIAL_BITS == 2. In this
space, the integers that make up the serial number space are 0, 1, 2, and 3.
That is, 3 == 2^SERIAL_BITS - 1.
https://tools.ietf.org/html/rfc1982#section-5.1
"""
def test_maxadd(self):
"""
In this space, the largest integer that it is meaningful to add to a
sequence number is 2^(SERIAL_BITS - 1) - 1, or 1.
"""
self.assertEqual(SerialNumber(0, serialBits=2)._maxAdd, 1)
def test_add(self):
"""
Then, as defined 0+1 == 1, 1+1 == 2, 2+1 == 3, and 3+1 == 0.
"""
self.assertEqual(serialNumber2(0) + serialNumber2(1), serialNumber2(1))
self.assertEqual(serialNumber2(1) + serialNumber2(1), serialNumber2(2))
self.assertEqual(serialNumber2(2) + serialNumber2(1), serialNumber2(3))
self.assertEqual(serialNumber2(3) + serialNumber2(1), serialNumber2(0))
def test_gt(self):
"""
Further, 1 > 0, 2 > 1, 3 > 2, and 0 > 3.
"""
self.assertTrue(serialNumber2(1) > serialNumber2(0))
self.assertTrue(serialNumber2(2) > serialNumber2(1))
self.assertTrue(serialNumber2(3) > serialNumber2(2))
self.assertTrue(serialNumber2(0) > serialNumber2(3))
def test_undefined(self):
"""
It is undefined whether 2 > 0 or 0 > 2, and whether 1 > 3 or 3 > 1.
"""
assertUndefinedComparison(self, serialNumber2(2), serialNumber2(0))
assertUndefinedComparison(self, serialNumber2(0), serialNumber2(2))
assertUndefinedComparison(self, serialNumber2(1), serialNumber2(3))
assertUndefinedComparison(self, serialNumber2(3), serialNumber2(1))
serialNumber8 = partial(SerialNumber, serialBits=8)
class SerialNumber8BitTests(unittest.TestCase):
"""
Tests for correct answers to example calculations in RFC1982 5.2.
Consider the case where SERIAL_BITS == 8. In this space the integers that
make up the serial number space are 0, 1, 2, ... 254, 255. 255 ==
2^SERIAL_BITS - 1.
https://tools.ietf.org/html/rfc1982#section-5.2
"""
def test_maxadd(self):
"""
In this space, the largest integer that it is meaningful to add to a
sequence number is 2^(SERIAL_BITS - 1) - 1, or 127.
"""
self.assertEqual(SerialNumber(0, serialBits=8)._maxAdd, 127)
def test_add(self):
"""
Addition is as expected in this space, for example: 255+1 == 0,
100+100 == 200, and 200+100 == 44.
"""
self.assertEqual(serialNumber8(255) + serialNumber8(1), serialNumber8(0))
self.assertEqual(serialNumber8(100) + serialNumber8(100), serialNumber8(200))
self.assertEqual(serialNumber8(200) + serialNumber8(100), serialNumber8(44))
def test_gt(self):
"""
Comparison is more interesting, 1 > 0, 44 > 0, 100 > 0, 100 > 44,
200 > 100, 255 > 200, 0 > 255, 100 > 255, 0 > 200, and 44 > 200.
"""
self.assertTrue(serialNumber8(1) > serialNumber8(0))
self.assertTrue(serialNumber8(44) > serialNumber8(0))
self.assertTrue(serialNumber8(100) > serialNumber8(0))
self.assertTrue(serialNumber8(100) > serialNumber8(44))
self.assertTrue(serialNumber8(200) > serialNumber8(100))
self.assertTrue(serialNumber8(255) > serialNumber8(200))
self.assertTrue(serialNumber8(100) > serialNumber8(255))
self.assertTrue(serialNumber8(0) > serialNumber8(200))
self.assertTrue(serialNumber8(44) > serialNumber8(200))
def test_surprisingAddition(self):
"""
Note that 100+100 > 100, but that (100+100)+100 < 100. Incrementing a
serial number can cause it to become "smaller". Of course, incrementing
by a smaller number will allow many more increments to be made before
this occurs. However this is always something to be aware of, it can
cause surprising errors, or be useful as it is the only defined way to
actually cause a serial number to decrease.
"""
self.assertTrue(serialNumber8(100) + serialNumber8(100) > serialNumber8(100))
self.assertTrue(
serialNumber8(100) + serialNumber8(100) + serialNumber8(100)
< serialNumber8(100)
)
def test_undefined(self):
"""
The pairs of values 0 and 128, 1 and 129, 2 and 130, etc, to 127 and 255
are not equal, but in each pair, neither number is defined as being
greater than, or less than, the other.
"""
assertUndefinedComparison(self, serialNumber8(0), serialNumber8(128))
assertUndefinedComparison(self, serialNumber8(1), serialNumber8(129))
assertUndefinedComparison(self, serialNumber8(2), serialNumber8(130))
assertUndefinedComparison(self, serialNumber8(127), serialNumber8(255))

View File

@@ -0,0 +1,738 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for Twisted.names' root resolver.
"""
from zope.interface import implementer
from zope.interface.verify import verifyClass
from twisted.internet.defer import Deferred, TimeoutError, gatherResults, succeed
from twisted.internet.interfaces import IResolverSimple
from twisted.names import client, root
from twisted.names.dns import (
CNAME,
ENAME,
HS,
IN,
NS,
OK,
A,
Message,
Name,
Query,
Record_A,
Record_CNAME,
Record_NS,
RRHeader,
)
from twisted.names.error import DNSNameError, ResolverError
from twisted.names.root import Resolver
from twisted.names.test.test_util import MemoryReactor
from twisted.python.log import msg
from twisted.trial import util
from twisted.trial.unittest import SynchronousTestCase, TestCase
def getOnePayload(results):
"""
From the result of a L{Deferred} returned by L{IResolver.lookupAddress},
return the payload of the first record in the answer section.
"""
ans, auth, add = results
return ans[0].payload
def getOneAddress(results):
"""
From the result of a L{Deferred} returned by L{IResolver.lookupAddress},
return the first IPv4 address from the answer section.
"""
return getOnePayload(results).dottedQuad()
class RootResolverTests(TestCase):
"""
Tests for L{twisted.names.root.Resolver}.
"""
def _queryTest(self, filter):
"""
Invoke L{Resolver._query} and verify that it sends the correct DNS
query. Deliver a canned response to the query and return whatever the
L{Deferred} returned by L{Resolver._query} fires with.
@param filter: The value to pass for the C{filter} parameter to
L{Resolver._query}.
"""
reactor = MemoryReactor()
resolver = Resolver([], reactor=reactor)
d = resolver._query(
Query(b"foo.example.com", A, IN), [("1.1.2.3", 1053)], (30,), filter
)
# A UDP port should have been started.
portNumber, transport = reactor.udpPorts.popitem()
# And a DNS packet sent.
[(packet, address)] = transport._sentPackets
message = Message()
message.fromStr(packet)
# It should be a query with the parameters used above.
self.assertEqual(message.queries, [Query(b"foo.example.com", A, IN)])
self.assertEqual(message.answers, [])
self.assertEqual(message.authority, [])
self.assertEqual(message.additional, [])
response = []
d.addCallback(response.append)
self.assertEqual(response, [])
# Once a reply is received, the Deferred should fire.
del message.queries[:]
message.answer = 1
message.answers.append(
RRHeader(b"foo.example.com", payload=Record_A("5.8.13.21"))
)
transport._protocol.datagramReceived(message.toStr(), ("1.1.2.3", 1053))
return response[0]
def test_filteredQuery(self):
"""
L{Resolver._query} accepts a L{Query} instance and an address, issues
the query, and returns a L{Deferred} which fires with the response to
the query. If a true value is passed for the C{filter} parameter, the
result is a three-tuple of lists of records.
"""
answer, authority, additional = self._queryTest(True)
self.assertEqual(
answer, [RRHeader(b"foo.example.com", payload=Record_A("5.8.13.21", ttl=0))]
)
self.assertEqual(authority, [])
self.assertEqual(additional, [])
def test_unfilteredQuery(self):
"""
Similar to L{test_filteredQuery}, but for the case where a false value
is passed for the C{filter} parameter. In this case, the result is a
L{Message} instance.
"""
message = self._queryTest(False)
self.assertIsInstance(message, Message)
self.assertEqual(message.queries, [])
self.assertEqual(
message.answers,
[RRHeader(b"foo.example.com", payload=Record_A("5.8.13.21", ttl=0))],
)
self.assertEqual(message.authority, [])
self.assertEqual(message.additional, [])
def _respond(self, answers=[], authority=[], additional=[], rCode=OK):
"""
Create a L{Message} suitable for use as a response to a query.
@param answers: A C{list} of two-tuples giving data for the answers
section of the message. The first element of each tuple is a name
for the L{RRHeader}. The second element is the payload.
@param authority: A C{list} like C{answers}, but for the authority
section of the response.
@param additional: A C{list} like C{answers}, but for the
additional section of the response.
@param rCode: The response code the message will be created with.
@return: A new L{Message} initialized with the given values.
"""
response = Message(rCode=rCode)
for section, data in [
(response.answers, answers),
(response.authority, authority),
(response.additional, additional),
]:
section.extend(
[
RRHeader(
name, record.TYPE, getattr(record, "CLASS", IN), payload=record
)
for (name, record) in data
]
)
return response
def _getResolver(self, serverResponses, maximumQueries=10):
"""
Create and return a new L{root.Resolver} modified to resolve queries
against the record data represented by C{servers}.
@param serverResponses: A mapping from dns server addresses to
mappings. The inner mappings are from query two-tuples (name,
type) to dictionaries suitable for use as **arguments to
L{_respond}. See that method for details.
"""
roots = ["1.1.2.3"]
resolver = Resolver(roots, maximumQueries)
def query(query, serverAddresses, timeout, filter):
msg(f"Query for QNAME {query.name} at {serverAddresses!r}")
for addr in serverAddresses:
try:
server = serverResponses[addr]
except KeyError:
continue
records = server[query.name.name, query.type]
return succeed(self._respond(**records))
resolver._query = query
return resolver
def test_lookupAddress(self):
"""
L{root.Resolver.lookupAddress} looks up the I{A} records for the
specified hostname by first querying one of the root servers the
resolver was created with and then following the authority delegations
until a result is received.
"""
servers = {
("1.1.2.3", 53): {
(b"foo.example.com", A): {
"authority": [(b"foo.example.com", Record_NS(b"ns1.example.com"))],
"additional": [(b"ns1.example.com", Record_A("34.55.89.144"))],
},
},
("34.55.89.144", 53): {
(b"foo.example.com", A): {
"answers": [(b"foo.example.com", Record_A("10.0.0.1"))],
}
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"foo.example.com")
d.addCallback(getOneAddress)
d.addCallback(self.assertEqual, "10.0.0.1")
return d
def test_lookupChecksClass(self):
"""
If a response includes a record with a class different from the one
in the query, it is ignored and lookup continues until a record with
the right class is found.
"""
badClass = Record_A("10.0.0.1")
badClass.CLASS = HS
servers = {
("1.1.2.3", 53): {
(b"foo.example.com", A): {
"answers": [(b"foo.example.com", badClass)],
"authority": [(b"foo.example.com", Record_NS(b"ns1.example.com"))],
"additional": [(b"ns1.example.com", Record_A("10.0.0.2"))],
},
},
("10.0.0.2", 53): {
(b"foo.example.com", A): {
"answers": [(b"foo.example.com", Record_A("10.0.0.3"))],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"foo.example.com")
d.addCallback(getOnePayload)
d.addCallback(self.assertEqual, Record_A("10.0.0.3"))
return d
def test_missingGlue(self):
"""
If an intermediate response includes no glue records for the
authorities, separate queries are made to find those addresses.
"""
servers = {
("1.1.2.3", 53): {
(b"foo.example.com", A): {
"authority": [(b"foo.example.com", Record_NS(b"ns1.example.org"))],
# Conspicuous lack of an additional section naming ns1.example.com
},
(b"ns1.example.org", A): {
"answers": [(b"ns1.example.org", Record_A("10.0.0.1"))],
},
},
("10.0.0.1", 53): {
(b"foo.example.com", A): {
"answers": [(b"foo.example.com", Record_A("10.0.0.2"))],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"foo.example.com")
d.addCallback(getOneAddress)
d.addCallback(self.assertEqual, "10.0.0.2")
return d
def test_missingName(self):
"""
If a name is missing, L{Resolver.lookupAddress} returns a L{Deferred}
which fails with L{DNSNameError}.
"""
servers = {
("1.1.2.3", 53): {
(b"foo.example.com", A): {
"rCode": ENAME,
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"foo.example.com")
return self.assertFailure(d, DNSNameError)
def test_answerless(self):
"""
If a query is responded to with no answers or nameserver records, the
L{Deferred} returned by L{Resolver.lookupAddress} fires with
L{ResolverError}.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
return self.assertFailure(d, ResolverError)
def test_delegationLookupError(self):
"""
If there is an error resolving the nameserver in a delegation response,
the L{Deferred} returned by L{Resolver.lookupAddress} fires with that
error.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"authority": [(b"example.com", Record_NS(b"ns1.example.com"))],
},
(b"ns1.example.com", A): {
"rCode": ENAME,
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
return self.assertFailure(d, DNSNameError)
def test_delegationLookupEmpty(self):
"""
If there are no records in the response to a lookup of a delegation
nameserver, the L{Deferred} returned by L{Resolver.lookupAddress} fires
with L{ResolverError}.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"authority": [(b"example.com", Record_NS(b"ns1.example.com"))],
},
(b"ns1.example.com", A): {},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
return self.assertFailure(d, ResolverError)
def test_lookupNameservers(self):
"""
L{Resolver.lookupNameservers} is like L{Resolver.lookupAddress}, except
it queries for I{NS} records instead of I{A} records.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"rCode": ENAME,
},
(b"example.com", NS): {
"answers": [(b"example.com", Record_NS(b"ns1.example.com"))],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupNameservers(b"example.com")
def getOneName(results):
ans, auth, add = results
return ans[0].payload.name
d.addCallback(getOneName)
d.addCallback(self.assertEqual, Name(b"ns1.example.com"))
return d
def test_returnCanonicalName(self):
"""
If a I{CNAME} record is encountered as the answer to a query for
another record type, that record is returned as the answer.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"answers": [
(b"example.com", Record_CNAME(b"example.net")),
(b"example.net", Record_A("10.0.0.7")),
],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
d.addCallback(lambda results: results[0]) # Get the answer section
d.addCallback(
self.assertEqual,
[
RRHeader(b"example.com", CNAME, payload=Record_CNAME(b"example.net")),
RRHeader(b"example.net", A, payload=Record_A("10.0.0.7")),
],
)
return d
def test_followCanonicalName(self):
"""
If no record of the requested type is included in a response, but a
I{CNAME} record for the query name is included, queries are made to
resolve the value of the I{CNAME}.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"answers": [(b"example.com", Record_CNAME(b"example.net"))],
},
(b"example.net", A): {
"answers": [(b"example.net", Record_A("10.0.0.5"))],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
d.addCallback(lambda results: results[0]) # Get the answer section
d.addCallback(
self.assertEqual,
[
RRHeader(b"example.com", CNAME, payload=Record_CNAME(b"example.net")),
RRHeader(b"example.net", A, payload=Record_A("10.0.0.5")),
],
)
return d
def test_detectCanonicalNameLoop(self):
"""
If there is a cycle between I{CNAME} records in a response, this is
detected and the L{Deferred} returned by the lookup method fails
with L{ResolverError}.
"""
servers = {
("1.1.2.3", 53): {
(b"example.com", A): {
"answers": [
(b"example.com", Record_CNAME(b"example.net")),
(b"example.net", Record_CNAME(b"example.com")),
],
},
},
}
resolver = self._getResolver(servers)
d = resolver.lookupAddress(b"example.com")
return self.assertFailure(d, ResolverError)
def test_boundedQueries(self):
"""
L{Resolver.lookupAddress} won't issue more queries following
delegations than the limit passed to its initializer.
"""
servers = {
("1.1.2.3", 53): {
# First query - force it to start over with a name lookup of
# ns1.example.com
(b"example.com", A): {
"authority": [(b"example.com", Record_NS(b"ns1.example.com"))],
},
# Second query - let it resume the original lookup with the
# address of the nameserver handling the delegation.
(b"ns1.example.com", A): {
"answers": [(b"ns1.example.com", Record_A("10.0.0.2"))],
},
},
("10.0.0.2", 53): {
# Third query - let it jump straight to asking the
# delegation server by including its address here (different
# case from the first query).
(b"example.com", A): {
"authority": [(b"example.com", Record_NS(b"ns2.example.com"))],
"additional": [(b"ns2.example.com", Record_A("10.0.0.3"))],
},
},
("10.0.0.3", 53): {
# Fourth query - give it the answer, we're done.
(b"example.com", A): {
"answers": [(b"example.com", Record_A("10.0.0.4"))],
},
},
}
# Make two resolvers. One which is allowed to make 3 queries
# maximum, and so will fail, and on which may make 4, and so should
# succeed.
failer = self._getResolver(servers, 3)
failD = self.assertFailure(failer.lookupAddress(b"example.com"), ResolverError)
succeeder = self._getResolver(servers, 4)
succeedD = succeeder.lookupAddress(b"example.com")
succeedD.addCallback(getOnePayload)
succeedD.addCallback(self.assertEqual, Record_A("10.0.0.4"))
return gatherResults([failD, succeedD])
class ResolverFactoryArguments(Exception):
"""
Raised by L{raisingResolverFactory} with the *args and **kwargs passed to
that function.
"""
def __init__(self, args, kwargs):
"""
Store the supplied args and kwargs as attributes.
@param args: Positional arguments.
@param kwargs: Keyword arguments.
"""
self.args = args
self.kwargs = kwargs
def raisingResolverFactory(*args, **kwargs):
"""
Raise a L{ResolverFactoryArguments} exception containing the
positional and keyword arguments passed to resolverFactory.
@param args: A L{list} of all the positional arguments supplied by
the caller.
@param kwargs: A L{list} of all the keyword arguments supplied by
the caller.
"""
raise ResolverFactoryArguments(args, kwargs)
class RootResolverResolverFactoryTests(TestCase):
"""
Tests for L{root.Resolver._resolverFactory}.
"""
def test_resolverFactoryArgumentPresent(self):
"""
L{root.Resolver.__init__} accepts a C{resolverFactory}
argument and assigns it to C{self._resolverFactory}.
"""
r = Resolver(hints=[None], resolverFactory=raisingResolverFactory)
self.assertIs(r._resolverFactory, raisingResolverFactory)
def test_resolverFactoryArgumentAbsent(self):
"""
L{root.Resolver.__init__} sets L{client.Resolver} as the
C{_resolverFactory} if a C{resolverFactory} argument is not
supplied.
"""
r = Resolver(hints=[None])
self.assertIs(r._resolverFactory, client.Resolver)
def test_resolverFactoryOnlyExpectedArguments(self):
"""
L{root.Resolver._resolverFactory} is supplied with C{reactor} and
C{servers} keyword arguments.
"""
dummyReactor = object()
r = Resolver(
hints=["192.0.2.101"],
resolverFactory=raisingResolverFactory,
reactor=dummyReactor,
)
e = self.assertRaises(ResolverFactoryArguments, r.lookupAddress, "example.com")
self.assertEqual(
((), {"reactor": dummyReactor, "servers": [("192.0.2.101", 53)]}),
(e.args, e.kwargs),
)
ROOT_SERVERS = [
"a.root-servers.net",
"b.root-servers.net",
"c.root-servers.net",
"d.root-servers.net",
"e.root-servers.net",
"f.root-servers.net",
"g.root-servers.net",
"h.root-servers.net",
"i.root-servers.net",
"j.root-servers.net",
"k.root-servers.net",
"l.root-servers.net",
"m.root-servers.net",
]
@implementer(IResolverSimple)
class StubResolver:
"""
An L{IResolverSimple} implementer which traces all getHostByName
calls and their deferred results. The deferred results can be
accessed and fired synchronously.
"""
def __init__(self):
"""
@type calls: L{list} of L{tuple} containing C{args} and
C{kwargs} supplied to C{getHostByName} calls.
@type pendingResults: L{list} of L{Deferred} returned by
C{getHostByName}.
"""
self.calls = []
self.pendingResults = []
def getHostByName(self, *args, **kwargs):
"""
A fake implementation of L{IResolverSimple.getHostByName}
@param args: A L{list} of all the positional arguments supplied by
the caller.
@param kwargs: A L{list} of all the keyword arguments supplied by
the caller.
@return: A L{Deferred} which may be fired later from the test
fixture.
"""
self.calls.append((args, kwargs))
d = Deferred()
self.pendingResults.append(d)
return d
verifyClass(IResolverSimple, StubResolver)
class BootstrapTests(SynchronousTestCase):
"""
Tests for L{root.bootstrap}
"""
def test_returnsDeferredResolver(self):
"""
L{root.bootstrap} returns an object which is initially a
L{root.DeferredResolver}.
"""
deferredResolver = root.bootstrap(StubResolver())
self.assertIsInstance(deferredResolver, root.DeferredResolver)
def test_resolves13RootServers(self):
"""
The L{IResolverSimple} supplied to L{root.bootstrap} is used to lookup
the IP addresses of the 13 root name servers.
"""
stubResolver = StubResolver()
root.bootstrap(stubResolver)
self.assertEqual(stubResolver.calls, [((s,), {}) for s in ROOT_SERVERS])
def test_becomesResolver(self):
"""
The L{root.DeferredResolver} initially returned by L{root.bootstrap}
becomes a L{root.Resolver} when the supplied resolver has successfully
looked up all root hints.
"""
stubResolver = StubResolver()
deferredResolver = root.bootstrap(stubResolver)
for d in stubResolver.pendingResults:
d.callback("192.0.2.101")
self.assertIsInstance(deferredResolver, Resolver)
def test_resolverReceivesRootHints(self):
"""
The L{root.Resolver} which eventually replaces L{root.DeferredResolver}
is supplied with the IP addresses of the 13 root servers.
"""
stubResolver = StubResolver()
deferredResolver = root.bootstrap(stubResolver)
for d in stubResolver.pendingResults:
d.callback("192.0.2.101")
self.assertEqual(deferredResolver.hints, ["192.0.2.101"] * 13)
def test_continuesWhenSomeRootHintsFail(self):
"""
The L{root.Resolver} is eventually created, even if some of the root
hint lookups fail. Only the working root hint IP addresses are supplied
to the L{root.Resolver}.
"""
stubResolver = StubResolver()
deferredResolver = root.bootstrap(stubResolver)
results = iter(stubResolver.pendingResults)
d1 = next(results)
for d in results:
d.callback("192.0.2.101")
d1.errback(TimeoutError())
def checkHints(res):
self.assertEqual(deferredResolver.hints, ["192.0.2.101"] * 12)
d1.addBoth(checkHints)
def test_continuesWhenAllRootHintsFail(self):
"""
The L{root.Resolver} is eventually created, even if all of the root hint
lookups fail. Pending and new lookups will then fail with
AttributeError.
"""
stubResolver = StubResolver()
deferredResolver = root.bootstrap(stubResolver)
results = iter(stubResolver.pendingResults)
d1 = next(results)
for d in results:
d.errback(TimeoutError())
d1.errback(TimeoutError())
def checkHints(res):
self.assertEqual(deferredResolver.hints, [])
d1.addBoth(checkHints)
self.addCleanup(self.flushLoggedErrors, TimeoutError)
def test_passesResolverFactory(self):
"""
L{root.bootstrap} accepts a C{resolverFactory} argument which is passed
as an argument to L{root.Resolver} when it has successfully looked up
root hints.
"""
stubResolver = StubResolver()
deferredResolver = root.bootstrap(
stubResolver, resolverFactory=raisingResolverFactory
)
for d in stubResolver.pendingResults:
d.callback("192.0.2.101")
self.assertIs(deferredResolver._resolverFactory, raisingResolverFactory)
class StubDNSDatagramProtocol:
"""
A do-nothing stand-in for L{DNSDatagramProtocol} which can be used to avoid
network traffic in tests where that kind of thing doesn't matter.
"""
def query(self, *a, **kw):
return Deferred()
_retrySuppression = util.suppress(
category=DeprecationWarning,
message=(
"twisted.names.root.retry is deprecated since Twisted 10.0. Use a "
"Resolver object for retry logic."
),
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,289 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.names.srvconnect}.
"""
import random
from zope.interface.verify import verifyObject
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError, ServiceNameUnknownError
from twisted.internet.interfaces import IConnector
from twisted.internet.testing import MemoryReactor
from twisted.names import client, dns, srvconnect
from twisted.names.common import ResolverBase
from twisted.names.error import DNSNameError
from twisted.trial import unittest
class FakeResolver(ResolverBase):
"""
Resolver that only gives out one given result.
Either L{results} or L{failure} must be set and will be used for
the return value of L{_lookup}
@ivar results: List of L{dns.RRHeader} for the desired result.
@type results: C{list}
@ivar failure: Failure with an exception from L{twisted.names.error}.
@type failure: L{Failure<twisted.python.failure.Failure>}
"""
def __init__(self, results=None, failure=None):
self.results = results
self.failure = failure
self.lookups = []
def _lookup(self, name, cls, qtype, timeout):
"""
Return the result or failure on lookup.
"""
self.lookups.append((name, cls, qtype, timeout))
if self.results is not None:
return defer.succeed((self.results, [], []))
else:
return defer.fail(self.failure)
class DummyFactory(protocol.ClientFactory):
"""
Dummy client factory that stores the reason of connection failure.
"""
def __init__(self):
self.reason = None
def clientConnectionFailed(self, connector, reason):
self.reason = reason
class SRVConnectorTests(unittest.TestCase):
"""
Tests for L{srvconnect.SRVConnector}.
"""
def setUp(self):
self.patch(client, "theResolver", FakeResolver())
self.reactor = MemoryReactor()
self.factory = DummyFactory()
self.connector = srvconnect.SRVConnector(
self.reactor, "xmpp-server", "example.org", self.factory
)
self.randIntArgs = []
self.randIntResults = []
def _randint(self, min, max):
"""
Fake randint.
Returns the first element of L{randIntResults} and records the
arguments passed to it in L{randIntArgs}.
@param min: Lower bound of the random number.
@type min: L{int}
@param max: Higher bound of the random number.
@type max: L{int}
@return: Fake random number from L{randIntResults}.
@rtype: L{int}
"""
self.randIntArgs.append((min, max))
return self.randIntResults.pop(0)
def test_interface(self):
"""
L{srvconnect.SRVConnector} implements L{IConnector}.
"""
verifyObject(IConnector, self.connector)
def test_SRVPresent(self):
"""
Test connectTCP gets called with the address from the SRV record.
"""
payload = dns.Record_SRV(port=6269, target="host.example.org", ttl=60)
client.theResolver.results = [
dns.RRHeader(
name="example.org", type=dns.SRV, cls=dns.IN, ttl=60, payload=payload
)
]
self.connector.connect()
self.assertIsNone(self.factory.reason)
self.assertEqual(self.reactor.tcpClients.pop()[:2], ("host.example.org", 6269))
def test_SRVNotPresent(self):
"""
Test connectTCP gets called with fallback parameters on NXDOMAIN.
"""
client.theResolver.failure = DNSNameError(b"example.org")
self.connector.connect()
self.assertIsNone(self.factory.reason)
self.assertEqual(
self.reactor.tcpClients.pop()[:2], ("example.org", "xmpp-server")
)
def test_SRVNoResult(self):
"""
Test connectTCP gets called with fallback parameters on empty result.
"""
client.theResolver.results = []
self.connector.connect()
self.assertIsNone(self.factory.reason)
self.assertEqual(
self.reactor.tcpClients.pop()[:2], ("example.org", "xmpp-server")
)
def test_SRVNoResultUnknownServiceDefaultPort(self):
"""
connectTCP gets called with default port if the service is not defined.
"""
self.connector = srvconnect.SRVConnector(
self.reactor,
"thisbetternotexist",
"example.org",
self.factory,
defaultPort=5222,
)
client.theResolver.failure = ServiceNameUnknownError()
self.connector.connect()
self.assertIsNone(self.factory.reason)
self.assertEqual(self.reactor.tcpClients.pop()[:2], ("example.org", 5222))
def test_SRVNoResultUnknownServiceNoDefaultPort(self):
"""
Connect fails on no result, unknown service and no default port.
"""
self.connector = srvconnect.SRVConnector(
self.reactor, "thisbetternotexist", "example.org", self.factory
)
client.theResolver.failure = ServiceNameUnknownError()
self.connector.connect()
self.assertTrue(self.factory.reason.check(ServiceNameUnknownError))
def test_SRVBadResult(self):
"""
Test connectTCP gets called with fallback parameters on bad result.
"""
client.theResolver.results = [
dns.RRHeader(
name="example.org", type=dns.CNAME, cls=dns.IN, ttl=60, payload=None
)
]
self.connector.connect()
self.assertIsNone(self.factory.reason)
self.assertEqual(
self.reactor.tcpClients.pop()[:2], ("example.org", "xmpp-server")
)
def test_SRVNoService(self):
"""
Test that connecting fails when no service is present.
"""
payload = dns.Record_SRV(port=5269, target=b".", ttl=60)
client.theResolver.results = [
dns.RRHeader(
name="example.org", type=dns.SRV, cls=dns.IN, ttl=60, payload=payload
)
]
self.connector.connect()
self.assertIsNotNone(self.factory.reason)
self.factory.reason.trap(DNSLookupError)
self.assertEqual(self.reactor.tcpClients, [])
def test_SRVLookupName(self):
"""
The lookup name is a native string from service, protocol and domain.
"""
client.theResolver.results = []
self.connector.connect()
name = client.theResolver.lookups[-1][0]
self.assertEqual(b"_xmpp-server._tcp.example.org", name)
def test_unicodeDomain(self):
"""
L{srvconnect.SRVConnector} automatically encodes unicode domain using
C{idna} encoding.
"""
self.connector = srvconnect.SRVConnector(
self.reactor, "xmpp-client", "\u00e9chec.example.org", self.factory
)
self.assertEqual(b"xn--chec-9oa.example.org", self.connector.domain)
def test_pickServerWeights(self):
"""
pickServer calculates running sum of weights and calls randint.
This exercises the server selection algorithm specified in RFC 2782 by
preparing fake L{random.randint} results and checking the values it was
called with.
"""
record1 = dns.Record_SRV(10, 10, 5222, "host1.example.org")
record2 = dns.Record_SRV(10, 20, 5222, "host2.example.org")
self.connector.orderedServers = [record1, record2]
self.connector.servers = []
self.patch(random, "randint", self._randint)
# 1st round
self.randIntResults = [11, 0]
self.connector.pickServer()
self.assertEqual(self.randIntArgs[0], (0, 30))
self.connector.pickServer()
self.assertEqual(self.randIntArgs[1], (0, 10))
# 2nd round
self.randIntResults = [10, 0]
self.connector.pickServer()
self.assertEqual(self.randIntArgs[2], (0, 30))
self.connector.pickServer()
self.assertEqual(self.randIntArgs[3], (0, 20))
def test_pickServerSamePriorities(self):
"""
Two records with equal priorities compare on weight (ascending).
"""
record1 = dns.Record_SRV(10, 10, 5222, "host1.example.org")
record2 = dns.Record_SRV(10, 20, 5222, "host2.example.org")
self.connector.orderedServers = [record2, record1]
self.connector.servers = []
self.patch(random, "randint", self._randint)
self.randIntResults = [0, 0]
self.assertEqual(("host1.example.org", 5222), self.connector.pickServer())
self.assertEqual(("host2.example.org", 5222), self.connector.pickServer())
def test_srvDifferentPriorities(self):
"""
Two records with differing priorities compare on priority (ascending).
"""
record1 = dns.Record_SRV(10, 0, 5222, "host1.example.org")
record2 = dns.Record_SRV(20, 0, 5222, "host2.example.org")
self.connector.orderedServers = [record2, record1]
self.connector.servers = []
self.patch(random, "randint", self._randint)
self.randIntResults = [0, 0]
self.assertEqual(("host1.example.org", 5222), self.connector.pickServer())
self.assertEqual(("host2.example.org", 5222), self.connector.pickServer())

View File

@@ -0,0 +1,118 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.names.tap}.
"""
from twisted.internet.base import ThreadedResolver
from twisted.names.client import Resolver
from twisted.names.dns import PORT
from twisted.names.resolve import ResolverChain
from twisted.names.secondary import SecondaryAuthorityService
from twisted.names.tap import Options, _buildResolvers
from twisted.python.runtime import platform
from twisted.python.usage import UsageError
from twisted.trial.unittest import SynchronousTestCase
class OptionsTests(SynchronousTestCase):
"""
Tests for L{Options}, defining how command line arguments for the DNS server
are parsed.
"""
def test_malformedSecondary(self) -> None:
"""
If the value supplied for an I{--secondary} option does not provide a
server IP address, optional port number, and domain name,
L{Options.parseOptions} raises L{UsageError}.
"""
options = Options()
self.assertRaises(UsageError, options.parseOptions, ["--secondary", ""])
self.assertRaises(UsageError, options.parseOptions, ["--secondary", "1.2.3.4"])
self.assertRaises(
UsageError, options.parseOptions, ["--secondary", "1.2.3.4:hello"]
)
self.assertRaises(
UsageError,
options.parseOptions,
["--secondary", "1.2.3.4:hello/example.com"],
)
def test_secondary(self) -> None:
"""
An argument of the form C{"ip/domain"} is parsed by L{Options} for the
I{--secondary} option and added to its list of secondaries, using the
default DNS port number.
"""
options = Options()
options.parseOptions(["--secondary", "1.2.3.4/example.com"])
self.assertEqual([(("1.2.3.4", PORT), ["example.com"])], options.secondaries)
def test_secondaryExplicitPort(self) -> None:
"""
An argument of the form C{"ip:port/domain"} can be used to specify an
alternate port number for which to act as a secondary.
"""
options = Options()
options.parseOptions(["--secondary", "1.2.3.4:5353/example.com"])
self.assertEqual([(("1.2.3.4", 5353), ["example.com"])], options.secondaries)
def test_secondaryAuthorityServices(self) -> None:
"""
After parsing I{--secondary} options, L{Options} constructs a
L{SecondaryAuthorityService} instance for each configured secondary.
"""
options = Options()
options.parseOptions(
[
"--secondary",
"1.2.3.4:5353/example.com",
"--secondary",
"1.2.3.5:5354/example.com",
]
)
self.assertEqual(len(options.svcs), 2)
secondary = options.svcs[0]
self.assertIsInstance(options.svcs[0], SecondaryAuthorityService)
self.assertEqual(secondary.primary, "1.2.3.4")
self.assertEqual(secondary._port, 5353)
secondary = options.svcs[1]
self.assertIsInstance(options.svcs[1], SecondaryAuthorityService)
self.assertEqual(secondary.primary, "1.2.3.5")
self.assertEqual(secondary._port, 5354)
def test_recursiveConfiguration(self) -> None:
"""
Recursive DNS lookups, if enabled, should be a last-resort option.
Any other lookup method (cache, local lookup, etc.) should take
precedence over recursive lookups
"""
options = Options()
options.parseOptions(["--hosts-file", "hosts.txt", "--recursive"])
ca, cl = _buildResolvers(options)
# Extra cleanup, necessary on POSIX because client.Resolver doesn't know
# when to stop parsing resolv.conf. See #NNN for improving this.
for x in cl:
if isinstance(x, ResolverChain):
recurser = x.resolvers[-1]
if isinstance(recurser, Resolver):
recurser._parseCall.cancel()
# On Windows, we need to use a threaded resolver, which leaves trash
# lying about that we can't easily clean up without reaching into the
# reactor and cancelling them. We only cancel the cleanup functions, as
# there should be no others (and it leaving a callLater lying about
# should rightly cause the test to fail).
if platform.getType() != "posix":
# We want the delayed calls on the reactor, which should be all of
# ours from the threaded resolver cleanup
from twisted.internet import reactor
for x in reactor._newTimedCalls: # type: ignore[attr-defined]
self.assertEqual(x.func.__func__, ThreadedResolver._cleanup)
x.cancel()
self.assertIsInstance(cl[-1], ResolverChain)

View File

@@ -0,0 +1,129 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Utilities for Twisted.names tests.
"""
from random import randrange
from zope.interface import implementer
from zope.interface.verify import verifyClass
from twisted.internet.address import IPv4Address
from twisted.internet.defer import succeed
from twisted.internet.interfaces import IReactorUDP, IUDPTransport
from twisted.internet.task import Clock
@implementer(IUDPTransport)
class MemoryDatagramTransport:
"""
This L{IUDPTransport} implementation enforces the usual connection rules
and captures sent traffic in a list for later inspection.
@ivar _host: The host address to which this transport is bound.
@ivar _protocol: The protocol connected to this transport.
@ivar _sentPackets: A C{list} of two-tuples of the datagrams passed to
C{write} and the addresses to which they are destined.
@ivar _connectedTo: L{None} if this transport is unconnected, otherwise an
address to which all traffic is supposedly sent.
@ivar _maxPacketSize: An C{int} giving the maximum length of a datagram
which will be successfully handled by C{write}.
"""
def __init__(self, host, protocol, maxPacketSize):
self._host = host
self._protocol = protocol
self._sentPackets = []
self._connectedTo = None
self._maxPacketSize = maxPacketSize
def getHost(self):
"""
Return the address which this transport is pretending to be bound
to.
"""
return IPv4Address("UDP", *self._host)
def connect(self, host, port):
"""
Connect this transport to the given address.
"""
if self._connectedTo is not None:
raise ValueError("Already connected")
self._connectedTo = (host, port)
def write(self, datagram, addr=None):
"""
Send the given datagram.
"""
if addr is None:
addr = self._connectedTo
if addr is None:
raise ValueError("Need an address")
if len(datagram) > self._maxPacketSize:
raise ValueError("Packet too big")
self._sentPackets.append((datagram, addr))
def stopListening(self):
"""
Shut down this transport.
"""
self._protocol.stopProtocol()
return succeed(None)
def setBroadcastAllowed(self, enabled):
"""
Dummy implementation to satisfy L{IUDPTransport}.
"""
pass
def getBroadcastAllowed(self):
"""
Dummy implementation to satisfy L{IUDPTransport}.
"""
pass
verifyClass(IUDPTransport, MemoryDatagramTransport)
@implementer(IReactorUDP)
class MemoryReactor(Clock):
"""
An L{IReactorTime} and L{IReactorUDP} provider.
Time is controlled deterministically via the base class, L{Clock}. UDP is
handled in-memory by connecting protocols to instances of
L{MemoryDatagramTransport}.
@ivar udpPorts: A C{dict} mapping port numbers to instances of
L{MemoryDatagramTransport}.
"""
def __init__(self):
Clock.__init__(self)
self.udpPorts = {}
def listenUDP(self, port, protocol, interface="", maxPacketSize=8192):
"""
Pretend to bind a UDP port and connect the given protocol to it.
"""
if port == 0:
while True:
port = randrange(1, 2**16)
if port not in self.udpPorts:
break
if port in self.udpPorts:
raise ValueError("Address in use")
transport = MemoryDatagramTransport((interface, port), protocol, maxPacketSize)
self.udpPorts[port] = transport
protocol.makeConnection(transport)
return transport
verifyClass(IReactorUDP, MemoryReactor)