daren/apps/chat/views.py
2025-05-29 10:11:19 +08:00

795 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import json
import traceback
import uuid
from datetime import datetime
from django.db.models import Q, Max, Count
from django.http import HttpResponse, StreamingHttpResponse
from rest_framework import viewsets, status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.decorators import action
from apps.user.models import User
from apps.knowledge_base.models import KnowledgeBase
from apps.chat.models import ChatHistory
from apps.chat.serializers import ChatHistorySerializer
from apps.common.services.chat_service import ChatService
from apps.user.authentication import CustomTokenAuthentication
from apps.chat.services.chat_api import (
ExternalAPIError, stream_chat_answer, get_chat_answer, generate_conversation_title,
get_hit_test_documents, generate_conversation_title_from_deepseek
)
logger = logging.getLogger(__name__)
class ChatHistoryViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated]
authentication_classes = [CustomTokenAuthentication]
serializer_class = ChatHistorySerializer
queryset = ChatHistory.objects.all()
def get_queryset(self):
"""确保用户只能看到自己的未删除的聊天记录以及有权限的知识库关联的聊天记录"""
user = self.request.user
accessible_kb_ids = [
kb.id for kb in KnowledgeBase.objects.all()
if self.check_knowledge_base_permission(kb, user, 'read')
]
return ChatHistory.objects.filter(
Q(user=user) | Q(knowledge_base_id__in=accessible_kb_ids),
is_deleted=False
)
def list(self, request):
"""获取对话列表概览"""
try:
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
latest_chats = self.get_queryset().values(
'conversation_id'
).annotate(
latest_id=Max('id'),
message_count=Count('id'),
last_message=Max('created_at')
).order_by('-last_message')
total = latest_chats.count()
start = (page - 1) * page_size
end = start + page_size
chats = latest_chats[start:end]
results = []
for chat in chats:
latest_record = ChatHistory.objects.get(id=chat['latest_id'])
dataset_info = []
if latest_record.metadata:
dataset_id_list = latest_record.metadata.get('dataset_id_list', [])
dataset_names = latest_record.metadata.get('dataset_names', [])
if dataset_id_list:
if dataset_names and len(dataset_names) == len(dataset_id_list):
dataset_info = [
{'id': str(id), 'name': name}
for id, name in zip(dataset_id_list, dataset_names)
]
else:
datasets = KnowledgeBase.objects.filter(id__in=dataset_id_list)
dataset_info = [
{'id': str(ds.id), 'name': ds.name}
for ds in datasets
]
results.append({
'conversation_id': chat['conversation_id'],
'message_count': chat['message_count'],
'last_message': latest_record.content,
'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S'),
'dataset_id_list': [ds['id'] for ds in dataset_info],
'datasets': dataset_info
})
return Response({
'code': 200,
'message': '获取成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': results
}
})
except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def conversation_detail(self, request):
"""获取特定对话的详细信息"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
chat_service = ChatService()
result = chat_service.get_conversation_detail(request.user, conversation_id)
return Response({
'code': 200,
'message': '获取成功',
'data': result
})
except ValueError as e:
return Response({
'code': 404,
'message': str(e),
'data': None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
logger.error(f"获取对话详情失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'获取对话详情失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def available_datasets(self, request):
"""获取用户可访问的知识库列表"""
try:
user = request.user
accessible_datasets = [
dataset for dataset in KnowledgeBase.objects.all()
if self.check_knowledge_base_permission(dataset, user, 'read')
]
return Response({
'code': 200,
'message': '获取成功',
'data': [
{
'id': str(ds.id),
'name': ds.name,
'type': ds.type,
'department': ds.department,
'description': ds.desc
}
for ds in accessible_datasets
]
})
except Exception as e:
logger.error(f"获取可用知识库列表失败: {str(e)}")
return Response({
'code': 500,
'message': f'获取可用知识库列表失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['post'])
def create_conversation(self, request):
"""创建会话 - 先选择知识库创建会话ID不发送问题"""
try:
data = request.data
# 检查知识库ID支持dataset_id或dataset_id_list格式
dataset_ids = []
if 'dataset_id' in data:
dataset_id = data['dataset_id']
# 直接使用标准UUID格式
dataset_ids.append(str(dataset_id))
elif 'dataset_id_list' in data and isinstance(data['dataset_id_list'], (list, str)):
# 处理可能的字符串格式
if isinstance(data['dataset_id_list'], str):
try:
# 尝试解析JSON字符串
dataset_list = json.loads(data['dataset_id_list'])
if isinstance(dataset_list, list):
dataset_ids = [str(id) for id in dataset_list]
except json.JSONDecodeError:
# 如果解析失败可能是单个ID
dataset_ids = [str(data['dataset_id_list'])]
else:
# 如果已经是列表直接使用标准UUID格式
dataset_ids = [str(id) for id in data['dataset_id_list']]
else:
return Response({
'code': 400,
'message': '缺少必填字段: dataset_id 或 dataset_id_list',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
if not dataset_ids:
return Response({
'code': 400,
'message': '至少需要提供一个知识库ID',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 验证所有知识库
user = request.user
knowledge_bases = [] # 存储所有知识库对象
for kb_id in dataset_ids:
try:
knowledge_base = KnowledgeBase.objects.filter(id=kb_id).first()
if not knowledge_base:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
knowledge_bases.append(knowledge_base)
# 使用统一的权限检查方法
if not self.check_knowledge_base_permission(knowledge_base, user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {knowledge_base.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
except Exception as e:
return Response({
'code': 400,
'message': f'处理知识库ID出错: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建一个新的会话ID
conversation_id = str(uuid.uuid4())
logger.info(f"创建新的会话ID: {conversation_id}")
# 准备metadata (仍然保存知识库名称用于内部处理)
metadata = {
'dataset_id_list': [str(id) for id in dataset_ids],
'dataset_names': [kb.name for kb in knowledge_bases]
}
return Response({
'code': 200,
'message': '会话创建成功',
'data': {
'conversation_id': conversation_id,
'dataset_id_list': metadata['dataset_id_list']
}
})
except Exception as e:
logger.error(f"创建会话失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'创建会话失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def create(self, request):
"""创建聊天记录"""
try:
chat_service = ChatService()
question_record, conversation_id, metadata, knowledge_bases, external_id_list = chat_service.create_chat_record(
request.user, request.data, request.data.get('conversation_id')
)
use_stream = request.data.get('stream', True)
title = request.data.get('title', 'New chat')
if use_stream:
def stream_response():
answer_record = ChatHistory.objects.create(
user=question_record.user,
knowledge_base=knowledge_bases[0],
conversation_id=conversation_id,
title=title,
parent_id=str(question_record.id),
role='assistant',
content="",
metadata=metadata
)
yield f"data: {json.dumps({'code': 200, 'message': '开始流式传输', 'data': {'id': str(answer_record.id), 'conversation_id': conversation_id, 'content': '', 'is_end': False}})}\n\n"
full_content = ""
for data in stream_chat_answer(conversation_id, request.data['question'], external_id_list, metadata):
parsed_data = json.loads(data[5:-2]) # 移除"data: "和"\n\n"
if parsed_data['code'] == 200 and 'content' in parsed_data['data']:
content_part = parsed_data['data']['content']
full_content += content_part
response_data = {
'code': 200,
'message': 'partial',
'data': {
'id': str(answer_record.id),
'conversation_id': conversation_id,
'title': title,
'content': content_part,
'is_end': parsed_data['data']['is_end']
}
}
yield f"data: {json.dumps(response_data)}\n\n"
if parsed_data['data']['is_end']:
answer_record.content = full_content.strip()
answer_record.save()
current_title = ChatHistory.objects.filter(
conversation_id=conversation_id
).exclude(
title__in=["New chat", "新对话", ""]
).values_list('title', flat=True).first()
if current_title:
title_updated = current_title
else:
try:
generated_title = generate_conversation_title(
request.data['question'], full_content.strip()
)
if generated_title:
ChatHistory.objects.filter(
conversation_id=conversation_id
).update(title=generated_title)
title_updated = generated_title
else:
title_updated = "新对话"
except ExternalAPIError as e:
logger.error(f"自动生成标题失败: {str(e)}")
title_updated = "新对话"
final_response = {
'code': 200,
'message': '完成',
'data': {
'id': str(answer_record.id),
'conversation_id': conversation_id,
'title': title_updated,
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': full_content.strip(),
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'is_end': True
}
}
yield f"data: {json.dumps(final_response)}\n\n"
break
elif parsed_data['code'] != 200:
yield data
break
if full_content:
try:
answer_record.content = full_content.strip()
answer_record.save()
except Exception as save_error:
logger.error(f"保存部分内容失败: {str(save_error)}")
response = StreamingHttpResponse(
stream_response(),
content_type='text/event-stream',
status=status.HTTP_201_CREATED
)
response['Cache-Control'] = 'no-cache, no-store'
response['Connection'] = 'keep-alive'
return response
else:
logger.info("使用非流式输出模式")
try:
answer = get_chat_answer(external_id_list, request.data['question'])
except ExternalAPIError as e:
logger.error(f"获取回答失败: {str(e)}")
return Response({
'code': 500,
'message': f'获取回答失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
if answer is None:
return Response({
'code': 500,
'message': '获取回答失败',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
answer_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_bases[0],
conversation_id=conversation_id,
title=title,
parent_id=str(question_record.id),
role='assistant',
content=answer,
metadata=metadata
)
existing_records = ChatHistory.objects.filter(conversation_id=conversation_id)
should_generate_title = not existing_records.exclude(id=question_record.id).exists() and (not title or title == 'New chat')
if should_generate_title:
try:
generated_title = generate_conversation_title(
request.data['question'], answer
)
if generated_title:
ChatHistory.objects.filter(conversation_id=conversation_id).update(title=generated_title)
title = generated_title
except ExternalAPIError as e:
logger.error(f"自动生成标题失败: {str(e)}")
return Response({
'code': 200,
'message': '成功',
'data': {
'id': str(answer_record.id),
'conversation_id': conversation_id,
'title': title,
'dataset_id_list': metadata.get('dataset_id_list', []),
'dataset_names': metadata.get('dataset_names', []),
'role': 'assistant',
'content': answer,
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
}
}, status=status.HTTP_201_CREATED)
except ValueError as e:
return Response({
'code': 400,
'message': str(e),
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
logger.error(f"创建聊天记录失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'创建聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['post'])
def hit_test(self, request):
"""获取问题与知识库文档的匹配度"""
try:
data = request.data
if 'question' not in data or 'dataset_id_list' not in data or not data['dataset_id_list']:
return Response({
'code': 400,
'message': '缺少必填字段: question 或 dataset_id_list',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
question = data['question']
dataset_ids = data['dataset_id_list']
if not isinstance(dataset_ids, list):
try:
dataset_ids = json.loads(dataset_ids)
if not isinstance(dataset_ids, list):
dataset_ids = [dataset_ids]
except (json.JSONDecodeError, TypeError):
dataset_ids = [dataset_ids]
external_id_list = []
for kb_id in dataset_ids:
kb = KnowledgeBase.objects.filter(id=kb_id).first()
if not kb:
return Response({
'code': 404,
'message': f'知识库不存在: {kb_id}',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
if not self.check_knowledge_base_permission(kb, request.user, 'read'):
return Response({
'code': 403,
'message': f'无权访问知识库: {kb.name}',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
if kb.external_id:
external_id_list.append(str(kb.external_id))
if not external_id_list:
return Response({
'code': 400,
'message': '没有有效的知识库external_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
all_documents = []
for dataset_id in external_id_list:
try:
doc_info = get_hit_test_documents(dataset_id, question)
if doc_info:
all_documents.extend(doc_info)
except ExternalAPIError as e:
logger.error(f"调用hit_test失败: 知识库ID={dataset_id}, 错误={str(e)}")
continue # 宽松处理,跳过失败的知识库
all_documents = sorted(all_documents, key=lambda x: x.get('similarity', 0), reverse=True)
return Response({
'code': 200,
'message': '成功',
'data': {
'question': question,
'matched_documents': all_documents,
'total_count': len(all_documents)
}
})
except Exception as e:
logger.error(f"hit_test接口调用失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'hit_test接口调用失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _highlight_keyword(self, text, keyword):
"""高亮关键词"""
if not keyword or not text:
return text
return text.replace(keyword, f'<em class="highlight">{keyword}</em>')
def update(self, request, pk=None):
"""更新聊天记录"""
try:
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
data = request.data
updateable_fields = ['content', 'metadata']
if 'content' in data:
record.content = data['content']
if 'metadata' in data:
current_metadata = record.metadata or {}
current_metadata.update(data['metadata'])
record.metadata = current_metadata
record.save()
return Response({
'code': 200,
'message': '更新成功',
'data': {
'id': str(record.id),
'conversation_id': record.conversation_id,
'role': record.role,
'content': record.content,
'metadata': record.metadata,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
})
except Exception as e:
logger.error(f"更新聊天记录失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'更新聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def destroy(self, request, pk=None):
"""删除聊天记录(软删除)"""
try:
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
record.soft_delete()
return Response({
'code': 200,
'message': '删除成功',
'data': None
})
except Exception as e:
logger.error(f"删除聊天记录失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'删除聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
def search(self, request):
"""搜索聊天记录"""
try:
keyword = request.query_params.get('keyword', '').strip()
dataset_id = request.query_params.get('dataset_id')
start_date = request.query_params.get('start_date')
end_date = request.query_params.get('end_date')
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
query = self.get_queryset()
if keyword:
query = query.filter(
Q(content__icontains=keyword) |
Q(knowledge_base__name__icontains=keyword)
)
if dataset_id:
knowledge_base = KnowledgeBase.objects.filter(id=dataset_id).first()
if knowledge_base and not self.check_knowledge_base_permission(knowledge_base, request.user, 'read'):
return Response({
'code': 403,
'message': '无权访问该知识库',
'data': None
}, status=status.HTTP_403_FORBIDDEN)
query = query.filter(knowledge_base__id=dataset_id)
if start_date:
query = query.filter(created_at__gte=start_date)
if end_date:
query = query.filter(created_at__lte=end_date)
total = query.count()
start = (page - 1) * page_size
end = start + page_size
records = query.order_by('-created_at')[start:end]
results = [
{
'id': str(record.id),
'conversation_id': record.conversation_id,
'dataset_id': str(record.knowledge_base.id),
'dataset_name': record.knowledge_base.name,
'role': record.role,
'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': record.metadata,
'highlights': {'content': self._highlight_keyword(record.content, keyword)} if keyword else {}
}
for record in records
]
return Response({
'code': 200,
'message': '搜索成功',
'data': {
'total': total,
'page': page,
'page_size': page_size,
'results': results
}
})
except Exception as e:
logger.error(f"搜索聊天记录失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'搜索失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['delete'])
def delete_conversation(self, request):
"""通过conversation_id删除一组会话"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少必要参数: conversation_id',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
records = self.get_queryset().filter(conversation_id=conversation_id)
if not records.exists():
return Response({
'code': 404,
'message': '未找到该会话或无权限访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
records_count = records.count()
for record in records:
record.soft_delete()
return Response({
'code': 200,
'message': '删除成功',
'data': {
'conversation_id': conversation_id,
'deleted_count': records_count
}
})
except Exception as e:
logger.error(f"删除会话失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f'删除会话失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'], url_path='generate-conversation-title')
def generate_conversation_title(self, request):
"""更新会话标题"""
try:
conversation_id = request.query_params.get('conversation_id')
if not conversation_id:
return Response({
'code': 400,
'message': '缺少conversation_id参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 检查对话是否存在
messages = self.get_queryset().filter(
conversation_id=conversation_id,
is_deleted=False,
user=request.user
).order_by('created_at')
if not messages.exists():
return Response({
'code': 404,
'message': '对话不存在或无权访问',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 检查是否有自定义标题参数
custom_title = request.query_params.get('title')
if not custom_title:
return Response({
'code': 400,
'message': '缺少title参数',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 更新所有相关记录的标题
ChatHistory.objects.filter(
conversation_id=conversation_id,
user=request.user
).update(title=custom_title)
return Response({
'code': 200,
'message': '更新会话标题成功',
'data': {
'conversation_id': conversation_id,
'title': custom_title
}
})
except Exception as e:
logger.error(f"更新会话标题失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'code': 500,
'message': f"更新会话标题失败: {str(e)}",
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)