daren/apps/common/middlewares.py
2025-05-23 19:25:35 +08:00

46 lines
1.6 KiB
Python

from django.db import close_old_connections
from rest_framework.authtoken.models import Token
from channels.middleware import BaseMiddleware
from channels.db import database_sync_to_async
from urllib.parse import parse_qs
import logging
logger = logging.getLogger(__name__)
@database_sync_to_async
def get_user_from_token(token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
except Exception as e:
logger.error(f"获取用户Token失败: {str(e)}")
return None
class TokenAuthMiddleware(BaseMiddleware):
async def __call__(self, scope, receive, send):
# 关闭之前的数据库连接
close_old_connections()
# 从查询字符串中提取token
query_string = scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if token_key:
user = await get_user_from_token(token_key)
if user:
scope['user'] = user
logger.info(f"WebSocket认证成功: 用户 {user.id}")
else:
logger.warning(f"WebSocket认证失败: 无效的Token {token_key}")
scope['user'] = None
else:
logger.warning("WebSocket连接未提供Token")
scope['user'] = None
return await super().__call__(scope, receive, send)
def TokenAuthMiddlewareStack(inner):
return TokenAuthMiddleware(inner)