daren/apps/rlhf/views.py
2025-06-09 16:29:14 +08:00

559 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from django.shortcuts import render, get_object_or_404
from rest_framework import viewsets, status
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.pagination import PageNumberPagination
from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig
)
from .serializers import (
ConversationSerializer, MessageSerializer, FeedbackSerializer,
FeedbackTagSerializer, DetailedFeedbackSerializer, ConversationSubmissionSerializer,
ConversationEvaluationSerializer, SystemConfigSerializer
)
from apps.user.models import User, UserActivityLog, AnnotationStats
from django.utils import timezone
import uuid
import json
from django.db.models import Count, Avg, Sum, Q, F
from datetime import datetime, timedelta
from django.db import transaction
from django.db.models.functions import TruncDate
from apps.user.authentication import CustomTokenAuthentication
class ConversationViewSet(viewsets.ModelViewSet):
queryset = Conversation.objects.all()
serializer_class = ConversationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Conversation.objects.filter(user=user).order_by('-created_at')
def perform_create(self, serializer):
serializer.save(user=self.request.user)
@action(detail=True, methods=['get'])
def messages(self, request, pk=None):
conversation = self.get_object()
messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
serializer = MessageSerializer(messages, many=True)
return Response(serializer.data)
@action(detail=True, methods=['post'])
def message(self, request, pk=None):
conversation = self.get_object()
content = request.data.get('content')
if not content:
return Response({'error': '消息内容不能为空'}, status=status.HTTP_400_BAD_REQUEST)
# 创建用户消息
user_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='user',
content=content
)
# 这里需要调用AI服务获取回复
# 示例调用SiliconFlow或其他AI服务
ai_response = self._generate_ai_response(user_message.content, conversation)
# 创建AI回复消息
ai_message = Message.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
role='assistant',
content=ai_response
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id)
messages = [
MessageSerializer(user_message).data,
MessageSerializer(ai_message).data
]
return Response(messages)
@action(detail=True, methods=['post'])
def submit(self, request, pk=None):
conversation = self.get_object()
title = request.data.get('title', '')
description = request.data.get('description', '')
if conversation.is_submitted:
return Response({'error': '该对话已提交'}, status=status.HTTP_400_BAD_REQUEST)
# 更新对话为已提交状态
conversation.is_submitted = True
conversation.save()
# 创建提交记录
submission = ConversationSubmission.objects.create(
id=str(uuid.uuid4()),
conversation=conversation,
user=request.user,
title=title,
description=description,
status='submitted',
submitted_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_submit',
target_type='conversation',
target_id=str(conversation.id),
details={'title': title}
)
return Response({
'message': '对话提交成功',
'submission_id': submission.id
})
@action(detail=True, methods=['post'])
def resume(self, request, pk=None):
conversation = self.get_object()
if not conversation.is_submitted:
return Response({'error': '该对话未提交,无需恢复'}, status=status.HTTP_400_BAD_REQUEST)
# 更新对话为未提交状态
conversation.is_submitted = False
conversation.save()
# 获取最新的提交记录
submission = ConversationSubmission.objects.filter(
conversation=conversation
).order_by('-submitted_at').first()
if submission and submission.status == 'submitted':
submission.status = 'rejected'
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_resume',
target_type='conversation',
target_id=str(conversation.id)
)
return Response({'message': '对话已恢复为未提交状态'})
def _generate_ai_response(self, user_message, conversation):
"""
生成AI回复
这里只是一个示例实际应用中需要对接真实的AI服务
"""
# 从系统配置获取当前使用的模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
model_name = model_config.config_value if model_config else "默认模型"
# 获取历史消息作为上下文
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
history = []
for msg in history_messages:
history.append({"role": msg.role, "content": msg.content})
# 在这里调用实际的AI API
# 例如如果使用SiliconFlow API
# response = sf_client.chat(user_message, history)
# 这里仅作为示例,返回一个固定的回复
return f"这是AI({model_name})的回复:我已收到您的消息「{user_message}」。根据您的问题,我的建议是..."
def _update_annotation_stats(self, user_id):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新消息计数
stats.messages_count += 1
stats.save()
class MessageViewSet(viewsets.ModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Message.objects.filter(conversation__user=user).order_by('timestamp')
class FeedbackViewSet(viewsets.ModelViewSet):
queryset = Feedback.objects.all()
serializer_class = FeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return Feedback.objects.filter(user=user).order_by('-timestamp')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_value = request.data.get('feedback_value')
if not message_id or not conversation_id:
return Response({'error': '消息ID和对话ID不能为空'}, status=status.HTTP_400_BAD_REQUEST)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return Response({'error': '消息或对话不存在'}, status=status.HTTP_404_NOT_FOUND)
# 创建或更新反馈
feedback, created = Feedback.objects.update_or_create(
message_id=message_id,
conversation_id=conversation_id,
user=request.user,
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'feedback_value': feedback_value,
'timestamp': timezone.now()
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_value)
return Response(FeedbackSerializer(feedback).data)
def _update_annotation_stats(self, user_id, feedback_value):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_value > 0:
stats.positive_annotations += 1
elif feedback_value < 0:
stats.negative_annotations += 1
stats.save()
class FeedbackTagViewSet(viewsets.ModelViewSet):
queryset = FeedbackTag.objects.all()
serializer_class = FeedbackTagSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
tag_type = self.request.query_params.get('type')
if tag_type and tag_type in ['positive', 'negative']:
return FeedbackTag.objects.filter(tag_type=tag_type)
return FeedbackTag.objects.all()
class DetailedFeedbackViewSet(viewsets.ModelViewSet):
queryset = DetailedFeedback.objects.all()
serializer_class = DetailedFeedbackSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return DetailedFeedback.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
message_id = request.data.get('message_id')
conversation_id = request.data.get('conversation_id')
feedback_type = request.data.get('feedback_type')
feedback_tags = request.data.get('feedback_tags', [])
custom_tags = request.data.get('custom_tags', '')
custom_content = request.data.get('custom_content', '')
is_inline = request.data.get('is_inline', True)
if not message_id or not conversation_id or not feedback_type:
return Response({
'error': '消息ID、对话ID和反馈类型不能为空'
}, status=status.HTTP_400_BAD_REQUEST)
try:
message = Message.objects.get(id=message_id)
conversation = Conversation.objects.get(id=conversation_id)
except (Message.DoesNotExist, Conversation.DoesNotExist):
return Response({'error': '消息或对话不存在'}, status=status.HTTP_404_NOT_FOUND)
# 将标签列表转换为JSON字符串
if isinstance(feedback_tags, list):
feedback_tags = json.dumps(feedback_tags)
# 创建详细反馈
detailed_feedback = DetailedFeedback.objects.create(
id=str(uuid.uuid4()),
message=message,
conversation=conversation,
user=request.user,
feedback_type=feedback_type,
feedback_tags=feedback_tags,
custom_tags=custom_tags,
custom_content=custom_content,
is_inline=is_inline,
created_at=timezone.now(),
updated_at=timezone.now()
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='detailed_feedback_submit',
target_type='message',
target_id=message_id,
details={
'feedback_type': feedback_type,
'is_inline': is_inline
}
)
# 更新用户的标注统计
self._update_annotation_stats(request.user.id, feedback_type)
return Response(DetailedFeedbackSerializer(detailed_feedback).data)
def _update_annotation_stats(self, user_id, feedback_type):
"""更新用户的标注统计信息"""
today = timezone.now().date()
# 获取或创建今天的统计记录
stats, created = AnnotationStats.objects.get_or_create(
user_id=user_id,
date=today,
defaults={
'id': str(uuid.uuid4()),
'total_annotations': 0,
'positive_annotations': 0,
'negative_annotations': 0,
'conversations_count': 0,
'messages_count': 0
}
)
# 更新统计
stats.total_annotations += 1
if feedback_type == 'positive':
stats.positive_annotations += 1
elif feedback_type == 'negative':
stats.negative_annotations += 1
stats.save()
class ConversationSubmissionViewSet(viewsets.ModelViewSet):
queryset = ConversationSubmission.objects.all()
serializer_class = ConversationSubmissionSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
# 管理员可以查看所有提交
if user.role == 'admin':
queryset = ConversationSubmission.objects.all()
else:
# 普通用户只能查看自己的提交
queryset = ConversationSubmission.objects.filter(user=user)
# 过滤状态
status_filter = self.request.query_params.get('status')
if status_filter:
queryset = queryset.filter(status=status_filter)
return queryset.order_by('-submitted_at')
@action(detail=True, methods=['post'])
def review(self, request, pk=None):
if request.user.role != 'admin':
return Response({'error': '只有管理员可以进行审核'}, status=status.HTTP_403_FORBIDDEN)
submission = self.get_object()
status_value = request.data.get('status')
quality_score = request.data.get('quality_score')
reviewer_notes = request.data.get('reviewer_notes', '')
if not status_value or status_value not in ['accepted', 'rejected']:
return Response({'error': '状态值无效'}, status=status.HTTP_400_BAD_REQUEST)
if quality_score is not None and (quality_score < 1 or quality_score > 5):
return Response({'error': '质量分数必须在1-5之间'}, status=status.HTTP_400_BAD_REQUEST)
# 更新提交状态
submission.status = status_value
submission.quality_score = quality_score
submission.reviewer = request.user
submission.reviewer_notes = reviewer_notes
submission.reviewed_at = timezone.now()
submission.save()
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='submission_review',
target_type='submission',
target_id=submission.id,
details={
'status': status_value,
'quality_score': quality_score
}
)
return Response({
'message': '审核完成',
'submission': ConversationSubmissionSerializer(submission).data
})
class ConversationEvaluationViewSet(viewsets.ModelViewSet):
queryset = ConversationEvaluation.objects.all()
serializer_class = ConversationEvaluationSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
user = self.request.user
return ConversationEvaluation.objects.filter(user=user).order_by('-created_at')
def create(self, request, *args, **kwargs):
conversation_id = request.data.get('conversation_id')
overall_feeling = request.data.get('overall_feeling', '')
has_logical_issues = request.data.get('has_logical_issues')
needs_satisfied = request.data.get('needs_satisfied')
if not conversation_id or not has_logical_issues or not needs_satisfied:
return Response({
'error': '对话ID、逻辑问题和需求满足度不能为空'
}, status=status.HTTP_400_BAD_REQUEST)
try:
conversation = Conversation.objects.get(id=conversation_id)
except Conversation.DoesNotExist:
return Response({'error': '对话不存在'}, status=status.HTTP_404_NOT_FOUND)
# 创建或更新评估
evaluation, created = ConversationEvaluation.objects.update_or_create(
conversation_id=conversation_id,
user=request.user,
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'overall_feeling': overall_feeling,
'has_logical_issues': has_logical_issues,
'needs_satisfied': needs_satisfied,
'updated_at': timezone.now()
}
)
# 记录活动日志
UserActivityLog.objects.create(
user=request.user,
action_type='conversation_evaluation',
target_type='conversation',
target_id=conversation_id,
details={
'has_logical_issues': has_logical_issues,
'needs_satisfied': needs_satisfied
}
)
return Response(ConversationEvaluationSerializer(evaluation).data)
class SystemConfigViewSet(viewsets.ModelViewSet):
queryset = SystemConfig.objects.all()
serializer_class = SystemConfigSerializer
authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
# 只有管理员可以查看所有配置
if self.request.user.role == 'admin':
return SystemConfig.objects.all()
# 其他用户只能查看公开配置
return SystemConfig.objects.filter(config_key__startswith='public_')
@action(detail=False, methods=['get', 'post'])
def model(self, request):
if request.method == 'GET':
# 获取当前模型
model_config = SystemConfig.objects.filter(config_key='current_model').first()
if model_config:
return Response({'model': model_config.config_value})
return Response({'model': '默认模型'})
elif request.method == 'POST':
# 设置当前模型
if request.user.role != 'admin':
return Response({'error': '只有管理员可以更改模型'}, status=status.HTTP_403_FORBIDDEN)
model_name = request.data.get('model')
if not model_name:
return Response({'error': '模型名称不能为空'}, status=status.HTTP_400_BAD_REQUEST)
# 更新或创建配置
config, created = SystemConfig.objects.update_or_create(
config_key='current_model',
defaults={
'id': str(uuid.uuid4()) if created else F('id'),
'config_value': model_name,
'config_type': 'string',
'description': '当前使用的AI模型',
'updated_at': timezone.now()
}
)
return Response({'model': model_name})
@action(detail=False, methods=['get'])
def models(self, request):
# 返回可用的模型列表
return Response({
'models': [
{'id': 'model1', 'name': 'GPT-3.5'},
{'id': 'model2', 'name': 'GPT-4'},
{'id': 'model3', 'name': 'Claude'},
{'id': 'model4', 'name': 'LLaMA'},
{'id': 'model5', 'name': 'Qwen'}
]
})