from openai import OpenAI import base64 from datasets import load_dataset import requests from PIL import Image from io import BytesIO import pickle import logging from datetime import datetime import concurrent.futures import threading from tqdm import tqdm ds = load_dataset("DBQ/My.Theresa.Product.prices.France") # 打印数据集信息 print(f"数据集加载完成,共包含 {len(ds)} 条数据") # 初始化客户端 client = OpenAI( api_key='sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf', base_url='https://api.siliconflow.cn/v1' # 通义千问API的基础URL ) # 添加线程锁用于安全打印和保存 print_lock = threading.Lock() results_lock = threading.Lock() def encode_image_to_base64(image_path): """将本地图片转换为base64编码""" try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') except Exception as e: return None # 使用ChatGPT模型 def chat_with_gpt(prompt, image_inputs=None): try: messages = [] content = [{"type": "text", "text": prompt}] if image_inputs: # 确保image_inputs是列表 if not isinstance(image_inputs, list): image_inputs = [image_inputs] # 处理每张图片 for image_input in image_inputs: if image_input.startswith(('http://', 'https://')): image_data = { "type": "image_url", "image_url": {"url": image_input} } else: # 处理本地图片 base64_image = encode_image_to_base64(image_input) if base64_image: image_data = { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } } else: raise Exception("无法读取本地图片") content.append(image_data) messages.append({ "role": "user", "content": content }) response = client.chat.completions.create( model="Qwen/QVQ-72B-Preview", messages=messages, temperature=0.1, # 降低温度使输出更加确定性 top_p=0.2, # 降低采样范围,使输出更加保守 max_tokens=200, # 控制回答长度 presence_penalty=0.0, # 不鼓励模型谈论新主题 frequency_penalty=0.0, # 不惩罚频繁词汇 stream=False ) return response.choices[0].message.content except Exception as e: return f"发生错误:{str(e)}" # 设置日志配置 def setup_logging(): # 创建日志文件名(包含时间戳) log_filename = f'process_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt' # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_filename, encoding='utf-8'), logging.StreamHandler() # 同时输出到控制台 ] ) return log_filename def process_single_item(item): """处理单个数据项的函数""" try: image_url = item['imageurl'] response = requests.get(image_url) img = Image.open(BytesIO(response.content)) prompt = f"The brand of the product in the picture is {item['brand']}, write a caption including brand information." gpt_response = chat_with_gpt(prompt, image_url) return { 'brand': item['brand'], 'image': img, 'response': gpt_response } except Exception as e: with print_lock: logging.error(f"处理数据时出错: {str(e)}") return None def process_dataset(ds, batch_size=20000, max_workers=10): results = [] valid_count = 0 total_items = len(ds['train']) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # 创建进度条 with tqdm(total=total_items, desc="处理数据") as pbar: # 提交所有任务 future_to_item = {executor.submit(process_single_item, item): i for i, item in enumerate(ds['train'])} # 处理完成的任务 for future in concurrent.futures.as_completed(future_to_item): result = future.result() if result: with results_lock: results.append(result) valid_count += 1 # 达到batch_size时保存 if len(results) >= batch_size: save_results(results, valid_count // batch_size) results = [] pbar.update(1) # 保存剩余结果 if results: save_results(results, (valid_count // batch_size) + 1) logging.info(f"处理完成,共处理 {valid_count} 条有效数据") def save_results(results, batch_num): # 保存为pickle文件 with open(f'batch_{batch_num}.pkl', 'wb') as f: pickle.dump(results, f) # 运行处理 if __name__ == "__main__": # 设置日志 log_file = setup_logging() logging.info("开始处理数据集") logging.info(f"数据集加载完成,共包含 {len(ds)} 条数据") process_dataset(ds)