添加 vlm_grpo_lora.py

This commit is contained in:
CHAOGAO 2025-03-31 15:47:01 +08:00
parent 00f8a10076
commit e77007c5a8

292
vlm_grpo_lora.py Normal file
View File

@ -0,0 +1,292 @@
import os
# 环境设置
#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
import logging
import pickle
import random
import torch
torch.set_float32_matmul_precision('high')
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig
from predata import load_image
from Internvl2_5.conversation import get_conv_template
from Internvl2_5.modeling_internvl_chat import InternVLChatModel
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
from accelerate import Accelerator
from accelerate.utils import set_seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 日志配置
def setup_logging():
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_filename = f'training_log_{current_time}.txt'
logging.basicConfig(
filename=log_filename,
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
class PromptWithImage:
def __init__(self, text, pixel_values):
self.text = text
self.pixel_values = pixel_values
def __str__(self):
return self.text
def __repr__(self):
return self.text
class ChatDataset(Dataset):
def __init__(self, data, tokenizer, batch_size, all_brands):
self.data = data
self.tokenizer = tokenizer
self.batch_size = batch_size
self.all_brands = all_brands
def __len__(self):
return len(self.data)
def create_prompt(self, pixel_values):
prompt_text = "<image>\nPlease identify the most appropriate brand from the following options based on the image content:\n"
for brand in self.all_brands:
prompt_text += f"{brand}\n"
prompt_text += "Please only respond with the brand name, no additional explanation needed."
return PromptWithImage(prompt_text, pixel_values)
def __getitem__(self, idx):
item = self.data[idx]
pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16)
return {
"prompt": self.create_prompt(pixel_values),
"correct_brand": item['brand'],
}
# 奖励函数
def reward_func(prompts, completions, **kwargs):
rewards = []
# 打印调试信息
print("prompts length:", len(prompts))
print("completions length:", len(completions))
print("kwargs keys:", kwargs.keys())
correct_brands = kwargs.get('correct_brand')
print("correct_brands length:", len(correct_brands))
for completion, correct_brand in zip(completions, correct_brands):
# 简单判断:品牌名是否出现在回答中(不区分大小写)
correct = correct_brand.lower() in completion.lower()
# 打印调试信息
print("completion:", completion)
print("correct_brand:", correct_brand)
print("is_correct:", correct)
rewards.append(float(correct))
print("rewards length:", len(rewards))
return rewards
# 模型配置
def get_model_config():
return {
'torch_dtype': torch.bfloat16,
'low_cpu_mem_usage': True,
'use_flash_attn': True,
'trust_remote_code': True,
'vision_model': None,
'language_model': None
}
# LoRA配置
def get_lora_config():
return LoraConfig(
task_type="CAUSAL_LM",
r=4,
lora_alpha=16,
lora_dropout=0.1,
target_modules={
"mlp1.1": {
"r": 4,
"lora_alpha": 16,
"lora_dropout": 0.1,
},
"mlp1.3": {
"r": 4,
"lora_alpha": 16,
"lora_dropout": 0.1,
},
# "q_proj": {
# "r": 4,
# "lora_alpha": 16,
# "lora_dropout": 0.1,
# },
# "k_proj": {
# "r": 4,
# "lora_alpha": 16,
# "lora_dropout": 0.1,
# }
},
bias="none"
)
# 训练配置
def get_training_args():
return GRPOConfig(
output_dir="chat_grpo_output",
num_generations = 4,
learning_rate=1e-5,
logging_steps=100,
max_prompt_length=None,
gradient_accumulation_steps=4,
max_completion_length=50,
per_device_train_batch_size=2,
max_steps=1000,
dataloader_pin_memory=False # 禁用 pin_memory
)
class ImagePromptProcessor:
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
# 继承 tokenizer 的所有属性
for attr_name in dir(tokenizer):
# 跳过内置属性和方法
if not attr_name.startswith('__'):
try:
setattr(self, attr_name, getattr(tokenizer, attr_name))
except AttributeError:
pass
def __getattr__(self, name):
# 如果属性在本类中找不到,则尝试从 tokenizer 中获取
return getattr(self.tokenizer, name)
def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs):
if isinstance(prompts[0], PromptWithImage):
pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0)
texts = [str(p) for p in prompts]
num_patches_list = [p.pixel_values.shape[0] for p in prompts]
# 使用 batch_chat 获取 prompt_ids 和 attention_mask
prompt_ids, prompt_mask = self.model.for_grpo(
self.tokenizer,
pixel_values,
texts,
num_patches_list=num_patches_list,
history=None,
return_history=False,
)
return {
"input_ids": prompt_ids,
"attention_mask": prompt_mask
}
else:
# 处理普通文本
return self.tokenizer(
prompts,
return_tensors=return_tensors,
padding=padding,
padding_side=padding_side,
add_special_tokens=add_special_tokens
)
# 修改模型类,添加 generate 方法
class CustomInternVLModel(InternVLChatModel):
def forward(self, *args, **kwargs):
try:
# 首先尝试直接使用模型的forward
return super().forward(*args, **kwargs)
except (TypeError, ValueError, AttributeError) as e:
# 如果出现参数不匹配或其他错误,使用 language_model
print(f"切换到 language_model 进行前向传播: {str(e)}")
return self.language_model(*args, **kwargs)
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
**kwargs
):
print(f"切换到 language_model 进行生成")
return self.language_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
def main():
# 初始化 accelerator
accelerator = Accelerator(
gradient_accumulation_steps=4,
mixed_precision="bf16"
)
# 设置随机种子
set_seed(42)
# 设置日志
setup_logging()
# 加载数据
logging.info("正在加载数据...")
with open('bal_data/frequent_brands_data.pkl', 'rb') as f:
data = pickle.load(f)
logging.info(f"成功加载数据")
# 加载模型时使用自定义类
path = 'Internvl2_5'
model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train()
model.name_or_path = 'CustomInternVLModel'
# 加载预训练权重
print("正在加载预训练权重 vit_mlp_epoch_15.pth ...")
model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth'))
print("成功加载预训练权重")
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# 创建处理器
processor = ImagePromptProcessor(tokenizer, model)
# 创建数据集
dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data))))
# 创建训练器(使用原始的 GRPOTrainer
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_func,
args=get_training_args(),
train_dataset=dataset,
processing_class=processor,
peft_config=get_lora_config()
)
# 使用 accelerator 准备模型和训练器
trainer = accelerator.prepare(trainer)
# 开始训练
logging.info("开始训练...")
trainer.train()
logging.info("训练完成")
# 保存模型
output_dir = "chat_model_lora_3"
unwrapped_model = accelerator.unwrap_model(trainer.model)
unwrapped_model.save_pretrained(output_dir)
logging.info(f"模型已保存到 {output_dir}")
if __name__ == "__main__":
main()