Fortrain/vlm_grpo_lora.py
2025-03-31 15:47:01 +08:00

293 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 环境设置
#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
import logging
import pickle
import random
import torch
torch.set_float32_matmul_precision('high')
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig
from predata import load_image
from Internvl2_5.conversation import get_conv_template
from Internvl2_5.modeling_internvl_chat import InternVLChatModel
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
from accelerate import Accelerator
from accelerate.utils import set_seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 日志配置
def setup_logging():
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_filename = f'training_log_{current_time}.txt'
logging.basicConfig(
filename=log_filename,
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
class PromptWithImage:
def __init__(self, text, pixel_values):
self.text = text
self.pixel_values = pixel_values
def __str__(self):
return self.text
def __repr__(self):
return self.text
class ChatDataset(Dataset):
def __init__(self, data, tokenizer, batch_size, all_brands):
self.data = data
self.tokenizer = tokenizer
self.batch_size = batch_size
self.all_brands = all_brands
def __len__(self):
return len(self.data)
def create_prompt(self, pixel_values):
prompt_text = "<image>\nPlease identify the most appropriate brand from the following options based on the image content:\n"
for brand in self.all_brands:
prompt_text += f"{brand}\n"
prompt_text += "Please only respond with the brand name, no additional explanation needed."
return PromptWithImage(prompt_text, pixel_values)
def __getitem__(self, idx):
item = self.data[idx]
pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16)
return {
"prompt": self.create_prompt(pixel_values),
"correct_brand": item['brand'],
}
# 奖励函数
def reward_func(prompts, completions, **kwargs):
rewards = []
# 打印调试信息
print("prompts length:", len(prompts))
print("completions length:", len(completions))
print("kwargs keys:", kwargs.keys())
correct_brands = kwargs.get('correct_brand')
print("correct_brands length:", len(correct_brands))
for completion, correct_brand in zip(completions, correct_brands):
# 简单判断:品牌名是否出现在回答中(不区分大小写)
correct = correct_brand.lower() in completion.lower()
# 打印调试信息
print("completion:", completion)
print("correct_brand:", correct_brand)
print("is_correct:", correct)
rewards.append(float(correct))
print("rewards length:", len(rewards))
return rewards
# 模型配置
def get_model_config():
return {
'torch_dtype': torch.bfloat16,
'low_cpu_mem_usage': True,
'use_flash_attn': True,
'trust_remote_code': True,
'vision_model': None,
'language_model': None
}
# LoRA配置
def get_lora_config():
return LoraConfig(
task_type="CAUSAL_LM",
r=4,
lora_alpha=16,
lora_dropout=0.1,
target_modules={
"mlp1.1": {
"r": 4,
"lora_alpha": 16,
"lora_dropout": 0.1,
},
"mlp1.3": {
"r": 4,
"lora_alpha": 16,
"lora_dropout": 0.1,
},
# "q_proj": {
# "r": 4,
# "lora_alpha": 16,
# "lora_dropout": 0.1,
# },
# "k_proj": {
# "r": 4,
# "lora_alpha": 16,
# "lora_dropout": 0.1,
# }
},
bias="none"
)
# 训练配置
def get_training_args():
return GRPOConfig(
output_dir="chat_grpo_output",
num_generations = 4,
learning_rate=1e-5,
logging_steps=100,
max_prompt_length=None,
gradient_accumulation_steps=4,
max_completion_length=50,
per_device_train_batch_size=2,
max_steps=1000,
dataloader_pin_memory=False # 禁用 pin_memory
)
class ImagePromptProcessor:
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
# 继承 tokenizer 的所有属性
for attr_name in dir(tokenizer):
# 跳过内置属性和方法
if not attr_name.startswith('__'):
try:
setattr(self, attr_name, getattr(tokenizer, attr_name))
except AttributeError:
pass
def __getattr__(self, name):
# 如果属性在本类中找不到,则尝试从 tokenizer 中获取
return getattr(self.tokenizer, name)
def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs):
if isinstance(prompts[0], PromptWithImage):
pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0)
texts = [str(p) for p in prompts]
num_patches_list = [p.pixel_values.shape[0] for p in prompts]
# 使用 batch_chat 获取 prompt_ids 和 attention_mask
prompt_ids, prompt_mask = self.model.for_grpo(
self.tokenizer,
pixel_values,
texts,
num_patches_list=num_patches_list,
history=None,
return_history=False,
)
return {
"input_ids": prompt_ids,
"attention_mask": prompt_mask
}
else:
# 处理普通文本
return self.tokenizer(
prompts,
return_tensors=return_tensors,
padding=padding,
padding_side=padding_side,
add_special_tokens=add_special_tokens
)
# 修改模型类,添加 generate 方法
class CustomInternVLModel(InternVLChatModel):
def forward(self, *args, **kwargs):
try:
# 首先尝试直接使用模型的forward
return super().forward(*args, **kwargs)
except (TypeError, ValueError, AttributeError) as e:
# 如果出现参数不匹配或其他错误,使用 language_model
print(f"切换到 language_model 进行前向传播: {str(e)}")
return self.language_model(*args, **kwargs)
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
**kwargs
):
print(f"切换到 language_model 进行生成")
return self.language_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
def main():
# 初始化 accelerator
accelerator = Accelerator(
gradient_accumulation_steps=4,
mixed_precision="bf16"
)
# 设置随机种子
set_seed(42)
# 设置日志
setup_logging()
# 加载数据
logging.info("正在加载数据...")
with open('bal_data/frequent_brands_data.pkl', 'rb') as f:
data = pickle.load(f)
logging.info(f"成功加载数据")
# 加载模型时使用自定义类
path = 'Internvl2_5'
model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train()
model.name_or_path = 'CustomInternVLModel'
# 加载预训练权重
print("正在加载预训练权重 vit_mlp_epoch_15.pth ...")
model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth'))
print("成功加载预训练权重")
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# 创建处理器
processor = ImagePromptProcessor(tokenizer, model)
# 创建数据集
dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data))))
# 创建训练器(使用原始的 GRPOTrainer
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_func,
args=get_training_args(),
train_dataset=dataset,
processing_class=processor,
peft_config=get_lora_config()
)
# 使用 accelerator 准备模型和训练器
trainer = accelerator.prepare(trainer)
# 开始训练
logging.info("开始训练...")
trainer.train()
logging.info("训练完成")
# 保存模型
output_dir = "chat_model_lora_3"
unwrapped_model = accelerator.unwrap_model(trainer.model)
unwrapped_model.save_pretrained(output_dir)
logging.info(f"模型已保存到 {output_dir}")
if __name__ == "__main__":
main()