Fortrain/zhtrain.py
2025-02-13 17:15:57 +08:00

235 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
import pickle
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from predata import prepare_training_data
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
from torch.optim import AdamW
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained("Internvl2_5")
# 加载处理好的数据
print("正在加载数据...")
with open('valid_products.pkl', 'rb') as f:
data = pickle.load(f)
print(f"成功加载数据,共 {len(data)} 条记录")
class ProductDataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = "Tell me the product name in the picture."
answer = "The product name is: " + item['name']
# 直接使用prepare_training_data处理所有数据
training_data = prepare_training_data(
question=question,
answer=answer,
tokenizer=self.tokenizer,
image_path=item['image'] # 直接传入PIL Image对象
)
return training_data
# 创建数据集实例
dataset = ProductDataset(data, tokenizer)
# 测试输出第一条数据看看
sample = dataset[0]
print("\n数据样例:")
print(f"输入形状: {sample['input_ids'].shape}")
print(f"图片张量形状: {sample['pixel_values'].shape}")
print(f"标签形状: {sample['labels'].shape}")
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
path = 'Internvl2_5'
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
vision_model = None,
language_model = None).train().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# 加载之前训练的权重
print("正在加载预训练权重 model_epoch_5.pth ...")
model.load_state_dict(torch.load('model_epoch_5.pth'))
print("成功加载预训练权重")
def prepare_model_for_training(model, cast_trainable_params_to_fp32=True):
print("正在设置模型参数...")
# 冻结vision_model
for param in model.vision_model.parameters():
param.requires_grad_(False)
# 冻结language_model
for param in model.language_model.parameters():
param.requires_grad_(False)
# 设置mlp1为可训练并可选转换为fp32
# for param in model.mlp1.parameters():
# if cast_trainable_params_to_fp32:
# param.data = param.data.to(torch.float32)
# 打印可训练参数数量
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"可训练参数数量: {trainable_params}")
return model
# 准备模型
model = prepare_model_for_training(model)
# 只对可训练参数创建优化器
optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=1e-4, # 学习率
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01
)
def custom_collate_fn(batch):
"""
自定义collate函数来处理变长序列
"""
# 获取batch中最大的序列长度
max_len = max([b['input_ids'].size(1) for b in batch])
batch_size = len(batch)
# 创建填充后的张量
input_ids = torch.zeros((batch_size, max_len), dtype=batch[0]['input_ids'].dtype)
attention_mask = torch.zeros((batch_size, max_len), dtype=batch[0]['attention_mask'].dtype)
labels = torch.full((batch_size, max_len), -100, dtype=batch[0]['labels'].dtype) # 用-100填充标签
# 填充每个样本
for i, item in enumerate(batch):
seq_len = item['input_ids'].size(1)
input_ids[i, :seq_len] = item['input_ids'][0, :seq_len]
attention_mask[i, :seq_len] = item['attention_mask'][0, :seq_len]
labels[i, :seq_len] = item['labels'][0, :seq_len]
# 处理pixel_values (这个应该是固定大小的)
pixel_values = torch.stack([item['pixel_values'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'pixel_values': pixel_values
}
# 使用自定义的collate_fn创建DataLoader
train_loader = DataLoader(
dataset,
batch_size=5, # 可以根据需要调整
shuffle=True,
pin_memory=torch.cuda.is_available(),
collate_fn=custom_collate_fn # 使用自定义的collate函数
)
# 计算img_context_token_id
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
def extract_feature(model, pixel_values):
# 移除多余的维度
if pixel_values.dim() == 5: # [batch_size, 1, 3, 448, 448]
pixel_values = pixel_values.squeeze(1) # 变成 [batch_size, 3, 448, 448]
# 使用vision_model提取特征
vit_embeds = model.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True
).last_hidden_state
vit_embeds = vit_embeds[:, 1:, :] # 移除CLS token
# 重塑并处理特征
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = model.pixel_shuffle(vit_embeds, scale_factor=model.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = model.mlp1(vit_embeds)
return vit_embeds
# 训练循环
num_epochs = 40 # 修改为40个epochs
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
# 将数据移到GPU
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
pixel_values = batch['pixel_values'].to(device)
# 提取图像特征
vit_embeds = extract_feature(model, pixel_values)
# 计算输入嵌入
input_embeds = model.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
# 替换图像上下文token的嵌入
input_ids_flat = input_ids.reshape(B * N)
selected = (input_ids_flat == img_context_token_id)
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
# 恢复原始形状
input_embeds = input_embeds.reshape(B, N, C)
# 前向传播
outputs = model.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=True
)
loss = outputs.loss
total_loss += loss.item()
# 反向传播
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 每100步打印一次loss
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
# 每个epoch结束打印平均loss
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
# 每10个epoch保存一次模型
if (epoch + 1) % 10 == 0:
save_path = f'model_epoch_{epoch+1}.pth'
print(f"保存模型权重到 {save_path}")
torch.save(model.state_dict(), save_path)