293 lines
9.0 KiB
Python
293 lines
9.0 KiB
Python
import os
|
||
# 环境设置
|
||
#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
|
||
import logging
|
||
import pickle
|
||
import random
|
||
import torch
|
||
torch.set_float32_matmul_precision('high')
|
||
from datetime import datetime
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from transformers import AutoTokenizer
|
||
from trl import GRPOTrainer, GRPOConfig
|
||
from peft import LoraConfig
|
||
from predata import load_image
|
||
from Internvl2_5.conversation import get_conv_template
|
||
from Internvl2_5.modeling_internvl_chat import InternVLChatModel
|
||
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
|
||
from accelerate import Accelerator
|
||
from accelerate.utils import set_seed
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 日志配置
|
||
def setup_logging():
|
||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
log_filename = f'training_log_{current_time}.txt'
|
||
logging.basicConfig(
|
||
filename=log_filename,
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
class PromptWithImage:
|
||
def __init__(self, text, pixel_values):
|
||
self.text = text
|
||
self.pixel_values = pixel_values
|
||
|
||
def __str__(self):
|
||
return self.text
|
||
|
||
def __repr__(self):
|
||
return self.text
|
||
|
||
class ChatDataset(Dataset):
|
||
def __init__(self, data, tokenizer, batch_size, all_brands):
|
||
self.data = data
|
||
self.tokenizer = tokenizer
|
||
self.batch_size = batch_size
|
||
self.all_brands = all_brands
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def create_prompt(self, pixel_values):
|
||
prompt_text = "<image>\nPlease identify the most appropriate brand from the following options based on the image content:\n"
|
||
for brand in self.all_brands:
|
||
prompt_text += f"{brand}\n"
|
||
prompt_text += "Please only respond with the brand name, no additional explanation needed."
|
||
|
||
return PromptWithImage(prompt_text, pixel_values)
|
||
|
||
def __getitem__(self, idx):
|
||
item = self.data[idx]
|
||
pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16)
|
||
|
||
return {
|
||
"prompt": self.create_prompt(pixel_values),
|
||
"correct_brand": item['brand'],
|
||
}
|
||
|
||
|
||
# 奖励函数
|
||
def reward_func(prompts, completions, **kwargs):
|
||
rewards = []
|
||
# 打印调试信息
|
||
print("prompts length:", len(prompts))
|
||
print("completions length:", len(completions))
|
||
print("kwargs keys:", kwargs.keys())
|
||
correct_brands = kwargs.get('correct_brand')
|
||
print("correct_brands length:", len(correct_brands))
|
||
|
||
for completion, correct_brand in zip(completions, correct_brands):
|
||
|
||
# 简单判断:品牌名是否出现在回答中(不区分大小写)
|
||
correct = correct_brand.lower() in completion.lower()
|
||
|
||
# 打印调试信息
|
||
print("completion:", completion)
|
||
print("correct_brand:", correct_brand)
|
||
print("is_correct:", correct)
|
||
|
||
rewards.append(float(correct))
|
||
|
||
print("rewards length:", len(rewards))
|
||
return rewards
|
||
|
||
|
||
# 模型配置
|
||
def get_model_config():
|
||
return {
|
||
'torch_dtype': torch.bfloat16,
|
||
'low_cpu_mem_usage': True,
|
||
'use_flash_attn': True,
|
||
'trust_remote_code': True,
|
||
'vision_model': None,
|
||
'language_model': None
|
||
}
|
||
|
||
# LoRA配置
|
||
def get_lora_config():
|
||
return LoraConfig(
|
||
task_type="CAUSAL_LM",
|
||
r=4,
|
||
lora_alpha=16,
|
||
lora_dropout=0.1,
|
||
target_modules={
|
||
"mlp1.1": {
|
||
"r": 4,
|
||
"lora_alpha": 16,
|
||
"lora_dropout": 0.1,
|
||
},
|
||
"mlp1.3": {
|
||
"r": 4,
|
||
"lora_alpha": 16,
|
||
"lora_dropout": 0.1,
|
||
},
|
||
# "q_proj": {
|
||
# "r": 4,
|
||
# "lora_alpha": 16,
|
||
# "lora_dropout": 0.1,
|
||
# },
|
||
# "k_proj": {
|
||
# "r": 4,
|
||
# "lora_alpha": 16,
|
||
# "lora_dropout": 0.1,
|
||
# }
|
||
},
|
||
bias="none"
|
||
)
|
||
|
||
# 训练配置
|
||
def get_training_args():
|
||
return GRPOConfig(
|
||
output_dir="chat_grpo_output",
|
||
num_generations = 4,
|
||
learning_rate=1e-5,
|
||
logging_steps=100,
|
||
max_prompt_length=None,
|
||
gradient_accumulation_steps=4,
|
||
max_completion_length=50,
|
||
per_device_train_batch_size=2,
|
||
max_steps=1000,
|
||
dataloader_pin_memory=False # 禁用 pin_memory
|
||
)
|
||
|
||
class ImagePromptProcessor:
|
||
def __init__(self, tokenizer, model):
|
||
self.tokenizer = tokenizer
|
||
self.model = model
|
||
# 继承 tokenizer 的所有属性
|
||
for attr_name in dir(tokenizer):
|
||
# 跳过内置属性和方法
|
||
if not attr_name.startswith('__'):
|
||
try:
|
||
setattr(self, attr_name, getattr(tokenizer, attr_name))
|
||
except AttributeError:
|
||
pass
|
||
|
||
def __getattr__(self, name):
|
||
# 如果属性在本类中找不到,则尝试从 tokenizer 中获取
|
||
return getattr(self.tokenizer, name)
|
||
|
||
def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs):
|
||
if isinstance(prompts[0], PromptWithImage):
|
||
pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0)
|
||
texts = [str(p) for p in prompts]
|
||
num_patches_list = [p.pixel_values.shape[0] for p in prompts]
|
||
|
||
# 使用 batch_chat 获取 prompt_ids 和 attention_mask
|
||
prompt_ids, prompt_mask = self.model.for_grpo(
|
||
self.tokenizer,
|
||
pixel_values,
|
||
texts,
|
||
num_patches_list=num_patches_list,
|
||
history=None,
|
||
return_history=False,
|
||
)
|
||
|
||
return {
|
||
"input_ids": prompt_ids,
|
||
"attention_mask": prompt_mask
|
||
}
|
||
else:
|
||
# 处理普通文本
|
||
return self.tokenizer(
|
||
prompts,
|
||
return_tensors=return_tensors,
|
||
padding=padding,
|
||
padding_side=padding_side,
|
||
add_special_tokens=add_special_tokens
|
||
)
|
||
|
||
# 修改模型类,添加 generate 方法
|
||
class CustomInternVLModel(InternVLChatModel):
|
||
def forward(self, *args, **kwargs):
|
||
try:
|
||
# 首先尝试直接使用模型的forward
|
||
return super().forward(*args, **kwargs)
|
||
except (TypeError, ValueError, AttributeError) as e:
|
||
# 如果出现参数不匹配或其他错误,使用 language_model
|
||
print(f"切换到 language_model 进行前向传播: {str(e)}")
|
||
return self.language_model(*args, **kwargs)
|
||
|
||
@torch.no_grad()
|
||
def generate(
|
||
self,
|
||
input_ids,
|
||
attention_mask=None,
|
||
**kwargs
|
||
):
|
||
|
||
print(f"切换到 language_model 进行生成")
|
||
return self.language_model.generate(
|
||
input_ids=input_ids,
|
||
attention_mask=attention_mask,
|
||
**kwargs
|
||
)
|
||
|
||
def main():
|
||
# 初始化 accelerator
|
||
accelerator = Accelerator(
|
||
gradient_accumulation_steps=4,
|
||
mixed_precision="bf16"
|
||
)
|
||
|
||
# 设置随机种子
|
||
set_seed(42)
|
||
|
||
# 设置日志
|
||
setup_logging()
|
||
|
||
# 加载数据
|
||
logging.info("正在加载数据...")
|
||
with open('bal_data/frequent_brands_data.pkl', 'rb') as f:
|
||
data = pickle.load(f)
|
||
logging.info(f"成功加载数据")
|
||
|
||
# 加载模型时使用自定义类
|
||
path = 'Internvl2_5'
|
||
model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train()
|
||
model.name_or_path = 'CustomInternVLModel'
|
||
|
||
# 加载预训练权重
|
||
print("正在加载预训练权重 vit_mlp_epoch_15.pth ...")
|
||
model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth'))
|
||
print("成功加载预训练权重")
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
||
|
||
# 创建处理器
|
||
processor = ImagePromptProcessor(tokenizer, model)
|
||
|
||
# 创建数据集
|
||
dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data))))
|
||
|
||
# 创建训练器(使用原始的 GRPOTrainer)
|
||
trainer = GRPOTrainer(
|
||
model=model,
|
||
reward_funcs=reward_func,
|
||
args=get_training_args(),
|
||
train_dataset=dataset,
|
||
processing_class=processor,
|
||
peft_config=get_lora_config()
|
||
)
|
||
|
||
# 使用 accelerator 准备模型和训练器
|
||
trainer = accelerator.prepare(trainer)
|
||
|
||
# 开始训练
|
||
logging.info("开始训练...")
|
||
trainer.train()
|
||
logging.info("训练完成")
|
||
|
||
# 保存模型
|
||
output_dir = "chat_model_lora_3"
|
||
unwrapped_model = accelerator.unwrap_model(trainer.model)
|
||
unwrapped_model.save_pretrained(output_dir)
|
||
logging.info(f"模型已保存到 {output_dir}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|