first commit

This commit is contained in:
pacnpal
2024-10-28 17:09:57 -04:00
commit 2e1b4d7af7
9993 changed files with 1182741 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
from .api_jwk import PyJWK, PyJWKSet
from .api_jws import (
PyJWS,
get_algorithm_by_name,
get_unverified_header,
register_algorithm,
unregister_algorithm,
)
from .api_jwt import PyJWT, decode, encode
from .exceptions import (
DecodeError,
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAlgorithmError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidKeyError,
InvalidSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
PyJWKClientConnectionError,
PyJWKClientError,
PyJWKError,
PyJWKSetError,
PyJWTError,
)
from .jwks_client import PyJWKClient
__version__ = "2.9.0"
__title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python"
__url__ = "https://pyjwt.readthedocs.io"
__uri__ = __url__
__doc__ = f"{__description__} <{__uri__}>"
__author__ = "José Padilla"
__email__ = "hello@jpadilla.com"
__license__ = "MIT"
__copyright__ = "Copyright 2015-2022 José Padilla"
__all__ = [
"PyJWS",
"PyJWT",
"PyJWKClient",
"PyJWK",
"PyJWKSet",
"decode",
"encode",
"get_unverified_header",
"register_algorithm",
"unregister_algorithm",
"get_algorithm_by_name",
# Exceptions
"DecodeError",
"ExpiredSignatureError",
"ImmatureSignatureError",
"InvalidAlgorithmError",
"InvalidAudienceError",
"InvalidIssuedAtError",
"InvalidIssuerError",
"InvalidKeyError",
"InvalidSignatureError",
"InvalidTokenError",
"MissingRequiredClaimError",
"PyJWKClientConnectionError",
"PyJWKClientError",
"PyJWKError",
"PyJWKSetError",
"PyJWTError",
]

View File

@@ -0,0 +1,860 @@
from __future__ import annotations
import hashlib
import hmac
import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
from .utils import (
base64url_decode,
base64url_encode,
der_to_raw_signature,
force_bytes,
from_base64url_uint,
is_pem_format,
is_ssh_key,
raw_to_der_signature,
to_base64url_uint,
)
try:
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurve,
EllipticCurvePrivateKey,
EllipticCurvePrivateNumbers,
EllipticCurvePublicKey,
EllipticCurvePublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey,
Ed448PublicKey,
)
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPrivateNumbers,
RSAPublicKey,
RSAPublicNumbers,
rsa_crt_dmp1,
rsa_crt_dmq1,
rsa_crt_iqmp,
rsa_recover_prime_factors,
)
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
load_pem_private_key,
load_pem_public_key,
load_ssh_public_key,
)
has_crypto = True
except ModuleNotFoundError:
has_crypto = False
if TYPE_CHECKING:
# Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys = (
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
)
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys = (
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
)
AllowedPublicKeys = (
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
)
requires_cryptography = {
"RS256",
"RS384",
"RS512",
"ES256",
"ES256K",
"ES384",
"ES521",
"ES512",
"PS256",
"PS384",
"PS512",
"EdDSA",
}
def get_default_algorithms() -> dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
default_algorithms = {
"none": NoneAlgorithm(),
"HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
"HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
"HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
}
if has_crypto:
default_algorithms.update(
{
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
"ES512": ECAlgorithm(
ECAlgorithm.SHA512
), # Backward compat for #219 fix
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
"EdDSA": OKPAlgorithm(),
}
)
return default_algorithms
class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
If there is no hash algorithm, raises a NotImplementedError.
"""
# lookup self.hash_alg if defined in a way that mypy can understand
hash_alg = getattr(self, "hash_alg", None)
if hash_alg is None:
raise NotImplementedError
if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())
else:
return bytes(hash_alg(bytestr).digest())
@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
@abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
@abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
"""
Serializes a given key into a JWK
"""
@staticmethod
@abstractmethod
def from_jwk(jwk: str | JWKDict) -> Any:
"""
Deserializes a given key from JWK back into a key object
"""
class NoneAlgorithm(Algorithm):
"""
Placeholder for use when no signing or verification
operations are required.
"""
def prepare_key(self, key: str | None) -> None:
if key == "":
key = None
if key is not None:
raise InvalidKeyError('When alg = "none", key value must be None.')
return key
def sign(self, msg: bytes, key: None) -> bytes:
return b""
def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False
@staticmethod
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()
@staticmethod
def from_jwk(jwk: str | JWKDict) -> NoReturn:
raise NotImplementedError()
class HMACAlgorithm(Algorithm):
"""
Performs signing and verification operations using HMAC
and the specified hash function.
"""
SHA256: ClassVar[HashlibHash] = hashlib.sha256
SHA384: ClassVar[HashlibHash] = hashlib.sha384
SHA512: ClassVar[HashlibHash] = hashlib.sha512
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key: str | bytes) -> bytes:
key_bytes = force_bytes(key)
if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
return key_bytes
@overload
@staticmethod
def to_jwk(
key_obj: str | bytes, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
def to_jwk(
key_obj: str | bytes, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
jwk = {
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}
if as_dict:
return jwk
else:
return json.dumps(jwk)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
try:
if isinstance(jwk, str):
obj: JWKDict = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
if obj.get("kty") != "oct":
raise InvalidKeyError("Not an HMAC key")
return base64url_decode(obj["k"])
def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
return hmac.compare_digest(sig, self.sign(msg, key))
if has_crypto:
class RSAAlgorithm(Algorithm):
"""
Performs signing and verification operations using
RSASSA-PKCS-v1_5 and the specified hash function.
"""
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
return key
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
key_bytes = force_bytes(key)
try:
if key_bytes.startswith(b"ssh-rsa"):
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else:
return cast(
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
)
except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError("Could not parse the provided public key.")
@overload
@staticmethod
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()
obj = {
"kty": "RSA",
"key_ops": ["sign"],
"n": to_base64url_uint(numbers.public_numbers.n).decode(),
"e": to_base64url_uint(numbers.public_numbers.e).decode(),
"d": to_base64url_uint(numbers.d).decode(),
"p": to_base64url_uint(numbers.p).decode(),
"q": to_base64url_uint(numbers.q).decode(),
"dp": to_base64url_uint(numbers.dmp1).decode(),
"dq": to_base64url_uint(numbers.dmq1).decode(),
"qi": to_base64url_uint(numbers.iqmp).decode(),
}
elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()
obj = {
"kty": "RSA",
"key_ops": ["verify"],
"n": to_base64url_uint(numbers.n).decode(),
"e": to_base64url_uint(numbers.e).decode(),
}
else:
raise InvalidKeyError("Not a public or private key")
if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
if obj.get("kty") != "RSA":
raise InvalidKeyError("Not an RSA key")
if "d" in obj and "e" in obj and "n" in obj:
# Private key
if "oth" in obj:
raise InvalidKeyError(
"Unsupported RSA private key: > 2 primes not supported"
)
other_props = ["p", "q", "dp", "dq", "qi"]
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)
if any_props_found and not all(props_found):
raise InvalidKeyError(
"RSA key must include all parameters if any are present besides d"
)
public_numbers = RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
)
if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj["d"]),
p=from_base64url_uint(obj["p"]),
q=from_base64url_uint(obj["q"]),
dmp1=from_base64url_uint(obj["dp"]),
dmq1=from_base64url_uint(obj["dq"]),
iqmp=from_base64url_uint(obj["qi"]),
public_numbers=public_numbers,
)
else:
d = from_base64url_uint(obj["d"])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e
)
numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers,
)
return numbers.private_key()
elif "n" in obj and "e" in obj:
# Public key
return RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
).public_key()
else:
raise InvalidKeyError("Not a public or private key")
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
except InvalidSignature:
return False
class ECAlgorithm(Algorithm):
"""
Performs signing and verification operations using
ECDSA and the specified hash function
"""
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
key_bytes = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
if key_bytes.startswith(b"ecdsa-sha2-"):
crypto_key = load_ssh_public_key(key_bytes)
else:
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError:
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
):
raise InvalidKeyError(
"Expecting a [AWS-SECRET-REMOVED]licKey. Wrong key provided for ECDSA algorithms"
)
return crypto_key
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
try:
der_sig = raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
public_key = (
key.public_key()
if isinstance(key, EllipticCurvePrivateKey)
else key
)
public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@overload
@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
public_numbers = key_obj.public_numbers()
else:
raise InvalidKeyError("Not a public or private key")
if isinstance(key_obj.curve, SECP256R1):
crv = "P-256"
elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384"
elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521"
elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
"y": to_base64url_uint(public_numbers.y).decode(),
}
if isinstance(key_obj, EllipticCurvePrivateKey):
obj["d"] = to_base64url_uint(
key_obj.private_numbers().private_value
).decode()
if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
if obj.get("kty") != "EC":
raise InvalidKeyError("Not an Elliptic curve key")
if "x" not in obj or "y" not in obj:
raise InvalidKeyError("Not an Elliptic curve key")
x = base64url_decode(obj.get("x"))
y = base64url_decode(obj.get("y"))
curve = obj.get("crv")
curve_obj: EllipticCurve
if curve == "P-256":
if len(x) == len(y) == 32:
curve_obj = SECP256R1()
else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
curve_obj = SECP384R1()
else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
curve_obj = SECP521R1()
else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1":
if len(x) == len(y) == 32:
curve_obj = SECP256K1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1"
)
else:
raise InvalidKeyError(f"Invalid curve: {curve}")
public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve_obj,
)
if "d" not in obj:
return public_numbers.public_key()
d = base64url_decode(obj.get("d"))
if len(d) != len(x):
raise InvalidKeyError(
"D should be {} bytes for curve {}", len(x), curve
)
return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers
).private_key()
class RSAPSSAlgorithm(RSAAlgorithm):
"""
Performs a signature using RSASSA-PSS with MGF1
"""
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(
sig,
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
return True
except InvalidSignature:
return False
class OKPAlgorithm(Algorithm):
"""
Performs signing and verification operations using EdDSA
This class requires ``cryptography>=2.6`` to be installed.
"""
def __init__(self, **kwargs: Any) -> None:
pass
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)):
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
if "-----BEGIN PUBLIC" in key_str:
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif key_str[0:4] == "ssh-":
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
key,
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
):
raise InvalidKeyError(
"Expecting a [AWS-SECRET-REMOVED]licKey. Wrong key provided for EdDSA algorithms"
)
return key
def sign(
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
) -> bytes:
"""
Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign
:param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes
"""
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
return key.sign(msg_bytes)
def verify(
self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
) -> bool:
"""
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
:param str|bytes sig: EdDSA signature to check ``msg`` against
:param str|bytes msg: Message to sign
:param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
A private or public EdDSA key instance
:return bool verified: True if signature is valid, False if not.
"""
try:
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
public_key = (
key.public_key()
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
else key
)
public_key.verify(sig_bytes, msg_bytes)
return True # If no exception was raised, the signature is valid.
except InvalidSignature:
return False
@overload
@staticmethod
def to_jwk(
key: AllowedOKPKeys, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
def to_jwk(
key: AllowedOKPKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}
if as_dict:
return obj
else:
return json.dumps(obj)
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
encoding=Encoding.Raw,
format=PrivateFormat.Raw,
encryption_algorithm=NoEncryption(),
)
x = key.public_key().public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}
if as_dict:
return obj
else:
return json.dumps(obj)
raise InvalidKeyError("Not a public or private key")
@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
raise ValueError
except ValueError:
raise InvalidKeyError("Key is not valid JSON")
if obj.get("kty") != "OKP":
raise InvalidKeyError("Not an Octet Key Pair")
curve = obj.get("crv")
if curve != "Ed25519" and curve != "Ed448":
raise InvalidKeyError(f"Invalid curve: {curve}")
if "x" not in obj:
raise InvalidKeyError('OKP should have "x" parameter')
x = base64url_decode(obj.get("x"))
try:
if "d" not in obj:
if curve == "Ed25519":
return Ed25519PublicKey.from_public_bytes(x)
return Ed448PublicKey.from_public_bytes(x)
d = base64url_decode(obj.get("d"))
if curve == "Ed25519":
return Ed25519PrivateKey.from_private_bytes(d)
return Ed448PrivateKey.from_private_bytes(d)
except ValueError as err:
raise InvalidKeyError("Invalid key parameter") from err

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import json
import time
from typing import Any
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import (
InvalidKeyError,
MissingCryptographyError,
PyJWKError,
PyJWKSetError,
PyJWTError,
)
from .types import JWKDict
class PyJWK:
def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
kty = self._jwk_data.get("kty", None)
if not kty:
raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
if not algorithm and isinstance(self._jwk_data, dict):
algorithm = self._jwk_data.get("alg", None)
if not algorithm:
# Determine alg with kty (and crv).
crv = self._jwk_data.get("crv", None)
if kty == "EC":
if crv == "P-256" or not crv:
algorithm = "ES256"
elif crv == "P-384":
algorithm = "ES384"
elif crv == "P-521":
algorithm = "ES512"
elif crv == "secp256k1":
algorithm = "ES256K"
else:
raise InvalidKeyError(f"Unsupported crv: {crv}")
elif kty == "RSA":
algorithm = "RS256"
elif kty == "oct":
algorithm = "HS256"
elif kty == "OKP":
if not crv:
raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
if crv == "Ed25519":
algorithm = "EdDSA"
else:
raise InvalidKeyError(f"Unsupported crv: {crv}")
else:
raise InvalidKeyError(f"Unsupported kty: {kty}")
if not has_crypto and algorithm in requires_cryptography:
raise MissingCryptographyError(
f"{algorithm} requires 'cryptography' to be installed."
)
self.algorithm_name = algorithm
if algorithm in self._algorithms:
self.Algorithm = self._algorithms[algorithm]
else:
raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK:
return PyJWK(obj, algorithm)
@staticmethod
def from_json(data: str, algorithm: None = None) -> PyJWK:
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)
@property
def key_type(self) -> str | None:
return self._jwk_data.get("kty", None)
@property
def key_id(self) -> str | None:
return self._jwk_data.get("kid", None)
@property
def public_key_use(self) -> str | None:
return self._jwk_data.get("use", None)
class PyJWKSet:
def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []
if not keys:
raise PyJWKSetError("The JWK Set did not contain any keys")
if not isinstance(keys, list):
raise PyJWKSetError("Invalid JWK Set value")
for key in keys:
try:
self.keys.append(PyJWK(key))
except PyJWTError as error:
if isinstance(error, MissingCryptographyError):
raise error
# skip unusable keys
continue
if len(self.keys) == 0:
raise PyJWKSetError(
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
)
@staticmethod
def from_dict(obj: dict[str, Any]) -> PyJWKSet:
keys = obj.get("keys", [])
return PyJWKSet(keys)
@staticmethod
def from_json(data: str) -> PyJWKSet:
obj = json.loads(data)
return PyJWKSet.from_dict(obj)
def __getitem__(self, kid: str) -> PyJWK:
for key in self.keys:
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")
class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set
def get_timestamp(self) -> float:
return self.timestamp

View File

@@ -0,0 +1,335 @@
from __future__ import annotations
import binascii
import json
import warnings
from typing import TYPE_CHECKING, Any
from .algorithms import (
Algorithm,
get_default_algorithms,
has_crypto,
requires_cryptography,
)
from .api_jwk import PyJWK
from .exceptions import (
DecodeError,
InvalidAlgorithmError,
InvalidSignatureError,
InvalidTokenError,
)
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWS:
header_typ = "JWT"
def __init__(
self,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
)
# Remove algorithms that aren't on the whitelist
for key in list(self._algorithms.keys()):
if key not in self._valid_algs:
del self._algorithms[key]
if options is None:
options = {}
self.options = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True}
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
if alg_id in self._algorithms:
raise ValueError("Algorithm already has a handler.")
if not isinstance(alg_obj, Algorithm):
raise TypeError("Object is not of type `Algorithm`")
self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id)
def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
"""
if alg_id not in self._algorithms:
raise KeyError(
"The specified algorithm could not be removed"
" because it is not registered."
)
del self._algorithms[alg_id]
self._valid_algs.remove(alg_id)
def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
return list(self._valid_algs)
def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
For a given string name, return the matching Algorithm object.
Example usage:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
try:
return self._algorithms[alg_name]
except KeyError as e:
if not has_crypto and alg_name in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e
def encode(
self,
payload: bytes,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []
# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"
# Prefer headers values if present to function parameters.
if headers:
headers_alg = headers.get("alg")
if headers_alg:
algorithm_ = headers["alg"]
headers_b64 = headers.get("b64")
if headers_b64 is False:
is_payload_detached = True
# Header
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
if headers:
self._validate_headers(headers)
header.update(headers)
if not header["typ"]:
del header["typ"]
if is_payload_detached:
header["b64"] = False
elif "b64" in header:
# True is the standard value for b64, so no need for it
del header["b64"]
json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()
segments.append(base64url_encode(json_header))
if is_payload_detached:
msg_payload = payload
else:
msg_payload = base64url_encode(payload)
segments.append(msg_payload)
# Segments
signing_input = b".".join(segments)
alg_obj = self.get_algorithm_by_name(algorithm_)
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)
segments.append(base64url_encode(signature))
# Don't put the payload content inside the encoded token when detached
if is_payload_detached:
segments[1] = b""
encoded_string = b".".join(segments)
return encoded_string.decode("utf-8")
def decode_complete(
self,
jwt: str | bytes,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
if options is None:
options = {}
merged_options = {**self.options, **options}
verify_signature = merged_options["verify_signature"]
if verify_signature and not algorithms and not isinstance(key, PyJWK):
raise DecodeError(
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
payload, signing_input, header, signature = self._load(jwt)
if header.get("b64", True) is False:
if detached_payload is None:
raise DecodeError(
'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
)
payload = detached_payload
signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
if verify_signature:
self._verify_signature(signing_input, header, signature, key, algorithms)
return {
"payload": payload,
"header": header,
"signature": signature,
}
def decode(
self,
jwt: str | bytes,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
decoded = self.decode_complete(
jwt, key, algorithms, options, detached_payload=detached_payload
)
return decoded["payload"]
def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
should not be fully trusted until signature verification is complete
"""
headers = self._load(jwt)[2]
self._validate_headers(headers)
return headers
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
if not isinstance(jwt, bytes):
raise DecodeError(f"Invalid token type. Token must be a {bytes}")
try:
signing_input, crypto_segment = jwt.rsplit(b".", 1)
header_segment, payload_segment = signing_input.split(b".", 1)
except ValueError as err:
raise DecodeError("Not enough segments") from err
try:
header_data = base64url_decode(header_segment)
except (TypeError, binascii.Error) as err:
raise DecodeError("Invalid header padding") from err
try:
header = json.loads(header_data)
except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e
if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object")
try:
payload = base64url_decode(payload_segment)
except (TypeError, binascii.Error) as err:
raise DecodeError("Invalid payload padding") from err
try:
signature = base64url_decode(crypto_segment)
except (TypeError, binascii.Error) as err:
raise DecodeError("Invalid crypto padding") from err
return (payload, signing_input, header, signature)
def _verify_signature(
self,
signing_input: bytes,
header: dict[str, Any],
signature: bytes,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
if algorithms is None and isinstance(key, PyJWK):
algorithms = [key.algorithm_name]
try:
alg = header["alg"]
except KeyError:
raise InvalidAlgorithmError("Algorithm not specified")
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")
if isinstance(key, PyJWK):
alg_obj = key.Algorithm
prepared_key = key.key
else:
try:
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
prepared_key = alg_obj.prepare_key(key)
if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])
def _validate_kid(self, kid: Any) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")
_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode
decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header

View File

@@ -0,0 +1,377 @@
from __future__ import annotations
import json
import warnings
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, List
from . import api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
from .api_jwk import PyJWK
class PyJWT:
def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> dict[str, bool | list[str]]:
return {
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"require": [],
}
def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str:
# Check that we get a dict
if not isinstance(payload, dict):
raise TypeError(
"Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)
# Payload
payload = payload.copy()
for time_claim in ["exp", "iat", "nbf"]:
# Convert datetime to a intDate value in known time-format claims
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
json_payload = self._encode_payload(
payload,
headers=headers,
json_encoder=json_encoder,
)
return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)
def _encode_payload(
self,
payload: dict[str, Any],
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
return json.dumps(
payload,
separators=(",", ":"),
cls=json_encoder,
).encode("utf-8")
def decode_complete(
self,
jwt: str | bytes,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | List[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
options = dict(options or {}) # shallow-copy or initialize an empty dict
options.setdefault("verify_signature", True)
# If the user has set the legacy `verify` argument, and it doesn't match
# what the relevant `options` entry for the argument is, inform the user
# that they're likely making a mistake.
if verify is not None and verify != options["verify_signature"]:
warnings.warn(
"The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
"The equivalent is setting `verify_signature` to False in the `options` dictionary. "
"This invocation has a mismatch between the kwarg and the option entry.",
category=DeprecationWarning,
)
if not options["verify_signature"]:
options.setdefault("verify_exp", False)
options.setdefault("verify_nbf", False)
options.setdefault("verify_iat", False)
options.setdefault("verify_aud", False)
options.setdefault("verify_iss", False)
if options["verify_signature"] and not algorithms:
raise DecodeError(
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)
decoded = api_jws.decode_complete(
jwt,
key=key,
algorithms=algorithms,
options=options,
detached_payload=detached_payload,
)
payload = self._decode_payload(decoded)
merged_options = {**self.options, **options}
self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
)
decoded["payload"] = payload
return decoded
def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
This method is intended to be overridden by subclasses that need to
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
def decode(
self,
jwt: str | bytes,
key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | List[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
decoded = self.decode_complete(
jwt,
key,
algorithms,
options,
verify=verify,
detached_payload=detached_payload,
audience=audience,
issuer=issuer,
leeway=leeway,
)
return decoded["payload"]
def _validate_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
audience=None,
issuer=None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
if audience is not None and not isinstance(audience, (str, Iterable)):
raise TypeError("audience must be a string, iterable or None")
self._validate_required_claims(payload, options)
now = datetime.now(tz=timezone.utc).timestamp()
if "iat" in payload and options["verify_iat"]:
self._validate_iat(payload, now, leeway)
if "nbf" in payload and options["verify_nbf"]:
self._validate_nbf(payload, now, leeway)
if "exp" in payload and options["verify_exp"]:
self._validate_exp(payload, now, leeway)
if options["verify_iss"]:
self._validate_iss(payload, issuer)
if options["verify_aud"]:
self._validate_aud(
payload, audience, strict=options.get("strict_aud", False)
)
def _validate_required_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
) -> None:
for claim in options["require"]:
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)
def _validate_iat(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
iat = int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
def _validate_nbf(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
nbf = int(payload["nbf"])
except ValueError:
raise DecodeError("Not Before claim (nbf) must be an integer.")
if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
def _validate_exp(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
exp = int(payload["exp"])
except ValueError:
raise DecodeError("Expiration Time claim (exp) must be an integer.")
if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
def _validate_aud(
self,
payload: dict[str, Any],
audience: str | Iterable[str] | None,
*,
strict: bool = False,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
return
# Application did not specify an audience, but
# the token has the 'aud' claim
raise InvalidAudienceError("Invalid audience")
if "aud" not in payload or not payload["aud"]:
# Application specified an audience, but it could not be
# verified since the token does not contain a claim.
raise MissingRequiredClaimError("aud")
audience_claims = payload["aud"]
# In strict mode, we forbid list matching: the supplied audience
# must be a string, and it must exactly match the audience claim.
if strict:
# Only a single audience is allowed in strict mode.
if not isinstance(audience, str):
raise InvalidAudienceError("Invalid audience (strict)")
# Only a single audience claim is allowed in strict mode.
if not isinstance(audience_claims, str):
raise InvalidAudienceError("Invalid claim format in token (strict)")
if audience != audience_claims:
raise InvalidAudienceError("Audience doesn't match (strict)")
return
if isinstance(audience_claims, str):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
raise InvalidAudienceError("Invalid claim format in token")
if any(not isinstance(c, str) for c in audience_claims):
raise InvalidAudienceError("Invalid claim format in token")
if isinstance(audience, str):
audience = [audience]
if all(aud not in audience_claims for aud in audience):
raise InvalidAudienceError("Audience doesn't match")
def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if issuer is None:
return
if "iss" not in payload:
raise MissingRequiredClaimError("iss")
if isinstance(issuer, list):
if payload["iss"] not in issuer:
raise InvalidIssuerError("Invalid issuer")
else:
if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer")
_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode

View File

@@ -0,0 +1,74 @@
class PyJWTError(Exception):
"""
Base class for all exceptions
"""
pass
class InvalidTokenError(PyJWTError):
pass
class DecodeError(InvalidTokenError):
pass
class InvalidSignatureError(DecodeError):
pass
class ExpiredSignatureError(InvalidTokenError):
pass
class InvalidAudienceError(InvalidTokenError):
pass
class InvalidIssuerError(InvalidTokenError):
pass
class InvalidIssuedAtError(InvalidTokenError):
pass
class ImmatureSignatureError(InvalidTokenError):
pass
class InvalidKeyError(PyJWTError):
pass
class InvalidAlgorithmError(InvalidTokenError):
pass
class MissingRequiredClaimError(InvalidTokenError):
def __init__(self, claim: str) -> None:
self.claim = claim
def __str__(self) -> str:
return f'Token is missing the "{self.claim}" claim'
class PyJWKError(PyJWTError):
pass
class MissingCryptographyError(PyJWKError):
pass
class PyJWKSetError(PyJWTError):
pass
class PyJWKClientError(PyJWTError):
pass
class PyJWKClientConnectionError(PyJWKClientError):
pass

View File

@@ -0,0 +1,64 @@
import json
import platform
import sys
from typing import Dict
from . import __version__ as pyjwt_version
try:
import cryptography
cryptography_version = cryptography.__version__
except ModuleNotFoundError:
cryptography_version = ""
def info() -> Dict[str, Dict[str, str]]:
"""
Generate information for a bug report.
Based on the requests package help utility module.
"""
try:
platform_info = {
"system": platform.system(),
"release": platform.release(),
}
except OSError:
platform_info = {"system": "Unknown", "release": "Unknown"}
implementation = platform.python_implementation()
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == "PyPy":
pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
implementation_version = (
f"{pypy_version_info.major}."
f"{pypy_version_info.minor}."
f"{pypy_version_info.micro}"
)
if pypy_version_info.releaselevel != "final":
implementation_version = "".join(
[implementation_version, pypy_version_info.releaselevel]
)
else:
implementation_version = "Unknown"
return {
"platform": platform_info,
"implementation": {
"name": implementation,
"version": implementation_version,
},
"cryptography": {"version": cryptography_version},
"pyjwt": {"version": pyjwt_version},
}
def main() -> None:
"""Pretty-print the bug information as JSON."""
print(json.dumps(info(), sort_keys=True, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,31 @@
import time
from typing import Optional
from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
class JWKSetCache:
def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet) -> None:
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None
def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None
return self.jwk_set_with_timestamp.get_jwk_set()
def is_expired(self) -> bool:
return (
self.jwk_set_with_timestamp is not None
and self.lifespan > -1
and time.monotonic()
> self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
)

View File

@@ -0,0 +1,124 @@
import json
import urllib.request
from functools import lru_cache
from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientConnectionError, PyJWKClientError
from .jwk_set_cache import JWKSetCache
class PyJWKClient:
def __init__(
self,
uri: str,
cache_keys: bool = False,
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
headers: Optional[Dict[str, Any]] = None,
timeout: int = 30,
ssl_context: Optional[SSLContext] = None,
):
if headers is None:
headers = {}
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None
self.headers = headers
self.timeout = timeout
self.ssl_context = ssl_context
if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
# Default lifespan is 300 seconds (5 minutes).
if lifespan <= 0:
raise PyJWKClientError(
f'Lifespan must be greater than 0, the input is "{lifespan}"'
)
self.jwk_set_cache = JWKSetCache(lifespan)
else:
self.jwk_set_cache = None
if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
def fetch_data(self) -> Any:
jwk_set: Any = None
try:
r = urllib.request.Request(url=self.uri, headers=self.headers)
with urllib.request.urlopen(
r, timeout=self.timeout, context=self.ssl_context
) as response:
jwk_set = json.load(response)
except (URLError, TimeoutError) as e:
raise PyJWKClientConnectionError(
f'Fail to fetch data from the url, err: "{e}"'
)
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)
def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()
if data is None:
data = self.fetch_data()
if not isinstance(data, dict):
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
return PyJWKSet.from_dict(data)
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
]
if not signing_keys:
raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
return signing_keys
def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
signing_key = self.match_kid(signing_keys, kid)
if not signing_key:
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = self.get_signing_keys(refresh=True)
signing_key = self.match_kid(signing_keys, kid)
if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)
return signing_key
def get_signing_key_from_jwt(self, token: str) -> PyJWK:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))
@staticmethod
def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None
for key in signing_keys:
if key.key_id == kid:
signing_key = key
break
return signing_key

View File

@@ -0,0 +1,5 @@
from typing import Any, Callable, Dict
JWKDict = Dict[str, Any]
HashlibHash = Callable[..., Any]

View File

@@ -0,0 +1,145 @@
import base64
import binascii
import re
from typing import Union
try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
from cryptography.hazmat.primitives.asymmetric.utils import (
decode_dss_signature,
encode_dss_signature,
)
except ModuleNotFoundError:
pass
def force_bytes(value: Union[bytes, str]) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
return value
else:
raise TypeError("Expected a string value")
def base64url_decode(input: Union[bytes, str]) -> bytes:
input_bytes = force_bytes(input)
rem = len(input_bytes) % 4
if rem > 0:
input_bytes += b"=" * (4 - rem)
return base64.urlsafe_b64decode(input_bytes)
def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")
def to_base64url_uint(val: int) -> bytes:
if val < 0:
raise ValueError("Must be a positive integer")
int_bytes = bytes_from_int(val)
if len(int_bytes) == 0:
int_bytes = b"\x00"
return base64url_encode(int_bytes)
def from_base64url_uint(val: Union[bytes, str]) -> int:
data = base64url_decode(force_bytes(val))
return int.from_bytes(data, byteorder="big")
def number_to_bytes(num: int, num_bytes: int) -> bytes:
padded_hex = "%0*x" % (2 * num_bytes, num)
return binascii.a2b_hex(padded_hex.encode("ascii"))
def bytes_to_number(string: bytes) -> int:
return int(binascii.b2a_hex(string), 16)
def bytes_from_int(val: int) -> bytes:
remaining = val
byte_length = 0
while remaining != 0:
remaining >>= 8
byte_length += 1
return val.to_bytes(byte_length, "big", signed=False)
def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
r, s = decode_dss_signature(der_sig)
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
if len(raw_sig) != 2 * num_bytes:
raise ValueError("Invalid signature")
r = bytes_to_number(raw_sig[:num_bytes])
s = bytes_to_number(raw_sig[num_bytes:])
return bytes(encode_dss_signature(r, s))
# Based on https://github.[AWS-SECRET-REMOVED]f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
_PEMS = {
b"CERTIFICATE",
b"TRUSTED CERTIFICATE",
b"PRIVATE KEY",
b"PUBLIC KEY",
b"ENCRYPTED PRIVATE KEY",
b"OPENSSH PRIVATE KEY",
b"DSA PRIVATE KEY",
b"RSA PRIVATE KEY",
b"RSA PUBLIC KEY",
b"EC PRIVATE KEY",
b"DH PARAMETERS",
b"NEW CERTIFICATE REQUEST",
b"CERTIFICATE REQUEST",
b"SSH2 PUBLIC KEY",
b"SSH2 ENCRYPTED PRIVATE KEY",
b"X509 CRL",
}
_PEM_RE = re.compile(
b"----[- ]BEGIN ("
+ b"|".join(_PEMS)
+ b""")[- ]----\r?
.+?\r?
----[- ]END \\1[- ]----\r?\n?""",
re.DOTALL,
)
def is_pem_format(key: bytes) -> bool:
return bool(_PEM_RE.search(key))
# Based on https://github.[AWS-SECRET-REMOVED][AWS-SECRET-REMOVED][AWS-SECRET-REMOVED].py#L40-L46
_SSH_KEY_FORMATS = (
b"ssh-ed25519",
b"ssh-rsa",
b"ssh-dss",
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
)
def is_ssh_key(key: bytes) -> bool:
return key.startswith(_SSH_KEY_FORMATS)

View File

@@ -0,0 +1,2 @@
class RemovedInPyjwt3Warning(DeprecationWarning):
pass