From 63b2aebf8442864da0c6ab83d0698cebcbc9c92b Mon Sep 17 00:00:00 2001 From: wanjia Date: Mon, 9 Jun 2025 16:56:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=93=8D=E5=BA=94=E5=8F=82=E6=95=B0=E8=A7=84?= =?UTF-8?q?=E8=8C=83=E4=BA=86flhf=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/rlhf/views.py | 337 ++++++++++++++++++++++++++------------------- 1 file changed, 199 insertions(+), 138 deletions(-) diff --git a/apps/rlhf/views.py b/apps/rlhf/views.py index 6a3db32..c59a537 100644 --- a/apps/rlhf/views.py +++ b/apps/rlhf/views.py @@ -23,7 +23,75 @@ from django.db import transaction from django.db.models.functions import TruncDate from apps.user.authentication import CustomTokenAuthentication -class ConversationViewSet(viewsets.ModelViewSet): + +# 创建统一响应格式的基类 +class StandardResponseMixin: + """标准响应格式的混合类,用于统一API响应格式""" + + def get_standard_response(self, data=None, message="成功", code=200, status_code=None): + """返回标准格式的响应""" + response_data = { + "code": code, + "message": message, + "data": data + } + return Response(response_data, status=status_code or status.HTTP_200_OK) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + headers = self.get_success_headers(serializer.data) + return self.get_standard_response( + data=serializer.data, + message="创建成功", + code=200, + status_code=status.HTTP_201_CREATED + ) + + def list(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + serializer = self.get_serializer(queryset, many=True) + return self.get_standard_response(data=serializer.data) + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + return self.get_standard_response(data=serializer.data) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + return self.get_standard_response(data=serializer.data, message="更新成功") + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return self.get_standard_response(message="删除成功", data=None) + + def get_paginated_response(self, data): + """ + 重写分页响应格式 + """ + assert self.paginator is not None + return self.get_standard_response( + data={ + 'count': self.paginator.page.paginator.count, + 'next': self.paginator.get_next_link(), + 'previous': self.paginator.get_previous_link(), + 'results': data + } + ) + + +class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = Conversation.objects.all() serializer_class = ConversationSerializer authentication_classes = [CustomTokenAuthentication] @@ -41,11 +109,7 @@ class ConversationViewSet(viewsets.ModelViewSet): conversation = self.get_object() messages = Message.objects.filter(conversation=conversation).order_by('timestamp') serializer = MessageSerializer(messages, many=True) - return Response({ - 'code': 200, - 'message': '成功', - 'data': serializer.data - }) + return self.get_standard_response(data=serializer.data) @action(detail=True, methods=['post']) def message(self, request, pk=None): @@ -53,11 +117,12 @@ class ConversationViewSet(viewsets.ModelViewSet): content = request.data.get('content') if not content: - return Response({ - 'code': 400, - 'message': '消息内容不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='消息内容不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) # 创建用户消息 user_message = Message.objects.create( @@ -87,11 +152,7 @@ class ConversationViewSet(viewsets.ModelViewSet): MessageSerializer(ai_message).data ] - return Response({ - 'code': 200, - 'message': '成功', - 'data': messages - }) + return self.get_standard_response(data=messages) @action(detail=True, methods=['post']) def submit(self, request, pk=None): @@ -100,11 +161,12 @@ class ConversationViewSet(viewsets.ModelViewSet): description = request.data.get('description', '') if conversation.is_submitted: - return Response({ - 'code': 400, - 'message': '该对话已提交', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='该对话已提交', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) # 更新对话为已提交状态 conversation.is_submitted = True @@ -130,22 +192,22 @@ class ConversationViewSet(viewsets.ModelViewSet): details={'title': title} ) - return Response({ - 'code': 200, - 'message': '对话提交成功', - 'data': {'submission_id': submission.id} - }) + return self.get_standard_response( + message='对话提交成功', + data={'submission_id': submission.id} + ) @action(detail=True, methods=['post']) def resume(self, request, pk=None): conversation = self.get_object() if not conversation.is_submitted: - return Response({ - 'code': 400, - 'message': '该对话未提交,无需恢复', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='该对话未提交,无需恢复', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) # 更新对话为未提交状态 conversation.is_submitted = False @@ -168,11 +230,10 @@ class ConversationViewSet(viewsets.ModelViewSet): target_id=str(conversation.id) ) - return Response({ - 'code': 200, - 'message': '对话已恢复为未提交状态', - 'data': None - }) + return self.get_standard_response( + message='对话已恢复为未提交状态', + data=None + ) def _generate_ai_response(self, user_message, conversation): """ @@ -219,7 +280,7 @@ class ConversationViewSet(viewsets.ModelViewSet): stats.save() -class MessageViewSet(viewsets.ModelViewSet): +class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = Message.objects.all() serializer_class = MessageSerializer authentication_classes = [CustomTokenAuthentication] @@ -230,7 +291,7 @@ class MessageViewSet(viewsets.ModelViewSet): return Message.objects.filter(conversation__user=user).order_by('timestamp') -class FeedbackViewSet(viewsets.ModelViewSet): +class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = Feedback.objects.all() serializer_class = FeedbackSerializer authentication_classes = [CustomTokenAuthentication] @@ -246,21 +307,23 @@ class FeedbackViewSet(viewsets.ModelViewSet): feedback_value = request.data.get('feedback_value') if not message_id or not conversation_id: - return Response({ - 'code': 400, - 'message': '消息ID和对话ID不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='消息ID和对话ID不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) try: message = Message.objects.get(id=message_id) conversation = Conversation.objects.get(id=conversation_id) except (Message.DoesNotExist, Conversation.DoesNotExist): - return Response({ - 'code': 404, - 'message': '消息或对话不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) + return self.get_standard_response( + code=404, + message='消息或对话不存在', + data=None, + status_code=status.HTTP_404_NOT_FOUND + ) # 创建或更新反馈 feedback, created = Feedback.objects.update_or_create( @@ -277,11 +340,10 @@ class FeedbackViewSet(viewsets.ModelViewSet): # 更新用户的标注统计 self._update_annotation_stats(request.user.id, feedback_value) - return Response({ - 'code': 200, - 'message': '反馈提交成功', - 'data': FeedbackSerializer(feedback).data - }) + return self.get_standard_response( + message='反馈提交成功', + data=FeedbackSerializer(feedback).data + ) def _update_annotation_stats(self, user_id, feedback_value): """更新用户的标注统计信息""" @@ -311,7 +373,7 @@ class FeedbackViewSet(viewsets.ModelViewSet): stats.save() -class FeedbackTagViewSet(viewsets.ModelViewSet): +class FeedbackTagViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = FeedbackTag.objects.all() serializer_class = FeedbackTagSerializer authentication_classes = [CustomTokenAuthentication] @@ -324,7 +386,7 @@ class FeedbackTagViewSet(viewsets.ModelViewSet): return FeedbackTag.objects.all() -class DetailedFeedbackViewSet(viewsets.ModelViewSet): +class DetailedFeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = DetailedFeedback.objects.all() serializer_class = DetailedFeedbackSerializer authentication_classes = [CustomTokenAuthentication] @@ -344,21 +406,23 @@ class DetailedFeedbackViewSet(viewsets.ModelViewSet): is_inline = request.data.get('is_inline', True) if not message_id or not conversation_id or not feedback_type: - return Response({ - 'code': 400, - 'message': '消息ID、对话ID和反馈类型不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='消息ID、对话ID和反馈类型不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) try: message = Message.objects.get(id=message_id) conversation = Conversation.objects.get(id=conversation_id) except (Message.DoesNotExist, Conversation.DoesNotExist): - return Response({ - 'code': 404, - 'message': '消息或对话不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) + return self.get_standard_response( + code=404, + message='消息或对话不存在', + data=None, + status_code=status.HTTP_404_NOT_FOUND + ) # 将标签列表转换为JSON字符串 if isinstance(feedback_tags, list): @@ -394,11 +458,10 @@ class DetailedFeedbackViewSet(viewsets.ModelViewSet): # 更新用户的标注统计 self._update_annotation_stats(request.user.id, feedback_type) - return Response({ - 'code': 200, - 'message': '详细反馈提交成功', - 'data': DetailedFeedbackSerializer(detailed_feedback).data - }) + return self.get_standard_response( + message='详细反馈提交成功', + data=DetailedFeedbackSerializer(detailed_feedback).data + ) def _update_annotation_stats(self, user_id, feedback_type): """更新用户的标注统计信息""" @@ -428,7 +491,7 @@ class DetailedFeedbackViewSet(viewsets.ModelViewSet): stats.save() -class ConversationSubmissionViewSet(viewsets.ModelViewSet): +class ConversationSubmissionViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = ConversationSubmission.objects.all() serializer_class = ConversationSubmissionSerializer authentication_classes = [CustomTokenAuthentication] @@ -454,11 +517,12 @@ class ConversationSubmissionViewSet(viewsets.ModelViewSet): @action(detail=True, methods=['post']) def review(self, request, pk=None): if request.user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以进行审核', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) + return self.get_standard_response( + code=403, + message='只有管理员可以进行审核', + data=None, + status_code=status.HTTP_403_FORBIDDEN + ) submission = self.get_object() status_value = request.data.get('status') @@ -466,18 +530,20 @@ class ConversationSubmissionViewSet(viewsets.ModelViewSet): reviewer_notes = request.data.get('reviewer_notes', '') if not status_value or status_value not in ['accepted', 'rejected']: - return Response({ - 'code': 400, - 'message': '状态值无效', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='状态值无效', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) if quality_score is not None and (quality_score < 1 or quality_score > 5): - return Response({ - 'code': 400, - 'message': '质量分数必须在1-5之间', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='质量分数必须在1-5之间', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) # 更新提交状态 submission.status = status_value @@ -499,14 +565,13 @@ class ConversationSubmissionViewSet(viewsets.ModelViewSet): } ) - return Response({ - 'code': 200, - 'message': '审核完成', - 'data': ConversationSubmissionSerializer(submission).data - }) + return self.get_standard_response( + message='审核完成', + data=ConversationSubmissionSerializer(submission).data + ) -class ConversationEvaluationViewSet(viewsets.ModelViewSet): +class ConversationEvaluationViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = ConversationEvaluation.objects.all() serializer_class = ConversationEvaluationSerializer authentication_classes = [CustomTokenAuthentication] @@ -523,20 +588,22 @@ class ConversationEvaluationViewSet(viewsets.ModelViewSet): needs_satisfied = request.data.get('needs_satisfied') if not conversation_id or not has_logical_issues or not needs_satisfied: - return Response({ - 'code': 400, - 'message': '对话ID、逻辑问题和需求满足度不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='对话ID、逻辑问题和需求满足度不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) try: conversation = Conversation.objects.get(id=conversation_id) except Conversation.DoesNotExist: - return Response({ - 'code': 404, - 'message': '对话不存在', - 'data': None - }, status=status.HTTP_404_NOT_FOUND) + return self.get_standard_response( + code=404, + message='对话不存在', + data=None, + status_code=status.HTTP_404_NOT_FOUND + ) # 创建或更新评估 evaluation, created = ConversationEvaluation.objects.update_or_create( @@ -563,14 +630,13 @@ class ConversationEvaluationViewSet(viewsets.ModelViewSet): } ) - return Response({ - 'code': 200, - 'message': '评估提交成功', - 'data': ConversationEvaluationSerializer(evaluation).data - }) + return self.get_standard_response( + message='评估提交成功', + data=ConversationEvaluationSerializer(evaluation).data + ) -class SystemConfigViewSet(viewsets.ModelViewSet): +class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = SystemConfig.objects.all() serializer_class = SystemConfigSerializer authentication_classes = [CustomTokenAuthentication] @@ -589,33 +655,31 @@ class SystemConfigViewSet(viewsets.ModelViewSet): # 获取当前模型 model_config = SystemConfig.objects.filter(config_key='current_model').first() if model_config: - return Response({ - 'code': 200, - 'message': '成功', - 'data': {'model': model_config.config_value} - }) - return Response({ - 'code': 200, - 'message': '成功', - 'data': {'model': '默认模型'} - }) + return self.get_standard_response( + data={'model': model_config.config_value} + ) + return self.get_standard_response( + data={'model': '默认模型'} + ) elif request.method == 'POST': # 设置当前模型 if request.user.role != 'admin': - return Response({ - 'code': 403, - 'message': '只有管理员可以更改模型', - 'data': None - }, status=status.HTTP_403_FORBIDDEN) + return self.get_standard_response( + code=403, + message='只有管理员可以更改模型', + data=None, + status_code=status.HTTP_403_FORBIDDEN + ) model_name = request.data.get('model') if not model_name: - return Response({ - 'code': 400, - 'message': '模型名称不能为空', - 'data': None - }, status=status.HTTP_400_BAD_REQUEST) + return self.get_standard_response( + code=400, + message='模型名称不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) # 更新或创建配置 config, created = SystemConfig.objects.update_or_create( @@ -629,19 +693,16 @@ class SystemConfigViewSet(viewsets.ModelViewSet): } ) - return Response({ - 'code': 200, - 'message': '模型设置成功', - 'data': {'model': model_name} - }) + return self.get_standard_response( + message='模型设置成功', + data={'model': model_name} + ) @action(detail=False, methods=['get']) def models(self, request): # 返回可用的模型列表 - return Response({ - 'code': 200, - 'message': '成功', - 'data': { + return self.get_standard_response( + data={ 'models': [ {'id': 'model1', 'name': 'GPT-3.5'}, {'id': 'model2', 'name': 'GPT-4'}, @@ -650,4 +711,4 @@ class SystemConfigViewSet(viewsets.ModelViewSet): {'id': 'model5', 'name': 'Qwen'} ] } - }) \ No newline at end of file + ) \ No newline at end of file