daren_project/venv/Lib/site-packages/rest_framework_simplejwt/serializers.py

193 lines
5.9 KiB
Python

from typing import Any, Dict, Optional, Type, TypeVar
from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AbstractBaseUser, update_last_login
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.exceptions import AuthenticationFailed, ValidationError
from .models import TokenUser
from .settings import api_settings
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
if api_settings.BLACKLIST_AFTER_ROTATION:
from .token_blacklist.models import BlacklistedToken
class PasswordField(serializers.CharField):
def __init__(self, *args, **kwargs) -> None:
kwargs.setdefault("style", {})
kwargs["style"]["input_type"] = "password"
kwargs["write_only"] = True
super().__init__(*args, **kwargs)
class TokenObtainSerializer(serializers.Serializer):
username_field = get_user_model().USERNAME_FIELD
token_class: Optional[Type[Token]] = None
default_error_messages = {
"no_active_account": _("No active account found with the given credentials")
}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fields[self.username_field] = serializers.CharField(write_only=True)
self.fields["password"] = PasswordField()
def validate(self, attrs: Dict[str, Any]) -> Dict[Any, Any]:
authenticate_kwargs = {
self.username_field: attrs[self.username_field],
"password": attrs["password"],
}
try:
authenticate_kwargs["request"] = self.context["request"]
except KeyError:
pass
self.user = authenticate(**authenticate_kwargs)
if not api_settings.USER_AUTHENTICATION_RULE(self.user):
raise exceptions.AuthenticationFailed(
self.error_messages["no_active_account"],
"no_active_account",
)
return {}
@classmethod
def get_token(cls, user: AuthUser) -> Token:
return cls.token_class.for_user(user) # type: ignore
class TokenObtainPairSerializer(TokenObtainSerializer):
token_class = RefreshToken
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
data = super().validate(attrs)
refresh = self.get_token(self.user)
data["refresh"] = str(refresh)
data["access"] = str(refresh.access_token)
if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, self.user)
return data
class TokenObtainSlidingSerializer(TokenObtainSerializer):
token_class = SlidingToken
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
data = super().validate(attrs)
token = self.get_token(self.user)
data["token"] = str(token)
if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, self.user)
return data
class TokenRefreshSerializer(serializers.Serializer):
refresh = serializers.CharField()
access = serializers.CharField(read_only=True)
token_class = RefreshToken
default_error_messages = {
"no_active_account": _("No active account found for the given token.")
}
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
refresh = self.token_class(attrs["refresh"])
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
if user_id and (
user := get_user_model().objects.get(
**{api_settings.USER_ID_FIELD: user_id}
)
):
if not api_settings.USER_AUTHENTICATION_RULE(user):
raise AuthenticationFailed(
self.error_messages["no_active_account"],
"no_active_account",
)
data = {"access": str(refresh.access_token)}
if api_settings.ROTATE_REFRESH_TOKENS:
if api_settings.BLACKLIST_AFTER_ROTATION:
try:
# Attempt to blacklist the given refresh token
refresh.blacklist()
except AttributeError:
# If blacklist app not installed, `blacklist` method will
# not be present
pass
refresh.set_jti()
refresh.set_exp()
refresh.set_iat()
data["refresh"] = str(refresh)
return data
class TokenRefreshSlidingSerializer(serializers.Serializer):
token = serializers.CharField()
token_class = SlidingToken
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
token = self.token_class(attrs["token"])
# Check that the timestamp in the "refresh_exp" claim has not
# passed
token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)
# Update the "exp" and "iat" claims
token.set_exp()
token.set_iat()
return {"token": str(token)}
class TokenVerifySerializer(serializers.Serializer):
token = serializers.CharField(write_only=True)
def validate(self, attrs: Dict[str, None]) -> Dict[Any, Any]:
token = UntypedToken(attrs["token"])
if (
api_settings.BLACKLIST_AFTER_ROTATION
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
):
jti = token.get(api_settings.JTI_CLAIM)
if BlacklistedToken.objects.filter(token__jti=jti).exists():
raise ValidationError("Token is blacklisted")
return {}
class TokenBlacklistSerializer(serializers.Serializer):
refresh = serializers.CharField(write_only=True)
token_class = RefreshToken
def validate(self, attrs: Dict[str, Any]) -> Dict[Any, Any]:
refresh = self.token_class(attrs["refresh"])
try:
refresh.blacklist()
except AttributeError:
pass
return {}