修改rlhf
This commit is contained in:
parent
6a0ae5b132
commit
954e314afe
@ -1,24 +1,40 @@
|
||||
from django.contrib import admin
|
||||
from .models import (
|
||||
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
||||
RLHFConversation, NegotiationChat, ChatHistory, Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig
|
||||
)
|
||||
|
||||
|
||||
@admin.register(Conversation)
|
||||
class ConversationAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'user', 'is_submitted', 'created_at')
|
||||
list_filter = ('is_submitted', 'created_at')
|
||||
search_fields = ('id', 'user__username')
|
||||
@admin.register(RLHFConversation)
|
||||
class RLHFConversationAdmin(admin.ModelAdmin):
|
||||
list_display = ('negotiation_chat', 'id', 'user', 'is_submitted', 'created_at')
|
||||
list_filter = ('is_submitted',)
|
||||
search_fields = ('negotiation_chat__conversation_id', 'negotiation_chat__negotiation__user__username')
|
||||
|
||||
def id(self, obj):
|
||||
return obj.negotiation_chat.conversation_id
|
||||
|
||||
def user(self, obj):
|
||||
return obj.negotiation_chat.negotiation.user
|
||||
|
||||
def created_at(self, obj):
|
||||
return obj.negotiation_chat.created_at
|
||||
|
||||
|
||||
@admin.register(NegotiationChat)
|
||||
class NegotiationChatAdmin(admin.ModelAdmin):
|
||||
list_display = ('conversation_id', 'negotiation', 'creator', 'product', 'created_at', 'updated_at')
|
||||
list_filter = ('created_at', 'updated_at')
|
||||
search_fields = ('conversation_id', 'negotiation__user__username', 'creator__name', 'product__name')
|
||||
date_hierarchy = 'created_at'
|
||||
|
||||
|
||||
@admin.register(Message)
|
||||
class MessageAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'conversation', 'role', 'short_content', 'timestamp')
|
||||
list_filter = ('role', 'timestamp')
|
||||
search_fields = ('id', 'conversation__id', 'content')
|
||||
date_hierarchy = 'timestamp'
|
||||
@admin.register(ChatHistory)
|
||||
class ChatHistoryAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'conversation_id', 'user', 'role', 'short_content', 'created_at')
|
||||
list_filter = ('role', 'created_at')
|
||||
search_fields = ('id', 'conversation_id', 'content')
|
||||
date_hierarchy = 'created_at'
|
||||
|
||||
def short_content(self, obj):
|
||||
return obj.content[:50] + '...' if len(obj.content) > 50 else obj.content
|
||||
@ -27,9 +43,9 @@ class MessageAdmin(admin.ModelAdmin):
|
||||
|
||||
@admin.register(Feedback)
|
||||
class FeedbackAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp')
|
||||
list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp')
|
||||
list_filter = ('feedback_value', 'timestamp')
|
||||
search_fields = ('id', 'message__id', 'conversation__id', 'user__username')
|
||||
search_fields = ('id', 'message__id', 'conversation_id', 'user__username')
|
||||
date_hierarchy = 'timestamp'
|
||||
|
||||
|
||||
@ -42,25 +58,25 @@ class FeedbackTagAdmin(admin.ModelAdmin):
|
||||
|
||||
@admin.register(DetailedFeedback)
|
||||
class DetailedFeedbackAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'message', 'conversation', 'user', 'feedback_type', 'is_inline', 'created_at')
|
||||
list_display = ('id', 'message', 'conversation_id', 'user', 'feedback_type', 'is_inline', 'created_at')
|
||||
list_filter = ('feedback_type', 'is_inline', 'created_at')
|
||||
search_fields = ('id', 'message__id', 'conversation__id', 'user__username', 'custom_content')
|
||||
search_fields = ('id', 'message__id', 'conversation_id', 'user__username', 'custom_content')
|
||||
date_hierarchy = 'created_at'
|
||||
|
||||
|
||||
@admin.register(ConversationSubmission)
|
||||
class ConversationSubmissionAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'conversation', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at')
|
||||
list_display = ('id', 'conversation_id', 'user', 'title', 'status', 'quality_score', 'reviewer', 'submitted_at', 'reviewed_at')
|
||||
list_filter = ('status', 'quality_score', 'submitted_at', 'reviewed_at')
|
||||
search_fields = ('id', 'conversation__id', 'user__username', 'title', 'description', 'reviewer__username')
|
||||
search_fields = ('id', 'conversation_id', 'user__username', 'title', 'description', 'reviewer__username')
|
||||
date_hierarchy = 'submitted_at'
|
||||
|
||||
|
||||
@admin.register(ConversationEvaluation)
|
||||
class ConversationEvaluationAdmin(admin.ModelAdmin):
|
||||
list_display = ('id', 'conversation', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at')
|
||||
list_display = ('id', 'conversation_id', 'user', 'has_logical_issues', 'needs_satisfied', 'created_at')
|
||||
list_filter = ('has_logical_issues', 'needs_satisfied', 'created_at')
|
||||
search_fields = ('id', 'conversation__id', 'user__username', 'overall_feeling')
|
||||
search_fields = ('id', 'conversation_id', 'user__username', 'overall_feeling')
|
||||
date_hierarchy = 'created_at'
|
||||
|
||||
|
||||
|
@ -1,368 +1,368 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag
|
||||
from django.db.models import Count, Avg, Sum, Q, F
|
||||
from django.utils import timezone
|
||||
from django.contrib.auth import get_user_model
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '分析RLHF反馈数据,生成统计报告'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--export',
|
||||
action='store_true',
|
||||
help='导出数据到JSON文件',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--days',
|
||||
type=int,
|
||||
default=30,
|
||||
help='分析最近的天数',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.stdout.write(self.style.SUCCESS("=" * 60))
|
||||
self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告"))
|
||||
self.stdout.write(self.style.SUCCESS("=" * 60))
|
||||
|
||||
# 基本统计
|
||||
feedback_stats = self.get_feedback_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:"))
|
||||
self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}")
|
||||
self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)")
|
||||
self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}")
|
||||
self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}")
|
||||
|
||||
# 对话统计
|
||||
conv_stats = self.get_conversation_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:"))
|
||||
self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}")
|
||||
self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}")
|
||||
self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}")
|
||||
|
||||
# 标签统计
|
||||
tag_stats = self.get_tag_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:"))
|
||||
self.stdout.write(f" 最常用的正面标签:")
|
||||
for tag in tag_stats['top_positive']:
|
||||
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次")
|
||||
|
||||
self.stdout.write(f" 最常用的负面标签:")
|
||||
for tag in tag_stats['top_negative']:
|
||||
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次")
|
||||
|
||||
# 每日趋势
|
||||
days = options['days']
|
||||
daily_trend = self.get_daily_feedback_trend(days)
|
||||
self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:"))
|
||||
for day in daily_trend:
|
||||
self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)")
|
||||
|
||||
# 用户统计
|
||||
user_stats = self.get_user_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:"))
|
||||
self.stdout.write(f" 总用户数量: {user_stats['total_users']}")
|
||||
self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}")
|
||||
self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}")
|
||||
|
||||
# 导出数据
|
||||
if options['export']:
|
||||
filename = self.export_data_to_json()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}"))
|
||||
|
||||
def get_feedback_stats(self):
|
||||
"""获取反馈统计信息"""
|
||||
# 基本反馈统计
|
||||
basic_feedback = Feedback.objects.aggregate(
|
||||
total=Count('id'),
|
||||
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
|
||||
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)),
|
||||
avg=Avg('feedback_value')
|
||||
)
|
||||
|
||||
# 详细反馈统计
|
||||
detailed_feedback = DetailedFeedback.objects.aggregate(
|
||||
total=Count('id'),
|
||||
positive=Count('id', filter=Q(feedback_type='positive')),
|
||||
negative=Count('id', filter=Q(feedback_type='negative'))
|
||||
)
|
||||
|
||||
# 合并统计
|
||||
total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0)
|
||||
positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0)
|
||||
negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0)
|
||||
|
||||
# 计算平均分和正面比例
|
||||
avg_feedback = basic_feedback['avg'] or 0
|
||||
positive_rate = (positive / total * 100) if total > 0 else 0
|
||||
|
||||
return {
|
||||
'total_feedback': total,
|
||||
'positive_feedback': positive,
|
||||
'negative_feedback': negative,
|
||||
'avg_feedback': avg_feedback,
|
||||
'positive_rate': positive_rate
|
||||
}
|
||||
|
||||
def get_conversation_stats(self):
|
||||
"""获取对话统计信息"""
|
||||
total_conversations = Conversation.objects.count()
|
||||
total_messages = Message.objects.count()
|
||||
|
||||
# 计算每个对话的消息数量分布
|
||||
conversation_messages = Message.objects.values('conversation').annotate(count=Count('id'))
|
||||
avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0
|
||||
|
||||
return {
|
||||
'total_conversations': total_conversations,
|
||||
'total_messages': total_messages,
|
||||
'avg_messages_per_conversation': avg_messages
|
||||
}
|
||||
|
||||
def get_tag_stats(self):
|
||||
"""获取标签使用统计"""
|
||||
# 分析DetailedFeedback中的标签使用情况
|
||||
# 注意:由于标签可能存储为JSON字符串,这里需要解析
|
||||
|
||||
# 首先获取所有的标签
|
||||
all_tags = FeedbackTag.objects.all()
|
||||
tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags}
|
||||
|
||||
# 计算每个标签的使用次数
|
||||
tag_counts = {}
|
||||
for feedback in DetailedFeedback.objects.all():
|
||||
if feedback.feedback_tags:
|
||||
try:
|
||||
# 尝试解析JSON标签列表
|
||||
tag_ids = json.loads(feedback.feedback_tags)
|
||||
if isinstance(tag_ids, list):
|
||||
for tag_id in tag_ids:
|
||||
tag_id = str(tag_id)
|
||||
if tag_id in tag_counts:
|
||||
tag_counts[tag_id] += 1
|
||||
else:
|
||||
tag_counts[tag_id] = 1
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 如果不是有效的JSON,可能是单个标签ID
|
||||
tag_id = str(feedback.feedback_tags)
|
||||
if tag_id in tag_counts:
|
||||
tag_counts[tag_id] += 1
|
||||
else:
|
||||
tag_counts[tag_id] = 1
|
||||
|
||||
# 获取排名前5的正面和负面标签
|
||||
positive_tags = FeedbackTag.objects.filter(tag_type='positive')
|
||||
negative_tags = FeedbackTag.objects.filter(tag_type='negative')
|
||||
|
||||
top_positive = []
|
||||
for tag in positive_tags:
|
||||
tag_id = str(tag.id)
|
||||
if tag_id in tag_counts:
|
||||
top_positive.append({
|
||||
'tag_name': tag.tag_name,
|
||||
'count': tag_counts[tag_id]
|
||||
})
|
||||
|
||||
top_negative = []
|
||||
for tag in negative_tags:
|
||||
tag_id = str(tag.id)
|
||||
if tag_id in tag_counts:
|
||||
top_negative.append({
|
||||
'tag_name': tag.tag_name,
|
||||
'count': tag_counts[tag_id]
|
||||
})
|
||||
|
||||
# 按使用次数排序
|
||||
top_positive.sort(key=lambda x: x['count'], reverse=True)
|
||||
top_negative.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# 取前5
|
||||
return {
|
||||
'top_positive': top_positive[:5],
|
||||
'top_negative': top_negative[:5]
|
||||
}
|
||||
|
||||
def get_daily_feedback_trend(self, days=30):
|
||||
"""获取每日反馈趋势"""
|
||||
# 计算开始日期
|
||||
start_date = timezone.now().date() - timedelta(days=days)
|
||||
|
||||
# 基本反馈按日期分组
|
||||
basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \
|
||||
.values('timestamp__date') \
|
||||
.annotate(
|
||||
date=F('timestamp__date'),
|
||||
total=Count('id'),
|
||||
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
|
||||
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0))
|
||||
) \
|
||||
.order_by('date')
|
||||
|
||||
# 详细反馈按日期分组
|
||||
detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \
|
||||
.values('created_at__date') \
|
||||
.annotate(
|
||||
date=F('created_at__date'),
|
||||
total=Count('id'),
|
||||
positive=Count('id', filter=Q(feedback_type='positive')),
|
||||
negative=Count('id', filter=Q(feedback_type='negative'))
|
||||
) \
|
||||
.order_by('date')
|
||||
|
||||
# 合并两种反馈数据
|
||||
daily_data = {}
|
||||
|
||||
for item in basic_daily:
|
||||
date_str = item['date'].strftime('%Y-%m-%d')
|
||||
daily_data[date_str] = {
|
||||
'date': date_str,
|
||||
'total': item['total'],
|
||||
'positive': item['positive'],
|
||||
'negative': item['negative']
|
||||
}
|
||||
|
||||
for item in detailed_daily:
|
||||
date_str = item['date'].strftime('%Y-%m-%d')
|
||||
if date_str in daily_data:
|
||||
daily_data[date_str]['total'] += item['total']
|
||||
daily_data[date_str]['positive'] += item['positive']
|
||||
daily_data[date_str]['negative'] += item['negative']
|
||||
else:
|
||||
daily_data[date_str] = {
|
||||
'date': date_str,
|
||||
'total': item['total'],
|
||||
'positive': item['positive'],
|
||||
'negative': item['negative']
|
||||
}
|
||||
|
||||
# 计算正面反馈比例
|
||||
for date_str, data in daily_data.items():
|
||||
data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0
|
||||
|
||||
# 转换为列表并按日期排序
|
||||
result = list(daily_data.values())
|
||||
result.sort(key=lambda x: x['date'])
|
||||
|
||||
return result
|
||||
|
||||
def get_user_stats(self):
|
||||
"""获取用户统计信息"""
|
||||
# 总用户数
|
||||
total_users = User.objects.count()
|
||||
|
||||
# 有反馈记录的用户数
|
||||
users_with_feedback = User.objects.filter(
|
||||
Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False)
|
||||
).distinct().count()
|
||||
|
||||
# 最近30天活跃的标注用户
|
||||
thirty_days_ago = timezone.now() - timedelta(days=30)
|
||||
active_users = User.objects.filter(
|
||||
Q(feedback__timestamp__gte=thirty_days_ago) |
|
||||
Q(detailed_feedback__created_at__gte=thirty_days_ago)
|
||||
).distinct().count()
|
||||
|
||||
# 计算每个用户的标注量
|
||||
user_annotations = {}
|
||||
|
||||
for feedback in Feedback.objects.all():
|
||||
user_id = str(feedback.user_id)
|
||||
if user_id in user_annotations:
|
||||
user_annotations[user_id] += 1
|
||||
else:
|
||||
user_annotations[user_id] = 1
|
||||
|
||||
for feedback in DetailedFeedback.objects.all():
|
||||
user_id = str(feedback.user_id)
|
||||
if user_id in user_annotations:
|
||||
user_annotations[user_id] += 1
|
||||
else:
|
||||
user_annotations[user_id] = 1
|
||||
|
||||
# 计算平均每用户标注量
|
||||
if user_annotations:
|
||||
avg_annotations = sum(user_annotations.values()) / len(user_annotations)
|
||||
else:
|
||||
avg_annotations = 0
|
||||
|
||||
return {
|
||||
'total_users': total_users,
|
||||
'users_with_feedback': users_with_feedback,
|
||||
'active_users': active_users,
|
||||
'avg_annotations_per_user': avg_annotations
|
||||
}
|
||||
|
||||
def export_data_to_json(self):
|
||||
"""导出数据到JSON文件"""
|
||||
data = {
|
||||
'conversations': [],
|
||||
'feedback_summary': self.get_feedback_stats(),
|
||||
'tag_stats': self.get_tag_stats(),
|
||||
'daily_trend': self.get_daily_feedback_trend(30),
|
||||
'export_time': timezone.now().isoformat()
|
||||
}
|
||||
|
||||
# 导出对话和消息数据
|
||||
for conv in Conversation.objects.all().prefetch_related('messages'):
|
||||
conv_data = {
|
||||
'id': str(conv.id),
|
||||
'created_at': conv.created_at.isoformat(),
|
||||
'user_id': str(conv.user_id),
|
||||
'is_submitted': conv.is_submitted,
|
||||
'messages': []
|
||||
}
|
||||
|
||||
for msg in conv.messages.all().order_by('timestamp'):
|
||||
msg_data = {
|
||||
'id': str(msg.id),
|
||||
'role': msg.role,
|
||||
'content': msg.content,
|
||||
'timestamp': msg.timestamp.isoformat(),
|
||||
'feedback': []
|
||||
}
|
||||
|
||||
# 获取消息的反馈
|
||||
for fb in Feedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(fb.id),
|
||||
'type': 'basic',
|
||||
'value': fb.feedback_value,
|
||||
'user_id': str(fb.user_id),
|
||||
'timestamp': fb.timestamp.isoformat()
|
||||
})
|
||||
|
||||
# 获取详细反馈
|
||||
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
|
||||
try:
|
||||
tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tags = [dfb.feedback_tags] if dfb.feedback_tags else []
|
||||
|
||||
msg_data['feedback'].append({
|
||||
'id': str(dfb.id),
|
||||
'type': 'detailed',
|
||||
'feedback_type': dfb.feedback_type,
|
||||
'tags': tags,
|
||||
'custom_tags': dfb.custom_tags,
|
||||
'custom_content': dfb.custom_content,
|
||||
'is_inline': dfb.is_inline,
|
||||
'user_id': str(dfb.user_id),
|
||||
'timestamp': dfb.created_at.isoformat()
|
||||
})
|
||||
|
||||
conv_data['messages'].append(msg_data)
|
||||
|
||||
data['conversations'].append(conv_data)
|
||||
|
||||
# 保存到文件
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'rlhf_data_export_{timestamp}.json'
|
||||
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag
|
||||
from django.db.models import Count, Avg, Sum, Q, F
|
||||
from django.utils import timezone
|
||||
from django.contrib.auth import get_user_model
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '分析RLHF反馈数据,生成统计报告'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--export',
|
||||
action='store_true',
|
||||
help='导出数据到JSON文件',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--days',
|
||||
type=int,
|
||||
default=30,
|
||||
help='分析最近的天数',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.stdout.write(self.style.SUCCESS("=" * 60))
|
||||
self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告"))
|
||||
self.stdout.write(self.style.SUCCESS("=" * 60))
|
||||
|
||||
# 基本统计
|
||||
feedback_stats = self.get_feedback_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:"))
|
||||
self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}")
|
||||
self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)")
|
||||
self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}")
|
||||
self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}")
|
||||
|
||||
# 对话统计
|
||||
conv_stats = self.get_conversation_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:"))
|
||||
self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}")
|
||||
self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}")
|
||||
self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}")
|
||||
|
||||
# 标签统计
|
||||
tag_stats = self.get_tag_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:"))
|
||||
self.stdout.write(f" 最常用的正面标签:")
|
||||
for tag in tag_stats['top_positive']:
|
||||
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次")
|
||||
|
||||
self.stdout.write(f" 最常用的负面标签:")
|
||||
for tag in tag_stats['top_negative']:
|
||||
self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次")
|
||||
|
||||
# 每日趋势
|
||||
days = options['days']
|
||||
daily_trend = self.get_daily_feedback_trend(days)
|
||||
self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:"))
|
||||
for day in daily_trend:
|
||||
self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)")
|
||||
|
||||
# 用户统计
|
||||
user_stats = self.get_user_stats()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:"))
|
||||
self.stdout.write(f" 总用户数量: {user_stats['total_users']}")
|
||||
self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}")
|
||||
self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}")
|
||||
|
||||
# 导出数据
|
||||
if options['export']:
|
||||
filename = self.export_data_to_json()
|
||||
self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}"))
|
||||
|
||||
def get_feedback_stats(self):
|
||||
"""获取反馈统计信息"""
|
||||
# 基本反馈统计
|
||||
basic_feedback = Feedback.objects.aggregate(
|
||||
total=Count('id'),
|
||||
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
|
||||
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)),
|
||||
avg=Avg('feedback_value')
|
||||
)
|
||||
|
||||
# 详细反馈统计
|
||||
detailed_feedback = DetailedFeedback.objects.aggregate(
|
||||
total=Count('id'),
|
||||
positive=Count('id', filter=Q(feedback_type='positive')),
|
||||
negative=Count('id', filter=Q(feedback_type='negative'))
|
||||
)
|
||||
|
||||
# 合并统计
|
||||
total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0)
|
||||
positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0)
|
||||
negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0)
|
||||
|
||||
# 计算平均分和正面比例
|
||||
avg_feedback = basic_feedback['avg'] or 0
|
||||
positive_rate = (positive / total * 100) if total > 0 else 0
|
||||
|
||||
return {
|
||||
'total_feedback': total,
|
||||
'positive_feedback': positive,
|
||||
'negative_feedback': negative,
|
||||
'avg_feedback': avg_feedback,
|
||||
'positive_rate': positive_rate
|
||||
}
|
||||
|
||||
def get_conversation_stats(self):
|
||||
"""获取对话统计信息"""
|
||||
total_conversations = Conversation.objects.count()
|
||||
total_messages = Message.objects.count()
|
||||
|
||||
# 计算每个对话的消息数量分布
|
||||
conversation_messages = Message.objects.values('conversation').annotate(count=Count('id'))
|
||||
avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0
|
||||
|
||||
return {
|
||||
'total_conversations': total_conversations,
|
||||
'total_messages': total_messages,
|
||||
'avg_messages_per_conversation': avg_messages
|
||||
}
|
||||
|
||||
def get_tag_stats(self):
|
||||
"""获取标签使用统计"""
|
||||
# 分析DetailedFeedback中的标签使用情况
|
||||
# 注意:由于标签可能存储为JSON字符串,这里需要解析
|
||||
|
||||
# 首先获取所有的标签
|
||||
all_tags = FeedbackTag.objects.all()
|
||||
tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags}
|
||||
|
||||
# 计算每个标签的使用次数
|
||||
tag_counts = {}
|
||||
for feedback in DetailedFeedback.objects.all():
|
||||
if feedback.feedback_tags:
|
||||
try:
|
||||
# 尝试解析JSON标签列表
|
||||
tag_ids = json.loads(feedback.feedback_tags)
|
||||
if isinstance(tag_ids, list):
|
||||
for tag_id in tag_ids:
|
||||
tag_id = str(tag_id)
|
||||
if tag_id in tag_counts:
|
||||
tag_counts[tag_id] += 1
|
||||
else:
|
||||
tag_counts[tag_id] = 1
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 如果不是有效的JSON,可能是单个标签ID
|
||||
tag_id = str(feedback.feedback_tags)
|
||||
if tag_id in tag_counts:
|
||||
tag_counts[tag_id] += 1
|
||||
else:
|
||||
tag_counts[tag_id] = 1
|
||||
|
||||
# 获取排名前5的正面和负面标签
|
||||
positive_tags = FeedbackTag.objects.filter(tag_type='positive')
|
||||
negative_tags = FeedbackTag.objects.filter(tag_type='negative')
|
||||
|
||||
top_positive = []
|
||||
for tag in positive_tags:
|
||||
tag_id = str(tag.id)
|
||||
if tag_id in tag_counts:
|
||||
top_positive.append({
|
||||
'tag_name': tag.tag_name,
|
||||
'count': tag_counts[tag_id]
|
||||
})
|
||||
|
||||
top_negative = []
|
||||
for tag in negative_tags:
|
||||
tag_id = str(tag.id)
|
||||
if tag_id in tag_counts:
|
||||
top_negative.append({
|
||||
'tag_name': tag.tag_name,
|
||||
'count': tag_counts[tag_id]
|
||||
})
|
||||
|
||||
# 按使用次数排序
|
||||
top_positive.sort(key=lambda x: x['count'], reverse=True)
|
||||
top_negative.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# 取前5
|
||||
return {
|
||||
'top_positive': top_positive[:5],
|
||||
'top_negative': top_negative[:5]
|
||||
}
|
||||
|
||||
def get_daily_feedback_trend(self, days=30):
|
||||
"""获取每日反馈趋势"""
|
||||
# 计算开始日期
|
||||
start_date = timezone.now().date() - timedelta(days=days)
|
||||
|
||||
# 基本反馈按日期分组
|
||||
basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \
|
||||
.values('timestamp__date') \
|
||||
.annotate(
|
||||
date=F('timestamp__date'),
|
||||
total=Count('id'),
|
||||
positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)),
|
||||
negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0))
|
||||
) \
|
||||
.order_by('date')
|
||||
|
||||
# 详细反馈按日期分组
|
||||
detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \
|
||||
.values('created_at__date') \
|
||||
.annotate(
|
||||
date=F('created_at__date'),
|
||||
total=Count('id'),
|
||||
positive=Count('id', filter=Q(feedback_type='positive')),
|
||||
negative=Count('id', filter=Q(feedback_type='negative'))
|
||||
) \
|
||||
.order_by('date')
|
||||
|
||||
# 合并两种反馈数据
|
||||
daily_data = {}
|
||||
|
||||
for item in basic_daily:
|
||||
date_str = item['date'].strftime('%Y-%m-%d')
|
||||
daily_data[date_str] = {
|
||||
'date': date_str,
|
||||
'total': item['total'],
|
||||
'positive': item['positive'],
|
||||
'negative': item['negative']
|
||||
}
|
||||
|
||||
for item in detailed_daily:
|
||||
date_str = item['date'].strftime('%Y-%m-%d')
|
||||
if date_str in daily_data:
|
||||
daily_data[date_str]['total'] += item['total']
|
||||
daily_data[date_str]['positive'] += item['positive']
|
||||
daily_data[date_str]['negative'] += item['negative']
|
||||
else:
|
||||
daily_data[date_str] = {
|
||||
'date': date_str,
|
||||
'total': item['total'],
|
||||
'positive': item['positive'],
|
||||
'negative': item['negative']
|
||||
}
|
||||
|
||||
# 计算正面反馈比例
|
||||
for date_str, data in daily_data.items():
|
||||
data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0
|
||||
|
||||
# 转换为列表并按日期排序
|
||||
result = list(daily_data.values())
|
||||
result.sort(key=lambda x: x['date'])
|
||||
|
||||
return result
|
||||
|
||||
def get_user_stats(self):
|
||||
"""获取用户统计信息"""
|
||||
# 总用户数
|
||||
total_users = User.objects.count()
|
||||
|
||||
# 有反馈记录的用户数
|
||||
users_with_feedback = User.objects.filter(
|
||||
Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False)
|
||||
).distinct().count()
|
||||
|
||||
# 最近30天活跃的标注用户
|
||||
thirty_days_ago = timezone.now() - timedelta(days=30)
|
||||
active_users = User.objects.filter(
|
||||
Q(feedback__timestamp__gte=thirty_days_ago) |
|
||||
Q(detailed_feedback__created_at__gte=thirty_days_ago)
|
||||
).distinct().count()
|
||||
|
||||
# 计算每个用户的标注量
|
||||
user_annotations = {}
|
||||
|
||||
for feedback in Feedback.objects.all():
|
||||
user_id = str(feedback.user_id)
|
||||
if user_id in user_annotations:
|
||||
user_annotations[user_id] += 1
|
||||
else:
|
||||
user_annotations[user_id] = 1
|
||||
|
||||
for feedback in DetailedFeedback.objects.all():
|
||||
user_id = str(feedback.user_id)
|
||||
if user_id in user_annotations:
|
||||
user_annotations[user_id] += 1
|
||||
else:
|
||||
user_annotations[user_id] = 1
|
||||
|
||||
# 计算平均每用户标注量
|
||||
if user_annotations:
|
||||
avg_annotations = sum(user_annotations.values()) / len(user_annotations)
|
||||
else:
|
||||
avg_annotations = 0
|
||||
|
||||
return {
|
||||
'total_users': total_users,
|
||||
'users_with_feedback': users_with_feedback,
|
||||
'active_users': active_users,
|
||||
'avg_annotations_per_user': avg_annotations
|
||||
}
|
||||
|
||||
def export_data_to_json(self):
|
||||
"""导出数据到JSON文件"""
|
||||
data = {
|
||||
'conversations': [],
|
||||
'feedback_summary': self.get_feedback_stats(),
|
||||
'tag_stats': self.get_tag_stats(),
|
||||
'daily_trend': self.get_daily_feedback_trend(30),
|
||||
'export_time': timezone.now().isoformat()
|
||||
}
|
||||
|
||||
# 导出对话和消息数据
|
||||
for conv in Conversation.objects.all().prefetch_related('messages'):
|
||||
conv_data = {
|
||||
'id': str(conv.id),
|
||||
'created_at': conv.created_at.isoformat(),
|
||||
'user_id': str(conv.user_id),
|
||||
'is_submitted': conv.is_submitted,
|
||||
'messages': []
|
||||
}
|
||||
|
||||
for msg in conv.messages.all().order_by('timestamp'):
|
||||
msg_data = {
|
||||
'id': str(msg.id),
|
||||
'role': msg.role,
|
||||
'content': msg.content,
|
||||
'timestamp': msg.timestamp.isoformat(),
|
||||
'feedback': []
|
||||
}
|
||||
|
||||
# 获取消息的反馈
|
||||
for fb in Feedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(fb.id),
|
||||
'type': 'basic',
|
||||
'value': fb.feedback_value,
|
||||
'user_id': str(fb.user_id),
|
||||
'timestamp': fb.timestamp.isoformat()
|
||||
})
|
||||
|
||||
# 获取详细反馈
|
||||
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
|
||||
try:
|
||||
tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tags = [dfb.feedback_tags] if dfb.feedback_tags else []
|
||||
|
||||
msg_data['feedback'].append({
|
||||
'id': str(dfb.id),
|
||||
'type': 'detailed',
|
||||
'feedback_type': dfb.feedback_type,
|
||||
'tags': tags,
|
||||
'custom_tags': dfb.custom_tags,
|
||||
'custom_content': dfb.custom_content,
|
||||
'is_inline': dfb.is_inline,
|
||||
'user_id': str(dfb.user_id),
|
||||
'timestamp': dfb.created_at.isoformat()
|
||||
})
|
||||
|
||||
conv_data['messages'].append(msg_data)
|
||||
|
||||
data['conversations'].append(conv_data)
|
||||
|
||||
# 保存到文件
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'rlhf_data_export_{timestamp}.json'
|
||||
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return filename
|
@ -1,289 +1,289 @@
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from rlhf.models import (
|
||||
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig
|
||||
)
|
||||
from django.utils import timezone
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '导入/导出RLHF数据'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--import-file',
|
||||
help='导入JSON数据文件的路径',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--export-file',
|
||||
help='导出JSON数据文件的路径',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--import-tags',
|
||||
action='store_true',
|
||||
help='从init_tags.py导入标签',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--clear',
|
||||
action='store_true',
|
||||
help='在导入前清除现有数据',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
if options['import_file']:
|
||||
self.import_data(options['import_file'], options['clear'])
|
||||
elif options['export_file']:
|
||||
self.export_data(options['export_file'])
|
||||
elif options['import_tags']:
|
||||
self.import_tags_from_init_file()
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径'))
|
||||
|
||||
def import_data(self, file_path, clear=False):
|
||||
"""从JSON文件导入数据"""
|
||||
if not os.path.exists(file_path):
|
||||
raise CommandError(f'文件不存在: {file_path}')
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if clear:
|
||||
self.stdout.write(self.style.WARNING('正在清除现有数据...'))
|
||||
self._clear_data()
|
||||
|
||||
# 导入对话
|
||||
for conv_data in data.get('conversations', []):
|
||||
self._import_conversation(conv_data)
|
||||
|
||||
# 导入配置
|
||||
for config in data.get('system_configs', []):
|
||||
self._import_system_config(config)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导入数据失败: {str(e)}')
|
||||
|
||||
def export_data(self, file_path):
|
||||
"""导出数据到JSON文件"""
|
||||
try:
|
||||
data = {
|
||||
'conversations': [],
|
||||
'system_configs': [],
|
||||
'tags': [],
|
||||
'export_time': timezone.now().isoformat()
|
||||
}
|
||||
|
||||
# 导出标签
|
||||
for tag in FeedbackTag.objects.all():
|
||||
data['tags'].append({
|
||||
'id': str(tag.id),
|
||||
'tag_name': tag.tag_name,
|
||||
'tag_type': tag.tag_type,
|
||||
'description': tag.description,
|
||||
'created_at': tag.created_at.isoformat()
|
||||
})
|
||||
|
||||
# 导出系统配置
|
||||
for config in SystemConfig.objects.all():
|
||||
data['system_configs'].append({
|
||||
'id': str(config.id),
|
||||
'config_key': config.config_key,
|
||||
'config_value': config.config_value,
|
||||
'config_type': config.config_type,
|
||||
'description': config.description,
|
||||
'created_at': config.created_at.isoformat(),
|
||||
'updated_at': config.updated_at.isoformat()
|
||||
})
|
||||
|
||||
# 导出对话
|
||||
for conv in Conversation.objects.all().prefetch_related('messages'):
|
||||
conv_data = {
|
||||
'id': str(conv.id),
|
||||
'user_id': str(conv.user_id),
|
||||
'is_submitted': conv.is_submitted,
|
||||
'created_at': conv.created_at.isoformat(),
|
||||
'messages': []
|
||||
}
|
||||
|
||||
for msg in conv.messages.all().order_by('timestamp'):
|
||||
msg_data = {
|
||||
'id': str(msg.id),
|
||||
'role': msg.role,
|
||||
'content': msg.content,
|
||||
'timestamp': msg.timestamp.isoformat(),
|
||||
'feedback': []
|
||||
}
|
||||
|
||||
# 获取基本反馈
|
||||
for fb in Feedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(fb.id),
|
||||
'type': 'basic',
|
||||
'feedback_value': fb.feedback_value,
|
||||
'user_id': str(fb.user_id),
|
||||
'timestamp': fb.timestamp.isoformat()
|
||||
})
|
||||
|
||||
# 获取详细反馈
|
||||
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(dfb.id),
|
||||
'type': 'detailed',
|
||||
'feedback_type': dfb.feedback_type,
|
||||
'feedback_tags': dfb.feedback_tags,
|
||||
'custom_tags': dfb.custom_tags,
|
||||
'custom_content': dfb.custom_content,
|
||||
'is_inline': dfb.is_inline,
|
||||
'user_id': str(dfb.user_id),
|
||||
'created_at': dfb.created_at.isoformat(),
|
||||
'updated_at': dfb.updated_at.isoformat()
|
||||
})
|
||||
|
||||
conv_data['messages'].append(msg_data)
|
||||
|
||||
data['conversations'].append(conv_data)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导出数据失败: {str(e)}')
|
||||
|
||||
def import_tags_from_init_file(self):
|
||||
"""从init_tags.py导入标签"""
|
||||
try:
|
||||
# 定义标签数据
|
||||
positive_tags = [
|
||||
('有帮助', '回答对问题有实际帮助'),
|
||||
('准确', '信息准确可靠'),
|
||||
('清晰', '表达清楚易懂'),
|
||||
('完整', '回答全面完整'),
|
||||
('友好', '语调友善亲切'),
|
||||
('创新', '提供了新颖的观点')
|
||||
]
|
||||
|
||||
negative_tags = [
|
||||
('不准确', '包含错误信息'),
|
||||
('不相关', '回答偏离主题'),
|
||||
('不完整', '回答过于简略'),
|
||||
('不清晰', '表达模糊难懂'),
|
||||
('不友好', '语调生硬冷淡'),
|
||||
('重复', '内容重复冗余')
|
||||
]
|
||||
|
||||
# 插入正面标签
|
||||
for tag_name, description in positive_tags:
|
||||
FeedbackTag.objects.get_or_create(
|
||||
tag_name=tag_name,
|
||||
defaults={
|
||||
'id': str(uuid.uuid4()),
|
||||
'tag_type': 'positive',
|
||||
'description': description,
|
||||
'created_at': timezone.now()
|
||||
}
|
||||
)
|
||||
|
||||
# 插入负面标签
|
||||
for tag_name, description in negative_tags:
|
||||
FeedbackTag.objects.get_or_create(
|
||||
tag_name=tag_name,
|
||||
defaults={
|
||||
'id': str(uuid.uuid4()),
|
||||
'tag_type': 'negative',
|
||||
'description': description,
|
||||
'created_at': timezone.now()
|
||||
}
|
||||
)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS('成功导入标签'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导入标签失败: {str(e)}')
|
||||
|
||||
def _clear_data(self):
|
||||
"""清除现有数据"""
|
||||
DetailedFeedback.objects.all().delete()
|
||||
Feedback.objects.all().delete()
|
||||
Message.objects.all().delete()
|
||||
Conversation.objects.all().delete()
|
||||
ConversationSubmission.objects.all().delete()
|
||||
ConversationEvaluation.objects.all().delete()
|
||||
self.stdout.write(self.style.SUCCESS('已清除现有数据'))
|
||||
|
||||
def _import_conversation(self, conv_data):
|
||||
"""导入单个对话"""
|
||||
# 检查用户是否存在
|
||||
user_id = conv_data.get('user_id')
|
||||
if not User.objects.filter(id=user_id).exists():
|
||||
self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户'))
|
||||
user = User.objects.filter(is_superuser=True).first()
|
||||
if not user:
|
||||
raise CommandError('找不到管理员用户')
|
||||
user_id = user.id
|
||||
|
||||
# 创建对话
|
||||
conv = Conversation.objects.create(
|
||||
id=conv_data.get('id', str(uuid.uuid4())),
|
||||
user_id=user_id,
|
||||
is_submitted=conv_data.get('is_submitted', False),
|
||||
created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
# 创建消息
|
||||
for msg_data in conv_data.get('messages', []):
|
||||
msg = Message.objects.create(
|
||||
id=msg_data.get('id', str(uuid.uuid4())),
|
||||
conversation=conv,
|
||||
role=msg_data.get('role', 'user'),
|
||||
content=msg_data.get('content', ''),
|
||||
timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
# 创建反馈
|
||||
for fb_data in msg_data.get('feedback', []):
|
||||
if fb_data.get('type') == 'basic':
|
||||
Feedback.objects.create(
|
||||
id=fb_data.get('id', str(uuid.uuid4())),
|
||||
message=msg,
|
||||
conversation=conv,
|
||||
user_id=fb_data.get('user_id', user_id),
|
||||
feedback_value=fb_data.get('feedback_value', 0),
|
||||
timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat()))
|
||||
)
|
||||
elif fb_data.get('type') == 'detailed':
|
||||
DetailedFeedback.objects.create(
|
||||
id=fb_data.get('id', str(uuid.uuid4())),
|
||||
message=msg,
|
||||
conversation=conv,
|
||||
user_id=fb_data.get('user_id', user_id),
|
||||
feedback_type=fb_data.get('feedback_type', 'neutral'),
|
||||
feedback_tags=fb_data.get('feedback_tags', '[]'),
|
||||
custom_tags=fb_data.get('custom_tags', ''),
|
||||
custom_content=fb_data.get('custom_content', ''),
|
||||
is_inline=fb_data.get('is_inline', True),
|
||||
created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())),
|
||||
updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
def _import_system_config(self, config_data):
|
||||
"""导入系统配置"""
|
||||
SystemConfig.objects.update_or_create(
|
||||
config_key=config_data.get('config_key'),
|
||||
defaults={
|
||||
'id': config_data.get('id', str(uuid.uuid4())),
|
||||
'config_value': config_data.get('config_value', ''),
|
||||
'config_type': config_data.get('config_type', 'string'),
|
||||
'description': config_data.get('description', ''),
|
||||
'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())),
|
||||
'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat()))
|
||||
}
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from rlhf.models import (
|
||||
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig
|
||||
)
|
||||
from django.utils import timezone
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '导入/导出RLHF数据'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--import-file',
|
||||
help='导入JSON数据文件的路径',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--export-file',
|
||||
help='导出JSON数据文件的路径',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--import-tags',
|
||||
action='store_true',
|
||||
help='从init_tags.py导入标签',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--clear',
|
||||
action='store_true',
|
||||
help='在导入前清除现有数据',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
if options['import_file']:
|
||||
self.import_data(options['import_file'], options['clear'])
|
||||
elif options['export_file']:
|
||||
self.export_data(options['export_file'])
|
||||
elif options['import_tags']:
|
||||
self.import_tags_from_init_file()
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径'))
|
||||
|
||||
def import_data(self, file_path, clear=False):
|
||||
"""从JSON文件导入数据"""
|
||||
if not os.path.exists(file_path):
|
||||
raise CommandError(f'文件不存在: {file_path}')
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if clear:
|
||||
self.stdout.write(self.style.WARNING('正在清除现有数据...'))
|
||||
self._clear_data()
|
||||
|
||||
# 导入对话
|
||||
for conv_data in data.get('conversations', []):
|
||||
self._import_conversation(conv_data)
|
||||
|
||||
# 导入配置
|
||||
for config in data.get('system_configs', []):
|
||||
self._import_system_config(config)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导入数据失败: {str(e)}')
|
||||
|
||||
def export_data(self, file_path):
|
||||
"""导出数据到JSON文件"""
|
||||
try:
|
||||
data = {
|
||||
'conversations': [],
|
||||
'system_configs': [],
|
||||
'tags': [],
|
||||
'export_time': timezone.now().isoformat()
|
||||
}
|
||||
|
||||
# 导出标签
|
||||
for tag in FeedbackTag.objects.all():
|
||||
data['tags'].append({
|
||||
'id': str(tag.id),
|
||||
'tag_name': tag.tag_name,
|
||||
'tag_type': tag.tag_type,
|
||||
'description': tag.description,
|
||||
'created_at': tag.created_at.isoformat()
|
||||
})
|
||||
|
||||
# 导出系统配置
|
||||
for config in SystemConfig.objects.all():
|
||||
data['system_configs'].append({
|
||||
'id': str(config.id),
|
||||
'config_key': config.config_key,
|
||||
'config_value': config.config_value,
|
||||
'config_type': config.config_type,
|
||||
'description': config.description,
|
||||
'created_at': config.created_at.isoformat(),
|
||||
'updated_at': config.updated_at.isoformat()
|
||||
})
|
||||
|
||||
# 导出对话
|
||||
for conv in Conversation.objects.all().prefetch_related('messages'):
|
||||
conv_data = {
|
||||
'id': str(conv.id),
|
||||
'user_id': str(conv.user_id),
|
||||
'is_submitted': conv.is_submitted,
|
||||
'created_at': conv.created_at.isoformat(),
|
||||
'messages': []
|
||||
}
|
||||
|
||||
for msg in conv.messages.all().order_by('timestamp'):
|
||||
msg_data = {
|
||||
'id': str(msg.id),
|
||||
'role': msg.role,
|
||||
'content': msg.content,
|
||||
'timestamp': msg.timestamp.isoformat(),
|
||||
'feedback': []
|
||||
}
|
||||
|
||||
# 获取基本反馈
|
||||
for fb in Feedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(fb.id),
|
||||
'type': 'basic',
|
||||
'feedback_value': fb.feedback_value,
|
||||
'user_id': str(fb.user_id),
|
||||
'timestamp': fb.timestamp.isoformat()
|
||||
})
|
||||
|
||||
# 获取详细反馈
|
||||
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
|
||||
msg_data['feedback'].append({
|
||||
'id': str(dfb.id),
|
||||
'type': 'detailed',
|
||||
'feedback_type': dfb.feedback_type,
|
||||
'feedback_tags': dfb.feedback_tags,
|
||||
'custom_tags': dfb.custom_tags,
|
||||
'custom_content': dfb.custom_content,
|
||||
'is_inline': dfb.is_inline,
|
||||
'user_id': str(dfb.user_id),
|
||||
'created_at': dfb.created_at.isoformat(),
|
||||
'updated_at': dfb.updated_at.isoformat()
|
||||
})
|
||||
|
||||
conv_data['messages'].append(msg_data)
|
||||
|
||||
data['conversations'].append(conv_data)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导出数据失败: {str(e)}')
|
||||
|
||||
def import_tags_from_init_file(self):
|
||||
"""从init_tags.py导入标签"""
|
||||
try:
|
||||
# 定义标签数据
|
||||
positive_tags = [
|
||||
('有帮助', '回答对问题有实际帮助'),
|
||||
('准确', '信息准确可靠'),
|
||||
('清晰', '表达清楚易懂'),
|
||||
('完整', '回答全面完整'),
|
||||
('友好', '语调友善亲切'),
|
||||
('创新', '提供了新颖的观点')
|
||||
]
|
||||
|
||||
negative_tags = [
|
||||
('不准确', '包含错误信息'),
|
||||
('不相关', '回答偏离主题'),
|
||||
('不完整', '回答过于简略'),
|
||||
('不清晰', '表达模糊难懂'),
|
||||
('不友好', '语调生硬冷淡'),
|
||||
('重复', '内容重复冗余')
|
||||
]
|
||||
|
||||
# 插入正面标签
|
||||
for tag_name, description in positive_tags:
|
||||
FeedbackTag.objects.get_or_create(
|
||||
tag_name=tag_name,
|
||||
defaults={
|
||||
'id': str(uuid.uuid4()),
|
||||
'tag_type': 'positive',
|
||||
'description': description,
|
||||
'created_at': timezone.now()
|
||||
}
|
||||
)
|
||||
|
||||
# 插入负面标签
|
||||
for tag_name, description in negative_tags:
|
||||
FeedbackTag.objects.get_or_create(
|
||||
tag_name=tag_name,
|
||||
defaults={
|
||||
'id': str(uuid.uuid4()),
|
||||
'tag_type': 'negative',
|
||||
'description': description,
|
||||
'created_at': timezone.now()
|
||||
}
|
||||
)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS('成功导入标签'))
|
||||
|
||||
except Exception as e:
|
||||
raise CommandError(f'导入标签失败: {str(e)}')
|
||||
|
||||
def _clear_data(self):
|
||||
"""清除现有数据"""
|
||||
DetailedFeedback.objects.all().delete()
|
||||
Feedback.objects.all().delete()
|
||||
Message.objects.all().delete()
|
||||
Conversation.objects.all().delete()
|
||||
ConversationSubmission.objects.all().delete()
|
||||
ConversationEvaluation.objects.all().delete()
|
||||
self.stdout.write(self.style.SUCCESS('已清除现有数据'))
|
||||
|
||||
def _import_conversation(self, conv_data):
|
||||
"""导入单个对话"""
|
||||
# 检查用户是否存在
|
||||
user_id = conv_data.get('user_id')
|
||||
if not User.objects.filter(id=user_id).exists():
|
||||
self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户'))
|
||||
user = User.objects.filter(is_superuser=True).first()
|
||||
if not user:
|
||||
raise CommandError('找不到管理员用户')
|
||||
user_id = user.id
|
||||
|
||||
# 创建对话
|
||||
conv = Conversation.objects.create(
|
||||
id=conv_data.get('id', str(uuid.uuid4())),
|
||||
user_id=user_id,
|
||||
is_submitted=conv_data.get('is_submitted', False),
|
||||
created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
# 创建消息
|
||||
for msg_data in conv_data.get('messages', []):
|
||||
msg = Message.objects.create(
|
||||
id=msg_data.get('id', str(uuid.uuid4())),
|
||||
conversation=conv,
|
||||
role=msg_data.get('role', 'user'),
|
||||
content=msg_data.get('content', ''),
|
||||
timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
# 创建反馈
|
||||
for fb_data in msg_data.get('feedback', []):
|
||||
if fb_data.get('type') == 'basic':
|
||||
Feedback.objects.create(
|
||||
id=fb_data.get('id', str(uuid.uuid4())),
|
||||
message=msg,
|
||||
conversation=conv,
|
||||
user_id=fb_data.get('user_id', user_id),
|
||||
feedback_value=fb_data.get('feedback_value', 0),
|
||||
timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat()))
|
||||
)
|
||||
elif fb_data.get('type') == 'detailed':
|
||||
DetailedFeedback.objects.create(
|
||||
id=fb_data.get('id', str(uuid.uuid4())),
|
||||
message=msg,
|
||||
conversation=conv,
|
||||
user_id=fb_data.get('user_id', user_id),
|
||||
feedback_type=fb_data.get('feedback_type', 'neutral'),
|
||||
feedback_tags=fb_data.get('feedback_tags', '[]'),
|
||||
custom_tags=fb_data.get('custom_tags', ''),
|
||||
custom_content=fb_data.get('custom_content', ''),
|
||||
is_inline=fb_data.get('is_inline', True),
|
||||
created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())),
|
||||
updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat()))
|
||||
)
|
||||
|
||||
def _import_system_config(self, config_data):
|
||||
"""导入系统配置"""
|
||||
SystemConfig.objects.update_or_create(
|
||||
config_key=config_data.get('config_key'),
|
||||
defaults={
|
||||
'id': config_data.get('id', str(uuid.uuid4())),
|
||||
'config_value': config_data.get('config_value', ''),
|
||||
'config_type': config_data.get('config_type', 'string'),
|
||||
'description': config_data.get('description', ''),
|
||||
'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())),
|
||||
'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat()))
|
||||
}
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
# Generated by Django 5.1.5 on 2025-06-09 08:28
|
||||
# Generated by Django 5.2.1 on 2025-06-10 08:10
|
||||
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
@ -12,6 +12,7 @@ class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('chat', '0002_negotiationchat'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
@ -30,6 +31,17 @@ class Migration(migrations.Migration):
|
||||
'verbose_name_plural': '反馈标签',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='RLHFConversation',
|
||||
fields=[
|
||||
('negotiation_chat', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='rlhf_extension', serialize=False, to='chat.negotiationchat')),
|
||||
('is_submitted', models.BooleanField(default=False)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'RLHF对话',
|
||||
'verbose_name_plural': 'RLHF对话',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='SystemConfig',
|
||||
fields=[
|
||||
@ -46,23 +58,11 @@ class Migration(migrations.Migration):
|
||||
'verbose_name_plural': '系统配置',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Conversation',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('is_submitted', models.BooleanField(default=False)),
|
||||
('created_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='conversations', to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '对话',
|
||||
'verbose_name_plural': '对话',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ConversationSubmission',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('conversation_id', models.CharField(max_length=100)),
|
||||
('title', models.CharField(blank=True, max_length=255, null=True)),
|
||||
('description', models.TextField(blank=True, null=True)),
|
||||
('status', models.CharField(choices=[('submitted', '已提交'), ('reviewed', '已审核'), ('accepted', '已接受'), ('rejected', '已拒绝')], default='submitted', max_length=20)),
|
||||
@ -72,7 +72,6 @@ class Migration(migrations.Migration):
|
||||
('reviewed_at', models.DateTimeField(blank=True, null=True)),
|
||||
('created_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to='rlhf.conversation')),
|
||||
('reviewer', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reviewed_submissions', to=settings.AUTH_USER_MODEL)),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='submissions', to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
@ -81,40 +80,11 @@ class Migration(migrations.Migration):
|
||||
'verbose_name_plural': '对话提交',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Message',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('role', models.CharField(choices=[('user', '用户'), ('assistant', '助手'), ('system', '系统')], max_length=20)),
|
||||
('content', models.TextField()),
|
||||
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='rlhf.conversation')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '消息',
|
||||
'verbose_name_plural': '消息',
|
||||
'ordering': ['timestamp'],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Feedback',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('feedback_value', models.IntegerField()),
|
||||
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.conversation')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)),
|
||||
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to='rlhf.message')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '反馈',
|
||||
'verbose_name_plural': '反馈',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='DetailedFeedback',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('conversation_id', models.CharField(max_length=100)),
|
||||
('feedback_type', models.CharField(choices=[('positive', '正面'), ('negative', '负面'), ('neutral', '中性')], max_length=20)),
|
||||
('feedback_tags', models.TextField(blank=True, null=True)),
|
||||
('custom_tags', models.TextField(blank=True, null=True)),
|
||||
@ -122,31 +92,45 @@ class Migration(migrations.Migration):
|
||||
('is_inline', models.BooleanField(default=True)),
|
||||
('created_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.conversation')),
|
||||
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_detailed_feedback', to='chat.chathistory')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to=settings.AUTH_USER_MODEL)),
|
||||
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='detailed_feedback', to='rlhf.message')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '详细反馈',
|
||||
'verbose_name_plural': '详细反馈',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Feedback',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('conversation_id', models.CharField(max_length=100)),
|
||||
('feedback_value', models.IntegerField()),
|
||||
('timestamp', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('message', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rlhf_feedback', to='chat.chathistory')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='feedback', to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '反馈',
|
||||
'verbose_name_plural': '反馈',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ConversationEvaluation',
|
||||
fields=[
|
||||
('id', models.CharField(default=uuid.uuid4, editable=False, max_length=36, primary_key=True, serialize=False)),
|
||||
('conversation_id', models.CharField(max_length=100)),
|
||||
('overall_feeling', models.TextField(blank=True, null=True)),
|
||||
('has_logical_issues', models.CharField(choices=[('yes', '是'), ('no', '否'), ('unsure', '不确定')], max_length=10)),
|
||||
('needs_satisfied', models.CharField(choices=[('yes', '是'), ('no', '否'), ('partially', '部分')], max_length=10)),
|
||||
('created_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('updated_at', models.DateTimeField(default=django.utils.timezone.now)),
|
||||
('conversation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to='rlhf.conversation')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='evaluations', to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '对话评估',
|
||||
'verbose_name_plural': '对话评估',
|
||||
'unique_together': {('conversation', 'user')},
|
||||
'unique_together': {('conversation_id', 'user')},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -2,60 +2,91 @@ from django.db import models
|
||||
import uuid
|
||||
from django.utils import timezone
|
||||
from apps.user.models import User
|
||||
from apps.chat.models import ChatHistory, NegotiationChat
|
||||
from apps.daren_detail.models import CreatorProfile
|
||||
from apps.brands.models import Product
|
||||
|
||||
|
||||
class Conversation(models.Model):
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations')
|
||||
# 使用NegotiationChat替代Conversation
|
||||
# class Conversation(models.Model):
|
||||
# id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
# user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='conversations')
|
||||
# is_submitted = models.BooleanField(default=False)
|
||||
# created_at = models.DateTimeField(default=timezone.now)
|
||||
#
|
||||
# class Meta:
|
||||
# verbose_name = '对话'
|
||||
# verbose_name_plural = '对话'
|
||||
#
|
||||
# def __str__(self):
|
||||
# return f"Conversation {self.id[:8]}"
|
||||
|
||||
|
||||
# 为NegotiationChat添加RLHF所需字段的代理模型
|
||||
class RLHFConversation(models.Model):
|
||||
"""对NegotiationChat的扩展,添加RLHF所需字段"""
|
||||
negotiation_chat = models.OneToOneField(NegotiationChat, on_delete=models.CASCADE, primary_key=True, related_name='rlhf_extension')
|
||||
is_submitted = models.BooleanField(default=False)
|
||||
created_at = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
verbose_name = '对话'
|
||||
verbose_name_plural = '对话'
|
||||
verbose_name = 'RLHF对话'
|
||||
verbose_name_plural = 'RLHF对话'
|
||||
|
||||
def __str__(self):
|
||||
return f"Conversation {self.id[:8]}"
|
||||
return f"RLHF Conversation {self.negotiation_chat.conversation_id[:8]}"
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.negotiation_chat.conversation_id
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
# 从谈判关联的用户中获取
|
||||
return self.negotiation_chat.negotiation.user
|
||||
|
||||
@property
|
||||
def created_at(self):
|
||||
return self.negotiation_chat.created_at
|
||||
|
||||
|
||||
class Message(models.Model):
|
||||
ROLE_CHOICES = (
|
||||
('user', '用户'),
|
||||
('assistant', '助手'),
|
||||
('system', '系统'),
|
||||
)
|
||||
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages')
|
||||
role = models.CharField(max_length=20, choices=ROLE_CHOICES)
|
||||
content = models.TextField()
|
||||
timestamp = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '消息'
|
||||
verbose_name_plural = '消息'
|
||||
ordering = ['timestamp']
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.role}: {self.content[:50]}..."
|
||||
# 使用ChatHistory替代Message
|
||||
# class Message(models.Model):
|
||||
# ROLE_CHOICES = (
|
||||
# ('user', '用户'),
|
||||
# ('assistant', '助手'),
|
||||
# ('system', '系统'),
|
||||
# )
|
||||
#
|
||||
# id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
# conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='messages')
|
||||
# role = models.CharField(max_length=20, choices=ROLE_CHOICES)
|
||||
# content = models.TextField()
|
||||
# timestamp = models.DateTimeField(default=timezone.now)
|
||||
#
|
||||
# class Meta:
|
||||
#
|
||||
# verbose_name = '消息'
|
||||
# verbose_name_plural = '消息'
|
||||
# ordering = ['timestamp']
|
||||
#
|
||||
# def __str__(self):
|
||||
# return f"{self.role}: {self.content[:50]}..."
|
||||
|
||||
|
||||
class Feedback(models.Model):
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='feedback')
|
||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='feedback')
|
||||
message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_feedback')
|
||||
conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='feedback')
|
||||
feedback_value = models.IntegerField()
|
||||
timestamp = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '反馈'
|
||||
verbose_name_plural = '反馈'
|
||||
|
||||
def __str__(self):
|
||||
return f"Feedback on {self.message.id[:8]}"
|
||||
return f"Feedback on {self.message.id}"
|
||||
|
||||
|
||||
class FeedbackTag(models.Model):
|
||||
@ -71,7 +102,6 @@ class FeedbackTag(models.Model):
|
||||
created_at = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '反馈标签'
|
||||
verbose_name_plural = '反馈标签'
|
||||
|
||||
@ -87,8 +117,8 @@ class DetailedFeedback(models.Model):
|
||||
)
|
||||
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
message = models.ForeignKey(Message, on_delete=models.CASCADE, related_name='detailed_feedback')
|
||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='detailed_feedback')
|
||||
message = models.ForeignKey(ChatHistory, on_delete=models.CASCADE, related_name='rlhf_detailed_feedback')
|
||||
conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='detailed_feedback')
|
||||
feedback_type = models.CharField(max_length=20, choices=FEEDBACK_TYPE_CHOICES)
|
||||
feedback_tags = models.TextField(blank=True, null=True) # JSON格式存储多个标签
|
||||
@ -99,12 +129,11 @@ class DetailedFeedback(models.Model):
|
||||
updated_at = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '详细反馈'
|
||||
verbose_name_plural = '详细反馈'
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.feedback_type} feedback on {self.message.id[:8]}"
|
||||
return f"{self.feedback_type} feedback on {self.message.id}"
|
||||
|
||||
|
||||
class ConversationSubmission(models.Model):
|
||||
@ -116,7 +145,7 @@ class ConversationSubmission(models.Model):
|
||||
)
|
||||
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='submissions')
|
||||
conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='submissions')
|
||||
title = models.CharField(max_length=255, blank=True, null=True)
|
||||
description = models.TextField(blank=True, null=True)
|
||||
@ -130,12 +159,11 @@ class ConversationSubmission(models.Model):
|
||||
updated_at = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '对话提交'
|
||||
verbose_name_plural = '对话提交'
|
||||
|
||||
def __str__(self):
|
||||
return f"Submission for {self.conversation.id[:8]}"
|
||||
return f"Submission for {self.conversation_id[:8]}"
|
||||
|
||||
|
||||
class ConversationEvaluation(models.Model):
|
||||
@ -152,7 +180,7 @@ class ConversationEvaluation(models.Model):
|
||||
)
|
||||
|
||||
id = models.CharField(primary_key=True, max_length=36, default=uuid.uuid4, editable=False)
|
||||
conversation = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name='evaluations')
|
||||
conversation_id = models.CharField(max_length=100) # 存储NegotiationChat的conversation_id
|
||||
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='evaluations')
|
||||
overall_feeling = models.TextField(blank=True, null=True)
|
||||
has_logical_issues = models.CharField(max_length=10, choices=LOGICAL_CHOICES)
|
||||
@ -161,13 +189,12 @@ class ConversationEvaluation(models.Model):
|
||||
updated_at = models.DateTimeField(default=timezone.now)
|
||||
|
||||
class Meta:
|
||||
|
||||
verbose_name = '对话评估'
|
||||
verbose_name_plural = '对话评估'
|
||||
unique_together = ('conversation', 'user')
|
||||
unique_together = ('conversation_id', 'user')
|
||||
|
||||
def __str__(self):
|
||||
return f"Evaluation for {self.conversation.id[:8]}"
|
||||
return f"Evaluation for {self.conversation_id[:8]}"
|
||||
|
||||
|
||||
class SystemConfig(models.Model):
|
||||
|
@ -1,29 +1,43 @@
|
||||
from rest_framework import serializers
|
||||
from .models import (
|
||||
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig
|
||||
Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig,
|
||||
RLHFConversation, NegotiationChat, ChatHistory
|
||||
)
|
||||
from apps.user.serializers import UserSerializer
|
||||
|
||||
|
||||
class ConversationSerializer(serializers.ModelSerializer):
|
||||
id = serializers.CharField(source='conversation_id', read_only=True)
|
||||
user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True)
|
||||
is_submitted = serializers.SerializerMethodField()
|
||||
created_at = serializers.DateTimeField(source='created_at', read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Conversation
|
||||
model = NegotiationChat
|
||||
fields = ['id', 'user', 'is_submitted', 'created_at']
|
||||
read_only_fields = ['id', 'created_at', 'user']
|
||||
|
||||
def get_is_submitted(self, obj):
|
||||
# 尝试获取RLHF扩展
|
||||
try:
|
||||
return obj.rlhf_extension.is_submitted
|
||||
except RLHFConversation.DoesNotExist:
|
||||
return False
|
||||
|
||||
|
||||
class MessageSerializer(serializers.ModelSerializer):
|
||||
conversation = serializers.CharField(source='conversation_id', read_only=True)
|
||||
timestamp = serializers.DateTimeField(source='created_at', read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Message
|
||||
model = ChatHistory
|
||||
fields = ['id', 'conversation', 'role', 'content', 'timestamp']
|
||||
read_only_fields = ['id', 'timestamp']
|
||||
|
||||
|
||||
class FeedbackSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Feedback
|
||||
fields = ['id', 'message', 'conversation', 'user', 'feedback_value', 'timestamp']
|
||||
fields = ['id', 'message', 'conversation_id', 'user', 'feedback_value', 'timestamp']
|
||||
read_only_fields = ['id', 'timestamp']
|
||||
|
||||
|
||||
@ -37,7 +51,11 @@ class FeedbackTagSerializer(serializers.ModelSerializer):
|
||||
class DetailedFeedbackSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = DetailedFeedback
|
||||
fields = ['id', 'message', 'conversation', 'user', 'feedback_type', 'feedback_tags', 'custom_tags', 'custom_content', 'is_inline', 'created_at', 'updated_at']
|
||||
fields = [
|
||||
'id', 'message', 'conversation_id', 'user', 'feedback_type',
|
||||
'feedback_tags', 'custom_tags', 'custom_content', 'is_inline',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = ['id', 'created_at', 'updated_at']
|
||||
|
||||
|
||||
@ -47,38 +65,72 @@ class ConversationSubmissionSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = ConversationSubmission
|
||||
fields = ['id', 'conversation', 'user', 'user_details', 'title', 'description', 'status', 'quality_score', 'reviewer', 'reviewer_details', 'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at']
|
||||
fields = [
|
||||
'id', 'conversation_id', 'user', 'title', 'description',
|
||||
'status', 'quality_score', 'reviewer', 'reviewer_details',
|
||||
'reviewer_notes', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = ['id', 'submitted_at', 'reviewed_at', 'created_at', 'updated_at']
|
||||
|
||||
|
||||
class ConversationEvaluationSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ConversationEvaluation
|
||||
fields = ['id', 'conversation', 'user', 'overall_feeling', 'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at']
|
||||
fields = [
|
||||
'id', 'conversation_id', 'user', 'overall_feeling',
|
||||
'has_logical_issues', 'needs_satisfied', 'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = ['id', 'created_at', 'updated_at']
|
||||
|
||||
|
||||
class SystemConfigSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = SystemConfig
|
||||
fields = ['id', 'config_key', 'config_value', 'config_type', 'description', 'created_at', 'updated_at']
|
||||
fields = [
|
||||
'id', 'config_key', 'config_value', 'config_type',
|
||||
'description', 'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = ['id', 'created_at', 'updated_at']
|
||||
|
||||
|
||||
class ConversationWithMessagesSerializer(serializers.ModelSerializer):
|
||||
messages = MessageSerializer(many=True, read_only=True)
|
||||
id = serializers.CharField(source='conversation_id', read_only=True)
|
||||
user = serializers.PrimaryKeyRelatedField(source='negotiation.user', read_only=True)
|
||||
is_submitted = serializers.SerializerMethodField()
|
||||
created_at = serializers.DateTimeField(source='created_at', read_only=True)
|
||||
messages = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Conversation
|
||||
model = NegotiationChat
|
||||
fields = ['id', 'user', 'is_submitted', 'created_at', 'messages']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def get_is_submitted(self, obj):
|
||||
try:
|
||||
return obj.rlhf_extension.is_submitted
|
||||
except RLHFConversation.DoesNotExist:
|
||||
return False
|
||||
|
||||
def get_messages(self, obj):
|
||||
messages = ChatHistory.objects.filter(
|
||||
conversation_id=obj.conversation_id
|
||||
).order_by('created_at')
|
||||
return MessageSerializer(messages, many=True).data
|
||||
|
||||
|
||||
class MessageWithFeedbackSerializer(serializers.ModelSerializer):
|
||||
feedback = FeedbackSerializer(many=True, read_only=True)
|
||||
detailed_feedback = DetailedFeedbackSerializer(many=True, read_only=True)
|
||||
conversation = serializers.CharField(source='conversation_id', read_only=True)
|
||||
timestamp = serializers.DateTimeField(source='created_at', read_only=True)
|
||||
feedback = serializers.SerializerMethodField()
|
||||
detailed_feedback = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Message
|
||||
model = ChatHistory
|
||||
fields = ['id', 'conversation', 'role', 'content', 'timestamp', 'feedback', 'detailed_feedback']
|
||||
read_only_fields = ['id', 'timestamp']
|
||||
|
||||
def get_feedback(self, obj):
|
||||
feedback = Feedback.objects.filter(message=obj)
|
||||
return FeedbackSerializer(feedback, many=True).data
|
||||
|
||||
def get_detailed_feedback(self, obj):
|
||||
detailed_feedback = DetailedFeedback.objects.filter(message=obj)
|
||||
return DetailedFeedbackSerializer(detailed_feedback, many=True).data
|
@ -5,8 +5,9 @@ from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from .models import (
|
||||
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig
|
||||
Feedback, FeedbackTag, DetailedFeedback,
|
||||
ConversationSubmission, ConversationEvaluation, SystemConfig,
|
||||
NegotiationChat, ChatHistory, RLHFConversation, CreatorProfile, Product
|
||||
)
|
||||
from .serializers import (
|
||||
ConversationSerializer, MessageSerializer, FeedbackSerializer,
|
||||
@ -98,22 +99,44 @@ class StandardResponseMixin:
|
||||
|
||||
|
||||
class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
queryset = Conversation.objects.all()
|
||||
queryset = NegotiationChat.objects.all()
|
||||
serializer_class = ConversationSerializer
|
||||
authentication_classes = [CustomTokenAuthentication]
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
return Conversation.objects.filter(user=user).order_by('-created_at')
|
||||
return NegotiationChat.objects.filter(negotiation__user=user).order_by('-updated_at')
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(user=self.request.user)
|
||||
creator = CreatorProfile.objects.first()
|
||||
product = Product.objects.first()
|
||||
|
||||
negotiation = Negotiation.objects.create(
|
||||
user=self.request.user,
|
||||
status='active'
|
||||
)
|
||||
|
||||
chat = NegotiationChat.objects.create(
|
||||
negotiation=negotiation,
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
creator=creator,
|
||||
product=product
|
||||
)
|
||||
|
||||
RLHFConversation.objects.create(
|
||||
negotiation_chat=chat,
|
||||
is_submitted=False
|
||||
)
|
||||
|
||||
serializer.instance = chat
|
||||
|
||||
@action(detail=True, methods=['get'])
|
||||
def messages(self, request, pk=None):
|
||||
conversation = self.get_object()
|
||||
messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
|
||||
messages = ChatHistory.objects.filter(
|
||||
conversation_id=conversation.conversation_id
|
||||
).order_by('created_at')
|
||||
serializer = MessageSerializer(messages, many=True)
|
||||
return self.get_standard_response(data=serializer.data)
|
||||
|
||||
@ -130,27 +153,26 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 创建用户消息
|
||||
user_message = Message.objects.create(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation=conversation,
|
||||
knowledge_base = KnowledgeBase.objects.first()
|
||||
|
||||
user_message = ChatHistory.objects.create(
|
||||
user=request.user,
|
||||
knowledge_base=knowledge_base,
|
||||
conversation_id=conversation.conversation_id,
|
||||
role='user',
|
||||
content=content
|
||||
)
|
||||
|
||||
# 这里需要调用AI服务获取回复
|
||||
# 示例:调用SiliconFlow或其他AI服务
|
||||
ai_response = self._generate_ai_response(user_message.content, conversation)
|
||||
ai_response = self._generate_ai_response(content, conversation)
|
||||
|
||||
# 创建AI回复消息
|
||||
ai_message = Message.objects.create(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation=conversation,
|
||||
ai_message = ChatHistory.objects.create(
|
||||
user=request.user,
|
||||
knowledge_base=knowledge_base,
|
||||
conversation_id=conversation.conversation_id,
|
||||
role='assistant',
|
||||
content=ai_response
|
||||
)
|
||||
|
||||
# 更新用户的标注统计
|
||||
self._update_annotation_stats(request.user.id)
|
||||
|
||||
messages = [
|
||||
@ -166,7 +188,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
title = request.data.get('title', '')
|
||||
description = request.data.get('description', '')
|
||||
|
||||
if conversation.is_submitted:
|
||||
rlhf_conv, created = RLHFConversation.objects.get_or_create(
|
||||
negotiation_chat=conversation,
|
||||
defaults={'is_submitted': False}
|
||||
)
|
||||
|
||||
if rlhf_conv.is_submitted:
|
||||
return self.get_standard_response(
|
||||
code=400,
|
||||
message='该对话已提交',
|
||||
@ -174,14 +201,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 更新对话为已提交状态
|
||||
conversation.is_submitted = True
|
||||
conversation.save()
|
||||
rlhf_conv.is_submitted = True
|
||||
rlhf_conv.save()
|
||||
|
||||
# 创建提交记录
|
||||
submission = ConversationSubmission.objects.create(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation=conversation,
|
||||
conversation_id=conversation.conversation_id,
|
||||
user=request.user,
|
||||
title=title,
|
||||
description=description,
|
||||
@ -189,12 +214,11 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
submitted_at=timezone.now()
|
||||
)
|
||||
|
||||
# 记录活动日志
|
||||
UserActivityLog.objects.create(
|
||||
user=request.user,
|
||||
action_type='conversation_submit',
|
||||
target_type='conversation',
|
||||
target_id=str(conversation.id),
|
||||
target_id=conversation.conversation_id,
|
||||
details={'title': title}
|
||||
)
|
||||
|
||||
@ -207,7 +231,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
def resume(self, request, pk=None):
|
||||
conversation = self.get_object()
|
||||
|
||||
if not conversation.is_submitted:
|
||||
rlhf_conv, created = RLHFConversation.objects.get_or_create(
|
||||
negotiation_chat=conversation,
|
||||
defaults={'is_submitted': False}
|
||||
)
|
||||
|
||||
if not rlhf_conv.is_submitted:
|
||||
return self.get_standard_response(
|
||||
code=400,
|
||||
message='该对话未提交,无需恢复',
|
||||
@ -215,25 +244,22 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 更新对话为未提交状态
|
||||
conversation.is_submitted = False
|
||||
conversation.save()
|
||||
rlhf_conv.is_submitted = False
|
||||
rlhf_conv.save()
|
||||
|
||||
# 获取最新的提交记录
|
||||
submission = ConversationSubmission.objects.filter(
|
||||
conversation=conversation
|
||||
conversation_id=conversation.conversation_id
|
||||
).order_by('-submitted_at').first()
|
||||
|
||||
if submission and submission.status == 'submitted':
|
||||
submission.status = 'rejected'
|
||||
submission.save()
|
||||
|
||||
# 记录活动日志
|
||||
UserActivityLog.objects.create(
|
||||
user=request.user,
|
||||
action_type='conversation_resume',
|
||||
target_type='conversation',
|
||||
target_id=str(conversation.id)
|
||||
target_id=conversation.conversation_id
|
||||
)
|
||||
|
||||
return self.get_standard_response(
|
||||
@ -264,7 +290,7 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
sf_client.set_system_message(system_prompt_config.config_value)
|
||||
|
||||
# 获取历史消息作为上下文
|
||||
history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp')
|
||||
history_messages = ChatHistory.objects.filter(conversation_id=conversation.conversation_id).order_by('created_at')
|
||||
|
||||
# 添加历史消息到客户端
|
||||
for msg in history_messages:
|
||||
@ -385,28 +411,28 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
|
||||
def _get_recent_conversations(self, user_id, limit=5):
|
||||
"""获取用户最近的对话"""
|
||||
conversations = Conversation.objects.filter(
|
||||
user_id=user_id
|
||||
).order_by('-created_at')[:limit]
|
||||
conversations = NegotiationChat.objects.filter(
|
||||
negotiation__user_id=user_id
|
||||
).order_by('-updated_at')[:limit]
|
||||
|
||||
result = []
|
||||
for conv in conversations:
|
||||
# 获取最后一条消息内容作为对话摘要
|
||||
last_message = Message.objects.filter(
|
||||
conversation_id=conv.id
|
||||
).order_by('-timestamp').first()
|
||||
last_message = ChatHistory.objects.filter(
|
||||
conversation_id=conv.conversation_id
|
||||
).order_by('-created_at').first()
|
||||
|
||||
# 统计消息数
|
||||
message_count = Message.objects.filter(conversation_id=conv.id).count()
|
||||
message_count = ChatHistory.objects.filter(conversation_id=conv.conversation_id).count()
|
||||
|
||||
# 统计反馈数
|
||||
feedback_count = Feedback.objects.filter(conversation_id=conv.id).count()
|
||||
detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.id).count()
|
||||
feedback_count = Feedback.objects.filter(conversation_id=conv.conversation_id).count()
|
||||
detailed_count = DetailedFeedback.objects.filter(conversation_id=conv.conversation_id).count()
|
||||
|
||||
result.append({
|
||||
'id': str(conv.id),
|
||||
'created_at': conv.created_at.isoformat(),
|
||||
'is_submitted': conv.is_submitted,
|
||||
'id': str(conv.conversation_id),
|
||||
'created_at': conv.updated_at.isoformat(),
|
||||
'is_submitted': conv.negotiation.is_submitted,
|
||||
'message_count': message_count,
|
||||
'feedback_count': feedback_count + detailed_count,
|
||||
'summary': last_message.content[:100] + "..." if last_message and len(last_message.content) > 100 else (last_message.content if last_message else "")
|
||||
@ -416,12 +442,12 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
|
||||
def _get_conversation_stats(self, user_id):
|
||||
"""获取对话统计"""
|
||||
total_conversations = Conversation.objects.filter(user_id=user_id).count()
|
||||
submitted_conversations = Conversation.objects.filter(user_id=user_id, is_submitted=True).count()
|
||||
total_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id).count()
|
||||
submitted_conversations = NegotiationChat.objects.filter(negotiation__user_id=user_id, negotiation__is_submitted=True).count()
|
||||
|
||||
# 对话消息统计
|
||||
message_stats = Message.objects.filter(
|
||||
conversation__user_id=user_id
|
||||
message_stats = ChatHistory.objects.filter(
|
||||
conversation__negotiation__user_id=user_id
|
||||
).aggregate(
|
||||
total=Count('id'),
|
||||
user_messages=Count('id', filter=Q(role='user')),
|
||||
@ -620,14 +646,21 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
|
||||
|
||||
class MessageViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
queryset = Message.objects.all()
|
||||
queryset = ChatHistory.objects.all()
|
||||
serializer_class = MessageSerializer
|
||||
authentication_classes = [CustomTokenAuthentication]
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
return Message.objects.filter(conversation__user=user).order_by('timestamp')
|
||||
# 获取用户参与的所有对话的ID
|
||||
user_conversation_ids = NegotiationChat.objects.filter(
|
||||
negotiation__user=user
|
||||
).values_list('conversation_id', flat=True)
|
||||
# 筛选这些对话中的消息
|
||||
return ChatHistory.objects.filter(
|
||||
conversation_id__in=user_conversation_ids
|
||||
).order_by('created_at')
|
||||
|
||||
|
||||
class FeedbackViewSet(StandardResponseMixin, viewsets.ModelViewSet):
|
||||
|
Loading…
Reference in New Issue
Block a user