Fortrain/qw/open_r1/grpo_jsonl.py
2025-03-31 15:56:36 +08:00

650 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import pathlib
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from babel.numbers import parse_decimal
from utils.math import compute_score
from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
import PIL
from Levenshtein import ratio
from open_r1.utils.pycocotools.coco import COCO
from open_r1.utils.pycocotools.cocoeval import COCOeval
import json
from open_r1.vlm_modules import *
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple
from transformers.utils import logging
from openai import OpenAI
logger = logging.get_logger(__name__)
client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
)
def custom_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
# Add this
cos = cos.to(torch.float)
sin = sin.to(torch.float)
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
"""
data_file_paths: str = field(
default=None,
metadata={"help": "Paths to data files, separated by ':'"},
)
image_folders: str = field(
default=None,
metadata={"help": "Paths to image folders, separated by ':'"},
)
arrow_cache_dir: str = field(
default=None,
metadata={"help": "Path to arrow cache directory"},
)
val_split_ratio: float = field(
default=0.0,
metadata={"help": "Ratio of validation split, default 0.0"},
)
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
)
max_anyres_num: Optional[int] = field(
default=12,
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
)
reward_method: Optional[str] = field(
default=None,
metadata={
"help": "Choose reward method: 'default', 'mcp', ..."
},
)
def extract_choice(text):
# 1. Clean and normalize text
text = text.upper() # Convert to uppercase
text = re.sub(r'\s+', ' ', text) # Normalize spaces
# 2. Choice should not have uppercase letters before or after
choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
if not choices:
return None
# 3. If only one choice, return it directly
if len(choices) == 1:
return choices[0]
# 4. If multiple choices, use heuristic rules
choice_scores = {choice: 0 for choice in choices}
# 4.1 Keywords around choices get points
keywords = [
'答案', '选择', '正确', '', '',
'answer', 'correct', 'choose', 'select', 'right',
'认为', '应该', '觉得', 'think', 'believe', 'should'
]
# Get context for each choice (20 chars before and after)
for choice in choices:
pos = text.find(choice)
context = text[max(0, pos-20):min(len(text), pos+20)]
# Add points for keywords
for keyword in keywords:
if keyword.upper() in context:
choice_scores[choice] += 1
# Add points if choice is near the end (usually final answer)
if pos > len(text) * 0.7: # In last 30% of text
choice_scores[choice] += 2
# Add points if followed by punctuation
if pos < len(text) - 1 and text[pos+1] in '。.!,':
choice_scores[choice] += 1
# Return highest scoring choice
return max(choice_scores.items(), key=lambda x: x[1])[0]
def evaluate_answer_similarity(student_answer, ground_truth):
"""Use llm to evaluate answer similarity."""
try:
response = client.chat.completions.create(
model="qwen2.5:7b",
messages=[
{
"role": "user",
"content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
},
{
"role": "user",
"content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
}
],
temperature=0
)
result = response.choices[0].message.content.strip()
return float(result)
except Exception as e:
print(f"Error in GPT evaluation: {e}")
# If API call fails, fall back to simple text matching
return 1.0 if student_answer ==ground_truth else 0.0
def llm_reward(content, sol, **kwargs):
# Extract answer from content if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_matches[-1].strip() if content_matches else content.strip()
return evaluate_answer_similarity(student_answer, ground_truth)
def mcq_reward(content, sol, **kwargs):
# For multiple choice, extract and compare choices
has_choices = extract_choice(sol)
correct_choice = has_choices.upper() if has_choices else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_match.group(1).strip() if content_match else content.strip()
student_choice = extract_choice(student_answer)
if student_choice:
reward = 1.0 if student_choice == correct_choice else 0.0
else:
reward = 0.0
return reward
def yes_no_reward(content, sol, **kwargs):
content = content.lower()
sol = sol.lower()
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_match.group(1).strip() if content_match else content.strip()
ground_yes_no = re.search(r'(yes|no)', ground_truth)
ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
student_yes_no = re.search(r'(yes|no)', student_answer)
student_yes_no = student_yes_no.group(1) if student_yes_no else ''
reward = 1.0 if ground_yes_no == student_yes_no else 0.0
return reward
def calculate_map(pred_bbox_list, gt_bbox_list):
# Calculate mAP
# Initialize COCO object for ground truth
gt_json = {"annotations": [], "images": [], "categories": []}
gt_json["images"] = [{
"id": 0,
"width": 2048,
"height": 2048,
"file_name": "image_0.jpg"
}]
gt_json["categories"] = []
cats2id = {}
cat_count = 0
for idx, gt_bbox in enumerate(gt_bbox_list):
if gt_bbox["label"] not in cats2id:
cats2id[gt_bbox["label"]] = cat_count
gt_json["categories"].append({
"id": cat_count,
"name": gt_bbox["label"]
})
cat_count += 1
gt_json["annotations"].append({
"id": idx+1,
"image_id": 0,
"category_id": cats2id[gt_bbox["label"]],
"bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
"area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
"iscrowd": 0
})
coco_gt = COCO(gt_json)
dt_json = []
for idx, pred_bbox in enumerate(pred_bbox_list):
try:
dt_json.append({
"image_id": 0,
"category_id": cats2id[pred_bbox["label"]],
"bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
"score": 1.0,
"area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
})
except:
pass
if len(dt_json) == 0:
return 0.0
coco_dt = coco_gt.loadRes(dt_json)
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats[1]
def map_reward(content, sol, **kwargs):
"""
Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
Args:
content: String containing predicted bounding boxes in JSON format
sol: String containing ground truth bounding boxes in JSON format
Returns:
float: mAP reward score between 0 and 1
"""
# Extract JSON content between ```json tags
pattern = r'```json(.*?)```'
json_match = re.search(pattern, sol, re.DOTALL)
bbox_json = json_match.group(1).strip() if json_match else None
# Parse ground truth JSON to get bbox list
gt_bbox_list = []
if bbox_json:
bbox_data = json.loads(bbox_json)
gt_bbox_list = [item for item in bbox_data]
# Parse predicted JSON to get bbox list
pred_bbox_list = []
json_match = re.search(pattern, content, re.DOTALL)
if json_match:
try:
bbox_data = json.loads(json_match.group(1).strip())
pred_bbox_list = [item for item in bbox_data]
except:
# Return empty list if JSON parsing fails
pred_bbox_list = []
# Calculate mAP if both prediction and ground truth exist
if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
else:
bbox_reward = 0.0
return bbox_reward
def numeric_reward(content, sol, **kwargs):
content = clean_text(content)
sol = clean_text(sol)
try:
content, sol = float(content), float(sol)
return 1.0 if content == sol else 0.0
except:
return None
def math_reward(content, sol, **kwargs):
content = clean_text(content)
sol = clean_text(sol)
return compute_score(content, sol)
def clean_text(text, exclue_chars=['\n', '\r']):
# Extract content between <answer> and </answer> if present
answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
if answer_matches:
# Use the last match
text = answer_matches[-1]
for char in exclue_chars:
if char in ['\n', '\r']:
# If there is a space before the newline, remove the newline
text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
# If there is no space before the newline, replace it with a space
text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
else:
text = text.replace(char, ' ')
# Remove leading and trailing spaces and convert to lowercase
return text.strip().rstrip('.').lower()
def default_accuracy_reward(content, sol, **kwargs):
reward = 0.0
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_matches[-1].strip() if content_matches else content.strip()
# Try symbolic verification first for numeric answers
try:
answer = parse(student_answer)
if float(verify(answer, parse(ground_truth))) > 0:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
# If symbolic verification failed, try string matching or fuzzy matching
if reward == 0.0:
try:
# Check if ground truth contains numbers
has_numbers = bool(re.search(r'\d', ground_truth))
# Check if it's a multiple choice question
has_choices = extract_choice(ground_truth)
if has_numbers:
# For numeric answers, use exact matching
reward = numeric_reward(student_answer, ground_truth)
if reward is None:
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
elif has_choices:
# For multiple choice, extract and compare choices
correct_choice = has_choices.upper()
student_choice = extract_choice(student_answer)
if student_choice:
reward = 1.0 if student_choice == correct_choice else 0.0
else:
# For text answers, use fuzzy matching
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
except Exception:
pass # Keep reward as 0.0 if all methods fail
return reward
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
# if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
if accu_reward_method == "mcq":
reward = mcq_reward(content, sol)
elif accu_reward_method == 'yes_no':
reward = yes_no_reward(content, sol)
elif accu_reward_method == 'llm':
reward = llm_reward(content, sol)
elif accu_reward_method == 'map':
reward = map_reward(content, sol)
elif accu_reward_method == 'math':
reward = math_reward(content, sol)
else:
reward = default_accuracy_reward(content, sol)
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
problem = kwargs.get("problem")[0]
if reward <= 1.0: # this condition can be changed for debug
with open(log_path, "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
f.write(f"accu_reward_method: {accu_reward_method}\n")
f.write(f"image_path: {image_path}\n")
f.write(f"problem: {problem}\n")
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Format reward -------------\n")
for content, match in zip(completion_contents, matches):
f.write(f"Content: {content}\n")
f.write(f"Has format: {bool(match)}\n")
return [1.0 if match else 0.0 for match in matches]
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
@dataclass
class GRPOModelConfig(ModelConfig):
freeze_vision_modules: bool = False
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def get_vlm_module(model_name_or_path):
if "qwen" in model_name_or_path.lower():
return Qwen2VLModule
elif "internvl" in model_name_or_path.lower():
return InvernVLModule
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
def main(script_args, training_args, model_args):
# Load the VLM module
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
question_prompt = vlm_module_cls.get_question_template(task_type="default")
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the JSONL datasets
import json
from datasets import Dataset
data_files = script_args.data_file_paths.split(":")
image_folders = script_args.image_folders.split(":")
if len(data_files) != len(image_folders):
raise ValueError("Number of data files must match number of image folders")
if script_args.reward_method is None:
accu_reward_methods = ["default"] * len(data_files)
else:
accu_reward_methods = script_args.reward_method.split(":")
assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
if len(data_files) != len(image_folders):
raise ValueError("Number of data files must match number of image folders")
all_data = []
for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
with open(data_file, 'r') as f:
for line in f:
item = json.loads(line)
if 'image' in item:
if isinstance(item['image'], str):
# Store image path instead of loading the image
item['image_path'] = [os.path.join(image_folder, item['image'])]
del item['image'] # remove the image column so that it can be loaded later
elif isinstance(item['image'], list):
# if the image is a list, then it is a list of images (for multi-image input)
item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
del item['image'] # remove the image column so that it can be loaded later
else:
raise ValueError(f"Unsupported image type: {type(item['image'])}")
# Remove immediate image loading
item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
# Handle solution that could be a float or string
solution_value = item['conversations'][1]['value']
if isinstance(solution_value, str):
item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
else:
# If it's a float or other non-string type, keep it as is
item['solution'] = str(solution_value)
del item['conversations']
item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
all_data.append(item)
dataset = Dataset.from_list(all_data)
def make_conversation_from_jsonl(example):
if 'image_path' in example and example['image_path'] is not None:
# Don't load image here, just store the path
return {
'image_path': [p for p in example['image_path']], # Store path instead of loaded image
'problem': example['problem'],
'solution': f"<answer> {example['solution']} </answer>",
'accu_reward_method': example['accu_reward_method'],
'prompt': [{
'role': 'user',
'content': [
*({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
]
}]
}
else:
return {
'problem': example['problem'],
'solution': f"<answer> {example['solution']} </answer>",
'accu_reward_method': example['accu_reward_method'],
'prompt': [{
'role': 'user',
'content': [
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
]
}]
}
# Map the conversations
dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
# Split dataset for validation if requested
splits = {'train': dataset}
if script_args.val_split_ratio > 0:
train_val_split = dataset.train_test_split(
test_size=script_args.val_split_ratio
)
splits['train'] = train_val_split['train']
splits['validation'] = train_val_split['test']
# Select trainer class based on vlm_trainer argument
trainer_cls = VLMGRPOTrainer
print("using trainer:", trainer_cls.__name__)
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
vlm_module=vlm_module_cls(),
train_dataset=splits['train'],
eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
freeze_vision_modules=model_args.freeze_vision_modules,
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
)
# Train and push the model to the Hub
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)