46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
![]() |
from django.db import close_old_connections
|
||
|
from rest_framework.authtoken.models import Token
|
||
|
from channels.middleware import BaseMiddleware
|
||
|
from channels.db import database_sync_to_async
|
||
|
from urllib.parse import parse_qs
|
||
|
import logging
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
@database_sync_to_async
|
||
|
def get_user_from_token(token_key):
|
||
|
try:
|
||
|
token = Token.objects.select_related('user').get(key=token_key)
|
||
|
return token.user
|
||
|
except Token.DoesNotExist:
|
||
|
return None
|
||
|
except Exception as e:
|
||
|
logger.error(f"获取用户Token失败: {str(e)}")
|
||
|
return None
|
||
|
|
||
|
class TokenAuthMiddleware(BaseMiddleware):
|
||
|
async def __call__(self, scope, receive, send):
|
||
|
# 关闭之前的数据库连接
|
||
|
close_old_connections()
|
||
|
|
||
|
# 从查询字符串中提取token
|
||
|
query_string = scope.get('query_string', b'').decode()
|
||
|
query_params = parse_qs(query_string)
|
||
|
token_key = query_params.get('token', [''])[0]
|
||
|
|
||
|
if token_key:
|
||
|
user = await get_user_from_token(token_key)
|
||
|
if user:
|
||
|
scope['user'] = user
|
||
|
logger.info(f"WebSocket认证成功: 用户 {user.id}")
|
||
|
else:
|
||
|
logger.warning(f"WebSocket认证失败: 无效的Token {token_key}")
|
||
|
scope['user'] = None
|
||
|
else:
|
||
|
logger.warning("WebSocket连接未提供Token")
|
||
|
scope['user'] = None
|
||
|
|
||
|
return await super().__call__(scope, receive, send)
|
||
|
|
||
|
def TokenAuthMiddlewareStack(inner):
|
||
|
return TokenAuthMiddleware(inner)
|