import os # 环境设置 #os.environ['CUDA_VISIBLE_DEVICES'] = '2,3' import logging import pickle import random import torch torch.set_float32_matmul_precision('high') from datetime import datetime from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer from trl import GRPOTrainer, GRPOConfig from peft import LoraConfig from predata import load_image from Internvl2_5.conversation import get_conv_template from Internvl2_5.modeling_internvl_chat import InternVLChatModel from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration from accelerate import Accelerator from accelerate.utils import set_seed device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 日志配置 def setup_logging(): current_time = datetime.now().strftime('%Y%m%d_%H%M%S') log_filename = f'training_log_{current_time}.txt' logging.basicConfig( filename=log_filename, level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) class PromptWithImage: def __init__(self, text, pixel_values): self.text = text self.pixel_values = pixel_values def __str__(self): return self.text def __repr__(self): return self.text class ChatDataset(Dataset): def __init__(self, data, tokenizer, batch_size, all_brands): self.data = data self.tokenizer = tokenizer self.batch_size = batch_size self.all_brands = all_brands def __len__(self): return len(self.data) def create_prompt(self, pixel_values): prompt_text = "\nPlease identify the most appropriate brand from the following options based on the image content:\n" for brand in self.all_brands: prompt_text += f"{brand}\n" prompt_text += "Please only respond with the brand name, no additional explanation needed." return PromptWithImage(prompt_text, pixel_values) def __getitem__(self, idx): item = self.data[idx] pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16) return { "prompt": self.create_prompt(pixel_values), "correct_brand": item['brand'], } # 奖励函数 def reward_func(prompts, completions, **kwargs): rewards = [] # 打印调试信息 print("prompts length:", len(prompts)) print("completions length:", len(completions)) print("kwargs keys:", kwargs.keys()) correct_brands = kwargs.get('correct_brand') print("correct_brands length:", len(correct_brands)) for completion, correct_brand in zip(completions, correct_brands): # 简单判断:品牌名是否出现在回答中(不区分大小写) correct = correct_brand.lower() in completion.lower() # 打印调试信息 print("completion:", completion) print("correct_brand:", correct_brand) print("is_correct:", correct) rewards.append(float(correct)) print("rewards length:", len(rewards)) return rewards # 模型配置 def get_model_config(): return { 'torch_dtype': torch.bfloat16, 'low_cpu_mem_usage': True, 'use_flash_attn': True, 'trust_remote_code': True, 'vision_model': None, 'language_model': None } # LoRA配置 def get_lora_config(): return LoraConfig( task_type="CAUSAL_LM", r=4, lora_alpha=16, lora_dropout=0.1, target_modules={ "mlp1.1": { "r": 4, "lora_alpha": 16, "lora_dropout": 0.1, }, "mlp1.3": { "r": 4, "lora_alpha": 16, "lora_dropout": 0.1, }, # "q_proj": { # "r": 4, # "lora_alpha": 16, # "lora_dropout": 0.1, # }, # "k_proj": { # "r": 4, # "lora_alpha": 16, # "lora_dropout": 0.1, # } }, bias="none" ) # 训练配置 def get_training_args(): return GRPOConfig( output_dir="chat_grpo_output", num_generations = 4, learning_rate=1e-5, logging_steps=100, max_prompt_length=None, gradient_accumulation_steps=4, max_completion_length=50, per_device_train_batch_size=2, max_steps=1000, dataloader_pin_memory=False # 禁用 pin_memory ) class ImagePromptProcessor: def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model # 继承 tokenizer 的所有属性 for attr_name in dir(tokenizer): # 跳过内置属性和方法 if not attr_name.startswith('__'): try: setattr(self, attr_name, getattr(tokenizer, attr_name)) except AttributeError: pass def __getattr__(self, name): # 如果属性在本类中找不到,则尝试从 tokenizer 中获取 return getattr(self.tokenizer, name) def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs): if isinstance(prompts[0], PromptWithImage): pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0) texts = [str(p) for p in prompts] num_patches_list = [p.pixel_values.shape[0] for p in prompts] # 使用 batch_chat 获取 prompt_ids 和 attention_mask prompt_ids, prompt_mask = self.model.for_grpo( self.tokenizer, pixel_values, texts, num_patches_list=num_patches_list, history=None, return_history=False, ) return { "input_ids": prompt_ids, "attention_mask": prompt_mask } else: # 处理普通文本 return self.tokenizer( prompts, return_tensors=return_tensors, padding=padding, padding_side=padding_side, add_special_tokens=add_special_tokens ) # 修改模型类,添加 generate 方法 class CustomInternVLModel(InternVLChatModel): def forward(self, *args, **kwargs): try: # 首先尝试直接使用模型的forward return super().forward(*args, **kwargs) except (TypeError, ValueError, AttributeError) as e: # 如果出现参数不匹配或其他错误,使用 language_model print(f"切换到 language_model 进行前向传播: {str(e)}") return self.language_model(*args, **kwargs) @torch.no_grad() def generate( self, input_ids, attention_mask=None, **kwargs ): print(f"切换到 language_model 进行生成") return self.language_model.generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) def main(): # 初始化 accelerator accelerator = Accelerator( gradient_accumulation_steps=4, mixed_precision="bf16" ) # 设置随机种子 set_seed(42) # 设置日志 setup_logging() # 加载数据 logging.info("正在加载数据...") with open('bal_data/frequent_brands_data.pkl', 'rb') as f: data = pickle.load(f) logging.info(f"成功加载数据") # 加载模型时使用自定义类 path = 'Internvl2_5' model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train() model.name_or_path = 'CustomInternVLModel' # 加载预训练权重 print("正在加载预训练权重 vit_mlp_epoch_15.pth ...") model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth')) print("成功加载预训练权重") tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) # 创建处理器 processor = ImagePromptProcessor(tokenizer, model) # 创建数据集 dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data)))) # 创建训练器(使用原始的 GRPOTrainer) trainer = GRPOTrainer( model=model, reward_funcs=reward_func, args=get_training_args(), train_dataset=dataset, processing_class=processor, peft_config=get_lora_config() ) # 使用 accelerator 准备模型和训练器 trainer = accelerator.prepare(trainer) # 开始训练 logging.info("开始训练...") trainer.train() logging.info("训练完成") # 保存模型 output_dir = "chat_model_lora_3" unwrapped_model = accelerator.unwrap_model(trainer.model) unwrapped_model.save_pretrained(output_dir) logging.info(f"模型已保存到 {output_dir}") if __name__ == "__main__": main()