155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
import json
|
|
from collections.abc import Iterable
|
|
from datetime import timedelta
|
|
from typing import Any, Dict, Optional, Type, Union
|
|
|
|
import jwt
|
|
from django.utils.translation import gettext_lazy as _
|
|
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms
|
|
|
|
from .exceptions import TokenBackendError
|
|
from .tokens import Token
|
|
from .utils import format_lazy
|
|
|
|
try:
|
|
from jwt import PyJWKClient, PyJWKClientError
|
|
|
|
JWK_CLIENT_AVAILABLE = True
|
|
except ImportError:
|
|
JWK_CLIENT_AVAILABLE = False
|
|
|
|
ALLOWED_ALGORITHMS = {
|
|
"HS256",
|
|
"HS384",
|
|
"HS512",
|
|
"RS256",
|
|
"RS384",
|
|
"RS512",
|
|
"ES256",
|
|
"ES384",
|
|
"ES512",
|
|
}.union(algorithms.requires_cryptography)
|
|
|
|
|
|
class TokenBackend:
|
|
def __init__(
|
|
self,
|
|
algorithm: str,
|
|
signing_key: Optional[str] = None,
|
|
verifying_key: str = "",
|
|
audience: Union[str, Iterable, None] = None,
|
|
issuer: Optional[str] = None,
|
|
jwk_url: Optional[str] = None,
|
|
leeway: Union[float, int, timedelta, None] = None,
|
|
json_encoder: Optional[Type[json.JSONEncoder]] = None,
|
|
) -> None:
|
|
self._validate_algorithm(algorithm)
|
|
|
|
self.algorithm = algorithm
|
|
self.signing_key = signing_key
|
|
self.verifying_key = verifying_key
|
|
self.audience = audience
|
|
self.issuer = issuer
|
|
|
|
if JWK_CLIENT_AVAILABLE:
|
|
self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
|
|
else:
|
|
self.jwks_client = None
|
|
|
|
self.leeway = leeway
|
|
self.json_encoder = json_encoder
|
|
|
|
def _validate_algorithm(self, algorithm: str) -> None:
|
|
"""
|
|
Ensure that the nominated algorithm is recognized, and that cryptography is installed for those
|
|
algorithms that require it
|
|
"""
|
|
if algorithm not in ALLOWED_ALGORITHMS:
|
|
raise TokenBackendError(
|
|
format_lazy(_("Unrecognized algorithm type '{}'"), algorithm)
|
|
)
|
|
|
|
if algorithm in algorithms.requires_cryptography and not algorithms.has_crypto:
|
|
raise TokenBackendError(
|
|
format_lazy(
|
|
_("You must have cryptography installed to use {}."), algorithm
|
|
)
|
|
)
|
|
|
|
def get_leeway(self) -> timedelta:
|
|
if self.leeway is None:
|
|
return timedelta(seconds=0)
|
|
elif isinstance(self.leeway, (int, float)):
|
|
return timedelta(seconds=self.leeway)
|
|
elif isinstance(self.leeway, timedelta):
|
|
return self.leeway
|
|
else:
|
|
raise TokenBackendError(
|
|
format_lazy(
|
|
_(
|
|
"Unrecognized type '{}', 'leeway' must be of type int, float or timedelta."
|
|
),
|
|
type(self.leeway),
|
|
)
|
|
)
|
|
|
|
def get_verifying_key(self, token: Token) -> Optional[str]:
|
|
if self.algorithm.startswith("HS"):
|
|
return self.signing_key
|
|
|
|
if self.jwks_client:
|
|
try:
|
|
return self.jwks_client.get_signing_key_from_jwt(token).key
|
|
except PyJWKClientError as ex:
|
|
raise TokenBackendError(_("Token is invalid or expired")) from ex
|
|
|
|
return self.verifying_key
|
|
|
|
def encode(self, payload: Dict[str, Any]) -> str:
|
|
"""
|
|
Returns an encoded token for the given payload dictionary.
|
|
"""
|
|
jwt_payload = payload.copy()
|
|
if self.audience is not None:
|
|
jwt_payload["aud"] = self.audience
|
|
if self.issuer is not None:
|
|
jwt_payload["iss"] = self.issuer
|
|
|
|
token = jwt.encode(
|
|
jwt_payload,
|
|
self.signing_key,
|
|
algorithm=self.algorithm,
|
|
json_encoder=self.json_encoder,
|
|
)
|
|
if isinstance(token, bytes):
|
|
# For PyJWT <= 1.7.1
|
|
return token.decode("utf-8")
|
|
# For PyJWT >= 2.0.0a1
|
|
return token
|
|
|
|
def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]:
|
|
"""
|
|
Performs a validation of the given token and returns its payload
|
|
dictionary.
|
|
|
|
Raises a `TokenBackendError` if the token is malformed, if its
|
|
signature check fails, or if its 'exp' claim indicates it has expired.
|
|
"""
|
|
try:
|
|
return jwt.decode(
|
|
token,
|
|
self.get_verifying_key(token),
|
|
algorithms=[self.algorithm],
|
|
audience=self.audience,
|
|
issuer=self.issuer,
|
|
leeway=self.get_leeway(),
|
|
options={
|
|
"verify_aud": self.audience is not None,
|
|
"verify_signature": verify,
|
|
},
|
|
)
|
|
except InvalidAlgorithmError as ex:
|
|
raise TokenBackendError(_("Invalid algorithm specified")) from ex
|
|
except InvalidTokenError as ex:
|
|
raise TokenBackendError(_("Token is invalid or expired")) from ex
|