import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
from accelerate import Accelerator
import torch
import logging
from datetime import datetime
from transformers import BitsAndBytesConfig, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import pickle
from torch.utils.data import Dataset
import re
from trl import GRPOTrainer, GRPOConfig
from open_r1.trainer.grpo_trainer import VLMGRPOTrainer
from open_r1.vlm_modules.qwen_module import Qwen2VLModule
# 在最开始添加日志配置
def setup_logging():
# 创建logs目录(如果不存在)
if not os.path.exists('logs'):
os.makedirs('logs')
# 生成带时间戳的日志文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/training_{timestamp}.log'
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler() # 同时输出到控制台
]
)
logging.info(f"日志文件创建于: {log_file}")
return log_file
# 设置日志
log_file = setup_logging()
# 初始化 accelerator
accelerator = Accelerator()
# ## TRAINING
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
# bnb_4bit_use_double_quant= True,
)
# # 修改模型加载部分
# model = AutoModelForImageTextToText.from_pretrained(
# "./model",
# quantization_config=bnb_config,
# torch_dtype=torch.bfloat16
# )
# model.gradient_checkpointing_enable()
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
peft_config = LoraConfig(
task_type="CAUSAL_LM", # 因为是Causal Language Model
inference_mode=False,
r=8, # LoRA 秩
lora_alpha=32, # LoRA alpha参数
lora_dropout=0.1, # Dropout概率
target_modules=[ # 需要训练的模型层
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
bias="none",
)
# # 打印原始模型的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info("=== 训练配置信息 ===")
# logging.info(f"原始模型总参数量: {total_params:,}")
# logging.info(f"原始模型可训练参数量: {trainable_params:,}")
# logging.info(f"原始模型可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# model = get_peft_model(model, peft_config)
# # 打印QLora后的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info(f"\nQLora后总参数量: {total_params:,}")
# logging.info(f"QLora后可训练参数量: {trainable_params:,}")
# logging.info(f"QLora后可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# # 开启需要训练的参数的梯度更新
# model.train()
# for name, param in model.named_parameters():
# if param.requires_grad:
# # logging.info(f"开启参数 {name} 的梯度更新")
# param.requires_grad_(True)
class ChatDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt_text = "Please tell me the brand of the product in the picture between labels and and explain the reason between labels and "
# 使用模板格式化prompt
formatted_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>" + prompt_text + "<|im_end|>\n"
"<|im_start|>assistant"
)
return {
"prompt": formatted_prompt,
"image": item['image'],
"correct_brand": item['brand']
}
# 加载数据集
logging.info("加载训练数据...")
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
def reward_func(prompts, completions, **kwargs):
rewards = []
correct_brands = kwargs.get('correct_brand')
for completion, correct_brand in zip(completions, correct_brands):
reward = 0.0
# 提取标签中的内容
answer_match = re.search(r'(.*?)', completion)
# 提取标签中的内容
thinking_match = re.search(r'(.*?)', completion)
if answer_match:
answer_content = answer_match.group(1).lower()
if correct_brand.lower() in answer_content:
reward += 1.0
if thinking_match:
thinking_content = thinking_match.group(1).lower() # 使用单独的变量
if correct_brand.lower() in thinking_content: # 使用thinking的内容
reward += 1.0
# 使用logging替代print
logging.debug(f"\nCompletion: {completion}")
logging.debug(f"Correct brand: {correct_brand}")
logging.debug(f"Final reward: {reward}")
rewards.append(reward)
return rewards
def get_training_args():
args = GRPOConfig(
output_dir="chat_grpo_output",
num_generations=6,
learning_rate=1e-5,
logging_steps=100,
max_prompt_length=None,
gradient_accumulation_steps=1,
max_completion_length=200,
per_device_train_batch_size=3,
max_steps=1000,
dataloader_pin_memory=False,
model_init_kwargs={
"quantization_config": bnb_config,
"torch_dtype": torch.bfloat16,
"use_cache": False
}
)
args.epsilon = 0.2
args.num_iterations = 1
return args
# 然后再创建trainer
trainer = VLMGRPOTrainer(
model='./Qwen2.5-VL-7B',
reward_funcs=reward_func,
args=get_training_args(),
train_dataset=train_dataset,
peft_config=peft_config,
vlm_module=Qwen2VLModule()
)
# 训练相关的日志
logging.info("开始训练...")
trainer.train()
logging.info("训练完成")
# 保存模型
output_dir = "chat_model_lora"
unwrapped_model = accelerator.unwrap_model(trainer.model)
unwrapped_model.save_pretrained(output_dir)
logging.info(f"模型已保存到 {output_dir}")