fully working!

This commit is contained in:
Kalam :p
2020-06-28 12:44:20 +01:00
parent bb6da691a7
commit 174683d2b4
128 changed files with 15017 additions and 4411 deletions

BIN
Darwin/7za Normal file

Binary file not shown.

BIN
Darwin/fatcat Normal file

Binary file not shown.

BIN
Darwin/ndsblc Normal file

Binary file not shown.

BIN
Darwin/twltool Normal file

Binary file not shown.

BIN
Linux/7za Normal file

Binary file not shown.

BIN
Linux/fatcat Normal file

Binary file not shown.

BIN
Linux/ndsblc Normal file

Binary file not shown.

BIN
Linux/twltool Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

3
certifi/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .core import contents, where
__version__ = "2020.04.05.1"

12
certifi/__main__.py Normal file
View File

@@ -0,0 +1,12 @@
import argparse
from certifi import contents, where
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--contents", action="store_true")
args = parser.parse_args()
if args.contents:
print(contents())
else:
print(where())

Binary file not shown.

Binary file not shown.

Binary file not shown.

4641
certifi/cacert.pem Normal file

File diff suppressed because it is too large Load Diff

30
certifi/core.py Normal file
View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""
certifi.py
~~~~~~~~~~
This module returns the installation location of cacert.pem or its contents.
"""
import os
try:
from importlib.resources import read_text
except ImportError:
# This fallback will work for Python versions prior to 3.7 that lack the
# importlib.resources module but relies on the existing `where` function
# so won't address issues with environments like PyOxidizer that don't set
# __file__ on modules.
def read_text(_module, _path, encoding="ascii"):
with open(where(), "r", encoding=encoding) as data:
return data.read()
def where():
f = os.path.dirname(__file__)
return os.path.join(f, "cacert.pem")
def contents():
return read_text("certifi", "cacert.pem", encoding="ascii")

BIN
lazy.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

78
main.py
View File

@@ -7,9 +7,9 @@ import platform
import sys
import requests
import json
import py7zr
from pathlib import Path
import shutil
from subprocess import Popen
import zipfile
dsiVersions = ["1.0 - 1.3 (USA, EUR, AUS, JPN)", "1.4 - 1.4.5 (USA, EUR, AUS, JPN)", "All versions (KOR, CHN)"]
@@ -18,6 +18,8 @@ memoryPitLinks = ["https://github.com/YourKalamity/just-a-dsi-cfw-installer/raw/
window = tkinter.Tk()
window.sourceFolder = ''
window.sourceFile = ''
appTitle = tkinter.Label(text="Lazy DSi file downloader")
appTitle.width = 100
SDlabel = tkinter.Label(text = "SD card directory")
SDlabel.width = 100
SDentry = tkinter.Entry()
@@ -63,6 +65,40 @@ def validateDirectory(directory):
def start():
outputBox.delete(0, tkinter.END)
sysname = platform.system()
_7za = os.path.join(sysname, '7za')
_7z = None
if sysname == "Windows":
from winreg import OpenKey, QueryValueEx, HKEY_LOCAL_MACHINE, KEY_READ, KEY_WOW64_64KEY
print('Searching for 7-Zip in the Windows registry...')
try:
with OpenKey(HKEY_LOCAL_MACHINE, 'SOFTWARE\\7-Zip', 0, KEY_READ | KEY_WOW64_64KEY) as hkey:
_7z = os.path.join(QueryValueEx(hkey, 'Path')[0], '7z.exe')
if not os.path.exists(_7z):
raise WindowsError
_7za = _7z
except WindowsError:
print('Searching for 7-Zip in the 32-bit Windows registry...')
try:
with OpenKey(HKEY_LOCAL_MACHINE, 'SOFTWARE\\7-Zip') as hkey:
_7z = os.path.join(QueryValueEx(hkey, 'Path')[0], '7z.exe')
if not os.path.exists(_7z):
raise WindowsError
_7za = _7z
except WindowsError:
print("7-Zip not found, please install it before using")
outputbox("7-Zip not found")
return
print("7-Zip found!")
outputBox.configure(state='normal')
outputBox.delete('1.0', tkinter.END)
outputBox.configure(state='disabled')
#Variables
directory = SDentry.get()
version = firmwareVersion.get()
@@ -85,19 +121,25 @@ def start():
r = requests.get(memoryPitDownload, allow_redirects=True)
memoryPitLocation = memoryPitLocation + "pit.bin"
open(memoryPitLocation, 'wb').write(r.content)
outputbox("Memory Pit Downloaded")
outputbox("Memory Pit Downloaded ")
#Download TWiLight Menu
r = requests.get(getLatestTWLmenu(), allow_redirects=True)
TWLmenuLocation = temp + "TWiLightMenu.7z"
open(TWLmenuLocation,'wb').write(r.content)
outputbox("TWiLight Menu ++ Downloaded ")
outputbox("TWiLight Menu ++ Downloaded ")
#Extract TWiLight Menu
archive = py7zr.SevenZipFile(TWLmenuLocation, mode='r')
archive.extractall(path=temp)
archive.close()
outputbox("TWiLight Menu ++ Extracted ")
proc = Popen([ _7za, 'x', TWLmenuLocation, '-o' + temp, '_nds', 'DSi - CFW users',
'DSi&3DS - SD card users', 'roms' ])
ret_val = proc.wait()
while True:
if ret_val == 0:
outputbox("TWiLight Menu ++ Extracted ")
break
else:
continue
#Move TWiLight Menu
shutil.copy(temp + "DSi&3DS - SD card users/BOOT.NDS", directory)
@@ -106,32 +148,40 @@ def start():
shutil.move(temp + "DSi - CFW users/SDNAND root/title", directory)
shutil.copy(temp + "DSi&3DS - SD card users/_nds/nds-bootstrap-hb-nightly.nds", directory + "/_nds")
shutil.copy(temp + "DSi&3DS - SD card users/_nds/nds-bootstrap-hb-release.nds", directory + "/_nds")
outputbox("TWiLight Menu placed ")
outputbox("TWiLight Menu ++ placed ")
#Download dumpTool
r = requests.get(getLatestdumpTool(), allow_redirects=True)
dumpToolLocation = directory + "/dumpTool.nds"
open(dumpToolLocation,'wb').write(r.content)
outputbox("dumpTool Downloaded ")
outputbox("dumpTool Downloaded ")
if unlaunchNeeded == 1 :
#Download Unlaunch
url = "https://problemkaputt.de/unlaunch.zip"
r = requests.get(url, allow_redirects=True)
unlaunchLocation = temp + "unlaunch.zip"
open(dumpToolLocation,'wb').write(r.content)
outputbox("Unlaunch Downloaded ")
open(unlaunchLocation,'wb').write(r.content)
outputbox("Unlaunch Downloaded ")
#Extract Unlaunch
with zipfile.ZipFile(unlaunchLocation, 'r') as zip_ref:
zip_ref.extractall(directory)
zip_ref.close()
#Delete tmp folder
shutil.rmtree(directory + '/tmp')
outputbox("Done!")
def chooseDir():
window.sourceFolder = filedialog.askdirectory(parent=window, initialdir= "/", title='Please select the directory of your SD card')
SDentry.delete(0, tkinter.END)
SDentry.insert(0, window.sourceFolder)
b_chooseDir = tkinter.Button(window, text = "Choose Folder", width = 20, command = chooseDir)
b_chooseDir.width = 100
b_chooseDir.height = 50
@@ -152,7 +202,9 @@ outputLabel = tkinter.Label(text="Output")
outputLabel.width = 100
outputBox = tkinter.Text(window,state='disabled', width = 30, height = 10)
window.title("Lazy DSi file downloader")
window.resizable(0, 0)
appTitle.pack()
SDlabel.pack()
SDentry.pack()
b_chooseDir.pack()

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +0,0 @@
#!/usr/bin/python -u
#
# p7zr library
#
# Copyright (c) 2020 Hiroshi Miura <miurahr@linux.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
from abc import ABC, abstractmethod
class Callback(ABC):
"""Abstrat base class for progress callbacks."""
@abstractmethod
def report_start_preparation(self):
"""report a start of preparation event such as making list of files and looking into its properties."""
pass
@abstractmethod
def report_start(self, processing_file_path, processing_bytes):
"""report a start event of specified archive file and its input bytes."""
pass
@abstractmethod
def report_end(self, processing_file_path, wrote_bytes):
"""report an end event of specified archive file and its output bytes."""
pass
@abstractmethod
def report_warning(self, message):
"""report an warning event with its message"""
pass
@abstractmethod
def report_postprocess(self):
"""report a start of post processing event such as set file properties and permissions or creating symlinks."""
pass
class ExtractCallback(Callback):
"""Abstrat base class for extraction progress callbacks."""
pass
class ArchiveCallback(Callback):
"""Abstrat base class for progress callbacks."""
pass

View File

@@ -1,317 +0,0 @@
#!/usr/bin/env python
#
# Pure python p7zr implementation
# Copyright (C) 2019, 2020 Hiroshi Miura
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
import argparse
import getpass
import os
import pathlib
import re
import shutil
import sys
from lzma import CHECK_CRC64, CHECK_SHA256, is_check_supported
from typing import Any, List, Optional
import texttable # type: ignore
import py7zr
from py7zr.callbacks import ExtractCallback
from py7zr.helpers import Local
from py7zr.properties import READ_BLOCKSIZE, SupportedMethods
class CliExtractCallback(ExtractCallback):
def __init__(self, total_bytes, ofd=sys.stdout):
self.ofd = ofd
self.archive_total = total_bytes
self.total_bytes = 0
self.columns, _ = shutil.get_terminal_size(fallback=(80, 24))
self.pwidth = 0
def report_start_preparation(self):
pass
def report_start(self, processing_file_path, processing_bytes):
self.ofd.write('- {}'.format(processing_file_path))
self.pwidth += len(processing_file_path) + 2
def report_end(self, processing_file_path, wrote_bytes):
self.total_bytes += int(wrote_bytes)
plest = self.columns - self.pwidth
progress = self.total_bytes / self.archive_total
msg = '({:.0%})\n'.format(progress)
if plest - len(msg) > 0:
self.ofd.write(msg.rjust(plest))
else:
self.ofd.write(msg)
self.pwidth = 0
def report_postprocess(self):
pass
def report_warning(self, message):
pass
class Cli():
dunits = {'b': 1, 'B': 1, 'k': 1024, 'K': 1024, 'm': 1024 * 1024, 'M': 1024 * 1024,
'g': 1024 * 1024 * 1024, 'G': 1024 * 1024 * 1024}
def __init__(self):
self.parser = self._create_parser()
self.unit_pattern = re.compile(r'^([0-9]+)([bkmg]?)$', re.IGNORECASE)
def run(self, arg: Optional[Any] = None) -> int:
args = self.parser.parse_args(arg)
return args.func(args)
def _create_parser(self):
parser = argparse.ArgumentParser(prog='py7zr', description='py7zr',
formatter_class=argparse.RawTextHelpFormatter, add_help=True)
subparsers = parser.add_subparsers(title='subcommands', help='subcommand for py7zr l .. list, x .. extract,'
' t .. check integrity, i .. information')
list_parser = subparsers.add_parser('l')
list_parser.set_defaults(func=self.run_list)
list_parser.add_argument("arcfile", help="7z archive file")
list_parser.add_argument("--verbose", action="store_true", help="verbose output")
extract_parser = subparsers.add_parser('x')
extract_parser.set_defaults(func=self.run_extract)
extract_parser.add_argument("arcfile", help="7z archive file")
extract_parser.add_argument("odir", nargs="?", help="output directory")
extract_parser.add_argument("-P", "--password", action="store_true",
help="Password protected archive(you will be asked a password).")
extract_parser.add_argument("--verbose", action="store_true", help="verbose output")
create_parser = subparsers.add_parser('c')
create_parser.set_defaults(func=self.run_create)
create_parser.add_argument("arcfile", help="7z archive file")
create_parser.add_argument("filenames", nargs="+", help="filenames to archive")
create_parser.add_argument("-v", "--volume", nargs=1, help="Create volumes.")
test_parser = subparsers.add_parser('t')
test_parser.set_defaults(func=self.run_test)
test_parser.add_argument("arcfile", help="7z archive file")
info_parser = subparsers.add_parser("i")
info_parser.set_defaults(func=self.run_info)
parser.set_defaults(func=self.show_help)
return parser
def show_help(self, args):
self.parser.print_help()
return(0)
def run_info(self, args):
print("py7zr version {} {}".format(py7zr.__version__, py7zr.__copyright__))
print("Formats:")
table = texttable.Texttable()
table.set_deco(texttable.Texttable.HEADER)
table.set_cols_dtype(['t', 't'])
table.set_cols_align(["l", "r"])
for f in SupportedMethods.formats:
m = ''.join(' {:02x}'.format(x) for x in f['magic'])
table.add_row([f['name'], m])
print(table.draw())
print("\nCodecs:")
table = texttable.Texttable()
table.set_deco(texttable.Texttable.HEADER)
table.set_cols_dtype(['t', 't'])
table.set_cols_align(["l", "r"])
for c in SupportedMethods.codecs:
m = ''.join('{:02x}'.format(x) for x in c['id'])
table.add_row([m, c['name']])
print(table.draw())
print("\nChecks:")
print("CHECK_NONE")
print("CHECK_CRC32")
if is_check_supported(CHECK_CRC64):
print("CHECK_CRC64")
if is_check_supported(CHECK_SHA256):
print("CHECK_SHA256")
def run_list(self, args):
"""Print a table of contents to file. """
target = args.arcfile
verbose = args.verbose
if not py7zr.is_7zfile(target):
print('not a 7z file')
return(1)
with open(target, 'rb') as f:
a = py7zr.SevenZipFile(f)
file = sys.stdout
archive_info = a.archiveinfo()
archive_list = a.list()
if verbose:
file.write("Listing archive: {}\n".format(target))
file.write("--\n")
file.write("Path = {}\n".format(archive_info.filename))
file.write("Type = 7z\n")
fstat = os.stat(archive_info.filename)
file.write("Phisical Size = {}\n".format(fstat.st_size))
file.write("Headers Size = {}\n".format(archive_info.header_size))
file.write("Method = {}\n".format(archive_info.method_names))
if archive_info.solid:
file.write("Solid = {}\n".format('+'))
else:
file.write("Solid = {}\n".format('-'))
file.write("Blocks = {}\n".format(archive_info.blocks))
file.write('\n')
file.write(
'total %d files and directories in %sarchive\n' % (len(archive_list),
(archive_info.solid and 'solid ') or ''))
file.write(' Date Time Attr Size Compressed Name\n')
file.write('------------------- ----- ------------ ------------ ------------------------\n')
for f in archive_list:
if f.creationtime is not None:
creationdate = f.creationtime.astimezone(Local).strftime("%Y-%m-%d")
creationtime = f.creationtime.astimezone(Local).strftime("%H:%M:%S")
else:
creationdate = ' '
creationtime = ' '
if f.is_directory:
attrib = 'D...'
else:
attrib = '....'
if f.archivable:
attrib += 'A'
else:
attrib += '.'
if f.is_directory:
extra = ' 0 '
elif f.compressed is None:
extra = ' '
else:
extra = '%12d ' % (f.compressed)
file.write('%s %s %s %12d %s %s\n' % (creationdate, creationtime, attrib,
f.uncompressed, extra, f.filename))
file.write('------------------- ----- ------------ ------------ ------------------------\n')
return(0)
@staticmethod
def print_archiveinfo(archive, file):
file.write("--\n")
file.write("Path = {}\n".format(archive.filename))
file.write("Type = 7z\n")
fstat = os.stat(archive.filename)
file.write("Phisical Size = {}\n".format(fstat.st_size))
file.write("Headers Size = {}\n".format(archive.header.size)) # fixme.
file.write("Method = {}\n".format(archive._get_method_names()))
if archive._is_solid():
file.write("Solid = {}\n".format('+'))
else:
file.write("Solid = {}\n".format('-'))
file.write("Blocks = {}\n".format(len(archive.header.main_streams.unpackinfo.folders)))
def run_test(self, args):
target = args.arcfile
if not py7zr.is_7zfile(target):
print('not a 7z file')
return(1)
with open(target, 'rb') as f:
a = py7zr.SevenZipFile(f)
file = sys.stdout
file.write("Testing archive: {}\n".format(a.filename))
self.print_archiveinfo(archive=a, file=file)
file.write('\n')
if a.testzip() is None:
file.write('Everything is Ok\n')
return(0)
else:
file.write('Bad 7zip file\n')
return(1)
def run_extract(self, args: argparse.Namespace) -> int:
target = args.arcfile
verbose = args.verbose
if not py7zr.is_7zfile(target):
print('not a 7z file')
return(1)
if not args.password:
password = None # type: Optional[str]
else:
try:
password = getpass.getpass()
except getpass.GetPassWarning:
sys.stderr.write('Warning: your password may be shown.\n')
return(1)
a = py7zr.SevenZipFile(target, 'r', password=password)
cb = None # Optional[ExtractCallback]
if verbose:
archive_info = a.archiveinfo()
cb = CliExtractCallback(total_bytes=archive_info.uncompressed, ofd=sys.stderr)
if args.odir:
a.extractall(path=args.odir, callback=cb)
else:
a.extractall(callback=cb)
return(0)
def _check_volumesize_valid(self, size: str) -> bool:
if self.unit_pattern.match(size):
return True
else:
return False
def _volumesize_unitconv(self, size: str) -> int:
m = self.unit_pattern.match(size)
num = m.group(1)
unit = m.group(2)
return int(num) if unit is None else int(num) * self.dunits[unit]
def run_create(self, args):
sztarget = args.arcfile # type: str
filenames = args.filenames # type: List[str]
volume_size = args.volume[0] if getattr(args, 'volume', None) is not None else None
if volume_size is not None and not self._check_volumesize_valid(volume_size):
sys.stderr.write('Error: Specified volume size is invalid.\n')
self.show_help(args)
exit(1)
if not sztarget.endswith('.7z'):
sztarget += '.7z'
target = pathlib.Path(sztarget)
if target.exists():
sys.stderr.write('Archive file exists!\n')
self.show_help(args)
exit(1)
with py7zr.SevenZipFile(target, 'w') as szf:
for path in filenames:
src = pathlib.Path(path)
if src.is_dir():
szf.writeall(src)
else:
szf.write(src)
if volume_size is None:
return (0)
size = self._volumesize_unitconv(volume_size)
self._split_file(target, size)
target.unlink()
return(0)
def _split_file(self, filepath, size):
chapters = 0
written = [0, 0]
total_size = filepath.stat().st_size
with filepath.open('rb') as src:
while written[0] <= total_size:
with open(str(filepath) + '.%03d' % chapters, 'wb') as tgt:
written[1] = 0
while written[1] < size:
read_size = min(READ_BLOCKSIZE, size - written[1])
tgt.write(src.read(read_size))
written[1] += read_size
written[0] += read_size
chapters += 1

View File

@@ -1,395 +0,0 @@
#!/usr/bin/python -u
#
# p7zr library
#
# Copyright (c) 2019 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
# 7-Zip Copyright (C) 1999-2010 Igor Pavlov
# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
import bz2
import io
import lzma
import os
import queue
import sys
import threading
from typing import IO, Any, BinaryIO, Dict, List, Optional, Union
from py7zr.exceptions import Bad7zFile, CrcError, UnsupportedCompressionMethodError
from py7zr.extra import AESDecompressor, CopyDecompressor, DeflateDecompressor, ISevenZipDecompressor, ZstdDecompressor
from py7zr.helpers import MemIO, NullIO, calculate_crc32, readlink
from py7zr.properties import READ_BLOCKSIZE, ArchivePassword, CompressionMethod
if sys.version_info < (3, 6):
import pathlib2 as pathlib
else:
import pathlib
try:
import zstandard as Zstd # type: ignore
except ImportError:
Zstd = None
class Worker:
"""Extract worker class to invoke handler"""
def __init__(self, files, src_start: int, header) -> None:
self.target_filepath = {} # type: Dict[int, Union[MemIO, pathlib.Path, None]]
self.files = files
self.src_start = src_start
self.header = header
def extract(self, fp: BinaryIO, parallel: bool, q=None) -> None:
"""Extract worker method to handle 7zip folder and decompress each files."""
if hasattr(self.header, 'main_streams') and self.header.main_streams is not None:
src_end = self.src_start + self.header.main_streams.packinfo.packpositions[-1]
numfolders = self.header.main_streams.unpackinfo.numfolders
if numfolders == 1:
self.extract_single(fp, self.files, self.src_start, src_end, q)
else:
folders = self.header.main_streams.unpackinfo.folders
positions = self.header.main_streams.packinfo.packpositions
empty_files = [f for f in self.files if f.emptystream]
if not parallel:
self.extract_single(fp, empty_files, 0, 0, q)
for i in range(numfolders):
self.extract_single(fp, folders[i].files, self.src_start + positions[i],
self.src_start + positions[i + 1], q)
else:
filename = getattr(fp, 'name', None)
self.extract_single(open(filename, 'rb'), empty_files, 0, 0, q)
extract_threads = []
for i in range(numfolders):
p = threading.Thread(target=self.extract_single,
args=(filename, folders[i].files,
self.src_start + positions[i], self.src_start + positions[i + 1], q))
p.start()
extract_threads.append((p))
for p in extract_threads:
p.join()
else:
empty_files = [f for f in self.files if f.emptystream]
self.extract_single(fp, empty_files, 0, 0, q)
def extract_single(self, fp: Union[BinaryIO, str], files, src_start: int, src_end: int,
q: Optional[queue.Queue]) -> None:
"""Single thread extractor that takes file lists in single 7zip folder."""
if files is None:
return
if isinstance(fp, str):
fp = open(fp, 'rb')
fp.seek(src_start)
for f in files:
if q is not None:
q.put(('s', str(f.filename), str(f.compressed) if f.compressed is not None else '0'))
fileish = self.target_filepath.get(f.id, None)
if fileish is not None:
fileish.parent.mkdir(parents=True, exist_ok=True)
with fileish.open(mode='wb') as ofp:
if not f.emptystream:
# extract to file
crc32 = self.decompress(fp, f.folder, ofp, f.uncompressed[-1], f.compressed, src_end)
ofp.seek(0)
if f.crc32 is not None and crc32 != f.crc32:
raise CrcError("{}".format(f.filename))
else:
pass # just create empty file
elif not f.emptystream:
# read and bin off a data but check crc
with NullIO() as ofp:
crc32 = self.decompress(fp, f.folder, ofp, f.uncompressed[-1], f.compressed, src_end)
if f.crc32 is not None and crc32 != f.crc32:
raise CrcError("{}".format(f.filename))
if q is not None:
q.put(('e', str(f.filename), str(f.uncompressed[-1])))
def decompress(self, fp: BinaryIO, folder, fq: IO[Any],
size: int, compressed_size: Optional[int], src_end: int) -> int:
"""decompressor wrapper called from extract method.
:parameter fp: archive source file pointer
:parameter folder: Folder object that have decompressor object.
:parameter fq: output file pathlib.Path
:parameter size: uncompressed size of target file.
:parameter compressed_size: compressed size of target file.
:parameter src_end: end position of the folder
:returns CRC32 of the file
"""
assert folder is not None
crc32 = 0
out_remaining = size
decompressor = folder.get_decompressor(compressed_size)
while out_remaining > 0:
max_length = min(out_remaining, io.DEFAULT_BUFFER_SIZE)
rest_size = src_end - fp.tell()
read_size = min(READ_BLOCKSIZE, rest_size)
if read_size == 0:
tmp = decompressor.decompress(b'', max_length)
if len(tmp) == 0:
raise Exception("decompression get wrong: no output data.")
else:
inp = fp.read(read_size)
tmp = decompressor.decompress(inp, max_length)
if len(tmp) > 0 and out_remaining >= len(tmp):
out_remaining -= len(tmp)
fq.write(tmp)
crc32 = calculate_crc32(tmp, crc32)
if out_remaining <= 0:
break
if fp.tell() >= src_end:
# Check folder.digest integrity.
if decompressor.crc is not None and not decompressor.check_crc():
raise Bad7zFile("Folder CRC32 error.")
return crc32
def _find_link_target(self, target):
"""Find the target member of a symlink or hardlink member in the archive.
"""
targetname = target.as_posix() # type: str
linkname = readlink(targetname)
# Check windows full path symlinks
if linkname.startswith("\\\\?\\"):
linkname = linkname[4:]
# normalize as posix style
linkname = pathlib.Path(linkname).as_posix() # type: str
member = None
for j in range(len(self.files)):
if linkname == self.files[j].origin.as_posix():
# FIXME: when API user specify arcname, it will break
member = os.path.relpath(linkname, os.path.dirname(targetname))
break
if member is None:
member = linkname
return member
def archive(self, fp: BinaryIO, folder, deref=False):
"""Run archive task for specified 7zip folder."""
compressor = folder.get_compressor()
outsize = 0
self.header.main_streams.packinfo.numstreams = 1
num_unpack_streams = 0
self.header.main_streams.substreamsinfo.digests = []
self.header.main_streams.substreamsinfo.digestsdefined = []
last_file_index = 0
foutsize = 0
for i, f in enumerate(self.files):
file_info = f.file_properties()
self.header.files_info.files.append(file_info)
self.header.files_info.emptyfiles.append(f.emptystream)
foutsize = 0
if f.is_symlink and not deref:
last_file_index = i
num_unpack_streams += 1
link_target = self._find_link_target(f.origin) # type: str
tgt = link_target.encode('utf-8') # type: bytes
insize = len(tgt)
crc = calculate_crc32(tgt, 0) # type: int
out = compressor.compress(tgt)
outsize += len(out)
foutsize += len(out)
fp.write(out)
self.header.main_streams.substreamsinfo.digests.append(crc)
self.header.main_streams.substreamsinfo.digestsdefined.append(True)
self.header.main_streams.substreamsinfo.unpacksizes.append(insize)
self.header.files_info.files[i]['maxsize'] = foutsize
elif not f.emptystream:
last_file_index = i
num_unpack_streams += 1
insize = 0
with f.origin.open(mode='rb') as fd:
data = fd.read(READ_BLOCKSIZE)
insize += len(data)
crc = 0
while data:
crc = calculate_crc32(data, crc)
out = compressor.compress(data)
outsize += len(out)
foutsize += len(out)
fp.write(out)
data = fd.read(READ_BLOCKSIZE)
insize += len(data)
self.header.main_streams.substreamsinfo.digests.append(crc)
self.header.main_streams.substreamsinfo.digestsdefined.append(True)
self.header.files_info.files[i]['maxsize'] = foutsize
self.header.main_streams.substreamsinfo.unpacksizes.append(insize)
else:
out = compressor.flush()
outsize += len(out)
foutsize += len(out)
fp.write(out)
if len(self.files) > 0:
self.header.files_info.files[last_file_index]['maxsize'] = foutsize
# Update size data in header
self.header.main_streams.packinfo.packsizes = [outsize]
folder.unpacksizes = [sum(self.header.main_streams.substreamsinfo.unpacksizes)]
self.header.main_streams.substreamsinfo.num_unpackstreams_folders = [num_unpack_streams]
def register_filelike(self, id: int, fileish: Union[MemIO, pathlib.Path, None]) -> None:
"""register file-ish to worker."""
self.target_filepath[id] = fileish
class SevenZipDecompressor:
"""Main decompressor object which is properly configured and bind to each 7zip folder.
because 7zip folder can have a custom compression method"""
lzma_methods_map = {
CompressionMethod.LZMA: lzma.FILTER_LZMA1,
CompressionMethod.LZMA2: lzma.FILTER_LZMA2,
CompressionMethod.DELTA: lzma.FILTER_DELTA,
CompressionMethod.P7Z_BCJ: lzma.FILTER_X86,
CompressionMethod.BCJ_ARM: lzma.FILTER_ARM,
CompressionMethod.BCJ_ARMT: lzma.FILTER_ARMTHUMB,
CompressionMethod.BCJ_IA64: lzma.FILTER_IA64,
CompressionMethod.BCJ_PPC: lzma.FILTER_POWERPC,
CompressionMethod.BCJ_SPARC: lzma.FILTER_SPARC,
}
FILTER_BZIP2 = 0x31
FILTER_ZIP = 0x32
FILTER_COPY = 0x33
FILTER_AES = 0x34
FILTER_ZSTD = 0x35
alt_methods_map = {
CompressionMethod.MISC_BZIP2: FILTER_BZIP2,
CompressionMethod.MISC_DEFLATE: FILTER_ZIP,
CompressionMethod.COPY: FILTER_COPY,
CompressionMethod.CRYPT_AES256_SHA256: FILTER_AES,
CompressionMethod.MISC_ZSTD: FILTER_ZSTD,
}
def __init__(self, coders: List[Dict[str, Any]], size: int, crc: Optional[int]) -> None:
# Get password which was set when creation of py7zr.SevenZipFile object.
self.input_size = size
self.consumed = 0 # type: int
self.crc = crc
self.digest = None # type: Optional[int]
if self._check_lzma_coders(coders):
self._set_lzma_decompressor(coders)
else:
self._set_alternative_decompressor(coders)
def _check_lzma_coders(self, coders: List[Dict[str, Any]]) -> bool:
res = True
for coder in coders:
if self.lzma_methods_map.get(coder['method'], None) is None:
res = False
break
return res
def _set_lzma_decompressor(self, coders: List[Dict[str, Any]]) -> None:
filters = [] # type: List[Dict[str, Any]]
for coder in coders:
if coder['numinstreams'] != 1 or coder['numoutstreams'] != 1:
raise UnsupportedCompressionMethodError('Only a simple compression method is currently supported.')
filter_id = self.lzma_methods_map.get(coder['method'], None)
if filter_id is None:
raise UnsupportedCompressionMethodError
properties = coder.get('properties', None)
if properties is not None:
filters[:0] = [lzma._decode_filter_properties(filter_id, properties)] # type: ignore
else:
filters[:0] = [{'id': filter_id}]
self.decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_RAW, filters=filters) # type: Union[bz2.BZ2Decompressor, lzma.LZMADecompressor, ISevenZipDecompressor] # noqa
def _set_alternative_decompressor(self, coders: List[Dict[str, Any]]) -> None:
filter_id = self.alt_methods_map.get(coders[0]['method'], None)
if filter_id == self.FILTER_BZIP2:
self.decompressor = bz2.BZ2Decompressor()
elif filter_id == self.FILTER_ZIP:
self.decompressor = DeflateDecompressor()
elif filter_id == self.FILTER_COPY:
self.decompressor = CopyDecompressor()
elif filter_id == self.FILTER_ZSTD and Zstd:
self.decompressor = ZstdDecompressor()
elif filter_id == self.FILTER_AES:
password = ArchivePassword().get()
properties = coders[0].get('properties', None)
self.decompressor = AESDecompressor(properties, password, coders[1:])
else:
raise UnsupportedCompressionMethodError
def decompress(self, data: bytes, max_length: Optional[int] = None) -> bytes:
self.consumed += len(data)
if max_length is not None:
folder_data = self.decompressor.decompress(data, max_length=max_length)
else:
folder_data = self.decompressor.decompress(data)
# calculate CRC with uncompressed data
if self.crc is not None:
self.digest = calculate_crc32(folder_data, self.digest)
return folder_data
def check_crc(self):
return self.crc == self.digest
class SevenZipCompressor:
"""Main compressor object to configured for each 7zip folder."""
__slots__ = ['filters', 'compressor', 'coders']
lzma_methods_map_r = {
lzma.FILTER_LZMA2: CompressionMethod.LZMA2,
lzma.FILTER_DELTA: CompressionMethod.DELTA,
lzma.FILTER_X86: CompressionMethod.P7Z_BCJ,
}
def __init__(self, filters=None):
if filters is None:
self.filters = [{"id": lzma.FILTER_LZMA2, "preset": 7 | lzma.PRESET_EXTREME}, ]
else:
self.filters = filters
self.compressor = lzma.LZMACompressor(format=lzma.FORMAT_RAW, filters=self.filters)
self.coders = []
for filter in self.filters:
if filter is None:
break
method = self.lzma_methods_map_r[filter['id']]
properties = lzma._encode_filter_properties(filter)
self.coders.append({'method': method, 'properties': properties, 'numinstreams': 1, 'numoutstreams': 1})
def compress(self, data):
return self.compressor.compress(data)
def flush(self):
return self.compressor.flush()
def get_methods_names(coders: List[dict]) -> List[str]:
"""Return human readable method names for specified coders"""
methods_name_map = {
CompressionMethod.LZMA2: "LZMA2",
CompressionMethod.LZMA: "LZMA",
CompressionMethod.DELTA: "delta",
CompressionMethod.P7Z_BCJ: "BCJ",
CompressionMethod.BCJ_ARM: "BCJ(ARM)",
CompressionMethod.BCJ_ARMT: "BCJ(ARMT)",
CompressionMethod.BCJ_IA64: "BCJ(IA64)",
CompressionMethod.BCJ_PPC: "BCJ(POWERPC)",
CompressionMethod.BCJ_SPARC: "BCJ(SPARC)",
CompressionMethod.CRYPT_AES256_SHA256: "7zAES",
}
methods_names = [] # type: List[str]
for coder in coders:
try:
methods_names.append(methods_name_map[coder['method']])
except KeyError:
raise UnsupportedCompressionMethodError("Unknown method {}".format(coder['method']))
return methods_names

View File

@@ -1,46 +0,0 @@
#
# p7zr library
#
# Copyright (c) 2019 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
# 7-Zip Copyright (C) 1999-2010 Igor Pavlov
# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
class ArchiveError(Exception):
pass
class Bad7zFile(ArchiveError):
pass
class CrcError(ArchiveError):
pass
class UnsupportedCompressionMethodError(ArchiveError):
pass
class DecompressionError(ArchiveError):
pass
class InternalError(ArchiveError):
pass

View File

@@ -1,214 +0,0 @@
#!/usr/bin/python -u
#
# p7zr library
#
# Copyright (c) 2019 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
# 7-Zip Copyright (C) 1999-2010 Igor Pavlov
# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
import lzma
import zlib
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from Crypto.Cipher import AES
from py7zr import UnsupportedCompressionMethodError
from py7zr.helpers import Buffer, calculate_key
from py7zr.properties import READ_BLOCKSIZE, CompressionMethod
try:
import zstandard as Zstd # type: ignore
except ImportError:
Zstd = None
class ISevenZipCompressor(ABC):
@abstractmethod
def compress(self, data: Union[bytes, bytearray, memoryview]) -> bytes:
pass
@abstractmethod
def flush(self) -> bytes:
pass
class ISevenZipDecompressor(ABC):
@abstractmethod
def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes:
pass
class DeflateDecompressor(ISevenZipDecompressor):
def __init__(self):
self.buf = b''
self._decompressor = zlib.decompressobj(-15)
def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1):
if max_length < 0:
res = self.buf + self._decompressor.decompress(data)
self.buf = b''
else:
tmp = self.buf + self._decompressor.decompress(data)
res = tmp[:max_length]
self.buf = tmp[max_length:]
return res
class CopyDecompressor(ISevenZipDecompressor):
def __init__(self):
self._buf = bytes()
def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes:
if max_length < 0:
length = len(data)
else:
length = min(len(data), max_length)
buflen = len(self._buf)
if length > buflen:
res = self._buf + data[:length - buflen]
self._buf = data[length - buflen:]
else:
res = self._buf[:length]
self._buf = self._buf[length:] + data
return res
class AESDecompressor(ISevenZipDecompressor):
lzma_methods_map = {
CompressionMethod.LZMA: lzma.FILTER_LZMA1,
CompressionMethod.LZMA2: lzma.FILTER_LZMA2,
CompressionMethod.DELTA: lzma.FILTER_DELTA,
CompressionMethod.P7Z_BCJ: lzma.FILTER_X86,
CompressionMethod.BCJ_ARM: lzma.FILTER_ARM,
CompressionMethod.BCJ_ARMT: lzma.FILTER_ARMTHUMB,
CompressionMethod.BCJ_IA64: lzma.FILTER_IA64,
CompressionMethod.BCJ_PPC: lzma.FILTER_POWERPC,
CompressionMethod.BCJ_SPARC: lzma.FILTER_SPARC,
}
def __init__(self, aes_properties: bytes, password: str, coders: List[Dict[str, Any]]) -> None:
byte_password = password.encode('utf-16LE')
firstbyte = aes_properties[0]
numcyclespower = firstbyte & 0x3f
if firstbyte & 0xc0 != 0:
saltsize = (firstbyte >> 7) & 1
ivsize = (firstbyte >> 6) & 1
secondbyte = aes_properties[1]
saltsize += (secondbyte >> 4)
ivsize += (secondbyte & 0x0f)
assert len(aes_properties) == 2 + saltsize + ivsize
salt = aes_properties[2:2 + saltsize]
iv = aes_properties[2 + saltsize:2 + saltsize + ivsize]
assert len(salt) == saltsize
assert len(iv) == ivsize
assert numcyclespower <= 24
if ivsize < 16:
iv += bytes('\x00' * (16 - ivsize), 'ascii')
key = calculate_key(byte_password, numcyclespower, salt, 'sha256')
if len(coders) > 0:
self.lzma_decompressor = self._set_lzma_decompressor(coders) # type: Union[lzma.LZMADecompressor, CopyDecompressor] # noqa
else:
self.lzma_decompressor = CopyDecompressor()
self.cipher = AES.new(key, AES.MODE_CBC, iv)
self.buf = Buffer(size=READ_BLOCKSIZE + 16)
self.flushed = False
else:
raise UnsupportedCompressionMethodError
# set pipeline decompressor
def _set_lzma_decompressor(self, coders: List[Dict[str, Any]]) -> lzma.LZMADecompressor:
filters = [] # type: List[Dict[str, Any]]
for coder in coders:
filter = self.lzma_methods_map.get(coder['method'], None)
if filter is not None:
properties = coder.get('properties', None)
if properties is not None:
filters[:0] = [lzma._decode_filter_properties(filter, properties)] # type: ignore
else:
filters[:0] = [{'id': filter}]
else:
raise UnsupportedCompressionMethodError
return lzma.LZMADecompressor(format=lzma.FORMAT_RAW, filters=filters)
def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes:
if len(data) == 0 and len(self.buf) == 0: # action flush
return self.lzma_decompressor.decompress(b'', max_length)
elif len(data) == 0: # action padding
self.flushded = True
# align = 16
# padlen = (align - offset % align) % align
# = (align - (offset & (align - 1))) & (align - 1)
# = -offset & (align -1)
# = -offset & (16 - 1) = -offset & 15
padlen = -len(self.buf) & 15
self.buf.add(bytes(padlen))
temp = self.cipher.decrypt(self.buf.view) # type: bytes
self.buf.reset()
return self.lzma_decompressor.decompress(temp, max_length)
else:
currentlen = len(self.buf) + len(data)
nextpos = (currentlen // 16) * 16
if currentlen == nextpos:
self.buf.add(data)
temp = self.cipher.decrypt(self.buf.view)
self.buf.reset()
return self.lzma_decompressor.decompress(temp, max_length)
else:
buflen = len(self.buf)
temp2 = data[nextpos - buflen:]
self.buf.add(data[:nextpos - buflen])
temp = self.cipher.decrypt(self.buf.view)
self.buf.set(temp2)
return self.lzma_decompressor.decompress(temp, max_length)
class ZstdDecompressor(ISevenZipDecompressor):
def __init__(self):
if Zstd is None:
raise UnsupportedCompressionMethodError
self.buf = b'' # type: bytes
self._ctc = Zstd.ZstdDecompressor() # type: ignore
def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes:
dobj = self._ctc.decompressobj() # type: ignore
if max_length < 0:
res = self.buf + dobj.decompress(data)
self.buf = b''
else:
tmp = self.buf + dobj.decompress(data)
res = tmp[:max_length]
self.buf = tmp[max_length:]
return res
class ZstdCompressor(ISevenZipCompressor):
def __init__(self):
if Zstd is None:
raise UnsupportedCompressionMethodError
self._ctc = Zstd.ZstdCompressor() # type: ignore
def compress(self, data: Union[bytes, bytearray, memoryview]) -> bytes:
return self._ctc.compress(data) # type: ignore
def flush(self):
pass

View File

@@ -1,397 +0,0 @@
#!/usr/bin/python -u
#
# p7zr library
#
# Copyright (c) 2019 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
#
import _hashlib # type: ignore # noqa
import ctypes
import os
import pathlib
import platform
import sys
import time as _time
import zlib
from datetime import datetime, timedelta, timezone, tzinfo
from typing import BinaryIO, Optional, Union
import py7zr.win32compat
def calculate_crc32(data: bytes, value: Optional[int] = None, blocksize: int = 1024 * 1024) -> int:
"""Calculate CRC32 of strings with arbitrary lengths."""
length = len(data)
pos = blocksize
if value:
value = zlib.crc32(data[:pos], value)
else:
value = zlib.crc32(data[:pos])
while pos < length:
value = zlib.crc32(data[pos:pos + blocksize], value)
pos += blocksize
return value & 0xffffffff
def _calculate_key1(password: bytes, cycles: int, salt: bytes, digest: str) -> bytes:
"""Calculate 7zip AES encryption key. Base implementation. """
if digest not in ('sha256'):
raise ValueError('Unknown digest method for password protection.')
assert cycles <= 0x3f
if cycles == 0x3f:
ba = bytearray(salt + password + bytes(32))
key = bytes(ba[:32]) # type: bytes
else:
rounds = 1 << cycles
m = _hashlib.new(digest)
for round in range(rounds):
m.update(salt + password + round.to_bytes(8, byteorder='little', signed=False))
key = m.digest()[:32]
return key
def _calculate_key2(password: bytes, cycles: int, salt: bytes, digest: str):
"""Calculate 7zip AES encryption key.
It utilize ctypes and memoryview buffer and zero-copy technology on Python."""
if digest not in ('sha256'):
raise ValueError('Unknown digest method for password protection.')
assert cycles <= 0x3f
if cycles == 0x3f:
key = bytes(bytearray(salt + password + bytes(32))[:32]) # type: bytes
else:
rounds = 1 << cycles
m = _hashlib.new(digest)
length = len(salt) + len(password)
class RoundBuf(ctypes.LittleEndianStructure):
_pack_ = 1
_fields_ = [
('saltpassword', ctypes.c_ubyte * length),
('round', ctypes.c_uint64)
]
buf = RoundBuf()
for i, c in enumerate(salt + password):
buf.saltpassword[i] = c
buf.round = 0
mv = memoryview(buf) # type: ignore # noqa
while buf.round < rounds:
m.update(mv)
buf.round += 1
key = m.digest()[:32]
return key
def _calculate_key3(password: bytes, cycles: int, salt: bytes, digest: str) -> bytes:
"""Calculate 7zip AES encryption key.
Concat values in order to reduce number of calls of Hash.update()."""
if digest not in ('sha256'):
raise ValueError('Unknown digest method for password protection.')
assert cycles <= 0x3f
if cycles == 0x3f:
ba = bytearray(salt + password + bytes(32))
key = bytes(ba[:32]) # type: bytes
else:
cat_cycle = 6
if cycles > cat_cycle:
rounds = 1 << cat_cycle
stages = 1 << (cycles - cat_cycle)
else:
rounds = 1 << cycles
stages = 1 << 0
m = _hashlib.new(digest)
saltpassword = salt + password
s = 0 # type: int # (0..stages) * rounds
if platform.python_implementation() == "PyPy":
for _ in range(stages):
m.update(memoryview(b''.join([saltpassword + (s + i).to_bytes(8, byteorder='little', signed=False)
for i in range(rounds)])))
s += rounds
else:
for _ in range(stages):
m.update(b''.join([saltpassword + (s + i).to_bytes(8, byteorder='little', signed=False)
for i in range(rounds)]))
s += rounds
key = m.digest()[:32]
return key
if platform.python_implementation() == "PyPy" or sys.version_info > (3, 6):
calculate_key = _calculate_key3
else:
calculate_key = _calculate_key2 # it is faster when CPython 3.6.x
def filetime_to_dt(ft):
"""Convert Windows NTFS file time into python datetime object."""
EPOCH_AS_FILETIME = 116444736000000000
us = (ft - EPOCH_AS_FILETIME) // 10
return datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(microseconds=us)
ZERO = timedelta(0)
HOUR = timedelta(hours=1)
SECOND = timedelta(seconds=1)
# A class capturing the platform's idea of local time.
# (May result in wrong values on historical times in
# timezones where UTC offset and/or the DST rules had
# changed in the past.)
STDOFFSET = timedelta(seconds=-_time.timezone)
if _time.daylight:
DSTOFFSET = timedelta(seconds=-_time.altzone)
else:
DSTOFFSET = STDOFFSET
DSTDIFF = DSTOFFSET - STDOFFSET
class LocalTimezone(tzinfo):
def fromutc(self, dt):
assert dt.tzinfo is self
stamp = (dt - datetime(1970, 1, 1, tzinfo=self)) // SECOND
args = _time.localtime(stamp)[:6]
dst_diff = DSTDIFF // SECOND
# Detect fold
fold = (args == _time.localtime(stamp - dst_diff))
return datetime(*args, microsecond=dt.microsecond, tzinfo=self)
def utcoffset(self, dt):
if self._isdst(dt):
return DSTOFFSET
else:
return STDOFFSET
def dst(self, dt):
if self._isdst(dt):
return DSTDIFF
else:
return ZERO
def tzname(self, dt):
return _time.tzname[self._isdst(dt)]
def _isdst(self, dt):
tt = (dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second,
dt.weekday(), 0, 0)
stamp = _time.mktime(tt)
tt = _time.localtime(stamp)
return tt.tm_isdst > 0
Local = LocalTimezone()
TIMESTAMP_ADJUST = -11644473600
class UTC(tzinfo):
"""UTC"""
def utcoffset(self, dt):
return ZERO
def tzname(self, dt):
return "UTC"
def dst(self, dt):
return ZERO
def _call__(self):
return self
class ArchiveTimestamp(int):
"""Windows FILETIME timestamp."""
def __repr__(self):
return '%s(%d)' % (type(self).__name__, self)
def totimestamp(self) -> float:
"""Convert 7z FILETIME to Python timestamp."""
# FILETIME is 100-nanosecond intervals since 1601/01/01 (UTC)
return (self / 10000000.0) + TIMESTAMP_ADJUST
def as_datetime(self):
"""Convert FILETIME to Python datetime object."""
return datetime.fromtimestamp(self.totimestamp(), UTC())
@staticmethod
def from_datetime(val):
return ArchiveTimestamp((val - TIMESTAMP_ADJUST) * 10000000.0)
def islink(path):
"""
Cross-platform islink implementation.
Supports Windows NT symbolic links and reparse points.
"""
is_symlink = os.path.islink(str(path))
if sys.version_info >= (3, 8) or sys.platform != "win32" or sys.getwindowsversion()[0] < 6:
return is_symlink
# special check for directory junctions which py38 does.
if is_symlink:
if py7zr.win32compat.is_reparse_point(path):
is_symlink = False
return is_symlink
def readlink(path: Union[str, pathlib.Path], *, dir_fd=None) -> Union[str, pathlib.Path]:
"""
Cross-platform compat implementation of os.readlink and Path.readlink().
Supports Windows NT symbolic links and reparse points.
When called with path argument as pathlike(str), return result as a pathlike(str).
When called with Path object, return also Path object.
When called with path argument as bytes, return result as a bytes.
"""
is_path_pathlib = isinstance(path, pathlib.Path)
if sys.version_info >= (3, 9):
if is_path_pathlib and dir_fd is None:
return path.readlink()
else:
return os.readlink(path, dir_fd=dir_fd)
elif sys.version_info >= (3, 8) or sys.platform != "win32":
res = os.readlink(path, dir_fd=dir_fd)
# Hack to handle a wrong type of results
if isinstance(res, bytes):
res = os.fsdecode(res)
if is_path_pathlib:
return pathlib.Path(res)
else:
return res
elif not os.path.exists(str(path)):
raise OSError(22, 'Invalid argument', path)
return py7zr.win32compat.readlink(path)
class MemIO:
"""pathlib.Path-like IO class to write memory(io.Bytes)"""
def __init__(self, buf: BinaryIO):
self._buf = buf
def write(self, data: bytes) -> int:
return self._buf.write(data)
def read(self, length: Optional[int] = None) -> bytes:
if length is not None:
return self._buf.read(length)
else:
return self._buf.read()
def close(self) -> None:
self._buf.seek(0)
def flush(self) -> None:
pass
def seek(self, position: int) -> None:
self._buf.seek(position)
def open(self, mode=None):
return self
@property
def parent(self):
return self
def mkdir(self, parents=None, exist_ok=False):
return None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class NullIO:
"""pathlib.Path-like IO class of /dev/null"""
def __init__(self):
pass
def write(self, data):
return len(data)
def read(self, length=None):
if length is not None:
return bytes(length)
else:
return b''
def close(self):
pass
def flush(self):
pass
def open(self, mode=None):
return self
@property
def parent(self):
return self
def mkdir(self):
return None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class BufferOverflow(Exception):
pass
class Buffer:
def __init__(self, size: int = 16):
self._size = size
self._buf = bytearray(size)
self._buflen = 0
self.view = memoryview(self._buf[0:0])
def add(self, data: Union[bytes, bytearray, memoryview]):
length = len(data)
if length + self._buflen > self._size:
raise BufferOverflow()
self._buf[self._buflen:self._buflen + length] = data
self._buflen += length
self.view = memoryview(self._buf[0:self._buflen])
def reset(self) -> None:
self._buflen = 0
self.view = memoryview(self._buf[0:0])
def set(self, data: Union[bytes, bytearray, memoryview]) -> None:
length = len(data)
if length > self._size:
raise BufferOverflow()
self._buf[0:length] = data
self._buflen = length
self.view = memoryview(self._buf[0:length])
def __len__(self) -> int:
return self._buflen

View File

@@ -1,155 +0,0 @@
#
# p7zr library
#
# Copyright (c) 2019 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
# 7-Zip Copyright (C) 1999-2010 Igor Pavlov
# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
import binascii
from enum import Enum
from typing import Optional
MAGIC_7Z = binascii.unhexlify('377abcaf271c')
FINISH_7Z = binascii.unhexlify('377abcaf271d')
READ_BLOCKSIZE = 32248
QUEUELEN = READ_BLOCKSIZE * 2
READ_BLOCKSIZE = 32248
class ByteEnum(bytes, Enum):
pass
class Property(ByteEnum):
"""Hold 7zip property fixed values."""
END = binascii.unhexlify('00')
HEADER = binascii.unhexlify('01')
ARCHIVE_PROPERTIES = binascii.unhexlify('02')
ADDITIONAL_STREAMS_INFO = binascii.unhexlify('03')
MAIN_STREAMS_INFO = binascii.unhexlify('04')
FILES_INFO = binascii.unhexlify('05')
PACK_INFO = binascii.unhexlify('06')
UNPACK_INFO = binascii.unhexlify('07')
SUBSTREAMS_INFO = binascii.unhexlify('08')
SIZE = binascii.unhexlify('09')
CRC = binascii.unhexlify('0a')
FOLDER = binascii.unhexlify('0b')
CODERS_UNPACK_SIZE = binascii.unhexlify('0c')
NUM_UNPACK_STREAM = binascii.unhexlify('0d')
EMPTY_STREAM = binascii.unhexlify('0e')
EMPTY_FILE = binascii.unhexlify('0f')
ANTI = binascii.unhexlify('10')
NAME = binascii.unhexlify('11')
CREATION_TIME = binascii.unhexlify('12')
LAST_ACCESS_TIME = binascii.unhexlify('13')
LAST_WRITE_TIME = binascii.unhexlify('14')
ATTRIBUTES = binascii.unhexlify('15')
COMMENT = binascii.unhexlify('16')
ENCODED_HEADER = binascii.unhexlify('17')
START_POS = binascii.unhexlify('18')
DUMMY = binascii.unhexlify('19')
class CompressionMethod(ByteEnum):
"""Hold fixed values for method parameter."""
COPY = binascii.unhexlify('00')
DELTA = binascii.unhexlify('03')
BCJ = binascii.unhexlify('04')
PPC = binascii.unhexlify('05')
IA64 = binascii.unhexlify('06')
ARM = binascii.unhexlify('07')
ARMT = binascii.unhexlify('08')
SPARC = binascii.unhexlify('09')
# SWAP = 02..
SWAP2 = binascii.unhexlify('020302')
SWAP4 = binascii.unhexlify('020304')
# 7Z = 03..
LZMA = binascii.unhexlify('030101')
PPMD = binascii.unhexlify('030401')
P7Z_BCJ = binascii.unhexlify('03030103')
P7Z_BCJ2 = binascii.unhexlify('0303011B')
BCJ_PPC = binascii.unhexlify('03030205')
BCJ_IA64 = binascii.unhexlify('03030401')
BCJ_ARM = binascii.unhexlify('03030501')
BCJ_ARMT = binascii.unhexlify('03030701')
BCJ_SPARC = binascii.unhexlify('03030805')
LZMA2 = binascii.unhexlify('21')
# MISC : 04..
MISC_ZIP = binascii.unhexlify('0401')
MISC_BZIP2 = binascii.unhexlify('040202')
MISC_DEFLATE = binascii.unhexlify('040108')
MISC_DEFLATE64 = binascii.unhexlify('040109')
MISC_Z = binascii.unhexlify('0405')
MISC_LZH = binascii.unhexlify('0406')
NSIS_DEFLATE = binascii.unhexlify('040901')
NSIS_BZIP2 = binascii.unhexlify('040902')
#
MISC_ZSTD = binascii.unhexlify('04f71101')
MISC_BROTLI = binascii.unhexlify('04f71102')
MISC_LZ4 = binascii.unhexlify('04f71104')
MISC_LZS = binascii.unhexlify('04f71105')
MISC_LIZARD = binascii.unhexlify('04f71106')
# CRYPTO 06..
CRYPT_ZIPCRYPT = binascii.unhexlify('06f10101')
CRYPT_RAR29AES = binascii.unhexlify('06f10303')
CRYPT_AES256_SHA256 = binascii.unhexlify('06f10701')
class SupportedMethods:
"""Hold list of methods which python3 can support."""
formats = [{'name': "7z", 'magic': MAGIC_7Z}]
codecs = [{'id': CompressionMethod.LZMA, 'name': "LZMA"},
{'id': CompressionMethod.LZMA2, 'name': "LZMA2"},
{'id': CompressionMethod.DELTA, 'name': "DELTA"},
{'id': CompressionMethod.P7Z_BCJ, 'name': "BCJ"},
{'id': CompressionMethod.BCJ_PPC, 'name': 'PPC'},
{'id': CompressionMethod.BCJ_IA64, 'name': 'IA64'},
{'id': CompressionMethod.BCJ_ARM, 'name': "ARM"},
{'id': CompressionMethod.BCJ_ARMT, 'name': "ARMT"},
{'id': CompressionMethod.BCJ_SPARC, 'name': 'SPARC'}
]
# this class is Borg/Singleton
class ArchivePassword:
_shared_state = {
'_password': None,
}
def __init__(self, password: Optional[str] = None):
self.__dict__ = self._shared_state
if password is not None:
self._password = password
def set(self, password):
self._password = password
def get(self):
if self._password is not None:
return self._password
else:
return ''
def __str__(self):
if self._password is not None:
return self._password
else:
return ''

View File

@@ -1,974 +0,0 @@
#!/usr/bin/python -u
#
# p7zr library
#
# Copyright (c) 2019,2020 Hiroshi Miura <miurahr@linux.com>
# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de
# 7-Zip Copyright (C) 1999-2010 Igor Pavlov
# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
#
"""Read 7zip format archives."""
import collections.abc
import datetime
import errno
import functools
import io
import operator
import os
import queue
import stat
import sys
import threading
from io import BytesIO
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union
from py7zr.archiveinfo import Folder, Header, SignatureHeader
from py7zr.callbacks import ExtractCallback
from py7zr.compression import SevenZipCompressor, Worker, get_methods_names
from py7zr.exceptions import Bad7zFile, CrcError, DecompressionError, InternalError
from py7zr.helpers import ArchiveTimestamp, MemIO, calculate_crc32, filetime_to_dt
from py7zr.properties import MAGIC_7Z, READ_BLOCKSIZE, ArchivePassword
if sys.version_info < (3, 6):
import contextlib2 as contextlib
import pathlib2 as pathlib
else:
import contextlib
import pathlib
if sys.platform.startswith('win'):
import _winapi
FILE_ATTRIBUTE_UNIX_EXTENSION = 0x8000
FILE_ATTRIBUTE_WINDOWS_MASK = 0x04fff
class ArchiveFile:
"""Represent each files metadata inside archive file.
It holds file properties; filename, permissions, and type whether
it is directory, link or normal file.
Instances of the :class:`ArchiveFile` class are returned by iterating :attr:`files_list` of
:class:`SevenZipFile` objects.
Each object stores information about a single member of the 7z archive. Most of users use :meth:`extractall()`.
The class also hold an archive parameter where file is exist in
archive file folder(container)."""
def __init__(self, id: int, file_info: Dict[str, Any]) -> None:
self.id = id
self._file_info = file_info
def file_properties(self) -> Dict[str, Any]:
"""Return file properties as a hash object. Following keys are included: readonly, is_directory,
posix_mode, archivable, emptystream, filename, creationtime, lastaccesstime,
lastwritetime, attributes
"""
properties = self._file_info
if properties is not None:
properties['readonly'] = self.readonly
properties['posix_mode'] = self.posix_mode
properties['archivable'] = self.archivable
properties['is_directory'] = self.is_directory
return properties
def _get_property(self, key: str) -> Any:
try:
return self._file_info[key]
except KeyError:
return None
@property
def origin(self) -> pathlib.Path:
return self._get_property('origin')
@property
def folder(self) -> Folder:
return self._get_property('folder')
@property
def filename(self) -> str:
"""return filename of archive file."""
return self._get_property('filename')
@property
def emptystream(self) -> bool:
"""True if file is empty(0-byte file), otherwise False"""
return self._get_property('emptystream')
@property
def uncompressed(self) -> List[int]:
return self._get_property('uncompressed')
@property
def uncompressed_size(self) -> int:
"""Uncompressed file size."""
return functools.reduce(operator.add, self.uncompressed)
@property
def compressed(self) -> Optional[int]:
"""Compressed size"""
return self._get_property('compressed')
@property
def crc32(self) -> Optional[int]:
"""CRC of archived file(optional)"""
return self._get_property('digest')
def _test_attribute(self, target_bit: int) -> bool:
attributes = self._get_property('attributes')
if attributes is None:
return False
return attributes & target_bit == target_bit
@property
def archivable(self) -> bool:
"""File has a Windows `archive` flag."""
return self._test_attribute(stat.FILE_ATTRIBUTE_ARCHIVE) # type: ignore # noqa
@property
def is_directory(self) -> bool:
"""True if file is a directory, otherwise False."""
return self._test_attribute(stat.FILE_ATTRIBUTE_DIRECTORY) # type: ignore # noqa
@property
def readonly(self) -> bool:
"""True if file is readonly, otherwise False."""
return self._test_attribute(stat.FILE_ATTRIBUTE_READONLY) # type: ignore # noqa
def _get_unix_extension(self) -> Optional[int]:
attributes = self._get_property('attributes')
if self._test_attribute(FILE_ATTRIBUTE_UNIX_EXTENSION):
return attributes >> 16
return None
@property
def is_symlink(self) -> bool:
"""True if file is a symbolic link, otherwise False."""
e = self._get_unix_extension()
if e is not None:
return stat.S_ISLNK(e)
return self._test_attribute(stat.FILE_ATTRIBUTE_REPARSE_POINT) # type: ignore # noqa
@property
def is_junction(self) -> bool:
"""True if file is a junction/reparse point on windows, otherwise False."""
return self._test_attribute(stat.FILE_ATTRIBUTE_REPARSE_POINT | # type: ignore # noqa
stat.FILE_ATTRIBUTE_DIRECTORY) # type: ignore # noqa
@property
def is_socket(self) -> bool:
"""True if file is a socket, otherwise False."""
e = self._get_unix_extension()
if e is not None:
return stat.S_ISSOCK(e)
return False
@property
def lastwritetime(self) -> Optional[ArchiveTimestamp]:
"""Return last written timestamp of a file."""
return self._get_property('lastwritetime')
@property
def posix_mode(self) -> Optional[int]:
"""
posix mode when a member has a unix extension property, or None
:return: Return file stat mode can be set by os.chmod()
"""
e = self._get_unix_extension()
if e is not None:
return stat.S_IMODE(e)
return None
@property
def st_fmt(self) -> Optional[int]:
"""
:return: Return the portion of the file mode that describes the file type
"""
e = self._get_unix_extension()
if e is not None:
return stat.S_IFMT(e)
return None
class ArchiveFileList(collections.abc.Iterable):
"""Iteratable container of ArchiveFile."""
def __init__(self, offset: int = 0):
self.files_list = [] # type: List[dict]
self.index = 0
self.offset = offset
def append(self, file_info: Dict[str, Any]) -> None:
self.files_list.append(file_info)
def __len__(self) -> int:
return len(self.files_list)
def __iter__(self) -> 'ArchiveFileListIterator':
return ArchiveFileListIterator(self)
def __getitem__(self, index):
if index > len(self.files_list):
raise IndexError
if index < 0:
raise IndexError
res = ArchiveFile(index + self.offset, self.files_list[index])
return res
class ArchiveFileListIterator(collections.abc.Iterator):
def __init__(self, archive_file_list):
self._archive_file_list = archive_file_list
self._index = 0
def __next__(self) -> ArchiveFile:
if self._index == len(self._archive_file_list):
raise StopIteration
res = self._archive_file_list[self._index]
self._index += 1
return res
# ------------------
# Exported Classes
# ------------------
class ArchiveInfo:
"""Hold archive information"""
def __init__(self, filename, size, header_size, method_names, solid, blocks, uncompressed):
self.filename = filename
self.size = size
self.header_size = header_size
self.method_names = method_names
self.solid = solid
self.blocks = blocks
self.uncompressed = uncompressed
class FileInfo:
"""Hold archived file information."""
def __init__(self, filename, compressed, uncompressed, archivable, is_directory, creationtime, crc32):
self.filename = filename
self.compressed = compressed
self.uncompressed = uncompressed
self.archivable = archivable
self.is_directory = is_directory
self.creationtime = creationtime
self.crc32 = crc32
class SevenZipFile(contextlib.AbstractContextManager):
"""The SevenZipFile Class provides an interface to 7z archives."""
def __init__(self, file: Union[BinaryIO, str, pathlib.Path], mode: str = 'r',
*, filters: Optional[str] = None, dereference=False, password: Optional[str] = None) -> None:
if mode not in ('r', 'w', 'x', 'a'):
raise ValueError("ZipFile requires mode 'r', 'w', 'x', or 'a'")
if password is not None:
if mode not in ('r'):
raise NotImplementedError("It has not been implemented to create archive with password.")
ArchivePassword(password)
self.password_protected = True
else:
self.password_protected = False
# Check if we were passed a file-like object or not
if isinstance(file, str):
self._filePassed = False # type: bool
self.filename = file # type: str
if mode == 'r':
self.fp = open(file, 'rb') # type: BinaryIO
elif mode == 'w':
self.fp = open(file, 'w+b')
elif mode == 'x':
self.fp = open(file, 'x+b')
elif mode == 'a':
self.fp = open(file, 'r+b')
else:
raise ValueError("File open error.")
self.mode = mode
elif isinstance(file, pathlib.Path):
self._filePassed = False
self.filename = str(file)
if mode == 'r':
self.fp = file.open(mode='rb') # type: ignore # noqa # typeshed issue: 2911
elif mode == 'w':
self.fp = file.open(mode='w+b') # type: ignore # noqa
elif mode == 'x':
self.fp = file.open(mode='x+b') # type: ignore # noqa
elif mode == 'a':
self.fp = file.open(mode='r+b') # type: ignore # noqa
else:
raise ValueError("File open error.")
self.mode = mode
elif isinstance(file, io.IOBase):
self._filePassed = True
self.fp = file
self.filename = getattr(file, 'name', None)
self.mode = mode # type: ignore #noqa
else:
raise TypeError("invalid file: {}".format(type(file)))
self._fileRefCnt = 1
try:
if mode == "r":
self._real_get_contents(self.fp)
self._reset_worker()
elif mode in 'w':
# FIXME: check filters here
self.folder = self._create_folder(filters)
self.files = ArchiveFileList()
self._prepare_write()
self._reset_worker()
elif mode in 'x':
raise NotImplementedError
elif mode == 'a':
raise NotImplementedError
else:
raise ValueError("Mode must be 'r', 'w', 'x', or 'a'")
except Exception as e:
self._fpclose()
raise e
self.encoded_header_mode = False
self._dict = {} # type: Dict[str, IO[Any]]
self.dereference = dereference
self.reporterd = None # type: Optional[threading.Thread]
self.q = queue.Queue() # type: queue.Queue[Any]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def _create_folder(self, filters):
folder = Folder()
folder.compressor = SevenZipCompressor(filters)
folder.coders = folder.compressor.coders
folder.solid = True
folder.digestdefined = False
folder.bindpairs = []
folder.totalin = 1
folder.totalout = 1
return folder
def _fpclose(self) -> None:
assert self._fileRefCnt > 0
self._fileRefCnt -= 1
if not self._fileRefCnt and not self._filePassed:
self.fp.close()
def _real_get_contents(self, fp: BinaryIO) -> None:
if not self._check_7zfile(fp):
raise Bad7zFile('not a 7z file')
self.sig_header = SignatureHeader.retrieve(self.fp)
self.afterheader = self.fp.tell()
buffer = self._read_header_data()
header = Header.retrieve(self.fp, buffer, self.afterheader)
if header is None:
return
self.header = header
buffer.close()
self.files = ArchiveFileList()
if getattr(self.header, 'files_info', None) is not None:
self._filelist_retrieve()
def _read_header_data(self) -> BytesIO:
self.fp.seek(self.sig_header.nextheaderofs, os.SEEK_CUR)
buffer = io.BytesIO(self.fp.read(self.sig_header.nextheadersize))
if self.sig_header.nextheadercrc != calculate_crc32(buffer.getvalue()):
raise Bad7zFile('invalid header data')
return buffer
class ParseStatus:
def __init__(self, src_pos=0):
self.src_pos = src_pos
self.folder = 0 # 7zip folder where target stored
self.outstreams = 0 # output stream count
self.input = 0 # unpack stream count in each folder
self.stream = 0 # target input stream position
def _gen_filename(self) -> str:
# compressed file is stored without a name, generate one
try:
basefilename = self.filename
except AttributeError:
# 7z archive file doesn't have a name
return 'contents'
else:
if basefilename is not None:
fn, ext = os.path.splitext(os.path.basename(basefilename))
return fn
else:
return 'contents'
def _get_fileinfo_sizes(self, pstat, subinfo, packinfo, folder, packsizes, unpacksizes, file_in_solid, numinstreams):
if pstat.input == 0:
folder.solid = subinfo.num_unpackstreams_folders[pstat.folder] > 1
maxsize = (folder.solid and packinfo.packsizes[pstat.stream]) or None
uncompressed = unpacksizes[pstat.outstreams]
if not isinstance(uncompressed, (list, tuple)):
uncompressed = [uncompressed] * len(folder.coders)
if file_in_solid > 0:
compressed = None
elif pstat.stream < len(packsizes): # file is compressed
compressed = packsizes[pstat.stream]
else: # file is not compressed
compressed = uncompressed
packsize = packsizes[pstat.stream:pstat.stream + numinstreams]
return maxsize, compressed, uncompressed, packsize, folder.solid
def _filelist_retrieve(self) -> None:
# Initialize references for convenience
if hasattr(self.header, 'main_streams') and self.header.main_streams is not None:
folders = self.header.main_streams.unpackinfo.folders
packinfo = self.header.main_streams.packinfo
subinfo = self.header.main_streams.substreamsinfo
packsizes = packinfo.packsizes
unpacksizes = subinfo.unpacksizes if subinfo.unpacksizes is not None else [x.unpacksizes for x in folders]
else:
subinfo = None
folders = None
packinfo = None
packsizes = []
unpacksizes = [0]
pstat = self.ParseStatus()
pstat.src_pos = self.afterheader
file_in_solid = 0
for file_id, file_info in enumerate(self.header.files_info.files):
if not file_info['emptystream'] and folders is not None:
folder = folders[pstat.folder]
numinstreams = max([coder.get('numinstreams', 1) for coder in folder.coders])
(maxsize, compressed, uncompressed,
packsize, solid) = self._get_fileinfo_sizes(pstat, subinfo, packinfo, folder, packsizes,
unpacksizes, file_in_solid, numinstreams)
pstat.input += 1
folder.solid = solid
file_info['folder'] = folder
file_info['maxsize'] = maxsize
file_info['compressed'] = compressed
file_info['uncompressed'] = uncompressed
file_info['packsizes'] = packsize
if subinfo.digestsdefined[pstat.outstreams]:
file_info['digest'] = subinfo.digests[pstat.outstreams]
if folder is None:
pstat.src_pos += file_info['compressed']
else:
if folder.solid:
file_in_solid += 1
pstat.outstreams += 1
if folder.files is None:
folder.files = ArchiveFileList(offset=file_id)
folder.files.append(file_info)
if pstat.input >= subinfo.num_unpackstreams_folders[pstat.folder]:
file_in_solid = 0
pstat.src_pos += sum(packinfo.packsizes[pstat.stream:pstat.stream + numinstreams])
pstat.folder += 1
pstat.stream += numinstreams
pstat.input = 0
else:
file_info['folder'] = None
file_info['maxsize'] = 0
file_info['compressed'] = 0
file_info['uncompressed'] = [0]
file_info['packsizes'] = [0]
if 'filename' not in file_info:
file_info['filename'] = self._gen_filename()
self.files.append(file_info)
def _num_files(self) -> int:
if getattr(self.header, 'files_info', None) is not None:
return len(self.header.files_info.files)
return 0
def _set_file_property(self, outfilename: pathlib.Path, properties: Dict[str, Any]) -> None:
# creation time
creationtime = ArchiveTimestamp(properties['lastwritetime']).totimestamp()
if creationtime is not None:
os.utime(str(outfilename), times=(creationtime, creationtime))
if os.name == 'posix':
st_mode = properties['posix_mode']
if st_mode is not None:
outfilename.chmod(st_mode)
return
# fallback: only set readonly if specified
if properties['readonly'] and not properties['is_directory']:
ro_mask = 0o777 ^ (stat.S_IWRITE | stat.S_IWGRP | stat.S_IWOTH)
outfilename.chmod(outfilename.stat().st_mode & ro_mask)
def _reset_decompressor(self) -> None:
if self.header.main_streams is not None and self.header.main_streams.unpackinfo.numfolders > 0:
for i, folder in enumerate(self.header.main_streams.unpackinfo.folders):
folder.decompressor = None
def _reset_worker(self) -> None:
"""Seek to where archive data start in archive and recreate new worker."""
self.fp.seek(self.afterheader)
self.worker = Worker(self.files, self.afterheader, self.header)
def set_encoded_header_mode(self, mode: bool) -> None:
self.encoded_header_mode = mode
@staticmethod
def _check_7zfile(fp: Union[BinaryIO, io.BufferedReader]) -> bool:
result = MAGIC_7Z == fp.read(len(MAGIC_7Z))[:len(MAGIC_7Z)]
fp.seek(-len(MAGIC_7Z), 1)
return result
def _get_method_names(self) -> str:
methods_names = [] # type: List[str]
for folder in self.header.main_streams.unpackinfo.folders:
methods_names += get_methods_names(folder.coders)
return ', '.join(x for x in methods_names)
def _test_digest_raw(self, pos: int, size: int, crc: int) -> bool:
self.fp.seek(pos)
remaining_size = size
digest = None
while remaining_size > 0:
block = min(READ_BLOCKSIZE, remaining_size)
digest = calculate_crc32(self.fp.read(block), digest)
remaining_size -= block
return digest == crc
def _prepare_write(self) -> None:
self.sig_header = SignatureHeader()
self.sig_header._write_skelton(self.fp)
self.afterheader = self.fp.tell()
self.folder.totalin = 1
self.folder.totalout = 1
self.folder.bindpairs = []
self.folder.unpacksizes = []
self.header = Header.build_header([self.folder])
def _write_archive(self):
self.worker.archive(self.fp, self.folder, deref=self.dereference)
# Write header and update signature header
(header_pos, header_len, header_crc) = self.header.write(self.fp, self.afterheader,
encoded=self.encoded_header_mode)
self.sig_header.nextheaderofs = header_pos - self.afterheader
self.sig_header.calccrc(header_len, header_crc)
self.sig_header.write(self.fp)
return
def _is_solid(self):
for f in self.header.main_streams.substreamsinfo.num_unpackstreams_folders:
if f > 1:
return True
return False
def _var_release(self):
self._dict = None
self.files = None
self.folder = None
self.header = None
self.worker = None
self.sig_header = None
@staticmethod
def _make_file_info(target: pathlib.Path, arcname: Optional[str] = None, dereference=False) -> Dict[str, Any]:
f = {} # type: Dict[str, Any]
f['origin'] = target
if arcname is not None:
f['filename'] = pathlib.Path(arcname).as_posix()
else:
f['filename'] = target.as_posix()
if os.name == 'nt':
fstat = target.lstat()
if target.is_symlink():
if dereference:
fstat = target.stat()
if stat.S_ISDIR(fstat.st_mode):
f['emptystream'] = True
f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa
else:
f['emptystream'] = False
f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa
f['uncompressed'] = fstat.st_size
else:
f['emptystream'] = False
f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa
# f['attributes'] |= stat.FILE_ATTRIBUTE_REPARSE_POINT # type: ignore # noqa
elif target.is_dir():
f['emptystream'] = True
f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa
elif target.is_file():
f['emptystream'] = False
f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa
f['uncompressed'] = fstat.st_size
else:
fstat = target.lstat()
if target.is_symlink():
if dereference:
fstat = target.stat()
if stat.S_ISDIR(fstat.st_mode):
f['emptystream'] = True
f['attributes'] = stat.FILE_ATTRIBUTE_DIRECTORY # type: ignore # noqa
f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFDIR << 16)
f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16)
else:
f['emptystream'] = False
f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa
f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IMODE(fstat.st_mode) << 16)
else:
f['emptystream'] = False
f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE | stat.FILE_ATTRIBUTE_REPARSE_POINT # type: ignore # noqa
f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFLNK << 16)
f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16)
elif target.is_dir():
f['emptystream'] = True
f['attributes'] = stat.FILE_ATTRIBUTE_DIRECTORY # type: ignore # noqa
f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFDIR << 16)
f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16)
elif target.is_file():
f['emptystream'] = False
f['uncompressed'] = fstat.st_size
f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa
f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IMODE(fstat.st_mode) << 16)
f['creationtime'] = fstat.st_ctime
f['lastwritetime'] = fstat.st_mtime
f['lastaccesstime'] = fstat.st_atime
return f
# --------------------------------------------------------------------------
# The public methods which SevenZipFile provides:
def getnames(self) -> List[str]:
"""Return the members of the archive as a list of their names. It has
the same order as the list returned by getmembers().
"""
return list(map(lambda x: x.filename, self.files))
def archiveinfo(self) -> ArchiveInfo:
fstat = os.stat(self.filename)
uncompressed = 0
for f in self.files:
uncompressed += f.uncompressed_size
return ArchiveInfo(self.filename, fstat.st_size, self.header.size, self._get_method_names(),
self._is_solid(), len(self.header.main_streams.unpackinfo.folders),
uncompressed)
def list(self) -> List[FileInfo]:
"""Returns contents information """
alist = [] # type: List[FileInfo]
creationtime = None # type: Optional[datetime.datetime]
for f in self.files:
if f.lastwritetime is not None:
creationtime = filetime_to_dt(f.lastwritetime)
alist.append(FileInfo(f.filename, f.compressed, f.uncompressed_size, f.archivable, f.is_directory,
creationtime, f.crc32))
return alist
def readall(self) -> Optional[Dict[str, IO[Any]]]:
return self._extract(path=None, return_dict=True)
def extractall(self, path: Optional[Any] = None, callback: Optional[ExtractCallback] = None) -> None:
"""Extract all members from the archive to the current working
directory and set owner, modification time and permissions on
directories afterwards. `path' specifies a different directory
to extract to.
"""
self._extract(path=path, return_dict=False, callback=callback)
def read(self, targets: Optional[List[str]] = None) -> Optional[Dict[str, IO[Any]]]:
return self._extract(path=None, targets=targets, return_dict=True)
def extract(self, path: Optional[Any] = None, targets: Optional[List[str]] = None) -> None:
self._extract(path, targets, return_dict=False)
def _extract(self, path: Optional[Any] = None, targets: Optional[List[str]] = None,
return_dict: bool = False, callback: Optional[ExtractCallback] = None) -> Optional[Dict[str, IO[Any]]]:
if callback is not None and not isinstance(callback, ExtractCallback):
raise ValueError('Callback specified is not a subclass of py7zr.callbacks.ExtractCallback class')
elif callback is not None:
self.reporterd = threading.Thread(target=self.reporter, args=(callback,), daemon=True)
self.reporterd.start()
target_junction = [] # type: List[pathlib.Path]
target_sym = [] # type: List[pathlib.Path]
target_files = [] # type: List[Tuple[pathlib.Path, Dict[str, Any]]]
target_dirs = [] # type: List[pathlib.Path]
if path is not None:
if isinstance(path, str):
path = pathlib.Path(path)
try:
if not path.exists():
path.mkdir(parents=True)
else:
pass
except OSError as e:
if e.errno == errno.EEXIST and path.is_dir():
pass
else:
raise e
fnames = [] # type: List[str] # check duplicated filename in one archive?
self.q.put(('pre', None, None))
for f in self.files:
# TODO: sanity check
# check whether f.filename with invalid characters: '../'
if f.filename.startswith('../'):
raise Bad7zFile
# When archive has a multiple files which have same name
# To guarantee order of archive, multi-thread decompression becomes off.
# Currently always overwrite by latter archives.
# TODO: provide option to select overwrite or skip.
if f.filename not in fnames:
outname = f.filename
else:
i = 0
while True:
outname = f.filename + '_%d' % i
if outname not in fnames:
break
fnames.append(outname)
if path is not None:
outfilename = path.joinpath(outname)
else:
outfilename = pathlib.Path(outname)
if os.name == 'nt':
if outfilename.is_absolute():
# hack for microsoft windows path length limit < 255
outfilename = pathlib.WindowsPath('\\\\?\\' + str(outfilename))
if targets is not None and f.filename not in targets:
self.worker.register_filelike(f.id, None)
continue
if f.is_directory:
if not outfilename.exists():
target_dirs.append(outfilename)
target_files.append((outfilename, f.file_properties()))
else:
pass
elif f.is_socket:
pass
elif return_dict:
fname = outfilename.as_posix()
_buf = io.BytesIO()
self._dict[fname] = _buf
self.worker.register_filelike(f.id, MemIO(_buf))
elif f.is_symlink:
target_sym.append(outfilename)
try:
if outfilename.exists():
outfilename.unlink()
except OSError as ose:
if ose.errno not in [errno.ENOENT]:
raise
self.worker.register_filelike(f.id, outfilename)
elif f.is_junction:
target_junction.append(outfilename)
self.worker.register_filelike(f.id, outfilename)
else:
self.worker.register_filelike(f.id, outfilename)
target_files.append((outfilename, f.file_properties()))
for target_dir in sorted(target_dirs):
try:
target_dir.mkdir()
except FileExistsError:
if target_dir.is_dir():
pass
elif target_dir.is_file():
raise DecompressionError("Directory {} is existed as a normal file.".format(str(target_dir)))
else:
raise DecompressionError("Directory {} making fails on unknown condition.".format(str(target_dir)))
try:
if callback is not None:
self.worker.extract(self.fp, parallel=(not self.password_protected and not self._filePassed), q=self.q)
else:
self.worker.extract(self.fp, parallel=(not self.password_protected and not self._filePassed))
except CrcError as ce:
raise Bad7zFile("CRC32 error on archived file {}.".format(str(ce)))
self.q.put(('post', None, None))
if return_dict:
return self._dict
else:
# create symbolic links on target path as a working directory.
# if path is None, work on current working directory.
for t in target_sym:
sym_dst = t.resolve()
with sym_dst.open('rb') as b:
sym_src = b.read().decode(encoding='utf-8') # symlink target name stored in utf-8
sym_dst.unlink() # unlink after close().
sym_dst.symlink_to(pathlib.Path(sym_src))
# create junction point only on windows platform
if sys.platform.startswith('win'):
for t in target_junction:
junction_dst = t.resolve()
with junction_dst.open('rb') as b:
junction_target = pathlib.Path(b.read().decode(encoding='utf-8'))
junction_dst.unlink()
_winapi.CreateJunction(junction_target, str(junction_dst)) # type: ignore # noqa
# set file properties
for o, p in target_files:
self._set_file_property(o, p)
return None
def reporter(self, callback: ExtractCallback):
while True:
try:
item = self.q.get(timeout=1) # type: Optional[Tuple[str, str, str]]
except queue.Empty:
pass
else:
if item is None:
break
elif item[0] == 's':
callback.report_start(item[1], item[2])
elif item[0] == 'e':
callback.report_end(item[1], item[2])
elif item[0] == 'pre':
callback.report_start_preparation()
elif item[0] == 'post':
callback.report_postprocess()
elif item[0] == 'w':
callback.report_warning(item[1])
else:
pass
self.q.task_done()
def writeall(self, path: Union[pathlib.Path, str], arcname: Optional[str] = None):
"""Write files in target path into archive."""
if isinstance(path, str):
path = pathlib.Path(path)
if not path.exists():
raise ValueError("specified path does not exist.")
if path.is_dir() or path.is_file():
self._writeall(path, arcname)
else:
raise ValueError("specified path is not a directory or a file")
def _writeall(self, path, arcname):
try:
if path.is_symlink() and not self.dereference:
self.write(path, arcname)
elif path.is_file():
self.write(path, arcname)
elif path.is_dir():
if not path.samefile('.'):
self.write(path, arcname)
for nm in sorted(os.listdir(str(path))):
arc = os.path.join(arcname, nm) if arcname is not None else None
self._writeall(path.joinpath(nm), arc)
else:
return # pathlib ignores ELOOP and return False for is_*().
except OSError as ose:
if self.dereference and ose.errno in [errno.ELOOP]:
return # ignore ELOOP here, this resulted to stop looped symlink reference.
elif self.dereference and sys.platform == 'win32' and ose.errno in [errno.ENOENT]:
return # ignore ENOENT which is happened when a case of ELOOP on windows.
else:
raise
def write(self, file: Union[pathlib.Path, str], arcname: Optional[str] = None):
"""Write single target file into archive(Not implemented yet)."""
if isinstance(file, str):
path = pathlib.Path(file)
elif isinstance(file, pathlib.Path):
path = file
else:
raise ValueError("Unsupported file type.")
file_info = self._make_file_info(path, arcname, self.dereference)
self.files.append(file_info)
def close(self):
"""Flush all the data into archive and close it.
When close py7zr start reading target and writing actual archive file.
"""
if 'w' in self.mode:
self._write_archive()
if 'r' in self.mode:
if self.reporterd is not None:
self.q.put_nowait(None)
self.reporterd.join(1)
if self.reporterd.is_alive():
raise InternalError("Progress report thread terminate error.")
self.reporterd = None
self._fpclose()
self._var_release()
def reset(self) -> None:
"""When read mode, it reset file pointer, decompress worker and decompressor"""
if self.mode == 'r':
self._reset_worker()
self._reset_decompressor()
def test(self) -> Optional[bool]:
self._reset_worker()
crcs = self.header.main_streams.packinfo.crcs # type: Optional[List[int]]
if crcs is None or len(crcs) == 0:
return None
# check packed stream's crc
assert len(crcs) == len(self.header.main_streams.packinfo.packpositions)
for i, p in enumerate(self.header.main_streams.packinfo.packpositions):
if not self._test_digest_raw(p, self.header.main_streams.packinfo.packsizes[i], crcs[i]):
return False
return True
def testzip(self) -> Optional[str]:
self._reset_worker()
for f in self.files:
self.worker.register_filelike(f.id, None)
try:
self.worker.extract(self.fp, parallel=(not self.password_protected)) # TODO: print progress
except CrcError as crce:
return str(crce)
else:
return None
# --------------------
# exported functions
# --------------------
def is_7zfile(file: Union[BinaryIO, str, pathlib.Path]) -> bool:
"""Quickly see if a file is a 7Z file by checking the magic number.
The file argument may be a filename or file-like object too.
"""
result = False
try:
if isinstance(file, io.IOBase) and hasattr(file, "read"):
result = SevenZipFile._check_7zfile(file) # type: ignore # noqa
elif isinstance(file, str):
with open(file, 'rb') as fp:
result = SevenZipFile._check_7zfile(fp)
elif isinstance(file, pathlib.Path) or isinstance(file, pathlib.PosixPath) or \
isinstance(file, pathlib.WindowsPath):
with file.open(mode='rb') as fp: # type: ignore # noqa
result = SevenZipFile._check_7zfile(fp)
else:
raise TypeError('invalid type: file should be str, pathlib.Path or BinaryIO, but {}'.format(type(file)))
except OSError:
pass
return result
def unpack_7zarchive(archive, path, extra=None):
"""Function for registering with shutil.register_unpack_format()"""
arc = SevenZipFile(archive)
arc.extractall(path)
arc.close()
def pack_7zarchive(base_name, base_dir, owner=None, group=None, dry_run=None, logger=None):
"""Function for registering with shutil.register_archive_format()"""
target_name = '{}.7z'.format(base_name)
archive = SevenZipFile(target_name, mode='w')
archive.writeall(path=base_dir)
archive.close()

View File

@@ -1,571 +0,0 @@
import getpass
import lzma
import os
import re
import sys
import pytest
import py7zr
import py7zr.archiveinfo
import py7zr.callbacks
import py7zr.cli
import py7zr.compression
import py7zr.properties
from . import check_output, decode_all, ltime2
if sys.version_info < (3, 6):
import pathlib2 as pathlib
else:
import pathlib
testdata_path = os.path.join(os.path.dirname(__file__), 'data')
os.umask(0o022)
@pytest.mark.basic
def test_basic_initinfo():
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
assert archive is not None
@pytest.mark.cli
def test_cli_list_1(capsys):
arc = os.path.join(testdata_path, 'test_1.7z')
expected = """total 4 files and directories in solid archive
Date Time Attr Size Compressed Name
------------------- ----- ------------ ------------ ------------------------
"""
expected += "{} D.... 0 0 scripts\n".format(ltime2(2019, 3, 14, 0, 10, 8))
expected += "{} ....A 111 441 scripts/py7zr\n".format(ltime2(2019, 3, 14, 0, 10, 8))
expected += "{} ....A 58 setup.cfg\n".format(ltime2(2019, 3, 14, 0, 7, 13))
expected += "{} ....A 559 setup.py\n".format(ltime2(2019, 3, 14, 0, 9, 1))
expected += "------------------- ----- ------------ ------------ ------------------------\n"
cli = py7zr.cli.Cli()
cli.run(["l", arc])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.basic
def test_cli_list_2(capsys):
arc = os.path.join(testdata_path, 'test_3.7z')
expected = "total 28 files and directories in solid archive\n"
expected += " Date Time Attr Size Compressed Name\n"
expected += "------------------- ----- ------------ ------------ ------------------------\n"
expected += "{} D.... 0 0 5.9.7\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64\n".format(ltime2(2018, 10, 18, 14, 52, 43)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/include\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/include/QtX11Extras\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/lib\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/lib/cmake\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/lib/cmake/Qt5X11Extras\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/lib/pkgconfig\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/mkspecs\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} D.... 0 0 5.9.7/gcc_64/mkspecs/modules\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} ....A 26 8472 5.9.7/gcc_64/include/QtX11Extras/QX11Info\n".format(ltime2(2018, 10, 16, 10, 26, 21)) # noqa: E501
expected += "{} ....A 176 5.9.7/gcc_64/include/QtX11Extras/QtX11Extras\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 201 5.9.7/gcc_64/include/QtX11Extras/QtX11ExtrasDepends\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 32 5.9.7/gcc_64/include/QtX11Extras/QtX11ExtrasVersion\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 722 5.9.7/gcc_64/lib/libQt5X11Extras.la\n".format(ltime2(2018, 10, 16, 10, 26, 27)) # noqa: E501
expected += "{} ....A 2280 5.9.7/gcc_64/include/QtX11Extras/qtx11extrasglobal.h\n".format(ltime2(2018, 10, 16, 10, 26, 21)) # noqa: E501
expected += "{} ....A 222 5.9.7/gcc_64/include/QtX11Extras/qtx11extrasversion.h\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 2890 5.9.7/gcc_64/include/QtX11Extras/qx11info_x11.h\n".format(ltime2(2018, 10, 16, 10, 26, 21)) # noqa: E501
expected += "{} ....A 24 5.9.7/gcc_64/lib/libQt5X11Extras.so\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} ....A 24 5.9.7/gcc_64/lib/libQt5X11Extras.so.5\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} ....A 14568 5.9.7/gcc_64/lib/libQt5X11Extras.so.5.9.7\n".format(ltime2(2018, 10, 16, 10, 26, 27)) # noqa: E501
expected += "{} ....A 24 5.9.7/gcc_64/lib/libQt5X11Extras.so.5.9\n".format(ltime2(2018, 10, 18, 14, 52, 42)) # noqa: E501
expected += "{} ....A 6704 5.9.7/gcc_64/lib/cmake/Qt5X11Extras/Qt5X11ExtrasConfig.cmake\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 287 5.9.7/gcc_64/lib/cmake/Qt5X11Extras/Qt5X11ExtrasConfigVersion.cmake\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 283 5.9.7/gcc_64/lib/pkgconfig/Qt5X11Extras.pc\n".format(ltime2(2018, 10, 16, 10, 26, 27)) # noqa: E501
expected += "{} ....A 555 5.9.7/gcc_64/mkspecs/modules/qt_lib_x11extras.pri\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 526 5.9.7/gcc_64/mkspecs/modules/qt_lib_x11extras_private.pri\n".format(ltime2(2018, 10, 16, 10, 26, 24)) # noqa: E501
expected += "{} ....A 1064 5.9.7/gcc_64/lib/libQt5X11Extras.prl\n".format(ltime2(2018, 10, 18, 10, 28, 16)) # noqa: E501
expected += "------------------- ----- ------------ ------------ ------------------------\n"
cli = py7zr.cli.Cli()
cli.run(["l", arc])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.api
def test_basic_not_implemented_yet1(tmp_path):
with pytest.raises(NotImplementedError):
py7zr.SevenZipFile(tmp_path.joinpath('test_x.7z'), mode='x')
@pytest.mark.api
def test_write_mode(tmp_path):
py7zr.SevenZipFile(tmp_path.joinpath('test_w.7z'), mode='w')
@pytest.mark.api
def test_basic_not_implemented_yet3(tmp_path):
with tmp_path.joinpath('test_a.7z').open('w') as f:
f.write('foo')
with pytest.raises(NotImplementedError):
py7zr.SevenZipFile(tmp_path.joinpath('test_a.7z'), mode='a')
@pytest.mark.api
def test_basic_wrong_option_value(tmp_path):
with pytest.raises(ValueError):
py7zr.SevenZipFile(tmp_path.joinpath('test_p.7z'), mode='p')
@pytest.mark.basic
def test_basic_extract_1(tmp_path):
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
expected = [{'filename': 'setup.cfg', 'mode': 33188, 'mtime': 1552522033,
'digest': 'ff77878e070c4ba52732b0c847b5a055a7c454731939c3217db4a7fb4a1e7240'},
{'filename': 'setup.py', 'mode': 33188, 'mtime': 1552522141,
'digest': 'b916eed2a4ee4e48c51a2b51d07d450de0be4dbb83d20e67f6fd166ff7921e49'},
{'filename': 'scripts/py7zr', 'mode': 33261, 'mtime': 1552522208,
'digest': 'b0385e71d6a07eb692f5fb9798e9d33aaf87be7dfff936fd2473eab2a593d4fd'}
]
decode_all(archive, expected, tmp_path)
@pytest.mark.basic
def test_basic_extract_2(tmp_path):
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_2.7z'), 'rb'))
expected = [{'filename': 'qt.qt5.597.gcc_64/installscript.qs',
'digest': '39445276e79ea43c0fa8b393b35dc621fcb2045cb82238ddf2b838a4fbf8a587'}]
decode_all(archive, expected, tmp_path)
@pytest.mark.basic
def test_basic_decode_3(tmp_path):
"""Test when passing path string instead of file-like object."""
archive = py7zr.SevenZipFile(os.path.join(testdata_path, 'test_1.7z'))
expected = [{'filename': 'setup.cfg', 'mode': 33188, 'mtime': 1552522033,
'digest': 'ff77878e070c4ba52732b0c847b5a055a7c454731939c3217db4a7fb4a1e7240'}]
decode_all(archive, expected, tmp_path)
@pytest.mark.api
def test_py7zr_is_7zfile():
assert py7zr.is_7zfile(os.path.join(testdata_path, 'test_1.7z'))
@pytest.mark.api
def test_py7zr_is_7zfile_fileish():
assert py7zr.is_7zfile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
@pytest.mark.api
def test_py7zr_is_7zfile_path():
assert py7zr.is_7zfile(pathlib.Path(testdata_path).joinpath('test_1.7z'))
@pytest.mark.basic
def test_py7zr_is_not_7zfile(tmp_path):
target = tmp_path.joinpath('test_not.7z')
with target.open('wb') as f:
f.write(b'12345dahodjg98adfjfak;')
with target.open('rb') as f:
assert not py7zr.is_7zfile(f)
@pytest.mark.cli
def test_cli_help(capsys):
expected = "usage: py7zr [-h] {l,x,c,t,i}"
cli = py7zr.cli.Cli()
with pytest.raises(SystemExit):
cli.run(["-h"])
out, err = capsys.readouterr()
assert out.startswith(expected)
@pytest.mark.cli
def test_cli_no_subcommand(capsys):
expected = "usage: py7zr [-h] {l,x,c,t,i}"
cli = py7zr.cli.Cli()
cli.run([])
out, err = capsys.readouterr()
assert out.startswith(expected)
@pytest.mark.cli
def test_cli_list_verbose(capsys):
arcfile = os.path.join(testdata_path, "test_1.7z")
expected = """Listing archive: {}
--
Path = {}
Type = 7z
Phisical Size = 657
Headers Size = 0
Method = LZMA2
Solid = +
Blocks = 1
total 4 files and directories in solid archive
Date Time Attr Size Compressed Name
------------------- ----- ------------ ------------ ------------------------
""".format(arcfile, arcfile)
expected += "{} D.... 0 0 scripts\n".format(ltime2(2019, 3, 14, 0, 10, 8))
expected += "{} ....A 111 441 scripts/py7zr\n".format(ltime2(2019, 3, 14, 0, 10, 8))
expected += "{} ....A 58 setup.cfg\n".format(ltime2(2019, 3, 14, 0, 7, 13))
expected += "{} ....A 559 setup.py\n".format(ltime2(2019, 3, 14, 0, 9, 1))
expected += "------------------- ----- ------------ ------------ ------------------------\n"
cli = py7zr.cli.Cli()
cli.run(["l", "--verbose", arcfile])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.cli
def test_cli_test(capsys):
arcfile = os.path.join(testdata_path, 'test_2.7z')
expected = """Testing archive: {}
--
Path = {}
Type = 7z
Phisical Size = 1663
Headers Size = 0
Method = LZMA2
Solid = -
Blocks = 1
Everything is Ok
""".format(arcfile, arcfile)
cli = py7zr.cli.Cli()
cli.run(["t", arcfile])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.cli
def test_cli_info(capsys):
if lzma.is_check_supported(lzma.CHECK_CRC64):
check0 = "\nCHECK_CRC64"
else:
check0 = ""
if lzma.is_check_supported(lzma.CHECK_SHA256):
check1 = "\nCHECK_SHA256"
else:
check1 = ""
expected_checks = """Checks:
CHECK_NONE
CHECK_CRC32{}{}""".format(check0, check1)
expected = """py7zr version {} {}
Formats:
7z 37 7a bc af 27 1c
Codecs:
030101 LZMA
21 LZMA2
03 DELTA
03030103 BCJ
03030205 PPC
03030401 IA64
03030501 ARM
03030701 ARMT
03030805 SPARC
{}
""".format(py7zr.__version__, py7zr.__copyright__, expected_checks)
cli = py7zr.cli.Cli()
cli.run(["i"])
out, err = capsys.readouterr()
assert expected == out
@pytest.mark.cli
def test_cli_extract(tmp_path):
arcfile = os.path.join(testdata_path, "test_1.7z")
cli = py7zr.cli.Cli()
cli.run(["x", arcfile, str(tmp_path.resolve())])
expected = [{'filename': 'setup.cfg', 'mode': 33188, 'mtime': 1552522033,
'digest': 'ff77878e070c4ba52732b0c847b5a055a7c454731939c3217db4a7fb4a1e7240'},
{'filename': 'setup.py', 'mode': 33188, 'mtime': 1552522141,
'digest': 'b916eed2a4ee4e48c51a2b51d07d450de0be4dbb83d20e67f6fd166ff7921e49'},
{'filename': 'scripts/py7zr', 'mode': 33261, 'mtime': 1552522208,
'digest': 'b0385e71d6a07eb692f5fb9798e9d33aaf87be7dfff936fd2473eab2a593d4fd'}
]
check_output(expected, tmp_path)
@pytest.mark.cli
def test_cli_encrypted_extract(monkeypatch, tmp_path):
def _getpasswd():
return 'secret'
monkeypatch.setattr(getpass, "getpass", _getpasswd)
arcfile = os.path.join(testdata_path, "encrypted_1.7z")
cli = py7zr.cli.Cli()
cli.run(["x", "--password", arcfile, str(tmp_path.resolve())])
expected = [{'filename': 'test1.txt', 'mode': 33188,
'digest': '0f16b2f4c3a74b9257cd6229c0b7b91855b3260327ef0a42ecf59c44d065c5b2'},
{'filename': 'test/test2.txt', 'mode': 33188,
'digest': '1d0d28682fca74c5912ea7e3f6878ccfdb6e4e249b161994b7f2870e6649ef09'}
]
check_output(expected, tmp_path)
@pytest.mark.basic
def test_digests():
arcfile = os.path.join(testdata_path, "test_2.7z")
archive = py7zr.SevenZipFile(arcfile)
assert archive.test() is None
assert archive.testzip() is None
@pytest.mark.basic
def test_digests_corrupted():
arcfile = os.path.join(testdata_path, "crc_corrupted.7z")
with py7zr.SevenZipFile(arcfile) as archive:
assert archive.test() is None
assert archive.testzip().endswith('src/scripts/py7zr')
@pytest.mark.cli
def test_non7z_ext(capsys, tmp_path):
expected = "not a 7z file\n"
arcfile = os.path.join(testdata_path, "test_1.txt")
cli = py7zr.cli.Cli()
cli.run(["x", arcfile, str(tmp_path.resolve())])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.cli
def test_non7z_test(capsys):
expected = "not a 7z file\n"
arcfile = os.path.join(testdata_path, "test_1.txt")
cli = py7zr.cli.Cli()
cli.run(["t", arcfile])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.cli
def test_non7z_list(capsys):
expected = "not a 7z file\n"
arcfile = os.path.join(testdata_path, "test_1.txt")
cli = py7zr.cli.Cli()
cli.run(["l", arcfile])
out, err = capsys.readouterr()
assert out == expected
@pytest.mark.cli
@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher")
def test_archive_creation(tmp_path, capsys):
tmp_path.joinpath('src').mkdir()
py7zr.unpack_7zarchive(os.path.join(testdata_path, 'test_1.7z'), path=tmp_path.joinpath('src'))
os.chdir(str(tmp_path))
target = "target.7z"
source = 'src'
cli = py7zr.cli.Cli()
cli.run(['c', target, source])
out, err = capsys.readouterr()
@pytest.mark.cli
def test_archive_already_exist(tmp_path, capsys):
expected = 'Archive file exists!\n'
py7zr.unpack_7zarchive(os.path.join(testdata_path, 'test_1.7z'), path=tmp_path.joinpath('src'))
target = tmp_path / "target.7z"
with target.open('w') as f:
f.write('Already exist!')
source = str(tmp_path / 'src')
cli = py7zr.cli.Cli()
with pytest.raises(SystemExit):
cli.run(['c', str(target), source])
out, err = capsys.readouterr()
assert err == expected
@pytest.mark.cli
@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher")
def test_archive_without_extension(tmp_path, capsys):
py7zr.unpack_7zarchive(os.path.join(testdata_path, 'test_1.7z'), path=tmp_path.joinpath('src'))
target = str(tmp_path / "target")
source = str(tmp_path / 'src')
cli = py7zr.cli.Cli()
cli.run(['c', target, source])
expected_target = tmp_path / "target.7z"
assert expected_target.exists()
@pytest.mark.cli
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.6 or higher")
def test_volume_creation(tmp_path, capsys):
tmp_path.joinpath('src').mkdir()
py7zr.unpack_7zarchive(os.path.join(testdata_path, 'lzma2bcj.7z'), path=tmp_path.joinpath('src'))
target = str(tmp_path / "target.7z")
source = str(tmp_path / 'src')
cli = py7zr.cli.Cli()
cli.run(['c', target, source, '-v', '2m'])
out, err = capsys.readouterr()
@pytest.mark.cli
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.6 or higher")
def test_volume_creation_wrong_volume_unit(tmp_path, capsys):
expected = 'Error: Specified volume size is invalid.\n'
target = str(tmp_path / "target.7z")
source = tmp_path / 'src'
source.mkdir()
cli = py7zr.cli.Cli()
with pytest.raises(SystemExit):
cli.run(['c', target, str(source), '-v', '2P'])
out, err = capsys.readouterr()
assert err == expected
@pytest.mark.unit
def test_py7zr_write_mode(tmp_path):
target = tmp_path.joinpath('target.7z')
archive = py7zr.SevenZipFile(target, 'w')
archive.write(os.path.join(testdata_path, "test1.txt"), "test1.txt")
assert archive.files is not None
assert len(archive.files) == 1
for f in archive.files:
assert f.filename in ('test1.txt')
assert not f.emptystream
@pytest.mark.api
def test_py7zr_writeall_single(tmp_path):
target = tmp_path.joinpath('target.7z')
archive = py7zr.SevenZipFile(target, 'w')
archive.writeall(os.path.join(testdata_path, "test1.txt"), "test1.txt")
assert archive.files is not None
assert len(archive.files) == 1
for f in archive.files:
assert f.filename in ('test1.txt')
assert not f.emptystream
@pytest.mark.api
def test_py7zr_writeall_dir(tmp_path):
target = tmp_path.joinpath('target.7z')
archive = py7zr.SevenZipFile(target, 'w')
archive.writeall(os.path.join(testdata_path, "src"), "src")
assert archive.files is not None
assert len(archive.files) == 2
for f in archive.files:
assert f.filename in ('src', 'src/bra.txt')
archive._fpclose()
@pytest.mark.api
def test_py7zr_extract_specified_file(tmp_path):
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
expected = [{'filename': 'scripts/py7zr', 'mode': 33261, 'mtime': 1552522208,
'digest': 'b0385e71d6a07eb692f5fb9798e9d33aaf87be7dfff936fd2473eab2a593d4fd'}
]
archive.extract(path=tmp_path, targets=['scripts', 'scripts/py7zr'])
archive.close()
assert tmp_path.joinpath('scripts').is_dir()
assert tmp_path.joinpath('scripts/py7zr').exists()
assert not tmp_path.joinpath('setup.cfg').exists()
assert not tmp_path.joinpath('setup.py').exists()
check_output(expected, tmp_path)
@pytest.mark.api
def test_py7zr_extract_and_getnames(tmp_path):
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
allfiles = archive.getnames()
filter_pattern = re.compile(r'scripts.*')
targets = []
for f in allfiles:
if filter_pattern.match(f):
targets.append(f)
archive.extract(path=tmp_path, targets=targets)
archive.close()
assert tmp_path.joinpath('scripts').is_dir()
assert tmp_path.joinpath('scripts/py7zr').exists()
assert not tmp_path.joinpath('setup.cfg').exists()
assert not tmp_path.joinpath('setup.py').exists()
@pytest.mark.api
def test_py7zr_extract_and_reset_iteration(tmp_path):
archive = py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb'))
iterations = archive.getnames()
for target in iterations:
archive.extract(path=tmp_path, targets=[target])
archive.reset()
archive.close()
assert tmp_path.joinpath('scripts').is_dir()
assert tmp_path.joinpath('scripts/py7zr').exists()
assert tmp_path.joinpath('setup.cfg').exists()
assert tmp_path.joinpath('setup.py').exists()
@pytest.mark.api
def test_context_manager_1(tmp_path):
with py7zr.SevenZipFile(os.path.join(testdata_path, 'test_1.7z'), 'r') as z:
z.extractall(path=tmp_path)
assert tmp_path.joinpath('scripts').is_dir()
assert tmp_path.joinpath('scripts/py7zr').exists()
assert tmp_path.joinpath('setup.cfg').exists()
assert tmp_path.joinpath('setup.py').exists()
@pytest.mark.api
def test_context_manager_2(tmp_path):
target = tmp_path.joinpath('target.7z')
with py7zr.SevenZipFile(target, 'w') as z:
z.writeall(os.path.join(testdata_path, "src"), "src")
@pytest.mark.api
def test_extract_callback(tmp_path):
class ECB(py7zr.callbacks.ExtractCallback):
def __init__(self, ofd):
self.ofd = ofd
def report_start_preparation(self):
self.ofd.write('preparation.\n')
def report_start(self, processing_file_path, processing_bytes):
self.ofd.write('start \"{}\" (compressed in {} bytes)\n'.format(processing_file_path, processing_bytes))
def report_end(self, processing_file_path, wrote_bytes):
self.ofd.write('end \"{}\" extracted to {} bytes\n'.format(processing_file_path, wrote_bytes))
def report_postprocess(self):
self.ofd.write('post processing.\n')
def report_warning(self, message):
self.ofd.write('warning: {:s}\n'.format(message))
cb = ECB(sys.stdout)
with py7zr.SevenZipFile(open(os.path.join(testdata_path, 'test_1.7z'), 'rb')) as archive:
archive.extractall(path=tmp_path, callback=cb)
@pytest.mark.api
def test_py7zr_list_values():
with py7zr.SevenZipFile(os.path.join(testdata_path, 'test_1.7z'), 'r') as z:
file_list = z.list()
assert file_list[0].filename == 'scripts'
assert file_list[1].filename == 'scripts/py7zr'
assert file_list[2].filename == 'setup.cfg'
assert file_list[3].filename == 'setup.py'
assert file_list[0].uncompressed == 0
assert file_list[1].uncompressed == 111
assert file_list[2].uncompressed == 58
assert file_list[3].uncompressed == 559
assert file_list[0].is_directory is True
assert file_list[1].archivable is True
assert file_list[2].archivable is True
assert file_list[3].archivable is True
assert file_list[0].compressed == 0
assert file_list[1].compressed == 441
assert file_list[2].compressed is None
assert file_list[3].compressed is None
assert file_list[0].crc32 is None
assert file_list[1].crc32 == 0xb36aaedb
assert file_list[2].crc32 == 0xdcbf8d07
assert file_list[3].crc32 == 0x80fc72be

View File

@@ -1,174 +0,0 @@
import pathlib
import stat
import sys
from logging import getLogger
from typing import Union
if sys.platform == "win32":
import ctypes
from ctypes.wintypes import BOOL, DWORD, HANDLE, LPCWSTR, LPDWORD, LPVOID, LPWSTR
_stdcall_libraries = {}
_stdcall_libraries['kernel32'] = ctypes.WinDLL('kernel32')
CloseHandle = _stdcall_libraries['kernel32'].CloseHandle
CreateFileW = _stdcall_libraries['kernel32'].CreateFileW
DeviceIoControl = _stdcall_libraries['kernel32'].DeviceIoControl
GetFileAttributesW = _stdcall_libraries['kernel32'].GetFileAttributesW
OPEN_EXISTING = 3
GENERIC_READ = 2147483648
FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000
FSCTL_GET_REPARSE_POINT = 0x000900A8
FILE_FLAG_BACKUP_SEMANTICS = 0x02000000
IO_REPARSE_TAG_MOUNT_POINT = 0xA0000003
IO_REPARSE_TAG_SYMLINK = 0xA000000C
MAXIMUM_REPARSE_DATA_BUFFER_SIZE = 16 * 1024
def _check_bit(val: int, flag: int) -> bool:
return bool(val & flag == flag)
class SymbolicLinkReparseBuffer(ctypes.Structure):
""" Implementing the below in Python:
typedef struct _REPARSE_DATA_BUFFER {
ULONG ReparseTag;
USHORT ReparseDataLength;
USHORT Reserved;
union {
struct {
USHORT SubstituteNameOffset;
USHORT SubstituteNameLength;
USHORT PrintNameOffset;
USHORT PrintNameLength;
ULONG Flags;
WCHAR PathBuffer[1];
} SymbolicLinkReparseBuffer;
struct {
USHORT SubstituteNameOffset;
USHORT SubstituteNameLength;
USHORT PrintNameOffset;
USHORT PrintNameLength;
WCHAR PathBuffer[1];
} MountPointReparseBuffer;
struct {
UCHAR DataBuffer[1];
} GenericReparseBuffer;
} DUMMYUNIONNAME;
} REPARSE_DATA_BUFFER, *PREPARSE_DATA_BUFFER;
"""
# See https://docs.microsoft.com/en-us/windows-hardware/drivers/ddi/content/ntifs/ns-ntifs-_reparse_data_buffer
_fields_ = [
('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte * (MAXIMUM_REPARSE_DATA_BUFFER_SIZE - 20))
]
class MountReparseBuffer(ctypes.Structure):
_fields_ = [
('path_buffer', ctypes.c_byte * (MAXIMUM_REPARSE_DATA_BUFFER_SIZE - 16)),
]
class ReparseBufferField(ctypes.Union):
_fields_ = [
('symlink', SymbolicLinkReparseBuffer),
('mount', MountReparseBuffer)
]
class ReparseBuffer(ctypes.Structure):
_anonymous_ = ("u",)
_fields_ = [
('reparse_tag', ctypes.c_ulong),
('reparse_data_length', ctypes.c_ushort),
('reserved', ctypes.c_ushort),
('substitute_name_offset', ctypes.c_ushort),
('substitute_name_length', ctypes.c_ushort),
('print_name_offset', ctypes.c_ushort),
('print_name_length', ctypes.c_ushort),
('u', ReparseBufferField)
]
def is_reparse_point(path: Union[str, pathlib.Path]) -> bool:
GetFileAttributesW.argtypes = [LPCWSTR]
GetFileAttributesW.restype = DWORD
return _check_bit(GetFileAttributesW(str(path)), stat.FILE_ATTRIBUTE_REPARSE_POINT)
def readlink(path: Union[str, pathlib.Path]) -> Union[str, pathlib.WindowsPath]:
# FILE_FLAG_OPEN_REPARSE_POINT alone is not enough if 'path'
# is a symbolic link to a directory or a NTFS junction.
# We need to set FILE_FLAG_BACKUP_SEMANTICS as well.
# See https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-createfilea
# description from _winapi.c:601
# /* REPARSE_DATA_BUFFER usage is heavily under-documented, especially for
# junction points. Here's what I've learned along the way:
# - A junction point has two components: a print name and a substitute
# name. They both describe the link target, but the substitute name is
# the physical target and the print name is shown in directory listings.
# - The print name must be a native name, prefixed with "\??\".
# - Both names are stored after each other in the same buffer (the
# PathBuffer) and both must be NUL-terminated.
# - There are four members defining their respective offset and length
# inside PathBuffer: SubstituteNameOffset, SubstituteNameLength,
# PrintNameOffset and PrintNameLength.
# - The total size we need to allocate for the REPARSE_DATA_BUFFER, thus,
# is the sum of:
# - the fixed header size (REPARSE_DATA_BUFFER_HEADER_SIZE)
# - the size of the MountPointReparseBuffer member without the PathBuffer
# - the size of the prefix ("\??\") in bytes
# - the size of the print name in bytes
# - the size of the substitute name in bytes
# - the size of two NUL terminators in bytes */
target_is_path = isinstance(path, pathlib.Path)
if target_is_path:
target = str(path)
else:
target = path
CreateFileW.argtypes = [LPWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE]
CreateFileW.restype = HANDLE
DeviceIoControl.argtypes = [HANDLE, DWORD, LPVOID, DWORD, LPVOID, DWORD, LPDWORD, LPVOID]
DeviceIoControl.restype = BOOL
handle = HANDLE(CreateFileW(target, GENERIC_READ, 0, None, OPEN_EXISTING,
FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, 0))
buf = ReparseBuffer()
ret = DWORD(0)
status = DeviceIoControl(handle, FSCTL_GET_REPARSE_POINT, None, 0, ctypes.byref(buf),
MAXIMUM_REPARSE_DATA_BUFFER_SIZE, ctypes.byref(ret), None)
CloseHandle(handle)
if not status:
logger = getLogger(__file__)
logger.error("Failed IOCTL access to REPARSE_POINT {})".format(target))
raise ValueError("not a symbolic link or access permission violation")
if buf.reparse_tag == IO_REPARSE_TAG_SYMLINK:
offset = buf.substitute_name_offset
ending = offset + buf.substitute_name_length
rpath = bytearray(buf.symlink.path_buffer)[offset:ending].decode('UTF-16-LE')
elif buf.reparse_tag == IO_REPARSE_TAG_MOUNT_POINT:
offset = buf.substitute_name_offset
ending = offset + buf.substitute_name_length
rpath = bytearray(buf.mount.path_buffer)[offset:ending].decode('UTF-16-LE')
else:
raise ValueError("not a symbolic link")
# on posixmodule.c:7859 in py38, we do that
# ```
# else if (rdb->ReparseTag == IO_REPARSE_TAG_MOUNT_POINT)
# {
# name = (wchar_t *)((char*)rdb->MountPointReparseBuffer.PathBuffer +
# rdb->MountPointReparseBuffer.SubstituteNameOffset);
# nameLen = rdb->MountPointReparseBuffer.SubstituteNameLength / sizeof(wchar_t);
# }
# else
# {
# PyErr_SetString(PyExc_ValueError, "not a symbolic link");
# }
# if (nameLen > 4 && wcsncmp(name, L"\\??\\", 4) == 0) {
# /* Our buffer is mutable, so this is okay */
# name[1] = L'\\';
# }
# ```
# so substitute prefix here.
if rpath.startswith('\\??\\'):
rpath = '\\\\' + rpath[2:]
if target_is_path:
return pathlib.WindowsPath(rpath)
else:
return rpath

86
urllib3/__init__.py Normal file
View File

@@ -0,0 +1,86 @@
"""
urllib3 - Thread-safe connection pooling and re-using.
"""
from __future__ import absolute_import
import warnings
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from . import exceptions
from .filepost import encode_multipart_formdata
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
from .response import HTTPResponse
from .util.request import make_headers
from .util.url import get_host
from .util.timeout import Timeout
from .util.retry import Retry
# Set default logging handler to avoid "No handler found" warnings.
import logging
from logging import NullHandler
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT"
__version__ = "1.25.9"
__all__ = (
"HTTPConnectionPool",
"HTTPSConnectionPool",
"PoolManager",
"ProxyManager",
"HTTPResponse",
"Retry",
"Timeout",
"add_stderr_logger",
"connection_from_url",
"disable_warnings",
"encode_multipart_formdata",
"get_host",
"make_headers",
"proxy_from_url",
)
logging.getLogger(__name__).addHandler(NullHandler())
def add_stderr_logger(level=logging.DEBUG):
"""
Helper for quickly adding a StreamHandler to the logger. Useful for
debugging.
Returns the handler after adding it.
"""
# This method needs to be in this __init__.py to get the __name__ correct
# even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(level)
logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler
# ... Clean up.
del NullHandler
# All warning filters *must* be appended unless you're really certain that they
# shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them.
# SecurityWarning's always go off by default.
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning):
"""
Helper for quickly disabling all urllib3 warnings.
"""
warnings.simplefilter("ignore", category)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

336
urllib3/_collections.py Normal file
View File

@@ -0,0 +1,336 @@
from __future__ import absolute_import
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
from collections import Mapping, MutableMapping
try:
from threading import RLock
except ImportError: # Platform-specific: No threads available
class RLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
from collections import OrderedDict
from .exceptions import InvalidHeader
from .packages.six import iterkeys, itervalues, PY3
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object()
class RecentlyUsedContainer(MutableMapping):
"""
Provides a thread-safe dict-like container which maintains up to
``maxsize`` keys while throwing away the least-recently-used keys beyond
``maxsize``.
:param maxsize:
Maximum number of recent elements to retain.
:param dispose_func:
Every time an item is evicted from the container,
``dispose_func(value)`` is called. Callback which will get called
"""
ContainerCls = OrderedDict
def __init__(self, maxsize=10, dispose_func=None):
self._maxsize = maxsize
self.dispose_func = dispose_func
self._container = self.ContainerCls()
self.lock = RLock()
def __getitem__(self, key):
# Re-insert the item, moving it to the end of the eviction line.
with self.lock:
item = self._container.pop(key)
self._container[key] = item
return item
def __setitem__(self, key, value):
evicted_value = _Null
with self.lock:
# Possibly evict the existing value of 'key'
evicted_value = self._container.get(key, _Null)
self._container[key] = value
# If we didn't evict an existing value, we might have to evict the
# least recently used item from the beginning of the container.
if len(self._container) > self._maxsize:
_key, evicted_value = self._container.popitem(last=False)
if self.dispose_func and evicted_value is not _Null:
self.dispose_func(evicted_value)
def __delitem__(self, key):
with self.lock:
value = self._container.pop(key)
if self.dispose_func:
self.dispose_func(value)
def __len__(self):
with self.lock:
return len(self._container)
def __iter__(self):
raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self):
with self.lock:
# Copy pointers to all values, then wipe the mapping
values = list(itervalues(self._container))
self._container.clear()
if self.dispose_func:
for value in values:
self.dispose_func(value)
def keys(self):
with self.lock:
return list(iterkeys(self._container))
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
when compared case-insensitively.
:param kwargs:
Additional field-value pairs to pass in to ``dict.update``.
A ``dict`` like container for storing HTTP Headers.
Field names are stored and compared case-insensitively in compliance with
RFC 7230. Iteration provides the first case-sensitive key seen for each
case-insensitive pair.
Using ``__setitem__`` syntax overwrites fields that compare equal
case-insensitively in order to maintain ``dict``'s api. For fields that
compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
in a loop.
If multiple fields that are equal case-insensitively are passed to the
constructor or ``.update``, the behavior is undefined and some will be
lost.
>>> headers = HTTPHeaderDict()
>>> headers.add('Set-Cookie', 'foo=bar')
>>> headers.add('set-cookie', 'baz=quxx')
>>> headers['content-length'] = '7'
>>> headers['SET-cookie']
'foo=bar, baz=quxx'
>>> headers['Content-Length']
'7'
"""
def __init__(self, headers=None, **kwargs):
super(HTTPHeaderDict, self).__init__()
self._container = OrderedDict()
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
else:
self.extend(headers)
if kwargs:
self.extend(kwargs)
def __setitem__(self, key, val):
self._container[key.lower()] = [key, val]
return self._container[key.lower()]
def __getitem__(self, key):
val = self._container[key.lower()]
return ", ".join(val[1:])
def __delitem__(self, key):
del self._container[key.lower()]
def __contains__(self, key):
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
(k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other):
return not self.__eq__(other)
if not PY3: # Python 2
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues
__marker = object()
def __len__(self):
return len(self._container)
def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]
def pop(self, key, default=__marker):
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
"""
# Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel.
try:
value = self[key]
except KeyError:
if default is self.__marker:
raise
return default
else:
del self[key]
return value
def discard(self, key):
try:
del self[key]
except KeyError:
pass
def add(self, key, val):
"""Adds a (name, value) pair, doesn't overwrite the value if it already
exists.
>>> headers = HTTPHeaderDict(foo='bar')
>>> headers.add('Foo', 'baz')
>>> headers['foo']
'bar, baz'
"""
key_lower = key.lower()
new_vals = [key, val]
# Keep the common case aka no item present as fast as possible
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
vals.append(val)
def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__
"""
if len(args) > 1:
raise TypeError(
"extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args))
)
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
elif isinstance(other, Mapping):
for key in other:
self.add(key, other[key])
elif hasattr(other, "keys"):
for key in other.keys():
self.add(key, other[key])
else:
for key, value in other:
self.add(key, value)
for key, value in kwargs.items():
self.add(key, value)
def getlist(self, key, default=__marker):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = self._container[key.lower()]
except KeyError:
if default is self.__marker:
return []
return default
else:
return vals[1:]
# Backwards compatibility for httplib
getheaders = getlist
getallmatchingheaders = getlist
iget = getlist
# Backwards compatibility for http.cookiejar
get_all = getlist
def __repr__(self):
return "%s(%s)" % (type(self).__name__, dict(self.itermerged()))
def _copy_from(self, other):
for key in other:
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
self._container[key.lower()] = [key] + val
def copy(self):
clone = type(self)()
clone._copy_from(self)
return clone
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val
def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[key.lower()]
yield val[0], ", ".join(val[1:])
def items(self):
return list(self.iteritems())
@classmethod
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
obs_fold_continued_leaders = (" ", "\t")
headers = []
for line in message.headers:
if line.startswith(obs_fold_continued_leaders):
if not headers:
# We received a header line that starts with OWS as described
# in RFC-7230 S3.2.4. This indicates a multiline header, but
# there exists no previous header to which we can attach it.
raise InvalidHeader(
"Header continuation with no previous header: %s" % line
)
else:
key, value = headers[-1]
headers[-1] = (key, value + " " + line.strip())
continue
key, value = line.split(":", 1)
headers.append((key, value.strip()))
return cls(headers)

423
urllib3/connection.py Normal file
View File

@@ -0,0 +1,423 @@
from __future__ import absolute_import
import re
import datetime
import logging
import os
import socket
from socket import error as SocketError, timeout as SocketTimeout
import warnings
from .packages import six
from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection
from .packages.six.moves.http_client import HTTPException # noqa: F401
try: # Compiled with SSL?
import ssl
BaseSSLError = ssl.SSLError
except (ImportError, AttributeError): # Platform-specific: No SSL.
ssl = None
class BaseSSLError(BaseException):
pass
try:
# Python 3: not a no-op, we're adding this to the namespace so it can be imported.
ConnectionError = ConnectionError
except NameError:
# Python 2
class ConnectionError(Exception):
pass
from .exceptions import (
NewConnectionError,
ConnectTimeoutError,
SubjectAltNameWarning,
SystemTimeWarning,
)
from .packages.ssl_match_hostname import match_hostname, CertificateError
from .util.ssl_ import (
resolve_cert_reqs,
resolve_ssl_version,
assert_fingerprint,
create_urllib3_context,
ssl_wrap_socket,
)
from .util import connection
from ._collections import HTTPHeaderDict
log = logging.getLogger(__name__)
port_by_scheme = {"http": 80, "https": 443}
# When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2019, 1, 1)
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
class HTTPConnection(_HTTPConnection, object):
"""
Based on httplib.HTTPConnection but provides an extra constructor
backwards-compatibility layer between older and newer Pythons.
Additional keyword parameters are used to configure attributes of the connection.
Accepted parameters include:
- ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool`
- ``source_address``: Set the source address for the current connection.
- ``socket_options``: Set specific options on the underlying socket. If not specified, then
defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling
Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy.
For example, if you wish to enable TCP Keep Alive in addition to the defaults,
you might pass::
HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
]
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
"""
default_port = port_by_scheme["http"]
#: Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
#: Whether this connection verifies the host's certificate.
is_verified = False
def __init__(self, *args, **kw):
if not six.PY2:
kw.pop("strict", None)
# Pre-set source_address.
self.source_address = kw.get("source_address")
#: The socket options provided by the user. If no options are
#: provided, we use the default options.
self.socket_options = kw.pop("socket_options", self.default_socket_options)
_HTTPConnection.__init__(self, *args, **kw)
@property
def host(self):
"""
Getter method to remove any trailing dots that indicate the hostname is an FQDN.
In general, SSL certificates don't include the trailing dot indicating a
fully-qualified domain name, and thus, they don't validate properly when
checked against a domain name that includes the dot. In addition, some
servers may not expect to receive the trailing dot when provided.
However, the hostname with trailing dot is critical to DNS resolution; doing a
lookup with the trailing dot will properly only resolve the appropriate FQDN,
whereas a lookup without a trailing dot will search the system's search domain
list. Thus, it's important to keep the original host around for use only in
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
actual TCP connection across which we're going to send HTTP requests).
"""
return self._dns_host.rstrip(".")
@host.setter
def host(self, value):
"""
Setter for the `host` property.
We assume that only urllib3 uses the _dns_host attribute; httplib itself
only uses `host`, and it seems reasonable that other libraries follow suit.
"""
self._dns_host = value
def _new_conn(self):
""" Establish a socket connection and set nodelay settings on it.
:return: New socket connection.
"""
extra_kw = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw["socket_options"] = self.socket_options
try:
conn = connection.create_connection(
(self._dns_host, self.port), self.timeout, **extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except SocketError as e:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
return conn
def _prepare_conn(self, conn):
self.sock = conn
# Google App Engine's httplib does not define _tunnel_host
if getattr(self, "_tunnel_host", None):
# TODO: Fix tunnel so it doesn't depend on self.sock state.
self._tunnel()
# Mark this connection as not reusable
self.auto_open = 0
def connect(self):
conn = self._new_conn()
self._prepare_conn(conn)
def putrequest(self, method, url, *args, **kwargs):
"""Send a request to the server"""
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
raise ValueError(
"Method cannot contain non-token characters %r (found at least %r)"
% (method, match.group())
)
return _HTTPConnection.putrequest(self, method, url, *args, **kwargs)
def request_chunked(self, method, url, body=None, headers=None):
"""
Alternative to the common request method, which sends the
body with chunked encoding and not as one block
"""
headers = HTTPHeaderDict(headers if headers is not None else {})
skip_accept_encoding = "accept-encoding" in headers
skip_host = "host" in headers
self.putrequest(
method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
)
for header, value in headers.items():
self.putheader(header, value)
if "transfer-encoding" not in headers:
self.putheader("Transfer-Encoding", "chunked")
self.endheaders()
if body is not None:
stringish_types = six.string_types + (bytes,)
if isinstance(body, stringish_types):
body = (body,)
for chunk in body:
if not chunk:
continue
if not isinstance(chunk, bytes):
chunk = chunk.encode("utf8")
len_str = hex(len(chunk))[2:]
self.send(len_str.encode("utf-8"))
self.send(b"\r\n")
self.send(chunk)
self.send(b"\r\n")
# After the if clause, to always have a closed body
self.send(b"0\r\n\r\n")
class HTTPSConnection(HTTPConnection):
default_port = port_by_scheme["https"]
cert_reqs = None
ca_certs = None
ca_cert_dir = None
ca_cert_data = None
ssl_version = None
assert_fingerprint = None
def __init__(
self,
host,
port=None,
key_file=None,
cert_file=None,
key_password=None,
strict=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None,
server_hostname=None,
**kw
):
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
self.ssl_context = ssl_context
self.server_hostname = server_hostname
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = "https"
def set_cert(
self,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
):
"""
This method should only be called once, before the connection is used.
"""
# If cert_reqs is not provided we'll assume CERT_REQUIRED unless we also
# have an SSLContext object in which case we'll use its verify_mode.
if cert_reqs is None:
if self.ssl_context is not None:
cert_reqs = self.ssl_context.verify_mode
else:
cert_reqs = resolve_cert_reqs(None)
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
self.key_password = key_password
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
def connect(self):
# Add certificate verification
conn = self._new_conn()
hostname = self.host
# Google App Engine's httplib does not define _tunnel_host
if getattr(self, "_tunnel_host", None):
self.sock = conn
# Calls self._set_hostport(), so self.host is
# self._tunnel_host below.
self._tunnel()
# Mark this connection as not reusable
self.auto_open = 0
# Override the host with the one we're requesting data from.
hostname = self._tunnel_host
server_hostname = hostname
if self.server_hostname is not None:
server_hostname = self.server_hostname
is_time_off = datetime.date.today() < RECENT_DATE
if is_time_off:
warnings.warn(
(
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# Wrap socket using verification with the root certs in
# trusted_root_certs
default_ssl_context = False
if self.ssl_context is None:
default_ssl_context = True
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(self.ssl_version),
cert_reqs=resolve_cert_reqs(self.cert_reqs),
)
context = self.ssl_context
context.verify_mode = resolve_cert_reqs(self.cert_reqs)
# Try to load OS default certs if none are given.
# Works well on Windows (requires Python3.4+)
if (
not self.ca_certs
and not self.ca_cert_dir
and not self.ca_cert_data
and default_ssl_context
and hasattr(context, "load_default_certs")
):
context.load_default_certs()
self.sock = ssl_wrap_socket(
sock=conn,
keyfile=self.key_file,
certfile=self.cert_file,
key_password=self.key_password,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
ca_cert_data=self.ca_cert_data,
server_hostname=server_hostname,
ssl_context=context,
)
if self.assert_fingerprint:
assert_fingerprint(
self.sock.getpeercert(binary_form=True), self.assert_fingerprint
)
elif (
context.verify_mode != ssl.CERT_NONE
and not getattr(context, "check_hostname", False)
and self.assert_hostname is not False
):
# While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames.
cert = self.sock.getpeercert()
if not cert.get("subjectAltName", ()):
warnings.warn(
(
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
"`commonName` for now. This feature is being removed by major browsers and "
"deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
"for details.)".format(hostname)
),
SubjectAltNameWarning,
)
_match_hostname(cert, self.assert_hostname or server_hostname)
self.is_verified = (
context.verify_mode == ssl.CERT_REQUIRED
or self.assert_fingerprint is not None
)
def _match_hostname(cert, asserted_hostname):
try:
match_hostname(cert, asserted_hostname)
except CertificateError as e:
log.warning(
"Certificate did not match expected hostname: %s. Certificate: %s",
asserted_hostname,
cert,
)
# Add cert to exception and reraise so client code can inspect
# the cert when catching the exception, if they want to
e._peer_cert = cert
raise
if not ssl:
HTTPSConnection = DummyConnection # noqa: F811
VerifiedHTTPSConnection = HTTPSConnection

1033
urllib3/connectionpool.py Normal file

File diff suppressed because it is too large Load Diff

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,36 @@
"""
This module provides means to detect the App Engine environment.
"""
import os
def is_appengine():
return is_local_appengine() or is_prod_appengine()
def is_appengine_sandbox():
"""Reports if the app is running in the first generation sandbox.
The second generation runtimes are technically still in a sandbox, but it
is much less restrictive, so generally you shouldn't need to check for it.
see https://cloud.google.com/appengine/docs/standard/runtimes
"""
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
def is_local_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Development/")
def is_prod_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Google App Engine/")
def is_prod_appengine_mvms():
"""Deprecated."""
return False

View File

@@ -0,0 +1,493 @@
"""
This module uses ctypes to bind a whole bunch of functions and constants from
SecureTransport. The goal here is to provide the low-level API to
SecureTransport. These are essentially the C-level functions and constants, and
they're pretty gross to work with.
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
import platform
from ctypes.util import find_library
from ctypes import (
c_void_p,
c_int32,
c_char_p,
c_size_t,
c_byte,
c_uint32,
c_ulong,
c_long,
c_bool,
)
from ctypes import CDLL, POINTER, CFUNCTYPE
security_path = find_library("Security")
if not security_path:
raise ImportError("The library Security could not be found")
core_foundation_path = find_library("CoreFoundation")
if not core_foundation_path:
raise ImportError("The library CoreFoundation could not be found")
version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8):
raise OSError(
"Only OS X 10.8 and newer are supported, not %s.%s"
% (version_info[0], version_info[1])
)
Security = CDLL(security_path, use_errno=True)
CoreFoundation = CDLL(core_foundation_path, use_errno=True)
Boolean = c_bool
CFIndex = c_long
CFStringEncoding = c_uint32
CFData = c_void_p
CFString = c_void_p
CFArray = c_void_p
CFMutableArray = c_void_p
CFDictionary = c_void_p
CFError = c_void_p
CFType = c_void_p
CFTypeID = c_ulong
CFTypeRef = POINTER(CFType)
CFAllocatorRef = c_void_p
OSStatus = c_int32
CFDataRef = POINTER(CFData)
CFStringRef = POINTER(CFString)
CFArrayRef = POINTER(CFArray)
CFMutableArrayRef = POINTER(CFMutableArray)
CFDictionaryRef = POINTER(CFDictionary)
CFArrayCallBacks = c_void_p
CFDictionaryKeyCallBacks = c_void_p
CFDictionaryValueCallBacks = c_void_p
SecCertificateRef = POINTER(c_void_p)
SecExternalFormat = c_uint32
SecExternalItemType = c_uint32
SecIdentityRef = POINTER(c_void_p)
SecItemImportExportFlags = c_uint32
SecItemImportExportKeyParameters = c_void_p
SecKeychainRef = POINTER(c_void_p)
SSLProtocol = c_uint32
SSLCipherSuite = c_uint32
SSLContextRef = POINTER(c_void_p)
SecTrustRef = POINTER(c_void_p)
SSLConnectionRef = c_uint32
SecTrustResultType = c_uint32
SecTrustOptionFlags = c_uint32
SSLProtocolSide = c_uint32
SSLConnectionType = c_uint32
SSLSessionOption = c_uint32
try:
Security.SecItemImport.argtypes = [
CFDataRef,
CFStringRef,
POINTER(SecExternalFormat),
POINTER(SecExternalItemType),
SecItemImportExportFlags,
POINTER(SecItemImportExportKeyParameters),
SecKeychainRef,
POINTER(CFArrayRef),
]
Security.SecItemImport.restype = OSStatus
Security.SecCertificateGetTypeID.argtypes = []
Security.SecCertificateGetTypeID.restype = CFTypeID
Security.SecIdentityGetTypeID.argtypes = []
Security.SecIdentityGetTypeID.restype = CFTypeID
Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef,
SecCertificateRef,
POINTER(SecIdentityRef),
]
Security.SecIdentityCreateWithCertificate.restype = OSStatus
Security.SecKeychainCreate.argtypes = [
c_char_p,
c_uint32,
c_void_p,
Boolean,
c_void_p,
POINTER(SecKeychainRef),
]
Security.SecKeychainCreate.restype = OSStatus
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [
CFDataRef,
CFDictionaryRef,
POINTER(CFArrayRef),
]
Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [SSLContextRef]
Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [SSLContextRef]
Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
c_size_t,
]
Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef,
POINTER(SSLProtocol),
]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [
CFAllocatorRef,
SSLProtocolSide,
SSLConnectionType,
]
Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMin.restype = OSStatus
Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMax.restype = OSStatus
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SSLReadFunc = SSLReadFunc
Security.SSLWriteFunc = SSLWriteFunc
Security.SSLContextRef = SSLContextRef
Security.SSLProtocol = SSLProtocol
Security.SSLCipherSuite = SSLCipherSuite
Security.SecIdentityRef = SecIdentityRef
Security.SecKeychainRef = SecKeychainRef
Security.SecTrustRef = SecTrustRef
Security.SecTrustResultType = SecTrustResultType
Security.SecExternalFormat = SecExternalFormat
Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, "kSecImportExportPassphrase"
)
Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, "kSecImportItemIdentity"
)
# CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
CoreFoundation.CFRetain.restype = CFTypeRef
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef,
c_char_p,
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFDictionaryCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
POINTER(CFTypeRef),
CFIndex,
CFDictionaryKeyCallBacks,
CFDictionaryValueCallBacks,
]
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreate.restype = CFArrayRef
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef,
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
)
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef
CoreFoundation.CFArrayRef = CFArrayRef
CoreFoundation.CFStringRef = CFStringRef
CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError):
raise ImportError("Error initializing ctypes")
class CFConst(object):
"""
A class object that acts as essentially a namespace for CoreFoundation
constants.
"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
class SecurityConst(object):
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1
kSSLStreamType = 0
kSecFormatPEMSequence = 10
kSecTrustResultInvalid = 0
kSecTrustResultProceed = 1
# This gap is present on purpose: this was kSecTrustResultConfirm, which
# is deprecated.
kSecTrustResultDeny = 3
kSecTrustResultUnspecified = 4
kSecTrustResultRecoverableTrustFailure = 5
kSecTrustResultFatalTrustFailure = 6
kSecTrustResultOtherError = 7
errSSLProtocol = -9800
errSSLWouldBlock = -9803
errSSLClosedGraceful = -9805
errSSLClosedNoNotify = -9816
errSSLClosedAbort = -9806
errSSLXCertChainInvalid = -9807
errSSLCrypto = -9809
errSSLInternal = -9810
errSSLCertExpired = -9814
errSSLCertNotYetValid = -9815
errSSLUnknownRootCert = -9812
errSSLNoRootCert = -9813
errSSLHostNameMismatch = -9843
errSSLPeerHandshakeFail = -9824
errSSLPeerUserCancelled = -9839
errSSLWeakPeerEphemeralDHKey = -9850
errSSLServerAuthCompleted = -9841
errSSLRecordOverflow = -9847
errSecVerifyFailed = -67808
errSecNoTrustSettings = -25263
errSecItemNotFound = -25300
errSecInvalidTrustSettings = -25262
# Cipher suites. We only pick the ones our default cipher string allows.
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C
TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
TLS_AES_128_GCM_SHA256 = 0x1301
TLS_AES_256_GCM_SHA384 = 0x1302
TLS_AES_128_CCM_8_SHA256 = 0x1305
TLS_AES_128_CCM_SHA256 = 0x1304

View File

@@ -0,0 +1,328 @@
"""
Low-level helpers for the SecureTransport bindings.
These are Python functions that are not directly related to the high-level APIs
but are necessary to get them to work. They include a whole bunch of low-level
CoreFoundation messing about and memory management. The concerns in this module
are almost entirely about trying to avoid memory leaks and providing
appropriate and useful assistance to the higher-level code.
"""
import base64
import ctypes
import itertools
import re
import os
import ssl
import tempfile
from .bindings import Security, CoreFoundation, CFConst
# This regular expression is used to grab PEM data out of a PEM bundle.
_PEM_CERTS_RE = re.compile(
b"-----BEGIN CERTIFICATE-----\n(.*?)\n-----END CERTIFICATE-----", re.DOTALL
)
def _cf_data_from_bytes(bytestring):
"""
Given a bytestring, create a CFData object from it. This CFData object must
be CFReleased by the caller.
"""
return CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, bytestring, len(bytestring)
)
def _cf_dictionary_from_tuples(tuples):
"""
Given a list of Python tuples, create an associated CFDictionary.
"""
dictionary_size = len(tuples)
# We need to get the dictionary keys and values out in the same order.
keys = (t[0] for t in tuples)
values = (t[1] for t in tuples)
cf_keys = (CoreFoundation.CFTypeRef * dictionary_size)(*keys)
cf_values = (CoreFoundation.CFTypeRef * dictionary_size)(*values)
return CoreFoundation.CFDictionaryCreate(
CoreFoundation.kCFAllocatorDefault,
cf_keys,
cf_values,
dictionary_size,
CoreFoundation.kCFTypeDictionaryKeyCallBacks,
CoreFoundation.kCFTypeDictionaryValueCallBacks,
)
def _cf_string_to_unicode(value):
"""
Creates a Unicode string from a CFString object. Used entirely for error
reporting.
Yes, it annoys me quite a lot that this function is this complex.
"""
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
string = CoreFoundation.CFStringGetCStringPtr(
value_as_void_p, CFConst.kCFStringEncodingUTF8
)
if string is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode("utf-8")
return string
def _assert_no_error(error, exception_class=None):
"""
Checks the return code and throws an exception if there is an error to
report
"""
if error == 0:
return
cf_error_string = Security.SecCopyErrorMessageString(error, None)
output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u"":
output = u"OSStatus %s" % error
if exception_class is None:
exception_class = ssl.SSLError
raise exception_class(output)
def _cert_array_from_pem(pem_bundle):
"""
Given a bundle of certs in PEM format, turns them into a CFArray of certs
that can be used to validate a cert chain.
"""
# Normalize the PEM bundle's line endings.
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
der_certs = [
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
]
if not der_certs:
raise ssl.SSLError("No root certificates specified")
cert_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cert_array:
raise ssl.SSLError("Unable to allocate memory!")
try:
for der_bytes in der_certs:
certdata = _cf_data_from_bytes(der_bytes)
if not certdata:
raise ssl.SSLError("Unable to allocate memory!")
cert = Security.SecCertificateCreateWithData(
CoreFoundation.kCFAllocatorDefault, certdata
)
CoreFoundation.CFRelease(certdata)
if not cert:
raise ssl.SSLError("Unable to build cert object!")
CoreFoundation.CFArrayAppendValue(cert_array, cert)
CoreFoundation.CFRelease(cert)
except Exception:
# We need to free the array before the exception bubbles further.
# We only want to do that if an error occurs: otherwise, the caller
# should free.
CoreFoundation.CFRelease(cert_array)
return cert_array
def _is_cert(item):
"""
Returns True if a given CFTypeRef is a certificate.
"""
expected = Security.SecCertificateGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected
def _is_identity(item):
"""
Returns True if a given CFTypeRef is an identity.
"""
expected = Security.SecIdentityGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected
def _temporary_keychain():
"""
This function creates a temporary Mac keychain that we can use to work with
credentials. This keychain uses a one-time password and a temporary file to
store the data. We expect to have one keychain per socket. The returned
SecKeychainRef must be freed by the caller, including calling
SecKeychainDelete.
Returns a tuple of the SecKeychainRef and the path to the temporary
directory that contains it.
"""
# Unfortunately, SecKeychainCreate requires a path to a keychain. This
# means we cannot use mkstemp to use a generic temporary file. Instead,
# we're going to create a temporary directory and a filename to use there.
# This filename will be 8 random bytes expanded into base64. We also need
# some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes.
random_bytes = os.urandom(40)
filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself.
keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate(
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
)
_assert_no_error(status)
# Having created the keychain, we want to pass it off to the caller.
return keychain, tempdirectory
def _load_items_from_file(keychain, path):
"""
Given a single file, loads all the trust objects from it into arrays and
the keychain.
Returns a tuple of lists: the first list is a list of identities, the
second a list of certs.
"""
certificates = []
identities = []
result_array = None
with open(path, "rb") as f:
raw_filedata = f.read()
try:
filedata = CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
)
result_array = CoreFoundation.CFArrayRef()
result = Security.SecItemImport(
filedata, # cert data
None, # Filename, leaving it out for now
None, # What the type of the file is, we don't care
None, # what's in the file, we don't care
0, # import flags
None, # key params, can include passphrase in the future
keychain, # The keychain to insert into
ctypes.byref(result_array), # Results
)
_assert_no_error(result)
# A CFArray is not very useful to us as an intermediary
# representation, so we are going to extract the objects we want
# and then free the array. We don't need to keep hold of keys: the
# keychain already has them!
result_count = CoreFoundation.CFArrayGetCount(result_array)
for index in range(result_count):
item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
item = ctypes.cast(item, CoreFoundation.CFTypeRef)
if _is_cert(item):
CoreFoundation.CFRetain(item)
certificates.append(item)
elif _is_identity(item):
CoreFoundation.CFRetain(item)
identities.append(item)
finally:
if result_array:
CoreFoundation.CFRelease(result_array)
CoreFoundation.CFRelease(filedata)
return (identities, certificates)
def _load_client_cert_chain(keychain, *paths):
"""
Load certificates and maybe keys from a number of files. Has the end goal
of returning a CFArray containing one SecIdentityRef, and then zero or more
SecCertificateRef objects, suitable for use as a client certificate trust
chain.
"""
# Ok, the strategy.
#
# This relies on knowing that macOS will not give you a SecIdentityRef
# unless you have imported a key into a keychain. This is a somewhat
# artificial limitation of macOS (for example, it doesn't necessarily
# affect iOS), but there is nothing inside Security.framework that lets you
# get a SecIdentityRef without having a key in a keychain.
#
# So the policy here is we take all the files and iterate them in order.
# Each one will use SecItemImport to have one or more objects loaded from
# it. We will also point at a keychain that macOS can use to work with the
# private key.
#
# Once we have all the objects, we'll check what we actually have. If we
# already have a SecIdentityRef in hand, fab: we'll use that. Otherwise,
# we'll take the first certificate (which we assume to be our leaf) and
# ask the keychain to give us a SecIdentityRef with that cert's associated
# key.
#
# We'll then return a CFArray containing the trust chain: one
# SecIdentityRef and then zero-or-more SecCertificateRef objects. The
# responsibility for freeing this CFArray will be with the caller. This
# CFArray must remain alive for the entire connection, so in practice it
# will be stored with a single SSLSocket, along with the reference to the
# keychain.
certificates = []
identities = []
# Filter out bad paths.
paths = (path for path in paths if path)
try:
for file_path in paths:
new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities)
certificates.extend(new_certs)
# Ok, we have everything. The question is: do we have an identity? If
# not, we want to grab one from the first cert we have.
if not identities:
new_identity = Security.SecIdentityRef()
status = Security.SecIdentityCreateWithCertificate(
keychain, certificates[0], ctypes.byref(new_identity)
)
_assert_no_error(status)
identities.append(new_identity)
# We now want to release the original certificate, as we no longer
# need it.
CoreFoundation.CFRelease(certificates.pop(0))
# We now need to build a new CFArray that holds the trust chain.
trust_chain = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
for item in itertools.chain(identities, certificates):
# ArrayAppendValue does a CFRetain on the item. That's fine,
# because the finally block will release our other refs to them.
CoreFoundation.CFArrayAppendValue(trust_chain, item)
return trust_chain
finally:
for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj)

View File

@@ -0,0 +1,314 @@
"""
This module provides a pool manager that uses Google App Engine's
`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Example usage::
from urllib3 import PoolManager
from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox
if is_appengine_sandbox():
# AppEngineManager uses AppEngine's URLFetch API behind the scenes
http = AppEngineManager()
else:
# PoolManager uses a socket-level API behind the scenes
http = PoolManager()
r = http.request('GET', 'https://google.com/')
There are `limitations <https://cloud.google.com/appengine/docs/python/\
urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be
the best choice for your application. There are three options for using
urllib3 on Google App Engine:
1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is
cost-effective in many circumstances as long as your usage is within the
limitations.
2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets.
Sockets also have `limitations and restrictions
<https://cloud.google.com/appengine/docs/python/sockets/\
#limitations-and-restrictions>`_ and have a lower free quota than URLFetch.
To use sockets, be sure to specify the following in your ``app.yaml``::
env_variables:
GAE_USE_SOCKETS_HTTPLIB : 'true'
3. If you are using `App Engine Flexible
<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard
:class:`PoolManager` without any configuration or special environment variables.
"""
from __future__ import absolute_import
import io
import logging
import warnings
from ..packages.six.moves.urllib.parse import urljoin
from ..exceptions import (
HTTPError,
HTTPWarning,
MaxRetryError,
ProtocolError,
TimeoutError,
SSLError,
)
from ..request import RequestMethods
from ..response import HTTPResponse
from ..util.timeout import Timeout
from ..util.retry import Retry
from . import _appengine_environ
try:
from google.appengine.api import urlfetch
except ImportError:
urlfetch = None
log = logging.getLogger(__name__)
class AppEnginePlatformWarning(HTTPWarning):
pass
class AppEnginePlatformError(HTTPError):
pass
class AppEngineManager(RequestMethods):
"""
Connection manager for Google App Engine sandbox applications.
This manager uses the URLFetch service directly instead of using the
emulated httplib, and is subject to URLFetch limitations as described in
the App Engine documentation `here
<https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Notably it will raise an :class:`AppEnginePlatformError` if:
* URLFetch is not available.
* If you attempt to use this on App Engine Flexible, as full socket
support is available.
* If a request size is more than 10 megabytes.
* If a response size is more than 32 megabtyes.
* If you use an unsupported request method such as OPTIONS.
Beyond those cases, it will raise normal urllib3 errors.
"""
def __init__(
self,
headers=None,
retries=None,
validate_certificate=True,
urlfetch_retries=True,
):
if not urlfetch:
raise AppEnginePlatformError(
"URLFetch is not available in this environment."
)
warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.io/en/latest/reference/urllib3.contrib.html.",
AppEnginePlatformWarning,
)
RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate
self.urlfetch_retries = urlfetch_retries
self.retries = retries or Retry.DEFAULT
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Return False to re-raise any potential exceptions
return False
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw
):
retries = self._get_retries(retries, redirect)
try:
follow_redirects = redirect and retries.redirect != 0 and retries.total
response = urlfetch.fetch(
url,
payload=body,
method=method,
headers=headers or {},
allow_truncated=False,
follow_redirects=self.urlfetch_retries and follow_redirects,
deadline=self._get_absolute_timeout(timeout),
validate_certificate=self.validate_certificate,
)
except urlfetch.DeadlineExceededError as e:
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if "too large" in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.",
e,
)
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.",
e,
)
except urlfetch.SSLCertificateError as e:
raise SSLError(e)
except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e
)
http_response = self._urlfetch_response_to_http_response(
response, retries=retries, **response_kw
)
# Handle redirect?
redirect_location = redirect and http_response.get_redirect_location()
if redirect_location:
# Check for redirect response
if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
else:
if http_response.status == 303:
method = "GET"
try:
retries = retries.increment(
method, url, response=http_response, _pool=self
)
except MaxRetryError:
if retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
return http_response
retries.sleep_for_retry(http_response)
log.debug("Redirecting %s -> %s", url, redirect_location)
redirect_url = urljoin(url, redirect_location)
return self.urlopen(
method,
redirect_url,
body,
headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
# Check if we should retry the HTTP response.
has_retry_after = bool(http_response.getheader("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment(method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url)
retries.sleep(http_response)
return self.urlopen(
method,
url,
body=body,
headers=headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
return http_response
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == "deflate":
del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content,
# so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",")
encodings.remove("chunked")
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
original_response = HTTPResponse(
# In order for decoding to work, we must present the content as
# a file-like object.
body=io.BytesIO(urlfetch_resp.content),
msg=urlfetch_resp.header_msg,
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
**response_kw
)
return HTTPResponse(
body=io.BytesIO(urlfetch_resp.content),
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
original_response=original_response,
**response_kw
)
def _get_absolute_timeout(self, timeout):
if timeout is Timeout.DEFAULT_TIMEOUT:
return None # Defer to URLFetch's default.
if isinstance(timeout, Timeout):
if timeout._read is not None or timeout._connect is not None:
warnings.warn(
"URLFetch does not support granular timeout settings, "
"reverting to total or default URLFetch timeout.",
AppEnginePlatformWarning,
)
return timeout.total
return timeout
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning,
)
return retries
# Alias methods from _appengine_environ to maintain public API interface.
is_appengine = _appengine_environ.is_appengine
is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
is_local_appengine = _appengine_environ.is_local_appengine
is_prod_appengine = _appengine_environ.is_prod_appengine
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms

121
urllib3/contrib/ntlmpool.py Normal file
View File

@@ -0,0 +1,121 @@
"""
NTLM authenticating pool, contributed by erikcederstran
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
"""
from __future__ import absolute_import
from logging import getLogger
from ntlm import ntlm
from .. import HTTPSConnectionPool
from ..packages.six.moves.http_client import HTTPSConnection
log = getLogger(__name__)
class NTLMConnectionPool(HTTPSConnectionPool):
"""
Implements an NTLM authentication version of an urllib3 connection pool
"""
scheme = "https"
def __init__(self, user, pw, authurl, *args, **kwargs):
"""
authurl is a random URL on the server that is protected by NTLM.
user is the Windows user, probably in the DOMAIN\\username format.
pw is the password for the user.
"""
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
self.authurl = authurl
self.rawuser = user
user_parts = user.split("\\", 1)
self.domain = user_parts[0].upper()
self.user = user_parts[1]
self.pw = pw
def _new_conn(self):
# Performs the NTLM handshake that secures the connection. The socket
# must be kept open while requests are performed.
self.num_connections += 1
log.debug(
"Starting NTLM HTTPS connection no. %d: https://%s%s",
self.num_connections,
self.host,
self.authurl,
)
headers = {"Connection": "Keep-Alive"}
req_header = "Authorization"
resp_header = "www-authenticate"
conn = HTTPSConnection(host=self.host, port=self.port)
# Send negotiation message
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
self.rawuser
)
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
reshdr = dict(res.getheaders())
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", reshdr)
log.debug("Response data: %s [...]", res.read(100))
# Remove the reference to the socket, so that it can not be closed by
# the response object (we want to keep the socket open)
res.fp = None
# Server should respond with a challenge message
auth_header_values = reshdr[resp_header].split(", ")
auth_header_value = None
for s in auth_header_values:
if s[:5] == "NTLM ":
auth_header_value = s[5:]
if auth_header_value is None:
raise Exception(
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
)
# Send authentication message
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
auth_header_value
)
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
)
headers[req_header] = "NTLM %s" % auth_msg
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", dict(res.getheaders()))
log.debug("Response data: %s [...]", res.read()[:100])
if res.status != 200:
if res.status == 401:
raise Exception("Server rejected request: wrong username or password")
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
res.fp = None
log.debug("Connection established")
return conn
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=3,
redirect=True,
assert_same_host=True,
):
if headers is None:
headers = {}
headers["Connection"] = "Keep-Alive"
return super(NTLMConnectionPool, self).urlopen(
method, url, body, headers, retries, redirect, assert_same_host
)

View File

@@ -0,0 +1,501 @@
"""
SSL with SNI_-support for Python 2. Follow these instructions if you would
like to verify SSL certificates in Python 2. Note, the default libraries do
*not* do certificate checking; you need to do additional work to validate
certificates yourself.
This needs the following packages installed:
* pyOpenSSL (tested with 16.0.0)
* cryptography (minimum 1.3.4, from pyopenssl)
* idna (minimum 2.0, from cryptography)
However, pyopenssl depends on cryptography, which depends on idna, so while we
use all three directly here we end up having relatively few packages required.
You can install them with the following command:
pip install pyopenssl cryptography idna
To activate certificate checking, call
:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
before you begin making HTTP requests. This can be done in a ``sitecustomize``
module, or at any other time before your application begins using ``urllib3``,
like this::
try:
import urllib3.contrib.pyopenssl
urllib3.contrib.pyopenssl.inject_into_urllib3()
except ImportError:
pass
Now you can use :mod:`urllib3` as you normally would, and it will support SNI
when the required modules are installed.
Activating this module also has the positive side effect of disabling SSL/TLS
compression in Python 2 (see `CRIME attack`_).
If you want to configure the default list of supported cipher suites, you can
set the ``urllib3.contrib.pyopenssl.DEFAULT_SSL_CIPHER_LIST`` variable.
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
"""
from __future__ import absolute_import
import OpenSSL.SSL
from cryptography import x509
from cryptography.hazmat.backends.openssl import backend as openssl_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate
try:
from cryptography.x509 import UnsupportedExtension
except ImportError:
# UnsupportedExtension is gone in cryptography >= 2.1.0
class UnsupportedExtension(Exception):
pass
from socket import timeout, error as SocketError
from io import BytesIO
try: # Platform-specific: Python 2
from socket import _fileobject
except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
import logging
import ssl
from ..packages import six
import sys
from .. import util
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works.
HAS_SNI = True
# Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = {
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
}
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
_stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext
log = logging.getLogger(__name__)
def inject_into_urllib3():
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met()
util.SSLContext = PyOpenSSLContext
util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_PYOPENSSL = True
util.ssl_.IS_PYOPENSSL = True
def extract_from_urllib3():
"Undo monkey-patching by :func:`inject_into_urllib3`."
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_PYOPENSSL = False
util.ssl_.IS_PYOPENSSL = False
def _validate_dependencies_met():
"""
Verifies that PyOpenSSL's package-level dependencies have been met.
Throws `ImportError` if they are not met.
"""
# Method added in `cryptography==1.1`; not available in older versions
from cryptography.x509.extensions import Extensions
if getattr(Extensions, "get_extension_for_class", None) is None:
raise ImportError(
"'cryptography' module missing required functionality. "
"Try upgrading to v1.3.4 or newer."
)
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
# attribute is only present on those versions.
from OpenSSL.crypto import X509
x509 = X509()
if getattr(x509, "_x509", None) is None:
raise ImportError(
"'pyOpenSSL' module missing required functionality. "
"Try upgrading to v0.14 or newer."
)
def _dnsname_to_stdlib(name):
"""
Converts a dNSName SubjectAlternativeName field to the form used by the
standard library on the given Python version.
Cryptography produces a dNSName as a unicode string that was idna-decoded
from ASCII bytes. We need to idna-encode that string to get it back, and
then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
If the name cannot be idna-encoded then we return None signalling that
the name given should be skipped.
"""
def idna_encode(name):
"""
Borrowed wholesale from the Python Cryptography Project. It turns out
that we can't just safely call `idna.encode`: it can explode for
wildcard names. This avoids that problem.
"""
import idna
try:
for prefix in [u"*.", u"."]:
if name.startswith(prefix):
name = name[len(prefix) :]
return prefix.encode("ascii") + idna.encode(name)
return idna.encode(name)
except idna.core.IDNAError:
return None
# Don't send IPv6 addresses through the IDNA encoder.
if ":" in name:
return name
name = idna_encode(name)
if name is None:
return None
elif sys.version_info >= (3, 0):
name = name.decode("utf-8")
return name
def get_subj_alt_name(peer_cert):
"""
Given an PyOpenSSL certificate, provides all the subject alternative names.
"""
# Pass the cert to cryptography, which has much better APIs for this.
if hasattr(peer_cert, "to_cryptography"):
cert = peer_cert.to_cryptography()
else:
# This is technically using private APIs, but should work across all
# relevant versions before PyOpenSSL got a proper API for this.
cert = _Certificate(openssl_backend, peer_cert._x509)
# We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python)
try:
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
except x509.ExtensionNotFound:
# No such extension, return the empty list.
return []
except (
x509.DuplicateExtension,
UnsupportedExtension,
x509.UnsupportedGeneralNameType,
UnicodeError,
) as e:
# A problem has been found with the quality of the certificate. Assume
# no SAN field is present.
log.warning(
"A problem was encountered with the certificate that prevented "
"urllib3 from finding the SubjectAlternativeName field. This can "
"affect certificate validation. The error was %s",
e,
)
return []
# We want to return dNSName and iPAddress fields. We need to cast the IPs
# back to strings because the match_hostname function wants them as
# strings.
# Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
# decoded. This is pretty frustrating, but that's what the standard library
# does with certificates, and so we need to attempt to do the same.
# We also want to skip over names which cannot be idna encoded.
names = [
("DNS", name)
for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
if name is not None
]
names.extend(
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
)
return names
class WrappedSocket(object):
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
collector of pypy.
"""
def __init__(self, connection, socket, suppress_ragged_eofs=True):
self.connection = connection
self.socket = socket
self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0
self._closed = False
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
self._makefile_refs -= 1
if self._closed:
self.close()
def recv(self, *args, **kwargs):
try:
data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b""
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b""
else:
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out")
else:
return self.recv(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
else:
return data
def recv_into(self, *args, **kwargs):
try:
return self.connection.recv_into(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return 0
else:
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out")
else:
return self.recv_into(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
def settimeout(self, timeout):
return self.socket.settimeout(timeout)
def _send_until_done(self, data):
while True:
try:
return self.connection.send(data)
except OpenSSL.SSL.WantWriteError:
if not util.wait_for_write(self.socket, self.socket.gettimeout()):
raise timeout()
continue
except OpenSSL.SSL.SysCallError as e:
raise SocketError(str(e))
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self._send_until_done(
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
)
total_sent += sent
def shutdown(self):
# FIXME rethrow compatible exceptions should we ever use this
self.connection.shutdown()
def close(self):
if self._makefile_refs < 1:
try:
self._closed = True
return self.connection.close()
except OpenSSL.SSL.Error:
return
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False):
x509 = self.connection.get_peer_certificate()
if not x509:
return x509
if binary_form:
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
return {
"subject": ((("commonName", x509.get_subject().CN),),),
"subjectAltName": get_subj_alt_name(x509),
}
def version(self):
return self.connection.get_protocol_version_name()
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
makefile = backport_makefile
WrappedSocket.makefile = makefile
class PyOpenSSLContext(object):
"""
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
for translating the interface of the standard library ``SSLContext`` object
to calls into PyOpenSSL.
"""
def __init__(self, protocol):
self.protocol = _openssl_versions[protocol]
self._ctx = OpenSSL.SSL.Context(self.protocol)
self._options = 0
self.check_hostname = False
@property
def options(self):
return self._options
@options.setter
def options(self, value):
self._options = value
self._ctx.set_options(value)
@property
def verify_mode(self):
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
@verify_mode.setter
def verify_mode(self, value):
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
def set_default_verify_paths(self):
self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers):
if isinstance(ciphers, six.text_type):
ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
if cafile is not None:
cafile = cafile.encode("utf-8")
if capath is not None:
capath = capath.encode("utf-8")
try:
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_chain_file(certfile)
if password is not None:
if not isinstance(password, six.binary_type):
password = password.encode("utf-8")
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile)
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
server_hostname = server_hostname.encode("utf-8")
if server_hostname is not None:
cnx.set_tlsext_host_name(server_hostname)
cnx.set_connect_state()
while True:
try:
cnx.do_handshake()
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(sock, sock.gettimeout()):
raise timeout("select timed out")
continue
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("bad handshake: %r" % e)
break
return WrappedSocket(cnx, sock)
def _verify_callback(cnx, x509, err_no, err_depth, return_code):
return err_no == 0

View File

@@ -0,0 +1,864 @@
"""
SecureTranport support for urllib3 via ctypes.
This makes platform-native TLS available to urllib3 users on macOS without the
use of a compiler. This is an important feature because the Python Package
Index is moving to become a TLSv1.2-or-higher server, and the default OpenSSL
that ships with macOS is not capable of doing TLSv1.2. The only way to resolve
this is to give macOS users an alternative solution to the problem, and that
solution is to use SecureTransport.
We use ctypes here because this solution must not require a compiler. That's
because pip is not allowed to require a compiler either.
This is not intended to be a seriously long-term solution to this problem.
The hope is that PEP 543 will eventually solve this issue for us, at which
point we can retire this contrib module. But in the short term, we need to
solve the impending tire fire that is Python on Mac without this kind of
contrib module. So...here we are.
To use this module, simply import and inject it::
import urllib3.contrib.securetransport
urllib3.contrib.securetransport.inject_into_urllib3()
Happy TLSing!
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
import contextlib
import ctypes
import errno
import os.path
import shutil
import socket
import ssl
import threading
import weakref
from .. import util
from ._securetransport.bindings import Security, SecurityConst, CoreFoundation
from ._securetransport.low_level import (
_assert_no_error,
_cert_array_from_pem,
_temporary_keychain,
_load_client_cert_chain,
)
try: # Platform-specific: Python 2
from socket import _fileobject
except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works
HAS_SNI = True
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext
# This dictionary is used by the read callback to obtain a handle to the
# calling wrapped socket. This is a pretty silly approach, but for now it'll
# do. I feel like I should be able to smuggle a handle to the wrapped socket
# directly in the SSLConnectionRef, but for now this approach will work I
# guess.
#
# We need to lock around this structure for inserts, but we don't do it for
# reads/writes in the callbacks. The reasoning here goes as follows:
#
# 1. It is not possible to call into the callbacks before the dictionary is
# populated, so once in the callback the id must be in the dictionary.
# 2. The callbacks don't mutate the dictionary, they only read from it, and
# so cannot conflict with any of the insertions.
#
# This is good: if we had to lock in the callbacks we'd drastically slow down
# the performance of this code.
_connection_refs = weakref.WeakValueDictionary()
_connection_ref_lock = threading.Lock()
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
# for no better reason than we need *a* limit, and this one is right there.
SSL_WRITE_BLOCKSIZE = 16384
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
# individual cipher suites. We need to do this because this is how
# SecureTransport wants them.
CIPHER_SUITES = [
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_AES_256_GCM_SHA384,
SecurityConst.TLS_AES_128_GCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_AES_128_CCM_8_SHA256,
SecurityConst.TLS_AES_128_CCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA,
]
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = {
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12)
}
if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
)
if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
)
if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
)
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
)
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
)
def inject_into_urllib3():
"""
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
"""
util.SSLContext = SecureTransportContext
util.ssl_.SSLContext = SecureTransportContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True
def extract_from_urllib3():
"""
Undo monkey-patching by :func:`inject_into_urllib3`.
"""
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False
def _read_callback(connection_id, data_buffer, data_length_pointer):
"""
SecureTransport read callback. This is called by ST to request that data
be returned from the socket.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
requested_length = data_length_pointer[0]
timeout = wrapped_socket.gettimeout()
error = None
read_count = 0
try:
while read_count < requested_length:
if timeout is None or timeout >= 0:
if not util.wait_for_read(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out")
remaining = requested_length - read_count
buffer = (ctypes.c_char * remaining).from_address(
data_buffer + read_count
)
chunk_size = base_socket.recv_into(buffer, remaining)
read_count += chunk_size
if not chunk_size:
if not read_count:
return SecurityConst.errSSLClosedGraceful
break
except (socket.error) as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = read_count
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = read_count
if read_count != requested_length:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
def _write_callback(connection_id, data_buffer, data_length_pointer):
"""
SecureTransport write callback. This is called by ST to request that data
actually be sent on the network.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
bytes_to_write = data_length_pointer[0]
data = ctypes.string_at(data_buffer, bytes_to_write)
timeout = wrapped_socket.gettimeout()
error = None
sent = 0
try:
while sent < bytes_to_write:
if timeout is None or timeout >= 0:
if not util.wait_for_write(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data)
sent += chunk_sent
# This has some needless copying here, but I'm not sure there's
# much value in optimising this data path.
data = data[chunk_sent:]
except (socket.error) as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = sent
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = sent
if sent != bytes_to_write:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
# We need to keep these two objects references alive: if they get GC'd while
# in use then SecureTransport could attempt to call a function that is in freed
# memory. That would be...uh...bad. Yeah, that's the word. Bad.
_read_callback_pointer = Security.SSLReadFunc(_read_callback)
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
class WrappedSocket(object):
"""
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
collector of PyPy.
"""
def __init__(self, socket):
self.socket = socket
self.context = None
self._makefile_refs = 0
self._closed = False
self._exception = None
self._keychain = None
self._keychain_dir = None
self._client_cert_chain = None
# We save off the previously-configured timeout and then set it to
# zero. This is done because we use select and friends to handle the
# timeouts, but if we leave the timeout set on the lower socket then
# Python will "kindly" call select on that socket again for us. Avoid
# that by forcing the timeout to zero.
self._timeout = self.socket.gettimeout()
self.socket.settimeout(0)
@contextlib.contextmanager
def _raise_on_error(self):
"""
A context manager that can be used to wrap calls that do I/O from
SecureTransport. If any of the I/O callbacks hit an exception, this
context manager will correctly propagate the exception after the fact.
This avoids silently swallowing those exceptions.
It also correctly forces the socket closed.
"""
self._exception = None
# We explicitly don't catch around this yield because in the unlikely
# event that an exception was hit in the block we don't want to swallow
# it.
yield
if self._exception is not None:
exception, self._exception = self._exception, None
self.close()
raise exception
def _set_ciphers(self):
"""
Sets up the allowed ciphers. By default this matches the set in
util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done
custom and doesn't allow changing at this time, mostly because parsing
OpenSSL cipher strings is going to be a freaking nightmare.
"""
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
result = Security.SSLSetEnabledCiphers(
self.context, ciphers, len(CIPHER_SUITES)
)
_assert_no_error(result)
def _custom_validate(self, verify, trust_bundle):
"""
Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when
using a custom trust DB.
"""
# If we disabled cert validation, just say: cool.
if not verify:
return
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f:
trust_bundle = f.read()
cert_array = None
trust = Security.SecTrustRef()
try:
# Get a CFArray that contains the certs we want.
cert_array = _cert_array_from_pem(trust_bundle)
# Ok, now the hard part. We want to get the SecTrustRef that ST has
# created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
raise ssl.SSLError("Failed to copy trust reference")
result = Security.SecTrustSetAnchorCertificates(trust, cert_array)
_assert_no_error(result)
result = Security.SecTrustSetAnchorCertificatesOnly(trust, True)
_assert_no_error(result)
trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
_assert_no_error(result)
finally:
if trust:
CoreFoundation.CFRelease(trust)
if cert_array is not None:
CoreFoundation.CFRelease(cert_array)
# Ok, now we can look at what the result was.
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" % trust_result.value
)
def handshake(
self,
server_hostname,
verify,
trust_bundle,
min_version,
max_version,
client_cert,
client_key,
client_key_passphrase,
):
"""
Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code.
"""
# First, we do the initial bits of connection setup. We need to create
# a context, set its I/O funcs, and set the connection reference.
self.context = Security.SSLCreateContext(
None, SecurityConst.kSSLClientSide, SecurityConst.kSSLStreamType
)
result = Security.SSLSetIOFuncs(
self.context, _read_callback_pointer, _write_callback_pointer
)
_assert_no_error(result)
# Here we need to compute the handle to use. We do this by taking the
# id of self modulo 2**31 - 1. If this is already in the dictionary, we
# just keep incrementing by one until we find a free space.
with _connection_ref_lock:
handle = id(self) % 2147483647
while handle in _connection_refs:
handle = (handle + 1) % 2147483647
_connection_refs[handle] = self
result = Security.SSLSetConnection(self.context, handle)
_assert_no_error(result)
# If we have a server hostname, we should set that too.
if server_hostname:
if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname)
)
_assert_no_error(result)
# Setup the ciphers.
self._set_ciphers()
# Set the minimum and maximum TLS versions.
result = Security.SSLSetProtocolVersionMin(self.context, min_version)
_assert_no_error(result)
result = Security.SSLSetProtocolVersionMax(self.context, max_version)
_assert_no_error(result)
# If there's a trust DB, we need to use it. We do that by telling
# SecureTransport to break on server auth. We also do that if we don't
# want to validate the certs at all: we just won't actually do any
# authing in that case.
if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption(
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
)
_assert_no_error(result)
# If there's a client cert, we need to use it.
if client_cert:
self._keychain, self._keychain_dir = _temporary_keychain()
self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key
)
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
_assert_no_error(result)
while True:
with self._raise_on_error():
result = Security.SSLHandshake(self.context)
if result == SecurityConst.errSSLWouldBlock:
raise socket.timeout("handshake timed out")
elif result == SecurityConst.errSSLServerAuthCompleted:
self._custom_validate(verify, trust_bundle)
continue
else:
_assert_no_error(result)
break
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
self._makefile_refs -= 1
if self._closed:
self.close()
def recv(self, bufsiz):
buffer = ctypes.create_string_buffer(bufsiz)
bytes_read = self.recv_into(buffer, bufsiz)
data = buffer[:bytes_read]
return data
def recv_into(self, buffer, nbytes=None):
# Read short on EOF.
if self._closed:
return 0
if nbytes is None:
nbytes = len(buffer)
buffer = (ctypes.c_char * nbytes).from_buffer(buffer)
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLRead(
self.context, buffer, nbytes, ctypes.byref(processed_bytes)
)
# There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify.
if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short"
# and return.
if processed_bytes.value == 0:
# Timed out, no data read.
raise socket.timeout("recv timed out")
elif result in (
SecurityConst.errSSLClosedGraceful,
SecurityConst.errSSLClosedNoNotify,
):
# The remote peer has closed this connection. We should do so as
# well. Note that we don't actually return here because in
# principle this could actually be fired along with return data.
# It's unlikely though.
self.close()
else:
_assert_no_error(result)
# Ok, we read and probably succeeded. We should return whatever data
# was actually read.
return processed_bytes.value
def settimeout(self, timeout):
self._timeout = timeout
def gettimeout(self):
return self._timeout
def send(self, data):
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLWrite(
self.context, data, len(data), ctypes.byref(processed_bytes)
)
if result == SecurityConst.errSSLWouldBlock and processed_bytes.value == 0:
# Timed out
raise socket.timeout("send timed out")
else:
_assert_no_error(result)
# We sent, and probably succeeded. Tell them how much we sent.
return processed_bytes.value
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self):
with self._raise_on_error():
Security.SSLClose(self.context)
def close(self):
# TODO: should I do clean shutdown here? Do I have to?
if self._makefile_refs < 1:
self._closed = True
if self.context:
CoreFoundation.CFRelease(self.context)
self.context = None
if self._client_cert_chain:
CoreFoundation.CFRelease(self._client_cert_chain)
self._client_cert_chain = None
if self._keychain:
Security.SecKeychainDelete(self._keychain)
CoreFoundation.CFRelease(self._keychain)
shutil.rmtree(self._keychain_dir)
self._keychain = self._keychain_dir = None
return self.socket.close()
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False):
# Urgh, annoying.
#
# Here's how we do this:
#
# 1. Call SSLCopyPeerTrust to get hold of the trust object for this
# connection.
# 2. Call SecTrustGetCertificateAtIndex for index 0 to get the leaf.
# 3. To get the CN, call SecCertificateCopyCommonName and process that
# string so that it's of the appropriate type.
# 4. To get the SAN, we need to do something a bit more complex:
# a. Call SecCertificateCopyValues to get the data, requesting
# kSecOIDSubjectAltName.
# b. Mess about with this dictionary to try to get the SANs out.
#
# This is gross. Really gross. It's going to be a few hundred LoC extra
# just to repeat something that SecureTransport can *already do*. So my
# operating assumption at this time is that what we want to do is
# instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport.
if not binary_form:
raise ValueError("SecureTransport only supports dumping binary certs")
trust = Security.SecTrustRef()
certdata = None
der_bytes = None
try:
# Grab the trust store.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
# Probably we haven't done the handshake yet. No biggie.
return None
cert_count = Security.SecTrustGetCertificateCount(trust)
if not cert_count:
# Also a case that might happen if we haven't handshaked.
# Handshook? Handshaken?
return None
leaf = Security.SecTrustGetCertificateAtIndex(trust, 0)
assert leaf
# Ok, now we want the DER bytes.
certdata = Security.SecCertificateCopyData(leaf)
assert certdata
data_length = CoreFoundation.CFDataGetLength(certdata)
data_buffer = CoreFoundation.CFDataGetBytePtr(certdata)
der_bytes = ctypes.string_at(data_buffer, data_length)
finally:
if certdata:
CoreFoundation.CFRelease(certdata)
if trust:
CoreFoundation.CFRelease(trust)
return der_bytes
def version(self):
protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol)
)
_assert_no_error(result)
if protocol.value == SecurityConst.kTLSProtocol13:
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
elif protocol.value == SecurityConst.kTLSProtocol12:
return "TLSv1.2"
elif protocol.value == SecurityConst.kTLSProtocol11:
return "TLSv1.1"
elif protocol.value == SecurityConst.kTLSProtocol1:
return "TLSv1"
elif protocol.value == SecurityConst.kSSLProtocol3:
return "SSLv3"
elif protocol.value == SecurityConst.kSSLProtocol2:
return "SSLv2"
else:
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
def makefile(self, mode="r", buffering=None, *args, **kwargs):
# We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more).
buffering = 0
return backport_makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile
class SecureTransportContext(object):
"""
I am a wrapper class for the SecureTransport library, to translate the
interface of the standard library ``SSLContext`` object to calls into
SecureTransport.
"""
def __init__(self, protocol):
self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0
self._verify = False
self._trust_bundle = None
self._client_cert = None
self._client_key = None
self._client_key_passphrase = None
@property
def check_hostname(self):
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
return True
@check_hostname.setter
def check_hostname(self, value):
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
pass
@property
def options(self):
# TODO: Well, crap.
#
# So this is the bit of the code that is the most likely to cause us
# trouble. Essentially we need to enumerate all of the SSL options that
# users might want to use and try to see if we can sensibly translate
# them, or whether we should just ignore them.
return self._options
@options.setter
def options(self, value):
# TODO: Update in line with above.
self._options = value
@property
def verify_mode(self):
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
@verify_mode.setter
def verify_mode(self, value):
self._verify = True if value == ssl.CERT_REQUIRED else False
def set_default_verify_paths(self):
# So, this has to do something a bit weird. Specifically, what it does
# is nothing.
#
# This means that, if we had previously had load_verify_locations
# called, this does not undo that. We need to do that because it turns
# out that the rest of the urllib3 code will attempt to load the
# default verify paths if it hasn't been told about any paths, even if
# the context itself was sometime earlier. We resolve that by just
# ignoring it.
pass
def load_default_certs(self):
return self.set_default_verify_paths()
def set_ciphers(self, ciphers):
# For now, we just require the default cipher string.
if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# OK, we only really support cadata and cafile.
if capath is not None:
raise ValueError("SecureTransport does not support cert directories")
# Raise if cafile does not exist.
if cafile is not None:
with open(cafile):
pass
self._trust_bundle = cafile or cadata
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._client_cert = certfile
self._client_key = keyfile
self._client_cert_passphrase = password
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
# So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal.
assert not server_side
assert do_handshake_on_connect
assert suppress_ragged_eofs
# Ok, we're good to go. Now we want to create the wrapped socket object
# and store it in the appropriate place.
wrapped_socket = WrappedSocket(sock)
# Now we can handshake
wrapped_socket.handshake(
server_hostname,
self._verify,
self._trust_bundle,
self._min_version,
self._max_version,
self._client_cert,
self._client_key,
self._client_key_passphrase,
)
return wrapped_socket

210
urllib3/contrib/socks.py Normal file
View File

@@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install PySocks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request::
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy::
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import absolute_import
try:
import socks
except ImportError:
import warnings
from ..exceptions import DependencyWarning
warnings.warn(
(
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
raise
from socket import error as SocketError, timeout as SocketTimeout
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None
class SOCKSConnection(HTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(self, *args, **kwargs):
self._socks_options = kwargs.pop("_socks_options")
super(SOCKSConnection, self).__init__(*args, **kwargs)
def _new_conn(self):
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw["socket_options"] = self.socket_options
try:
conn = socks.create_connection(
(self.host, self.port),
proxy_type=self._socks_options["socks_version"],
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"],
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout,
**extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise
# useful errors here.
if e.socket_err:
error = e.socket_err
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
else:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % error
)
else:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
except SocketError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
return conn
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url,
username=None,
password=None,
num_pools=10,
headers=None,
**connection_pool_kw
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw
)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme

272
urllib3/exceptions.py Normal file
View File

@@ -0,0 +1,272 @@
from __future__ import absolute_import
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
# Base Exceptions
class HTTPError(Exception):
"Base exception used by this module."
pass
class HTTPWarning(Warning):
"Base warning used by this module."
pass
class PoolError(HTTPError):
"Base exception for errors caused within a pool."
def __init__(self, pool, message):
self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message))
def __reduce__(self):
# For pickling purposes.
return self.__class__, (None, None)
class RequestError(PoolError):
"Base exception for PoolErrors that have associated URLs."
def __init__(self, pool, url, message):
self.url = url
PoolError.__init__(self, pool, message)
def __reduce__(self):
# For pickling purposes.
return self.__class__, (None, self.url, None)
class SSLError(HTTPError):
"Raised when SSL certificate fails in an HTTPS connection."
pass
class ProxyError(HTTPError):
"Raised when the connection to a proxy fails."
def __init__(self, message, error, *args):
super(ProxyError, self).__init__(message, error, *args)
self.original_error = error
class DecodeError(HTTPError):
"Raised when automatic decoding based on Content-Type fails."
pass
class ProtocolError(HTTPError):
"Raised when something unexpected happens mid-request/response."
pass
#: Renamed to ProtocolError but aliased for backwards compatibility.
ConnectionError = ProtocolError
# Leaf Exceptions
class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded.
:param pool: The connection pool
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
:param string url: The requested Url
:param exceptions.Exception reason: The underlying error
"""
def __init__(self, pool, url, reason=None):
self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
RequestError.__init__(self, pool, url, message)
class HostChangedError(RequestError):
"Raised when an existing pool gets a request for a foreign host."
def __init__(self, pool, url, retries=3):
message = "Tried to open a foreign host with url: %s" % url
RequestError.__init__(self, pool, url, message)
self.retries = retries
class TimeoutStateError(HTTPError):
""" Raised when passing an invalid state to a timeout """
pass
class TimeoutError(HTTPError):
""" Raised when a socket timeout error occurs.
Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
"""
pass
class ReadTimeoutError(TimeoutError, RequestError):
"Raised when a socket timeout occurs while receiving data from a server"
pass
# This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError
class ConnectTimeoutError(TimeoutError):
"Raised when a socket timeout occurs while connecting to a server"
pass
class NewConnectionError(ConnectTimeoutError, PoolError):
"Raised when we fail to establish a new connection. Usually ECONNREFUSED."
pass
class EmptyPoolError(PoolError):
"Raised when a pool runs out of connections and no more are allowed."
pass
class ClosedPoolError(PoolError):
"Raised when a request enters a pool after the pool has been closed."
pass
class LocationValueError(ValueError, HTTPError):
"Raised when there is something wrong with a given URL input."
pass
class LocationParseError(LocationValueError):
"Raised when get_host or similar fails to parse the URL input."
def __init__(self, location):
message = "Failed to parse: %s" % location
HTTPError.__init__(self, message)
self.location = location
class ResponseError(HTTPError):
"Used as a container for an error reason supplied in a MaxRetryError."
GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning):
"Warned when performing security reducing actions"
pass
class SubjectAltNameWarning(SecurityWarning):
"Warned when connecting to a host with a certificate missing a SAN."
pass
class InsecureRequestWarning(SecurityWarning):
"Warned when making an unverified HTTPS request."
pass
class SystemTimeWarning(SecurityWarning):
"Warned when system time is suspected to be wrong"
pass
class InsecurePlatformWarning(SecurityWarning):
"Warned when certain SSL configuration is not available on a platform."
pass
class SNIMissingWarning(HTTPWarning):
"Warned when making a HTTPS request without SNI available."
pass
class DependencyWarning(HTTPWarning):
"""
Warned when an attempt is made to import a module with missing optional
dependencies.
"""
pass
class InvalidProxyConfigurationWarning(HTTPWarning):
"""
Warned when using an HTTPS proxy and an HTTPS URL. Currently
urllib3 doesn't support HTTPS proxies and the proxy will be
contacted via HTTP instead. This warning can be fixed by
changing your HTTPS proxy URL into an HTTP proxy URL.
If you encounter this warning read this:
https://github.com/urllib3/urllib3/issues/1850
"""
pass
class ResponseNotChunked(ProtocolError, ValueError):
"Response needs to be chunked in order to read it as chunks."
pass
class BodyNotHttplibCompatible(HTTPError):
"""
Body should be httplib.HTTPResponse like (have an fp attribute which
returns raw chunks) for read_chunked().
"""
pass
class IncompleteRead(HTTPError, httplib_IncompleteRead):
"""
Response length doesn't match expected Content-Length
Subclass of http_client.IncompleteRead to allow int value
for `partial` to avoid creating large objects on streamed
reads.
"""
def __init__(self, partial, expected):
super(IncompleteRead, self).__init__(partial, expected)
def __repr__(self):
return "IncompleteRead(%i bytes read, %i more expected)" % (
self.partial,
self.expected,
)
class InvalidHeader(HTTPError):
"The header provided was somehow invalid."
pass
class ProxySchemeUnknown(AssertionError, ValueError):
"ProxyManager does not support the supplied scheme"
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme):
message = "Not supported proxy scheme %s" % scheme
super(ProxySchemeUnknown, self).__init__(message)
class HeaderParsingError(HTTPError):
"Raised by assert_header_parsing, but we convert it to a log.warning statement."
def __init__(self, defects, unparsed_data):
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
super(HeaderParsingError, self).__init__(message)
class UnrewindableBodyError(HTTPError):
"urllib3 encountered an error when trying to rewind a body"
pass

273
urllib3/fields.py Normal file
View File

@@ -0,0 +1,273 @@
from __future__ import absolute_import
import email.utils
import mimetypes
import re
from .packages import six
def guess_content_type(filename, default="application/octet-stream"):
"""
Guess the "Content-Type" of a file.
:param filename:
The filename to guess the "Content-Type" of using :mod:`mimetypes`.
:param default:
If no "Content-Type" can be guessed, default to `default`.
"""
if filename:
return mimetypes.guess_type(filename)[0] or default
return default
def format_header_param_rfc2231(name, value):
"""
Helper function to format and quote a single header parameter using the
strategy defined in RFC 2231.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows RFC 2388 Section 4.4.
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:ret:
An RFC-2231-formatted unicode string.
"""
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
if not any(ch in value for ch in '"\\\r\n'):
result = u'%s="%s"' % (name, value)
try:
result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError):
pass
else:
return result
if six.PY2: # Python 2:
value = value.encode("utf-8")
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
# string in Python 2 but accepts and returns unicode strings in Python 3
value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value)
if six.PY2: # Python 2:
value = value.decode("utf-8")
return value
_HTML5_REPLACEMENTS = {
u"\u0022": u"%22",
# Replace "\" with "\\".
u"\u005C": u"\u005C\u005C",
u"\u005C": u"\u005C\u005C",
}
# All control characters from 0x00 to 0x1F *except* 0x1B.
_HTML5_REPLACEMENTS.update(
{
six.unichr(cc): u"%{:02X}".format(cc)
for cc in range(0x00, 0x1F + 1)
if cc not in (0x1B,)
}
)
def _replace_multiple(value, needles_and_replacements):
def replacer(match):
return needles_and_replacements[match.group(0)]
pattern = re.compile(
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
)
result = pattern.sub(replacer, value)
return result
def format_header_param_html5(name, value):
"""
Helper function to format and quote a single header parameter using the
HTML5 strategy.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows the `HTML5 Working Draft
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers.
.. _HTML5 Working Draft Section 4.10.22.7:
https://w3c.github.io/html/sec-forms.html#multipart-form-data
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:ret:
A unicode string, stripped of troublesome characters.
"""
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
value = _replace_multiple(value, _HTML5_REPLACEMENTS)
return u'%s="%s"' % (name, value)
# For backwards-compatibility.
format_header_param = format_header_param_html5
class RequestField(object):
"""
A data container for request body parameters.
:param name:
The name of this request field. Must be unicode.
:param data:
The data/value body.
:param filename:
An optional filename of the request field. Must be unicode.
:param headers:
An optional dict-like object of headers to initially use for the field.
:param header_formatter:
An optional callable that is used to encode and format the headers. By
default, this is :func:`format_header_param_html5`.
"""
def __init__(
self,
name,
data,
filename=None,
headers=None,
header_formatter=format_header_param_html5,
):
self._name = name
self._filename = filename
self.data = data
self.headers = {}
if headers:
self.headers = dict(headers)
self.header_formatter = header_formatter
@classmethod
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
"""
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
Supports constructing :class:`~urllib3.fields.RequestField` from
parameter of key/value strings AND key/filetuple. A filetuple is a
(filename, data, MIME type) tuple where the MIME type is optional.
For example::
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
Field names and filenames must be unicode.
"""
if isinstance(value, tuple):
if len(value) == 3:
filename, data, content_type = value
else:
filename, data = value
content_type = guess_content_type(filename)
else:
filename = None
content_type = None
data = value
request_param = cls(
fieldname, data, filename=filename, header_formatter=header_formatter
)
request_param.make_multipart(content_type=content_type)
return request_param
def _render_part(self, name, value):
"""
Overridable helper function to format a single header parameter. By
default, this calls ``self.header_formatter``.
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as a unicode string.
"""
return self.header_formatter(name, value)
def _render_parts(self, header_parts):
"""
Helper function to format and quote a single header.
Useful for single headers that are composed of multiple items. E.g.,
'Content-Disposition' fields.
:param header_parts:
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
as `k1="v1"; k2="v2"; ...`.
"""
parts = []
iterable = header_parts
if isinstance(header_parts, dict):
iterable = header_parts.items()
for name, value in iterable:
if value is not None:
parts.append(self._render_part(name, value))
return u"; ".join(parts)
def render_headers(self):
"""
Renders the headers for this request field.
"""
lines = []
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys:
if self.headers.get(sort_key, False):
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
for header_name, header_value in self.headers.items():
if header_name not in sort_keys:
if header_value:
lines.append(u"%s: %s" % (header_name, header_value))
lines.append(u"\r\n")
return u"\r\n".join(lines)
def make_multipart(
self, content_disposition=None, content_type=None, content_location=None
):
"""
Makes this request field into a multipart request field.
This method overrides "Content-Disposition", "Content-Type" and
"Content-Location" headers to the request parameter.
:param content_type:
The 'Content-Type' of the request body.
:param content_location:
The 'Content-Location' of the request body.
"""
self.headers["Content-Disposition"] = content_disposition or u"form-data"
self.headers["Content-Disposition"] += u"; ".join(
[
u"",
self._render_parts(
((u"name", self._name), (u"filename", self._filename))
),
]
)
self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location

98
urllib3/filepost.py Normal file
View File

@@ -0,0 +1,98 @@
from __future__ import absolute_import
import binascii
import codecs
import os
from io import BytesIO
from .packages import six
from .packages.six import b
from .fields import RequestField
writer = codecs.lookup("utf-8")[3]
def choose_boundary():
"""
Our embarrassingly-simple replacement for mimetools.choose_boundary.
"""
boundary = binascii.hexlify(os.urandom(16))
if not six.PY2:
boundary = boundary.decode("ascii")
return boundary
def iter_field_objects(fields):
"""
Iterate over fields.
Supports list of (k, v) tuples and dicts, and lists of
:class:`~urllib3.fields.RequestField`.
"""
if isinstance(fields, dict):
i = six.iteritems(fields)
else:
i = iter(fields)
for field in i:
if isinstance(field, RequestField):
yield field
else:
yield RequestField.from_tuples(*field)
def iter_fields(fields):
"""
.. deprecated:: 1.6
Iterate over fields.
The addition of :class:`~urllib3.fields.RequestField` makes this function
obsolete. Instead, use :func:`iter_field_objects`, which returns
:class:`~urllib3.fields.RequestField` objects.
Supports list of (k, v) tuples and dicts.
"""
if isinstance(fields, dict):
return ((k, v) for k, v in six.iteritems(fields))
return ((k, v) for k, v in fields)
def encode_multipart_formdata(fields, boundary=None):
"""
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
:param fields:
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
:param boundary:
If not specified, then a random boundary will be generated using
:func:`urllib3.filepost.choose_boundary`.
"""
body = BytesIO()
if boundary is None:
boundary = choose_boundary()
for field in iter_field_objects(fields):
body.write(b("--%s\r\n" % (boundary)))
writer(body).write(field.render_headers())
data = field.data
if isinstance(data, int):
data = str(data) # Backwards compatibility
if isinstance(data, six.text_type):
writer(body).write(data)
else:
body.write(data)
body.write(b"\r\n")
body.write(b("--%s--\r\n" % (boundary)))
content_type = str("multipart/form-data; boundary=%s" % boundary)
return body.getvalue(), content_type

View File

@@ -0,0 +1,5 @@
from __future__ import absolute_import
from . import ssl_match_hostname
__all__ = ("ssl_match_hostname",)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""
backports.makefile
~~~~~~~~~~~~~~~~~~
Backports the Python 3 ``socket.makefile`` method for use with anything that
wants to create a "fake" socket object.
"""
import io
from socket import SocketIO
def backport_makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Backport of ``socket.makefile`` from Python 3.5.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = SocketIO(self, rawmode)
self._makefile_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text

1021
urllib3/packages/six.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
import sys
try:
# Our match_hostname function is the same as 3.5's, so we only want to
# import the match_hostname function if it's at least that good.
if sys.version_info < (3, 5):
raise ImportError("Fallback to vendored code")
from ssl import CertificateError, match_hostname
except ImportError:
try:
# Backport of the function from a pypi module
from backports.ssl_match_hostname import CertificateError, match_hostname
except ImportError:
# Our vendored copy
from ._implementation import CertificateError, match_hostname
# Not needed, but documenting what we provide.
__all__ = ("CertificateError", "match_hostname")

View File

@@ -0,0 +1,160 @@
"""The match_hostname() function from Python 3.3.3, essential when using SSL."""
# Note: This file is under the PSF license as the code comes from the python
# stdlib. http://docs.python.org/3/license.html
import re
import sys
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the
# system, use it to handle IPAddress ServerAltnames (this was added in
# python-3.5) otherwise only do DNS matching. This allows
# backports.ssl_match_hostname to continue to be used in Python 2.7.
try:
import ipaddress
except ImportError:
ipaddress = None
__version__ = "3.5.0.1"
class CertificateError(ValueError):
pass
def _dnsname_match(dn, hostname, max_wildcards=1):
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
"""
pats = []
if not dn:
return False
# Ported from python3-syntax:
# leftmost, *remainder = dn.split(r'.')
parts = dn.split(r".")
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count("*")
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn)
)
# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append("[^.]+")
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
# U-label of an internationalized domain name.
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname)
def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,):
obj = unicode(obj, encoding="ascii", errors="strict")
return obj
def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this
(section 1.7.2 - "Out of Scope").
"""
# OpenSSL may add a trailing newline to a subjectAltName's IP address
# Divergence from upstream: ipaddress can't handle byte str
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip())
return ip == host_ip
def match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*.
CertificateError is raised on failure. On success, the function
returns nothing.
"""
if not cert:
raise ValueError(
"empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED"
)
try:
# Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname))
except ValueError:
# Not an IP address (common case)
host_ip = None
except UnicodeError:
# Divergence from upstream: Have to deal with ipaddress not taking
# byte strings. addresses should be all ascii, so we consider it not
# an ipaddress in this case
host_ip = None
except AttributeError:
# Divergence from upstream: Make ipaddress library optional
if ipaddress is None:
host_ip = None
else:
raise
dnsnames = []
san = cert.get("subjectAltName", ())
for key, value in san:
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
return
dnsnames.append(value)
elif key == "IP Address":
if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get("subject", ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == "commonName":
if _dnsname_match(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError(
"hostname %r "
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1:
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
else:
raise CertificateError(
"no appropriate commonName or subjectAltName fields were found"
)

492
urllib3/poolmanager.py Normal file
View File

@@ -0,0 +1,492 @@
from __future__ import absolute_import
import collections
import functools
import logging
import warnings
from ._collections import RecentlyUsedContainer
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from .connectionpool import port_by_scheme
from .exceptions import (
LocationValueError,
MaxRetryError,
ProxySchemeUnknown,
InvalidProxyConfigurationWarning,
)
from .packages import six
from .packages.six.moves.urllib.parse import urljoin
from .request import RequestMethods
from .util.url import parse_url
from .util.retry import Retry
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
log = logging.getLogger(__name__)
SSL_KEYWORDS = (
"key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
"key_password",
)
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
"key_scheme", # str
"key_host", # str
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_strict", # bool
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_key_password", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
)
#: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple("PoolKey", _key_fields)
def _default_key_normalizer(key_class, request_context):
"""
Create a pool key out of a request context dictionary.
According to RFC 3986, both the scheme and host are case-insensitive.
Therefore, this function normalizes both before constructing the pool
key for an HTTPS request. If you wish to change this behaviour, provide
alternate callables to ``key_fn_by_scheme``.
:param key_class:
The class to use when constructing the key. This should be a namedtuple
with the ``scheme`` and ``host`` keys at a minimum.
:type key_class: namedtuple
:param request_context:
A dictionary-like object that contain the context for a request.
:type request_context: dict
:return: A namedtuple that can be used as a connection pool key.
:rtype: PoolKey
"""
# Since we mutate the dictionary, make a copy first
context = request_context.copy()
context["scheme"] = context["scheme"].lower()
context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets
for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None:
context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a
# tuple.
socket_opts = context.get("socket_options")
if socket_opts is not None:
context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'.
for key in list(context.keys()):
context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context
for field in key_class._fields:
if field not in context:
context[field] = None
return key_class(**context)
#: A dictionary that maps a scheme to a callable that creates a pool key.
#: This can be used to alter the way pool keys are constructed, if desired.
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.
key_fn_by_scheme = {
"http": functools.partial(_default_key_normalizer, PoolKey),
"https": functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
class PoolManager(RequestMethods):
"""
Allows for arbitrary requests while transparently keeping track of
necessary connection pools for you.
:param num_pools:
Number of connection pools to cache before discarding the least
recently used pool.
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
:param \\**connection_pool_kw:
Additional parameters are used to create fresh
:class:`urllib3.connectionpool.ConnectionPool` instances.
Example::
>>> manager = PoolManager(num_pools=2)
>>> r = manager.request('GET', 'http://google.com/')
>>> r = manager.request('GET', 'http://google.com/mail')
>>> r = manager.request('GET', 'http://yahoo.com/')
>>> len(manager.pools)
2
"""
proxy = None
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
self.key_fn_by_scheme = key_fn_by_scheme.copy()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.clear()
# Return False to re-raise any potential exceptions
return False
def _new_pool(self, scheme, host, port, request_context=None):
"""
Create a new :class:`ConnectionPool` based on host, port, scheme, and
any additional pool keyword arguments.
If ``request_context`` is provided, it is provided as keyword arguments
to the pool class used. This method is used to actually create the
connection pools handed out by :meth:`connection_from_url` and
companion methods. It is intended to be overridden for customization.
"""
pool_cls = self.pool_classes_by_scheme[scheme]
if request_context is None:
request_context = self.connection_pool_kw.copy()
# Although the context has everything necessary to create the pool,
# this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can
# be removed.
for key in ("scheme", "host", "port"):
request_context.pop(key, None)
if scheme == "http":
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)
return pool_cls(host, port, **request_context)
def clear(self):
"""
Empty our store of pools and direct them all to close.
This will not affect in-flight connections, but they will not be
re-used after completion.
"""
self.pools.clear()
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
"""
Get a :class:`ConnectionPool` based on the host, port, and scheme.
If ``port`` isn't given, it will be derived from the ``scheme`` using
``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is
provided, it is merged with the instance's ``connection_pool_kw``
variable and used to create the new connection pool, if one is
needed.
"""
if not host:
raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs)
request_context["scheme"] = scheme or "http"
if not port:
port = port_by_scheme.get(request_context["scheme"].lower(), 80)
request_context["port"] = port
request_context["host"] = host
return self.connection_from_context(request_context)
def connection_from_context(self, request_context):
"""
Get a :class:`ConnectionPool` based on the request context.
``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable.
"""
scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme[scheme]
pool_key = pool_key_constructor(request_context)
return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None):
"""
Get a :class:`ConnectionPool` based on the provided pool key.
``pool_key`` should be a namedtuple that only contains immutable
objects. At a minimum it must have the ``scheme``, ``host``, and
``port`` fields.
"""
with self.pools.lock:
# If the scheme, host, or port doesn't match existing open
# connections, open a new ConnectionPool.
pool = self.pools.get(pool_key)
if pool:
return pool
# Make a fresh ConnectionPool of the desired type
scheme = request_context["scheme"]
host = request_context["host"]
port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool
return pool
def connection_from_url(self, url, pool_kwargs=None):
"""
Similar to :func:`urllib3.connectionpool.connection_from_url`.
If ``pool_kwargs`` is not provided and a new pool needs to be
constructed, ``self.connection_pool_kw`` is used to initialize
the :class:`urllib3.connectionpool.ConnectionPool`. If ``pool_kwargs``
is provided, it is used instead. Note that if a new pool does not
need to be created for the request, the provided ``pool_kwargs`` are
not used.
"""
u = parse_url(url)
return self.connection_from_host(
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
)
def _merge_pool_kwargs(self, override):
"""
Merge a dictionary of override values for self.connection_pool_kw.
This does not modify self.connection_pool_kw and returns a new dict.
Any keys in the override dictionary with a value of ``None`` are
removed from the merged dictionary.
"""
base_pool_kwargs = self.connection_pool_kw.copy()
if override:
for key, value in override.items():
if value is None:
try:
del base_pool_kwargs[key]
except KeyError:
pass
else:
base_pool_kwargs[key] = value
return base_pool_kwargs
def urlopen(self, method, url, redirect=True, **kw):
"""
Same as :meth:`urllib3.connectionpool.HTTPConnectionPool.urlopen`
with custom cross-host redirect logic and only sends the request-uri
portion of the ``url``.
The given ``url`` parameter must be absolute, such that an appropriate
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
"""
u = parse_url(url)
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
kw["assert_same_host"] = False
kw["redirect"] = False
if "headers" not in kw:
kw["headers"] = self.headers.copy()
if self.proxy is not None and u.scheme == "http":
response = conn.urlopen(method, url, **kw)
else:
response = conn.urlopen(method, u.request_uri, **kw)
redirect_location = redirect and response.get_redirect_location()
if not redirect_location:
return response
# Support relative URLs for redirecting.
redirect_location = urljoin(url, redirect_location)
# RFC 7231, Section 6.4.4
if response.status == 303:
method = "GET"
retries = kw.get("retries")
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect)
# Strip headers marked as unsafe to forward to the redirected location.
# Check remove_headers_on_redirect to avoid a potential network call within
# conn.is_same_host() which may use socket.gethostbyname() in the future.
if retries.remove_headers_on_redirect and not conn.is_same_host(
redirect_location
):
headers = list(six.iterkeys(kw["headers"]))
for header in headers:
if header.lower() in retries.remove_headers_on_redirect:
kw["headers"].pop(header, None)
try:
retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError:
if retries.raise_on_redirect:
response.drain_conn()
raise
return response
kw["retries"] = retries
kw["redirect"] = redirect
log.info("Redirecting %s -> %s", url, redirect_location)
response.drain_conn()
return self.urlopen(method, redirect_location, **kw)
class ProxyManager(PoolManager):
"""
Behaves just like :class:`PoolManager`, but sends all requests through
the defined proxy, using the CONNECT method for HTTPS URLs.
:param proxy_url:
The URL of the proxy to be used.
:param proxy_headers:
A dictionary containing headers that will be sent to the proxy. In case
of HTTP they are being sent with each request, while in the
HTTPS/CONNECT case they are sent only once. Could be used for proxy
authentication.
Example:
>>> proxy = urllib3.ProxyManager('http://localhost:3128/')
>>> r1 = proxy.request('GET', 'http://google.com/')
>>> r2 = proxy.request('GET', 'http://httpbin.org/')
>>> len(proxy.pools)
1
>>> r3 = proxy.request('GET', 'https://httpbin.org/')
>>> r4 = proxy.request('GET', 'https://twitter.com/')
>>> len(proxy.pools)
3
"""
def __init__(
self,
proxy_url,
num_pools=10,
headers=None,
proxy_headers=None,
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = "%s://%s:%i" % (
proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url)
if not proxy.port:
port = port_by_scheme.get(proxy.scheme, 80)
proxy = proxy._replace(port=port)
if proxy.scheme not in ("http", "https"):
raise ProxySchemeUnknown(proxy.scheme)
self.proxy = proxy
self.proxy_headers = proxy_headers or {}
connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https":
return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host(
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None):
"""
Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user.
"""
headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc
if netloc:
headers_["Host"] = netloc
if headers:
headers_.update(headers)
return headers_
def _validate_proxy_scheme_url_selection(self, url_scheme):
if url_scheme == "https" and self.proxy.scheme == "https":
warnings.warn(
"Your proxy configuration specified an HTTPS scheme for the proxy. "
"Are you sure you want to use HTTPS to contact the proxy? "
"This most likely indicates an error in your configuration. "
"Read this issue for more info: "
"https://github.com/urllib3/urllib3/issues/1850",
InvalidProxyConfigurationWarning,
stacklevel=3,
)
def urlopen(self, method, url, redirect=True, **kw):
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
u = parse_url(url)
self._validate_proxy_scheme_url_selection(u.scheme)
if u.scheme == "http":
# For proxied HTTPS requests, httplib sets the necessary headers
# on the CONNECT to the proxy. For HTTP, we'll definitely
# need to set 'Host' at the very least.
headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
def proxy_from_url(url, **kw):
return ProxyManager(proxy_url=url, **kw)

171
urllib3/request.py Normal file
View File

@@ -0,0 +1,171 @@
from __future__ import absolute_import
from .filepost import encode_multipart_formdata
from .packages.six.moves.urllib.parse import urlencode
__all__ = ["RequestMethods"]
class RequestMethods(object):
"""
Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`~urllib3.connectionpool.HTTPConnectionPool` and
:class:`~urllib3.poolmanager.PoolManager`.
Provides behavior for making common types of HTTP request methods and
decides which type of request field encoding to use.
Specifically,
:meth:`.request_encode_url` is for sending requests whose fields are
encoded in the URL (such as GET, HEAD, DELETE).
:meth:`.request_encode_body` is for sending requests whose fields are
encoded in the *body* of the request using multipart or www-form-urlencoded
(such as for POST, PUT, PATCH).
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
Initializer parameters:
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers=None):
self.headers = headers or {}
def urlopen(
self,
method,
url,
body=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**kw
): # Abstract
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
def request(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the appropriate encoding of
``fields`` based on the ``method`` used.
This is a convenience method that requires the least amount of manual
effort. It can be used in most situations, while still having the
option to drop down to more specific methods when necessary, such as
:meth:`request_encode_url`, :meth:`request_encode_body`,
or even the lowest level :meth:`urlopen`.
"""
method = method.upper()
urlopen_kw["request_url"] = url
if method in self._encode_url_methods:
return self.request_encode_url(
method, url, fields=fields, headers=headers, **urlopen_kw
)
else:
return self.request_encode_body(
method, url, fields=fields, headers=headers, **urlopen_kw
)
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw)
def request_encode_body(
self,
method,
url,
fields=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**urlopen_kw
):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc.
When ``encode_multipart=True`` (default), then
:meth:`urllib3.filepost.encode_multipart_formdata` is used to encode
the payload with the appropriate content type. Otherwise
:meth:`urllib.urlencode` is used with the
'application/x-www-form-urlencoded' content type.
Multipart encoding must be used when posting files, and it's reasonably
safe to use it in other times too. However, it may break request
signing, such as with OAuth.
Supports an optional ``fields`` parameter of key/value strings AND
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
the MIME type is optional. For example::
fields = {
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(),
'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
}
When uploading a file, providing a filename (the first parameter of the
tuple) is optional but recommended to best mimic behavior of browsers.
Note that if ``headers`` are supplied, the 'Content-Type' header will
be overwritten because it depends on the dynamic random boundary string
which is used to compose the body of the request. The random boundary
string can be explicitly set with the ``multipart_boundary`` parameter.
"""
if headers is None:
headers = self.headers
extra_kw = {"headers": {}}
if fields:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart:
body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else:
body, content_type = (
urlencode(fields),
"application/x-www-form-urlencoded",
)
extra_kw["body"] = body
extra_kw["headers"] = {"Content-Type": content_type}
extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw)

821
urllib3/response.py Normal file
View File

@@ -0,0 +1,821 @@
from __future__ import absolute_import
from contextlib import contextmanager
import zlib
import io
import logging
from socket import timeout as SocketTimeout
from socket import error as SocketError
try:
import brotli
except ImportError:
brotli = None
from ._collections import HTTPHeaderDict
from .exceptions import (
BodyNotHttplibCompatible,
ProtocolError,
DecodeError,
ReadTimeoutError,
ResponseNotChunked,
IncompleteRead,
InvalidHeader,
HTTPError,
)
from .packages.six import string_types as basestring, PY3
from .packages.six.moves import http_client as httplib
from .connection import HTTPException, BaseSSLError
from .util.response import is_fp_closed, is_response_to_head
log = logging.getLogger(__name__)
class DeflateDecoder(object):
def __init__(self):
self._first_try = True
self._data = b""
self._obj = zlib.decompressobj()
def __getattr__(self, name):
return getattr(self._obj, name)
def decompress(self, data):
if not data:
return data
if not self._first_try:
return self._obj.decompress(data)
self._data += data
try:
decompressed = self._obj.decompress(data)
if decompressed:
self._first_try = False
self._data = None
return decompressed
except zlib.error:
self._first_try = False
self._obj = zlib.decompressobj(-zlib.MAX_WBITS)
try:
return self.decompress(self._data)
finally:
self._data = None
class GzipDecoderState(object):
FIRST_MEMBER = 0
OTHER_MEMBERS = 1
SWALLOW_DATA = 2
class GzipDecoder(object):
def __init__(self):
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
self._state = GzipDecoderState.FIRST_MEMBER
def __getattr__(self, name):
return getattr(self._obj, name)
def decompress(self, data):
ret = bytearray()
if self._state == GzipDecoderState.SWALLOW_DATA or not data:
return bytes(ret)
while True:
try:
ret += self._obj.decompress(data)
except zlib.error:
previous_state = self._state
# Ignore data after the first error
self._state = GzipDecoderState.SWALLOW_DATA
if previous_state == GzipDecoderState.OTHER_MEMBERS:
# Allow trailing garbage acceptable in other gzip clients
return bytes(ret)
raise
data = self._obj.unused_data
if not data:
return bytes(ret)
self._state = GzipDecoderState.OTHER_MEMBERS
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
if brotli is not None:
class BrotliDecoder(object):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self):
self._obj = brotli.Decompressor()
def decompress(self, data):
if hasattr(self._obj, "decompress"):
return self._obj.decompress(data)
return self._obj.process(data)
def flush(self):
if hasattr(self._obj, "flush"):
return self._obj.flush()
return b""
class MultiDecoder(object):
"""
From RFC7231:
If one or more encodings have been applied to a representation, the
sender that applied the encodings MUST generate a Content-Encoding
header field that lists the content codings in the order in which
they were applied.
"""
def __init__(self, modes):
self._decoders = [_get_decoder(m.strip()) for m in modes.split(",")]
def flush(self):
return self._decoders[0].flush()
def decompress(self, data):
for d in reversed(self._decoders):
data = d.decompress(data)
return data
def _get_decoder(mode):
if "," in mode:
return MultiDecoder(mode)
if mode == "gzip":
return GzipDecoder()
if brotli is not None and mode == "br":
return BrotliDecoder()
return DeflateDecoder()
class HTTPResponse(io.IOBase):
"""
HTTP Response container.
Backwards-compatible to httplib's HTTPResponse but the response ``body`` is
loaded and decoded on-demand when the ``data`` property is accessed. This
class is also compatible with the Python standard library's :mod:`io`
module, and can hence be treated as a readable object in the context of that
framework.
Extra parameters for behaviour not present in httplib.HTTPResponse:
:param preload_content:
If True, the response's body will be preloaded during construction.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
:param original_response:
When this HTTPResponse wrapper is generated from an httplib.HTTPResponse
object, it's convenient to include the original for debug purposes. It's
otherwise unused.
:param retries:
The retries contains the last :class:`~urllib3.util.retry.Retry` that
was used during the request.
:param enforce_content_length:
Enforce content length checking. Body returned by server must match
value of Content-Length header, if present. Otherwise, raise error.
"""
CONTENT_DECODERS = ["gzip", "deflate"]
if brotli is not None:
CONTENT_DECODERS += ["br"]
REDIRECT_STATUSES = [301, 302, 303, 307, 308]
def __init__(
self,
body="",
headers=None,
status=0,
version=0,
reason=None,
strict=0,
preload_content=True,
decode_content=True,
original_response=None,
pool=None,
connection=None,
msg=None,
retries=None,
enforce_content_length=False,
request_method=None,
request_url=None,
auto_close=True,
):
if isinstance(headers, HTTPHeaderDict):
self.headers = headers
else:
self.headers = HTTPHeaderDict(headers)
self.status = status
self.version = version
self.reason = reason
self.strict = strict
self.decode_content = decode_content
self.retries = retries
self.enforce_content_length = enforce_content_length
self.auto_close = auto_close
self._decoder = None
self._body = None
self._fp = None
self._original_response = original_response
self._fp_bytes_read = 0
self.msg = msg
self._request_url = request_url
if body and isinstance(body, (basestring, bytes)):
self._body = body
self._pool = pool
self._connection = connection
if hasattr(body, "read"):
self._fp = body
# Are we using the chunked-style of transfer encoding?
self.chunked = False
self.chunk_left = None
tr_enc = self.headers.get("transfer-encoding", "").lower()
# Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings:
self.chunked = True
# Determine length of response
self.length_remaining = self._init_length(request_method)
# If requested, preload the body.
if preload_content and not self._body:
self._body = self.read(decode_content=decode_content)
def get_redirect_location(self):
"""
Should we redirect and where to?
:returns: Truthy redirect location string if we got a redirect status
code and valid location. ``None`` if redirect status and no
location. ``False`` if not a redirect status code.
"""
if self.status in self.REDIRECT_STATUSES:
return self.headers.get("location")
return False
def release_conn(self):
if not self._pool or not self._connection:
return
self._pool._put_conn(self._connection)
self._connection = None
def drain_conn(self):
"""
Read and discard any remaining HTTP response data in the response connection.
Unread data in the HTTPResponse connection blocks the connection from being released back to the pool.
"""
try:
self.read()
except (HTTPError, SocketError, BaseSSLError, HTTPException):
pass
@property
def data(self):
# For backwords-compat with earlier urllib3 0.4 and earlier.
if self._body:
return self._body
if self._fp:
return self.read(cache_content=True)
@property
def connection(self):
return self._connection
def isclosed(self):
return is_fp_closed(self._fp)
def tell(self):
"""
Obtain the number of bytes pulled over the wire so far. May differ from
the amount of content returned by :meth:``HTTPResponse.read`` if bytes
are encoded on the wire (e.g, compressed).
"""
return self._fp_bytes_read
def _init_length(self, request_method):
"""
Set initial length value for Response content if available.
"""
length = self.headers.get("content-length")
if length is not None:
if self.chunked:
# This Response will fail with an IncompleteRead if it can't be
# received as chunked. This method falls back to attempt reading
# the response before raising an exception.
log.warning(
"Received response with both Content-Length and "
"Transfer-Encoding set. This is expressly forbidden "
"by RFC 7230 sec 3.3.2. Ignoring Content-Length and "
"attempting to process response as Transfer-Encoding: "
"chunked."
)
return None
try:
# RFC 7230 section 3.3.2 specifies multiple content lengths can
# be sent in a single Content-Length header
# (e.g. Content-Length: 42, 42). This line ensures the values
# are all valid ints and that as long as the `set` length is 1,
# all values are the same. Otherwise, the header is invalid.
lengths = set([int(val) for val in length.split(",")])
if len(lengths) > 1:
raise InvalidHeader(
"Content-Length contained multiple "
"unmatching values (%s)" % length
)
length = lengths.pop()
except ValueError:
length = None
else:
if length < 0:
length = None
# Convert status to int for comparison
# In some cases, httplib returns a status of "_UNKNOWN"
try:
status = int(self.status)
except ValueError:
status = 0
# Check for responses that shouldn't include a body
if status in (204, 304) or 100 <= status < 200 or request_method == "HEAD":
length = 0
return length
def _init_decoder(self):
"""
Set-up the _decoder attribute if necessary.
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get("content-encoding", "").lower()
if self._decoder is None:
if content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
elif "," in content_encoding:
encodings = [
e.strip()
for e in content_encoding.split(",")
if e.strip() in self.CONTENT_DECODERS
]
if len(encodings):
self._decoder = _get_decoder(content_encoding)
DECODER_ERROR_CLASSES = (IOError, zlib.error)
if brotli is not None:
DECODER_ERROR_CLASSES += (brotli.error,)
def _decode(self, data, decode_content, flush_decoder):
"""
Decode the data passed in and potentially flush the decoder.
"""
if not decode_content:
return data
try:
if self._decoder:
data = self._decoder.decompress(data)
except self.DECODER_ERROR_CLASSES as e:
content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding,
e,
)
if flush_decoder:
data += self._flush_decoder()
return data
def _flush_decoder(self):
"""
Flushes the decoder. Should only be called if the decoder is actually
being used.
"""
if self._decoder:
buf = self._decoder.decompress(b"")
return buf + self._decoder.flush()
return b""
@contextmanager
def _error_catcher(self):
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On exit, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except (HTTPException, SocketError) as e:
# This includes IncompleteRead.
raise ProtocolError("Connection broken: %r" % e, e)
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._original_response:
self._original_response.close()
# Closing the response may not actually be sufficient to close
# everything, so if we have a hold of the connection close that
# too.
if self._connection:
self._connection.close()
# If we hold the original response but it's closed now, we should
# return the connection back to the pool.
if self._original_response and self._original_response.isclosed():
self.release_conn()
def read(self, amt=None, decode_content=None, cache_content=False):
"""
Similar to :meth:`httplib.HTTPResponse.read`, but with two additional
parameters: ``decode_content`` and ``cache_content``.
:param amt:
How much of the content to read. If specified, caching is skipped
because it doesn't make sense to cache partial content as the full
response.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
:param cache_content:
If True, will save the returned data such that the same result is
returned despite of the state of the underlying file object. This
is useful if you want the ``.data`` property to continue working
after having ``.read()`` the file object. (Overridden if ``amt`` is
set.)
"""
self._init_decoder()
if decode_content is None:
decode_content = self.decode_content
if self._fp is None:
return
flush_decoder = False
fp_closed = getattr(self._fp, "closed", False)
with self._error_catcher():
if amt is None:
# cStringIO doesn't like amt=None
data = self._fp.read() if not fp_closed else b""
flush_decoder = True
else:
cache_content = False
data = self._fp.read(amt) if not fp_closed else b""
if (
amt != 0 and not data
): # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned
#
# This is redundant to what httplib/http.client _should_
# already do. However, versions of python released before
# December 15, 2012 (http://bugs.python.org/issue16298) do
# not properly close the connection in all cases. There is
# no harm in redundantly calling close.
self._fp.close()
flush_decoder = True
if self.enforce_content_length and self.length_remaining not in (
0,
None,
):
# This is an edge case that httplib failed to cover due
# to concerns of backward compatibility. We're
# addressing it here to make sure IncompleteRead is
# raised during streaming, so all calls with incorrect
# Content-Length are caught.
raise IncompleteRead(self._fp_bytes_read, self.length_remaining)
if data:
self._fp_bytes_read += len(data)
if self.length_remaining is not None:
self.length_remaining -= len(data)
data = self._decode(data, decode_content, flush_decoder)
if cache_content:
self._body = data
return data
def stream(self, amt=2 ** 16, decode_content=None):
"""
A generator wrapper for the read() method. A call will block until
``amt`` bytes have been read from the connection or until the
connection is closed.
:param amt:
How much of the content to read. The generator will return up to
much data per iteration, but may return less. This is particularly
likely when using compressed data. However, the empty string will
never be returned.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
if self.chunked and self.supports_chunked_reads():
for line in self.read_chunked(amt, decode_content=decode_content):
yield line
else:
while not is_fp_closed(self._fp):
data = self.read(amt=amt, decode_content=decode_content)
if data:
yield data
@classmethod
def from_httplib(ResponseCls, r, **response_kw):
"""
Given an :class:`httplib.HTTPResponse` instance ``r``, return a
corresponding :class:`urllib3.response.HTTPResponse` object.
Remaining parameters are passed to the HTTPResponse constructor, along
with ``original_response=r``.
"""
headers = r.msg
if not isinstance(headers, HTTPHeaderDict):
if PY3:
headers = HTTPHeaderDict(headers.items())
else:
# Python 2.7
headers = HTTPHeaderDict.from_httplib(headers)
# HTTPResponse objects in Python 3 don't have a .strict attribute
strict = getattr(r, "strict", 0)
resp = ResponseCls(
body=r,
headers=headers,
status=r.status,
version=r.version,
reason=r.reason,
strict=strict,
original_response=r,
**response_kw
)
return resp
# Backwards-compatibility methods for httplib.HTTPResponse
def getheaders(self):
return self.headers
def getheader(self, name, default=None):
return self.headers.get(name, default)
# Backwards compatibility for http.cookiejar
def info(self):
return self.headers
# Overrides from io.IOBase
def close(self):
if not self.closed:
self._fp.close()
if self._connection:
self._connection.close()
if not self.auto_close:
io.IOBase.close(self)
@property
def closed(self):
if not self.auto_close:
return io.IOBase.closed.__get__(self)
elif self._fp is None:
return True
elif hasattr(self._fp, "isclosed"):
return self._fp.isclosed()
elif hasattr(self._fp, "closed"):
return self._fp.closed
else:
return True
def fileno(self):
if self._fp is None:
raise IOError("HTTPResponse has no file to get a fileno from")
elif hasattr(self._fp, "fileno"):
return self._fp.fileno()
else:
raise IOError(
"The file-like object this HTTPResponse is wrapped "
"around has no file descriptor"
)
def flush(self):
if (
self._fp is not None
and hasattr(self._fp, "flush")
and not getattr(self._fp, "closed", False)
):
return self._fp.flush()
def readable(self):
# This method is required for `io` module compatibility.
return True
def readinto(self, b):
# This method is required for `io` module compatibility.
temp = self.read(len(b))
if len(temp) == 0:
return 0
else:
b[: len(temp)] = temp
return len(temp)
def supports_chunked_reads(self):
"""
Checks if the underlying file-like object looks like a
httplib.HTTPResponse object. We do this by testing for the fp
attribute. If it is present we assume it returns raw chunks as
processed by read_chunked().
"""
return hasattr(self._fp, "fp")
def _update_chunk_length(self):
# First, we'll figure out length of a chunk and then
# we'll try to read it from socket.
if self.chunk_left is not None:
return
line = self._fp.fp.readline()
line = line.split(b";", 1)[0]
try:
self.chunk_left = int(line, 16)
except ValueError:
# Invalid chunked protocol response, abort.
self.close()
raise httplib.IncompleteRead(line)
def _handle_chunk(self, amt):
returned_chunk = None
if amt is None:
chunk = self._fp._safe_read(self.chunk_left)
returned_chunk = chunk
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
elif amt < self.chunk_left:
value = self._fp._safe_read(amt)
self.chunk_left = self.chunk_left - amt
returned_chunk = value
elif amt == self.chunk_left:
value = self._fp._safe_read(amt)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
returned_chunk = value
else: # amt > self.chunk_left
returned_chunk = self._fp._safe_read(self.chunk_left)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
return returned_chunk
def read_chunked(self, amt=None, decode_content=None):
"""
Similar to :meth:`HTTPResponse.read`, but with an additional
parameter: ``decode_content``.
:param amt:
How much of the content to read. If specified, caching is skipped
because it doesn't make sense to cache partial content as the full
response.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
self._init_decoder()
# FIXME: Rewrite this method and make it a class with a better structured logic.
if not self.chunked:
raise ResponseNotChunked(
"Response is not chunked. "
"Header 'transfer-encoding: chunked' is missing."
)
if not self.supports_chunked_reads():
raise BodyNotHttplibCompatible(
"Body should be httplib.HTTPResponse like. "
"It should have have an fp attribute which returns raw chunks."
)
with self._error_catcher():
# Don't bother reading the body of a HEAD request.
if self._original_response and is_response_to_head(self._original_response):
self._original_response.close()
return
# If a response is already read and closed
# then return immediately.
if self._fp.fp is None:
return
while True:
self._update_chunk_length()
if self.chunk_left == 0:
break
chunk = self._handle_chunk(amt)
decoded = self._decode(
chunk, decode_content=decode_content, flush_decoder=False
)
if decoded:
yield decoded
if decode_content:
# On CPython and PyPy, we should never need to flush the
# decoder. However, on Jython we *might* need to, so
# lets defensively do it anyway.
decoded = self._flush_decoder()
if decoded: # Platform-specific: Jython.
yield decoded
# Chunk content ends with \r\n: discard it.
while True:
line = self._fp.fp.readline()
if not line:
# Some sites may not end with '\r\n'.
break
if line == b"\r\n":
break
# We read everything; close the "file".
if self._original_response:
self._original_response.close()
def geturl(self):
"""
Returns the URL that was the source of this response.
If the request that generated this response redirected, this method
will return the final redirect location.
"""
if self.retries is not None and len(self.retries.history):
return self.retries.history[-1].redirect_location
else:
return self._request_url
def __iter__(self):
buffer = []
for chunk in self.stream(decode_content=True):
if b"\n" in chunk:
chunk = chunk.split(b"\n")
yield b"".join(buffer) + chunk[0] + b"\n"
for x in chunk[1:-1]:
yield x + b"\n"
if chunk[-1]:
buffer = [chunk[-1]]
else:
buffer = []
else:
buffer.append(chunk)
if buffer:
yield b"".join(buffer)

46
urllib3/util/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here.
from .connection import is_connection_dropped
from .request import make_headers
from .response import is_fp_closed
from .ssl_ import (
SSLContext,
HAS_SNI,
IS_PYOPENSSL,
IS_SECURETRANSPORT,
assert_fingerprint,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
PROTOCOL_TLS,
)
from .timeout import current_time, Timeout
from .retry import Retry
from .url import get_host, parse_url, split_first, Url
from .wait import wait_for_read, wait_for_write
__all__ = (
"HAS_SNI",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"PROTOCOL_TLS",
"Retry",
"Timeout",
"Url",
"assert_fingerprint",
"current_time",
"is_connection_dropped",
"is_fp_closed",
"get_host",
"parse_url",
"make_headers",
"resolve_cert_reqs",
"resolve_ssl_version",
"split_first",
"ssl_wrap_socket",
"wait_for_read",
"wait_for_write",
)

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More