daren/apps/chat/consumers.py
2025-05-29 10:11:19 +08:00

501 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# apps/chat/consumers.py
from channels.generic.websocket import AsyncWebsocketConsumer
import json
from channels.db import database_sync_to_async
from apps.knowledge_base.models import KnowledgeBase
from apps.chat.models import ChatHistory
from rest_framework.authtoken.models import Token
from django.conf import settings
import logging
import traceback
import uuid
import aiohttp
from urllib.parse import parse_qs
logger = logging.getLogger(__name__)
class ChatStreamConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立WebSocket连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.name} 流式输出WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭WebSocket连接"""
logger.info(f"用户 {self.user.name if hasattr(self, 'user') else 'unknown'} WebSocket连接断开代码: {close_code}")
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 检查必填字段
if 'question' not in data:
await self.send_error("缺少必填字段: question")
return
if 'conversation_id' not in data:
await self.send_error("缺少必填字段: conversation_id")
return
# 处理新会话或现有会话
await self.process_chat_request(data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理消息时出错: {str(e)}")
async def process_chat_request(self, data):
"""处理聊天请求"""
try:
conversation_id = data['conversation_id']
question = data['question']
# 获取会话信息和知识库
session_info = await self.get_session_info(data)
if not session_info:
return
knowledge_bases, metadata, dataset_external_id_list = session_info
# 创建问题记录
question_record = await self.create_question_record(
conversation_id,
question,
knowledge_bases,
metadata
)
if not question_record:
return
# 创建AI回答记录
answer_record = await self.create_answer_record(
conversation_id,
question_record,
knowledge_bases,
metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': '',
'is_end': False
}
})
# 调用外部API获取流式响应
await self.stream_from_external_api(
conversation_id,
question,
dataset_external_id_list,
answer_record,
metadata,
knowledge_bases
)
except Exception as e:
logger.error(f"处理聊天请求时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理聊天请求时出错: {str(e)}")
@database_sync_to_async
def get_session_info(self, data):
"""获取会话信息和知识库"""
try:
conversation_id = data['conversation_id']
# 查找该会话ID下的历史记录
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 如果有历史记录使用第一条记录的metadata
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
# 获取知识库信息
dataset_ids = metadata.get('dataset_id_list', [])
external_id_list = metadata.get('dataset_external_id_list', [])
if not dataset_ids:
logger.error('找不到会话关联的知识库信息')
return None
# 验证知识库是否存在且用户有权限
knowledge_bases = []
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {kb.name}')
return None
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
logger.error(f'知识库不存在: {kb_id}')
return None
if not external_id_list or not knowledge_bases:
logger.error('会话关联的知识库信息不完整')
return None
return knowledge_bases, metadata, external_id_list
else:
# 如果是新会话的第一条记录需要提供知识库ID
dataset_ids = []
if 'dataset_id' in data:
dataset_ids.append(str(data['dataset_id']))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
if isinstance(data['dataset_id_list'], str):
try:
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
else:
dataset_ids = [str(data['dataset_id_list'])]
except json.JSONDecodeError:
dataset_ids = [str(data['dataset_id_list'])]
else:
dataset_ids = [str(id) for id in data['dataset_id_list']]
if not dataset_ids:
logger.error('新会话需要提供知识库ID')
return None
# 验证所有知识库并收集external_ids
external_id_list = []
knowledge_bases = []
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
logger.error(f'知识库不存在: {kb_id}')
return None
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {knowledge_base.name}')
return None
# 添加知识库的external_id到列表
if knowledge_base.external_id:
external_id_list.append(str(knowledge_base.external_id))
else:
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
except Exception as e:
logger.error(f"处理知识库ID出错: {str(e)}")
return None
if not external_id_list:
logger.error('没有有效的知识库external_id')
return None
# 创建metadata
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_external_id_list': [str(id) for id in external_id_list],
'dataset_names': [kb.name for kb in knowledge_bases]
}
return knowledge_bases, metadata, external_id_list
except Exception as e:
logger.error(f"获取会话信息时出错: {str(e)}")
return None
def check_knowledge_base_permission(self, kb, user, permission_type):
"""检查知识库权限"""
# 实现权限检查逻辑
return True # 临时返回 True需要根据实际情况实现
@database_sync_to_async
def create_question_record(self, conversation_id, question, knowledge_bases, metadata):
"""创建问题记录"""
try:
title = metadata.get('title', 'New chat')
# 创建用户问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
conversation_id=str(conversation_id),
title=title,
role='user',
content=question,
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录时出错: {str(e)}")
return None
@database_sync_to_async
def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata):
"""创建AI回答记录"""
try:
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
metadata=metadata
)
except Exception as e:
logger.error(f"创建回答记录时出错: {str(e)}")
return None
async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases):
"""从外部API获取流式响应"""
try:
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
# 获取标题
title = answer_record.title or 'New chat'
# 异步收集完整内容,用于最后保存
full_content = ""
# 使用aiohttp进行异步HTTP请求
async with aiohttp.ClientSession() as session:
# 第一步: 创建聊天会话
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
"dataset_id_list": dataset_external_ids,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10, "similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
) as chat_response:
if chat_response.status != 200:
error_msg = f"外部API调用失败: {await chat_response.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_data = await chat_response.json()
if chat_data.get('code') != 200 or not chat_data.get('data'):
error_msg = f"外部API返回错误: {chat_data}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_id = chat_data['data']
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
# 第二步: 建立流式连接
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
logger.info(f"开始流式请求: {message_url}")
# 创建流式请求
async with session.post(
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as message_request:
if message_request.status != 200:
error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
# 创建一个缓冲区以处理分段的数据
buffer = ""
# 读取并处理每个响应块
logger.info("开始处理流式响应")
async for chunk in message_request.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
# 检查是否有完整的数据行
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
# 提取JSON数据
json_str = line[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
# 记录并处理部分响应
if 'content' in data:
content_part = data['content']
full_content += content_part
# 发送部分内容
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': content_part,
'is_end': data.get('is_end', False)
}
})
# 处理结束标记
if data.get('is_end', False):
logger.info("收到流式响应结束标记")
# 保存完整内容
await self.update_answer_content(answer_record.id, full_content.strip())
# 处理标题
title = await self.get_or_generate_title(
conversation_id,
question,
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title,
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
await self.update_answer_content(answer_record.id, full_content.strip())
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def update_answer_content(self, answer_id, content):
"""更新回答内容"""
try:
answer_record = ChatHistory.objects.get(id=answer_id)
answer_record.content = content
answer_record.save()
return True
except Exception as e:
logger.error(f"更新回答内容失败: {str(e)}")
return False
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题)
generated_title = question[:20] + "..." if len(question) > 20 else question
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送JSON格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})