diff --git a/__pycache__/py7zr.cpython-38.pyc b/__pycache__/py7zr.cpython-38.pyc new file mode 100644 index 0000000..8f7b7f6 Binary files /dev/null and b/__pycache__/py7zr.cpython-38.pyc differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..e8f89d0 --- /dev/null +++ b/main.py @@ -0,0 +1,128 @@ +import tkinter +from tkinter import messagebox +from tkinter import filedialog +import os +import platform +import sys +import requests +import json +import py7zr +from pathlib import Path + +dsiVersions = ["1.0 - 1.3 (USA, EUR, AUS, JPN)", "1.4 - 1.4.5 (USA, EUR, AUS, JPN)", "All versions (KOR, CHN)"] +memoryPitLinks = ["https://github.com/YourKalamity/just-a-dsi-cfw-installer/raw/master/assets/files/memoryPit/256/pit.bin","https://github.com/YourKalamity/just-a-dsi-cfw-installer/raw/master/assets/files/memoryPit/768_1024/pit.bin"] + +window = tkinter.Tk() +window.sourceFolder = '' +window.sourceFile = '' +SDlabel = tkinter.Label(text = "SD card directory") +SDlabel.width = 100 +SDentry = tkinter.Entry() +SDentry.width = 100 + +def getLatestTWLmenu(): + release = json.loads(requests.get("https://api.github.com/repos/DS-Homebrew/TWiLightMenu/releases/latest").content) + url = release["assets"][0]["browser_download_url"] + return url + +def outputbox(message): + outputBox.configure(state='normal') + outputBox.insert('end', message) + outputBox.configure(state='disabled') + +def validateDirectory(directory): + try: + directory = str(directory) + except TypeError: + outputbox("That's not a valid directory") + return False + try: + string = directory + "/test.file" + with open(string, 'w') as file: + file.close() + os.remove(string) + except FileNotFoundError: + outputbox("That's not a valid directory") + outputbox(" or you do not have the") + outputbox(" permission to write there") + return False + except PermissionError: + outputbox("You do not have write") + outputbox(" access to that folder") + return False + else: + return True + +def start(): + outputBox.delete(0, tkinter.END) + #Variables + directory = SDentry.get() + version = firmwareVersion.get() + unlaunchNeeded = unlaunch.get() + + directoryValidated = validateDirectory(directory) + if directoryValidated == False: + return + if dsiVersions.index(version) == 1: + memoryPitDownload = memoryPitLinks[1] + elif dsiVersions.index(version) in [0,2]: + memoryPitDownload = memoryPitLinks[0] + + temp = directory + "/tmp/" + Path(temp).mkdir(parents=True,exist_ok=True) + + #Download Memory Pit + memoryPitLocation = directory + "/private/ds/app/484E494A/" + Path(memoryPitLocation).mkdir(parents=True, exist_ok=True) + r = requests.get(memoryPitDownload, allow_redirects=True) + memoryPitLocation = memoryPitLocation + "pit.bin" + open(memoryPitLocation, 'wb').write(r.content) + outputbox("Memory Pit Downloaded ") + + #Download TWiLight Menu + r = requests.get(getLatestTWLmenu(), allow_redirects=True) + TWLmenuLocation = temp + "TWiLightMenu.7z" + open(TWLmenuLocation,'wb').write(r.content) + outputbox("TWiLight Menu ++ Downloaded ") + + #Extract TWiLight Menu + archive = py7zr.SevenZipFile(TWLmenuLocation, mode='r') + archive.extractall(path=temp) + archive.close() + +def chooseDir(): + window.sourceFolder = filedialog.askdirectory(parent=window, initialdir= "/", title='Please select the directory of your SD card') + SDentry.delete(0, tkinter.END) + SDentry.insert(0, window.sourceFolder) +b_chooseDir = tkinter.Button(window, text = "Choose Folder", width = 20, command = chooseDir) +b_chooseDir.width = 100 +b_chooseDir.height = 50 + +firmwareLabel = tkinter.Label(text = "Select your DSi firmware") +firmwareLabel.width = 100 + +firmwareVersion = tkinter.StringVar(window) +firmwareVersion.set(dsiVersions[0]) +selector = tkinter.OptionMenu(window, firmwareVersion, *dsiVersions) +selector.width = 100 + +unlaunch = tkinter.IntVar(value=1) +unlaunchCheck = tkinter.Checkbutton(window, text = "Install Unlaunch?", variable =unlaunch) + +startButton = tkinter.Button(window, text = "Start", width = 20, command = start) +outputLabel = tkinter.Label(text="Output") +outputLabel.width = 100 +outputBox = tkinter.Text(window,state='disabled', width = 30, height = 10) + + +SDlabel.pack() +SDentry.pack() +b_chooseDir.pack() +firmwareLabel.pack() +selector.pack() +unlaunchCheck.pack() +startButton.pack() +outputLabel.pack() +outputBox.pack() +window.mainloop() + diff --git a/py7zr/archiveinfo.py b/py7zr/archiveinfo.py new file mode 100644 index 0000000..2e2b95e --- /dev/null +++ b/py7zr/archiveinfo.py @@ -0,0 +1,1094 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2019,2020 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# + +import functools +import io +import os +import struct +from binascii import unhexlify +from functools import reduce +from io import BytesIO +from operator import and_, or_ +from struct import pack, unpack +from typing import Any, BinaryIO, Dict, List, Optional, Tuple + +from py7zr.compression import SevenZipCompressor, SevenZipDecompressor +from py7zr.exceptions import Bad7zFile +from py7zr.helpers import ArchiveTimestamp, calculate_crc32 +from py7zr.properties import MAGIC_7Z, CompressionMethod, Property + +MAX_LENGTH = 65536 +P7ZIP_MAJOR_VERSION = b'\x00' +P7ZIP_MINOR_VERSION = b'\x04' + + +def read_crcs(file: BinaryIO, count: int) -> List[int]: + data = file.read(4 * count) + return [unpack(' Tuple[bytes, ...]: + return unpack(b'B' * length, file.read(length)) + + +def read_byte(file: BinaryIO) -> int: + return ord(file.read(1)) + + +def write_bytes(file: BinaryIO, data: bytes): + return file.write(data) + + +def write_byte(file: BinaryIO, data): + assert len(data) == 1 + return write_bytes(file, data) + + +def read_real_uint64(file: BinaryIO) -> Tuple[int, bytes]: + """read 8 bytes, return unpacked value as a little endian unsigned long long, and raw data.""" + res = file.read(8) + a = unpack(' Tuple[int, bytes]: + """read 4 bytes, return unpacked value as a little endian unsigned long, and raw data.""" + res = file.read(4) + a = unpack(' int: + """read UINT64, definition show in write_uint64()""" + b = ord(file.read(1)) + if b == 255: + return read_real_uint64(file)[0] + blen = [(0b01111111, 0), (0b10111111, 1), (0b11011111, 2), (0b11101111, 3), + (0b11110111, 4), (0b11111011, 5), (0b11111101, 6), (0b11111110, 7)] + mask = 0x80 + vlen = 8 + for v, l in blen: + if b <= v: + vlen = l + break + mask >>= 1 + if vlen == 0: + return b & (mask - 1) + val = file.read(vlen) + value = int.from_bytes(val, byteorder='little') + highpart = b & (mask - 1) + return value + (highpart << (vlen * 8)) + + +def write_real_uint64(file: BinaryIO, value: int): + """write 8 bytes, as an unsigned long long.""" + file.write(pack(' 0x01ffffffffffffff: + file.write(b'\xff') + file.write(value.to_bytes(8, 'little')) + return + byte_length = (value.bit_length() + 7) // 8 + ba = bytearray(value.to_bytes(byte_length, 'little')) + high_byte = int(ba[-1]) + if high_byte < 2 << (8 - byte_length - 1): + for x in range(byte_length - 1): + high_byte |= 0x80 >> x + file.write(pack('B', high_byte)) + file.write(ba[:byte_length - 1]) + else: + mask = 0x80 + for x in range(byte_length): + mask |= 0x80 >> x + file.write(pack('B', mask)) + file.write(ba) + + +def read_boolean(file: BinaryIO, count: int, checkall: bool = False) -> List[bool]: + if checkall: + all_defined = file.read(1) + if all_defined != unhexlify('00'): + return [True] * count + result = [] + b = 0 + mask = 0 + for i in range(count): + if mask == 0: + b = ord(file.read(1)) + mask = 0x80 + result.append(b & mask != 0) + mask >>= 1 + return result + + +def write_boolean(file: BinaryIO, booleans: List[bool], all_defined: bool = False): + if all_defined and reduce(and_, booleans, True): + file.write(b'\x01') + return + elif all_defined: + file.write(b'\x00') + o = bytearray(-(-len(booleans) // 8)) + for i, b in enumerate(booleans): + if b: + o[i // 8] |= 1 << (7 - i % 8) + file.write(o) + + +def read_utf16(file: BinaryIO) -> str: + """read a utf-16 string from file""" + val = '' + for _ in range(MAX_LENGTH): + ch = file.read(2) + if ch == unhexlify('0000'): + break + val += ch.decode('utf-16LE') + return val + + +def write_utf16(file: BinaryIO, val: str): + """write a utf-16 string to file""" + for c in val: + file.write(c.encode('utf-16LE')) + file.write(b'\x00\x00') + + +def bits_to_bytes(bit_length: int) -> int: + return - (-bit_length // 8) + + +class ArchiveProperties: + + __slots__ = ['property_data'] + + def __init__(self): + self.property_data = [] + + @classmethod + def retrieve(cls, file): + return cls()._read(file) + + def _read(self, file): + pid = file.read(1) + if pid == Property.ARCHIVE_PROPERTIES: + while True: + ptype = file.read(1) + if ptype == Property.END: + break + size = read_uint64(file) + props = read_bytes(file, size) + self.property_data.append(props) + return self + + def write(self, file): + if len(self.property_data) > 0: + write_byte(file, Property.ARCHIVE_PROPERTIES) + for data in self.property_data: + write_uint64(file, len(data)) + write_bytes(file, data) + write_byte(file, Property.END) + + +class PackInfo: + """ information about packed streams """ + + __slots__ = ['packpos', 'numstreams', 'packsizes', 'packpositions', 'crcs'] + + def __init__(self) -> None: + self.packpos = 0 # type: int + self.numstreams = 0 # type: int + self.packsizes = [] # type: List[int] + self.crcs = None # type: Optional[List[int]] + + @classmethod + def retrieve(cls, file: BinaryIO): + return cls()._read(file) + + def _read(self, file: BinaryIO): + self.packpos = read_uint64(file) + self.numstreams = read_uint64(file) + pid = file.read(1) + if pid == Property.SIZE: + self.packsizes = [read_uint64(file) for _ in range(self.numstreams)] + pid = file.read(1) + if pid == Property.CRC: + self.crcs = [read_uint64(file) for _ in range(self.numstreams)] + pid = file.read(1) + if pid != Property.END: + raise Bad7zFile('end id expected but %s found' % repr(pid)) + self.packpositions = [sum(self.packsizes[:i]) for i in range(self.numstreams + 1)] # type: List[int] + return self + + def write(self, file: BinaryIO): + assert self.packpos is not None + numstreams = len(self.packsizes) + assert self.crcs is None or len(self.crcs) == numstreams + write_byte(file, Property.PACK_INFO) + write_uint64(file, self.packpos) + write_uint64(file, numstreams) + write_byte(file, Property.SIZE) + for size in self.packsizes: + write_uint64(file, size) + if self.crcs is not None: + write_bytes(file, Property.CRC) + for crc in self.crcs: + write_uint64(file, crc) + write_byte(file, Property.END) + + +class Folder: + """ a "Folder" represents a stream of compressed data. + coders: list of coder + num_coders: length of coders + coder: hash list + keys of coders: method, numinstreams, numoutstreams, properties + unpacksizes: uncompressed sizes of outstreams + """ + + __slots__ = ['unpacksizes', 'solid', 'coders', 'digestdefined', 'totalin', 'totalout', + 'bindpairs', 'packed_indices', 'crc', 'decompressor', 'compressor', 'files'] + + def __init__(self) -> None: + self.unpacksizes = None # type: Optional[List[int]] + self.coders = [] # type: List[Dict[str, Any]] + self.bindpairs = [] # type: List[Any] + self.packed_indices = [] # type: List[int] + # calculated values + self.totalin = 0 # type: int + self.totalout = 0 # type: int + # internal values + self.solid = False # type: bool + self.digestdefined = False # type: bool + self.crc = None # type: Optional[int] + # compress/decompress objects + self.decompressor = None # type: Optional[SevenZipDecompressor] + self.compressor = None # type: Optional[SevenZipCompressor] + self.files = None + + @classmethod + def retrieve(cls, file: BinaryIO): + obj = cls() + obj._read(file) + return obj + + def _read(self, file: BinaryIO) -> None: + num_coders = read_uint64(file) + for _ in range(num_coders): + b = read_byte(file) + methodsize = b & 0xf + iscomplex = b & 0x10 == 0x10 + hasattributes = b & 0x20 == 0x20 + c = {'method': file.read(methodsize)} # type: Dict[str, Any] + if iscomplex: + c['numinstreams'] = read_uint64(file) + c['numoutstreams'] = read_uint64(file) + else: + c['numinstreams'] = 1 + c['numoutstreams'] = 1 + self.totalin += c['numinstreams'] + self.totalout += c['numoutstreams'] + if hasattributes: + proplen = read_uint64(file) + c['properties'] = file.read(proplen) + self.coders.append(c) + num_bindpairs = self.totalout - 1 + for i in range(num_bindpairs): + self.bindpairs.append((read_uint64(file), read_uint64(file),)) + num_packedstreams = self.totalin - num_bindpairs + if num_packedstreams == 1: + for i in range(self.totalin): + if self._find_in_bin_pair(i) < 0: # there is no in_bin_pair + self.packed_indices.append(i) + elif num_packedstreams > 1: + for i in range(num_packedstreams): + self.packed_indices.append(read_uint64(file)) + + def write(self, file: BinaryIO): + num_coders = len(self.coders) + assert num_coders > 0 + write_uint64(file, num_coders) + for i, c in enumerate(self.coders): + id = c['method'] # type: bytes + id_size = len(id) & 0x0f + iscomplex = 0x10 if not self.is_simple(c) else 0x00 + hasattributes = 0x20 if c['properties'] is not None else 0x00 + flag = struct.pack('B', id_size | iscomplex | hasattributes) + write_byte(file, flag) + write_bytes(file, id[:id_size]) + if not self.is_simple(c): + write_uint64(file, c['numinstreams']) + assert c['numoutstreams'] == 1 + write_uint64(file, c['numoutstreams']) + if c['properties'] is not None: + write_uint64(file, len(c['properties'])) + write_bytes(file, c['properties']) + num_bindpairs = self.totalout - 1 + assert len(self.bindpairs) == num_bindpairs + num_packedstreams = self.totalin - num_bindpairs + for bp in self.bindpairs: + write_uint64(file, bp[0]) + write_uint64(file, bp[1]) + if num_packedstreams > 1: + for pi in self.packed_indices: + write_uint64(file, pi) + + def is_simple(self, coder): + return coder['numinstreams'] == 1 and coder['numoutstreams'] == 1 + + def get_decompressor(self, size: int, reset: bool = False) -> SevenZipDecompressor: + if self.decompressor is not None and not reset: + return self.decompressor + else: + self.decompressor = SevenZipDecompressor(self.coders, size, self.crc) + return self.decompressor + + def get_compressor(self) -> SevenZipCompressor: + if self.compressor is not None: + return self.compressor + else: + try: + # FIXME: set filters + self.compressor = SevenZipCompressor() + self.coders = self.compressor.coders + return self.compressor + except Exception as e: + raise e + + def get_unpack_size(self) -> int: + if self.unpacksizes is None: + return 0 + for i in range(len(self.unpacksizes) - 1, -1, -1): + if self._find_out_bin_pair(i): + return self.unpacksizes[i] + raise TypeError('not found') + + def _find_in_bin_pair(self, index: int) -> int: + for idx, (a, b) in enumerate(self.bindpairs): + if a == index: + return idx + return -1 + + def _find_out_bin_pair(self, index: int) -> int: + for idx, (a, b) in enumerate(self.bindpairs): + if b == index: + return idx + return -1 + + def is_encrypted(self) -> bool: + return CompressionMethod.CRYPT_AES256_SHA256 in [x['method'] for x in self.coders] + + +class UnpackInfo: + """ combines multiple folders """ + + __slots__ = ['numfolders', 'folders', 'datastreamidx'] + + @classmethod + def retrieve(cls, file: BinaryIO): + obj = cls() + obj._read(file) + return obj + + def __init__(self): + self.numfolders = None + self.folders = [] + self.datastreamidx = None + + def _read(self, file: BinaryIO): + pid = file.read(1) + if pid != Property.FOLDER: + raise Bad7zFile('folder id expected but %s found' % repr(pid)) + self.numfolders = read_uint64(file) + self.folders = [] + external = read_byte(file) + if external == 0x00: + self.folders = [Folder.retrieve(file) for _ in range(self.numfolders)] + else: + datastreamidx = read_uint64(file) + current_pos = file.tell() + file.seek(datastreamidx, 0) + self.folders = [Folder.retrieve(file) for _ in range(self.numfolders)] + file.seek(current_pos, 0) + self._retrieve_coders_info(file) + + def _retrieve_coders_info(self, file: BinaryIO): + pid = file.read(1) + if pid != Property.CODERS_UNPACK_SIZE: + raise Bad7zFile('coders unpack size id expected but %s found' % repr(pid)) + for folder in self.folders: + folder.unpacksizes = [read_uint64(file) for _ in range(folder.totalout)] + pid = file.read(1) + if pid == Property.CRC: + defined = read_boolean(file, self.numfolders, checkall=True) + crcs = read_crcs(file, self.numfolders) + for idx, folder in enumerate(self.folders): + folder.digestdefined = defined[idx] + folder.crc = crcs[idx] + pid = file.read(1) + if pid != Property.END: + raise Bad7zFile('end id expected but %s found at %d' % (repr(pid), file.tell())) + + def write(self, file: BinaryIO): + assert self.numfolders is not None + assert self.folders is not None + assert self.numfolders == len(self.folders) + file.write(Property.UNPACK_INFO) + file.write(Property.FOLDER) + write_uint64(file, self.numfolders) + write_byte(file, b'\x00') + for folder in self.folders: + folder.write(file) + # If support external entity, we may write + # self.datastreamidx here. + # folder data will be written in another place. + # write_byte(file, b'\x01') + # assert self.datastreamidx is not None + # write_uint64(file, self.datastreamidx) + write_byte(file, Property.CODERS_UNPACK_SIZE) + for folder in self.folders: + for i in range(folder.totalout): + write_uint64(file, folder.unpacksizes[i]) + write_byte(file, Property.END) + + +class SubstreamsInfo: + """ defines the substreams of a folder """ + + __slots__ = ['digests', 'digestsdefined', 'unpacksizes', 'num_unpackstreams_folders'] + + def __init__(self): + self.digests = [] # type: List[int] + self.digestsdefined = [] # type: List[bool] + self.unpacksizes = None # type: Optional[List[int]] + self.num_unpackstreams_folders = [] # type: List[int] + + @classmethod + def retrieve(cls, file: BinaryIO, numfolders: int, folders: List[Folder]): + obj = cls() + obj._read(file, numfolders, folders) + return obj + + def _read(self, file: BinaryIO, numfolders: int, folders: List[Folder]): + pid = file.read(1) + if pid == Property.NUM_UNPACK_STREAM: + self.num_unpackstreams_folders = [read_uint64(file) for _ in range(numfolders)] + pid = file.read(1) + else: + self.num_unpackstreams_folders = [1] * numfolders + if pid == Property.SIZE: + self.unpacksizes = [] + for i in range(len(self.num_unpackstreams_folders)): + totalsize = 0 # type: int + for j in range(1, self.num_unpackstreams_folders[i]): + size = read_uint64(file) + self.unpacksizes.append(size) + totalsize += size + self.unpacksizes.append(folders[i].get_unpack_size() - totalsize) + pid = file.read(1) + num_digests = 0 + num_digests_total = 0 + for i in range(numfolders): + numsubstreams = self.num_unpackstreams_folders[i] + if numsubstreams != 1 or not folders[i].digestdefined: + num_digests += numsubstreams + num_digests_total += numsubstreams + if pid == Property.CRC: + defined = read_boolean(file, num_digests, checkall=True) + crcs = read_crcs(file, num_digests) + didx = 0 + for i in range(numfolders): + folder = folders[i] + numsubstreams = self.num_unpackstreams_folders[i] + if numsubstreams == 1 and folder.digestdefined and folder.crc is not None: + self.digestsdefined.append(True) + self.digests.append(folder.crc) + else: + for j in range(numsubstreams): + self.digestsdefined.append(defined[didx]) + self.digests.append(crcs[didx]) + didx += 1 + pid = file.read(1) + if pid != Property.END: + raise Bad7zFile('end id expected but %r found' % pid) + if not self.digestsdefined: + self.digestsdefined = [False] * num_digests_total + self.digests = [0] * num_digests_total + + def write(self, file: BinaryIO, numfolders: int): + assert self.num_unpackstreams_folders is not None + if len(self.num_unpackstreams_folders) == 0: + # nothing to write + return + if self.unpacksizes is None: + raise ValueError + write_byte(file, Property.SUBSTREAMS_INFO) + if not functools.reduce(lambda x, y: x and (y == 1), self.num_unpackstreams_folders, True): + write_byte(file, Property.NUM_UNPACK_STREAM) + for n in self.num_unpackstreams_folders: + write_uint64(file, n) + write_byte(file, Property.SIZE) + idx = 0 + for i in range(numfolders): + for j in range(1, self.num_unpackstreams_folders[i]): + size = self.unpacksizes[idx] + write_uint64(file, size) + idx += 1 + idx += 1 + if functools.reduce(lambda x, y: x or y, self.digestsdefined, False): + write_byte(file, Property.CRC) + write_boolean(file, self.digestsdefined, all_defined=True) + write_crcs(file, self.digests) + write_byte(file, Property.END) + + +class StreamsInfo: + """ information about compressed streams """ + + __slots__ = ['packinfo', 'unpackinfo', 'substreamsinfo'] + + def __init__(self): + self.packinfo = None # type: PackInfo + self.unpackinfo = None # type: UnpackInfo + self.substreamsinfo = None # type: Optional[SubstreamsInfo] + + @classmethod + def retrieve(cls, file: BinaryIO): + obj = cls() + obj.read(file) + return obj + + def read(self, file: BinaryIO) -> None: + pid = file.read(1) + if pid == Property.PACK_INFO: + self.packinfo = PackInfo.retrieve(file) + pid = file.read(1) + if pid == Property.UNPACK_INFO: + self.unpackinfo = UnpackInfo.retrieve(file) + pid = file.read(1) + if pid == Property.SUBSTREAMS_INFO: + self.substreamsinfo = SubstreamsInfo.retrieve(file, self.unpackinfo.numfolders, self.unpackinfo.folders) + pid = file.read(1) + if pid != Property.END: + raise Bad7zFile('end id expected but %s found' % repr(pid)) + + def write(self, file: BinaryIO): + write_byte(file, Property.MAIN_STREAMS_INFO) + self._write(file) + + def _write(self, file: BinaryIO): + if self.packinfo is not None: + self.packinfo.write(file) + if self.unpackinfo is not None: + self.unpackinfo.write(file) + if self.substreamsinfo is not None: + self.substreamsinfo.write(file, self.unpackinfo.numfolders) + write_byte(file, Property.END) + + +class HeaderStreamsInfo(StreamsInfo): + + def __init__(self): + super().__init__() + self.packinfo = PackInfo() + self.unpackinfo = UnpackInfo() + folder = Folder() + folder.compressor = SevenZipCompressor() + folder.coders = folder.compressor.coders + folder.solid = False + folder.digestdefined = False + folder.bindpairs = [] + folder.totalin = 1 + folder.totalout = 1 + folder.digestdefined = [True] + self.unpackinfo.numfolders = 1 + self.unpackinfo.folders = [folder] + + def write(self, file: BinaryIO): + self._write(file) + + +class FilesInfo: + """ holds file properties """ + + __slots__ = ['files', 'emptyfiles', 'antifiles'] + + def __init__(self): + self.files = [] # type: List[Dict[str, Any]] + self.emptyfiles = [] # type: List[bool] + self.antifiles = None + + @classmethod + def retrieve(cls, file: BinaryIO): + obj = cls() + obj._read(file) + return obj + + def _read(self, fp: BinaryIO): + numfiles = read_uint64(fp) + self.files = [{'emptystream': False} for _ in range(numfiles)] + numemptystreams = 0 + while True: + prop = fp.read(1) + if prop == Property.END: + break + size = read_uint64(fp) + if prop == Property.DUMMY: + # Added by newer versions of 7z to adjust padding. + fp.seek(size, os.SEEK_CUR) + continue + buffer = io.BytesIO(fp.read(size)) + if prop == Property.EMPTY_STREAM: + isempty = read_boolean(buffer, numfiles, checkall=False) + list(map(lambda x, y: x.update({'emptystream': y}), self.files, isempty)) # type: ignore + numemptystreams += isempty.count(True) + elif prop == Property.EMPTY_FILE: + self.emptyfiles = read_boolean(buffer, numemptystreams, checkall=False) + elif prop == Property.ANTI: + self.antifiles = read_boolean(buffer, numemptystreams, checkall=False) + elif prop == Property.NAME: + external = buffer.read(1) + if external == b'\x00': + self._read_name(buffer) + else: + dataindex = read_uint64(buffer) + current_pos = fp.tell() + fp.seek(dataindex, 0) + self._read_name(fp) + fp.seek(current_pos, 0) + elif prop == Property.CREATION_TIME: + self._read_times(buffer, 'creationtime') + elif prop == Property.LAST_ACCESS_TIME: + self._read_times(buffer, 'lastaccesstime') + elif prop == Property.LAST_WRITE_TIME: + self._read_times(buffer, 'lastwritetime') + elif prop == Property.ATTRIBUTES: + defined = read_boolean(buffer, numfiles, checkall=True) + external = buffer.read(1) + if external == b'\x00': + self._read_attributes(buffer, defined) + else: + dataindex = read_uint64(buffer) + # try to read external data + current_pos = fp.tell() + fp.seek(dataindex, 0) + self._read_attributes(fp, defined) + fp.seek(current_pos, 0) + elif prop == Property.START_POS: + self._read_start_pos(buffer) + else: + raise Bad7zFile('invalid type %r' % prop) + + def _read_name(self, buffer: BinaryIO) -> None: + for f in self.files: + f['filename'] = read_utf16(buffer).replace('\\', '/') + + def _read_attributes(self, buffer: BinaryIO, defined: List[bool]) -> None: + for idx, f in enumerate(self.files): + f['attributes'] = read_uint32(buffer)[0] if defined[idx] else None + + def _read_times(self, fp: BinaryIO, name: str) -> None: + defined = read_boolean(fp, len(self.files), checkall=True) + # NOTE: the "external" flag is currently ignored, should be 0x00 + external = fp.read(1) + assert external == b'\x00' + for i, f in enumerate(self.files): + f[name] = ArchiveTimestamp(read_real_uint64(fp)[0]) if defined[i] else None + + def _read_start_pos(self, fp: BinaryIO) -> None: + defined = read_boolean(fp, len(self.files), checkall=True) + # NOTE: the "external" flag is currently ignored, should be 0x00 + external = fp.read(1) + assert external == 0x00 + for i, f in enumerate(self.files): + f['startpos'] = read_real_uint64(fp)[0] if defined[i] else None + + def _write_times(self, fp: BinaryIO, propid, name: str) -> None: + write_byte(fp, propid) + defined = [] # type: List[bool] + num_defined = 0 # type: int + for f in self.files: + if name in f.keys(): + if f[name] is not None: + defined.append(True) + num_defined += 1 + size = num_defined * 8 + 2 + if not reduce(and_, defined, True): + size += bits_to_bytes(num_defined) + write_uint64(fp, size) + write_boolean(fp, defined, all_defined=True) + write_byte(fp, b'\x00') + for i, file in enumerate(self.files): + if defined[i]: + write_real_uint64(fp, ArchiveTimestamp.from_datetime(file[name])) + else: + pass + + def _write_prop_bool_vector(self, fp: BinaryIO, propid, vector) -> None: + write_byte(fp, propid) + write_boolean(fp, vector, all_defined=True) + + @staticmethod + def _are_there(vector) -> bool: + if vector is not None: + if functools.reduce(or_, vector, False): + return True + return False + + def _write_names(self, file: BinaryIO): + name_defined = 0 + names = [] + name_size = 0 + for f in self.files: + if f.get('filename', None) is not None: + name_defined += 1 + names.append(f['filename']) + name_size += len(f['filename'].encode('utf-16LE')) + 2 # len(str + NULL_WORD) + if name_defined > 0: + write_byte(file, Property.NAME) + write_uint64(file, name_size + 1) + write_byte(file, b'\x00') + for n in names: + write_utf16(file, n) + + def _write_attributes(self, file): + defined = [] # type: List[bool] + num_defined = 0 + for f in self.files: + if 'attributes' in f.keys() and f['attributes'] is not None: + defined.append(True) + num_defined += 1 + else: + defined.append(False) + size = num_defined * 4 + 2 + if num_defined != len(defined): + size += bits_to_bytes(num_defined) + write_byte(file, Property.ATTRIBUTES) + write_uint64(file, size) + write_boolean(file, defined, all_defined=True) + write_byte(file, b'\x00') + for i, f in enumerate(self.files): + if defined[i]: + write_uint32(file, f['attributes']) + + def write(self, file: BinaryIO): + assert self.files is not None + write_byte(file, Property.FILES_INFO) + numfiles = len(self.files) + write_uint64(file, numfiles) + emptystreams = [] # List[bool] + for f in self.files: + emptystreams.append(f['emptystream']) + if self._are_there(emptystreams): + write_byte(file, Property.EMPTY_STREAM) + write_uint64(file, bits_to_bytes(numfiles)) + write_boolean(file, emptystreams, all_defined=False) + else: + if self._are_there(self.emptyfiles): + self._write_prop_bool_vector(file, Property.EMPTY_FILE, self.emptyfiles) + if self._are_there(self.antifiles): + self._write_prop_bool_vector(file, Property.ANTI, self.antifiles) + # Name + self._write_names(file) + # timestamps + self._write_times(file, Property.CREATION_TIME, 'creationtime') + self._write_times(file, Property.LAST_ACCESS_TIME, 'lastaccesstime') + self._write_times(file, Property.LAST_WRITE_TIME, 'lastwritetime') + # start_pos + # FIXME: TBD + # attribute + self._write_attributes(file) + write_byte(file, Property.END) + + +class Header: + """ the archive header """ + + __slot__ = ['solid', 'properties', 'additional_streams', 'main_streams', 'files_info', + 'size', '_start_pos'] + + def __init__(self) -> None: + self.solid = False + self.properties = None + self.additional_streams = None + self.main_streams = None + self.files_info = None + self.size = 0 # fixme. Not implemented yet + self._start_pos = 0 + + @classmethod + def retrieve(cls, fp: BinaryIO, buffer: BytesIO, start_pos: int): + obj = cls() + obj._read(fp, buffer, start_pos) + return obj + + def _read(self, fp: BinaryIO, buffer: BytesIO, start_pos: int) -> None: + self._start_pos = start_pos + fp.seek(self._start_pos) + self._decode_header(fp, buffer) + + def _decode_header(self, fp: BinaryIO, buffer: BytesIO) -> None: + """ + Decode header data or encoded header data from buffer. + When buffer consist of encoded buffer, it get stream data + from it and call itself recursively + """ + pid = buffer.read(1) + if not pid: + # empty archive + return + elif pid == Property.HEADER: + self._extract_header_info(buffer) + return + elif pid != Property.ENCODED_HEADER: + raise TypeError('Unknown field: %r' % id) + # get from encoded header + streams = HeaderStreamsInfo.retrieve(buffer) + self._decode_header(fp, self._get_headerdata_from_streams(fp, streams)) + + def _get_headerdata_from_streams(self, fp: BinaryIO, streams: StreamsInfo) -> BytesIO: + """get header data from given streams.unpackinfo and packinfo. + folder data are stored in raw data positioned in afterheader.""" + buffer = io.BytesIO() + src_start = self._start_pos + for folder in streams.unpackinfo.folders: + uncompressed = folder.unpacksizes + if not isinstance(uncompressed, (list, tuple)): + uncompressed = [uncompressed] * len(folder.coders) + compressed_size = streams.packinfo.packsizes[0] + uncompressed_size = uncompressed[-1] + + src_start += streams.packinfo.packpos + fp.seek(src_start, 0) + decompressor = folder.get_decompressor(compressed_size) + folder_data = decompressor.decompress(fp.read(compressed_size))[:uncompressed_size] + src_start += uncompressed_size + if folder.digestdefined: + if folder.crc != calculate_crc32(folder_data): + raise Bad7zFile('invalid block data') + buffer.write(folder_data) + buffer.seek(0, 0) + return buffer + + def _encode_header(self, file: BinaryIO, afterheader: int): + startpos = file.tell() + packpos = startpos - afterheader + buf = io.BytesIO() + _, raw_header_len, raw_crc = self.write(buf, 0, False) + streams = HeaderStreamsInfo() + streams.packinfo.packpos = packpos + folder = streams.unpackinfo.folders[0] + folder.crc = [raw_crc] + folder.unpacksizes = [raw_header_len] + compressed_len = 0 + buf.seek(0, 0) + data = buf.read(io.DEFAULT_BUFFER_SIZE) + while data: + out = folder.compressor.compress(data) + compressed_len += len(out) + file.write(out) + data = buf.read(io.DEFAULT_BUFFER_SIZE) + out = folder.compressor.flush() + compressed_len += len(out) + file.write(out) + # + streams.packinfo.packsizes = [compressed_len] + # actual header start position + startpos = file.tell() + write_byte(file, Property.ENCODED_HEADER) + streams.write(file) + write_byte(file, Property.END) + return startpos + + def write(self, file: BinaryIO, afterheader: int, encoded: bool = True): + startpos = file.tell() + if encoded: + startpos = self._encode_header(file, afterheader) + else: + write_byte(file, Property.HEADER) + # Archive properties + if self.main_streams is not None: + self.main_streams.write(file) + # Files Info + if self.files_info is not None: + self.files_info.write(file) + if self.properties is not None: + self.properties.write(file) + # AdditionalStreams + if self.additional_streams is not None: + self.additional_streams.write(file) + write_byte(file, Property.END) + endpos = file.tell() + header_len = endpos - startpos + file.seek(startpos, io.SEEK_SET) + crc = calculate_crc32(file.read(header_len)) + file.seek(endpos, io.SEEK_SET) + return startpos, header_len, crc + + def _extract_header_info(self, fp: BinaryIO) -> None: + pid = fp.read(1) + if pid == Property.ARCHIVE_PROPERTIES: + self.properties = ArchiveProperties.retrieve(fp) + pid = fp.read(1) + if pid == Property.ADDITIONAL_STREAMS_INFO: + self.additional_streams = StreamsInfo.retrieve(fp) + pid = fp.read(1) + if pid == Property.MAIN_STREAMS_INFO: + self.main_streams = StreamsInfo.retrieve(fp) + pid = fp.read(1) + if pid == Property.FILES_INFO: + self.files_info = FilesInfo.retrieve(fp) + pid = fp.read(1) + if pid != Property.END: + raise Bad7zFile('end id expected but %s found' % (repr(pid))) + + @staticmethod + def build_header(folders): + header = Header() + header.files_info = FilesInfo() + header.main_streams = StreamsInfo() + header.main_streams.packinfo = PackInfo() + header.main_streams.packinfo.numstreams = 0 + header.main_streams.packinfo.packpos = 0 + header.main_streams.unpackinfo = UnpackInfo() + header.main_streams.unpackinfo.numfolders = len(folders) + header.main_streams.unpackinfo.folders = folders + header.main_streams.substreamsinfo = SubstreamsInfo() + header.main_streams.substreamsinfo.num_unpackstreams_folders = [len(folders)] + header.main_streams.substreamsinfo.unpacksizes = [] + return header + + +class SignatureHeader: + """The SignatureHeader class hold information of a signature header of archive.""" + + __slots__ = ['version', 'startheadercrc', 'nextheaderofs', 'nextheadersize', 'nextheadercrc'] + + def __init__(self) -> None: + self.version = (P7ZIP_MAJOR_VERSION, P7ZIP_MINOR_VERSION) # type: Tuple[bytes, ...] + self.startheadercrc = None # type: Optional[int] + self.nextheaderofs = None # type: Optional[int] + self.nextheadersize = None # type: Optional[int] + self.nextheadercrc = None # type: Optional[int] + + @classmethod + def retrieve(cls, file: BinaryIO): + obj = cls() + obj._read(file) + return obj + + def _read(self, file: BinaryIO) -> None: + file.seek(len(MAGIC_7Z), 0) + self.version = read_bytes(file, 2) + self.startheadercrc, _ = read_uint32(file) + self.nextheaderofs, data = read_real_uint64(file) + crc = calculate_crc32(data) + self.nextheadersize, data = read_real_uint64(file) + crc = calculate_crc32(data, crc) + self.nextheadercrc, data = read_uint32(file) + crc = calculate_crc32(data, crc) + if crc != self.startheadercrc: + raise Bad7zFile('invalid header data') + + def calccrc(self, length: int, header_crc: int): + self.nextheadersize = length + self.nextheadercrc = header_crc + assert self.nextheaderofs is not None + buf = io.BytesIO() + write_real_uint64(buf, self.nextheaderofs) + write_real_uint64(buf, self.nextheadersize) + write_uint32(buf, self.nextheadercrc) + startdata = buf.getvalue() + self.startheadercrc = calculate_crc32(startdata) + + def write(self, file: BinaryIO): + assert self.startheadercrc is not None + assert self.nextheadercrc is not None + assert self.nextheaderofs is not None + assert self.nextheadersize is not None + file.seek(0, 0) + write_bytes(file, MAGIC_7Z) + write_byte(file, self.version[0]) + write_byte(file, self.version[1]) + write_uint32(file, self.startheadercrc) + write_real_uint64(file, self.nextheaderofs) + write_real_uint64(file, self.nextheadersize) + write_uint32(file, self.nextheadercrc) + + def _write_skelton(self, file: BinaryIO): + file.seek(0, 0) + write_bytes(file, MAGIC_7Z) + write_byte(file, self.version[0]) + write_byte(file, self.version[1]) + write_uint32(file, 1) + write_real_uint64(file, 2) + write_real_uint64(file, 3) + write_uint32(file, 4) + + +class FinishHeader(): + """Finish header for multi-volume 7z file.""" + + def __init__(self): + self.archive_start_offset = None # data offset from end of the finish header + self.additional_start_block_size = None # start signature & start header size + self.finish_header_size = 20 + 16 + + @classmethod + def retrieve(cls, file): + obj = cls() + obj._read(file) + return obj + + def _read(self, file): + self.archive_start_offset = read_uint64(file) + self.additional_start_block_size = read_uint64(file) diff --git a/py7zr/callbacks.py b/py7zr/callbacks.py new file mode 100644 index 0000000..6b2c083 --- /dev/null +++ b/py7zr/callbacks.py @@ -0,0 +1,61 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2020 Hiroshi Miura +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# + +from abc import ABC, abstractmethod + + +class Callback(ABC): + """Abstrat base class for progress callbacks.""" + + @abstractmethod + def report_start_preparation(self): + """report a start of preparation event such as making list of files and looking into its properties.""" + pass + + @abstractmethod + def report_start(self, processing_file_path, processing_bytes): + """report a start event of specified archive file and its input bytes.""" + pass + + @abstractmethod + def report_end(self, processing_file_path, wrote_bytes): + """report an end event of specified archive file and its output bytes.""" + pass + + @abstractmethod + def report_warning(self, message): + """report an warning event with its message""" + pass + + @abstractmethod + def report_postprocess(self): + """report a start of post processing event such as set file properties and permissions or creating symlinks.""" + pass + + +class ExtractCallback(Callback): + """Abstrat base class for extraction progress callbacks.""" + pass + + +class ArchiveCallback(Callback): + """Abstrat base class for progress callbacks.""" + pass diff --git a/py7zr/cli.py b/py7zr/cli.py new file mode 100644 index 0000000..3d7808f --- /dev/null +++ b/py7zr/cli.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python +# +# Pure python p7zr implementation +# Copyright (C) 2019, 2020 Hiroshi Miura +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +import argparse +import getpass +import os +import pathlib +import re +import shutil +import sys +from lzma import CHECK_CRC64, CHECK_SHA256, is_check_supported +from typing import Any, List, Optional + +import texttable # type: ignore + +import py7zr +from py7zr.callbacks import ExtractCallback +from py7zr.helpers import Local +from py7zr.properties import READ_BLOCKSIZE, SupportedMethods + + +class CliExtractCallback(ExtractCallback): + + def __init__(self, total_bytes, ofd=sys.stdout): + self.ofd = ofd + self.archive_total = total_bytes + self.total_bytes = 0 + self.columns, _ = shutil.get_terminal_size(fallback=(80, 24)) + self.pwidth = 0 + + def report_start_preparation(self): + pass + + def report_start(self, processing_file_path, processing_bytes): + self.ofd.write('- {}'.format(processing_file_path)) + self.pwidth += len(processing_file_path) + 2 + + def report_end(self, processing_file_path, wrote_bytes): + self.total_bytes += int(wrote_bytes) + plest = self.columns - self.pwidth + progress = self.total_bytes / self.archive_total + msg = '({:.0%})\n'.format(progress) + if plest - len(msg) > 0: + self.ofd.write(msg.rjust(plest)) + else: + self.ofd.write(msg) + self.pwidth = 0 + + def report_postprocess(self): + pass + + def report_warning(self, message): + pass + + +class Cli(): + + dunits = {'b': 1, 'B': 1, 'k': 1024, 'K': 1024, 'm': 1024 * 1024, 'M': 1024 * 1024, + 'g': 1024 * 1024 * 1024, 'G': 1024 * 1024 * 1024} + + def __init__(self): + self.parser = self._create_parser() + self.unit_pattern = re.compile(r'^([0-9]+)([bkmg]?)$', re.IGNORECASE) + + def run(self, arg: Optional[Any] = None) -> int: + args = self.parser.parse_args(arg) + return args.func(args) + + def _create_parser(self): + parser = argparse.ArgumentParser(prog='py7zr', description='py7zr', + formatter_class=argparse.RawTextHelpFormatter, add_help=True) + subparsers = parser.add_subparsers(title='subcommands', help='subcommand for py7zr l .. list, x .. extract,' + ' t .. check integrity, i .. information') + list_parser = subparsers.add_parser('l') + list_parser.set_defaults(func=self.run_list) + list_parser.add_argument("arcfile", help="7z archive file") + list_parser.add_argument("--verbose", action="store_true", help="verbose output") + extract_parser = subparsers.add_parser('x') + extract_parser.set_defaults(func=self.run_extract) + extract_parser.add_argument("arcfile", help="7z archive file") + extract_parser.add_argument("odir", nargs="?", help="output directory") + extract_parser.add_argument("-P", "--password", action="store_true", + help="Password protected archive(you will be asked a password).") + extract_parser.add_argument("--verbose", action="store_true", help="verbose output") + create_parser = subparsers.add_parser('c') + create_parser.set_defaults(func=self.run_create) + create_parser.add_argument("arcfile", help="7z archive file") + create_parser.add_argument("filenames", nargs="+", help="filenames to archive") + create_parser.add_argument("-v", "--volume", nargs=1, help="Create volumes.") + test_parser = subparsers.add_parser('t') + test_parser.set_defaults(func=self.run_test) + test_parser.add_argument("arcfile", help="7z archive file") + info_parser = subparsers.add_parser("i") + info_parser.set_defaults(func=self.run_info) + parser.set_defaults(func=self.show_help) + return parser + + def show_help(self, args): + self.parser.print_help() + return(0) + + def run_info(self, args): + print("py7zr version {} {}".format(py7zr.__version__, py7zr.__copyright__)) + print("Formats:") + table = texttable.Texttable() + table.set_deco(texttable.Texttable.HEADER) + table.set_cols_dtype(['t', 't']) + table.set_cols_align(["l", "r"]) + for f in SupportedMethods.formats: + m = ''.join(' {:02x}'.format(x) for x in f['magic']) + table.add_row([f['name'], m]) + print(table.draw()) + print("\nCodecs:") + table = texttable.Texttable() + table.set_deco(texttable.Texttable.HEADER) + table.set_cols_dtype(['t', 't']) + table.set_cols_align(["l", "r"]) + for c in SupportedMethods.codecs: + m = ''.join('{:02x}'.format(x) for x in c['id']) + table.add_row([m, c['name']]) + print(table.draw()) + print("\nChecks:") + print("CHECK_NONE") + print("CHECK_CRC32") + if is_check_supported(CHECK_CRC64): + print("CHECK_CRC64") + if is_check_supported(CHECK_SHA256): + print("CHECK_SHA256") + + def run_list(self, args): + """Print a table of contents to file. """ + target = args.arcfile + verbose = args.verbose + if not py7zr.is_7zfile(target): + print('not a 7z file') + return(1) + with open(target, 'rb') as f: + a = py7zr.SevenZipFile(f) + file = sys.stdout + archive_info = a.archiveinfo() + archive_list = a.list() + if verbose: + file.write("Listing archive: {}\n".format(target)) + file.write("--\n") + file.write("Path = {}\n".format(archive_info.filename)) + file.write("Type = 7z\n") + fstat = os.stat(archive_info.filename) + file.write("Phisical Size = {}\n".format(fstat.st_size)) + file.write("Headers Size = {}\n".format(archive_info.header_size)) + file.write("Method = {}\n".format(archive_info.method_names)) + if archive_info.solid: + file.write("Solid = {}\n".format('+')) + else: + file.write("Solid = {}\n".format('-')) + file.write("Blocks = {}\n".format(archive_info.blocks)) + file.write('\n') + file.write( + 'total %d files and directories in %sarchive\n' % (len(archive_list), + (archive_info.solid and 'solid ') or '')) + file.write(' Date Time Attr Size Compressed Name\n') + file.write('------------------- ----- ------------ ------------ ------------------------\n') + for f in archive_list: + if f.creationtime is not None: + creationdate = f.creationtime.astimezone(Local).strftime("%Y-%m-%d") + creationtime = f.creationtime.astimezone(Local).strftime("%H:%M:%S") + else: + creationdate = ' ' + creationtime = ' ' + if f.is_directory: + attrib = 'D...' + else: + attrib = '....' + if f.archivable: + attrib += 'A' + else: + attrib += '.' + if f.is_directory: + extra = ' 0 ' + elif f.compressed is None: + extra = ' ' + else: + extra = '%12d ' % (f.compressed) + file.write('%s %s %s %12d %s %s\n' % (creationdate, creationtime, attrib, + f.uncompressed, extra, f.filename)) + file.write('------------------- ----- ------------ ------------ ------------------------\n') + + return(0) + + @staticmethod + def print_archiveinfo(archive, file): + file.write("--\n") + file.write("Path = {}\n".format(archive.filename)) + file.write("Type = 7z\n") + fstat = os.stat(archive.filename) + file.write("Phisical Size = {}\n".format(fstat.st_size)) + file.write("Headers Size = {}\n".format(archive.header.size)) # fixme. + file.write("Method = {}\n".format(archive._get_method_names())) + if archive._is_solid(): + file.write("Solid = {}\n".format('+')) + else: + file.write("Solid = {}\n".format('-')) + file.write("Blocks = {}\n".format(len(archive.header.main_streams.unpackinfo.folders))) + + def run_test(self, args): + target = args.arcfile + if not py7zr.is_7zfile(target): + print('not a 7z file') + return(1) + with open(target, 'rb') as f: + a = py7zr.SevenZipFile(f) + file = sys.stdout + file.write("Testing archive: {}\n".format(a.filename)) + self.print_archiveinfo(archive=a, file=file) + file.write('\n') + if a.testzip() is None: + file.write('Everything is Ok\n') + return(0) + else: + file.write('Bad 7zip file\n') + return(1) + + def run_extract(self, args: argparse.Namespace) -> int: + target = args.arcfile + verbose = args.verbose + if not py7zr.is_7zfile(target): + print('not a 7z file') + return(1) + if not args.password: + password = None # type: Optional[str] + else: + try: + password = getpass.getpass() + except getpass.GetPassWarning: + sys.stderr.write('Warning: your password may be shown.\n') + return(1) + a = py7zr.SevenZipFile(target, 'r', password=password) + cb = None # Optional[ExtractCallback] + if verbose: + archive_info = a.archiveinfo() + cb = CliExtractCallback(total_bytes=archive_info.uncompressed, ofd=sys.stderr) + if args.odir: + a.extractall(path=args.odir, callback=cb) + else: + a.extractall(callback=cb) + return(0) + + def _check_volumesize_valid(self, size: str) -> bool: + if self.unit_pattern.match(size): + return True + else: + return False + + def _volumesize_unitconv(self, size: str) -> int: + m = self.unit_pattern.match(size) + num = m.group(1) + unit = m.group(2) + return int(num) if unit is None else int(num) * self.dunits[unit] + + def run_create(self, args): + sztarget = args.arcfile # type: str + filenames = args.filenames # type: List[str] + volume_size = args.volume[0] if getattr(args, 'volume', None) is not None else None + if volume_size is not None and not self._check_volumesize_valid(volume_size): + sys.stderr.write('Error: Specified volume size is invalid.\n') + self.show_help(args) + exit(1) + if not sztarget.endswith('.7z'): + sztarget += '.7z' + target = pathlib.Path(sztarget) + if target.exists(): + sys.stderr.write('Archive file exists!\n') + self.show_help(args) + exit(1) + with py7zr.SevenZipFile(target, 'w') as szf: + for path in filenames: + src = pathlib.Path(path) + if src.is_dir(): + szf.writeall(src) + else: + szf.write(src) + if volume_size is None: + return (0) + size = self._volumesize_unitconv(volume_size) + self._split_file(target, size) + target.unlink() + return(0) + + def _split_file(self, filepath, size): + chapters = 0 + written = [0, 0] + total_size = filepath.stat().st_size + with filepath.open('rb') as src: + while written[0] <= total_size: + with open(str(filepath) + '.%03d' % chapters, 'wb') as tgt: + written[1] = 0 + while written[1] < size: + read_size = min(READ_BLOCKSIZE, size - written[1]) + tgt.write(src.read(read_size)) + written[1] += read_size + written[0] += read_size + chapters += 1 diff --git a/py7zr/compression.py b/py7zr/compression.py new file mode 100644 index 0000000..4ba303a --- /dev/null +++ b/py7zr/compression.py @@ -0,0 +1,395 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2019 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# +import bz2 +import io +import lzma +import os +import queue +import sys +import threading +from typing import IO, Any, BinaryIO, Dict, List, Optional, Union + +from py7zr.exceptions import Bad7zFile, CrcError, UnsupportedCompressionMethodError +from py7zr.extra import AESDecompressor, CopyDecompressor, DeflateDecompressor, ISevenZipDecompressor, ZstdDecompressor +from py7zr.helpers import MemIO, NullIO, calculate_crc32, readlink +from py7zr.properties import READ_BLOCKSIZE, ArchivePassword, CompressionMethod + +if sys.version_info < (3, 6): + import pathlib2 as pathlib +else: + import pathlib +try: + import zstandard as Zstd # type: ignore +except ImportError: + Zstd = None + + +class Worker: + """Extract worker class to invoke handler""" + + def __init__(self, files, src_start: int, header) -> None: + self.target_filepath = {} # type: Dict[int, Union[MemIO, pathlib.Path, None]] + self.files = files + self.src_start = src_start + self.header = header + + def extract(self, fp: BinaryIO, parallel: bool, q=None) -> None: + """Extract worker method to handle 7zip folder and decompress each files.""" + if hasattr(self.header, 'main_streams') and self.header.main_streams is not None: + src_end = self.src_start + self.header.main_streams.packinfo.packpositions[-1] + numfolders = self.header.main_streams.unpackinfo.numfolders + if numfolders == 1: + self.extract_single(fp, self.files, self.src_start, src_end, q) + else: + folders = self.header.main_streams.unpackinfo.folders + positions = self.header.main_streams.packinfo.packpositions + empty_files = [f for f in self.files if f.emptystream] + if not parallel: + self.extract_single(fp, empty_files, 0, 0, q) + for i in range(numfolders): + self.extract_single(fp, folders[i].files, self.src_start + positions[i], + self.src_start + positions[i + 1], q) + else: + filename = getattr(fp, 'name', None) + self.extract_single(open(filename, 'rb'), empty_files, 0, 0, q) + extract_threads = [] + for i in range(numfolders): + p = threading.Thread(target=self.extract_single, + args=(filename, folders[i].files, + self.src_start + positions[i], self.src_start + positions[i + 1], q)) + p.start() + extract_threads.append((p)) + for p in extract_threads: + p.join() + else: + empty_files = [f for f in self.files if f.emptystream] + self.extract_single(fp, empty_files, 0, 0, q) + + def extract_single(self, fp: Union[BinaryIO, str], files, src_start: int, src_end: int, + q: Optional[queue.Queue]) -> None: + """Single thread extractor that takes file lists in single 7zip folder.""" + if files is None: + return + if isinstance(fp, str): + fp = open(fp, 'rb') + fp.seek(src_start) + for f in files: + if q is not None: + q.put(('s', str(f.filename), str(f.compressed) if f.compressed is not None else '0')) + fileish = self.target_filepath.get(f.id, None) + if fileish is not None: + fileish.parent.mkdir(parents=True, exist_ok=True) + with fileish.open(mode='wb') as ofp: + if not f.emptystream: + # extract to file + crc32 = self.decompress(fp, f.folder, ofp, f.uncompressed[-1], f.compressed, src_end) + ofp.seek(0) + if f.crc32 is not None and crc32 != f.crc32: + raise CrcError("{}".format(f.filename)) + else: + pass # just create empty file + elif not f.emptystream: + # read and bin off a data but check crc + with NullIO() as ofp: + crc32 = self.decompress(fp, f.folder, ofp, f.uncompressed[-1], f.compressed, src_end) + if f.crc32 is not None and crc32 != f.crc32: + raise CrcError("{}".format(f.filename)) + if q is not None: + q.put(('e', str(f.filename), str(f.uncompressed[-1]))) + + def decompress(self, fp: BinaryIO, folder, fq: IO[Any], + size: int, compressed_size: Optional[int], src_end: int) -> int: + """decompressor wrapper called from extract method. + + :parameter fp: archive source file pointer + :parameter folder: Folder object that have decompressor object. + :parameter fq: output file pathlib.Path + :parameter size: uncompressed size of target file. + :parameter compressed_size: compressed size of target file. + :parameter src_end: end position of the folder + :returns CRC32 of the file + """ + assert folder is not None + crc32 = 0 + out_remaining = size + decompressor = folder.get_decompressor(compressed_size) + while out_remaining > 0: + max_length = min(out_remaining, io.DEFAULT_BUFFER_SIZE) + rest_size = src_end - fp.tell() + read_size = min(READ_BLOCKSIZE, rest_size) + if read_size == 0: + tmp = decompressor.decompress(b'', max_length) + if len(tmp) == 0: + raise Exception("decompression get wrong: no output data.") + else: + inp = fp.read(read_size) + tmp = decompressor.decompress(inp, max_length) + if len(tmp) > 0 and out_remaining >= len(tmp): + out_remaining -= len(tmp) + fq.write(tmp) + crc32 = calculate_crc32(tmp, crc32) + if out_remaining <= 0: + break + if fp.tell() >= src_end: + # Check folder.digest integrity. + if decompressor.crc is not None and not decompressor.check_crc(): + raise Bad7zFile("Folder CRC32 error.") + return crc32 + + def _find_link_target(self, target): + """Find the target member of a symlink or hardlink member in the archive. + """ + targetname = target.as_posix() # type: str + linkname = readlink(targetname) + # Check windows full path symlinks + if linkname.startswith("\\\\?\\"): + linkname = linkname[4:] + # normalize as posix style + linkname = pathlib.Path(linkname).as_posix() # type: str + member = None + for j in range(len(self.files)): + if linkname == self.files[j].origin.as_posix(): + # FIXME: when API user specify arcname, it will break + member = os.path.relpath(linkname, os.path.dirname(targetname)) + break + if member is None: + member = linkname + return member + + def archive(self, fp: BinaryIO, folder, deref=False): + """Run archive task for specified 7zip folder.""" + compressor = folder.get_compressor() + outsize = 0 + self.header.main_streams.packinfo.numstreams = 1 + num_unpack_streams = 0 + self.header.main_streams.substreamsinfo.digests = [] + self.header.main_streams.substreamsinfo.digestsdefined = [] + last_file_index = 0 + foutsize = 0 + for i, f in enumerate(self.files): + file_info = f.file_properties() + self.header.files_info.files.append(file_info) + self.header.files_info.emptyfiles.append(f.emptystream) + foutsize = 0 + if f.is_symlink and not deref: + last_file_index = i + num_unpack_streams += 1 + link_target = self._find_link_target(f.origin) # type: str + tgt = link_target.encode('utf-8') # type: bytes + insize = len(tgt) + crc = calculate_crc32(tgt, 0) # type: int + out = compressor.compress(tgt) + outsize += len(out) + foutsize += len(out) + fp.write(out) + self.header.main_streams.substreamsinfo.digests.append(crc) + self.header.main_streams.substreamsinfo.digestsdefined.append(True) + self.header.main_streams.substreamsinfo.unpacksizes.append(insize) + self.header.files_info.files[i]['maxsize'] = foutsize + elif not f.emptystream: + last_file_index = i + num_unpack_streams += 1 + insize = 0 + with f.origin.open(mode='rb') as fd: + data = fd.read(READ_BLOCKSIZE) + insize += len(data) + crc = 0 + while data: + crc = calculate_crc32(data, crc) + out = compressor.compress(data) + outsize += len(out) + foutsize += len(out) + fp.write(out) + data = fd.read(READ_BLOCKSIZE) + insize += len(data) + self.header.main_streams.substreamsinfo.digests.append(crc) + self.header.main_streams.substreamsinfo.digestsdefined.append(True) + self.header.files_info.files[i]['maxsize'] = foutsize + self.header.main_streams.substreamsinfo.unpacksizes.append(insize) + else: + out = compressor.flush() + outsize += len(out) + foutsize += len(out) + fp.write(out) + if len(self.files) > 0: + self.header.files_info.files[last_file_index]['maxsize'] = foutsize + # Update size data in header + self.header.main_streams.packinfo.packsizes = [outsize] + folder.unpacksizes = [sum(self.header.main_streams.substreamsinfo.unpacksizes)] + self.header.main_streams.substreamsinfo.num_unpackstreams_folders = [num_unpack_streams] + + def register_filelike(self, id: int, fileish: Union[MemIO, pathlib.Path, None]) -> None: + """register file-ish to worker.""" + self.target_filepath[id] = fileish + + +class SevenZipDecompressor: + """Main decompressor object which is properly configured and bind to each 7zip folder. + because 7zip folder can have a custom compression method""" + + lzma_methods_map = { + CompressionMethod.LZMA: lzma.FILTER_LZMA1, + CompressionMethod.LZMA2: lzma.FILTER_LZMA2, + CompressionMethod.DELTA: lzma.FILTER_DELTA, + CompressionMethod.P7Z_BCJ: lzma.FILTER_X86, + CompressionMethod.BCJ_ARM: lzma.FILTER_ARM, + CompressionMethod.BCJ_ARMT: lzma.FILTER_ARMTHUMB, + CompressionMethod.BCJ_IA64: lzma.FILTER_IA64, + CompressionMethod.BCJ_PPC: lzma.FILTER_POWERPC, + CompressionMethod.BCJ_SPARC: lzma.FILTER_SPARC, + } + + FILTER_BZIP2 = 0x31 + FILTER_ZIP = 0x32 + FILTER_COPY = 0x33 + FILTER_AES = 0x34 + FILTER_ZSTD = 0x35 + alt_methods_map = { + CompressionMethod.MISC_BZIP2: FILTER_BZIP2, + CompressionMethod.MISC_DEFLATE: FILTER_ZIP, + CompressionMethod.COPY: FILTER_COPY, + CompressionMethod.CRYPT_AES256_SHA256: FILTER_AES, + CompressionMethod.MISC_ZSTD: FILTER_ZSTD, + } + + def __init__(self, coders: List[Dict[str, Any]], size: int, crc: Optional[int]) -> None: + # Get password which was set when creation of py7zr.SevenZipFile object. + self.input_size = size + self.consumed = 0 # type: int + self.crc = crc + self.digest = None # type: Optional[int] + if self._check_lzma_coders(coders): + self._set_lzma_decompressor(coders) + else: + self._set_alternative_decompressor(coders) + + def _check_lzma_coders(self, coders: List[Dict[str, Any]]) -> bool: + res = True + for coder in coders: + if self.lzma_methods_map.get(coder['method'], None) is None: + res = False + break + return res + + def _set_lzma_decompressor(self, coders: List[Dict[str, Any]]) -> None: + filters = [] # type: List[Dict[str, Any]] + for coder in coders: + if coder['numinstreams'] != 1 or coder['numoutstreams'] != 1: + raise UnsupportedCompressionMethodError('Only a simple compression method is currently supported.') + filter_id = self.lzma_methods_map.get(coder['method'], None) + if filter_id is None: + raise UnsupportedCompressionMethodError + properties = coder.get('properties', None) + if properties is not None: + filters[:0] = [lzma._decode_filter_properties(filter_id, properties)] # type: ignore + else: + filters[:0] = [{'id': filter_id}] + self.decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_RAW, filters=filters) # type: Union[bz2.BZ2Decompressor, lzma.LZMADecompressor, ISevenZipDecompressor] # noqa + + def _set_alternative_decompressor(self, coders: List[Dict[str, Any]]) -> None: + filter_id = self.alt_methods_map.get(coders[0]['method'], None) + if filter_id == self.FILTER_BZIP2: + self.decompressor = bz2.BZ2Decompressor() + elif filter_id == self.FILTER_ZIP: + self.decompressor = DeflateDecompressor() + elif filter_id == self.FILTER_COPY: + self.decompressor = CopyDecompressor() + elif filter_id == self.FILTER_ZSTD and Zstd: + self.decompressor = ZstdDecompressor() + elif filter_id == self.FILTER_AES: + password = ArchivePassword().get() + properties = coders[0].get('properties', None) + self.decompressor = AESDecompressor(properties, password, coders[1:]) + else: + raise UnsupportedCompressionMethodError + + def decompress(self, data: bytes, max_length: Optional[int] = None) -> bytes: + self.consumed += len(data) + if max_length is not None: + folder_data = self.decompressor.decompress(data, max_length=max_length) + else: + folder_data = self.decompressor.decompress(data) + # calculate CRC with uncompressed data + if self.crc is not None: + self.digest = calculate_crc32(folder_data, self.digest) + return folder_data + + def check_crc(self): + return self.crc == self.digest + + +class SevenZipCompressor: + + """Main compressor object to configured for each 7zip folder.""" + + __slots__ = ['filters', 'compressor', 'coders'] + + lzma_methods_map_r = { + lzma.FILTER_LZMA2: CompressionMethod.LZMA2, + lzma.FILTER_DELTA: CompressionMethod.DELTA, + lzma.FILTER_X86: CompressionMethod.P7Z_BCJ, + } + + def __init__(self, filters=None): + if filters is None: + self.filters = [{"id": lzma.FILTER_LZMA2, "preset": 7 | lzma.PRESET_EXTREME}, ] + else: + self.filters = filters + self.compressor = lzma.LZMACompressor(format=lzma.FORMAT_RAW, filters=self.filters) + self.coders = [] + for filter in self.filters: + if filter is None: + break + method = self.lzma_methods_map_r[filter['id']] + properties = lzma._encode_filter_properties(filter) + self.coders.append({'method': method, 'properties': properties, 'numinstreams': 1, 'numoutstreams': 1}) + + def compress(self, data): + return self.compressor.compress(data) + + def flush(self): + return self.compressor.flush() + + +def get_methods_names(coders: List[dict]) -> List[str]: + """Return human readable method names for specified coders""" + methods_name_map = { + CompressionMethod.LZMA2: "LZMA2", + CompressionMethod.LZMA: "LZMA", + CompressionMethod.DELTA: "delta", + CompressionMethod.P7Z_BCJ: "BCJ", + CompressionMethod.BCJ_ARM: "BCJ(ARM)", + CompressionMethod.BCJ_ARMT: "BCJ(ARMT)", + CompressionMethod.BCJ_IA64: "BCJ(IA64)", + CompressionMethod.BCJ_PPC: "BCJ(POWERPC)", + CompressionMethod.BCJ_SPARC: "BCJ(SPARC)", + CompressionMethod.CRYPT_AES256_SHA256: "7zAES", + } + methods_names = [] # type: List[str] + for coder in coders: + try: + methods_names.append(methods_name_map[coder['method']]) + except KeyError: + raise UnsupportedCompressionMethodError("Unknown method {}".format(coder['method'])) + return methods_names diff --git a/py7zr/exceptions.py b/py7zr/exceptions.py new file mode 100644 index 0000000..4286266 --- /dev/null +++ b/py7zr/exceptions.py @@ -0,0 +1,46 @@ +# +# p7zr library +# +# Copyright (c) 2019 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# + + +class ArchiveError(Exception): + pass + + +class Bad7zFile(ArchiveError): + pass + + +class CrcError(ArchiveError): + pass + + +class UnsupportedCompressionMethodError(ArchiveError): + pass + + +class DecompressionError(ArchiveError): + pass + + +class InternalError(ArchiveError): + pass diff --git a/py7zr/extra.py b/py7zr/extra.py new file mode 100644 index 0000000..309ea94 --- /dev/null +++ b/py7zr/extra.py @@ -0,0 +1,214 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2019 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# +import lzma +import zlib +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + +from Crypto.Cipher import AES + +from py7zr import UnsupportedCompressionMethodError +from py7zr.helpers import Buffer, calculate_key +from py7zr.properties import READ_BLOCKSIZE, CompressionMethod + +try: + import zstandard as Zstd # type: ignore +except ImportError: + Zstd = None + + +class ISevenZipCompressor(ABC): + @abstractmethod + def compress(self, data: Union[bytes, bytearray, memoryview]) -> bytes: + pass + + @abstractmethod + def flush(self) -> bytes: + pass + + +class ISevenZipDecompressor(ABC): + @abstractmethod + def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes: + pass + + +class DeflateDecompressor(ISevenZipDecompressor): + def __init__(self): + self.buf = b'' + self._decompressor = zlib.decompressobj(-15) + + def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1): + if max_length < 0: + res = self.buf + self._decompressor.decompress(data) + self.buf = b'' + else: + tmp = self.buf + self._decompressor.decompress(data) + res = tmp[:max_length] + self.buf = tmp[max_length:] + return res + + +class CopyDecompressor(ISevenZipDecompressor): + + def __init__(self): + self._buf = bytes() + + def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes: + if max_length < 0: + length = len(data) + else: + length = min(len(data), max_length) + buflen = len(self._buf) + if length > buflen: + res = self._buf + data[:length - buflen] + self._buf = data[length - buflen:] + else: + res = self._buf[:length] + self._buf = self._buf[length:] + data + return res + + +class AESDecompressor(ISevenZipDecompressor): + + lzma_methods_map = { + CompressionMethod.LZMA: lzma.FILTER_LZMA1, + CompressionMethod.LZMA2: lzma.FILTER_LZMA2, + CompressionMethod.DELTA: lzma.FILTER_DELTA, + CompressionMethod.P7Z_BCJ: lzma.FILTER_X86, + CompressionMethod.BCJ_ARM: lzma.FILTER_ARM, + CompressionMethod.BCJ_ARMT: lzma.FILTER_ARMTHUMB, + CompressionMethod.BCJ_IA64: lzma.FILTER_IA64, + CompressionMethod.BCJ_PPC: lzma.FILTER_POWERPC, + CompressionMethod.BCJ_SPARC: lzma.FILTER_SPARC, + } + + def __init__(self, aes_properties: bytes, password: str, coders: List[Dict[str, Any]]) -> None: + byte_password = password.encode('utf-16LE') + firstbyte = aes_properties[0] + numcyclespower = firstbyte & 0x3f + if firstbyte & 0xc0 != 0: + saltsize = (firstbyte >> 7) & 1 + ivsize = (firstbyte >> 6) & 1 + secondbyte = aes_properties[1] + saltsize += (secondbyte >> 4) + ivsize += (secondbyte & 0x0f) + assert len(aes_properties) == 2 + saltsize + ivsize + salt = aes_properties[2:2 + saltsize] + iv = aes_properties[2 + saltsize:2 + saltsize + ivsize] + assert len(salt) == saltsize + assert len(iv) == ivsize + assert numcyclespower <= 24 + if ivsize < 16: + iv += bytes('\x00' * (16 - ivsize), 'ascii') + key = calculate_key(byte_password, numcyclespower, salt, 'sha256') + if len(coders) > 0: + self.lzma_decompressor = self._set_lzma_decompressor(coders) # type: Union[lzma.LZMADecompressor, CopyDecompressor] # noqa + else: + self.lzma_decompressor = CopyDecompressor() + self.cipher = AES.new(key, AES.MODE_CBC, iv) + self.buf = Buffer(size=READ_BLOCKSIZE + 16) + self.flushed = False + else: + raise UnsupportedCompressionMethodError + + # set pipeline decompressor + def _set_lzma_decompressor(self, coders: List[Dict[str, Any]]) -> lzma.LZMADecompressor: + filters = [] # type: List[Dict[str, Any]] + for coder in coders: + filter = self.lzma_methods_map.get(coder['method'], None) + if filter is not None: + properties = coder.get('properties', None) + if properties is not None: + filters[:0] = [lzma._decode_filter_properties(filter, properties)] # type: ignore + else: + filters[:0] = [{'id': filter}] + else: + raise UnsupportedCompressionMethodError + return lzma.LZMADecompressor(format=lzma.FORMAT_RAW, filters=filters) + + def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes: + if len(data) == 0 and len(self.buf) == 0: # action flush + return self.lzma_decompressor.decompress(b'', max_length) + elif len(data) == 0: # action padding + self.flushded = True + # align = 16 + # padlen = (align - offset % align) % align + # = (align - (offset & (align - 1))) & (align - 1) + # = -offset & (align -1) + # = -offset & (16 - 1) = -offset & 15 + padlen = -len(self.buf) & 15 + self.buf.add(bytes(padlen)) + temp = self.cipher.decrypt(self.buf.view) # type: bytes + self.buf.reset() + return self.lzma_decompressor.decompress(temp, max_length) + else: + currentlen = len(self.buf) + len(data) + nextpos = (currentlen // 16) * 16 + if currentlen == nextpos: + self.buf.add(data) + temp = self.cipher.decrypt(self.buf.view) + self.buf.reset() + return self.lzma_decompressor.decompress(temp, max_length) + else: + buflen = len(self.buf) + temp2 = data[nextpos - buflen:] + self.buf.add(data[:nextpos - buflen]) + temp = self.cipher.decrypt(self.buf.view) + self.buf.set(temp2) + return self.lzma_decompressor.decompress(temp, max_length) + + +class ZstdDecompressor(ISevenZipDecompressor): + + def __init__(self): + if Zstd is None: + raise UnsupportedCompressionMethodError + self.buf = b'' # type: bytes + self._ctc = Zstd.ZstdDecompressor() # type: ignore + + def decompress(self, data: Union[bytes, bytearray, memoryview], max_length: int = -1) -> bytes: + dobj = self._ctc.decompressobj() # type: ignore + if max_length < 0: + res = self.buf + dobj.decompress(data) + self.buf = b'' + else: + tmp = self.buf + dobj.decompress(data) + res = tmp[:max_length] + self.buf = tmp[max_length:] + return res + + +class ZstdCompressor(ISevenZipCompressor): + + def __init__(self): + if Zstd is None: + raise UnsupportedCompressionMethodError + self._ctc = Zstd.ZstdCompressor() # type: ignore + + def compress(self, data: Union[bytes, bytearray, memoryview]) -> bytes: + return self._ctc.compress(data) # type: ignore + + def flush(self): + pass diff --git a/py7zr/helpers.py b/py7zr/helpers.py new file mode 100644 index 0000000..0bd7eba --- /dev/null +++ b/py7zr/helpers.py @@ -0,0 +1,397 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2019 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# +# + +import _hashlib # type: ignore # noqa +import ctypes +import os +import pathlib +import platform +import sys +import time as _time +import zlib +from datetime import datetime, timedelta, timezone, tzinfo +from typing import BinaryIO, Optional, Union + +import py7zr.win32compat + + +def calculate_crc32(data: bytes, value: Optional[int] = None, blocksize: int = 1024 * 1024) -> int: + """Calculate CRC32 of strings with arbitrary lengths.""" + length = len(data) + pos = blocksize + if value: + value = zlib.crc32(data[:pos], value) + else: + value = zlib.crc32(data[:pos]) + while pos < length: + value = zlib.crc32(data[pos:pos + blocksize], value) + pos += blocksize + + return value & 0xffffffff + + +def _calculate_key1(password: bytes, cycles: int, salt: bytes, digest: str) -> bytes: + """Calculate 7zip AES encryption key. Base implementation. """ + if digest not in ('sha256'): + raise ValueError('Unknown digest method for password protection.') + assert cycles <= 0x3f + if cycles == 0x3f: + ba = bytearray(salt + password + bytes(32)) + key = bytes(ba[:32]) # type: bytes + else: + rounds = 1 << cycles + m = _hashlib.new(digest) + for round in range(rounds): + m.update(salt + password + round.to_bytes(8, byteorder='little', signed=False)) + key = m.digest()[:32] + return key + + +def _calculate_key2(password: bytes, cycles: int, salt: bytes, digest: str): + """Calculate 7zip AES encryption key. + It utilize ctypes and memoryview buffer and zero-copy technology on Python.""" + if digest not in ('sha256'): + raise ValueError('Unknown digest method for password protection.') + assert cycles <= 0x3f + if cycles == 0x3f: + key = bytes(bytearray(salt + password + bytes(32))[:32]) # type: bytes + else: + rounds = 1 << cycles + m = _hashlib.new(digest) + length = len(salt) + len(password) + + class RoundBuf(ctypes.LittleEndianStructure): + _pack_ = 1 + _fields_ = [ + ('saltpassword', ctypes.c_ubyte * length), + ('round', ctypes.c_uint64) + ] + + buf = RoundBuf() + for i, c in enumerate(salt + password): + buf.saltpassword[i] = c + buf.round = 0 + mv = memoryview(buf) # type: ignore # noqa + while buf.round < rounds: + m.update(mv) + buf.round += 1 + key = m.digest()[:32] + return key + + +def _calculate_key3(password: bytes, cycles: int, salt: bytes, digest: str) -> bytes: + """Calculate 7zip AES encryption key. + Concat values in order to reduce number of calls of Hash.update().""" + if digest not in ('sha256'): + raise ValueError('Unknown digest method for password protection.') + assert cycles <= 0x3f + if cycles == 0x3f: + ba = bytearray(salt + password + bytes(32)) + key = bytes(ba[:32]) # type: bytes + else: + cat_cycle = 6 + if cycles > cat_cycle: + rounds = 1 << cat_cycle + stages = 1 << (cycles - cat_cycle) + else: + rounds = 1 << cycles + stages = 1 << 0 + m = _hashlib.new(digest) + saltpassword = salt + password + s = 0 # type: int # (0..stages) * rounds + if platform.python_implementation() == "PyPy": + for _ in range(stages): + m.update(memoryview(b''.join([saltpassword + (s + i).to_bytes(8, byteorder='little', signed=False) + for i in range(rounds)]))) + s += rounds + else: + for _ in range(stages): + m.update(b''.join([saltpassword + (s + i).to_bytes(8, byteorder='little', signed=False) + for i in range(rounds)])) + s += rounds + key = m.digest()[:32] + + return key + + +if platform.python_implementation() == "PyPy" or sys.version_info > (3, 6): + calculate_key = _calculate_key3 +else: + calculate_key = _calculate_key2 # it is faster when CPython 3.6.x + + +def filetime_to_dt(ft): + """Convert Windows NTFS file time into python datetime object.""" + EPOCH_AS_FILETIME = 116444736000000000 + us = (ft - EPOCH_AS_FILETIME) // 10 + return datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(microseconds=us) + + +ZERO = timedelta(0) +HOUR = timedelta(hours=1) +SECOND = timedelta(seconds=1) + +# A class capturing the platform's idea of local time. +# (May result in wrong values on historical times in +# timezones where UTC offset and/or the DST rules had +# changed in the past.) + +STDOFFSET = timedelta(seconds=-_time.timezone) +if _time.daylight: + DSTOFFSET = timedelta(seconds=-_time.altzone) +else: + DSTOFFSET = STDOFFSET + +DSTDIFF = DSTOFFSET - STDOFFSET + + +class LocalTimezone(tzinfo): + + def fromutc(self, dt): + assert dt.tzinfo is self + stamp = (dt - datetime(1970, 1, 1, tzinfo=self)) // SECOND + args = _time.localtime(stamp)[:6] + dst_diff = DSTDIFF // SECOND + # Detect fold + fold = (args == _time.localtime(stamp - dst_diff)) + return datetime(*args, microsecond=dt.microsecond, tzinfo=self) + + def utcoffset(self, dt): + if self._isdst(dt): + return DSTOFFSET + else: + return STDOFFSET + + def dst(self, dt): + if self._isdst(dt): + return DSTDIFF + else: + return ZERO + + def tzname(self, dt): + return _time.tzname[self._isdst(dt)] + + def _isdst(self, dt): + tt = (dt.year, dt.month, dt.day, + dt.hour, dt.minute, dt.second, + dt.weekday(), 0, 0) + stamp = _time.mktime(tt) + tt = _time.localtime(stamp) + return tt.tm_isdst > 0 + + +Local = LocalTimezone() +TIMESTAMP_ADJUST = -11644473600 + + +class UTC(tzinfo): + """UTC""" + + def utcoffset(self, dt): + return ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return ZERO + + def _call__(self): + return self + + +class ArchiveTimestamp(int): + """Windows FILETIME timestamp.""" + + def __repr__(self): + return '%s(%d)' % (type(self).__name__, self) + + def totimestamp(self) -> float: + """Convert 7z FILETIME to Python timestamp.""" + # FILETIME is 100-nanosecond intervals since 1601/01/01 (UTC) + return (self / 10000000.0) + TIMESTAMP_ADJUST + + def as_datetime(self): + """Convert FILETIME to Python datetime object.""" + return datetime.fromtimestamp(self.totimestamp(), UTC()) + + @staticmethod + def from_datetime(val): + return ArchiveTimestamp((val - TIMESTAMP_ADJUST) * 10000000.0) + + +def islink(path): + """ + Cross-platform islink implementation. + Supports Windows NT symbolic links and reparse points. + """ + is_symlink = os.path.islink(str(path)) + if sys.version_info >= (3, 8) or sys.platform != "win32" or sys.getwindowsversion()[0] < 6: + return is_symlink + # special check for directory junctions which py38 does. + if is_symlink: + if py7zr.win32compat.is_reparse_point(path): + is_symlink = False + return is_symlink + + +def readlink(path: Union[str, pathlib.Path], *, dir_fd=None) -> Union[str, pathlib.Path]: + """ + Cross-platform compat implementation of os.readlink and Path.readlink(). + Supports Windows NT symbolic links and reparse points. + When called with path argument as pathlike(str), return result as a pathlike(str). + When called with Path object, return also Path object. + When called with path argument as bytes, return result as a bytes. + """ + is_path_pathlib = isinstance(path, pathlib.Path) + if sys.version_info >= (3, 9): + if is_path_pathlib and dir_fd is None: + return path.readlink() + else: + return os.readlink(path, dir_fd=dir_fd) + elif sys.version_info >= (3, 8) or sys.platform != "win32": + res = os.readlink(path, dir_fd=dir_fd) + # Hack to handle a wrong type of results + if isinstance(res, bytes): + res = os.fsdecode(res) + if is_path_pathlib: + return pathlib.Path(res) + else: + return res + elif not os.path.exists(str(path)): + raise OSError(22, 'Invalid argument', path) + return py7zr.win32compat.readlink(path) + + +class MemIO: + """pathlib.Path-like IO class to write memory(io.Bytes)""" + def __init__(self, buf: BinaryIO): + self._buf = buf + + def write(self, data: bytes) -> int: + return self._buf.write(data) + + def read(self, length: Optional[int] = None) -> bytes: + if length is not None: + return self._buf.read(length) + else: + return self._buf.read() + + def close(self) -> None: + self._buf.seek(0) + + def flush(self) -> None: + pass + + def seek(self, position: int) -> None: + self._buf.seek(position) + + def open(self, mode=None): + return self + + @property + def parent(self): + return self + + def mkdir(self, parents=None, exist_ok=False): + return None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class NullIO: + """pathlib.Path-like IO class of /dev/null""" + + def __init__(self): + pass + + def write(self, data): + return len(data) + + def read(self, length=None): + if length is not None: + return bytes(length) + else: + return b'' + + def close(self): + pass + + def flush(self): + pass + + def open(self, mode=None): + return self + + @property + def parent(self): + return self + + def mkdir(self): + return None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class BufferOverflow(Exception): + pass + + +class Buffer: + + def __init__(self, size: int = 16): + self._size = size + self._buf = bytearray(size) + self._buflen = 0 + self.view = memoryview(self._buf[0:0]) + + def add(self, data: Union[bytes, bytearray, memoryview]): + length = len(data) + if length + self._buflen > self._size: + raise BufferOverflow() + self._buf[self._buflen:self._buflen + length] = data + self._buflen += length + self.view = memoryview(self._buf[0:self._buflen]) + + def reset(self) -> None: + self._buflen = 0 + self.view = memoryview(self._buf[0:0]) + + def set(self, data: Union[bytes, bytearray, memoryview]) -> None: + length = len(data) + if length > self._size: + raise BufferOverflow() + self._buf[0:length] = data + self._buflen = length + self.view = memoryview(self._buf[0:length]) + + def __len__(self) -> int: + return self._buflen diff --git a/py7zr/properties.py b/py7zr/properties.py new file mode 100644 index 0000000..38cfbe8 --- /dev/null +++ b/py7zr/properties.py @@ -0,0 +1,155 @@ +# +# p7zr library +# +# Copyright (c) 2019 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# + +import binascii +from enum import Enum +from typing import Optional + +MAGIC_7Z = binascii.unhexlify('377abcaf271c') +FINISH_7Z = binascii.unhexlify('377abcaf271d') +READ_BLOCKSIZE = 32248 +QUEUELEN = READ_BLOCKSIZE * 2 + +READ_BLOCKSIZE = 32248 + + +class ByteEnum(bytes, Enum): + pass + + +class Property(ByteEnum): + """Hold 7zip property fixed values.""" + END = binascii.unhexlify('00') + HEADER = binascii.unhexlify('01') + ARCHIVE_PROPERTIES = binascii.unhexlify('02') + ADDITIONAL_STREAMS_INFO = binascii.unhexlify('03') + MAIN_STREAMS_INFO = binascii.unhexlify('04') + FILES_INFO = binascii.unhexlify('05') + PACK_INFO = binascii.unhexlify('06') + UNPACK_INFO = binascii.unhexlify('07') + SUBSTREAMS_INFO = binascii.unhexlify('08') + SIZE = binascii.unhexlify('09') + CRC = binascii.unhexlify('0a') + FOLDER = binascii.unhexlify('0b') + CODERS_UNPACK_SIZE = binascii.unhexlify('0c') + NUM_UNPACK_STREAM = binascii.unhexlify('0d') + EMPTY_STREAM = binascii.unhexlify('0e') + EMPTY_FILE = binascii.unhexlify('0f') + ANTI = binascii.unhexlify('10') + NAME = binascii.unhexlify('11') + CREATION_TIME = binascii.unhexlify('12') + LAST_ACCESS_TIME = binascii.unhexlify('13') + LAST_WRITE_TIME = binascii.unhexlify('14') + ATTRIBUTES = binascii.unhexlify('15') + COMMENT = binascii.unhexlify('16') + ENCODED_HEADER = binascii.unhexlify('17') + START_POS = binascii.unhexlify('18') + DUMMY = binascii.unhexlify('19') + + +class CompressionMethod(ByteEnum): + """Hold fixed values for method parameter.""" + COPY = binascii.unhexlify('00') + DELTA = binascii.unhexlify('03') + BCJ = binascii.unhexlify('04') + PPC = binascii.unhexlify('05') + IA64 = binascii.unhexlify('06') + ARM = binascii.unhexlify('07') + ARMT = binascii.unhexlify('08') + SPARC = binascii.unhexlify('09') + # SWAP = 02.. + SWAP2 = binascii.unhexlify('020302') + SWAP4 = binascii.unhexlify('020304') + # 7Z = 03.. + LZMA = binascii.unhexlify('030101') + PPMD = binascii.unhexlify('030401') + P7Z_BCJ = binascii.unhexlify('03030103') + P7Z_BCJ2 = binascii.unhexlify('0303011B') + BCJ_PPC = binascii.unhexlify('03030205') + BCJ_IA64 = binascii.unhexlify('03030401') + BCJ_ARM = binascii.unhexlify('03030501') + BCJ_ARMT = binascii.unhexlify('03030701') + BCJ_SPARC = binascii.unhexlify('03030805') + LZMA2 = binascii.unhexlify('21') + # MISC : 04.. + MISC_ZIP = binascii.unhexlify('0401') + MISC_BZIP2 = binascii.unhexlify('040202') + MISC_DEFLATE = binascii.unhexlify('040108') + MISC_DEFLATE64 = binascii.unhexlify('040109') + MISC_Z = binascii.unhexlify('0405') + MISC_LZH = binascii.unhexlify('0406') + NSIS_DEFLATE = binascii.unhexlify('040901') + NSIS_BZIP2 = binascii.unhexlify('040902') + # + MISC_ZSTD = binascii.unhexlify('04f71101') + MISC_BROTLI = binascii.unhexlify('04f71102') + MISC_LZ4 = binascii.unhexlify('04f71104') + MISC_LZS = binascii.unhexlify('04f71105') + MISC_LIZARD = binascii.unhexlify('04f71106') + # CRYPTO 06.. + CRYPT_ZIPCRYPT = binascii.unhexlify('06f10101') + CRYPT_RAR29AES = binascii.unhexlify('06f10303') + CRYPT_AES256_SHA256 = binascii.unhexlify('06f10701') + + +class SupportedMethods: + """Hold list of methods which python3 can support.""" + formats = [{'name': "7z", 'magic': MAGIC_7Z}] + codecs = [{'id': CompressionMethod.LZMA, 'name': "LZMA"}, + {'id': CompressionMethod.LZMA2, 'name': "LZMA2"}, + {'id': CompressionMethod.DELTA, 'name': "DELTA"}, + {'id': CompressionMethod.P7Z_BCJ, 'name': "BCJ"}, + {'id': CompressionMethod.BCJ_PPC, 'name': 'PPC'}, + {'id': CompressionMethod.BCJ_IA64, 'name': 'IA64'}, + {'id': CompressionMethod.BCJ_ARM, 'name': "ARM"}, + {'id': CompressionMethod.BCJ_ARMT, 'name': "ARMT"}, + {'id': CompressionMethod.BCJ_SPARC, 'name': 'SPARC'} + ] + + +# this class is Borg/Singleton +class ArchivePassword: + + _shared_state = { + '_password': None, + } + + def __init__(self, password: Optional[str] = None): + self.__dict__ = self._shared_state + if password is not None: + self._password = password + + def set(self, password): + self._password = password + + def get(self): + if self._password is not None: + return self._password + else: + return '' + + def __str__(self): + if self._password is not None: + return self._password + else: + return '' diff --git a/py7zr/py7zr.py b/py7zr/py7zr.py new file mode 100644 index 0000000..7d228d0 --- /dev/null +++ b/py7zr/py7zr.py @@ -0,0 +1,974 @@ +#!/usr/bin/python -u +# +# p7zr library +# +# Copyright (c) 2019,2020 Hiroshi Miura +# Copyright (c) 2004-2015 by Joachim Bauch, mail@joachim-bauch.de +# 7-Zip Copyright (C) 1999-2010 Igor Pavlov +# LZMA SDK Copyright (C) 1999-2010 Igor Pavlov +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# +# +"""Read 7zip format archives.""" +import collections.abc +import datetime +import errno +import functools +import io +import operator +import os +import queue +import stat +import sys +import threading +from io import BytesIO +from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union + +from py7zr.archiveinfo import Folder, Header, SignatureHeader +from py7zr.callbacks import ExtractCallback +from py7zr.compression import SevenZipCompressor, Worker, get_methods_names +from py7zr.exceptions import Bad7zFile, CrcError, DecompressionError, InternalError +from py7zr.helpers import ArchiveTimestamp, MemIO, calculate_crc32, filetime_to_dt +from py7zr.properties import MAGIC_7Z, READ_BLOCKSIZE, ArchivePassword + +if sys.version_info < (3, 6): + import contextlib2 as contextlib + import pathlib2 as pathlib +else: + import contextlib + import pathlib + +if sys.platform.startswith('win'): + import _winapi + +FILE_ATTRIBUTE_UNIX_EXTENSION = 0x8000 +FILE_ATTRIBUTE_WINDOWS_MASK = 0x04fff + + +class ArchiveFile: + """Represent each files metadata inside archive file. + It holds file properties; filename, permissions, and type whether + it is directory, link or normal file. + + Instances of the :class:`ArchiveFile` class are returned by iterating :attr:`files_list` of + :class:`SevenZipFile` objects. + Each object stores information about a single member of the 7z archive. Most of users use :meth:`extractall()`. + + The class also hold an archive parameter where file is exist in + archive file folder(container).""" + def __init__(self, id: int, file_info: Dict[str, Any]) -> None: + self.id = id + self._file_info = file_info + + def file_properties(self) -> Dict[str, Any]: + """Return file properties as a hash object. Following keys are included: ‘readonly’, ‘is_directory’, + ‘posix_mode’, ‘archivable’, ‘emptystream’, ‘filename’, ‘creationtime’, ‘lastaccesstime’, + ‘lastwritetime’, ‘attributes’ + """ + properties = self._file_info + if properties is not None: + properties['readonly'] = self.readonly + properties['posix_mode'] = self.posix_mode + properties['archivable'] = self.archivable + properties['is_directory'] = self.is_directory + return properties + + def _get_property(self, key: str) -> Any: + try: + return self._file_info[key] + except KeyError: + return None + + @property + def origin(self) -> pathlib.Path: + return self._get_property('origin') + + @property + def folder(self) -> Folder: + return self._get_property('folder') + + @property + def filename(self) -> str: + """return filename of archive file.""" + return self._get_property('filename') + + @property + def emptystream(self) -> bool: + """True if file is empty(0-byte file), otherwise False""" + return self._get_property('emptystream') + + @property + def uncompressed(self) -> List[int]: + return self._get_property('uncompressed') + + @property + def uncompressed_size(self) -> int: + """Uncompressed file size.""" + return functools.reduce(operator.add, self.uncompressed) + + @property + def compressed(self) -> Optional[int]: + """Compressed size""" + return self._get_property('compressed') + + @property + def crc32(self) -> Optional[int]: + """CRC of archived file(optional)""" + return self._get_property('digest') + + def _test_attribute(self, target_bit: int) -> bool: + attributes = self._get_property('attributes') + if attributes is None: + return False + return attributes & target_bit == target_bit + + @property + def archivable(self) -> bool: + """File has a Windows `archive` flag.""" + return self._test_attribute(stat.FILE_ATTRIBUTE_ARCHIVE) # type: ignore # noqa + + @property + def is_directory(self) -> bool: + """True if file is a directory, otherwise False.""" + return self._test_attribute(stat.FILE_ATTRIBUTE_DIRECTORY) # type: ignore # noqa + + @property + def readonly(self) -> bool: + """True if file is readonly, otherwise False.""" + return self._test_attribute(stat.FILE_ATTRIBUTE_READONLY) # type: ignore # noqa + + def _get_unix_extension(self) -> Optional[int]: + attributes = self._get_property('attributes') + if self._test_attribute(FILE_ATTRIBUTE_UNIX_EXTENSION): + return attributes >> 16 + return None + + @property + def is_symlink(self) -> bool: + """True if file is a symbolic link, otherwise False.""" + e = self._get_unix_extension() + if e is not None: + return stat.S_ISLNK(e) + return self._test_attribute(stat.FILE_ATTRIBUTE_REPARSE_POINT) # type: ignore # noqa + + @property + def is_junction(self) -> bool: + """True if file is a junction/reparse point on windows, otherwise False.""" + return self._test_attribute(stat.FILE_ATTRIBUTE_REPARSE_POINT | # type: ignore # noqa + stat.FILE_ATTRIBUTE_DIRECTORY) # type: ignore # noqa + + @property + def is_socket(self) -> bool: + """True if file is a socket, otherwise False.""" + e = self._get_unix_extension() + if e is not None: + return stat.S_ISSOCK(e) + return False + + @property + def lastwritetime(self) -> Optional[ArchiveTimestamp]: + """Return last written timestamp of a file.""" + return self._get_property('lastwritetime') + + @property + def posix_mode(self) -> Optional[int]: + """ + posix mode when a member has a unix extension property, or None + :return: Return file stat mode can be set by os.chmod() + """ + e = self._get_unix_extension() + if e is not None: + return stat.S_IMODE(e) + return None + + @property + def st_fmt(self) -> Optional[int]: + """ + :return: Return the portion of the file mode that describes the file type + """ + e = self._get_unix_extension() + if e is not None: + return stat.S_IFMT(e) + return None + + +class ArchiveFileList(collections.abc.Iterable): + """Iteratable container of ArchiveFile.""" + + def __init__(self, offset: int = 0): + self.files_list = [] # type: List[dict] + self.index = 0 + self.offset = offset + + def append(self, file_info: Dict[str, Any]) -> None: + self.files_list.append(file_info) + + def __len__(self) -> int: + return len(self.files_list) + + def __iter__(self) -> 'ArchiveFileListIterator': + return ArchiveFileListIterator(self) + + def __getitem__(self, index): + if index > len(self.files_list): + raise IndexError + if index < 0: + raise IndexError + res = ArchiveFile(index + self.offset, self.files_list[index]) + return res + + +class ArchiveFileListIterator(collections.abc.Iterator): + + def __init__(self, archive_file_list): + self._archive_file_list = archive_file_list + self._index = 0 + + def __next__(self) -> ArchiveFile: + if self._index == len(self._archive_file_list): + raise StopIteration + res = self._archive_file_list[self._index] + self._index += 1 + return res + + +# ------------------ +# Exported Classes +# ------------------ +class ArchiveInfo: + """Hold archive information""" + + def __init__(self, filename, size, header_size, method_names, solid, blocks, uncompressed): + self.filename = filename + self.size = size + self.header_size = header_size + self.method_names = method_names + self.solid = solid + self.blocks = blocks + self.uncompressed = uncompressed + + +class FileInfo: + """Hold archived file information.""" + + def __init__(self, filename, compressed, uncompressed, archivable, is_directory, creationtime, crc32): + self.filename = filename + self.compressed = compressed + self.uncompressed = uncompressed + self.archivable = archivable + self.is_directory = is_directory + self.creationtime = creationtime + self.crc32 = crc32 + + +class SevenZipFile(contextlib.AbstractContextManager): + """The SevenZipFile Class provides an interface to 7z archives.""" + + def __init__(self, file: Union[BinaryIO, str, pathlib.Path], mode: str = 'r', + *, filters: Optional[str] = None, dereference=False, password: Optional[str] = None) -> None: + if mode not in ('r', 'w', 'x', 'a'): + raise ValueError("ZipFile requires mode 'r', 'w', 'x', or 'a'") + if password is not None: + if mode not in ('r'): + raise NotImplementedError("It has not been implemented to create archive with password.") + ArchivePassword(password) + self.password_protected = True + else: + self.password_protected = False + # Check if we were passed a file-like object or not + if isinstance(file, str): + self._filePassed = False # type: bool + self.filename = file # type: str + if mode == 'r': + self.fp = open(file, 'rb') # type: BinaryIO + elif mode == 'w': + self.fp = open(file, 'w+b') + elif mode == 'x': + self.fp = open(file, 'x+b') + elif mode == 'a': + self.fp = open(file, 'r+b') + else: + raise ValueError("File open error.") + self.mode = mode + elif isinstance(file, pathlib.Path): + self._filePassed = False + self.filename = str(file) + if mode == 'r': + self.fp = file.open(mode='rb') # type: ignore # noqa # typeshed issue: 2911 + elif mode == 'w': + self.fp = file.open(mode='w+b') # type: ignore # noqa + elif mode == 'x': + self.fp = file.open(mode='x+b') # type: ignore # noqa + elif mode == 'a': + self.fp = file.open(mode='r+b') # type: ignore # noqa + else: + raise ValueError("File open error.") + self.mode = mode + elif isinstance(file, io.IOBase): + self._filePassed = True + self.fp = file + self.filename = getattr(file, 'name', None) + self.mode = mode # type: ignore #noqa + else: + raise TypeError("invalid file: {}".format(type(file))) + self._fileRefCnt = 1 + try: + if mode == "r": + self._real_get_contents(self.fp) + self._reset_worker() + elif mode in 'w': + # FIXME: check filters here + self.folder = self._create_folder(filters) + self.files = ArchiveFileList() + self._prepare_write() + self._reset_worker() + elif mode in 'x': + raise NotImplementedError + elif mode == 'a': + raise NotImplementedError + else: + raise ValueError("Mode must be 'r', 'w', 'x', or 'a'") + except Exception as e: + self._fpclose() + raise e + self.encoded_header_mode = False + self._dict = {} # type: Dict[str, IO[Any]] + self.dereference = dereference + self.reporterd = None # type: Optional[threading.Thread] + self.q = queue.Queue() # type: queue.Queue[Any] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def _create_folder(self, filters): + folder = Folder() + folder.compressor = SevenZipCompressor(filters) + folder.coders = folder.compressor.coders + folder.solid = True + folder.digestdefined = False + folder.bindpairs = [] + folder.totalin = 1 + folder.totalout = 1 + return folder + + def _fpclose(self) -> None: + assert self._fileRefCnt > 0 + self._fileRefCnt -= 1 + if not self._fileRefCnt and not self._filePassed: + self.fp.close() + + def _real_get_contents(self, fp: BinaryIO) -> None: + if not self._check_7zfile(fp): + raise Bad7zFile('not a 7z file') + self.sig_header = SignatureHeader.retrieve(self.fp) + self.afterheader = self.fp.tell() + buffer = self._read_header_data() + header = Header.retrieve(self.fp, buffer, self.afterheader) + if header is None: + return + self.header = header + buffer.close() + self.files = ArchiveFileList() + if getattr(self.header, 'files_info', None) is not None: + self._filelist_retrieve() + + def _read_header_data(self) -> BytesIO: + self.fp.seek(self.sig_header.nextheaderofs, os.SEEK_CUR) + buffer = io.BytesIO(self.fp.read(self.sig_header.nextheadersize)) + if self.sig_header.nextheadercrc != calculate_crc32(buffer.getvalue()): + raise Bad7zFile('invalid header data') + return buffer + + class ParseStatus: + def __init__(self, src_pos=0): + self.src_pos = src_pos + self.folder = 0 # 7zip folder where target stored + self.outstreams = 0 # output stream count + self.input = 0 # unpack stream count in each folder + self.stream = 0 # target input stream position + + def _gen_filename(self) -> str: + # compressed file is stored without a name, generate one + try: + basefilename = self.filename + except AttributeError: + # 7z archive file doesn't have a name + return 'contents' + else: + if basefilename is not None: + fn, ext = os.path.splitext(os.path.basename(basefilename)) + return fn + else: + return 'contents' + + def _get_fileinfo_sizes(self, pstat, subinfo, packinfo, folder, packsizes, unpacksizes, file_in_solid, numinstreams): + if pstat.input == 0: + folder.solid = subinfo.num_unpackstreams_folders[pstat.folder] > 1 + maxsize = (folder.solid and packinfo.packsizes[pstat.stream]) or None + uncompressed = unpacksizes[pstat.outstreams] + if not isinstance(uncompressed, (list, tuple)): + uncompressed = [uncompressed] * len(folder.coders) + if file_in_solid > 0: + compressed = None + elif pstat.stream < len(packsizes): # file is compressed + compressed = packsizes[pstat.stream] + else: # file is not compressed + compressed = uncompressed + packsize = packsizes[pstat.stream:pstat.stream + numinstreams] + return maxsize, compressed, uncompressed, packsize, folder.solid + + def _filelist_retrieve(self) -> None: + # Initialize references for convenience + if hasattr(self.header, 'main_streams') and self.header.main_streams is not None: + folders = self.header.main_streams.unpackinfo.folders + packinfo = self.header.main_streams.packinfo + subinfo = self.header.main_streams.substreamsinfo + packsizes = packinfo.packsizes + unpacksizes = subinfo.unpacksizes if subinfo.unpacksizes is not None else [x.unpacksizes for x in folders] + else: + subinfo = None + folders = None + packinfo = None + packsizes = [] + unpacksizes = [0] + + pstat = self.ParseStatus() + pstat.src_pos = self.afterheader + file_in_solid = 0 + + for file_id, file_info in enumerate(self.header.files_info.files): + if not file_info['emptystream'] and folders is not None: + folder = folders[pstat.folder] + numinstreams = max([coder.get('numinstreams', 1) for coder in folder.coders]) + (maxsize, compressed, uncompressed, + packsize, solid) = self._get_fileinfo_sizes(pstat, subinfo, packinfo, folder, packsizes, + unpacksizes, file_in_solid, numinstreams) + pstat.input += 1 + folder.solid = solid + file_info['folder'] = folder + file_info['maxsize'] = maxsize + file_info['compressed'] = compressed + file_info['uncompressed'] = uncompressed + file_info['packsizes'] = packsize + if subinfo.digestsdefined[pstat.outstreams]: + file_info['digest'] = subinfo.digests[pstat.outstreams] + if folder is None: + pstat.src_pos += file_info['compressed'] + else: + if folder.solid: + file_in_solid += 1 + pstat.outstreams += 1 + if folder.files is None: + folder.files = ArchiveFileList(offset=file_id) + folder.files.append(file_info) + if pstat.input >= subinfo.num_unpackstreams_folders[pstat.folder]: + file_in_solid = 0 + pstat.src_pos += sum(packinfo.packsizes[pstat.stream:pstat.stream + numinstreams]) + pstat.folder += 1 + pstat.stream += numinstreams + pstat.input = 0 + else: + file_info['folder'] = None + file_info['maxsize'] = 0 + file_info['compressed'] = 0 + file_info['uncompressed'] = [0] + file_info['packsizes'] = [0] + + if 'filename' not in file_info: + file_info['filename'] = self._gen_filename() + self.files.append(file_info) + + def _num_files(self) -> int: + if getattr(self.header, 'files_info', None) is not None: + return len(self.header.files_info.files) + return 0 + + def _set_file_property(self, outfilename: pathlib.Path, properties: Dict[str, Any]) -> None: + # creation time + creationtime = ArchiveTimestamp(properties['lastwritetime']).totimestamp() + if creationtime is not None: + os.utime(str(outfilename), times=(creationtime, creationtime)) + if os.name == 'posix': + st_mode = properties['posix_mode'] + if st_mode is not None: + outfilename.chmod(st_mode) + return + # fallback: only set readonly if specified + if properties['readonly'] and not properties['is_directory']: + ro_mask = 0o777 ^ (stat.S_IWRITE | stat.S_IWGRP | stat.S_IWOTH) + outfilename.chmod(outfilename.stat().st_mode & ro_mask) + + def _reset_decompressor(self) -> None: + if self.header.main_streams is not None and self.header.main_streams.unpackinfo.numfolders > 0: + for i, folder in enumerate(self.header.main_streams.unpackinfo.folders): + folder.decompressor = None + + def _reset_worker(self) -> None: + """Seek to where archive data start in archive and recreate new worker.""" + self.fp.seek(self.afterheader) + self.worker = Worker(self.files, self.afterheader, self.header) + + def set_encoded_header_mode(self, mode: bool) -> None: + self.encoded_header_mode = mode + + @staticmethod + def _check_7zfile(fp: Union[BinaryIO, io.BufferedReader]) -> bool: + result = MAGIC_7Z == fp.read(len(MAGIC_7Z))[:len(MAGIC_7Z)] + fp.seek(-len(MAGIC_7Z), 1) + return result + + def _get_method_names(self) -> str: + methods_names = [] # type: List[str] + for folder in self.header.main_streams.unpackinfo.folders: + methods_names += get_methods_names(folder.coders) + return ', '.join(x for x in methods_names) + + def _test_digest_raw(self, pos: int, size: int, crc: int) -> bool: + self.fp.seek(pos) + remaining_size = size + digest = None + while remaining_size > 0: + block = min(READ_BLOCKSIZE, remaining_size) + digest = calculate_crc32(self.fp.read(block), digest) + remaining_size -= block + return digest == crc + + def _prepare_write(self) -> None: + self.sig_header = SignatureHeader() + self.sig_header._write_skelton(self.fp) + self.afterheader = self.fp.tell() + self.folder.totalin = 1 + self.folder.totalout = 1 + self.folder.bindpairs = [] + self.folder.unpacksizes = [] + self.header = Header.build_header([self.folder]) + + def _write_archive(self): + self.worker.archive(self.fp, self.folder, deref=self.dereference) + # Write header and update signature header + (header_pos, header_len, header_crc) = self.header.write(self.fp, self.afterheader, + encoded=self.encoded_header_mode) + self.sig_header.nextheaderofs = header_pos - self.afterheader + self.sig_header.calccrc(header_len, header_crc) + self.sig_header.write(self.fp) + return + + def _is_solid(self): + for f in self.header.main_streams.substreamsinfo.num_unpackstreams_folders: + if f > 1: + return True + return False + + def _var_release(self): + self._dict = None + self.files = None + self.folder = None + self.header = None + self.worker = None + self.sig_header = None + + @staticmethod + def _make_file_info(target: pathlib.Path, arcname: Optional[str] = None, dereference=False) -> Dict[str, Any]: + f = {} # type: Dict[str, Any] + f['origin'] = target + if arcname is not None: + f['filename'] = pathlib.Path(arcname).as_posix() + else: + f['filename'] = target.as_posix() + if os.name == 'nt': + fstat = target.lstat() + if target.is_symlink(): + if dereference: + fstat = target.stat() + if stat.S_ISDIR(fstat.st_mode): + f['emptystream'] = True + f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa + else: + f['emptystream'] = False + f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa + f['uncompressed'] = fstat.st_size + else: + f['emptystream'] = False + f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa + # f['attributes'] |= stat.FILE_ATTRIBUTE_REPARSE_POINT # type: ignore # noqa + elif target.is_dir(): + f['emptystream'] = True + f['attributes'] = fstat.st_file_attributes & FILE_ATTRIBUTE_WINDOWS_MASK # type: ignore # noqa + elif target.is_file(): + f['emptystream'] = False + f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa + f['uncompressed'] = fstat.st_size + else: + fstat = target.lstat() + if target.is_symlink(): + if dereference: + fstat = target.stat() + if stat.S_ISDIR(fstat.st_mode): + f['emptystream'] = True + f['attributes'] = stat.FILE_ATTRIBUTE_DIRECTORY # type: ignore # noqa + f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFDIR << 16) + f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16) + else: + f['emptystream'] = False + f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa + f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IMODE(fstat.st_mode) << 16) + else: + f['emptystream'] = False + f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE | stat.FILE_ATTRIBUTE_REPARSE_POINT # type: ignore # noqa + f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFLNK << 16) + f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16) + elif target.is_dir(): + f['emptystream'] = True + f['attributes'] = stat.FILE_ATTRIBUTE_DIRECTORY # type: ignore # noqa + f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IFDIR << 16) + f['attributes'] |= (stat.S_IMODE(fstat.st_mode) << 16) + elif target.is_file(): + f['emptystream'] = False + f['uncompressed'] = fstat.st_size + f['attributes'] = stat.FILE_ATTRIBUTE_ARCHIVE # type: ignore # noqa + f['attributes'] |= FILE_ATTRIBUTE_UNIX_EXTENSION | (stat.S_IMODE(fstat.st_mode) << 16) + + f['creationtime'] = fstat.st_ctime + f['lastwritetime'] = fstat.st_mtime + f['lastaccesstime'] = fstat.st_atime + return f + + # -------------------------------------------------------------------------- + # The public methods which SevenZipFile provides: + def getnames(self) -> List[str]: + """Return the members of the archive as a list of their names. It has + the same order as the list returned by getmembers(). + """ + return list(map(lambda x: x.filename, self.files)) + + def archiveinfo(self) -> ArchiveInfo: + fstat = os.stat(self.filename) + uncompressed = 0 + for f in self.files: + uncompressed += f.uncompressed_size + return ArchiveInfo(self.filename, fstat.st_size, self.header.size, self._get_method_names(), + self._is_solid(), len(self.header.main_streams.unpackinfo.folders), + uncompressed) + + def list(self) -> List[FileInfo]: + """Returns contents information """ + alist = [] # type: List[FileInfo] + creationtime = None # type: Optional[datetime.datetime] + for f in self.files: + if f.lastwritetime is not None: + creationtime = filetime_to_dt(f.lastwritetime) + alist.append(FileInfo(f.filename, f.compressed, f.uncompressed_size, f.archivable, f.is_directory, + creationtime, f.crc32)) + return alist + + def readall(self) -> Optional[Dict[str, IO[Any]]]: + return self._extract(path=None, return_dict=True) + + def extractall(self, path: Optional[Any] = None, callback: Optional[ExtractCallback] = None) -> None: + """Extract all members from the archive to the current working + directory and set owner, modification time and permissions on + directories afterwards. `path' specifies a different directory + to extract to. + """ + self._extract(path=path, return_dict=False, callback=callback) + + def read(self, targets: Optional[List[str]] = None) -> Optional[Dict[str, IO[Any]]]: + return self._extract(path=None, targets=targets, return_dict=True) + + def extract(self, path: Optional[Any] = None, targets: Optional[List[str]] = None) -> None: + self._extract(path, targets, return_dict=False) + + def _extract(self, path: Optional[Any] = None, targets: Optional[List[str]] = None, + return_dict: bool = False, callback: Optional[ExtractCallback] = None) -> Optional[Dict[str, IO[Any]]]: + if callback is not None and not isinstance(callback, ExtractCallback): + raise ValueError('Callback specified is not a subclass of py7zr.callbacks.ExtractCallback class') + elif callback is not None: + self.reporterd = threading.Thread(target=self.reporter, args=(callback,), daemon=True) + self.reporterd.start() + target_junction = [] # type: List[pathlib.Path] + target_sym = [] # type: List[pathlib.Path] + target_files = [] # type: List[Tuple[pathlib.Path, Dict[str, Any]]] + target_dirs = [] # type: List[pathlib.Path] + if path is not None: + if isinstance(path, str): + path = pathlib.Path(path) + try: + if not path.exists(): + path.mkdir(parents=True) + else: + pass + except OSError as e: + if e.errno == errno.EEXIST and path.is_dir(): + pass + else: + raise e + fnames = [] # type: List[str] # check duplicated filename in one archive? + self.q.put(('pre', None, None)) + for f in self.files: + # TODO: sanity check + # check whether f.filename with invalid characters: '../' + if f.filename.startswith('../'): + raise Bad7zFile + # When archive has a multiple files which have same name + # To guarantee order of archive, multi-thread decompression becomes off. + # Currently always overwrite by latter archives. + # TODO: provide option to select overwrite or skip. + if f.filename not in fnames: + outname = f.filename + else: + i = 0 + while True: + outname = f.filename + '_%d' % i + if outname not in fnames: + break + fnames.append(outname) + if path is not None: + outfilename = path.joinpath(outname) + else: + outfilename = pathlib.Path(outname) + if os.name == 'nt': + if outfilename.is_absolute(): + # hack for microsoft windows path length limit < 255 + outfilename = pathlib.WindowsPath('\\\\?\\' + str(outfilename)) + if targets is not None and f.filename not in targets: + self.worker.register_filelike(f.id, None) + continue + if f.is_directory: + if not outfilename.exists(): + target_dirs.append(outfilename) + target_files.append((outfilename, f.file_properties())) + else: + pass + elif f.is_socket: + pass + elif return_dict: + fname = outfilename.as_posix() + _buf = io.BytesIO() + self._dict[fname] = _buf + self.worker.register_filelike(f.id, MemIO(_buf)) + elif f.is_symlink: + target_sym.append(outfilename) + try: + if outfilename.exists(): + outfilename.unlink() + except OSError as ose: + if ose.errno not in [errno.ENOENT]: + raise + self.worker.register_filelike(f.id, outfilename) + elif f.is_junction: + target_junction.append(outfilename) + self.worker.register_filelike(f.id, outfilename) + else: + self.worker.register_filelike(f.id, outfilename) + target_files.append((outfilename, f.file_properties())) + for target_dir in sorted(target_dirs): + try: + target_dir.mkdir() + except FileExistsError: + if target_dir.is_dir(): + pass + elif target_dir.is_file(): + raise DecompressionError("Directory {} is existed as a normal file.".format(str(target_dir))) + else: + raise DecompressionError("Directory {} making fails on unknown condition.".format(str(target_dir))) + + try: + if callback is not None: + self.worker.extract(self.fp, parallel=(not self.password_protected and not self._filePassed), q=self.q) + else: + self.worker.extract(self.fp, parallel=(not self.password_protected and not self._filePassed)) + except CrcError as ce: + raise Bad7zFile("CRC32 error on archived file {}.".format(str(ce))) + + self.q.put(('post', None, None)) + if return_dict: + return self._dict + else: + # create symbolic links on target path as a working directory. + # if path is None, work on current working directory. + for t in target_sym: + sym_dst = t.resolve() + with sym_dst.open('rb') as b: + sym_src = b.read().decode(encoding='utf-8') # symlink target name stored in utf-8 + sym_dst.unlink() # unlink after close(). + sym_dst.symlink_to(pathlib.Path(sym_src)) + # create junction point only on windows platform + if sys.platform.startswith('win'): + for t in target_junction: + junction_dst = t.resolve() + with junction_dst.open('rb') as b: + junction_target = pathlib.Path(b.read().decode(encoding='utf-8')) + junction_dst.unlink() + _winapi.CreateJunction(junction_target, str(junction_dst)) # type: ignore # noqa + # set file properties + for o, p in target_files: + self._set_file_property(o, p) + return None + + def reporter(self, callback: ExtractCallback): + while True: + try: + item = self.q.get(timeout=1) # type: Optional[Tuple[str, str, str]] + except queue.Empty: + pass + else: + if item is None: + break + elif item[0] == 's': + callback.report_start(item[1], item[2]) + elif item[0] == 'e': + callback.report_end(item[1], item[2]) + elif item[0] == 'pre': + callback.report_start_preparation() + elif item[0] == 'post': + callback.report_postprocess() + elif item[0] == 'w': + callback.report_warning(item[1]) + else: + pass + self.q.task_done() + + def writeall(self, path: Union[pathlib.Path, str], arcname: Optional[str] = None): + """Write files in target path into archive.""" + if isinstance(path, str): + path = pathlib.Path(path) + if not path.exists(): + raise ValueError("specified path does not exist.") + if path.is_dir() or path.is_file(): + self._writeall(path, arcname) + else: + raise ValueError("specified path is not a directory or a file") + + def _writeall(self, path, arcname): + try: + if path.is_symlink() and not self.dereference: + self.write(path, arcname) + elif path.is_file(): + self.write(path, arcname) + elif path.is_dir(): + if not path.samefile('.'): + self.write(path, arcname) + for nm in sorted(os.listdir(str(path))): + arc = os.path.join(arcname, nm) if arcname is not None else None + self._writeall(path.joinpath(nm), arc) + else: + return # pathlib ignores ELOOP and return False for is_*(). + except OSError as ose: + if self.dereference and ose.errno in [errno.ELOOP]: + return # ignore ELOOP here, this resulted to stop looped symlink reference. + elif self.dereference and sys.platform == 'win32' and ose.errno in [errno.ENOENT]: + return # ignore ENOENT which is happened when a case of ELOOP on windows. + else: + raise + + def write(self, file: Union[pathlib.Path, str], arcname: Optional[str] = None): + """Write single target file into archive(Not implemented yet).""" + if isinstance(file, str): + path = pathlib.Path(file) + elif isinstance(file, pathlib.Path): + path = file + else: + raise ValueError("Unsupported file type.") + file_info = self._make_file_info(path, arcname, self.dereference) + self.files.append(file_info) + + def close(self): + """Flush all the data into archive and close it. + When close py7zr start reading target and writing actual archive file. + """ + if 'w' in self.mode: + self._write_archive() + if 'r' in self.mode: + if self.reporterd is not None: + self.q.put_nowait(None) + self.reporterd.join(1) + if self.reporterd.is_alive(): + raise InternalError("Progress report thread terminate error.") + self.reporterd = None + self._fpclose() + self._var_release() + + def reset(self) -> None: + """When read mode, it reset file pointer, decompress worker and decompressor""" + if self.mode == 'r': + self._reset_worker() + self._reset_decompressor() + + def test(self) -> Optional[bool]: + self._reset_worker() + crcs = self.header.main_streams.packinfo.crcs # type: Optional[List[int]] + if crcs is None or len(crcs) == 0: + return None + # check packed stream's crc + assert len(crcs) == len(self.header.main_streams.packinfo.packpositions) + for i, p in enumerate(self.header.main_streams.packinfo.packpositions): + if not self._test_digest_raw(p, self.header.main_streams.packinfo.packsizes[i], crcs[i]): + return False + return True + + def testzip(self) -> Optional[str]: + self._reset_worker() + for f in self.files: + self.worker.register_filelike(f.id, None) + try: + self.worker.extract(self.fp, parallel=(not self.password_protected)) # TODO: print progress + except CrcError as crce: + return str(crce) + else: + return None + + +# -------------------- +# exported functions +# -------------------- +def is_7zfile(file: Union[BinaryIO, str, pathlib.Path]) -> bool: + """Quickly see if a file is a 7Z file by checking the magic number. + The file argument may be a filename or file-like object too. + """ + result = False + try: + if isinstance(file, io.IOBase) and hasattr(file, "read"): + result = SevenZipFile._check_7zfile(file) # type: ignore # noqa + elif isinstance(file, str): + with open(file, 'rb') as fp: + result = SevenZipFile._check_7zfile(fp) + elif isinstance(file, pathlib.Path) or isinstance(file, pathlib.PosixPath) or \ + isinstance(file, pathlib.WindowsPath): + with file.open(mode='rb') as fp: # type: ignore # noqa + result = SevenZipFile._check_7zfile(fp) + else: + raise TypeError('invalid type: file should be str, pathlib.Path or BinaryIO, but {}'.format(type(file))) + except OSError: + pass + return result + + +def unpack_7zarchive(archive, path, extra=None): + """Function for registering with shutil.register_unpack_format()""" + arc = SevenZipFile(archive) + arc.extractall(path) + arc.close() + + +def pack_7zarchive(base_name, base_dir, owner=None, group=None, dry_run=None, logger=None): + """Function for registering with shutil.register_archive_format()""" + target_name = '{}.7z'.format(base_name) + archive = SevenZipFile(target_name, mode='w') + archive.writeall(path=base_dir) + archive.close() diff --git a/py7zr/win32compat.py b/py7zr/win32compat.py new file mode 100644 index 0000000..dc72bfd --- /dev/null +++ b/py7zr/win32compat.py @@ -0,0 +1,174 @@ +import pathlib +import stat +import sys +from logging import getLogger +from typing import Union + +if sys.platform == "win32": + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE, LPCWSTR, LPDWORD, LPVOID, LPWSTR + + _stdcall_libraries = {} + _stdcall_libraries['kernel32'] = ctypes.WinDLL('kernel32') + CloseHandle = _stdcall_libraries['kernel32'].CloseHandle + CreateFileW = _stdcall_libraries['kernel32'].CreateFileW + DeviceIoControl = _stdcall_libraries['kernel32'].DeviceIoControl + GetFileAttributesW = _stdcall_libraries['kernel32'].GetFileAttributesW + OPEN_EXISTING = 3 + GENERIC_READ = 2147483648 + FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000 + FSCTL_GET_REPARSE_POINT = 0x000900A8 + FILE_FLAG_BACKUP_SEMANTICS = 0x02000000 + IO_REPARSE_TAG_MOUNT_POINT = 0xA0000003 + IO_REPARSE_TAG_SYMLINK = 0xA000000C + MAXIMUM_REPARSE_DATA_BUFFER_SIZE = 16 * 1024 + + def _check_bit(val: int, flag: int) -> bool: + return bool(val & flag == flag) + + class SymbolicLinkReparseBuffer(ctypes.Structure): + """ Implementing the below in Python: + + typedef struct _REPARSE_DATA_BUFFER { + ULONG ReparseTag; + USHORT ReparseDataLength; + USHORT Reserved; + union { + struct { + USHORT SubstituteNameOffset; + USHORT SubstituteNameLength; + USHORT PrintNameOffset; + USHORT PrintNameLength; + ULONG Flags; + WCHAR PathBuffer[1]; + } SymbolicLinkReparseBuffer; + struct { + USHORT SubstituteNameOffset; + USHORT SubstituteNameLength; + USHORT PrintNameOffset; + USHORT PrintNameLength; + WCHAR PathBuffer[1]; + } MountPointReparseBuffer; + struct { + UCHAR DataBuffer[1]; + } GenericReparseBuffer; + } DUMMYUNIONNAME; + } REPARSE_DATA_BUFFER, *PREPARSE_DATA_BUFFER; + """ + # See https://docs.microsoft.com/en-us/windows-hardware/drivers/ddi/content/ntifs/ns-ntifs-_reparse_data_buffer + _fields_ = [ + ('flags', ctypes.c_ulong), + ('path_buffer', ctypes.c_byte * (MAXIMUM_REPARSE_DATA_BUFFER_SIZE - 20)) + ] + + class MountReparseBuffer(ctypes.Structure): + _fields_ = [ + ('path_buffer', ctypes.c_byte * (MAXIMUM_REPARSE_DATA_BUFFER_SIZE - 16)), + ] + + class ReparseBufferField(ctypes.Union): + _fields_ = [ + ('symlink', SymbolicLinkReparseBuffer), + ('mount', MountReparseBuffer) + ] + + class ReparseBuffer(ctypes.Structure): + _anonymous_ = ("u",) + _fields_ = [ + ('reparse_tag', ctypes.c_ulong), + ('reparse_data_length', ctypes.c_ushort), + ('reserved', ctypes.c_ushort), + ('substitute_name_offset', ctypes.c_ushort), + ('substitute_name_length', ctypes.c_ushort), + ('print_name_offset', ctypes.c_ushort), + ('print_name_length', ctypes.c_ushort), + ('u', ReparseBufferField) + ] + + def is_reparse_point(path: Union[str, pathlib.Path]) -> bool: + GetFileAttributesW.argtypes = [LPCWSTR] + GetFileAttributesW.restype = DWORD + return _check_bit(GetFileAttributesW(str(path)), stat.FILE_ATTRIBUTE_REPARSE_POINT) + + def readlink(path: Union[str, pathlib.Path]) -> Union[str, pathlib.WindowsPath]: + # FILE_FLAG_OPEN_REPARSE_POINT alone is not enough if 'path' + # is a symbolic link to a directory or a NTFS junction. + # We need to set FILE_FLAG_BACKUP_SEMANTICS as well. + # See https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-createfilea + + # description from _winapi.c:601 + # /* REPARSE_DATA_BUFFER usage is heavily under-documented, especially for + # junction points. Here's what I've learned along the way: + # - A junction point has two components: a print name and a substitute + # name. They both describe the link target, but the substitute name is + # the physical target and the print name is shown in directory listings. + # - The print name must be a native name, prefixed with "\??\". + # - Both names are stored after each other in the same buffer (the + # PathBuffer) and both must be NUL-terminated. + # - There are four members defining their respective offset and length + # inside PathBuffer: SubstituteNameOffset, SubstituteNameLength, + # PrintNameOffset and PrintNameLength. + # - The total size we need to allocate for the REPARSE_DATA_BUFFER, thus, + # is the sum of: + # - the fixed header size (REPARSE_DATA_BUFFER_HEADER_SIZE) + # - the size of the MountPointReparseBuffer member without the PathBuffer + # - the size of the prefix ("\??\") in bytes + # - the size of the print name in bytes + # - the size of the substitute name in bytes + # - the size of two NUL terminators in bytes */ + + target_is_path = isinstance(path, pathlib.Path) + if target_is_path: + target = str(path) + else: + target = path + CreateFileW.argtypes = [LPWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE] + CreateFileW.restype = HANDLE + DeviceIoControl.argtypes = [HANDLE, DWORD, LPVOID, DWORD, LPVOID, DWORD, LPDWORD, LPVOID] + DeviceIoControl.restype = BOOL + handle = HANDLE(CreateFileW(target, GENERIC_READ, 0, None, OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, 0)) + buf = ReparseBuffer() + ret = DWORD(0) + status = DeviceIoControl(handle, FSCTL_GET_REPARSE_POINT, None, 0, ctypes.byref(buf), + MAXIMUM_REPARSE_DATA_BUFFER_SIZE, ctypes.byref(ret), None) + CloseHandle(handle) + if not status: + logger = getLogger(__file__) + logger.error("Failed IOCTL access to REPARSE_POINT {})".format(target)) + raise ValueError("not a symbolic link or access permission violation") + + if buf.reparse_tag == IO_REPARSE_TAG_SYMLINK: + offset = buf.substitute_name_offset + ending = offset + buf.substitute_name_length + rpath = bytearray(buf.symlink.path_buffer)[offset:ending].decode('UTF-16-LE') + elif buf.reparse_tag == IO_REPARSE_TAG_MOUNT_POINT: + offset = buf.substitute_name_offset + ending = offset + buf.substitute_name_length + rpath = bytearray(buf.mount.path_buffer)[offset:ending].decode('UTF-16-LE') + else: + raise ValueError("not a symbolic link") + # on posixmodule.c:7859 in py38, we do that + # ``` + # else if (rdb->ReparseTag == IO_REPARSE_TAG_MOUNT_POINT) + # { + # name = (wchar_t *)((char*)rdb->MountPointReparseBuffer.PathBuffer + + # rdb->MountPointReparseBuffer.SubstituteNameOffset); + # nameLen = rdb->MountPointReparseBuffer.SubstituteNameLength / sizeof(wchar_t); + # } + # else + # { + # PyErr_SetString(PyExc_ValueError, "not a symbolic link"); + # } + # if (nameLen > 4 && wcsncmp(name, L"\\??\\", 4) == 0) { + # /* Our buffer is mutable, so this is okay */ + # name[1] = L'\\'; + # } + # ``` + # so substitute prefix here. + if rpath.startswith('\\??\\'): + rpath = '\\\\' + rpath[2:] + if target_is_path: + return pathlib.WindowsPath(rpath) + else: + return rpath diff --git a/requests/__init__.py b/requests/__init__.py new file mode 100644 index 0000000..626247c --- /dev/null +++ b/requests/__init__.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- + +# __ +# /__) _ _ _ _ _/ _ +# / ( (- (/ (/ (- _) / _) +# / + +""" +Requests HTTP Library +~~~~~~~~~~~~~~~~~~~~~ + +Requests is an HTTP library, written in Python, for human beings. +Basic GET usage: + + >>> import requests + >>> r = requests.get('https://www.python.org') + >>> r.status_code + 200 + >>> b'Python is a programming language' in r.content + True + +... or POST: + + >>> payload = dict(key1='value1', key2='value2') + >>> r = requests.post('https://httpbin.org/post', data=payload) + >>> print(r.text) + { + ... + "form": { + "key1": "value1", + "key2": "value2" + }, + ... + } + +The other HTTP methods are supported - see `requests.api`. Full documentation +is at . + +:copyright: (c) 2017 by Kenneth Reitz. +:license: Apache 2.0, see LICENSE for more details. +""" + +import urllib3 +import chardet +import warnings +from .exceptions import RequestsDependencyWarning + + +def check_compatibility(urllib3_version, chardet_version): + urllib3_version = urllib3_version.split('.') + assert urllib3_version != ['dev'] # Verify urllib3 isn't installed from git. + + # Sometimes, urllib3 only reports its version as 16.1. + if len(urllib3_version) == 2: + urllib3_version.append('0') + + # Check urllib3 for compatibility. + major, minor, patch = urllib3_version # noqa: F811 + major, minor, patch = int(major), int(minor), int(patch) + # urllib3 >= 1.21.1, <= 1.25 + assert major == 1 + assert minor >= 21 + assert minor <= 25 + + # Check chardet for compatibility. + major, minor, patch = chardet_version.split('.')[:3] + major, minor, patch = int(major), int(minor), int(patch) + # chardet >= 3.0.2, < 3.1.0 + assert major == 3 + assert minor < 1 + assert patch >= 2 + + +def _check_cryptography(cryptography_version): + # cryptography < 1.3.4 + try: + cryptography_version = list(map(int, cryptography_version.split('.'))) + except ValueError: + return + + if cryptography_version < [1, 3, 4]: + warning = 'Old version of cryptography ({}) may cause slowdown.'.format(cryptography_version) + warnings.warn(warning, RequestsDependencyWarning) + +# Check imported dependencies for compatibility. +try: + check_compatibility(urllib3.__version__, chardet.__version__) +except (AssertionError, ValueError): + warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported " + "version!".format(urllib3.__version__, chardet.__version__), + RequestsDependencyWarning) + +# Attempt to enable urllib3's SNI support, if possible +try: + from urllib3.contrib import pyopenssl + pyopenssl.inject_into_urllib3() + + # Check cryptography version + from cryptography import __version__ as cryptography_version + _check_cryptography(cryptography_version) +except ImportError: + pass + +# urllib3's DependencyWarnings should be silenced. +from urllib3.exceptions import DependencyWarning +warnings.simplefilter('ignore', DependencyWarning) + +from .__version__ import __title__, __description__, __url__, __version__ +from .__version__ import __build__, __author__, __author_email__, __license__ +from .__version__ import __copyright__, __cake__ + +from . import utils +from . import packages +from .models import Request, Response, PreparedRequest +from .api import request, get, head, post, patch, put, delete, options +from .sessions import session, Session +from .status_codes import codes +from .exceptions import ( + RequestException, Timeout, URLRequired, + TooManyRedirects, HTTPError, ConnectionError, + FileModeWarning, ConnectTimeout, ReadTimeout +) + +# Set default logging handler to avoid "No handler found" warnings. +import logging +from logging import NullHandler + +logging.getLogger(__name__).addHandler(NullHandler()) + +# FileModeWarnings go off per the default. +warnings.simplefilter('default', FileModeWarning, append=True) diff --git a/requests/__pycache__/__init__.cpython-38.pyc b/requests/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..0e8d5e9 Binary files /dev/null and b/requests/__pycache__/__init__.cpython-38.pyc differ diff --git a/requests/__pycache__/__version__.cpython-38.pyc b/requests/__pycache__/__version__.cpython-38.pyc new file mode 100644 index 0000000..5ec95e1 Binary files /dev/null and b/requests/__pycache__/__version__.cpython-38.pyc differ diff --git a/requests/__pycache__/_internal_utils.cpython-38.pyc b/requests/__pycache__/_internal_utils.cpython-38.pyc new file mode 100644 index 0000000..e6ae5c0 Binary files /dev/null and b/requests/__pycache__/_internal_utils.cpython-38.pyc differ diff --git a/requests/__pycache__/adapters.cpython-38.pyc b/requests/__pycache__/adapters.cpython-38.pyc new file mode 100644 index 0000000..c37df30 Binary files /dev/null and b/requests/__pycache__/adapters.cpython-38.pyc differ diff --git a/requests/__pycache__/api.cpython-38.pyc b/requests/__pycache__/api.cpython-38.pyc new file mode 100644 index 0000000..7277cc1 Binary files /dev/null and b/requests/__pycache__/api.cpython-38.pyc differ diff --git a/requests/__pycache__/auth.cpython-38.pyc b/requests/__pycache__/auth.cpython-38.pyc new file mode 100644 index 0000000..0bd9ce9 Binary files /dev/null and b/requests/__pycache__/auth.cpython-38.pyc differ diff --git a/requests/__pycache__/certs.cpython-38.pyc b/requests/__pycache__/certs.cpython-38.pyc new file mode 100644 index 0000000..0934cfe Binary files /dev/null and b/requests/__pycache__/certs.cpython-38.pyc differ diff --git a/requests/__pycache__/compat.cpython-38.pyc b/requests/__pycache__/compat.cpython-38.pyc new file mode 100644 index 0000000..95459c1 Binary files /dev/null and b/requests/__pycache__/compat.cpython-38.pyc differ diff --git a/requests/__pycache__/cookies.cpython-38.pyc b/requests/__pycache__/cookies.cpython-38.pyc new file mode 100644 index 0000000..196e469 Binary files /dev/null and b/requests/__pycache__/cookies.cpython-38.pyc differ diff --git a/requests/__pycache__/exceptions.cpython-38.pyc b/requests/__pycache__/exceptions.cpython-38.pyc new file mode 100644 index 0000000..e7cb34e Binary files /dev/null and b/requests/__pycache__/exceptions.cpython-38.pyc differ diff --git a/requests/__pycache__/help.cpython-38.pyc b/requests/__pycache__/help.cpython-38.pyc new file mode 100644 index 0000000..8800341 Binary files /dev/null and b/requests/__pycache__/help.cpython-38.pyc differ diff --git a/requests/__pycache__/hooks.cpython-38.pyc b/requests/__pycache__/hooks.cpython-38.pyc new file mode 100644 index 0000000..b581d29 Binary files /dev/null and b/requests/__pycache__/hooks.cpython-38.pyc differ diff --git a/requests/__pycache__/models.cpython-38.pyc b/requests/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000..bc12118 Binary files /dev/null and b/requests/__pycache__/models.cpython-38.pyc differ diff --git a/requests/__pycache__/packages.cpython-38.pyc b/requests/__pycache__/packages.cpython-38.pyc new file mode 100644 index 0000000..5c333a2 Binary files /dev/null and b/requests/__pycache__/packages.cpython-38.pyc differ diff --git a/requests/__pycache__/sessions.cpython-38.pyc b/requests/__pycache__/sessions.cpython-38.pyc new file mode 100644 index 0000000..c987063 Binary files /dev/null and b/requests/__pycache__/sessions.cpython-38.pyc differ diff --git a/requests/__pycache__/status_codes.cpython-38.pyc b/requests/__pycache__/status_codes.cpython-38.pyc new file mode 100644 index 0000000..e8807ff Binary files /dev/null and b/requests/__pycache__/status_codes.cpython-38.pyc differ diff --git a/requests/__pycache__/structures.cpython-38.pyc b/requests/__pycache__/structures.cpython-38.pyc new file mode 100644 index 0000000..3ace97f Binary files /dev/null and b/requests/__pycache__/structures.cpython-38.pyc differ diff --git a/requests/__pycache__/utils.cpython-38.pyc b/requests/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..85ce455 Binary files /dev/null and b/requests/__pycache__/utils.cpython-38.pyc differ diff --git a/requests/__version__.py b/requests/__version__.py new file mode 100644 index 0000000..b9e7df4 --- /dev/null +++ b/requests/__version__.py @@ -0,0 +1,14 @@ +# .-. .-. .-. . . .-. .-. .-. .-. +# |( |- |.| | | |- `-. | `-. +# ' ' `-' `-`.`-' `-' `-' ' `-' + +__title__ = 'requests' +__description__ = 'Python HTTP for Humans.' +__url__ = 'https://requests.readthedocs.io' +__version__ = '2.23.0' +__build__ = 0x022300 +__author__ = 'Kenneth Reitz' +__author_email__ = 'me@kennethreitz.org' +__license__ = 'Apache 2.0' +__copyright__ = 'Copyright 2020 Kenneth Reitz' +__cake__ = u'\u2728 \U0001f370 \u2728' diff --git a/requests/_internal_utils.py b/requests/_internal_utils.py new file mode 100644 index 0000000..759d9a5 --- /dev/null +++ b/requests/_internal_utils.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +""" +requests._internal_utils +~~~~~~~~~~~~~~ + +Provides utility functions that are consumed internally by Requests +which depend on extremely few external helpers (such as compat) +""" + +from .compat import is_py2, builtin_str, str + + +def to_native_string(string, encoding='ascii'): + """Given a string object, regardless of type, returns a representation of + that string in the native string type, encoding and decoding where + necessary. This assumes ASCII unless told otherwise. + """ + if isinstance(string, builtin_str): + out = string + else: + if is_py2: + out = string.encode(encoding) + else: + out = string.decode(encoding) + + return out + + +def unicode_is_ascii(u_string): + """Determine if unicode string only contains ASCII characters. + + :param str u_string: unicode string to check. Must be unicode + and not Python 2 `str`. + :rtype: bool + """ + assert isinstance(u_string, str) + try: + u_string.encode('ascii') + return True + except UnicodeEncodeError: + return False diff --git a/requests/adapters.py b/requests/adapters.py new file mode 100644 index 0000000..fa4d9b3 --- /dev/null +++ b/requests/adapters.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- + +""" +requests.adapters +~~~~~~~~~~~~~~~~~ + +This module contains the transport adapters that Requests uses to define +and maintain connections. +""" + +import os.path +import socket + +from urllib3.poolmanager import PoolManager, proxy_from_url +from urllib3.response import HTTPResponse +from urllib3.util import parse_url +from urllib3.util import Timeout as TimeoutSauce +from urllib3.util.retry import Retry +from urllib3.exceptions import ClosedPoolError +from urllib3.exceptions import ConnectTimeoutError +from urllib3.exceptions import HTTPError as _HTTPError +from urllib3.exceptions import MaxRetryError +from urllib3.exceptions import NewConnectionError +from urllib3.exceptions import ProxyError as _ProxyError +from urllib3.exceptions import ProtocolError +from urllib3.exceptions import ReadTimeoutError +from urllib3.exceptions import SSLError as _SSLError +from urllib3.exceptions import ResponseError +from urllib3.exceptions import LocationValueError + +from .models import Response +from .compat import urlparse, basestring +from .utils import (DEFAULT_CA_BUNDLE_PATH, extract_zipped_paths, + get_encoding_from_headers, prepend_scheme_if_needed, + get_auth_from_url, urldefragauth, select_proxy) +from .structures import CaseInsensitiveDict +from .cookies import extract_cookies_to_jar +from .exceptions import (ConnectionError, ConnectTimeout, ReadTimeout, SSLError, + ProxyError, RetryError, InvalidSchema, InvalidProxyURL, + InvalidURL) +from .auth import _basic_auth_str + +try: + from urllib3.contrib.socks import SOCKSProxyManager +except ImportError: + def SOCKSProxyManager(*args, **kwargs): + raise InvalidSchema("Missing dependencies for SOCKS support.") + +DEFAULT_POOLBLOCK = False +DEFAULT_POOLSIZE = 10 +DEFAULT_RETRIES = 0 +DEFAULT_POOL_TIMEOUT = None + + +class BaseAdapter(object): + """The Base Transport Adapter""" + + def __init__(self): + super(BaseAdapter, self).__init__() + + def send(self, request, stream=False, timeout=None, verify=True, + cert=None, proxies=None): + """Sends PreparedRequest object. Returns Response object. + + :param request: The :class:`PreparedRequest ` being sent. + :param stream: (optional) Whether to stream the request content. + :param timeout: (optional) How long to wait for the server to send + data before giving up, as a float, or a :ref:`(connect timeout, + read timeout) ` tuple. + :type timeout: float or tuple + :param verify: (optional) Either a boolean, in which case it controls whether we verify + the server's TLS certificate, or a string, in which case it must be a path + to a CA bundle to use + :param cert: (optional) Any user-provided SSL certificate to be trusted. + :param proxies: (optional) The proxies dictionary to apply to the request. + """ + raise NotImplementedError + + def close(self): + """Cleans up adapter specific items.""" + raise NotImplementedError + + +class HTTPAdapter(BaseAdapter): + """The built-in HTTP Adapter for urllib3. + + Provides a general-case interface for Requests sessions to contact HTTP and + HTTPS urls by implementing the Transport Adapter interface. This class will + usually be created by the :class:`Session ` class under the + covers. + + :param pool_connections: The number of urllib3 connection pools to cache. + :param pool_maxsize: The maximum number of connections to save in the pool. + :param max_retries: The maximum number of retries each connection + should attempt. Note, this applies only to failed DNS lookups, socket + connections and connection timeouts, never to requests where data has + made it to the server. By default, Requests does not retry failed + connections. If you need granular control over the conditions under + which we retry a request, import urllib3's ``Retry`` class and pass + that instead. + :param pool_block: Whether the connection pool should block for connections. + + Usage:: + + >>> import requests + >>> s = requests.Session() + >>> a = requests.adapters.HTTPAdapter(max_retries=3) + >>> s.mount('http://', a) + """ + __attrs__ = ['max_retries', 'config', '_pool_connections', '_pool_maxsize', + '_pool_block'] + + def __init__(self, pool_connections=DEFAULT_POOLSIZE, + pool_maxsize=DEFAULT_POOLSIZE, max_retries=DEFAULT_RETRIES, + pool_block=DEFAULT_POOLBLOCK): + if max_retries == DEFAULT_RETRIES: + self.max_retries = Retry(0, read=False) + else: + self.max_retries = Retry.from_int(max_retries) + self.config = {} + self.proxy_manager = {} + + super(HTTPAdapter, self).__init__() + + self._pool_connections = pool_connections + self._pool_maxsize = pool_maxsize + self._pool_block = pool_block + + self.init_poolmanager(pool_connections, pool_maxsize, block=pool_block) + + def __getstate__(self): + return {attr: getattr(self, attr, None) for attr in self.__attrs__} + + def __setstate__(self, state): + # Can't handle by adding 'proxy_manager' to self.__attrs__ because + # self.poolmanager uses a lambda function, which isn't pickleable. + self.proxy_manager = {} + self.config = {} + + for attr, value in state.items(): + setattr(self, attr, value) + + self.init_poolmanager(self._pool_connections, self._pool_maxsize, + block=self._pool_block) + + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): + """Initializes a urllib3 PoolManager. + + This method should not be called from user code, and is only + exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param connections: The number of urllib3 connection pools to cache. + :param maxsize: The maximum number of connections to save in the pool. + :param block: Block when no free connections are available. + :param pool_kwargs: Extra keyword arguments used to initialize the Pool Manager. + """ + # save these values for pickling + self._pool_connections = connections + self._pool_maxsize = maxsize + self._pool_block = block + + self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, + block=block, strict=True, **pool_kwargs) + + def proxy_manager_for(self, proxy, **proxy_kwargs): + """Return urllib3 ProxyManager for the given proxy. + + This method should not be called from user code, and is only + exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param proxy: The proxy to return a urllib3 ProxyManager for. + :param proxy_kwargs: Extra keyword arguments used to configure the Proxy Manager. + :returns: ProxyManager + :rtype: urllib3.ProxyManager + """ + if proxy in self.proxy_manager: + manager = self.proxy_manager[proxy] + elif proxy.lower().startswith('socks'): + username, password = get_auth_from_url(proxy) + manager = self.proxy_manager[proxy] = SOCKSProxyManager( + proxy, + username=username, + password=password, + num_pools=self._pool_connections, + maxsize=self._pool_maxsize, + block=self._pool_block, + **proxy_kwargs + ) + else: + proxy_headers = self.proxy_headers(proxy) + manager = self.proxy_manager[proxy] = proxy_from_url( + proxy, + proxy_headers=proxy_headers, + num_pools=self._pool_connections, + maxsize=self._pool_maxsize, + block=self._pool_block, + **proxy_kwargs) + + return manager + + def cert_verify(self, conn, url, verify, cert): + """Verify a SSL certificate. This method should not be called from user + code, and is only exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param conn: The urllib3 connection object associated with the cert. + :param url: The requested URL. + :param verify: Either a boolean, in which case it controls whether we verify + the server's TLS certificate, or a string, in which case it must be a path + to a CA bundle to use + :param cert: The SSL certificate to verify. + """ + if url.lower().startswith('https') and verify: + + cert_loc = None + + # Allow self-specified cert location. + if verify is not True: + cert_loc = verify + + if not cert_loc: + cert_loc = extract_zipped_paths(DEFAULT_CA_BUNDLE_PATH) + + if not cert_loc or not os.path.exists(cert_loc): + raise IOError("Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(cert_loc)) + + conn.cert_reqs = 'CERT_REQUIRED' + + if not os.path.isdir(cert_loc): + conn.ca_certs = cert_loc + else: + conn.ca_cert_dir = cert_loc + else: + conn.cert_reqs = 'CERT_NONE' + conn.ca_certs = None + conn.ca_cert_dir = None + + if cert: + if not isinstance(cert, basestring): + conn.cert_file = cert[0] + conn.key_file = cert[1] + else: + conn.cert_file = cert + conn.key_file = None + if conn.cert_file and not os.path.exists(conn.cert_file): + raise IOError("Could not find the TLS certificate file, " + "invalid path: {}".format(conn.cert_file)) + if conn.key_file and not os.path.exists(conn.key_file): + raise IOError("Could not find the TLS key file, " + "invalid path: {}".format(conn.key_file)) + + def build_response(self, req, resp): + """Builds a :class:`Response ` object from a urllib3 + response. This should not be called from user code, and is only exposed + for use when subclassing the + :class:`HTTPAdapter ` + + :param req: The :class:`PreparedRequest ` used to generate the response. + :param resp: The urllib3 response object. + :rtype: requests.Response + """ + response = Response() + + # Fallback to None if there's no status_code, for whatever reason. + response.status_code = getattr(resp, 'status', None) + + # Make headers case-insensitive. + response.headers = CaseInsensitiveDict(getattr(resp, 'headers', {})) + + # Set encoding. + response.encoding = get_encoding_from_headers(response.headers) + response.raw = resp + response.reason = response.raw.reason + + if isinstance(req.url, bytes): + response.url = req.url.decode('utf-8') + else: + response.url = req.url + + # Add new cookies from the server. + extract_cookies_to_jar(response.cookies, req, resp) + + # Give the Response some context. + response.request = req + response.connection = self + + return response + + def get_connection(self, url, proxies=None): + """Returns a urllib3 connection for the given URL. This should not be + called from user code, and is only exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param url: The URL to connect to. + :param proxies: (optional) A Requests-style dictionary of proxies used on this request. + :rtype: urllib3.ConnectionPool + """ + proxy = select_proxy(url, proxies) + + if proxy: + proxy = prepend_scheme_if_needed(proxy, 'http') + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL("Please check proxy URL. It is malformed" + " and could be missing the host.") + proxy_manager = self.proxy_manager_for(proxy) + conn = proxy_manager.connection_from_url(url) + else: + # Only scheme should be lower case + parsed = urlparse(url) + url = parsed.geturl() + conn = self.poolmanager.connection_from_url(url) + + return conn + + def close(self): + """Disposes of any internal state. + + Currently, this closes the PoolManager and any active ProxyManager, + which closes any pooled connections. + """ + self.poolmanager.clear() + for proxy in self.proxy_manager.values(): + proxy.clear() + + def request_url(self, request, proxies): + """Obtain the url to use when making the final request. + + If the message is being sent through a HTTP proxy, the full URL has to + be used. Otherwise, we should only use the path portion of the URL. + + This should not be called from user code, and is only exposed for use + when subclassing the + :class:`HTTPAdapter `. + + :param request: The :class:`PreparedRequest ` being sent. + :param proxies: A dictionary of schemes or schemes and hosts to proxy URLs. + :rtype: str + """ + proxy = select_proxy(request.url, proxies) + scheme = urlparse(request.url).scheme + + is_proxied_http_request = (proxy and scheme != 'https') + using_socks_proxy = False + if proxy: + proxy_scheme = urlparse(proxy).scheme.lower() + using_socks_proxy = proxy_scheme.startswith('socks') + + url = request.path_url + if is_proxied_http_request and not using_socks_proxy: + url = urldefragauth(request.url) + + return url + + def add_headers(self, request, **kwargs): + """Add any headers needed by the connection. As of v2.0 this does + nothing by default, but is left for overriding by users that subclass + the :class:`HTTPAdapter `. + + This should not be called from user code, and is only exposed for use + when subclassing the + :class:`HTTPAdapter `. + + :param request: The :class:`PreparedRequest ` to add headers to. + :param kwargs: The keyword arguments from the call to send(). + """ + pass + + def proxy_headers(self, proxy): + """Returns a dictionary of the headers to add to any request sent + through a proxy. This works with urllib3 magic to ensure that they are + correctly sent to the proxy, rather than in a tunnelled request if + CONNECT is being used. + + This should not be called from user code, and is only exposed for use + when subclassing the + :class:`HTTPAdapter `. + + :param proxy: The url of the proxy being used for this request. + :rtype: dict + """ + headers = {} + username, password = get_auth_from_url(proxy) + + if username: + headers['Proxy-Authorization'] = _basic_auth_str(username, + password) + + return headers + + def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None): + """Sends PreparedRequest object. Returns Response object. + + :param request: The :class:`PreparedRequest ` being sent. + :param stream: (optional) Whether to stream the request content. + :param timeout: (optional) How long to wait for the server to send + data before giving up, as a float, or a :ref:`(connect timeout, + read timeout) ` tuple. + :type timeout: float or tuple or urllib3 Timeout object + :param verify: (optional) Either a boolean, in which case it controls whether + we verify the server's TLS certificate, or a string, in which case it + must be a path to a CA bundle to use + :param cert: (optional) Any user-provided SSL certificate to be trusted. + :param proxies: (optional) The proxies dictionary to apply to the request. + :rtype: requests.Response + """ + + try: + conn = self.get_connection(request.url, proxies) + except LocationValueError as e: + raise InvalidURL(e, request=request) + + self.cert_verify(conn, request.url, verify, cert) + url = self.request_url(request, proxies) + self.add_headers(request, stream=stream, timeout=timeout, verify=verify, cert=cert, proxies=proxies) + + chunked = not (request.body is None or 'Content-Length' in request.headers) + + if isinstance(timeout, tuple): + try: + connect, read = timeout + timeout = TimeoutSauce(connect=connect, read=read) + except ValueError as e: + # this may raise a string formatting error. + err = ("Invalid timeout {}. Pass a (connect, read) " + "timeout tuple, or a single float to set " + "both timeouts to the same value".format(timeout)) + raise ValueError(err) + elif isinstance(timeout, TimeoutSauce): + pass + else: + timeout = TimeoutSauce(connect=timeout, read=timeout) + + try: + if not chunked: + resp = conn.urlopen( + method=request.method, + url=url, + body=request.body, + headers=request.headers, + redirect=False, + assert_same_host=False, + preload_content=False, + decode_content=False, + retries=self.max_retries, + timeout=timeout + ) + + # Send the request. + else: + if hasattr(conn, 'proxy_pool'): + conn = conn.proxy_pool + + low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT) + + try: + low_conn.putrequest(request.method, + url, + skip_accept_encoding=True) + + for header, value in request.headers.items(): + low_conn.putheader(header, value) + + low_conn.endheaders() + + for i in request.body: + low_conn.send(hex(len(i))[2:].encode('utf-8')) + low_conn.send(b'\r\n') + low_conn.send(i) + low_conn.send(b'\r\n') + low_conn.send(b'0\r\n\r\n') + + # Receive the response from the server + try: + # For Python 2.7, use buffering of HTTP responses + r = low_conn.getresponse(buffering=True) + except TypeError: + # For compatibility with Python 3.3+ + r = low_conn.getresponse() + + resp = HTTPResponse.from_httplib( + r, + pool=conn, + connection=low_conn, + preload_content=False, + decode_content=False + ) + except: + # If we hit any problems here, clean up the connection. + # Then, reraise so that we can handle the actual exception. + low_conn.close() + raise + + except (ProtocolError, socket.error) as err: + raise ConnectionError(err, request=request) + + except MaxRetryError as e: + if isinstance(e.reason, ConnectTimeoutError): + # TODO: Remove this in 3.0.0: see #2811 + if not isinstance(e.reason, NewConnectionError): + raise ConnectTimeout(e, request=request) + + if isinstance(e.reason, ResponseError): + raise RetryError(e, request=request) + + if isinstance(e.reason, _ProxyError): + raise ProxyError(e, request=request) + + if isinstance(e.reason, _SSLError): + # This branch is for urllib3 v1.22 and later. + raise SSLError(e, request=request) + + raise ConnectionError(e, request=request) + + except ClosedPoolError as e: + raise ConnectionError(e, request=request) + + except _ProxyError as e: + raise ProxyError(e) + + except (_SSLError, _HTTPError) as e: + if isinstance(e, _SSLError): + # This branch is for urllib3 versions earlier than v1.22 + raise SSLError(e, request=request) + elif isinstance(e, ReadTimeoutError): + raise ReadTimeout(e, request=request) + else: + raise + + return self.build_response(request, resp) diff --git a/requests/api.py b/requests/api.py new file mode 100644 index 0000000..e978e20 --- /dev/null +++ b/requests/api.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- + +""" +requests.api +~~~~~~~~~~~~ + +This module implements the Requests API. + +:copyright: (c) 2012 by Kenneth Reitz. +:license: Apache2, see LICENSE for more details. +""" + +from . import sessions + + +def request(method, url, **kwargs): + """Constructs and sends a :class:`Request `. + + :param method: method for the new :class:`Request` object: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``, or ``DELETE``. + :param url: URL for the new :class:`Request` object. + :param params: (optional) Dictionary, list of tuples or bytes to send + in the query string for the :class:`Request`. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`. + :param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`. + :param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`. + :param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``) for multipart encoding upload. + ``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')`` + or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content-type'`` is a string + defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers + to add for the file. + :param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth. + :param timeout: (optional) How many seconds to wait for the server to send data + before giving up, as a float, or a :ref:`(connect timeout, read + timeout) ` tuple. + :type timeout: float or tuple + :param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection. Defaults to ``True``. + :type allow_redirects: bool + :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy. + :param verify: (optional) Either a boolean, in which case it controls whether we verify + the server's TLS certificate, or a string, in which case it must be a path + to a CA bundle to use. Defaults to ``True``. + :param stream: (optional) if ``False``, the response content will be immediately downloaded. + :param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair. + :return: :class:`Response ` object + :rtype: requests.Response + + Usage:: + + >>> import requests + >>> req = requests.request('GET', 'https://httpbin.org/get') + >>> req + + """ + + # By using the 'with' statement we are sure the session is closed, thus we + # avoid leaving sockets open which can trigger a ResourceWarning in some + # cases, and look like a memory leak in others. + with sessions.Session() as session: + return session.request(method=method, url=url, **kwargs) + + +def get(url, params=None, **kwargs): + r"""Sends a GET request. + + :param url: URL for the new :class:`Request` object. + :param params: (optional) Dictionary, list of tuples or bytes to send + in the query string for the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', True) + return request('get', url, params=params, **kwargs) + + +def options(url, **kwargs): + r"""Sends an OPTIONS request. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', True) + return request('options', url, **kwargs) + + +def head(url, **kwargs): + r"""Sends a HEAD request. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. If + `allow_redirects` is not provided, it will be set to `False` (as + opposed to the default :meth:`request` behavior). + :return: :class:`Response ` object + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', False) + return request('head', url, **kwargs) + + +def post(url, data=None, json=None, **kwargs): + r"""Sends a POST request. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) json data to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + return request('post', url, data=data, json=json, **kwargs) + + +def put(url, data=None, **kwargs): + r"""Sends a PUT request. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) json data to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + return request('put', url, data=data, **kwargs) + + +def patch(url, data=None, **kwargs): + r"""Sends a PATCH request. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) json data to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + return request('patch', url, data=data, **kwargs) + + +def delete(url, **kwargs): + r"""Sends a DELETE request. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + :rtype: requests.Response + """ + + return request('delete', url, **kwargs) diff --git a/requests/auth.py b/requests/auth.py new file mode 100644 index 0000000..eeface3 --- /dev/null +++ b/requests/auth.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- + +""" +requests.auth +~~~~~~~~~~~~~ + +This module contains the authentication handlers for Requests. +""" + +import os +import re +import time +import hashlib +import threading +import warnings + +from base64 import b64encode + +from .compat import urlparse, str, basestring +from .cookies import extract_cookies_to_jar +from ._internal_utils import to_native_string +from .utils import parse_dict_header + +CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded' +CONTENT_TYPE_MULTI_PART = 'multipart/form-data' + + +def _basic_auth_str(username, password): + """Returns a Basic Auth string.""" + + # "I want us to put a big-ol' comment on top of it that + # says that this behaviour is dumb but we need to preserve + # it because people are relying on it." + # - Lukasa + # + # These are here solely to maintain backwards compatibility + # for things like ints. This will be removed in 3.0.0. + if not isinstance(username, basestring): + warnings.warn( + "Non-string usernames will no longer be supported in Requests " + "3.0.0. Please convert the object you've passed in ({!r}) to " + "a string or bytes object in the near future to avoid " + "problems.".format(username), + category=DeprecationWarning, + ) + username = str(username) + + if not isinstance(password, basestring): + warnings.warn( + "Non-string passwords will no longer be supported in Requests " + "3.0.0. Please convert the object you've passed in ({!r}) to " + "a string or bytes object in the near future to avoid " + "problems.".format(type(password)), + category=DeprecationWarning, + ) + password = str(password) + # -- End Removal -- + + if isinstance(username, str): + username = username.encode('latin1') + + if isinstance(password, str): + password = password.encode('latin1') + + authstr = 'Basic ' + to_native_string( + b64encode(b':'.join((username, password))).strip() + ) + + return authstr + + +class AuthBase(object): + """Base class that all auth implementations derive from""" + + def __call__(self, r): + raise NotImplementedError('Auth hooks must be callable.') + + +class HTTPBasicAuth(AuthBase): + """Attaches HTTP Basic Authentication to the given Request object.""" + + def __init__(self, username, password): + self.username = username + self.password = password + + def __eq__(self, other): + return all([ + self.username == getattr(other, 'username', None), + self.password == getattr(other, 'password', None) + ]) + + def __ne__(self, other): + return not self == other + + def __call__(self, r): + r.headers['Authorization'] = _basic_auth_str(self.username, self.password) + return r + + +class HTTPProxyAuth(HTTPBasicAuth): + """Attaches HTTP Proxy Authentication to a given Request object.""" + + def __call__(self, r): + r.headers['Proxy-Authorization'] = _basic_auth_str(self.username, self.password) + return r + + +class HTTPDigestAuth(AuthBase): + """Attaches HTTP Digest Authentication to the given Request object.""" + + def __init__(self, username, password): + self.username = username + self.password = password + # Keep state in per-thread local storage + self._thread_local = threading.local() + + def init_per_thread_state(self): + # Ensure state is initialized just once per-thread + if not hasattr(self._thread_local, 'init'): + self._thread_local.init = True + self._thread_local.last_nonce = '' + self._thread_local.nonce_count = 0 + self._thread_local.chal = {} + self._thread_local.pos = None + self._thread_local.num_401_calls = None + + def build_digest_header(self, method, url): + """ + :rtype: str + """ + + realm = self._thread_local.chal['realm'] + nonce = self._thread_local.chal['nonce'] + qop = self._thread_local.chal.get('qop') + algorithm = self._thread_local.chal.get('algorithm') + opaque = self._thread_local.chal.get('opaque') + hash_utf8 = None + + if algorithm is None: + _algorithm = 'MD5' + else: + _algorithm = algorithm.upper() + # lambdas assume digest modules are imported at the top level + if _algorithm == 'MD5' or _algorithm == 'MD5-SESS': + def md5_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.md5(x).hexdigest() + hash_utf8 = md5_utf8 + elif _algorithm == 'SHA': + def sha_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha1(x).hexdigest() + hash_utf8 = sha_utf8 + elif _algorithm == 'SHA-256': + def sha256_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha256(x).hexdigest() + hash_utf8 = sha256_utf8 + elif _algorithm == 'SHA-512': + def sha512_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha512(x).hexdigest() + hash_utf8 = sha512_utf8 + + KD = lambda s, d: hash_utf8("%s:%s" % (s, d)) + + if hash_utf8 is None: + return None + + # XXX not implemented yet + entdig = None + p_parsed = urlparse(url) + #: path is request-uri defined in RFC 2616 which should not be empty + path = p_parsed.path or "/" + if p_parsed.query: + path += '?' + p_parsed.query + + A1 = '%s:%s:%s' % (self.username, realm, self.password) + A2 = '%s:%s' % (method, path) + + HA1 = hash_utf8(A1) + HA2 = hash_utf8(A2) + + if nonce == self._thread_local.last_nonce: + self._thread_local.nonce_count += 1 + else: + self._thread_local.nonce_count = 1 + ncvalue = '%08x' % self._thread_local.nonce_count + s = str(self._thread_local.nonce_count).encode('utf-8') + s += nonce.encode('utf-8') + s += time.ctime().encode('utf-8') + s += os.urandom(8) + + cnonce = (hashlib.sha1(s).hexdigest()[:16]) + if _algorithm == 'MD5-SESS': + HA1 = hash_utf8('%s:%s:%s' % (HA1, nonce, cnonce)) + + if not qop: + respdig = KD(HA1, "%s:%s" % (nonce, HA2)) + elif qop == 'auth' or 'auth' in qop.split(','): + noncebit = "%s:%s:%s:%s:%s" % ( + nonce, ncvalue, cnonce, 'auth', HA2 + ) + respdig = KD(HA1, noncebit) + else: + # XXX handle auth-int. + return None + + self._thread_local.last_nonce = nonce + + # XXX should the partial digests be encoded too? + base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \ + 'response="%s"' % (self.username, realm, nonce, path, respdig) + if opaque: + base += ', opaque="%s"' % opaque + if algorithm: + base += ', algorithm="%s"' % algorithm + if entdig: + base += ', digest="%s"' % entdig + if qop: + base += ', qop="auth", nc=%s, cnonce="%s"' % (ncvalue, cnonce) + + return 'Digest %s' % (base) + + def handle_redirect(self, r, **kwargs): + """Reset num_401_calls counter on redirects.""" + if r.is_redirect: + self._thread_local.num_401_calls = 1 + + def handle_401(self, r, **kwargs): + """ + Takes the given response and tries digest-auth, if needed. + + :rtype: requests.Response + """ + + # If response is not 4xx, do not auth + # See https://github.com/psf/requests/issues/3772 + if not 400 <= r.status_code < 500: + self._thread_local.num_401_calls = 1 + return r + + if self._thread_local.pos is not None: + # Rewind the file position indicator of the body to where + # it was to resend the request. + r.request.body.seek(self._thread_local.pos) + s_auth = r.headers.get('www-authenticate', '') + + if 'digest' in s_auth.lower() and self._thread_local.num_401_calls < 2: + + self._thread_local.num_401_calls += 1 + pat = re.compile(r'digest ', flags=re.IGNORECASE) + self._thread_local.chal = parse_dict_header(pat.sub('', s_auth, count=1)) + + # Consume content and release the original connection + # to allow our new request to reuse the same one. + r.content + r.close() + prep = r.request.copy() + extract_cookies_to_jar(prep._cookies, r.request, r.raw) + prep.prepare_cookies(prep._cookies) + + prep.headers['Authorization'] = self.build_digest_header( + prep.method, prep.url) + _r = r.connection.send(prep, **kwargs) + _r.history.append(r) + _r.request = prep + + return _r + + self._thread_local.num_401_calls = 1 + return r + + def __call__(self, r): + # Initialize per-thread state, if needed + self.init_per_thread_state() + # If we have a saved nonce, skip the 401 + if self._thread_local.last_nonce: + r.headers['Authorization'] = self.build_digest_header(r.method, r.url) + try: + self._thread_local.pos = r.body.tell() + except AttributeError: + # In the case of HTTPDigestAuth being reused and the body of + # the previous request was a file-like object, pos has the + # file position of the previous body. Ensure it's set to + # None. + self._thread_local.pos = None + r.register_hook('response', self.handle_401) + r.register_hook('response', self.handle_redirect) + self._thread_local.num_401_calls = 1 + + return r + + def __eq__(self, other): + return all([ + self.username == getattr(other, 'username', None), + self.password == getattr(other, 'password', None) + ]) + + def __ne__(self, other): + return not self == other diff --git a/requests/certs.py b/requests/certs.py new file mode 100644 index 0000000..d1a378d --- /dev/null +++ b/requests/certs.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +requests.certs +~~~~~~~~~~~~~~ + +This module returns the preferred default CA certificate bundle. There is +only one — the one from the certifi package. + +If you are packaging Requests, e.g., for a Linux distribution or a managed +environment, you can change the definition of where() to return a separately +packaged CA bundle. +""" +from certifi import where + +if __name__ == '__main__': + print(where()) diff --git a/requests/compat.py b/requests/compat.py new file mode 100644 index 0000000..5de0769 --- /dev/null +++ b/requests/compat.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +""" +requests.compat +~~~~~~~~~~~~~~~ + +This module handles import compatibility issues between Python 2 and +Python 3. +""" + +import chardet + +import sys + +# ------- +# Pythons +# ------- + +# Syntax sugar. +_ver = sys.version_info + +#: Python 2.x? +is_py2 = (_ver[0] == 2) + +#: Python 3.x? +is_py3 = (_ver[0] == 3) + +try: + import simplejson as json +except ImportError: + import json + +# --------- +# Specifics +# --------- + +if is_py2: + from urllib import ( + quote, unquote, quote_plus, unquote_plus, urlencode, getproxies, + proxy_bypass, proxy_bypass_environment, getproxies_environment) + from urlparse import urlparse, urlunparse, urljoin, urlsplit, urldefrag + from urllib2 import parse_http_list + import cookielib + from Cookie import Morsel + from StringIO import StringIO + # Keep OrderedDict for backwards compatibility. + from collections import Callable, Mapping, MutableMapping, OrderedDict + + + builtin_str = str + bytes = str + str = unicode + basestring = basestring + numeric_types = (int, long, float) + integer_types = (int, long) + +elif is_py3: + from urllib.parse import urlparse, urlunparse, urljoin, urlsplit, urlencode, quote, unquote, quote_plus, unquote_plus, urldefrag + from urllib.request import parse_http_list, getproxies, proxy_bypass, proxy_bypass_environment, getproxies_environment + from http import cookiejar as cookielib + from http.cookies import Morsel + from io import StringIO + # Keep OrderedDict for backwards compatibility. + from collections import OrderedDict + from collections.abc import Callable, Mapping, MutableMapping + + builtin_str = str + str = str + bytes = bytes + basestring = (str, bytes) + numeric_types = (int, float) + integer_types = (int,) diff --git a/requests/cookies.py b/requests/cookies.py new file mode 100644 index 0000000..56fccd9 --- /dev/null +++ b/requests/cookies.py @@ -0,0 +1,549 @@ +# -*- coding: utf-8 -*- + +""" +requests.cookies +~~~~~~~~~~~~~~~~ + +Compatibility code to be able to use `cookielib.CookieJar` with requests. + +requests.utils imports from here, so be careful with imports. +""" + +import copy +import time +import calendar + +from ._internal_utils import to_native_string +from .compat import cookielib, urlparse, urlunparse, Morsel, MutableMapping + +try: + import threading +except ImportError: + import dummy_threading as threading + + +class MockRequest(object): + """Wraps a `requests.Request` to mimic a `urllib2.Request`. + + The code in `cookielib.CookieJar` expects this interface in order to correctly + manage cookie policies, i.e., determine whether a cookie can be set, given the + domains of the request and the cookie. + + The original request object is read-only. The client is responsible for collecting + the new headers via `get_new_headers()` and interpreting them appropriately. You + probably want `get_cookie_header`, defined below. + """ + + def __init__(self, request): + self._r = request + self._new_headers = {} + self.type = urlparse(self._r.url).scheme + + def get_type(self): + return self.type + + def get_host(self): + return urlparse(self._r.url).netloc + + def get_origin_req_host(self): + return self.get_host() + + def get_full_url(self): + # Only return the response's URL if the user hadn't set the Host + # header + if not self._r.headers.get('Host'): + return self._r.url + # If they did set it, retrieve it and reconstruct the expected domain + host = to_native_string(self._r.headers['Host'], encoding='utf-8') + parsed = urlparse(self._r.url) + # Reconstruct the URL as we expect it + return urlunparse([ + parsed.scheme, host, parsed.path, parsed.params, parsed.query, + parsed.fragment + ]) + + def is_unverifiable(self): + return True + + def has_header(self, name): + return name in self._r.headers or name in self._new_headers + + def get_header(self, name, default=None): + return self._r.headers.get(name, self._new_headers.get(name, default)) + + def add_header(self, key, val): + """cookielib has no legitimate use for this method; add it back if you find one.""" + raise NotImplementedError("Cookie headers should be added with add_unredirected_header()") + + def add_unredirected_header(self, name, value): + self._new_headers[name] = value + + def get_new_headers(self): + return self._new_headers + + @property + def unverifiable(self): + return self.is_unverifiable() + + @property + def origin_req_host(self): + return self.get_origin_req_host() + + @property + def host(self): + return self.get_host() + + +class MockResponse(object): + """Wraps a `httplib.HTTPMessage` to mimic a `urllib.addinfourl`. + + ...what? Basically, expose the parsed HTTP headers from the server response + the way `cookielib` expects to see them. + """ + + def __init__(self, headers): + """Make a MockResponse for `cookielib` to read. + + :param headers: a httplib.HTTPMessage or analogous carrying the headers + """ + self._headers = headers + + def info(self): + return self._headers + + def getheaders(self, name): + self._headers.getheaders(name) + + +def extract_cookies_to_jar(jar, request, response): + """Extract the cookies from the response into a CookieJar. + + :param jar: cookielib.CookieJar (not necessarily a RequestsCookieJar) + :param request: our own requests.Request object + :param response: urllib3.HTTPResponse object + """ + if not (hasattr(response, '_original_response') and + response._original_response): + return + # the _original_response field is the wrapped httplib.HTTPResponse object, + req = MockRequest(request) + # pull out the HTTPMessage with the headers and put it in the mock: + res = MockResponse(response._original_response.msg) + jar.extract_cookies(res, req) + + +def get_cookie_header(jar, request): + """ + Produce an appropriate Cookie header string to be sent with `request`, or None. + + :rtype: str + """ + r = MockRequest(request) + jar.add_cookie_header(r) + return r.get_new_headers().get('Cookie') + + +def remove_cookie_by_name(cookiejar, name, domain=None, path=None): + """Unsets a cookie by name, by default over all domains and paths. + + Wraps CookieJar.clear(), is O(n). + """ + clearables = [] + for cookie in cookiejar: + if cookie.name != name: + continue + if domain is not None and domain != cookie.domain: + continue + if path is not None and path != cookie.path: + continue + clearables.append((cookie.domain, cookie.path, cookie.name)) + + for domain, path, name in clearables: + cookiejar.clear(domain, path, name) + + +class CookieConflictError(RuntimeError): + """There are two cookies that meet the criteria specified in the cookie jar. + Use .get and .set and include domain and path args in order to be more specific. + """ + + +class RequestsCookieJar(cookielib.CookieJar, MutableMapping): + """Compatibility class; is a cookielib.CookieJar, but exposes a dict + interface. + + This is the CookieJar we create by default for requests and sessions that + don't specify one, since some clients may expect response.cookies and + session.cookies to support dict operations. + + Requests does not use the dict interface internally; it's just for + compatibility with external client code. All requests code should work + out of the box with externally provided instances of ``CookieJar``, e.g. + ``LWPCookieJar`` and ``FileCookieJar``. + + Unlike a regular CookieJar, this class is pickleable. + + .. warning:: dictionary operations that are normally O(1) may be O(n). + """ + + def get(self, name, default=None, domain=None, path=None): + """Dict-like get() that also supports optional domain and path args in + order to resolve naming collisions from using one cookie jar over + multiple domains. + + .. warning:: operation is O(n), not O(1). + """ + try: + return self._find_no_duplicates(name, domain, path) + except KeyError: + return default + + def set(self, name, value, **kwargs): + """Dict-like set() that also supports optional domain and path args in + order to resolve naming collisions from using one cookie jar over + multiple domains. + """ + # support client code that unsets cookies by assignment of a None value: + if value is None: + remove_cookie_by_name(self, name, domain=kwargs.get('domain'), path=kwargs.get('path')) + return + + if isinstance(value, Morsel): + c = morsel_to_cookie(value) + else: + c = create_cookie(name, value, **kwargs) + self.set_cookie(c) + return c + + def iterkeys(self): + """Dict-like iterkeys() that returns an iterator of names of cookies + from the jar. + + .. seealso:: itervalues() and iteritems(). + """ + for cookie in iter(self): + yield cookie.name + + def keys(self): + """Dict-like keys() that returns a list of names of cookies from the + jar. + + .. seealso:: values() and items(). + """ + return list(self.iterkeys()) + + def itervalues(self): + """Dict-like itervalues() that returns an iterator of values of cookies + from the jar. + + .. seealso:: iterkeys() and iteritems(). + """ + for cookie in iter(self): + yield cookie.value + + def values(self): + """Dict-like values() that returns a list of values of cookies from the + jar. + + .. seealso:: keys() and items(). + """ + return list(self.itervalues()) + + def iteritems(self): + """Dict-like iteritems() that returns an iterator of name-value tuples + from the jar. + + .. seealso:: iterkeys() and itervalues(). + """ + for cookie in iter(self): + yield cookie.name, cookie.value + + def items(self): + """Dict-like items() that returns a list of name-value tuples from the + jar. Allows client-code to call ``dict(RequestsCookieJar)`` and get a + vanilla python dict of key value pairs. + + .. seealso:: keys() and values(). + """ + return list(self.iteritems()) + + def list_domains(self): + """Utility method to list all the domains in the jar.""" + domains = [] + for cookie in iter(self): + if cookie.domain not in domains: + domains.append(cookie.domain) + return domains + + def list_paths(self): + """Utility method to list all the paths in the jar.""" + paths = [] + for cookie in iter(self): + if cookie.path not in paths: + paths.append(cookie.path) + return paths + + def multiple_domains(self): + """Returns True if there are multiple domains in the jar. + Returns False otherwise. + + :rtype: bool + """ + domains = [] + for cookie in iter(self): + if cookie.domain is not None and cookie.domain in domains: + return True + domains.append(cookie.domain) + return False # there is only one domain in jar + + def get_dict(self, domain=None, path=None): + """Takes as an argument an optional domain and path and returns a plain + old Python dict of name-value pairs of cookies that meet the + requirements. + + :rtype: dict + """ + dictionary = {} + for cookie in iter(self): + if ( + (domain is None or cookie.domain == domain) and + (path is None or cookie.path == path) + ): + dictionary[cookie.name] = cookie.value + return dictionary + + def __contains__(self, name): + try: + return super(RequestsCookieJar, self).__contains__(name) + except CookieConflictError: + return True + + def __getitem__(self, name): + """Dict-like __getitem__() for compatibility with client code. Throws + exception if there are more than one cookie with name. In that case, + use the more explicit get() method instead. + + .. warning:: operation is O(n), not O(1). + """ + return self._find_no_duplicates(name) + + def __setitem__(self, name, value): + """Dict-like __setitem__ for compatibility with client code. Throws + exception if there is already a cookie of that name in the jar. In that + case, use the more explicit set() method instead. + """ + self.set(name, value) + + def __delitem__(self, name): + """Deletes a cookie given a name. Wraps ``cookielib.CookieJar``'s + ``remove_cookie_by_name()``. + """ + remove_cookie_by_name(self, name) + + def set_cookie(self, cookie, *args, **kwargs): + if hasattr(cookie.value, 'startswith') and cookie.value.startswith('"') and cookie.value.endswith('"'): + cookie.value = cookie.value.replace('\\"', '') + return super(RequestsCookieJar, self).set_cookie(cookie, *args, **kwargs) + + def update(self, other): + """Updates this jar with cookies from another CookieJar or dict-like""" + if isinstance(other, cookielib.CookieJar): + for cookie in other: + self.set_cookie(copy.copy(cookie)) + else: + super(RequestsCookieJar, self).update(other) + + def _find(self, name, domain=None, path=None): + """Requests uses this method internally to get cookie values. + + If there are conflicting cookies, _find arbitrarily chooses one. + See _find_no_duplicates if you want an exception thrown if there are + conflicting cookies. + + :param name: a string containing name of cookie + :param domain: (optional) string containing domain of cookie + :param path: (optional) string containing path of cookie + :return: cookie.value + """ + for cookie in iter(self): + if cookie.name == name: + if domain is None or cookie.domain == domain: + if path is None or cookie.path == path: + return cookie.value + + raise KeyError('name=%r, domain=%r, path=%r' % (name, domain, path)) + + def _find_no_duplicates(self, name, domain=None, path=None): + """Both ``__get_item__`` and ``get`` call this function: it's never + used elsewhere in Requests. + + :param name: a string containing name of cookie + :param domain: (optional) string containing domain of cookie + :param path: (optional) string containing path of cookie + :raises KeyError: if cookie is not found + :raises CookieConflictError: if there are multiple cookies + that match name and optionally domain and path + :return: cookie.value + """ + toReturn = None + for cookie in iter(self): + if cookie.name == name: + if domain is None or cookie.domain == domain: + if path is None or cookie.path == path: + if toReturn is not None: # if there are multiple cookies that meet passed in criteria + raise CookieConflictError('There are multiple cookies with name, %r' % (name)) + toReturn = cookie.value # we will eventually return this as long as no cookie conflict + + if toReturn: + return toReturn + raise KeyError('name=%r, domain=%r, path=%r' % (name, domain, path)) + + def __getstate__(self): + """Unlike a normal CookieJar, this class is pickleable.""" + state = self.__dict__.copy() + # remove the unpickleable RLock object + state.pop('_cookies_lock') + return state + + def __setstate__(self, state): + """Unlike a normal CookieJar, this class is pickleable.""" + self.__dict__.update(state) + if '_cookies_lock' not in self.__dict__: + self._cookies_lock = threading.RLock() + + def copy(self): + """Return a copy of this RequestsCookieJar.""" + new_cj = RequestsCookieJar() + new_cj.set_policy(self.get_policy()) + new_cj.update(self) + return new_cj + + def get_policy(self): + """Return the CookiePolicy instance used.""" + return self._policy + + +def _copy_cookie_jar(jar): + if jar is None: + return None + + if hasattr(jar, 'copy'): + # We're dealing with an instance of RequestsCookieJar + return jar.copy() + # We're dealing with a generic CookieJar instance + new_jar = copy.copy(jar) + new_jar.clear() + for cookie in jar: + new_jar.set_cookie(copy.copy(cookie)) + return new_jar + + +def create_cookie(name, value, **kwargs): + """Make a cookie from underspecified parameters. + + By default, the pair of `name` and `value` will be set for the domain '' + and sent on every request (this is sometimes called a "supercookie"). + """ + result = { + 'version': 0, + 'name': name, + 'value': value, + 'port': None, + 'domain': '', + 'path': '/', + 'secure': False, + 'expires': None, + 'discard': True, + 'comment': None, + 'comment_url': None, + 'rest': {'HttpOnly': None}, + 'rfc2109': False, + } + + badargs = set(kwargs) - set(result) + if badargs: + err = 'create_cookie() got unexpected keyword arguments: %s' + raise TypeError(err % list(badargs)) + + result.update(kwargs) + result['port_specified'] = bool(result['port']) + result['domain_specified'] = bool(result['domain']) + result['domain_initial_dot'] = result['domain'].startswith('.') + result['path_specified'] = bool(result['path']) + + return cookielib.Cookie(**result) + + +def morsel_to_cookie(morsel): + """Convert a Morsel object into a Cookie containing the one k/v pair.""" + + expires = None + if morsel['max-age']: + try: + expires = int(time.time() + int(morsel['max-age'])) + except ValueError: + raise TypeError('max-age: %s must be integer' % morsel['max-age']) + elif morsel['expires']: + time_template = '%a, %d-%b-%Y %H:%M:%S GMT' + expires = calendar.timegm( + time.strptime(morsel['expires'], time_template) + ) + return create_cookie( + comment=morsel['comment'], + comment_url=bool(morsel['comment']), + discard=False, + domain=morsel['domain'], + expires=expires, + name=morsel.key, + path=morsel['path'], + port=None, + rest={'HttpOnly': morsel['httponly']}, + rfc2109=False, + secure=bool(morsel['secure']), + value=morsel.value, + version=morsel['version'] or 0, + ) + + +def cookiejar_from_dict(cookie_dict, cookiejar=None, overwrite=True): + """Returns a CookieJar from a key/value dictionary. + + :param cookie_dict: Dict of key/values to insert into CookieJar. + :param cookiejar: (optional) A cookiejar to add the cookies to. + :param overwrite: (optional) If False, will not replace cookies + already in the jar with new ones. + :rtype: CookieJar + """ + if cookiejar is None: + cookiejar = RequestsCookieJar() + + if cookie_dict is not None: + names_from_jar = [cookie.name for cookie in cookiejar] + for name in cookie_dict: + if overwrite or (name not in names_from_jar): + cookiejar.set_cookie(create_cookie(name, cookie_dict[name])) + + return cookiejar + + +def merge_cookies(cookiejar, cookies): + """Add cookies to cookiejar and returns a merged CookieJar. + + :param cookiejar: CookieJar object to add the cookies to. + :param cookies: Dictionary or CookieJar object to be added. + :rtype: CookieJar + """ + if not isinstance(cookiejar, cookielib.CookieJar): + raise ValueError('You can only merge into CookieJar') + + if isinstance(cookies, dict): + cookiejar = cookiejar_from_dict( + cookies, cookiejar=cookiejar, overwrite=False) + elif isinstance(cookies, cookielib.CookieJar): + try: + cookiejar.update(cookies) + except AttributeError: + for cookie_in_jar in cookies: + cookiejar.set_cookie(cookie_in_jar) + + return cookiejar diff --git a/requests/exceptions.py b/requests/exceptions.py new file mode 100644 index 0000000..a80cad8 --- /dev/null +++ b/requests/exceptions.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +""" +requests.exceptions +~~~~~~~~~~~~~~~~~~~ + +This module contains the set of Requests' exceptions. +""" +from urllib3.exceptions import HTTPError as BaseHTTPError + + +class RequestException(IOError): + """There was an ambiguous exception that occurred while handling your + request. + """ + + def __init__(self, *args, **kwargs): + """Initialize RequestException with `request` and `response` objects.""" + response = kwargs.pop('response', None) + self.response = response + self.request = kwargs.pop('request', None) + if (response is not None and not self.request and + hasattr(response, 'request')): + self.request = self.response.request + super(RequestException, self).__init__(*args, **kwargs) + + +class HTTPError(RequestException): + """An HTTP error occurred.""" + + +class ConnectionError(RequestException): + """A Connection error occurred.""" + + +class ProxyError(ConnectionError): + """A proxy error occurred.""" + + +class SSLError(ConnectionError): + """An SSL error occurred.""" + + +class Timeout(RequestException): + """The request timed out. + + Catching this error will catch both + :exc:`~requests.exceptions.ConnectTimeout` and + :exc:`~requests.exceptions.ReadTimeout` errors. + """ + + +class ConnectTimeout(ConnectionError, Timeout): + """The request timed out while trying to connect to the remote server. + + Requests that produced this error are safe to retry. + """ + + +class ReadTimeout(Timeout): + """The server did not send any data in the allotted amount of time.""" + + +class URLRequired(RequestException): + """A valid URL is required to make a request.""" + + +class TooManyRedirects(RequestException): + """Too many redirects.""" + + +class MissingSchema(RequestException, ValueError): + """The URL schema (e.g. http or https) is missing.""" + + +class InvalidSchema(RequestException, ValueError): + """See defaults.py for valid schemas.""" + + +class InvalidURL(RequestException, ValueError): + """The URL provided was somehow invalid.""" + + +class InvalidHeader(RequestException, ValueError): + """The header value provided was somehow invalid.""" + + +class InvalidProxyURL(InvalidURL): + """The proxy URL provided is invalid.""" + + +class ChunkedEncodingError(RequestException): + """The server declared chunked encoding but sent an invalid chunk.""" + + +class ContentDecodingError(RequestException, BaseHTTPError): + """Failed to decode response content""" + + +class StreamConsumedError(RequestException, TypeError): + """The content for this response was already consumed""" + + +class RetryError(RequestException): + """Custom retries logic failed""" + + +class UnrewindableBodyError(RequestException): + """Requests encountered an error when trying to rewind a body""" + +# Warnings + + +class RequestsWarning(Warning): + """Base warning for Requests.""" + pass + + +class FileModeWarning(RequestsWarning, DeprecationWarning): + """A file was opened in text mode, but Requests determined its binary length.""" + pass + + +class RequestsDependencyWarning(RequestsWarning): + """An imported dependency doesn't match the expected version range.""" + pass diff --git a/requests/help.py b/requests/help.py new file mode 100644 index 0000000..e53d35e --- /dev/null +++ b/requests/help.py @@ -0,0 +1,119 @@ +"""Module containing bug report helper(s).""" +from __future__ import print_function + +import json +import platform +import sys +import ssl + +import idna +import urllib3 +import chardet + +from . import __version__ as requests_version + +try: + from urllib3.contrib import pyopenssl +except ImportError: + pyopenssl = None + OpenSSL = None + cryptography = None +else: + import OpenSSL + import cryptography + + +def _implementation(): + """Return a dict with the Python implementation and version. + + Provide both the name and the version of the Python implementation + currently running. For example, on CPython 2.7.5 it will return + {'name': 'CPython', 'version': '2.7.5'}. + + This function works best on CPython and PyPy: in particular, it probably + doesn't work for Jython or IronPython. Future investigation should be done + to work out the correct shape of the code for those platforms. + """ + implementation = platform.python_implementation() + + if implementation == 'CPython': + implementation_version = platform.python_version() + elif implementation == 'PyPy': + implementation_version = '%s.%s.%s' % (sys.pypy_version_info.major, + sys.pypy_version_info.minor, + sys.pypy_version_info.micro) + if sys.pypy_version_info.releaselevel != 'final': + implementation_version = ''.join([ + implementation_version, sys.pypy_version_info.releaselevel + ]) + elif implementation == 'Jython': + implementation_version = platform.python_version() # Complete Guess + elif implementation == 'IronPython': + implementation_version = platform.python_version() # Complete Guess + else: + implementation_version = 'Unknown' + + return {'name': implementation, 'version': implementation_version} + + +def info(): + """Generate information for a bug report.""" + try: + platform_info = { + 'system': platform.system(), + 'release': platform.release(), + } + except IOError: + platform_info = { + 'system': 'Unknown', + 'release': 'Unknown', + } + + implementation_info = _implementation() + urllib3_info = {'version': urllib3.__version__} + chardet_info = {'version': chardet.__version__} + + pyopenssl_info = { + 'version': None, + 'openssl_version': '', + } + if OpenSSL: + pyopenssl_info = { + 'version': OpenSSL.__version__, + 'openssl_version': '%x' % OpenSSL.SSL.OPENSSL_VERSION_NUMBER, + } + cryptography_info = { + 'version': getattr(cryptography, '__version__', ''), + } + idna_info = { + 'version': getattr(idna, '__version__', ''), + } + + system_ssl = ssl.OPENSSL_VERSION_NUMBER + system_ssl_info = { + 'version': '%x' % system_ssl if system_ssl is not None else '' + } + + return { + 'platform': platform_info, + 'implementation': implementation_info, + 'system_ssl': system_ssl_info, + 'using_pyopenssl': pyopenssl is not None, + 'pyOpenSSL': pyopenssl_info, + 'urllib3': urllib3_info, + 'chardet': chardet_info, + 'cryptography': cryptography_info, + 'idna': idna_info, + 'requests': { + 'version': requests_version, + }, + } + + +def main(): + """Pretty-print the bug information as JSON.""" + print(json.dumps(info(), sort_keys=True, indent=2)) + + +if __name__ == '__main__': + main() diff --git a/requests/hooks.py b/requests/hooks.py new file mode 100644 index 0000000..7a51f21 --- /dev/null +++ b/requests/hooks.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +""" +requests.hooks +~~~~~~~~~~~~~~ + +This module provides the capabilities for the Requests hooks system. + +Available hooks: + +``response``: + The response generated from a Request. +""" +HOOKS = ['response'] + + +def default_hooks(): + return {event: [] for event in HOOKS} + +# TODO: response is the only one + + +def dispatch_hook(key, hooks, hook_data, **kwargs): + """Dispatches a hook dictionary on a given piece of data.""" + hooks = hooks or {} + hooks = hooks.get(key) + if hooks: + if hasattr(hooks, '__call__'): + hooks = [hooks] + for hook in hooks: + _hook_data = hook(hook_data, **kwargs) + if _hook_data is not None: + hook_data = _hook_data + return hook_data diff --git a/requests/models.py b/requests/models.py new file mode 100644 index 0000000..3579883 --- /dev/null +++ b/requests/models.py @@ -0,0 +1,954 @@ +# -*- coding: utf-8 -*- + +""" +requests.models +~~~~~~~~~~~~~~~ + +This module contains the primary objects that power Requests. +""" + +import datetime +import sys + +# Import encoding now, to avoid implicit import later. +# Implicit import within threads may cause LookupError when standard library is in a ZIP, +# such as in Embedded Python. See https://github.com/psf/requests/issues/3578. +import encodings.idna + +from urllib3.fields import RequestField +from urllib3.filepost import encode_multipart_formdata +from urllib3.util import parse_url +from urllib3.exceptions import ( + DecodeError, ReadTimeoutError, ProtocolError, LocationParseError) + +from io import UnsupportedOperation +from .hooks import default_hooks +from .structures import CaseInsensitiveDict + +from .auth import HTTPBasicAuth +from .cookies import cookiejar_from_dict, get_cookie_header, _copy_cookie_jar +from .exceptions import ( + HTTPError, MissingSchema, InvalidURL, ChunkedEncodingError, + ContentDecodingError, ConnectionError, StreamConsumedError) +from ._internal_utils import to_native_string, unicode_is_ascii +from .utils import ( + guess_filename, get_auth_from_url, requote_uri, + stream_decode_response_unicode, to_key_val_list, parse_header_links, + iter_slices, guess_json_utf, super_len, check_header_validity) +from .compat import ( + Callable, Mapping, + cookielib, urlunparse, urlsplit, urlencode, str, bytes, + is_py2, chardet, builtin_str, basestring) +from .compat import json as complexjson +from .status_codes import codes + +#: The set of HTTP status codes that indicate an automatically +#: processable redirect. +REDIRECT_STATI = ( + codes.moved, # 301 + codes.found, # 302 + codes.other, # 303 + codes.temporary_redirect, # 307 + codes.permanent_redirect, # 308 +) + +DEFAULT_REDIRECT_LIMIT = 30 +CONTENT_CHUNK_SIZE = 10 * 1024 +ITER_CHUNK_SIZE = 512 + + +class RequestEncodingMixin(object): + @property + def path_url(self): + """Build the path URL to use.""" + + url = [] + + p = urlsplit(self.url) + + path = p.path + if not path: + path = '/' + + url.append(path) + + query = p.query + if query: + url.append('?') + url.append(query) + + return ''.join(url) + + @staticmethod + def _encode_params(data): + """Encode parameters in a piece of data. + + Will successfully encode parameters when passed as a dict or a list of + 2-tuples. Order is retained if data is a list of 2-tuples but arbitrary + if parameters are supplied as a dict. + """ + + if isinstance(data, (str, bytes)): + return data + elif hasattr(data, 'read'): + return data + elif hasattr(data, '__iter__'): + result = [] + for k, vs in to_key_val_list(data): + if isinstance(vs, basestring) or not hasattr(vs, '__iter__'): + vs = [vs] + for v in vs: + if v is not None: + result.append( + (k.encode('utf-8') if isinstance(k, str) else k, + v.encode('utf-8') if isinstance(v, str) else v)) + return urlencode(result, doseq=True) + else: + return data + + @staticmethod + def _encode_files(files, data): + """Build the body for a multipart/form-data request. + + Will successfully encode files when passed as a dict or a list of + tuples. Order is retained if data is a list of tuples but arbitrary + if parameters are supplied as a dict. + The tuples may be 2-tuples (filename, fileobj), 3-tuples (filename, fileobj, contentype) + or 4-tuples (filename, fileobj, contentype, custom_headers). + """ + if (not files): + raise ValueError("Files must be provided.") + elif isinstance(data, basestring): + raise ValueError("Data must not be a string.") + + new_fields = [] + fields = to_key_val_list(data or {}) + files = to_key_val_list(files or {}) + + for field, val in fields: + if isinstance(val, basestring) or not hasattr(val, '__iter__'): + val = [val] + for v in val: + if v is not None: + # Don't call str() on bytestrings: in Py3 it all goes wrong. + if not isinstance(v, bytes): + v = str(v) + + new_fields.append( + (field.decode('utf-8') if isinstance(field, bytes) else field, + v.encode('utf-8') if isinstance(v, str) else v)) + + for (k, v) in files: + # support for explicit filename + ft = None + fh = None + if isinstance(v, (tuple, list)): + if len(v) == 2: + fn, fp = v + elif len(v) == 3: + fn, fp, ft = v + else: + fn, fp, ft, fh = v + else: + fn = guess_filename(v) or k + fp = v + + if isinstance(fp, (str, bytes, bytearray)): + fdata = fp + elif hasattr(fp, 'read'): + fdata = fp.read() + elif fp is None: + continue + else: + fdata = fp + + rf = RequestField(name=k, data=fdata, filename=fn, headers=fh) + rf.make_multipart(content_type=ft) + new_fields.append(rf) + + body, content_type = encode_multipart_formdata(new_fields) + + return body, content_type + + +class RequestHooksMixin(object): + def register_hook(self, event, hook): + """Properly register a hook.""" + + if event not in self.hooks: + raise ValueError('Unsupported event specified, with event name "%s"' % (event)) + + if isinstance(hook, Callable): + self.hooks[event].append(hook) + elif hasattr(hook, '__iter__'): + self.hooks[event].extend(h for h in hook if isinstance(h, Callable)) + + def deregister_hook(self, event, hook): + """Deregister a previously registered hook. + Returns True if the hook existed, False if not. + """ + + try: + self.hooks[event].remove(hook) + return True + except ValueError: + return False + + +class Request(RequestHooksMixin): + """A user-created :class:`Request ` object. + + Used to prepare a :class:`PreparedRequest `, which is sent to the server. + + :param method: HTTP method to use. + :param url: URL to send. + :param headers: dictionary of headers to send. + :param files: dictionary of {filename: fileobject} files to multipart upload. + :param data: the body to attach to the request. If a dictionary or + list of tuples ``[(key, value)]`` is provided, form-encoding will + take place. + :param json: json for the body to attach to the request (if files or data is not specified). + :param params: URL parameters to append to the URL. If a dictionary or + list of tuples ``[(key, value)]`` is provided, form-encoding will + take place. + :param auth: Auth handler or (user, pass) tuple. + :param cookies: dictionary or CookieJar of cookies to attach to this request. + :param hooks: dictionary of callback hooks, for internal usage. + + Usage:: + + >>> import requests + >>> req = requests.Request('GET', 'https://httpbin.org/get') + >>> req.prepare() + + """ + + def __init__(self, + method=None, url=None, headers=None, files=None, data=None, + params=None, auth=None, cookies=None, hooks=None, json=None): + + # Default empty dicts for dict params. + data = [] if data is None else data + files = [] if files is None else files + headers = {} if headers is None else headers + params = {} if params is None else params + hooks = {} if hooks is None else hooks + + self.hooks = default_hooks() + for (k, v) in list(hooks.items()): + self.register_hook(event=k, hook=v) + + self.method = method + self.url = url + self.headers = headers + self.files = files + self.data = data + self.json = json + self.params = params + self.auth = auth + self.cookies = cookies + + def __repr__(self): + return '' % (self.method) + + def prepare(self): + """Constructs a :class:`PreparedRequest ` for transmission and returns it.""" + p = PreparedRequest() + p.prepare( + method=self.method, + url=self.url, + headers=self.headers, + files=self.files, + data=self.data, + json=self.json, + params=self.params, + auth=self.auth, + cookies=self.cookies, + hooks=self.hooks, + ) + return p + + +class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): + """The fully mutable :class:`PreparedRequest ` object, + containing the exact bytes that will be sent to the server. + + Generated from either a :class:`Request ` object or manually. + + Usage:: + + >>> import requests + >>> req = requests.Request('GET', 'https://httpbin.org/get') + >>> r = req.prepare() + >>> r + + + >>> s = requests.Session() + >>> s.send(r) + + """ + + def __init__(self): + #: HTTP verb to send to the server. + self.method = None + #: HTTP URL to send the request to. + self.url = None + #: dictionary of HTTP headers. + self.headers = None + # The `CookieJar` used to create the Cookie header will be stored here + # after prepare_cookies is called + self._cookies = None + #: request body to send to the server. + self.body = None + #: dictionary of callback hooks, for internal usage. + self.hooks = default_hooks() + #: integer denoting starting position of a readable file-like body. + self._body_position = None + + def prepare(self, + method=None, url=None, headers=None, files=None, data=None, + params=None, auth=None, cookies=None, hooks=None, json=None): + """Prepares the entire request with the given parameters.""" + + self.prepare_method(method) + self.prepare_url(url, params) + self.prepare_headers(headers) + self.prepare_cookies(cookies) + self.prepare_body(data, files, json) + self.prepare_auth(auth, url) + + # Note that prepare_auth must be last to enable authentication schemes + # such as OAuth to work on a fully prepared request. + + # This MUST go after prepare_auth. Authenticators could add a hook + self.prepare_hooks(hooks) + + def __repr__(self): + return '' % (self.method) + + def copy(self): + p = PreparedRequest() + p.method = self.method + p.url = self.url + p.headers = self.headers.copy() if self.headers is not None else None + p._cookies = _copy_cookie_jar(self._cookies) + p.body = self.body + p.hooks = self.hooks + p._body_position = self._body_position + return p + + def prepare_method(self, method): + """Prepares the given HTTP method.""" + self.method = method + if self.method is not None: + self.method = to_native_string(self.method.upper()) + + @staticmethod + def _get_idna_encoded_host(host): + import idna + + try: + host = idna.encode(host, uts46=True).decode('utf-8') + except idna.IDNAError: + raise UnicodeError + return host + + def prepare_url(self, url, params): + """Prepares the given HTTP URL.""" + #: Accept objects that have string representations. + #: We're unable to blindly call unicode/str functions + #: as this will include the bytestring indicator (b'') + #: on python 3.x. + #: https://github.com/psf/requests/pull/2238 + if isinstance(url, bytes): + url = url.decode('utf8') + else: + url = unicode(url) if is_py2 else str(url) + + # Remove leading whitespaces from url + url = url.lstrip() + + # Don't do any URL preparation for non-HTTP schemes like `mailto`, + # `data` etc to work around exceptions from `url_parse`, which + # handles RFC 3986 only. + if ':' in url and not url.lower().startswith('http'): + self.url = url + return + + # Support for unicode domain names and paths. + try: + scheme, auth, host, port, path, query, fragment = parse_url(url) + except LocationParseError as e: + raise InvalidURL(*e.args) + + if not scheme: + error = ("Invalid URL {0!r}: No schema supplied. Perhaps you meant http://{0}?") + error = error.format(to_native_string(url, 'utf8')) + + raise MissingSchema(error) + + if not host: + raise InvalidURL("Invalid URL %r: No host supplied" % url) + + # In general, we want to try IDNA encoding the hostname if the string contains + # non-ASCII characters. This allows users to automatically get the correct IDNA + # behaviour. For strings containing only ASCII characters, we need to also verify + # it doesn't start with a wildcard (*), before allowing the unencoded hostname. + if not unicode_is_ascii(host): + try: + host = self._get_idna_encoded_host(host) + except UnicodeError: + raise InvalidURL('URL has an invalid label.') + elif host.startswith(u'*'): + raise InvalidURL('URL has an invalid label.') + + # Carefully reconstruct the network location + netloc = auth or '' + if netloc: + netloc += '@' + netloc += host + if port: + netloc += ':' + str(port) + + # Bare domains aren't valid URLs. + if not path: + path = '/' + + if is_py2: + if isinstance(scheme, str): + scheme = scheme.encode('utf-8') + if isinstance(netloc, str): + netloc = netloc.encode('utf-8') + if isinstance(path, str): + path = path.encode('utf-8') + if isinstance(query, str): + query = query.encode('utf-8') + if isinstance(fragment, str): + fragment = fragment.encode('utf-8') + + if isinstance(params, (str, bytes)): + params = to_native_string(params) + + enc_params = self._encode_params(params) + if enc_params: + if query: + query = '%s&%s' % (query, enc_params) + else: + query = enc_params + + url = requote_uri(urlunparse([scheme, netloc, path, None, query, fragment])) + self.url = url + + def prepare_headers(self, headers): + """Prepares the given HTTP headers.""" + + self.headers = CaseInsensitiveDict() + if headers: + for header in headers.items(): + # Raise exception on invalid header value. + check_header_validity(header) + name, value = header + self.headers[to_native_string(name)] = value + + def prepare_body(self, data, files, json=None): + """Prepares the given HTTP body data.""" + + # Check if file, fo, generator, iterator. + # If not, run through normal process. + + # Nottin' on you. + body = None + content_type = None + + if not data and json is not None: + # urllib3 requires a bytes-like body. Python 2's json.dumps + # provides this natively, but Python 3 gives a Unicode string. + content_type = 'application/json' + body = complexjson.dumps(json) + if not isinstance(body, bytes): + body = body.encode('utf-8') + + is_stream = all([ + hasattr(data, '__iter__'), + not isinstance(data, (basestring, list, tuple, Mapping)) + ]) + + try: + length = super_len(data) + except (TypeError, AttributeError, UnsupportedOperation): + length = None + + if is_stream: + body = data + + if getattr(body, 'tell', None) is not None: + # Record the current file position before reading. + # This will allow us to rewind a file in the event + # of a redirect. + try: + self._body_position = body.tell() + except (IOError, OSError): + # This differentiates from None, allowing us to catch + # a failed `tell()` later when trying to rewind the body + self._body_position = object() + + if files: + raise NotImplementedError('Streamed bodies and files are mutually exclusive.') + + if length: + self.headers['Content-Length'] = builtin_str(length) + else: + self.headers['Transfer-Encoding'] = 'chunked' + else: + # Multi-part file uploads. + if files: + (body, content_type) = self._encode_files(files, data) + else: + if data: + body = self._encode_params(data) + if isinstance(data, basestring) or hasattr(data, 'read'): + content_type = None + else: + content_type = 'application/x-www-form-urlencoded' + + self.prepare_content_length(body) + + # Add content-type if it wasn't explicitly provided. + if content_type and ('content-type' not in self.headers): + self.headers['Content-Type'] = content_type + + self.body = body + + def prepare_content_length(self, body): + """Prepare Content-Length header based on request method and body""" + if body is not None: + length = super_len(body) + if length: + # If length exists, set it. Otherwise, we fallback + # to Transfer-Encoding: chunked. + self.headers['Content-Length'] = builtin_str(length) + elif self.method not in ('GET', 'HEAD') and self.headers.get('Content-Length') is None: + # Set Content-Length to 0 for methods that can have a body + # but don't provide one. (i.e. not GET or HEAD) + self.headers['Content-Length'] = '0' + + def prepare_auth(self, auth, url=''): + """Prepares the given HTTP auth data.""" + + # If no Auth is explicitly provided, extract it from the URL first. + if auth is None: + url_auth = get_auth_from_url(self.url) + auth = url_auth if any(url_auth) else None + + if auth: + if isinstance(auth, tuple) and len(auth) == 2: + # special-case basic HTTP auth + auth = HTTPBasicAuth(*auth) + + # Allow auth to make its changes. + r = auth(self) + + # Update self to reflect the auth changes. + self.__dict__.update(r.__dict__) + + # Recompute Content-Length + self.prepare_content_length(self.body) + + def prepare_cookies(self, cookies): + """Prepares the given HTTP cookie data. + + This function eventually generates a ``Cookie`` header from the + given cookies using cookielib. Due to cookielib's design, the header + will not be regenerated if it already exists, meaning this function + can only be called once for the life of the + :class:`PreparedRequest ` object. Any subsequent calls + to ``prepare_cookies`` will have no actual effect, unless the "Cookie" + header is removed beforehand. + """ + if isinstance(cookies, cookielib.CookieJar): + self._cookies = cookies + else: + self._cookies = cookiejar_from_dict(cookies) + + cookie_header = get_cookie_header(self._cookies, self) + if cookie_header is not None: + self.headers['Cookie'] = cookie_header + + def prepare_hooks(self, hooks): + """Prepares the given hooks.""" + # hooks can be passed as None to the prepare method and to this + # method. To prevent iterating over None, simply use an empty list + # if hooks is False-y + hooks = hooks or [] + for event in hooks: + self.register_hook(event, hooks[event]) + + +class Response(object): + """The :class:`Response ` object, which contains a + server's response to an HTTP request. + """ + + __attrs__ = [ + '_content', 'status_code', 'headers', 'url', 'history', + 'encoding', 'reason', 'cookies', 'elapsed', 'request' + ] + + def __init__(self): + self._content = False + self._content_consumed = False + self._next = None + + #: Integer Code of responded HTTP Status, e.g. 404 or 200. + self.status_code = None + + #: Case-insensitive Dictionary of Response Headers. + #: For example, ``headers['content-encoding']`` will return the + #: value of a ``'Content-Encoding'`` response header. + self.headers = CaseInsensitiveDict() + + #: File-like object representation of response (for advanced usage). + #: Use of ``raw`` requires that ``stream=True`` be set on the request. + #: This requirement does not apply for use internally to Requests. + self.raw = None + + #: Final URL location of Response. + self.url = None + + #: Encoding to decode with when accessing r.text. + self.encoding = None + + #: A list of :class:`Response ` objects from + #: the history of the Request. Any redirect responses will end + #: up here. The list is sorted from the oldest to the most recent request. + self.history = [] + + #: Textual reason of responded HTTP Status, e.g. "Not Found" or "OK". + self.reason = None + + #: A CookieJar of Cookies the server sent back. + self.cookies = cookiejar_from_dict({}) + + #: The amount of time elapsed between sending the request + #: and the arrival of the response (as a timedelta). + #: This property specifically measures the time taken between sending + #: the first byte of the request and finishing parsing the headers. It + #: is therefore unaffected by consuming the response content or the + #: value of the ``stream`` keyword argument. + self.elapsed = datetime.timedelta(0) + + #: The :class:`PreparedRequest ` object to which this + #: is a response. + self.request = None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def __getstate__(self): + # Consume everything; accessing the content attribute makes + # sure the content has been fully read. + if not self._content_consumed: + self.content + + return {attr: getattr(self, attr, None) for attr in self.__attrs__} + + def __setstate__(self, state): + for name, value in state.items(): + setattr(self, name, value) + + # pickled objects do not have .raw + setattr(self, '_content_consumed', True) + setattr(self, 'raw', None) + + def __repr__(self): + return '' % (self.status_code) + + def __bool__(self): + """Returns True if :attr:`status_code` is less than 400. + + This attribute checks if the status code of the response is between + 400 and 600 to see if there was a client error or a server error. If + the status code, is between 200 and 400, this will return True. This + is **not** a check to see if the response code is ``200 OK``. + """ + return self.ok + + def __nonzero__(self): + """Returns True if :attr:`status_code` is less than 400. + + This attribute checks if the status code of the response is between + 400 and 600 to see if there was a client error or a server error. If + the status code, is between 200 and 400, this will return True. This + is **not** a check to see if the response code is ``200 OK``. + """ + return self.ok + + def __iter__(self): + """Allows you to use a response as an iterator.""" + return self.iter_content(128) + + @property + def ok(self): + """Returns True if :attr:`status_code` is less than 400, False if not. + + This attribute checks if the status code of the response is between + 400 and 600 to see if there was a client error or a server error. If + the status code is between 200 and 400, this will return True. This + is **not** a check to see if the response code is ``200 OK``. + """ + try: + self.raise_for_status() + except HTTPError: + return False + return True + + @property + def is_redirect(self): + """True if this Response is a well-formed HTTP redirect that could have + been processed automatically (by :meth:`Session.resolve_redirects`). + """ + return ('location' in self.headers and self.status_code in REDIRECT_STATI) + + @property + def is_permanent_redirect(self): + """True if this Response one of the permanent versions of redirect.""" + return ('location' in self.headers and self.status_code in (codes.moved_permanently, codes.permanent_redirect)) + + @property + def next(self): + """Returns a PreparedRequest for the next request in a redirect chain, if there is one.""" + return self._next + + @property + def apparent_encoding(self): + """The apparent encoding, provided by the chardet library.""" + return chardet.detect(self.content)['encoding'] + + def iter_content(self, chunk_size=1, decode_unicode=False): + """Iterates over the response data. When stream=True is set on the + request, this avoids reading the content at once into memory for + large responses. The chunk size is the number of bytes it should + read into memory. This is not necessarily the length of each item + returned as decoding can take place. + + chunk_size must be of type int or None. A value of None will + function differently depending on the value of `stream`. + stream=True will read data as it arrives in whatever size the + chunks are received. If stream=False, data is returned as + a single chunk. + + If decode_unicode is True, content will be decoded using the best + available encoding based on the response. + """ + + def generate(): + # Special case for urllib3. + if hasattr(self.raw, 'stream'): + try: + for chunk in self.raw.stream(chunk_size, decode_content=True): + yield chunk + except ProtocolError as e: + raise ChunkedEncodingError(e) + except DecodeError as e: + raise ContentDecodingError(e) + except ReadTimeoutError as e: + raise ConnectionError(e) + else: + # Standard file-like object. + while True: + chunk = self.raw.read(chunk_size) + if not chunk: + break + yield chunk + + self._content_consumed = True + + if self._content_consumed and isinstance(self._content, bool): + raise StreamConsumedError() + elif chunk_size is not None and not isinstance(chunk_size, int): + raise TypeError("chunk_size must be an int, it is instead a %s." % type(chunk_size)) + # simulate reading small chunks of the content + reused_chunks = iter_slices(self._content, chunk_size) + + stream_chunks = generate() + + chunks = reused_chunks if self._content_consumed else stream_chunks + + if decode_unicode: + chunks = stream_decode_response_unicode(chunks, self) + + return chunks + + def iter_lines(self, chunk_size=ITER_CHUNK_SIZE, decode_unicode=False, delimiter=None): + """Iterates over the response data, one line at a time. When + stream=True is set on the request, this avoids reading the + content at once into memory for large responses. + + .. note:: This method is not reentrant safe. + """ + + pending = None + + for chunk in self.iter_content(chunk_size=chunk_size, decode_unicode=decode_unicode): + + if pending is not None: + chunk = pending + chunk + + if delimiter: + lines = chunk.split(delimiter) + else: + lines = chunk.splitlines() + + if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: + pending = lines.pop() + else: + pending = None + + for line in lines: + yield line + + if pending is not None: + yield pending + + @property + def content(self): + """Content of the response, in bytes.""" + + if self._content is False: + # Read the contents. + if self._content_consumed: + raise RuntimeError( + 'The content for this response was already consumed') + + if self.status_code == 0 or self.raw is None: + self._content = None + else: + self._content = b''.join(self.iter_content(CONTENT_CHUNK_SIZE)) or b'' + + self._content_consumed = True + # don't need to release the connection; that's been handled by urllib3 + # since we exhausted the data. + return self._content + + @property + def text(self): + """Content of the response, in unicode. + + If Response.encoding is None, encoding will be guessed using + ``chardet``. + + The encoding of the response content is determined based solely on HTTP + headers, following RFC 2616 to the letter. If you can take advantage of + non-HTTP knowledge to make a better guess at the encoding, you should + set ``r.encoding`` appropriately before accessing this property. + """ + + # Try charset from content-type + content = None + encoding = self.encoding + + if not self.content: + return str('') + + # Fallback to auto-detected encoding. + if self.encoding is None: + encoding = self.apparent_encoding + + # Decode unicode from given encoding. + try: + content = str(self.content, encoding, errors='replace') + except (LookupError, TypeError): + # A LookupError is raised if the encoding was not found which could + # indicate a misspelling or similar mistake. + # + # A TypeError can be raised if encoding is None + # + # So we try blindly encoding. + content = str(self.content, errors='replace') + + return content + + def json(self, **kwargs): + r"""Returns the json-encoded content of a response, if any. + + :param \*\*kwargs: Optional arguments that ``json.loads`` takes. + :raises ValueError: If the response body does not contain valid json. + """ + + if not self.encoding and self.content and len(self.content) > 3: + # No encoding set. JSON RFC 4627 section 3 states we should expect + # UTF-8, -16 or -32. Detect which one to use; If the detection or + # decoding fails, fall back to `self.text` (using chardet to make + # a best guess). + encoding = guess_json_utf(self.content) + if encoding is not None: + try: + return complexjson.loads( + self.content.decode(encoding), **kwargs + ) + except UnicodeDecodeError: + # Wrong UTF codec detected; usually because it's not UTF-8 + # but some other 8-bit codec. This is an RFC violation, + # and the server didn't bother to tell us what codec *was* + # used. + pass + return complexjson.loads(self.text, **kwargs) + + @property + def links(self): + """Returns the parsed header links of the response, if any.""" + + header = self.headers.get('link') + + # l = MultiDict() + l = {} + + if header: + links = parse_header_links(header) + + for link in links: + key = link.get('rel') or link.get('url') + l[key] = link + + return l + + def raise_for_status(self): + """Raises stored :class:`HTTPError`, if one occurred.""" + + http_error_msg = '' + if isinstance(self.reason, bytes): + # We attempt to decode utf-8 first because some servers + # choose to localize their reason strings. If the string + # isn't utf-8, we fall back to iso-8859-1 for all other + # encodings. (See PR #3538) + try: + reason = self.reason.decode('utf-8') + except UnicodeDecodeError: + reason = self.reason.decode('iso-8859-1') + else: + reason = self.reason + + if 400 <= self.status_code < 500: + http_error_msg = u'%s Client Error: %s for url: %s' % (self.status_code, reason, self.url) + + elif 500 <= self.status_code < 600: + http_error_msg = u'%s Server Error: %s for url: %s' % (self.status_code, reason, self.url) + + if http_error_msg: + raise HTTPError(http_error_msg, response=self) + + def close(self): + """Releases the connection back to the pool. Once this method has been + called the underlying ``raw`` object must not be accessed again. + + *Note: Should not normally need to be called explicitly.* + """ + if not self._content_consumed: + self.raw.close() + + release_conn = getattr(self.raw, 'release_conn', None) + if release_conn is not None: + release_conn() diff --git a/requests/packages.py b/requests/packages.py new file mode 100644 index 0000000..7232fe0 --- /dev/null +++ b/requests/packages.py @@ -0,0 +1,14 @@ +import sys + +# This code exists for backwards compatibility reasons. +# I don't like it either. Just look the other way. :) + +for package in ('urllib3', 'idna', 'chardet'): + locals()[package] = __import__(package) + # This traversal is apparently necessary such that the identities are + # preserved (requests.packages.urllib3.* is urllib3.*) + for mod in list(sys.modules): + if mod == package or mod.startswith(package + '.'): + sys.modules['requests.packages.' + mod] = sys.modules[mod] + +# Kinda cool, though, right? diff --git a/requests/sessions.py b/requests/sessions.py new file mode 100644 index 0000000..2845880 --- /dev/null +++ b/requests/sessions.py @@ -0,0 +1,767 @@ +# -*- coding: utf-8 -*- + +""" +requests.session +~~~~~~~~~~~~~~~~ + +This module provides a Session object to manage and persist settings across +requests (cookies, auth, proxies). +""" +import os +import sys +import time +from datetime import timedelta +from collections import OrderedDict + +from .auth import _basic_auth_str +from .compat import cookielib, is_py3, urljoin, urlparse, Mapping +from .cookies import ( + cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar, merge_cookies) +from .models import Request, PreparedRequest, DEFAULT_REDIRECT_LIMIT +from .hooks import default_hooks, dispatch_hook +from ._internal_utils import to_native_string +from .utils import to_key_val_list, default_headers, DEFAULT_PORTS +from .exceptions import ( + TooManyRedirects, InvalidSchema, ChunkedEncodingError, ContentDecodingError) + +from .structures import CaseInsensitiveDict +from .adapters import HTTPAdapter + +from .utils import ( + requote_uri, get_environ_proxies, get_netrc_auth, should_bypass_proxies, + get_auth_from_url, rewind_body +) + +from .status_codes import codes + +# formerly defined here, reexposed here for backward compatibility +from .models import REDIRECT_STATI + +# Preferred clock, based on which one is more accurate on a given system. +if sys.platform == 'win32': + try: # Python 3.4+ + preferred_clock = time.perf_counter + except AttributeError: # Earlier than Python 3. + preferred_clock = time.clock +else: + preferred_clock = time.time + + +def merge_setting(request_setting, session_setting, dict_class=OrderedDict): + """Determines appropriate setting for a given request, taking into account + the explicit setting on that request, and the setting in the session. If a + setting is a dictionary, they will be merged together using `dict_class` + """ + + if session_setting is None: + return request_setting + + if request_setting is None: + return session_setting + + # Bypass if not a dictionary (e.g. verify) + if not ( + isinstance(session_setting, Mapping) and + isinstance(request_setting, Mapping) + ): + return request_setting + + merged_setting = dict_class(to_key_val_list(session_setting)) + merged_setting.update(to_key_val_list(request_setting)) + + # Remove keys that are set to None. Extract keys first to avoid altering + # the dictionary during iteration. + none_keys = [k for (k, v) in merged_setting.items() if v is None] + for key in none_keys: + del merged_setting[key] + + return merged_setting + + +def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict): + """Properly merges both requests and session hooks. + + This is necessary because when request_hooks == {'response': []}, the + merge breaks Session hooks entirely. + """ + if session_hooks is None or session_hooks.get('response') == []: + return request_hooks + + if request_hooks is None or request_hooks.get('response') == []: + return session_hooks + + return merge_setting(request_hooks, session_hooks, dict_class) + + +class SessionRedirectMixin(object): + + def get_redirect_target(self, resp): + """Receives a Response. Returns a redirect URI or ``None``""" + # Due to the nature of how requests processes redirects this method will + # be called at least once upon the original response and at least twice + # on each subsequent redirect response (if any). + # If a custom mixin is used to handle this logic, it may be advantageous + # to cache the redirect location onto the response object as a private + # attribute. + if resp.is_redirect: + location = resp.headers['location'] + # Currently the underlying http module on py3 decode headers + # in latin1, but empirical evidence suggests that latin1 is very + # rarely used with non-ASCII characters in HTTP headers. + # It is more likely to get UTF8 header rather than latin1. + # This causes incorrect handling of UTF8 encoded location headers. + # To solve this, we re-encode the location in latin1. + if is_py3: + location = location.encode('latin1') + return to_native_string(location, 'utf8') + return None + + def should_strip_auth(self, old_url, new_url): + """Decide whether Authorization header should be removed when redirecting""" + old_parsed = urlparse(old_url) + new_parsed = urlparse(new_url) + if old_parsed.hostname != new_parsed.hostname: + return True + # Special case: allow http -> https redirect when using the standard + # ports. This isn't specified by RFC 7235, but is kept to avoid + # breaking backwards compatibility with older versions of requests + # that allowed any redirects on the same host. + if (old_parsed.scheme == 'http' and old_parsed.port in (80, None) + and new_parsed.scheme == 'https' and new_parsed.port in (443, None)): + return False + + # Handle default port usage corresponding to scheme. + changed_port = old_parsed.port != new_parsed.port + changed_scheme = old_parsed.scheme != new_parsed.scheme + default_port = (DEFAULT_PORTS.get(old_parsed.scheme, None), None) + if (not changed_scheme and old_parsed.port in default_port + and new_parsed.port in default_port): + return False + + # Standard case: root URI must match + return changed_port or changed_scheme + + def resolve_redirects(self, resp, req, stream=False, timeout=None, + verify=True, cert=None, proxies=None, yield_requests=False, **adapter_kwargs): + """Receives a Response. Returns a generator of Responses or Requests.""" + + hist = [] # keep track of history + + url = self.get_redirect_target(resp) + previous_fragment = urlparse(req.url).fragment + while url: + prepared_request = req.copy() + + # Update history and keep track of redirects. + # resp.history must ignore the original request in this loop + hist.append(resp) + resp.history = hist[1:] + + try: + resp.content # Consume socket so it can be released + except (ChunkedEncodingError, ContentDecodingError, RuntimeError): + resp.raw.read(decode_content=False) + + if len(resp.history) >= self.max_redirects: + raise TooManyRedirects('Exceeded {} redirects.'.format(self.max_redirects), response=resp) + + # Release the connection back into the pool. + resp.close() + + # Handle redirection without scheme (see: RFC 1808 Section 4) + if url.startswith('//'): + parsed_rurl = urlparse(resp.url) + url = ':'.join([to_native_string(parsed_rurl.scheme), url]) + + # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2) + parsed = urlparse(url) + if parsed.fragment == '' and previous_fragment: + parsed = parsed._replace(fragment=previous_fragment) + elif parsed.fragment: + previous_fragment = parsed.fragment + url = parsed.geturl() + + # Facilitate relative 'location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + # Compliant with RFC3986, we percent encode the url. + if not parsed.netloc: + url = urljoin(resp.url, requote_uri(url)) + else: + url = requote_uri(url) + + prepared_request.url = to_native_string(url) + + self.rebuild_method(prepared_request, resp) + + # https://github.com/psf/requests/issues/1084 + if resp.status_code not in (codes.temporary_redirect, codes.permanent_redirect): + # https://github.com/psf/requests/issues/3490 + purged_headers = ('Content-Length', 'Content-Type', 'Transfer-Encoding') + for header in purged_headers: + prepared_request.headers.pop(header, None) + prepared_request.body = None + + headers = prepared_request.headers + headers.pop('Cookie', None) + + # Extract any cookies sent on the response to the cookiejar + # in the new request. Because we've mutated our copied prepared + # request, use the old one that we haven't yet touched. + extract_cookies_to_jar(prepared_request._cookies, req, resp.raw) + merge_cookies(prepared_request._cookies, self.cookies) + prepared_request.prepare_cookies(prepared_request._cookies) + + # Rebuild auth and proxy information. + proxies = self.rebuild_proxies(prepared_request, proxies) + self.rebuild_auth(prepared_request, resp) + + # A failed tell() sets `_body_position` to `object()`. This non-None + # value ensures `rewindable` will be True, allowing us to raise an + # UnrewindableBodyError, instead of hanging the connection. + rewindable = ( + prepared_request._body_position is not None and + ('Content-Length' in headers or 'Transfer-Encoding' in headers) + ) + + # Attempt to rewind consumed file-like object. + if rewindable: + rewind_body(prepared_request) + + # Override the original request. + req = prepared_request + + if yield_requests: + yield req + else: + + resp = self.send( + req, + stream=stream, + timeout=timeout, + verify=verify, + cert=cert, + proxies=proxies, + allow_redirects=False, + **adapter_kwargs + ) + + extract_cookies_to_jar(self.cookies, prepared_request, resp.raw) + + # extract redirect url, if any, for the next loop + url = self.get_redirect_target(resp) + yield resp + + def rebuild_auth(self, prepared_request, response): + """When being redirected we may want to strip authentication from the + request to avoid leaking credentials. This method intelligently removes + and reapplies authentication where possible to avoid credential loss. + """ + headers = prepared_request.headers + url = prepared_request.url + + if 'Authorization' in headers and self.should_strip_auth(response.request.url, url): + # If we get redirected to a new host, we should strip out any + # authentication headers. + del headers['Authorization'] + + # .netrc might have more auth for us on our new host. + new_auth = get_netrc_auth(url) if self.trust_env else None + if new_auth is not None: + prepared_request.prepare_auth(new_auth) + + + def rebuild_proxies(self, prepared_request, proxies): + """This method re-evaluates the proxy configuration by considering the + environment variables. If we are redirected to a URL covered by + NO_PROXY, we strip the proxy configuration. Otherwise, we set missing + proxy keys for this URL (in case they were stripped by a previous + redirect). + + This method also replaces the Proxy-Authorization header where + necessary. + + :rtype: dict + """ + proxies = proxies if proxies is not None else {} + headers = prepared_request.headers + url = prepared_request.url + scheme = urlparse(url).scheme + new_proxies = proxies.copy() + no_proxy = proxies.get('no_proxy') + + bypass_proxy = should_bypass_proxies(url, no_proxy=no_proxy) + if self.trust_env and not bypass_proxy: + environ_proxies = get_environ_proxies(url, no_proxy=no_proxy) + + proxy = environ_proxies.get(scheme, environ_proxies.get('all')) + + if proxy: + new_proxies.setdefault(scheme, proxy) + + if 'Proxy-Authorization' in headers: + del headers['Proxy-Authorization'] + + try: + username, password = get_auth_from_url(new_proxies[scheme]) + except KeyError: + username, password = None, None + + if username and password: + headers['Proxy-Authorization'] = _basic_auth_str(username, password) + + return new_proxies + + def rebuild_method(self, prepared_request, response): + """When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = prepared_request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.see_other and method != 'HEAD': + method = 'GET' + + # Do what the browsers do, despite standards... + # First, turn 302s into GETs. + if response.status_code == codes.found and method != 'HEAD': + method = 'GET' + + # Second, if a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in Issue 1704. + if response.status_code == codes.moved and method == 'POST': + method = 'GET' + + prepared_request.method = method + + +class Session(SessionRedirectMixin): + """A Requests session. + + Provides cookie persistence, connection-pooling, and configuration. + + Basic Usage:: + + >>> import requests + >>> s = requests.Session() + >>> s.get('https://httpbin.org/get') + + + Or as a context manager:: + + >>> with requests.Session() as s: + ... s.get('https://httpbin.org/get') + + """ + + __attrs__ = [ + 'headers', 'cookies', 'auth', 'proxies', 'hooks', 'params', 'verify', + 'cert', 'adapters', 'stream', 'trust_env', + 'max_redirects', + ] + + def __init__(self): + + #: A case-insensitive dictionary of headers to be sent on each + #: :class:`Request ` sent from this + #: :class:`Session `. + self.headers = default_headers() + + #: Default Authentication tuple or object to attach to + #: :class:`Request `. + self.auth = None + + #: Dictionary mapping protocol or protocol and host to the URL of the proxy + #: (e.g. {'http': 'foo.bar:3128', 'http://host.name': 'foo.bar:4012'}) to + #: be used on each :class:`Request `. + self.proxies = {} + + #: Event-handling hooks. + self.hooks = default_hooks() + + #: Dictionary of querystring data to attach to each + #: :class:`Request `. The dictionary values may be lists for + #: representing multivalued query parameters. + self.params = {} + + #: Stream response content default. + self.stream = False + + #: SSL Verification default. + self.verify = True + + #: SSL client certificate default, if String, path to ssl client + #: cert file (.pem). If Tuple, ('cert', 'key') pair. + self.cert = None + + #: Maximum number of redirects allowed. If the request exceeds this + #: limit, a :class:`TooManyRedirects` exception is raised. + #: This defaults to requests.models.DEFAULT_REDIRECT_LIMIT, which is + #: 30. + self.max_redirects = DEFAULT_REDIRECT_LIMIT + + #: Trust environment settings for proxy configuration, default + #: authentication and similar. + self.trust_env = True + + #: A CookieJar containing all currently outstanding cookies set on this + #: session. By default it is a + #: :class:`RequestsCookieJar `, but + #: may be any other ``cookielib.CookieJar`` compatible object. + self.cookies = cookiejar_from_dict({}) + + # Default connection adapters. + self.adapters = OrderedDict() + self.mount('https://', HTTPAdapter()) + self.mount('http://', HTTPAdapter()) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def prepare_request(self, request): + """Constructs a :class:`PreparedRequest ` for + transmission and returns it. The :class:`PreparedRequest` has settings + merged from the :class:`Request ` instance and those of the + :class:`Session`. + + :param request: :class:`Request` instance to prepare with this + session's settings. + :rtype: requests.PreparedRequest + """ + cookies = request.cookies or {} + + # Bootstrap CookieJar. + if not isinstance(cookies, cookielib.CookieJar): + cookies = cookiejar_from_dict(cookies) + + # Merge with session cookies + merged_cookies = merge_cookies( + merge_cookies(RequestsCookieJar(), self.cookies), cookies) + + # Set environment's basic authentication if not explicitly set. + auth = request.auth + if self.trust_env and not auth and not self.auth: + auth = get_netrc_auth(request.url) + + p = PreparedRequest() + p.prepare( + method=request.method.upper(), + url=request.url, + files=request.files, + data=request.data, + json=request.json, + headers=merge_setting(request.headers, self.headers, dict_class=CaseInsensitiveDict), + params=merge_setting(request.params, self.params), + auth=merge_setting(auth, self.auth), + cookies=merged_cookies, + hooks=merge_hooks(request.hooks, self.hooks), + ) + return p + + def request(self, method, url, + params=None, data=None, headers=None, cookies=None, files=None, + auth=None, timeout=None, allow_redirects=True, proxies=None, + hooks=None, stream=None, verify=None, cert=None, json=None): + """Constructs a :class:`Request `, prepares it and sends it. + Returns :class:`Response ` object. + + :param method: method for the new :class:`Request` object. + :param url: URL for the new :class:`Request` object. + :param params: (optional) Dictionary or bytes to be sent in the query + string for the :class:`Request`. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) json to send in the body of the + :class:`Request`. + :param headers: (optional) Dictionary of HTTP Headers to send with the + :class:`Request`. + :param cookies: (optional) Dict or CookieJar object to send with the + :class:`Request`. + :param files: (optional) Dictionary of ``'filename': file-like-objects`` + for multipart encoding upload. + :param auth: (optional) Auth tuple or callable to enable + Basic/Digest/Custom HTTP Auth. + :param timeout: (optional) How long to wait for the server to send + data before giving up, as a float, or a :ref:`(connect timeout, + read timeout) ` tuple. + :type timeout: float or tuple + :param allow_redirects: (optional) Set to True by default. + :type allow_redirects: bool + :param proxies: (optional) Dictionary mapping protocol or protocol and + hostname to the URL of the proxy. + :param stream: (optional) whether to immediately download the response + content. Defaults to ``False``. + :param verify: (optional) Either a boolean, in which case it controls whether we verify + the server's TLS certificate, or a string, in which case it must be a path + to a CA bundle to use. Defaults to ``True``. + :param cert: (optional) if String, path to ssl client cert file (.pem). + If Tuple, ('cert', 'key') pair. + :rtype: requests.Response + """ + # Create the Request. + req = Request( + method=method.upper(), + url=url, + headers=headers, + files=files, + data=data or {}, + json=json, + params=params or {}, + auth=auth, + cookies=cookies, + hooks=hooks, + ) + prep = self.prepare_request(req) + + proxies = proxies or {} + + settings = self.merge_environment_settings( + prep.url, proxies, stream, verify, cert + ) + + # Send the request. + send_kwargs = { + 'timeout': timeout, + 'allow_redirects': allow_redirects, + } + send_kwargs.update(settings) + resp = self.send(prep, **send_kwargs) + + return resp + + def get(self, url, **kwargs): + r"""Sends a GET request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', True) + return self.request('GET', url, **kwargs) + + def options(self, url, **kwargs): + r"""Sends a OPTIONS request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', True) + return self.request('OPTIONS', url, **kwargs) + + def head(self, url, **kwargs): + r"""Sends a HEAD request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + kwargs.setdefault('allow_redirects', False) + return self.request('HEAD', url, **kwargs) + + def post(self, url, data=None, json=None, **kwargs): + r"""Sends a POST request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param json: (optional) json to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + return self.request('POST', url, data=data, json=json, **kwargs) + + def put(self, url, data=None, **kwargs): + r"""Sends a PUT request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + return self.request('PUT', url, data=data, **kwargs) + + def patch(self, url, data=None, **kwargs): + r"""Sends a PATCH request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param data: (optional) Dictionary, list of tuples, bytes, or file-like + object to send in the body of the :class:`Request`. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + return self.request('PATCH', url, data=data, **kwargs) + + def delete(self, url, **kwargs): + r"""Sends a DELETE request. Returns :class:`Response` object. + + :param url: URL for the new :class:`Request` object. + :param \*\*kwargs: Optional arguments that ``request`` takes. + :rtype: requests.Response + """ + + return self.request('DELETE', url, **kwargs) + + def send(self, request, **kwargs): + """Send a given PreparedRequest. + + :rtype: requests.Response + """ + # Set defaults that the hooks can utilize to ensure they always have + # the correct parameters to reproduce the previous request. + kwargs.setdefault('stream', self.stream) + kwargs.setdefault('verify', self.verify) + kwargs.setdefault('cert', self.cert) + kwargs.setdefault('proxies', self.proxies) + + # It's possible that users might accidentally send a Request object. + # Guard against that specific failure case. + if isinstance(request, Request): + raise ValueError('You can only send PreparedRequests.') + + # Set up variables needed for resolve_redirects and dispatching of hooks + allow_redirects = kwargs.pop('allow_redirects', True) + stream = kwargs.get('stream') + hooks = request.hooks + + # Get the appropriate adapter to use + adapter = self.get_adapter(url=request.url) + + # Start time (approximately) of the request + start = preferred_clock() + + # Send the request + r = adapter.send(request, **kwargs) + + # Total elapsed time of the request (approximately) + elapsed = preferred_clock() - start + r.elapsed = timedelta(seconds=elapsed) + + # Response manipulation hooks + r = dispatch_hook('response', hooks, r, **kwargs) + + # Persist cookies + if r.history: + + # If the hooks create history then we want those cookies too + for resp in r.history: + extract_cookies_to_jar(self.cookies, resp.request, resp.raw) + + extract_cookies_to_jar(self.cookies, request, r.raw) + + # Redirect resolving generator. + gen = self.resolve_redirects(r, request, **kwargs) + + # Resolve redirects if allowed. + history = [resp for resp in gen] if allow_redirects else [] + + # Shuffle things around if there's history. + if history: + # Insert the first (original) request at the start + history.insert(0, r) + # Get the last request made + r = history.pop() + r.history = history + + # If redirects aren't being followed, store the response on the Request for Response.next(). + if not allow_redirects: + try: + r._next = next(self.resolve_redirects(r, request, yield_requests=True, **kwargs)) + except StopIteration: + pass + + if not stream: + r.content + + return r + + def merge_environment_settings(self, url, proxies, stream, verify, cert): + """ + Check the environment and merge it with some settings. + + :rtype: dict + """ + # Gather clues from the surrounding environment. + if self.trust_env: + # Set environment's proxies. + no_proxy = proxies.get('no_proxy') if proxies is not None else None + env_proxies = get_environ_proxies(url, no_proxy=no_proxy) + for (k, v) in env_proxies.items(): + proxies.setdefault(k, v) + + # Look for requests environment configuration and be compatible + # with cURL. + if verify is True or verify is None: + verify = (os.environ.get('REQUESTS_CA_BUNDLE') or + os.environ.get('CURL_CA_BUNDLE')) + + # Merge all the kwargs. + proxies = merge_setting(proxies, self.proxies) + stream = merge_setting(stream, self.stream) + verify = merge_setting(verify, self.verify) + cert = merge_setting(cert, self.cert) + + return {'verify': verify, 'proxies': proxies, 'stream': stream, + 'cert': cert} + + def get_adapter(self, url): + """ + Returns the appropriate connection adapter for the given URL. + + :rtype: requests.adapters.BaseAdapter + """ + for (prefix, adapter) in self.adapters.items(): + + if url.lower().startswith(prefix.lower()): + return adapter + + # Nothing matches :-/ + raise InvalidSchema("No connection adapters were found for {!r}".format(url)) + + def close(self): + """Closes all adapters and as such the session""" + for v in self.adapters.values(): + v.close() + + def mount(self, prefix, adapter): + """Registers a connection adapter to a prefix. + + Adapters are sorted in descending order by prefix length. + """ + self.adapters[prefix] = adapter + keys_to_move = [k for k in self.adapters if len(k) < len(prefix)] + + for key in keys_to_move: + self.adapters[key] = self.adapters.pop(key) + + def __getstate__(self): + state = {attr: getattr(self, attr, None) for attr in self.__attrs__} + return state + + def __setstate__(self, state): + for attr, value in state.items(): + setattr(self, attr, value) + + +def session(): + """ + Returns a :class:`Session` for context-management. + + .. deprecated:: 1.0.0 + + This method has been deprecated since version 1.0.0 and is only kept for + backwards compatibility. New code should use :class:`~requests.sessions.Session` + to create a session. This may be removed at a future date. + + :rtype: Session + """ + return Session() diff --git a/requests/status_codes.py b/requests/status_codes.py new file mode 100644 index 0000000..d80a7cd --- /dev/null +++ b/requests/status_codes.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +r""" +The ``codes`` object defines a mapping from common names for HTTP statuses +to their numerical codes, accessible either as attributes or as dictionary +items. + +Example:: + + >>> import requests + >>> requests.codes['temporary_redirect'] + 307 + >>> requests.codes.teapot + 418 + >>> requests.codes['\o/'] + 200 + +Some codes have multiple names, and both upper- and lower-case versions of +the names are allowed. For example, ``codes.ok``, ``codes.OK``, and +``codes.okay`` all correspond to the HTTP status code 200. +""" + +from .structures import LookupDict + +_codes = { + + # Informational. + 100: ('continue',), + 101: ('switching_protocols',), + 102: ('processing',), + 103: ('checkpoint',), + 122: ('uri_too_long', 'request_uri_too_long'), + 200: ('ok', 'okay', 'all_ok', 'all_okay', 'all_good', '\\o/', '✓'), + 201: ('created',), + 202: ('accepted',), + 203: ('non_authoritative_info', 'non_authoritative_information'), + 204: ('no_content',), + 205: ('reset_content', 'reset'), + 206: ('partial_content', 'partial'), + 207: ('multi_status', 'multiple_status', 'multi_stati', 'multiple_stati'), + 208: ('already_reported',), + 226: ('im_used',), + + # Redirection. + 300: ('multiple_choices',), + 301: ('moved_permanently', 'moved', '\\o-'), + 302: ('found',), + 303: ('see_other', 'other'), + 304: ('not_modified',), + 305: ('use_proxy',), + 306: ('switch_proxy',), + 307: ('temporary_redirect', 'temporary_moved', 'temporary'), + 308: ('permanent_redirect', + 'resume_incomplete', 'resume',), # These 2 to be removed in 3.0 + + # Client Error. + 400: ('bad_request', 'bad'), + 401: ('unauthorized',), + 402: ('payment_required', 'payment'), + 403: ('forbidden',), + 404: ('not_found', '-o-'), + 405: ('method_not_allowed', 'not_allowed'), + 406: ('not_acceptable',), + 407: ('proxy_authentication_required', 'proxy_auth', 'proxy_authentication'), + 408: ('request_timeout', 'timeout'), + 409: ('conflict',), + 410: ('gone',), + 411: ('length_required',), + 412: ('precondition_failed', 'precondition'), + 413: ('request_entity_too_large',), + 414: ('request_uri_too_large',), + 415: ('unsupported_media_type', 'unsupported_media', 'media_type'), + 416: ('requested_range_not_satisfiable', 'requested_range', 'range_not_satisfiable'), + 417: ('expectation_failed',), + 418: ('im_a_teapot', 'teapot', 'i_am_a_teapot'), + 421: ('misdirected_request',), + 422: ('unprocessable_entity', 'unprocessable'), + 423: ('locked',), + 424: ('failed_dependency', 'dependency'), + 425: ('unordered_collection', 'unordered'), + 426: ('upgrade_required', 'upgrade'), + 428: ('precondition_required', 'precondition'), + 429: ('too_many_requests', 'too_many'), + 431: ('header_fields_too_large', 'fields_too_large'), + 444: ('no_response', 'none'), + 449: ('retry_with', 'retry'), + 450: ('blocked_by_windows_parental_controls', 'parental_controls'), + 451: ('unavailable_for_legal_reasons', 'legal_reasons'), + 499: ('client_closed_request',), + + # Server Error. + 500: ('internal_server_error', 'server_error', '/o\\', '✗'), + 501: ('not_implemented',), + 502: ('bad_gateway',), + 503: ('service_unavailable', 'unavailable'), + 504: ('gateway_timeout',), + 505: ('http_version_not_supported', 'http_version'), + 506: ('variant_also_negotiates',), + 507: ('insufficient_storage',), + 509: ('bandwidth_limit_exceeded', 'bandwidth'), + 510: ('not_extended',), + 511: ('network_authentication_required', 'network_auth', 'network_authentication'), +} + +codes = LookupDict(name='status_codes') + +def _init(): + for code, titles in _codes.items(): + for title in titles: + setattr(codes, title, code) + if not title.startswith(('\\', '/')): + setattr(codes, title.upper(), code) + + def doc(code): + names = ', '.join('``%s``' % n for n in _codes[code]) + return '* %d: %s' % (code, names) + + global __doc__ + __doc__ = (__doc__ + '\n' + + '\n'.join(doc(code) for code in sorted(_codes)) + if __doc__ is not None else None) + +_init() diff --git a/requests/structures.py b/requests/structures.py new file mode 100644 index 0000000..8ee0ba7 --- /dev/null +++ b/requests/structures.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +""" +requests.structures +~~~~~~~~~~~~~~~~~~~ + +Data structures that power Requests. +""" + +from collections import OrderedDict + +from .compat import Mapping, MutableMapping + + +class CaseInsensitiveDict(MutableMapping): + """A case-insensitive ``dict``-like object. + + Implements all methods and operations of + ``MutableMapping`` as well as dict's ``copy``. Also + provides ``lower_items``. + + All keys are expected to be strings. The structure remembers the + case of the last key to be set, and ``iter(instance)``, + ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` + will contain case-sensitive keys. However, querying and contains + testing is case insensitive:: + + cid = CaseInsensitiveDict() + cid['Accept'] = 'application/json' + cid['aCCEPT'] == 'application/json' # True + list(cid) == ['Accept'] # True + + For example, ``headers['content-encoding']`` will return the + value of a ``'Content-Encoding'`` response header, regardless + of how the header name was originally stored. + + If the constructor, ``.update``, or equality comparison + operations are given keys that have equal ``.lower()``s, the + behavior is undefined. + """ + + def __init__(self, data=None, **kwargs): + self._store = OrderedDict() + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key, value): + # Use the lowercased key for lookups, but store the actual + # key alongside the value. + self._store[key.lower()] = (key, value) + + def __getitem__(self, key): + return self._store[key.lower()][1] + + def __delitem__(self, key): + del self._store[key.lower()] + + def __iter__(self): + return (casedkey for casedkey, mappedvalue in self._store.values()) + + def __len__(self): + return len(self._store) + + def lower_items(self): + """Like iteritems(), but with all lowercase keys.""" + return ( + (lowerkey, keyval[1]) + for (lowerkey, keyval) + in self._store.items() + ) + + def __eq__(self, other): + if isinstance(other, Mapping): + other = CaseInsensitiveDict(other) + else: + return NotImplemented + # Compare insensitively + return dict(self.lower_items()) == dict(other.lower_items()) + + # Copy is required + def copy(self): + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self): + return str(dict(self.items())) + + +class LookupDict(dict): + """Dictionary lookup object.""" + + def __init__(self, name=None): + self.name = name + super(LookupDict, self).__init__() + + def __repr__(self): + return '' % (self.name) + + def __getitem__(self, key): + # We allow fall-through here, so values default to None + + return self.__dict__.get(key, None) + + def get(self, key, default=None): + return self.__dict__.get(key, default) diff --git a/requests/utils.py b/requests/utils.py new file mode 100644 index 0000000..c1700d7 --- /dev/null +++ b/requests/utils.py @@ -0,0 +1,982 @@ +# -*- coding: utf-8 -*- + +""" +requests.utils +~~~~~~~~~~~~~~ + +This module provides utility functions that are used within Requests +that are also useful for external consumption. +""" + +import codecs +import contextlib +import io +import os +import re +import socket +import struct +import sys +import tempfile +import warnings +import zipfile +from collections import OrderedDict + +from .__version__ import __version__ +from . import certs +# to_native_string is unused here, but imported here for backwards compatibility +from ._internal_utils import to_native_string +from .compat import parse_http_list as _parse_list_header +from .compat import ( + quote, urlparse, bytes, str, unquote, getproxies, + proxy_bypass, urlunparse, basestring, integer_types, is_py3, + proxy_bypass_environment, getproxies_environment, Mapping) +from .cookies import cookiejar_from_dict +from .structures import CaseInsensitiveDict +from .exceptions import ( + InvalidURL, InvalidHeader, FileModeWarning, UnrewindableBodyError) + +NETRC_FILES = ('.netrc', '_netrc') + +DEFAULT_CA_BUNDLE_PATH = certs.where() + +DEFAULT_PORTS = {'http': 80, 'https': 443} + + +if sys.platform == 'win32': + # provide a proxy_bypass version on Windows without DNS lookups + + def proxy_bypass_registry(host): + try: + if is_py3: + import winreg + else: + import _winreg as winreg + except ImportError: + return False + + try: + internetSettings = winreg.OpenKey(winreg.HKEY_CURRENT_USER, + r'Software\Microsoft\Windows\CurrentVersion\Internet Settings') + # ProxyEnable could be REG_SZ or REG_DWORD, normalizing it + proxyEnable = int(winreg.QueryValueEx(internetSettings, + 'ProxyEnable')[0]) + # ProxyOverride is almost always a string + proxyOverride = winreg.QueryValueEx(internetSettings, + 'ProxyOverride')[0] + except OSError: + return False + if not proxyEnable or not proxyOverride: + return False + + # make a check value list from the registry entry: replace the + # '' string by the localhost entry and the corresponding + # canonical entry. + proxyOverride = proxyOverride.split(';') + # now check if we match one of the registry values. + for test in proxyOverride: + if test == '': + if '.' not in host: + return True + test = test.replace(".", r"\.") # mask dots + test = test.replace("*", r".*") # change glob sequence + test = test.replace("?", r".") # change glob char + if re.match(test, host, re.I): + return True + return False + + def proxy_bypass(host): # noqa + """Return True, if the host should be bypassed. + + Checks proxy settings gathered from the environment, if specified, + or the registry. + """ + if getproxies_environment(): + return proxy_bypass_environment(host) + else: + return proxy_bypass_registry(host) + + +def dict_to_sequence(d): + """Returns an internal sequence dictionary update.""" + + if hasattr(d, 'items'): + d = d.items() + + return d + + +def super_len(o): + total_length = None + current_position = 0 + + if hasattr(o, '__len__'): + total_length = len(o) + + elif hasattr(o, 'len'): + total_length = o.len + + elif hasattr(o, 'fileno'): + try: + fileno = o.fileno() + except io.UnsupportedOperation: + pass + else: + total_length = os.fstat(fileno).st_size + + # Having used fstat to determine the file length, we need to + # confirm that this file was opened up in binary mode. + if 'b' not in o.mode: + warnings.warn(( + "Requests has determined the content-length for this " + "request using the binary size of the file: however, the " + "file has been opened in text mode (i.e. without the 'b' " + "flag in the mode). This may lead to an incorrect " + "content-length. In Requests 3.0, support will be removed " + "for files in text mode."), + FileModeWarning + ) + + if hasattr(o, 'tell'): + try: + current_position = o.tell() + except (OSError, IOError): + # This can happen in some weird situations, such as when the file + # is actually a special file descriptor like stdin. In this + # instance, we don't know what the length is, so set it to zero and + # let requests chunk it instead. + if total_length is not None: + current_position = total_length + else: + if hasattr(o, 'seek') and total_length is None: + # StringIO and BytesIO have seek but no useable fileno + try: + # seek to end of file + o.seek(0, 2) + total_length = o.tell() + + # seek back to current position to support + # partially read file-like objects + o.seek(current_position or 0) + except (OSError, IOError): + total_length = 0 + + if total_length is None: + total_length = 0 + + return max(0, total_length - current_position) + + +def get_netrc_auth(url, raise_errors=False): + """Returns the Requests tuple auth for a given url from netrc.""" + + try: + from netrc import netrc, NetrcParseError + + netrc_path = None + + for f in NETRC_FILES: + try: + loc = os.path.expanduser('~/{}'.format(f)) + except KeyError: + # os.path.expanduser can fail when $HOME is undefined and + # getpwuid fails. See https://bugs.python.org/issue20164 & + # https://github.com/psf/requests/issues/1846 + return + + if os.path.exists(loc): + netrc_path = loc + break + + # Abort early if there isn't one. + if netrc_path is None: + return + + ri = urlparse(url) + + # Strip port numbers from netloc. This weird `if...encode`` dance is + # used for Python 3.2, which doesn't support unicode literals. + splitstr = b':' + if isinstance(url, str): + splitstr = splitstr.decode('ascii') + host = ri.netloc.split(splitstr)[0] + + try: + _netrc = netrc(netrc_path).authenticators(host) + if _netrc: + # Return with login / password + login_i = (0 if _netrc[0] else 1) + return (_netrc[login_i], _netrc[2]) + except (NetrcParseError, IOError): + # If there was a parsing error or a permissions issue reading the file, + # we'll just skip netrc auth unless explicitly asked to raise errors. + if raise_errors: + raise + + # AppEngine hackiness. + except (ImportError, AttributeError): + pass + + +def guess_filename(obj): + """Tries to guess the filename of the given object.""" + name = getattr(obj, 'name', None) + if (name and isinstance(name, basestring) and name[0] != '<' and + name[-1] != '>'): + return os.path.basename(name) + + +def extract_zipped_paths(path): + """Replace nonexistent paths that look like they refer to a member of a zip + archive with the location of an extracted copy of the target, or else + just return the provided path unchanged. + """ + if os.path.exists(path): + # this is already a valid path, no need to do anything further + return path + + # find the first valid part of the provided path and treat that as a zip archive + # assume the rest of the path is the name of a member in the archive + archive, member = os.path.split(path) + while archive and not os.path.exists(archive): + archive, prefix = os.path.split(archive) + member = '/'.join([prefix, member]) + + if not zipfile.is_zipfile(archive): + return path + + zip_file = zipfile.ZipFile(archive) + if member not in zip_file.namelist(): + return path + + # we have a valid zip archive and a valid member of that archive + tmp = tempfile.gettempdir() + extracted_path = os.path.join(tmp, *member.split('/')) + if not os.path.exists(extracted_path): + extracted_path = zip_file.extract(member, path=tmp) + + return extracted_path + + +def from_key_val_list(value): + """Take an object and test to see if it can be represented as a + dictionary. Unless it can not be represented as such, return an + OrderedDict, e.g., + + :: + + >>> from_key_val_list([('key', 'val')]) + OrderedDict([('key', 'val')]) + >>> from_key_val_list('string') + Traceback (most recent call last): + ... + ValueError: cannot encode objects that are not 2-tuples + >>> from_key_val_list({'key': 'val'}) + OrderedDict([('key', 'val')]) + + :rtype: OrderedDict + """ + if value is None: + return None + + if isinstance(value, (str, bytes, bool, int)): + raise ValueError('cannot encode objects that are not 2-tuples') + + return OrderedDict(value) + + +def to_key_val_list(value): + """Take an object and test to see if it can be represented as a + dictionary. If it can be, return a list of tuples, e.g., + + :: + + >>> to_key_val_list([('key', 'val')]) + [('key', 'val')] + >>> to_key_val_list({'key': 'val'}) + [('key', 'val')] + >>> to_key_val_list('string') + Traceback (most recent call last): + ... + ValueError: cannot encode objects that are not 2-tuples + + :rtype: list + """ + if value is None: + return None + + if isinstance(value, (str, bytes, bool, int)): + raise ValueError('cannot encode objects that are not 2-tuples') + + if isinstance(value, Mapping): + value = value.items() + + return list(value) + + +# From mitsuhiko/werkzeug (used with permission). +def parse_list_header(value): + """Parse lists as described by RFC 2068 Section 2. + + In particular, parse comma-separated lists where the elements of + the list may include quoted-strings. A quoted-string could + contain a comma. A non-quoted string could have quotes in the + middle. Quotes are removed automatically after parsing. + + It basically works like :func:`parse_set_header` just that items + may appear multiple times and case sensitivity is preserved. + + The return value is a standard :class:`list`: + + >>> parse_list_header('token, "quoted value"') + ['token', 'quoted value'] + + To create a header from the :class:`list` again, use the + :func:`dump_header` function. + + :param value: a string with a list header. + :return: :class:`list` + :rtype: list + """ + result = [] + for item in _parse_list_header(value): + if item[:1] == item[-1:] == '"': + item = unquote_header_value(item[1:-1]) + result.append(item) + return result + + +# From mitsuhiko/werkzeug (used with permission). +def parse_dict_header(value): + """Parse lists of key, value pairs as described by RFC 2068 Section 2 and + convert them into a python dict: + + >>> d = parse_dict_header('foo="is a fish", bar="as well"') + >>> type(d) is dict + True + >>> sorted(d.items()) + [('bar', 'as well'), ('foo', 'is a fish')] + + If there is no value for a key it will be `None`: + + >>> parse_dict_header('key_without_value') + {'key_without_value': None} + + To create a header from the :class:`dict` again, use the + :func:`dump_header` function. + + :param value: a string with a dict header. + :return: :class:`dict` + :rtype: dict + """ + result = {} + for item in _parse_list_header(value): + if '=' not in item: + result[item] = None + continue + name, value = item.split('=', 1) + if value[:1] == value[-1:] == '"': + value = unquote_header_value(value[1:-1]) + result[name] = value + return result + + +# From mitsuhiko/werkzeug (used with permission). +def unquote_header_value(value, is_filename=False): + r"""Unquotes a header value. (Reversal of :func:`quote_header_value`). + This does not use the real unquoting but what browsers are actually + using for quoting. + + :param value: the header value to unquote. + :rtype: str + """ + if value and value[0] == value[-1] == '"': + # this is not the real unquoting, but fixing this so that the + # RFC is met will result in bugs with internet explorer and + # probably some other browsers as well. IE for example is + # uploading files with "C:\foo\bar.txt" as filename + value = value[1:-1] + + # if this is a filename and the starting characters look like + # a UNC path, then just return the value without quotes. Using the + # replace sequence below on a UNC path has the effect of turning + # the leading double slash into a single slash and then + # _fix_ie_filename() doesn't work correctly. See #458. + if not is_filename or value[:2] != '\\\\': + return value.replace('\\\\', '\\').replace('\\"', '"') + return value + + +def dict_from_cookiejar(cj): + """Returns a key/value dictionary from a CookieJar. + + :param cj: CookieJar object to extract cookies from. + :rtype: dict + """ + + cookie_dict = {} + + for cookie in cj: + cookie_dict[cookie.name] = cookie.value + + return cookie_dict + + +def add_dict_to_cookiejar(cj, cookie_dict): + """Returns a CookieJar from a key/value dictionary. + + :param cj: CookieJar to insert cookies into. + :param cookie_dict: Dict of key/values to insert into CookieJar. + :rtype: CookieJar + """ + + return cookiejar_from_dict(cookie_dict, cj) + + +def get_encodings_from_content(content): + """Returns encodings from given content string. + + :param content: bytestring to extract encodings from. + """ + warnings.warn(( + 'In requests 3.0, get_encodings_from_content will be removed. For ' + 'more information, please see the discussion on issue #2266. (This' + ' warning should only appear once.)'), + DeprecationWarning) + + charset_re = re.compile(r']', flags=re.I) + pragma_re = re.compile(r']', flags=re.I) + xml_re = re.compile(r'^<\?xml.*?encoding=["\']*(.+?)["\'>]') + + return (charset_re.findall(content) + + pragma_re.findall(content) + + xml_re.findall(content)) + + +def _parse_content_type_header(header): + """Returns content type and parameters from given header + + :param header: string + :return: tuple containing content type and dictionary of + parameters + """ + + tokens = header.split(';') + content_type, params = tokens[0].strip(), tokens[1:] + params_dict = {} + items_to_strip = "\"' " + + for param in params: + param = param.strip() + if param: + key, value = param, True + index_of_equals = param.find("=") + if index_of_equals != -1: + key = param[:index_of_equals].strip(items_to_strip) + value = param[index_of_equals + 1:].strip(items_to_strip) + params_dict[key.lower()] = value + return content_type, params_dict + + +def get_encoding_from_headers(headers): + """Returns encodings from given HTTP Header Dict. + + :param headers: dictionary to extract encoding from. + :rtype: str + """ + + content_type = headers.get('content-type') + + if not content_type: + return None + + content_type, params = _parse_content_type_header(content_type) + + if 'charset' in params: + return params['charset'].strip("'\"") + + if 'text' in content_type: + return 'ISO-8859-1' + + +def stream_decode_response_unicode(iterator, r): + """Stream decodes a iterator.""" + + if r.encoding is None: + for item in iterator: + yield item + return + + decoder = codecs.getincrementaldecoder(r.encoding)(errors='replace') + for chunk in iterator: + rv = decoder.decode(chunk) + if rv: + yield rv + rv = decoder.decode(b'', final=True) + if rv: + yield rv + + +def iter_slices(string, slice_length): + """Iterate over slices of a string.""" + pos = 0 + if slice_length is None or slice_length <= 0: + slice_length = len(string) + while pos < len(string): + yield string[pos:pos + slice_length] + pos += slice_length + + +def get_unicode_from_response(r): + """Returns the requested content back in unicode. + + :param r: Response object to get unicode content from. + + Tried: + + 1. charset from content-type + 2. fall back and replace all unicode characters + + :rtype: str + """ + warnings.warn(( + 'In requests 3.0, get_unicode_from_response will be removed. For ' + 'more information, please see the discussion on issue #2266. (This' + ' warning should only appear once.)'), + DeprecationWarning) + + tried_encodings = [] + + # Try charset from content-type + encoding = get_encoding_from_headers(r.headers) + + if encoding: + try: + return str(r.content, encoding) + except UnicodeError: + tried_encodings.append(encoding) + + # Fall back: + try: + return str(r.content, encoding, errors='replace') + except TypeError: + return r.content + + +# The unreserved URI characters (RFC 3986) +UNRESERVED_SET = frozenset( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~") + + +def unquote_unreserved(uri): + """Un-escape any percent-escape sequences in a URI that are unreserved + characters. This leaves all reserved, illegal and non-ASCII bytes encoded. + + :rtype: str + """ + parts = uri.split('%') + for i in range(1, len(parts)): + h = parts[i][0:2] + if len(h) == 2 and h.isalnum(): + try: + c = chr(int(h, 16)) + except ValueError: + raise InvalidURL("Invalid percent-escape sequence: '%s'" % h) + + if c in UNRESERVED_SET: + parts[i] = c + parts[i][2:] + else: + parts[i] = '%' + parts[i] + else: + parts[i] = '%' + parts[i] + return ''.join(parts) + + +def requote_uri(uri): + """Re-quote the given URI. + + This function passes the given URI through an unquote/quote cycle to + ensure that it is fully and consistently quoted. + + :rtype: str + """ + safe_with_percent = "!#$%&'()*+,/:;=?@[]~" + safe_without_percent = "!#$&'()*+,/:;=?@[]~" + try: + # Unquote only the unreserved characters + # Then quote only illegal characters (do not quote reserved, + # unreserved, or '%') + return quote(unquote_unreserved(uri), safe=safe_with_percent) + except InvalidURL: + # We couldn't unquote the given URI, so let's try quoting it, but + # there may be unquoted '%'s in the URI. We need to make sure they're + # properly quoted so they do not cause issues elsewhere. + return quote(uri, safe=safe_without_percent) + + +def address_in_network(ip, net): + """This function allows you to check if an IP belongs to a network subnet + + Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24 + returns False if ip = 192.168.1.1 and net = 192.168.100.0/24 + + :rtype: bool + """ + ipaddr = struct.unpack('=L', socket.inet_aton(ip))[0] + netaddr, bits = net.split('/') + netmask = struct.unpack('=L', socket.inet_aton(dotted_netmask(int(bits))))[0] + network = struct.unpack('=L', socket.inet_aton(netaddr))[0] & netmask + return (ipaddr & netmask) == (network & netmask) + + +def dotted_netmask(mask): + """Converts mask from /xx format to xxx.xxx.xxx.xxx + + Example: if mask is 24 function returns 255.255.255.0 + + :rtype: str + """ + bits = 0xffffffff ^ (1 << 32 - mask) - 1 + return socket.inet_ntoa(struct.pack('>I', bits)) + + +def is_ipv4_address(string_ip): + """ + :rtype: bool + """ + try: + socket.inet_aton(string_ip) + except socket.error: + return False + return True + + +def is_valid_cidr(string_network): + """ + Very simple check of the cidr format in no_proxy variable. + + :rtype: bool + """ + if string_network.count('/') == 1: + try: + mask = int(string_network.split('/')[1]) + except ValueError: + return False + + if mask < 1 or mask > 32: + return False + + try: + socket.inet_aton(string_network.split('/')[0]) + except socket.error: + return False + else: + return False + return True + + +@contextlib.contextmanager +def set_environ(env_name, value): + """Set the environment variable 'env_name' to 'value' + + Save previous value, yield, and then restore the previous value stored in + the environment variable 'env_name'. + + If 'value' is None, do nothing""" + value_changed = value is not None + if value_changed: + old_value = os.environ.get(env_name) + os.environ[env_name] = value + try: + yield + finally: + if value_changed: + if old_value is None: + del os.environ[env_name] + else: + os.environ[env_name] = old_value + + +def should_bypass_proxies(url, no_proxy): + """ + Returns whether we should bypass proxies or not. + + :rtype: bool + """ + # Prioritize lowercase environment variables over uppercase + # to keep a consistent behaviour with other http projects (curl, wget). + get_proxy = lambda k: os.environ.get(k) or os.environ.get(k.upper()) + + # First check whether no_proxy is defined. If it is, check that the URL + # we're getting isn't in the no_proxy list. + no_proxy_arg = no_proxy + if no_proxy is None: + no_proxy = get_proxy('no_proxy') + parsed = urlparse(url) + + if parsed.hostname is None: + # URLs don't always have hostnames, e.g. file:/// urls. + return True + + if no_proxy: + # We need to check whether we match here. We need to see if we match + # the end of the hostname, both with and without the port. + no_proxy = ( + host for host in no_proxy.replace(' ', '').split(',') if host + ) + + if is_ipv4_address(parsed.hostname): + for proxy_ip in no_proxy: + if is_valid_cidr(proxy_ip): + if address_in_network(parsed.hostname, proxy_ip): + return True + elif parsed.hostname == proxy_ip: + # If no_proxy ip was defined in plain IP notation instead of cidr notation & + # matches the IP of the index + return True + else: + host_with_port = parsed.hostname + if parsed.port: + host_with_port += ':{}'.format(parsed.port) + + for host in no_proxy: + if parsed.hostname.endswith(host) or host_with_port.endswith(host): + # The URL does match something in no_proxy, so we don't want + # to apply the proxies on this URL. + return True + + with set_environ('no_proxy', no_proxy_arg): + # parsed.hostname can be `None` in cases such as a file URI. + try: + bypass = proxy_bypass(parsed.hostname) + except (TypeError, socket.gaierror): + bypass = False + + if bypass: + return True + + return False + + +def get_environ_proxies(url, no_proxy=None): + """ + Return a dict of environment proxies. + + :rtype: dict + """ + if should_bypass_proxies(url, no_proxy=no_proxy): + return {} + else: + return getproxies() + + +def select_proxy(url, proxies): + """Select a proxy for the url, if applicable. + + :param url: The url being for the request + :param proxies: A dictionary of schemes or schemes and hosts to proxy URLs + """ + proxies = proxies or {} + urlparts = urlparse(url) + if urlparts.hostname is None: + return proxies.get(urlparts.scheme, proxies.get('all')) + + proxy_keys = [ + urlparts.scheme + '://' + urlparts.hostname, + urlparts.scheme, + 'all://' + urlparts.hostname, + 'all', + ] + proxy = None + for proxy_key in proxy_keys: + if proxy_key in proxies: + proxy = proxies[proxy_key] + break + + return proxy + + +def default_user_agent(name="python-requests"): + """ + Return a string representing the default user agent. + + :rtype: str + """ + return '%s/%s' % (name, __version__) + + +def default_headers(): + """ + :rtype: requests.structures.CaseInsensitiveDict + """ + return CaseInsensitiveDict({ + 'User-Agent': default_user_agent(), + 'Accept-Encoding': ', '.join(('gzip', 'deflate')), + 'Accept': '*/*', + 'Connection': 'keep-alive', + }) + + +def parse_header_links(value): + """Return a list of parsed link headers proxies. + + i.e. Link: ; rel=front; type="image/jpeg",; rel=back;type="image/jpeg" + + :rtype: list + """ + + links = [] + + replace_chars = ' \'"' + + value = value.strip(replace_chars) + if not value: + return links + + for val in re.split(', *<', value): + try: + url, params = val.split(';', 1) + except ValueError: + url, params = val, '' + + link = {'url': url.strip('<> \'"')} + + for param in params.split(';'): + try: + key, value = param.split('=') + except ValueError: + break + + link[key.strip(replace_chars)] = value.strip(replace_chars) + + links.append(link) + + return links + + +# Null bytes; no need to recreate these on each call to guess_json_utf +_null = '\x00'.encode('ascii') # encoding to ASCII for Python 3 +_null2 = _null * 2 +_null3 = _null * 3 + + +def guess_json_utf(data): + """ + :rtype: str + """ + # JSON always starts with two ASCII characters, so detection is as + # easy as counting the nulls and from their location and count + # determine the encoding. Also detect a BOM, if present. + sample = data[:4] + if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): + return 'utf-32' # BOM included + if sample[:3] == codecs.BOM_UTF8: + return 'utf-8-sig' # BOM included, MS style (discouraged) + if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): + return 'utf-16' # BOM included + nullcount = sample.count(_null) + if nullcount == 0: + return 'utf-8' + if nullcount == 2: + if sample[::2] == _null2: # 1st and 3rd are null + return 'utf-16-be' + if sample[1::2] == _null2: # 2nd and 4th are null + return 'utf-16-le' + # Did not detect 2 valid UTF-16 ascii-range characters + if nullcount == 3: + if sample[:3] == _null3: + return 'utf-32-be' + if sample[1:] == _null3: + return 'utf-32-le' + # Did not detect a valid UTF-32 ascii-range character + return None + + +def prepend_scheme_if_needed(url, new_scheme): + """Given a URL that may or may not have a scheme, prepend the given scheme. + Does not replace a present scheme with the one provided as an argument. + + :rtype: str + """ + scheme, netloc, path, params, query, fragment = urlparse(url, new_scheme) + + # urlparse is a finicky beast, and sometimes decides that there isn't a + # netloc present. Assume that it's being over-cautious, and switch netloc + # and path if urlparse decided there was no netloc. + if not netloc: + netloc, path = path, netloc + + return urlunparse((scheme, netloc, path, params, query, fragment)) + + +def get_auth_from_url(url): + """Given a url with authentication components, extract them into a tuple of + username,password. + + :rtype: (str,str) + """ + parsed = urlparse(url) + + try: + auth = (unquote(parsed.username), unquote(parsed.password)) + except (AttributeError, TypeError): + auth = ('', '') + + return auth + + +# Moved outside of function to avoid recompile every call +_CLEAN_HEADER_REGEX_BYTE = re.compile(b'^\\S[^\\r\\n]*$|^$') +_CLEAN_HEADER_REGEX_STR = re.compile(r'^\S[^\r\n]*$|^$') + + +def check_header_validity(header): + """Verifies that header value is a string which doesn't contain + leading whitespace or return characters. This prevents unintended + header injection. + + :param header: tuple, in the format (name, value). + """ + name, value = header + + if isinstance(value, bytes): + pat = _CLEAN_HEADER_REGEX_BYTE + else: + pat = _CLEAN_HEADER_REGEX_STR + try: + if not pat.match(value): + raise InvalidHeader("Invalid return character or leading space in header: %s" % name) + except TypeError: + raise InvalidHeader("Value for header {%s: %s} must be of type str or " + "bytes, not %s" % (name, value, type(value))) + + +def urldefragauth(url): + """ + Given a url remove the fragment and the authentication part. + + :rtype: str + """ + scheme, netloc, path, params, query, fragment = urlparse(url) + + # see func:`prepend_scheme_if_needed` + if not netloc: + netloc, path = path, netloc + + netloc = netloc.rsplit('@', 1)[-1] + + return urlunparse((scheme, netloc, path, params, query, '')) + + +def rewind_body(prepared_request): + """Move file pointer back to its recorded starting position + so it can be read again on redirect. + """ + body_seek = getattr(prepared_request.body, 'seek', None) + if body_seek is not None and isinstance(prepared_request._body_position, integer_types): + try: + body_seek(prepared_request._body_position) + except (IOError, OSError): + raise UnrewindableBodyError("An error occurred when rewinding request " + "body for redirect.") + else: + raise UnrewindableBodyError("Unable to rewind request body for redirect.")