289 lines
12 KiB
Python
289 lines
12 KiB
Python
from django.core.management.base import BaseCommand, CommandError
|
|
from rlhf.models import (
|
|
Conversation, Message, Feedback, FeedbackTag, DetailedFeedback,
|
|
ConversationSubmission, ConversationEvaluation, SystemConfig
|
|
)
|
|
from django.utils import timezone
|
|
import json
|
|
import uuid
|
|
import os
|
|
from django.contrib.auth import get_user_model
|
|
|
|
User = get_user_model()
|
|
|
|
class Command(BaseCommand):
|
|
help = '导入/导出RLHF数据'
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument(
|
|
'--import-file',
|
|
help='导入JSON数据文件的路径',
|
|
)
|
|
parser.add_argument(
|
|
'--export-file',
|
|
help='导出JSON数据文件的路径',
|
|
)
|
|
parser.add_argument(
|
|
'--import-tags',
|
|
action='store_true',
|
|
help='从init_tags.py导入标签',
|
|
)
|
|
parser.add_argument(
|
|
'--clear',
|
|
action='store_true',
|
|
help='在导入前清除现有数据',
|
|
)
|
|
|
|
def handle(self, *args, **options):
|
|
if options['import_file']:
|
|
self.import_data(options['import_file'], options['clear'])
|
|
elif options['export_file']:
|
|
self.export_data(options['export_file'])
|
|
elif options['import_tags']:
|
|
self.import_tags_from_init_file()
|
|
else:
|
|
self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径'))
|
|
|
|
def import_data(self, file_path, clear=False):
|
|
"""从JSON文件导入数据"""
|
|
if not os.path.exists(file_path):
|
|
raise CommandError(f'文件不存在: {file_path}')
|
|
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
if clear:
|
|
self.stdout.write(self.style.WARNING('正在清除现有数据...'))
|
|
self._clear_data()
|
|
|
|
# 导入对话
|
|
for conv_data in data.get('conversations', []):
|
|
self._import_conversation(conv_data)
|
|
|
|
# 导入配置
|
|
for config in data.get('system_configs', []):
|
|
self._import_system_config(config)
|
|
|
|
self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}'))
|
|
|
|
except Exception as e:
|
|
raise CommandError(f'导入数据失败: {str(e)}')
|
|
|
|
def export_data(self, file_path):
|
|
"""导出数据到JSON文件"""
|
|
try:
|
|
data = {
|
|
'conversations': [],
|
|
'system_configs': [],
|
|
'tags': [],
|
|
'export_time': timezone.now().isoformat()
|
|
}
|
|
|
|
# 导出标签
|
|
for tag in FeedbackTag.objects.all():
|
|
data['tags'].append({
|
|
'id': str(tag.id),
|
|
'tag_name': tag.tag_name,
|
|
'tag_type': tag.tag_type,
|
|
'description': tag.description,
|
|
'created_at': tag.created_at.isoformat()
|
|
})
|
|
|
|
# 导出系统配置
|
|
for config in SystemConfig.objects.all():
|
|
data['system_configs'].append({
|
|
'id': str(config.id),
|
|
'config_key': config.config_key,
|
|
'config_value': config.config_value,
|
|
'config_type': config.config_type,
|
|
'description': config.description,
|
|
'created_at': config.created_at.isoformat(),
|
|
'updated_at': config.updated_at.isoformat()
|
|
})
|
|
|
|
# 导出对话
|
|
for conv in Conversation.objects.all().prefetch_related('messages'):
|
|
conv_data = {
|
|
'id': str(conv.id),
|
|
'user_id': str(conv.user_id),
|
|
'is_submitted': conv.is_submitted,
|
|
'created_at': conv.created_at.isoformat(),
|
|
'messages': []
|
|
}
|
|
|
|
for msg in conv.messages.all().order_by('timestamp'):
|
|
msg_data = {
|
|
'id': str(msg.id),
|
|
'role': msg.role,
|
|
'content': msg.content,
|
|
'timestamp': msg.timestamp.isoformat(),
|
|
'feedback': []
|
|
}
|
|
|
|
# 获取基本反馈
|
|
for fb in Feedback.objects.filter(message_id=msg.id):
|
|
msg_data['feedback'].append({
|
|
'id': str(fb.id),
|
|
'type': 'basic',
|
|
'feedback_value': fb.feedback_value,
|
|
'user_id': str(fb.user_id),
|
|
'timestamp': fb.timestamp.isoformat()
|
|
})
|
|
|
|
# 获取详细反馈
|
|
for dfb in DetailedFeedback.objects.filter(message_id=msg.id):
|
|
msg_data['feedback'].append({
|
|
'id': str(dfb.id),
|
|
'type': 'detailed',
|
|
'feedback_type': dfb.feedback_type,
|
|
'feedback_tags': dfb.feedback_tags,
|
|
'custom_tags': dfb.custom_tags,
|
|
'custom_content': dfb.custom_content,
|
|
'is_inline': dfb.is_inline,
|
|
'user_id': str(dfb.user_id),
|
|
'created_at': dfb.created_at.isoformat(),
|
|
'updated_at': dfb.updated_at.isoformat()
|
|
})
|
|
|
|
conv_data['messages'].append(msg_data)
|
|
|
|
data['conversations'].append(conv_data)
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}'))
|
|
|
|
except Exception as e:
|
|
raise CommandError(f'导出数据失败: {str(e)}')
|
|
|
|
def import_tags_from_init_file(self):
|
|
"""从init_tags.py导入标签"""
|
|
try:
|
|
# 定义标签数据
|
|
positive_tags = [
|
|
('有帮助', '回答对问题有实际帮助'),
|
|
('准确', '信息准确可靠'),
|
|
('清晰', '表达清楚易懂'),
|
|
('完整', '回答全面完整'),
|
|
('友好', '语调友善亲切'),
|
|
('创新', '提供了新颖的观点')
|
|
]
|
|
|
|
negative_tags = [
|
|
('不准确', '包含错误信息'),
|
|
('不相关', '回答偏离主题'),
|
|
('不完整', '回答过于简略'),
|
|
('不清晰', '表达模糊难懂'),
|
|
('不友好', '语调生硬冷淡'),
|
|
('重复', '内容重复冗余')
|
|
]
|
|
|
|
# 插入正面标签
|
|
for tag_name, description in positive_tags:
|
|
FeedbackTag.objects.get_or_create(
|
|
tag_name=tag_name,
|
|
defaults={
|
|
'id': str(uuid.uuid4()),
|
|
'tag_type': 'positive',
|
|
'description': description,
|
|
'created_at': timezone.now()
|
|
}
|
|
)
|
|
|
|
# 插入负面标签
|
|
for tag_name, description in negative_tags:
|
|
FeedbackTag.objects.get_or_create(
|
|
tag_name=tag_name,
|
|
defaults={
|
|
'id': str(uuid.uuid4()),
|
|
'tag_type': 'negative',
|
|
'description': description,
|
|
'created_at': timezone.now()
|
|
}
|
|
)
|
|
|
|
self.stdout.write(self.style.SUCCESS('成功导入标签'))
|
|
|
|
except Exception as e:
|
|
raise CommandError(f'导入标签失败: {str(e)}')
|
|
|
|
def _clear_data(self):
|
|
"""清除现有数据"""
|
|
DetailedFeedback.objects.all().delete()
|
|
Feedback.objects.all().delete()
|
|
Message.objects.all().delete()
|
|
Conversation.objects.all().delete()
|
|
ConversationSubmission.objects.all().delete()
|
|
ConversationEvaluation.objects.all().delete()
|
|
self.stdout.write(self.style.SUCCESS('已清除现有数据'))
|
|
|
|
def _import_conversation(self, conv_data):
|
|
"""导入单个对话"""
|
|
# 检查用户是否存在
|
|
user_id = conv_data.get('user_id')
|
|
if not User.objects.filter(id=user_id).exists():
|
|
self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户'))
|
|
user = User.objects.filter(is_superuser=True).first()
|
|
if not user:
|
|
raise CommandError('找不到管理员用户')
|
|
user_id = user.id
|
|
|
|
# 创建对话
|
|
conv = Conversation.objects.create(
|
|
id=conv_data.get('id', str(uuid.uuid4())),
|
|
user_id=user_id,
|
|
is_submitted=conv_data.get('is_submitted', False),
|
|
created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat()))
|
|
)
|
|
|
|
# 创建消息
|
|
for msg_data in conv_data.get('messages', []):
|
|
msg = Message.objects.create(
|
|
id=msg_data.get('id', str(uuid.uuid4())),
|
|
conversation=conv,
|
|
role=msg_data.get('role', 'user'),
|
|
content=msg_data.get('content', ''),
|
|
timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat()))
|
|
)
|
|
|
|
# 创建反馈
|
|
for fb_data in msg_data.get('feedback', []):
|
|
if fb_data.get('type') == 'basic':
|
|
Feedback.objects.create(
|
|
id=fb_data.get('id', str(uuid.uuid4())),
|
|
message=msg,
|
|
conversation=conv,
|
|
user_id=fb_data.get('user_id', user_id),
|
|
feedback_value=fb_data.get('feedback_value', 0),
|
|
timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat()))
|
|
)
|
|
elif fb_data.get('type') == 'detailed':
|
|
DetailedFeedback.objects.create(
|
|
id=fb_data.get('id', str(uuid.uuid4())),
|
|
message=msg,
|
|
conversation=conv,
|
|
user_id=fb_data.get('user_id', user_id),
|
|
feedback_type=fb_data.get('feedback_type', 'neutral'),
|
|
feedback_tags=fb_data.get('feedback_tags', '[]'),
|
|
custom_tags=fb_data.get('custom_tags', ''),
|
|
custom_content=fb_data.get('custom_content', ''),
|
|
is_inline=fb_data.get('is_inline', True),
|
|
created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())),
|
|
updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat()))
|
|
)
|
|
|
|
def _import_system_config(self, config_data):
|
|
"""导入系统配置"""
|
|
SystemConfig.objects.update_or_create(
|
|
config_key=config_data.get('config_key'),
|
|
defaults={
|
|
'id': config_data.get('id', str(uuid.uuid4())),
|
|
'config_value': config_data.get('config_value', ''),
|
|
'config_type': config_data.get('config_type', 'string'),
|
|
'description': config_data.get('description', ''),
|
|
'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())),
|
|
'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat()))
|
|
}
|
|
) |