qw和gemma3 grpo

This commit is contained in:
Zixiao Wang 2025-03-31 15:56:36 +08:00
parent e77007c5a8
commit 6d01dcc49a
35 changed files with 6449 additions and 0 deletions

159
gemma3/test.py Normal file
View File

@ -0,0 +1,159 @@
# pip install accelerate
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
import base64
from io import BytesIO
import pickle
from torch.utils.data import Dataset, DataLoader
import logging
import re
from tqdm import tqdm
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
model_id = "gemma-3-12b"#"outputs/checkpoint-400"
# 加载模型
logging.info("加载模型中...")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id)
logging.info("模型加载完成")
def image_to_base64(image):
# 将PIL Image对象转换为base64
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
class TestDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image_base64 = image_to_base64(item['image'])
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_base64},
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"}
]
}
]
return {
"messages": messages,
"correct_brand": item['brand']
}
# 加载测试数据
logging.info("加载测试数据...")
test_dataset = TestDataset('../work/bal_data/test_data.pkl')
total_samples = len(test_dataset)
logging.info(f"加载了 {total_samples} 条测试数据")
def evaluate_prediction(prediction, correct_brand):
answer_match = re.search(r'<answer/>(.*?)</answer>', prediction)
if answer_match:
predicted_brand = answer_match.group(1).strip().lower()
return correct_brand.lower() in predicted_brand
return False
# 进行测试
batch_size = 25
correct_count = 0
processed_count = 0
# 计算总批次数
total_batches = (total_samples + batch_size - 1) // batch_size
logging.info("开始测试...")
progress_bar = tqdm(
total=total_samples,
desc="测试进度",
ncols=100, # 进度条总长度
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
)
with torch.inference_mode():
for i in range(0, total_samples, batch_size):
# 准备batch数据
batch_messages = []
batch_correct_brands = []
# 获取当前batch的数据
end_idx = min(i + batch_size, total_samples)
current_batch_size = end_idx - i
for idx in range(i, end_idx):
sample = test_dataset[idx]
batch_messages.append(sample['messages'])
batch_correct_brands.append(sample['correct_brand'])
# 处理当前batch
inputs = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
generation = model.generate(**inputs, max_new_tokens=300, do_sample=False)
generation = generation[:,input_len:]
predictions = processor.batch_decode(generation, skip_special_tokens=True)
# 评估结果
batch_correct = 0
for pred, correct in zip(predictions, batch_correct_brands):
if evaluate_prediction(pred, correct):
correct_count += 1
batch_correct += 1
processed_count += current_batch_size
# 更新进度条
progress_bar.update(current_batch_size)
progress_bar.set_postfix({
'acc': f'{correct_count/processed_count:.4f}',
'correct': f'{correct_count}/{processed_count}'
})
progress_bar.close()
# 输出最终结果
final_accuracy = correct_count / total_samples
logging.info("\n测试完成!")
logging.info(f"总样本数: {total_samples}")
logging.info(f"正确预测数: {correct_count}")
logging.info(f"最终准确率: {final_accuracy:.4f}")
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
# It has a slightly soft, natural feel, likely captured in daylight.

238
gemma3/train.py Normal file
View File

@ -0,0 +1,238 @@
from unsloth import FastModel
import torch
from PIL import Image
import pickle
from torch.utils.data import Dataset
import base64
from io import BytesIO
import logging
from trl import GRPOConfig, GRPOTrainer
import re
# 在文件开头添加日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training.log'),
logging.StreamHandler()
]
)
max_seq_length = 1024
# 在模型加载后添加日志
logging.info("正在加载模型和tokenizer...")
model, tokenizer = FastModel.from_pretrained(
model_name = "gemma-3-4b",
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = False, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
logging.info("模型加载完成")
logging.info("正在配置PEFT参数...")
model = FastModel.get_peft_model(
model,
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 8, # Larger = higher accuracy, but might overfit
lora_alpha = 8, # Recommended alpha == r at least
lora_dropout = 0,
bias = "none",
random_state = 3407,
)
logging.info("PEFT模型配置完成")
def image_to_base64(image):
# 将PIL Image对象转换为base64
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
class ChatDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 将图片转换为base64格式
image_base64 = image_to_base64(item['image'])
# 使用dra.py的messages格式
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image_base64},
{"type": "text", "text": "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"}
]
}
]
return {
"prompt": messages, # 包含了图片和提示文本的完整模板
"correct_brand": item['brand']
}
# 加载数据集
logging.info("加载训练数据...")
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
# 使用示例:
# batch = next(iter(train_dataset))
# messages = batch["messages"]
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# inputs = tokenizer(input_text, add_special_tokens=False, return_tensors="pt")
# messages = [
# {
# "role": "system",
# "content": [{"type": "text", "text": "You are a helpful assistant."}]
# },
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": "Describe this image in detail."}
# ]
# }
# ]
# input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
# image = Image.open("../work/Tesla.jpg")
# inputs = tokenizer(
# image,
# input_text,
# add_special_tokens = False,
# return_tensors = "pt",
# ).to(model.device, dtype=torch.bfloat16)
# input_len = inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
# generation = generation[0][input_len:]
# decoded = tokenizer.decode(generation, skip_special_tokens=True)
# print(decoded)
def reward_func(prompts, completions, **kwargs):
logging.info("开始计算奖励值...")
rewards = []
correct_brands = kwargs.get('correct_brand')
for idx, (completion, correct_brand) in enumerate(zip(completions, correct_brands)):
reward = 0.0
# 确保 completion 是字符串类型
try:
if isinstance(completion, (list, tuple)):
completion = completion[0] # 如果是列表或元组,取第一个元素
completion = str(completion) # 转换为字符串
logging.debug(f"样本 {idx + 1}:")
logging.debug(f"完整回答类型: {type(completion)}")
logging.debug(f"完整回答: {completion}")
logging.debug(f"正确品牌: {correct_brand}")
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
# 答案部分仍然检查品牌是否正确
if answer_match:
answer_content = answer_match.group(1).lower()
if correct_brand.lower() in answer_content:
reward += 1.0
logging.debug("答案部分匹配正确 (+1.0)")
# 推理部分根据长度评分
if thinking_match:
thinking_content = thinking_match.group(1).strip()
content_length = len(thinking_content)
if content_length < 50:
thinking_reward = 0.25
level = "简单"
elif content_length < 100:
thinking_reward = 0.5
level = "基础"
elif content_length < 150:
thinking_reward = 0.75
level = "详细"
else:
thinking_reward = 1.0
level = "非常详细"
reward += thinking_reward
logging.debug(f"推理部分长度: {content_length} 字符")
logging.debug(f"推理详细程度: {level}")
logging.debug(f"推理部分得分: +{thinking_reward}")
except Exception as e:
logging.error(f"处理样本 {idx} 时发生错误: {str(e)}")
logging.error(f"completion 类型: {type(completion)}")
logging.error(f"completion 内容: {completion}")
reward = 0.0 # 发生错误时给出0分
logging.debug(f"最终奖励值: {reward}\n")
rewards.append(reward)
batch_avg = sum(rewards)/len(rewards) if rewards else 0
logging.info(f"批次平均奖励值: {batch_avg:.3f}")
return rewards
max_prompt_length = 256
training_args = GRPOConfig(
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_torch_fused",
logging_steps = 1,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4, # Increase to 4 for smoother training
num_generations = 2, # Decrease if out of memory
max_completion_length = 512,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 400,
save_steps = 200,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
# 在训练开始前添加配置信息日志
logging.info("训练配置信息:")
logging.info(f"学习率: {training_args.learning_rate}")
logging.info(f"批次大小: {training_args.per_device_train_batch_size}")
logging.info(f"梯度累积步数: {training_args.gradient_accumulation_steps}")
logging.info(f"最大训练步数: {training_args.max_steps}")
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = reward_func,
args = training_args,
train_dataset = train_dataset,
)
trainer.train()

0
qw/open_r1/__init__.py Executable file
View File

Binary file not shown.

82
qw/open_r1/configs.py Executable file
View File

@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import trl
# TODO: add the shared options with a mixin to reduce code duplication
@dataclass
class GRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
@dataclass
class SFTConfig(trl.SFTConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use for benchmarking."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": "The Hub model branch to push the model to."},
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)

85
qw/open_r1/evaluate.py Executable file
View File

@ -0,0 +1,85 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
LatexExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
aggregation_function=max,
)
def prompt_fn(line, task_name: str = None):
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
return Doc(
task_name=task_name,
query=line["problem"],
choices=[line["solution"]],
gold_index=0,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(math_500)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))

156
qw/open_r1/generate.py Executable file
View File

@ -0,0 +1,156 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
def build_distilabel_pipeline(
model: str,
base_url: str = "http://localhost:8000/v1",
prompt_column: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_new_tokens: int = 8192,
num_generations: int = 1,
) -> Pipeline:
generation_kwargs = {"max_new_tokens": max_new_tokens}
if temperature is not None:
generation_kwargs["temperature"] = temperature
if top_p is not None:
generation_kwargs["top_p"] = top_p
with Pipeline().ray() as pipeline:
TextGeneration(
llm=OpenAILLM(
base_url=base_url,
api_key="something",
model=model,
# thinking can take some time...
timeout=10 * 60,
generation_kwargs=generation_kwargs,
),
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
num_generations=num_generations,
)
return pipeline
if __name__ == "__main__":
import argparse
from datasets import load_dataset
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
parser.add_argument(
"--hf-dataset",
type=str,
required=True,
help="HuggingFace dataset to load",
)
parser.add_argument(
"--hf-dataset-config",
type=str,
required=False,
help="Dataset config to use",
)
parser.add_argument(
"--hf-dataset-split",
type=str,
default="train",
help="Dataset split to use",
)
parser.add_argument("--prompt-column", type=str, default="prompt")
parser.add_argument(
"--model",
type=str,
required=True,
help="Model name to use for generation",
)
parser.add_argument(
"--vllm-server-url",
type=str,
default="http://localhost:8000/v1",
help="URL of the vLLM server",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for generation",
)
parser.add_argument(
"--top-p",
type=float,
help="Top-p value for generation",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=8192,
help="Maximum number of new tokens to generate",
)
parser.add_argument(
"--num-generations",
type=int,
default=1,
help="Number of generations per problem",
)
parser.add_argument(
"--hf-output-dataset",
type=str,
required=False,
help="HuggingFace repo to push results to",
)
parser.add_argument(
"--private",
action="store_true",
help="Whether to make the output dataset private when pushing to HF Hub",
)
args = parser.parse_args()
print("\nRunning with arguments:")
for arg, value in vars(args).items():
print(f" {arg}: {value}")
print()
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
print("Dataset loaded!")
pipeline = build_distilabel_pipeline(
model=args.model,
base_url=args.vllm_server_url,
prompt_column=args.prompt_column,
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
num_generations=args.num_generations,
)
print("Running generation pipeline...")
distiset = pipeline.run(dataset=dataset, use_cache=False)
print("Generation pipeline finished!")
if args.hf_output_dataset:
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
print("Dataset pushed!")

214
qw/open_r1/grpo.py Executable file
View File

@ -0,0 +1,214 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import debugpy
# try:
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
# debugpy.listen(("localhost", 9501))
# print("Waiting for debugger attach")
# debugpy.wait_for_client()
# except Exception as e:
# pass
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image"},
)
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
for content, sol in zip(contents, solution):
reward = 0.0
# Try symbolic verification first
try:
answer = parse(content)
if float(verify(answer, parse(sol))) > 0:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
# If symbolic verification failed, try string matching
if reward == 0.0:
try:
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content)
student_answer = content_match.group(1).strip() if content_match else content.strip()
# Compare the extracted answers
if student_answer == ground_truth:
reward = 1.0
except Exception:
pass # Keep reward as 0.0 if both methods fail
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
# local_rank = int(os.getenv("LOCAL_RANK", 0))
with open(log_path, "a") as f:
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def main(script_args, training_args, model_args):
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
# def make_conversation_image(example):
# return {
# "prompt": [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": example["problem"]},
# ],
# },
# ],
# }
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
def make_conversation_image(example):
return {
"prompt": [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
],
},
],
}
if "image" in dataset[script_args.dataset_train_split].features:
print("has image in dataset")
dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
# dataset = dataset.remove_columns(["original_question", "original_answer"])
else:
print("no image in dataset")
dataset = dataset.map(make_conversation)
dataset = dataset.remove_columns("messages")
trainer_cls = VLMGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
torch_dtype=model_args.torch_dtype,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

649
qw/open_r1/grpo_jsonl.py Normal file
View File

@ -0,0 +1,649 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import pathlib
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from babel.numbers import parse_decimal
from utils.math import compute_score
from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
import PIL
from Levenshtein import ratio
from open_r1.utils.pycocotools.coco import COCO
from open_r1.utils.pycocotools.cocoeval import COCOeval
import json
from open_r1.vlm_modules import *
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple
from transformers.utils import logging
from openai import OpenAI
logger = logging.get_logger(__name__)
client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
)
def custom_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
# Add this
cos = cos.to(torch.float)
sin = sin.to(torch.float)
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
"""
data_file_paths: str = field(
default=None,
metadata={"help": "Paths to data files, separated by ':'"},
)
image_folders: str = field(
default=None,
metadata={"help": "Paths to image folders, separated by ':'"},
)
arrow_cache_dir: str = field(
default=None,
metadata={"help": "Path to arrow cache directory"},
)
val_split_ratio: float = field(
default=0.0,
metadata={"help": "Ratio of validation split, default 0.0"},
)
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
)
max_anyres_num: Optional[int] = field(
default=12,
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
)
reward_method: Optional[str] = field(
default=None,
metadata={
"help": "Choose reward method: 'default', 'mcp', ..."
},
)
def extract_choice(text):
# 1. Clean and normalize text
text = text.upper() # Convert to uppercase
text = re.sub(r'\s+', ' ', text) # Normalize spaces
# 2. Choice should not have uppercase letters before or after
choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
if not choices:
return None
# 3. If only one choice, return it directly
if len(choices) == 1:
return choices[0]
# 4. If multiple choices, use heuristic rules
choice_scores = {choice: 0 for choice in choices}
# 4.1 Keywords around choices get points
keywords = [
'答案', '选择', '正确', '', '',
'answer', 'correct', 'choose', 'select', 'right',
'认为', '应该', '觉得', 'think', 'believe', 'should'
]
# Get context for each choice (20 chars before and after)
for choice in choices:
pos = text.find(choice)
context = text[max(0, pos-20):min(len(text), pos+20)]
# Add points for keywords
for keyword in keywords:
if keyword.upper() in context:
choice_scores[choice] += 1
# Add points if choice is near the end (usually final answer)
if pos > len(text) * 0.7: # In last 30% of text
choice_scores[choice] += 2
# Add points if followed by punctuation
if pos < len(text) - 1 and text[pos+1] in '。.!,':
choice_scores[choice] += 1
# Return highest scoring choice
return max(choice_scores.items(), key=lambda x: x[1])[0]
def evaluate_answer_similarity(student_answer, ground_truth):
"""Use llm to evaluate answer similarity."""
try:
response = client.chat.completions.create(
model="qwen2.5:7b",
messages=[
{
"role": "user",
"content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
},
{
"role": "user",
"content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
}
],
temperature=0
)
result = response.choices[0].message.content.strip()
return float(result)
except Exception as e:
print(f"Error in GPT evaluation: {e}")
# If API call fails, fall back to simple text matching
return 1.0 if student_answer ==ground_truth else 0.0
def llm_reward(content, sol, **kwargs):
# Extract answer from content if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_matches[-1].strip() if content_matches else content.strip()
return evaluate_answer_similarity(student_answer, ground_truth)
def mcq_reward(content, sol, **kwargs):
# For multiple choice, extract and compare choices
has_choices = extract_choice(sol)
correct_choice = has_choices.upper() if has_choices else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_match.group(1).strip() if content_match else content.strip()
student_choice = extract_choice(student_answer)
if student_choice:
reward = 1.0 if student_choice == correct_choice else 0.0
else:
reward = 0.0
return reward
def yes_no_reward(content, sol, **kwargs):
content = content.lower()
sol = sol.lower()
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_match.group(1).strip() if content_match else content.strip()
ground_yes_no = re.search(r'(yes|no)', ground_truth)
ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
student_yes_no = re.search(r'(yes|no)', student_answer)
student_yes_no = student_yes_no.group(1) if student_yes_no else ''
reward = 1.0 if ground_yes_no == student_yes_no else 0.0
return reward
def calculate_map(pred_bbox_list, gt_bbox_list):
# Calculate mAP
# Initialize COCO object for ground truth
gt_json = {"annotations": [], "images": [], "categories": []}
gt_json["images"] = [{
"id": 0,
"width": 2048,
"height": 2048,
"file_name": "image_0.jpg"
}]
gt_json["categories"] = []
cats2id = {}
cat_count = 0
for idx, gt_bbox in enumerate(gt_bbox_list):
if gt_bbox["label"] not in cats2id:
cats2id[gt_bbox["label"]] = cat_count
gt_json["categories"].append({
"id": cat_count,
"name": gt_bbox["label"]
})
cat_count += 1
gt_json["annotations"].append({
"id": idx+1,
"image_id": 0,
"category_id": cats2id[gt_bbox["label"]],
"bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
"area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
"iscrowd": 0
})
coco_gt = COCO(gt_json)
dt_json = []
for idx, pred_bbox in enumerate(pred_bbox_list):
try:
dt_json.append({
"image_id": 0,
"category_id": cats2id[pred_bbox["label"]],
"bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
"score": 1.0,
"area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
})
except:
pass
if len(dt_json) == 0:
return 0.0
coco_dt = coco_gt.loadRes(dt_json)
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats[1]
def map_reward(content, sol, **kwargs):
"""
Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
Args:
content: String containing predicted bounding boxes in JSON format
sol: String containing ground truth bounding boxes in JSON format
Returns:
float: mAP reward score between 0 and 1
"""
# Extract JSON content between ```json tags
pattern = r'```json(.*?)```'
json_match = re.search(pattern, sol, re.DOTALL)
bbox_json = json_match.group(1).strip() if json_match else None
# Parse ground truth JSON to get bbox list
gt_bbox_list = []
if bbox_json:
bbox_data = json.loads(bbox_json)
gt_bbox_list = [item for item in bbox_data]
# Parse predicted JSON to get bbox list
pred_bbox_list = []
json_match = re.search(pattern, content, re.DOTALL)
if json_match:
try:
bbox_data = json.loads(json_match.group(1).strip())
pred_bbox_list = [item for item in bbox_data]
except:
# Return empty list if JSON parsing fails
pred_bbox_list = []
# Calculate mAP if both prediction and ground truth exist
if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
else:
bbox_reward = 0.0
return bbox_reward
def numeric_reward(content, sol, **kwargs):
content = clean_text(content)
sol = clean_text(sol)
try:
content, sol = float(content), float(sol)
return 1.0 if content == sol else 0.0
except:
return None
def math_reward(content, sol, **kwargs):
content = clean_text(content)
sol = clean_text(sol)
return compute_score(content, sol)
def clean_text(text, exclue_chars=['\n', '\r']):
# Extract content between <answer> and </answer> if present
answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
if answer_matches:
# Use the last match
text = answer_matches[-1]
for char in exclue_chars:
if char in ['\n', '\r']:
# If there is a space before the newline, remove the newline
text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
# If there is no space before the newline, replace it with a space
text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
else:
text = text.replace(char, ' ')
# Remove leading and trailing spaces and convert to lowercase
return text.strip().rstrip('.').lower()
def default_accuracy_reward(content, sol, **kwargs):
reward = 0.0
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
student_answer = content_matches[-1].strip() if content_matches else content.strip()
# Try symbolic verification first for numeric answers
try:
answer = parse(student_answer)
if float(verify(answer, parse(ground_truth))) > 0:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
# If symbolic verification failed, try string matching or fuzzy matching
if reward == 0.0:
try:
# Check if ground truth contains numbers
has_numbers = bool(re.search(r'\d', ground_truth))
# Check if it's a multiple choice question
has_choices = extract_choice(ground_truth)
if has_numbers:
# For numeric answers, use exact matching
reward = numeric_reward(student_answer, ground_truth)
if reward is None:
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
elif has_choices:
# For multiple choice, extract and compare choices
correct_choice = has_choices.upper()
student_choice = extract_choice(student_answer)
if student_choice:
reward = 1.0 if student_choice == correct_choice else 0.0
else:
# For text answers, use fuzzy matching
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
except Exception:
pass # Keep reward as 0.0 if all methods fail
return reward
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
# if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
if accu_reward_method == "mcq":
reward = mcq_reward(content, sol)
elif accu_reward_method == 'yes_no':
reward = yes_no_reward(content, sol)
elif accu_reward_method == 'llm':
reward = llm_reward(content, sol)
elif accu_reward_method == 'map':
reward = map_reward(content, sol)
elif accu_reward_method == 'math':
reward = math_reward(content, sol)
else:
reward = default_accuracy_reward(content, sol)
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
problem = kwargs.get("problem")[0]
if reward <= 1.0: # this condition can be changed for debug
with open(log_path, "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
f.write(f"accu_reward_method: {accu_reward_method}\n")
f.write(f"image_path: {image_path}\n")
f.write(f"problem: {problem}\n")
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Format reward -------------\n")
for content, match in zip(completion_contents, matches):
f.write(f"Content: {content}\n")
f.write(f"Has format: {bool(match)}\n")
return [1.0 if match else 0.0 for match in matches]
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
@dataclass
class GRPOModelConfig(ModelConfig):
freeze_vision_modules: bool = False
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def get_vlm_module(model_name_or_path):
if "qwen" in model_name_or_path.lower():
return Qwen2VLModule
elif "internvl" in model_name_or_path.lower():
return InvernVLModule
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
def main(script_args, training_args, model_args):
# Load the VLM module
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
question_prompt = vlm_module_cls.get_question_template(task_type="default")
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the JSONL datasets
import json
from datasets import Dataset
data_files = script_args.data_file_paths.split(":")
image_folders = script_args.image_folders.split(":")
if len(data_files) != len(image_folders):
raise ValueError("Number of data files must match number of image folders")
if script_args.reward_method is None:
accu_reward_methods = ["default"] * len(data_files)
else:
accu_reward_methods = script_args.reward_method.split(":")
assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
if len(data_files) != len(image_folders):
raise ValueError("Number of data files must match number of image folders")
all_data = []
for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
with open(data_file, 'r') as f:
for line in f:
item = json.loads(line)
if 'image' in item:
if isinstance(item['image'], str):
# Store image path instead of loading the image
item['image_path'] = [os.path.join(image_folder, item['image'])]
del item['image'] # remove the image column so that it can be loaded later
elif isinstance(item['image'], list):
# if the image is a list, then it is a list of images (for multi-image input)
item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
del item['image'] # remove the image column so that it can be loaded later
else:
raise ValueError(f"Unsupported image type: {type(item['image'])}")
# Remove immediate image loading
item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
# Handle solution that could be a float or string
solution_value = item['conversations'][1]['value']
if isinstance(solution_value, str):
item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
else:
# If it's a float or other non-string type, keep it as is
item['solution'] = str(solution_value)
del item['conversations']
item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
all_data.append(item)
dataset = Dataset.from_list(all_data)
def make_conversation_from_jsonl(example):
if 'image_path' in example and example['image_path'] is not None:
# Don't load image here, just store the path
return {
'image_path': [p for p in example['image_path']], # Store path instead of loaded image
'problem': example['problem'],
'solution': f"<answer> {example['solution']} </answer>",
'accu_reward_method': example['accu_reward_method'],
'prompt': [{
'role': 'user',
'content': [
*({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
]
}]
}
else:
return {
'problem': example['problem'],
'solution': f"<answer> {example['solution']} </answer>",
'accu_reward_method': example['accu_reward_method'],
'prompt': [{
'role': 'user',
'content': [
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
]
}]
}
# Map the conversations
dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
# Split dataset for validation if requested
splits = {'train': dataset}
if script_args.val_split_ratio > 0:
train_val_split = dataset.train_test_split(
test_size=script_args.val_split_ratio
)
splits['train'] = train_val_split['train']
splits['validation'] = train_val_split['test']
# Select trainer class based on vlm_trainer argument
trainer_cls = VLMGRPOTrainer
print("using trainer:", trainer_cls.__name__)
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
vlm_module=vlm_module_cls(),
train_dataset=splits['train'],
eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
freeze_vision_modules=model_args.freeze_vision_modules,
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
)
# Train and push the model to the Hub
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

291
qw/open_r1/grpo_rec.py Executable file
View File

@ -0,0 +1,291 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import debugpy
# try:
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
# debugpy.listen(("localhost", 9501))
# print("Waiting for debugger attach")
# debugpy.wait_for_client()
# except Exception as e:
# pass
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
from PIL import Image
from torch.utils.data import Dataset
from transformers import Qwen2VLForConditionalGeneration
from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from open_r1.vlm_modules import *
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
from transformers import TrainingArguments
import yaml
import json
import random
import math
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
import torch
from typing import Tuple
def custom_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().float()
sin = emb.sin().float()
else:
cos, sin = position_embeddings
# Add this
cos = cos.to(torch.float)
sin = sin.to(torch.float)
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
# ----------------------- Main Script -----------------------
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
)
max_anyres_num: Optional[int] = field(
default=12,
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
)
image_root: Optional[str] = field(
default=None,
metadata={"help": "Root directory of the image"},
)
@dataclass
class GRPOModelConfig(ModelConfig):
freeze_vision_modules: bool = False
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
class LazySupervisedDataset(Dataset):
def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
super(LazySupervisedDataset, self).__init__()
self.script_args = script_args
self.list_data_dict = []
self.question_template = question_template
if data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
for data in datasets:
json_path = data.get("json_path")
sampling_strategy = data.get("sampling_strategy", "all")
sampling_number = None
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
raise ValueError(f"Unsupported file type: {data_path}")
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i):
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
QUESTION_TEMPLATE = self.question_template
def make_conversation_image(example):
return {
"prompt": [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
],
},
],
}
example = self.list_data_dict[i]
image_root = self.script_args.image_root
if 'image' in example:
image_path = os.path.join(image_root, example['image'])
# In case the image is not found
while not os.path.exists(image_path):
print(f"Warning: Image {image_path} not found, randomly selecting another image")
new_index = random.randint(0, len(self.list_data_dict)-1)
example = self.list_data_dict[new_index]
image_path = os.path.join(image_root, example['image'])
image = Image.open(image_path).convert("RGB")
else:
image = None
return {
'image': image,
'problem': example['problem'],
'solution': example['solution'],
'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
}
def get_vlm_module(model_name_or_path):
if "qwen" in model_name_or_path.lower():
return Qwen2VLModule
elif "internvl" in model_name_or_path.lower():
return InvernVLModule
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
def main(script_args, training_args, model_args):
# Load the VLM module
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
# Load the reward functions
reward_funcs_registry = {
"accuracy": vlm_module_cls.iou_reward,
"format": vlm_module_cls.format_reward_rec,
}
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the dataset
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
trainer_cls = VLMGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
vlm_module=vlm_module_cls(),
train_dataset=dataset,
eval_dataset=None,
peft_config=get_peft_config(model_args),
freeze_vision_modules=model_args.freeze_vision_modules,
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
max_anyres_num=script_args.max_anyres_num,
torch_dtype=model_args.torch_dtype,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

346
qw/open_r1/sft.py Executable file
View File

@ -0,0 +1,346 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Supervised fine-tuning script for decoder language models.
Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 4096 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--gradient_checkpointing \
--bf16 \
--logging_steps 5 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
"""
import logging
import os
import sys
import datasets
import torch
from torch.utils.data import Dataset
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, set_seed, AutoProcessor
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
import yaml
import json
import math
import random
from PIL import Image
from trl import (
ModelConfig,
ScriptArguments,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from dataclasses import field
from qwen_vl_utils import process_vision_info
logger = logging.getLogger(__name__)
from dataclasses import dataclass
@dataclass
class SFTScriptArguments(ScriptArguments):
image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
processor = None
class LazySupervisedDataset(Dataset):
def __init__(self, data_path: str, script_args: ScriptArguments):
super(LazySupervisedDataset, self).__init__()
self.script_args = script_args
self.list_data_dict = []
if data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
for data in datasets:
json_path = data.get("json_path")
sampling_strategy = data.get("sampling_strategy", "all")
sampling_number = None
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
raise ValueError(f"Unsupported file type: {data_path}")
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i):
# Format into conversation
def make_conversation_image(example):
image_root = self.script_args.image_root
# print(111, image_root)
# print(222, example['image'])
image_path = os.path.join(image_root, example['image'])
x1, y1, x2, y2 = example["solution"]
normal_caption = example["normal_caption"]
return [
{
"role": "user",
"content": [
{"type": "image", "image": f"file://{image_path}"},
{"type": "text", "text": example["problem"]},
],
},
{
"role": "assistant",
"content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
}
]
example = self.list_data_dict[i]
example["messages"] = make_conversation_image(example)
return example
def collate_fn(examples):
texts = [
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
for example in examples
]
image_inputs = []
for example in examples:
imgs, vids = process_vision_info(example["messages"])
image_inputs.append(imgs)
batch = processor(
text=texts,
images=image_inputs,
return_tensors="pt",
padding=True,
)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
################
# Load datasets
################
dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
################
# Load tokenizer
################
global processor
if "vl" in model_args.model_name_or_path.lower():
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
logger.info("Using AutoProcessor for vision-language model.")
else:
processor = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
logger.info("Using AutoTokenizer for text-only model.")
if hasattr(processor, "pad_token") and processor.pad_token is None:
processor.pad_token = processor.eos_token
elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
###################
# Model init kwargs
###################
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# training_args.model_init_kwargs = model_kwargs
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
if "Qwen2-VL" in model_args.model_name_or_path:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
elif "Qwen2.5-VL" in model_args.model_name_or_path:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
else:
raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
############################
# Initialize the SFT Trainer
############################
training_args.dataset_kwargs = {
"skip_prepare_dataset": True,
}
training_args.remove_unused_columns = False
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=None,
processing_class=processor.tokenizer,
data_collator=collate_fn,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name),
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
print(script_args)
main(script_args, training_args, model_args)

4
qw/open_r1/trainer/__init__.py Executable file
View File

@ -0,0 +1,4 @@
from .grpo_trainer import VLMGRPOTrainer
from .grpo_config import GRPOConfig
__all__ = ["VLMGRPOTrainer"]

View File

@ -0,0 +1,286 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class GRPOConfig(TrainingArguments):
r"""
Configuration class for the [`GRPOTrainer`].
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
[`~transformers.TrainingArguments`] documentation.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model and reference model
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`GRPOTrainer`] is provided as a string.
> Parameters that control the data preprocessing
remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
with vLLM generation.
> Parameters that control generation acceleration powered by vLLM
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
training has not already occupied all available GPUs. If only one device is available, the device will be
shared between both training and vLLM.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
support this feature.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
> Parameters that control the logging
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
"""
# Parameters that control the model and reference model
model_init_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
"argument of the `GRPOTrainer` is provided as a string."
},
)
# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
default=8,
metadata={
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
max_completion_length: Optional[int] = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
"is not compatible with vLLM generation."
},
)
# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
"unused for training, as vLLM will require one for generation. vLLM must be installed "
"(`pip install vllm`)."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
"will automatically select the next available GPU after the last one used for training. This assumes "
"that training has not already occupied all available GPUs."
},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
vllm_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
vllm_enable_prefix_caching: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
"the hardware support this feature."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
# Parameters that control the training
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
beta: float = field(
default=0.04,
metadata={
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
"training speed."
},
)
num_iterations: int = field(
default=1,
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
"rewards are weighted equally with weight `1.0`."
},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.6,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
},
)
ref_model_sync_steps: int = field(
default=512,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
# Parameters that control the logging
log_completions: bool = field(
default=False,
metadata={
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
},
)

View File

@ -0,0 +1,849 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union, Sized
import torch
import torch.utils.data
import transformers
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
AriaForConditionalGeneration,
AriaProcessor,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
from trl import GRPOTrainer
from accelerate.utils import is_peft_model, set_seed
import PIL.Image
import copy
from torch.utils.data import Sampler
import warnings
if is_peft_available():
from peft import PeftConfig, get_peft_model
if is_wandb_available():
import wandb
from open_r1.vlm_modules.vlm_module import VLMBaseModule
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset in a structured manner.
Args:
data_source (`Sized`):
Dataset to sample from.
mini_repeat_count (`int`):
Number of times to repeat each index per batch.
batch_size (`int`, *optional*, defaults to `1`):
Number of unique indices per batch.
repeat_count (`int`, *optional*, defaults to `1`):
Number of times to repeat the full sampling process.
seed (`int` or `None`, *optional*, defaults to `None`):
Random seed for reproducibility.
"""
def __init__(
self,
data_source: Sized,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
seed: Optional[int] = None,
):
self.data_source = data_source
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
def __iter__(self):
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
for chunk in indexes:
for _ in range(self.repeat_count):
for index in chunk:
for _ in range(self.mini_repeat_count):
yield index
def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count
class VLMGRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
Example:
```python
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
train_dataset=dataset,
)
trainer.train()
```
Args:
model (`Union[str, PreTrainedModel]`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
- A single reward function, such as:
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- A custom reward function: The function is provided with the prompts and the generated completions,
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
[Using a custom reward function](#using-a-custom-reward-function).
- A list of reward functions, where each item can independently be any of the above types. Mixing different
types within the list (e.g., a string model ID and a custom reward function) is allowed.
args ([`GRPOConfig`], *optional*, defaults to `None`):
Configuration for this trainer. If `None`, a default configuration is used.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
ignored. The format of the samples can be either:
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
- A single processing class: Used when `reward_funcs` contains only one reward function.
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
the corresponding entries in `reward_processing_classes` are ignored.
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
List of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None,
vlm_module: VLMBaseModule = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
freeze_vision_modules: Optional[bool] = False,
attn_implementation: str = "flash_attention_2",
torch_dtype: str = "bfloat16",
**kwargs,
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")
self.vlm_module = vlm_module
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
# FIXME
# Remember to modify it in the invernvl
model_init_kwargs["attn_implementation"] = attn_implementation
if model_init_kwargs.get("torch_dtype") is None:
model_init_kwargs["torch_dtype"] = torch_dtype
assert isinstance(model, str), "model must be a string in the current implementation"
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs)
model = model_cls.from_pretrained(model_id, **model_init_kwargs)
# LoRA
self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords()
if peft_config is not None:
def find_all_linear_names(model, multimodal_keywords):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
# LoRA is not applied to the vision modules
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
lora_module_names.add(name)
for m in lora_module_names: # needed for 16-bit
if "embed_tokens" in m:
lora_module_names.remove(m)
return list(lora_module_names)
target_modules = find_all_linear_names(model, self.vision_modules_keywords)
peft_config.target_modules = target_modules
model = get_peft_model(model, peft_config)
# Freeze vision modules
if freeze_vision_modules:
print("Freezing vision modules...")
for n, p in model.named_parameters():
if any(keyword in n for keyword in self.vision_modules_keywords):
p.requires_grad = False
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
# Processing class
if processing_class is None:
processing_cls = self.vlm_module.get_processing_class()
processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None))
for processing_keyword in self.vlm_module.get_custom_processing_keywords():
if processing_keyword in kwargs:
setattr(processing_class, processing_keyword, kwargs[processing_keyword])
if getattr(processing_class, "tokenizer", None) is not None:
pad_token_id = processing_class.tokenizer.pad_token_id
processing_class.pad_token_id = pad_token_id
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
else:
assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
pad_token_id = processing_class.pad_token_id
self.vlm_module.post_model_init(model, processing_class)
self.vlm_module.post_model_init(self.ref_model, processing_class)
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_prompt_length = None
if args.max_prompt_length is not None:
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=1,
pad_token_id=pad_token_id,
)
if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL
self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class)
print(222, self.vlm_module.get_eos_token_id(processing_class))
self.beta = args.beta
self.epsilon = args.epsilon
# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates
self._buffered_inputs = [None] * args.gradient_accumulation_steps
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics = defaultdict(list)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
)
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
f"batch size, the valid values for the number of generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
f"eval batch size, the valid values for the number of generations are: {possible_values}."
)
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
# it's safer to set it in all cases.
set_seed(args.seed, device_specific=True)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
try:
model.gradient_checkpointing_enable()
except:
# For InternVL; these operations are copied from the original training script of InternVL
model.language_model.config.use_cache = False
model.vision_model.gradient_checkpointing = True
model.vision_model.encoder.gradient_checkpointing = True
model.language_model._set_gradient_checkpointing()
# This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
args.gradient_checkpointing = False
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt"]
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def _prepare_inputs(self, inputs):
# Simple pass-through, just like original
return inputs
def _get_key_from_inputs(self, x, key):
ele = x.get(key, None)
assert ele is not None, f"The key {key} is not found in the input"
if isinstance(ele, list):
return [e for e in ele]
else:
return [ele]
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs)
# Handle both pre-loaded images and image paths
images = []
for x in inputs:
if "image" in x:
imgs = self._get_key_from_inputs(x, "image")
elif "image_path" in x and x["image_path"] is not None:
imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")]
for img in imgs:
try:
# Ensure minimum dimensions of 28 pixels
w, h = img.size
if w < 28 or h < 28:
# Calculate new dimensions maintaining aspect ratio
if w < h:
new_w = 28
new_h = int(h * (28/w))
else:
new_h = 28
new_w = int(w * (28/h))
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
except:
pass
images.append(img)
prompt_inputs = self.vlm_module.prepare_model_inputs(
self.processing_class,
prompts_text,
images,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
# max_prompt_length is not supported yet
# if self.max_prompt_length is not None:
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
# prompt_inputs["input_ids"] = prompt_ids
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# prompt_inputs["attention_mask"] = prompt_mask
# Generate completions
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
generate_returned_result = unwrapped_model.generate(
**{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()},
generation_config=self.generation_config
)
prompt_length = prompt_ids.size(1)
if not self.vlm_module.is_embeds_input():
prompt_completion_ids = generate_returned_result
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
else:
# In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
# So the returned result of the `generate` method only contains the completion ids
completion_ids = generate_returned_result
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
# Get the multimodal inputs
multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords()
multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
model, prompt_completion_ids, attention_mask, **multimodal_inputs
)
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
else:
old_per_token_logps = None
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
model, prompt_completion_ids, attention_mask, **multimodal_inputs
)
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
# No need to duplicate prompts as we're not generating multiple completions per prompt
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# No need to duplicate prompts as we're not generating multiple completions per prompt
# reward_kwargs[key].extend([example[key]] * self.num_generations)
reward_kwargs[key].extend([example[key]])
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Gather rewards across processes
rewards_per_func = self.accelerator.gather(rewards_per_func)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
# Each group consists of num_generations completions for the same prompt
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Get only the local slice of advantages
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": old_per_token_logps,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
"multimodal_inputs": multimodal_inputs
}
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Check if we need to generate new completions or use buffered ones
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs, model)
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
else:
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
self._step += 1
# Get the prepared inputs
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
multimodal_inputs = inputs["multimodal_inputs"]
# Concatenate for full sequence
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# Get the current policy's log probabilities
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
# Get the advantages from inputs
advantages = inputs["advantages"]
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
# and use per_token_logps.detach() instead
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
# Compute the policy ratio and clipped version
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
# Add KL penalty if beta > 0
if self.beta > 0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
per_token_loss = per_token_loss + self.beta * per_token_kl
# Log KL divergence
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
# Compute final loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log clip ratio
is_clipped = (per_token_loss1 < per_token_loss2).float()
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or []
if isinstance(tags, str):
tags = [tags]
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")
citation = textwrap.dedent(
"""\
@article{zhihong2024deepseekmath,
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year = 2024,
eprint = {arXiv:2402.03300},
"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GRPO",
trainer_citation=citation,
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
paper_id="2402.03300",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))
def _get_train_sampler(self) -> Sampler:
"""Returns a sampler that ensures proper data sampling for GRPO training."""
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
seed=self.args.seed,
)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
"""Returns a sampler for evaluation."""
return RepeatRandomSampler(
data_source=eval_dataset,
mini_repeat_count=self.num_generations,
seed=self.args.seed,
)

View File

@ -0,0 +1,825 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from accelerate.utils.other import is_compiled_module
from accelerate.utils import broadcast_object_list, gather, gather_object
import torch
import torch.utils.data
import transformers
import warnings
from unittest.mock import patch
from datasets import Dataset, IterableDataset
from packaging import version
from transformers import (
AriaForConditionalGeneration,
AriaProcessor,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.import_utils import is_vllm_available
from trl.models import (
create_reference_model,
prepare_deepspeed,
unwrap_model_for_generation,
)
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
from trl import GRPOTrainer
import copy
if is_peft_available():
from peft import PeftConfig, get_peft_model
if is_vllm_available():
from vllm import LLM, SamplingParams
if is_wandb_available():
import wandb
import torch.nn as nn
from torch.utils.data import Sampler
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
```
"""
def __init__(self, data_source, repeat_count: int):
self.data_source = data_source
self.repeat_count = repeat_count
self.num_samples = len(data_source)
def __iter__(self):
indexes = [
idx
for idx in torch.randperm(self.num_samples).tolist()
for _ in range(self.repeat_count)
]
return iter(indexes)
def __len__(self):
return self.num_samples * self.repeat_count
class Qwen2VLGRPOVLLMTrainer(Trainer):
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
] = (None, None),
peft_config: Optional["PeftConfig"] = None,
# qwen2-vl related params
max_pixels: Optional[int] = 12845056,
min_pixels: Optional[int] = 3136,
attn_implementation: str = "flash_attention_2",
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
model_init_kwargs["attn_implementation"] = attn_implementation
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if (
isinstance(torch_dtype, torch.dtype)
or torch_dtype == "auto"
or torch_dtype is None
):
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False
if args.gradient_checkpointing
else model_init_kwargs.get("use_cache")
)
if "Qwen2-VL" in model_id:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model, **model_init_kwargs
)
elif "Qwen2.5-VL" in model_id:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
elif "Aria" in model_id:
model_init_kwargs.pop("use_cache")
model = AriaForConditionalGeneration.from_pretrained(
model, **model_init_kwargs
)
else:
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
if peft_config is not None:
model = get_peft_model(model, peft_config)
# Reference model
if is_deepspeed_zero3_enabled():
if "Qwen2-VL" in model_id:
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id, **model_init_kwargs
)
elif "Aria" in model_id:
self.ref_model = AriaForConditionalGeneration.from_pretrained(
model_id, **model_init_kwargs
)
else:
self.ref_model = AutoModelForCausalLM.from_pretrained(
model_id, **model_init_kwargs
)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
# Processing class
if processing_class is None:
if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
processing_class = AutoProcessor.from_pretrained(model_id)
pad_token_id = processing_class.tokenizer.pad_token_id
processing_class.pad_token_id = pad_token_id
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
processing_class.image_processor.max_pixels = max_pixels
processing_class.image_processor.min_pixels = min_pixels
else:
processing_class = AutoTokenizer.from_pretrained(
model.config._name_or_path, padding_side="left"
)
pad_token_id = processing_class.pad_token_id
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError(
"The number of reward processing classes must match the number of reward functions."
)
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs)
):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(
reward_func.config._name_or_path
)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = (
reward_processing_class.eos_token
)
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = (
args.max_completion_length
) # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=1, # HACK
num_return_sequences=self.num_generations,
pad_token_id=pad_token_id,
)
self.beta = args.beta
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics = defaultdict(list)
self.use_vllm = args.use_vllm
# # rewrite the processing AutoTokenizer -> AutoProcessor
# model_id = model if isinstance(model, str) else model.config._name_or_path
# if processing_class is None:
# if "Qwen2-VL" in model_id or "Aria" in model_id:
# processing_class = AutoProcessor.from_pretrained(model_id)
# pad_token_id = processing_class.tokenizer.pad_token_id
# processing_class.pad_token_id = pad_token_id
# processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
# if "Qwen2-VL" in model_id:
# processing_class.image_processor.max_pixels = max_pixels
# processing_class.image_processor.min_pixels = min_pixels
# else:
# processing_class = AutoTokenizer.from_pretrained(
# model.config._name_or_path, padding_side="left"
# )
# pad_token_id = processing_class.pad_token_id
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
f"batch size, the valid values for the number of generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [
n_gen
for n_gen in range(2, global_batch_size + 1)
if (global_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
f"eval batch size, the valid values for the number of generations are: {possible_values}."
)
if self.use_vllm:
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)
if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
if vllm_device == "auto":
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if (
vllm_device.split(":")[0] == "cuda"
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
):
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
}:
warnings.warn(
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
"behavior. It is recommended to use a dedicated device for vLLM."
)
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
# setting (profiling_patch).
world_size_patch = patch(
"torch.distributed.get_world_size", return_value=1
)
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
return_value=None,
)
with world_size_patch, profiling_patch:
print("vllm is running on: ", vllm_device)
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
dtype=torch.bfloat16,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=True,
enforce_eager=True,
max_model_len=args.max_completion_length,
)
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=self.max_completion_length,
)
self._last_loaded_step = (
0 # tag to avoid useless loading during grad accumulation
)
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
raise ValueError(
"Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True
)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt"]
# We need a custom sampler that samples the same prompt multiple times
def _get_train_sampler(self):
return RepeatRandomSampler(self.train_dataset, self.num_generations)
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(
self,
model,
input_ids,
attention_mask,
pixel_values,
image_grid_thw,
logits_to_keep,
):
pixel_values = pixel_values.to(model.device)
image_grid_thw = image_grid_thw.to(device=model.device)
logits = model(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
).logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[
:, -logits_to_keep:
] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
logits = logits[:, -logits_to_keep:]
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(
log_probs, dim=1, index=input_ids_row.unsqueeze(1)
).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
def _prepare_inputs(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
images = [x["image"] for x in inputs]
prompts_text = [
maybe_apply_chat_template(example, self.processing_class)["prompt"]
for example in inputs
]
prompt_inputs = self.processing_class(
# prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
text=prompts_text,
images=images,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
)
prompt_ids, prompt_mask = (
prompt_inputs["input_ids"].to(device),
prompt_inputs["attention_mask"].to(device),
)
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
all_images = gather_object(images)
# group into pairs
all_multimodal_inputs = [
{"prompt": p, "multi_modal_data": {"image": i}}
for p, i in zip(all_prompts_text, all_images)
]
if self.accelerator.is_main_process:
outputs = self.llm.generate(
all_multimodal_inputs,
sampling_params=self.sampling_params,
use_tqdm=False,
)
completion_ids = [
out.token_ids
for completions in outputs
for out in completions.outputs
]
else:
completion_ids = [None] * len(all_prompts_text)
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [
torch.tensor(ids, device=device) for ids in completion_ids
]
completion_ids = pad(
completion_ids, padding_value=self.processing_class.pad_token_id
)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
raise ValueError("Only vLLM generation is supported in this version ")
# below are the same with yifan's code
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
device = self.accelerator.device
eos_idx = torch.full(
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
is_eos.size(0), -1
)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
# pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
# self.num_generations, dim=0
# )
pixel_values = prompt_inputs["pixel_values"]
# [None].repeat_interleave(self.num_generations, dim=0)
# pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
image_grid_thw = prompt_inputs["image_grid_thw"]
# .repeat_interleave(
# self.num_generations, dim=0
# )
logits_to_keep = completion_ids.size(1)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model,
prompt_completion_ids,
attention_mask,
pixel_values,
image_grid_thw,
logits_to_keep,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model,
prompt_completion_ids,
attention_mask,
pixel_values,
image_grid_thw,
logits_to_keep,
)
# Decode the generated completions
completions = self.processing_class.batch_decode(
completion_ids, skip_special_tokens=True
)
if is_conversational(inputs[0]):
completions = [
[{"role": "assistant", "content": completion}]
for completion in completions
]
# Compute the rewards
rewards_per_func = torch.zeros(
len(prompts), len(self.reward_funcs), device=device
)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=False,
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
:, 0
] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {
key: []
for key in inputs[0].keys()
if key not in ["prompt", "completion"]
}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(
prompts=prompts, completions=completions, **reward_kwargs
)
rewards_per_func[:, i] = torch.tensor(
output_reward_func, dtype=torch.float32, device=device
)
rewards_per_func = gather(rewards_per_func)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.num_generations, dim=0
)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
reward_per_func = rewards_per_func.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(
reward_per_func[i].item()
)
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
}
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
pixel_values = inputs["pixel_values"]
image_grid_thw = inputs["image_grid_thw"]
logits_to_keep = completion_ids.size(
1
) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(
model,
input_ids,
attention_mask,
pixel_values,
image_grid_thw,
logits_to_keep,
)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps)
- (ref_per_token_logps - per_token_logps)
- 1
)
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(
per_token_logps - per_token_logps.detach()
) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = (
(per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
).mean()
# Log the metrics
completion_length = (
self.accelerator.gather_for_metrics(completion_mask.sum(1))
.float()
.mean()
.item()
)
self._metrics["completion_length"].append(completion_length)
mean_kl = (
(per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
).mean()
self._metrics["kl"].append(
self.accelerator.gather_for_metrics(mean_kl).mean().item()
)
return loss
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if next(iter(logs.keys())).startswith("eval_"):
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()

0
qw/open_r1/utils/__init__.py Executable file
View File

86
qw/open_r1/utils/callbacks.py Executable file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from typing import List
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from .evaluation import run_benchmark_jobs
from .hub import push_to_hub_revision
def is_slurm_available() -> bool:
# returns true if a slurm queueing system is available
try:
subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return True
except FileNotFoundError:
return False
class DummyConfig:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, model_config) -> None:
self.model_config = model_config
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_world_process_zero:
global_step = state.global_step
# WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround
# Also if you instantiate a new SFTConfig, the accelerator dist state will be broken
dummy_config = DummyConfig(
hub_model_id=args.hub_model_id,
hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}",
output_dir=f"{args.output_dir}/checkpoint-{global_step}",
system_prompt=args.system_prompt,
)
future = push_to_hub_revision(
dummy_config, extra_ignore_patterns=["*.pt"]
) # don't push the optimizer states
if is_slurm_available():
dummy_config.benchmarks = args.benchmarks
def run_benchmark_callback(_):
print(f"Checkpoint {global_step} pushed to hub.")
run_benchmark_jobs(dummy_config, self.model_config)
future.add_done_callback(run_benchmark_callback)
CALLBACKS = {
"push_to_hub_revision": PushToHubRevisionCallback,
}
def get_callbacks(train_config, model_config) -> List[TrainerCallback]:
callbacks = []
for callback_name in train_config.callbacks:
if callback_name not in CALLBACKS:
raise ValueError(f"Callback {callback_name} not found in CALLBACKS.")
callbacks.append(CALLBACKS[callback_name](model_config))
return callbacks

105
qw/open_r1/utils/evaluation.py Executable file
View File

@ -0,0 +1,105 @@
import subprocess
from typing import TYPE_CHECKING, Dict, Union
from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
if TYPE_CHECKING:
from trl import GRPOConfig, SFTConfig, ModelConfig
import os
# We need a special environment setup to launch vLLM from within Slurm training jobs.
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
user_home_directory = os.path.expanduser("~")
VLLM_SLURM_PREFIX = [
"env",
"-i",
"bash",
"-c",
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ",
]
def register_lighteval_task(
configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
):
"""Registers a LightEval task configuration.
- Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl
- Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks
Args:
configs (Dict[str, str]): The dictionary to store the task configuration.
eval_suite (str, optional): The evaluation suite.
task_name (str): The name of the task.
task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0".
num_fewshot (int, optional): The number of few-shot examples. Defaults to 0.
is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False.
"""
# Format task list in lighteval format
task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(","))
configs[task_name] = task_list
LIGHTEVAL_TASKS = {}
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25_part1", "aime25:part1", 0)
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
def get_lighteval_tasks():
return list(LIGHTEVAL_TASKS.keys())
SUPPORTED_BENCHMARKS = get_lighteval_tasks()
def run_lighteval_job(
benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
) -> None:
task_list = LIGHTEVAL_TASKS[benchmark]
model_name = training_args.hub_model_id
model_revision = training_args.hub_model_revision
# For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM
num_gpus = get_gpu_count_for_vllm(model_name, model_revision)
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
tensor_parallel = True
else:
tensor_parallel = False
cmd = VLLM_SLURM_PREFIX.copy()
cmd_args = [
f"--gres=gpu:{num_gpus}",
f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}",
"slurm/evaluate.slurm",
benchmark,
f'"{task_list}"',
model_name,
model_revision,
f"{tensor_parallel}",
f"{model_args.trust_remote_code}",
]
if training_args.system_prompt is not None:
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
cmd[-1] += " " + " ".join(cmd_args)
subprocess.run(cmd, check=True)
def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None:
benchmarks = training_args.benchmarks
if len(benchmarks) == 1 and benchmarks[0] == "all":
benchmarks = get_lighteval_tasks()
# Evaluate on all supported benchmarks. Later we may want to include a `chat` option
# that just evaluates on `ifeval` and `mt_bench` etc.
for benchmark in benchmarks:
print(f"Launching benchmark `{benchmark}`")
if benchmark in get_lighteval_tasks():
run_lighteval_job(benchmark, training_args, model_args)
else:
raise ValueError(f"Unknown benchmark {benchmark}")

131
qw/open_r1/utils/hub.py Executable file
View File

@ -0,0 +1,131 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from concurrent.futures import Future
from transformers import AutoConfig
from huggingface_hub import (
create_branch,
create_repo,
get_safetensors_metadata,
list_repo_commits,
list_repo_files,
list_repo_refs,
repo_exists,
upload_folder,
)
from trl import GRPOConfig, SFTConfig
logger = logging.getLogger(__name__)
def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future:
"""Pushes the model to branch on a Hub repo."""
# Create a repo if it doesn't exist yet
repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
# Get initial commit to branch from
initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
# Now create the branch we'll be pushing to
create_branch(
repo_id=training_args.hub_model_id,
branch=training_args.hub_model_revision,
revision=initial_commit.commit_id,
exist_ok=True,
)
logger.info(f"Created target repo at {repo_url}")
logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...")
ignore_patterns = ["checkpoint-*", "*.pth"]
ignore_patterns.extend(extra_ignore_patterns)
future = upload_folder(
repo_id=training_args.hub_model_id,
folder_path=training_args.output_dir,
revision=training_args.hub_model_revision,
commit_message=f"Add {training_args.hub_model_revision} checkpoint",
ignore_patterns=ignore_patterns,
run_as_future=True,
)
logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!")
return future
def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
"""Checks if a given Hub revision exists."""
if repo_exists(training_args.hub_model_id):
if training_args.push_to_hub_revision is True:
# First check if the revision exists
revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches]
# If the revision exists, we next check it has a README file
if training_args.hub_model_revision in revisions:
repo_files = list_repo_files(
repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
)
if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
raise ValueError(
f"Revision {training_args.hub_model_revision} already exists. "
"Use --overwrite_hub_revision to overwrite it."
)
def get_param_count_from_repo_id(repo_id: str) -> int:
"""Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
try:
metadata = get_safetensors_metadata(repo_id)
return list(metadata.parameter_count.values())[0]
except Exception:
# Pattern to match products (like 8x7b) and single values (like 42m)
pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
matches = re.findall(pattern, repo_id.lower())
param_counts = []
for full_match, number1, _, _, number2, _, unit in matches:
if number2: # If there's a second number, it's a product
number = float(number1) * float(number2)
else: # Otherwise, it's a single value
number = float(number1)
if unit == "b":
number *= 1_000_000_000 # Convert to billion
elif unit == "m":
number *= 1_000_000 # Convert to million
param_counts.append(number)
if len(param_counts) > 0:
# Return the largest number
return int(max(param_counts))
else:
# Return -1 if no match found
return -1
def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int:
"""vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs.
This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model.
"""
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
# Get number of attention heads
num_heads = config.num_attention_heads
# Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus
while num_heads % num_gpus != 0 or 64 % num_gpus != 0:
logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus")
num_gpus -= 1
return num_gpus

220
qw/open_r1/utils/math.py Normal file
View File

@ -0,0 +1,220 @@
from math_verify import parse, verify
def compute_score(solution_str, ground_truth) -> float:
retval = 0.
if solution_str == ground_truth:
return 1.0
if float(verify(parse(solution_str), parse(ground_truth))) > 0:
return 1.0
try:
answer = solution_str
string_in_last_boxed = last_boxed_only_string(solution_str)
if string_in_last_boxed is not None:
answer = remove_boxed(string_in_last_boxed)
if is_equiv(answer, ground_truth):
return 1.0
except Exception as e:
print(e)
return retval
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[:len(left)] == left
return s[len(left):]
left = "\\boxed{"
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string

View File

@ -0,0 +1,398 @@
import json
import time
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
import numpy as np
import copy
import itertools
#from . import mask as maskUtils
import os
from collections import defaultdict
import sys
PYTHON_VERSION = sys.version_info[0]
if PYTHON_VERSION == 2:
from urllib import urlretrieve
elif PYTHON_VERSION == 3:
from urllib.request import urlretrieve
def _isArrayLike(obj):
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
class COCO:
def __init__(self, annotation_file=None):
"""
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
:param annotation_file (str): location of annotation file
:param image_folder (str): location to the folder that hosts images.
:return:
"""
# load dataset
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
if not annotation_file == None:
# print('loading annotations into memory...')
tic = time.time()
if type(annotation_file) == dict:
dataset = annotation_file
else:
dataset = json.load(open(annotation_file, 'r'))
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
# print('Done (t={:0.2f}s)'.format(time.time()- tic))
self.dataset = dataset
self.createIndex()
def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann
if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img
if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat
if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])
# print('index created!')
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
def info(self):
"""
Print information about the annotation file.
:return:
"""
for key, value in self.dataset['info'].items():
print('{}: {}'.format(key, value))
def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
"""
Get ann ids that satisfy given filter conditions. default skips that filter
:param imgIds (int array) : get anns for given imgs
catIds (int array) : get anns for given cats
areaRng (float array) : get anns for given area range (e.g. [0 inf])
iscrowd (boolean) : get anns for given crowd label (False or True)
:return: ids (int array) : integer array of ann ids
"""
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
catIds = catIds if _isArrayLike(catIds) else [catIds]
if len(imgIds) == len(catIds) == len(areaRng) == 0:
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
anns = list(itertools.chain.from_iterable(lists))
else:
anns = self.dataset['annotations']
anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
if not iscrowd == None:
ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
else:
ids = [ann['id'] for ann in anns]
return ids
def getCatIds(self, catNms=[], supNms=[], catIds=[]):
"""
filtering parameters. default skips that filter.
:param catNms (str array) : get cats for given cat names
:param supNms (str array) : get cats for given supercategory names
:param catIds (int array) : get cats for given cat ids
:return: ids (int array) : integer array of cat ids
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
supNms = supNms if _isArrayLike(supNms) else [supNms]
catIds = catIds if _isArrayLike(catIds) else [catIds]
if len(catNms) == len(supNms) == len(catIds) == 0:
cats = self.dataset['categories']
else:
cats = self.dataset['categories']
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
ids = [cat['id'] for cat in cats]
return ids
def getImgIds(self, imgIds=[], catIds=[]):
'''
Get img ids that satisfy given filter conditions.
:param imgIds (int array) : get imgs for given ids
:param catIds (int array) : get imgs with all given cats
:return: ids (int array) : integer array of img ids
'''
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
catIds = catIds if _isArrayLike(catIds) else [catIds]
if len(imgIds) == len(catIds) == 0:
ids = self.imgs.keys()
else:
ids = set(imgIds)
for i, catId in enumerate(catIds):
if i == 0 and len(ids) == 0:
ids = set(self.catToImgs[catId])
else:
ids &= set(self.catToImgs[catId])
return list(ids)
def loadAnns(self, ids=[]):
"""
Load anns with the specified ids.
:param ids (int array) : integer ids specifying anns
:return: anns (object array) : loaded ann objects
"""
if _isArrayLike(ids):
return [self.anns[id] for id in ids]
elif type(ids) == int:
return [self.anns[ids]]
def loadCats(self, ids=[]):
"""
Load cats with the specified ids.
:param ids (int array) : integer ids specifying cats
:return: cats (object array) : loaded cat objects
"""
if _isArrayLike(ids):
return [self.cats[id] for id in ids]
elif type(ids) == int:
return [self.cats[ids]]
def loadImgs(self, ids=[]):
"""
Load anns with the specified ids.
:param ids (int array) : integer ids specifying img
:return: imgs (object array) : loaded img objects
"""
if _isArrayLike(ids):
return [self.imgs[id] for id in ids]
elif type(ids) == int:
return [self.imgs[ids]]
def showAnns(self, anns, draw_bbox=False):
"""
Display the specified annotations.
:param anns (array of object): annotations to display
:return: None
"""
if len(anns) == 0:
return 0
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
datasetType = 'instances'
elif 'caption' in anns[0]:
datasetType = 'captions'
else:
raise Exception('datasetType not supported')
if datasetType == 'instances':
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
if 'segmentation' in ann:
if type(ann['segmentation']) == list:
# polygon
for seg in ann['segmentation']:
poly = np.array(seg).reshape((int(len(seg)/2), 2))
polygons.append(Polygon(poly))
color.append(c)
else:
# mask
t = self.imgs[ann['image_id']]
if type(ann['segmentation']['counts']) == list:
rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
else:
rle = [ann['segmentation']]
m = maskUtils.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
if ann['iscrowd'] == 1:
color_mask = np.array([2.0,166.0,101.0])/255
if ann['iscrowd'] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
if 'keypoints' in ann and type(ann['keypoints']) == list:
# turn skeleton into zero-based index
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
kp = np.array(ann['keypoints'])
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk]>0):
plt.plot(x[sk],y[sk], linewidth=3, color=c)
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
if draw_bbox:
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
np_poly = np.array(poly).reshape((4,2))
polygons.append(Polygon(np_poly))
color.append(c)
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
ax.add_collection(p)
elif datasetType == 'captions':
for ann in anns:
print(ann['caption'])
def loadRes(self, resFile):
"""
Load result file and return a result api object.
:param resFile (str) : file name of result file
:return: res (obj) : result api object
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]
# print('Loading and preparing results...')
tic = time.time()
if type(resFile) == str or (PYTHON_VERSION == 2 and type(resFile) == unicode):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, 'results in not an array of objects'
annsImgIds = [ann['image_id'] for ann in anns]
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
'Results do not correspond to current coco set'
if 'caption' in anns[0]:
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
for id, ann in enumerate(anns):
ann['id'] = id+1
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
bb = ann['bbox']
x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
if not 'segmentation' in ann:
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann['area'] = bb[2]*bb[3]
ann['id'] = id+1
ann['iscrowd'] = 0
elif 'segmentation' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann['area'] = maskUtils.area(ann['segmentation'])
if not 'bbox' in ann:
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
ann['id'] = id+1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
ann['area'] = (x1-x0)*(y1-y0)
ann['id'] = id + 1
ann['bbox'] = [x0,y0,x1-x0,y1-y0]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
res.dataset['annotations'] = anns
res.createIndex()
return res
def download(self, tarDir = None, imgIds = [] ):
'''
Download COCO images from mscoco.org server.
:param tarDir (str): COCO results directory name
imgIds (list): images to be downloaded
:return:
'''
if tarDir is None:
print('Please specify target directory')
return -1
if len(imgIds) == 0:
imgs = self.imgs.values()
else:
imgs = self.loadImgs(imgIds)
N = len(imgs)
if not os.path.exists(tarDir):
os.makedirs(tarDir)
for i, img in enumerate(imgs):
tic = time.time()
fname = os.path.join(tarDir, img['file_name'])
if not os.path.exists(fname):
urlretrieve(img['coco_url'], fname)
print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
def loadNumpyAnnotations(self, data):
"""
Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
:param data (numpy.ndarray)
:return: annotations (python nested list)
"""
print('Converting ndarray to lists...')
assert(type(data) == np.ndarray)
print(data.shape)
assert(data.shape[1] == 7)
N = data.shape[0]
ann = []
for i in range(N):
if i % 1000000 == 0:
print('{}/{}'.format(i,N))
ann += [{
'image_id' : int(data[i, 0]),
'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
'score' : data[i, 5],
'category_id': int(data[i, 6]),
}]
return ann
def annToRLE(self, ann):
"""
Convert annotation which can be polygons, uncompressed RLE to RLE.
:return: binary mask (numpy 2D array)
"""
t = self.imgs[ann['image_id']]
h, w = t['height'], t['width']
segm = ann['segmentation']
if type(segm) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(segm, h, w)
rle = maskUtils.merge(rles)
elif type(segm['counts']) == list:
# uncompressed RLE
rle = maskUtils.frPyObjects(segm, h, w)
else:
# rle
rle = ann['segmentation']
return rle
def annToMask(self, ann):
"""
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
:return: binary mask (numpy 2D array)
"""
rle = self.annToRLE(ann)
m = maskUtils.decode(rle)
return m

View File

@ -0,0 +1,532 @@
import numpy as np
import datetime
import time
from collections import defaultdict
from pycocotools import mask as maskUtils
import copy
class COCOeval:
# Interface for evaluating detection on the Microsoft COCO dataset.
#
# The usage for CocoEval is as follows:
# cocoGt=..., cocoDt=... # load dataset and results
# E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
# E.params.recThrs = ...; # set parameters as desired
# E.evaluate(); # run per image evaluation
# E.accumulate(); # accumulate per image results
# E.summarize(); # display summary metrics of results
# For example usage see evalDemo.m and http://mscoco.org/.
#
# The evaluation parameters are as follows (defaults in brackets):
# imgIds - [all] N img ids to use for evaluation
# catIds - [all] K cat ids to use for evaluation
# iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
# recThrs - [0:.01:1] R=101 recall thresholds for evaluation
# areaRng - [...] A=4 object area ranges for evaluation
# maxDets - [1 10 100] M=3 thresholds on max detections per image
# iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
# iouType replaced the now DEPRECATED useSegm parameter.
# useCats - [1] if true use category labels for evaluation
# Note: if useCats=0 category labels are ignored as in proposal scoring.
# Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
#
# evaluate(): evaluates detections on every image and every category and
# concats the results into the "evalImgs" with fields:
# dtIds - [1xD] id for each of the D detections (dt)
# gtIds - [1xG] id for each of the G ground truths (gt)
# dtMatches - [TxD] matching gt id at each IoU or 0
# gtMatches - [TxG] matching dt id at each IoU or 0
# dtScores - [1xD] confidence of each dt
# gtIgnore - [1xG] ignore flag for each gt
# dtIgnore - [TxD] ignore flag for each dt at each IoU
#
# accumulate(): accumulates the per-image, per-category evaluation
# results in "evalImgs" into the dictionary "eval" with fields:
# params - parameters used for evaluation
# date - date evaluation was performed
# counts - [T,R,K,A,M] parameter dimensions (see above)
# precision - [TxRxKxAxM] precision for every evaluation setting
# recall - [TxKxAxM] max recall for every evaluation setting
# Note: precision and recall==-1 for settings with no gt objects.
#
# See also coco, mask, pycocoDemo, pycocoEvalDemo
#
# Microsoft COCO Toolbox. version 2.0
# Data, paper, and tutorials available at: http://mscoco.org/
# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
# Licensed under the Simplified BSD License [see coco/license.txt]
def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
'''
Initialize CocoEval using coco APIs for gt and dt
:param cocoGt: coco object with ground truth annotations
:param cocoDt: coco object with detection results
:return: None
'''
if not iouType:
print('iouType not specified. use default iouType segm')
self.cocoGt = cocoGt # ground truth COCO API
self.cocoDt = cocoDt # detections COCO API
self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
self.eval = {} # accumulated evaluation results
self._gts = defaultdict(list) # gt for evaluation
self._dts = defaultdict(list) # dt for evaluation
self.params = Params(iouType=iouType) # parameters
self._paramsEval = {} # parameters for evaluation
self.stats = [] # result summarization
self.ious = {} # ious between all gts and dts
if not cocoGt is None:
self.params.imgIds = sorted(cocoGt.getImgIds())
self.params.catIds = sorted(cocoGt.getCatIds())
def _prepare(self):
'''
Prepare ._gts and ._dts for evaluation based on params
:return: None
'''
def _toMask(anns, coco):
# modify ann['segmentation'] by reference
for ann in anns:
rle = coco.annToRLE(ann)
ann['segmentation'] = rle
p = self.params
if p.useCats:
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
else:
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
# convert ground truth to mask if iouType == 'segm'
if p.iouType == 'segm':
_toMask(gts, self.cocoGt)
_toMask(dts, self.cocoDt)
# set ignore flag
for gt in gts:
gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
if p.iouType == 'keypoints':
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
self._gts = defaultdict(list) # gt for evaluation
self._dts = defaultdict(list) # dt for evaluation
for gt in gts:
self._gts[gt['image_id'], gt['category_id']].append(gt)
for dt in dts:
self._dts[dt['image_id'], dt['category_id']].append(dt)
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
self.eval = {} # accumulated evaluation results
def evaluate(self):
'''
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
tic = time.time()
#('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if not p.useSegm is None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params=p
self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]
if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {(imgId, catId): computeIoU(imgId, catId) \
for imgId in p.imgIds
for catId in catIds}
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
self._paramsEval = copy.deepcopy(self.params)
toc = time.time()
#print('DONE (t={:0.2f}s).'.format(toc-tic))
def computeIoU(self, imgId, catId):
p = self.params
if p.useCats:
gt = self._gts[imgId,catId]
dt = self._dts[imgId,catId]
else:
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
if len(gt) == 0 and len(dt) ==0:
return []
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
dt = [dt[i] for i in inds]
if len(dt) > p.maxDets[-1]:
dt=dt[0:p.maxDets[-1]]
if p.iouType == 'segm':
g = [g['segmentation'] for g in gt]
d = [d['segmentation'] for d in dt]
elif p.iouType == 'bbox':
g = [g['bbox'] for g in gt]
d = [d['bbox'] for d in dt]
else:
raise Exception('unknown iouType for iou computation')
# compute iou between each dt and gt region
iscrowd = [int(o['iscrowd']) for o in gt]
ious = maskUtils.iou(d,g,iscrowd)
return ious
def computeOks(self, imgId, catId):
p = self.params
# dimention here should be Nxm
gts = self._gts[imgId, catId]
dts = self._dts[imgId, catId]
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
dts = [dts[i] for i in inds]
if len(dts) > p.maxDets[-1]:
dts = dts[0:p.maxDets[-1]]
# if len(gts) == 0 and len(dts) == 0:
if len(gts) == 0 or len(dts) == 0:
return []
ious = np.zeros((len(dts), len(gts)))
sigmas = p.kpt_oks_sigmas
vars = (sigmas * 2)**2
k = len(sigmas)
# compute oks between each detection and ground truth object
for j, gt in enumerate(gts):
# create bounds for ignore regions(double the gt bbox)
g = np.array(gt['keypoints'])
xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
k1 = np.count_nonzero(vg > 0)
bb = gt['bbox']
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
for i, dt in enumerate(dts):
d = np.array(dt['keypoints'])
xd = d[0::3]; yd = d[1::3]
if k1>0:
# measure the per-keypoint distance if keypoints visible
dx = xd - xg
dy = yd - yg
else:
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
z = np.zeros((k))
dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
if k1 > 0:
e=e[vg > 0]
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
return ious
def evaluateImg(self, imgId, catId, aRng, maxDet):
'''
perform evaluation for single category and image
:return: dict (single image results)
'''
p = self.params
if p.useCats:
gt = self._gts[imgId,catId]
dt = self._dts[imgId,catId]
else:
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
if len(gt) == 0 and len(dt) ==0:
return None
for g in gt:
if g['ignore'] or (g['area']<aRng[0] or g['area']>aRng[1]):
g['_ignore'] = 1
else:
g['_ignore'] = 0
# sort dt highest score first, sort gt ignore last
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
gt = [gt[i] for i in gtind]
dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
dt = [dt[i] for i in dtind[0:maxDet]]
iscrowd = [int(o['iscrowd']) for o in gt]
# load computed ious
ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
T = len(p.iouThrs)
G = len(gt)
D = len(dt)
gtm = np.zeros((T,G))
dtm = np.zeros((T,D))
gtIg = np.array([g['_ignore'] for g in gt])
dtIg = np.zeros((T,D))
if not len(ious)==0:
for tind, t in enumerate(p.iouThrs):
for dind, d in enumerate(dt):
# information about best match so far (m=-1 -> unmatched)
iou = min([t,1-1e-10])
m = -1
for gind, g in enumerate(gt):
# if this gt already matched, and not a crowd, continue
if gtm[tind,gind]>0 and not iscrowd[gind]:
continue
# if dt matched to reg gt, and on ignore gt, stop
if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
break
# continue to next gt unless better match made
if ious[dind,gind] < iou:
continue
# if match successful and best so far, store appropriately
iou=ious[dind,gind]
m=gind
# if match made store id of match for both dt and gt
if m ==-1:
continue
dtIg[tind,dind] = gtIg[m]
dtm[tind,dind] = gt[m]['id']
gtm[tind,m] = d['id']
# set unmatched detections outside of area range to ignore
a = np.array([d['area']<aRng[0] or d['area']>aRng[1] for d in dt]).reshape((1, len(dt)))
dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
# store results for given image and category
return {
'image_id': imgId,
'category_id': catId,
'aRng': aRng,
'maxDet': maxDet,
'dtIds': [d['id'] for d in dt],
'gtIds': [g['id'] for g in gt],
'dtMatches': dtm,
'gtMatches': gtm,
'dtScores': [d['score'] for d in dt],
'gtIgnore': gtIg,
'dtIgnore': dtIg,
}
def accumulate(self, p = None):
'''
Accumulate per image evaluation results and store the result in self.eval
:param p: input params for evaluation
:return: None
'''
#print('Accumulating evaluation results...')
tic = time.time()
if not self.evalImgs:
print('Please run evaluate() first')
# allows input customized parameters
if p is None:
p = self.params
p.catIds = p.catIds if p.useCats == 1 else [-1]
T = len(p.iouThrs)
R = len(p.recThrs)
K = len(p.catIds) if p.useCats else 1
A = len(p.areaRng)
M = len(p.maxDets)
precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
recall = -np.ones((T,K,A,M))
scores = -np.ones((T,R,K,A,M))
# create dictionary for future indexing
_pe = self._paramsEval
catIds = _pe.catIds if _pe.useCats else [-1]
setK = set(catIds)
setA = set(map(tuple, _pe.areaRng))
setM = set(_pe.maxDets)
setI = set(_pe.imgIds)
# get inds to evaluate
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
I0 = len(_pe.imgIds)
A0 = len(_pe.areaRng)
# retrieve E at each category, area range, and max number of detections
for k, k0 in enumerate(k_list):
Nk = k0*A0*I0
for a, a0 in enumerate(a_list):
Na = a0*I0
for m, maxDet in enumerate(m_list):
E = [self.evalImgs[Nk + Na + i] for i in i_list]
E = [e for e in E if not e is None]
if len(E) == 0:
continue
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
inds = np.argsort(-dtScores, kind='mergesort')
dtScoresSorted = dtScores[inds]
dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
gtIg = np.concatenate([e['gtIgnore'] for e in E])
npig = np.count_nonzero(gtIg==0 )
if npig == 0:
continue
tps = np.logical_and( dtm, np.logical_not(dtIg) )
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
tp = np.array(tp)
fp = np.array(fp)
nd = len(tp)
rc = tp / npig
pr = tp / (fp+tp+np.spacing(1))
q = np.zeros((R,))
ss = np.zeros((R,))
if nd:
recall[t,k,a,m] = rc[-1]
else:
recall[t,k,a,m] = 0
# numpy is slow without cython optimization for accessing elements
# use python array gets significant speed improvement
pr = pr.tolist(); q = q.tolist()
for i in range(nd-1, 0, -1):
if pr[i] > pr[i-1]:
pr[i-1] = pr[i]
inds = np.searchsorted(rc, p.recThrs, side='left')
try:
for ri, pi in enumerate(inds):
q[ri] = pr[pi]
ss[ri] = dtScoresSorted[pi]
except:
pass
precision[t,:,k,a,m] = np.array(q)
scores[t,:,k,a,m] = np.array(ss)
self.eval = {
'params': p,
'counts': [T, R, K, A, M],
'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'precision': precision,
'recall': recall,
'scores': scores,
}
toc = time.time()
# print('DONE (t={:0.2f}s).'.format( toc-tic))
def summarize(self):
'''
Compute and display summary metrics for evaluation results.
Note this functin can *only* be applied on the default parameter setting
'''
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
p = self.params
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
typeStr = '(AP)' if ap==1 else '(AR)'
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(iouThr)
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
if ap == 1:
# dimension of precision: [TxRxKxAxM]
s = self.eval['precision']
# IoU
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:,:,:,aind,mind]
else:
# dimension of recall: [TxKxAxM]
s = self.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
s = s[:,:,aind,mind]
if len(s[s>-1])==0:
mean_s = -1
else:
mean_s = np.mean(s[s>-1])
#print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
return mean_s
def _summarizeDets():
stats = np.zeros((12,))
stats[0] = _summarize(1)
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
return stats
def _summarizeKps():
stats = np.zeros((10,))
stats[0] = _summarize(1, maxDets=20)
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
stats[3] = _summarize(1, maxDets=20, areaRng='medium')
stats[4] = _summarize(1, maxDets=20, areaRng='large')
stats[5] = _summarize(0, maxDets=20)
stats[6] = _summarize(0, maxDets=20, iouThr=.5)
stats[7] = _summarize(0, maxDets=20, iouThr=.75)
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
stats[9] = _summarize(0, maxDets=20, areaRng='large')
return stats
if not self.eval:
raise Exception('Please run accumulate() first')
iouType = self.params.iouType
if iouType == 'segm' or iouType == 'bbox':
summarize = _summarizeDets
elif iouType == 'keypoints':
summarize = _summarizeKps
self.stats = summarize()
def __str__(self):
self.summarize()
class Params:
'''
Params for coco evaluation api
'''
def setDetParams(self):
self.imgIds = []
self.catIds = []
# np.arange causes trouble. the data point on arange is slightly larger than the true value
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
self.maxDets = [1, 10, 100]
self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
self.areaRngLbl = ['all', 'small', 'medium', 'large']
self.useCats = 1
def setKpParams(self):
self.imgIds = []
self.catIds = []
# np.arange causes trouble. the data point on arange is slightly larger than the true value
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
self.maxDets = [20]
self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
self.areaRngLbl = ['all', 'medium', 'large']
self.useCats = 1
self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
def __init__(self, iouType='segm'):
if iouType == 'segm' or iouType == 'bbox':
self.setDetParams()
elif iouType == 'keypoints':
self.setKpParams()
else:
raise Exception('iouType not supported')
self.iouType = iouType
# useSegm is deprecated
self.useSegm = None

View File

@ -0,0 +1,5 @@
from .vlm_module import VLMBaseModule
from .qwen_module import Qwen2VLModule
from .internvl_module import InvernVLModule
__all__ = ["VLMBaseModule", "Qwen2VLModule", "InvernVLModule"]

View File

@ -0,0 +1,292 @@
from open_r1.vlm_modules.vlm_module import VLMBaseModule
from typing import Dict, Any, Union
from transformers import AutoModel, AutoProcessor, AutoConfig
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers.feature_extraction_sequence_utils import BatchFeature
IMG_START_TOKEN='<img>'
IMG_END_TOKEN='</img>'
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class InvernVLModule(VLMBaseModule):
def __init__(self):
super().__init__()
self.conv_template = None
self.num_image_token = None
def get_vlm_key(self):
return "internvl"
def get_model_class(self, model_id: str, model_init_kwargs: dict):
assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
# The model class of InternVL when being mapped has been determined by its config
model_cls = AutoModel
# InternVL should be inputted with "trust_remote_code=True"
model_init_kwargs["trust_remote_code"] = True
# "use_cache" should be removed
model_init_kwargs.pop("use_cache", None)
# "flash_attention_2" should be modified to "use_flash_attn" in InternVL
if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
model_init_kwargs["use_flash_attn"] = True
model_init_kwargs.pop("attn_implementation")
return model_cls
def post_model_init(self, model, processing_class):
self.conv_template = model.conv_template if self.conv_template is None else self.conv_template
self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token
img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
model.img_context_token_id = img_context_token_id
def is_embeds_input(self):
return True
def get_processing_class(self):
return AutoProcessor
def get_eos_token_id(self, processing_class):
eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip())
return eos_token_id
def get_vision_modules_keywords(self):
return ['vision_model']
def get_custom_multimodal_keywords(self):
return ['pixel_values', 'image_flags']
def get_non_generate_params(self):
return ['image_flags']
def get_custom_processing_keywords(self):
return ['max_anyres_num']
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
prompts_text = []
for example in inputs:
template = self.conv_template.copy()
conversation_list = example["prompt"]
system_message = extract_system_message(conversation_list)
if system_message is not None:
template.system_message = system_message
processed_list = process_conversation_list(conversation_list, system_message)
for i, processed_item in enumerate(processed_list):
if i % 2 == 0:
template.append_message(template.roles[0], processed_item)
else:
template.append_message(template.roles[1], processed_item)
if len(processed_list) % 2 == 1:
template.append_message(template.roles[1], None)
query = template.get_prompt()
prompts_text.append(query)
return prompts_text
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
# Process images
full_pixel_values = []
num_patches_list = []
for img in images:
pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num)
full_pixel_values.append(pixel_values)
num_patches_list.append(pixel_values.shape[0])
full_pixel_values = torch.cat(full_pixel_values, dim=0)
# Process prompts
queries = []
image_idx = 0
for query in prompts_text:
while "<image>" in query:
num_patches = num_patches_list[image_idx]
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace("<image>", image_tokens, 1)
image_idx += 1
queries.append(query)
assert image_idx == len(num_patches_list)
model_inputs = processing_class(
queries,
return_tensors=return_tensors,
padding=padding,
padding_side=padding_side,
add_special_tokens=add_special_tokens,
)
model_inputs["pixel_values"] = full_pixel_values
# Only support pure-image data currently (each sample should contain the image)
model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long)
model_inputs = BatchFeature(data=model_inputs)
return model_inputs
def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12):
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
@staticmethod
def get_question_template(task_type: str):
match task_type:
case _:
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
@staticmethod
def format_reward_rec(completions, **kwargs):
"""Check if the InternVL model output matches a specific format."""
import re
pattern = r"<think>.*?</think>\s*<answer>.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
@staticmethod
def iou_reward(completions, solution, **kwargs):
"""Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box."""
"""Adopt soft iou reward here"""
import re
import os
from datetime import datetime
def iou(box1, box2):
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[2]-1, box2[2]-1)
inter_y2 = min(box1[3]-1, box2[3]-1)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
else:
inter = 0
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
return float(inter)/union
contents = [completion[0]["content"] for completion in completions]
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
answer_tag_pattern = r'<answer>(.*?)</answer>'
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
for content, sol in zip(contents, solution):
reward = 0.0
# Try symbolic verification first
try:
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
if content_answer_match:
content_answer = content_answer_match.group(1).strip()
bbox_match = re.search(bbox_pattern, content_answer)
if bbox_match:
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
reward = iou(bbox, sol)
except Exception:
pass # Continue to next verification method if this fails
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
# local_rank = int(os.getenv("LOCAL_RANK", 0))
with open(log_path, "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards
def process_conversation_list(conversation_list, system_message=None, image_newline=True):
if system_message is not None:
conversation_list = conversation_list[1:]
processed_list = []
for item in conversation_list:
role = item["role"]
content = item["content"]
if isinstance(content, list):
overall_str = ""
for content_item in content:
if content_item.get("type") == "image":
overall_str += "<image>" if not image_newline else "<image>\n"
elif content_item.get("type") == "text":
overall_str += content_item.get("text")
else:
raise ValueError(f"Unsupported content type: {type(content_item)}")
processed_list.append(overall_str)
elif isinstance(content, str):
processed_list.append(content)
else:
raise ValueError(f"Unsupported content type: {type(content)}")
return processed_list
def extract_system_message(conversation_list):
if conversation_list[0]["role"] == "system":
if isinstance(conversation_list[0]["content"], list):
return conversation_list[0]["content"][0]["text"]
else:
return conversation_list[0]["content"]
return None
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images

View File

@ -0,0 +1,134 @@
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
from typing import Dict, Any, Union
from trl.data_utils import maybe_apply_chat_template
import torch
from open_r1.vlm_modules.vlm_module import VLMBaseModule
class Qwen2VLModule(VLMBaseModule):
def __init__(self):
super().__init__()
def get_vlm_key(self):
return "qwen"
def get_model_class(self, model_id: str, model_init_kwargs: dict):
if "Qwen2-VL" in model_id:
model_cls = Qwen2VLForConditionalGeneration
elif "Qwen2.5-VL" in model_id:
model_cls = Qwen2_5_VLForConditionalGeneration
else:
raise ValueError(f"Unsupported model: {model_id}")
return model_cls
def post_model_init(self, model, processing_class):
pass
def get_processing_class(self):
return AutoProcessor
def get_vision_modules_keywords(self):
return ['visual']
def get_custom_multimodal_keywords(self):
return ['pixel_values', 'image_grid_thw']
def get_non_generate_params(self):
return []
def get_custom_processing_keywords(self):
return ['max_pixels', 'min_pixels']
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
return prompts_text
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
# FIXME
# This could only process pure-multimodal or pure-text inputs
if len(images) > 0:
prompt_inputs = processing_class(
text=prompts_text,
images=images,
return_tensors=return_tensors,
padding=padding,
padding_side=padding_side,
add_special_tokens=add_special_tokens)
else:
prompt_inputs = processing_class(
text=prompts_text,
return_tensors=return_tensors,
padding=padding,
padding_side=padding_side,
add_special_tokens=add_special_tokens)
return prompt_inputs
@staticmethod
def get_question_template(task_type: str):
match task_type:
case "rec":
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
case _:
return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
@staticmethod
def format_reward_rec(completions, **kwargs):
"""Check if the Qwen model output matches a specific format."""
import re
pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
def format_reward(completions, **kwargs):
pattern = r"<think>.*?</think>\s*<answer>.*?\[.*?{\"bbox_2d\":\s*\[\s*\d+,\s*\d+,\s*\d+,\s*\d+\s*\]\s*,\s*\"label\":\s*\".*?\"\s*}.*?\].*?</answer>"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
@staticmethod
def iou_reward(completions, solution, **kwargs):
"""Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
import re
import os
from datetime import datetime
def iou(box1, box2):
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[2]-1, box2[2]-1)
inter_y2 = min(box1[3]-1, box2[3]-1)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
else:
inter = 0
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
return float(inter)/union
contents = [completion[0]["content"] for completion in completions]
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
answer_tag_pattern = r'<answer>(.*?)</answer>'
bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
for content, sol in zip(contents, solution):
reward = 0.0
# Try symbolic verification first
try:
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
if content_answer_match:
content_answer = content_answer_match.group(1).strip()
bbox_match = re.search(bbox_pattern, content_answer)
if bbox_match:
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
if iou(bbox, sol) > 0.5:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
# local_rank = int(os.getenv("LOCAL_RANK", 0))
with open(log_path, "a", encoding='utf-8') as f:
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
f.write(f"Content: {content}\n")
f.write(f"Solution: {sol}\n")
return rewards

View File

@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Union
import torch
class VLMBaseModule(ABC):
def __init__(self):
super().__init__()
@abstractmethod
def get_vlm_key(self):
pass
@abstractmethod
def get_model_class(self, model_id: str, model_init_kwargs: dict):
pass
def post_model_init(self, model, processing_class):
pass
def is_embeds_input(self):
return False
@abstractmethod
def get_processing_class(self):
pass
@abstractmethod
def get_vision_modules_keywords(self):
pass
@abstractmethod
def get_custom_multimodal_keywords(self):
pass
@abstractmethod
def get_non_generate_params(self):
pass
@abstractmethod
def get_custom_processing_keywords(self):
pass
@abstractmethod
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
pass
@abstractmethod
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens):
pass

102
qw/test.py Normal file
View File

@ -0,0 +1,102 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import pickle
import re
from tqdm import tqdm
from peft import PeftModel
# Load the model in half-precision on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen2.5-VL-3B", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2"
)
# 加载LoRA权重添加load_in_8bit=False和device_map参数
# model = PeftModel.from_pretrained(
# model,
# "hybrid_train_output/checkpoint-100",
# load_in_8bit=False,
# device_map="auto",
# is_trainable=False,
# assign = True
# )
# 确保模型处于评估模式
model.eval()
processor = AutoProcessor.from_pretrained("Qwen2.5-VL-3B")
# Image
# image = Image.open("Tesla.jpg")
# 定义提示文本
text_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>"
"Please tell me the brand of the product in the picture between labels <answer/> and </answer> "
"and explain the reason between labels <thinking/> and </thinking>"
"<|im_end|>\n"
"<|im_start|>assistant"
)
# 加载测试数据
with open("../work/bal_data/test_data.pkl", "rb") as f:
test_data = pickle.load(f)
# 批处理大小
batch_size = 20
correct = 0
total = 0
# 遍历测试数据
for i in tqdm(range(0, len(test_data), batch_size)):
# 准备当前批次的数据
batch = test_data[i:i+batch_size]
batch_images = [item['image'] for item in batch]
batch_brands = [item['brand'] for item in batch]
batch_prompts = [text_prompt] * len(batch_images)
# 模型处理
inputs = processor(
text=batch_prompts,
images=batch_images,
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda")
# 生成输出
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_texts = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
# 提取预测的品牌名称并比较
for pred_text, true_brand in zip(output_texts, batch_brands):
# 使用正则表达式提取<answer>标签中的内容
match = re.search(r'<answer>(.*?)</answer>', pred_text)
if match:
pred_brand = match.group(1).strip().lower()
true_brand = true_brand.lower()
# 比较预测结果
if pred_brand == true_brand:
correct += 1
total += 1
# 计算并输出准确率
accuracy = correct / total if total > 0 else 0
print(f"准确率: {accuracy:.2%} ({correct}/{total})")

210
qw/train.py Normal file
View File

@ -0,0 +1,210 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
from accelerate import Accelerator
import torch
import logging
from datetime import datetime
from transformers import BitsAndBytesConfig, AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import pickle
from torch.utils.data import Dataset
import re
from trl import GRPOTrainer, GRPOConfig
from open_r1.trainer.grpo_trainer import VLMGRPOTrainer
from open_r1.vlm_modules.qwen_module import Qwen2VLModule
# 在最开始添加日志配置
def setup_logging():
# 创建logs目录如果不存在
if not os.path.exists('logs'):
os.makedirs('logs')
# 生成带时间戳的日志文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/training_{timestamp}.log'
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler() # 同时输出到控制台
]
)
logging.info(f"日志文件创建于: {log_file}")
return log_file
# 设置日志
log_file = setup_logging()
# 初始化 accelerator
accelerator = Accelerator()
# ## TRAINING
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
# bnb_4bit_use_double_quant= True,
)
# # 修改模型加载部分
# model = AutoModelForImageTextToText.from_pretrained(
# "./model",
# quantization_config=bnb_config,
# torch_dtype=torch.bfloat16
# )
# model.gradient_checkpointing_enable()
# model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
peft_config = LoraConfig(
task_type="CAUSAL_LM", # 因为是Causal Language Model
inference_mode=False,
r=8, # LoRA 秩
lora_alpha=32, # LoRA alpha参数
lora_dropout=0.1, # Dropout概率
target_modules=[ # 需要训练的模型层
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
bias="none",
)
# # 打印原始模型的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info("=== 训练配置信息 ===")
# logging.info(f"原始模型总参数量: {total_params:,}")
# logging.info(f"原始模型可训练参数量: {trainable_params:,}")
# logging.info(f"原始模型可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# model = get_peft_model(model, peft_config)
# # 打印QLora后的参数统计
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logging.info(f"\nQLora后总参数量: {total_params:,}")
# logging.info(f"QLora后可训练参数量: {trainable_params:,}")
# logging.info(f"QLora后可训练参数占比: {100 * trainable_params / total_params:.2f}%")
# # 开启需要训练的参数的梯度更新
# model.train()
# for name, param in model.named_parameters():
# if param.requires_grad:
# # logging.info(f"开启参数 {name} 的梯度更新")
# param.requires_grad_(True)
class ChatDataset(Dataset):
def __init__(self, data_path):
with open(data_path, 'rb') as f:
self.data = pickle.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt_text = "Please tell me the brand of the product in the picture between labels <answer/> and </answer> and explain the reason between labels <thinking/> and </thinking>"
# 使用模板格式化prompt
formatted_prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>" + prompt_text + "<|im_end|>\n"
"<|im_start|>assistant"
)
return {
"prompt": formatted_prompt,
"image": item['image'],
"correct_brand": item['brand']
}
# 加载数据集
logging.info("加载训练数据...")
train_dataset = ChatDataset('../work/bal_data/frequent_brands_data.pkl')
logging.info(f"加载了 {len(train_dataset)} 条训练数据")
def reward_func(prompts, completions, **kwargs):
rewards = []
correct_brands = kwargs.get('correct_brand')
for completion, correct_brand in zip(completions, correct_brands):
reward = 0.0
# 提取<answer>标签中的内容
answer_match = re.search(r'<answer/>(.*?)</answer>', completion)
# 提取<thinking>标签中的内容
thinking_match = re.search(r'<thinking/>(.*?)</thinking>', completion)
if answer_match:
answer_content = answer_match.group(1).lower()
if correct_brand.lower() in answer_content:
reward += 1.0
if thinking_match:
thinking_content = thinking_match.group(1).lower() # 使用单独的变量
if correct_brand.lower() in thinking_content: # 使用thinking的内容
reward += 1.0
# 使用logging替代print
logging.debug(f"\nCompletion: {completion}")
logging.debug(f"Correct brand: {correct_brand}")
logging.debug(f"Final reward: {reward}")
rewards.append(reward)
return rewards
def get_training_args():
args = GRPOConfig(
output_dir="chat_grpo_output",
num_generations=6,
learning_rate=1e-5,
logging_steps=100,
max_prompt_length=None,
gradient_accumulation_steps=1,
max_completion_length=200,
per_device_train_batch_size=3,
max_steps=1000,
dataloader_pin_memory=False,
model_init_kwargs={
"quantization_config": bnb_config,
"torch_dtype": torch.bfloat16,
"use_cache": False
}
)
args.epsilon = 0.2
args.num_iterations = 1
return args
# 然后再创建trainer
trainer = VLMGRPOTrainer(
model='./Qwen2.5-VL-7B',
reward_funcs=reward_func,
args=get_training_args(),
train_dataset=train_dataset,
peft_config=peft_config,
vlm_module=Qwen2VLModule()
)
# 训练相关的日志
logging.info("开始训练...")
trainer.train()
logging.info("训练完成")
# 保存模型
output_dir = "chat_model_lora"
unwrapped_model = accelerator.unwrap_model(trainer.model)
unwrapped_model.save_pretrained(output_dir)
logging.info(f"模型已保存到 {output_dir}")