daren_project/venv/Lib/site-packages/rest_framework_simplejwt/backends.py

155 lines
4.9 KiB
Python

import json
from collections.abc import Iterable
from datetime import timedelta
from typing import Any, Dict, Optional, Type, Union
import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms
from .exceptions import TokenBackendError
from .tokens import Token
from .utils import format_lazy
try:
from jwt import PyJWKClient, PyJWKClientError
JWK_CLIENT_AVAILABLE = True
except ImportError:
JWK_CLIENT_AVAILABLE = False
ALLOWED_ALGORITHMS = {
"HS256",
"HS384",
"HS512",
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
}.union(algorithms.requires_cryptography)
class TokenBackend:
def __init__(
self,
algorithm: str,
signing_key: Optional[str] = None,
verifying_key: str = "",
audience: Union[str, Iterable, None] = None,
issuer: Optional[str] = None,
jwk_url: Optional[str] = None,
leeway: Union[float, int, timedelta, None] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> None:
self._validate_algorithm(algorithm)
self.algorithm = algorithm
self.signing_key = signing_key
self.verifying_key = verifying_key
self.audience = audience
self.issuer = issuer
if JWK_CLIENT_AVAILABLE:
self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
else:
self.jwks_client = None
self.leeway = leeway
self.json_encoder = json_encoder
def _validate_algorithm(self, algorithm: str) -> None:
"""
Ensure that the nominated algorithm is recognized, and that cryptography is installed for those
algorithms that require it
"""
if algorithm not in ALLOWED_ALGORITHMS:
raise TokenBackendError(
format_lazy(_("Unrecognized algorithm type '{}'"), algorithm)
)
if algorithm in algorithms.requires_cryptography and not algorithms.has_crypto:
raise TokenBackendError(
format_lazy(
_("You must have cryptography installed to use {}."), algorithm
)
)
def get_leeway(self) -> timedelta:
if self.leeway is None:
return timedelta(seconds=0)
elif isinstance(self.leeway, (int, float)):
return timedelta(seconds=self.leeway)
elif isinstance(self.leeway, timedelta):
return self.leeway
else:
raise TokenBackendError(
format_lazy(
_(
"Unrecognized type '{}', 'leeway' must be of type int, float or timedelta."
),
type(self.leeway),
)
)
def get_verifying_key(self, token: Token) -> Optional[str]:
if self.algorithm.startswith("HS"):
return self.signing_key
if self.jwks_client:
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
return self.verifying_key
def encode(self, payload: Dict[str, Any]) -> str:
"""
Returns an encoded token for the given payload dictionary.
"""
jwt_payload = payload.copy()
if self.audience is not None:
jwt_payload["aud"] = self.audience
if self.issuer is not None:
jwt_payload["iss"] = self.issuer
token = jwt.encode(
jwt_payload,
self.signing_key,
algorithm=self.algorithm,
json_encoder=self.json_encoder,
)
if isinstance(token, bytes):
# For PyJWT <= 1.7.1
return token.decode("utf-8")
# For PyJWT >= 2.0.0a1
return token
def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]:
"""
Performs a validation of the given token and returns its payload
dictionary.
Raises a `TokenBackendError` if the token is malformed, if its
signature check fails, or if its 'exp' claim indicates it has expired.
"""
try:
return jwt.decode(
token,
self.get_verifying_key(token),
algorithms=[self.algorithm],
audience=self.audience,
issuer=self.issuer,
leeway=self.get_leeway(),
options={
"verify_aud": self.audience is not None,
"verify_signature": verify,
},
)
except InvalidAlgorithmError as ex:
raise TokenBackendError(_("Invalid algorithm specified")) from ex
except InvalidTokenError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex