import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import pickle
import re
from tqdm import tqdm
from peft import PeftModel
# Load the model in half-precision on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen2.5-VL-3B", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
)
# 加载LoRA权重,添加load_in_8bit=False和device_map参数
# model = PeftModel.from_pretrained(
# model,
# "hybrid_train_output/checkpoint-100",
# load_in_8bit=False,
# device_map="auto",
# is_trainable=False,
# assign = True
# )
# 确保模型处于评估模式
model.eval()
processor = AutoProcessor.from_pretrained("Qwen2.5-VL-3B")
# Image
# image = Image.open("Tesla.jpg")
# 定义提示文本
text_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
"Please tell me the brand of the product in the picture between labels and "
"and explain the reason between labels and "
"<|im_end|>\n"
"<|im_start|>assistant"
)
# 加载测试数据
with open("../work/bal_data/test_data.pkl", "rb") as f:
test_data = pickle.load(f)
# 批处理大小
batch_size = 20
correct = 0
total = 0
# 遍历测试数据
for i in tqdm(range(0, len(test_data), batch_size)):
# 准备当前批次的数据
batch = test_data[i:i+batch_size]
batch_images = [item['image'] for item in batch]
batch_brands = [item['brand'] for item in batch]
batch_prompts = [text_prompt] * len(batch_images)
# 模型处理
inputs = processor(
text=batch_prompts,
images=batch_images,
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda")
# 生成输出
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_texts = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
# 提取预测的品牌名称并比较
for pred_text, true_brand in zip(output_texts, batch_brands):
# 使用正则表达式提取标签中的内容
match = re.search(r'(.*?)', pred_text)
if match:
pred_brand = match.group(1).strip().lower()
true_brand = true_brand.lower()
# 比较预测结果
if pred_brand == true_brand:
correct += 1
total += 1
# 计算并输出准确率
accuracy = correct / total if total > 0 else 0
print(f"准确率: {accuracy:.2%} ({correct}/{total})")