diff --git a/requirements.txt b/requirements.txt index 5b42cfb..2e65c21 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/user_management/urls.py b/user_management/urls.py index ab74202..4a44b82 100644 --- a/user_management/urls.py +++ b/user_management/urls.py @@ -12,7 +12,8 @@ from .views import ( change_password, RegisterView, LoginView, - LogoutView + LogoutView, + ChatHistoryViewSet ) # 创建路由器 @@ -22,6 +23,7 @@ router = DefaultRouter() router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='knowledge-base') router.register(r'permissions', PermissionViewSet, basename='permission') router.register(r'notifications', NotificationViewSet, basename='notification') +router.register(r'chat-history', ChatHistoryViewSet, basename='chat-history') # URL patterns urlpatterns = [ diff --git a/user_management/views.py b/user_management/views.py index 3699077..565fbc4 100644 --- a/user_management/views.py +++ b/user_management/views.py @@ -83,50 +83,52 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): queryset = ChatHistory.objects.all() def get_queryset(self): - """确保用户只能看到自己的聊天记录""" - return ChatHistory.objects.filter(user=self.request.user) + """确保用户只能看到自己的未删除的聊天记录""" + return ChatHistory.objects.filter( + user=self.request.user, + is_deleted=False + ) def list(self, request): - """获取聊天记录列表,按dataset_id分组""" + """获取聊天记录列表""" try: # 获取查询参数 dataset_id = request.query_params.get('dataset_id') page = int(request.query_params.get('page', 1)) page_size = int(request.query_params.get('page_size', 10)) - # 基础查询 query = self.get_queryset() if dataset_id: - # 如果指定了dataset_id,获取该数据集的完整对话历史 + # 获取特定知识库的完整对话历史 records = query.filter( - dataset_id=dataset_id - ).order_by('created_at') # 按时间正序排列 + knowledge_base__id=dataset_id + ).order_by('created_at') - # 序列化对话数据 conversation = { 'dataset_id': dataset_id, - 'dataset_name': records.first().dataset_name if records.exists() else None, + 'dataset_name': records.first().knowledge_base.name if records.exists() else None, 'messages': [{ 'id': record.id, - 'role': 'user' if idx % 2 == 0 else 'assistant', - 'content': record.question if idx % 2 == 0 else record.answer, + 'role': record.role, + 'content': record.content, 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') - } for idx, record in enumerate(records)] + } for record in records] } return Response({ + 'code': 200, 'message': '获取成功', 'data': conversation }) else: - # 如果没有指定dataset_id,获取所有对话的概览 - # 按dataset_id分组,获取最新一条记录 + # 获取所有对话的概览 latest_chats = query.values( - 'dataset_id' + 'conversation_id', + 'knowledge_base__id', + 'knowledge_base__name' ).annotate( latest_id=Max('id'), - dataset_name=F('dataset_name'), message_count=Count('id'), last_message=Max('created_at') ).order_by('-last_message') @@ -143,14 +145,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): for chat in chats: latest_record = ChatHistory.objects.get(id=chat['latest_id']) results.append({ - 'dataset_id': chat['dataset_id'], - 'dataset_name': chat['dataset_name'], + 'conversation_id': chat['conversation_id'], + 'dataset_id': str(chat['knowledge_base__id']), + 'dataset_name': chat['knowledge_base__name'], 'message_count': chat['message_count'], - 'last_message': latest_record.answer, + 'last_message': latest_record.content, 'last_time': chat['last_message'].strftime('%Y-%m-%d %H:%M:%S') }) return Response({ + 'code': 200, 'message': '获取成功', 'data': { 'total': total, @@ -161,12 +165,16 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): }) except Exception as e: + logger.error(f"获取聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) return Response({ - 'error': f'获取聊天记录失败: {str(e)}' + 'code': 500, + 'message': f'获取聊天记录失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) def create(self, request): - """创建新的聊天记录""" + """创建聊天记录""" try: data = request.data required_fields = ['dataset_id', 'dataset_name', 'question', 'answer'] @@ -175,90 +183,147 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): for field in required_fields: if field not in data: return Response({ - 'error': f'缺少必填字段: {field}' + 'code': 400, + 'message': f'缺少必填字段: {field}', + 'data': None }, status=status.HTTP_400_BAD_REQUEST) - # 创建记录 - record = ChatHistory.objects.create( + # 获取或创建对话ID + conversation_id = data.get('conversation_id', str(uuid.uuid4())) + + # 获取知识库 - 不进行 UUID 转换 + try: + knowledge_base = KnowledgeBase.objects.filter(id=data['dataset_id']).first() + if not knowledge_base: + return Response({ + 'code': 404, + 'message': '知识库不存在', + 'data': None + }, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + return Response({ + 'code': 400, + 'message': f'无效的知识库ID: {str(e)}', + 'data': None + }, status=status.HTTP_400_BAD_REQUEST) + + # 创建用户问题记录 + question_record = ChatHistory.objects.create( user=request.user, - dataset_id=data['dataset_id'], - dataset_name=data['dataset_name'], - question=data['question'], - answer=data['answer'], - model_name=data.get('model_name', 'default') + knowledge_base=knowledge_base, + conversation_id=conversation_id, + role='user', + content=data['question'], + metadata={'model_name': data.get('model_name', 'default')} + ) + + # 创建AI回答记录 + answer_record = ChatHistory.objects.create( + user=request.user, + knowledge_base=knowledge_base, + conversation_id=conversation_id, + parent_id=str(question_record.id), + role='assistant', + content=data['answer'], + metadata={'model_name': data.get('model_name', 'default')} ) return Response({ + 'code': 200, 'message': '创建成功', 'data': { - 'id': record.id, - 'dataset_id': record.dataset_id, + 'id': answer_record.id, + 'conversation_id': conversation_id, + 'dataset_id': str(knowledge_base.id), 'role': 'assistant', - 'content': record.answer, - 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') + 'content': answer_record.content, + 'created_at': answer_record.created_at.strftime('%Y-%m-%d %H:%M:%S') } }, status=status.HTTP_201_CREATED) except Exception as e: + logger.error(f"创建聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) return Response({ - 'error': f'创建聊天记录失败: {str(e)}' + 'code': 500, + 'message': f'创建聊天记录失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) def update(self, request, pk=None): """更新聊天记录""" try: - # 获取记录 record = self.get_queryset().filter(id=pk).first() if not record: return Response({ - 'error': '记录不存在或无权限' + 'code': 404, + 'message': '记录不存在或无权限', + 'data': None }, status=status.HTTP_404_NOT_FOUND) - # 更新字段 data = request.data - updateable_fields = ['question', 'answer', 'model_name'] + updateable_fields = ['content', 'metadata'] - for field in updateable_fields: - if field in data: - setattr(record, field, data[field]) + if 'content' in data: + record.content = data['content'] + + if 'metadata' in data: + current_metadata = record.metadata or {} + current_metadata.update(data['metadata']) + record.metadata = current_metadata record.save() return Response({ + 'code': 200, 'message': '更新成功', 'data': { 'id': record.id, + 'conversation_id': record.conversation_id, + 'role': record.role, + 'content': record.content, + 'metadata': record.metadata, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } }) except Exception as e: + logger.error(f"更新聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) return Response({ - 'error': f'更新聊天记录失败: {str(e)}' + 'code': 500, + 'message': f'更新聊天记录失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) def destroy(self, request, pk=None): - """删除聊天记录""" + """删除聊天记录(软删除)""" try: - # 获取记录 record = self.get_queryset().filter(id=pk).first() if not record: return Response({ - 'error': '记录不存在或无权限' + 'code': 404, + 'message': '记录不存在或无权限', + 'data': None }, status=status.HTTP_404_NOT_FOUND) - # 删除记录 - record.delete() + record.soft_delete() return Response({ - 'message': '删除成功' + 'code': 200, + 'message': '删除成功', + 'data': None }) except Exception as e: + logger.error(f"删除聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) return Response({ - 'error': f'删除聊天记录失败: {str(e)}' + 'code': 500, + 'message': f'删除聊天记录失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @action(detail=False, methods=['get']) @@ -273,20 +338,18 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): page = int(request.query_params.get('page', 1)) page_size = int(request.query_params.get('page_size', 10)) - # 基础查询:当前用户的记录 + # 基础查询 query = self.get_queryset() - # 添加关键词搜索 + # 添加过滤条件 if keyword: query = query.filter( - Q(question__icontains=keyword) | # 问题包含关键词 - Q(answer__icontains=keyword) | # 回答包含关键词 - Q(dataset_name__icontains=keyword) # 知识库名称包含关键词 + Q(content__icontains=keyword) | + Q(knowledge_base__name__icontains=keyword) ) - # 添加其他过滤条件 if dataset_id: - query = query.filter(dataset_id=dataset_id) + query = query.filter(knowledge_base__id=dataset_id) if start_date: query = query.filter(created_at__gte=start_date) if end_date: @@ -305,24 +368,24 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): for record in records: result = { 'id': record.id, - 'dataset_id': record.dataset_id, - 'dataset_name': record.dataset_name, - 'question': record.question, - 'answer': record.answer, - 'model_name': record.model_name, - 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S') + 'conversation_id': record.conversation_id, + 'dataset_id': str(record.knowledge_base.id), + 'dataset_name': record.knowledge_base.name, + 'role': record.role, + 'content': record.content, + 'created_at': record.created_at.strftime('%Y-%m-%d %H:%M:%S'), + 'metadata': record.metadata } - # 如果有关键词,添加高亮信息 if keyword: result['highlights'] = { - 'question': self._highlight_keyword(record.question, keyword), - 'answer': self._highlight_keyword(record.answer, keyword) + 'content': self._highlight_keyword(record.content, keyword) } results.append(result) return Response({ + 'code': 200, 'message': '搜索成功', 'data': { 'total': total, @@ -333,11 +396,12 @@ class ChatHistoryViewSet(viewsets.ModelViewSet): }) except Exception as e: - print(f"搜索失败: {str(e)}") - print(f"错误类型: {type(e)}") - print(f"错误堆栈: {traceback.format_exc()}") + logger.error(f"搜索聊天记录失败: {str(e)}") + logger.error(traceback.format_exc()) return Response({ - 'error': f'搜索失败: {str(e)}' + 'code': 500, + 'message': f'搜索失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) def _highlight_keyword(self, text, keyword): @@ -715,16 +779,51 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet): "data": None }, status=status.HTTP_403_FORBIDDEN) - # 执行更新 - serializer = self.get_serializer(instance, data=request.data, partial=True) - serializer.is_valid(raise_exception=True) - self.perform_update(serializer) + with transaction.atomic(): + # 执行本地更新 + serializer = self.get_serializer(instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) - return Response({ - "code": 200, - "message": "知识库更新成功", - "data": serializer.data - }) + # 更新外部知识库 + if instance.external_id: + try: + api_data = { + "name": serializer.validated_data.get('name', instance.name), + "desc": serializer.validated_data.get('desc', instance.desc), + "type": "0", # 保持与创建时一致 + "meta": {}, # 保持与创建时一致 + "documents": [] # 保持与创建时一致 + } + + response = requests.put( + f'{settings.API_BASE_URL}/api/dataset/{instance.external_id}', + json=api_data, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code != 200: + raise ExternalAPIError(f"更新外部知识库失败,状态码: {response.status_code}, 响应: {response.text}") + + api_response = response.json() + if not api_response.get('code') == 200: + raise ExternalAPIError(f"更新外部知识库失败: {api_response.get('message', '未知错误')}") + + logger.info(f"外部知识库更新成功: {instance.external_id}") + + except requests.exceptions.Timeout: + raise ExternalAPIError("请求超时,请稍后重试") + except requests.exceptions.RequestException as e: + raise ExternalAPIError(f"API请求失败: {str(e)}") + except Exception as e: + raise ExternalAPIError(f"更新外部知识库失败: {str(e)}") + + return Response({ + "code": 200, + "message": "知识库更新成功", + "data": serializer.data + }) except Http404: return Response({ @@ -732,6 +831,14 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet): "message": "知识库不存在", "data": None }, status=status.HTTP_404_NOT_FOUND) + except ExternalAPIError as e: + logger.error(f"更新外部知识库失败: {str(e)}") + logger.error(traceback.format_exc()) + return Response({ + "code": 500, + "message": str(e), + "data": None + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) except Exception as e: logger.error(f"更新知识库失败: {str(e)}") logger.error(traceback.format_exc()) @@ -858,91 +965,108 @@ class KnowledgeBaseViewSet(viewsets.ModelViewSet): try: user = request.user - # 基础查询:根据用户角色过滤 - if user.role == 'admin': - # 管理员可以看到所有知识库 - queryset = KnowledgeBase.objects.all() - else: - # 其他用户看不到secret类型的知识库 - queryset = KnowledgeBase.objects.exclude(type='secret') + # 基础查询:排除secret类型的知识库 + queryset = KnowledgeBase.objects.exclude(type='secret') + + # 获取用户所有有效的知识库权限 + active_permissions = KBPermissionModel.objects.filter( + user=user, + status='active', + expires_at__gt=timezone.now() + ).select_related('knowledge_base') + + # 创建权限映射字典 + permission_map = { + str(perm.knowledge_base.id): { + 'can_read': perm.can_read, + 'can_edit': perm.can_edit, + 'can_delete': perm.can_delete + } + for perm in active_permissions + } - # 获取每个知识库的权限信息 summaries = [] for kb in queryset: - # 默认权限 + # 获取基础权限 permissions = { - "can_read": False, - "can_edit": False, - "can_delete": False + 'can_read': False, + 'can_edit': False, + 'can_delete': False } - - # 根据角色和知识库类型设置权限 - if kb.type == 'admin': - permissions.update({ - "can_read": user.role == 'admin', # 只有管理员可以读 - "can_edit": user.role == 'admin', # 只有管理员可以编辑 - "can_delete": user.role == 'admin' # 只有管理员可以删除 - }) - elif kb.type == 'leader': + + # 检查知识库特定权限 + kb_id = str(kb.id) + if kb_id in permission_map: + permissions.update(permission_map[kb_id]) + # 如果没有特定权限,根据角色和部门设置默认权限 + else: if user.role == 'admin': permissions.update({ - "can_read": True, - "can_edit": True, - "can_delete": True + 'can_read': True, + 'can_edit': True, + 'can_delete': True }) - elif user.role == 'leader' and kb.department == user.department: + elif kb.type == 'leader': + if user.role == 'leader' and user.department == kb.department: + permissions.update({ + 'can_read': True, + 'can_edit': True, + 'can_delete': True + }) + elif user.role == 'leader' and user.department == kb.department: + permissions.update({ + 'can_read': True, + 'can_edit': True, + 'can_delete': True + }) + elif user.role == 'member' and user.department == kb.department: + permissions.update({ + 'can_read': True, + 'can_edit': False, + 'can_delete': False + }) + elif kb.type == 'member': + if user.role == 'leader' and user.department == kb.department: + permissions.update({ + 'can_read': True, + 'can_edit': True, + 'can_delete': True + }) + elif user.role == 'member' and user.department == kb.department: + permissions.update({ + 'can_read': True, + 'can_edit': False, + 'can_delete': False + }) + elif kb.type == 'private' and str(kb.user_id) == str(user.id): permissions.update({ - "can_read": True, - "can_edit": False, - "can_delete": False - }) - elif kb.type == 'member': - if user.role == 'admin': - permissions.update({ - "can_read": True, - "can_edit": True, - "can_delete": True - }) - elif user.role == 'leader' and kb.department == user.department: - permissions.update({ - "can_read": True, - "can_edit": True, - "can_delete": True - }) - elif user.role == 'member' and kb.department == user.department: - permissions.update({ - "can_read": True, - "can_edit": False, - "can_delete": False - }) - elif kb.type == 'private': - if str(kb.user_id) == str(user.id): - permissions.update({ - "can_read": True, - "can_edit": True, - "can_delete": True + 'can_read': True, + 'can_edit': True, + 'can_delete': True }) + # 只返回概要信息 summary = { - "id": kb.id, - "name": kb.name, - "desc": kb.desc, - "type": kb.type, - "permissions": permissions + 'id': str(kb.id), + 'name': kb.name, + 'desc': kb.desc, + 'type': kb.type, + 'department': kb.department, + 'permissions': permissions } summaries.append(summary) return Response({ - "code": 200, - "message": "获取知识库概要信息成功", - "data": summaries + 'code': 200, + 'message': '获取知识库概要信息成功', + 'data': summaries }) except Exception as e: return Response({ - "code": 500, - "message": f"获取知识库概要信息失败: {str(e)}", - "data": None + 'code': 500, + 'message': f'获取知识库概要信息失败: {str(e)}', + 'data': None }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) class PermissionViewSet(viewsets.ModelViewSet):