diff --git a/vlm_grpo_lora.py b/vlm_grpo_lora.py new file mode 100644 index 0000000..bfacb5d --- /dev/null +++ b/vlm_grpo_lora.py @@ -0,0 +1,292 @@ +import os +# 环境设置 +#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3' +import logging +import pickle +import random +import torch +torch.set_float32_matmul_precision('high') +from datetime import datetime +from torch.utils.data import Dataset, DataLoader +from transformers import AutoTokenizer +from trl import GRPOTrainer, GRPOConfig +from peft import LoraConfig +from predata import load_image +from Internvl2_5.conversation import get_conv_template +from Internvl2_5.modeling_internvl_chat import InternVLChatModel +from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration +from accelerate import Accelerator +from accelerate.utils import set_seed + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 日志配置 +def setup_logging(): + current_time = datetime.now().strftime('%Y%m%d_%H%M%S') + log_filename = f'training_log_{current_time}.txt' + logging.basicConfig( + filename=log_filename, + level=logging.INFO, + format='%(asctime)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + +class PromptWithImage: + def __init__(self, text, pixel_values): + self.text = text + self.pixel_values = pixel_values + + def __str__(self): + return self.text + + def __repr__(self): + return self.text + +class ChatDataset(Dataset): + def __init__(self, data, tokenizer, batch_size, all_brands): + self.data = data + self.tokenizer = tokenizer + self.batch_size = batch_size + self.all_brands = all_brands + + def __len__(self): + return len(self.data) + + def create_prompt(self, pixel_values): + prompt_text = "\nPlease identify the most appropriate brand from the following options based on the image content:\n" + for brand in self.all_brands: + prompt_text += f"{brand}\n" + prompt_text += "Please only respond with the brand name, no additional explanation needed." + + return PromptWithImage(prompt_text, pixel_values) + + def __getitem__(self, idx): + item = self.data[idx] + pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16) + + return { + "prompt": self.create_prompt(pixel_values), + "correct_brand": item['brand'], + } + + +# 奖励函数 +def reward_func(prompts, completions, **kwargs): + rewards = [] + # 打印调试信息 + print("prompts length:", len(prompts)) + print("completions length:", len(completions)) + print("kwargs keys:", kwargs.keys()) + correct_brands = kwargs.get('correct_brand') + print("correct_brands length:", len(correct_brands)) + + for completion, correct_brand in zip(completions, correct_brands): + + # 简单判断:品牌名是否出现在回答中(不区分大小写) + correct = correct_brand.lower() in completion.lower() + + # 打印调试信息 + print("completion:", completion) + print("correct_brand:", correct_brand) + print("is_correct:", correct) + + rewards.append(float(correct)) + + print("rewards length:", len(rewards)) + return rewards + + +# 模型配置 +def get_model_config(): + return { + 'torch_dtype': torch.bfloat16, + 'low_cpu_mem_usage': True, + 'use_flash_attn': True, + 'trust_remote_code': True, + 'vision_model': None, + 'language_model': None + } + +# LoRA配置 +def get_lora_config(): + return LoraConfig( + task_type="CAUSAL_LM", + r=4, + lora_alpha=16, + lora_dropout=0.1, + target_modules={ + "mlp1.1": { + "r": 4, + "lora_alpha": 16, + "lora_dropout": 0.1, + }, + "mlp1.3": { + "r": 4, + "lora_alpha": 16, + "lora_dropout": 0.1, + }, + # "q_proj": { + # "r": 4, + # "lora_alpha": 16, + # "lora_dropout": 0.1, + # }, + # "k_proj": { + # "r": 4, + # "lora_alpha": 16, + # "lora_dropout": 0.1, + # } + }, + bias="none" + ) + +# 训练配置 +def get_training_args(): + return GRPOConfig( + output_dir="chat_grpo_output", + num_generations = 4, + learning_rate=1e-5, + logging_steps=100, + max_prompt_length=None, + gradient_accumulation_steps=4, + max_completion_length=50, + per_device_train_batch_size=2, + max_steps=1000, + dataloader_pin_memory=False # 禁用 pin_memory + ) + +class ImagePromptProcessor: + def __init__(self, tokenizer, model): + self.tokenizer = tokenizer + self.model = model + # 继承 tokenizer 的所有属性 + for attr_name in dir(tokenizer): + # 跳过内置属性和方法 + if not attr_name.startswith('__'): + try: + setattr(self, attr_name, getattr(tokenizer, attr_name)) + except AttributeError: + pass + + def __getattr__(self, name): + # 如果属性在本类中找不到,则尝试从 tokenizer 中获取 + return getattr(self.tokenizer, name) + + def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs): + if isinstance(prompts[0], PromptWithImage): + pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0) + texts = [str(p) for p in prompts] + num_patches_list = [p.pixel_values.shape[0] for p in prompts] + + # 使用 batch_chat 获取 prompt_ids 和 attention_mask + prompt_ids, prompt_mask = self.model.for_grpo( + self.tokenizer, + pixel_values, + texts, + num_patches_list=num_patches_list, + history=None, + return_history=False, + ) + + return { + "input_ids": prompt_ids, + "attention_mask": prompt_mask + } + else: + # 处理普通文本 + return self.tokenizer( + prompts, + return_tensors=return_tensors, + padding=padding, + padding_side=padding_side, + add_special_tokens=add_special_tokens + ) + +# 修改模型类,添加 generate 方法 +class CustomInternVLModel(InternVLChatModel): + def forward(self, *args, **kwargs): + try: + # 首先尝试直接使用模型的forward + return super().forward(*args, **kwargs) + except (TypeError, ValueError, AttributeError) as e: + # 如果出现参数不匹配或其他错误,使用 language_model + print(f"切换到 language_model 进行前向传播: {str(e)}") + return self.language_model(*args, **kwargs) + + @torch.no_grad() + def generate( + self, + input_ids, + attention_mask=None, + **kwargs + ): + + print(f"切换到 language_model 进行生成") + return self.language_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs + ) + +def main(): + # 初始化 accelerator + accelerator = Accelerator( + gradient_accumulation_steps=4, + mixed_precision="bf16" + ) + + # 设置随机种子 + set_seed(42) + + # 设置日志 + setup_logging() + + # 加载数据 + logging.info("正在加载数据...") + with open('bal_data/frequent_brands_data.pkl', 'rb') as f: + data = pickle.load(f) + logging.info(f"成功加载数据") + + # 加载模型时使用自定义类 + path = 'Internvl2_5' + model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train() + model.name_or_path = 'CustomInternVLModel' + + # 加载预训练权重 + print("正在加载预训练权重 vit_mlp_epoch_15.pth ...") + model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth')) + print("成功加载预训练权重") + + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) + + # 创建处理器 + processor = ImagePromptProcessor(tokenizer, model) + + # 创建数据集 + dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data)))) + + # 创建训练器(使用原始的 GRPOTrainer) + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_func, + args=get_training_args(), + train_dataset=dataset, + processing_class=processor, + peft_config=get_lora_config() + ) + + # 使用 accelerator 准备模型和训练器 + trainer = accelerator.prepare(trainer) + + # 开始训练 + logging.info("开始训练...") + trainer.train() + logging.info("训练完成") + + # 保存模型 + output_dir = "chat_model_lora_3" + unwrapped_model = accelerator.unwrap_model(trainer.model) + unwrapped_model.save_pretrained(output_dir) + logging.info(f"模型已保存到 {output_dir}") + +if __name__ == "__main__": + main()