# apps/notification/consumers.py from channels.generic.websocket import AsyncWebsocketConsumer import json from channels.db import database_sync_to_async from rest_framework.authtoken.models import Token import logging logger = logging.getLogger(__name__) class NotificationConsumer(AsyncWebsocketConsumer): async def connect(self): # 获取token参数 query_string = self.scope.get('query_string', b'').decode() query_params = dict(param.split('=') for param in query_string.split('&') if '=' in param) token_key = query_params.get('token', None) if token_key: # 使用token获取用户 self.user = await self.get_user_from_token(token_key) if not self.user: logger.error(f"Invalid token: {token_key}") await self.close() return else: # 使用scope中的用户(如果有认证) self.user = self.scope.get('user') if not self.user or not self.user.is_authenticated: logger.error("No valid authentication in WebSocket connection") await self.close() return logger.info(f"WebSocket connected for user: {self.user.id}") self.group_name = f"notification_user_{self.user.id}" await self.channel_layer.group_add( self.group_name, self.channel_name ) await self.accept() @database_sync_to_async def get_user_from_token(self, token_key): try: token = Token.objects.select_related('user').get(key=token_key) return token.user except Token.DoesNotExist: return None except Exception as e: logger.error(f"Error authenticating token: {str(e)}") return None async def disconnect(self, close_code): logger.info(f"WebSocket disconnected with code: {close_code}") if hasattr(self, 'group_name'): await self.channel_layer.group_discard( self.group_name, self.channel_name ) async def notification(self, event): """处理通知事件""" await self.send(text_data=json.dumps(event['data']))