from unsloth import FastModel
import torch
from PIL import Image
import pickle
from torch.utils.data import Dataset
import base64
from io import BytesIO
import logging
from trl import GRPOConfig, GRPOTrainer
import re
# 在文件开头添加日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training.log'),
logging.StreamHandler()
]
)
max_seq_length = 1024
# 在模型加载后添加日志
logging.info("正在加载模型和tokenizer...")
model, tokenizer = FastModel.from_pretrained(
model_name = "gemma-3-4b",
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = False, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
logging.info("模型加载完成")
logging.info("正在配置PEFT参数...")
model = FastModel.get_peft_model(
model,
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 8, # Larger = higher accuracy, but might overfit
lora_alpha = 8, # Recommended alpha == r at least
lora_dropout = 0,
bias = "none",
random_state = 3407,
)
logging.info("PEFT模型配置完成")
def image_to_base64(image):
# 将PIL Image对象转换为base64
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
class ChatDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 将图片转换为base64格式
image_base64 = image_to_base64(item['image'])
# 使用dra.py的messages格式
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_base64},
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels and and explain the reason between labels and "}
]
}
]
return {
"prompt": messages, # 包含了图片和提示文本的完整模板
"correct_brand": item['brand']
}
# 加载数据集
logging.info("加载训练数据...")
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
# 使用示例:
# batch = next(iter(train_dataset))
# messages = batch["messages"]
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# inputs = tokenizer(input_text, add_special_tokens=False, return_tensors="pt")
# messages = [
# {
# "role": "system",
# "content": [{"type": "text", "text": "You are a helpful assistant."}]
# },
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": "Describe this image in detail."}
# ]
# }
# ]
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
# image = Image.open("../work/Tesla.jpg")
# inputs = tokenizer(
# image,
# input_text,
# add_special_tokens = False,
# return_tensors = "pt",
# ).to(model.device, dtype=torch.bfloat16)
# input_len = inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
# generation = generation[0][input_len:]
# decoded = tokenizer.decode(generation, skip_special_tokens=True)
# print(decoded)
def reward_func(prompts, completions, **kwargs):
logging.info("开始计算奖励值...")
rewards = []
correct_brands = kwargs.get('correct_brand')
for idx, (completion, correct_brand) in enumerate(zip(completions, correct_brands)):
reward = 0.0
# 确保 completion 是字符串类型
try:
if isinstance(completion, (list, tuple)):
completion = completion[0] # 如果是列表或元组,取第一个元素
completion = str(completion) # 转换为字符串
logging.debug(f"样本 {idx + 1}:")
logging.debug(f"完整回答类型: {type(completion)}")
logging.debug(f"完整回答: {completion}")
logging.debug(f"正确品牌: {correct_brand}")
answer_match = re.search(r'(.*?)', completion)
thinking_match = re.search(r'(.*?)', completion)
# 答案部分仍然检查品牌是否正确
if answer_match:
answer_content = answer_match.group(1).lower()
if correct_brand.lower() in answer_content:
reward += 1.0
logging.debug("答案部分匹配正确 (+1.0)")
# 推理部分根据长度评分
if thinking_match:
thinking_content = thinking_match.group(1).strip()
content_length = len(thinking_content)
if content_length < 50:
thinking_reward = 0.25
level = "简单"
elif content_length < 100:
thinking_reward = 0.5
level = "基础"
elif content_length < 150:
thinking_reward = 0.75
level = "详细"
else:
thinking_reward = 1.0
level = "非常详细"
reward += thinking_reward
logging.debug(f"推理部分长度: {content_length} 字符")
logging.debug(f"推理详细程度: {level}")
logging.debug(f"推理部分得分: +{thinking_reward}")
except Exception as e:
logging.error(f"处理样本 {idx} 时发生错误: {str(e)}")
logging.error(f"completion 类型: {type(completion)}")
logging.error(f"completion 内容: {completion}")
reward = 0.0 # 发生错误时给出0分
logging.debug(f"最终奖励值: {reward}\n")
rewards.append(reward)
batch_avg = sum(rewards)/len(rewards) if rewards else 0
logging.info(f"批次平均奖励值: {batch_avg:.3f}")
return rewards
max_prompt_length = 256
training_args = GRPOConfig(
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_torch_fused",
logging_steps = 1,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4, # Increase to 4 for smoother training
num_generations = 2, # Decrease if out of memory
max_completion_length = 512,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 400,
save_steps = 200,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
# 在训练开始前添加配置信息日志
logging.info("训练配置信息:")
logging.info(f"学习率: {training_args.learning_rate}")
logging.info(f"批次大小: {training_args.per_device_train_batch_size}")
logging.info(f"梯度累积步数: {training_args.gradient_accumulation_steps}")
logging.info(f"最大训练步数: {training_args.max_steps}")
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = reward_func,
args = training_args,
train_dataset = train_dataset,
)
trainer.train()