Fortrain/gemma3/test.py

160 lines
4.9 KiB
Python
Raw Normal View History

2025-03-31 15:56:36 +08:00
# pip install accelerate
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
import base64
from io import BytesIO
import pickle
from torch.utils.data import Dataset, DataLoader
import logging
import re
from tqdm import tqdm
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
model_id = "gemma-3-12b"#"outputs/checkpoint-400"
# 加载模型
logging.info("加载模型中...")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id)
logging.info("模型加载完成")
def image_to_base64(image):
# 将PIL Image对象转换为base64
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
class TestDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image_base64 = image_to_base64(item['image'])
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_base64},
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"}
]
}
]
return {
"messages": messages,
"correct_brand": item['brand']
}
# 加载测试数据
logging.info("加载测试数据...")
test_dataset = TestDataset('../work/bal_data/test_data.pkl')
total_samples = len(test_dataset)
logging.info(f"加载了 {total_samples} 条测试数据")
def evaluate_prediction(prediction, correct_brand):
answer_match = re.search(r'<answer/>(.*?)</answer>', prediction)
if answer_match:
predicted_brand = answer_match.group(1).strip().lower()
return correct_brand.lower() in predicted_brand
return False
# 进行测试
batch_size = 25
correct_count = 0
processed_count = 0
# 计算总批次数
total_batches = (total_samples + batch_size - 1) // batch_size
logging.info("开始测试...")
progress_bar = tqdm(
total=total_samples,
desc="测试进度",
ncols=100, # 进度条总长度
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
)
with torch.inference_mode():
for i in range(0, total_samples, batch_size):
# 准备batch数据
batch_messages = []
batch_correct_brands = []
# 获取当前batch的数据
end_idx = min(i + batch_size, total_samples)
current_batch_size = end_idx - i
for idx in range(i, end_idx):
sample = test_dataset[idx]
batch_messages.append(sample['messages'])
batch_correct_brands.append(sample['correct_brand'])
# 处理当前batch
inputs = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
generation = model.generate(**inputs, max_new_tokens=300, do_sample=False)
generation = generation[:,input_len:]
predictions = processor.batch_decode(generation, skip_special_tokens=True)
# 评估结果
batch_correct = 0
for pred, correct in zip(predictions, batch_correct_brands):
if evaluate_prediction(pred, correct):
correct_count += 1
batch_correct += 1
processed_count += current_batch_size
# 更新进度条
progress_bar.update(current_batch_size)
progress_bar.set_postfix({
'acc': f'{correct_count/processed_count:.4f}',
'correct': f'{correct_count}/{processed_count}'
})
progress_bar.close()
# 输出最终结果
final_accuracy = correct_count / total_samples
logging.info("\n测试完成!")
logging.info(f"总样本数: {total_samples}")
logging.info(f"正确预测数: {correct_count}")
logging.info(f"最终准确率: {final_accuracy:.4f}")
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
# It has a slightly soft, natural feel, likely captured in daylight.