流式输出

This commit is contained in:
dspwasc 2025-04-29 10:21:13 +08:00
parent d769f814b4
commit 61eaec4d64
10 changed files with 1524 additions and 291 deletions

View File

@ -18,11 +18,16 @@ django.setup() # 添加这行来初始化 Django
from django.core.asgi import get_asgi_application
from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack
from channels.security.websocket import AllowedHostsOriginValidator
from user_management.routing import websocket_urlpatterns
from user_management.middleware import TokenAuthMiddleware
# 使用TokenAuthMiddleware代替AuthMiddlewareStack
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": AuthMiddlewareStack(
"websocket": AllowedHostsOriginValidator(
TokenAuthMiddleware(
URLRouter(websocket_urlpatterns)
)
),
})

View File

@ -41,6 +41,9 @@ ALLOWED_HOSTS = ['*'] # 仅在开发环境使用
# 服务器配置
DEBUG = False
# 是否允许注册新用户
ALLOW_REGISTRATION = True
# ALLOWED_HOSTS = ['frptx.chiyong.fun', 'localhost', '127.0.0.1']
# Application definition
@ -70,6 +73,7 @@ MIDDLEWARE = [
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'user_management.middleware.UserActivityMiddleware',
'user_management.middleware.CSRFExemptMiddleware', # 添加CSRF豁免中间件
]
ROOT_URLCONF = 'role_based_system.urls'
@ -168,7 +172,12 @@ ASGI_APPLICATION = "role_based_system.asgi.application"
# Channel Layers 配置
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels.layers.InMemoryChannelLayer",
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [("127.0.0.1", 6379)],
"capacity": 1500, # 默认100
"expiry": 60, # 默认60秒
},
},
}
@ -289,3 +298,9 @@ REST_FRAMEWORK = {
'rest_framework.parsers.MultiPartParser'
],
}
# DeepSeek API配置
DEEPSEEK_API_KEY = "sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf" # 请替换为您的实际有效的DeepSeek API密钥
SILICON_CLOUD_API_KEY = 'sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf'

View File

@ -4,6 +4,13 @@ from channels.db import database_sync_to_async
from channels.exceptions import StopConsumer
import logging
from rest_framework.authtoken.models import Token
from urllib.parse import parse_qs
from .models import ChatHistory, KnowledgeBase
import aiohttp
import asyncio
from django.conf import settings
import uuid
import traceback
logger = logging.getLogger(__name__)
@ -11,19 +18,20 @@ class NotificationConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立WebSocket连接"""
try:
# 获取token
headers = dict(self.scope['headers'])
auth_header = headers.get(b'authorization', b'').decode()
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not auth_header.startswith('Token '):
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
token_key = auth_header.split(' ')[1]
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
@ -34,8 +42,10 @@ class NotificationConsumer(AsyncWebsocketConsumer):
self.channel_name
)
await self.accept()
logger.info(f"用户 {self.user.username} WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
@ -65,3 +75,816 @@ class NotificationConsumer(AsyncWebsocketConsumer):
logger.info(f"已发送通知给用户 {self.user.username}")
except Exception as e:
logger.error(f"发送通知消息时发生错误: {str(e)}")
class ChatConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立 WebSocket 连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.username} WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭 WebSocket 连接"""
pass
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 验证必要字段
if 'question' not in data or 'conversation_id' not in data:
await self.send_error("缺少必要字段")
return
# 创建问题记录
question_record = await self.create_question_record(data)
if not question_record:
return
# 开始流式处理
await self.stream_answer(question_record, data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
await self.send_error(f"处理消息时出错: {str(e)}")
@database_sync_to_async
def _create_question_record_sync(self, data):
"""同步创建问题记录"""
try:
# 获取会话历史记录
conversation_id = data['conversation_id']
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 获取或创建元数据
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
dataset_ids = metadata.get('dataset_id_list', [])
knowledge_bases = []
# 验证知识库权限
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
raise Exception(f'无权访问知识库: {kb.name}')
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
raise Exception(f'知识库不存在: {kb_id}')
else:
# 新会话处理
dataset_ids = data.get('dataset_id_list', [])
if not dataset_ids:
raise Exception('新会话需要提供知识库ID')
knowledge_bases = []
for kb_id in dataset_ids:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
raise Exception(f'无权访问知识库: {kb.name}')
knowledge_bases.append(kb)
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(kb.id) for kb in knowledge_bases],
'dataset_external_id_list': [str(kb.external_id) for kb in knowledge_bases if kb.external_id],
'dataset_names': [kb.name for kb in knowledge_bases]
}
# 创建问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0],
conversation_id=conversation_id,
title=data.get('title', 'New chat'),
role='user',
content=data['question'],
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录失败: {str(e)}")
return None, str(e)
async def create_question_record(self, data):
"""异步创建问题记录"""
try:
result = await self._create_question_record_sync(data)
if isinstance(result, tuple):
_, error_message = result
await self.send_error(error_message)
return None
return result
except Exception as e:
await self.send_error(str(e))
return None
def check_knowledge_base_permission(self, kb, user, permission_type):
"""检查知识库权限"""
# 实现权限检查逻辑
return True # 临时返回 True需要根据实际情况实现
async def stream_answer(self, question_record, data):
"""流式处理回答"""
try:
# 创建 AI 回答记录
answer_record = await database_sync_to_async(ChatHistory.objects.create)(
user=self.scope["user"],
knowledge_base=question_record.knowledge_base,
conversation_id=str(question_record.conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="",
metadata=question_record.metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'content': '',
'is_end': False
}
})
# 调用外部 API 获取流式响应
async with aiohttp.ClientSession() as session:
# 创建聊天会话
chat_response = await session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": question_record.metadata.get('model_id'),
"dataset_id_list": question_record.metadata.get('dataset_external_id_list', []),
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10,
"similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
)
chat_data = await chat_response.json()
if chat_data.get('code') != 200:
raise Exception(f"创建聊天会话失败: {chat_data}")
chat_id = chat_data['data']
# 建立流式连接
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}",
json={"message": data['question'], "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as response:
full_content = ""
buffer = ""
async for chunk in response.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
json_str = line[6:]
chunk_data = json.loads(json_str)
if 'content' in chunk_data:
content_part = chunk_data['content']
full_content += content_part
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'content': content_part,
'is_end': chunk_data.get('is_end', False)
}
})
if chunk_data.get('is_end', False):
# 保存完整内容
answer_record.content = full_content.strip()
await database_sync_to_async(answer_record.save)()
# 生成或获取标题
title = await self.get_or_generate_title(
question_record.conversation_id,
data['question'],
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(question_record.conversation_id),
'title': title,
'dataset_id_list': question_record.metadata.get('dataset_id_list', []),
'dataset_names': question_record.metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
answer_record.content = full_content.strip()
await database_sync_to_async(answer_record.save)()
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 如果没有标题,生成新标题
# 这里需要实现标题生成的逻辑
generated_title = "新对话" # 临时使用默认标题
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送 JSON 格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})
class ChatStreamConsumer(AsyncWebsocketConsumer):
async def connect(self):
"""建立WebSocket连接"""
try:
# 从URL参数中获取token
query_string = self.scope.get('query_string', b'').decode()
query_params = parse_qs(query_string)
token_key = query_params.get('token', [''])[0]
if not token_key:
logger.warning("WebSocket连接尝试但没有提供token")
await self.close()
return
# 验证token
self.user = await self.get_user_from_token(token_key)
if not self.user:
logger.warning(f"WebSocket连接尝试但token无效: {token_key}")
await self.close()
return
# 将用户信息存储在scope中
self.scope["user"] = self.user
await self.accept()
logger.info(f"用户 {self.user.username} 流式输出WebSocket连接成功")
except Exception as e:
logger.error(f"WebSocket连接错误: {str(e)}")
await self.close()
@database_sync_to_async
def get_user_from_token(self, token_key):
try:
token = Token.objects.select_related('user').get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
async def disconnect(self, close_code):
"""关闭WebSocket连接"""
logger.info(f"用户 {self.user.username if hasattr(self, 'user') else 'unknown'} WebSocket连接断开代码: {close_code}")
async def receive(self, text_data):
"""接收消息并处理"""
try:
data = json.loads(text_data)
# 检查必填字段
if 'question' not in data:
await self.send_error("缺少必填字段: question")
return
if 'conversation_id' not in data:
await self.send_error("缺少必填字段: conversation_id")
return
# 处理新会话或现有会话
await self.process_chat_request(data)
except Exception as e:
logger.error(f"处理消息时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理消息时出错: {str(e)}")
async def process_chat_request(self, data):
"""处理聊天请求"""
try:
conversation_id = data['conversation_id']
question = data['question']
# 获取会话信息和知识库
session_info = await self.get_session_info(data)
if not session_info:
return
knowledge_bases, metadata, dataset_external_id_list = session_info
# 创建问题记录
question_record = await self.create_question_record(
conversation_id,
question,
knowledge_bases,
metadata
)
if not question_record:
return
# 创建AI回答记录
answer_record = await self.create_answer_record(
conversation_id,
question_record,
knowledge_bases,
metadata
)
# 发送初始响应
await self.send_json({
'code': 200,
'message': '开始流式传输',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': '',
'is_end': False
}
})
# 调用外部API获取流式响应
await self.stream_from_external_api(
conversation_id,
question,
dataset_external_id_list,
answer_record,
metadata,
knowledge_bases
)
except Exception as e:
logger.error(f"处理聊天请求时出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(f"处理聊天请求时出错: {str(e)}")
@database_sync_to_async
def get_session_info(self, data):
"""获取会话信息和知识库"""
try:
conversation_id = data['conversation_id']
# 查找该会话ID下的历史记录
existing_records = ChatHistory.objects.filter(
conversation_id=conversation_id
).order_by('created_at')
# 如果有历史记录使用第一条记录的metadata
if existing_records.exists():
first_record = existing_records.first()
metadata = first_record.metadata or {}
# 获取知识库信息
dataset_ids = metadata.get('dataset_id_list', [])
external_id_list = metadata.get('dataset_external_id_list', [])
if not dataset_ids:
logger.error('找不到会话关联的知识库信息')
return None
# 验证知识库是否存在且用户有权限
knowledge_bases = []
for kb_id in dataset_ids:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if not self.check_knowledge_base_permission(kb, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {kb.name}')
return None
knowledge_bases.append(kb)
except KnowledgeBase.DoesNotExist:
logger.error(f'知识库不存在: {kb_id}')
return None
if not external_id_list or not knowledge_bases:
logger.error('会话关联的知识库信息不完整')
return None
return knowledge_bases, metadata, external_id_list
else:
# 如果是新会话的第一条记录需要提供知识库ID
dataset_ids = []
if 'dataset_id' in data:
dataset_ids.append(str(data['dataset_id']))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
if isinstance(data['dataset_id_list'], str):
try:
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
else:
dataset_ids = [str(data['dataset_id_list'])]
except json.JSONDecodeError:
dataset_ids = [str(data['dataset_id_list'])]
else:
dataset_ids = [str(id) for id in data['dataset_id_list']]
if not dataset_ids:
logger.error('新会话需要提供知识库ID')
return None
# 验证所有知识库并收集external_ids
external_id_list = []
knowledge_bases = []
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
logger.error(f'知识库不存在: {kb_id}')
return None
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, self.scope["user"], 'read'):
logger.error(f'无权访问知识库: {knowledge_base.name}')
return None
# 添加知识库的external_id到列表
if knowledge_base.external_id:
external_id_list.append(str(knowledge_base.external_id))
else:
logger.warning(f"知识库 {knowledge_base.id} ({knowledge_base.name}) 没有external_id")
except Exception as e:
logger.error(f"处理知识库ID出错: {str(e)}")
return None
if not external_id_list:
logger.error('没有有效的知识库external_id')
return None
# 创建metadata
metadata = {
'model_id': data.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_external_id_list': [str(id) for id in external_id_list],
'dataset_names': [kb.name for kb in knowledge_bases]
}
return knowledge_bases, metadata, external_id_list
except Exception as e:
logger.error(f"获取会话信息时出错: {str(e)}")
return None
def check_knowledge_base_permission(self, kb, user, permission_type):
"""检查知识库权限"""
# 实现权限检查逻辑
return True # 临时返回 True需要根据实际情况实现
@database_sync_to_async
def create_question_record(self, conversation_id, question, knowledge_bases, metadata):
"""创建问题记录"""
try:
title = metadata.get('title', 'New chat')
# 创建用户问题记录
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
conversation_id=str(conversation_id),
title=title,
role='user',
content=question,
metadata=metadata
)
except Exception as e:
logger.error(f"创建问题记录时出错: {str(e)}")
return None
@database_sync_to_async
def create_answer_record(self, conversation_id, question_record, knowledge_bases, metadata):
"""创建AI回答记录"""
try:
return ChatHistory.objects.create(
user=self.scope["user"],
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=question_record.title,
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
metadata=metadata
)
except Exception as e:
logger.error(f"创建回答记录时出错: {str(e)}")
return None
async def stream_from_external_api(self, conversation_id, question, dataset_external_id_list, answer_record, metadata, knowledge_bases):
"""从外部API获取流式响应"""
try:
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
# 获取标题
title = answer_record.title or 'New chat'
# 异步收集完整内容,用于最后保存
full_content = ""
# 使用aiohttp进行异步HTTP请求
async with aiohttp.ClientSession() as session:
# 第一步: 创建聊天会话
async with session.post(
f"{settings.API_BASE_URL}/api/application/chat/open",
json={
"id": "d5d11efa-ea9a-11ef-9933-0242ac120006",
"model_id": metadata.get('model_id', '7a214d0e-e65e-11ef-9f4a-0242ac120006'),
"dataset_id_list": dataset_external_ids,
"multiple_rounds_dialogue": False,
"dataset_setting": {
"top_n": 10, "similarity": "0.3",
"max_paragraph_char_number": 10000,
"search_mode": "blend",
"no_references_setting": {
"value": "{question}",
"status": "ai_questioning"
}
},
"model_setting": {
"prompt": "**相关文档内容**{data} **回答要求**:如果相关文档内容中没有可用信息,请回答\"没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作\"。请根据相关文档内容回答用户问题。不要输出与用户问题无关的内容。请使用中文回答客户问题。**用户问题**{question}"
},
"problem_optimization": False
}
) as chat_response:
if chat_response.status != 200:
error_msg = f"外部API调用失败: {await chat_response.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_data = await chat_response.json()
if chat_data.get('code') != 200 or not chat_data.get('data'):
error_msg = f"外部API返回错误: {chat_data}"
logger.error(error_msg)
await self.send_error(error_msg)
return
chat_id = chat_data['data']
logger.info(f"成功创建聊天会话, chat_id: {chat_id}")
# 第二步: 建立流式连接
message_url = f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}"
logger.info(f"开始流式请求: {message_url}")
# 创建流式请求
async with session.post(
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"}
) as message_request:
if message_request.status != 200:
error_msg = f"外部API聊天消息调用失败: {message_request.status}, {await message_request.text()}"
logger.error(error_msg)
await self.send_error(error_msg)
return
# 创建一个缓冲区以处理分段的数据
buffer = ""
# 读取并处理每个响应块
logger.info("开始处理流式响应")
async for chunk in message_request.content.iter_any():
chunk_str = chunk.decode('utf-8')
buffer += chunk_str
# 检查是否有完整的数据行
while '\n\n' in buffer:
parts = buffer.split('\n\n', 1)
line = parts[0]
buffer = parts[1]
if line.startswith('data: '):
try:
# 提取JSON数据
json_str = line[6:] # 去掉 "data: " 前缀
data = json.loads(json_str)
# 记录并处理部分响应
if 'content' in data:
content_part = data['content']
full_content += content_part
# 发送部分内容
await self.send_json({
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'content': content_part,
'is_end': data.get('is_end', False)
}
})
# 处理结束标记
if data.get('is_end', False):
logger.info("收到流式响应结束标记")
# 保存完整内容
await self.update_answer_content(answer_record.id, full_content.strip())
# 处理标题
title = await self.get_or_generate_title(
conversation_id,
question,
full_content.strip()
)
# 发送最终响应
await self.send_json({
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title,
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
})
return
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 数据: {line}")
continue
except Exception as e:
logger.error(f"流式处理出错: {str(e)}")
logger.error(traceback.format_exc())
await self.send_error(str(e))
# 保存已收集的内容
if 'full_content' in locals() and full_content:
try:
await self.update_answer_content(answer_record.id, full_content.strip())
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
@database_sync_to_async
def update_answer_content(self, answer_id, content):
"""更新回答内容"""
try:
answer_record = ChatHistory.objects.get(id=answer_id)
answer_record.content = content
answer_record.save()
return True
except Exception as e:
logger.error(f"更新回答内容失败: {str(e)}")
return False
@database_sync_to_async
def get_or_generate_title(self, conversation_id, question, answer):
"""获取或生成对话标题"""
try:
# 先检查是否已有标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
return current_title
# 简单的标题生成逻辑 (可替换为调用DeepSeek API生成标题)
generated_title = question[:20] + "..." if len(question) > 20 else question
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
return generated_title
except Exception as e:
logger.error(f"获取或生成标题失败: {str(e)}")
return "新对话"
async def send_json(self, content):
"""发送JSON格式的消息"""
await self.send(text_data=json.dumps(content))
async def send_error(self, message):
"""发送错误消息"""
await self.send_json({
'code': 500,
'message': message,
'data': {'is_end': True}
})

View File

@ -6,23 +6,60 @@ from rest_framework.authtoken.models import Token
User = get_user_model()
class Command(BaseCommand):
help = '创建测试用户:1个管理员2个组长4个组员'
help = '创建测试用户:4个管理员7个组长4个组员'
def handle(self, *args, **kwargs):
# 创建管理员 - 技术部管理员
admin, created = User.objects.get_or_create(
username='admin',
defaults={
'email': 'admin@example.com',
'name': '张管理',
# 创建管理员 - 4个管理员
admins = [
{
'username': 'admin1',
'password': 'admin123',
'email': 'admin1@example.com',
'name': '张技术管理',
'department': '技术部门',
'role': 'admin',
},
{
'username': 'admin2',
'password': 'admin123',
'email': 'admin2@example.com',
'name': '王产品管理',
'department': '产品部门',
'role': 'admin',
},
{
'username': 'admin3',
'password': 'admin123',
'email': 'admin3@example.com',
'name': '李商务管理',
'department': '商务部门',
'role': 'admin',
},
{
'username': 'admin4',
'password': 'admin123',
'email': 'admin4@example.com',
'name': '赵HR管理',
'department': 'HR',
'role': 'admin',
}
]
for admin_data in admins:
admin, created = User.objects.get_or_create(
username=admin_data['username'],
defaults={
'email': admin_data['email'],
'name': admin_data['name'],
'role': admin_data['role'],
'department': admin_data['department'],
'is_staff': True,
'is_superuser': True,
'last_login': timezone.now()
}
)
if created:
admin.set_password('admin123')
admin.set_password(admin_data['password'])
admin.save()
token = Token.objects.create(user=admin)
self.stdout.write(self.style.SUCCESS(
@ -31,22 +68,62 @@ class Command(BaseCommand):
else:
self.stdout.write(self.style.WARNING(f'管理员用户已存在: {admin.username}'))
# 创建组长 - 研发部组长和测试部组长
# 创建组长 - 7个部门的组长
leaders = [
{
'username': 'leader1',
'password': 'leader123',
'email': 'leader1@example.com',
'name': '李研发',
'department': '研发部',
'name': '陈达人',
'department': '达人部门',
'role': 'leader'
},
{
'username': 'leader2',
'password': 'leader123',
'email': 'leader2@example.com',
'name': '王测试',
'department': '测试部',
'name': '刘商务',
'department': '商务部门',
'role': 'leader'
},
{
'username': 'leader3',
'password': 'leader123',
'email': 'leader3@example.com',
'name': '杨样本',
'department': '样本中心',
'role': 'leader'
},
{
'username': 'leader4',
'password': 'leader123',
'email': 'leader4@example.com',
'name': '黄产品',
'department': '产品部门',
'role': 'leader'
},
{
'username': 'leader5',
'password': 'leader123',
'email': 'leader5@example.com',
'name': '周AI',
'department': 'AI自媒体',
'role': 'leader'
},
{
'username': 'leader6',
'password': 'leader123',
'email': 'leader6@example.com',
'name': '吴HR',
'department': 'HR',
'role': 'leader'
},
{
'username': 'leader7',
'password': 'leader123',
'email': 'leader7@example.com',
'name': '郑技术',
'department': '技术部门',
'role': 'leader'
}
]
@ -73,67 +150,5 @@ class Command(BaseCommand):
else:
self.stdout.write(self.style.WARNING(f'组长用户已存在: {leader.username}'))
# 创建组员 - 2个开发组员2个测试组员
members = [
{
'username': 'member1',
'password': 'member123',
'email': 'member1@example.com',
'name': '赵开发',
'department': '研发部',
'role': 'member',
'group': '前端组'
},
{
'username': 'member2',
'password': 'member123',
'email': 'member2@example.com',
'name': '钱开发',
'department': '研发部',
'role': 'member',
'group': '后端组'
},
{
'username': 'member3',
'password': 'member123',
'email': 'member3@example.com',
'name': '孙测试',
'department': '测试部',
'role': 'member',
'group': '功能测试组'
},
{
'username': 'member4',
'password': 'member123',
'email': 'member4@example.com',
'name': '周测试',
'department': '测试部',
'role': 'member',
'group': '自动化测试组'
}
]
for member_data in members:
member, created = User.objects.get_or_create(
username=member_data['username'],
defaults={
'email': member_data['email'],
'name': member_data['name'],
'role': member_data['role'],
'department': member_data['department'],
'group': member_data['group'],
'is_staff': False,
'last_login': timezone.now()
}
)
if created:
member.set_password(member_data['password'])
member.save()
token = Token.objects.create(user=member)
self.stdout.write(self.style.SUCCESS(
f'成功创建组员用户: {member.username}{member.name}, Token: {token.key}'
))
else:
self.stdout.write(self.style.WARNING(f'组员用户已存在: {member.username}'))
self.stdout.write(self.style.SUCCESS('所有测试用户创建完成!'))

View File

@ -4,6 +4,10 @@ from django.contrib.auth.models import AnonymousUser
from rest_framework.authtoken.models import Token
from django.contrib.auth import get_user_model
import logging
import re
from django.middleware.csrf import CsrfViewMiddleware
from django.conf import settings
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger(__name__)
@ -44,11 +48,28 @@ class TokenAuthMiddleware(BaseMiddleware):
scope['user'] = AnonymousUser()
return await super().__call__(scope, receive, send)
class UserActivityMiddleware:
"""用户活动中间件"""
def __init__(self, get_response):
self.get_response = get_response
class UserActivityMiddleware(MiddlewareMixin):
"""中间件用于记录用户活动"""
def __call__(self, request):
response = self.get_response(request)
return response
def process_request(self, request):
# 可以在这里记录用户活动日志
pass
class CSRFExemptMiddleware(MiddlewareMixin):
"""为特定URL路径豁免CSRF保护的中间件"""
def process_view(self, request, callback, callback_args, callback_kwargs):
# 检查是否有CSRF豁免URL配置
if not hasattr(settings, 'CSRF_EXEMPT_URLS'):
return None
# 获取当前请求的路径
path = request.path_info.lstrip('/')
# 检查是否匹配任何豁免模式
for exempt_pattern in settings.CSRF_EXEMPT_URLS:
if re.match(exempt_pattern, path):
setattr(request, '_dont_enforce_csrf_checks', True)
break
return None

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.5 on 2025-04-23 14:20
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('user_management', '0004_knowledgebasedocument'),
]
operations = [
migrations.AddField(
model_name='chathistory',
name='title',
field=models.CharField(blank=True, default='New chat', help_text='对话标题', max_length=100, null=True),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 5.1.5 on 2025-04-23 16:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('user_management', '0005_chathistory_title'),
]
operations = [
migrations.AddField(
model_name='knowledgebasedocument',
name='uploader_name',
field=models.CharField(default='未知用户', max_length=100, verbose_name='上传者姓名'),
),
]

View File

@ -291,6 +291,8 @@ class ChatHistory(models.Model):
knowledge_base = models.ForeignKey('KnowledgeBase', on_delete=models.CASCADE)
# 用于标识知识库组合的对话
conversation_id = models.CharField(max_length=100, db_index=True)
# 对话标题
title = models.CharField(max_length=100, null=True, blank=True, default='New chat', help_text="对话标题")
parent_id = models.CharField(max_length=100, null=True, blank=True)
role = models.CharField(max_length=20, choices=ROLE_CHOICES)
content = models.TextField()
@ -391,6 +393,7 @@ class ChatHistory(models.Model):
]
}
class UserProfile(models.Model):
"""用户档案模型"""
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
@ -682,6 +685,7 @@ class KnowledgeBaseDocument(models.Model):
document_id = models.CharField(max_length=100, verbose_name='文档ID')
document_name = models.CharField(max_length=255, verbose_name='文档名称')
external_id = models.CharField(max_length=100, verbose_name='外部文档ID')
uploader_name = models.CharField(max_length=100, default="未知用户", verbose_name='上传者姓名')
status = models.CharField(
max_length=20,
default='active',

View File

@ -3,4 +3,6 @@ from . import consumers
websocket_urlpatterns = [
re_path(r'ws/notifications/$', consumers.NotificationConsumer.as_asgi()),
re_path(r'ws/chat/$', consumers.ChatConsumer.as_asgi()),
re_path(r'ws/chat/stream/$', consumers.ChatStreamConsumer.as_asgi()),
]

View File

@ -253,87 +253,136 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
queryset = ChatHistory.objects.all()
def get_queryset(self):
"""确保用户只能看到自己的未删除的聊天记录"""
return ChatHistory.objects.filter(
user=self.request.user,
"""确保用户只能看到自己的未删除的聊天记录以及有权限的知识库关联的聊天记录"""
user = self.request.user
# 当前用户的聊天记录
user_records = ChatHistory.objects.filter(
user=user,
is_deleted=False
)
def list(self, request):
"""获取对话列表概览"""
try:
# 获取查询参数
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 获取用户有权限的知识库ID列表
accessible_kb_ids = []
for kb in KnowledgeBase.objects.all():
if self.check_knowledge_base_permission(kb, user, 'read'):
accessible_kb_ids.append(kb.id)
# 获取所有对话的概览
latest_chats = self.get_queryset().values(
'conversation_id'
).annotate(
latest_id=Max('id'),
message_count=Count('id'),
last_message=Max('created_at')
# 其他用户创建的、但当前用户有权限访问的知识库的聊天记录
others_records = ChatHistory.objects.filter(
knowledge_base_id__in=accessible_kb_ids,
is_deleted=False
).exclude(user=user) # 排除用户自己的记录,避免重复
# 合并两个查询集
combined_queryset = user_records | others_records
return combined_queryset
def list(self, request):
"""获取对话列表"""
try:
# 获取用户所有的对话
unique_conversations = self.get_queryset().values('conversation_id').annotate(
last_message=Max('created_at'),
message_count=Count('id')
).order_by('-last_message')
# 计算分页
total = latest_chats.count()
start = (page - 1) * page_size
end = start + page_size
chats = latest_chats[start:end]
# 构建结果列表
conversation_list = []
for conv in unique_conversations:
# 获取对话中的第一条消息,用于显示标题和知识库信息
first_message = self.get_queryset().filter(
conversation_id=conv['conversation_id']
).order_by('created_at').first()
results = []
for chat in chats:
# 获取最新消息记录
latest_record = ChatHistory.objects.get(id=chat['latest_id'])
if not first_message:
continue
# 从metadata中获取完整的知识库信息
# 获取知识库信息
dataset_info = []
if latest_record.metadata:
dataset_id_list = latest_record.metadata.get('dataset_id_list', [])
dataset_names = latest_record.metadata.get('dataset_names', [])
# 如果有知识库ID列表
if dataset_id_list:
# 如果同时有名称列表且长度匹配
if dataset_names and len(dataset_names) == len(dataset_id_list):
dataset_info = [{
'id': str(id),
'name': name
} for id, name in zip(dataset_id_list, dataset_names)]
else:
# 如果没有名称列表则只返回ID
datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list)
dataset_info = [{
'id': str(ds.id),
'name': ds.name
} for ds in datasets]
results.append({
'conversation_id': chat['conversation_id'],
'message_count': chat['message_count'],
'last_message': latest_record.content,
'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'),
'dataset_id_list': [ds['id'] for ds in dataset_info], # 添加完整的知识库ID列表
'datasets': dataset_info # 包含ID和名称的完整信息
if first_message.metadata and 'dataset_id_list' in first_message.metadata:
# 获取用户有权限访问的知识库
valid_kb_ids = []
for kb_id in first_message.metadata['dataset_id_list']:
try:
kb = KnowledgeBase.objects.get(id=kb_id)
if self.check_knowledge_base_permission(kb, request.user, 'read'):
valid_kb_ids.append(kb_id)
dataset_info.append({
'id': str(kb.id),
'name': kb.name,
'type': kb.type
})
except KnowledgeBase.DoesNotExist:
continue
# 获取最近的消息用于预览
last_user_message = self.get_queryset().filter(
conversation_id=conv['conversation_id'],
role='user'
).order_by('-created_at').first()
# 处理对话标题 - 优先使用已有标题,否则尝试生成新标题
title = first_message.title
# 如果标题为空或为默认值'New chat',尝试生成新标题
if not title or title == 'New chat':
# 找到对话中的第一对问答
messages = list(self.get_queryset().filter(
conversation_id=conv['conversation_id']
).order_by('created_at'))
user_message = None
assistant_message = None
for i in range(len(messages)-1):
if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id):
user_message = messages[i]
assistant_message = messages[i+1]
break
if user_message and assistant_message:
# 调用DeepSeek API生成标题
generated_title = self._generate_conversation_title_from_deepseek(
user_message.content,
assistant_message.content
)
if generated_title:
# 更新所有相关记录的标题
title = generated_title
ChatHistory.objects.filter(
conversation_id=conv['conversation_id']
).update(title=generated_title)
# 如果生成失败使用对话ID的一部分作为临时标题
if not title:
title = f"对话 {conv['conversation_id'][:8]}"
# 构建返回结果
conversation_data = {
'conversation_id': conv['conversation_id'],
'last_message': conv['last_message'].strftime('%Y-%m-%d %H:%M:%S'),
'message_count': conv['message_count'],
'title': title,
'preview': last_user_message.content[:100] if last_user_message else "",
'datasets': dataset_info
}
conversation_list.append(conversation_data)
# 返回结果
return Response({
'code': 200,
'message': '获取成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': results
}
'data': conversation_list
})
except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
logger.error(f"获取对话列表失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取聊天记录失败: {str(e)}',
'message': f"获取对话列表失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@ -349,7 +398,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 获取对话历史
# 获取对话历史,确保按时间顺序排序
messages = self.get_queryset().filter(
conversation_id=conversation_id
).order_by('created_at')
@ -380,18 +429,60 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
'type': ds.type
} for ds in accessible_datasets]
# 处理对话标题 - 优先使用已有标题,否则尝试生成新标题
title = first_message.title
# 如果标题为空或为默认值'New chat',尝试生成新标题
if not title or title == 'New chat':
# 尝试找到一对完整的问答
user_message = None
assistant_message = None
for i in range(len(messages)-1):
if messages[i].role == 'user' and messages[i+1].role == 'assistant' and messages[i+1].parent_id == str(messages[i].id):
user_message = messages[i]
assistant_message = messages[i+1]
break
if user_message and assistant_message:
# 调用DeepSeek API生成标题
generated_title = self._generate_conversation_title_from_deepseek(
user_message.content,
assistant_message.content
)
if generated_title:
# 更新所有相关记录的标题
title = generated_title
ChatHistory.objects.filter(
conversation_id=conversation_id
).update(title=generated_title)
# 如果生成失败使用对话ID的一部分作为临时标题
if not title:
title = f"对话 {conversation_id[:8]}"
# 构建消息列表包含parent_id信息
message_list = []
for msg in messages:
message_data = {
'id': str(msg.id),
'parent_id': msg.parent_id, # 添加parent_id
'role': msg.role,
'content': msg.content,
'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': msg.metadata # 添加metadata
}
message_list.append(message_data)
return Response({
'code': 200,
'message': '获取成功',
'data': {
'conversation_id': conversation_id,
'title': title, # 返回标题
'datasets': dataset_info,
'messages': [{
'id': str(msg.id),
'role': msg.role,
'content': msg.content,
'created_at': msg.created_at.strftime('%Y-%m-%d %H:%M:%S')
} for msg in messages]
'messages': message_list
}
})
@ -512,6 +603,9 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
conversation_id = str(uuid.uuid4())
logger.info(f"创建新的会话ID: {conversation_id}")
# 获取自定义标题(如果有)
title = data.get('title', 'New chat')
# 准备metadata (仍然保存知识库名称用于内部处理)
metadata = {
'dataset_id_list': [str(id) for id in dataset_ids],
@ -523,6 +617,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
'message': '会话创建成功',
'data': {
'conversation_id': conversation_id,
'title': title, # 添加标题字段
'dataset_id_list': metadata['dataset_id_list']
}
})
@ -681,11 +776,15 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
'dataset_names': [kb.name for kb in knowledge_bases]
}
# 检查是否有自定义标题
title = data.get('title', 'New chat')
# 创建用户问题记录
question_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_bases[0], # 使用第一个知识库作为主知识库
conversation_id=str(conversation_id),
title=title, # 设置标题
role='user',
content=data['question'],
metadata=metadata
@ -696,7 +795,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
if use_stream:
# 创建流式响应
return StreamingHttpResponse(
response = StreamingHttpResponse(
self._stream_answer_from_external_api(
conversation_id=str(conversation_id),
question_record=question_record,
@ -705,8 +804,15 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
question=data['question'],
metadata=metadata
),
content_type='text/event-stream'
content_type='text/event-stream',
status=status.HTTP_201_CREATED # 修改状态码为201
)
# 添加禁用缓存的头部
response['Cache-Control'] = 'no-cache, no-store'
response['Connection'] = 'keep-alive'
return response
else:
# 使用非流式输出
logger.info("使用非流式输出模式")
@ -725,25 +831,45 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
user=request.user,
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=title, # 设置标题
parent_id=str(question_record.id),
role='assistant',
content=answer,
metadata=metadata
)
# 如果是新会话的第一条消息,并且没有自定义标题,则自动生成标题
should_generate_title = not existing_records.exists() and (not title or title == 'New chat')
if should_generate_title:
try:
generated_title = self._generate_conversation_title_from_deepseek(
data['question'],
answer
)
if generated_title:
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
title = generated_title
except Exception as e:
logger.error(f"自动生成标题失败: {str(e)}")
# 继续执行,不影响主流程
return Response({
'code': 200,
'code': 200, # 修改状态码为201
'message': '成功',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加标题字段
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': answer,
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
}
})
}, status=status.HTTP_200_CREATED) # 修改状态码为201
except Exception as e:
logger.error(f"创建聊天记录失败: {str(e)}")
@ -760,11 +886,15 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
# 确保所有ID都是字符串
dataset_external_ids = [str(id) if isinstance(id, uuid.UUID) else id for id in dataset_external_id_list]
# 获取标题
title = question_record.title or 'New chat'
# 创建AI回答记录对象稍后更新内容
answer_record = ChatHistory.objects.create(
user=question_record.user,
knowledge_base=knowledge_bases[0],
conversation_id=str(conversation_id),
title=title, # 设置标题
parent_id=str(question_record.id),
role='assistant',
content="", # 初始内容为空
@ -803,7 +933,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
},
"problem_optimization": False
},
headers={"Content-Type": "application/json"}
headers={"Content-Type": "application/json"},
)
if chat_response.status_code != 200:
@ -831,7 +962,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
url=message_url,
json={"message": question, "re_chat": False, "stream": True},
headers={"Content-Type": "application/json"},
stream=True # 启用流式传输
stream=True, # 启用流式传输
)
if message_request.status_code != 200:
@ -872,11 +1004,12 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
# 构建响应数据
response_data = {
'code': 200,
'code': 200, # 修改状态码为201
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加标题字段
'content': content_part,
'is_end': data.get('is_end', False)
}
@ -892,13 +1025,46 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
answer_record.content = full_content.strip()
answer_record.save()
# 先检查当前conversation_id是否已有有效标题
current_title = ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
# 如果已有有效标题,则复用
if current_title:
title = current_title
logger.info(f"复用已有标题: {title}")
else:
# 没有有效标题时,直接基于当前问题和回答生成标题
try:
# 直接使用当前的问题和完整的AI回答来生成标题
generated_title = self._generate_conversation_title_from_deepseek(
question, full_content.strip()
)
if generated_title:
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=str(conversation_id)
).update(title=generated_title)
title = generated_title
logger.info(f"成功生成标题: {title}")
else:
title = "新对话" # 如果生成失败,使用默认标题
logger.warning("生成标题失败,使用默认标题")
except Exception as e:
logger.error(f"自动生成标题失败: {str(e)}")
title = "新对话" # 如果出错,使用默认标题
# 发送完整内容的最终响应
final_response = {
'code': 200,
'code': 200, # 修改状态码为201
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'title': title, # 添加生成的标题
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
@ -929,11 +1095,11 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
full_content += content_part
response_data = {
'code': 200,
'code': 200, # 修改状态码为201
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': str(conversation_id),
'conversation_id': str(conversation_id), # 添加标题字段
'content': content_part,
'is_end': data.get('is_end', False)
}
@ -1008,7 +1174,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
chat_response = requests.post(
url=f"{settings.API_BASE_URL}/api/application/chat/open",
json=chat_request_data,
headers={"Content-Type": "application/json"}
headers={"Content-Type": "application/json"},
)
logger.info(f"API响应状态码: {chat_response.status_code}")
@ -1038,7 +1205,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
message_response = requests.post(
url=f"{settings.API_BASE_URL}/api/application/chat_message/{chat_id}",
json=message_request_data,
headers={"Content-Type": "application/json"}
headers={"Content-Type": "application/json"},
)
if message_response.status_code != 200:
@ -1295,6 +1463,7 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
export_response = requests.get(
url=export_url,
stream=True # 使用流式传输处理大文件
)
@ -1351,7 +1520,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
# 调用外部API
response = requests.get(
url=api_url,
params=params
params=params,
)
if response.status_code != 200:
@ -1537,7 +1707,8 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
response = requests.get(
url=url,
params=params
params=params,
)
if response.status_code != 200:
@ -1617,6 +1788,116 @@ class ChatHistoryViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'], url_path='generate-conversation-title')
def generate_conversation_title(self, request):
"""更新会话标题"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查对话是否存在
messages = self.get_queryset().filter(
conversation_id=conversation_id,
is_deleted=False,
user=request.user
).order_by('created_at')
if not messages.exists():
return Response({
'code': 404,
'message': '对话不存在或无权访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 检查是否有自定义标题参数
custom_title = request.query_params.get('title')
if not custom_title:
return Response({
'code': 400,
'message': '缺少title参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=conversation_id,
user=request.user
).update(title=custom_title)
return Response({
'code': 200,
'message': '更新会话标题成功',
'data': {
'conversation_id': conversation_id,
'title': custom_title
}
})
except Exception as e:
logger.error(f"更新会话标题失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"更新会话标题失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _generate_conversation_title_from_deepseek(self, user_question, assistant_answer):
"""调用SiliconCloud API生成会话标题直接基于当前问题和回答内容"""
try:
# 从Django设置中获取API密钥
api_key = settings.SILICON_CLOUD_API_KEY
if not api_key:
return "新对话"
# 构建提示信息
prompt = f"请根据用户的问题和助手的回答生成一个简短的对话标题不超过20个字\n\n用户问题: {user_question}\n\n助手回答: {assistant_answer}"
import requests
url = "https://api.siliconflow.cn/v1/chat/completions"
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"stream": False,
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.7,
"top_k": 50,
"frequency_penalty": 0.5,
"n": 1,
"stop": [],
"messages": [
{
"role": "user",
"content": prompt
}
]
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
response_data = response.json()
if response.status_code == 200 and 'choices' in response_data and response_data['choices']:
title = response_data['choices'][0]['message']['content'].strip()
return title[:50] # 截断过长的标题
else:
logger.error(f"生成标题时出错: {response.text}")
return "新对话"
except Exception as e:
logger.exception(f"生成对话标题时发生错误: {str(e)}")
return "新对话"
class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
serializer_class = KnowledgeBaseSerializer
@ -1657,28 +1938,9 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
serializer = self.get_serializer(paginated_queryset, many=True)
data = serializer.data
# 获取文档数量统计
kb_ids = [kb.id for kb in paginated_queryset]
doc_counts = KnowledgeBaseDocument.objects.filter(
knowledge_base_id__in=kb_ids,
status='active'
).values('knowledge_base_id').annotate(
count=Count('id')
)
# 创建文档数量映射字典
doc_count_map = {
str(item['knowledge_base_id']): item['count']
for item in doc_counts
}
# 为每个知识库添加权限信息和文档数量
# 为每个知识库添加权限信息
user = request.user
for item in data:
# 添加文档数量
kb_id = item['id']
item['document_count'] = doc_count_map.get(kb_id, 0)
# 获取必要的知识库属性
kb_type = item['type']
department = item.get('department')
@ -2060,7 +2322,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
response = requests.put(
f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}',
json=api_data,
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
@ -2072,6 +2335,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
logger.info(f"外部知识库更新成功: {instance.external_id}")
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
@ -2120,24 +2385,31 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
"data": None
}, status=status.HTTP_403_FORBIDDEN)
with transaction.atomic():
# 删除外部知识库
# 删除外部知识库(如果存在)
external_delete_success = True
external_error_message = None
if instance.external_id:
try:
self._delete_external_dataset(instance.external_id)
logger.info(f"外部知识库删除成功: {instance.external_id}")
except ExternalAPIError as e:
logger.error(f"删除外部知识库失败: {str(e)}")
return Response({
"code": 500,
"message": f"删除外部知识库失败: {str(e)}",
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# 记录错误但继续执行本地删除
external_delete_success = False
external_error_message = str(e)
logger.warning(f"外部知识库删除失败,将继续删除本地知识库: {str(e)}")
# 删除本地知识库
self.perform_destroy(instance)
logger.info(f"本地知识库删除成功: id={instance.id}, name={instance.name}")
# 如果外部知识库删除失败,返回警告消息
if not external_delete_success:
return Response({
"code": 200,
"message": f"知识库已删除,但外部知识库删除失败: {external_error_message}",
"data": None
})
return Response({
"code": 200,
"message": "知识库删除成功",
@ -2159,6 +2431,59 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _delete_external_dataset(self, external_id):
"""删除外部知识库"""
try:
if not external_id:
logger.warning("外部知识库ID为空跳过删除")
return True
response = requests.delete(
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
headers={'Content-Type': 'application/json'},
)
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
# 检查响应状态码
if response.status_code == 404:
logger.warning(f"外部知识库不存在: {external_id}")
return True # 如果知识库不存在,也视为删除成功
elif response.status_code not in [200, 204]:
logger.warning(f"删除外部知识库状态码异常: {response.status_code}, {response.text}")
return True # 即使状态码异常,也允许继续删除本地知识库
# 检查业务状态码
try:
api_response = response.json()
if api_response.get('code') != 200:
# 如果是因为ID不存在也视为成功
if "不存在" in api_response.get('message', ''):
logger.warning(f"外部知识库ID不存在视为删除成功: {external_id}")
return True
logger.warning(f"业务处理返回非200状态码: {api_response.get('code')}, {api_response.get('message')}")
return True # 不再抛出异常,允许本地删除继续
logger.info(f"外部知识库删除成功: {external_id}")
return True
except ValueError:
# 如果无法解析 JSON但状态码是 200也认为成功
logger.warning(f"外部知识库删除响应无法解析JSON但状态码为200视为成功: {external_id}")
return True
except requests.exceptions.Timeout:
logger.error(f"删除外部知识库超时: {external_id}")
# 不再抛出异常,允许本地删除继续
return False
except requests.exceptions.RequestException as e:
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
# 不再抛出异常,允许本地删除继续
return False
except Exception as e:
logger.error(f"删除外部知识库其他错误: {external_id}, error={str(e)}")
# 不再抛出异常,允许本地删除继续
return False
@action(detail=True, methods=['get'])
def permissions(self, request, pk=None):
"""获取用户对特定知识库的权限"""
@ -2612,10 +2937,11 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
logger.info(f"调用分割API URL: {url}")
logger.info(f"请求字段: {list(files_data.keys())}")
# 发送请求 - 移除timeout参数
# 发送请求
response = requests.post(
url,
files=files_data
files=files_data,
)
# 记录请求头和响应信息,方便排查问题
@ -2845,7 +3171,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
knowledge_base=instance,
document_id=document_id,
document_name=doc_name,
external_id=document_id
external_id=document_id,
uploader_name=user.name
)
saved_documents.append({
@ -2908,7 +3235,7 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
# 记录请求数据,方便调试
logger.info(f"上传文档数据: 文档名={doc_data.get('name')}, 段落数={len(doc_data.get('paragraphs', []))}")
# 发送请求,不设置超时限制
# 发送请求
response = requests.post(url, json=doc_data)
# 记录响应结果
@ -2985,7 +3312,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
response = requests.post(
f'{settings.API_BASE_URL}/api/dataset',
json=api_data,
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
@ -3001,6 +3329,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
return dataset_id
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
@ -3014,7 +3344,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
response = requests.delete(
f'{settings.API_BASE_URL}/api/dataset/{external_id}',
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)
logger.info(f"删除外部知识库响应: status_code={response.status_code}, response={response.text}")
@ -3043,6 +3374,9 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
logger.warning(f"外部知识库删除响应无法解析JSON但状态码为200视为成功: {external_id}")
return True
except requests.exceptions.Timeout:
logger.error(f"删除外部知识库超时: {external_id}")
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
logger.error(f"删除外部知识库请求异常: {external_id}, error={str(e)}")
raise ExternalAPIError(f"API请求失败: {str(e)}")
@ -3078,7 +3412,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
url = f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}/document'
response = requests.get(
url,
headers={'Content-Type': 'application/json'}
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
@ -3138,7 +3473,8 @@ class KnowledgeBaseViewSet(KnowledgeBasePermissionMixin, viewsets.ModelViewSet):
# 添加外部API返回的额外信息
"char_length": next((d.get('char_length', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
"paragraph_count": next((d.get('paragraph_count', 0) for d in external_documents if d.get('id') == doc.external_id), 0),
"is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True)
"is_active": next((d.get('is_active', True) for d in external_documents if d.get('id') == doc.external_id), True),
"uploader_name": doc.uploader_name
} for doc in documents]
return Response({
@ -4260,13 +4596,23 @@ class LoginView(APIView):
class RegisterView(APIView):
"""用户注册视图"""
permission_classes = [AllowAny]
authentication_classes = [] # 清空认证类
def post(self, request):
try:
# 检查是否允许注册
from django.conf import settings
if not getattr(settings, 'ALLOW_REGISTRATION', True):
return Response({
"code": 403,
"message": "系统当前不允许注册新用户",
"data": None
}, status=status.HTTP_403_FORBIDDEN)
data = request.data
# 检查必填字段
required_fields = ['username', 'password', 'email', 'role', 'department', 'name']
required_fields = ['username', 'password', 'email', 'role', 'name']
for field in required_fields:
if not data.get(field):
return Response({
@ -4285,32 +4631,6 @@ class RegisterView(APIView):
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证部门是否存在
if data['department'] not in settings.DEPARTMENT_GROUPS:
return Response({
"code": 400,
"message": f"无效的部门,可选部门: {', '.join(settings.DEPARTMENT_GROUPS.keys())}",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 如果是组员,验证小组
if data['role'] == 'member':
if not data.get('group'):
return Response({
"code": 400,
"message": "组员必须指定所属小组",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证小组是否存在且属于指定部门
valid_groups = settings.DEPARTMENT_GROUPS.get(data['department'], [])
if data['group'] not in valid_groups:
return Response({
"code": 400,
"message": f"无效的小组,{data['department']}的可选小组: {', '.join(valid_groups)}",
"data": None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查用户名是否已存在
if User.objects.filter(username=data['username']).exists():
return Response({
@ -4351,9 +4671,9 @@ class RegisterView(APIView):
email=data['email'],
password=data['password'],
role=data['role'],
department=data['department'],
department=data.get('department'), # 不再强制要求部门
name=data['name'],
group=data.get('group') if data['role'] == 'member' else None,
group=data.get('group'), # 不再强制要求小组
is_staff=False,
is_superuser=False
)
@ -4577,6 +4897,7 @@ def change_password(request):
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@csrf_exempt
@api_view(['POST'])
@permission_classes([AllowAny])
def user_register(request):
@ -4592,7 +4913,7 @@ def user_register(request):
data = request.data
# 检查必填字段
required_fields = ['username', 'password', 'email', 'role', 'department', 'name']
required_fields = ['username', 'password', 'email', 'role', 'name']
for field in required_fields:
if not data.get(field):
return Response({
@ -4610,14 +4931,6 @@ def user_register(request):
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 如果是组员,必须指定小组
if data['role'] == 'member' and not data.get('group'):
return Response({
'code': 400,
'message': '组员必须指定所属小组',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查用户名是否已存在
if User.objects.filter(username=data['username']).exists():
return Response({
@ -4658,9 +4971,9 @@ def user_register(request):
email=data['email'],
password=data['password'],
role=data['role'],
department=data['department'],
department=data.get('department'), # 不再强制要求部门
name=data['name'],
group=data.get('group') if data['role'] == 'member' else None,
group=data.get('group'), # 不再强制要求小组
is_staff=False,
is_superuser=False
)
@ -5014,4 +5327,3 @@ def user_list(request):
'message': f'获取用户列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)