Fortrain/qw/test.py
2025-03-31 15:56:36 +08:00

103 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import pickle
import re
from tqdm import tqdm
from peft import PeftModel
# Load the model in half-precision on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen2.5-VL-3B", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
)
# 加载LoRA权重添加load_in_8bit=False和device_map参数
# model = PeftModel.from_pretrained(
# model,
# "hybrid_train_output/checkpoint-100",
# load_in_8bit=False,
# device_map="auto",
# is_trainable=False,
# assign = True
# )
# 确保模型处于评估模式
model.eval()
processor = AutoProcessor.from_pretrained("Qwen2.5-VL-3B")
# Image
# image = Image.open("Tesla.jpg")
# 定义提示文本
text_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
"Please tell me the brand of the product in the picture between labels <answer/> and </answer> "
"and explain the reason between labels <thinking/> and </thinking>"
"<|im_end|>\n"
"<|im_start|>assistant"
)
# 加载测试数据
with open("../work/bal_data/test_data.pkl", "rb") as f:
test_data = pickle.load(f)
# 批处理大小
batch_size = 20
correct = 0
total = 0
# 遍历测试数据
for i in tqdm(range(0, len(test_data), batch_size)):
# 准备当前批次的数据
batch = test_data[i:i+batch_size]
batch_images = [item['image'] for item in batch]
batch_brands = [item['brand'] for item in batch]
batch_prompts = [text_prompt] * len(batch_images)
# 模型处理
inputs = processor(
text=batch_prompts,
images=batch_images,
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda")
# 生成输出
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_texts = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
# 提取预测的品牌名称并比较
for pred_text, true_brand in zip(output_texts, batch_brands):
# 使用正则表达式提取<answer>标签中的内容
match = re.search(r'<answer>(.*?)</answer>', pred_text)
if match:
pred_brand = match.group(1).strip().lower()
true_brand = true_brand.lower()
# 比较预测结果
if pred_brand == true_brand:
correct += 1
total += 1
# 计算并输出准确率
accuracy = correct / total if total > 0 else 0
print(f"准确率: {accuracy:.2%} ({correct}/{total})")