# apps/chat/consumers.py from channels.generic.websocket import AsyncWebsocketConsumer import json from channels.db import database_sync_to_async from apps.chat.models import ChatHistory from apps.user.models import UserToken from django.conf import settings import logging import traceback import uuid import aiohttp from urllib.parse import parse_qs from django.utils import timezone from rest_framework.permissions import IsAuthenticated from apps.user.authentication import CustomTokenAuthentication logger = logging.getLogger(__name__) class ChatStreamConsumer(AsyncWebsocketConsumer): # 固定知识库ID DEFAULT_KNOWLEDGE_BASE_ID = "b680a4fa-37be-11f0-a7cb-0242ac120002" async def connect(self): """建立WebSocket连接""" try: # 从URL参数中获取token query_string = self.scope.get('query_string', b'').decode() query_params = parse_qs(query_string) token_key = query_params.get('token', [''])[0] if not token_key: logger.warning("WebSocket连接尝试,但没有提供token") await self.close() return # 验证token self.user = await self.get_user_from_token(token_key) if not self.user: logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") await self.close() return # 将用户信息存储在scope中 self.scope["user"] = self.user await self.accept() logger.info(f"用户 {self.user.email} 流式输出WebSocket连接成功") except Exception as e: logger.error(f"WebSocket连接错误: {str(e)}") await self.close() @database_sync_to_async def get_user_from_token(self, token_key): try: # 使用项目的UserToken模型而不是rest_framework的Token token = UserToken.objects.select_related('user').get( token=token_key, expired_at__gt=timezone.now() # 确保token未过期 ) return token.user except UserToken.DoesNotExist: return None async def disconnect(self, close_code): """关闭WebSocket连接""" logger.info(f"用户 {self.user.email if hasattr(self, 'user') else 'unknown'} WebSocket连接断开,代码: {close_code}") async def receive(self, text_data): """接收消息并处理""" try: data = json.loads(text_data) # 检查必填字段 if 'question' not in data: await self.send_error("缺少必填字段: question") return if 'conversation_id' not in data: await self.send_error("缺少必填字段: conversation_id") return # 处理新会话或现有会话 await self.process_chat_request(data) except Exception as e: logger.error(f"处理消息时出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(f"处理消息时出错: {str(e)}") async def process_chat_request(self, data): """处理聊天请求""" try: conversation_id = data['conversation_id'] question = data['question'] # 准备metadata metadata = {} # 创建问题记录 question_record = await self.create_question_record( conversation_id, question, metadata ) if not question_record: return # 创建AI回答记录 answer_record = await self.create_answer_record( conversation_id, question_record, metadata ) # 发送初始响应 await self.send_json({ 'code': 200, 'message': '开始流式传输', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': '', 'is_end': False } }) # 设置外部API需要的ID列表 - 简化为空列表 dataset_external_id_list = [] # 调用外部API获取流式响应 await self.stream_from_external_api( conversation_id, question, dataset_external_id_list, answer_record, metadata ) except Exception as e: logger.error(f"处理聊天请求时出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(f"处理聊天请求时出错: {str(e)}") @database_sync_to_async def create_question_record(self, conversation_id, question, metadata): """创建问题记录""" try: title = "New chat" # 创建用户问题记录 return ChatHistory.objects.create( user=self.scope["user"], knowledge_base_id=self.DEFAULT_KNOWLEDGE_BASE_ID, conversation_id=str(conversation_id), title=title, role='user', content=question, metadata=metadata ) except Exception as e: logger.error(f"创建问题记录时出错: {str(e)}") return None @database_sync_to_async def create_answer_record(self, conversation_id, question_record, metadata): """创建AI回答记录""" try: return ChatHistory.objects.create( user=self.scope["user"], knowledge_base_id=self.DEFAULT_KNOWLEDGE_BASE_ID, conversation_id=str(conversation_id), title=question_record.title, parent_id=str(question_record.id), role='assistant', content="", # 初始内容为空 metadata=metadata ) except Exception as e: logger.error(f"创建回答记录时出错: {str(e)}") return None async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata): """从外部API获取流式响应""" try: # 获取标题 title = answer_record.title or 'New chat' # 异步收集完整内容,用于最后保存 full_content = "" # 使用aiohttp进行异步HTTP请求 async with aiohttp.ClientSession() as session: # 第一步: 创建聊天会话 async with session.post( f"{settings.API_BASE_URL}/api/application/chat/open", json={ "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", "model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), "dataset_id_list": dataset_external_id_list, "multiple_rounds_dialogue": False, "dataset_setting": { "top_n": 10, "similarity": "0.3", "max_paragraph_char_number": 10000, "search_mode": "blend", "no_references_setting": { "value": "{question}", "status": "ai_questioning" } }, "model_setting": { "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" }, "problem_optimization": False } ) as chat_response: if chat_response.status != 200: error_msg = f"外部API调用失败: {await chat_response.text()}" logger.error(error_msg) await self.send_error(error_msg) return chat_data = await chat_response.json() if chat_data.get('code') != 200 or not chat_data.get('data'): error_msg = f"外部API返回错误: {chat_data}" logger.error(error_msg) await self.send_error(error_msg) return chat_id = chat_data['data'] logger.info(f"成功创建聊天会话, chat_id: {chat_id}") # 第二步: 建立流式连接 message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}" logger.info(f"开始流式请求: {message_url}") # 创建流式请求 async with session.post( url=message_url, json={"message": question, "re_chat": False, "stream": True}, headers={"Content-Type": "application/json"} ) as message_request: if message_request.status != 200: error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}" logger.error(error_msg) await self.send_error(error_msg) return # 创建一个缓冲区以处理分段的数据 buffer = "" # 读取并处理每个响应块 logger.info("开始处理流式响应") async for chunk in message_request.content.iter_any(): chunk_str = chunk.decode('utf-8') buffer += chunk_str # 检查是否有完整的数据行 while '\n\n' in buffer: parts = buffer.split('\n\n', 1) line = parts[0] buffer = parts[1] if line.startswith('data: '): try: # 提取JSON数据 json_str = line[6:] # 去掉 "data: " 前缀 data = json.loads(json_str) # 记录并处理部分响应 if 'content' in data: content_part = data['content'] full_content += content_part # 发送部分内容 await self.send_json({ 'code': 200, 'message': 'partial', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': content_part, 'is_end': data.get('is_end', False) } }) # 处理结束标记 if data.get('is_end', False): logger.info("收到流式响应结束标记") # 保存完整内容 await self.update_answer_content(answer_record.id, full_content.strip()) # 处理标题 title = await self.get_or_generate_title( conversation_id, question, full_content.strip() ) # 发送最终响应 await self.send_json({ 'code': 200, 'message': '完成', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'title': title, 'role': 'assistant', 'content': full_content.strip(), 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'is_end': True } }) return except json.JSONDecodeError as e: logger.error(f"JSON解析错误: {e}, 数据: {line}") continue except Exception as e: logger.error(f"流式处理出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(str(e)) # 保存已收集的内容 if 'full_content' in locals() and full_content: try: await self.update_answer_content(answer_record.id, full_content.strip()) except Exception as save_error: logger.error(f"保存部分内容失败: {str(save_error)}") @database_sync_to_async def update_answer_content(self, answer_id, content): """更新回答内容""" try: answer_record = ChatHistory.objects.get(id=answer_id) answer_record.content = content answer_record.save() return True except Exception as e: logger.error(f"更新回答内容失败: {str(e)}") return False @database_sync_to_async def get_or_generate_title(self, conversation_id, question, answer): """获取或生成对话标题""" try: # 先检查是否已有标题 current_title = ChatHistory.objects.filter( conversation_id=str(conversation_id) ).exclude( title__in=["New chat", "新对话", ""] ).values_list('title', flat=True).first() if current_title: return current_title # 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题) generated_title = question[:20] + "..." if len(question) > 20 else question # 更新所有相关记录的标题 ChatHistory.objects.filter( conversation_id=str(conversation_id) ).update(title=generated_title) return generated_title except Exception as e: logger.error(f"获取或生成标题失败: {str(e)}") return "新对话" async def send_json(self, content): """发送JSON格式的消息""" await self.send(text_data=json.dumps(content)) async def send_error(self, message): """发送错误消息""" await self.send_json({ 'code': 500, 'message': message, 'data': {'is_end': True} })