daren/apps/rlhf/views.py
2025-06-09 18:00:00 +08:00

769 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.pagination import PageNumberPagination
from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig
)
from .serializers import (
ConversationSerializer, MessageSerializer, FeedbackSerializer,
FeedbackTagSerializer, DetailedFeedbackSerializer, ConversationSubmissionSerializer,
ConversationEvaluationSerializer, SystemConfigSerializer
)
from apps.user.models import User, UserActivityLog, AnnotationStats
from django.utils import timezone
import uuid
import json
from django.db.models import Count, Avg, Sum, Q, F
from datetime import datetime, timedelta
from django.db import transaction
from django.db.models.functions import TruncDate
from apps.user.authentication import CustomTokenAuthentication
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
# 创建统一响应格式的基类
class StandardResponseMixin:
"""标准响应格式的混合类用于统一API响应格式"""
def get_standard_response(self, data=None, message="成功", code=200, status_code=None):
"""返回标准格式的响应"""
response_data = {
"code": code,
"message": message,
"data": data
}
return Response(response_data, status=status_code or status.HTTP_200_OK)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return self.get_standard_response(
data=serializer.data,
message="创建成功",
code=200,
status_code=status.HTTP_201_CREATED
)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return self.get_standard_response(data=serializer.data)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return self.get_standard_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return self.get_standard_response(data=serializer.data, message="更新成功")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return self.get_standard_response(message="删除成功", data=None)
def get_paginated_response(self, data):
"""
重写分页响应格式
"""
assert self.paginator is not None
return self.get_standard_response(
data={
'count': self.paginator.page.paginator.count,
'next': self.paginator.get_next_link(),
'previous': self.paginator.get_previous_link(),
'results': data
}
)
class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Conversation.objects.filter(user=user).order_by('-created_at')
def perform_create(self, serializer):
serializer.save(user=self.request.user)
@action(detail=True, methods=['get'])
def messages(self, request, pk=None):
conversation = self.get_object()
messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
serializer = MessageSerializer(messages, many=True)
return self.get_standard_response(data=serializer.data)
@action(detail=True, methods=['post'])
def message(self, request, pk=None):
conversation = self.get_object()
content = request.data.get('content')
if not content:
return self.get_standard_response(
code=400,
message='消息内容不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 创建用户消息
user_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='user',
content=content
)
# 这里需要调用AI服务获取回复
# 示例调用SiliconFlow或其他AI服务
ai_response = self._generate_ai_response(user_message.content, conversation)
# 创建AI回复消息
ai_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='assistant',
content=ai_response
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id)
messages = [
MessageSerializer(user_message).data,
MessageSerializer(ai_message).data
]
return self.get_standard_response(data=messages)
@action(detail=True, methods=['post'])
def submit(self, request, pk=None):
conversation = self.get_object()
title = request.data.get('title', '')
description = request.data.get('description', '')
if conversation.is_submitted:
return self.get_standard_response(
code=400,
message='该对话已提交',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新对话为已提交状态
conversation.is_submitted = True
conversation.save()
# 创建提交记录
submission = ConversationSubmission.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
user=request.user,
title=title,
description=description,
status='submitted',
submitted_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_submit',
target_type='conversation',
target_id=str(conversation.id),
details={'title': title}
)
return self.get_standard_response(
message='对话提交成功',
data={'submission_id': submission.id}
)
@action(detail=True, methods=['post'])
def resume(self, request, pk=None):
conversation = self.get_object()
if not conversation.is_submitted:
return self.get_standard_response(
code=400,
message='该对话未提交,无需恢复',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新对话为未提交状态
conversation.is_submitted = False
conversation.save()
# 获取最新的提交记录
submission = ConversationSubmission.objects.filter(
conversation=conversation
).order_by('-submitted_at').first()
if submission and submission.status == 'submitted':
submission.status = 'rejected'
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_resume',
target_type='conversation',
target_id=str(conversation.id)
)
return self.get_standard_response(
message='对话已恢复为未提交状态',
data=None
)
def _generate_ai_response(self, user_message, conversation):
"""
生成AI回复通过调用SiliconFlow API
"""
logger = logging.getLogger(__name__)
try:
# 从系统配置获取当前使用的模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
model_name = model_config.config_value if model_config else getattr(settings, 'DEFAULT_AI_MODEL', "Qwen/QwQ-32B")
# 初始化SiliconFlow客户端
sf_client = SiliconFlowClient(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None),
model=model_name
)
# 获取系统提示词(如果有)
system_prompt_config = SystemConfig.objects.filter(config_key='system_prompt').first()
if system_prompt_config and system_prompt_config.config_value:
sf_client.set_system_message(system_prompt_config.config_value)
# 获取历史消息作为上下文
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
# 添加历史消息到客户端
for msg in history_messages:
sf_client.add_message(msg.role, msg.content)
# 调用API获取回复
logger.info(f"正在调用AI API (模型: {model_name}) 处理消息: {user_message[:50]}...")
response = sf_client.chat(user_message)
# 记录AI回复到日志
logger.info(f"AI回复成功回复长度: {len(response)}")
return response
except Exception as e:
logger.exception(f"AI API调用失败: {str(e)}")
return f"很抱歉AI服务暂时不可用: {str(e)}"
def _update_annotation_stats(self, user_id):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新消息计数
stats.messages_count += 1
stats.save()
class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Message.objects.filter(conversation__user=user).order_by('timestamp')
class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Feedback.objects.all()
serializer_class = FeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Feedback.objects.filter(user=user).order_by('-timestamp')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_value = request.data.get('feedback_value')
if not message_id or not conversation_id:
return self.get_standard_response(
code=400,
message='消息ID和对话ID不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 创建或更新反馈
feedback_id = str(uuid.uuid4())
feedback, created = Feedback.objects.update_or_create(
message_id=message_id,
conversation_id=conversation_id,
user=request.user,
defaults={
'id': feedback_id,
'feedback_value': feedback_value,
'timestamp': timezone.now()
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_value)
return self.get_standard_response(
message='反馈提交成功',
data=FeedbackSerializer(feedback).data
)
def _update_annotation_stats(self, user_id, feedback_value):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_value > 0:
stats.positive_annotations += 1
elif feedback_value < 0:
stats.negative_annotations += 1
stats.save()
class FeedbackTagViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = FeedbackTag.objects.all()
serializer_class = FeedbackTagSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
tag_type = self.request.query_params.get('type')
if tag_type and tag_type in ['positive', 'negative']:
return FeedbackTag.objects.filter(tag_type=tag_type)
return FeedbackTag.objects.all()
class DetailedFeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = DetailedFeedback.objects.all()
serializer_class = DetailedFeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return DetailedFeedback.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_type = request.data.get('feedback_type')
feedback_tags = request.data.get('feedback_tags', [])
custom_tags = request.data.get('custom_tags', '')
custom_content = request.data.get('custom_content', '')
is_inline = request.data.get('is_inline', True)
if not message_id or not conversation_id or not feedback_type:
return self.get_standard_response(
code=400,
message='消息ID、对话ID和反馈类型不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 将标签列表转换为JSON字符串
if isinstance(feedback_tags, list):
feedback_tags = json.dumps(feedback_tags)
# 创建详细反馈
detailed_feedback = DetailedFeedback.objects.create(
id=str(uuid.uuid4()),
message=message,
conversation=conversation,
user=request.user,
feedback_type=feedback_type,
feedback_tags=feedback_tags,
custom_tags=custom_tags,
custom_content=custom_content,
is_inline=is_inline,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='detailed_feedback_submit',
target_type='message',
target_id=message_id,
details={
'feedback_type': feedback_type,
'is_inline': is_inline
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_type)
return self.get_standard_response(
message='详细反馈提交成功',
data=DetailedFeedbackSerializer(detailed_feedback).data
)
def _update_annotation_stats(self, user_id, feedback_type):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_type == 'positive':
stats.positive_annotations += 1
elif feedback_type == 'negative':
stats.negative_annotations += 1
stats.save()
class ConversationSubmissionViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = ConversationSubmission.objects.all()
serializer_class = ConversationSubmissionSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
# 管理员可以查看所有提交
if user.role == 'admin':
queryset = ConversationSubmission.objects.all()
else:
# 普通用户只能查看自己的提交
queryset = ConversationSubmission.objects.filter(user=user)
# 过滤状态
status_filter = self.request.query_params.get('status')
if status_filter:
queryset = queryset.filter(status=status_filter)
return queryset.order_by('-submitted_at')
@action(detail=True, methods=['post'])
def review(self, request, pk=None):
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以进行审核',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
submission = self.get_object()
status_value = request.data.get('status')
quality_score = request.data.get('quality_score')
reviewer_notes = request.data.get('reviewer_notes', '')
if not status_value or status_value not in ['accepted', 'rejected']:
return self.get_standard_response(
code=400,
message='状态值无效',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
if quality_score is not None and (quality_score < 1 or quality_score > 5):
return self.get_standard_response(
code=400,
message='质量分数必须在1-5之间',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新提交状态
submission.status = status_value
submission.quality_score = quality_score
submission.reviewer = request.user
submission.reviewer_notes = reviewer_notes
submission.reviewed_at = timezone.now()
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='submission_review',
target_type='submission',
target_id=submission.id,
details={
'status': status_value,
'quality_score': quality_score
}
)
return self.get_standard_response(
message='审核完成',
data=ConversationSubmissionSerializer(submission).data
)
class ConversationEvaluationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = ConversationEvaluation.objects.all()
serializer_class = ConversationEvaluationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return ConversationEvaluation.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
conversation_id = request.data.get('conversation_id')
overall_feeling = request.data.get('overall_feeling', '')
has_logical_issues = request.data.get('has_logical_issues')
needs_satisfied = request.data.get('needs_satisfied')
if not conversation_id or has_logical_issues is None or needs_satisfied is None:
return self.get_standard_response(
code=400,
message='对话ID、逻辑问题和需求满足度不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
try:
conversation = Conversation.objects.get(id=conversation_id)
except Conversation.DoesNotExist:
return self.get_standard_response(
code=404,
message='对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
# 检查是否已存在评估
existing_evaluation = ConversationEvaluation.objects.filter(
conversation_id=conversation_id,
user=request.user
).first()
if existing_evaluation:
# 如果已存在评估,返回提示信息
return self.get_standard_response(
code=400,
message='您已经对这个对话进行过评估请使用PUT或PATCH方法更新评估',
data={
'evaluation_id': existing_evaluation.id,
'created_at': existing_evaluation.created_at
},
status_code=status.HTTP_400_BAD_REQUEST
)
# 创建新评估
evaluation = ConversationEvaluation.objects.create(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user=request.user,
overall_feeling=overall_feeling,
has_logical_issues=has_logical_issues,
needs_satisfied=needs_satisfied,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_evaluation',
target_type='conversation',
target_id=conversation_id,
details={
'has_logical_issues': has_logical_issues,
'needs_satisfied': needs_satisfied
}
)
return self.get_standard_response(
message='评估提交成功',
data=ConversationEvaluationSerializer(evaluation).data
)
class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = SystemConfig.objects.all()
serializer_class = SystemConfigSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
# 只有管理员可以查看所有配置
if self.request.user.role == 'admin':
return SystemConfig.objects.all()
# 其他用户只能查看公开配置
return SystemConfig.objects.filter(config_key__startswith='public_')
@action(detail=False, methods=['get', 'post'])
def model(self, request):
if request.method == 'GET':
# 获取当前模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
if model_config:
return self.get_standard_response(
data={'model': model_config.config_value}
)
return self.get_standard_response(
data={'model': '默认模型'}
)
elif request.method == 'POST':
# 设置当前模型
if request.user.role != 'admin':
return self.get_standard_response(
code=403,
message='只有管理员可以更改模型',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
model_name = request.data.get('model')
if not model_name:
return self.get_standard_response(
code=400,
message='模型名称不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
# 更新或创建配置
config, created = SystemConfig.objects.update_or_create(
config_key='current_model',
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'config_value': model_name,
'config_type': 'string',
'description': '当前使用的AI模型',
'updated_at': timezone.now()
}
)
return self.get_standard_response(
message='模型设置成功',
data={'model': model_name}
)
@action(detail=False, methods=['get'])
def models(self, request):
"""返回可用的模型列表"""
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
logger = logging.getLogger(__name__)
try:
# 从SiliconFlow获取可用模型列表
logger.info("正在获取可用模型列表...")
models = SiliconFlowClient.get_available_models(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None)
)
logger.info(f"成功获取 {len(models)} 个可用模型")
return self.get_standard_response(
data={'models': models}
)
except Exception as e:
logger.exception(f"获取模型列表失败: {str(e)}")
# 使用预定义的模型列表作为备选
fallback_models = SiliconFlowClient._get_fallback_models()
return self.get_standard_response(
code=200, # 仍然返回200以避免前端错误
message="无法从API获取模型列表使用预定义列表",
data={'models': fallback_models}
)