fix: 修复聊天记录知识库ID字段问题

This commit is contained in:
wanjia 2025-03-03 14:17:02 +08:00
parent 9ef1608809
commit 07a33dbf54
3 changed files with 270 additions and 144 deletions

Binary file not shown.

View File

@ -12,7 +12,8 @@ from .views import (
change_password,
RegisterView,
LoginView,
LogoutView
LogoutView,
ChatHistoryViewSet
)
# 创建路由器
@ -22,6 +23,7 @@ router = DefaultRouter()
router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-base')
router.register(r'permissions', PermissionViewSet, basename='permission')
router.register(r'notifications', NotificationViewSet, basename='notification')
router.register(r'chat-history', ChatHistoryViewSet, basename='chat-history')
# URL patterns
urlpatterns = [

View File

@ -83,50 +83,52 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
queryset = ChatHistory.objects.all()
def get_queryset(self):
"""确保用户只能看到自己的聊天记录"""
return ChatHistory.objects.filter(user=self.request.user)
"""确保用户只能看到自己的未删除的聊天记录"""
return ChatHistory.objects.filter(
user=self.request.user,
is_deleted=False
)
def list(self, request):
"""获取聊天记录列表按dataset_id分组"""
"""获取聊天记录列表"""
try:
# 获取查询参数
dataset_id = request.query_params.get('dataset_id')
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 基础查询
query = self.get_queryset()
if dataset_id:
# 如果指定了dataset_id获取该数据集的完整对话历史
# 获取特定知识库的完整对话历史
records = query.filter(
dataset_id=dataset_id
).order_by('created_at') # 按时间正序排列
knowledge_base__id=dataset_id
).order_by('created_at')
# 序列化对话数据
conversation = {
'dataset_id': dataset_id,
'dataset_name': records.first().dataset_name if records.exists() else None,
'dataset_name': records.first().knowledge_base.name if records.exists() else None,
'messages': [{
'id': record.id,
'role': 'user' if idx % 2 == 0 else 'assistant',
'content': record.question if idx % 2 == 0 else record.answer,
'role': record.role,
'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S')
} for idx, record in enumerate(records)]
} for record in records]
}
return Response({
'code': 200,
'message': '获取成功',
'data': conversation
})
else:
# 如果没有指定dataset_id获取所有对话的概览
# 按dataset_id分组获取最新一条记录
# 获取所有对话的概览
latest_chats = query.values(
'dataset_id'
'conversation_id',
'knowledge_base__id',
'knowledge_base__name'
).annotate(
latest_id=Max('id'),
dataset_name=F('dataset_name'),
message_count=Count('id'),
last_message=Max('created_at')
).order_by('-last_message')
@ -143,14 +145,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for chat in chats:
latest_record = ChatHistory.objects.get(id=chat['latest_id'])
results.append({
'dataset_id': chat['dataset_id'],
'dataset_name': chat['dataset_name'],
'conversation_id': chat['conversation_id'],
'dataset_id': str(chat['knowledge_base__id']),
'dataset_name': chat['knowledge_base__name'],
'message_count': chat['message_count'],
'last_message': latest_record.answer,
'last_message': latest_record.content,
'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S')
})
return Response({
'code': 200,
'message': '获取成功',
'data': {
'total': total,
@ -161,12 +165,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
})
except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'error': f'获取聊天记录失败: {str(e)}'
'code': 500,
'message': f'获取聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def create(self, request):
"""创建新的聊天记录"""
"""创建聊天记录"""
try:
data = request.data
required_fields = ['dataset_id', 'dataset_name', 'question', 'answer']
@ -175,90 +183,147 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for field in required_fields:
if field not in data:
return Response({
'error': f'缺少必填字段: {field}'
'code': 400,
'message': f'缺少必填字段: {field}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建记录
record = ChatHistory.objects.create(
# 获取或创建对话ID
conversation_id = data.get('conversation_id', str(uuid.uuid4()))
# 获取知识库 - 不进行 UUID 转换
try:
knowledge_base = KnowledgeBase.objects.filter(id=data['dataset_id']).first()
if not knowledge_base:
return Response({
'code': 404,
'message': '知识库不存在',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
except Exception as e:
return Response({
'code': 400,
'message': f'无效的知识库ID: {str(e)}',
'data': None
}, status=status.HTTP_400_BAD_REQUEST)
# 创建用户问题记录
question_record = ChatHistory.objects.create(
user=request.user,
dataset_id=data['dataset_id'],
dataset_name=data['dataset_name'],
question=data['question'],
answer=data['answer'],
model_name=data.get('model_name', 'default')
knowledge_base=knowledge_base,
conversation_id=conversation_id,
role='user',
content=data['question'],
metadata={'model_name': data.get('model_name', 'default')}
)
# 创建AI回答记录
answer_record = ChatHistory.objects.create(
user=request.user,
knowledge_base=knowledge_base,
conversation_id=conversation_id,
parent_id=str(question_record.id),
role='assistant',
content=data['answer'],
metadata={'model_name': data.get('model_name', 'default')}
)
return Response({
'code': 200,
'message': '创建成功',
'data': {
'id': record.id,
'dataset_id': record.dataset_id,
'id': answer_record.id,
'conversation_id': conversation_id,
'dataset_id': str(knowledge_base.id),
'role': 'assistant',
'content': record.answer,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S')
'content': answer_record.content,
'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S')
}
}, status=status.HTTP_201_CREATED)
except Exception as e:
logger.error(f"创建聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'error': f'创建聊天记录失败: {str(e)}'
'code': 500,
'message': f'创建聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def update(self, request, pk=None):
"""更新聊天记录"""
try:
# 获取记录
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'error': '记录不存在或无权限'
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 更新字段
data = request.data
updateable_fields = ['question', 'answer', 'model_name']
updateable_fields = ['content', 'metadata']
for field in updateable_fields:
if field in data:
setattr(record, field, data[field])
if 'content' in data:
record.content = data['content']
if 'metadata' in data:
current_metadata = record.metadata or {}
current_metadata.update(data['metadata'])
record.metadata = current_metadata
record.save()
return Response({
'code': 200,
'message': '更新成功',
'data': {
'id': record.id,
'conversation_id': record.conversation_id,
'role': record.role,
'content': record.content,
'metadata': record.metadata,
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
})
except Exception as e:
logger.error(f"更新聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'error': f'更新聊天记录失败: {str(e)}'
'code': 500,
'message': f'更新聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def destroy(self, request, pk=None):
"""删除聊天记录"""
"""删除聊天记录(软删除)"""
try:
# 获取记录
record = self.get_queryset().filter(id=pk).first()
if not record:
return Response({
'error': '记录不存在或无权限'
'code': 404,
'message': '记录不存在或无权限',
'data': None
}, status=status.HTTP_404_NOT_FOUND)
# 删除记录
record.delete()
record.soft_delete()
return Response({
'message': '删除成功'
'code': 200,
'message': '删除成功',
'data': None
})
except Exception as e:
logger.error(f"删除聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'error': f'删除聊天记录失败: {str(e)}'
'code': 500,
'message': f'删除聊天记录失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=False, methods=['get'])
@ -273,20 +338,18 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('page_size', 10))
# 基础查询:当前用户的记录
# 基础查询
query = self.get_queryset()
# 添加关键词搜索
# 添加过滤条件
if keyword:
query = query.filter(
Q(question__icontains=keyword) | # 问题包含关键词
Q(answer__icontains=keyword) | # 回答包含关键词
Q(dataset_name__icontains=keyword) # 知识库名称包含关键词
Q(content__icontains=keyword) |
Q(knowledge_base__name__icontains=keyword)
)
# 添加其他过滤条件
if dataset_id:
query = query.filter(dataset_id=dataset_id)
query = query.filter(knowledge_base__id=dataset_id)
if start_date:
query = query.filter(created_at__gte=start_date)
if end_date:
@ -305,24 +368,24 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
for record in records:
result = {
'id': record.id,
'dataset_id': record.dataset_id,
'dataset_name': record.dataset_name,
'question': record.question,
'answer': record.answer,
'model_name': record.model_name,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S')
'conversation_id': record.conversation_id,
'dataset_id': str(record.knowledge_base.id),
'dataset_name': record.knowledge_base.name,
'role': record.role,
'content': record.content,
'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'metadata': record.metadata
}
# 如果有关键词,添加高亮信息
if keyword:
result['highlights'] = {
'question': self._highlight_keyword(record.question, keyword),
'answer': self._highlight_keyword(record.answer, keyword)
'content': self._highlight_keyword(record.content, keyword)
}
results.append(result)
return Response({
'code': 200,
'message': '搜索成功',
'data': {
'total': total,
@ -333,11 +396,12 @@ class ChatHistoryViewSet(viewsets.ModelViewSet):
})
except Exception as e:
print(f"搜索失败: {str(e)}")
print(f"错误类型: {type(e)}")
print(f"错误堆栈: {traceback.format_exc()}")
logger.error(f"搜索聊天记录失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
'error': f'搜索失败: {str(e)}'
'code': 500,
'message': f'搜索失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def _highlight_keyword(self, text, keyword):
@ -715,16 +779,51 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
"data": None
}, status=status.HTTP_403_FORBIDDEN)
# 执行更新
serializer = self.get_serializer(instance, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
with transaction.atomic():
# 执行本地更新
serializer = self.get_serializer(instance, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response({
"code": 200,
"message": "知识库更新成功",
"data": serializer.data
})
# 更新外部知识库
if instance.external_id:
try:
api_data = {
"name": serializer.validated_data.get('name', instance.name),
"desc": serializer.validated_data.get('desc', instance.desc),
"type": "0", # 保持与创建时一致
"meta": {}, # 保持与创建时一致
"documents": [] # 保持与创建时一致
}
response = requests.put(
f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}',
json=api_data,
headers={'Content-Type': 'application/json'},
timeout=30
)
if response.status_code != 200:
raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}")
api_response = response.json()
if not api_response.get('code') == 200:
raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}")
logger.info(f"外部知识库更新成功: {instance.external_id}")
except requests.exceptions.Timeout:
raise ExternalAPIError("请求超时,请稍后重试")
except requests.exceptions.RequestException as e:
raise ExternalAPIError(f"API请求失败: {str(e)}")
except Exception as e:
raise ExternalAPIError(f"更新外部知识库失败: {str(e)}")
return Response({
"code": 200,
"message": "知识库更新成功",
"data": serializer.data
})
except Http404:
return Response({
@ -732,6 +831,14 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
"message": "知识库不存在",
"data": None
}, status=status.HTTP_404_NOT_FOUND)
except ExternalAPIError as e:
logger.error(f"更新外部知识库失败: {str(e)}")
logger.error(traceback.format_exc())
return Response({
"code": 500,
"message": str(e),
"data": None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except Exception as e:
logger.error(f"更新知识库失败: {str(e)}")
logger.error(traceback.format_exc())
@ -858,91 +965,108 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet):
try:
user = request.user
# 基础查询:根据用户角色过滤
if user.role == 'admin':
# 管理员可以看到所有知识库
queryset = KnowledgeBase.objects.all()
else:
# 其他用户看不到secret类型的知识库
queryset = KnowledgeBase.objects.exclude(type='secret')
# 基础查询排除secret类型的知识库
queryset = KnowledgeBase.objects.exclude(type='secret')
# 获取用户所有有效的知识库权限
active_permissions = KBPermissionModel.objects.filter(
user=user,
status='active',
expires_at__gt=timezone.now()
).select_related('knowledge_base')
# 创建权限映射字典
permission_map = {
str(perm.knowledge_base.id): {
'can_read': perm.can_read,
'can_edit': perm.can_edit,
'can_delete': perm.can_delete
}
for perm in active_permissions
}
# 获取每个知识库的权限信息
summaries = []
for kb in queryset:
# 默认权限
# 获取基础权限
permissions = {
"can_read": False,
"can_edit": False,
"can_delete": False
'can_read': False,
'can_edit': False,
'can_delete': False
}
# 根据角色和知识库类型设置权限
if kb.type == 'admin':
permissions.update({
"can_read": user.role == 'admin', # 只有管理员可以读
"can_edit": user.role == 'admin', # 只有管理员可以编辑
"can_delete": user.role == 'admin' # 只有管理员可以删除
})
elif kb.type == 'leader':
# 检查知识库特定权限
kb_id = str(kb.id)
if kb_id in permission_map:
permissions.update(permission_map[kb_id])
# 如果没有特定权限,根据角色和部门设置默认权限
else:
if user.role == 'admin':
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'leader' and kb.department == user.department:
elif kb.type == 'leader':
if user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'member' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': False,
'can_delete': False
})
elif kb.type == 'member':
if user.role == 'leader' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': True,
'can_delete': True
})
elif user.role == 'member' and user.department == kb.department:
permissions.update({
'can_read': True,
'can_edit': False,
'can_delete': False
})
elif kb.type == 'private' and str(kb.user_id) == str(user.id):
permissions.update({
"can_read": True,
"can_edit": False,
"can_delete": False
})
elif kb.type == 'member':
if user.role == 'admin':
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
})
elif user.role == 'leader' and kb.department == user.department:
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
})
elif user.role == 'member' and kb.department == user.department:
permissions.update({
"can_read": True,
"can_edit": False,
"can_delete": False
})
elif kb.type == 'private':
if str(kb.user_id) == str(user.id):
permissions.update({
"can_read": True,
"can_edit": True,
"can_delete": True
'can_read': True,
'can_edit': True,
'can_delete': True
})
# 只返回概要信息
summary = {
"id": kb.id,
"name": kb.name,
"desc": kb.desc,
"type": kb.type,
"permissions": permissions
'id': str(kb.id),
'name': kb.name,
'desc': kb.desc,
'type': kb.type,
'department': kb.department,
'permissions': permissions
}
summaries.append(summary)
return Response({
"code": 200,
"message": "获取知识库概要信息成功",
"data": summaries
'code': 200,
'message': '获取知识库概要信息成功',
'data': summaries
})
except Exception as e:
return Response({
"code": 500,
"message": f"获取知识库概要信息失败: {str(e)}",
"data": None
'code': 500,
'message': f'获取知识库概要信息失败: {str(e)}',
'data': None
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
class PermissionViewSet(viewsets.ModelViewSet):