238 lines
8.1 KiB
Python
238 lines
8.1 KiB
Python
|
from unsloth import FastModel
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
import pickle
|
||
|
from torch.utils.data import Dataset
|
||
|
import base64
|
||
|
from io import BytesIO
|
||
|
import logging
|
||
|
from trl import GRPOConfig, GRPOTrainer
|
||
|
import re
|
||
|
|
||
|
# 在文件开头添加日志配置
|
||
|
logging.basicConfig(
|
||
|
level=logging.INFO,
|
||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
|
handlers=[
|
||
|
logging.FileHandler('training.log'),
|
||
|
logging.StreamHandler()
|
||
|
]
|
||
|
)
|
||
|
|
||
|
max_seq_length = 1024
|
||
|
|
||
|
# 在模型加载后添加日志
|
||
|
logging.info("正在加载模型和tokenizer...")
|
||
|
model, tokenizer = FastModel.from_pretrained(
|
||
|
model_name = "gemma-3-4b",
|
||
|
max_seq_length = max_seq_length, # Choose any for long context!
|
||
|
load_in_4bit = False, # 4 bit quantization to reduce memory
|
||
|
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
|
||
|
full_finetuning = False, # [NEW!] We have full finetuning now!
|
||
|
)
|
||
|
logging.info("模型加载完成")
|
||
|
|
||
|
logging.info("正在配置PEFT参数...")
|
||
|
model = FastModel.get_peft_model(
|
||
|
model,
|
||
|
finetune_vision_layers = True, # Turn off for just text!
|
||
|
finetune_language_layers = True, # Should leave on!
|
||
|
finetune_attention_modules = True, # Attention good for GRPO
|
||
|
finetune_mlp_modules = True, # SHould leave on always!
|
||
|
|
||
|
r = 8, # Larger = higher accuracy, but might overfit
|
||
|
lora_alpha = 8, # Recommended alpha == r at least
|
||
|
lora_dropout = 0,
|
||
|
bias = "none",
|
||
|
random_state = 3407,
|
||
|
)
|
||
|
logging.info("PEFT模型配置完成")
|
||
|
|
||
|
def image_to_base64(image):
|
||
|
# 将PIL Image对象转换为base64
|
||
|
buffered = BytesIO()
|
||
|
image.save(buffered, format="JPEG")
|
||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||
|
return f"data:image/jpeg;base64,{img_str}"
|
||
|
|
||
|
class ChatDataset(Dataset):
|
||
|
def __init__(self, data_path):
|
||
|
with open(data_path, 'rb') as f:
|
||
|
self.data = pickle.load(f)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
item = self.data[idx]
|
||
|
|
||
|
# 将图片转换为base64格式
|
||
|
image_base64 = image_to_base64(item['image'])
|
||
|
|
||
|
# 使用dra.py的messages格式
|
||
|
messages = [
|
||
|
{
|
||
|
"role": "system",
|
||
|
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||
|
},
|
||
|
{
|
||
|
"role": "user",
|
||
|
"content": [
|
||
|
{"type": "image", "image": image_base64},
|
||
|
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"}
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
|
||
|
return {
|
||
|
"prompt": messages, # 包含了图片和提示文本的完整模板
|
||
|
"correct_brand": item['brand']
|
||
|
}
|
||
|
|
||
|
# 加载数据集
|
||
|
logging.info("加载训练数据...")
|
||
|
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
|
||
|
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
|
||
|
|
||
|
# 使用示例:
|
||
|
# batch = next(iter(train_dataset))
|
||
|
# messages = batch["messages"]
|
||
|
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||
|
# inputs = tokenizer(input_text, add_special_tokens=False, return_tensors="pt")
|
||
|
|
||
|
# messages = [
|
||
|
# {
|
||
|
# "role": "system",
|
||
|
# "content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||
|
# },
|
||
|
# {
|
||
|
# "role": "user",
|
||
|
# "content": [
|
||
|
# {"type": "image"},
|
||
|
# {"type": "text", "text": "Describe this image in detail."}
|
||
|
# ]
|
||
|
# }
|
||
|
# ]
|
||
|
|
||
|
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
||
|
# image = Image.open("../work/Tesla.jpg")
|
||
|
# inputs = tokenizer(
|
||
|
# image,
|
||
|
# input_text,
|
||
|
# add_special_tokens = False,
|
||
|
# return_tensors = "pt",
|
||
|
# ).to(model.device, dtype=torch.bfloat16)
|
||
|
|
||
|
# input_len = inputs["input_ids"].shape[-1]
|
||
|
|
||
|
# with torch.inference_mode():
|
||
|
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||
|
# generation = generation[0][input_len:]
|
||
|
|
||
|
# decoded = tokenizer.decode(generation, skip_special_tokens=True)
|
||
|
# print(decoded)
|
||
|
|
||
|
def reward_func(prompts, completions, **kwargs):
|
||
|
logging.info("开始计算奖励值...")
|
||
|
rewards = []
|
||
|
correct_brands = kwargs.get('correct_brand')
|
||
|
|
||
|
for idx, (completion, correct_brand) in enumerate(zip(completions, correct_brands)):
|
||
|
reward = 0.0
|
||
|
|
||
|
# 确保 completion 是字符串类型
|
||
|
try:
|
||
|
if isinstance(completion, (list, tuple)):
|
||
|
completion = completion[0] # 如果是列表或元组,取第一个元素
|
||
|
completion = str(completion) # 转换为字符串
|
||
|
|
||
|
logging.debug(f"样本 {idx + 1}:")
|
||
|
logging.debug(f"完整回答类型: {type(completion)}")
|
||
|
logging.debug(f"完整回答: {completion}")
|
||
|
logging.debug(f"正确品牌: {correct_brand}")
|
||
|
|
||
|
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
|
||
|
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
|
||
|
|
||
|
# 答案部分仍然检查品牌是否正确
|
||
|
if answer_match:
|
||
|
answer_content = answer_match.group(1).lower()
|
||
|
if correct_brand.lower() in answer_content:
|
||
|
reward += 1.0
|
||
|
logging.debug("答案部分匹配正确 (+1.0)")
|
||
|
|
||
|
# 推理部分根据长度评分
|
||
|
if thinking_match:
|
||
|
thinking_content = thinking_match.group(1).strip()
|
||
|
content_length = len(thinking_content)
|
||
|
|
||
|
if content_length < 50:
|
||
|
thinking_reward = 0.25
|
||
|
level = "简单"
|
||
|
elif content_length < 100:
|
||
|
thinking_reward = 0.5
|
||
|
level = "基础"
|
||
|
elif content_length < 150:
|
||
|
thinking_reward = 0.75
|
||
|
level = "详细"
|
||
|
else:
|
||
|
thinking_reward = 1.0
|
||
|
level = "非常详细"
|
||
|
|
||
|
reward += thinking_reward
|
||
|
logging.debug(f"推理部分长度: {content_length} 字符")
|
||
|
logging.debug(f"推理详细程度: {level}")
|
||
|
logging.debug(f"推理部分得分: +{thinking_reward}")
|
||
|
|
||
|
except Exception as e:
|
||
|
logging.error(f"处理样本 {idx} 时发生错误: {str(e)}")
|
||
|
logging.error(f"completion 类型: {type(completion)}")
|
||
|
logging.error(f"completion 内容: {completion}")
|
||
|
reward = 0.0 # 发生错误时给出0分
|
||
|
|
||
|
logging.debug(f"最终奖励值: {reward}\n")
|
||
|
rewards.append(reward)
|
||
|
|
||
|
batch_avg = sum(rewards)/len(rewards) if rewards else 0
|
||
|
logging.info(f"批次平均奖励值: {batch_avg:.3f}")
|
||
|
return rewards
|
||
|
|
||
|
max_prompt_length = 256
|
||
|
|
||
|
|
||
|
training_args = GRPOConfig(
|
||
|
learning_rate = 5e-6,
|
||
|
adam_beta1 = 0.9,
|
||
|
adam_beta2 = 0.99,
|
||
|
weight_decay = 0.1,
|
||
|
warmup_ratio = 0.1,
|
||
|
lr_scheduler_type = "cosine",
|
||
|
optim = "adamw_torch_fused",
|
||
|
logging_steps = 1,
|
||
|
per_device_train_batch_size = 1,
|
||
|
gradient_accumulation_steps = 4, # Increase to 4 for smoother training
|
||
|
num_generations = 2, # Decrease if out of memory
|
||
|
max_completion_length = 512,
|
||
|
# num_train_epochs = 1, # Set to 1 for a full training run
|
||
|
max_steps = 400,
|
||
|
save_steps = 200,
|
||
|
max_grad_norm = 0.1,
|
||
|
report_to = "none", # Can use Weights & Biases
|
||
|
output_dir = "outputs",
|
||
|
)
|
||
|
|
||
|
# 在训练开始前添加配置信息日志
|
||
|
logging.info("训练配置信息:")
|
||
|
logging.info(f"学习率: {training_args.learning_rate}")
|
||
|
logging.info(f"批次大小: {training_args.per_device_train_batch_size}")
|
||
|
logging.info(f"梯度累积步数: {training_args.gradient_accumulation_steps}")
|
||
|
logging.info(f"最大训练步数: {training_args.max_steps}")
|
||
|
|
||
|
trainer = GRPOTrainer(
|
||
|
model = model,
|
||
|
processing_class = tokenizer,
|
||
|
reward_funcs = reward_func,
|
||
|
args = training_args,
|
||
|
train_dataset = train_dataset,
|
||
|
)
|
||
|
trainer.train()
|