211 lines
6.7 KiB
Python
211 lines
6.7 KiB
Python
import os
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
||
from accelerate import Accelerator
|
||
import torch
|
||
import logging
|
||
from datetime import datetime
|
||
from transformers import BitsAndBytesConfig, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
|
||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||
import pickle
|
||
from torch.utils.data import Dataset
|
||
import re
|
||
from trl import GRPOTrainer, GRPOConfig
|
||
from open_r1.trainer.grpo_trainer import VLMGRPOTrainer
|
||
from open_r1.vlm_modules.qwen_module import Qwen2VLModule
|
||
|
||
# 在最开始添加日志配置
|
||
def setup_logging():
|
||
# 创建logs目录(如果不存在)
|
||
if not os.path.exists('logs'):
|
||
os.makedirs('logs')
|
||
|
||
# 生成带时间戳的日志文件名
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
log_file = f'logs/training_{timestamp}.log'
|
||
|
||
# 配置日志格式
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(log_file),
|
||
logging.StreamHandler() # 同时输出到控制台
|
||
]
|
||
)
|
||
|
||
logging.info(f"日志文件创建于: {log_file}")
|
||
return log_file
|
||
|
||
# 设置日志
|
||
log_file = setup_logging()
|
||
|
||
# 初始化 accelerator
|
||
accelerator = Accelerator()
|
||
|
||
# ## TRAINING
|
||
bnb_config = BitsAndBytesConfig(
|
||
load_in_4bit= True,
|
||
bnb_4bit_quant_type= "nf4",
|
||
bnb_4bit_compute_dtype= torch.bfloat16,
|
||
# bnb_4bit_use_double_quant= True,
|
||
)
|
||
# # 修改模型加载部分
|
||
# model = AutoModelForImageTextToText.from_pretrained(
|
||
# "./model",
|
||
# quantization_config=bnb_config,
|
||
# torch_dtype=torch.bfloat16
|
||
# )
|
||
|
||
# model.gradient_checkpointing_enable()
|
||
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
|
||
peft_config = LoraConfig(
|
||
task_type="CAUSAL_LM", # 因为是Causal Language Model
|
||
inference_mode=False,
|
||
r=8, # LoRA 秩
|
||
lora_alpha=32, # LoRA alpha参数
|
||
lora_dropout=0.1, # Dropout概率
|
||
target_modules=[ # 需要训练的模型层
|
||
"q_proj",
|
||
"k_proj",
|
||
"v_proj",
|
||
"o_proj",
|
||
],
|
||
bias="none",
|
||
)
|
||
|
||
# # 打印原始模型的参数统计
|
||
# total_params = sum(p.numel() for p in model.parameters())
|
||
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
# logging.info("=== 训练配置信息 ===")
|
||
# logging.info(f"原始模型总参数量: {total_params:,}")
|
||
# logging.info(f"原始模型可训练参数量: {trainable_params:,}")
|
||
# logging.info(f"原始模型可训练参数占比: {100 * trainable_params / total_params:.2f}%")
|
||
|
||
# model = get_peft_model(model, peft_config)
|
||
|
||
# # 打印QLora后的参数统计
|
||
# total_params = sum(p.numel() for p in model.parameters())
|
||
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
# logging.info(f"\nQLora后总参数量: {total_params:,}")
|
||
# logging.info(f"QLora后可训练参数量: {trainable_params:,}")
|
||
# logging.info(f"QLora后可训练参数占比: {100 * trainable_params / total_params:.2f}%")
|
||
|
||
# # 开启需要训练的参数的梯度更新
|
||
# model.train()
|
||
# for name, param in model.named_parameters():
|
||
# if param.requires_grad:
|
||
# # logging.info(f"开启参数 {name} 的梯度更新")
|
||
# param.requires_grad_(True)
|
||
|
||
class ChatDataset(Dataset):
|
||
def __init__(self, data_path):
|
||
with open(data_path, 'rb') as f:
|
||
self.data = pickle.load(f)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
prompt_text = "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"
|
||
|
||
# 使用模板格式化prompt
|
||
formatted_prompt = (
|
||
"<|im_start|>system\n"
|
||
"You are a helpful assistant.<|im_end|>\n"
|
||
"<|im_start|>user\n"
|
||
"<|vision_start|><|image_pad|><|vision_end|>" + prompt_text + "<|im_end|>\n"
|
||
"<|im_start|>assistant"
|
||
)
|
||
|
||
return {
|
||
"prompt": formatted_prompt,
|
||
"image": item['image'],
|
||
"correct_brand": item['brand']
|
||
}
|
||
|
||
# 加载数据集
|
||
logging.info("加载训练数据...")
|
||
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
|
||
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
|
||
|
||
def reward_func(prompts, completions, **kwargs):
|
||
rewards = []
|
||
correct_brands = kwargs.get('correct_brand')
|
||
|
||
for completion, correct_brand in zip(completions, correct_brands):
|
||
reward = 0.0
|
||
|
||
# 提取<answer>标签中的内容
|
||
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
|
||
# 提取<thinking>标签中的内容
|
||
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
|
||
|
||
if answer_match:
|
||
answer_content = answer_match.group(1).lower()
|
||
if correct_brand.lower() in answer_content:
|
||
reward += 1.0
|
||
|
||
if thinking_match:
|
||
thinking_content = thinking_match.group(1).lower() # 使用单独的变量
|
||
if correct_brand.lower() in thinking_content: # 使用thinking的内容
|
||
reward += 1.0
|
||
|
||
# 使用logging替代print
|
||
logging.debug(f"\nCompletion: {completion}")
|
||
logging.debug(f"Correct brand: {correct_brand}")
|
||
logging.debug(f"Final reward: {reward}")
|
||
|
||
rewards.append(reward)
|
||
|
||
return rewards
|
||
|
||
|
||
def get_training_args():
|
||
args = GRPOConfig(
|
||
output_dir="chat_grpo_output",
|
||
num_generations=6,
|
||
learning_rate=1e-5,
|
||
logging_steps=100,
|
||
max_prompt_length=None,
|
||
gradient_accumulation_steps=1,
|
||
max_completion_length=200,
|
||
per_device_train_batch_size=3,
|
||
max_steps=1000,
|
||
dataloader_pin_memory=False,
|
||
model_init_kwargs={
|
||
"quantization_config": bnb_config,
|
||
"torch_dtype": torch.bfloat16,
|
||
"use_cache": False
|
||
}
|
||
)
|
||
args.epsilon = 0.2
|
||
args.num_iterations = 1
|
||
return args
|
||
|
||
# 然后再创建trainer
|
||
trainer = VLMGRPOTrainer(
|
||
model='./Qwen2.5-VL-7B',
|
||
reward_funcs=reward_func,
|
||
args=get_training_args(),
|
||
train_dataset=train_dataset,
|
||
peft_config=peft_config,
|
||
vlm_module=Qwen2VLModule()
|
||
)
|
||
|
||
|
||
# 训练相关的日志
|
||
logging.info("开始训练...")
|
||
trainer.train()
|
||
logging.info("训练完成")
|
||
|
||
# 保存模型
|
||
output_dir = "chat_model_lora"
|
||
unwrapped_model = accelerator.unwrap_model(trainer.model)
|
||
unwrapped_model.save_pretrained(output_dir)
|
||
logging.info(f"模型已保存到 {output_dir}")
|
||
|
||
|
||
|