# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import pathlib from datetime import datetime from dataclasses import dataclass, field from typing import Optional from babel.numbers import parse_decimal from utils.math import compute_score from datasets import load_dataset, load_from_disk from transformers import Qwen2VLForConditionalGeneration from math_verify import parse, verify from open_r1.trainer import VLMGRPOTrainer, GRPOConfig from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config import PIL from Levenshtein import ratio from open_r1.utils.pycocotools.coco import COCO from open_r1.utils.pycocotools.cocoeval import COCOeval import json from open_r1.vlm_modules import * # ----------------------- Fix the flash attention bug in the current version of transformers ----------------------- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func import torch from typing import Tuple from transformers.utils import logging from openai import OpenAI logger = logging.get_logger(__name__) client = OpenAI( api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"), base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") ) def custom_forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) # print(111, 222, 333, 444, 555, 666, 777, 888, 999) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos().float() sin = emb.sin().float() else: cos, sin = position_embeddings # Add this cos = cos.to(torch.float) sin = sin.to(torch.float) q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) q = q.squeeze(0) k = k.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) return attn_output Qwen2_5_VLVisionFlashAttention2.forward = custom_forward @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script. """ data_file_paths: str = field( default=None, metadata={"help": "Paths to data files, separated by ':'"}, ) image_folders: str = field( default=None, metadata={"help": "Paths to image folders, separated by ':'"}, ) arrow_cache_dir: str = field( default=None, metadata={"help": "Path to arrow cache directory"}, ) val_split_ratio: float = field( default=0.0, metadata={"help": "Ratio of validation split, default 0.0"}, ) reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, ) max_anyres_num: Optional[int] = field( default=12, metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"}, ) reward_method: Optional[str] = field( default=None, metadata={ "help": "Choose reward method: 'default', 'mcp', ..." }, ) def extract_choice(text): # 1. Clean and normalize text text = text.upper() # Convert to uppercase text = re.sub(r'\s+', ' ', text) # Normalize spaces # 2. Choice should not have uppercase letters before or after choices = re.findall(r'(? len(text) * 0.7: # In last 30% of text choice_scores[choice] += 2 # Add points if followed by punctuation if pos < len(text) - 1 and text[pos+1] in '。.!!,,': choice_scores[choice] += 1 # Return highest scoring choice return max(choice_scores.items(), key=lambda x: x[1])[0] def evaluate_answer_similarity(student_answer, ground_truth): """Use llm to evaluate answer similarity.""" try: response = client.chat.completions.create( model="qwen2.5:7b", messages=[ { "role": "user", "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed." }, { "role": "user", "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:" } ], temperature=0 ) result = response.choices[0].message.content.strip() return float(result) except Exception as e: print(f"Error in GPT evaluation: {e}") # If API call fails, fall back to simple text matching return 1.0 if student_answer ==ground_truth else 0.0 def llm_reward(content, sol, **kwargs): # Extract answer from content if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_matches = re.findall(r'(.*?)', content, re.DOTALL) student_answer = content_matches[-1].strip() if content_matches else content.strip() return evaluate_answer_similarity(student_answer, ground_truth) def mcq_reward(content, sol, **kwargs): # For multiple choice, extract and compare choices has_choices = extract_choice(sol) correct_choice = has_choices.upper() if has_choices else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content, re.DOTALL) student_answer = content_match.group(1).strip() if content_match else content.strip() student_choice = extract_choice(student_answer) if student_choice: reward = 1.0 if student_choice == correct_choice else 0.0 else: reward = 0.0 return reward def yes_no_reward(content, sol, **kwargs): content = content.lower() sol = sol.lower() # Extract answer from solution if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content, re.DOTALL) student_answer = content_match.group(1).strip() if content_match else content.strip() ground_yes_no = re.search(r'(yes|no)', ground_truth) ground_yes_no = ground_yes_no.group(1) if ground_yes_no else '' student_yes_no = re.search(r'(yes|no)', student_answer) student_yes_no = student_yes_no.group(1) if student_yes_no else '' reward = 1.0 if ground_yes_no == student_yes_no else 0.0 return reward def calculate_map(pred_bbox_list, gt_bbox_list): # Calculate mAP # Initialize COCO object for ground truth gt_json = {"annotations": [], "images": [], "categories": []} gt_json["images"] = [{ "id": 0, "width": 2048, "height": 2048, "file_name": "image_0.jpg" }] gt_json["categories"] = [] cats2id = {} cat_count = 0 for idx, gt_bbox in enumerate(gt_bbox_list): if gt_bbox["label"] not in cats2id: cats2id[gt_bbox["label"]] = cat_count gt_json["categories"].append({ "id": cat_count, "name": gt_bbox["label"] }) cat_count += 1 gt_json["annotations"].append({ "id": idx+1, "image_id": 0, "category_id": cats2id[gt_bbox["label"]], "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]], "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]), "iscrowd": 0 }) coco_gt = COCO(gt_json) dt_json = [] for idx, pred_bbox in enumerate(pred_bbox_list): try: dt_json.append({ "image_id": 0, "category_id": cats2id[pred_bbox["label"]], "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]], "score": 1.0, "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]) }) except: pass if len(dt_json) == 0: return 0.0 coco_dt = coco_gt.loadRes(dt_json) coco_eval = COCOeval(coco_gt, coco_dt, "bbox") coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() return coco_eval.stats[1] def map_reward(content, sol, **kwargs): """ Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes Args: content: String containing predicted bounding boxes in JSON format sol: String containing ground truth bounding boxes in JSON format Returns: float: mAP reward score between 0 and 1 """ # Extract JSON content between ```json tags pattern = r'```json(.*?)```' json_match = re.search(pattern, sol, re.DOTALL) bbox_json = json_match.group(1).strip() if json_match else None # Parse ground truth JSON to get bbox list gt_bbox_list = [] if bbox_json: bbox_data = json.loads(bbox_json) gt_bbox_list = [item for item in bbox_data] # Parse predicted JSON to get bbox list pred_bbox_list = [] json_match = re.search(pattern, content, re.DOTALL) if json_match: try: bbox_data = json.loads(json_match.group(1).strip()) pred_bbox_list = [item for item in bbox_data] except: # Return empty list if JSON parsing fails pred_bbox_list = [] # Calculate mAP if both prediction and ground truth exist if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0: bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list) else: bbox_reward = 0.0 return bbox_reward def numeric_reward(content, sol, **kwargs): content = clean_text(content) sol = clean_text(sol) try: content, sol = float(content), float(sol) return 1.0 if content == sol else 0.0 except: return None def math_reward(content, sol, **kwargs): content = clean_text(content) sol = clean_text(sol) return compute_score(content, sol) def clean_text(text, exclue_chars=['\n', '\r']): # Extract content between and if present answer_matches = re.findall(r'(.*?)', text, re.DOTALL) if answer_matches: # Use the last match text = answer_matches[-1] for char in exclue_chars: if char in ['\n', '\r']: # If there is a space before the newline, remove the newline text = re.sub(r'(?<=\s)' + re.escape(char), '', text) # If there is no space before the newline, replace it with a space text = re.sub(r'(?(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_matches = re.findall(r'(.*?)', content, re.DOTALL) student_answer = content_matches[-1].strip() if content_matches else content.strip() # Try symbolic verification first for numeric answers try: answer = parse(student_answer) if float(verify(answer, parse(ground_truth))) > 0: reward = 1.0 except Exception: pass # Continue to next verification method if this fails # If symbolic verification failed, try string matching or fuzzy matching if reward == 0.0: try: # Check if ground truth contains numbers has_numbers = bool(re.search(r'\d', ground_truth)) # Check if it's a multiple choice question has_choices = extract_choice(ground_truth) if has_numbers: # For numeric answers, use exact matching reward = numeric_reward(student_answer, ground_truth) if reward is None: reward = ratio(clean_text(student_answer), clean_text(ground_truth)) elif has_choices: # For multiple choice, extract and compare choices correct_choice = has_choices.upper() student_choice = extract_choice(student_answer) if student_choice: reward = 1.0 if student_choice == correct_choice else 0.0 else: # For text answers, use fuzzy matching reward = ratio(clean_text(student_answer), clean_text(ground_truth)) except Exception: pass # Keep reward as 0.0 if all methods fail return reward def accuracy_reward(completions, solution, **kwargs): """Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching.""" contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")): # if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function if accu_reward_method == "mcq": reward = mcq_reward(content, sol) elif accu_reward_method == 'yes_no': reward = yes_no_reward(content, sol) elif accu_reward_method == 'llm': reward = llm_reward(content, sol) elif accu_reward_method == 'map': reward = map_reward(content, sol) elif accu_reward_method == 'math': reward = math_reward(content, sol) else: reward = default_accuracy_reward(content, sol) rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") current_time = datetime.now().strftime("%d-%H-%M-%S-%f") image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None problem = kwargs.get("problem")[0] if reward <= 1.0: # this condition can be changed for debug with open(log_path, "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"accu_reward_method: {accu_reward_method}\n") f.write(f"image_path: {image_path}\n") f.write(f"problem: {problem}\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f: f.write(f"------------- {current_time} Format reward -------------\n") for content, match in zip(completion_contents, matches): f.write(f"Content: {content}\n") f.write(f"Has format: {bool(match)}\n") return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = { "accuracy": accuracy_reward, "format": format_reward, } @dataclass class GRPOModelConfig(ModelConfig): freeze_vision_modules: bool = False SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here " ) def get_vlm_module(model_name_or_path): if "qwen" in model_name_or_path.lower(): return Qwen2VLModule elif "internvl" in model_name_or_path.lower(): return InvernVLModule else: raise ValueError(f"Unsupported model: {model_name_or_path}") def main(script_args, training_args, model_args): # Load the VLM module vlm_module_cls = get_vlm_module(model_args.model_name_or_path) print("using vlm module:", vlm_module_cls.__name__) question_prompt = vlm_module_cls.get_question_template(task_type="default") # Get reward functions reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] print("reward_funcs:", reward_funcs) # Load the JSONL datasets import json from datasets import Dataset data_files = script_args.data_file_paths.split(":") image_folders = script_args.image_folders.split(":") if len(data_files) != len(image_folders): raise ValueError("Number of data files must match number of image folders") if script_args.reward_method is None: accu_reward_methods = ["default"] * len(data_files) else: accu_reward_methods = script_args.reward_method.split(":") assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}" if len(data_files) != len(image_folders): raise ValueError("Number of data files must match number of image folders") all_data = [] for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods): with open(data_file, 'r') as f: for line in f: item = json.loads(line) if 'image' in item: if isinstance(item['image'], str): # Store image path instead of loading the image item['image_path'] = [os.path.join(image_folder, item['image'])] del item['image'] # remove the image column so that it can be loaded later elif isinstance(item['image'], list): # if the image is a list, then it is a list of images (for multi-image input) item['image_path'] = [os.path.join(image_folder, image) for image in item['image']] del item['image'] # remove the image column so that it can be loaded later else: raise ValueError(f"Unsupported image type: {type(item['image'])}") # Remove immediate image loading item['problem'] = item['conversations'][0]['value'].replace('', '') # Handle solution that could be a float or string solution_value = item['conversations'][1]['value'] if isinstance(solution_value, str): item['solution'] = solution_value.replace('', '').replace('', '').strip() else: # If it's a float or other non-string type, keep it as is item['solution'] = str(solution_value) del item['conversations'] item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value all_data.append(item) dataset = Dataset.from_list(all_data) def make_conversation_from_jsonl(example): if 'image_path' in example and example['image_path'] is not None: # Don't load image here, just store the path return { 'image_path': [p for p in example['image_path']], # Store path instead of loaded image 'problem': example['problem'], 'solution': f" {example['solution']} ", 'accu_reward_method': example['accu_reward_method'], 'prompt': [{ 'role': 'user', 'content': [ *({'type': 'image', 'text': None} for _ in range(len(example['image_path']))), {'type': 'text', 'text': question_prompt.format(Question=example['problem'])} ] }] } else: return { 'problem': example['problem'], 'solution': f" {example['solution']} ", 'accu_reward_method': example['accu_reward_method'], 'prompt': [{ 'role': 'user', 'content': [ {'type': 'text', 'text': question_prompt.format(Question=example['problem'])} ] }] } # Map the conversations dataset = dataset.map(make_conversation_from_jsonl, num_proc=8) # Split dataset for validation if requested splits = {'train': dataset} if script_args.val_split_ratio > 0: train_val_split = dataset.train_test_split( test_size=script_args.val_split_ratio ) splits['train'] = train_val_split['train'] splits['validation'] = train_val_split['test'] # Select trainer class based on vlm_trainer argument trainer_cls = VLMGRPOTrainer print("using trainer:", trainer_cls.__name__) # Initialize the GRPO trainer trainer = trainer_cls( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, vlm_module=vlm_module_cls(), train_dataset=splits['train'], eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), freeze_vision_modules=model_args.freeze_vision_modules, attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, ) # Train and push the model to the Hub if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub() if __name__ == "__main__": parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)