850 lines
42 KiB
Python
Executable File
850 lines
42 KiB
Python
Executable File
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import textwrap
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, Optional, Union, Sized
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
import transformers
|
|
from datasets import Dataset, IterableDataset
|
|
from packaging import version
|
|
from transformers import (
|
|
AriaForConditionalGeneration,
|
|
AriaProcessor,
|
|
AutoModelForCausalLM,
|
|
AutoModelForSequenceClassification,
|
|
AutoProcessor,
|
|
AutoTokenizer,
|
|
GenerationConfig,
|
|
PreTrainedModel,
|
|
PreTrainedTokenizerBase,
|
|
Qwen2VLForConditionalGeneration,
|
|
Qwen2_5_VLForConditionalGeneration,
|
|
Trainer,
|
|
TrainerCallback,
|
|
is_wandb_available,
|
|
)
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from transformers.utils import is_peft_available
|
|
|
|
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
|
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
|
from trl.trainer.grpo_config import GRPOConfig
|
|
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
|
|
from trl import GRPOTrainer
|
|
|
|
from accelerate.utils import is_peft_model, set_seed
|
|
import PIL.Image
|
|
|
|
import copy
|
|
from torch.utils.data import Sampler
|
|
import warnings
|
|
|
|
if is_peft_available():
|
|
from peft import PeftConfig, get_peft_model
|
|
|
|
if is_wandb_available():
|
|
import wandb
|
|
|
|
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
|
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
|
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
|
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
|
|
|
|
|
class RepeatRandomSampler(Sampler):
|
|
"""
|
|
Sampler that repeats the indices of a dataset in a structured manner.
|
|
|
|
Args:
|
|
data_source (`Sized`):
|
|
Dataset to sample from.
|
|
mini_repeat_count (`int`):
|
|
Number of times to repeat each index per batch.
|
|
batch_size (`int`, *optional*, defaults to `1`):
|
|
Number of unique indices per batch.
|
|
repeat_count (`int`, *optional*, defaults to `1`):
|
|
Number of times to repeat the full sampling process.
|
|
seed (`int` or `None`, *optional*, defaults to `None`):
|
|
Random seed for reproducibility.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_source: Sized,
|
|
mini_repeat_count: int,
|
|
batch_size: int = 1,
|
|
repeat_count: int = 1,
|
|
seed: Optional[int] = None,
|
|
):
|
|
self.data_source = data_source
|
|
self.mini_repeat_count = mini_repeat_count
|
|
self.batch_size = batch_size
|
|
self.repeat_count = repeat_count
|
|
self.num_samples = len(data_source)
|
|
self.seed = seed
|
|
self.generator = torch.Generator()
|
|
if seed is not None:
|
|
self.generator.manual_seed(seed)
|
|
|
|
def __iter__(self):
|
|
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
|
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
|
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
|
|
|
for chunk in indexes:
|
|
for _ in range(self.repeat_count):
|
|
for index in chunk:
|
|
for _ in range(self.mini_repeat_count):
|
|
yield index
|
|
|
|
def __len__(self) -> int:
|
|
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
|
|
|
|
|
class VLMGRPOTrainer(Trainer):
|
|
"""
|
|
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
|
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
|
|
|
Example:
|
|
|
|
```python
|
|
from datasets import load_dataset
|
|
from trl import GRPOTrainer
|
|
|
|
dataset = load_dataset("trl-lib/tldr", split="train")
|
|
|
|
trainer = GRPOTrainer(
|
|
model="Qwen/Qwen2-0.5B-Instruct",
|
|
reward_funcs="weqweasdas/RM-Gemma-2B",
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
trainer.train()
|
|
```
|
|
|
|
Args:
|
|
model (`Union[str, PreTrainedModel]`):
|
|
Model to be trained. Can be either:
|
|
|
|
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
|
a path to a *directory* containing model weights saved using
|
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
|
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
|
in `args.model_init_kwargs`.
|
|
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
|
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
|
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
|
functions with the prompts and completions and sum the rewards. Can be either:
|
|
|
|
- A single reward function, such as:
|
|
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
|
path to a *directory* containing model weights saved using
|
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
|
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
|
keyword arguments in `args.model_init_kwargs`.
|
|
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
|
- A custom reward function: The function is provided with the prompts and the generated completions,
|
|
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
|
[Using a custom reward function](#using-a-custom-reward-function).
|
|
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
|
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
|
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
|
Configuration for this trainer. If `None`, a default configuration is used.
|
|
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
|
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
|
ignored. The format of the samples can be either:
|
|
|
|
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
|
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
|
and content).
|
|
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
|
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
|
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
|
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
|
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
|
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
|
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
|
|
|
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
|
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
|
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
|
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
|
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
|
the corresponding entries in `reward_processing_classes` are ignored.
|
|
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
|
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
|
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
|
|
|
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
|
method.
|
|
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
|
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
|
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
|
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Union[str, PreTrainedModel],
|
|
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
|
args: GRPOConfig = None,
|
|
vlm_module: VLMBaseModule = None,
|
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
|
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
|
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
|
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
|
callbacks: Optional[list[TrainerCallback]] = None,
|
|
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
freeze_vision_modules: Optional[bool] = False,
|
|
attn_implementation: str = "flash_attention_2",
|
|
torch_dtype: str = "bfloat16",
|
|
**kwargs,
|
|
):
|
|
# Args
|
|
if args is None:
|
|
model_name = model if isinstance(model, str) else model.config._name_or_path
|
|
model_name = model_name.split("/")[-1]
|
|
args = GRPOConfig(f"{model_name}-GRPO")
|
|
|
|
self.vlm_module = vlm_module
|
|
|
|
# Models
|
|
# Trained model
|
|
model_init_kwargs = args.model_init_kwargs or {}
|
|
# FIXME
|
|
# Remember to modify it in the invernvl
|
|
model_init_kwargs["attn_implementation"] = attn_implementation
|
|
if model_init_kwargs.get("torch_dtype") is None:
|
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
|
|
|
assert isinstance(model, str), "model must be a string in the current implementation"
|
|
model_id = model
|
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
|
torch_dtype = getattr(torch, torch_dtype)
|
|
else:
|
|
raise ValueError(
|
|
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
|
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
|
)
|
|
model_init_kwargs["use_cache"] = (
|
|
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
|
)
|
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
|
model_init_kwargs["use_cache"] = (
|
|
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
|
)
|
|
model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs)
|
|
model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
|
|
|
# LoRA
|
|
self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords()
|
|
if peft_config is not None:
|
|
def find_all_linear_names(model, multimodal_keywords):
|
|
cls = torch.nn.Linear
|
|
lora_module_names = set()
|
|
for name, module in model.named_modules():
|
|
# LoRA is not applied to the vision modules
|
|
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
|
continue
|
|
if isinstance(module, cls):
|
|
lora_module_names.add(name)
|
|
for m in lora_module_names: # needed for 16-bit
|
|
if "embed_tokens" in m:
|
|
lora_module_names.remove(m)
|
|
return list(lora_module_names)
|
|
target_modules = find_all_linear_names(model, self.vision_modules_keywords)
|
|
peft_config.target_modules = target_modules
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
# Freeze vision modules
|
|
if freeze_vision_modules:
|
|
print("Freezing vision modules...")
|
|
for n, p in model.named_parameters():
|
|
if any(keyword in n for keyword in self.vision_modules_keywords):
|
|
p.requires_grad = False
|
|
|
|
# Enable gradient checkpointing if requested
|
|
if args.gradient_checkpointing:
|
|
model = self._enable_gradient_checkpointing(model, args)
|
|
|
|
# Reference model
|
|
if is_deepspeed_zero3_enabled():
|
|
self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
|
elif peft_config is None:
|
|
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
|
self.ref_model = create_reference_model(model)
|
|
else:
|
|
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
|
# to revert to the initial model.
|
|
self.ref_model = None
|
|
|
|
# Processing class
|
|
if processing_class is None:
|
|
processing_cls = self.vlm_module.get_processing_class()
|
|
processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None))
|
|
for processing_keyword in self.vlm_module.get_custom_processing_keywords():
|
|
if processing_keyword in kwargs:
|
|
setattr(processing_class, processing_keyword, kwargs[processing_keyword])
|
|
if getattr(processing_class, "tokenizer", None) is not None:
|
|
pad_token_id = processing_class.tokenizer.pad_token_id
|
|
processing_class.pad_token_id = pad_token_id
|
|
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
|
else:
|
|
assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
|
|
pad_token_id = processing_class.pad_token_id
|
|
|
|
self.vlm_module.post_model_init(model, processing_class)
|
|
self.vlm_module.post_model_init(self.ref_model, processing_class)
|
|
|
|
# Reward functions
|
|
if not isinstance(reward_funcs, list):
|
|
reward_funcs = [reward_funcs]
|
|
for i, reward_func in enumerate(reward_funcs):
|
|
if isinstance(reward_func, str):
|
|
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
|
reward_func, num_labels=1, **model_init_kwargs
|
|
)
|
|
self.reward_funcs = reward_funcs
|
|
|
|
# Reward processing class
|
|
if reward_processing_classes is None:
|
|
reward_processing_classes = [None] * len(reward_funcs)
|
|
elif not isinstance(reward_processing_classes, list):
|
|
reward_processing_classes = [reward_processing_classes]
|
|
else:
|
|
if len(reward_processing_classes) != len(reward_funcs):
|
|
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
|
|
|
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
if reward_processing_class is None:
|
|
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
|
if reward_processing_class.pad_token_id is None:
|
|
reward_processing_class.pad_token = reward_processing_class.eos_token
|
|
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
|
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
|
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
|
reward_processing_classes[i] = reward_processing_class
|
|
self.reward_processing_classes = reward_processing_classes
|
|
|
|
# Data collator
|
|
def data_collator(features): # No data collation is needed in GRPO
|
|
return features
|
|
|
|
# Training arguments
|
|
self.max_prompt_length = args.max_prompt_length
|
|
self.max_prompt_length = None
|
|
if args.max_prompt_length is not None:
|
|
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
|
|
|
|
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
|
self.num_generations = args.num_generations # = G in the GRPO paper
|
|
self.generation_config = GenerationConfig(
|
|
max_new_tokens=self.max_completion_length,
|
|
do_sample=True,
|
|
temperature=1,
|
|
pad_token_id=pad_token_id,
|
|
)
|
|
if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL
|
|
self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class)
|
|
print(222, self.vlm_module.get_eos_token_id(processing_class))
|
|
self.beta = args.beta
|
|
self.epsilon = args.epsilon
|
|
|
|
# Multi-step
|
|
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
|
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
|
|
self._step = 0
|
|
# Buffer the batch to reuse generated outputs across multiple updates
|
|
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
|
|
|
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
|
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
|
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
|
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
|
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
|
# This acts as a flag to indicate that the warning has already been issued.
|
|
model.warnings_issued["estimate_tokens"] = True
|
|
|
|
# Initialize the metrics
|
|
self._metrics = defaultdict(list)
|
|
|
|
super().__init__(
|
|
model=model,
|
|
args=args,
|
|
data_collator=data_collator,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=processing_class,
|
|
callbacks=callbacks,
|
|
optimizers=optimizers,
|
|
)
|
|
|
|
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
|
num_processes = self.accelerator.num_processes
|
|
global_batch_size = args.per_device_train_batch_size * num_processes
|
|
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
|
if self.num_generations not in possible_values:
|
|
raise ValueError(
|
|
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
|
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
|
f"batch size, the valid values for the number of generations are: {possible_values}."
|
|
)
|
|
if self.args.eval_strategy != "no":
|
|
global_batch_size = args.per_device_eval_batch_size * num_processes
|
|
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
|
if self.num_generations not in possible_values:
|
|
raise ValueError(
|
|
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
|
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
|
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
|
)
|
|
|
|
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
|
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
|
# it's safer to set it in all cases.
|
|
set_seed(args.seed, device_specific=True)
|
|
|
|
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
|
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
|
# self.model_accepts_loss_kwargs to False to enable scaling.
|
|
self.model_accepts_loss_kwargs = False
|
|
|
|
if self.ref_model is not None:
|
|
if self.is_deepspeed_enabled:
|
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
|
else:
|
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
|
|
|
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
|
"""Enables gradient checkpointing for the model."""
|
|
# Ensure use_cache is disabled
|
|
model.config.use_cache = False
|
|
|
|
# Enable gradient checkpointing on the base model for PEFT
|
|
if is_peft_model(model):
|
|
model.base_model.gradient_checkpointing_enable()
|
|
# Enable gradient checkpointing for non-PEFT models
|
|
else:
|
|
try:
|
|
model.gradient_checkpointing_enable()
|
|
except:
|
|
# For InternVL; these operations are copied from the original training script of InternVL
|
|
model.language_model.config.use_cache = False
|
|
model.vision_model.gradient_checkpointing = True
|
|
model.vision_model.encoder.gradient_checkpointing = True
|
|
model.language_model._set_gradient_checkpointing()
|
|
# This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
|
|
args.gradient_checkpointing = False
|
|
|
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
|
use_reentrant = (
|
|
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
|
)
|
|
|
|
if use_reentrant:
|
|
model.enable_input_require_grads()
|
|
|
|
return model
|
|
|
|
def _set_signature_columns_if_needed(self):
|
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
|
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
|
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
|
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
|
if self._signature_columns is None:
|
|
self._signature_columns = ["prompt"]
|
|
|
|
|
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
|
def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
|
|
logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
|
|
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
|
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
|
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
|
per_token_logps = []
|
|
for logits_row, input_ids_row in zip(logits, input_ids):
|
|
log_probs = logits_row.log_softmax(dim=-1)
|
|
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
|
per_token_logps.append(token_log_prob)
|
|
return torch.stack(per_token_logps)
|
|
|
|
|
|
def _prepare_inputs(self, inputs):
|
|
# Simple pass-through, just like original
|
|
return inputs
|
|
|
|
def _get_key_from_inputs(self, x, key):
|
|
ele = x.get(key, None)
|
|
assert ele is not None, f"The key {key} is not found in the input"
|
|
if isinstance(ele, list):
|
|
return [e for e in ele]
|
|
else:
|
|
return [ele]
|
|
|
|
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
|
|
device = self.accelerator.device
|
|
prompts = [x["prompt"] for x in inputs]
|
|
prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs)
|
|
# Handle both pre-loaded images and image paths
|
|
images = []
|
|
for x in inputs:
|
|
if "image" in x:
|
|
imgs = self._get_key_from_inputs(x, "image")
|
|
elif "image_path" in x and x["image_path"] is not None:
|
|
imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")]
|
|
|
|
for img in imgs:
|
|
try:
|
|
# Ensure minimum dimensions of 28 pixels
|
|
w, h = img.size
|
|
if w < 28 or h < 28:
|
|
# Calculate new dimensions maintaining aspect ratio
|
|
if w < h:
|
|
new_w = 28
|
|
new_h = int(h * (28/w))
|
|
else:
|
|
new_h = 28
|
|
new_w = int(w * (28/h))
|
|
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
|
|
except:
|
|
pass
|
|
images.append(img)
|
|
|
|
|
|
prompt_inputs = self.vlm_module.prepare_model_inputs(
|
|
self.processing_class,
|
|
prompts_text,
|
|
images,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
padding_side="left",
|
|
add_special_tokens=False
|
|
)
|
|
|
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
|
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
|
|
# max_prompt_length is not supported yet
|
|
# if self.max_prompt_length is not None:
|
|
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
|
# prompt_inputs["input_ids"] = prompt_ids
|
|
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
|
# prompt_inputs["attention_mask"] = prompt_mask
|
|
|
|
# Generate completions
|
|
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
|
generate_returned_result = unwrapped_model.generate(
|
|
**{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()},
|
|
generation_config=self.generation_config
|
|
)
|
|
prompt_length = prompt_ids.size(1)
|
|
if not self.vlm_module.is_embeds_input():
|
|
prompt_completion_ids = generate_returned_result
|
|
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
|
completion_ids = prompt_completion_ids[:, prompt_length:]
|
|
else:
|
|
# In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
|
|
# So the returned result of the `generate` method only contains the completion ids
|
|
completion_ids = generate_returned_result
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
|
|
# Mask everything after the first EOS token
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
|
|
|
# Get the multimodal inputs
|
|
multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords()
|
|
multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
|
|
with torch.no_grad():
|
|
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
|
# computation here, and use per_token_logps.detach() instead.
|
|
if self.num_iterations > 1:
|
|
old_per_token_logps = self._get_per_token_logps(
|
|
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
|
)
|
|
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
|
|
else:
|
|
old_per_token_logps = None
|
|
|
|
if self.beta == 0.0:
|
|
ref_per_token_logps = None
|
|
elif self.ref_model is not None:
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
|
)
|
|
else:
|
|
with self.accelerator.unwrap_model(model).disable_adapter():
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
|
)
|
|
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
|
|
|
|
# Decode the generated completions
|
|
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
|
if is_conversational(inputs[0]):
|
|
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
|
|
|
# Compute the rewards
|
|
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
|
|
|
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
|
for i, (reward_func, reward_processing_class) in enumerate(
|
|
zip(self.reward_funcs, self.reward_processing_classes)
|
|
):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
if is_conversational(inputs[0]):
|
|
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
|
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
|
else:
|
|
texts = [p + c for p, c in zip(prompts, completions)]
|
|
reward_inputs = reward_processing_class(
|
|
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
|
)
|
|
reward_inputs = super()._prepare_inputs(reward_inputs)
|
|
with torch.inference_mode():
|
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
|
else:
|
|
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
|
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
|
for key in reward_kwargs:
|
|
for example in inputs:
|
|
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
|
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
|
reward_kwargs[key].extend([example[key]])
|
|
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
|
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
|
|
|
# Gather rewards across processes
|
|
rewards_per_func = self.accelerator.gather(rewards_per_func)
|
|
|
|
# Sum the rewards from all reward functions
|
|
rewards = rewards_per_func.sum(dim=1)
|
|
|
|
# Compute grouped-wise rewards
|
|
# Each group consists of num_generations completions for the same prompt
|
|
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
|
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
|
|
|
# Normalize the rewards to compute the advantages
|
|
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
|
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
|
|
|
# Get only the local slice of advantages
|
|
process_slice = slice(
|
|
self.accelerator.process_index * len(prompts),
|
|
(self.accelerator.process_index + 1) * len(prompts),
|
|
)
|
|
advantages = advantages[process_slice]
|
|
|
|
# Log the metrics
|
|
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
|
self._metrics["completion_length"].append(completion_length)
|
|
|
|
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
else:
|
|
reward_func_name = reward_func.__name__
|
|
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
|
|
|
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
|
|
|
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
|
|
|
return {
|
|
"prompt_ids": prompt_ids,
|
|
"prompt_mask": prompt_mask,
|
|
"completion_ids": completion_ids,
|
|
"completion_mask": completion_mask,
|
|
"old_per_token_logps": old_per_token_logps,
|
|
"ref_per_token_logps": ref_per_token_logps,
|
|
"advantages": advantages,
|
|
"multimodal_inputs": multimodal_inputs
|
|
}
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
if return_outputs:
|
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
|
|
# Check if we need to generate new completions or use buffered ones
|
|
if self.state.global_step % self.num_iterations == 0:
|
|
inputs = self._generate_and_score_completions(inputs, model)
|
|
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
|
else:
|
|
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
|
self._step += 1
|
|
|
|
# Get the prepared inputs
|
|
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
|
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
|
multimodal_inputs = inputs["multimodal_inputs"]
|
|
|
|
# Concatenate for full sequence
|
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
|
|
# Get the current policy's log probabilities
|
|
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
|
|
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
|
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
|
|
|
|
# Get the advantages from inputs
|
|
advantages = inputs["advantages"]
|
|
|
|
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
|
# and use per_token_logps.detach() instead
|
|
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
|
|
|
# Compute the policy ratio and clipped version
|
|
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
|
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
|
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
|
|
# Add KL penalty if beta > 0
|
|
if self.beta > 0:
|
|
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
|
per_token_loss = per_token_loss + self.beta * per_token_kl
|
|
|
|
# Log KL divergence
|
|
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
|
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
|
|
|
# Compute final loss
|
|
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
|
|
|
# Log clip ratio
|
|
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
|
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
|
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
|
|
|
return loss
|
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
|
logs = {**logs, **metrics}
|
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
super().log(logs, start_time)
|
|
else: # transformers<=4.46
|
|
super().log(logs)
|
|
self._metrics.clear()
|
|
|
|
def create_model_card(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
dataset_name: Optional[str] = None,
|
|
tags: Union[str, list[str], None] = None,
|
|
):
|
|
"""
|
|
Creates a draft of a model card using the information available to the `Trainer`.
|
|
|
|
Args:
|
|
model_name (`str` or `None`, *optional*, defaults to `None`):
|
|
Name of the model.
|
|
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
|
Name of the dataset used for training.
|
|
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
|
Tags to be associated with the model card.
|
|
"""
|
|
if not self.is_world_process_zero():
|
|
return
|
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
|
base_model = self.model.config._name_or_path
|
|
else:
|
|
base_model = None
|
|
|
|
tags = tags or []
|
|
if isinstance(tags, str):
|
|
tags = [tags]
|
|
|
|
if hasattr(self.model.config, "unsloth_version"):
|
|
tags.append("unsloth")
|
|
|
|
citation = textwrap.dedent(
|
|
"""\
|
|
@article{zhihong2024deepseekmath,
|
|
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
|
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
|
year = 2024,
|
|
eprint = {arXiv:2402.03300},
|
|
"""
|
|
)
|
|
|
|
model_card = generate_model_card(
|
|
base_model=base_model,
|
|
model_name=model_name,
|
|
hub_model_id=self.hub_model_id,
|
|
dataset_name=dataset_name,
|
|
tags=tags,
|
|
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
|
comet_url=get_comet_experiment_url(),
|
|
trainer_name="GRPO",
|
|
trainer_citation=citation,
|
|
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
|
paper_id="2402.03300",
|
|
)
|
|
|
|
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
|
|
|
def _get_train_sampler(self) -> Sampler:
|
|
"""Returns a sampler that ensures proper data sampling for GRPO training."""
|
|
effective_batch_size = (
|
|
self.args.per_device_train_batch_size
|
|
* self.accelerator.num_processes
|
|
* self.args.gradient_accumulation_steps
|
|
)
|
|
|
|
return RepeatRandomSampler(
|
|
data_source=self.train_dataset,
|
|
mini_repeat_count=self.num_generations,
|
|
batch_size=effective_batch_size // self.num_generations,
|
|
repeat_count=self.num_iterations,
|
|
seed=self.args.seed,
|
|
)
|
|
|
|
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
|
"""Returns a sampler for evaluation."""
|
|
return RepeatRandomSampler(
|
|
data_source=eval_dataset,
|
|
mini_repeat_count=self.num_generations,
|
|
seed=self.args.seed,
|
|
)
|