650 lines
26 KiB
Python
650 lines
26 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import os
|
||
import re
|
||
import pathlib
|
||
from datetime import datetime
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional
|
||
from babel.numbers import parse_decimal
|
||
from utils.math import compute_score
|
||
from datasets import load_dataset, load_from_disk
|
||
from transformers import Qwen2VLForConditionalGeneration
|
||
|
||
from math_verify import parse, verify
|
||
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
||
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||
import PIL
|
||
from Levenshtein import ratio
|
||
from open_r1.utils.pycocotools.coco import COCO
|
||
from open_r1.utils.pycocotools.cocoeval import COCOeval
|
||
import json
|
||
|
||
from open_r1.vlm_modules import *
|
||
|
||
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
||
import torch
|
||
from typing import Tuple
|
||
from transformers.utils import logging
|
||
|
||
from openai import OpenAI
|
||
|
||
logger = logging.get_logger(__name__)
|
||
|
||
client = OpenAI(
|
||
api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
|
||
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||
)
|
||
|
||
def custom_forward(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
cu_seqlens: torch.Tensor,
|
||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
) -> torch.Tensor:
|
||
seq_length = hidden_states.shape[0]
|
||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
||
if position_embeddings is None:
|
||
logger.warning_once(
|
||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||
"removed and `position_embeddings` will be mandatory."
|
||
)
|
||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||
cos = emb.cos().float()
|
||
sin = emb.sin().float()
|
||
else:
|
||
cos, sin = position_embeddings
|
||
# Add this
|
||
cos = cos.to(torch.float)
|
||
sin = sin.to(torch.float)
|
||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||
q = q.squeeze(0)
|
||
k = k.squeeze(0)
|
||
|
||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||
seq_length, -1
|
||
)
|
||
attn_output = self.proj(attn_output)
|
||
return attn_output
|
||
|
||
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
||
|
||
@dataclass
|
||
class GRPOScriptArguments(ScriptArguments):
|
||
"""
|
||
Script arguments for the GRPO training script.
|
||
"""
|
||
data_file_paths: str = field(
|
||
default=None,
|
||
metadata={"help": "Paths to data files, separated by ':'"},
|
||
)
|
||
image_folders: str = field(
|
||
default=None,
|
||
metadata={"help": "Paths to image folders, separated by ':'"},
|
||
)
|
||
arrow_cache_dir: str = field(
|
||
default=None,
|
||
metadata={"help": "Path to arrow cache directory"},
|
||
)
|
||
val_split_ratio: float = field(
|
||
default=0.0,
|
||
metadata={"help": "Ratio of validation split, default 0.0"},
|
||
)
|
||
reward_funcs: list[str] = field(
|
||
default_factory=lambda: ["accuracy", "format"],
|
||
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
||
)
|
||
max_pixels: Optional[int] = field(
|
||
default=12845056,
|
||
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
||
)
|
||
min_pixels: Optional[int] = field(
|
||
default=3136,
|
||
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
||
)
|
||
max_anyres_num: Optional[int] = field(
|
||
default=12,
|
||
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
||
)
|
||
reward_method: Optional[str] = field(
|
||
default=None,
|
||
metadata={
|
||
"help": "Choose reward method: 'default', 'mcp', ..."
|
||
},
|
||
)
|
||
|
||
def extract_choice(text):
|
||
# 1. Clean and normalize text
|
||
text = text.upper() # Convert to uppercase
|
||
text = re.sub(r'\s+', ' ', text) # Normalize spaces
|
||
|
||
# 2. Choice should not have uppercase letters before or after
|
||
choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
|
||
|
||
if not choices:
|
||
return None
|
||
|
||
# 3. If only one choice, return it directly
|
||
if len(choices) == 1:
|
||
return choices[0]
|
||
|
||
# 4. If multiple choices, use heuristic rules
|
||
choice_scores = {choice: 0 for choice in choices}
|
||
|
||
# 4.1 Keywords around choices get points
|
||
keywords = [
|
||
'答案', '选择', '正确', '是', '对',
|
||
'answer', 'correct', 'choose', 'select', 'right',
|
||
'认为', '应该', '觉得', 'think', 'believe', 'should'
|
||
]
|
||
|
||
# Get context for each choice (20 chars before and after)
|
||
for choice in choices:
|
||
pos = text.find(choice)
|
||
context = text[max(0, pos-20):min(len(text), pos+20)]
|
||
|
||
# Add points for keywords
|
||
for keyword in keywords:
|
||
if keyword.upper() in context:
|
||
choice_scores[choice] += 1
|
||
|
||
# Add points if choice is near the end (usually final answer)
|
||
if pos > len(text) * 0.7: # In last 30% of text
|
||
choice_scores[choice] += 2
|
||
|
||
# Add points if followed by punctuation
|
||
if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
|
||
choice_scores[choice] += 1
|
||
|
||
# Return highest scoring choice
|
||
return max(choice_scores.items(), key=lambda x: x[1])[0]
|
||
|
||
def evaluate_answer_similarity(student_answer, ground_truth):
|
||
"""Use llm to evaluate answer similarity."""
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model="qwen2.5:7b",
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
|
||
}
|
||
],
|
||
temperature=0
|
||
)
|
||
result = response.choices[0].message.content.strip()
|
||
return float(result)
|
||
|
||
except Exception as e:
|
||
print(f"Error in GPT evaluation: {e}")
|
||
# If API call fails, fall back to simple text matching
|
||
return 1.0 if student_answer ==ground_truth else 0.0
|
||
|
||
def llm_reward(content, sol, **kwargs):
|
||
# Extract answer from content if it has think/answer tags
|
||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||
|
||
# Extract answer from content if it has think/answer tags
|
||
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
||
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
||
return evaluate_answer_similarity(student_answer, ground_truth)
|
||
|
||
def mcq_reward(content, sol, **kwargs):
|
||
# For multiple choice, extract and compare choices
|
||
has_choices = extract_choice(sol)
|
||
correct_choice = has_choices.upper() if has_choices else sol.strip()
|
||
|
||
# Extract answer from content if it has think/answer tags
|
||
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
||
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
||
student_choice = extract_choice(student_answer)
|
||
if student_choice:
|
||
reward = 1.0 if student_choice == correct_choice else 0.0
|
||
else:
|
||
reward = 0.0
|
||
|
||
return reward
|
||
|
||
|
||
def yes_no_reward(content, sol, **kwargs):
|
||
content = content.lower()
|
||
sol = sol.lower()
|
||
|
||
# Extract answer from solution if it has think/answer tags
|
||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||
|
||
# Extract answer from content if it has think/answer tags
|
||
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
||
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
||
|
||
ground_yes_no = re.search(r'(yes|no)', ground_truth)
|
||
ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
|
||
student_yes_no = re.search(r'(yes|no)', student_answer)
|
||
student_yes_no = student_yes_no.group(1) if student_yes_no else ''
|
||
|
||
reward = 1.0 if ground_yes_no == student_yes_no else 0.0
|
||
|
||
return reward
|
||
|
||
def calculate_map(pred_bbox_list, gt_bbox_list):
|
||
# Calculate mAP
|
||
|
||
# Initialize COCO object for ground truth
|
||
gt_json = {"annotations": [], "images": [], "categories": []}
|
||
gt_json["images"] = [{
|
||
"id": 0,
|
||
"width": 2048,
|
||
"height": 2048,
|
||
"file_name": "image_0.jpg"
|
||
}]
|
||
|
||
gt_json["categories"] = []
|
||
|
||
cats2id = {}
|
||
cat_count = 0
|
||
for idx, gt_bbox in enumerate(gt_bbox_list):
|
||
if gt_bbox["label"] not in cats2id:
|
||
cats2id[gt_bbox["label"]] = cat_count
|
||
gt_json["categories"].append({
|
||
"id": cat_count,
|
||
"name": gt_bbox["label"]
|
||
})
|
||
cat_count += 1
|
||
|
||
gt_json["annotations"].append({
|
||
"id": idx+1,
|
||
"image_id": 0,
|
||
"category_id": cats2id[gt_bbox["label"]],
|
||
"bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
|
||
"area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
|
||
"iscrowd": 0
|
||
})
|
||
coco_gt = COCO(gt_json)
|
||
|
||
dt_json = []
|
||
for idx, pred_bbox in enumerate(pred_bbox_list):
|
||
try:
|
||
dt_json.append({
|
||
"image_id": 0,
|
||
"category_id": cats2id[pred_bbox["label"]],
|
||
"bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
|
||
"score": 1.0,
|
||
"area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
|
||
})
|
||
except:
|
||
pass
|
||
|
||
if len(dt_json) == 0:
|
||
return 0.0
|
||
|
||
coco_dt = coco_gt.loadRes(dt_json)
|
||
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
|
||
|
||
coco_eval.evaluate()
|
||
coco_eval.accumulate()
|
||
coco_eval.summarize()
|
||
return coco_eval.stats[1]
|
||
|
||
def map_reward(content, sol, **kwargs):
|
||
"""
|
||
Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
|
||
|
||
Args:
|
||
content: String containing predicted bounding boxes in JSON format
|
||
sol: String containing ground truth bounding boxes in JSON format
|
||
|
||
Returns:
|
||
float: mAP reward score between 0 and 1
|
||
"""
|
||
# Extract JSON content between ```json tags
|
||
pattern = r'```json(.*?)```'
|
||
json_match = re.search(pattern, sol, re.DOTALL)
|
||
bbox_json = json_match.group(1).strip() if json_match else None
|
||
|
||
# Parse ground truth JSON to get bbox list
|
||
gt_bbox_list = []
|
||
if bbox_json:
|
||
bbox_data = json.loads(bbox_json)
|
||
gt_bbox_list = [item for item in bbox_data]
|
||
|
||
# Parse predicted JSON to get bbox list
|
||
pred_bbox_list = []
|
||
json_match = re.search(pattern, content, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
bbox_data = json.loads(json_match.group(1).strip())
|
||
pred_bbox_list = [item for item in bbox_data]
|
||
except:
|
||
# Return empty list if JSON parsing fails
|
||
pred_bbox_list = []
|
||
|
||
# Calculate mAP if both prediction and ground truth exist
|
||
if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
|
||
bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
|
||
else:
|
||
bbox_reward = 0.0
|
||
|
||
return bbox_reward
|
||
|
||
|
||
def numeric_reward(content, sol, **kwargs):
|
||
content = clean_text(content)
|
||
sol = clean_text(sol)
|
||
try:
|
||
content, sol = float(content), float(sol)
|
||
return 1.0 if content == sol else 0.0
|
||
except:
|
||
return None
|
||
def math_reward(content, sol, **kwargs):
|
||
content = clean_text(content)
|
||
sol = clean_text(sol)
|
||
return compute_score(content, sol)
|
||
def clean_text(text, exclue_chars=['\n', '\r']):
|
||
# Extract content between <answer> and </answer> if present
|
||
answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
|
||
if answer_matches:
|
||
# Use the last match
|
||
text = answer_matches[-1]
|
||
|
||
for char in exclue_chars:
|
||
if char in ['\n', '\r']:
|
||
# If there is a space before the newline, remove the newline
|
||
text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
|
||
# If there is no space before the newline, replace it with a space
|
||
text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
|
||
else:
|
||
text = text.replace(char, ' ')
|
||
|
||
# Remove leading and trailing spaces and convert to lowercase
|
||
return text.strip().rstrip('.').lower()
|
||
|
||
def default_accuracy_reward(content, sol, **kwargs):
|
||
reward = 0.0
|
||
# Extract answer from solution if it has think/answer tags
|
||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||
|
||
# Extract answer from content if it has think/answer tags
|
||
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
||
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
||
|
||
# Try symbolic verification first for numeric answers
|
||
try:
|
||
answer = parse(student_answer)
|
||
if float(verify(answer, parse(ground_truth))) > 0:
|
||
reward = 1.0
|
||
except Exception:
|
||
pass # Continue to next verification method if this fails
|
||
|
||
# If symbolic verification failed, try string matching or fuzzy matching
|
||
if reward == 0.0:
|
||
try:
|
||
# Check if ground truth contains numbers
|
||
has_numbers = bool(re.search(r'\d', ground_truth))
|
||
# Check if it's a multiple choice question
|
||
has_choices = extract_choice(ground_truth)
|
||
|
||
if has_numbers:
|
||
# For numeric answers, use exact matching
|
||
reward = numeric_reward(student_answer, ground_truth)
|
||
if reward is None:
|
||
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
||
elif has_choices:
|
||
# For multiple choice, extract and compare choices
|
||
correct_choice = has_choices.upper()
|
||
student_choice = extract_choice(student_answer)
|
||
if student_choice:
|
||
reward = 1.0 if student_choice == correct_choice else 0.0
|
||
else:
|
||
# For text answers, use fuzzy matching
|
||
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
||
except Exception:
|
||
pass # Keep reward as 0.0 if all methods fail
|
||
|
||
return reward
|
||
|
||
def accuracy_reward(completions, solution, **kwargs):
|
||
"""Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
|
||
contents = [completion[0]["content"] for completion in completions]
|
||
rewards = []
|
||
for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
|
||
# if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
|
||
if accu_reward_method == "mcq":
|
||
reward = mcq_reward(content, sol)
|
||
elif accu_reward_method == 'yes_no':
|
||
reward = yes_no_reward(content, sol)
|
||
elif accu_reward_method == 'llm':
|
||
reward = llm_reward(content, sol)
|
||
elif accu_reward_method == 'map':
|
||
reward = map_reward(content, sol)
|
||
elif accu_reward_method == 'math':
|
||
reward = math_reward(content, sol)
|
||
else:
|
||
reward = default_accuracy_reward(content, sol)
|
||
rewards.append(reward)
|
||
|
||
if os.getenv("DEBUG_MODE") == "true":
|
||
log_path = os.getenv("LOG_PATH")
|
||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
|
||
problem = kwargs.get("problem")[0]
|
||
if reward <= 1.0: # this condition can be changed for debug
|
||
with open(log_path, "a", encoding='utf-8') as f:
|
||
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
||
f.write(f"accu_reward_method: {accu_reward_method}\n")
|
||
f.write(f"image_path: {image_path}\n")
|
||
f.write(f"problem: {problem}\n")
|
||
f.write(f"Content: {content}\n")
|
||
f.write(f"Solution: {sol}\n")
|
||
|
||
|
||
return rewards
|
||
|
||
|
||
def format_reward(completions, **kwargs):
|
||
"""Reward function that checks if the completion has a specific format."""
|
||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||
completion_contents = [completion[0]["content"] for completion in completions]
|
||
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
||
|
||
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
||
if os.getenv("DEBUG_MODE") == "true":
|
||
log_path = os.getenv("LOG_PATH")
|
||
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
|
||
f.write(f"------------- {current_time} Format reward -------------\n")
|
||
for content, match in zip(completion_contents, matches):
|
||
f.write(f"Content: {content}\n")
|
||
f.write(f"Has format: {bool(match)}\n")
|
||
|
||
return [1.0 if match else 0.0 for match in matches]
|
||
|
||
|
||
reward_funcs_registry = {
|
||
"accuracy": accuracy_reward,
|
||
"format": format_reward,
|
||
}
|
||
|
||
@dataclass
|
||
class GRPOModelConfig(ModelConfig):
|
||
freeze_vision_modules: bool = False
|
||
|
||
SYSTEM_PROMPT = (
|
||
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
||
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
||
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
||
"<think> reasoning process here </think><answer> answer here </answer>"
|
||
)
|
||
|
||
|
||
def get_vlm_module(model_name_or_path):
|
||
if "qwen" in model_name_or_path.lower():
|
||
return Qwen2VLModule
|
||
elif "internvl" in model_name_or_path.lower():
|
||
return InvernVLModule
|
||
else:
|
||
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
||
|
||
def main(script_args, training_args, model_args):
|
||
# Load the VLM module
|
||
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
||
print("using vlm module:", vlm_module_cls.__name__)
|
||
question_prompt = vlm_module_cls.get_question_template(task_type="default")
|
||
|
||
# Get reward functions
|
||
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
||
print("reward_funcs:", reward_funcs)
|
||
|
||
# Load the JSONL datasets
|
||
import json
|
||
from datasets import Dataset
|
||
|
||
data_files = script_args.data_file_paths.split(":")
|
||
image_folders = script_args.image_folders.split(":")
|
||
|
||
if len(data_files) != len(image_folders):
|
||
raise ValueError("Number of data files must match number of image folders")
|
||
|
||
if script_args.reward_method is None:
|
||
accu_reward_methods = ["default"] * len(data_files)
|
||
else:
|
||
accu_reward_methods = script_args.reward_method.split(":")
|
||
assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
|
||
|
||
|
||
if len(data_files) != len(image_folders):
|
||
raise ValueError("Number of data files must match number of image folders")
|
||
|
||
all_data = []
|
||
for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
|
||
with open(data_file, 'r') as f:
|
||
for line in f:
|
||
item = json.loads(line)
|
||
if 'image' in item:
|
||
if isinstance(item['image'], str):
|
||
# Store image path instead of loading the image
|
||
item['image_path'] = [os.path.join(image_folder, item['image'])]
|
||
del item['image'] # remove the image column so that it can be loaded later
|
||
elif isinstance(item['image'], list):
|
||
# if the image is a list, then it is a list of images (for multi-image input)
|
||
item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
|
||
del item['image'] # remove the image column so that it can be loaded later
|
||
else:
|
||
raise ValueError(f"Unsupported image type: {type(item['image'])}")
|
||
# Remove immediate image loading
|
||
item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
|
||
|
||
# Handle solution that could be a float or string
|
||
solution_value = item['conversations'][1]['value']
|
||
if isinstance(solution_value, str):
|
||
item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
|
||
else:
|
||
# If it's a float or other non-string type, keep it as is
|
||
item['solution'] = str(solution_value)
|
||
|
||
del item['conversations']
|
||
item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
|
||
all_data.append(item)
|
||
|
||
dataset = Dataset.from_list(all_data)
|
||
|
||
def make_conversation_from_jsonl(example):
|
||
if 'image_path' in example and example['image_path'] is not None:
|
||
# Don't load image here, just store the path
|
||
return {
|
||
'image_path': [p for p in example['image_path']], # Store path instead of loaded image
|
||
'problem': example['problem'],
|
||
'solution': f"<answer> {example['solution']} </answer>",
|
||
'accu_reward_method': example['accu_reward_method'],
|
||
'prompt': [{
|
||
'role': 'user',
|
||
'content': [
|
||
*({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
|
||
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
||
]
|
||
}]
|
||
}
|
||
else:
|
||
return {
|
||
'problem': example['problem'],
|
||
'solution': f"<answer> {example['solution']} </answer>",
|
||
'accu_reward_method': example['accu_reward_method'],
|
||
'prompt': [{
|
||
'role': 'user',
|
||
'content': [
|
||
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
||
]
|
||
}]
|
||
}
|
||
|
||
# Map the conversations
|
||
dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
|
||
|
||
# Split dataset for validation if requested
|
||
splits = {'train': dataset}
|
||
if script_args.val_split_ratio > 0:
|
||
train_val_split = dataset.train_test_split(
|
||
test_size=script_args.val_split_ratio
|
||
)
|
||
splits['train'] = train_val_split['train']
|
||
splits['validation'] = train_val_split['test']
|
||
|
||
# Select trainer class based on vlm_trainer argument
|
||
trainer_cls = VLMGRPOTrainer
|
||
print("using trainer:", trainer_cls.__name__)
|
||
|
||
# Initialize the GRPO trainer
|
||
trainer = trainer_cls(
|
||
model=model_args.model_name_or_path,
|
||
reward_funcs=reward_funcs,
|
||
args=training_args,
|
||
vlm_module=vlm_module_cls(),
|
||
train_dataset=splits['train'],
|
||
eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
|
||
peft_config=get_peft_config(model_args),
|
||
freeze_vision_modules=model_args.freeze_vision_modules,
|
||
attn_implementation=model_args.attn_implementation,
|
||
max_pixels=script_args.max_pixels,
|
||
min_pixels=script_args.min_pixels,
|
||
)
|
||
|
||
# Train and push the model to the Hub
|
||
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
||
trainer.train(resume_from_checkpoint=True)
|
||
else:
|
||
trainer.train()
|
||
|
||
# Save and push to hub
|
||
trainer.save_model(training_args.output_dir)
|
||
if training_args.push_to_hub:
|
||
trainer.push_to_hub()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||
main(script_args, training_args, model_args)
|