diff --git a/clean_text.py b/clean_text.py new file mode 100644 index 0000000..bd18302 --- /dev/null +++ b/clean_text.py @@ -0,0 +1,43 @@ +import pickle + +def clean_text(text): + # 将文本反转 + reversed_text = text[::-1] + # 查找第一个句号的位置 + dot_pos = reversed_text.find('.') + + if dot_pos == -1: # 如果没有找到句号 + return text + + # 删除句号之前的所有文本,然后再次反转 + cleaned_text = reversed_text[dot_pos:][::-1] + return cleaned_text.strip() + +# 加载原始数据 +print("正在加载数据...") +with open('batch_1.pkl', 'rb') as f: + data = pickle.load(f) + +# 处理文本 +print("正在处理文本...") +cleaned_data = [] +for item in data: + cleaned_item = item.copy() # 复制原始数据项 + cleaned_item['response'] = clean_text(item['response']) + cleaned_data.append(cleaned_item) + +# 保存处理后的数据 +print("正在保存清理后的数据...") +with open('batch_1_cleaned.pkl', 'wb') as f: + pickle.dump(cleaned_data, f) + +# 打印示例 +print("\n处理示例:") +for i in range(min(3, len(data))): + print(f"\n原始文本 {i+1}:") + print(data[i]['response']) + print(f"\n处理后文本 {i+1}:") + print(cleaned_data[i]['response']) + +print(f"\n总数据量: {len(data)}") +print("数据已保存到 batch_1_cleaned.pkl") \ No newline at end of file diff --git a/test2.py b/test2.py new file mode 100644 index 0000000..93f5817 --- /dev/null +++ b/test2.py @@ -0,0 +1,168 @@ +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '3' + +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +import requests +from io import BytesIO +from urllib.parse import urlparse +from torch.profiler import profile, record_function, ProfilerActivity +import time +import pickle +import random +from datetime import datetime + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + +def load_image(image_file, input_size=448, max_num=12): + # 如果已经是 PIL Image 对象 + if isinstance(image_file, Image.Image): + image = image_file.convert('RGB') + # 如果是 URL + elif isinstance(image_file, str) and bool(urlparse(image_file).netloc): + try: + response = requests.get(image_file, timeout=10) + response.raise_for_status() # 检查请求是否成功 + image = Image.open(BytesIO(response.content)).convert('RGB') + except Exception as e: + raise ValueError(f"无法从URL加载图片: {str(e)}") + # 如果是本地文件路径 + elif isinstance(image_file, str): + image = Image.open(image_file).convert('RGB') + # 如果是字节数据 + elif isinstance(image_file, bytes): + image = Image.open(BytesIO(image_file)).convert('RGB') + else: + raise ValueError(f"不支持的图片格式: {type(image_file)}") + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + +# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section. +path = 'Internvl2_5' +model = AutoModel.from_pretrained( + path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + use_flash_attn=True, + trust_remote_code=True, + vision_model = None, + language_model = None).eval().cuda() + +# 加载要测试的权重 +state_dict = torch.load("mlp_epoch_5.pth") # 加载训练好的权重 +model.load_state_dict(state_dict) +generation_config = dict(max_new_tokens=1024, do_sample=True) +tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) + +print(model) + +# 测试 +# 加载数据 +with open('batch_1.pkl', 'rb') as f: + data = pickle.load(f) + +# 随机选择100条数据 +test_samples = random.sample(data, 100) + +# 创建结果文件 +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +results_file = f'test_results_{timestamp}.txt' + +with open(results_file, 'w', encoding='utf-8') as f: + for i, test_item in enumerate(test_samples, 1): + question = "Identify the brand of the product in the picture, and write a caption including brand information in 200 words." + expected_answer = test_item['response'] + + try: + pixel_values = load_image(test_item['image'], max_num=12).to(torch.bfloat16).cuda() + response, _ = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True) + + # 写入结果 + f.write(f"Test case {i}/100:\n") + f.write(f"Question: {question}\n") + f.write(f"Expected: {expected_answer}\n") + f.write(f"Response: {response}\n") + f.write("-" * 50 + "\n") + + # 打印进度 + print(f"Processed {i}/100 samples") + + except Exception as e: + f.write(f"Test case {i}/100 - ERROR:\n") + f.write(f"Error message: {str(e)}\n") + f.write("-" * 50 + "\n") + print(f"Error processing sample {i}: {str(e)}") + +print(f"测试完成!结果已保存到文件:{results_file}") diff --git a/webdata.py b/webdata.py new file mode 100644 index 0000000..f39738e --- /dev/null +++ b/webdata.py @@ -0,0 +1,168 @@ +from openai import OpenAI +import base64 +from datasets import load_dataset +import requests +from PIL import Image +from io import BytesIO +import pickle +import logging +from datetime import datetime +import concurrent.futures +import threading +from tqdm import tqdm + +ds = load_dataset("DBQ/My.Theresa.Product.prices.France") +# 打印数据集信息 +print(f"数据集加载完成,共包含 {len(ds)} 条数据") + +# 初始化客户端 +client = OpenAI( + api_key='sk-xqbujijjqqmlmlvkhvxeogqjtzslnhdtqxqgiyuhwpoqcjvf', + base_url='https://api.siliconflow.cn/v1' # 通义千问API的基础URL +) + +# 添加线程锁用于安全打印和保存 +print_lock = threading.Lock() +results_lock = threading.Lock() + +def encode_image_to_base64(image_path): + """将本地图片转换为base64编码""" + try: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + except Exception as e: + return None + +# 使用ChatGPT模型 +def chat_with_gpt(prompt, image_inputs=None): + try: + messages = [] + content = [{"type": "text", "text": prompt}] + + if image_inputs: + # 确保image_inputs是列表 + if not isinstance(image_inputs, list): + image_inputs = [image_inputs] + + # 处理每张图片 + for image_input in image_inputs: + if image_input.startswith(('http://', 'https://')): + image_data = { + "type": "image_url", + "image_url": {"url": image_input} + } + else: + # 处理本地图片 + base64_image = encode_image_to_base64(image_input) + if base64_image: + image_data = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + else: + raise Exception("无法读取本地图片") + content.append(image_data) + + messages.append({ + "role": "user", + "content": content + }) + + response = client.chat.completions.create( + model="Qwen/QVQ-72B-Preview", + messages=messages, + temperature=0.1, # 降低温度使输出更加确定性 + top_p=0.2, # 降低采样范围,使输出更加保守 + max_tokens=200, # 控制回答长度 + presence_penalty=0.0, # 不鼓励模型谈论新主题 + frequency_penalty=0.0, # 不惩罚频繁词汇 + stream=False + ) + return response.choices[0].message.content + except Exception as e: + return f"发生错误:{str(e)}" + +# 设置日志配置 +def setup_logging(): + # 创建日志文件名(包含时间戳) + log_filename = f'process_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt' + + # 配置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_filename, encoding='utf-8'), + logging.StreamHandler() # 同时输出到控制台 + ] + ) + return log_filename + +def process_single_item(item): + """处理单个数据项的函数""" + try: + image_url = item['imageurl'] + response = requests.get(image_url) + img = Image.open(BytesIO(response.content)) + + prompt = f"The brand of the product in the picture is {item['brand']}, write a caption including brand information." + gpt_response = chat_with_gpt(prompt, image_url) + + return { + 'brand': item['brand'], + 'image': img, + 'response': gpt_response + } + except Exception as e: + with print_lock: + logging.error(f"处理数据时出错: {str(e)}") + return None + +def process_dataset(ds, batch_size=20000, max_workers=10): + results = [] + valid_count = 0 + total_items = len(ds['train']) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # 创建进度条 + with tqdm(total=total_items, desc="处理数据") as pbar: + # 提交所有任务 + future_to_item = {executor.submit(process_single_item, item): i + for i, item in enumerate(ds['train'])} + + # 处理完成的任务 + for future in concurrent.futures.as_completed(future_to_item): + result = future.result() + if result: + with results_lock: + results.append(result) + valid_count += 1 + + # 达到batch_size时保存 + if len(results) >= batch_size: + save_results(results, valid_count // batch_size) + results = [] + + pbar.update(1) + + # 保存剩余结果 + if results: + save_results(results, (valid_count // batch_size) + 1) + + logging.info(f"处理完成,共处理 {valid_count} 条有效数据") + +def save_results(results, batch_num): + # 保存为pickle文件 + with open(f'batch_{batch_num}.pkl', 'wb') as f: + pickle.dump(results, f) + +# 运行处理 +if __name__ == "__main__": + # 设置日志 + log_file = setup_logging() + logging.info("开始处理数据集") + logging.info(f"数据集加载完成,共包含 {len(ds)} 条数据") + + process_dataset(ds)