Fortrain/webdata.py

169 lines
5.6 KiB
Python
Raw Permalink Normal View History

from openai import OpenAI
import base64
from datasets import load_dataset
import requests
from PIL import Image
from io import BytesIO
import pickle
import logging
from datetime import datetime
import concurrent.futures
import threading
from tqdm import tqdm
ds = load_dataset("DBQ/My.Theresa.Product.prices.France")
# 打印数据集信息
print(f"数据集加载完成,共包含 {len(ds)} 条数据")
# 初始化客户端
client = OpenAI(
api_key='sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf',
base_url='https://api.siliconflow.cn/v1' # 通义千问API的基础URL
)
# 添加线程锁用于安全打印和保存
print_lock = threading.Lock()
results_lock = threading.Lock()
def encode_image_to_base64(image_path):
"""将本地图片转换为base64编码"""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
except Exception as e:
return None
# 使用ChatGPT模型
def chat_with_gpt(prompt, image_inputs=None):
try:
messages = []
content = [{"type": "text", "text": prompt}]
if image_inputs:
# 确保image_inputs是列表
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
# 处理每张图片
for image_input in image_inputs:
if image_input.startswith(('http://', 'https://')):
image_data = {
"type": "image_url",
"image_url": {"url": image_input}
}
else:
# 处理本地图片
base64_image = encode_image_to_base64(image_input)
if base64_image:
image_data = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
else:
raise Exception("无法读取本地图片")
content.append(image_data)
messages.append({
"role": "user",
"content": content
})
response = client.chat.completions.create(
model="Qwen/QVQ-72B-Preview",
messages=messages,
temperature=0.1, # 降低温度使输出更加确定性
top_p=0.2, # 降低采样范围,使输出更加保守
max_tokens=200, # 控制回答长度
presence_penalty=0.0, # 不鼓励模型谈论新主题
frequency_penalty=0.0, # 不惩罚频繁词汇
stream=False
)
return response.choices[0].message.content
except Exception as e:
return f"发生错误:{str(e)}"
# 设置日志配置
def setup_logging():
# 创建日志文件名(包含时间戳)
log_filename = f'process_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_filename, encoding='utf-8'),
logging.StreamHandler() # 同时输出到控制台
]
)
return log_filename
def process_single_item(item):
"""处理单个数据项的函数"""
try:
image_url = item['imageurl']
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
prompt = f"The brand of the product in the picture is {item['brand']}, write a caption including brand information."
gpt_response = chat_with_gpt(prompt, image_url)
return {
'brand': item['brand'],
'image': img,
'response': gpt_response
}
except Exception as e:
with print_lock:
logging.error(f"处理数据时出错: {str(e)}")
return None
def process_dataset(ds, batch_size=20000, max_workers=10):
results = []
valid_count = 0
total_items = len(ds['train'])
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 创建进度条
with tqdm(total=total_items, desc="处理数据") as pbar:
# 提交所有任务
future_to_item = {executor.submit(process_single_item, item): i
for i, item in enumerate(ds['train'])}
# 处理完成的任务
for future in concurrent.futures.as_completed(future_to_item):
result = future.result()
if result:
with results_lock:
results.append(result)
valid_count += 1
# 达到batch_size时保存
if len(results) >= batch_size:
save_results(results, valid_count // batch_size)
results = []
pbar.update(1)
# 保存剩余结果
if results:
save_results(results, (valid_count // batch_size) + 1)
logging.info(f"处理完成,共处理 {valid_count} 条有效数据")
def save_results(results, batch_num):
# 保存为pickle文件
with open(f'batch_{batch_num}.pkl', 'wb') as f:
pickle.dump(results, f)
# 运行处理
if __name__ == "__main__":
# 设置日志
log_file = setup_logging()
logging.info("开始处理数据集")
logging.info(f"数据集加载完成,共包含 {len(ds)} 条数据")
process_dataset(ds)