# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import debugpy # try: # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 # debugpy.listen(("localhost", 9501)) # print("Waiting for debugger attach") # debugpy.wait_for_client() # except Exception as e: # pass import os import re from datetime import datetime from dataclasses import dataclass, field from typing import Optional from datasets import load_dataset, load_from_disk from transformers import Qwen2VLForConditionalGeneration from math_verify import parse, verify from open_r1.trainer import VLMGRPOTrainer from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config @dataclass class GRPOScriptArguments(ScriptArguments): """ Script arguments for the GRPO training script. Args: reward_funcs (`list[str]`): List of reward functions. Possible values: 'accuracy', 'format'. """ reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image"}, ) def accuracy_reward(completions, solution, **kwargs): """Reward function that checks if the completion is correct using either symbolic verification or exact string matching.""" contents = [completion[0]["content"] for completion in completions] rewards = [] current_time = datetime.now().strftime("%d-%H-%M-%S-%f") for content, sol in zip(contents, solution): reward = 0.0 # Try symbolic verification first try: answer = parse(content) if float(verify(answer, parse(sol))) > 0: reward = 1.0 except Exception: pass # Continue to next verification method if this fails # If symbolic verification failed, try string matching if reward == 0.0: try: # Extract answer from solution if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content) student_answer = content_match.group(1).strip() if content_match else content.strip() # Compare the extracted answers if student_answer == ground_truth: reward = 1.0 except Exception: pass # Keep reward as 0.0 if both methods fail rewards.append(reward) if os.getenv("DEBUG_MODE") == "true": log_path = os.getenv("LOG_PATH") # local_rank = int(os.getenv("LOCAL_RANK", 0)) with open(log_path, "a") as f: f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") f.write(f"Content: {content}\n") f.write(f"Solution: {sol}\n") return rewards def format_reward(completions, **kwargs): """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, content) for content in completion_contents] return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = { "accuracy": accuracy_reward, "format": format_reward, } SYSTEM_PROMPT = ( "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " "process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here " ) def main(script_args, training_args, model_args): # Get reward functions reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] print("reward_funcs:", reward_funcs) # Load the dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Format into conversation def make_conversation(example): return { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": example["problem"]}, ], } # def make_conversation_image(example): # return { # "prompt": [ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, # { # "role": "user", # "content": [ # {"type": "image"}, # {"type": "text", "text": example["problem"]}, # ], # }, # ], # } QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (number) in tags." def make_conversation_image(example): return { "prompt": [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}, ], }, ], } if "image" in dataset[script_args.dataset_train_split].features: print("has image in dataset") dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping # dataset = dataset.remove_columns(["original_question", "original_answer"]) else: print("no image in dataset") dataset = dataset.map(make_conversation) dataset = dataset.remove_columns("messages") trainer_cls = VLMGRPOTrainer # Initialize the GRPO trainer trainer = trainer_cls( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, torch_dtype=model_args.torch_dtype, ) # Train and push the model to the Hub trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if __name__ == "__main__": parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)