Files
thrillwiki_django_no_react/venv/lib/python3.11/site-packages/jwcrypto/jwa.py
2024-10-28 17:09:57 -04:00

1203 lines
34 KiB
Python

# Copyright (C) 2016 JWCrypto Project Contributors - see LICENSE file
import os
import struct
from abc import ABCMeta, abstractmethod
from binascii import hexlify, unhexlify
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import constant_time, hashes, hmac
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import utils as ec_utils
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap, aes_key_wrap
from cryptography.hazmat.primitives.padding import PKCS7
from jwcrypto.common import InvalidCEKeyLength
from jwcrypto.common import InvalidJWAAlgorithm
from jwcrypto.common import InvalidJWEKeyLength
from jwcrypto.common import InvalidJWEKeyType
from jwcrypto.common import InvalidJWEOperation
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode
from jwcrypto.jwk import JWK
# Implements RFC 7518 - JSON Web Algorithms (JWA)
default_max_pbkdf2_iterations = 16384
class JWAAlgorithm(metaclass=ABCMeta):
@property
@abstractmethod
def name(self):
"""The algorithm Name"""
@property
@abstractmethod
def description(self):
"""A short description"""
@property
@abstractmethod
def keysize(self):
"""The algorithm key size"""
@property
@abstractmethod
def algorithm_usage_location(self):
"""One of 'alg', 'enc' or 'JWK'"""
@property
@abstractmethod
def algorithm_use(self):
"""One of 'sig', 'kex', 'enc'"""
@property
def input_keysize(self):
"""The input key size"""
try:
return self.wrap_key_size
except AttributeError:
return self.keysize
def _bitsize(x):
return len(x) * 8
def _inbytes(x):
return x // 8
def _randombits(x):
if x % 8 != 0:
raise ValueError("length must be a multiple of 8")
return os.urandom(_inbytes(x))
# Note: the number of bits should be a multiple of 16
def _encode_int(n, bits):
e = '{:x}'.format(n)
ilen = ((bits + 7) // 8) * 2 # number of bytes rounded up times 2 bytes
return unhexlify(e.rjust(ilen, '0')[:ilen])
def _decode_int(n):
return int(hexlify(n), 16)
class _RawJWS:
def sign(self, key, payload):
raise NotImplementedError
def verify(self, key, payload, signature):
raise NotImplementedError
class _RawHMAC(_RawJWS):
def __init__(self, hashfn):
self.backend = default_backend()
self.hashfn = hashfn
def _hmac_setup(self, key, payload):
h = hmac.HMAC(key, self.hashfn, backend=self.backend)
h.update(payload)
return h
def sign(self, key, payload):
skey = base64url_decode(key.get_op_key('sign'))
h = self._hmac_setup(skey, payload)
return h.finalize()
def verify(self, key, payload, signature):
vkey = base64url_decode(key.get_op_key('verify'))
h = self._hmac_setup(vkey, payload)
h.verify(signature)
class _RawRSA(_RawJWS):
def __init__(self, padfn, hashfn):
self.padfn = padfn
self.hashfn = hashfn
def sign(self, key, payload):
skey = key.get_op_key('sign')
return skey.sign(payload, self.padfn, self.hashfn)
def verify(self, key, payload, signature):
pkey = key.get_op_key('verify')
pkey.verify(signature, payload, self.padfn, self.hashfn)
class _RawEC(_RawJWS):
def __init__(self, curve, hashfn):
self._curve = curve
self.hashfn = hashfn
@property
def curve(self):
return self._curve
def sign(self, key, payload):
skey = key.get_op_key('sign', self._curve)
size = skey.key_size
signature = skey.sign(payload, ec.ECDSA(self.hashfn))
r, s = ec_utils.decode_dss_signature(signature)
return _encode_int(r, size) + _encode_int(s, size)
def verify(self, key, payload, signature):
pkey = key.get_op_key('verify', self._curve)
r = signature[:len(signature) // 2]
s = signature[len(signature) // 2:]
enc_signature = ec_utils.encode_dss_signature(
int(hexlify(r), 16), int(hexlify(s), 16))
pkey.verify(enc_signature, payload, ec.ECDSA(self.hashfn))
class _RawNone(_RawJWS):
def sign(self, key, payload):
return ''
def verify(self, key, payload, signature):
if key['kty'] != 'oct' or key.get_op_key() != '':
raise InvalidSignature('The "none" signature cannot be verified')
class _HS256(_RawHMAC, JWAAlgorithm):
name = "HS256"
description = "HMAC using SHA-256"
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_HS256, self).__init__(hashes.SHA256())
class _HS384(_RawHMAC, JWAAlgorithm):
name = "HS384"
description = "HMAC using SHA-384"
keysize = 384
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_HS384, self).__init__(hashes.SHA384())
class _HS512(_RawHMAC, JWAAlgorithm):
name = "HS512"
description = "HMAC using SHA-512"
keysize = 512
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_HS512, self).__init__(hashes.SHA512())
class _RS256(_RawRSA, JWAAlgorithm):
name = "RS256"
description = "RSASSA-PKCS1-v1_5 using SHA-256"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_RS256, self).__init__(padding.PKCS1v15(), hashes.SHA256())
class _RS384(_RawRSA, JWAAlgorithm):
name = "RS384"
description = "RSASSA-PKCS1-v1_5 using SHA-384"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_RS384, self).__init__(padding.PKCS1v15(), hashes.SHA384())
class _RS512(_RawRSA, JWAAlgorithm):
name = "RS512"
description = "RSASSA-PKCS1-v1_5 using SHA-512"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_RS512, self).__init__(padding.PKCS1v15(), hashes.SHA512())
class _ES256(_RawEC, JWAAlgorithm):
name = "ES256"
description = "ECDSA using P-256 and SHA-256"
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_ES256, self).__init__('P-256', hashes.SHA256())
class _ES256K(_RawEC, JWAAlgorithm):
name = "ES256K"
description = "ECDSA using secp256k1 curve and SHA-256"
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_ES256K, self).__init__('secp256k1', hashes.SHA256())
class _ES384(_RawEC, JWAAlgorithm):
name = "ES384"
description = "ECDSA using P-384 and SHA-384"
keysize = 384
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_ES384, self).__init__('P-384', hashes.SHA384())
class _ES512(_RawEC, JWAAlgorithm):
name = "ES512"
description = "ECDSA using P-521 and SHA-512"
keysize = 512
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_ES512, self).__init__('P-521', hashes.SHA512())
class _PS256(_RawRSA, JWAAlgorithm):
name = "PS256"
description = "RSASSA-PSS using SHA-256 and MGF1 with SHA-256"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
padfn = padding.PSS(padding.MGF1(hashes.SHA256()),
hashes.SHA256.digest_size)
super(_PS256, self).__init__(padfn, hashes.SHA256())
class _PS384(_RawRSA, JWAAlgorithm):
name = "PS384"
description = "RSASSA-PSS using SHA-384 and MGF1 with SHA-384"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
padfn = padding.PSS(padding.MGF1(hashes.SHA384()),
hashes.SHA384.digest_size)
super(_PS384, self).__init__(padfn, hashes.SHA384())
class _PS512(_RawRSA, JWAAlgorithm):
name = "PS512"
description = "RSASSA-PSS using SHA-512 and MGF1 with SHA-512"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
padfn = padding.PSS(padding.MGF1(hashes.SHA512()),
hashes.SHA512.digest_size)
super(_PS512, self).__init__(padfn, hashes.SHA512())
class _None(_RawNone, JWAAlgorithm):
name = "none"
description = "No digital signature or MAC performed"
keysize = 0
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
class _RawKeyMgmt:
def wrap(self, key, bitsize, cek, headers):
raise NotImplementedError
def unwrap(self, key, bitsize, ek, headers):
raise NotImplementedError
class _RSA(_RawKeyMgmt):
def __init__(self, padfn):
self.padfn = padfn
def _check_key(self, key):
if not isinstance(key, JWK):
raise ValueError('key is not a JWK object')
if key['kty'] != 'RSA':
raise InvalidJWEKeyType('RSA', key['kty'])
# FIXME: get key size and ensure > 2048 bits
def wrap(self, key, bitsize, cek, headers):
self._check_key(key)
if not cek:
cek = _randombits(bitsize)
rk = key.get_op_key('wrapKey')
ek = rk.encrypt(cek, self.padfn)
return {'cek': cek, 'ek': ek}
def unwrap(self, key, bitsize, ek, headers):
self._check_key(key)
rk = key.get_op_key('unwrapKey')
cek = rk.decrypt(ek, self.padfn)
if _bitsize(cek) != bitsize:
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
return cek
class _Rsa15(_RSA, JWAAlgorithm):
name = 'RSA1_5'
description = "RSAES-PKCS1-v1_5"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
def __init__(self):
super(_Rsa15, self).__init__(padding.PKCS1v15())
def unwrap(self, key, bitsize, ek, headers):
self._check_key(key)
# Address MMA attack by implementing RFC 3218 - 2.3.2. Random Filling
# provides a random cek that will cause the decryption engine to
# run to the end, but will fail decryption later.
# always generate a random cek so we spend roughly the
# same time as in the exception side of the branch
cek = _randombits(bitsize)
try:
cek = super(_Rsa15, self).unwrap(key, bitsize, ek, headers)
# always raise so we always run through the exception handling
# code in all cases
raise ValueError('Dummy')
except Exception: # pylint: disable=broad-except
return cek
class _RsaOaep(_RSA, JWAAlgorithm):
name = 'RSA-OAEP'
description = "RSAES OAEP using default parameters"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
def __init__(self):
super(_RsaOaep, self).__init__(
padding.OAEP(padding.MGF1(hashes.SHA1()),
hashes.SHA1(), None))
class _RsaOaep256(_RSA, JWAAlgorithm): # noqa: ignore=N801
name = 'RSA-OAEP-256'
description = "RSAES OAEP using SHA-256 and MGF1 with SHA-256"
keysize = 2048
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
def __init__(self):
super(_RsaOaep256, self).__init__(
padding.OAEP(padding.MGF1(hashes.SHA256()),
hashes.SHA256(), None))
class _AesKw(_RawKeyMgmt):
keysize = None
def __init__(self):
self.backend = default_backend()
def _get_key(self, key, op):
if not isinstance(key, JWK):
raise ValueError('key is not a JWK object')
if key['kty'] != 'oct':
raise InvalidJWEKeyType('oct', key['kty'])
rk = base64url_decode(key.get_op_key(op))
if _bitsize(rk) != self.keysize:
raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
return rk
def wrap(self, key, bitsize, cek, headers):
rk = self._get_key(key, 'encrypt')
if not cek:
cek = _randombits(bitsize)
ek = aes_key_wrap(rk, cek, default_backend())
return {'cek': cek, 'ek': ek}
def unwrap(self, key, bitsize, ek, headers):
rk = self._get_key(key, 'decrypt')
cek = aes_key_unwrap(rk, ek, default_backend())
if _bitsize(cek) != bitsize:
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
return cek
class _A128KW(_AesKw, JWAAlgorithm):
name = 'A128KW'
description = "AES Key Wrap using 128-bit key"
keysize = 128
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _A192KW(_AesKw, JWAAlgorithm):
name = 'A192KW'
description = "AES Key Wrap using 192-bit key"
keysize = 192
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _A256KW(_AesKw, JWAAlgorithm):
name = 'A256KW'
description = "AES Key Wrap using 256-bit key"
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _AesGcmKw(_RawKeyMgmt):
keysize = None
def __init__(self):
self.backend = default_backend()
def _get_key(self, key, op):
if not isinstance(key, JWK):
raise ValueError('key is not a JWK object')
if key['kty'] != 'oct':
raise InvalidJWEKeyType('oct', key['kty'])
rk = base64url_decode(key.get_op_key(op))
if _bitsize(rk) != self.keysize:
raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
return rk
def wrap(self, key, bitsize, cek, headers):
rk = self._get_key(key, 'encrypt')
if not cek:
cek = _randombits(bitsize)
iv = _randombits(96)
cipher = Cipher(algorithms.AES(rk), modes.GCM(iv),
backend=self.backend)
encryptor = cipher.encryptor()
ek = encryptor.update(cek) + encryptor.finalize()
tag = encryptor.tag
return {'cek': cek, 'ek': ek,
'header': {'iv': base64url_encode(iv),
'tag': base64url_encode(tag)}}
def unwrap(self, key, bitsize, ek, headers):
rk = self._get_key(key, 'decrypt')
if 'iv' not in headers:
raise ValueError('Invalid Header, missing "iv" parameter')
iv = base64url_decode(headers['iv'])
if 'tag' not in headers:
raise ValueError('Invalid Header, missing "tag" parameter')
tag = base64url_decode(headers['tag'])
cipher = Cipher(algorithms.AES(rk), modes.GCM(iv, tag),
backend=self.backend)
decryptor = cipher.decryptor()
cek = decryptor.update(ek) + decryptor.finalize()
if _bitsize(cek) != bitsize:
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
return cek
class _A128GcmKw(_AesGcmKw, JWAAlgorithm):
name = 'A128GCMKW'
description = "Key wrapping with AES GCM using 128-bit key"
keysize = 128
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _A192GcmKw(_AesGcmKw, JWAAlgorithm):
name = 'A192GCMKW'
description = "Key wrapping with AES GCM using 192-bit key"
keysize = 192
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _A256GcmKw(_AesGcmKw, JWAAlgorithm):
name = 'A256GCMKW'
description = "Key wrapping with AES GCM using 256-bit key"
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _Pbes2HsAesKw(_RawKeyMgmt):
name = None
keysize = None
hashsize = None
def __init__(self):
self.backend = default_backend()
self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}
def _get_key(self, alg, key, p2s, p2c):
if p2c > default_max_pbkdf2_iterations:
raise ValueError('Invalid p2c value, too large')
if not isinstance(key, JWK):
# backwards compatibility for old interface
if isinstance(key, bytes):
plain = key
else:
plain = key.encode('utf8')
else:
plain = base64url_decode(key.get_op_key())
salt = bytes(self.name.encode('utf8')) + b'\x00' + p2s
if self.hashsize == 256:
hashalg = hashes.SHA256()
elif self.hashsize == 384:
hashalg = hashes.SHA384()
elif self.hashsize == 512:
hashalg = hashes.SHA512()
else:
raise ValueError('Unknown Hash Size')
kdf = PBKDF2HMAC(algorithm=hashalg, length=_inbytes(self.keysize),
salt=salt, iterations=p2c, backend=self.backend)
rk = kdf.derive(plain)
if _bitsize(rk) != self.keysize:
raise InvalidJWEKeyLength(self.keysize, len(rk))
return JWK(kty="oct", use="enc", k=base64url_encode(rk))
def wrap(self, key, bitsize, cek, headers):
ret_header = {}
if 'p2s' in headers:
p2s = base64url_decode(headers['p2s'])
if len(p2s) < 8:
raise ValueError('Invalid Salt, must be 8 or more octects')
else:
p2s = _randombits(128)
ret_header['p2s'] = base64url_encode(p2s)
if 'p2c' in headers:
p2c = headers['p2c']
else:
p2c = 8192
ret_header['p2c'] = p2c
kek = self._get_key(headers['alg'], key, p2s, p2c)
aeskw = self.aeskwmap[self.keysize]()
ret = aeskw.wrap(kek, bitsize, cek, headers)
if len(ret_header) > 0:
ret['header'] = ret_header
return ret
def unwrap(self, key, bitsize, ek, headers):
if 'p2s' not in headers:
raise ValueError('Invalid Header, missing "p2s" parameter')
if 'p2c' not in headers:
raise ValueError('Invalid Header, missing "p2c" parameter')
p2s = base64url_decode(headers['p2s'])
p2c = headers['p2c']
kek = self._get_key(headers['alg'], key, p2s, p2c)
aeskw = self.aeskwmap[self.keysize]()
return aeskw.unwrap(kek, bitsize, ek, headers)
class _Pbes2Hs256A128Kw(_Pbes2HsAesKw, JWAAlgorithm):
name = 'PBES2-HS256+A128KW'
description = 'PBES2 with HMAC SHA-256 and "A128KW" wrapping'
keysize = 128
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
hashsize = 256
class _Pbes2Hs384A192Kw(_Pbes2HsAesKw, JWAAlgorithm):
name = 'PBES2-HS384+A192KW'
description = 'PBES2 with HMAC SHA-384 and "A192KW" wrapping'
keysize = 192
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
hashsize = 384
class _Pbes2Hs512A256Kw(_Pbes2HsAesKw, JWAAlgorithm):
name = 'PBES2-HS512+A256KW'
description = 'PBES2 with HMAC SHA-512 and "A256KW" wrapping'
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
hashsize = 512
class _Direct(_RawKeyMgmt, JWAAlgorithm):
name = 'dir'
description = "Direct use of a shared symmetric key"
keysize = 128
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
def _check_key(self, key):
if not isinstance(key, JWK):
raise ValueError('key is not a JWK object')
if key['kty'] != 'oct':
raise InvalidJWEKeyType('oct', key['kty'])
def wrap(self, key, bitsize, cek, headers):
self._check_key(key)
if cek:
return (cek, None)
k = base64url_decode(key.get_op_key('encrypt'))
if _bitsize(k) != bitsize:
raise InvalidCEKeyLength(bitsize, _bitsize(k))
return {'cek': k}
def unwrap(self, key, bitsize, ek, headers):
self._check_key(key)
if ek != b'':
raise ValueError('Invalid Encryption Key.')
cek = base64url_decode(key.get_op_key('decrypt'))
if _bitsize(cek) != bitsize:
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
return cek
class _EcdhEs(_RawKeyMgmt, JWAAlgorithm):
name = 'ECDH-ES'
description = "ECDH-ES using Concat KDF"
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
keysize = None
def __init__(self):
self.backend = default_backend()
self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}
def _check_key(self, key):
if not isinstance(key, JWK):
raise ValueError('key is not a JWK object')
if key['kty'] not in ['EC', 'OKP']:
raise InvalidJWEKeyType('EC or OKP', key['kty'])
if key['kty'] == 'OKP':
if key['crv'] not in ['X25519', 'X448']:
raise InvalidJWEKeyType('X25519 or X448',
key['crv'])
def _derive(self, privkey, pubkey, alg, bitsize, headers):
# OtherInfo is defined in NIST SP 56A 5.8.1.2.1
# AlgorithmID
otherinfo = struct.pack('>I', len(alg))
otherinfo += bytes(alg.encode('utf8'))
# PartyUInfo
apu = base64url_decode(headers['apu']) if 'apu' in headers else b''
otherinfo += struct.pack('>I', len(apu))
otherinfo += apu
# PartyVInfo
apv = base64url_decode(headers['apv']) if 'apv' in headers else b''
otherinfo += struct.pack('>I', len(apv))
otherinfo += apv
# SuppPubInfo
otherinfo += struct.pack('>I', bitsize)
# no SuppPrivInfo
# Shared Key generation
if isinstance(privkey, ec.EllipticCurvePrivateKey):
shared_key = privkey.exchange(ec.ECDH(), pubkey)
else:
# X25519/X448
shared_key = privkey.exchange(pubkey)
ckdf = ConcatKDFHash(algorithm=hashes.SHA256(),
length=_inbytes(bitsize),
otherinfo=otherinfo,
backend=self.backend)
return ckdf.derive(shared_key)
def wrap(self, key, bitsize, cek, headers):
self._check_key(key)
dk_size = self.keysize
if self.keysize is None:
if cek is not None:
raise InvalidJWEOperation('ECDH-ES cannot use an existing CEK')
alg = headers['enc']
dk_size = bitsize
else:
alg = headers['alg']
epk = JWK.generate(kty=key['kty'], crv=key['crv'])
dk = self._derive(epk.get_op_key('unwrapKey'),
key.get_op_key('wrapKey'),
alg, dk_size, headers)
if self.keysize is None:
ret = {'cek': dk}
else:
aeskw = self.aeskwmap[self.keysize]()
kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
ret = aeskw.wrap(kek, bitsize, cek, headers)
ret['header'] = {'epk': json_decode(epk.export_public())}
return ret
def unwrap(self, key, bitsize, ek, headers):
if 'epk' not in headers:
raise ValueError('Invalid Header, missing "epk" parameter')
self._check_key(key)
dk_size = self.keysize
if self.keysize is None:
alg = headers['enc']
dk_size = bitsize
else:
alg = headers['alg']
epk = JWK(**headers['epk'])
dk = self._derive(key.get_op_key('unwrapKey'),
epk.get_op_key('wrapKey'),
alg, dk_size, headers)
if self.keysize is None:
return dk
else:
aeskw = self.aeskwmap[self.keysize]()
kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
cek = aeskw.unwrap(kek, bitsize, ek, headers)
return cek
class _EcdhEsAes128Kw(_EcdhEs):
name = 'ECDH-ES+A128KW'
description = 'ECDH-ES using Concat KDF and "A128KW" wrapping'
keysize = 128
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _EcdhEsAes192Kw(_EcdhEs):
name = 'ECDH-ES+A192KW'
description = 'ECDH-ES using Concat KDF and "A192KW" wrapping'
keysize = 192
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _EcdhEsAes256Kw(_EcdhEs):
name = 'ECDH-ES+A256KW'
description = 'ECDH-ES using Concat KDF and "A256KW" wrapping'
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'kex'
class _EdDsa(_RawJWS, JWAAlgorithm):
name = 'EdDSA'
description = 'EdDSA using Ed25519 or Ed448 algorithms'
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
keysize = None
def sign(self, key, payload):
if key['crv'] in ['Ed25519', 'Ed448']:
skey = key.get_op_key('sign')
return skey.sign(payload)
raise NotImplementedError
def verify(self, key, payload, signature):
if key['crv'] in ['Ed25519', 'Ed448']:
pkey = key.get_op_key('verify')
return pkey.verify(signature, payload)
raise NotImplementedError
class _RawJWE:
def encrypt(self, k, aad, m):
raise NotImplementedError
def decrypt(self, k, aad, iv, e, t):
raise NotImplementedError
class _AesCbcHmacSha2(_RawJWE):
keysize = None
def __init__(self, hashfn):
self.backend = default_backend()
self.hashfn = hashfn
self.blocksize = algorithms.AES.block_size
self.wrap_key_size = self.keysize * 2
def _mac(self, k, aad, iv, e):
al = _encode_int(_bitsize(aad), 64)
h = hmac.HMAC(k, self.hashfn, backend=self.backend)
h.update(aad)
h.update(iv)
h.update(e)
h.update(al)
m = h.finalize()
return m[:_inbytes(self.keysize)]
# RFC 7518 - 5.2.2
def encrypt(self, k, aad, m):
""" Encrypt according to the selected encryption and hashing
functions.
:param k: Encryption key
:param aad: Additional Authentication Data
:param m: Plaintext
Returns a dictionary with the computed data.
"""
if len(k) != _inbytes(self.wrap_key_size):
raise ValueError("Invalid input key size")
hkey = k[:_inbytes(self.keysize)]
ekey = k[_inbytes(self.keysize):]
# encrypt
iv = _randombits(self.blocksize)
cipher = Cipher(algorithms.AES(ekey), modes.CBC(iv),
backend=self.backend)
encryptor = cipher.encryptor()
padder = PKCS7(self.blocksize).padder()
padded_data = padder.update(m) + padder.finalize()
e = encryptor.update(padded_data) + encryptor.finalize()
# mac
t = self._mac(hkey, aad, iv, e)
return (iv, e, t)
def decrypt(self, k, aad, iv, e, t):
""" Decrypt according to the selected encryption and hashing
functions.
:param k: Encryption key
:param aad: Additional Authenticated Data
:param iv: Initialization Vector
:param e: Ciphertext
:param t: Authentication Tag
Returns plaintext or raises an error
"""
if len(k) != _inbytes(self.wrap_key_size):
raise ValueError("Invalid input key size")
hkey = k[:_inbytes(self.keysize)]
dkey = k[_inbytes(self.keysize):]
# verify mac
if not constant_time.bytes_eq(t, self._mac(hkey, aad, iv, e)):
raise InvalidSignature('Failed to verify MAC')
# decrypt
cipher = Cipher(algorithms.AES(dkey), modes.CBC(iv),
backend=self.backend)
decryptor = cipher.decryptor()
d = decryptor.update(e) + decryptor.finalize()
unpadder = PKCS7(self.blocksize).unpadder()
return unpadder.update(d) + unpadder.finalize()
class _A128CbcHs256(_AesCbcHmacSha2, JWAAlgorithm):
name = 'A128CBC-HS256'
description = "AES_128_CBC_HMAC_SHA_256 authenticated"
keysize = 128
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
def __init__(self):
super(_A128CbcHs256, self).__init__(hashes.SHA256())
class _A192CbcHs384(_AesCbcHmacSha2, JWAAlgorithm):
name = 'A192CBC-HS384'
description = "AES_192_CBC_HMAC_SHA_384 authenticated"
keysize = 192
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
def __init__(self):
super(_A192CbcHs384, self).__init__(hashes.SHA384())
class _A256CbcHs512(_AesCbcHmacSha2, JWAAlgorithm):
name = 'A256CBC-HS512'
description = "AES_256_CBC_HMAC_SHA_512 authenticated"
keysize = 256
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
def __init__(self):
super(_A256CbcHs512, self).__init__(hashes.SHA512())
class _AesGcm(_RawJWE):
keysize = None
def __init__(self):
self.backend = default_backend()
self.wrap_key_size = self.keysize
# RFC 7518 - 5.3
def encrypt(self, k, aad, m):
""" Encrypt according to the selected encryption and hashing
functions.
:param k: Encryption key
:param aad: Additional Authentication Data
:param m: Plaintext
Returns a dictionary with the computed data.
"""
iv = _randombits(96)
cipher = Cipher(algorithms.AES(k), modes.GCM(iv),
backend=self.backend)
encryptor = cipher.encryptor()
encryptor.authenticate_additional_data(aad)
e = encryptor.update(m) + encryptor.finalize()
return (iv, e, encryptor.tag)
def decrypt(self, k, aad, iv, e, t):
""" Decrypt according to the selected encryption and hashing
functions.
:param k: Encryption key
:param aad: Additional Authenticated Data
:param iv: Initialization Vector
:param e: Ciphertext
:param t: Authentication Tag
Returns plaintext or raises an error
"""
cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t),
backend=self.backend)
decryptor = cipher.decryptor()
decryptor.authenticate_additional_data(aad)
return decryptor.update(e) + decryptor.finalize()
class _A128Gcm(_AesGcm, JWAAlgorithm):
name = 'A128GCM'
description = "AES GCM using 128-bit key"
keysize = 128
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
class _A192Gcm(_AesGcm, JWAAlgorithm):
name = 'A192GCM'
description = "AES GCM using 192-bit key"
keysize = 192
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
class _A256Gcm(_AesGcm, JWAAlgorithm):
name = 'A256GCM'
description = "AES GCM using 256-bit key"
keysize = 256
algorithm_usage_location = 'enc'
algorithm_use = 'enc'
class _BP256R1(_RawEC, JWAAlgorithm):
name = "BP256R1"
description = (
"ECDSA using Brainpool256R1 curve and SHA-256"
" (unregistered, custom-defined in breach"
" of IETF rules by gematik GmbH)"
)
keysize = 256
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_BP256R1, self).__init__('BP-256', hashes.SHA256())
class _BP384R1(_RawEC, JWAAlgorithm):
name = "BP384R1"
description = (
"ECDSA using Brainpool384R1 curve and SHA-384"
" (unregistered, custom-defined in breach"
" of IETF rules by gematik GmbH)"
)
keysize = 384
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_BP384R1, self).__init__('BP-384', hashes.SHA384())
class _BP512R1(_RawEC, JWAAlgorithm):
name = "BP512R1"
description = (
"ECDSA using Brainpool512R1 curve and SHA-512"
" (unregistered, custom-defined in breach"
" of IETF rules by gematik GmbH)"
)
keysize = 512
algorithm_usage_location = 'alg'
algorithm_use = 'sig'
def __init__(self):
super(_BP512R1, self).__init__('BP-512', hashes.SHA512())
class JWA:
"""JWA Signing Algorithms.
This class provides access to all JWA algorithms.
"""
algorithms_registry = {
'HS256': _HS256,
'HS384': _HS384,
'HS512': _HS512,
'RS256': _RS256,
'RS384': _RS384,
'RS512': _RS512,
'ES256': _ES256,
'ES256K': _ES256K,
'ES384': _ES384,
'ES512': _ES512,
'PS256': _PS256,
'PS384': _PS384,
'PS512': _PS512,
'none': _None,
'RSA1_5': _Rsa15,
'RSA-OAEP': _RsaOaep,
'RSA-OAEP-256': _RsaOaep256,
'A128KW': _A128KW,
'A192KW': _A192KW,
'A256KW': _A256KW,
'dir': _Direct,
'ECDH-ES': _EcdhEs,
'ECDH-ES+A128KW': _EcdhEsAes128Kw,
'ECDH-ES+A192KW': _EcdhEsAes192Kw,
'ECDH-ES+A256KW': _EcdhEsAes256Kw,
'EdDSA': _EdDsa,
'A128GCMKW': _A128GcmKw,
'A192GCMKW': _A192GcmKw,
'A256GCMKW': _A256GcmKw,
'PBES2-HS256+A128KW': _Pbes2Hs256A128Kw,
'PBES2-HS384+A192KW': _Pbes2Hs384A192Kw,
'PBES2-HS512+A256KW': _Pbes2Hs512A256Kw,
'A128CBC-HS256': _A128CbcHs256,
'A192CBC-HS384': _A192CbcHs384,
'A256CBC-HS512': _A256CbcHs512,
'A128GCM': _A128Gcm,
'A192GCM': _A192Gcm,
'A256GCM': _A256Gcm,
'BP256R1': _BP256R1,
'BP384R1': _BP384R1,
'BP512R1': _BP512R1
}
@classmethod
def instantiate_alg(cls, name, use=None):
alg = cls.algorithms_registry[name]
if use is not None and alg.algorithm_use != use:
raise KeyError
return alg()
@classmethod
def signing_alg(cls, name):
try:
return cls.instantiate_alg(name, use='sig')
except KeyError:
raise InvalidJWAAlgorithm('%s is not a valid Signing algorithm'
' name' % name) from None
@classmethod
def keymgmt_alg(cls, name):
try:
return cls.instantiate_alg(name, use='kex')
except KeyError:
raise InvalidJWAAlgorithm('%s is not a valid Key Management'
' algorithm name' % name) from None
@classmethod
def encryption_alg(cls, name):
try:
return cls.instantiate_alg(name, use='enc')
except KeyError:
raise InvalidJWAAlgorithm('%s is not a valid Encryption'
' algorithm name' % name) from None