From 816d3fdb3a5e136265560d129304821bf62c67e2 Mon Sep 17 00:00:00 2001 From: wanjia Date: Mon, 9 Jun 2025 18:00:00 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9C=9F=E5=AE=9Eai=E5=AF=B9=E8=AF=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/rlhf/management/commands/analyze_data.py | 368 ++++++++++++++++++ apps/rlhf/management/commands/import_data.py | 289 ++++++++++++++ apps/rlhf/siliconflow_client.py | 22 +- apps/rlhf/views.py | 97 +++-- 4 files changed, 738 insertions(+), 38 deletions(-) create mode 100644 apps/rlhf/management/commands/analyze_data.py create mode 100644 apps/rlhf/management/commands/import_data.py diff --git a/apps/rlhf/management/commands/analyze_data.py b/apps/rlhf/management/commands/analyze_data.py new file mode 100644 index 0000000..e14a9b8 --- /dev/null +++ b/apps/rlhf/management/commands/analyze_data.py @@ -0,0 +1,368 @@ +from django.core.management.base import BaseCommand +from rlhf.models import Conversation, Message, Feedback, DetailedFeedback, FeedbackTag +from django.db.models import Count, Avg, Sum, Q, F +from django.utils import timezone +from django.contrib.auth import get_user_model +import json +from datetime import datetime, timedelta + +User = get_user_model() + +class Command(BaseCommand): + help = '分析RLHF反馈数据,生成统计报告' + + def add_arguments(self, parser): + parser.add_argument( + '--export', + action='store_true', + help='导出数据到JSON文件', + ) + parser.add_argument( + '--days', + type=int, + default=30, + help='分析最近的天数', + ) + + def handle(self, *args, **options): + self.stdout.write(self.style.SUCCESS("=" * 60)) + self.stdout.write(self.style.SUCCESS("🤖 在线人类反馈强化学习系统 - 数据分析报告")) + self.stdout.write(self.style.SUCCESS("=" * 60)) + + # 基本统计 + feedback_stats = self.get_feedback_stats() + self.stdout.write(self.style.SUCCESS(f"\n📊 反馈统计:")) + self.stdout.write(f" 总反馈数量: {feedback_stats['total_feedback']}") + self.stdout.write(f" 正面反馈: {feedback_stats['positive_feedback']} ({feedback_stats['positive_rate']:.1f}%)") + self.stdout.write(f" 负面反馈: {feedback_stats['negative_feedback']}") + self.stdout.write(f" 平均反馈分数: {feedback_stats['avg_feedback']:.2f}") + + # 对话统计 + conv_stats = self.get_conversation_stats() + self.stdout.write(self.style.SUCCESS(f"\n💬 对话统计:")) + self.stdout.write(f" 总对话数量: {conv_stats['total_conversations']}") + self.stdout.write(f" 总消息数量: {conv_stats['total_messages']}") + self.stdout.write(f" 平均每对话消息数: {conv_stats['avg_messages_per_conversation']:.1f}") + + # 标签统计 + tag_stats = self.get_tag_stats() + self.stdout.write(self.style.SUCCESS(f"\n🏷️ 标签统计:")) + self.stdout.write(f" 最常用的正面标签:") + for tag in tag_stats['top_positive']: + self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") + + self.stdout.write(f" 最常用的负面标签:") + for tag in tag_stats['top_negative']: + self.stdout.write(f" - {tag['tag_name']}: {tag['count']}次") + + # 每日趋势 + days = options['days'] + daily_trend = self.get_daily_feedback_trend(days) + self.stdout.write(self.style.SUCCESS(f"\n📈 最近{days}天反馈趋势:")) + for day in daily_trend: + self.stdout.write(f" {day['date']}: {day['total']}条反馈 (正面率: {day['positive_rate']:.1f}%)") + + # 用户统计 + user_stats = self.get_user_stats() + self.stdout.write(self.style.SUCCESS(f"\n👥 用户统计:")) + self.stdout.write(f" 总用户数量: {user_stats['total_users']}") + self.stdout.write(f" 活跃标注用户: {user_stats['active_users']}") + self.stdout.write(f" 平均每用户标注量: {user_stats['avg_annotations_per_user']:.1f}") + + # 导出数据 + if options['export']: + filename = self.export_data_to_json() + self.stdout.write(self.style.SUCCESS(f"\n✅ 数据已导出到: {filename}")) + + def get_feedback_stats(self): + """获取反馈统计信息""" + # 基本反馈统计 + basic_feedback = Feedback.objects.aggregate( + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)), + avg=Avg('feedback_value') + ) + + # 详细反馈统计 + detailed_feedback = DetailedFeedback.objects.aggregate( + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) + + # 合并统计 + total = (basic_feedback['total'] or 0) + (detailed_feedback['total'] or 0) + positive = (basic_feedback['positive'] or 0) + (detailed_feedback['positive'] or 0) + negative = (basic_feedback['negative'] or 0) + (detailed_feedback['negative'] or 0) + + # 计算平均分和正面比例 + avg_feedback = basic_feedback['avg'] or 0 + positive_rate = (positive / total * 100) if total > 0 else 0 + + return { + 'total_feedback': total, + 'positive_feedback': positive, + 'negative_feedback': negative, + 'avg_feedback': avg_feedback, + 'positive_rate': positive_rate + } + + def get_conversation_stats(self): + """获取对话统计信息""" + total_conversations = Conversation.objects.count() + total_messages = Message.objects.count() + + # 计算每个对话的消息数量分布 + conversation_messages = Message.objects.values('conversation').annotate(count=Count('id')) + avg_messages = conversation_messages.aggregate(Avg('count'))['count__avg'] or 0 + + return { + 'total_conversations': total_conversations, + 'total_messages': total_messages, + 'avg_messages_per_conversation': avg_messages + } + + def get_tag_stats(self): + """获取标签使用统计""" + # 分析DetailedFeedback中的标签使用情况 + # 注意:由于标签可能存储为JSON字符串,这里需要解析 + + # 首先获取所有的标签 + all_tags = FeedbackTag.objects.all() + tag_id_to_name = {str(tag.id): tag.tag_name for tag in all_tags} + + # 计算每个标签的使用次数 + tag_counts = {} + for feedback in DetailedFeedback.objects.all(): + if feedback.feedback_tags: + try: + # 尝试解析JSON标签列表 + tag_ids = json.loads(feedback.feedback_tags) + if isinstance(tag_ids, list): + for tag_id in tag_ids: + tag_id = str(tag_id) + if tag_id in tag_counts: + tag_counts[tag_id] += 1 + else: + tag_counts[tag_id] = 1 + except (json.JSONDecodeError, TypeError): + # 如果不是有效的JSON,可能是单个标签ID + tag_id = str(feedback.feedback_tags) + if tag_id in tag_counts: + tag_counts[tag_id] += 1 + else: + tag_counts[tag_id] = 1 + + # 获取排名前5的正面和负面标签 + positive_tags = FeedbackTag.objects.filter(tag_type='positive') + negative_tags = FeedbackTag.objects.filter(tag_type='negative') + + top_positive = [] + for tag in positive_tags: + tag_id = str(tag.id) + if tag_id in tag_counts: + top_positive.append({ + 'tag_name': tag.tag_name, + 'count': tag_counts[tag_id] + }) + + top_negative = [] + for tag in negative_tags: + tag_id = str(tag.id) + if tag_id in tag_counts: + top_negative.append({ + 'tag_name': tag.tag_name, + 'count': tag_counts[tag_id] + }) + + # 按使用次数排序 + top_positive.sort(key=lambda x: x['count'], reverse=True) + top_negative.sort(key=lambda x: x['count'], reverse=True) + + # 取前5 + return { + 'top_positive': top_positive[:5], + 'top_negative': top_negative[:5] + } + + def get_daily_feedback_trend(self, days=30): + """获取每日反馈趋势""" + # 计算开始日期 + start_date = timezone.now().date() - timedelta(days=days) + + # 基本反馈按日期分组 + basic_daily = Feedback.objects.filter(timestamp__date__gte=start_date) \ + .values('timestamp__date') \ + .annotate( + date=F('timestamp__date'), + total=Count('id'), + positive=Sum(Case(When(feedback_value__gt=0, then=1), default=0)), + negative=Sum(Case(When(feedback_value__lt=0, then=1), default=0)) + ) \ + .order_by('date') + + # 详细反馈按日期分组 + detailed_daily = DetailedFeedback.objects.filter(created_at__date__gte=start_date) \ + .values('created_at__date') \ + .annotate( + date=F('created_at__date'), + total=Count('id'), + positive=Count('id', filter=Q(feedback_type='positive')), + negative=Count('id', filter=Q(feedback_type='negative')) + ) \ + .order_by('date') + + # 合并两种反馈数据 + daily_data = {} + + for item in basic_daily: + date_str = item['date'].strftime('%Y-%m-%d') + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + for item in detailed_daily: + date_str = item['date'].strftime('%Y-%m-%d') + if date_str in daily_data: + daily_data[date_str]['total'] += item['total'] + daily_data[date_str]['positive'] += item['positive'] + daily_data[date_str]['negative'] += item['negative'] + else: + daily_data[date_str] = { + 'date': date_str, + 'total': item['total'], + 'positive': item['positive'], + 'negative': item['negative'] + } + + # 计算正面反馈比例 + for date_str, data in daily_data.items(): + data['positive_rate'] = (data['positive'] / data['total'] * 100) if data['total'] > 0 else 0 + + # 转换为列表并按日期排序 + result = list(daily_data.values()) + result.sort(key=lambda x: x['date']) + + return result + + def get_user_stats(self): + """获取用户统计信息""" + # 总用户数 + total_users = User.objects.count() + + # 有反馈记录的用户数 + users_with_feedback = User.objects.filter( + Q(feedback__isnull=False) | Q(detailed_feedback__isnull=False) + ).distinct().count() + + # 最近30天活跃的标注用户 + thirty_days_ago = timezone.now() - timedelta(days=30) + active_users = User.objects.filter( + Q(feedback__timestamp__gte=thirty_days_ago) | + Q(detailed_feedback__created_at__gte=thirty_days_ago) + ).distinct().count() + + # 计算每个用户的标注量 + user_annotations = {} + + for feedback in Feedback.objects.all(): + user_id = str(feedback.user_id) + if user_id in user_annotations: + user_annotations[user_id] += 1 + else: + user_annotations[user_id] = 1 + + for feedback in DetailedFeedback.objects.all(): + user_id = str(feedback.user_id) + if user_id in user_annotations: + user_annotations[user_id] += 1 + else: + user_annotations[user_id] = 1 + + # 计算平均每用户标注量 + if user_annotations: + avg_annotations = sum(user_annotations.values()) / len(user_annotations) + else: + avg_annotations = 0 + + return { + 'total_users': total_users, + 'users_with_feedback': users_with_feedback, + 'active_users': active_users, + 'avg_annotations_per_user': avg_annotations + } + + def export_data_to_json(self): + """导出数据到JSON文件""" + data = { + 'conversations': [], + 'feedback_summary': self.get_feedback_stats(), + 'tag_stats': self.get_tag_stats(), + 'daily_trend': self.get_daily_feedback_trend(30), + 'export_time': timezone.now().isoformat() + } + + # 导出对话和消息数据 + for conv in Conversation.objects.all().prefetch_related('messages'): + conv_data = { + 'id': str(conv.id), + 'created_at': conv.created_at.isoformat(), + 'user_id': str(conv.user_id), + 'is_submitted': conv.is_submitted, + 'messages': [] + } + + for msg in conv.messages.all().order_by('timestamp'): + msg_data = { + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'timestamp': msg.timestamp.isoformat(), + 'feedback': [] + } + + # 获取消息的反馈 + for fb in Feedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(fb.id), + 'type': 'basic', + 'value': fb.feedback_value, + 'user_id': str(fb.user_id), + 'timestamp': fb.timestamp.isoformat() + }) + + # 获取详细反馈 + for dfb in DetailedFeedback.objects.filter(message_id=msg.id): + try: + tags = json.loads(dfb.feedback_tags) if dfb.feedback_tags else [] + except (json.JSONDecodeError, TypeError): + tags = [dfb.feedback_tags] if dfb.feedback_tags else [] + + msg_data['feedback'].append({ + 'id': str(dfb.id), + 'type': 'detailed', + 'feedback_type': dfb.feedback_type, + 'tags': tags, + 'custom_tags': dfb.custom_tags, + 'custom_content': dfb.custom_content, + 'is_inline': dfb.is_inline, + 'user_id': str(dfb.user_id), + 'timestamp': dfb.created_at.isoformat() + }) + + conv_data['messages'].append(msg_data) + + data['conversations'].append(conv_data) + + # 保存到文件 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'rlhf_data_export_{timestamp}.json' + + with open(filename, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + return filename \ No newline at end of file diff --git a/apps/rlhf/management/commands/import_data.py b/apps/rlhf/management/commands/import_data.py new file mode 100644 index 0000000..77e2be4 --- /dev/null +++ b/apps/rlhf/management/commands/import_data.py @@ -0,0 +1,289 @@ +from django.core.management.base import BaseCommand, CommandError +from rlhf.models import ( + Conversation, Message, Feedback, FeedbackTag, DetailedFeedback, + ConversationSubmission, ConversationEvaluation, SystemConfig +) +from django.utils import timezone +import json +import uuid +import os +from django.contrib.auth import get_user_model + +User = get_user_model() + +class Command(BaseCommand): + help = '导入/导出RLHF数据' + + def add_arguments(self, parser): + parser.add_argument( + '--import-file', + help='导入JSON数据文件的路径', + ) + parser.add_argument( + '--export-file', + help='导出JSON数据文件的路径', + ) + parser.add_argument( + '--import-tags', + action='store_true', + help='从init_tags.py导入标签', + ) + parser.add_argument( + '--clear', + action='store_true', + help='在导入前清除现有数据', + ) + + def handle(self, *args, **options): + if options['import_file']: + self.import_data(options['import_file'], options['clear']) + elif options['export_file']: + self.export_data(options['export_file']) + elif options['import_tags']: + self.import_tags_from_init_file() + else: + self.stdout.write(self.style.WARNING('请指定导入文件路径或导出文件路径')) + + def import_data(self, file_path, clear=False): + """从JSON文件导入数据""" + if not os.path.exists(file_path): + raise CommandError(f'文件不存在: {file_path}') + + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + if clear: + self.stdout.write(self.style.WARNING('正在清除现有数据...')) + self._clear_data() + + # 导入对话 + for conv_data in data.get('conversations', []): + self._import_conversation(conv_data) + + # 导入配置 + for config in data.get('system_configs', []): + self._import_system_config(config) + + self.stdout.write(self.style.SUCCESS(f'成功导入数据从: {file_path}')) + + except Exception as e: + raise CommandError(f'导入数据失败: {str(e)}') + + def export_data(self, file_path): + """导出数据到JSON文件""" + try: + data = { + 'conversations': [], + 'system_configs': [], + 'tags': [], + 'export_time': timezone.now().isoformat() + } + + # 导出标签 + for tag in FeedbackTag.objects.all(): + data['tags'].append({ + 'id': str(tag.id), + 'tag_name': tag.tag_name, + 'tag_type': tag.tag_type, + 'description': tag.description, + 'created_at': tag.created_at.isoformat() + }) + + # 导出系统配置 + for config in SystemConfig.objects.all(): + data['system_configs'].append({ + 'id': str(config.id), + 'config_key': config.config_key, + 'config_value': config.config_value, + 'config_type': config.config_type, + 'description': config.description, + 'created_at': config.created_at.isoformat(), + 'updated_at': config.updated_at.isoformat() + }) + + # 导出对话 + for conv in Conversation.objects.all().prefetch_related('messages'): + conv_data = { + 'id': str(conv.id), + 'user_id': str(conv.user_id), + 'is_submitted': conv.is_submitted, + 'created_at': conv.created_at.isoformat(), + 'messages': [] + } + + for msg in conv.messages.all().order_by('timestamp'): + msg_data = { + 'id': str(msg.id), + 'role': msg.role, + 'content': msg.content, + 'timestamp': msg.timestamp.isoformat(), + 'feedback': [] + } + + # 获取基本反馈 + for fb in Feedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(fb.id), + 'type': 'basic', + 'feedback_value': fb.feedback_value, + 'user_id': str(fb.user_id), + 'timestamp': fb.timestamp.isoformat() + }) + + # 获取详细反馈 + for dfb in DetailedFeedback.objects.filter(message_id=msg.id): + msg_data['feedback'].append({ + 'id': str(dfb.id), + 'type': 'detailed', + 'feedback_type': dfb.feedback_type, + 'feedback_tags': dfb.feedback_tags, + 'custom_tags': dfb.custom_tags, + 'custom_content': dfb.custom_content, + 'is_inline': dfb.is_inline, + 'user_id': str(dfb.user_id), + 'created_at': dfb.created_at.isoformat(), + 'updated_at': dfb.updated_at.isoformat() + }) + + conv_data['messages'].append(msg_data) + + data['conversations'].append(conv_data) + + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + self.stdout.write(self.style.SUCCESS(f'成功导出数据到: {file_path}')) + + except Exception as e: + raise CommandError(f'导出数据失败: {str(e)}') + + def import_tags_from_init_file(self): + """从init_tags.py导入标签""" + try: + # 定义标签数据 + positive_tags = [ + ('有帮助', '回答对问题有实际帮助'), + ('准确', '信息准确可靠'), + ('清晰', '表达清楚易懂'), + ('完整', '回答全面完整'), + ('友好', '语调友善亲切'), + ('创新', '提供了新颖的观点') + ] + + negative_tags = [ + ('不准确', '包含错误信息'), + ('不相关', '回答偏离主题'), + ('不完整', '回答过于简略'), + ('不清晰', '表达模糊难懂'), + ('不友好', '语调生硬冷淡'), + ('重复', '内容重复冗余') + ] + + # 插入正面标签 + for tag_name, description in positive_tags: + FeedbackTag.objects.get_or_create( + tag_name=tag_name, + defaults={ + 'id': str(uuid.uuid4()), + 'tag_type': 'positive', + 'description': description, + 'created_at': timezone.now() + } + ) + + # 插入负面标签 + for tag_name, description in negative_tags: + FeedbackTag.objects.get_or_create( + tag_name=tag_name, + defaults={ + 'id': str(uuid.uuid4()), + 'tag_type': 'negative', + 'description': description, + 'created_at': timezone.now() + } + ) + + self.stdout.write(self.style.SUCCESS('成功导入标签')) + + except Exception as e: + raise CommandError(f'导入标签失败: {str(e)}') + + def _clear_data(self): + """清除现有数据""" + DetailedFeedback.objects.all().delete() + Feedback.objects.all().delete() + Message.objects.all().delete() + Conversation.objects.all().delete() + ConversationSubmission.objects.all().delete() + ConversationEvaluation.objects.all().delete() + self.stdout.write(self.style.SUCCESS('已清除现有数据')) + + def _import_conversation(self, conv_data): + """导入单个对话""" + # 检查用户是否存在 + user_id = conv_data.get('user_id') + if not User.objects.filter(id=user_id).exists(): + self.stdout.write(self.style.WARNING(f'用户不存在: {user_id},将使用第一个管理员用户')) + user = User.objects.filter(is_superuser=True).first() + if not user: + raise CommandError('找不到管理员用户') + user_id = user.id + + # 创建对话 + conv = Conversation.objects.create( + id=conv_data.get('id', str(uuid.uuid4())), + user_id=user_id, + is_submitted=conv_data.get('is_submitted', False), + created_at=timezone.parse_datetime(conv_data.get('created_at', timezone.now().isoformat())) + ) + + # 创建消息 + for msg_data in conv_data.get('messages', []): + msg = Message.objects.create( + id=msg_data.get('id', str(uuid.uuid4())), + conversation=conv, + role=msg_data.get('role', 'user'), + content=msg_data.get('content', ''), + timestamp=timezone.parse_datetime(msg_data.get('timestamp', timezone.now().isoformat())) + ) + + # 创建反馈 + for fb_data in msg_data.get('feedback', []): + if fb_data.get('type') == 'basic': + Feedback.objects.create( + id=fb_data.get('id', str(uuid.uuid4())), + message=msg, + conversation=conv, + user_id=fb_data.get('user_id', user_id), + feedback_value=fb_data.get('feedback_value', 0), + timestamp=timezone.parse_datetime(fb_data.get('timestamp', timezone.now().isoformat())) + ) + elif fb_data.get('type') == 'detailed': + DetailedFeedback.objects.create( + id=fb_data.get('id', str(uuid.uuid4())), + message=msg, + conversation=conv, + user_id=fb_data.get('user_id', user_id), + feedback_type=fb_data.get('feedback_type', 'neutral'), + feedback_tags=fb_data.get('feedback_tags', '[]'), + custom_tags=fb_data.get('custom_tags', ''), + custom_content=fb_data.get('custom_content', ''), + is_inline=fb_data.get('is_inline', True), + created_at=timezone.parse_datetime(fb_data.get('created_at', timezone.now().isoformat())), + updated_at=timezone.parse_datetime(fb_data.get('updated_at', timezone.now().isoformat())) + ) + + def _import_system_config(self, config_data): + """导入系统配置""" + SystemConfig.objects.update_or_create( + config_key=config_data.get('config_key'), + defaults={ + 'id': config_data.get('id', str(uuid.uuid4())), + 'config_value': config_data.get('config_value', ''), + 'config_type': config_data.get('config_type', 'string'), + 'description': config_data.get('description', ''), + 'created_at': timezone.parse_datetime(config_data.get('created_at', timezone.now().isoformat())), + 'updated_at': timezone.parse_datetime(config_data.get('updated_at', timezone.now().isoformat())) + } + ) \ No newline at end of file diff --git a/apps/rlhf/siliconflow_client.py b/apps/rlhf/siliconflow_client.py index 3174cbc..0b088c2 100644 --- a/apps/rlhf/siliconflow_client.py +++ b/apps/rlhf/siliconflow_client.py @@ -2,21 +2,22 @@ import requests import json import time import logging +from django.conf import settings logger = logging.getLogger(__name__) class SiliconFlowClient: - def __init__(self, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf", model="Qwen/QwQ-32B"): + def __init__(self, api_key=None, model=None): """ 初始化SiliconFlow客户端 """ - self.api_key = api_key - self.model = model + self.api_key = api_key or getattr(settings, 'SILICONFLOW_API_KEY', "sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf") + self.model = model or getattr(settings, 'DEFAULT_AI_MODEL', "Qwen/QwQ-32B") self.base_url = "https://api.siliconflow.cn/v1" self.messages = [] self.system_message = None - logger.info(f"初始化SiliconFlow客户端 - 模型: {model}") + logger.info(f"初始化SiliconFlow客户端 - 模型: {self.model}") def set_model(self, model): """设置使用的模型""" @@ -64,6 +65,7 @@ class SiliconFlowClient: "Content-Type": "application/json" } + logger.debug(f"发送请求到SiliconFlow API,模型:{self.model}") response = requests.post( f"{self.base_url}/chat/completions", json=payload, @@ -108,6 +110,7 @@ class SiliconFlowClient: "Content-Type": "application/json" } + logger.debug(f"发送流式请求到SiliconFlow API,模型:{self.model}") response = requests.post( f"{self.base_url}/chat/completions", json=payload, @@ -159,12 +162,15 @@ class SiliconFlowClient: yield error_msg @classmethod - def get_available_models(cls, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"): + def get_available_models(cls, api_key=None): """ 获取可用的模型列表 """ import os + if not api_key: + api_key = getattr(settings, 'SILICONFLOW_API_KEY', "sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf") + # 尝试多种网络配置 proxy_configs = [ # 不使用代理 @@ -243,9 +249,9 @@ class SiliconFlowClient: logger.warning(f"网络配置 {i+1} 异常: {str(e)}") continue - # 所有配置都失败了 - logger.error("所有网络配置都失败,无法获取模型列表") - raise Exception("无法连接到SiliconFlow API服务器") + # 所有配置都失败了,返回预定义的模型列表 + logger.error("所有网络配置都失败,使用预定义模型列表") + return cls._get_fallback_models() @classmethod def _get_fallback_models(cls): diff --git a/apps/rlhf/views.py b/apps/rlhf/views.py index 9872094..65de883 100644 --- a/apps/rlhf/views.py +++ b/apps/rlhf/views.py @@ -22,6 +22,9 @@ from datetime import datetime, timedelta from django.db import transaction from django.db.models.functions import TruncDate from apps.user.authentication import CustomTokenAuthentication +from .siliconflow_client import SiliconFlowClient +from django.conf import settings +import logging # 创建统一响应格式的基类 @@ -237,25 +240,45 @@ class ConversationViewSet(StandardResponseMixin, viewsets.ModelViewSet): def _generate_ai_response(self, user_message, conversation): """ - 生成AI回复 - 这里只是一个示例,实际应用中需要对接真实的AI服务 + 生成AI回复,通过调用SiliconFlow API """ - # 从系统配置获取当前使用的模型 - model_config = SystemConfig.objects.filter(config_key='current_model').first() - model_name = model_config.config_value if model_config else "默认模型" + logger = logging.getLogger(__name__) - # 获取历史消息作为上下文 - history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp') - history = [] - for msg in history_messages: - history.append({"role": msg.role, "content": msg.content}) - - # 在这里调用实际的AI API - # 例如,如果使用SiliconFlow API - # response = sf_client.chat(user_message, history) - - # 这里仅作为示例,返回一个固定的回复 - return f"这是AI({model_name})的回复:我已收到您的消息「{user_message}」。根据您的问题,我的建议是..." + try: + # 从系统配置获取当前使用的模型 + model_config = SystemConfig.objects.filter(config_key='current_model').first() + model_name = model_config.config_value if model_config else getattr(settings, 'DEFAULT_AI_MODEL', "Qwen/QwQ-32B") + + # 初始化SiliconFlow客户端 + sf_client = SiliconFlowClient( + api_key=getattr(settings, 'SILICONFLOW_API_KEY', None), + model=model_name + ) + + # 获取系统提示词(如果有) + system_prompt_config = SystemConfig.objects.filter(config_key='system_prompt').first() + if system_prompt_config and system_prompt_config.config_value: + sf_client.set_system_message(system_prompt_config.config_value) + + # 获取历史消息作为上下文 + history_messages = Message.objects.filter(conversation=conversation).order_by('timestamp') + + # 添加历史消息到客户端 + for msg in history_messages: + sf_client.add_message(msg.role, msg.content) + + # 调用API获取回复 + logger.info(f"正在调用AI API (模型: {model_name}) 处理消息: {user_message[:50]}...") + response = sf_client.chat(user_message) + + # 记录AI回复到日志 + logger.info(f"AI回复成功,回复长度: {len(response)}") + + return response + + except Exception as e: + logger.exception(f"AI API调用失败: {str(e)}") + return f"很抱歉,AI服务暂时不可用: {str(e)}" def _update_annotation_stats(self, user_id): """更新用户的标注统计信息""" @@ -653,7 +676,6 @@ class ConversationEvaluationViewSet(StandardResponseMixin, viewsets.ModelViewSet data=ConversationEvaluationSerializer(evaluation).data ) - class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet): queryset = SystemConfig.objects.all() serializer_class = SystemConfigSerializer @@ -718,15 +740,30 @@ class SystemConfigViewSet(StandardResponseMixin, viewsets.ModelViewSet): @action(detail=False, methods=['get']) def models(self, request): - # 返回可用的模型列表 - return self.get_standard_response( - data={ - 'models': [ - {'id': 'model1', 'name': 'GPT-3.5'}, - {'id': 'model2', 'name': 'GPT-4'}, - {'id': 'model3', 'name': 'Claude'}, - {'id': 'model4', 'name': 'LLaMA'}, - {'id': 'model5', 'name': 'Qwen'} - ] - } - ) \ No newline at end of file + """返回可用的模型列表""" + from .siliconflow_client import SiliconFlowClient + from django.conf import settings + import logging + + logger = logging.getLogger(__name__) + + try: + # 从SiliconFlow获取可用模型列表 + logger.info("正在获取可用模型列表...") + models = SiliconFlowClient.get_available_models( + api_key=getattr(settings, 'SILICONFLOW_API_KEY', None) + ) + logger.info(f"成功获取 {len(models)} 个可用模型") + + return self.get_standard_response( + data={'models': models} + ) + except Exception as e: + logger.exception(f"获取模型列表失败: {str(e)}") + # 使用预定义的模型列表作为备选 + fallback_models = SiliconFlowClient._get_fallback_models() + return self.get_standard_response( + code=200, # 仍然返回200以避免前端错误 + message="无法从API获取模型列表,使用预定义列表", + data={'models': fallback_models} + ) \ No newline at end of file