daren/apps/rlhf/views.py

769 lines
29 KiB
Python
Raw Normal View History

2025-06-09 16:29:14 +08:00
from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.pagination import PageNumberPagination
from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig
)
from .serializers import (
ConversationSerializer, MessageSerializer, FeedbackSerializer,
FeedbackTagSerializer, DetailedFeedbackSerializer, ConversationSubmissionSerializer,
ConversationEvaluationSerializer, SystemConfigSerializer
)
from apps.user.models import User, UserActivityLog, AnnotationStats
from django.utils import timezone
import uuid
import json
from django.db.models import Count, Avg, Sum, Q, F
from datetime import datetime, timedelta
from django.db import transaction
from django.db.models.functions import TruncDate
from apps.user.authentication import CustomTokenAuthentication
2025-06-09 18:00:00 +08:00
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
2025-06-09 16:29:14 +08:00
2025-06-09 16:56:16 +08:00
# 创建统一响应格式的基类
class StandardResponseMixin:
"""标准响应格式的混合类用于统一API响应格式"""
def get_standard_response(self, data=None, message="成功", code=200, status_code=None):
"""返回标准格式的响应"""
response_data = {
"code": code,
"message": message,
"data": data
}
return Response(response_data, status=status_code or status.HTTP_200_OK)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return self.get_standard_response(
data=serializer.data,
message="创建成功",
code=200,
status_code=status.HTTP_201_CREATED
)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return self.get_standard_response(data=serializer.data)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return self.get_standard_response(data=serializer.data)
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return self.get_standard_response(data=serializer.data, message="更新成功")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return self.get_standard_response(message="删除成功", data=None)
def get_paginated_response(self, data):
"""
重写分页响应格式
"""
assert self.paginator is not None
return self.get_standard_response(
data={
'count': self.paginator.page.paginator.count,
'next': self.paginator.get_next_link(),
'previous': self.paginator.get_previous_link(),
'results': data
}
)
class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Conversation.objects.filter(user=user).order_by('-created_at')
def perform_create(self, serializer):
serializer.save(user=self.request.user)
@action(detail=True, methods=['get'])
def messages(self, request, pk=None):
conversation = self.get_object()
messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
serializer = MessageSerializer(messages, many=True)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(data=serializer.data)
2025-06-09 16:29:14 +08:00
@action(detail=True, methods=['post'])
def message(self, request, pk=None):
conversation = self.get_object()
content = request.data.get('content')
if not content:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='消息内容不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
# 创建用户消息
user_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='user',
content=content
)
# 这里需要调用AI服务获取回复
# 示例调用SiliconFlow或其他AI服务
ai_response = self._generate_ai_response(user_message.content, conversation)
# 创建AI回复消息
ai_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='assistant',
content=ai_response
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id)
messages = [
MessageSerializer(user_message).data,
MessageSerializer(ai_message).data
]
2025-06-09 16:56:16 +08:00
return self.get_standard_response(data=messages)
2025-06-09 16:29:14 +08:00
@action(detail=True, methods=['post'])
def submit(self, request, pk=None):
conversation = self.get_object()
title = request.data.get('title', '')
description = request.data.get('description', '')
if conversation.is_submitted:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='该对话已提交',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
# 更新对话为已提交状态
conversation.is_submitted = True
conversation.save()
# 创建提交记录
submission = ConversationSubmission.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
user=request.user,
title=title,
description=description,
status='submitted',
submitted_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_submit',
target_type='conversation',
target_id=str(conversation.id),
details={'title': title}
)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='对话提交成功',
data={'submission_id': submission.id}
)
2025-06-09 16:29:14 +08:00
@action(detail=True, methods=['post'])
def resume(self, request, pk=None):
conversation = self.get_object()
if not conversation.is_submitted:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='该对话未提交,无需恢复',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
# 更新对话为未提交状态
conversation.is_submitted = False
conversation.save()
# 获取最新的提交记录
submission = ConversationSubmission.objects.filter(
conversation=conversation
).order_by('-submitted_at').first()
if submission and submission.status == 'submitted':
submission.status = 'rejected'
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_resume',
target_type='conversation',
target_id=str(conversation.id)
)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='对话已恢复为未提交状态',
data=None
)
2025-06-09 16:29:14 +08:00
def _generate_ai_response(self, user_message, conversation):
"""
2025-06-09 18:00:00 +08:00
生成AI回复通过调用SiliconFlow API
2025-06-09 16:29:14 +08:00
"""
2025-06-09 18:00:00 +08:00
logger = logging.getLogger(__name__)
try:
# 从系统配置获取当前使用的模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
model_name = model_config.config_value if model_config else getattr(settings, 'DEFAULT_AI_MODEL', "Qwen/QwQ-32B")
# 初始化SiliconFlow客户端
sf_client = SiliconFlowClient(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None),
model=model_name
)
# 获取系统提示词(如果有)
system_prompt_config = SystemConfig.objects.filter(config_key='system_prompt').first()
if system_prompt_config and system_prompt_config.config_value:
sf_client.set_system_message(system_prompt_config.config_value)
# 获取历史消息作为上下文
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
# 添加历史消息到客户端
for msg in history_messages:
sf_client.add_message(msg.role, msg.content)
# 调用API获取回复
logger.info(f"正在调用AI API (模型: {model_name}) 处理消息: {user_message[:50]}...")
response = sf_client.chat(user_message)
# 记录AI回复到日志
logger.info(f"AI回复成功回复长度: {len(response)}")
return response
except Exception as e:
logger.exception(f"AI API调用失败: {str(e)}")
return f"很抱歉AI服务暂时不可用: {str(e)}"
2025-06-09 16:29:14 +08:00
def _update_annotation_stats(self, user_id):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新消息计数
stats.messages_count += 1
stats.save()
2025-06-09 16:56:16 +08:00
class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = Message.objects.all()
serializer_class = MessageSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Message.objects.filter(conversation__user=user).order_by('timestamp')
2025-06-09 16:56:16 +08:00
class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = Feedback.objects.all()
serializer_class = FeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Feedback.objects.filter(user=user).order_by('-timestamp')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_value = request.data.get('feedback_value')
if not message_id or not conversation_id:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='消息ID和对话ID不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
2025-06-09 16:29:14 +08:00
# 创建或更新反馈
2025-06-09 17:29:57 +08:00
feedback_id = str(uuid.uuid4())
2025-06-09 16:29:14 +08:00
feedback, created = Feedback.objects.update_or_create(
message_id=message_id,
conversation_id=conversation_id,
user=request.user,
defaults={
2025-06-09 17:29:57 +08:00
'id': feedback_id,
2025-06-09 16:29:14 +08:00
'feedback_value': feedback_value,
'timestamp': timezone.now()
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_value)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='反馈提交成功',
data=FeedbackSerializer(feedback).data
)
2025-06-09 16:29:14 +08:00
def _update_annotation_stats(self, user_id, feedback_value):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_value > 0:
stats.positive_annotations += 1
elif feedback_value < 0:
stats.negative_annotations += 1
stats.save()
2025-06-09 16:56:16 +08:00
class FeedbackTagViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = FeedbackTag.objects.all()
serializer_class = FeedbackTagSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
tag_type = self.request.query_params.get('type')
if tag_type and tag_type in ['positive', 'negative']:
return FeedbackTag.objects.filter(tag_type=tag_type)
return FeedbackTag.objects.all()
2025-06-09 16:56:16 +08:00
class DetailedFeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = DetailedFeedback.objects.all()
serializer_class = DetailedFeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return DetailedFeedback.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_type = request.data.get('feedback_type')
feedback_tags = request.data.get('feedback_tags', [])
custom_tags = request.data.get('custom_tags', '')
custom_content = request.data.get('custom_content', '')
is_inline = request.data.get('is_inline', True)
if not message_id or not conversation_id or not feedback_type:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='消息ID、对话ID和反馈类型不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=404,
message='消息或对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
2025-06-09 16:29:14 +08:00
# 将标签列表转换为JSON字符串
if isinstance(feedback_tags, list):
feedback_tags = json.dumps(feedback_tags)
# 创建详细反馈
detailed_feedback = DetailedFeedback.objects.create(
id=str(uuid.uuid4()),
message=message,
conversation=conversation,
user=request.user,
feedback_type=feedback_type,
feedback_tags=feedback_tags,
custom_tags=custom_tags,
custom_content=custom_content,
is_inline=is_inline,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='detailed_feedback_submit',
target_type='message',
target_id=message_id,
details={
'feedback_type': feedback_type,
'is_inline': is_inline
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_type)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='详细反馈提交成功',
data=DetailedFeedbackSerializer(detailed_feedback).data
)
2025-06-09 16:29:14 +08:00
def _update_annotation_stats(self, user_id, feedback_type):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_type == 'positive':
stats.positive_annotations += 1
elif feedback_type == 'negative':
stats.negative_annotations += 1
stats.save()
2025-06-09 16:56:16 +08:00
class ConversationSubmissionViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = ConversationSubmission.objects.all()
serializer_class = ConversationSubmissionSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
# 管理员可以查看所有提交
if user.role == 'admin':
queryset = ConversationSubmission.objects.all()
else:
# 普通用户只能查看自己的提交
queryset = ConversationSubmission.objects.filter(user=user)
# 过滤状态
status_filter = self.request.query_params.get('status')
if status_filter:
queryset = queryset.filter(status=status_filter)
return queryset.order_by('-submitted_at')
@action(detail=True, methods=['post'])
def review(self, request, pk=None):
if request.user.role != 'admin':
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=403,
message='只有管理员可以进行审核',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
2025-06-09 16:29:14 +08:00
submission = self.get_object()
status_value = request.data.get('status')
quality_score = request.data.get('quality_score')
reviewer_notes = request.data.get('reviewer_notes', '')
if not status_value or status_value not in ['accepted', 'rejected']:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='状态值无效',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
if quality_score is not None and (quality_score < 1 or quality_score > 5):
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='质量分数必须在1-5之间',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
# 更新提交状态
submission.status = status_value
submission.quality_score = quality_score
submission.reviewer = request.user
submission.reviewer_notes = reviewer_notes
submission.reviewed_at = timezone.now()
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='submission_review',
target_type='submission',
target_id=submission.id,
details={
'status': status_value,
'quality_score': quality_score
}
)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='审核完成',
data=ConversationSubmissionSerializer(submission).data
)
2025-06-09 16:29:14 +08:00
2025-06-09 16:56:16 +08:00
class ConversationEvaluationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = ConversationEvaluation.objects.all()
serializer_class = ConversationEvaluationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return ConversationEvaluation.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
conversation_id = request.data.get('conversation_id')
overall_feeling = request.data.get('overall_feeling', '')
has_logical_issues = request.data.get('has_logical_issues')
needs_satisfied = request.data.get('needs_satisfied')
2025-06-09 17:29:57 +08:00
if not conversation_id or has_logical_issues is None or needs_satisfied is None:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='对话ID、逻辑问题和需求满足度不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
try:
conversation = Conversation.objects.get(id=conversation_id)
except Conversation.DoesNotExist:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=404,
message='对话不存在',
data=None,
status_code=status.HTTP_404_NOT_FOUND
)
2025-06-09 16:29:14 +08:00
2025-06-09 17:50:51 +08:00
# 检查是否已存在评估
existing_evaluation = ConversationEvaluation.objects.filter(
conversation_id=conversation_id,
user=request.user
).first()
if existing_evaluation:
# 如果已存在评估,返回提示信息
return self.get_standard_response(
code=400,
message='您已经对这个对话进行过评估请使用PUT或PATCH方法更新评估',
data={
'evaluation_id': existing_evaluation.id,
'created_at': existing_evaluation.created_at
},
status_code=status.HTTP_400_BAD_REQUEST
)
# 创建新评估
evaluation = ConversationEvaluation.objects.create(
id=str(uuid.uuid4()),
2025-06-09 16:29:14 +08:00
conversation_id=conversation_id,
user=request.user,
2025-06-09 17:50:51 +08:00
overall_feeling=overall_feeling,
has_logical_issues=has_logical_issues,
needs_satisfied=needs_satisfied,
created_at=timezone.now(),
updated_at=timezone.now()
2025-06-09 16:29:14 +08:00
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_evaluation',
target_type='conversation',
target_id=conversation_id,
details={
'has_logical_issues': has_logical_issues,
'needs_satisfied': needs_satisfied
}
)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='评估提交成功',
data=ConversationEvaluationSerializer(evaluation).data
)
2025-06-09 16:29:14 +08:00
2025-06-09 16:56:16 +08:00
class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet):
2025-06-09 16:29:14 +08:00
queryset = SystemConfig.objects.all()
serializer_class = SystemConfigSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
# 只有管理员可以查看所有配置
if self.request.user.role == 'admin':
return SystemConfig.objects.all()
# 其他用户只能查看公开配置
return SystemConfig.objects.filter(config_key__startswith='public_')
@action(detail=False, methods=['get', 'post'])
def model(self, request):
if request.method == 'GET':
# 获取当前模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
if model_config:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
data={'model': model_config.config_value}
)
return self.get_standard_response(
data={'model': '默认模型'}
)
2025-06-09 16:29:14 +08:00
elif request.method == 'POST':
# 设置当前模型
if request.user.role != 'admin':
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=403,
message='只有管理员可以更改模型',
data=None,
status_code=status.HTTP_403_FORBIDDEN
)
2025-06-09 16:29:14 +08:00
model_name = request.data.get('model')
if not model_name:
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
code=400,
message='模型名称不能为空',
data=None,
status_code=status.HTTP_400_BAD_REQUEST
)
2025-06-09 16:29:14 +08:00
# 更新或创建配置
config, created = SystemConfig.objects.update_or_create(
config_key='current_model',
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'config_value': model_name,
'config_type': 'string',
'description': '当前使用的AI模型',
'updated_at': timezone.now()
}
)
2025-06-09 16:56:16 +08:00
return self.get_standard_response(
message='模型设置成功',
data={'model': model_name}
)
2025-06-09 16:29:14 +08:00
@action(detail=False, methods=['get'])
def models(self, request):
2025-06-09 18:00:00 +08:00
"""返回可用的模型列表"""
from .siliconflow_client import SiliconFlowClient
from django.conf import settings
import logging
logger = logging.getLogger(__name__)
try:
# 从SiliconFlow获取可用模型列表
logger.info("正在获取可用模型列表...")
models = SiliconFlowClient.get_available_models(
api_key=getattr(settings, 'SILICONFLOW_API_KEY', None)
)
logger.info(f"成功获取 {len(models)} 个可用模型")
return self.get_standard_response(
data={'models': models}
)
except Exception as e:
logger.exception(f"获取模型列表失败: {str(e)}")
# 使用预定义的模型列表作为备选
fallback_models = SiliconFlowClient._get_fallback_models()
return self.get_standard_response(
code=200, # 仍然返回200以避免前端错误
message="无法从API获取模型列表使用预定义列表",
data={'models': fallback_models}
)