daren/apps/rlhf/siliconflow_client.py

307 lines
12 KiB
Python
Raw Normal View History

2025-06-09 16:29:14 +08:00
import requests
import json
import time
import logging
logger = logging.getLogger(__name__)
class SiliconFlowClient:
def __init__(self, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf", model="Qwen/QwQ-32B"):
"""
初始化SiliconFlow客户端
"""
self.api_key = api_key
self.model = model
self.base_url = "https://api.siliconflow.cn/v1"
self.messages = []
self.system_message = None
logger.info(f"初始化SiliconFlow客户端 - 模型: {model}")
def set_model(self, model):
"""设置使用的模型"""
self.model = model
logger.info(f"SiliconFlow切换模型: {model}")
def set_system_message(self, message):
"""设置系统消息"""
self.system_message = message
self.messages = []
if message:
self.messages.append({"role": "system", "content": message})
logger.debug(f"SiliconFlow设置系统消息 - 长度: {len(message) if message else 0}")
def add_message(self, role, content):
"""添加消息到对话历史"""
self.messages.append({"role": role, "content": content})
def clear_history(self):
"""清空对话历史"""
self.messages = []
if self.system_message:
self.messages.append({"role": "system", "content": self.system_message})
logger.debug("SiliconFlow清空对话历史")
def chat(self, message):
"""
非流式对话
"""
try:
# 添加用户消息
self.add_message("user", message)
payload = {
"model": self.model,
"messages": self.messages,
"stream": False,
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.7,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers,
timeout=60
)
if response.status_code == 200:
data = response.json()
assistant_message = data['choices'][0]['message']['content']
self.add_message("assistant", assistant_message)
return assistant_message
else:
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
logger.error(error_msg)
return f"API调用失败: {error_msg}"
except Exception as e:
error_msg = f"SiliconFlow对话出错: {str(e)}"
logger.exception("SiliconFlow对话异常")
return error_msg
def chat_stream(self, message):
"""
流式对话
"""
try:
# 添加用户消息
self.add_message("user", message)
payload = {
"model": self.model,
"messages": self.messages,
"stream": True,
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.7,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers,
stream=True,
timeout=120
)
if response.status_code != 200:
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
logger.error(error_msg)
yield f"API调用失败: {error_msg}"
return
assistant_message = ""
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:] # 移除 'data: ' 前缀
if data_str.strip() == '[DONE]':
break
try:
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
assistant_message += content
yield content
except json.JSONDecodeError:
continue
# 添加完整的助手回复到历史
if assistant_message:
self.add_message("assistant", assistant_message)
except requests.exceptions.Timeout:
error_msg = "SiliconFlow请求超时"
logger.error(error_msg)
yield error_msg
except Exception as e:
error_msg = f"SiliconFlow流式对话出错: {str(e)}"
logger.exception("SiliconFlow流式对话异常")
yield error_msg
@classmethod
def get_available_models(cls, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"):
"""
获取可用的模型列表
"""
import os
# 尝试多种网络配置
proxy_configs = [
# 不使用代理
{'proxies': None, 'verify': True},
# 使用系统代理但禁用SSL验证
{'proxies': None, 'verify': False},
# 明确禁用代理
{'proxies': {'http': None, 'https': None}, 'verify': True},
{'proxies': {'http': None, 'https': None}, 'verify': False},
]
for i, config in enumerate(proxy_configs):
try:
logger.info(f"尝试网络配置 {i+1}/{len(proxy_configs)}")
headers = {"Authorization": f"Bearer {api_key}"}
# 创建会话以便更好地控制连接
session = requests.Session()
if config['proxies'] is not None:
session.proxies.update(config['proxies'])
response = session.get(
"https://api.siliconflow.cn/v1/models",
headers=headers,
timeout=30,
verify=config['verify']
)
if response.status_code == 200:
data = response.json()
all_models = data.get('data', [])
logger.info(f"SiliconFlow API返回 {len(all_models)} 个模型")
models = []
excluded_keywords = ['embedding', 'stable-diffusion', 'bge-', 'rerank', 'whisper']
for model in all_models:
model_id = model.get('id', '')
if not model_id:
continue
# 排除明显的非聊天模型
if any(keyword in model_id.lower() for keyword in excluded_keywords):
continue
# 包含常见的聊天模型关键词或者包含chat、instruct等
chat_keywords = ['chat', 'instruct', 'qwen', 'glm', 'internlm', 'baichuan', 'llama', 'mistral', 'claude', 'gpt', 'yi']
if any(keyword in model_id.lower() for keyword in chat_keywords):
models.append({
'id': model_id,
'name': model_id,
'description': model.get('description', ''),
})
logger.info(f"获取SiliconFlow模型列表成功 - 总数: {len(all_models)}, 聊天模型: {len(models)}")
# 如果过滤后的模型太少,返回所有模型(除了明确的排除项)
if len(models) < 10:
logger.warning("过滤后的聊天模型数量过少,返回所有非排除模型")
models = []
for model in all_models:
model_id = model.get('id', '')
if model_id and not any(keyword in model_id.lower() for keyword in excluded_keywords):
models.append({
'id': model_id,
'name': model_id,
'description': model.get('description', ''),
})
return models
else:
logger.warning(f"网络配置 {i+1} 失败: {response.status_code} - {response.text}")
except Exception as e:
logger.warning(f"网络配置 {i+1} 异常: {str(e)}")
continue
# 所有配置都失败了
logger.error("所有网络配置都失败,无法获取模型列表")
raise Exception("无法连接到SiliconFlow API服务器")
@classmethod
def _get_fallback_models(cls):
"""
当API调用失败时返回预定义的常用模型列表
"""
logger.info("使用预定义的模型列表作为备选方案")
return [
{
'id': 'Qwen/Qwen2.5-7B-Instruct',
'name': 'Qwen2.5-7B-Instruct',
'description': '通义千问2.5 7B指令模型'
},
{
'id': 'Qwen/Qwen2.5-14B-Instruct',
'name': 'Qwen2.5-14B-Instruct',
'description': '通义千问2.5 14B指令模型'
},
{
'id': 'Qwen/Qwen2.5-32B-Instruct',
'name': 'Qwen2.5-32B-Instruct',
'description': '通义千问2.5 32B指令模型'
},
{
'id': 'Qwen/Qwen2.5-72B-Instruct',
'name': 'Qwen2.5-72B-Instruct',
'description': '通义千问2.5 72B指令模型'
},
{
'id': 'Qwen/QwQ-32B-Preview',
'name': 'QwQ-32B-Preview',
'description': '通义千问推理模型预览版'
},
{
'id': 'deepseek-ai/DeepSeek-V2.5',
'name': 'DeepSeek-V2.5',
'description': 'DeepSeek V2.5 模型'
},
{
'id': 'meta-llama/Llama-3.1-8B-Instruct',
'name': 'Llama-3.1-8B-Instruct',
'description': 'Meta Llama 3.1 8B指令模型'
},
{
'id': 'meta-llama/Llama-3.1-70B-Instruct',
'name': 'Llama-3.1-70B-Instruct',
'description': 'Meta Llama 3.1 70B指令模型'
},
{
'id': 'THUDM/glm-4-9b-chat',
'name': 'GLM-4-9B-Chat',
'description': '智谱GLM-4 9B对话模型'
},
{
'id': 'internlm/internlm2_5-7b-chat',
'name': 'InternLM2.5-7B-Chat',
'description': '书生浦语2.5 7B对话模型'
}
]