62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
# apps/notification/consumers.py
|
||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||
import json
|
||
from channels.db import database_sync_to_async
|
||
from rest_framework.authtoken.models import Token
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class NotificationConsumer(AsyncWebsocketConsumer):
|
||
async def connect(self):
|
||
# 获取token参数
|
||
query_string = self.scope.get('query_string', b'').decode()
|
||
query_params = dict(param.split('=') for param in query_string.split('&') if '=' in param)
|
||
token_key = query_params.get('token', None)
|
||
|
||
if token_key:
|
||
# 使用token获取用户
|
||
self.user = await self.get_user_from_token(token_key)
|
||
if not self.user:
|
||
logger.error(f"Invalid token: {token_key}")
|
||
await self.close()
|
||
return
|
||
else:
|
||
# 使用scope中的用户(如果有认证)
|
||
self.user = self.scope.get('user')
|
||
if not self.user or not self.user.is_authenticated:
|
||
logger.error("No valid authentication in WebSocket connection")
|
||
await self.close()
|
||
return
|
||
|
||
logger.info(f"WebSocket connected for user: {self.user.id}")
|
||
self.group_name = f"notification_user_{self.user.id}"
|
||
await self.channel_layer.group_add(
|
||
self.group_name,
|
||
self.channel_name
|
||
)
|
||
await self.accept()
|
||
|
||
@database_sync_to_async
|
||
def get_user_from_token(self, token_key):
|
||
try:
|
||
token = Token.objects.select_related('user').get(key=token_key)
|
||
return token.user
|
||
except Token.DoesNotExist:
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error authenticating token: {str(e)}")
|
||
return None
|
||
|
||
async def disconnect(self, close_code):
|
||
logger.info(f"WebSocket disconnected with code: {close_code}")
|
||
if hasattr(self, 'group_name'):
|
||
await self.channel_layer.group_discard(
|
||
self.group_name,
|
||
self.channel_name
|
||
)
|
||
|
||
async def notification(self, event):
|
||
"""处理通知事件"""
|
||
await self.send(text_data=json.dumps(event['data']))
|