from django.db import close_old_connections from rest_framework.authtoken.models import Token from channels.middleware import BaseMiddleware from channels.db import database_sync_to_async from urllib.parse import parse_qs import logging logger = logging.getLogger(__name__) @database_sync_to_async def get_user_from_token(token_key): try: token = Token.objects.select_related('user').get(key=token_key) return token.user except Token.DoesNotExist: return None except Exception as e: logger.error(f"获取用户Token失败: {str(e)}") return None class TokenAuthMiddleware(BaseMiddleware): async def __call__(self, scope, receive, send): # 关闭之前的数据库连接 close_old_connections() # 从查询字符串中提取token query_string = scope.get('query_string', b'').decode() query_params = parse_qs(query_string) token_key = query_params.get('token', [''])[0] if token_key: user = await get_user_from_token(token_key) if user: scope['user'] = user logger.info(f"WebSocket认证成功: 用户 {user.id}") else: logger.warning(f"WebSocket认证失败: 无效的Token {token_key}") scope['user'] = None else: logger.warning("WebSocket连接未提供Token") scope['user'] = None return await super().__call__(scope, receive, send) def TokenAuthMiddlewareStack(inner): return TokenAuthMiddleware(inner)