# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import debugpy # try: # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 # debugpy.listen(("localhost", 9501)) # print("Waiting for debugger attach") # debugpy.wait_for_client() # except Exception as e: # pass import os import re from datetime import datetime from dataclasses import dataclass, field from typing import Optional from PIL import Image from torch.utils.data import Dataset from transformers import Qwen2VLForConditionalGeneration from math_verify import parse, verify from open_r1.trainer import VLMGRPOTrainer, GRPOConfig from open_r1.vlm_modules import * from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config from transformers import TrainingArguments import yaml import json import random import math # ----------------------- Fix the flash attention bug in the current version of transformers ----------------------- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func import torch from typing import Tuple def custom_forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) # print(111, 222, 333, 444, 555, 666, 777, 888, 999) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos().float() sin = emb.sin().float() else: cos, sin = position_embeddings # Add this cos = cos.to(torch.float) sin = sin.to(torch.float) q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) q = q.squeeze(0) k = k.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) return attn_output Qwen2_5_VLVisionFlashAttention2.forward = custom_forward # ----------------------- Main Script ----------------------- @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script. Args: reward_funcs (`list[str]`): List of reward functions. Possible values: 'accuracy', 'format'. """ reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, ) max_anyres_num: Optional[int] = field( default=12, metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"}, ) image_root: Optional[str] = field( default=None, metadata={"help": "Root directory of the image"}, ) @dataclass class GRPOModelConfig(ModelConfig): freeze_vision_modules: bool = False SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here " ) class LazySupervisedDataset(Dataset): def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str): super(LazySupervisedDataset, self).__init__() self.script_args = script_args self.list_data_dict = [] self.question_template = question_template if data_path.endswith(".yaml"): with open(data_path, "r") as file: yaml_data = yaml.safe_load(file) datasets = yaml_data.get("datasets") # file should be in the format of: # datasets: # - json_path: xxxx1.json # sampling_strategy: first:1000 # - json_path: xxxx2.json # sampling_strategy: end:3000 # - json_path: xxxx3.json # sampling_strategy: random:999 for data in datasets: json_path = data.get("json_path") sampling_strategy = data.get("sampling_strategy", "all") sampling_number = None if json_path.endswith(".jsonl"): cur_data_dict = [] with open(json_path, "r") as json_file: for line in json_file: cur_data_dict.append(json.loads(line.strip())) elif json_path.endswith(".json"): with open(json_path, "r") as json_file: cur_data_dict = json.load(json_file) else: raise ValueError(f"Unsupported file type: {json_path}") if ":" in sampling_strategy: sampling_strategy, sampling_number = sampling_strategy.split(":") if "%" in sampling_number: sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) else: sampling_number = int(sampling_number) # Apply the sampling strategy if sampling_strategy == "first" and sampling_number is not None: cur_data_dict = cur_data_dict[:sampling_number] elif sampling_strategy == "end" and sampling_number is not None: cur_data_dict = cur_data_dict[-sampling_number:] elif sampling_strategy == "random" and sampling_number is not None: random.shuffle(cur_data_dict) cur_data_dict = cur_data_dict[:sampling_number] print(f"Loaded {len(cur_data_dict)} samples from {json_path}") self.list_data_dict.extend(cur_data_dict) else: raise ValueError(f"Unsupported file type: {data_path}") def __len__(self): return len(self.list_data_dict) def __getitem__(self, i): # Format into conversation def make_conversation(example): return { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": example["problem"]}, ], } QUESTION_TEMPLATE = self.question_template def make_conversation_image(example): return { "prompt": [ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}, ], }, ], } example = self.list_data_dict[i] image_root = self.script_args.image_root if 'image' in example: image_path = os.path.join(image_root, example['image']) # In case the image is not found while not os.path.exists(image_path): print(f"Warning: Image {image_path} not found, randomly selecting another image") new_index = random.randint(0, len(self.list_data_dict)-1) example = self.list_data_dict[new_index] image_path = os.path.join(image_root, example['image']) image = Image.open(image_path).convert("RGB") else: image = None return { 'image': image, 'problem': example['problem'], 'solution': example['solution'], 'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'], } def get_vlm_module(model_name_or_path): if "qwen" in model_name_or_path.lower(): return Qwen2VLModule elif "internvl" in model_name_or_path.lower(): return InvernVLModule else: raise ValueError(f"Unsupported model: {model_name_or_path}") def main(script_args, training_args, model_args): # Load the VLM module vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print("using vlm module:", vlm_module_cls.__name__) # Load the reward functions reward_funcs_registry = { "accuracy": vlm_module_cls.iou_reward, "format": vlm_module_cls.format_reward_rec, } reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] print("reward_funcs:", reward_funcs) # Load the dataset dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec")) trainer_cls = VLMGRPOTrainer # Initialize the GRPO trainer trainer = trainer_cls( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, vlm_module=vlm_module_cls(), train_dataset=dataset, eval_dataset=None, peft_config=get_peft_config(model_args), freeze_vision_modules=model_args.freeze_vision_modules, attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, max_anyres_num=script_args.max_anyres_num, torch_dtype=model_args.torch_dtype, ) # Train and push the model to the Hub trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if __name__ == "__main__": parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)