diff --git a/apps/rlhf/admin.py b/apps/rlhf/admin.py index 9bbea70..9b37ecb 100644 --- a/apps/rlhf/admin.py +++ b/apps/rlhf/admin.py @@ -1,24 +1,40 @@ from django.contrib import admin from .models import ( - Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, + RLHFConversation, NegotiationChat, ChatHistory, Feedback, FeedbackTag, DetailedFeedback, ConversationSubmission, ConversationEvaluation, SystemConfig ) -@admin.register(Conversation) -class ConversationAdmin(admin.ModelAdmin): - list_display = ('id', 'user', 'is_submitted', 'created_at') - list_filter = ('is_submitted', 'created_at') - search_fields = ('id', 'user__username') +@admin.register(RLHFConversation) +class RLHFConversationAdmin(admin.ModelAdmin): + list_display = ('negotiation_chat', 'id', 'user', 'is_submitted', 'created_at') + list_filter = ('is_submitted',) + search_fields = ('negotiation_chat__conversation_id', 'negotiation_chat__negotiation__user__username') + + def id(self, obj): + return obj.negotiation_chat.conversation_id + + def user(self, obj): + return obj.negotiation_chat.negotiation.user + + def created_at(self, obj): + return obj.negotiation_chat.created_at + + +@admin.register(NegotiationChat) +class NegotiationChatAdmin(admin.ModelAdmin): + list_display = ('conversation_id', 'negotiation', 'creator', 'product', 'created_at', 'updated_at') + list_filter = ('created_at', 'updated_at') + search_fields = ('conversation_id', 'negotiation__user__username', 'creator__name', 'product__name') date_hierarchy = 'created_at' -@admin.register(Message) -class MessageAdmin(admin.ModelAdmin): - list_display = ('id', 'conversation', 'role', 'short_content', 'timestamp') - list_filter = ('role', 'timestamp') - search_fields = ('id', 'conversation__id', 'content') - date_hierarchy = 'timestamp' +@admin.register(ChatHistory) +class ChatHistoryAdmin(admin.ModelAdmin): + list_display = ('id', 'conversation_id', 'user', 'role', 'short_content', 'created_at') + list_filter = ('role', 'created_at') + search_fields = ('id', 'conversation_id', 'content') + date_hierarchy = 'created_at' def short_content(self, obj): return obj.content[:50] + '...' if len(obj.content) > 50 else obj.content @@ -27,9 +43,9 @@ class MessageAdmin(admin.ModelAdmin): @admin.register(Feedback) class FeedbackAdmin(admin.ModelAdmin): - list_display = ('id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp') + list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp') list_filter = ('feedback_value', 'timestamp') - search_fields = ('id', 'message__id', 'conversation__id', 'user__username') + search_fields = ('id', 'message__id', 'conversation_id', 'user__username') date_hierarchy = 'timestamp' @@ -42,25 +58,25 @@ class FeedbackTagAdmin(admin.ModelAdmin): @admin.register(DetailedFeedback) class DetailedFeedbackAdmin(admin.ModelAdmin): - list_display = ('id', 'message', 'conversation', 'user', 'feedback_type', 'is_inline', 'created_at') + list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_type', 'is_inline', 'created_at') list_filter = ('feedback_type', 'is_inline', 'created_at') - search_fields = ('id', 'message__id', 'conversation__id', 'user__username', 'custom_content') + search_fields = ('id', 'message__id', 'conversation_id', 'user__username', 'custom_content') date_hierarchy = 'created_at' @admin.register(ConversationSubmission) class ConversationSubmissionAdmin(admin.ModelAdmin): - list_display = ('id', 'conversation', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at') + list_display = ('id', 'conversation_id', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at') list_filter = ('status', 'quality_score', 'submitted_at', 'reviewed_at') - search_fields = ('id', 'conversation__id', 'user__username', 'title', 'description', 'reviewer__username') + search_fields = ('id', 'conversation_id', 'user__username', 'title', 'description', 'reviewer__username') date_hierarchy = 'submitted_at' @admin.register(ConversationEvaluation) class ConversationEvaluationAdmin(admin.ModelAdmin): - list_display = ('id', 'conversation', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at') + list_display = ('id', 'conversation_id', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at') list_filter = ('has_logical_issues', 'needs_satisfied', 'created_at') - search_fields = ('id', 'conversation__id', 'user__username', 'overall_feeling') + search_fields = ('id', 'conversation_id', 'user__username', 'overall_feeling') date_hierarchy = 'created_at' diff --git a/apps/rlhf/management/commands/analyze_data.py b/apps/rlhf/management/commands/analyze_data.py index e14a9b8..2c24cfb 100644 --- a/apps/rlhf/management/commands/analyze_data.py +++ b/apps/rlhf/management/commands/analyze_data.py @@ -1,368 +1,368 @@ -from django.core.management.base import BaseCommand -from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag -from django.db.models import Count, Avg, Sum, Q, F -from django.utils import timezone -from django.contrib.auth import get_user_model -import json -from datetime import datetime, timedelta - -User = get_user_model() - -class Command(BaseCommand): - help = '分析RLHF反馈数据,生成统计报告' - - def add_arguments(self, parser): - parser.add_argument( - '--export', - action='store_true', - help='导出数据到JSON文件', - ) - parser.add_argument( - '--days', - type=int, - default=30, - help='分析最近的天数', - ) - - def handle(self, *args, **options): - self.stdout.write(self.style.SUCCESS("=" * 60)) - self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告")) - self.stdout.write(self.style.SUCCESS("=" * 60)) - - # 基本统计 - feedback_stats = self.get_feedback_stats() - self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:")) - self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}") - self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)") - self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}") - self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}") - - # 对话统计 - conv_stats = self.get_conversation_stats() - self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:")) - self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}") - self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}") - self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}") - - # 标签统计 - tag_stats = self.get_tag_stats() - self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:")) - self.stdout.write(f" 最常用的正面标签:") - for tag in tag_stats['top_positive']: - self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") - - self.stdout.write(f" 最常用的负面标签:") - for tag in tag_stats['top_negative']: - self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") - - # 每日趋势 - days = options['days'] - daily_trend = self.get_daily_feedback_trend(days) - self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:")) - for day in daily_trend: - self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)") - - # 用户统计 - user_stats = self.get_user_stats() - self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:")) - self.stdout.write(f" 总用户数量: {user_stats['total_users']}") - self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}") - self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}") - - # 导出数据 - if options['export']: - filename = self.export_data_to_json() - self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}")) - - def get_feedback_stats(self): - """获取反馈统计信息""" - # 基本反馈统计 - basic_feedback = Feedback.objects.aggregate( - total=Count('id'), - positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), - negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)), - avg=Avg('feedback_value') - ) - - # 详细反馈统计 - detailed_feedback = DetailedFeedback.objects.aggregate( - total=Count('id'), - positive=Count('id', filter=Q(feedback_type='positive')), - negative=Count('id', filter=Q(feedback_type='negative')) - ) - - # 合并统计 - total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) - positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) - negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) - - # 计算平均分和正面比例 - avg_feedback = basic_feedback['avg'] or 0 - positive_rate = (positive / total * 100) if total > 0 else 0 - - return { - 'total_feedback': total, - 'positive_feedback': positive, - 'negative_feedback': negative, - 'avg_feedback': avg_feedback, - 'positive_rate': positive_rate - } - - def get_conversation_stats(self): - """获取对话统计信息""" - total_conversations = Conversation.objects.count() - total_messages = Message.objects.count() - - # 计算每个对话的消息数量分布 - conversation_messages = Message.objects.values('conversation').annotate(count=Count('id')) - avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0 - - return { - 'total_conversations': total_conversations, - 'total_messages': total_messages, - 'avg_messages_per_conversation': avg_messages - } - - def get_tag_stats(self): - """获取标签使用统计""" - # 分析DetailedFeedback中的标签使用情况 - # 注意:由于标签可能存储为JSON字符串,这里需要解析 - - # 首先获取所有的标签 - all_tags = FeedbackTag.objects.all() - tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags} - - # 计算每个标签的使用次数 - tag_counts = {} - for feedback in DetailedFeedback.objects.all(): - if feedback.feedback_tags: - try: - # 尝试解析JSON标签列表 - tag_ids = json.loads(feedback.feedback_tags) - if isinstance(tag_ids, list): - for tag_id in tag_ids: - tag_id = str(tag_id) - if tag_id in tag_counts: - tag_counts[tag_id] += 1 - else: - tag_counts[tag_id] = 1 - except (json.JSONDecodeError, TypeError): - # 如果不是有效的JSON,可能是单个标签ID - tag_id = str(feedback.feedback_tags) - if tag_id in tag_counts: - tag_counts[tag_id] += 1 - else: - tag_counts[tag_id] = 1 - - # 获取排名前5的正面和负面标签 - positive_tags = FeedbackTag.objects.filter(tag_type='positive') - negative_tags = FeedbackTag.objects.filter(tag_type='negative') - - top_positive = [] - for tag in positive_tags: - tag_id = str(tag.id) - if tag_id in tag_counts: - top_positive.append({ - 'tag_name': tag.tag_name, - 'count': tag_counts[tag_id] - }) - - top_negative = [] - for tag in negative_tags: - tag_id = str(tag.id) - if tag_id in tag_counts: - top_negative.append({ - 'tag_name': tag.tag_name, - 'count': tag_counts[tag_id] - }) - - # 按使用次数排序 - top_positive.sort(key=lambda x: x['count'], reverse=True) - top_negative.sort(key=lambda x: x['count'], reverse=True) - - # 取前5 - return { - 'top_positive': top_positive[:5], - 'top_negative': top_negative[:5] - } - - def get_daily_feedback_trend(self, days=30): - """获取每日反馈趋势""" - # 计算开始日期 - start_date = timezone.now().date() - timedelta(days=days) - - # 基本反馈按日期分组 - basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \ - .values('timestamp__date') \ - .annotate( - date=F('timestamp__date'), - total=Count('id'), - positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), - negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)) - ) \ - .order_by('date') - - # 详细反馈按日期分组 - detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \ - .values('created_at__date') \ - .annotate( - date=F('created_at__date'), - total=Count('id'), - positive=Count('id', filter=Q(feedback_type='positive')), - negative=Count('id', filter=Q(feedback_type='negative')) - ) \ - .order_by('date') - - # 合并两种反馈数据 - daily_data = {} - - for item in basic_daily: - date_str = item['date'].strftime('%Y-%m-%d') - daily_data[date_str] = { - 'date': date_str, - 'total': item['total'], - 'positive': item['positive'], - 'negative': item['negative'] - } - - for item in detailed_daily: - date_str = item['date'].strftime('%Y-%m-%d') - if date_str in daily_data: - daily_data[date_str]['total'] += item['total'] - daily_data[date_str]['positive'] += item['positive'] - daily_data[date_str]['negative'] += item['negative'] - else: - daily_data[date_str] = { - 'date': date_str, - 'total': item['total'], - 'positive': item['positive'], - 'negative': item['negative'] - } - - # 计算正面反馈比例 - for date_str, data in daily_data.items(): - data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0 - - # 转换为列表并按日期排序 - result = list(daily_data.values()) - result.sort(key=lambda x: x['date']) - - return result - - def get_user_stats(self): - """获取用户统计信息""" - # 总用户数 - total_users = User.objects.count() - - # 有反馈记录的用户数 - users_with_feedback = User.objects.filter( - Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False) - ).distinct().count() - - # 最近30天活跃的标注用户 - thirty_days_ago = timezone.now() - timedelta(days=30) - active_users = User.objects.filter( - Q(feedback__timestamp__gte=thirty_days_ago) | - Q(detailed_feedback__created_at__gte=thirty_days_ago) - ).distinct().count() - - # 计算每个用户的标注量 - user_annotations = {} - - for feedback in Feedback.objects.all(): - user_id = str(feedback.user_id) - if user_id in user_annotations: - user_annotations[user_id] += 1 - else: - user_annotations[user_id] = 1 - - for feedback in DetailedFeedback.objects.all(): - user_id = str(feedback.user_id) - if user_id in user_annotations: - user_annotations[user_id] += 1 - else: - user_annotations[user_id] = 1 - - # 计算平均每用户标注量 - if user_annotations: - avg_annotations = sum(user_annotations.values()) / len(user_annotations) - else: - avg_annotations = 0 - - return { - 'total_users': total_users, - 'users_with_feedback': users_with_feedback, - 'active_users': active_users, - 'avg_annotations_per_user': avg_annotations - } - - def export_data_to_json(self): - """导出数据到JSON文件""" - data = { - 'conversations': [], - 'feedback_summary': self.get_feedback_stats(), - 'tag_stats': self.get_tag_stats(), - 'daily_trend': self.get_daily_feedback_trend(30), - 'export_time': timezone.now().isoformat() - } - - # 导出对话和消息数据 - for conv in Conversation.objects.all().prefetch_related('messages'): - conv_data = { - 'id': str(conv.id), - 'created_at': conv.created_at.isoformat(), - 'user_id': str(conv.user_id), - 'is_submitted': conv.is_submitted, - 'messages': [] - } - - for msg in conv.messages.all().order_by('timestamp'): - msg_data = { - 'id': str(msg.id), - 'role': msg.role, - 'content': msg.content, - 'timestamp': msg.timestamp.isoformat(), - 'feedback': [] - } - - # 获取消息的反馈 - for fb in Feedback.objects.filter(message_id=msg.id): - msg_data['feedback'].append({ - 'id': str(fb.id), - 'type': 'basic', - 'value': fb.feedback_value, - 'user_id': str(fb.user_id), - 'timestamp': fb.timestamp.isoformat() - }) - - # 获取详细反馈 - for dfb in DetailedFeedback.objects.filter(message_id=msg.id): - try: - tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else [] - except (json.JSONDecodeError, TypeError): - tags = [dfb.feedback_tags] if dfb.feedback_tags else [] - - msg_data['feedback'].append({ - 'id': str(dfb.id), - 'type': 'detailed', - 'feedback_type': dfb.feedback_type, - 'tags': tags, - 'custom_tags': dfb.custom_tags, - 'custom_content': dfb.custom_content, - 'is_inline': dfb.is_inline, - 'user_id': str(dfb.user_id), - 'timestamp': dfb.created_at.isoformat() - }) - - conv_data['messages'].append(msg_data) - - data['conversations'].append(conv_data) - - # 保存到文件 - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'rlhf_data_export_{timestamp}.json' - - with open(filename, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - +from django.core.management.base import BaseCommand +from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag +from django.db.models import Count, Avg, Sum, Q, F +from django.utils import timezone +from django.contrib.auth import get_user_model +import json +from datetime import datetime, timedelta + +User = get_user_model() + +class Command(BaseCommand): + help = '分析RLHF反馈数据,生成统计报告' + + def add_arguments(self, parser): + parser.add_argument( + '--export', + action='store_true', + help='导出数据到JSON文件', + ) + parser.add_argument( + '--days', + type=int, + default=30, + help='分析最近的天数', + ) + + def handle(self, *args, **options): + self.stdout.write(self.style.SUCCESS("=" * 60)) + self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告")) + self.stdout.write(self.style.SUCCESS("=" * 60)) + + # 基本统计 + feedback_stats = self.get_feedback_stats() + self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:")) + self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}") + self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)") + self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}") + self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}") + + # 对话统计 + conv_stats = self.get_conversation_stats() + self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:")) + self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}") + self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}") + self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}") + + # 标签统计 + tag_stats = self.get_tag_stats() + self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:")) + self.stdout.write(f" 最常用的正面标签:") + for tag in tag_stats['top_positive']: + self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") + + self.stdout.write(f" 最常用的负面标签:") + for tag in tag_stats['top_negative']: + self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") + + # 每日趋势 + days = options['days'] + daily_trend = self.get_daily_feedback_trend(days) + self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:")) + for day in daily_trend: + self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)") + + # 用户统计 + user_stats = self.get_user_stats() + self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:")) + self.stdout.write(f" 总用户数量: {user_stats['total_users']}") + self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}") + self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}") + + # 导出数据 + if options['export']: + filename = self.export_data_to_json() + self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}")) + + def get_feedback_stats(self): + """获取反馈统计信息""" + # 基本反馈统计 + basic_feedback = Feedback.objects.aggregate( + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)), + avg=Avg('feedback_value') + ) + + # 详细反馈统计 + detailed_feedback = DetailedFeedback.objects.aggregate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) + + # 合并统计 + total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) + positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) + negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) + + # 计算平均分和正面比例 + avg_feedback = basic_feedback['avg'] or 0 + positive_rate = (positive / total * 100) if total > 0 else 0 + + return { + 'total_feedback': total, + 'positive_feedback': positive, + 'negative_feedback': negative, + 'avg_feedback': avg_feedback, + 'positive_rate': positive_rate + } + + def get_conversation_stats(self): + """获取对话统计信息""" + total_conversations = Conversation.objects.count() + total_messages = Message.objects.count() + + # 计算每个对话的消息数量分布 + conversation_messages = Message.objects.values('conversation').annotate(count=Count('id')) + avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0 + + return { + 'total_conversations': total_conversations, + 'total_messages': total_messages, + 'avg_messages_per_conversation': avg_messages + } + + def get_tag_stats(self): + """获取标签使用统计""" + # 分析DetailedFeedback中的标签使用情况 + # 注意:由于标签可能存储为JSON字符串,这里需要解析 + + # 首先获取所有的标签 + all_tags = FeedbackTag.objects.all() + tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags} + + # 计算每个标签的使用次数 + tag_counts = {} + for feedback in DetailedFeedback.objects.all(): + if feedback.feedback_tags: + try: + # 尝试解析JSON标签列表 + tag_ids = json.loads(feedback.feedback_tags) + if isinstance(tag_ids, list): + for tag_id in tag_ids: + tag_id = str(tag_id) + if tag_id in tag_counts: + tag_counts[tag_id] += 1 + else: + tag_counts[tag_id] = 1 + except (json.JSONDecodeError, TypeError): + # 如果不是有效的JSON,可能是单个标签ID + tag_id = str(feedback.feedback_tags) + if tag_id in tag_counts: + tag_counts[tag_id] += 1 + else: + tag_counts[tag_id] = 1 + + # 获取排名前5的正面和负面标签 + positive_tags = FeedbackTag.objects.filter(tag_type='positive') + negative_tags = FeedbackTag.objects.filter(tag_type='negative') + + top_positive = [] + for tag in positive_tags: + tag_id = str(tag.id) + if tag_id in tag_counts: + top_positive.append({ + 'tag_name': tag.tag_name, + 'count': tag_counts[tag_id] + }) + + top_negative = [] + for tag in negative_tags: + tag_id = str(tag.id) + if tag_id in tag_counts: + top_negative.append({ + 'tag_name': tag.tag_name, + 'count': tag_counts[tag_id] + }) + + # 按使用次数排序 + top_positive.sort(key=lambda x: x['count'], reverse=True) + top_negative.sort(key=lambda x: x['count'], reverse=True) + + # 取前5 + return { + 'top_positive': top_positive[:5], + 'top_negative': top_negative[:5] + } + + def get_daily_feedback_trend(self, days=30): + """获取每日反馈趋势""" + # 计算开始日期 + start_date = timezone.now().date() - timedelta(days=days) + + # 基本反馈按日期分组 + basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \ + .values('timestamp__date') \ + .annotate( + date=F('timestamp__date'), + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)) + ) \ + .order_by('date') + + # 详细反馈按日期分组 + detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \ + .values('created_at__date') \ + .annotate( + date=F('created_at__date'), + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) \ + .order_by('date') + + # 合并两种反馈数据 + daily_data = {} + + for item in basic_daily: + date_str = item['date'].strftime('%Y-%m-%d') + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + for item in detailed_daily: + date_str = item['date'].strftime('%Y-%m-%d') + if date_str in daily_data: + daily_data[date_str]['total'] += item['total'] + daily_data[date_str]['positive'] += item['positive'] + daily_data[date_str]['negative'] += item['negative'] + else: + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + # 计算正面反馈比例 + for date_str, data in daily_data.items(): + data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0 + + # 转换为列表并按日期排序 + result = list(daily_data.values()) + result.sort(key=lambda x: x['date']) + + return result + + def get_user_stats(self): + """获取用户统计信息""" + # 总用户数 + total_users = User.objects.count() + + # 有反馈记录的用户数 + users_with_feedback = User.objects.filter( + Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False) + ).distinct().count() + + # 最近30天活跃的标注用户 + thirty_days_ago = timezone.now() - timedelta(days=30) + active_users = User.objects.filter( + Q(feedback__timestamp__gte=thirty_days_ago) | + Q(detailed_feedback__created_at__gte=thirty_days_ago) + ).distinct().count() + + # 计算每个用户的标注量 + user_annotations = {} + + for feedback in Feedback.objects.all(): + user_id = str(feedback.user_id) + if user_id in user_annotations: + user_annotations[user_id] += 1 + else: + user_annotations[user_id] = 1 + + for feedback in DetailedFeedback.objects.all(): + user_id = str(feedback.user_id) + if user_id in user_annotations: + user_annotations[user_id] += 1 + else: + user_annotations[user_id] = 1 + + # 计算平均每用户标注量 + if user_annotations: + avg_annotations = sum(user_annotations.values()) / len(user_annotations) + else: + avg_annotations = 0 + + return { + 'total_users': total_users, + 'users_with_feedback': users_with_feedback, + 'active_users': active_users, + 'avg_annotations_per_user': avg_annotations + } + + def export_data_to_json(self): + """导出数据到JSON文件""" + data = { + 'conversations': [], + 'feedback_summary': self.get_feedback_stats(), + 'tag_stats': self.get_tag_stats(), + 'daily_trend': self.get_daily_feedback_trend(30), + 'export_time': timezone.now().isoformat() + } + + # 导出对话和消息数据 + for conv in Conversation.objects.all().prefetch_related('messages'): + conv_data = { + 'id': str(conv.id), + 'created_at': conv.created_at.isoformat(), + 'user_id': str(conv.user_id), + 'is_submitted': conv.is_submitted, + 'messages': [] + } + + for msg in conv.messages.all().order_by('timestamp'): + msg_data = { + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'timestamp': msg.timestamp.isoformat(), + 'feedback': [] + } + + # 获取消息的反馈 + for fb in Feedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(fb.id), + 'type': 'basic', + 'value': fb.feedback_value, + 'user_id': str(fb.user_id), + 'timestamp': fb.timestamp.isoformat() + }) + + # 获取详细反馈 + for dfb in DetailedFeedback.objects.filter(message_id=msg.id): + try: + tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else [] + except (json.JSONDecodeError, TypeError): + tags = [dfb.feedback_tags] if dfb.feedback_tags else [] + + msg_data['feedback'].append({ + 'id': str(dfb.id), + 'type': 'detailed', + 'feedback_type': dfb.feedback_type, + 'tags': tags, + 'custom_tags': dfb.custom_tags, + 'custom_content': dfb.custom_content, + 'is_inline': dfb.is_inline, + 'user_id': str(dfb.user_id), + 'timestamp': dfb.created_at.isoformat() + }) + + conv_data['messages'].append(msg_data) + + data['conversations'].append(conv_data) + + # 保存到文件 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'rlhf_data_export_{timestamp}.json' + + with open(filename, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + return filename \ No newline at end of file diff --git a/apps/rlhf/management/commands/import_data.py b/apps/rlhf/management/commands/import_data.py index 77e2be4..5c275f2 100644 --- a/apps/rlhf/management/commands/import_data.py +++ b/apps/rlhf/management/commands/import_data.py @@ -1,289 +1,289 @@ -from django.core.management.base import BaseCommand, CommandError -from rlhf.models import ( - Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, - ConversationSubmission, ConversationEvaluation, SystemConfig -) -from django.utils import timezone -import json -import uuid -import os -from django.contrib.auth import get_user_model - -User = get_user_model() - -class Command(BaseCommand): - help = '导入/导出RLHF数据' - - def add_arguments(self, parser): - parser.add_argument( - '--import-file', - help='导入JSON数据文件的路径', - ) - parser.add_argument( - '--export-file', - help='导出JSON数据文件的路径', - ) - parser.add_argument( - '--import-tags', - action='store_true', - help='从init_tags.py导入标签', - ) - parser.add_argument( - '--clear', - action='store_true', - help='在导入前清除现有数据', - ) - - def handle(self, *args, **options): - if options['import_file']: - self.import_data(options['import_file'], options['clear']) - elif options['export_file']: - self.export_data(options['export_file']) - elif options['import_tags']: - self.import_tags_from_init_file() - else: - self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径')) - - def import_data(self, file_path, clear=False): - """从JSON文件导入数据""" - if not os.path.exists(file_path): - raise CommandError(f'文件不存在: {file_path}') - - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - if clear: - self.stdout.write(self.style.WARNING('正在清除现有数据...')) - self._clear_data() - - # 导入对话 - for conv_data in data.get('conversations', []): - self._import_conversation(conv_data) - - # 导入配置 - for config in data.get('system_configs', []): - self._import_system_config(config) - - self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}')) - - except Exception as e: - raise CommandError(f'导入数据失败: {str(e)}') - - def export_data(self, file_path): - """导出数据到JSON文件""" - try: - data = { - 'conversations': [], - 'system_configs': [], - 'tags': [], - 'export_time': timezone.now().isoformat() - } - - # 导出标签 - for tag in FeedbackTag.objects.all(): - data['tags'].append({ - 'id': str(tag.id), - 'tag_name': tag.tag_name, - 'tag_type': tag.tag_type, - 'description': tag.description, - 'created_at': tag.created_at.isoformat() - }) - - # 导出系统配置 - for config in SystemConfig.objects.all(): - data['system_configs'].append({ - 'id': str(config.id), - 'config_key': config.config_key, - 'config_value': config.config_value, - 'config_type': config.config_type, - 'description': config.description, - 'created_at': config.created_at.isoformat(), - 'updated_at': config.updated_at.isoformat() - }) - - # 导出对话 - for conv in Conversation.objects.all().prefetch_related('messages'): - conv_data = { - 'id': str(conv.id), - 'user_id': str(conv.user_id), - 'is_submitted': conv.is_submitted, - 'created_at': conv.created_at.isoformat(), - 'messages': [] - } - - for msg in conv.messages.all().order_by('timestamp'): - msg_data = { - 'id': str(msg.id), - 'role': msg.role, - 'content': msg.content, - 'timestamp': msg.timestamp.isoformat(), - 'feedback': [] - } - - # 获取基本反馈 - for fb in Feedback.objects.filter(message_id=msg.id): - msg_data['feedback'].append({ - 'id': str(fb.id), - 'type': 'basic', - 'feedback_value': fb.feedback_value, - 'user_id': str(fb.user_id), - 'timestamp': fb.timestamp.isoformat() - }) - - # 获取详细反馈 - for dfb in DetailedFeedback.objects.filter(message_id=msg.id): - msg_data['feedback'].append({ - 'id': str(dfb.id), - 'type': 'detailed', - 'feedback_type': dfb.feedback_type, - 'feedback_tags': dfb.feedback_tags, - 'custom_tags': dfb.custom_tags, - 'custom_content': dfb.custom_content, - 'is_inline': dfb.is_inline, - 'user_id': str(dfb.user_id), - 'created_at': dfb.created_at.isoformat(), - 'updated_at': dfb.updated_at.isoformat() - }) - - conv_data['messages'].append(msg_data) - - data['conversations'].append(conv_data) - - with open(file_path, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}')) - - except Exception as e: - raise CommandError(f'导出数据失败: {str(e)}') - - def import_tags_from_init_file(self): - """从init_tags.py导入标签""" - try: - # 定义标签数据 - positive_tags = [ - ('有帮助', '回答对问题有实际帮助'), - ('准确', '信息准确可靠'), - ('清晰', '表达清楚易懂'), - ('完整', '回答全面完整'), - ('友好', '语调友善亲切'), - ('创新', '提供了新颖的观点') - ] - - negative_tags = [ - ('不准确', '包含错误信息'), - ('不相关', '回答偏离主题'), - ('不完整', '回答过于简略'), - ('不清晰', '表达模糊难懂'), - ('不友好', '语调生硬冷淡'), - ('重复', '内容重复冗余') - ] - - # 插入正面标签 - for tag_name, description in positive_tags: - FeedbackTag.objects.get_or_create( - tag_name=tag_name, - defaults={ - 'id': str(uuid.uuid4()), - 'tag_type': 'positive', - 'description': description, - 'created_at': timezone.now() - } - ) - - # 插入负面标签 - for tag_name, description in negative_tags: - FeedbackTag.objects.get_or_create( - tag_name=tag_name, - defaults={ - 'id': str(uuid.uuid4()), - 'tag_type': 'negative', - 'description': description, - 'created_at': timezone.now() - } - ) - - self.stdout.write(self.style.SUCCESS('成功导入标签')) - - except Exception as e: - raise CommandError(f'导入标签失败: {str(e)}') - - def _clear_data(self): - """清除现有数据""" - DetailedFeedback.objects.all().delete() - Feedback.objects.all().delete() - Message.objects.all().delete() - Conversation.objects.all().delete() - ConversationSubmission.objects.all().delete() - ConversationEvaluation.objects.all().delete() - self.stdout.write(self.style.SUCCESS('已清除现有数据')) - - def _import_conversation(self, conv_data): - """导入单个对话""" - # 检查用户是否存在 - user_id = conv_data.get('user_id') - if not User.objects.filter(id=user_id).exists(): - self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户')) - user = User.objects.filter(is_superuser=True).first() - if not user: - raise CommandError('找不到管理员用户') - user_id = user.id - - # 创建对话 - conv = Conversation.objects.create( - id=conv_data.get('id', str(uuid.uuid4())), - user_id=user_id, - is_submitted=conv_data.get('is_submitted', False), - created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat())) - ) - - # 创建消息 - for msg_data in conv_data.get('messages', []): - msg = Message.objects.create( - id=msg_data.get('id', str(uuid.uuid4())), - conversation=conv, - role=msg_data.get('role', 'user'), - content=msg_data.get('content', ''), - timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat())) - ) - - # 创建反馈 - for fb_data in msg_data.get('feedback', []): - if fb_data.get('type') == 'basic': - Feedback.objects.create( - id=fb_data.get('id', str(uuid.uuid4())), - message=msg, - conversation=conv, - user_id=fb_data.get('user_id', user_id), - feedback_value=fb_data.get('feedback_value', 0), - timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat())) - ) - elif fb_data.get('type') == 'detailed': - DetailedFeedback.objects.create( - id=fb_data.get('id', str(uuid.uuid4())), - message=msg, - conversation=conv, - user_id=fb_data.get('user_id', user_id), - feedback_type=fb_data.get('feedback_type', 'neutral'), - feedback_tags=fb_data.get('feedback_tags', '[]'), - custom_tags=fb_data.get('custom_tags', ''), - custom_content=fb_data.get('custom_content', ''), - is_inline=fb_data.get('is_inline', True), - created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())), - updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat())) - ) - - def _import_system_config(self, config_data): - """导入系统配置""" - SystemConfig.objects.update_or_create( - config_key=config_data.get('config_key'), - defaults={ - 'id': config_data.get('id', str(uuid.uuid4())), - 'config_value': config_data.get('config_value', ''), - 'config_type': config_data.get('config_type', 'string'), - 'description': config_data.get('description', ''), - 'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())), - 'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat())) - } +from django.core.management.base import BaseCommand, CommandError +from rlhf.models import ( + Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, + ConversationSubmission, ConversationEvaluation, SystemConfig +) +from django.utils import timezone +import json +import uuid +import os +from django.contrib.auth import get_user_model + +User = get_user_model() + +class Command(BaseCommand): + help = '导入/导出RLHF数据' + + def add_arguments(self, parser): + parser.add_argument( + '--import-file', + help='导入JSON数据文件的路径', + ) + parser.add_argument( + '--export-file', + help='导出JSON数据文件的路径', + ) + parser.add_argument( + '--import-tags', + action='store_true', + help='从init_tags.py导入标签', + ) + parser.add_argument( + '--clear', + action='store_true', + help='在导入前清除现有数据', + ) + + def handle(self, *args, **options): + if options['import_file']: + self.import_data(options['import_file'], options['clear']) + elif options['export_file']: + self.export_data(options['export_file']) + elif options['import_tags']: + self.import_tags_from_init_file() + else: + self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径')) + + def import_data(self, file_path, clear=False): + """从JSON文件导入数据""" + if not os.path.exists(file_path): + raise CommandError(f'文件不存在: {file_path}') + + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + if clear: + self.stdout.write(self.style.WARNING('正在清除现有数据...')) + self._clear_data() + + # 导入对话 + for conv_data in data.get('conversations', []): + self._import_conversation(conv_data) + + # 导入配置 + for config in data.get('system_configs', []): + self._import_system_config(config) + + self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}')) + + except Exception as e: + raise CommandError(f'导入数据失败: {str(e)}') + + def export_data(self, file_path): + """导出数据到JSON文件""" + try: + data = { + 'conversations': [], + 'system_configs': [], + 'tags': [], + 'export_time': timezone.now().isoformat() + } + + # 导出标签 + for tag in FeedbackTag.objects.all(): + data['tags'].append({ + 'id': str(tag.id), + 'tag_name': tag.tag_name, + 'tag_type': tag.tag_type, + 'description': tag.description, + 'created_at': tag.created_at.isoformat() + }) + + # 导出系统配置 + for config in SystemConfig.objects.all(): + data['system_configs'].append({ + 'id': str(config.id), + 'config_key': config.config_key, + 'config_value': config.config_value, + 'config_type': config.config_type, + 'description': config.description, + 'created_at': config.created_at.isoformat(), + 'updated_at': config.updated_at.isoformat() + }) + + # 导出对话 + for conv in Conversation.objects.all().prefetch_related('messages'): + conv_data = { + 'id': str(conv.id), + 'user_id': str(conv.user_id), + 'is_submitted': conv.is_submitted, + 'created_at': conv.created_at.isoformat(), + 'messages': [] + } + + for msg in conv.messages.all().order_by('timestamp'): + msg_data = { + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'timestamp': msg.timestamp.isoformat(), + 'feedback': [] + } + + # 获取基本反馈 + for fb in Feedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(fb.id), + 'type': 'basic', + 'feedback_value': fb.feedback_value, + 'user_id': str(fb.user_id), + 'timestamp': fb.timestamp.isoformat() + }) + + # 获取详细反馈 + for dfb in DetailedFeedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(dfb.id), + 'type': 'detailed', + 'feedback_type': dfb.feedback_type, + 'feedback_tags': dfb.feedback_tags, + 'custom_tags': dfb.custom_tags, + 'custom_content': dfb.custom_content, + 'is_inline': dfb.is_inline, + 'user_id': str(dfb.user_id), + 'created_at': dfb.created_at.isoformat(), + 'updated_at': dfb.updated_at.isoformat() + }) + + conv_data['messages'].append(msg_data) + + data['conversations'].append(conv_data) + + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}')) + + except Exception as e: + raise CommandError(f'导出数据失败: {str(e)}') + + def import_tags_from_init_file(self): + """从init_tags.py导入标签""" + try: + # 定义标签数据 + positive_tags = [ + ('有帮助', '回答对问题有实际帮助'), + ('准确', '信息准确可靠'), + ('清晰', '表达清楚易懂'), + ('完整', '回答全面完整'), + ('友好', '语调友善亲切'), + ('创新', '提供了新颖的观点') + ] + + negative_tags = [ + ('不准确', '包含错误信息'), + ('不相关', '回答偏离主题'), + ('不完整', '回答过于简略'), + ('不清晰', '表达模糊难懂'), + ('不友好', '语调生硬冷淡'), + ('重复', '内容重复冗余') + ] + + # 插入正面标签 + for tag_name, description in positive_tags: + FeedbackTag.objects.get_or_create( + tag_name=tag_name, + defaults={ + 'id': str(uuid.uuid4()), + 'tag_type': 'positive', + 'description': description, + 'created_at': timezone.now() + } + ) + + # 插入负面标签 + for tag_name, description in negative_tags: + FeedbackTag.objects.get_or_create( + tag_name=tag_name, + defaults={ + 'id': str(uuid.uuid4()), + 'tag_type': 'negative', + 'description': description, + 'created_at': timezone.now() + } + ) + + self.stdout.write(self.style.SUCCESS('成功导入标签')) + + except Exception as e: + raise CommandError(f'导入标签失败: {str(e)}') + + def _clear_data(self): + """清除现有数据""" + DetailedFeedback.objects.all().delete() + Feedback.objects.all().delete() + Message.objects.all().delete() + Conversation.objects.all().delete() + ConversationSubmission.objects.all().delete() + ConversationEvaluation.objects.all().delete() + self.stdout.write(self.style.SUCCESS('已清除现有数据')) + + def _import_conversation(self, conv_data): + """导入单个对话""" + # 检查用户是否存在 + user_id = conv_data.get('user_id') + if not User.objects.filter(id=user_id).exists(): + self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户')) + user = User.objects.filter(is_superuser=True).first() + if not user: + raise CommandError('找不到管理员用户') + user_id = user.id + + # 创建对话 + conv = Conversation.objects.create( + id=conv_data.get('id', str(uuid.uuid4())), + user_id=user_id, + is_submitted=conv_data.get('is_submitted', False), + created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat())) + ) + + # 创建消息 + for msg_data in conv_data.get('messages', []): + msg = Message.objects.create( + id=msg_data.get('id', str(uuid.uuid4())), + conversation=conv, + role=msg_data.get('role', 'user'), + content=msg_data.get('content', ''), + timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat())) + ) + + # 创建反馈 + for fb_data in msg_data.get('feedback', []): + if fb_data.get('type') == 'basic': + Feedback.objects.create( + id=fb_data.get('id', str(uuid.uuid4())), + message=msg, + conversation=conv, + user_id=fb_data.get('user_id', user_id), + feedback_value=fb_data.get('feedback_value', 0), + timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat())) + ) + elif fb_data.get('type') == 'detailed': + DetailedFeedback.objects.create( + id=fb_data.get('id', str(uuid.uuid4())), + message=msg, + conversation=conv, + user_id=fb_data.get('user_id', user_id), + feedback_type=fb_data.get('feedback_type', 'neutral'), + feedback_tags=fb_data.get('feedback_tags', '[]'), + custom_tags=fb_data.get('custom_tags', ''), + custom_content=fb_data.get('custom_content', ''), + is_inline=fb_data.get('is_inline', True), + created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())), + updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat())) + ) + + def _import_system_config(self, config_data): + """导入系统配置""" + SystemConfig.objects.update_or_create( + config_key=config_data.get('config_key'), + defaults={ + 'id': config_data.get('id', str(uuid.uuid4())), + 'config_value': config_data.get('config_value', ''), + 'config_type': config_data.get('config_type', 'string'), + 'description': config_data.get('description', ''), + 'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())), + 'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat())) + } ) \ No newline at end of file diff --git a/apps/rlhf/migrations/0001_initial.py b/apps/rlhf/migrations/0001_initial.py index 320c586..e6c419d 100644 --- a/apps/rlhf/migrations/0001_initial.py +++ b/apps/rlhf/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.5 on 2025-06-09 08:28 +# Generated by Django 5.2.1 on 2025-06-10 08:10 import django.db.models.deletion import django.utils.timezone @@ -12,6 +12,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ + ('chat', '0002_negotiationchat'), migrations.swappable_dependency(settings.AUTH_USER_MODEL), ] @@ -30,6 +31,17 @@ class Migration(migrations.Migration): 'verbose_name_plural': '反馈标签', }, ), + migrations.CreateModel( + name='RLHFConversation', + fields=[ + ('negotiation_chat', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='rlhf_extension', serialize=False, to='chat.negotiationchat')), + ('is_submitted', models.BooleanField(default=False)), + ], + options={ + 'verbose_name': 'RLHF对话', + 'verbose_name_plural': 'RLHF对话', + }, + ), migrations.CreateModel( name='SystemConfig', fields=[ @@ -46,23 +58,11 @@ class Migration(migrations.Migration): 'verbose_name_plural': '系统配置', }, ), - migrations.CreateModel( - name='Conversation', - fields=[ - ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), - ('is_submitted', models.BooleanField(default=False)), - ('created_at', models.DateTimeField(default=django.utils.timezone.now)), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='conversations', to=settings.AUTH_USER_MODEL)), - ], - options={ - 'verbose_name': '对话', - 'verbose_name_plural': '对话', - }, - ), migrations.CreateModel( name='ConversationSubmission', fields=[ ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), + ('conversation_id', models.CharField(max_length=100)), ('title', models.CharField(blank=True, max_length=255, null=True)), ('description', models.TextField(blank=True, null=True)), ('status', models.CharField(choices=[('submitted', '已提交'), ('reviewed', '已审核'), ('accepted', '已接受'), ('rejected', '已拒绝')], default='submitted', max_length=20)), @@ -72,7 +72,6 @@ class Migration(migrations.Migration): ('reviewed_at', models.DateTimeField(blank=True, null=True)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to='rlhf.conversation')), ('reviewer', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reviewed_submissions', to=settings.AUTH_USER_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to=settings.AUTH_USER_MODEL)), ], @@ -81,40 +80,11 @@ class Migration(migrations.Migration): 'verbose_name_plural': '对话提交', }, ), - migrations.CreateModel( - name='Message', - fields=[ - ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), - ('role', models.CharField(choices=[('user', '用户'), ('assistant', '助手'), ('system', '系统')], max_length=20)), - ('content', models.TextField()), - ('timestamp', models.DateTimeField(default=django.utils.timezone.now)), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='rlhf.conversation')), - ], - options={ - 'verbose_name': '消息', - 'verbose_name_plural': '消息', - 'ordering': ['timestamp'], - }, - ), - migrations.CreateModel( - name='Feedback', - fields=[ - ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), - ('feedback_value', models.IntegerField()), - ('timestamp', models.DateTimeField(default=django.utils.timezone.now)), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.conversation')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)), - ('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.message')), - ], - options={ - 'verbose_name': '反馈', - 'verbose_name_plural': '反馈', - }, - ), migrations.CreateModel( name='DetailedFeedback', fields=[ ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), + ('conversation_id', models.CharField(max_length=100)), ('feedback_type', models.CharField(choices=[('positive', '正面'), ('negative', '负面'), ('neutral', '中性')], max_length=20)), ('feedback_tags', models.TextField(blank=True, null=True)), ('custom_tags', models.TextField(blank=True, null=True)), @@ -122,31 +92,45 @@ class Migration(migrations.Migration): ('is_inline', models.BooleanField(default=True)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.conversation')), + ('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_detailed_feedback', to='chat.chathistory')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to=settings.AUTH_USER_MODEL)), - ('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.message')), ], options={ 'verbose_name': '详细反馈', 'verbose_name_plural': '详细反馈', }, ), + migrations.CreateModel( + name='Feedback', + fields=[ + ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), + ('conversation_id', models.CharField(max_length=100)), + ('feedback_value', models.IntegerField()), + ('timestamp', models.DateTimeField(default=django.utils.timezone.now)), + ('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_feedback', to='chat.chathistory')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': '反馈', + 'verbose_name_plural': '反馈', + }, + ), migrations.CreateModel( name='ConversationEvaluation', fields=[ ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), + ('conversation_id', models.CharField(max_length=100)), ('overall_feeling', models.TextField(blank=True, null=True)), ('has_logical_issues', models.CharField(choices=[('yes', '是'), ('no', '否'), ('unsure', '不确定')], max_length=10)), ('needs_satisfied', models.CharField(choices=[('yes', '是'), ('no', '否'), ('partially', '部分')], max_length=10)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), - ('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to='rlhf.conversation')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to=settings.AUTH_USER_MODEL)), ], options={ 'verbose_name': '对话评估', 'verbose_name_plural': '对话评估', - 'unique_together': {('conversation', 'user')}, + 'unique_together': {('conversation_id', 'user')}, }, ), ] diff --git a/apps/rlhf/models.py b/apps/rlhf/models.py index fc756c8..428a5cb 100644 --- a/apps/rlhf/models.py +++ b/apps/rlhf/models.py @@ -2,60 +2,91 @@ from django.db import models import uuid from django.utils import timezone from apps.user.models import User +from apps.chat.models import ChatHistory, NegotiationChat +from apps.daren_detail.models import CreatorProfile +from apps.brands.models import Product -class Conversation(models.Model): - id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations') +# 使用NegotiationChat替代Conversation +# class Conversation(models.Model): +# id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) +# user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations') +# is_submitted = models.BooleanField(default=False) +# created_at = models.DateTimeField(default=timezone.now) +# +# class Meta: +# verbose_name = '对话' +# verbose_name_plural = '对话' +# +# def __str__(self): +# return f"Conversation {self.id[:8]}" + + +# 为NegotiationChat添加RLHF所需字段的代理模型 +class RLHFConversation(models.Model): + """对NegotiationChat的扩展,添加RLHF所需字段""" + negotiation_chat = models.OneToOneField(NegotiationChat, on_delete=models.CASCADE, primary_key=True, related_name='rlhf_extension') is_submitted = models.BooleanField(default=False) - created_at = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '对话' - verbose_name_plural = '对话' + verbose_name = 'RLHF对话' + verbose_name_plural = 'RLHF对话' def __str__(self): - return f"Conversation {self.id[:8]}" + return f"RLHF Conversation {self.negotiation_chat.conversation_id[:8]}" + + @property + def id(self): + return self.negotiation_chat.conversation_id + + @property + def user(self): + # 从谈判关联的用户中获取 + return self.negotiation_chat.negotiation.user + + @property + def created_at(self): + return self.negotiation_chat.created_at -class Message(models.Model): - ROLE_CHOICES = ( - ('user', '用户'), - ('assistant', '助手'), - ('system', '系统'), - ) - - id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages') - role = models.CharField(max_length=20, choices=ROLE_CHOICES) - content = models.TextField() - timestamp = models.DateTimeField(default=timezone.now) - - class Meta: - - verbose_name = '消息' - verbose_name_plural = '消息' - ordering = ['timestamp'] - - def __str__(self): - return f"{self.role}: {self.content[:50]}..." +# 使用ChatHistory替代Message +# class Message(models.Model): +# ROLE_CHOICES = ( +# ('user', '用户'), +# ('assistant', '助手'), +# ('system', '系统'), +# ) +# +# id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) +# conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages') +# role = models.CharField(max_length=20, choices=ROLE_CHOICES) +# content = models.TextField() +# timestamp = models.DateTimeField(default=timezone.now) +# +# class Meta: +# +# verbose_name = '消息' +# verbose_name_plural = '消息' +# ordering = ['timestamp'] +# +# def __str__(self): +# return f"{self.role}: {self.content[:50]}..." class Feedback(models.Model): id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='feedback') - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='feedback') + message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_feedback') + conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='feedback') feedback_value = models.IntegerField() timestamp = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '反馈' verbose_name_plural = '反馈' def __str__(self): - return f"Feedback on {self.message.id[:8]}" + return f"Feedback on {self.message.id}" class FeedbackTag(models.Model): @@ -71,7 +102,6 @@ class FeedbackTag(models.Model): created_at = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '反馈标签' verbose_name_plural = '反馈标签' @@ -87,8 +117,8 @@ class DetailedFeedback(models.Model): ) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='detailed_feedback') - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='detailed_feedback') + message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_detailed_feedback') + conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='detailed_feedback') feedback_type = models.CharField(max_length=20, choices=FEEDBACK_TYPE_CHOICES) feedback_tags = models.TextField(blank=True, null=True) # JSON格式存储多个标签 @@ -99,12 +129,11 @@ class DetailedFeedback(models.Model): updated_at = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '详细反馈' verbose_name_plural = '详细反馈' def __str__(self): - return f"{self.feedback_type} feedback on {self.message.id[:8]}" + return f"{self.feedback_type} feedback on {self.message.id}" class ConversationSubmission(models.Model): @@ -116,7 +145,7 @@ class ConversationSubmission(models.Model): ) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='submissions') + conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='submissions') title = models.CharField(max_length=255, blank=True, null=True) description = models.TextField(blank=True, null=True) @@ -130,12 +159,11 @@ class ConversationSubmission(models.Model): updated_at = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '对话提交' verbose_name_plural = '对话提交' def __str__(self): - return f"Submission for {self.conversation.id[:8]}" + return f"Submission for {self.conversation_id[:8]}" class ConversationEvaluation(models.Model): @@ -152,7 +180,7 @@ class ConversationEvaluation(models.Model): ) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) - conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='evaluations') + conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='evaluations') overall_feeling = models.TextField(blank=True, null=True) has_logical_issues = models.CharField(max_length=10, choices=LOGICAL_CHOICES) @@ -161,13 +189,12 @@ class ConversationEvaluation(models.Model): updated_at = models.DateTimeField(default=timezone.now) class Meta: - verbose_name = '对话评估' verbose_name_plural = '对话评估' - unique_together = ('conversation', 'user') + unique_together = ('conversation_id', 'user') def __str__(self): - return f"Evaluation for {self.conversation.id[:8]}" + return f"Evaluation for {self.conversation_id[:8]}" class SystemConfig(models.Model): diff --git a/apps/rlhf/serializers.py b/apps/rlhf/serializers.py index 7ed471a..700c674 100644 --- a/apps/rlhf/serializers.py +++ b/apps/rlhf/serializers.py @@ -1,29 +1,43 @@ from rest_framework import serializers from .models import ( - Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, - ConversationSubmission, ConversationEvaluation, SystemConfig + Feedback, FeedbackTag, DetailedFeedback, + ConversationSubmission, ConversationEvaluation, SystemConfig, + RLHFConversation, NegotiationChat, ChatHistory ) from apps.user.serializers import UserSerializer class ConversationSerializer(serializers.ModelSerializer): + id = serializers.CharField(source='conversation_id', read_only=True) + user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True) + is_submitted = serializers.SerializerMethodField() + created_at = serializers.DateTimeField(source='created_at', read_only=True) + class Meta: - model = Conversation + model = NegotiationChat fields = ['id', 'user', 'is_submitted', 'created_at'] - read_only_fields = ['id', 'created_at', 'user'] + + def get_is_submitted(self, obj): + # 尝试获取RLHF扩展 + try: + return obj.rlhf_extension.is_submitted + except RLHFConversation.DoesNotExist: + return False class MessageSerializer(serializers.ModelSerializer): + conversation = serializers.CharField(source='conversation_id', read_only=True) + timestamp = serializers.DateTimeField(source='created_at', read_only=True) + class Meta: - model = Message + model = ChatHistory fields = ['id', 'conversation', 'role', 'content', 'timestamp'] - read_only_fields = ['id', 'timestamp'] class FeedbackSerializer(serializers.ModelSerializer): class Meta: model = Feedback - fields = ['id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp'] + fields = ['id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp'] read_only_fields = ['id', 'timestamp'] @@ -37,7 +51,11 @@ class FeedbackTagSerializer(serializers.ModelSerializer): class DetailedFeedbackSerializer(serializers.ModelSerializer): class Meta: model = DetailedFeedback - fields = ['id', 'message', 'conversation', 'user', 'feedback_type', 'feedback_tags', 'custom_tags', 'custom_content', 'is_inline', 'created_at', 'updated_at'] + fields = [ + 'id', 'message', 'conversation_id', 'user', 'feedback_type', + 'feedback_tags', 'custom_tags', 'custom_content', 'is_inline', + 'created_at', 'updated_at' + ] read_only_fields = ['id', 'created_at', 'updated_at'] @@ -47,38 +65,72 @@ class ConversationSubmissionSerializer(serializers.ModelSerializer): class Meta: model = ConversationSubmission - fields = ['id', 'conversation', 'user', 'user_details', 'title', 'description', 'status', 'quality_score', 'reviewer', 'reviewer_details', 'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'] + fields = [ + 'id', 'conversation_id', 'user', 'title', 'description', + 'status', 'quality_score', 'reviewer', 'reviewer_details', + 'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at' + ] read_only_fields = ['id', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'] class ConversationEvaluationSerializer(serializers.ModelSerializer): class Meta: model = ConversationEvaluation - fields = ['id', 'conversation', 'user', 'overall_feeling', 'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at'] + fields = [ + 'id', 'conversation_id', 'user', 'overall_feeling', + 'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at' + ] read_only_fields = ['id', 'created_at', 'updated_at'] class SystemConfigSerializer(serializers.ModelSerializer): class Meta: model = SystemConfig - fields = ['id', 'config_key', 'config_value', 'config_type', 'description', 'created_at', 'updated_at'] + fields = [ + 'id', 'config_key', 'config_value', 'config_type', + 'description', 'created_at', 'updated_at' + ] read_only_fields = ['id', 'created_at', 'updated_at'] class ConversationWithMessagesSerializer(serializers.ModelSerializer): - messages = MessageSerializer(many=True, read_only=True) + id = serializers.CharField(source='conversation_id', read_only=True) + user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True) + is_submitted = serializers.SerializerMethodField() + created_at = serializers.DateTimeField(source='created_at', read_only=True) + messages = serializers.SerializerMethodField() class Meta: - model = Conversation + model = NegotiationChat fields = ['id', 'user', 'is_submitted', 'created_at', 'messages'] - read_only_fields = ['id', 'created_at'] + + def get_is_submitted(self, obj): + try: + return obj.rlhf_extension.is_submitted + except RLHFConversation.DoesNotExist: + return False + + def get_messages(self, obj): + messages = ChatHistory.objects.filter( + conversation_id=obj.conversation_id + ).order_by('created_at') + return MessageSerializer(messages, many=True).data class MessageWithFeedbackSerializer(serializers.ModelSerializer): - feedback = FeedbackSerializer(many=True, read_only=True) - detailed_feedback = DetailedFeedbackSerializer(many=True, read_only=True) + conversation = serializers.CharField(source='conversation_id', read_only=True) + timestamp = serializers.DateTimeField(source='created_at', read_only=True) + feedback = serializers.SerializerMethodField() + detailed_feedback = serializers.SerializerMethodField() class Meta: - model = Message + model = ChatHistory fields = ['id', 'conversation', 'role', 'content', 'timestamp', 'feedback', 'detailed_feedback'] - read_only_fields = ['id', 'timestamp'] \ No newline at end of file + + def get_feedback(self, obj): + feedback = Feedback.objects.filter(message=obj) + return FeedbackSerializer(feedback, many=True).data + + def get_detailed_feedback(self, obj): + detailed_feedback = DetailedFeedback.objects.filter(message=obj) + return DetailedFeedbackSerializer(detailed_feedback, many=True).data \ No newline at end of file diff --git a/apps/rlhf/views.py b/apps/rlhf/views.py index 85cd810..a17f4aa 100644 --- a/apps/rlhf/views.py +++ b/apps/rlhf/views.py @@ -5,8 +5,9 @@ from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated from rest_framework.pagination import PageNumberPagination from .models import ( - Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, - ConversationSubmission, ConversationEvaluation, SystemConfig + Feedback, FeedbackTag, DetailedFeedback, + ConversationSubmission, ConversationEvaluation, SystemConfig, + NegotiationChat, ChatHistory, RLHFConversation, CreatorProfile, Product ) from .serializers import ( ConversationSerializer, MessageSerializer, FeedbackSerializer, @@ -98,22 +99,44 @@ class StandardResponseMixin: class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): - queryset = Conversation.objects.all() + queryset = NegotiationChat.objects.all() serializer_class = ConversationSerializer authentication_classes = [CustomTokenAuthentication] permission_classes = [IsAuthenticated] def get_queryset(self): user = self.request.user - return Conversation.objects.filter(user=user).order_by('-created_at') + return NegotiationChat.objects.filter(negotiation__user=user).order_by('-updated_at') def perform_create(self, serializer): - serializer.save(user=self.request.user) + creator = CreatorProfile.objects.first() + product = Product.objects.first() + + negotiation = Negotiation.objects.create( + user=self.request.user, + status='active' + ) + + chat = NegotiationChat.objects.create( + negotiation=negotiation, + conversation_id=str(uuid.uuid4()), + creator=creator, + product=product + ) + + RLHFConversation.objects.create( + negotiation_chat=chat, + is_submitted=False + ) + + serializer.instance = chat @action(detail=True, methods=['get']) def messages(self, request, pk=None): conversation = self.get_object() - messages = Message.objects.filter(conversation=conversation).order_by('timestamp') + messages = ChatHistory.objects.filter( + conversation_id=conversation.conversation_id + ).order_by('created_at') serializer = MessageSerializer(messages, many=True) return self.get_standard_response(data=serializer.data) @@ -130,27 +153,26 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): status_code=status.HTTP_400_BAD_REQUEST ) - # 创建用户消息 - user_message = Message.objects.create( - id=str(uuid.uuid4()), - conversation=conversation, + knowledge_base = KnowledgeBase.objects.first() + + user_message = ChatHistory.objects.create( + user=request.user, + knowledge_base=knowledge_base, + conversation_id=conversation.conversation_id, role='user', content=content ) - # 这里需要调用AI服务获取回复 - # 示例:调用SiliconFlow或其他AI服务 - ai_response = self._generate_ai_response(user_message.content, conversation) + ai_response = self._generate_ai_response(content, conversation) - # 创建AI回复消息 - ai_message = Message.objects.create( - id=str(uuid.uuid4()), - conversation=conversation, + ai_message = ChatHistory.objects.create( + user=request.user, + knowledge_base=knowledge_base, + conversation_id=conversation.conversation_id, role='assistant', content=ai_response ) - # 更新用户的标注统计 self._update_annotation_stats(request.user.id) messages = [ @@ -166,7 +188,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): title = request.data.get('title', '') description = request.data.get('description', '') - if conversation.is_submitted: + rlhf_conv, created = RLHFConversation.objects.get_or_create( + negotiation_chat=conversation, + defaults={'is_submitted': False} + ) + + if rlhf_conv.is_submitted: return self.get_standard_response( code=400, message='该对话已提交', @@ -174,14 +201,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): status_code=status.HTTP_400_BAD_REQUEST ) - # 更新对话为已提交状态 - conversation.is_submitted = True - conversation.save() + rlhf_conv.is_submitted = True + rlhf_conv.save() - # 创建提交记录 submission = ConversationSubmission.objects.create( id=str(uuid.uuid4()), - conversation=conversation, + conversation_id=conversation.conversation_id, user=request.user, title=title, description=description, @@ -189,12 +214,11 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): submitted_at=timezone.now() ) - # 记录活动日志 UserActivityLog.objects.create( user=request.user, action_type='conversation_submit', target_type='conversation', - target_id=str(conversation.id), + target_id=conversation.conversation_id, details={'title': title} ) @@ -207,7 +231,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): def resume(self, request, pk=None): conversation = self.get_object() - if not conversation.is_submitted: + rlhf_conv, created = RLHFConversation.objects.get_or_create( + negotiation_chat=conversation, + defaults={'is_submitted': False} + ) + + if not rlhf_conv.is_submitted: return self.get_standard_response( code=400, message='该对话未提交,无需恢复', @@ -215,25 +244,22 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): status_code=status.HTTP_400_BAD_REQUEST ) - # 更新对话为未提交状态 - conversation.is_submitted = False - conversation.save() + rlhf_conv.is_submitted = False + rlhf_conv.save() - # 获取最新的提交记录 submission = ConversationSubmission.objects.filter( - conversation=conversation + conversation_id=conversation.conversation_id ).order_by('-submitted_at').first() if submission and submission.status == 'submitted': submission.status = 'rejected' submission.save() - # 记录活动日志 UserActivityLog.objects.create( user=request.user, action_type='conversation_resume', target_type='conversation', - target_id=str(conversation.id) + target_id=conversation.conversation_id ) return self.get_standard_response( @@ -264,7 +290,7 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): sf_client.set_system_message(system_prompt_config.config_value) # 获取历史消息作为上下文 - history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp') + history_messages = ChatHistory.objects.filter(conversation_id=conversation.conversation_id).order_by('created_at') # 添加历史消息到客户端 for msg in history_messages: @@ -385,28 +411,28 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): def _get_recent_conversations(self, user_id, limit=5): """获取用户最近的对话""" - conversations = Conversation.objects.filter( - user_id=user_id - ).order_by('-created_at')[:limit] + conversations = NegotiationChat.objects.filter( + negotiation__user_id=user_id + ).order_by('-updated_at')[:limit] result = [] for conv in conversations: # 获取最后一条消息内容作为对话摘要 - last_message = Message.objects.filter( - conversation_id=conv.id - ).order_by('-timestamp').first() + last_message = ChatHistory.objects.filter( + conversation_id=conv.conversation_id + ).order_by('-created_at').first() # 统计消息数 - message_count = Message.objects.filter(conversation_id=conv.id).count() + message_count = ChatHistory.objects.filter(conversation_id=conv.conversation_id).count() # 统计反馈数 - feedback_count = Feedback.objects.filter(conversation_id=conv.id).count() - detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.id).count() + feedback_count = Feedback.objects.filter(conversation_id=conv.conversation_id).count() + detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.conversation_id).count() result.append({ - 'id': str(conv.id), - 'created_at': conv.created_at.isoformat(), - 'is_submitted': conv.is_submitted, + 'id': str(conv.conversation_id), + 'created_at': conv.updated_at.isoformat(), + 'is_submitted': conv.negotiation.is_submitted, 'message_count': message_count, 'feedback_count': feedback_count + detailed_count, 'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "") @@ -416,12 +442,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): def _get_conversation_stats(self, user_id): """获取对话统计""" - total_conversations = Conversation.objects.filter(user_id=user_id).count() - submitted_conversations = Conversation.objects.filter(user_id=user_id, is_submitted=True).count() + total_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id).count() + submitted_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id, negotiation__is_submitted=True).count() # 对话消息统计 - message_stats = Message.objects.filter( - conversation__user_id=user_id + message_stats = ChatHistory.objects.filter( + conversation__negotiation__user_id=user_id ).aggregate( total=Count('id'), user_messages=Count('id', filter=Q(role='user')), @@ -620,14 +646,21 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet): - queryset = Message.objects.all() + queryset = ChatHistory.objects.all() serializer_class = MessageSerializer authentication_classes = [CustomTokenAuthentication] permission_classes = [IsAuthenticated] def get_queryset(self): user = self.request.user - return Message.objects.filter(conversation__user=user).order_by('timestamp') + # 获取用户参与的所有对话的ID + user_conversation_ids = NegotiationChat.objects.filter( + negotiation__user=user + ).values_list('conversation_id', flat=True) + # 筛选这些对话中的消息 + return ChatHistory.objects.filter( + conversation_id__in=user_conversation_ids + ).order_by('created_at') class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):