from channels.middleware import BaseMiddleware from channels.db import database_sync_to_async from django.contrib.auth.models import AnonymousUser from rest_framework.authtoken.models import Token from django.contrib.auth import get_user_model import logging logger = logging.getLogger(__name__) User = get_user_model() # 获取当前项目的用户模型 @database_sync_to_async def get_user(token_key): """异步获取用户信息""" try: token = Token.objects.get(key=token_key) logger.info(f"用户认证成功: {token.user.username}") return token.user except Token.DoesNotExist: logger.warning(f"无效的token: {token_key}") return AnonymousUser() except Exception as e: logger.error(f"认证错误: {str(e)}") return AnonymousUser() class TokenAuthMiddleware(BaseMiddleware): """Token认证中间件""" async def __call__(self, scope, receive, send): try: # 从请求头获取token headers = dict(scope['headers']) if b'authorization' in headers: token_name, token_key = headers[b'authorization'].decode().split() if token_name == 'Token': scope['user'] = await get_user(token_key) return await super().__call__(scope, receive, send) # 如果没有token或token无效 scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) except Exception as e: logger.error(f"WebSocket认证错误: {str(e)}") scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) class UserActivityMiddleware: """用户活动中间件""" def __init__(self, get_response): self.get_response = get_response def __call__(self, request): response = self.get_response(request) return response