from django.core.management.base import BaseCommand, CommandError from rlhf.models import ( Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, ConversationSubmission, ConversationEvaluation, SystemConfig ) from django.utils import timezone import json import uuid import os from django.contrib.auth import get_user_model User = get_user_model() class Command(BaseCommand): help = '导入/导出RLHF数据' def add_arguments(self, parser): parser.add_argument( '--import-file', help='导入JSON数据文件的路径', ) parser.add_argument( '--export-file', help='导出JSON数据文件的路径', ) parser.add_argument( '--import-tags', action='store_true', help='从init_tags.py导入标签', ) parser.add_argument( '--clear', action='store_true', help='在导入前清除现有数据', ) def handle(self, *args, **options): if options['import_file']: self.import_data(options['import_file'], options['clear']) elif options['export_file']: self.export_data(options['export_file']) elif options['import_tags']: self.import_tags_from_init_file() else: self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径')) def import_data(self, file_path, clear=False): """从JSON文件导入数据""" if not os.path.exists(file_path): raise CommandError(f'文件不存在: {file_path}') try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if clear: self.stdout.write(self.style.WARNING('正在清除现有数据...')) self._clear_data() # 导入对话 for conv_data in data.get('conversations', []): self._import_conversation(conv_data) # 导入配置 for config in data.get('system_configs', []): self._import_system_config(config) self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}')) except Exception as e: raise CommandError(f'导入数据失败: {str(e)}') def export_data(self, file_path): """导出数据到JSON文件""" try: data = { 'conversations': [], 'system_configs': [], 'tags': [], 'export_time': timezone.now().isoformat() } # 导出标签 for tag in FeedbackTag.objects.all(): data['tags'].append({ 'id': str(tag.id), 'tag_name': tag.tag_name, 'tag_type': tag.tag_type, 'description': tag.description, 'created_at': tag.created_at.isoformat() }) # 导出系统配置 for config in SystemConfig.objects.all(): data['system_configs'].append({ 'id': str(config.id), 'config_key': config.config_key, 'config_value': config.config_value, 'config_type': config.config_type, 'description': config.description, 'created_at': config.created_at.isoformat(), 'updated_at': config.updated_at.isoformat() }) # 导出对话 for conv in Conversation.objects.all().prefetch_related('messages'): conv_data = { 'id': str(conv.id), 'user_id': str(conv.user_id), 'is_submitted': conv.is_submitted, 'created_at': conv.created_at.isoformat(), 'messages': [] } for msg in conv.messages.all().order_by('timestamp'): msg_data = { 'id': str(msg.id), 'role': msg.role, 'content': msg.content, 'timestamp': msg.timestamp.isoformat(), 'feedback': [] } # 获取基本反馈 for fb in Feedback.objects.filter(message_id=msg.id): msg_data['feedback'].append({ 'id': str(fb.id), 'type': 'basic', 'feedback_value': fb.feedback_value, 'user_id': str(fb.user_id), 'timestamp': fb.timestamp.isoformat() }) # 获取详细反馈 for dfb in DetailedFeedback.objects.filter(message_id=msg.id): msg_data['feedback'].append({ 'id': str(dfb.id), 'type': 'detailed', 'feedback_type': dfb.feedback_type, 'feedback_tags': dfb.feedback_tags, 'custom_tags': dfb.custom_tags, 'custom_content': dfb.custom_content, 'is_inline': dfb.is_inline, 'user_id': str(dfb.user_id), 'created_at': dfb.created_at.isoformat(), 'updated_at': dfb.updated_at.isoformat() }) conv_data['messages'].append(msg_data) data['conversations'].append(conv_data) with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}')) except Exception as e: raise CommandError(f'导出数据失败: {str(e)}') def import_tags_from_init_file(self): """从init_tags.py导入标签""" try: # 定义标签数据 positive_tags = [ ('有帮助', '回答对问题有实际帮助'), ('准确', '信息准确可靠'), ('清晰', '表达清楚易懂'), ('完整', '回答全面完整'), ('友好', '语调友善亲切'), ('创新', '提供了新颖的观点') ] negative_tags = [ ('不准确', '包含错误信息'), ('不相关', '回答偏离主题'), ('不完整', '回答过于简略'), ('不清晰', '表达模糊难懂'), ('不友好', '语调生硬冷淡'), ('重复', '内容重复冗余') ] # 插入正面标签 for tag_name, description in positive_tags: FeedbackTag.objects.get_or_create( tag_name=tag_name, defaults={ 'id': str(uuid.uuid4()), 'tag_type': 'positive', 'description': description, 'created_at': timezone.now() } ) # 插入负面标签 for tag_name, description in negative_tags: FeedbackTag.objects.get_or_create( tag_name=tag_name, defaults={ 'id': str(uuid.uuid4()), 'tag_type': 'negative', 'description': description, 'created_at': timezone.now() } ) self.stdout.write(self.style.SUCCESS('成功导入标签')) except Exception as e: raise CommandError(f'导入标签失败: {str(e)}') def _clear_data(self): """清除现有数据""" DetailedFeedback.objects.all().delete() Feedback.objects.all().delete() Message.objects.all().delete() Conversation.objects.all().delete() ConversationSubmission.objects.all().delete() ConversationEvaluation.objects.all().delete() self.stdout.write(self.style.SUCCESS('已清除现有数据')) def _import_conversation(self, conv_data): """导入单个对话""" # 检查用户是否存在 user_id = conv_data.get('user_id') if not User.objects.filter(id=user_id).exists(): self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户')) user = User.objects.filter(is_superuser=True).first() if not user: raise CommandError('找不到管理员用户') user_id = user.id # 创建对话 conv = Conversation.objects.create( id=conv_data.get('id', str(uuid.uuid4())), user_id=user_id, is_submitted=conv_data.get('is_submitted', False), created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat())) ) # 创建消息 for msg_data in conv_data.get('messages', []): msg = Message.objects.create( id=msg_data.get('id', str(uuid.uuid4())), conversation=conv, role=msg_data.get('role', 'user'), content=msg_data.get('content', ''), timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat())) ) # 创建反馈 for fb_data in msg_data.get('feedback', []): if fb_data.get('type') == 'basic': Feedback.objects.create( id=fb_data.get('id', str(uuid.uuid4())), message=msg, conversation=conv, user_id=fb_data.get('user_id', user_id), feedback_value=fb_data.get('feedback_value', 0), timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat())) ) elif fb_data.get('type') == 'detailed': DetailedFeedback.objects.create( id=fb_data.get('id', str(uuid.uuid4())), message=msg, conversation=conv, user_id=fb_data.get('user_id', user_id), feedback_type=fb_data.get('feedback_type', 'neutral'), feedback_tags=fb_data.get('feedback_tags', '[]'), custom_tags=fb_data.get('custom_tags', ''), custom_content=fb_data.get('custom_content', ''), is_inline=fb_data.get('is_inline', True), created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())), updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat())) ) def _import_system_config(self, config_data): """导入系统配置""" SystemConfig.objects.update_or_create( config_key=config_data.get('config_key'), defaults={ 'id': config_data.get('id', str(uuid.uuid4())), 'config_value': config_data.get('config_value', ''), 'config_type': config_data.get('config_type', 'string'), 'description': config_data.get('description', ''), 'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())), 'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat())) } )