fix: 修复聊天记录知识库ID字段问题

This commit is contained in:
wanjia 2025-03-03 14:17:02 +08:00
parent 9ef1608809
commit 07a33dbf54
3 changed files with 270 additions and 144 deletions

Binary file not shown.

View File

@ -12,7 +12,8 @@ from .views import (
change_password, change_password,
RegisterView, RegisterView,
LoginView, LoginView,
LogoutView LogoutView,
ChatHistoryViewSet
) )
# 创建路由器 # 创建路由器
@ -22,6 +23,7 @@ router = DefaultRouter()
router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-base') router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-base')
router.register(r'permissions', PermissionViewSet, basename='permission') router.register(r'permissions', PermissionViewSet, basename='permission')
router.register(r'notifications', NotificationViewSet, basename='notification') router.register(r'notifications', NotificationViewSet, basename='notification')
router.register(r'chat-history', ChatHistoryViewSet, basename='chat-history')
# URL patterns # URL patterns
urlpatterns = [ urlpatterns = [

View File

@ -83,50 +83,52 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
queryset = ChatHistory.objects.all() queryset = ChatHistory.objects.all()
def get_queryset(self): def get_queryset(self):
"""确保用户只能看到自己的聊天记录""" """确保用户只能看到自己的未删除的聊天记录"""
return ChatHistory.objects.filter(user=self.request.user) return ChatHistory.objects.filter(
user=self.request.user,
is_deleted=False
)
def list(self, request): def list(self, request):
"""获取聊天记录列表按dataset_id分组""" """获取聊天记录列表"""
try: try:
# 获取查询参数 # 获取查询参数
dataset_id = request.query_params.get('dataset_id') dataset_id = request.query_params.get('dataset_id')
page = int(request.query_params.get('page', 1)) page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10)) page_size = int(request.query_params.get('page_size', 10))
# 基础查询
query = self.get_queryset() query = self.get_queryset()
if dataset_id: if dataset_id:
# 如果指定了dataset_id获取该数据集的完整对话历史 # 获取特定知识库的完整对话历史
records = query.filter( records = query.filter(
dataset_id=dataset_id knowledge_base__id=dataset_id
).order_by('created_at') # 按时间正序排列 ).order_by('created_at')
# 序列化对话数据
conversation = { conversation = {
'dataset_id': dataset_id, 'dataset_id': dataset_id,
'dataset_name': records.first().dataset_name if records.exists() else None, 'dataset_name': records.first().knowledge_base.name if records.exists() else None,
'messages': [{ 'messages': [{
'id': record.id, 'id': record.id,
'role': 'user' if idx % 2 == 0 else 'assistant', 'role': record.role,
'content': record.question if idx % 2 == 0 else record.answer, 'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S')
} for idx, record in enumerate(records)] } for record in records]
} }
return Response({ return Response({
'code': 200,
'message': '获取成功', 'message': '获取成功',
'data': conversation 'data': conversation
}) })
else: else:
# 如果没有指定dataset_id获取所有对话的概览 # 获取所有对话的概览
# 按dataset_id分组获取最新一条记录
latest_chats = query.values( latest_chats = query.values(
'dataset_id' 'conversation_id',
'knowledge_base__id',
'knowledge_base__name'
).annotate( ).annotate(
latest_id=Max('id'), latest_id=Max('id'),
dataset_name=F('dataset_name'),
message_count=Count('id'), message_count=Count('id'),
last_message=Max('created_at') last_message=Max('created_at')
).order_by('-last_message') ).order_by('-last_message')
@ -143,14 +145,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for chat in chats: for chat in chats:
latest_record = ChatHistory.objects.get(id=chat['latest_id']) latest_record = ChatHistory.objects.get(id=chat['latest_id'])
results.append({ results.append({
'dataset_id': chat['dataset_id'], 'conversation_id': chat['conversation_id'],
'dataset_name': chat['dataset_name'], 'dataset_id': str(chat['knowledge_base__id']),
'dataset_name': chat['knowledge_base__name'],
'message_count': chat['message_count'], 'message_count': chat['message_count'],
'last_message': latest_record.answer, 'last_message': latest_record.content,
'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S') 'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S')
}) })
return Response({ return Response({
'code': 200,
'message': '获取成功', 'message': '获取成功',
'data': { 'data': {
'total': total, 'total': total,
@ -161,12 +165,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({ return Response({
'error': f'获取聊天记录失败: {str(e)}' 'code': 500,
'message': f'获取聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def create(self, request): def create(self, request):
"""创建新的聊天记录""" """创建聊天记录"""
try: try:
data = request.data data = request.data
required_fields = ['dataset_id', 'dataset_name', 'question', 'answer'] required_fields = ['dataset_id', 'dataset_name', 'question', 'answer']
@ -175,90 +183,147 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for field in required_fields: for field in required_fields:
if field not in data: if field not in data:
return Response({ return Response({
'error': f'缺少必填字段: {field}' 'code': 400,
'message': f'缺少必填字段: {field}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST) }, status=status.HTTP_400_BAD_REQUEST)
# 创建记录 # 获取或创建对话ID
record = ChatHistory.objects.create( conversation_id = data.get('conversation_id', str(uuid.uuid4()))
# 获取知识库 - 不进行 UUID 转换
try:
knowledge_base = KnowledgeBase.objects.filter(id=data['dataset_id']).first()
if not knowledge_base:
return Response({
'code': 404,
'message': '知识库不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
return Response({
'code': 400,
'message': f'无效的知识库ID: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建用户问题记录
question_record = ChatHistory.objects.create(
user=request.user, user=request.user,
dataset_id=data['dataset_id'], knowledge_base=knowledge_base,
dataset_name=data['dataset_name'], conversation_id=conversation_id,
question=data['question'], role='user',
answer=data['answer'], content=data['question'],
model_name=data.get('model_name', 'default') metadata={'model_name': data.get('model_name', 'default')}
)
# 创建AI回答记录
answer_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_base,
conversation_id=conversation_id,
parent_id=str(question_record.id),
role='assistant',
content=data['answer'],
metadata={'model_name': data.get('model_name', 'default')}
) )
return Response({ return Response({
'code': 200,
'message': '创建成功', 'message': '创建成功',
'data': { 'data': {
'id': record.id, 'id': answer_record.id,
'dataset_id': record.dataset_id, 'conversation_id': conversation_id,
'dataset_id': str(knowledge_base.id),
'role': 'assistant', 'role': 'assistant',
'content': record.answer, 'content': answer_record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
} }
}, status=status.HTTP_201_CREATED) }, status=status.HTTP_201_CREATED)
except Exception as e: except Exception as e:
logger.error(f"创建聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({ return Response({
'error': f'创建聊天记录失败: {str(e)}' 'code': 500,
'message': f'创建聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def update(self, request, pk=None): def update(self, request, pk=None):
"""更新聊天记录""" """更新聊天记录"""
try: try:
# 获取记录
record = self.get_queryset().filter(id=pk).first() record = self.get_queryset().filter(id=pk).first()
if not record: if not record:
return Response({ return Response({
'error': '记录不存在或无权限' 'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND) }, status=status.HTTP_404_NOT_FOUND)
# 更新字段
data = request.data data = request.data
updateable_fields = ['question', 'answer', 'model_name'] updateable_fields = ['content', 'metadata']
for field in updateable_fields: if 'content' in data:
if field in data: record.content = data['content']
setattr(record, field, data[field])
if 'metadata' in data:
current_metadata = record.metadata or {}
current_metadata.update(data['metadata'])
record.metadata = current_metadata
record.save() record.save()
return Response({ return Response({
'code': 200,
'message': '更新成功', 'message': '更新成功',
'data': { 'data': {
'id': record.id, 'id': record.id,
'conversation_id': record.conversation_id,
'role': record.role,
'content': record.content,
'metadata': record.metadata,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
} }
}) })
except Exception as e: except Exception as e:
logger.error(f"更新聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({ return Response({
'error': f'更新聊天记录失败: {str(e)}' 'code': 500,
'message': f'更新聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def destroy(self, request, pk=None): def destroy(self, request, pk=None):
"""删除聊天记录""" """删除聊天记录(软删除)"""
try: try:
# 获取记录
record = self.get_queryset().filter(id=pk).first() record = self.get_queryset().filter(id=pk).first()
if not record: if not record:
return Response({ return Response({
'error': '记录不存在或无权限' 'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND) }, status=status.HTTP_404_NOT_FOUND)
# 删除记录 record.soft_delete()
record.delete()
return Response({ return Response({
'message': '删除成功' 'code': 200,
'message': '删除成功',
'data': None
}) })
except Exception as e: except Exception as e:
logger.error(f"删除聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({ return Response({
'error': f'删除聊天记录失败: {str(e)}' 'code': 500,
'message': f'删除聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get']) @action(detail=False, methods=['get'])
@ -273,20 +338,18 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
page = int(request.query_params.get('page', 1)) page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10)) page_size = int(request.query_params.get('page_size', 10))
# 基础查询:当前用户的记录 # 基础查询
query = self.get_queryset() query = self.get_queryset()
# 添加关键词搜索 # 添加过滤条件
if keyword: if keyword:
query = query.filter( query = query.filter(
Q(question__icontains=keyword) | # 问题包含关键词 Q(content__icontains=keyword) |
Q(answer__icontains=keyword) | # 回答包含关键词 Q(knowledge_base__name__icontains=keyword)
Q(dataset_name__icontains=keyword) # 知识库名称包含关键词
) )
# 添加其他过滤条件
if dataset_id: if dataset_id:
query = query.filter(dataset_id=dataset_id) query = query.filter(knowledge_base__id=dataset_id)
if start_date: if start_date:
query = query.filter(created_at__gte=start_date) query = query.filter(created_at__gte=start_date)
if end_date: if end_date:
@ -305,24 +368,24 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for record in records: for record in records:
result = { result = {
'id': record.id, 'id': record.id,
'dataset_id': record.dataset_id, 'conversation_id': record.conversation_id,
'dataset_name': record.dataset_name, 'dataset_id': str(record.knowledge_base.id),
'question': record.question, 'dataset_name': record.knowledge_base.name,
'answer': record.answer, 'role': record.role,
'model_name': record.model_name, 'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': record.metadata
} }
# 如果有关键词,添加高亮信息
if keyword: if keyword:
result['highlights'] = { result['highlights'] = {
'question': self._highlight_keyword(record.question, keyword), 'content': self._highlight_keyword(record.content, keyword)
'answer': self._highlight_keyword(record.answer, keyword)
} }
results.append(result) results.append(result)
return Response({ return Response({
'code': 200,
'message': '搜索成功', 'message': '搜索成功',
'data': { 'data': {
'total': total, 'total': total,
@ -333,11 +396,12 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
}) })
except Exception as e: except Exception as e:
print(f"搜索失败: {str(e)}") logger.error(f"搜索聊天记录失败: {str(e)}")
print(f"错误类型: {type(e)}") logger.error(traceback.format_exc())
print(f"错误堆栈: {traceback.format_exc()}")
return Response({ return Response({
'error': f'搜索失败: {str(e)}' 'code': 500,
'message': f'搜索失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _highlight_keyword(self, text, keyword): def _highlight_keyword(self, text, keyword):
@ -715,16 +779,51 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
"data": None "data": None
}, status=status.HTTP_403_FORBIDDEN) }, status=status.HTTP_403_FORBIDDEN)
# 执行更新 with transaction.atomic():
serializer = self.get_serializer(instance, data=request.data, partial=True) # 执行本地更新
serializer.is_valid(raise_exception=True) serializer = self.get_serializer(instance, data=request.data, partial=True)
self.perform_update(serializer) serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response({ # 更新外部知识库
"code": 200, if instance.external_id:
"message": "知识库更新成功", try:
"data": serializer.data api_data = {
}) "name": serializer.validated_data.get('name', instance.name),
"desc": serializer.validated_data.get('desc', instance.desc),
"type": "0", # 保持与创建时一致
"meta": {}, # 保持与创建时一致
"documents": [] # 保持与创建时一致
}
response = requests.put(
f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}',
json=api_data,
headers={'Content-Type': 'application/json'},
timeout=30
)
if response.status_code != 200:
raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}")
api_response = response.json()
if not api_response.get('code') == 200:
raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}")
logger.info(f"外部知识库更新成功: {instance.external_id}")
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
raise ExternalAPIError(f"更新外部知识库失败: {str(e)}")
return Response({
"code": 200,
"message": "知识库更新成功",
"data": serializer.data
})
except Http404: except Http404:
return Response({ return Response({
@ -732,6 +831,14 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
"message": "知识库不存在", "message": "知识库不存在",
"data": None "data": None
}, status=status.HTTP_404_NOT_FOUND) }, status=status.HTTP_404_NOT_FOUND)
except ExternalAPIError as e:
logger.error(f"更新外部知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": str(e),
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e: except Exception as e:
logger.error(f"更新知识库失败: {str(e)}") logger.error(f"更新知识库失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@ -858,91 +965,108 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
try: try:
user = request.user user = request.user
# 基础查询:根据用户角色过滤 # 基础查询排除secret类型的知识库
if user.role == 'admin': queryset = KnowledgeBase.objects.exclude(type='secret')
# 管理员可以看到所有知识库
queryset = KnowledgeBase.objects.all() # 获取用户所有有效的知识库权限
else: active_permissions = KBPermissionModel.objects.filter(
# 其他用户看不到secret类型的知识库 user=user,
queryset = KnowledgeBase.objects.exclude(type='secret') status='active',
expires_at__gt=timezone.now()
).select_related('knowledge_base')
# 创建权限映射字典
permission_map = {
str(perm.knowledge_base.id): {
'can_read': perm.can_read,
'can_edit': perm.can_edit,
'can_delete': perm.can_delete
}
for perm in active_permissions
}
# 获取每个知识库的权限信息
summaries = [] summaries = []
for kb in queryset: for kb in queryset:
# 默认权限 # 获取基础权限
permissions = { permissions = {
"can_read": False, 'can_read': False,
"can_edit": False, 'can_edit': False,
"can_delete": False 'can_delete': False
} }
# 根据角色和知识库类型设置权限 # 检查知识库特定权限
if kb.type == 'admin': kb_id = str(kb.id)
permissions.update({ if kb_id in permission_map:
"can_read": user.role == 'admin', # 只有管理员可以读 permissions.update(permission_map[kb_id])
"can_edit": user.role == 'admin', # 只有管理员可以编辑 # 如果没有特定权限,根据角色和部门设置默认权限
"can_delete": user.role == 'admin' # 只有管理员可以删除 else:
})
elif kb.type == 'leader':
if user.role == 'admin': if user.role == 'admin':
permissions.update({ permissions.update({
"can_read": True, 'can_read': True,
"can_edit": True, 'can_edit': True,
"can_delete": True 'can_delete': True
}) })
elif user.role == 'leader' and kb.department == user.department: elif kb.type == 'leader':
if user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'member' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': False,
'can_delete': False
})
elif kb.type == 'member':
if user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'member' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': False,
'can_delete': False
})
elif kb.type == 'private' and str(kb.user_id) == str(user.id):
permissions.update({ permissions.update({
"can_read": True, 'can_read': True,
"can_edit": False, 'can_edit': True,
"can_delete": False 'can_delete': True
})
elif kb.type == 'member':
if user.role == 'admin':
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
})
elif user.role == 'leader' and kb.department == user.department:
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
})
elif user.role == 'member' and kb.department == user.department:
permissions.update({
"can_read": True,
"can_edit": False,
"can_delete": False
})
elif kb.type == 'private':
if str(kb.user_id) == str(user.id):
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
}) })
# 只返回概要信息
summary = { summary = {
"id": kb.id, 'id': str(kb.id),
"name": kb.name, 'name': kb.name,
"desc": kb.desc, 'desc': kb.desc,
"type": kb.type, 'type': kb.type,
"permissions": permissions 'department': kb.department,
'permissions': permissions
} }
summaries.append(summary) summaries.append(summary)
return Response({ return Response({
"code": 200, 'code': 200,
"message": "获取知识库概要信息成功", 'message': '获取知识库概要信息成功',
"data": summaries 'data': summaries
}) })
except Exception as e: except Exception as e:
return Response({ return Response({
"code": 500, 'code': 500,
"message": f"获取知识库概要信息失败: {str(e)}", 'message': f'获取知识库概要信息失败: {str(e)}',
"data": None 'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
class PermissionViewSet(viewsets.ModelViewSet): class PermissionViewSet(viewsets.ModelViewSet):