import os os.environ["CUDA_VISIBLE_DEVICES"] = "3" os.environ["TOKENIZERS_PARALLELISM"] = "false" from PIL import Image import requests import torch from torchvision import io from typing import Dict from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor import pickle import re from tqdm import tqdm from peft import PeftModel # Load the model in half-precision on the available device(s) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen2.5-VL-3B", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" ) # 加载LoRA权重,添加load_in_8bit=False和device_map参数 # model = PeftModel.from_pretrained( # model, # "hybrid_train_output/checkpoint-100", # load_in_8bit=False, # device_map="auto", # is_trainable=False, # assign = True # ) # 确保模型处于评估模式 model.eval() processor = AutoProcessor.from_pretrained("Qwen2.5-VL-3B") # Image # image = Image.open("Tesla.jpg") # 定义提示文本 text_prompt = ( "<|im_start|>system\n" "You are a helpful assistant.<|im_end|>\n" "<|im_start|>user\n" "<|vision_start|><|image_pad|><|vision_end|>" "Please tell me the brand of the product in the picture between labels and " "and explain the reason between labels and " "<|im_end|>\n" "<|im_start|>assistant" ) # 加载测试数据 with open("../work/bal_data/test_data.pkl", "rb") as f: test_data = pickle.load(f) # 批处理大小 batch_size = 20 correct = 0 total = 0 # 遍历测试数据 for i in tqdm(range(0, len(test_data), batch_size)): # 准备当前批次的数据 batch = test_data[i:i+batch_size] batch_images = [item['image'] for item in batch] batch_brands = [item['brand'] for item in batch] batch_prompts = [text_prompt] * len(batch_images) # 模型处理 inputs = processor( text=batch_prompts, images=batch_images, padding=True, return_tensors="pt" ) inputs = inputs.to("cuda") # 生成输出 output_ids = model.generate(**inputs, max_new_tokens=128) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids) ] output_texts = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) # 提取预测的品牌名称并比较 for pred_text, true_brand in zip(output_texts, batch_brands): # 使用正则表达式提取标签中的内容 match = re.search(r'(.*?)', pred_text) if match: pred_brand = match.group(1).strip().lower() true_brand = true_brand.lower() # 比较预测结果 if pred_brand == true_brand: correct += 1 total += 1 # 计算并输出准确率 accuracy = correct / total if total > 0 else 0 print(f"准确率: {accuracy:.2%} ({correct}/{total})")