From f9d06f22deb07398289feecc23693d874356a08f Mon Sep 17 00:00:00 2001 From: wanjia Date: Mon, 9 Jun 2025 18:21:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84rlhf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/rlhf/apps.py | 10 +- apps/rlhf/urls.py | 33 ++- apps/rlhf/views.py | 603 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 639 insertions(+), 7 deletions(-) diff --git a/apps/rlhf/apps.py b/apps/rlhf/apps.py index 6fd3057..75a5fb0 100644 --- a/apps/rlhf/apps.py +++ b/apps/rlhf/apps.py @@ -1,6 +1,6 @@ -from django.apps import AppConfig - - -class RlhfConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' +from django.apps import AppConfig + + +class RlhfConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' name = 'apps.rlhf' \ No newline at end of file diff --git a/apps/rlhf/urls.py b/apps/rlhf/urls.py index 3064636..5088da7 100644 --- a/apps/rlhf/urls.py +++ b/apps/rlhf/urls.py @@ -18,12 +18,43 @@ router.register(r'system-config', SystemConfigViewSet) urlpatterns = [ path('', include(router.urls)), - # 额外的RLHF相关API端点 + + # 对话相关API端点 path('conversation//messages/', ConversationViewSet.as_view({'get': 'messages'}), name='conversation-messages'), path('conversation//message/', ConversationViewSet.as_view({'post': 'message'}), name='send-message'), path('conversation//submit/', ConversationViewSet.as_view({'post': 'submit'}), name='submit-conversation'), path('conversation//resume/', ConversationViewSet.as_view({'post': 'resume'}), name='resume-conversation'), + + # 仪表盘和统计分析API端点 + path('dashboard/', ConversationViewSet.as_view({'get': 'dashboard'}), name='dashboard'), + + # 提交评审API端点 path('submission//review/', ConversationSubmissionViewSet.as_view({'post': 'review'}), name='review-submission'), + + # 系统配置和模型管理API端点 path('models/', SystemConfigViewSet.as_view({'get': 'models'}), name='models-list'), path('model/', SystemConfigViewSet.as_view({'get': 'model', 'post': 'model'}), name='current-model'), + + # 数据导出和命令执行API端点 + path('export-feedback/', SystemConfigViewSet.as_view({'post': 'export_feedback'}), name='export-feedback'), + path('run-command/', SystemConfigViewSet.as_view({'post': 'run_command'}), name='run-command'), + + # 将原app.py中的API路径映射到相应ViewSet方法,方便前端迁移 + path('api/conversation/new', ConversationViewSet.as_view({'post': 'create'}), name='new-conversation'), + path('api/conversation//messages', ConversationViewSet.as_view({'get': 'messages'}), name='api-conversation-messages'), + path('api/conversation//message', ConversationViewSet.as_view({'post': 'message'}), name='api-send-message'), + path('api/conversation//submit', ConversationViewSet.as_view({'post': 'submit'}), name='api-submit-conversation'), + path('api/conversation//resume', ConversationViewSet.as_view({'post': 'resume'}), name='api-resume-conversation'), + path('api/conversation//evaluation', ConversationEvaluationViewSet.as_view({ + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update' + }), name='api-conversation-evaluation'), + path('api/feedback', FeedbackViewSet.as_view({'post': 'create'}), name='api-create-feedback'), + path('api/feedback/detailed', DetailedFeedbackViewSet.as_view({'post': 'create'}), name='api-create-detailed-feedback'), + path('api/feedback/tags', FeedbackTagViewSet.as_view({'get': 'list'}), name='api-feedback-tags'), + path('api/annotations/dashboard', ConversationViewSet.as_view({'get': 'dashboard'}), name='api-annotations-dashboard'), + path('api/models', SystemConfigViewSet.as_view({'get': 'models'}), name='api-models-list'), + path('api/model', SystemConfigViewSet.as_view({'get': 'model', 'post': 'model'}), name='api-current-model'), ] \ No newline at end of file diff --git a/apps/rlhf/views.py b/apps/rlhf/views.py index 65de883..85cd810 100644 --- a/apps/rlhf/views.py +++ b/apps/rlhf/views.py @@ -17,7 +17,7 @@ from apps.user.models import User, UserActivityLog, AnnotationStats from django.utils import timezone import uuid import json -from django.db.models import Count, Avg, Sum, Q, F +from django.db.models import Count, Avg, Sum, Q, F, Case, When, IntegerField from datetime import datetime, timedelta from django.db import transaction from django.db.models.functions import TruncDate @@ -25,6 +25,9 @@ from apps.user.authentication import CustomTokenAuthentication from .siliconflow_client import SiliconFlowClient from django.conf import settings import logging +from django.http import HttpResponse +from io import StringIO +from django.core.management import call_command # 创建统一响应格式的基类 @@ -302,6 +305,319 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): stats.messages_count += 1 stats.save() + @action(detail=False, methods=['get']) + def dashboard(self, request): + """获取仪表盘数据,包括反馈统计、对话统计等""" + user_id = request.user.id + + try: + # 获取基础统计 + feedback_stats = self._get_feedback_stats(user_id) + + # 获取最近对话 + recent_conversations = self._get_recent_conversations(user_id, limit=5) + + # 获取对话统计 + conversation_stats = self._get_conversation_stats(user_id) + + # 获取反馈标签统计 + tag_stats = self._get_tag_usage_stats(user_id) + + # 获取反馈趋势 + trend_data = self._get_feedback_trend(user_id, days=7) + + # 获取内联反馈统计 + inline_stats = self._get_inline_feedback_stats(user_id) + + # 构建统计数据 + dashboard_data = { + 'feedback_stats': feedback_stats, + 'conversation_stats': conversation_stats, + 'recent_conversations': recent_conversations, + 'tag_stats': tag_stats, + 'trend_data': trend_data, + 'inline_stats': inline_stats + } + + return self.get_standard_response(data=dashboard_data) + + except Exception as e: + logger = logging.getLogger(__name__) + logger.exception(f"获取仪表盘数据失败: {str(e)}") + + return self.get_standard_response( + code=500, + message=f'获取仪表盘数据失败: {str(e)}', + data=None, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + def _get_feedback_stats(self, user_id): + """获取用户反馈统计""" + # 基本反馈统计 + basic_feedback = Feedback.objects.filter(user_id=user_id).aggregate( + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField())) + ) + + # 详细反馈统计 + detailed_feedback = DetailedFeedback.objects.filter(user_id=user_id).aggregate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) + + # 合并统计 + total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) + positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) + negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) + + # 计算质量分数(0-100) + quality_score = (positive / total * 100) if total > 0 else 0 + + return { + 'total_annotations': total, + 'positive_count': positive, + 'negative_count': negative, + 'quality_score': round(quality_score, 1) + } + + def _get_recent_conversations(self, user_id, limit=5): + """获取用户最近的对话""" + conversations = Conversation.objects.filter( + user_id=user_id + ).order_by('-created_at')[:limit] + + result = [] + for conv in conversations: + # 获取最后一条消息内容作为对话摘要 + last_message = Message.objects.filter( + conversation_id=conv.id + ).order_by('-timestamp').first() + + # 统计消息数 + message_count = Message.objects.filter(conversation_id=conv.id).count() + + # 统计反馈数 + feedback_count = Feedback.objects.filter(conversation_id=conv.id).count() + detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.id).count() + + result.append({ + 'id': str(conv.id), + 'created_at': conv.created_at.isoformat(), + 'is_submitted': conv.is_submitted, + 'message_count': message_count, + 'feedback_count': feedback_count + detailed_count, + 'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "") + }) + + return result + + def _get_conversation_stats(self, user_id): + """获取对话统计""" + total_conversations = Conversation.objects.filter(user_id=user_id).count() + submitted_conversations = Conversation.objects.filter(user_id=user_id, is_submitted=True).count() + + # 对话消息统计 + message_stats = Message.objects.filter( + conversation__user_id=user_id + ).aggregate( + total=Count('id'), + user_messages=Count('id', filter=Q(role='user')), + assistant_messages=Count('id', filter=Q(role='assistant')) + ) + + # 对话评估统计 + evaluation_stats = ConversationEvaluation.objects.filter( + user_id=user_id + ).aggregate( + total=Count('id'), + satisfied=Count('id', filter=Q(needs_satisfied='yes')), + partially=Count('id', filter=Q(needs_satisfied='partially')), + not_satisfied=Count('id', filter=Q(needs_satisfied='no')), + has_issues=Count('id', filter=Q(has_logical_issues='yes')) + ) + + return { + 'total': total_conversations, + 'submitted': submitted_conversations, + 'messages': { + 'total': message_stats['total'] or 0, + 'user': message_stats['user_messages'] or 0, + 'assistant': message_stats['assistant_messages'] or 0 + }, + 'evaluations': { + 'total': evaluation_stats['total'] or 0, + 'satisfied': evaluation_stats['satisfied'] or 0, + 'partially': evaluation_stats['partially'] or 0, + 'not_satisfied': evaluation_stats['not_satisfied'] or 0, + 'has_issues': evaluation_stats['has_issues'] or 0 + } + } + + def _get_tag_usage_stats(self, user_id): + """获取标签使用统计""" + result = {'positive': [], 'negative': []} + + # 分析DetailedFeedback中的标签使用情况 + for feedback in DetailedFeedback.objects.filter(user_id=user_id): + if not feedback.feedback_tags: + continue + + try: + # 尝试解析JSON标签列表 + tag_ids = json.loads(feedback.feedback_tags) + if not isinstance(tag_ids, list): + continue + + # 获取标签详情 + for tag_id in tag_ids: + tag = FeedbackTag.objects.filter(id=tag_id).first() + if not tag: + continue + + # 根据标签类型添加到对应列表 + if tag.tag_type == 'positive': + found = False + for item in result['positive']: + if item['name'] == tag.tag_name: + item['count'] += 1 + found = True + break + + if not found: + result['positive'].append({ + 'name': tag.tag_name, + 'count': 1 + }) + elif tag.tag_type == 'negative': + found = False + for item in result['negative']: + if item['name'] == tag.tag_name: + item['count'] += 1 + found = True + break + + if not found: + result['negative'].append({ + 'name': tag.tag_name, + 'count': 1 + }) + except (json.JSONDecodeError, TypeError): + continue + + # 按使用次数排序 + result['positive'].sort(key=lambda x: x['count'], reverse=True) + result['negative'].sort(key=lambda x: x['count'], reverse=True) + + # 只返回前5个 + result['positive'] = result['positive'][:5] + result['negative'] = result['negative'][:5] + + return result + + def _get_feedback_trend(self, user_id, days=7): + """获取反馈趋势数据""" + from datetime import datetime, timedelta + + # 计算开始日期 + start_date = timezone.now().date() - timedelta(days=days-1) + + # 基本反馈按日期分组 + basic_daily = Feedback.objects.filter( + user_id=user_id, + timestamp__date__gte=start_date + ).annotate( + date=TruncDate('timestamp') + ).values('date').annotate( + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField())) + ).order_by('date') + + # 详细反馈按日期分组 + detailed_daily = DetailedFeedback.objects.filter( + user_id=user_id, + created_at__date__gte=start_date + ).annotate( + date=TruncDate('created_at') + ).values('date').annotate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ).order_by('date') + + # 合并两种反馈数据 + daily_data = {} + + for item in basic_daily: + date_str = item['date'].strftime('%Y-%m-%d') + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + for item in detailed_daily: + date_str = item['date'].strftime('%Y-%m-%d') + if date_str in daily_data: + daily_data[date_str]['total'] += item['total'] + daily_data[date_str]['positive'] += item['positive'] + daily_data[date_str]['negative'] += item['negative'] + else: + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + # 构建完整的日期范围数据 + result = { + 'labels': [], + 'positive': [], + 'negative': [] + } + + current_date = start_date + end_date = timezone.now().date() + + while current_date <= end_date: + date_str = current_date.strftime('%Y-%m-%d') + display_date = current_date.strftime('%m-%d') # 显示格式:月-日 + + result['labels'].append(display_date) + + if date_str in daily_data: + result['positive'].append(daily_data[date_str]['positive']) + result['negative'].append(daily_data[date_str]['negative']) + else: + result['positive'].append(0) + result['negative'].append(0) + + current_date += timedelta(days=1) + + return result + + def _get_inline_feedback_stats(self, user_id): + """获取内联反馈统计""" + inline_stats = DetailedFeedback.objects.filter( + user_id=user_id, + is_inline=True + ).aggregate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) + + return { + 'total': inline_stats['total'] or 0, + 'positive': inline_stats['positive'] or 0, + 'negative': inline_stats['negative'] or 0 + } + class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = Message.objects.all() @@ -766,4 +1082,289 @@ class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet): code=200, # 仍然返回200以避免前端错误 message="无法从API获取模型列表,使用预定义列表", data={'models': fallback_models} + ) + + @action(detail=False, methods=['post']) + def export_feedback(self, request): + """导出反馈数据到JSON文件""" + import json + from datetime import datetime + + # 只允许管理员导出数据 + if request.user.role != 'admin': + return self.get_standard_response( + code=403, + message='只有管理员可以导出数据', + data=None, + status_code=status.HTTP_403_FORBIDDEN + ) + + try: + # 获取导出数据 + data = { + 'conversations': [], + 'feedback_summary': self._get_feedback_summary(), + 'export_time': timezone.now().isoformat(), + 'exporter': request.user.username + } + + # 根据请求参数过滤数据 + conversation_ids = request.data.get('conversation_ids', []) + user_ids = request.data.get('user_ids', []) + date_from = request.data.get('date_from') + date_to = request.data.get('date_to') + include_messages = request.data.get('include_messages', True) + + # 构建查询条件 + query_filter = Q() + if conversation_ids: + query_filter &= Q(id__in=conversation_ids) + if user_ids: + query_filter &= Q(user_id__in=user_ids) + if date_from: + try: + date_from = timezone.datetime.fromisoformat(date_from) + query_filter &= Q(created_at__gte=date_from) + except (ValueError, TypeError): + pass + if date_to: + try: + date_to = timezone.datetime.fromisoformat(date_to) + query_filter &= Q(created_at__lte=date_to) + except (ValueError, TypeError): + pass + + # 查询对话数据 + conversations = Conversation.objects.filter(query_filter) + + # 构建导出数据 + for conv in conversations.prefetch_related('messages'): + conv_data = { + 'id': str(conv.id), + 'user_id': str(conv.user_id), + 'is_submitted': conv.is_submitted, + 'created_at': conv.created_at.isoformat() + } + + if include_messages: + conv_data['messages'] = [] + + for msg in conv.messages.all().order_by('timestamp'): + msg_data = { + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'timestamp': msg.timestamp.isoformat(), + 'feedback': [] + } + + # 获取反馈数据 + for fb in Feedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(fb.id), + 'type': 'basic', + 'user_id': str(fb.user_id), + 'feedback_value': fb.feedback_value, + 'timestamp': fb.timestamp.isoformat() + }) + + # 获取详细反馈 + for dfb in DetailedFeedback.objects.filter(message_id=msg.id): + try: + tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else [] + except (json.JSONDecodeError, TypeError): + tags = [dfb.feedback_tags] if dfb.feedback_tags else [] + + msg_data['feedback'].append({ + 'id': str(dfb.id), + 'type': 'detailed', + 'user_id': str(dfb.user_id), + 'feedback_type': dfb.feedback_type, + 'tags': tags, + 'custom_tags': dfb.custom_tags, + 'custom_content': dfb.custom_content, + 'is_inline': dfb.is_inline, + 'created_at': dfb.created_at.isoformat() + }) + + conv_data['messages'].append(msg_data) + + data['conversations'].append(conv_data) + + # 生成文件名 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'rlhf_feedback_export_{timestamp}.json' + + # 返回JSON响应,让浏览器下载 + response = HttpResponse( + json.dumps(data, ensure_ascii=False, indent=2), + content_type='application/json' + ) + response['Content-Disposition'] = f'attachment; filename="{filename}"' + + # 记录活动日志 + UserActivityLog.objects.create( + user=request.user, + action_type='export_feedback', + details={ + 'filename': filename, + 'conversations_count': len(data['conversations']) + } + ) + + return response + + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.exception(f"导出反馈数据失败: {str(e)}") + + return self.get_standard_response( + code=500, + message=f'导出数据失败: {str(e)}', + data=None, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + + def _get_feedback_summary(self): + """获取反馈数据摘要统计""" + from django.db.models import Count, Avg, Sum, Case, When, IntegerField + + # 基本反馈统计 + basic_feedback = Feedback.objects.aggregate( + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField())), + avg=Avg('feedback_value') + ) + + # 详细反馈统计 + detailed_feedback = DetailedFeedback.objects.aggregate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')), + neutral=Count('id', filter=Q(feedback_type='neutral')) + ) + + # 合并统计 + total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) + positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) + negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) + + # 计算正面反馈比例 + positive_rate = (positive / total * 100) if total > 0 else 0 + + return { + 'total_feedback': total, + 'positive_feedback': positive, + 'negative_feedback': negative, + 'average_score': basic_feedback['avg'] or 0, + 'positive_rate': positive_rate, + 'detailed_feedback_count': detailed_feedback['total'] or 0 + } + + @action(detail=False, methods=['post']) + def run_command(self, request): + """运行管理命令,仅限管理员使用""" + if request.user.role != 'admin': + return self.get_standard_response( + code=403, + message='只有管理员可以运行管理命令', + data=None, + status_code=status.HTTP_403_FORBIDDEN + ) + + command = request.data.get('command') + options = request.data.get('options', {}) + + if not command: + return self.get_standard_response( + code=400, + message='命令名称不能为空', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) + + # 限制只能运行安全的命令 + allowed_commands = ['analyze_data', 'import_data', 'init_feedback_tags'] + if command not in allowed_commands: + return self.get_standard_response( + code=400, + message=f'不允许运行该命令,允许的命令: {", ".join(allowed_commands)}', + data=None, + status_code=status.HTTP_400_BAD_REQUEST + ) + + try: + from io import StringIO + from django.core.management import call_command + + # 捕获命令输出 + out = StringIO() + err = StringIO() + + # 准备命令参数 + cmd_args = [] + cmd_kwargs = {} + + # 处理options参数 + for key, value in options.items(): + if isinstance(value, bool) and value: + # 布尔参数为True时添加为flag + cmd_args.append(f'--{key}') + elif not isinstance(value, bool): + # 非布尔参数添加为key=value + cmd_kwargs[key] = value + + # 执行命令 + call_command(command, *cmd_args, stdout=out, stderr=err, **cmd_kwargs) + + # 读取命令输出 + stdout_output = out.getvalue() + stderr_output = err.getvalue() + + # 记录活动日志 + UserActivityLog.objects.create( + user=request.user, + action_type='run_command', + target_type='system', + details={ + 'command': command, + 'options': options, + 'success': True + } + ) + + return self.get_standard_response( + message=f'命令 {command} 执行成功', + data={ + 'command': command, + 'stdout': stdout_output, + 'stderr': stderr_output + } + ) + + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.exception(f"执行命令 {command} 失败: {str(e)}") + + # 记录活动日志 + UserActivityLog.objects.create( + user=request.user, + action_type='run_command', + target_type='system', + details={ + 'command': command, + 'options': options, + 'success': False, + 'error': str(e) + } + ) + + return self.get_standard_response( + code=500, + message=f'执行命令失败: {str(e)}', + data=None, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) \ No newline at end of file