293 lines
9.0 KiB
Python
293 lines
9.0 KiB
Python
|
import os
|
|||
|
# 环境设置
|
|||
|
#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
|
|||
|
import logging
|
|||
|
import pickle
|
|||
|
import random
|
|||
|
import torch
|
|||
|
torch.set_float32_matmul_precision('high')
|
|||
|
from datetime import datetime
|
|||
|
from torch.utils.data import Dataset, DataLoader
|
|||
|
from transformers import AutoTokenizer
|
|||
|
from trl import GRPOTrainer, GRPOConfig
|
|||
|
from peft import LoraConfig
|
|||
|
from predata import load_image
|
|||
|
from Internvl2_5.conversation import get_conv_template
|
|||
|
from Internvl2_5.modeling_internvl_chat import InternVLChatModel
|
|||
|
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration
|
|||
|
from accelerate import Accelerator
|
|||
|
from accelerate.utils import set_seed
|
|||
|
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|||
|
# 日志配置
|
|||
|
def setup_logging():
|
|||
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|||
|
log_filename = f'training_log_{current_time}.txt'
|
|||
|
logging.basicConfig(
|
|||
|
filename=log_filename,
|
|||
|
level=logging.INFO,
|
|||
|
format='%(asctime)s - %(message)s',
|
|||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|||
|
)
|
|||
|
|
|||
|
class PromptWithImage:
|
|||
|
def __init__(self, text, pixel_values):
|
|||
|
self.text = text
|
|||
|
self.pixel_values = pixel_values
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
return self.text
|
|||
|
|
|||
|
def __repr__(self):
|
|||
|
return self.text
|
|||
|
|
|||
|
class ChatDataset(Dataset):
|
|||
|
def __init__(self, data, tokenizer, batch_size, all_brands):
|
|||
|
self.data = data
|
|||
|
self.tokenizer = tokenizer
|
|||
|
self.batch_size = batch_size
|
|||
|
self.all_brands = all_brands
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.data)
|
|||
|
|
|||
|
def create_prompt(self, pixel_values):
|
|||
|
prompt_text = "<image>\nPlease identify the most appropriate brand from the following options based on the image content:\n"
|
|||
|
for brand in self.all_brands:
|
|||
|
prompt_text += f"{brand}\n"
|
|||
|
prompt_text += "Please only respond with the brand name, no additional explanation needed."
|
|||
|
|
|||
|
return PromptWithImage(prompt_text, pixel_values)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
item = self.data[idx]
|
|||
|
pixel_values = load_image(item['image'], max_num=12).to(dtype=torch.bfloat16)
|
|||
|
|
|||
|
return {
|
|||
|
"prompt": self.create_prompt(pixel_values),
|
|||
|
"correct_brand": item['brand'],
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
# 奖励函数
|
|||
|
def reward_func(prompts, completions, **kwargs):
|
|||
|
rewards = []
|
|||
|
# 打印调试信息
|
|||
|
print("prompts length:", len(prompts))
|
|||
|
print("completions length:", len(completions))
|
|||
|
print("kwargs keys:", kwargs.keys())
|
|||
|
correct_brands = kwargs.get('correct_brand')
|
|||
|
print("correct_brands length:", len(correct_brands))
|
|||
|
|
|||
|
for completion, correct_brand in zip(completions, correct_brands):
|
|||
|
|
|||
|
# 简单判断:品牌名是否出现在回答中(不区分大小写)
|
|||
|
correct = correct_brand.lower() in completion.lower()
|
|||
|
|
|||
|
# 打印调试信息
|
|||
|
print("completion:", completion)
|
|||
|
print("correct_brand:", correct_brand)
|
|||
|
print("is_correct:", correct)
|
|||
|
|
|||
|
rewards.append(float(correct))
|
|||
|
|
|||
|
print("rewards length:", len(rewards))
|
|||
|
return rewards
|
|||
|
|
|||
|
|
|||
|
# 模型配置
|
|||
|
def get_model_config():
|
|||
|
return {
|
|||
|
'torch_dtype': torch.bfloat16,
|
|||
|
'low_cpu_mem_usage': True,
|
|||
|
'use_flash_attn': True,
|
|||
|
'trust_remote_code': True,
|
|||
|
'vision_model': None,
|
|||
|
'language_model': None
|
|||
|
}
|
|||
|
|
|||
|
# LoRA配置
|
|||
|
def get_lora_config():
|
|||
|
return LoraConfig(
|
|||
|
task_type="CAUSAL_LM",
|
|||
|
r=4,
|
|||
|
lora_alpha=16,
|
|||
|
lora_dropout=0.1,
|
|||
|
target_modules={
|
|||
|
"mlp1.1": {
|
|||
|
"r": 4,
|
|||
|
"lora_alpha": 16,
|
|||
|
"lora_dropout": 0.1,
|
|||
|
},
|
|||
|
"mlp1.3": {
|
|||
|
"r": 4,
|
|||
|
"lora_alpha": 16,
|
|||
|
"lora_dropout": 0.1,
|
|||
|
},
|
|||
|
# "q_proj": {
|
|||
|
# "r": 4,
|
|||
|
# "lora_alpha": 16,
|
|||
|
# "lora_dropout": 0.1,
|
|||
|
# },
|
|||
|
# "k_proj": {
|
|||
|
# "r": 4,
|
|||
|
# "lora_alpha": 16,
|
|||
|
# "lora_dropout": 0.1,
|
|||
|
# }
|
|||
|
},
|
|||
|
bias="none"
|
|||
|
)
|
|||
|
|
|||
|
# 训练配置
|
|||
|
def get_training_args():
|
|||
|
return GRPOConfig(
|
|||
|
output_dir="chat_grpo_output",
|
|||
|
num_generations = 4,
|
|||
|
learning_rate=1e-5,
|
|||
|
logging_steps=100,
|
|||
|
max_prompt_length=None,
|
|||
|
gradient_accumulation_steps=4,
|
|||
|
max_completion_length=50,
|
|||
|
per_device_train_batch_size=2,
|
|||
|
max_steps=1000,
|
|||
|
dataloader_pin_memory=False # 禁用 pin_memory
|
|||
|
)
|
|||
|
|
|||
|
class ImagePromptProcessor:
|
|||
|
def __init__(self, tokenizer, model):
|
|||
|
self.tokenizer = tokenizer
|
|||
|
self.model = model
|
|||
|
# 继承 tokenizer 的所有属性
|
|||
|
for attr_name in dir(tokenizer):
|
|||
|
# 跳过内置属性和方法
|
|||
|
if not attr_name.startswith('__'):
|
|||
|
try:
|
|||
|
setattr(self, attr_name, getattr(tokenizer, attr_name))
|
|||
|
except AttributeError:
|
|||
|
pass
|
|||
|
|
|||
|
def __getattr__(self, name):
|
|||
|
# 如果属性在本类中找不到,则尝试从 tokenizer 中获取
|
|||
|
return getattr(self.tokenizer, name)
|
|||
|
|
|||
|
def __call__(self, prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False, **kwargs):
|
|||
|
if isinstance(prompts[0], PromptWithImage):
|
|||
|
pixel_values = torch.cat([p.pixel_values for p in prompts], dim=0)
|
|||
|
texts = [str(p) for p in prompts]
|
|||
|
num_patches_list = [p.pixel_values.shape[0] for p in prompts]
|
|||
|
|
|||
|
# 使用 batch_chat 获取 prompt_ids 和 attention_mask
|
|||
|
prompt_ids, prompt_mask = self.model.for_grpo(
|
|||
|
self.tokenizer,
|
|||
|
pixel_values,
|
|||
|
texts,
|
|||
|
num_patches_list=num_patches_list,
|
|||
|
history=None,
|
|||
|
return_history=False,
|
|||
|
)
|
|||
|
|
|||
|
return {
|
|||
|
"input_ids": prompt_ids,
|
|||
|
"attention_mask": prompt_mask
|
|||
|
}
|
|||
|
else:
|
|||
|
# 处理普通文本
|
|||
|
return self.tokenizer(
|
|||
|
prompts,
|
|||
|
return_tensors=return_tensors,
|
|||
|
padding=padding,
|
|||
|
padding_side=padding_side,
|
|||
|
add_special_tokens=add_special_tokens
|
|||
|
)
|
|||
|
|
|||
|
# 修改模型类,添加 generate 方法
|
|||
|
class CustomInternVLModel(InternVLChatModel):
|
|||
|
def forward(self, *args, **kwargs):
|
|||
|
try:
|
|||
|
# 首先尝试直接使用模型的forward
|
|||
|
return super().forward(*args, **kwargs)
|
|||
|
except (TypeError, ValueError, AttributeError) as e:
|
|||
|
# 如果出现参数不匹配或其他错误,使用 language_model
|
|||
|
print(f"切换到 language_model 进行前向传播: {str(e)}")
|
|||
|
return self.language_model(*args, **kwargs)
|
|||
|
|
|||
|
@torch.no_grad()
|
|||
|
def generate(
|
|||
|
self,
|
|||
|
input_ids,
|
|||
|
attention_mask=None,
|
|||
|
**kwargs
|
|||
|
):
|
|||
|
|
|||
|
print(f"切换到 language_model 进行生成")
|
|||
|
return self.language_model.generate(
|
|||
|
input_ids=input_ids,
|
|||
|
attention_mask=attention_mask,
|
|||
|
**kwargs
|
|||
|
)
|
|||
|
|
|||
|
def main():
|
|||
|
# 初始化 accelerator
|
|||
|
accelerator = Accelerator(
|
|||
|
gradient_accumulation_steps=4,
|
|||
|
mixed_precision="bf16"
|
|||
|
)
|
|||
|
|
|||
|
# 设置随机种子
|
|||
|
set_seed(42)
|
|||
|
|
|||
|
# 设置日志
|
|||
|
setup_logging()
|
|||
|
|
|||
|
# 加载数据
|
|||
|
logging.info("正在加载数据...")
|
|||
|
with open('bal_data/frequent_brands_data.pkl', 'rb') as f:
|
|||
|
data = pickle.load(f)
|
|||
|
logging.info(f"成功加载数据")
|
|||
|
|
|||
|
# 加载模型时使用自定义类
|
|||
|
path = 'Internvl2_5'
|
|||
|
model = CustomInternVLModel.from_pretrained(path, **get_model_config()).train()
|
|||
|
model.name_or_path = 'CustomInternVLModel'
|
|||
|
|
|||
|
# 加载预训练权重
|
|||
|
print("正在加载预训练权重 vit_mlp_epoch_15.pth ...")
|
|||
|
model.load_state_dict(torch.load('weights/vit_mlp_epoch_15.pth'))
|
|||
|
print("成功加载预训练权重")
|
|||
|
|
|||
|
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
|||
|
|
|||
|
# 创建处理器
|
|||
|
processor = ImagePromptProcessor(tokenizer, model)
|
|||
|
|
|||
|
# 创建数据集
|
|||
|
dataset = ChatDataset(data, tokenizer, 1, sorted(list(set(item['brand'] for item in data))))
|
|||
|
|
|||
|
# 创建训练器(使用原始的 GRPOTrainer)
|
|||
|
trainer = GRPOTrainer(
|
|||
|
model=model,
|
|||
|
reward_funcs=reward_func,
|
|||
|
args=get_training_args(),
|
|||
|
train_dataset=dataset,
|
|||
|
processing_class=processor,
|
|||
|
peft_config=get_lora_config()
|
|||
|
)
|
|||
|
|
|||
|
# 使用 accelerator 准备模型和训练器
|
|||
|
trainer = accelerator.prepare(trainer)
|
|||
|
|
|||
|
# 开始训练
|
|||
|
logging.info("开始训练...")
|
|||
|
trainer.train()
|
|||
|
logging.info("训练完成")
|
|||
|
|
|||
|
# 保存模型
|
|||
|
output_dir = "chat_model_lora_3"
|
|||
|
unwrapped_model = accelerator.unwrap_model(trainer.model)
|
|||
|
unwrapped_model.save_pretrained(output_dir)
|
|||
|
logging.info(f"模型已保存到 {output_dir}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|