daren_project/user_management/consumers.py

891 lines
40 KiB
Python
Raw Permalink Normal View History

import json
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from channels.exceptions import StopConsumer
import logging
from rest_framework.authtoken.models import Token
2025-04-17 16:14:00 +08:00
from urllib.parse import parse_qs
2025-04-29 10:22:57 +08:00
from .models import ChatHistory, KnowledgeBase
import aiohttp
import asyncio
from django.conf import settings
import uuid
import traceback
logger = logging.getLogger(__name__)
class NotificationConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立WebSocket连接"""
try:
2025-04-17 16:14:00 +08:00
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
2025-04-17 16:14:00 +08:00
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
2025-04-17 16:14:00 +08:00
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 为用户创建专属房间
self.room_name = f"notification_user_{self.user.id}"
await self.channel_layer.group_add(
self.room_name,
self.channel_name
)
await self.accept()
2025-04-17 16:14:00 +08:00
logger.info(f"用户 {self.user.username} WebSocket连接成功")
except Exception as e:
2025-04-17 16:14:00 +08:00
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""断开WebSocket连接"""
try:
if hasattr(self, 'room_name'):
await self.channel_layer.group_discard(
self.room_name,
self.channel_name
)
logger.info(f"用户 {self.user.username} 已断开连接,关闭代码: {close_code}")
except Exception as e:
logger.error(f"断开连接时发生错误: {str(e)}")
async def notification(self, event):
"""处理并发送通知消息"""
try:
await self.send(text_data=json.dumps(event))
logger.info(f"已发送通知给用户 {self.user.username}")
except Exception as e:
logger.error(f"发送通知消息时发生错误: {str(e)}")
2025-04-29 10:22:57 +08:00
class ChatConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立 WebSocket 连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.username} WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭 WebSocket 连接"""
pass
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 验证必要字段
if 'question' not in data or 'conversation_id' not in data:
await self.send_error("缺少必要字段")
return
# 创建问题记录
question_record = await self.create_question_record(data)
if not question_record:
return
# 开始流式处理
await self.stream_answer(question_record, data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
await self.send_error(f"处理消息时出错: {str(e)}")
@database_sync_to_async
def _create_question_record_sync(self, data):
"""同步创建问题记录"""
try:
# 获取会话历史记录
conversation_id = data['conversation_id']
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 获取或创建元数据
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
dataset_ids = metadata.get('dataset_id_list', [])
knowledge_bases = []
# 验证知识库权限
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
raise Exception(f'无权访问知识库: {kb.name}')
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
raise Exception(f'知识库不存在: {kb_id}')
else:
# 新会话处理
dataset_ids = data.get('dataset_id_list', [])
if not dataset_ids:
raise Exception('新会话需要提供知识库ID')
knowledge_bases = []
for kb_id in dataset_ids:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
raise Exception(f'无权访问知识库: {kb.name}')
knowledge_bases.append(kb)
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(kb.id) for kb in knowledge_bases],
'dataset_external_id_list': [str(kb.external_id) for kb in knowledge_bases if kb.external_id],
'dataset_names': [kb.name for kb in knowledge_bases]
}
# 创建问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0],
conversation_id=conversation_id,
title=data.get('title', 'New chat'),
role='user',
content=data['question'],
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录失败: {str(e)}")
return None, str(e)
async def create_question_record(self, data):
"""异步创建问题记录"""
try:
result = await self._create_question_record_sync(data)
if isinstance(result, tuple):
_, error_message = result
await self.send_error(error_message)
return None
return result
except Exception as e:
await self.send_error(str(e))
return None
def check_knowledge_base_permission(self, kb, user, permission_type):
"""检查知识库权限"""
# 实现权限检查逻辑
return True # 临时返回 True需要根据实际情况实现
async def stream_answer(self, question_record, data):
"""流式处理回答"""
try:
# 创建 AI 回答记录
answer_record = await database_sync_to_async(ChatHistory.objects.create)(
user=self.scope["user"],
knowledge_base=question_record.knowledge_base,
conversation_id=str(question_record.conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="",
metadata=question_record.metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'content': '',
'is_end': False
}
})
# 调用外部 API 获取流式响应
async with aiohttp.ClientSession() as session:
# 创建聊天会话
chat_response = await session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": question_record.metadata.get('model_id'),
"dataset_id_list": question_record.metadata.get('dataset_external_id_list', []),
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10,
"similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
)
chat_data = await chat_response.json()
if chat_data.get('code') != 200:
raise Exception(f"创建聊天会话失败: {chat_data}")
chat_id = chat_data['data']
# 建立流式连接
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}",
json={"message": data['question'], "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as response:
full_content = ""
buffer = ""
async for chunk in response.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
json_str = line[6:]
chunk_data = json.loads(json_str)
if 'content' in chunk_data:
content_part = chunk_data['content']
full_content += content_part
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'content': content_part,
'is_end': chunk_data.get('is_end', False)
}
})
if chunk_data.get('is_end', False):
# 保存完整内容
answer_record.content = full_content.strip()
await database_sync_to_async(answer_record.save)()
# 生成或获取标题
title = await self.get_or_generate_title(
question_record.conversation_id,
data['question'],
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'title': title,
'dataset_id_list': question_record.metadata.get('dataset_id_list', []),
'dataset_names': question_record.metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
answer_record.content = full_content.strip()
await database_sync_to_async(answer_record.save)()
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 如果没有标题,生成新标题
# 这里需要实现标题生成的逻辑
generated_title = "新对话" # 临时使用默认标题
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送 JSON 格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})
class ChatStreamConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立WebSocket连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.username} 流式输出WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭WebSocket连接"""
logger.info(f"用户 {self.user.username if hasattr(self, 'user') else 'unknown'} WebSocket连接断开代码: {close_code}")
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 检查必填字段
if 'question' not in data:
await self.send_error("缺少必填字段: question")
return
if 'conversation_id' not in data:
await self.send_error("缺少必填字段: conversation_id")
return
# 处理新会话或现有会话
await self.process_chat_request(data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理消息时出错: {str(e)}")
async def process_chat_request(self, data):
"""处理聊天请求"""
try:
conversation_id = data['conversation_id']
question = data['question']
# 获取会话信息和知识库
session_info = await self.get_session_info(data)
if not session_info:
return
knowledge_bases, metadata, dataset_external_id_list = session_info
# 创建问题记录
question_record = await self.create_question_record(
conversation_id,
question,
knowledge_bases,
metadata
)
if not question_record:
return
# 创建AI回答记录
answer_record = await self.create_answer_record(
conversation_id,
question_record,
knowledge_bases,
metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': '',
'is_end': False
}
})
# 调用外部API获取流式响应
await self.stream_from_external_api(
conversation_id,
question,
dataset_external_id_list,
answer_record,
metadata,
knowledge_bases
)
except Exception as e:
logger.error(f"处理聊天请求时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理聊天请求时出错: {str(e)}")
@database_sync_to_async
def get_session_info(self, data):
"""获取会话信息和知识库"""
try:
conversation_id = data['conversation_id']
# 查找该会话ID下的历史记录
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 如果有历史记录使用第一条记录的metadata
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
# 获取知识库信息
dataset_ids = metadata.get('dataset_id_list', [])
external_id_list = metadata.get('dataset_external_id_list', [])
if not dataset_ids:
logger.error('找不到会话关联的知识库信息')
return None
# 验证知识库是否存在且用户有权限
knowledge_bases = []
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {kb.name}')
return None
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
logger.error(f'知识库不存在: {kb_id}')
return None
if not external_id_list or not knowledge_bases:
logger.error('会话关联的知识库信息不完整')
return None
return knowledge_bases, metadata, external_id_list
else:
# 如果是新会话的第一条记录需要提供知识库ID
dataset_ids = []
if 'dataset_id' in data:
dataset_ids.append(str(data['dataset_id']))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
if isinstance(data['dataset_id_list'], str):
try:
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
else:
dataset_ids = [str(data['dataset_id_list'])]
except json.JSONDecodeError:
dataset_ids = [str(data['dataset_id_list'])]
else:
dataset_ids = [str(id) for id in data['dataset_id_list']]
if not dataset_ids:
logger.error('新会话需要提供知识库ID')
return None
# 验证所有知识库并收集external_ids
external_id_list = []
knowledge_bases = []
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
logger.error(f'知识库不存在: {kb_id}')
return None
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {knowledge_base.name}')
return None
# 添加知识库的external_id到列表
if knowledge_base.external_id:
external_id_list.append(str(knowledge_base.external_id))
else:
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
except Exception as e:
logger.error(f"处理知识库ID出错: {str(e)}")
return None
if not external_id_list:
logger.error('没有有效的知识库external_id')
return None
# 创建metadata
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_external_id_list': [str(id) for id in external_id_list],
'dataset_names': [kb.name for kb in knowledge_bases]
}
return knowledge_bases, metadata, external_id_list
except Exception as e:
logger.error(f"获取会话信息时出错: {str(e)}")
return None
def check_knowledge_base_permission(self, kb, user, permission_type):
"""检查知识库权限"""
# 实现权限检查逻辑
return True # 临时返回 True需要根据实际情况实现
@database_sync_to_async
def create_question_record(self, conversation_id, question, knowledge_bases, metadata):
"""创建问题记录"""
try:
title = metadata.get('title', 'New chat')
# 创建用户问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
conversation_id=str(conversation_id),
title=title,
role='user',
content=question,
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录时出错: {str(e)}")
return None
@database_sync_to_async
def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata):
"""创建AI回答记录"""
try:
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
metadata=metadata
)
except Exception as e:
logger.error(f"创建回答记录时出错: {str(e)}")
return None
async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases):
"""从外部API获取流式响应"""
try:
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
# 获取标题
title = answer_record.title or 'New chat'
# 异步收集完整内容,用于最后保存
full_content = ""
# 使用aiohttp进行异步HTTP请求
async with aiohttp.ClientSession() as session:
# 第一步: 创建聊天会话
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
"dataset_id_list": dataset_external_ids,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10, "similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
) as chat_response:
if chat_response.status != 200:
error_msg = f"外部API调用失败: {await chat_response.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_data = await chat_response.json()
if chat_data.get('code') != 200 or not chat_data.get('data'):
error_msg = f"外部API返回错误: {chat_data}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_id = chat_data['data']
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
# 第二步: 建立流式连接
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
logger.info(f"开始流式请求: {message_url}")
# 创建流式请求
async with session.post(
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as message_request:
if message_request.status != 200:
error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
# 创建一个缓冲区以处理分段的数据
buffer = ""
# 读取并处理每个响应块
logger.info("开始处理流式响应")
async for chunk in message_request.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
# 检查是否有完整的数据行
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
# 提取JSON数据
json_str = line[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
# 记录并处理部分响应
if 'content' in data:
content_part = data['content']
full_content += content_part
# 发送部分内容
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': content_part,
'is_end': data.get('is_end', False)
}
})
# 处理结束标记
if data.get('is_end', False):
logger.info("收到流式响应结束标记")
# 保存完整内容
await self.update_answer_content(answer_record.id, full_content.strip())
# 处理标题
title = await self.get_or_generate_title(
conversation_id,
question,
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title,
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
await self.update_answer_content(answer_record.id, full_content.strip())
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def update_answer_content(self, answer_id, content):
"""更新回答内容"""
try:
answer_record = ChatHistory.objects.get(id=answer_id)
answer_record.content = content
answer_record.save()
return True
except Exception as e:
logger.error(f"更新回答内容失败: {str(e)}")
return False
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题)
generated_title = question[:20] + "..." if len(question) > 20 else question
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送JSON格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})