160 lines
4.9 KiB
Python
160 lines
4.9 KiB
Python
# pip install accelerate
|
|
|
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
from PIL import Image
|
|
import requests
|
|
import torch
|
|
import base64
|
|
from io import BytesIO
|
|
import pickle
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import logging
|
|
import re
|
|
from tqdm import tqdm
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
model_id = "gemma-3-12b"#"outputs/checkpoint-400"
|
|
|
|
# 加载模型
|
|
logging.info("加载模型中...")
|
|
model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
model_id, device_map="auto"
|
|
).eval()
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
logging.info("模型加载完成")
|
|
|
|
def image_to_base64(image):
|
|
# 将PIL Image对象转换为base64
|
|
buffered = BytesIO()
|
|
image.save(buffered, format="JPEG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
return f"data:image/jpeg;base64,{img_str}"
|
|
|
|
class TestDataset(Dataset):
|
|
def __init__(self, data_path):
|
|
with open(data_path, 'rb') as f:
|
|
self.data = pickle.load(f)
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
image_base64 = image_to_base64(item['image'])
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": image_base64},
|
|
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"}
|
|
]
|
|
}
|
|
]
|
|
|
|
return {
|
|
"messages": messages,
|
|
"correct_brand": item['brand']
|
|
}
|
|
|
|
# 加载测试数据
|
|
logging.info("加载测试数据...")
|
|
test_dataset = TestDataset('../work/bal_data/test_data.pkl')
|
|
total_samples = len(test_dataset)
|
|
logging.info(f"加载了 {total_samples} 条测试数据")
|
|
|
|
def evaluate_prediction(prediction, correct_brand):
|
|
answer_match = re.search(r'<answer/>(.*?)</answer>', prediction)
|
|
if answer_match:
|
|
predicted_brand = answer_match.group(1).strip().lower()
|
|
return correct_brand.lower() in predicted_brand
|
|
return False
|
|
|
|
# 进行测试
|
|
batch_size = 25
|
|
correct_count = 0
|
|
processed_count = 0
|
|
|
|
# 计算总批次数
|
|
total_batches = (total_samples + batch_size - 1) // batch_size
|
|
|
|
logging.info("开始测试...")
|
|
progress_bar = tqdm(
|
|
total=total_samples,
|
|
desc="测试进度",
|
|
ncols=100, # 进度条总长度
|
|
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
for i in range(0, total_samples, batch_size):
|
|
# 准备batch数据
|
|
batch_messages = []
|
|
batch_correct_brands = []
|
|
|
|
# 获取当前batch的数据
|
|
end_idx = min(i + batch_size, total_samples)
|
|
current_batch_size = end_idx - i
|
|
|
|
for idx in range(i, end_idx):
|
|
sample = test_dataset[idx]
|
|
batch_messages.append(sample['messages'])
|
|
batch_correct_brands.append(sample['correct_brand'])
|
|
|
|
# 处理当前batch
|
|
inputs = processor.apply_chat_template(
|
|
batch_messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt"
|
|
).to(model.device, dtype=torch.bfloat16)
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
|
generation = model.generate(**inputs, max_new_tokens=300, do_sample=False)
|
|
generation = generation[:,input_len:]
|
|
|
|
predictions = processor.batch_decode(generation, skip_special_tokens=True)
|
|
|
|
# 评估结果
|
|
batch_correct = 0
|
|
for pred, correct in zip(predictions, batch_correct_brands):
|
|
if evaluate_prediction(pred, correct):
|
|
correct_count += 1
|
|
batch_correct += 1
|
|
|
|
processed_count += current_batch_size
|
|
|
|
# 更新进度条
|
|
progress_bar.update(current_batch_size)
|
|
progress_bar.set_postfix({
|
|
'acc': f'{correct_count/processed_count:.4f}',
|
|
'correct': f'{correct_count}/{processed_count}'
|
|
})
|
|
|
|
progress_bar.close()
|
|
|
|
# 输出最终结果
|
|
final_accuracy = correct_count / total_samples
|
|
logging.info("\n测试完成!")
|
|
logging.info(f"总样本数: {total_samples}")
|
|
logging.info(f"正确预测数: {correct_count}")
|
|
logging.info(f"最终准确率: {final_accuracy:.4f}")
|
|
|
|
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
|
|
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
|
|
# It has a slightly soft, natural feel, likely captured in daylight.
|