daren/apps/rlhf/views.py
2025-06-09 18:21:37 +08:00

1370 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.pagination import PageNumberPagination
from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig
)
from .serializers import (
ConversationSerializer, MessageSerializer, FeedbackSerializer,
FeedbackTagSerializer, DetailedFeedbackSerializer, ConversationSubmissionSerializer,
ConversationEvaluationSerializer, SystemConfigSerializer
)
from apps.user.models import User, UserActivityLog, AnnotationStats
from django.utils import timezone
import uuid
import json
from django.db.models import Count, Avg, Sum, Q, F, Case, When, IntegerField
from datetime import datetime, timedelta
from django.db import transaction
from django.db.models.functions import TruncDate
from apps.user.authentication import CustomTokenAuthentication
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
from django.http import HttpResponse
from io import StringIO
from django.core.management import call_command
# 创建统一响应格式的基类
class StandardResponseMixin:
"""标准响应格式的混合类用于统一API响应格式"""
def get_standard_response(self, data=None, message="成功", code=200, status_code=None):
"""返回标准格式的响应"""
response_data = {
"code": code,
"message": message,
"data": data
}
return Response(response_data, status=status_code or status.HTTP_200_OK)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return self.get_standard_response(
data=serializer.data,
message="创建成功",
code=200,
status_code=status.HTTP_201_CREATED
)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return self.get_standard_response(data=serializer.data)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return self.get_standard_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return self.get_standard_response(data=serializer.data, message="更新成功")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return self.get_standard_response(message="删除成功", data=None)
def get_paginated_response(self, data):
"""
重写分页响应格式
"""
assert self.paginator is not None
return self.get_standard_response(
data={
'count': self.paginator.page.paginator.count,
'next': self.paginator.get_next_link(),
'previous': self.paginator.get_previous_link(),
'results': data
}
)
class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Conversation.objects.filter(user=user).order_by('-created_at')
def perform_create(self, serializer):
serializer.save(user=self.request.user)
@action(detail=True, methods=['get'])
def messages(self, request, pk=None):
conversation = self.get_object()
messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
serializer = MessageSerializer(messages, many=True)
return self.get_standard_response(data=serializer.data)
@action(detail=True, methods=['post'])
def message(self, request, pk=None):
conversation = self.get_object()
content = request.data.get('content')
if not content:
return self.get_standard_response(
code=400,
message='消息内容不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 创建用户消息
user_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='user',
content=content
)
# 这里需要调用AI服务获取回复
# 示例调用SiliconFlow或其他AI服务
ai_response = self._generate_ai_response(user_message.content, conversation)
# 创建AI回复消息
ai_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='assistant',
content=ai_response
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id)
messages = [
MessageSerializer(user_message).data,
MessageSerializer(ai_message).data
]
return self.get_standard_response(data=messages)
@action(detail=True, methods=['post'])
def submit(self, request, pk=None):
conversation = self.get_object()
title = request.data.get('title', '')
description = request.data.get('description', '')
if conversation.is_submitted:
return self.get_standard_response(
code=400,
message='该对话已提交',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新对话为已提交状态
conversation.is_submitted = True
conversation.save()
# 创建提交记录
submission = ConversationSubmission.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
user=request.user,
title=title,
description=description,
status='submitted',
submitted_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_submit',
target_type='conversation',
target_id=str(conversation.id),
details={'title': title}
)
return self.get_standard_response(
message='对话提交成功',
data={'submission_id': submission.id}
)
@action(detail=True, methods=['post'])
def resume(self, request, pk=None):
conversation = self.get_object()
if not conversation.is_submitted:
return self.get_standard_response(
code=400,
message='该对话未提交,无需恢复',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新对话为未提交状态
conversation.is_submitted = False
conversation.save()
# 获取最新的提交记录
submission = ConversationSubmission.objects.filter(
conversation=conversation
).order_by('-submitted_at').first()
if submission and submission.status == 'submitted':
submission.status = 'rejected'
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_resume',
target_type='conversation',
target_id=str(conversation.id)
)
return self.get_standard_response(
message='对话已恢复为未提交状态',
data=None
)
def _generate_ai_response(self, user_message, conversation):
"""
生成AI回复通过调用SiliconFlow API
"""
logger = logging.getLogger(__name__)
try:
# 从系统配置获取当前使用的模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
model_name = model_config.config_value if model_config else getattr(settings, 'DEFAULT_AI_MODEL', "Qwen/QwQ-32B")
# 初始化SiliconFlow客户端
sf_client = SiliconFlowClient(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None),
model=model_name
)
# 获取系统提示词(如果有)
system_prompt_config = SystemConfig.objects.filter(config_key='system_prompt').first()
if system_prompt_config and system_prompt_config.config_value:
sf_client.set_system_message(system_prompt_config.config_value)
# 获取历史消息作为上下文
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
# 添加历史消息到客户端
for msg in history_messages:
sf_client.add_message(msg.role, msg.content)
# 调用API获取回复
logger.info(f"正在调用AI API (模型: {model_name}) 处理消息: {user_message[:50]}...")
response = sf_client.chat(user_message)
# 记录AI回复到日志
logger.info(f"AI回复成功回复长度: {len(response)}")
return response
except Exception as e:
logger.exception(f"AI API调用失败: {str(e)}")
return f"很抱歉AI服务暂时不可用: {str(e)}"
def _update_annotation_stats(self, user_id):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新消息计数
stats.messages_count += 1
stats.save()
@action(detail=False, methods=['get'])
def dashboard(self, request):
"""获取仪表盘数据,包括反馈统计、对话统计等"""
user_id = request.user.id
try:
# 获取基础统计
feedback_stats = self._get_feedback_stats(user_id)
# 获取最近对话
recent_conversations = self._get_recent_conversations(user_id, limit=5)
# 获取对话统计
conversation_stats = self._get_conversation_stats(user_id)
# 获取反馈标签统计
tag_stats = self._get_tag_usage_stats(user_id)
# 获取反馈趋势
trend_data = self._get_feedback_trend(user_id, days=7)
# 获取内联反馈统计
inline_stats = self._get_inline_feedback_stats(user_id)
# 构建统计数据
dashboard_data = {
'feedback_stats': feedback_stats,
'conversation_stats': conversation_stats,
'recent_conversations': recent_conversations,
'tag_stats': tag_stats,
'trend_data': trend_data,
'inline_stats': inline_stats
}
return self.get_standard_response(data=dashboard_data)
except Exception as e:
logger = logging.getLogger(__name__)
logger.exception(f"获取仪表盘数据失败: {str(e)}")
return self.get_standard_response(
code=500,
message=f'获取仪表盘数据失败: {str(e)}',
data=None,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
def _get_feedback_stats(self, user_id):
"""获取用户反馈统计"""
# 基本反馈统计
basic_feedback = Feedback.objects.filter(user_id=user_id).aggregate(
total=Count('id'),
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())),
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField()))
)
# 详细反馈统计
detailed_feedback = DetailedFeedback.objects.filter(user_id=user_id).aggregate(
total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative'))
)
# 合并统计
total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0)
positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0)
negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0)
# 计算质量分数0-100
quality_score = (positive / total * 100) if total > 0 else 0
return {
'total_annotations': total,
'positive_count': positive,
'negative_count': negative,
'quality_score': round(quality_score, 1)
}
def _get_recent_conversations(self, user_id, limit=5):
"""获取用户最近的对话"""
conversations = Conversation.objects.filter(
user_id=user_id
).order_by('-created_at')[:limit]
result = []
for conv in conversations:
# 获取最后一条消息内容作为对话摘要
last_message = Message.objects.filter(
conversation_id=conv.id
).order_by('-timestamp').first()
# 统计消息数
message_count = Message.objects.filter(conversation_id=conv.id).count()
# 统计反馈数
feedback_count = Feedback.objects.filter(conversation_id=conv.id).count()
detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.id).count()
result.append({
'id': str(conv.id),
'created_at': conv.created_at.isoformat(),
'is_submitted': conv.is_submitted,
'message_count': message_count,
'feedback_count': feedback_count + detailed_count,
'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "")
})
return result
def _get_conversation_stats(self, user_id):
"""获取对话统计"""
total_conversations = Conversation.objects.filter(user_id=user_id).count()
submitted_conversations = Conversation.objects.filter(user_id=user_id, is_submitted=True).count()
# 对话消息统计
message_stats = Message.objects.filter(
conversation__user_id=user_id
).aggregate(
total=Count('id'),
user_messages=Count('id', filter=Q(role='user')),
assistant_messages=Count('id', filter=Q(role='assistant'))
)
# 对话评估统计
evaluation_stats = ConversationEvaluation.objects.filter(
user_id=user_id
).aggregate(
total=Count('id'),
satisfied=Count('id', filter=Q(needs_satisfied='yes')),
partially=Count('id', filter=Q(needs_satisfied='partially')),
not_satisfied=Count('id', filter=Q(needs_satisfied='no')),
has_issues=Count('id', filter=Q(has_logical_issues='yes'))
)
return {
'total': total_conversations,
'submitted': submitted_conversations,
'messages': {
'total': message_stats['total'] or 0,
'user': message_stats['user_messages'] or 0,
'assistant': message_stats['assistant_messages'] or 0
},
'evaluations': {
'total': evaluation_stats['total'] or 0,
'satisfied': evaluation_stats['satisfied'] or 0,
'partially': evaluation_stats['partially'] or 0,
'not_satisfied': evaluation_stats['not_satisfied'] or 0,
'has_issues': evaluation_stats['has_issues'] or 0
}
}
def _get_tag_usage_stats(self, user_id):
"""获取标签使用统计"""
result = {'positive': [], 'negative': []}
# 分析DetailedFeedback中的标签使用情况
for feedback in DetailedFeedback.objects.filter(user_id=user_id):
if not feedback.feedback_tags:
continue
try:
# 尝试解析JSON标签列表
tag_ids = json.loads(feedback.feedback_tags)
if not isinstance(tag_ids, list):
continue
# 获取标签详情
for tag_id in tag_ids:
tag = FeedbackTag.objects.filter(id=tag_id).first()
if not tag:
continue
# 根据标签类型添加到对应列表
if tag.tag_type == 'positive':
found = False
for item in result['positive']:
if item['name'] == tag.tag_name:
item['count'] += 1
found = True
break
if not found:
result['positive'].append({
'name': tag.tag_name,
'count': 1
})
elif tag.tag_type == 'negative':
found = False
for item in result['negative']:
if item['name'] == tag.tag_name:
item['count'] += 1
found = True
break
if not found:
result['negative'].append({
'name': tag.tag_name,
'count': 1
})
except (json.JSONDecodeError, TypeError):
continue
# 按使用次数排序
result['positive'].sort(key=lambda x: x['count'], reverse=True)
result['negative'].sort(key=lambda x: x['count'], reverse=True)
# 只返回前5个
result['positive'] = result['positive'][:5]
result['negative'] = result['negative'][:5]
return result
def _get_feedback_trend(self, user_id, days=7):
"""获取反馈趋势数据"""
from datetime import datetime, timedelta
# 计算开始日期
start_date = timezone.now().date() - timedelta(days=days-1)
# 基本反馈按日期分组
basic_daily = Feedback.objects.filter(
user_id=user_id,
timestamp__date__gte=start_date
).annotate(
date=TruncDate('timestamp')
).values('date').annotate(
total=Count('id'),
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())),
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField()))
).order_by('date')
# 详细反馈按日期分组
detailed_daily = DetailedFeedback.objects.filter(
user_id=user_id,
created_at__date__gte=start_date
).annotate(
date=TruncDate('created_at')
).values('date').annotate(
total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative'))
).order_by('date')
# 合并两种反馈数据
daily_data = {}
for item in basic_daily:
date_str = item['date'].strftime('%Y-%m-%d')
daily_data[date_str] = {
'date': date_str,
'total': item['total'],
'positive': item['positive'],
'negative': item['negative']
}
for item in detailed_daily:
date_str = item['date'].strftime('%Y-%m-%d')
if date_str in daily_data:
daily_data[date_str]['total'] += item['total']
daily_data[date_str]['positive'] += item['positive']
daily_data[date_str]['negative'] += item['negative']
else:
daily_data[date_str] = {
'date': date_str,
'total': item['total'],
'positive': item['positive'],
'negative': item['negative']
}
# 构建完整的日期范围数据
result = {
'labels': [],
'positive': [],
'negative': []
}
current_date = start_date
end_date = timezone.now().date()
while current_date <= end_date:
date_str = current_date.strftime('%Y-%m-%d')
display_date = current_date.strftime('%m-%d') # 显示格式:月-日
result['labels'].append(display_date)
if date_str in daily_data:
result['positive'].append(daily_data[date_str]['positive'])
result['negative'].append(daily_data[date_str]['negative'])
else:
result['positive'].append(0)
result['negative'].append(0)
current_date += timedelta(days=1)
return result
def _get_inline_feedback_stats(self, user_id):
"""获取内联反馈统计"""
inline_stats = DetailedFeedback.objects.filter(
user_id=user_id,
is_inline=True
).aggregate(
total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative'))
)
return {
'total': inline_stats['total'] or 0,
'positive': inline_stats['positive'] or 0,
'negative': inline_stats['negative'] or 0
}
class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Message.objects.filter(conversation__user=user).order_by('timestamp')
class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Feedback.objects.all()
serializer_class = FeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Feedback.objects.filter(user=user).order_by('-timestamp')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_value = request.data.get('feedback_value')
if not message_id or not conversation_id:
return self.get_standard_response(
code=400,
message='消息ID和对话ID不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 创建或更新反馈
feedback_id = str(uuid.uuid4())
feedback, created = Feedback.objects.update_or_create(
message_id=message_id,
conversation_id=conversation_id,
user=request.user,
defaults={
'id': feedback_id,
'feedback_value': feedback_value,
'timestamp': timezone.now()
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_value)
return self.get_standard_response(
message='反馈提交成功',
data=FeedbackSerializer(feedback).data
)
def _update_annotation_stats(self, user_id, feedback_value):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_value > 0:
stats.positive_annotations += 1
elif feedback_value < 0:
stats.negative_annotations += 1
stats.save()
class FeedbackTagViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = FeedbackTag.objects.all()
serializer_class = FeedbackTagSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
tag_type = self.request.query_params.get('type')
if tag_type and tag_type in ['positive', 'negative']:
return FeedbackTag.objects.filter(tag_type=tag_type)
return FeedbackTag.objects.all()
class DetailedFeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = DetailedFeedback.objects.all()
serializer_class = DetailedFeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return DetailedFeedback.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_type = request.data.get('feedback_type')
feedback_tags = request.data.get('feedback_tags', [])
custom_tags = request.data.get('custom_tags', '')
custom_content = request.data.get('custom_content', '')
is_inline = request.data.get('is_inline', True)
if not message_id or not conversation_id or not feedback_type:
return self.get_standard_response(
code=400,
message='消息ID、对话ID和反馈类型不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 将标签列表转换为JSON字符串
if isinstance(feedback_tags, list):
feedback_tags = json.dumps(feedback_tags)
# 创建详细反馈
detailed_feedback = DetailedFeedback.objects.create(
id=str(uuid.uuid4()),
message=message,
conversation=conversation,
user=request.user,
feedback_type=feedback_type,
feedback_tags=feedback_tags,
custom_tags=custom_tags,
custom_content=custom_content,
is_inline=is_inline,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='detailed_feedback_submit',
target_type='message',
target_id=message_id,
details={
'feedback_type': feedback_type,
'is_inline': is_inline
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_type)
return self.get_standard_response(
message='详细反馈提交成功',
data=DetailedFeedbackSerializer(detailed_feedback).data
)
def _update_annotation_stats(self, user_id, feedback_type):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_type == 'positive':
stats.positive_annotations += 1
elif feedback_type == 'negative':
stats.negative_annotations += 1
stats.save()
class ConversationSubmissionViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = ConversationSubmission.objects.all()
serializer_class = ConversationSubmissionSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
# 管理员可以查看所有提交
if user.role == 'admin':
queryset = ConversationSubmission.objects.all()
else:
# 普通用户只能查看自己的提交
queryset = ConversationSubmission.objects.filter(user=user)
# 过滤状态
status_filter = self.request.query_params.get('status')
if status_filter:
queryset = queryset.filter(status=status_filter)
return queryset.order_by('-submitted_at')
@action(detail=True, methods=['post'])
def review(self, request, pk=None):
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以进行审核',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
submission = self.get_object()
status_value = request.data.get('status')
quality_score = request.data.get('quality_score')
reviewer_notes = request.data.get('reviewer_notes', '')
if not status_value or status_value not in ['accepted', 'rejected']:
return self.get_standard_response(
code=400,
message='状态值无效',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
if quality_score is not None and (quality_score < 1 or quality_score > 5):
return self.get_standard_response(
code=400,
message='质量分数必须在1-5之间',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新提交状态
submission.status = status_value
submission.quality_score = quality_score
submission.reviewer = request.user
submission.reviewer_notes = reviewer_notes
submission.reviewed_at = timezone.now()
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='submission_review',
target_type='submission',
target_id=submission.id,
details={
'status': status_value,
'quality_score': quality_score
}
)
return self.get_standard_response(
message='审核完成',
data=ConversationSubmissionSerializer(submission).data
)
class ConversationEvaluationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = ConversationEvaluation.objects.all()
serializer_class = ConversationEvaluationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return ConversationEvaluation.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
conversation_id = request.data.get('conversation_id')
overall_feeling = request.data.get('overall_feeling', '')
has_logical_issues = request.data.get('has_logical_issues')
needs_satisfied = request.data.get('needs_satisfied')
if not conversation_id or has_logical_issues is None or needs_satisfied is None:
return self.get_standard_response(
code=400,
message='对话ID、逻辑问题和需求满足度不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
conversation = Conversation.objects.get(id=conversation_id)
except Conversation.DoesNotExist:
return self.get_standard_response(
code=404,
message='对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 检查是否已存在评估
existing_evaluation = ConversationEvaluation.objects.filter(
conversation_id=conversation_id,
user=request.user
).first()
if existing_evaluation:
# 如果已存在评估,返回提示信息
return self.get_standard_response(
code=400,
message='您已经对这个对话进行过评估请使用PUT或PATCH方法更新评估',
data={
'evaluation_id': existing_evaluation.id,
'created_at': existing_evaluation.created_at
},
status_code=status.HTTP_400_BAD_REQUEST
)
# 创建新评估
evaluation = ConversationEvaluation.objects.create(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user=request.user,
overall_feeling=overall_feeling,
has_logical_issues=has_logical_issues,
needs_satisfied=needs_satisfied,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_evaluation',
target_type='conversation',
target_id=conversation_id,
details={
'has_logical_issues': has_logical_issues,
'needs_satisfied': needs_satisfied
}
)
return self.get_standard_response(
message='评估提交成功',
data=ConversationEvaluationSerializer(evaluation).data
)
class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = SystemConfig.objects.all()
serializer_class = SystemConfigSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
# 只有管理员可以查看所有配置
if self.request.user.role == 'admin':
return SystemConfig.objects.all()
# 其他用户只能查看公开配置
return SystemConfig.objects.filter(config_key__startswith='public_')
@action(detail=False, methods=['get', 'post'])
def model(self, request):
if request.method == 'GET':
# 获取当前模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
if model_config:
return self.get_standard_response(
data={'model': model_config.config_value}
)
return self.get_standard_response(
data={'model': '默认模型'}
)
elif request.method == 'POST':
# 设置当前模型
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以更改模型',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
model_name = request.data.get('model')
if not model_name:
return self.get_standard_response(
code=400,
message='模型名称不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新或创建配置
config, created = SystemConfig.objects.update_or_create(
config_key='current_model',
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'config_value': model_name,
'config_type': 'string',
'description': '当前使用的AI模型',
'updated_at': timezone.now()
}
)
return self.get_standard_response(
message='模型设置成功',
data={'model': model_name}
)
@action(detail=False, methods=['get'])
def models(self, request):
"""返回可用的模型列表"""
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
logger = logging.getLogger(__name__)
try:
# 从SiliconFlow获取可用模型列表
logger.info("正在获取可用模型列表...")
models = SiliconFlowClient.get_available_models(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None)
)
logger.info(f"成功获取 {len(models)} 个可用模型")
return self.get_standard_response(
data={'models': models}
)
except Exception as e:
logger.exception(f"获取模型列表失败: {str(e)}")
# 使用预定义的模型列表作为备选
fallback_models = SiliconFlowClient._get_fallback_models()
return self.get_standard_response(
code=200, # 仍然返回200以避免前端错误
message="无法从API获取模型列表使用预定义列表",
data={'models': fallback_models}
)
@action(detail=False, methods=['post'])
def export_feedback(self, request):
"""导出反馈数据到JSON文件"""
import json
from datetime import datetime
# 只允许管理员导出数据
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以导出数据',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
try:
# 获取导出数据
data = {
'conversations': [],
'feedback_summary': self._get_feedback_summary(),
'export_time': timezone.now().isoformat(),
'exporter': request.user.username
}
# 根据请求参数过滤数据
conversation_ids = request.data.get('conversation_ids', [])
user_ids = request.data.get('user_ids', [])
date_from = request.data.get('date_from')
date_to = request.data.get('date_to')
include_messages = request.data.get('include_messages', True)
# 构建查询条件
query_filter = Q()
if conversation_ids:
query_filter &= Q(id__in=conversation_ids)
if user_ids:
query_filter &= Q(user_id__in=user_ids)
if date_from:
try:
date_from = timezone.datetime.fromisoformat(date_from)
query_filter &= Q(created_at__gte=date_from)
except (ValueError, TypeError):
pass
if date_to:
try:
date_to = timezone.datetime.fromisoformat(date_to)
query_filter &= Q(created_at__lte=date_to)
except (ValueError, TypeError):
pass
# 查询对话数据
conversations = Conversation.objects.filter(query_filter)
# 构建导出数据
for conv in conversations.prefetch_related('messages'):
conv_data = {
'id': str(conv.id),
'user_id': str(conv.user_id),
'is_submitted': conv.is_submitted,
'created_at': conv.created_at.isoformat()
}
if include_messages:
conv_data['messages'] = []
for msg in conv.messages.all().order_by('timestamp'):
msg_data = {
'id': str(msg.id),
'role': msg.role,
'content': msg.content,
'timestamp': msg.timestamp.isoformat(),
'feedback': []
}
# 获取反馈数据
for fb in Feedback.objects.filter(message_id=msg.id):
msg_data['feedback'].append({
'id': str(fb.id),
'type': 'basic',
'user_id': str(fb.user_id),
'feedback_value': fb.feedback_value,
'timestamp': fb.timestamp.isoformat()
})
# 获取详细反馈
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
try:
tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else []
except (json.JSONDecodeError, TypeError):
tags = [dfb.feedback_tags] if dfb.feedback_tags else []
msg_data['feedback'].append({
'id': str(dfb.id),
'type': 'detailed',
'user_id': str(dfb.user_id),
'feedback_type': dfb.feedback_type,
'tags': tags,
'custom_tags': dfb.custom_tags,
'custom_content': dfb.custom_content,
'is_inline': dfb.is_inline,
'created_at': dfb.created_at.isoformat()
})
conv_data['messages'].append(msg_data)
data['conversations'].append(conv_data)
# 生成文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'rlhf_feedback_export_{timestamp}.json'
# 返回JSON响应让浏览器下载
response = HttpResponse(
json.dumps(data, ensure_ascii=False, indent=2),
content_type='application/json'
)
response['Content-Disposition'] = f'attachment; filename="{filename}"'
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='export_feedback',
details={
'filename': filename,
'conversations_count': len(data['conversations'])
}
)
return response
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.exception(f"导出反馈数据失败: {str(e)}")
return self.get_standard_response(
code=500,
message=f'导出数据失败: {str(e)}',
data=None,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
def _get_feedback_summary(self):
"""获取反馈数据摘要统计"""
from django.db.models import Count, Avg, Sum, Case, When, IntegerField
# 基本反馈统计
basic_feedback = Feedback.objects.aggregate(
total=Count('id'),
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0, output_field=IntegerField())),
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0, output_field=IntegerField())),
avg=Avg('feedback_value')
)
# 详细反馈统计
detailed_feedback = DetailedFeedback.objects.aggregate(
total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative')),
neutral=Count('id', filter=Q(feedback_type='neutral'))
)
# 合并统计
total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0)
positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0)
negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0)
# 计算正面反馈比例
positive_rate = (positive / total * 100) if total > 0 else 0
return {
'total_feedback': total,
'positive_feedback': positive,
'negative_feedback': negative,
'average_score': basic_feedback['avg'] or 0,
'positive_rate': positive_rate,
'detailed_feedback_count': detailed_feedback['total'] or 0
}
@action(detail=False, methods=['post'])
def run_command(self, request):
"""运行管理命令,仅限管理员使用"""
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以运行管理命令',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
command = request.data.get('command')
options = request.data.get('options', {})
if not command:
return self.get_standard_response(
code=400,
message='命令名称不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 限制只能运行安全的命令
allowed_commands = ['analyze_data', 'import_data', 'init_feedback_tags']
if command not in allowed_commands:
return self.get_standard_response(
code=400,
message=f'不允许运行该命令,允许的命令: {", ".join(allowed_commands)}',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from io import StringIO
from django.core.management import call_command
# 捕获命令输出
out = StringIO()
err = StringIO()
# 准备命令参数
cmd_args = []
cmd_kwargs = {}
# 处理options参数
for key, value in options.items():
if isinstance(value, bool) and value:
# 布尔参数为True时添加为flag
cmd_args.append(f'--{key}')
elif not isinstance(value, bool):
# 非布尔参数添加为key=value
cmd_kwargs[key] = value
# 执行命令
call_command(command, *cmd_args, stdout=out, stderr=err, **cmd_kwargs)
# 读取命令输出
stdout_output = out.getvalue()
stderr_output = err.getvalue()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='run_command',
target_type='system',
details={
'command': command,
'options': options,
'success': True
}
)
return self.get_standard_response(
message=f'命令 {command} 执行成功',
data={
'command': command,
'stdout': stdout_output,
'stderr': stderr_output
}
)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.exception(f"执行命令 {command} 失败: {str(e)}")
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='run_command',
target_type='system',
details={
'command': command,
'options': options,
'success': False,
'error': str(e)
}
)
return self.get_standard_response(
code=500,
message=f'执行命令失败: {str(e)}',
data=None,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)