55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
from rest_framework import authentication
|
|
from rest_framework import exceptions
|
|
from django.contrib.auth.models import AnonymousUser
|
|
from .models import User, UserToken
|
|
from django.utils import timezone
|
|
from rest_framework.exceptions import APIException
|
|
from rest_framework import status
|
|
|
|
class CustomAuthenticationFailed(APIException):
|
|
status_code = status.HTTP_401_UNAUTHORIZED
|
|
|
|
def __init__(self, code, message):
|
|
self.detail = {"code": code, "message": message}
|
|
|
|
class CustomTokenAuthentication(authentication.BaseAuthentication):
|
|
keyword = 'Token' # 设置认证头关键字
|
|
|
|
def authenticate(self, request):
|
|
# 从请求头获取token
|
|
auth_header = request.META.get('HTTP_AUTHORIZATION')
|
|
|
|
if not auth_header:
|
|
raise CustomAuthenticationFailed(401, '未提供认证头')
|
|
|
|
try:
|
|
# 提取token
|
|
parts = auth_header.split()
|
|
|
|
if len(parts) != 2 or parts[0] != self.keyword:
|
|
raise CustomAuthenticationFailed(401, '认证头格式不正确')
|
|
|
|
token = parts[1]
|
|
|
|
# 查找token记录并确保token存在且有效
|
|
try:
|
|
token_obj = UserToken.objects.select_related('user').get(
|
|
token=token,
|
|
expired_at__gt=timezone.now() # 确保token未过期
|
|
)
|
|
except UserToken.DoesNotExist:
|
|
raise CustomAuthenticationFailed(401, '无效或过期的token')
|
|
|
|
# 检查用户是否激活
|
|
if not token_obj.user.is_active:
|
|
raise CustomAuthenticationFailed(401, '用户未激活')
|
|
|
|
return (token_obj.user, None)
|
|
|
|
except Exception as e:
|
|
if isinstance(e, CustomAuthenticationFailed):
|
|
raise e
|
|
raise CustomAuthenticationFailed(401, '认证失败')
|
|
|
|
def authenticate_header(self, request):
|
|
return self.keyword |