307 lines
12 KiB
Python
307 lines
12 KiB
Python
![]() |
import requests
|
|||
|
import json
|
|||
|
import time
|
|||
|
import logging
|
|||
|
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
class SiliconFlowClient:
|
|||
|
def __init__(self, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf", model="Qwen/QwQ-32B"):
|
|||
|
"""
|
|||
|
初始化SiliconFlow客户端
|
|||
|
"""
|
|||
|
self.api_key = api_key
|
|||
|
self.model = model
|
|||
|
self.base_url = "https://api.siliconflow.cn/v1"
|
|||
|
self.messages = []
|
|||
|
self.system_message = None
|
|||
|
|
|||
|
logger.info(f"初始化SiliconFlow客户端 - 模型: {model}")
|
|||
|
|
|||
|
def set_model(self, model):
|
|||
|
"""设置使用的模型"""
|
|||
|
self.model = model
|
|||
|
logger.info(f"SiliconFlow切换模型: {model}")
|
|||
|
|
|||
|
def set_system_message(self, message):
|
|||
|
"""设置系统消息"""
|
|||
|
self.system_message = message
|
|||
|
self.messages = []
|
|||
|
if message:
|
|||
|
self.messages.append({"role": "system", "content": message})
|
|||
|
logger.debug(f"SiliconFlow设置系统消息 - 长度: {len(message) if message else 0}")
|
|||
|
|
|||
|
def add_message(self, role, content):
|
|||
|
"""添加消息到对话历史"""
|
|||
|
self.messages.append({"role": role, "content": content})
|
|||
|
|
|||
|
def clear_history(self):
|
|||
|
"""清空对话历史"""
|
|||
|
self.messages = []
|
|||
|
if self.system_message:
|
|||
|
self.messages.append({"role": "system", "content": self.system_message})
|
|||
|
logger.debug("SiliconFlow清空对话历史")
|
|||
|
|
|||
|
def chat(self, message):
|
|||
|
"""
|
|||
|
非流式对话
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 添加用户消息
|
|||
|
self.add_message("user", message)
|
|||
|
|
|||
|
payload = {
|
|||
|
"model": self.model,
|
|||
|
"messages": self.messages,
|
|||
|
"stream": False,
|
|||
|
"max_tokens": 2048,
|
|||
|
"temperature": 0.7,
|
|||
|
"top_p": 0.7,
|
|||
|
}
|
|||
|
|
|||
|
headers = {
|
|||
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
"Content-Type": "application/json"
|
|||
|
}
|
|||
|
|
|||
|
response = requests.post(
|
|||
|
f"{self.base_url}/chat/completions",
|
|||
|
json=payload,
|
|||
|
headers=headers,
|
|||
|
timeout=60
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
data = response.json()
|
|||
|
assistant_message = data['choices'][0]['message']['content']
|
|||
|
self.add_message("assistant", assistant_message)
|
|||
|
return assistant_message
|
|||
|
else:
|
|||
|
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
|
|||
|
logger.error(error_msg)
|
|||
|
return f"API调用失败: {error_msg}"
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
error_msg = f"SiliconFlow对话出错: {str(e)}"
|
|||
|
logger.exception("SiliconFlow对话异常")
|
|||
|
return error_msg
|
|||
|
|
|||
|
def chat_stream(self, message):
|
|||
|
"""
|
|||
|
流式对话
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 添加用户消息
|
|||
|
self.add_message("user", message)
|
|||
|
|
|||
|
payload = {
|
|||
|
"model": self.model,
|
|||
|
"messages": self.messages,
|
|||
|
"stream": True,
|
|||
|
"max_tokens": 2048,
|
|||
|
"temperature": 0.7,
|
|||
|
"top_p": 0.7,
|
|||
|
}
|
|||
|
|
|||
|
headers = {
|
|||
|
"Authorization": f"Bearer {self.api_key}",
|
|||
|
"Content-Type": "application/json"
|
|||
|
}
|
|||
|
|
|||
|
response = requests.post(
|
|||
|
f"{self.base_url}/chat/completions",
|
|||
|
json=payload,
|
|||
|
headers=headers,
|
|||
|
stream=True,
|
|||
|
timeout=120
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code != 200:
|
|||
|
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
|
|||
|
logger.error(error_msg)
|
|||
|
yield f"API调用失败: {error_msg}"
|
|||
|
return
|
|||
|
|
|||
|
assistant_message = ""
|
|||
|
for line in response.iter_lines():
|
|||
|
if line:
|
|||
|
line = line.decode('utf-8')
|
|||
|
if line.startswith('data: '):
|
|||
|
data_str = line[6:] # 移除 'data: ' 前缀
|
|||
|
|
|||
|
if data_str.strip() == '[DONE]':
|
|||
|
break
|
|||
|
|
|||
|
try:
|
|||
|
data = json.loads(data_str)
|
|||
|
if 'choices' in data and len(data['choices']) > 0:
|
|||
|
delta = data['choices'][0].get('delta', {})
|
|||
|
content = delta.get('content', '')
|
|||
|
|
|||
|
if content:
|
|||
|
assistant_message += content
|
|||
|
yield content
|
|||
|
|
|||
|
except json.JSONDecodeError:
|
|||
|
continue
|
|||
|
|
|||
|
# 添加完整的助手回复到历史
|
|||
|
if assistant_message:
|
|||
|
self.add_message("assistant", assistant_message)
|
|||
|
|
|||
|
except requests.exceptions.Timeout:
|
|||
|
error_msg = "SiliconFlow请求超时"
|
|||
|
logger.error(error_msg)
|
|||
|
yield error_msg
|
|||
|
except Exception as e:
|
|||
|
error_msg = f"SiliconFlow流式对话出错: {str(e)}"
|
|||
|
logger.exception("SiliconFlow流式对话异常")
|
|||
|
yield error_msg
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_available_models(cls, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"):
|
|||
|
"""
|
|||
|
获取可用的模型列表
|
|||
|
"""
|
|||
|
import os
|
|||
|
|
|||
|
# 尝试多种网络配置
|
|||
|
proxy_configs = [
|
|||
|
# 不使用代理
|
|||
|
{'proxies': None, 'verify': True},
|
|||
|
# 使用系统代理但禁用SSL验证
|
|||
|
{'proxies': None, 'verify': False},
|
|||
|
# 明确禁用代理
|
|||
|
{'proxies': {'http': None, 'https': None}, 'verify': True},
|
|||
|
{'proxies': {'http': None, 'https': None}, 'verify': False},
|
|||
|
]
|
|||
|
|
|||
|
for i, config in enumerate(proxy_configs):
|
|||
|
try:
|
|||
|
logger.info(f"尝试网络配置 {i+1}/{len(proxy_configs)}")
|
|||
|
|
|||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|||
|
|
|||
|
# 创建会话以便更好地控制连接
|
|||
|
session = requests.Session()
|
|||
|
if config['proxies'] is not None:
|
|||
|
session.proxies.update(config['proxies'])
|
|||
|
|
|||
|
response = session.get(
|
|||
|
"https://api.siliconflow.cn/v1/models",
|
|||
|
headers=headers,
|
|||
|
timeout=30,
|
|||
|
verify=config['verify']
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
data = response.json()
|
|||
|
all_models = data.get('data', [])
|
|||
|
logger.info(f"SiliconFlow API返回 {len(all_models)} 个模型")
|
|||
|
|
|||
|
models = []
|
|||
|
excluded_keywords = ['embedding', 'stable-diffusion', 'bge-', 'rerank', 'whisper']
|
|||
|
|
|||
|
for model in all_models:
|
|||
|
model_id = model.get('id', '')
|
|||
|
if not model_id:
|
|||
|
continue
|
|||
|
|
|||
|
# 排除明显的非聊天模型
|
|||
|
if any(keyword in model_id.lower() for keyword in excluded_keywords):
|
|||
|
continue
|
|||
|
|
|||
|
# 包含常见的聊天模型关键词或者包含chat、instruct等
|
|||
|
chat_keywords = ['chat', 'instruct', 'qwen', 'glm', 'internlm', 'baichuan', 'llama', 'mistral', 'claude', 'gpt', 'yi']
|
|||
|
if any(keyword in model_id.lower() for keyword in chat_keywords):
|
|||
|
models.append({
|
|||
|
'id': model_id,
|
|||
|
'name': model_id,
|
|||
|
'description': model.get('description', ''),
|
|||
|
})
|
|||
|
|
|||
|
logger.info(f"获取SiliconFlow模型列表成功 - 总数: {len(all_models)}, 聊天模型: {len(models)}")
|
|||
|
|
|||
|
# 如果过滤后的模型太少,返回所有模型(除了明确的排除项)
|
|||
|
if len(models) < 10:
|
|||
|
logger.warning("过滤后的聊天模型数量过少,返回所有非排除模型")
|
|||
|
models = []
|
|||
|
for model in all_models:
|
|||
|
model_id = model.get('id', '')
|
|||
|
if model_id and not any(keyword in model_id.lower() for keyword in excluded_keywords):
|
|||
|
models.append({
|
|||
|
'id': model_id,
|
|||
|
'name': model_id,
|
|||
|
'description': model.get('description', ''),
|
|||
|
})
|
|||
|
|
|||
|
return models
|
|||
|
else:
|
|||
|
logger.warning(f"网络配置 {i+1} 失败: {response.status_code} - {response.text}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"网络配置 {i+1} 异常: {str(e)}")
|
|||
|
continue
|
|||
|
|
|||
|
# 所有配置都失败了
|
|||
|
logger.error("所有网络配置都失败,无法获取模型列表")
|
|||
|
raise Exception("无法连接到SiliconFlow API服务器")
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _get_fallback_models(cls):
|
|||
|
"""
|
|||
|
当API调用失败时,返回预定义的常用模型列表
|
|||
|
"""
|
|||
|
logger.info("使用预定义的模型列表作为备选方案")
|
|||
|
return [
|
|||
|
{
|
|||
|
'id': 'Qwen/Qwen2.5-7B-Instruct',
|
|||
|
'name': 'Qwen2.5-7B-Instruct',
|
|||
|
'description': '通义千问2.5 7B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'Qwen/Qwen2.5-14B-Instruct',
|
|||
|
'name': 'Qwen2.5-14B-Instruct',
|
|||
|
'description': '通义千问2.5 14B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'Qwen/Qwen2.5-32B-Instruct',
|
|||
|
'name': 'Qwen2.5-32B-Instruct',
|
|||
|
'description': '通义千问2.5 32B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'Qwen/Qwen2.5-72B-Instruct',
|
|||
|
'name': 'Qwen2.5-72B-Instruct',
|
|||
|
'description': '通义千问2.5 72B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'Qwen/QwQ-32B-Preview',
|
|||
|
'name': 'QwQ-32B-Preview',
|
|||
|
'description': '通义千问推理模型预览版'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'deepseek-ai/DeepSeek-V2.5',
|
|||
|
'name': 'DeepSeek-V2.5',
|
|||
|
'description': 'DeepSeek V2.5 模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'meta-llama/Llama-3.1-8B-Instruct',
|
|||
|
'name': 'Llama-3.1-8B-Instruct',
|
|||
|
'description': 'Meta Llama 3.1 8B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'meta-llama/Llama-3.1-70B-Instruct',
|
|||
|
'name': 'Llama-3.1-70B-Instruct',
|
|||
|
'description': 'Meta Llama 3.1 70B指令模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'THUDM/glm-4-9b-chat',
|
|||
|
'name': 'GLM-4-9B-Chat',
|
|||
|
'description': '智谱GLM-4 9B对话模型'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'internlm/internlm2_5-7b-chat',
|
|||
|
'name': 'InternLM2.5-7B-Chat',
|
|||
|
'description': '书生浦语2.5 7B对话模型'
|
|||
|
}
|
|||
|
]
|