Fortrain/qw/open_r1/grpo_rec.py

292 lines
12 KiB
Python
Raw Normal View History

2025-03-31 15:56:36 +08:00
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import debugpy
# try:
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
# debugpy.listen(("localhost", 9501))
# print("Waiting for debugger attach")
# debugpy.wait_for_client()
# except Exception as e:
# pass
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from PIL import Image
from torch.utils.data import Dataset
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from open_r1.vlm_modules import *
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
from transformers import TrainingArguments
import yaml
import json
import random
import math
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple
def custom_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
# Add this
cos = cos.to(torch.float)
sin = sin.to(torch.float)
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
# ----------------------- Main Script -----------------------
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
)
max_anyres_num: Optional[int] = field(
default=12,
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
)
image_root: Optional[str] = field(
default=None,
metadata={"help": "Root directory of the image"},
)
@dataclass
class GRPOModelConfig(ModelConfig):
freeze_vision_modules: bool = False
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
class LazySupervisedDataset(Dataset):
def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
super(LazySupervisedDataset, self).__init__()
self.script_args = script_args
self.list_data_dict = []
self.question_template = question_template
if data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
for data in datasets:
json_path = data.get("json_path")
sampling_strategy = data.get("sampling_strategy", "all")
sampling_number = None
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
raise ValueError(f"Unsupported file type: {data_path}")
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i):
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
QUESTION_TEMPLATE = self.question_template
def make_conversation_image(example):
return {
"prompt": [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
],
},
],
}
example = self.list_data_dict[i]
image_root = self.script_args.image_root
if 'image' in example:
image_path = os.path.join(image_root, example['image'])
# In case the image is not found
while not os.path.exists(image_path):
print(f"Warning: Image {image_path} not found, randomly selecting another image")
new_index = random.randint(0, len(self.list_data_dict)-1)
example = self.list_data_dict[new_index]
image_path = os.path.join(image_root, example['image'])
image = Image.open(image_path).convert("RGB")
else:
image = None
return {
'image': image,
'problem': example['problem'],
'solution': example['solution'],
'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
}
def get_vlm_module(model_name_or_path):
if "qwen" in model_name_or_path.lower():
return Qwen2VLModule
elif "internvl" in model_name_or_path.lower():
return InvernVLModule
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
def main(script_args, training_args, model_args):
# Load the VLM module
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
# Load the reward functions
reward_funcs_registry = {
"accuracy": vlm_module_cls.iou_reward,
"format": vlm_module_cls.format_reward_rec,
}
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the dataset
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
trainer_cls = VLMGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
vlm_module=vlm_module_cls(),
train_dataset=dataset,
eval_dataset=None,
peft_config=get_peft_config(model_args),
freeze_vision_modules=model_args.freeze_vision_modules,
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
max_anyres_num=script_args.max_anyres_num,
torch_dtype=model_args.torch_dtype,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)