from channels.middleware import BaseMiddleware from channels.db import database_sync_to_async from django.contrib.auth.models import AnonymousUser from rest_framework.authtoken.models import Token from django.contrib.auth import get_user_model import logging import re from django.middleware.csrf import CsrfViewMiddleware from django.conf import settings from django.utils.deprecation import MiddlewareMixin logger = logging.getLogger(__name__) User = get_user_model() # 获取当前项目的用户模型 @database_sync_to_async def get_user(token_key): """异步获取用户信息""" try: token = Token.objects.get(key=token_key) logger.info(f"用户认证成功: {token.user.username}") return token.user except Token.DoesNotExist: logger.warning(f"无效的token: {token_key}") return AnonymousUser() except Exception as e: logger.error(f"认证错误: {str(e)}") return AnonymousUser() class TokenAuthMiddleware(BaseMiddleware): """Token认证中间件""" async def __call__(self, scope, receive, send): try: # 从请求头获取token headers = dict(scope['headers']) if b'authorization' in headers: token_name, token_key = headers[b'authorization'].decode().split() if token_name == 'Token': scope['user'] = await get_user(token_key) return await super().__call__(scope, receive, send) # 如果没有token或token无效 scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) except Exception as e: logger.error(f"WebSocket认证错误: {str(e)}") scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) class UserActivityMiddleware(MiddlewareMixin): """中间件用于记录用户活动""" def process_request(self, request): # 可以在这里记录用户活动日志 pass class CSRFExemptMiddleware(MiddlewareMixin): """为特定URL路径豁免CSRF保护的中间件""" def process_view(self, request, callback, callback_args, callback_kwargs): # 检查是否有CSRF豁免URL配置 if not hasattr(settings, 'CSRF_EXEMPT_URLS'): return None # 获取当前请求的路径 path = request.path_info.lstrip('/') # 检查是否匹配任何豁免模式 for exempt_pattern in settings.CSRF_EXEMPT_URLS: if re.match(exempt_pattern, path): setattr(request, '_dont_enforce_csrf_checks', True) break return None