daren/apps/rlhf/siliconflow_client.py
2025-06-09 16:29:14 +08:00

307 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import json
import time
import logging
logger = logging.getLogger(__name__)
class SiliconFlowClient:
def __init__(self, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf", model="Qwen/QwQ-32B"):
"""
初始化SiliconFlow客户端
"""
self.api_key = api_key
self.model = model
self.base_url = "https://api.siliconflow.cn/v1"
self.messages = []
self.system_message = None
logger.info(f"初始化SiliconFlow客户端 - 模型: {model}")
def set_model(self, model):
"""设置使用的模型"""
self.model = model
logger.info(f"SiliconFlow切换模型: {model}")
def set_system_message(self, message):
"""设置系统消息"""
self.system_message = message
self.messages = []
if message:
self.messages.append({"role": "system", "content": message})
logger.debug(f"SiliconFlow设置系统消息 - 长度: {len(message) if message else 0}")
def add_message(self, role, content):
"""添加消息到对话历史"""
self.messages.append({"role": role, "content": content})
def clear_history(self):
"""清空对话历史"""
self.messages = []
if self.system_message:
self.messages.append({"role": "system", "content": self.system_message})
logger.debug("SiliconFlow清空对话历史")
def chat(self, message):
"""
非流式对话
"""
try:
# 添加用户消息
self.add_message("user", message)
payload = {
"model": self.model,
"messages": self.messages,
"stream": False,
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.7,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers,
timeout=60
)
if response.status_code == 200:
data = response.json()
assistant_message = data['choices'][0]['message']['content']
self.add_message("assistant", assistant_message)
return assistant_message
else:
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
logger.error(error_msg)
return f"API调用失败: {error_msg}"
except Exception as e:
error_msg = f"SiliconFlow对话出错: {str(e)}"
logger.exception("SiliconFlow对话异常")
return error_msg
def chat_stream(self, message):
"""
流式对话
"""
try:
# 添加用户消息
self.add_message("user", message)
payload = {
"model": self.model,
"messages": self.messages,
"stream": True,
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.7,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers,
stream=True,
timeout=120
)
if response.status_code != 200:
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
logger.error(error_msg)
yield f"API调用失败: {error_msg}"
return
assistant_message = ""
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
data_str = line[6:] # 移除 'data: ' 前缀
if data_str.strip() == '[DONE]':
break
try:
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
assistant_message += content
yield content
except json.JSONDecodeError:
continue
# 添加完整的助手回复到历史
if assistant_message:
self.add_message("assistant", assistant_message)
except requests.exceptions.Timeout:
error_msg = "SiliconFlow请求超时"
logger.error(error_msg)
yield error_msg
except Exception as e:
error_msg = f"SiliconFlow流式对话出错: {str(e)}"
logger.exception("SiliconFlow流式对话异常")
yield error_msg
@classmethod
def get_available_models(cls, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"):
"""
获取可用的模型列表
"""
import os
# 尝试多种网络配置
proxy_configs = [
# 不使用代理
{'proxies': None, 'verify': True},
# 使用系统代理但禁用SSL验证
{'proxies': None, 'verify': False},
# 明确禁用代理
{'proxies': {'http': None, 'https': None}, 'verify': True},
{'proxies': {'http': None, 'https': None}, 'verify': False},
]
for i, config in enumerate(proxy_configs):
try:
logger.info(f"尝试网络配置 {i+1}/{len(proxy_configs)}")
headers = {"Authorization": f"Bearer {api_key}"}
# 创建会话以便更好地控制连接
session = requests.Session()
if config['proxies'] is not None:
session.proxies.update(config['proxies'])
response = session.get(
"https://api.siliconflow.cn/v1/models",
headers=headers,
timeout=30,
verify=config['verify']
)
if response.status_code == 200:
data = response.json()
all_models = data.get('data', [])
logger.info(f"SiliconFlow API返回 {len(all_models)} 个模型")
models = []
excluded_keywords = ['embedding', 'stable-diffusion', 'bge-', 'rerank', 'whisper']
for model in all_models:
model_id = model.get('id', '')
if not model_id:
continue
# 排除明显的非聊天模型
if any(keyword in model_id.lower() for keyword in excluded_keywords):
continue
# 包含常见的聊天模型关键词或者包含chat、instruct等
chat_keywords = ['chat', 'instruct', 'qwen', 'glm', 'internlm', 'baichuan', 'llama', 'mistral', 'claude', 'gpt', 'yi']
if any(keyword in model_id.lower() for keyword in chat_keywords):
models.append({
'id': model_id,
'name': model_id,
'description': model.get('description', ''),
})
logger.info(f"获取SiliconFlow模型列表成功 - 总数: {len(all_models)}, 聊天模型: {len(models)}")
# 如果过滤后的模型太少,返回所有模型(除了明确的排除项)
if len(models) < 10:
logger.warning("过滤后的聊天模型数量过少,返回所有非排除模型")
models = []
for model in all_models:
model_id = model.get('id', '')
if model_id and not any(keyword in model_id.lower() for keyword in excluded_keywords):
models.append({
'id': model_id,
'name': model_id,
'description': model.get('description', ''),
})
return models
else:
logger.warning(f"网络配置 {i+1} 失败: {response.status_code} - {response.text}")
except Exception as e:
logger.warning(f"网络配置 {i+1} 异常: {str(e)}")
continue
# 所有配置都失败了
logger.error("所有网络配置都失败,无法获取模型列表")
raise Exception("无法连接到SiliconFlow API服务器")
@classmethod
def _get_fallback_models(cls):
"""
当API调用失败时返回预定义的常用模型列表
"""
logger.info("使用预定义的模型列表作为备选方案")
return [
{
'id': 'Qwen/Qwen2.5-7B-Instruct',
'name': 'Qwen2.5-7B-Instruct',
'description': '通义千问2.5 7B指令模型'
},
{
'id': 'Qwen/Qwen2.5-14B-Instruct',
'name': 'Qwen2.5-14B-Instruct',
'description': '通义千问2.5 14B指令模型'
},
{
'id': 'Qwen/Qwen2.5-32B-Instruct',
'name': 'Qwen2.5-32B-Instruct',
'description': '通义千问2.5 32B指令模型'
},
{
'id': 'Qwen/Qwen2.5-72B-Instruct',
'name': 'Qwen2.5-72B-Instruct',
'description': '通义千问2.5 72B指令模型'
},
{
'id': 'Qwen/QwQ-32B-Preview',
'name': 'QwQ-32B-Preview',
'description': '通义千问推理模型预览版'
},
{
'id': 'deepseek-ai/DeepSeek-V2.5',
'name': 'DeepSeek-V2.5',
'description': 'DeepSeek V2.5 模型'
},
{
'id': 'meta-llama/Llama-3.1-8B-Instruct',
'name': 'Llama-3.1-8B-Instruct',
'description': 'Meta Llama 3.1 8B指令模型'
},
{
'id': 'meta-llama/Llama-3.1-70B-Instruct',
'name': 'Llama-3.1-70B-Instruct',
'description': 'Meta Llama 3.1 70B指令模型'
},
{
'id': 'THUDM/glm-4-9b-chat',
'name': 'GLM-4-9B-Chat',
'description': '智谱GLM-4 9B对话模型'
},
{
'id': 'internlm/internlm2_5-7b-chat',
'name': 'InternLM2.5-7B-Chat',
'description': '书生浦语2.5 7B对话模型'
}
]