添加 vlm_grpo_lora.py
This commit is contained in:
parent
00f8a10076
commit
e77007c5a8
292
vlm_grpo_lora.py
Normal file
292
vlm_grpo_lora.py
Normal file
@ -0,0 +1,292 @@
|
||||
import os
|
||||
# 环境设置
|
||||
#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
import torch
|
||||
torch.set_float32_matmul_precision('high')
|
||||
from datetime import datetime
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from peft import LoraConfig
|
||||
from predata import load_image
|
||||
from Internvl2_5.conversation import get_conv_template
|
||||
from Internvl2_5.modeling_internvl_chat import InternVLChatModel
|
||||
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 日志配置
|
||||
def setup_logging():
|
||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_filename = f'training_log_{current_time}.txt'
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
class PromptWithImage:
|
||||
def __init__(self, text, pixel_values):
|
||||
self.text = text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
def __repr__(self):
|
||||
return self.text
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
def __init__(self, data, tokenizer, batch_size, all_brands):
|
||||
self.data = data
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.all_brands = all_brands
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def create_prompt(self, pixel_values):
|
||||
prompt_text = "<image>\nPlease identify the most appropriate brand from the following options based on the image content:\n"
|
||||
for brand in self.all_brands:
|
||||
prompt_text += f"{brand}\n"
|
||||
prompt_text += "Please only respond with the brand name, no additional explanation needed."
|
||||
|
||||
return PromptWithImage(prompt_text, pixel_values)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16)
|
||||
|
||||
return {
|
||||
"prompt": self.create_prompt(pixel_values),
|
||||
"correct_brand": item['brand'],
|
||||
}
|
||||
|
||||
|
||||
# 奖励函数
|
||||
def reward_func(prompts, completions, **kwargs):
|
||||
rewards = []
|
||||
# 打印调试信息
|
||||
print("prompts length:", len(prompts))
|
||||
print("completions length:", len(completions))
|
||||
print("kwargs keys:", kwargs.keys())
|
||||
correct_brands = kwargs.get('correct_brand')
|
||||
print("correct_brands length:", len(correct_brands))
|
||||
|
||||
for completion, correct_brand in zip(completions, correct_brands):
|
||||
|
||||
# 简单判断:品牌名是否出现在回答中(不区分大小写)
|
||||
correct = correct_brand.lower() in completion.lower()
|
||||
|
||||
# 打印调试信息
|
||||
print("completion:", completion)
|
||||
print("correct_brand:", correct_brand)
|
||||
print("is_correct:", correct)
|
||||
|
||||
rewards.append(float(correct))
|
||||
|
||||
print("rewards length:", len(rewards))
|
||||
return rewards
|
||||
|
||||
|
||||
# 模型配置
|
||||
def get_model_config():
|
||||
return {
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'low_cpu_mem_usage': True,
|
||||
'use_flash_attn': True,
|
||||
'trust_remote_code': True,
|
||||
'vision_model': None,
|
||||
'language_model': None
|
||||
}
|
||||
|
||||
# LoRA配置
|
||||
def get_lora_config():
|
||||
return LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
r=4,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
target_modules={
|
||||
"mlp1.1": {
|
||||
"r": 4,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
},
|
||||
"mlp1.3": {
|
||||
"r": 4,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
},
|
||||
# "q_proj": {
|
||||
# "r": 4,
|
||||
# "lora_alpha": 16,
|
||||
# "lora_dropout": 0.1,
|
||||
# },
|
||||
# "k_proj": {
|
||||
# "r": 4,
|
||||
# "lora_alpha": 16,
|
||||
# "lora_dropout": 0.1,
|
||||
# }
|
||||
},
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# 训练配置
|
||||
def get_training_args():
|
||||
return GRPOConfig(
|
||||
output_dir="chat_grpo_output",
|
||||
num_generations = 4,
|
||||
learning_rate=1e-5,
|
||||
logging_steps=100,
|
||||
max_prompt_length=None,
|
||||
gradient_accumulation_steps=4,
|
||||
max_completion_length=50,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=1000,
|
||||
dataloader_pin_memory=False # 禁用 pin_memory
|
||||
)
|
||||
|
||||
class ImagePromptProcessor:
|
||||
def __init__(self, tokenizer, model):
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
# 继承 tokenizer 的所有属性
|
||||
for attr_name in dir(tokenizer):
|
||||
# 跳过内置属性和方法
|
||||
if not attr_name.startswith('__'):
|
||||
try:
|
||||
setattr(self, attr_name, getattr(tokenizer, attr_name))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def __getattr__(self, name):
|
||||
# 如果属性在本类中找不到,则尝试从 tokenizer 中获取
|
||||
return getattr(self.tokenizer, name)
|
||||
|
||||
def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs):
|
||||
if isinstance(prompts[0], PromptWithImage):
|
||||
pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0)
|
||||
texts = [str(p) for p in prompts]
|
||||
num_patches_list = [p.pixel_values.shape[0] for p in prompts]
|
||||
|
||||
# 使用 batch_chat 获取 prompt_ids 和 attention_mask
|
||||
prompt_ids, prompt_mask = self.model.for_grpo(
|
||||
self.tokenizer,
|
||||
pixel_values,
|
||||
texts,
|
||||
num_patches_list=num_patches_list,
|
||||
history=None,
|
||||
return_history=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": prompt_ids,
|
||||
"attention_mask": prompt_mask
|
||||
}
|
||||
else:
|
||||
# 处理普通文本
|
||||
return self.tokenizer(
|
||||
prompts,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
padding_side=padding_side,
|
||||
add_special_tokens=add_special_tokens
|
||||
)
|
||||
|
||||
# 修改模型类,添加 generate 方法
|
||||
class CustomInternVLModel(InternVLChatModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
# 首先尝试直接使用模型的forward
|
||||
return super().forward(*args, **kwargs)
|
||||
except (TypeError, ValueError, AttributeError) as e:
|
||||
# 如果出现参数不匹配或其他错误,使用 language_model
|
||||
print(f"切换到 language_model 进行前向传播: {str(e)}")
|
||||
return self.language_model(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
print(f"切换到 language_model 进行生成")
|
||||
return self.language_model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def main():
|
||||
# 初始化 accelerator
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=4,
|
||||
mixed_precision="bf16"
|
||||
)
|
||||
|
||||
# 设置随机种子
|
||||
set_seed(42)
|
||||
|
||||
# 设置日志
|
||||
setup_logging()
|
||||
|
||||
# 加载数据
|
||||
logging.info("正在加载数据...")
|
||||
with open('bal_data/frequent_brands_data.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
logging.info(f"成功加载数据")
|
||||
|
||||
# 加载模型时使用自定义类
|
||||
path = 'Internvl2_5'
|
||||
model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train()
|
||||
model.name_or_path = 'CustomInternVLModel'
|
||||
|
||||
# 加载预训练权重
|
||||
print("正在加载预训练权重 vit_mlp_epoch_15.pth ...")
|
||||
model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth'))
|
||||
print("成功加载预训练权重")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
||||
|
||||
# 创建处理器
|
||||
processor = ImagePromptProcessor(tokenizer, model)
|
||||
|
||||
# 创建数据集
|
||||
dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data))))
|
||||
|
||||
# 创建训练器(使用原始的 GRPOTrainer)
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=reward_func,
|
||||
args=get_training_args(),
|
||||
train_dataset=dataset,
|
||||
processing_class=processor,
|
||||
peft_config=get_lora_config()
|
||||
)
|
||||
|
||||
# 使用 accelerator 准备模型和训练器
|
||||
trainer = accelerator.prepare(trainer)
|
||||
|
||||
# 开始训练
|
||||
logging.info("开始训练...")
|
||||
trainer.train()
|
||||
logging.info("训练完成")
|
||||
|
||||
# 保存模型
|
||||
output_dir = "chat_model_lora_3"
|
||||
unwrapped_model = accelerator.unwrap_model(trainer.model)
|
||||
unwrapped_model.save_pretrained(output_dir)
|
||||
logging.info(f"模型已保存到 {output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user