Fortrain/qw/train.py
2025-03-31 15:56:36 +08:00

211 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
from accelerate import Accelerator
import torch
import logging
from datetime import datetime
from transformers import BitsAndBytesConfig, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import pickle
from torch.utils.data import Dataset
import re
from trl import GRPOTrainer, GRPOConfig
from open_r1.trainer.grpo_trainer import VLMGRPOTrainer
from open_r1.vlm_modules.qwen_module import Qwen2VLModule
# 在最开始添加日志配置
def setup_logging():
# 创建logs目录如果不存在
if not os.path.exists('logs'):
os.makedirs('logs')
# 生成带时间戳的日志文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/training_{timestamp}.log'
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler() # 同时输出到控制台
]
)
logging.info(f"日志文件创建于: {log_file}")
return log_file
# 设置日志
log_file = setup_logging()
# 初始化 accelerator
accelerator = Accelerator()
# ## TRAINING
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
# bnb_4bit_use_double_quant= True,
)
# # 修改模型加载部分
# model = AutoModelForImageTextToText.from_pretrained(
# "./model",
# quantization_config=bnb_config,
# torch_dtype=torch.bfloat16
# )
# model.gradient_checkpointing_enable()
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
peft_config = LoraConfig(
task_type="CAUSAL_LM", # 因为是Causal Language Model
inference_mode=False,
r=8, # LoRA 秩
lora_alpha=32, # LoRA alpha参数
lora_dropout=0.1, # Dropout概率
target_modules=[ # 需要训练的模型层
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
bias="none",
)
# # 打印原始模型的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info("=== 训练配置信息 ===")
# logging.info(f"原始模型总参数量: {total_params:,}")
# logging.info(f"原始模型可训练参数量: {trainable_params:,}")
# logging.info(f"原始模型可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# model = get_peft_model(model, peft_config)
# # 打印QLora后的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info(f"\nQLora后总参数量: {total_params:,}")
# logging.info(f"QLora后可训练参数量: {trainable_params:,}")
# logging.info(f"QLora后可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# # 开启需要训练的参数的梯度更新
# model.train()
# for name, param in model.named_parameters():
# if param.requires_grad:
# # logging.info(f"开启参数 {name} 的梯度更新")
# param.requires_grad_(True)
class ChatDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt_text = "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"
# 使用模板格式化prompt
formatted_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>" + prompt_text + "<|im_end|>\n"
"<|im_start|>assistant"
)
return {
"prompt": formatted_prompt,
"image": item['image'],
"correct_brand": item['brand']
}
# 加载数据集
logging.info("加载训练数据...")
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
def reward_func(prompts, completions, **kwargs):
rewards = []
correct_brands = kwargs.get('correct_brand')
for completion, correct_brand in zip(completions, correct_brands):
reward = 0.0
# 提取<answer>标签中的内容
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
# 提取<thinking>标签中的内容
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
if answer_match:
answer_content = answer_match.group(1).lower()
if correct_brand.lower() in answer_content:
reward += 1.0
if thinking_match:
thinking_content = thinking_match.group(1).lower() # 使用单独的变量
if correct_brand.lower() in thinking_content: # 使用thinking的内容
reward += 1.0
# 使用logging替代print
logging.debug(f"\nCompletion: {completion}")
logging.debug(f"Correct brand: {correct_brand}")
logging.debug(f"Final reward: {reward}")
rewards.append(reward)
return rewards
def get_training_args():
args = GRPOConfig(
output_dir="chat_grpo_output",
num_generations=6,
learning_rate=1e-5,
logging_steps=100,
max_prompt_length=None,
gradient_accumulation_steps=1,
max_completion_length=200,
per_device_train_batch_size=3,
max_steps=1000,
dataloader_pin_memory=False,
model_init_kwargs={
"quantization_config": bnb_config,
"torch_dtype": torch.bfloat16,
"use_cache": False
}
)
args.epsilon = 0.2
args.num_iterations = 1
return args
# 然后再创建trainer
trainer = VLMGRPOTrainer(
model='./Qwen2.5-VL-7B',
reward_funcs=reward_func,
args=get_training_args(),
train_dataset=train_dataset,
peft_config=peft_config,
vlm_module=Qwen2VLModule()
)
# 训练相关的日志
logging.info("开始训练...")
trainer.train()
logging.info("训练完成")
# 保存模型
output_dir = "chat_model_lora"
unwrapped_model = accelerator.unwrap_model(trainer.model)
unwrapped_model.save_pretrained(output_dir)
logging.info(f"模型已保存到 {output_dir}")