# apps/chat/consumers.py import json import logging import traceback from channels.generic.websocket import AsyncWebsocketConsumer from channels.db import database_sync_to_async from rest_framework.authtoken.models import Token from urllib.parse import parse_qs from apps.chat.models import ChatHistory from apps.knowledge_base.models import KnowledgeBase from django.conf import settings import aiohttp import uuid from apps.common.services.permission_service import PermissionService logger = logging.getLogger(__name__) class ChatStreamConsumer(AsyncWebsocketConsumer): async def connect(self): """建立WebSocket连接""" try: # 从URL参数中获取token query_string = self.scope.get('query_string', b'').decode() query_params = parse_qs(query_string) token_key = query_params.get('token', [''])[0] if not token_key: logger.warning("WebSocket连接尝试,但没有提供token") await self.close() return # 验证token self.user = await self.get_user_from_token(token_key) if not self.user: logger.warning(f"WebSocket连接尝试,但token无效: {token_key}") await self.close() return # 将用户信息存储在scope中 self.scope["user"] = self.user await self.accept() logger.info(f"用户 {self.user.username} 流式输出WebSocket连接成功") except Exception as e: logger.error(f"WebSocket连接错误: {str(e)}") await self.close() @database_sync_to_async def get_user_from_token(self, token_key): try: token = Token.objects.select_related('user').get(key=token_key) return token.user except Token.DoesNotExist: return None async def disconnect(self, close_code): """关闭WebSocket连接""" logger.info(f"用户 {self.user.username if hasattr(self, 'user') else 'unknown'} WebSocket连接断开,代码: {close_code}") async def receive(self, text_data): """接收消息并处理""" try: data = json.loads(text_data) # 检查必填字段 if 'question' not in data: await self.send_error("缺少必填字段: question") return if 'conversation_id' not in data: await self.send_error("缺少必填字段: conversation_id") return # 处理新会话或现有会话 await self.process_chat_request(data) except Exception as e: logger.error(f"处理消息时出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(f"处理消息时出错: {str(e)}") async def process_chat_request(self, data): """处理聊天请求""" try: conversation_id = data['conversation_id'] question = data['question'] # 获取会话信息和知识库 session_info = await self.get_session_info(data) if not session_info: return knowledge_bases, metadata, dataset_external_id_list = session_info # 创建问题记录 question_record = await self.create_question_record( conversation_id, question, knowledge_bases, metadata ) if not question_record: return # 创建AI回答记录 answer_record = await self.create_answer_record( conversation_id, question_record, knowledge_bases, metadata ) # 发送初始响应 await self.send_json({ 'code': 200, 'message': '开始流式传输', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': '', 'is_end': False } }) # 调用外部API获取流式响应 await self.stream_from_external_api( conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases ) except Exception as e: logger.error(f"处理聊天请求时出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(f"处理聊天请求时出错: {str(e)}") @database_sync_to_async def get_session_info(self, data): """获取会话信息和知识库""" try: conversation_id = data['conversation_id'] # 查找该会话ID下的历史记录 existing_records = ChatHistory.objects.filter( conversation_id=conversation_id ).order_by('created_at') # 如果有历史记录,使用第一条记录的metadata if existing_records.exists(): first_record = existing_records.first() metadata = first_record.metadata or {} # 获取知识库信息 dataset_ids = metadata.get('dataset_id_list', []) external_id_list = metadata.get('dataset_external_id_list', []) if not dataset_ids: logger.error('找不到会话关联的知识库信息') return None # 验证知识库是否存在且用户有权限 knowledge_bases = [] for kb_id in dataset_ids: try: kb = KnowledgeBase.objects.get(id=kb_id) if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'): logger.error(f'无权访问知识库: {kb.name}') return None knowledge_bases.append(kb) except KnowledgeBase.DoesNotExist: logger.error(f'知识库不存在: {kb_id}') return None if not external_id_list or not knowledge_bases: logger.error('会话关联的知识库信息不完整') return None return knowledge_bases, metadata, external_id_list else: # 如果是新会话的第一条记录,需要提供知识库ID dataset_ids = [] if 'dataset_id' in data: dataset_ids.append(str(data['dataset_id'])) elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)): if isinstance(data['dataset_id_list'], str): try: dataset_list = json.loads(data['dataset_id_list']) if isinstance(dataset_list, list): dataset_ids = [str(id) for id in dataset_list] else: dataset_ids = [str(data['dataset_id_list'])] except json.JSONDecodeError: dataset_ids = [str(data['dataset_id_list'])] else: dataset_ids = [str(id) for id in data['dataset_id_list']] if not dataset_ids: logger.error('新会话需要提供知识库ID') return None # 验证所有知识库并收集external_ids external_id_list = [] knowledge_bases = [] for kb_id in dataset_ids: try: knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first() if not knowledge_base: logger.error(f'知识库不存在: {kb_id}') return None knowledge_bases.append(knowledge_base) # 使用统一的权限检查方法 if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'): logger.error(f'无权访问知识库: {knowledge_base.name}') return None # 添加知识库的external_id到列表 if knowledge_base.external_id: external_id_list.append(str(knowledge_base.external_id)) else: logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id") except Exception as e: logger.error(f"处理知识库ID出错: {str(e)}") return None if not external_id_list: logger.error('没有有效的知识库external_id') return None # 创建metadata metadata = { 'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), 'dataset_id_list': [str(id) for id in dataset_ids], 'dataset_external_id_list': [str(id) for id in external_id_list], 'dataset_names': [kb.name for kb in knowledge_bases] } return knowledge_bases, metadata, external_id_list except Exception as e: logger.error(f"获取会话信息时出错: {str(e)}") return None def check_knowledge_base_permission(self, kb, user, permission_type): """检查知识库权限""" # 实现权限检查逻辑 return True # 临时返回 True,需要根据实际情况实现 @database_sync_to_async def create_question_record(self, conversation_id, question, knowledge_bases, metadata): """创建问题记录""" try: title = metadata.get('title', 'New chat') # 创建用户问题记录 return ChatHistory.objects.create( user=self.scope["user"], knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库 conversation_id=str(conversation_id), title=title, role='user', content=question, metadata=metadata ) except Exception as e: logger.error(f"创建问题记录时出错: {str(e)}") return None @database_sync_to_async def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata): """创建AI回答记录""" try: return ChatHistory.objects.create( user=self.scope["user"], knowledge_base=knowledge_bases[0], conversation_id=str(conversation_id), title=question_record.title, parent_id=str(question_record.id), role='assistant', content="", # 初始内容为空 metadata=metadata ) except Exception as e: logger.error(f"创建回答记录时出错: {str(e)}") return None async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases): """从外部API获取流式响应""" try: # 确保所有ID都是字符串 dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list] # 获取标题 title = answer_record.title or 'New chat' # 异步收集完整内容,用于最后保存 full_content = "" # 使用aiohttp进行异步HTTP请求 async with aiohttp.ClientSession() as session: # 第一步: 创建聊天会话 async with session.post( f"{settings.API_BASE_URL}/api/application/chat/open", json={ "id": "d5d11efa-ea9a-11ef-9933-0242ac120006", "model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'), "dataset_id_list": dataset_external_ids, "multiple_rounds_dialogue": False, "dataset_setting": { "top_n": 10, "similarity": "0.3", "max_paragraph_char_number": 10000, "search_mode": "blend", "no_references_setting": { "value": "{question}", "status": "ai_questioning" } }, "model_setting": { "prompt": "**相关文档内容**:{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**:{question}" }, "problem_optimization": False } ) as chat_response: if chat_response.status != 200: error_msg = f"外部API调用失败: {await chat_response.text()}" logger.error(error_msg) await self.send_error(error_msg) return chat_data = await chat_response.json() if chat_data.get('code') != 200 or not chat_data.get('data'): error_msg = f"外部API返回错误: {chat_data}" logger.error(error_msg) await self.send_error(error_msg) return chat_id = chat_data['data'] logger.info(f"成功创建聊天会话, chat_id: {chat_id}") # 第二步: 建立流式连接 message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}" logger.info(f"开始流式请求: {message_url}") # 创建流式请求 async with session.post( url=message_url, json={"message": question, "re_chat": False, "stream": True}, headers={"Content-Type": "application/json"} ) as message_request: if message_request.status != 200: error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}" logger.error(error_msg) await self.send_error(error_msg) return # 创建一个缓冲区以处理分段的数据 buffer = "" # 读取并处理每个响应块 logger.info("开始处理流式响应") async for chunk in message_request.content.iter_any(): chunk_str = chunk.decode('utf-8') buffer += chunk_str # 检查是否有完整的数据行 while '\n\n' in buffer: parts = buffer.split('\n\n', 1) line = parts[0] buffer = parts[1] if line.startswith('data: '): try: # 提取JSON数据 json_str = line[6:] # 去掉 "data: " 前缀 data = json.loads(json_str) # 记录并处理部分响应 if 'content' in data: content_part = data['content'] full_content += content_part # 发送部分内容 await self.send_json({ 'code': 200, 'message': 'partial', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'content': content_part, 'is_end': data.get('is_end', False) } }) # 处理结束标记 if data.get('is_end', False): logger.info("收到流式响应结束标记") # 保存完整内容 await self.update_answer_content(answer_record.id, full_content.strip()) # 处理标题 title = await self.get_or_generate_title( conversation_id, question, full_content.strip() ) # 发送最终响应 await self.send_json({ 'code': 200, 'message': '完成', 'data': { 'id': str(answer_record.id), 'conversation_id': str(conversation_id), 'title': title, 'dataset_id_list': metadata.get('dataset_id_list', []), 'dataset_names': metadata.get('dataset_names', []), 'role': 'assistant', 'content': full_content.strip(), 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'is_end': True } }) return except json.JSONDecodeError as e: logger.error(f"JSON解析错误: {e}, 数据: {line}") continue except Exception as e: logger.error(f"流式处理出错: {str(e)}") logger.error(traceback.format_exc()) await self.send_error(str(e)) # 保存已收集的内容 if 'full_content' in locals() and full_content: try: await self.update_answer_content(answer_record.id, full_content.strip()) except Exception as save_error: logger.error(f"保存部分内容失败: {str(save_error)}") @database_sync_to_async def update_answer_content(self, answer_id, content): """更新回答内容""" try: answer_record = ChatHistory.objects.get(id=answer_id) answer_record.content = content answer_record.save() return True except Exception as e: logger.error(f"更新回答内容失败: {str(e)}") return False @database_sync_to_async def get_or_generate_title(self, conversation_id, question, answer): """获取或生成对话标题""" try: # 先检查是否已有标题 current_title = ChatHistory.objects.filter( conversation_id=str(conversation_id) ).exclude( title__in=["New chat", "新对话", ""] ).values_list('title', flat=True).first() if current_title: return current_title # 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题) generated_title = question[:20] + "..." if len(question) > 20 else question # 更新所有相关记录的标题 ChatHistory.objects.filter( conversation_id=str(conversation_id) ).update(title=generated_title) return generated_title except Exception as e: logger.error(f"获取或生成标题失败: {str(e)}") return "新对话" async def send_json(self, content): """发送JSON格式的消息""" await self.send(text_data=json.dumps(content)) async def send_error(self, message): """发送错误消息""" await self.send_json({ 'code': 500, 'message': message, 'data': {'is_end': True} })