# pip install accelerate from transformers import AutoProcessor, Gemma3ForConditionalGeneration from PIL import Image import requests import torch import base64 from io import BytesIO import pickle from torch.utils.data import Dataset, DataLoader import logging import re from tqdm import tqdm # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() ] ) model_id = "gemma-3-12b"#"outputs/checkpoint-400" # 加载模型 logging.info("加载模型中...") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto" ).eval() processor = AutoProcessor.from_pretrained(model_id) logging.info("模型加载完成") def image_to_base64(image): # 将PIL Image对象转换为base64 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" class TestDataset(Dataset): def __init__(self, data_path): with open(data_path, 'rb') as f: self.data = pickle.load(f) def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] image_base64 = image_to_base64(item['image']) messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}] }, { "role": "user", "content": [ {"type": "image", "image": image_base64}, {"type": "text", "text": "Please tell me the brand of the product in the picture between labels and and explain the reason between labels and "} ] } ] return { "messages": messages, "correct_brand": item['brand'] } # 加载测试数据 logging.info("加载测试数据...") test_dataset = TestDataset('../work/bal_data/test_data.pkl') total_samples = len(test_dataset) logging.info(f"加载了 {total_samples} 条测试数据") def evaluate_prediction(prediction, correct_brand): answer_match = re.search(r'(.*?)', prediction) if answer_match: predicted_brand = answer_match.group(1).strip().lower() return correct_brand.lower() in predicted_brand return False # 进行测试 batch_size = 25 correct_count = 0 processed_count = 0 # 计算总批次数 total_batches = (total_samples + batch_size - 1) // batch_size logging.info("开始测试...") progress_bar = tqdm( total=total_samples, desc="测试进度", ncols=100, # 进度条总长度 bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' ) with torch.inference_mode(): for i in range(0, total_samples, batch_size): # 准备batch数据 batch_messages = [] batch_correct_brands = [] # 获取当前batch的数据 end_idx = min(i + batch_size, total_samples) current_batch_size = end_idx - i for idx in range(i, end_idx): sample = test_dataset[idx] batch_messages.append(sample['messages']) batch_correct_brands.append(sample['correct_brand']) # 处理当前batch inputs = processor.apply_chat_template( batch_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] generation = model.generate(**inputs, max_new_tokens=300, do_sample=False) generation = generation[:,input_len:] predictions = processor.batch_decode(generation, skip_special_tokens=True) # 评估结果 batch_correct = 0 for pred, correct in zip(predictions, batch_correct_brands): if evaluate_prediction(pred, correct): correct_count += 1 batch_correct += 1 processed_count += current_batch_size # 更新进度条 progress_bar.update(current_batch_size) progress_bar.set_postfix({ 'acc': f'{correct_count/processed_count:.4f}', 'correct': f'{correct_count}/{processed_count}' }) progress_bar.close() # 输出最终结果 final_accuracy = correct_count / total_samples logging.info("\n测试完成!") logging.info(f"总样本数: {total_samples}") logging.info(f"正确预测数: {correct_count}") logging.info(f"最终准确率: {final_accuracy:.4f}") # **Overall Impression:** The image is a close-up shot of a vibrant garden scene, # focusing on a cluster of pink cosmos flowers and a busy bumblebee. # It has a slightly soft, natural feel, likely captured in daylight.