2025-02-26 21:05:55 +08:00
|
|
|
from channels.middleware import BaseMiddleware
|
|
|
|
from channels.db import database_sync_to_async
|
|
|
|
from django.contrib.auth.models import AnonymousUser
|
|
|
|
from rest_framework.authtoken.models import Token
|
|
|
|
from django.contrib.auth import get_user_model
|
|
|
|
import logging
|
2025-04-29 10:21:13 +08:00
|
|
|
import re
|
|
|
|
from django.middleware.csrf import CsrfViewMiddleware
|
|
|
|
from django.conf import settings
|
|
|
|
from django.utils.deprecation import MiddlewareMixin
|
2025-02-26 21:05:55 +08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
User = get_user_model() # 获取当前项目的用户模型
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
def get_user(token_key):
|
|
|
|
"""异步获取用户信息"""
|
|
|
|
try:
|
|
|
|
token = Token.objects.get(key=token_key)
|
|
|
|
logger.info(f"用户认证成功: {token.user.username}")
|
|
|
|
return token.user
|
|
|
|
except Token.DoesNotExist:
|
|
|
|
logger.warning(f"无效的token: {token_key}")
|
|
|
|
return AnonymousUser()
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"认证错误: {str(e)}")
|
|
|
|
return AnonymousUser()
|
|
|
|
|
|
|
|
class TokenAuthMiddleware(BaseMiddleware):
|
|
|
|
"""Token认证中间件"""
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
|
|
try:
|
|
|
|
# 从请求头获取token
|
|
|
|
headers = dict(scope['headers'])
|
|
|
|
if b'authorization' in headers:
|
|
|
|
token_name, token_key = headers[b'authorization'].decode().split()
|
|
|
|
if token_name == 'Token':
|
|
|
|
scope['user'] = await get_user(token_key)
|
|
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
|
|
|
|
# 如果没有token或token无效
|
|
|
|
scope['user'] = AnonymousUser()
|
|
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"WebSocket认证错误: {str(e)}")
|
|
|
|
scope['user'] = AnonymousUser()
|
|
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
|
2025-04-29 10:21:13 +08:00
|
|
|
class UserActivityMiddleware(MiddlewareMixin):
|
|
|
|
"""中间件用于记录用户活动"""
|
|
|
|
|
|
|
|
def process_request(self, request):
|
|
|
|
# 可以在这里记录用户活动日志
|
|
|
|
pass
|
2025-02-26 21:05:55 +08:00
|
|
|
|
2025-04-29 10:21:13 +08:00
|
|
|
class CSRFExemptMiddleware(MiddlewareMixin):
|
|
|
|
"""为特定URL路径豁免CSRF保护的中间件"""
|
|
|
|
|
|
|
|
def process_view(self, request, callback, callback_args, callback_kwargs):
|
|
|
|
# 检查是否有CSRF豁免URL配置
|
|
|
|
if not hasattr(settings, 'CSRF_EXEMPT_URLS'):
|
|
|
|
return None
|
|
|
|
|
|
|
|
# 获取当前请求的路径
|
|
|
|
path = request.path_info.lstrip('/')
|
|
|
|
|
|
|
|
# 检查是否匹配任何豁免模式
|
|
|
|
for exempt_pattern in settings.CSRF_EXEMPT_URLS:
|
|
|
|
if re.match(exempt_pattern, path):
|
|
|
|
setattr(request, '_dont_enforce_csrf_checks', True)
|
|
|
|
break
|
|
|
|
|
|
|
|
return None
|