307 lines
12 KiB
Python
307 lines
12 KiB
Python
import requests
|
||
import json
|
||
import time
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class SiliconFlowClient:
|
||
def __init__(self, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf", model="Qwen/QwQ-32B"):
|
||
"""
|
||
初始化SiliconFlow客户端
|
||
"""
|
||
self.api_key = api_key
|
||
self.model = model
|
||
self.base_url = "https://api.siliconflow.cn/v1"
|
||
self.messages = []
|
||
self.system_message = None
|
||
|
||
logger.info(f"初始化SiliconFlow客户端 - 模型: {model}")
|
||
|
||
def set_model(self, model):
|
||
"""设置使用的模型"""
|
||
self.model = model
|
||
logger.info(f"SiliconFlow切换模型: {model}")
|
||
|
||
def set_system_message(self, message):
|
||
"""设置系统消息"""
|
||
self.system_message = message
|
||
self.messages = []
|
||
if message:
|
||
self.messages.append({"role": "system", "content": message})
|
||
logger.debug(f"SiliconFlow设置系统消息 - 长度: {len(message) if message else 0}")
|
||
|
||
def add_message(self, role, content):
|
||
"""添加消息到对话历史"""
|
||
self.messages.append({"role": role, "content": content})
|
||
|
||
def clear_history(self):
|
||
"""清空对话历史"""
|
||
self.messages = []
|
||
if self.system_message:
|
||
self.messages.append({"role": "system", "content": self.system_message})
|
||
logger.debug("SiliconFlow清空对话历史")
|
||
|
||
def chat(self, message):
|
||
"""
|
||
非流式对话
|
||
"""
|
||
try:
|
||
# 添加用户消息
|
||
self.add_message("user", message)
|
||
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": self.messages,
|
||
"stream": False,
|
||
"max_tokens": 2048,
|
||
"temperature": 0.7,
|
||
"top_p": 0.7,
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{self.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=60
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
assistant_message = data['choices'][0]['message']['content']
|
||
self.add_message("assistant", assistant_message)
|
||
return assistant_message
|
||
else:
|
||
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
|
||
logger.error(error_msg)
|
||
return f"API调用失败: {error_msg}"
|
||
|
||
except Exception as e:
|
||
error_msg = f"SiliconFlow对话出错: {str(e)}"
|
||
logger.exception("SiliconFlow对话异常")
|
||
return error_msg
|
||
|
||
def chat_stream(self, message):
|
||
"""
|
||
流式对话
|
||
"""
|
||
try:
|
||
# 添加用户消息
|
||
self.add_message("user", message)
|
||
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": self.messages,
|
||
"stream": True,
|
||
"max_tokens": 2048,
|
||
"temperature": 0.7,
|
||
"top_p": 0.7,
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{self.base_url}/chat/completions",
|
||
json=payload,
|
||
headers=headers,
|
||
stream=True,
|
||
timeout=120
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
error_msg = f"SiliconFlow API错误: {response.status_code} - {response.text}"
|
||
logger.error(error_msg)
|
||
yield f"API调用失败: {error_msg}"
|
||
return
|
||
|
||
assistant_message = ""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith('data: '):
|
||
data_str = line[6:] # 移除 'data: ' 前缀
|
||
|
||
if data_str.strip() == '[DONE]':
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
if 'choices' in data and len(data['choices']) > 0:
|
||
delta = data['choices'][0].get('delta', {})
|
||
content = delta.get('content', '')
|
||
|
||
if content:
|
||
assistant_message += content
|
||
yield content
|
||
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 添加完整的助手回复到历史
|
||
if assistant_message:
|
||
self.add_message("assistant", assistant_message)
|
||
|
||
except requests.exceptions.Timeout:
|
||
error_msg = "SiliconFlow请求超时"
|
||
logger.error(error_msg)
|
||
yield error_msg
|
||
except Exception as e:
|
||
error_msg = f"SiliconFlow流式对话出错: {str(e)}"
|
||
logger.exception("SiliconFlow流式对话异常")
|
||
yield error_msg
|
||
|
||
@classmethod
|
||
def get_available_models(cls, api_key="sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf"):
|
||
"""
|
||
获取可用的模型列表
|
||
"""
|
||
import os
|
||
|
||
# 尝试多种网络配置
|
||
proxy_configs = [
|
||
# 不使用代理
|
||
{'proxies': None, 'verify': True},
|
||
# 使用系统代理但禁用SSL验证
|
||
{'proxies': None, 'verify': False},
|
||
# 明确禁用代理
|
||
{'proxies': {'http': None, 'https': None}, 'verify': True},
|
||
{'proxies': {'http': None, 'https': None}, 'verify': False},
|
||
]
|
||
|
||
for i, config in enumerate(proxy_configs):
|
||
try:
|
||
logger.info(f"尝试网络配置 {i+1}/{len(proxy_configs)}")
|
||
|
||
headers = {"Authorization": f"Bearer {api_key}"}
|
||
|
||
# 创建会话以便更好地控制连接
|
||
session = requests.Session()
|
||
if config['proxies'] is not None:
|
||
session.proxies.update(config['proxies'])
|
||
|
||
response = session.get(
|
||
"https://api.siliconflow.cn/v1/models",
|
||
headers=headers,
|
||
timeout=30,
|
||
verify=config['verify']
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
all_models = data.get('data', [])
|
||
logger.info(f"SiliconFlow API返回 {len(all_models)} 个模型")
|
||
|
||
models = []
|
||
excluded_keywords = ['embedding', 'stable-diffusion', 'bge-', 'rerank', 'whisper']
|
||
|
||
for model in all_models:
|
||
model_id = model.get('id', '')
|
||
if not model_id:
|
||
continue
|
||
|
||
# 排除明显的非聊天模型
|
||
if any(keyword in model_id.lower() for keyword in excluded_keywords):
|
||
continue
|
||
|
||
# 包含常见的聊天模型关键词或者包含chat、instruct等
|
||
chat_keywords = ['chat', 'instruct', 'qwen', 'glm', 'internlm', 'baichuan', 'llama', 'mistral', 'claude', 'gpt', 'yi']
|
||
if any(keyword in model_id.lower() for keyword in chat_keywords):
|
||
models.append({
|
||
'id': model_id,
|
||
'name': model_id,
|
||
'description': model.get('description', ''),
|
||
})
|
||
|
||
logger.info(f"获取SiliconFlow模型列表成功 - 总数: {len(all_models)}, 聊天模型: {len(models)}")
|
||
|
||
# 如果过滤后的模型太少,返回所有模型(除了明确的排除项)
|
||
if len(models) < 10:
|
||
logger.warning("过滤后的聊天模型数量过少,返回所有非排除模型")
|
||
models = []
|
||
for model in all_models:
|
||
model_id = model.get('id', '')
|
||
if model_id and not any(keyword in model_id.lower() for keyword in excluded_keywords):
|
||
models.append({
|
||
'id': model_id,
|
||
'name': model_id,
|
||
'description': model.get('description', ''),
|
||
})
|
||
|
||
return models
|
||
else:
|
||
logger.warning(f"网络配置 {i+1} 失败: {response.status_code} - {response.text}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"网络配置 {i+1} 异常: {str(e)}")
|
||
continue
|
||
|
||
# 所有配置都失败了
|
||
logger.error("所有网络配置都失败,无法获取模型列表")
|
||
raise Exception("无法连接到SiliconFlow API服务器")
|
||
|
||
@classmethod
|
||
def _get_fallback_models(cls):
|
||
"""
|
||
当API调用失败时,返回预定义的常用模型列表
|
||
"""
|
||
logger.info("使用预定义的模型列表作为备选方案")
|
||
return [
|
||
{
|
||
'id': 'Qwen/Qwen2.5-7B-Instruct',
|
||
'name': 'Qwen2.5-7B-Instruct',
|
||
'description': '通义千问2.5 7B指令模型'
|
||
},
|
||
{
|
||
'id': 'Qwen/Qwen2.5-14B-Instruct',
|
||
'name': 'Qwen2.5-14B-Instruct',
|
||
'description': '通义千问2.5 14B指令模型'
|
||
},
|
||
{
|
||
'id': 'Qwen/Qwen2.5-32B-Instruct',
|
||
'name': 'Qwen2.5-32B-Instruct',
|
||
'description': '通义千问2.5 32B指令模型'
|
||
},
|
||
{
|
||
'id': 'Qwen/Qwen2.5-72B-Instruct',
|
||
'name': 'Qwen2.5-72B-Instruct',
|
||
'description': '通义千问2.5 72B指令模型'
|
||
},
|
||
{
|
||
'id': 'Qwen/QwQ-32B-Preview',
|
||
'name': 'QwQ-32B-Preview',
|
||
'description': '通义千问推理模型预览版'
|
||
},
|
||
{
|
||
'id': 'deepseek-ai/DeepSeek-V2.5',
|
||
'name': 'DeepSeek-V2.5',
|
||
'description': 'DeepSeek V2.5 模型'
|
||
},
|
||
{
|
||
'id': 'meta-llama/Llama-3.1-8B-Instruct',
|
||
'name': 'Llama-3.1-8B-Instruct',
|
||
'description': 'Meta Llama 3.1 8B指令模型'
|
||
},
|
||
{
|
||
'id': 'meta-llama/Llama-3.1-70B-Instruct',
|
||
'name': 'Llama-3.1-70B-Instruct',
|
||
'description': 'Meta Llama 3.1 70B指令模型'
|
||
},
|
||
{
|
||
'id': 'THUDM/glm-4-9b-chat',
|
||
'name': 'GLM-4-9B-Chat',
|
||
'description': '智谱GLM-4 9B对话模型'
|
||
},
|
||
{
|
||
'id': 'internlm/internlm2_5-7b-chat',
|
||
'name': 'InternLM2.5-7B-Chat',
|
||
'description': '书生浦语2.5 7B对话模型'
|
||
}
|
||
] |