193 lines
5.9 KiB
Python
193 lines
5.9 KiB
Python
from typing import Any, Dict, Optional, Type, TypeVar
|
|
|
|
from django.conf import settings
|
|
from django.contrib.auth import authenticate, get_user_model
|
|
from django.contrib.auth.models import AbstractBaseUser, update_last_login
|
|
from django.utils.translation import gettext_lazy as _
|
|
from rest_framework import exceptions, serializers
|
|
from rest_framework.exceptions import AuthenticationFailed, ValidationError
|
|
|
|
from .models import TokenUser
|
|
from .settings import api_settings
|
|
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken
|
|
|
|
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
|
|
|
|
if api_settings.BLACKLIST_AFTER_ROTATION:
|
|
from .token_blacklist.models import BlacklistedToken
|
|
|
|
|
|
class PasswordField(serializers.CharField):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
kwargs.setdefault("style", {})
|
|
|
|
kwargs["style"]["input_type"] = "password"
|
|
kwargs["write_only"] = True
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class TokenObtainSerializer(serializers.Serializer):
|
|
username_field = get_user_model().USERNAME_FIELD
|
|
token_class: Optional[Type[Token]] = None
|
|
|
|
default_error_messages = {
|
|
"no_active_account": _("No active account found with the given credentials")
|
|
}
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.fields[self.username_field] = serializers.CharField(write_only=True)
|
|
self.fields["password"] = PasswordField()
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[Any, Any]:
|
|
authenticate_kwargs = {
|
|
self.username_field: attrs[self.username_field],
|
|
"password": attrs["password"],
|
|
}
|
|
try:
|
|
authenticate_kwargs["request"] = self.context["request"]
|
|
except KeyError:
|
|
pass
|
|
|
|
self.user = authenticate(**authenticate_kwargs)
|
|
|
|
if not api_settings.USER_AUTHENTICATION_RULE(self.user):
|
|
raise exceptions.AuthenticationFailed(
|
|
self.error_messages["no_active_account"],
|
|
"no_active_account",
|
|
)
|
|
|
|
return {}
|
|
|
|
@classmethod
|
|
def get_token(cls, user: AuthUser) -> Token:
|
|
return cls.token_class.for_user(user) # type: ignore
|
|
|
|
|
|
class TokenObtainPairSerializer(TokenObtainSerializer):
|
|
token_class = RefreshToken
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
|
|
data = super().validate(attrs)
|
|
|
|
refresh = self.get_token(self.user)
|
|
|
|
data["refresh"] = str(refresh)
|
|
data["access"] = str(refresh.access_token)
|
|
|
|
if api_settings.UPDATE_LAST_LOGIN:
|
|
update_last_login(None, self.user)
|
|
|
|
return data
|
|
|
|
|
|
class TokenObtainSlidingSerializer(TokenObtainSerializer):
|
|
token_class = SlidingToken
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
|
|
data = super().validate(attrs)
|
|
|
|
token = self.get_token(self.user)
|
|
|
|
data["token"] = str(token)
|
|
|
|
if api_settings.UPDATE_LAST_LOGIN:
|
|
update_last_login(None, self.user)
|
|
|
|
return data
|
|
|
|
|
|
class TokenRefreshSerializer(serializers.Serializer):
|
|
refresh = serializers.CharField()
|
|
access = serializers.CharField(read_only=True)
|
|
token_class = RefreshToken
|
|
|
|
default_error_messages = {
|
|
"no_active_account": _("No active account found for the given token.")
|
|
}
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
|
|
refresh = self.token_class(attrs["refresh"])
|
|
|
|
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
|
|
if user_id and (
|
|
user := get_user_model().objects.get(
|
|
**{api_settings.USER_ID_FIELD: user_id}
|
|
)
|
|
):
|
|
if not api_settings.USER_AUTHENTICATION_RULE(user):
|
|
raise AuthenticationFailed(
|
|
self.error_messages["no_active_account"],
|
|
"no_active_account",
|
|
)
|
|
|
|
data = {"access": str(refresh.access_token)}
|
|
|
|
if api_settings.ROTATE_REFRESH_TOKENS:
|
|
if api_settings.BLACKLIST_AFTER_ROTATION:
|
|
try:
|
|
# Attempt to blacklist the given refresh token
|
|
refresh.blacklist()
|
|
except AttributeError:
|
|
# If blacklist app not installed, `blacklist` method will
|
|
# not be present
|
|
pass
|
|
|
|
refresh.set_jti()
|
|
refresh.set_exp()
|
|
refresh.set_iat()
|
|
|
|
data["refresh"] = str(refresh)
|
|
|
|
return data
|
|
|
|
|
|
class TokenRefreshSlidingSerializer(serializers.Serializer):
|
|
token = serializers.CharField()
|
|
token_class = SlidingToken
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
|
|
token = self.token_class(attrs["token"])
|
|
|
|
# Check that the timestamp in the "refresh_exp" claim has not
|
|
# passed
|
|
token.check_exp(api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM)
|
|
|
|
# Update the "exp" and "iat" claims
|
|
token.set_exp()
|
|
token.set_iat()
|
|
|
|
return {"token": str(token)}
|
|
|
|
|
|
class TokenVerifySerializer(serializers.Serializer):
|
|
token = serializers.CharField(write_only=True)
|
|
|
|
def validate(self, attrs: Dict[str, None]) -> Dict[Any, Any]:
|
|
token = UntypedToken(attrs["token"])
|
|
|
|
if (
|
|
api_settings.BLACKLIST_AFTER_ROTATION
|
|
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
|
|
):
|
|
jti = token.get(api_settings.JTI_CLAIM)
|
|
if BlacklistedToken.objects.filter(token__jti=jti).exists():
|
|
raise ValidationError("Token is blacklisted")
|
|
|
|
return {}
|
|
|
|
|
|
class TokenBlacklistSerializer(serializers.Serializer):
|
|
refresh = serializers.CharField(write_only=True)
|
|
token_class = RefreshToken
|
|
|
|
def validate(self, attrs: Dict[str, Any]) -> Dict[Any, Any]:
|
|
refresh = self.token_class(attrs["refresh"])
|
|
try:
|
|
refresh.blacklist()
|
|
except AttributeError:
|
|
pass
|
|
return {}
|