211 lines
6.7 KiB
Python
211 lines
6.7 KiB
Python
|
import os
|
|||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
|||
|
from accelerate import Accelerator
|
|||
|
import torch
|
|||
|
import logging
|
|||
|
from datetime import datetime
|
|||
|
from transformers import BitsAndBytesConfig, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
|
|||
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|||
|
import pickle
|
|||
|
from torch.utils.data import Dataset
|
|||
|
import re
|
|||
|
from trl import GRPOTrainer, GRPOConfig
|
|||
|
from open_r1.trainer.grpo_trainer import VLMGRPOTrainer
|
|||
|
from open_r1.vlm_modules.qwen_module import Qwen2VLModule
|
|||
|
|
|||
|
# 在最开始添加日志配置
|
|||
|
def setup_logging():
|
|||
|
# 创建logs目录(如果不存在)
|
|||
|
if not os.path.exists('logs'):
|
|||
|
os.makedirs('logs')
|
|||
|
|
|||
|
# 生成带时间戳的日志文件名
|
|||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|||
|
log_file = f'logs/training_{timestamp}.log'
|
|||
|
|
|||
|
# 配置日志格式
|
|||
|
logging.basicConfig(
|
|||
|
level=logging.INFO,
|
|||
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|||
|
handlers=[
|
|||
|
logging.FileHandler(log_file),
|
|||
|
logging.StreamHandler() # 同时输出到控制台
|
|||
|
]
|
|||
|
)
|
|||
|
|
|||
|
logging.info(f"日志文件创建于: {log_file}")
|
|||
|
return log_file
|
|||
|
|
|||
|
# 设置日志
|
|||
|
log_file = setup_logging()
|
|||
|
|
|||
|
# 初始化 accelerator
|
|||
|
accelerator = Accelerator()
|
|||
|
|
|||
|
# ## TRAINING
|
|||
|
bnb_config = BitsAndBytesConfig(
|
|||
|
load_in_4bit= True,
|
|||
|
bnb_4bit_quant_type= "nf4",
|
|||
|
bnb_4bit_compute_dtype= torch.bfloat16,
|
|||
|
# bnb_4bit_use_double_quant= True,
|
|||
|
)
|
|||
|
# # 修改模型加载部分
|
|||
|
# model = AutoModelForImageTextToText.from_pretrained(
|
|||
|
# "./model",
|
|||
|
# quantization_config=bnb_config,
|
|||
|
# torch_dtype=torch.bfloat16
|
|||
|
# )
|
|||
|
|
|||
|
# model.gradient_checkpointing_enable()
|
|||
|
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
|
|||
|
peft_config = LoraConfig(
|
|||
|
task_type="CAUSAL_LM", # 因为是Causal Language Model
|
|||
|
inference_mode=False,
|
|||
|
r=8, # LoRA 秩
|
|||
|
lora_alpha=32, # LoRA alpha参数
|
|||
|
lora_dropout=0.1, # Dropout概率
|
|||
|
target_modules=[ # 需要训练的模型层
|
|||
|
"q_proj",
|
|||
|
"k_proj",
|
|||
|
"v_proj",
|
|||
|
"o_proj",
|
|||
|
],
|
|||
|
bias="none",
|
|||
|
)
|
|||
|
|
|||
|
# # 打印原始模型的参数统计
|
|||
|
# total_params = sum(p.numel() for p in model.parameters())
|
|||
|
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|||
|
# logging.info("=== 训练配置信息 ===")
|
|||
|
# logging.info(f"原始模型总参数量: {total_params:,}")
|
|||
|
# logging.info(f"原始模型可训练参数量: {trainable_params:,}")
|
|||
|
# logging.info(f"原始模型可训练参数占比: {100 * trainable_params / total_params:.2f}%")
|
|||
|
|
|||
|
# model = get_peft_model(model, peft_config)
|
|||
|
|
|||
|
# # 打印QLora后的参数统计
|
|||
|
# total_params = sum(p.numel() for p in model.parameters())
|
|||
|
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|||
|
# logging.info(f"\nQLora后总参数量: {total_params:,}")
|
|||
|
# logging.info(f"QLora后可训练参数量: {trainable_params:,}")
|
|||
|
# logging.info(f"QLora后可训练参数占比: {100 * trainable_params / total_params:.2f}%")
|
|||
|
|
|||
|
# # 开启需要训练的参数的梯度更新
|
|||
|
# model.train()
|
|||
|
# for name, param in model.named_parameters():
|
|||
|
# if param.requires_grad:
|
|||
|
# # logging.info(f"开启参数 {name} 的梯度更新")
|
|||
|
# param.requires_grad_(True)
|
|||
|
|
|||
|
class ChatDataset(Dataset):
|
|||
|
def __init__(self, data_path):
|
|||
|
with open(data_path, 'rb') as f:
|
|||
|
self.data = pickle.load(f)
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.data)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
item = self.data[idx]
|
|||
|
prompt_text = "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"
|
|||
|
|
|||
|
# 使用模板格式化prompt
|
|||
|
formatted_prompt = (
|
|||
|
"<|im_start|>system\n"
|
|||
|
"You are a helpful assistant.<|im_end|>\n"
|
|||
|
"<|im_start|>user\n"
|
|||
|
"<|vision_start|><|image_pad|><|vision_end|>" + prompt_text + "<|im_end|>\n"
|
|||
|
"<|im_start|>assistant"
|
|||
|
)
|
|||
|
|
|||
|
return {
|
|||
|
"prompt": formatted_prompt,
|
|||
|
"image": item['image'],
|
|||
|
"correct_brand": item['brand']
|
|||
|
}
|
|||
|
|
|||
|
# 加载数据集
|
|||
|
logging.info("加载训练数据...")
|
|||
|
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
|
|||
|
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
|
|||
|
|
|||
|
def reward_func(prompts, completions, **kwargs):
|
|||
|
rewards = []
|
|||
|
correct_brands = kwargs.get('correct_brand')
|
|||
|
|
|||
|
for completion, correct_brand in zip(completions, correct_brands):
|
|||
|
reward = 0.0
|
|||
|
|
|||
|
# 提取<answer>标签中的内容
|
|||
|
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
|
|||
|
# 提取<thinking>标签中的内容
|
|||
|
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
|
|||
|
|
|||
|
if answer_match:
|
|||
|
answer_content = answer_match.group(1).lower()
|
|||
|
if correct_brand.lower() in answer_content:
|
|||
|
reward += 1.0
|
|||
|
|
|||
|
if thinking_match:
|
|||
|
thinking_content = thinking_match.group(1).lower() # 使用单独的变量
|
|||
|
if correct_brand.lower() in thinking_content: # 使用thinking的内容
|
|||
|
reward += 1.0
|
|||
|
|
|||
|
# 使用logging替代print
|
|||
|
logging.debug(f"\nCompletion: {completion}")
|
|||
|
logging.debug(f"Correct brand: {correct_brand}")
|
|||
|
logging.debug(f"Final reward: {reward}")
|
|||
|
|
|||
|
rewards.append(reward)
|
|||
|
|
|||
|
return rewards
|
|||
|
|
|||
|
|
|||
|
def get_training_args():
|
|||
|
args = GRPOConfig(
|
|||
|
output_dir="chat_grpo_output",
|
|||
|
num_generations=6,
|
|||
|
learning_rate=1e-5,
|
|||
|
logging_steps=100,
|
|||
|
max_prompt_length=None,
|
|||
|
gradient_accumulation_steps=1,
|
|||
|
max_completion_length=200,
|
|||
|
per_device_train_batch_size=3,
|
|||
|
max_steps=1000,
|
|||
|
dataloader_pin_memory=False,
|
|||
|
model_init_kwargs={
|
|||
|
"quantization_config": bnb_config,
|
|||
|
"torch_dtype": torch.bfloat16,
|
|||
|
"use_cache": False
|
|||
|
}
|
|||
|
)
|
|||
|
args.epsilon = 0.2
|
|||
|
args.num_iterations = 1
|
|||
|
return args
|
|||
|
|
|||
|
# 然后再创建trainer
|
|||
|
trainer = VLMGRPOTrainer(
|
|||
|
model='./Qwen2.5-VL-7B',
|
|||
|
reward_funcs=reward_func,
|
|||
|
args=get_training_args(),
|
|||
|
train_dataset=train_dataset,
|
|||
|
peft_config=peft_config,
|
|||
|
vlm_module=Qwen2VLModule()
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
# 训练相关的日志
|
|||
|
logging.info("开始训练...")
|
|||
|
trainer.train()
|
|||
|
logging.info("训练完成")
|
|||
|
|
|||
|
# 保存模型
|
|||
|
output_dir = "chat_model_lora"
|
|||
|
unwrapped_model = accelerator.unwrap_model(trainer.model)
|
|||
|
unwrapped_model.save_pretrained(output_dir)
|
|||
|
logging.info(f"模型已保存到 {output_dir}")
|
|||
|
|
|||
|
|
|||
|
|