daren/apps/chat/consumers.py
2025-05-23 19:54:32 +08:00

387 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# apps/chat/consumers.py
from channels.generic.websocket import AsyncWebsocketConsumer
import json
from channels.db import database_sync_to_async
from apps.chat.models import ChatHistory
from apps.user.models import UserToken
from django.conf import settings
import logging
import traceback
import uuid
import aiohttp
from urllib.parse import parse_qs
from django.utils import timezone
from rest_framework.permissions import IsAuthenticated
from apps.user.authentication import CustomTokenAuthentication
logger = logging.getLogger(__name__)
class ChatStreamConsumer(AsyncWebsocketConsumer):
# 固定知识库ID
DEFAULT_KNOWLEDGE_BASE_ID = "b680a4fa-37be-11f0-a7cb-0242ac120002"
async def connect(self):
"""建立WebSocket连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.email} 流式输出WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
# 使用项目的UserToken模型而不是rest_framework的Token
token = UserToken.objects.select_related('user').get(
token=token_key,
expired_at__gt=timezone.now() # 确保token未过期
)
return token.user
except UserToken.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭WebSocket连接"""
logger.info(f"用户 {self.user.email if hasattr(self, 'user') else 'unknown'} WebSocket连接断开代码: {close_code}")
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 检查必填字段
if 'question' not in data:
await self.send_error("缺少必填字段: question")
return
if 'conversation_id' not in data:
await self.send_error("缺少必填字段: conversation_id")
return
# 处理新会话或现有会话
await self.process_chat_request(data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理消息时出错: {str(e)}")
async def process_chat_request(self, data):
"""处理聊天请求"""
try:
conversation_id = data['conversation_id']
question = data['question']
# 准备metadata
metadata = {}
# 创建问题记录
question_record = await self.create_question_record(
conversation_id,
question,
metadata
)
if not question_record:
return
# 创建AI回答记录
answer_record = await self.create_answer_record(
conversation_id,
question_record,
metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': '',
'is_end': False
}
})
# 设置外部API需要的ID列表 - 简化为空列表
dataset_external_id_list = []
# 调用外部API获取流式响应
await self.stream_from_external_api(
conversation_id,
question,
dataset_external_id_list,
answer_record,
metadata
)
except Exception as e:
logger.error(f"处理聊天请求时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理聊天请求时出错: {str(e)}")
@database_sync_to_async
def create_question_record(self, conversation_id, question, metadata):
"""创建问题记录"""
try:
title = "New chat"
# 创建用户问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base_id=self.DEFAULT_KNOWLEDGE_BASE_ID,
conversation_id=str(conversation_id),
title=title,
role='user',
content=question,
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录时出错: {str(e)}")
return None
@database_sync_to_async
def create_answer_record(self, conversation_id, question_record, metadata):
"""创建AI回答记录"""
try:
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base_id=self.DEFAULT_KNOWLEDGE_BASE_ID,
conversation_id=str(conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
metadata=metadata
)
except Exception as e:
logger.error(f"创建回答记录时出错: {str(e)}")
return None
async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata):
"""从外部API获取流式响应"""
try:
# 获取标题
title = answer_record.title or 'New chat'
# 异步收集完整内容,用于最后保存
full_content = ""
# 使用aiohttp进行异步HTTP请求
async with aiohttp.ClientSession() as session:
# 第一步: 创建聊天会话
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
"dataset_id_list": dataset_external_id_list,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10, "similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
) as chat_response:
if chat_response.status != 200:
error_msg = f"外部API调用失败: {await chat_response.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_data = await chat_response.json()
if chat_data.get('code') != 200 or not chat_data.get('data'):
error_msg = f"外部API返回错误: {chat_data}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_id = chat_data['data']
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
# 第二步: 建立流式连接
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
logger.info(f"开始流式请求: {message_url}")
# 创建流式请求
async with session.post(
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as message_request:
if message_request.status != 200:
error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
# 创建一个缓冲区以处理分段的数据
buffer = ""
# 读取并处理每个响应块
logger.info("开始处理流式响应")
async for chunk in message_request.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
# 检查是否有完整的数据行
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
# 提取JSON数据
json_str = line[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
# 记录并处理部分响应
if 'content' in data:
content_part = data['content']
full_content += content_part
# 发送部分内容
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': content_part,
'is_end': data.get('is_end', False)
}
})
# 处理结束标记
if data.get('is_end', False):
logger.info("收到流式响应结束标记")
# 保存完整内容
await self.update_answer_content(answer_record.id, full_content.strip())
# 处理标题
title = await self.get_or_generate_title(
conversation_id,
question,
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title,
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
await self.update_answer_content(answer_record.id, full_content.strip())
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def update_answer_content(self, answer_id, content):
"""更新回答内容"""
try:
answer_record = ChatHistory.objects.get(id=answer_id)
answer_record.content = content
answer_record.save()
return True
except Exception as e:
logger.error(f"更新回答内容失败: {str(e)}")
return False
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题)
generated_title = question[:20] + "..." if len(question) > 20 else question
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送JSON格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})