169 lines
5.6 KiB
Python
169 lines
5.6 KiB
Python
from openai import OpenAI
|
|
import base64
|
|
from datasets import load_dataset
|
|
import requests
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import pickle
|
|
import logging
|
|
from datetime import datetime
|
|
import concurrent.futures
|
|
import threading
|
|
from tqdm import tqdm
|
|
|
|
ds = load_dataset("DBQ/My.Theresa.Product.prices.France")
|
|
# 打印数据集信息
|
|
print(f"数据集加载完成,共包含 {len(ds)} 条数据")
|
|
|
|
# 初始化客户端
|
|
client = OpenAI(
|
|
api_key='sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf',
|
|
base_url='https://api.siliconflow.cn/v1' # 通义千问API的基础URL
|
|
)
|
|
|
|
# 添加线程锁用于安全打印和保存
|
|
print_lock = threading.Lock()
|
|
results_lock = threading.Lock()
|
|
|
|
def encode_image_to_base64(image_path):
|
|
"""将本地图片转换为base64编码"""
|
|
try:
|
|
with open(image_path, "rb") as image_file:
|
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
except Exception as e:
|
|
return None
|
|
|
|
# 使用ChatGPT模型
|
|
def chat_with_gpt(prompt, image_inputs=None):
|
|
try:
|
|
messages = []
|
|
content = [{"type": "text", "text": prompt}]
|
|
|
|
if image_inputs:
|
|
# 确保image_inputs是列表
|
|
if not isinstance(image_inputs, list):
|
|
image_inputs = [image_inputs]
|
|
|
|
# 处理每张图片
|
|
for image_input in image_inputs:
|
|
if image_input.startswith(('http://', 'https://')):
|
|
image_data = {
|
|
"type": "image_url",
|
|
"image_url": {"url": image_input}
|
|
}
|
|
else:
|
|
# 处理本地图片
|
|
base64_image = encode_image_to_base64(image_input)
|
|
if base64_image:
|
|
image_data = {
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
}
|
|
}
|
|
else:
|
|
raise Exception("无法读取本地图片")
|
|
content.append(image_data)
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": content
|
|
})
|
|
|
|
response = client.chat.completions.create(
|
|
model="Qwen/QVQ-72B-Preview",
|
|
messages=messages,
|
|
temperature=0.1, # 降低温度使输出更加确定性
|
|
top_p=0.2, # 降低采样范围,使输出更加保守
|
|
max_tokens=200, # 控制回答长度
|
|
presence_penalty=0.0, # 不鼓励模型谈论新主题
|
|
frequency_penalty=0.0, # 不惩罚频繁词汇
|
|
stream=False
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
return f"发生错误:{str(e)}"
|
|
|
|
# 设置日志配置
|
|
def setup_logging():
|
|
# 创建日志文件名(包含时间戳)
|
|
log_filename = f'process_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_filename, encoding='utf-8'),
|
|
logging.StreamHandler() # 同时输出到控制台
|
|
]
|
|
)
|
|
return log_filename
|
|
|
|
def process_single_item(item):
|
|
"""处理单个数据项的函数"""
|
|
try:
|
|
image_url = item['imageurl']
|
|
response = requests.get(image_url)
|
|
img = Image.open(BytesIO(response.content))
|
|
|
|
prompt = f"The brand of the product in the picture is {item['brand']}, write a caption including brand information."
|
|
gpt_response = chat_with_gpt(prompt, image_url)
|
|
|
|
return {
|
|
'brand': item['brand'],
|
|
'image': img,
|
|
'response': gpt_response
|
|
}
|
|
except Exception as e:
|
|
with print_lock:
|
|
logging.error(f"处理数据时出错: {str(e)}")
|
|
return None
|
|
|
|
def process_dataset(ds, batch_size=20000, max_workers=10):
|
|
results = []
|
|
valid_count = 0
|
|
total_items = len(ds['train'])
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# 创建进度条
|
|
with tqdm(total=total_items, desc="处理数据") as pbar:
|
|
# 提交所有任务
|
|
future_to_item = {executor.submit(process_single_item, item): i
|
|
for i, item in enumerate(ds['train'])}
|
|
|
|
# 处理完成的任务
|
|
for future in concurrent.futures.as_completed(future_to_item):
|
|
result = future.result()
|
|
if result:
|
|
with results_lock:
|
|
results.append(result)
|
|
valid_count += 1
|
|
|
|
# 达到batch_size时保存
|
|
if len(results) >= batch_size:
|
|
save_results(results, valid_count // batch_size)
|
|
results = []
|
|
|
|
pbar.update(1)
|
|
|
|
# 保存剩余结果
|
|
if results:
|
|
save_results(results, (valid_count // batch_size) + 1)
|
|
|
|
logging.info(f"处理完成,共处理 {valid_count} 条有效数据")
|
|
|
|
def save_results(results, batch_num):
|
|
# 保存为pickle文件
|
|
with open(f'batch_{batch_num}.pkl', 'wb') as f:
|
|
pickle.dump(results, f)
|
|
|
|
# 运行处理
|
|
if __name__ == "__main__":
|
|
# 设置日志
|
|
log_file = setup_logging()
|
|
logging.info("开始处理数据集")
|
|
logging.info(f"数据集加载完成,共包含 {len(ds)} 条数据")
|
|
|
|
process_dataset(ds)
|