修改rlhf

This commit is contained in:
jlj 2025-06-10 18:18:28 +08:00
parent 6a0ae5b132
commit 954e314afe
7 changed files with 951 additions and 839 deletions

View File

@ -1,24 +1,40 @@
from django.contrib import admin from django.contrib import admin
from .models import ( from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, RLHFConversation, NegotiationChat, ChatHistory, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig ConversationSubmission, ConversationEvaluation, SystemConfig
) )
@admin.register(Conversation) @admin.register(RLHFConversation)
class ConversationAdmin(admin.ModelAdmin): class RLHFConversationAdmin(admin.ModelAdmin):
list_display = ('id', 'user', 'is_submitted', 'created_at') list_display = ('negotiation_chat', 'id', 'user', 'is_submitted', 'created_at')
list_filter = ('is_submitted', 'created_at') list_filter = ('is_submitted',)
search_fields = ('id', 'user__username') search_fields = ('negotiation_chat__conversation_id', 'negotiation_chat__negotiation__user__username')
def id(self, obj):
return obj.negotiation_chat.conversation_id
def user(self, obj):
return obj.negotiation_chat.negotiation.user
def created_at(self, obj):
return obj.negotiation_chat.created_at
@admin.register(NegotiationChat)
class NegotiationChatAdmin(admin.ModelAdmin):
list_display = ('conversation_id', 'negotiation', 'creator', 'product', 'created_at', 'updated_at')
list_filter = ('created_at', 'updated_at')
search_fields = ('conversation_id', 'negotiation__user__username', 'creator__name', 'product__name')
date_hierarchy = 'created_at' date_hierarchy = 'created_at'
@admin.register(Message) @admin.register(ChatHistory)
class MessageAdmin(admin.ModelAdmin): class ChatHistoryAdmin(admin.ModelAdmin):
list_display = ('id', 'conversation', 'role', 'short_content', 'timestamp') list_display = ('id', 'conversation_id', 'user', 'role', 'short_content', 'created_at')
list_filter = ('role', 'timestamp') list_filter = ('role', 'created_at')
search_fields = ('id', 'conversation__id', 'content') search_fields = ('id', 'conversation_id', 'content')
date_hierarchy = 'timestamp' date_hierarchy = 'created_at'
def short_content(self, obj): def short_content(self, obj):
return obj.content[:50] + '...' if len(obj.content) > 50 else obj.content return obj.content[:50] + '...' if len(obj.content) > 50 else obj.content
@ -27,9 +43,9 @@ class MessageAdmin(admin.ModelAdmin):
@admin.register(Feedback) @admin.register(Feedback)
class FeedbackAdmin(admin.ModelAdmin): class FeedbackAdmin(admin.ModelAdmin):
list_display = ('id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp') list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp')
list_filter = ('feedback_value', 'timestamp') list_filter = ('feedback_value', 'timestamp')
search_fields = ('id', 'message__id', 'conversation__id', 'user__username') search_fields = ('id', 'message__id', 'conversation_id', 'user__username')
date_hierarchy = 'timestamp' date_hierarchy = 'timestamp'
@ -42,25 +58,25 @@ class FeedbackTagAdmin(admin.ModelAdmin):
@admin.register(DetailedFeedback) @admin.register(DetailedFeedback)
class DetailedFeedbackAdmin(admin.ModelAdmin): class DetailedFeedbackAdmin(admin.ModelAdmin):
list_display = ('id', 'message', 'conversation', 'user', 'feedback_type', 'is_inline', 'created_at') list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_type', 'is_inline', 'created_at')
list_filter = ('feedback_type', 'is_inline', 'created_at') list_filter = ('feedback_type', 'is_inline', 'created_at')
search_fields = ('id', 'message__id', 'conversation__id', 'user__username', 'custom_content') search_fields = ('id', 'message__id', 'conversation_id', 'user__username', 'custom_content')
date_hierarchy = 'created_at' date_hierarchy = 'created_at'
@admin.register(ConversationSubmission) @admin.register(ConversationSubmission)
class ConversationSubmissionAdmin(admin.ModelAdmin): class ConversationSubmissionAdmin(admin.ModelAdmin):
list_display = ('id', 'conversation', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at') list_display = ('id', 'conversation_id', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at')
list_filter = ('status', 'quality_score', 'submitted_at', 'reviewed_at') list_filter = ('status', 'quality_score', 'submitted_at', 'reviewed_at')
search_fields = ('id', 'conversation__id', 'user__username', 'title', 'description', 'reviewer__username') search_fields = ('id', 'conversation_id', 'user__username', 'title', 'description', 'reviewer__username')
date_hierarchy = 'submitted_at' date_hierarchy = 'submitted_at'
@admin.register(ConversationEvaluation) @admin.register(ConversationEvaluation)
class ConversationEvaluationAdmin(admin.ModelAdmin): class ConversationEvaluationAdmin(admin.ModelAdmin):
list_display = ('id', 'conversation', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at') list_display = ('id', 'conversation_id', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at')
list_filter = ('has_logical_issues', 'needs_satisfied', 'created_at') list_filter = ('has_logical_issues', 'needs_satisfied', 'created_at')
search_fields = ('id', 'conversation__id', 'user__username', 'overall_feeling') search_fields = ('id', 'conversation_id', 'user__username', 'overall_feeling')
date_hierarchy = 'created_at' date_hierarchy = 'created_at'

View File

@ -1,368 +1,368 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag
from django.db.models import Count, Avg, Sum, Q, F from django.db.models import Count, Avg, Sum, Q, F
from django.utils import timezone from django.utils import timezone
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
User = get_user_model() User = get_user_model()
class Command(BaseCommand): class Command(BaseCommand):
help = '分析RLHF反馈数据生成统计报告' help = '分析RLHF反馈数据生成统计报告'
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--export', '--export',
action='store_true', action='store_true',
help='导出数据到JSON文件', help='导出数据到JSON文件',
) )
parser.add_argument( parser.add_argument(
'--days', '--days',
type=int, type=int,
default=30, default=30,
help='分析最近的天数', help='分析最近的天数',
) )
def handle(self, *args, **options): def handle(self, *args, **options):
self.stdout.write(self.style.SUCCESS("=" * 60)) self.stdout.write(self.style.SUCCESS("=" * 60))
self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告")) self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告"))
self.stdout.write(self.style.SUCCESS("=" * 60)) self.stdout.write(self.style.SUCCESS("=" * 60))
# 基本统计 # 基本统计
feedback_stats = self.get_feedback_stats() feedback_stats = self.get_feedback_stats()
self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:")) self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:"))
self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}") self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}")
self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)") self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)")
self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}") self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}")
self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}") self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}")
# 对话统计 # 对话统计
conv_stats = self.get_conversation_stats() conv_stats = self.get_conversation_stats()
self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:")) self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:"))
self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}") self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}")
self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}") self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}")
self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}") self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}")
# 标签统计 # 标签统计
tag_stats = self.get_tag_stats() tag_stats = self.get_tag_stats()
self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:")) self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:"))
self.stdout.write(f" 最常用的正面标签:") self.stdout.write(f" 最常用的正面标签:")
for tag in tag_stats['top_positive']: for tag in tag_stats['top_positive']:
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}") self.stdout.write(f" - {tag['tag_name']}: {tag['count']}")
self.stdout.write(f" 最常用的负面标签:") self.stdout.write(f" 最常用的负面标签:")
for tag in tag_stats['top_negative']: for tag in tag_stats['top_negative']:
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}") self.stdout.write(f" - {tag['tag_name']}: {tag['count']}")
# 每日趋势 # 每日趋势
days = options['days'] days = options['days']
daily_trend = self.get_daily_feedback_trend(days) daily_trend = self.get_daily_feedback_trend(days)
self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:")) self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:"))
for day in daily_trend: for day in daily_trend:
self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)") self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)")
# 用户统计 # 用户统计
user_stats = self.get_user_stats() user_stats = self.get_user_stats()
self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:")) self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:"))
self.stdout.write(f" 总用户数量: {user_stats['total_users']}") self.stdout.write(f" 总用户数量: {user_stats['total_users']}")
self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}") self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}")
self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}") self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}")
# 导出数据 # 导出数据
if options['export']: if options['export']:
filename = self.export_data_to_json() filename = self.export_data_to_json()
self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}")) self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}"))
def get_feedback_stats(self): def get_feedback_stats(self):
"""获取反馈统计信息""" """获取反馈统计信息"""
# 基本反馈统计 # 基本反馈统计
basic_feedback = Feedback.objects.aggregate( basic_feedback = Feedback.objects.aggregate(
total=Count('id'), total=Count('id'),
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)), negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)),
avg=Avg('feedback_value') avg=Avg('feedback_value')
) )
# 详细反馈统计 # 详细反馈统计
detailed_feedback = DetailedFeedback.objects.aggregate( detailed_feedback = DetailedFeedback.objects.aggregate(
total=Count('id'), total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')), positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative')) negative=Count('id', filter=Q(feedback_type='negative'))
) )
# 合并统计 # 合并统计
total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0)
positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0)
negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0)
# 计算平均分和正面比例 # 计算平均分和正面比例
avg_feedback = basic_feedback['avg'] or 0 avg_feedback = basic_feedback['avg'] or 0
positive_rate = (positive / total * 100) if total > 0 else 0 positive_rate = (positive / total * 100) if total > 0 else 0
return { return {
'total_feedback': total, 'total_feedback': total,
'positive_feedback': positive, 'positive_feedback': positive,
'negative_feedback': negative, 'negative_feedback': negative,
'avg_feedback': avg_feedback, 'avg_feedback': avg_feedback,
'positive_rate': positive_rate 'positive_rate': positive_rate
} }
def get_conversation_stats(self): def get_conversation_stats(self):
"""获取对话统计信息""" """获取对话统计信息"""
total_conversations = Conversation.objects.count() total_conversations = Conversation.objects.count()
total_messages = Message.objects.count() total_messages = Message.objects.count()
# 计算每个对话的消息数量分布 # 计算每个对话的消息数量分布
conversation_messages = Message.objects.values('conversation').annotate(count=Count('id')) conversation_messages = Message.objects.values('conversation').annotate(count=Count('id'))
avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0 avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0
return { return {
'total_conversations': total_conversations, 'total_conversations': total_conversations,
'total_messages': total_messages, 'total_messages': total_messages,
'avg_messages_per_conversation': avg_messages 'avg_messages_per_conversation': avg_messages
} }
def get_tag_stats(self): def get_tag_stats(self):
"""获取标签使用统计""" """获取标签使用统计"""
# 分析DetailedFeedback中的标签使用情况 # 分析DetailedFeedback中的标签使用情况
# 注意由于标签可能存储为JSON字符串这里需要解析 # 注意由于标签可能存储为JSON字符串这里需要解析
# 首先获取所有的标签 # 首先获取所有的标签
all_tags = FeedbackTag.objects.all() all_tags = FeedbackTag.objects.all()
tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags} tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags}
# 计算每个标签的使用次数 # 计算每个标签的使用次数
tag_counts = {} tag_counts = {}
for feedback in DetailedFeedback.objects.all(): for feedback in DetailedFeedback.objects.all():
if feedback.feedback_tags: if feedback.feedback_tags:
try: try:
# 尝试解析JSON标签列表 # 尝试解析JSON标签列表
tag_ids = json.loads(feedback.feedback_tags) tag_ids = json.loads(feedback.feedback_tags)
if isinstance(tag_ids, list): if isinstance(tag_ids, list):
for tag_id in tag_ids: for tag_id in tag_ids:
tag_id = str(tag_id) tag_id = str(tag_id)
if tag_id in tag_counts: if tag_id in tag_counts:
tag_counts[tag_id] += 1 tag_counts[tag_id] += 1
else: else:
tag_counts[tag_id] = 1 tag_counts[tag_id] = 1
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
# 如果不是有效的JSON可能是单个标签ID # 如果不是有效的JSON可能是单个标签ID
tag_id = str(feedback.feedback_tags) tag_id = str(feedback.feedback_tags)
if tag_id in tag_counts: if tag_id in tag_counts:
tag_counts[tag_id] += 1 tag_counts[tag_id] += 1
else: else:
tag_counts[tag_id] = 1 tag_counts[tag_id] = 1
# 获取排名前5的正面和负面标签 # 获取排名前5的正面和负面标签
positive_tags = FeedbackTag.objects.filter(tag_type='positive') positive_tags = FeedbackTag.objects.filter(tag_type='positive')
negative_tags = FeedbackTag.objects.filter(tag_type='negative') negative_tags = FeedbackTag.objects.filter(tag_type='negative')
top_positive = [] top_positive = []
for tag in positive_tags: for tag in positive_tags:
tag_id = str(tag.id) tag_id = str(tag.id)
if tag_id in tag_counts: if tag_id in tag_counts:
top_positive.append({ top_positive.append({
'tag_name': tag.tag_name, 'tag_name': tag.tag_name,
'count': tag_counts[tag_id] 'count': tag_counts[tag_id]
}) })
top_negative = [] top_negative = []
for tag in negative_tags: for tag in negative_tags:
tag_id = str(tag.id) tag_id = str(tag.id)
if tag_id in tag_counts: if tag_id in tag_counts:
top_negative.append({ top_negative.append({
'tag_name': tag.tag_name, 'tag_name': tag.tag_name,
'count': tag_counts[tag_id] 'count': tag_counts[tag_id]
}) })
# 按使用次数排序 # 按使用次数排序
top_positive.sort(key=lambda x: x['count'], reverse=True) top_positive.sort(key=lambda x: x['count'], reverse=True)
top_negative.sort(key=lambda x: x['count'], reverse=True) top_negative.sort(key=lambda x: x['count'], reverse=True)
# 取前5 # 取前5
return { return {
'top_positive': top_positive[:5], 'top_positive': top_positive[:5],
'top_negative': top_negative[:5] 'top_negative': top_negative[:5]
} }
def get_daily_feedback_trend(self, days=30): def get_daily_feedback_trend(self, days=30):
"""获取每日反馈趋势""" """获取每日反馈趋势"""
# 计算开始日期 # 计算开始日期
start_date = timezone.now().date() - timedelta(days=days) start_date = timezone.now().date() - timedelta(days=days)
# 基本反馈按日期分组 # 基本反馈按日期分组
basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \ basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \
.values('timestamp__date') \ .values('timestamp__date') \
.annotate( .annotate(
date=F('timestamp__date'), date=F('timestamp__date'),
total=Count('id'), total=Count('id'),
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)) negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0))
) \ ) \
.order_by('date') .order_by('date')
# 详细反馈按日期分组 # 详细反馈按日期分组
detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \ detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \
.values('created_at__date') \ .values('created_at__date') \
.annotate( .annotate(
date=F('created_at__date'), date=F('created_at__date'),
total=Count('id'), total=Count('id'),
positive=Count('id', filter=Q(feedback_type='positive')), positive=Count('id', filter=Q(feedback_type='positive')),
negative=Count('id', filter=Q(feedback_type='negative')) negative=Count('id', filter=Q(feedback_type='negative'))
) \ ) \
.order_by('date') .order_by('date')
# 合并两种反馈数据 # 合并两种反馈数据
daily_data = {} daily_data = {}
for item in basic_daily: for item in basic_daily:
date_str = item['date'].strftime('%Y-%m-%d') date_str = item['date'].strftime('%Y-%m-%d')
daily_data[date_str] = { daily_data[date_str] = {
'date': date_str, 'date': date_str,
'total': item['total'], 'total': item['total'],
'positive': item['positive'], 'positive': item['positive'],
'negative': item['negative'] 'negative': item['negative']
} }
for item in detailed_daily: for item in detailed_daily:
date_str = item['date'].strftime('%Y-%m-%d') date_str = item['date'].strftime('%Y-%m-%d')
if date_str in daily_data: if date_str in daily_data:
daily_data[date_str]['total'] += item['total'] daily_data[date_str]['total'] += item['total']
daily_data[date_str]['positive'] += item['positive'] daily_data[date_str]['positive'] += item['positive']
daily_data[date_str]['negative'] += item['negative'] daily_data[date_str]['negative'] += item['negative']
else: else:
daily_data[date_str] = { daily_data[date_str] = {
'date': date_str, 'date': date_str,
'total': item['total'], 'total': item['total'],
'positive': item['positive'], 'positive': item['positive'],
'negative': item['negative'] 'negative': item['negative']
} }
# 计算正面反馈比例 # 计算正面反馈比例
for date_str, data in daily_data.items(): for date_str, data in daily_data.items():
data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0 data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0
# 转换为列表并按日期排序 # 转换为列表并按日期排序
result = list(daily_data.values()) result = list(daily_data.values())
result.sort(key=lambda x: x['date']) result.sort(key=lambda x: x['date'])
return result return result
def get_user_stats(self): def get_user_stats(self):
"""获取用户统计信息""" """获取用户统计信息"""
# 总用户数 # 总用户数
total_users = User.objects.count() total_users = User.objects.count()
# 有反馈记录的用户数 # 有反馈记录的用户数
users_with_feedback = User.objects.filter( users_with_feedback = User.objects.filter(
Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False) Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False)
).distinct().count() ).distinct().count()
# 最近30天活跃的标注用户 # 最近30天活跃的标注用户
thirty_days_ago = timezone.now() - timedelta(days=30) thirty_days_ago = timezone.now() - timedelta(days=30)
active_users = User.objects.filter( active_users = User.objects.filter(
Q(feedback__timestamp__gte=thirty_days_ago) | Q(feedback__timestamp__gte=thirty_days_ago) |
Q(detailed_feedback__created_at__gte=thirty_days_ago) Q(detailed_feedback__created_at__gte=thirty_days_ago)
).distinct().count() ).distinct().count()
# 计算每个用户的标注量 # 计算每个用户的标注量
user_annotations = {} user_annotations = {}
for feedback in Feedback.objects.all(): for feedback in Feedback.objects.all():
user_id = str(feedback.user_id) user_id = str(feedback.user_id)
if user_id in user_annotations: if user_id in user_annotations:
user_annotations[user_id] += 1 user_annotations[user_id] += 1
else: else:
user_annotations[user_id] = 1 user_annotations[user_id] = 1
for feedback in DetailedFeedback.objects.all(): for feedback in DetailedFeedback.objects.all():
user_id = str(feedback.user_id) user_id = str(feedback.user_id)
if user_id in user_annotations: if user_id in user_annotations:
user_annotations[user_id] += 1 user_annotations[user_id] += 1
else: else:
user_annotations[user_id] = 1 user_annotations[user_id] = 1
# 计算平均每用户标注量 # 计算平均每用户标注量
if user_annotations: if user_annotations:
avg_annotations = sum(user_annotations.values()) / len(user_annotations) avg_annotations = sum(user_annotations.values()) / len(user_annotations)
else: else:
avg_annotations = 0 avg_annotations = 0
return { return {
'total_users': total_users, 'total_users': total_users,
'users_with_feedback': users_with_feedback, 'users_with_feedback': users_with_feedback,
'active_users': active_users, 'active_users': active_users,
'avg_annotations_per_user': avg_annotations 'avg_annotations_per_user': avg_annotations
} }
def export_data_to_json(self): def export_data_to_json(self):
"""导出数据到JSON文件""" """导出数据到JSON文件"""
data = { data = {
'conversations': [], 'conversations': [],
'feedback_summary': self.get_feedback_stats(), 'feedback_summary': self.get_feedback_stats(),
'tag_stats': self.get_tag_stats(), 'tag_stats': self.get_tag_stats(),
'daily_trend': self.get_daily_feedback_trend(30), 'daily_trend': self.get_daily_feedback_trend(30),
'export_time': timezone.now().isoformat() 'export_time': timezone.now().isoformat()
} }
# 导出对话和消息数据 # 导出对话和消息数据
for conv in Conversation.objects.all().prefetch_related('messages'): for conv in Conversation.objects.all().prefetch_related('messages'):
conv_data = { conv_data = {
'id': str(conv.id), 'id': str(conv.id),
'created_at': conv.created_at.isoformat(), 'created_at': conv.created_at.isoformat(),
'user_id': str(conv.user_id), 'user_id': str(conv.user_id),
'is_submitted': conv.is_submitted, 'is_submitted': conv.is_submitted,
'messages': [] 'messages': []
} }
for msg in conv.messages.all().order_by('timestamp'): for msg in conv.messages.all().order_by('timestamp'):
msg_data = { msg_data = {
'id': str(msg.id), 'id': str(msg.id),
'role': msg.role, 'role': msg.role,
'content': msg.content, 'content': msg.content,
'timestamp': msg.timestamp.isoformat(), 'timestamp': msg.timestamp.isoformat(),
'feedback': [] 'feedback': []
} }
# 获取消息的反馈 # 获取消息的反馈
for fb in Feedback.objects.filter(message_id=msg.id): for fb in Feedback.objects.filter(message_id=msg.id):
msg_data['feedback'].append({ msg_data['feedback'].append({
'id': str(fb.id), 'id': str(fb.id),
'type': 'basic', 'type': 'basic',
'value': fb.feedback_value, 'value': fb.feedback_value,
'user_id': str(fb.user_id), 'user_id': str(fb.user_id),
'timestamp': fb.timestamp.isoformat() 'timestamp': fb.timestamp.isoformat()
}) })
# 获取详细反馈 # 获取详细反馈
for dfb in DetailedFeedback.objects.filter(message_id=msg.id): for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
try: try:
tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else [] tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else []
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
tags = [dfb.feedback_tags] if dfb.feedback_tags else [] tags = [dfb.feedback_tags] if dfb.feedback_tags else []
msg_data['feedback'].append({ msg_data['feedback'].append({
'id': str(dfb.id), 'id': str(dfb.id),
'type': 'detailed', 'type': 'detailed',
'feedback_type': dfb.feedback_type, 'feedback_type': dfb.feedback_type,
'tags': tags, 'tags': tags,
'custom_tags': dfb.custom_tags, 'custom_tags': dfb.custom_tags,
'custom_content': dfb.custom_content, 'custom_content': dfb.custom_content,
'is_inline': dfb.is_inline, 'is_inline': dfb.is_inline,
'user_id': str(dfb.user_id), 'user_id': str(dfb.user_id),
'timestamp': dfb.created_at.isoformat() 'timestamp': dfb.created_at.isoformat()
}) })
conv_data['messages'].append(msg_data) conv_data['messages'].append(msg_data)
data['conversations'].append(conv_data) data['conversations'].append(conv_data)
# 保存到文件 # 保存到文件
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'rlhf_data_export_{timestamp}.json' filename = f'rlhf_data_export_{timestamp}.json'
with open(filename, 'w', encoding='utf-8') as f: with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
return filename return filename

View File

@ -1,289 +1,289 @@
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from rlhf.models import ( from rlhf.models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig ConversationSubmission, ConversationEvaluation, SystemConfig
) )
from django.utils import timezone from django.utils import timezone
import json import json
import uuid import uuid
import os import os
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
User = get_user_model() User = get_user_model()
class Command(BaseCommand): class Command(BaseCommand):
help = '导入/导出RLHF数据' help = '导入/导出RLHF数据'
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--import-file', '--import-file',
help='导入JSON数据文件的路径', help='导入JSON数据文件的路径',
) )
parser.add_argument( parser.add_argument(
'--export-file', '--export-file',
help='导出JSON数据文件的路径', help='导出JSON数据文件的路径',
) )
parser.add_argument( parser.add_argument(
'--import-tags', '--import-tags',
action='store_true', action='store_true',
help='从init_tags.py导入标签', help='从init_tags.py导入标签',
) )
parser.add_argument( parser.add_argument(
'--clear', '--clear',
action='store_true', action='store_true',
help='在导入前清除现有数据', help='在导入前清除现有数据',
) )
def handle(self, *args, **options): def handle(self, *args, **options):
if options['import_file']: if options['import_file']:
self.import_data(options['import_file'], options['clear']) self.import_data(options['import_file'], options['clear'])
elif options['export_file']: elif options['export_file']:
self.export_data(options['export_file']) self.export_data(options['export_file'])
elif options['import_tags']: elif options['import_tags']:
self.import_tags_from_init_file() self.import_tags_from_init_file()
else: else:
self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径')) self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径'))
def import_data(self, file_path, clear=False): def import_data(self, file_path, clear=False):
"""从JSON文件导入数据""" """从JSON文件导入数据"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise CommandError(f'文件不存在: {file_path}') raise CommandError(f'文件不存在: {file_path}')
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
if clear: if clear:
self.stdout.write(self.style.WARNING('正在清除现有数据...')) self.stdout.write(self.style.WARNING('正在清除现有数据...'))
self._clear_data() self._clear_data()
# 导入对话 # 导入对话
for conv_data in data.get('conversations', []): for conv_data in data.get('conversations', []):
self._import_conversation(conv_data) self._import_conversation(conv_data)
# 导入配置 # 导入配置
for config in data.get('system_configs', []): for config in data.get('system_configs', []):
self._import_system_config(config) self._import_system_config(config)
self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}')) self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}'))
except Exception as e: except Exception as e:
raise CommandError(f'导入数据失败: {str(e)}') raise CommandError(f'导入数据失败: {str(e)}')
def export_data(self, file_path): def export_data(self, file_path):
"""导出数据到JSON文件""" """导出数据到JSON文件"""
try: try:
data = { data = {
'conversations': [], 'conversations': [],
'system_configs': [], 'system_configs': [],
'tags': [], 'tags': [],
'export_time': timezone.now().isoformat() 'export_time': timezone.now().isoformat()
} }
# 导出标签 # 导出标签
for tag in FeedbackTag.objects.all(): for tag in FeedbackTag.objects.all():
data['tags'].append({ data['tags'].append({
'id': str(tag.id), 'id': str(tag.id),
'tag_name': tag.tag_name, 'tag_name': tag.tag_name,
'tag_type': tag.tag_type, 'tag_type': tag.tag_type,
'description': tag.description, 'description': tag.description,
'created_at': tag.created_at.isoformat() 'created_at': tag.created_at.isoformat()
}) })
# 导出系统配置 # 导出系统配置
for config in SystemConfig.objects.all(): for config in SystemConfig.objects.all():
data['system_configs'].append({ data['system_configs'].append({
'id': str(config.id), 'id': str(config.id),
'config_key': config.config_key, 'config_key': config.config_key,
'config_value': config.config_value, 'config_value': config.config_value,
'config_type': config.config_type, 'config_type': config.config_type,
'description': config.description, 'description': config.description,
'created_at': config.created_at.isoformat(), 'created_at': config.created_at.isoformat(),
'updated_at': config.updated_at.isoformat() 'updated_at': config.updated_at.isoformat()
}) })
# 导出对话 # 导出对话
for conv in Conversation.objects.all().prefetch_related('messages'): for conv in Conversation.objects.all().prefetch_related('messages'):
conv_data = { conv_data = {
'id': str(conv.id), 'id': str(conv.id),
'user_id': str(conv.user_id), 'user_id': str(conv.user_id),
'is_submitted': conv.is_submitted, 'is_submitted': conv.is_submitted,
'created_at': conv.created_at.isoformat(), 'created_at': conv.created_at.isoformat(),
'messages': [] 'messages': []
} }
for msg in conv.messages.all().order_by('timestamp'): for msg in conv.messages.all().order_by('timestamp'):
msg_data = { msg_data = {
'id': str(msg.id), 'id': str(msg.id),
'role': msg.role, 'role': msg.role,
'content': msg.content, 'content': msg.content,
'timestamp': msg.timestamp.isoformat(), 'timestamp': msg.timestamp.isoformat(),
'feedback': [] 'feedback': []
} }
# 获取基本反馈 # 获取基本反馈
for fb in Feedback.objects.filter(message_id=msg.id): for fb in Feedback.objects.filter(message_id=msg.id):
msg_data['feedback'].append({ msg_data['feedback'].append({
'id': str(fb.id), 'id': str(fb.id),
'type': 'basic', 'type': 'basic',
'feedback_value': fb.feedback_value, 'feedback_value': fb.feedback_value,
'user_id': str(fb.user_id), 'user_id': str(fb.user_id),
'timestamp': fb.timestamp.isoformat() 'timestamp': fb.timestamp.isoformat()
}) })
# 获取详细反馈 # 获取详细反馈
for dfb in DetailedFeedback.objects.filter(message_id=msg.id): for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
msg_data['feedback'].append({ msg_data['feedback'].append({
'id': str(dfb.id), 'id': str(dfb.id),
'type': 'detailed', 'type': 'detailed',
'feedback_type': dfb.feedback_type, 'feedback_type': dfb.feedback_type,
'feedback_tags': dfb.feedback_tags, 'feedback_tags': dfb.feedback_tags,
'custom_tags': dfb.custom_tags, 'custom_tags': dfb.custom_tags,
'custom_content': dfb.custom_content, 'custom_content': dfb.custom_content,
'is_inline': dfb.is_inline, 'is_inline': dfb.is_inline,
'user_id': str(dfb.user_id), 'user_id': str(dfb.user_id),
'created_at': dfb.created_at.isoformat(), 'created_at': dfb.created_at.isoformat(),
'updated_at': dfb.updated_at.isoformat() 'updated_at': dfb.updated_at.isoformat()
}) })
conv_data['messages'].append(msg_data) conv_data['messages'].append(msg_data)
data['conversations'].append(conv_data) data['conversations'].append(conv_data)
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}')) self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}'))
except Exception as e: except Exception as e:
raise CommandError(f'导出数据失败: {str(e)}') raise CommandError(f'导出数据失败: {str(e)}')
def import_tags_from_init_file(self): def import_tags_from_init_file(self):
"""从init_tags.py导入标签""" """从init_tags.py导入标签"""
try: try:
# 定义标签数据 # 定义标签数据
positive_tags = [ positive_tags = [
('有帮助', '回答对问题有实际帮助'), ('有帮助', '回答对问题有实际帮助'),
('准确', '信息准确可靠'), ('准确', '信息准确可靠'),
('清晰', '表达清楚易懂'), ('清晰', '表达清楚易懂'),
('完整', '回答全面完整'), ('完整', '回答全面完整'),
('友好', '语调友善亲切'), ('友好', '语调友善亲切'),
('创新', '提供了新颖的观点') ('创新', '提供了新颖的观点')
] ]
negative_tags = [ negative_tags = [
('不准确', '包含错误信息'), ('不准确', '包含错误信息'),
('不相关', '回答偏离主题'), ('不相关', '回答偏离主题'),
('不完整', '回答过于简略'), ('不完整', '回答过于简略'),
('不清晰', '表达模糊难懂'), ('不清晰', '表达模糊难懂'),
('不友好', '语调生硬冷淡'), ('不友好', '语调生硬冷淡'),
('重复', '内容重复冗余') ('重复', '内容重复冗余')
] ]
# 插入正面标签 # 插入正面标签
for tag_name, description in positive_tags: for tag_name, description in positive_tags:
FeedbackTag.objects.get_or_create( FeedbackTag.objects.get_or_create(
tag_name=tag_name, tag_name=tag_name,
defaults={ defaults={
'id': str(uuid.uuid4()), 'id': str(uuid.uuid4()),
'tag_type': 'positive', 'tag_type': 'positive',
'description': description, 'description': description,
'created_at': timezone.now() 'created_at': timezone.now()
} }
) )
# 插入负面标签 # 插入负面标签
for tag_name, description in negative_tags: for tag_name, description in negative_tags:
FeedbackTag.objects.get_or_create( FeedbackTag.objects.get_or_create(
tag_name=tag_name, tag_name=tag_name,
defaults={ defaults={
'id': str(uuid.uuid4()), 'id': str(uuid.uuid4()),
'tag_type': 'negative', 'tag_type': 'negative',
'description': description, 'description': description,
'created_at': timezone.now() 'created_at': timezone.now()
} }
) )
self.stdout.write(self.style.SUCCESS('成功导入标签')) self.stdout.write(self.style.SUCCESS('成功导入标签'))
except Exception as e: except Exception as e:
raise CommandError(f'导入标签失败: {str(e)}') raise CommandError(f'导入标签失败: {str(e)}')
def _clear_data(self): def _clear_data(self):
"""清除现有数据""" """清除现有数据"""
DetailedFeedback.objects.all().delete() DetailedFeedback.objects.all().delete()
Feedback.objects.all().delete() Feedback.objects.all().delete()
Message.objects.all().delete() Message.objects.all().delete()
Conversation.objects.all().delete() Conversation.objects.all().delete()
ConversationSubmission.objects.all().delete() ConversationSubmission.objects.all().delete()
ConversationEvaluation.objects.all().delete() ConversationEvaluation.objects.all().delete()
self.stdout.write(self.style.SUCCESS('已清除现有数据')) self.stdout.write(self.style.SUCCESS('已清除现有数据'))
def _import_conversation(self, conv_data): def _import_conversation(self, conv_data):
"""导入单个对话""" """导入单个对话"""
# 检查用户是否存在 # 检查用户是否存在
user_id = conv_data.get('user_id') user_id = conv_data.get('user_id')
if not User.objects.filter(id=user_id).exists(): if not User.objects.filter(id=user_id).exists():
self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户')) self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户'))
user = User.objects.filter(is_superuser=True).first() user = User.objects.filter(is_superuser=True).first()
if not user: if not user:
raise CommandError('找不到管理员用户') raise CommandError('找不到管理员用户')
user_id = user.id user_id = user.id
# 创建对话 # 创建对话
conv = Conversation.objects.create( conv = Conversation.objects.create(
id=conv_data.get('id', str(uuid.uuid4())), id=conv_data.get('id', str(uuid.uuid4())),
user_id=user_id, user_id=user_id,
is_submitted=conv_data.get('is_submitted', False), is_submitted=conv_data.get('is_submitted', False),
created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat())) created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat()))
) )
# 创建消息 # 创建消息
for msg_data in conv_data.get('messages', []): for msg_data in conv_data.get('messages', []):
msg = Message.objects.create( msg = Message.objects.create(
id=msg_data.get('id', str(uuid.uuid4())), id=msg_data.get('id', str(uuid.uuid4())),
conversation=conv, conversation=conv,
role=msg_data.get('role', 'user'), role=msg_data.get('role', 'user'),
content=msg_data.get('content', ''), content=msg_data.get('content', ''),
timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat())) timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat()))
) )
# 创建反馈 # 创建反馈
for fb_data in msg_data.get('feedback', []): for fb_data in msg_data.get('feedback', []):
if fb_data.get('type') == 'basic': if fb_data.get('type') == 'basic':
Feedback.objects.create( Feedback.objects.create(
id=fb_data.get('id', str(uuid.uuid4())), id=fb_data.get('id', str(uuid.uuid4())),
message=msg, message=msg,
conversation=conv, conversation=conv,
user_id=fb_data.get('user_id', user_id), user_id=fb_data.get('user_id', user_id),
feedback_value=fb_data.get('feedback_value', 0), feedback_value=fb_data.get('feedback_value', 0),
timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat())) timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat()))
) )
elif fb_data.get('type') == 'detailed': elif fb_data.get('type') == 'detailed':
DetailedFeedback.objects.create( DetailedFeedback.objects.create(
id=fb_data.get('id', str(uuid.uuid4())), id=fb_data.get('id', str(uuid.uuid4())),
message=msg, message=msg,
conversation=conv, conversation=conv,
user_id=fb_data.get('user_id', user_id), user_id=fb_data.get('user_id', user_id),
feedback_type=fb_data.get('feedback_type', 'neutral'), feedback_type=fb_data.get('feedback_type', 'neutral'),
feedback_tags=fb_data.get('feedback_tags', '[]'), feedback_tags=fb_data.get('feedback_tags', '[]'),
custom_tags=fb_data.get('custom_tags', ''), custom_tags=fb_data.get('custom_tags', ''),
custom_content=fb_data.get('custom_content', ''), custom_content=fb_data.get('custom_content', ''),
is_inline=fb_data.get('is_inline', True), is_inline=fb_data.get('is_inline', True),
created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())), created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())),
updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat())) updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat()))
) )
def _import_system_config(self, config_data): def _import_system_config(self, config_data):
"""导入系统配置""" """导入系统配置"""
SystemConfig.objects.update_or_create( SystemConfig.objects.update_or_create(
config_key=config_data.get('config_key'), config_key=config_data.get('config_key'),
defaults={ defaults={
'id': config_data.get('id', str(uuid.uuid4())), 'id': config_data.get('id', str(uuid.uuid4())),
'config_value': config_data.get('config_value', ''), 'config_value': config_data.get('config_value', ''),
'config_type': config_data.get('config_type', 'string'), 'config_type': config_data.get('config_type', 'string'),
'description': config_data.get('description', ''), 'description': config_data.get('description', ''),
'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())), 'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())),
'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat())) 'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat()))
} }
) )

View File

@ -1,4 +1,4 @@
# Generated by Django 5.1.5 on 2025-06-09 08:28 # Generated by Django 5.2.1 on 2025-06-10 08:10
import django.db.models.deletion import django.db.models.deletion
import django.utils.timezone import django.utils.timezone
@ -12,6 +12,7 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
('chat', '0002_negotiationchat'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL),
] ]
@ -30,6 +31,17 @@ class Migration(migrations.Migration):
'verbose_name_plural': '反馈标签', 'verbose_name_plural': '反馈标签',
}, },
), ),
migrations.CreateModel(
name='RLHFConversation',
fields=[
('negotiation_chat', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='rlhf_extension', serialize=False, to='chat.negotiationchat')),
('is_submitted', models.BooleanField(default=False)),
],
options={
'verbose_name': 'RLHF对话',
'verbose_name_plural': 'RLHF对话',
},
),
migrations.CreateModel( migrations.CreateModel(
name='SystemConfig', name='SystemConfig',
fields=[ fields=[
@ -46,23 +58,11 @@ class Migration(migrations.Migration):
'verbose_name_plural': '系统配置', 'verbose_name_plural': '系统配置',
}, },
), ),
migrations.CreateModel(
name='Conversation',
fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('is_submitted', models.BooleanField(default=False)),
('created_at', models.DateTimeField(default=django.utils.timezone.now)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='conversations', to=settings.AUTH_USER_MODEL)),
],
options={
'verbose_name': '对话',
'verbose_name_plural': '对话',
},
),
migrations.CreateModel( migrations.CreateModel(
name='ConversationSubmission', name='ConversationSubmission',
fields=[ fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('conversation_id', models.CharField(max_length=100)),
('title', models.CharField(blank=True, max_length=255, null=True)), ('title', models.CharField(blank=True, max_length=255, null=True)),
('description', models.TextField(blank=True, null=True)), ('description', models.TextField(blank=True, null=True)),
('status', models.CharField(choices=[('submitted', '已提交'), ('reviewed', '已审核'), ('accepted', '已接受'), ('rejected', '已拒绝')], default='submitted', max_length=20)), ('status', models.CharField(choices=[('submitted', '已提交'), ('reviewed', '已审核'), ('accepted', '已接受'), ('rejected', '已拒绝')], default='submitted', max_length=20)),
@ -72,7 +72,6 @@ class Migration(migrations.Migration):
('reviewed_at', models.DateTimeField(blank=True, null=True)), ('reviewed_at', models.DateTimeField(blank=True, null=True)),
('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)),
('updated_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to='rlhf.conversation')),
('reviewer', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reviewed_submissions', to=settings.AUTH_USER_MODEL)), ('reviewer', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reviewed_submissions', to=settings.AUTH_USER_MODEL)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to=settings.AUTH_USER_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to=settings.AUTH_USER_MODEL)),
], ],
@ -81,40 +80,11 @@ class Migration(migrations.Migration):
'verbose_name_plural': '对话提交', 'verbose_name_plural': '对话提交',
}, },
), ),
migrations.CreateModel(
name='Message',
fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('role', models.CharField(choices=[('user', '用户'), ('assistant', '助手'), ('system', '系统')], max_length=20)),
('content', models.TextField()),
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='rlhf.conversation')),
],
options={
'verbose_name': '消息',
'verbose_name_plural': '消息',
'ordering': ['timestamp'],
},
),
migrations.CreateModel(
name='Feedback',
fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('feedback_value', models.IntegerField()),
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.conversation')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)),
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.message')),
],
options={
'verbose_name': '反馈',
'verbose_name_plural': '反馈',
},
),
migrations.CreateModel( migrations.CreateModel(
name='DetailedFeedback', name='DetailedFeedback',
fields=[ fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('conversation_id', models.CharField(max_length=100)),
('feedback_type', models.CharField(choices=[('positive', '正面'), ('negative', '负面'), ('neutral', '中性')], max_length=20)), ('feedback_type', models.CharField(choices=[('positive', '正面'), ('negative', '负面'), ('neutral', '中性')], max_length=20)),
('feedback_tags', models.TextField(blank=True, null=True)), ('feedback_tags', models.TextField(blank=True, null=True)),
('custom_tags', models.TextField(blank=True, null=True)), ('custom_tags', models.TextField(blank=True, null=True)),
@ -122,31 +92,45 @@ class Migration(migrations.Migration):
('is_inline', models.BooleanField(default=True)), ('is_inline', models.BooleanField(default=True)),
('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)),
('updated_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.conversation')), ('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_detailed_feedback', to='chat.chathistory')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to=settings.AUTH_USER_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to=settings.AUTH_USER_MODEL)),
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.message')),
], ],
options={ options={
'verbose_name': '详细反馈', 'verbose_name': '详细反馈',
'verbose_name_plural': '详细反馈', 'verbose_name_plural': '详细反馈',
}, },
), ),
migrations.CreateModel(
name='Feedback',
fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('conversation_id', models.CharField(max_length=100)),
('feedback_value', models.IntegerField()),
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_feedback', to='chat.chathistory')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)),
],
options={
'verbose_name': '反馈',
'verbose_name_plural': '反馈',
},
),
migrations.CreateModel( migrations.CreateModel(
name='ConversationEvaluation', name='ConversationEvaluation',
fields=[ fields=[
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)), ('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
('conversation_id', models.CharField(max_length=100)),
('overall_feeling', models.TextField(blank=True, null=True)), ('overall_feeling', models.TextField(blank=True, null=True)),
('has_logical_issues', models.CharField(choices=[('yes', ''), ('no', ''), ('unsure', '不确定')], max_length=10)), ('has_logical_issues', models.CharField(choices=[('yes', ''), ('no', ''), ('unsure', '不确定')], max_length=10)),
('needs_satisfied', models.CharField(choices=[('yes', ''), ('no', ''), ('partially', '部分')], max_length=10)), ('needs_satisfied', models.CharField(choices=[('yes', ''), ('no', ''), ('partially', '部分')], max_length=10)),
('created_at', models.DateTimeField(default=django.utils.timezone.now)), ('created_at', models.DateTimeField(default=django.utils.timezone.now)),
('updated_at', models.DateTimeField(default=django.utils.timezone.now)), ('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to='rlhf.conversation')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to=settings.AUTH_USER_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to=settings.AUTH_USER_MODEL)),
], ],
options={ options={
'verbose_name': '对话评估', 'verbose_name': '对话评估',
'verbose_name_plural': '对话评估', 'verbose_name_plural': '对话评估',
'unique_together': {('conversation', 'user')}, 'unique_together': {('conversation_id', 'user')},
}, },
), ),
] ]

View File

@ -2,60 +2,91 @@ from django.db import models
import uuid import uuid
from django.utils import timezone from django.utils import timezone
from apps.user.models import User from apps.user.models import User
from apps.chat.models import ChatHistory, NegotiationChat
from apps.daren_detail.models import CreatorProfile
from apps.brands.models import Product
class Conversation(models.Model): # 使用NegotiationChat替代Conversation
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) # class Conversation(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations') # id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
# user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations')
# is_submitted = models.BooleanField(default=False)
# created_at = models.DateTimeField(default=timezone.now)
#
# class Meta:
# verbose_name = '对话'
# verbose_name_plural = '对话'
#
# def __str__(self):
# return f"Conversation {self.id[:8]}"
# 为NegotiationChat添加RLHF所需字段的代理模型
class RLHFConversation(models.Model):
"""对NegotiationChat的扩展添加RLHF所需字段"""
negotiation_chat = models.OneToOneField(NegotiationChat, on_delete=models.CASCADE, primary_key=True, related_name='rlhf_extension')
is_submitted = models.BooleanField(default=False) is_submitted = models.BooleanField(default=False)
created_at = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '对话' verbose_name = 'RLHF对话'
verbose_name_plural = '对话' verbose_name_plural = 'RLHF对话'
def __str__(self): def __str__(self):
return f"Conversation {self.id[:8]}" return f"RLHF Conversation {self.negotiation_chat.conversation_id[:8]}"
@property
def id(self):
return self.negotiation_chat.conversation_id
@property
def user(self):
# 从谈判关联的用户中获取
return self.negotiation_chat.negotiation.user
@property
def created_at(self):
return self.negotiation_chat.created_at
class Message(models.Model): # 使用ChatHistory替代Message
ROLE_CHOICES = ( # class Message(models.Model):
('user', '用户'), # ROLE_CHOICES = (
('assistant', '助手'), # ('user', '用户'),
('system', '系统'), # ('assistant', '助手'),
) # ('system', '系统'),
# )
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) #
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages') # id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
role = models.CharField(max_length=20, choices=ROLE_CHOICES) # conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages')
content = models.TextField() # role = models.CharField(max_length=20, choices=ROLE_CHOICES)
timestamp = models.DateTimeField(default=timezone.now) # content = models.TextField()
# timestamp = models.DateTimeField(default=timezone.now)
class Meta: #
# class Meta:
verbose_name = '消息' #
verbose_name_plural = '消息' # verbose_name = '消息'
ordering = ['timestamp'] # verbose_name_plural = '消息'
# ordering = ['timestamp']
def __str__(self): #
return f"{self.role}: {self.content[:50]}..." # def __str__(self):
# return f"{self.role}: {self.content[:50]}..."
class Feedback(models.Model): class Feedback(models.Model):
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='feedback') message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_feedback')
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='feedback') conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='feedback') user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='feedback')
feedback_value = models.IntegerField() feedback_value = models.IntegerField()
timestamp = models.DateTimeField(default=timezone.now) timestamp = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '反馈' verbose_name = '反馈'
verbose_name_plural = '反馈' verbose_name_plural = '反馈'
def __str__(self): def __str__(self):
return f"Feedback on {self.message.id[:8]}" return f"Feedback on {self.message.id}"
class FeedbackTag(models.Model): class FeedbackTag(models.Model):
@ -71,7 +102,6 @@ class FeedbackTag(models.Model):
created_at = models.DateTimeField(default=timezone.now) created_at = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '反馈标签' verbose_name = '反馈标签'
verbose_name_plural = '反馈标签' verbose_name_plural = '反馈标签'
@ -87,8 +117,8 @@ class DetailedFeedback(models.Model):
) )
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='detailed_feedback') message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_detailed_feedback')
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='detailed_feedback') conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='detailed_feedback') user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='detailed_feedback')
feedback_type = models.CharField(max_length=20, choices=FEEDBACK_TYPE_CHOICES) feedback_type = models.CharField(max_length=20, choices=FEEDBACK_TYPE_CHOICES)
feedback_tags = models.TextField(blank=True, null=True) # JSON格式存储多个标签 feedback_tags = models.TextField(blank=True, null=True) # JSON格式存储多个标签
@ -99,12 +129,11 @@ class DetailedFeedback(models.Model):
updated_at = models.DateTimeField(default=timezone.now) updated_at = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '详细反馈' verbose_name = '详细反馈'
verbose_name_plural = '详细反馈' verbose_name_plural = '详细反馈'
def __str__(self): def __str__(self):
return f"{self.feedback_type} feedback on {self.message.id[:8]}" return f"{self.feedback_type} feedback on {self.message.id}"
class ConversationSubmission(models.Model): class ConversationSubmission(models.Model):
@ -116,7 +145,7 @@ class ConversationSubmission(models.Model):
) )
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='submissions') conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='submissions') user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='submissions')
title = models.CharField(max_length=255, blank=True, null=True) title = models.CharField(max_length=255, blank=True, null=True)
description = models.TextField(blank=True, null=True) description = models.TextField(blank=True, null=True)
@ -130,12 +159,11 @@ class ConversationSubmission(models.Model):
updated_at = models.DateTimeField(default=timezone.now) updated_at = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '对话提交' verbose_name = '对话提交'
verbose_name_plural = '对话提交' verbose_name_plural = '对话提交'
def __str__(self): def __str__(self):
return f"Submission for {self.conversation.id[:8]}" return f"Submission for {self.conversation_id[:8]}"
class ConversationEvaluation(models.Model): class ConversationEvaluation(models.Model):
@ -152,7 +180,7 @@ class ConversationEvaluation(models.Model):
) )
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False) id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='evaluations') conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='evaluations') user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='evaluations')
overall_feeling = models.TextField(blank=True, null=True) overall_feeling = models.TextField(blank=True, null=True)
has_logical_issues = models.CharField(max_length=10, choices=LOGICAL_CHOICES) has_logical_issues = models.CharField(max_length=10, choices=LOGICAL_CHOICES)
@ -161,13 +189,12 @@ class ConversationEvaluation(models.Model):
updated_at = models.DateTimeField(default=timezone.now) updated_at = models.DateTimeField(default=timezone.now)
class Meta: class Meta:
verbose_name = '对话评估' verbose_name = '对话评估'
verbose_name_plural = '对话评估' verbose_name_plural = '对话评估'
unique_together = ('conversation', 'user') unique_together = ('conversation_id', 'user')
def __str__(self): def __str__(self):
return f"Evaluation for {self.conversation.id[:8]}" return f"Evaluation for {self.conversation_id[:8]}"
class SystemConfig(models.Model): class SystemConfig(models.Model):

View File

@ -1,29 +1,43 @@
from rest_framework import serializers from rest_framework import serializers
from .models import ( from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig ConversationSubmission, ConversationEvaluation, SystemConfig,
RLHFConversation, NegotiationChat, ChatHistory
) )
from apps.user.serializers import UserSerializer from apps.user.serializers import UserSerializer
class ConversationSerializer(serializers.ModelSerializer): class ConversationSerializer(serializers.ModelSerializer):
id = serializers.CharField(source='conversation_id', read_only=True)
user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True)
is_submitted = serializers.SerializerMethodField()
created_at = serializers.DateTimeField(source='created_at', read_only=True)
class Meta: class Meta:
model = Conversation model = NegotiationChat
fields = ['id', 'user', 'is_submitted', 'created_at'] fields = ['id', 'user', 'is_submitted', 'created_at']
read_only_fields = ['id', 'created_at', 'user']
def get_is_submitted(self, obj):
# 尝试获取RLHF扩展
try:
return obj.rlhf_extension.is_submitted
except RLHFConversation.DoesNotExist:
return False
class MessageSerializer(serializers.ModelSerializer): class MessageSerializer(serializers.ModelSerializer):
conversation = serializers.CharField(source='conversation_id', read_only=True)
timestamp = serializers.DateTimeField(source='created_at', read_only=True)
class Meta: class Meta:
model = Message model = ChatHistory
fields = ['id', 'conversation', 'role', 'content', 'timestamp'] fields = ['id', 'conversation', 'role', 'content', 'timestamp']
read_only_fields = ['id', 'timestamp']
class FeedbackSerializer(serializers.ModelSerializer): class FeedbackSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Feedback model = Feedback
fields = ['id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp'] fields = ['id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp']
read_only_fields = ['id', 'timestamp'] read_only_fields = ['id', 'timestamp']
@ -37,7 +51,11 @@ class FeedbackTagSerializer(serializers.ModelSerializer):
class DetailedFeedbackSerializer(serializers.ModelSerializer): class DetailedFeedbackSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = DetailedFeedback model = DetailedFeedback
fields = ['id', 'message', 'conversation', 'user', 'feedback_type', 'feedback_tags', 'custom_tags', 'custom_content', 'is_inline', 'created_at', 'updated_at'] fields = [
'id', 'message', 'conversation_id', 'user', 'feedback_type',
'feedback_tags', 'custom_tags', 'custom_content', 'is_inline',
'created_at', 'updated_at'
]
read_only_fields = ['id', 'created_at', 'updated_at'] read_only_fields = ['id', 'created_at', 'updated_at']
@ -47,38 +65,72 @@ class ConversationSubmissionSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ConversationSubmission model = ConversationSubmission
fields = ['id', 'conversation', 'user', 'user_details', 'title', 'description', 'status', 'quality_score', 'reviewer', 'reviewer_details', 'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'] fields = [
'id', 'conversation_id', 'user', 'title', 'description',
'status', 'quality_score', 'reviewer', 'reviewer_details',
'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'
]
read_only_fields = ['id', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'] read_only_fields = ['id', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at']
class ConversationEvaluationSerializer(serializers.ModelSerializer): class ConversationEvaluationSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ConversationEvaluation model = ConversationEvaluation
fields = ['id', 'conversation', 'user', 'overall_feeling', 'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at'] fields = [
'id', 'conversation_id', 'user', 'overall_feeling',
'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at'
]
read_only_fields = ['id', 'created_at', 'updated_at'] read_only_fields = ['id', 'created_at', 'updated_at']
class SystemConfigSerializer(serializers.ModelSerializer): class SystemConfigSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SystemConfig model = SystemConfig
fields = ['id', 'config_key', 'config_value', 'config_type', 'description', 'created_at', 'updated_at'] fields = [
'id', 'config_key', 'config_value', 'config_type',
'description', 'created_at', 'updated_at'
]
read_only_fields = ['id', 'created_at', 'updated_at'] read_only_fields = ['id', 'created_at', 'updated_at']
class ConversationWithMessagesSerializer(serializers.ModelSerializer): class ConversationWithMessagesSerializer(serializers.ModelSerializer):
messages = MessageSerializer(many=True, read_only=True) id = serializers.CharField(source='conversation_id', read_only=True)
user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True)
is_submitted = serializers.SerializerMethodField()
created_at = serializers.DateTimeField(source='created_at', read_only=True)
messages = serializers.SerializerMethodField()
class Meta: class Meta:
model = Conversation model = NegotiationChat
fields = ['id', 'user', 'is_submitted', 'created_at', 'messages'] fields = ['id', 'user', 'is_submitted', 'created_at', 'messages']
read_only_fields = ['id', 'created_at']
def get_is_submitted(self, obj):
try:
return obj.rlhf_extension.is_submitted
except RLHFConversation.DoesNotExist:
return False
def get_messages(self, obj):
messages = ChatHistory.objects.filter(
conversation_id=obj.conversation_id
).order_by('created_at')
return MessageSerializer(messages, many=True).data
class MessageWithFeedbackSerializer(serializers.ModelSerializer): class MessageWithFeedbackSerializer(serializers.ModelSerializer):
feedback = FeedbackSerializer(many=True, read_only=True) conversation = serializers.CharField(source='conversation_id', read_only=True)
detailed_feedback = DetailedFeedbackSerializer(many=True, read_only=True) timestamp = serializers.DateTimeField(source='created_at', read_only=True)
feedback = serializers.SerializerMethodField()
detailed_feedback = serializers.SerializerMethodField()
class Meta: class Meta:
model = Message model = ChatHistory
fields = ['id', 'conversation', 'role', 'content', 'timestamp', 'feedback', 'detailed_feedback'] fields = ['id', 'conversation', 'role', 'content', 'timestamp', 'feedback', 'detailed_feedback']
read_only_fields = ['id', 'timestamp']
def get_feedback(self, obj):
feedback = Feedback.objects.filter(message=obj)
return FeedbackSerializer(feedback, many=True).data
def get_detailed_feedback(self, obj):
detailed_feedback = DetailedFeedback.objects.filter(message=obj)
return DetailedFeedbackSerializer(detailed_feedback, many=True).data

View File

@ -5,8 +5,9 @@ from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.pagination import PageNumberPagination from rest_framework.pagination import PageNumberPagination
from .models import ( from .models import (
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, Feedback, FeedbackTag, DetailedFeedback,
ConversationSubmission, ConversationEvaluation, SystemConfig ConversationSubmission, ConversationEvaluation, SystemConfig,
NegotiationChat, ChatHistory, RLHFConversation, CreatorProfile, Product
) )
from .serializers import ( from .serializers import (
ConversationSerializer, MessageSerializer, FeedbackSerializer, ConversationSerializer, MessageSerializer, FeedbackSerializer,
@ -98,22 +99,44 @@ class StandardResponseMixin:
class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Conversation.objects.all() queryset = NegotiationChat.objects.all()
serializer_class = ConversationSerializer serializer_class = ConversationSerializer
authentication_classes = [CustomTokenAuthentication] authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
return Conversation.objects.filter(user=user).order_by('-created_at') return NegotiationChat.objects.filter(negotiation__user=user).order_by('-updated_at')
def perform_create(self, serializer): def perform_create(self, serializer):
serializer.save(user=self.request.user) creator = CreatorProfile.objects.first()
product = Product.objects.first()
negotiation = Negotiation.objects.create(
user=self.request.user,
status='active'
)
chat = NegotiationChat.objects.create(
negotiation=negotiation,
conversation_id=str(uuid.uuid4()),
creator=creator,
product=product
)
RLHFConversation.objects.create(
negotiation_chat=chat,
is_submitted=False
)
serializer.instance = chat
@action(detail=True, methods=['get']) @action(detail=True, methods=['get'])
def messages(self, request, pk=None): def messages(self, request, pk=None):
conversation = self.get_object() conversation = self.get_object()
messages = Message.objects.filter(conversation=conversation).order_by('timestamp') messages = ChatHistory.objects.filter(
conversation_id=conversation.conversation_id
).order_by('created_at')
serializer = MessageSerializer(messages, many=True) serializer = MessageSerializer(messages, many=True)
return self.get_standard_response(data=serializer.data) return self.get_standard_response(data=serializer.data)
@ -130,27 +153,26 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
status_code=status.HTTP_400_BAD_REQUEST status_code=status.HTTP_400_BAD_REQUEST
) )
# 创建用户消息 knowledge_base = KnowledgeBase.objects.first()
user_message = Message.objects.create(
id=str(uuid.uuid4()), user_message = ChatHistory.objects.create(
conversation=conversation, user=request.user,
knowledge_base=knowledge_base,
conversation_id=conversation.conversation_id,
role='user', role='user',
content=content content=content
) )
# 这里需要调用AI服务获取回复 ai_response = self._generate_ai_response(content, conversation)
# 示例调用SiliconFlow或其他AI服务
ai_response = self._generate_ai_response(user_message.content, conversation)
# 创建AI回复消息 ai_message = ChatHistory.objects.create(
ai_message = Message.objects.create( user=request.user,
id=str(uuid.uuid4()), knowledge_base=knowledge_base,
conversation=conversation, conversation_id=conversation.conversation_id,
role='assistant', role='assistant',
content=ai_response content=ai_response
) )
# 更新用户的标注统计
self._update_annotation_stats(request.user.id) self._update_annotation_stats(request.user.id)
messages = [ messages = [
@ -166,7 +188,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
title = request.data.get('title', '') title = request.data.get('title', '')
description = request.data.get('description', '') description = request.data.get('description', '')
if conversation.is_submitted: rlhf_conv, created = RLHFConversation.objects.get_or_create(
negotiation_chat=conversation,
defaults={'is_submitted': False}
)
if rlhf_conv.is_submitted:
return self.get_standard_response( return self.get_standard_response(
code=400, code=400,
message='该对话已提交', message='该对话已提交',
@ -174,14 +201,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
status_code=status.HTTP_400_BAD_REQUEST status_code=status.HTTP_400_BAD_REQUEST
) )
# 更新对话为已提交状态 rlhf_conv.is_submitted = True
conversation.is_submitted = True rlhf_conv.save()
conversation.save()
# 创建提交记录
submission = ConversationSubmission.objects.create( submission = ConversationSubmission.objects.create(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
conversation=conversation, conversation_id=conversation.conversation_id,
user=request.user, user=request.user,
title=title, title=title,
description=description, description=description,
@ -189,12 +214,11 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
submitted_at=timezone.now() submitted_at=timezone.now()
) )
# 记录活动日志
UserActivityLog.objects.create( UserActivityLog.objects.create(
user=request.user, user=request.user,
action_type='conversation_submit', action_type='conversation_submit',
target_type='conversation', target_type='conversation',
target_id=str(conversation.id), target_id=conversation.conversation_id,
details={'title': title} details={'title': title}
) )
@ -207,7 +231,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
def resume(self, request, pk=None): def resume(self, request, pk=None):
conversation = self.get_object() conversation = self.get_object()
if not conversation.is_submitted: rlhf_conv, created = RLHFConversation.objects.get_or_create(
negotiation_chat=conversation,
defaults={'is_submitted': False}
)
if not rlhf_conv.is_submitted:
return self.get_standard_response( return self.get_standard_response(
code=400, code=400,
message='该对话未提交,无需恢复', message='该对话未提交,无需恢复',
@ -215,25 +244,22 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
status_code=status.HTTP_400_BAD_REQUEST status_code=status.HTTP_400_BAD_REQUEST
) )
# 更新对话为未提交状态 rlhf_conv.is_submitted = False
conversation.is_submitted = False rlhf_conv.save()
conversation.save()
# 获取最新的提交记录
submission = ConversationSubmission.objects.filter( submission = ConversationSubmission.objects.filter(
conversation=conversation conversation_id=conversation.conversation_id
).order_by('-submitted_at').first() ).order_by('-submitted_at').first()
if submission and submission.status == 'submitted': if submission and submission.status == 'submitted':
submission.status = 'rejected' submission.status = 'rejected'
submission.save() submission.save()
# 记录活动日志
UserActivityLog.objects.create( UserActivityLog.objects.create(
user=request.user, user=request.user,
action_type='conversation_resume', action_type='conversation_resume',
target_type='conversation', target_type='conversation',
target_id=str(conversation.id) target_id=conversation.conversation_id
) )
return self.get_standard_response( return self.get_standard_response(
@ -264,7 +290,7 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
sf_client.set_system_message(system_prompt_config.config_value) sf_client.set_system_message(system_prompt_config.config_value)
# 获取历史消息作为上下文 # 获取历史消息作为上下文
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp') history_messages = ChatHistory.objects.filter(conversation_id=conversation.conversation_id).order_by('created_at')
# 添加历史消息到客户端 # 添加历史消息到客户端
for msg in history_messages: for msg in history_messages:
@ -385,28 +411,28 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
def _get_recent_conversations(self, user_id, limit=5): def _get_recent_conversations(self, user_id, limit=5):
"""获取用户最近的对话""" """获取用户最近的对话"""
conversations = Conversation.objects.filter( conversations = NegotiationChat.objects.filter(
user_id=user_id negotiation__user_id=user_id
).order_by('-created_at')[:limit] ).order_by('-updated_at')[:limit]
result = [] result = []
for conv in conversations: for conv in conversations:
# 获取最后一条消息内容作为对话摘要 # 获取最后一条消息内容作为对话摘要
last_message = Message.objects.filter( last_message = ChatHistory.objects.filter(
conversation_id=conv.id conversation_id=conv.conversation_id
).order_by('-timestamp').first() ).order_by('-created_at').first()
# 统计消息数 # 统计消息数
message_count = Message.objects.filter(conversation_id=conv.id).count() message_count = ChatHistory.objects.filter(conversation_id=conv.conversation_id).count()
# 统计反馈数 # 统计反馈数
feedback_count = Feedback.objects.filter(conversation_id=conv.id).count() feedback_count = Feedback.objects.filter(conversation_id=conv.conversation_id).count()
detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.id).count() detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.conversation_id).count()
result.append({ result.append({
'id': str(conv.id), 'id': str(conv.conversation_id),
'created_at': conv.created_at.isoformat(), 'created_at': conv.updated_at.isoformat(),
'is_submitted': conv.is_submitted, 'is_submitted': conv.negotiation.is_submitted,
'message_count': message_count, 'message_count': message_count,
'feedback_count': feedback_count + detailed_count, 'feedback_count': feedback_count + detailed_count,
'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "") 'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "")
@ -416,12 +442,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
def _get_conversation_stats(self, user_id): def _get_conversation_stats(self, user_id):
"""获取对话统计""" """获取对话统计"""
total_conversations = Conversation.objects.filter(user_id=user_id).count() total_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id).count()
submitted_conversations = Conversation.objects.filter(user_id=user_id, is_submitted=True).count() submitted_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id, negotiation__is_submitted=True).count()
# 对话消息统计 # 对话消息统计
message_stats = Message.objects.filter( message_stats = ChatHistory.objects.filter(
conversation__user_id=user_id conversation__negotiation__user_id=user_id
).aggregate( ).aggregate(
total=Count('id'), total=Count('id'),
user_messages=Count('id', filter=Q(role='user')), user_messages=Count('id', filter=Q(role='user')),
@ -620,14 +646,21 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet): class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet):
queryset = Message.objects.all() queryset = ChatHistory.objects.all()
serializer_class = MessageSerializer serializer_class = MessageSerializer
authentication_classes = [CustomTokenAuthentication] authentication_classes = [CustomTokenAuthentication]
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
def get_queryset(self): def get_queryset(self):
user = self.request.user user = self.request.user
return Message.objects.filter(conversation__user=user).order_by('timestamp') # 获取用户参与的所有对话的ID
user_conversation_ids = NegotiationChat.objects.filter(
negotiation__user=user
).values_list('conversation_id', flat=True)
# 筛选这些对话中的消息
return ChatHistory.objects.filter(
conversation_id__in=user_conversation_ids
).order_by('created_at')
class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet): class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):