825 lines
36 KiB
Python
Executable File
825 lines
36 KiB
Python
Executable File
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import textwrap
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, Optional, Union
|
|
from accelerate.utils.other import is_compiled_module
|
|
from accelerate.utils import broadcast_object_list, gather, gather_object
|
|
import torch
|
|
import torch.utils.data
|
|
import transformers
|
|
import warnings
|
|
from unittest.mock import patch
|
|
from datasets import Dataset, IterableDataset
|
|
from packaging import version
|
|
from transformers import (
|
|
AriaForConditionalGeneration,
|
|
AriaProcessor,
|
|
AutoModelForCausalLM,
|
|
AutoModelForSequenceClassification,
|
|
AutoProcessor,
|
|
AutoTokenizer,
|
|
GenerationConfig,
|
|
PreTrainedModel,
|
|
PreTrainedTokenizerBase,
|
|
Qwen2VLForConditionalGeneration,
|
|
Qwen2_5_VLForConditionalGeneration,
|
|
Trainer,
|
|
TrainerCallback,
|
|
is_wandb_available,
|
|
)
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from transformers.utils import is_peft_available
|
|
|
|
from trl.data_utils import (
|
|
apply_chat_template,
|
|
is_conversational,
|
|
maybe_apply_chat_template,
|
|
)
|
|
from trl.import_utils import is_vllm_available
|
|
|
|
from trl.models import (
|
|
create_reference_model,
|
|
prepare_deepspeed,
|
|
unwrap_model_for_generation,
|
|
)
|
|
from trl.trainer.grpo_config import GRPOConfig
|
|
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
|
|
from trl import GRPOTrainer
|
|
|
|
import copy
|
|
|
|
if is_peft_available():
|
|
from peft import PeftConfig, get_peft_model
|
|
|
|
if is_vllm_available():
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
if is_wandb_available():
|
|
import wandb
|
|
import torch.nn as nn
|
|
from torch.utils.data import Sampler
|
|
|
|
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
|
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
|
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
|
|
|
|
|
class RepeatRandomSampler(Sampler):
|
|
"""
|
|
Sampler that repeats the indices of a dataset N times.
|
|
|
|
Args:
|
|
data_source (`Sized`):
|
|
Dataset to sample from.
|
|
repeat_count (`int`):
|
|
Number of times to repeat each index.
|
|
|
|
Example:
|
|
```python
|
|
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
|
>>> list(sampler)
|
|
[2, 2, 0, 0, 3, 3, 1, 1]
|
|
```
|
|
"""
|
|
|
|
def __init__(self, data_source, repeat_count: int):
|
|
self.data_source = data_source
|
|
self.repeat_count = repeat_count
|
|
self.num_samples = len(data_source)
|
|
|
|
def __iter__(self):
|
|
indexes = [
|
|
idx
|
|
for idx in torch.randperm(self.num_samples).tolist()
|
|
for _ in range(self.repeat_count)
|
|
]
|
|
return iter(indexes)
|
|
|
|
def __len__(self):
|
|
return self.num_samples * self.repeat_count
|
|
|
|
|
|
class Qwen2VLGRPOVLLMTrainer(Trainer):
|
|
def __init__(
|
|
self,
|
|
model: Union[str, PreTrainedModel],
|
|
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
|
args: GRPOConfig = None,
|
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
|
eval_dataset: Optional[
|
|
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
|
] = None,
|
|
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
|
reward_processing_classes: Optional[
|
|
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
|
] = None,
|
|
callbacks: Optional[list[TrainerCallback]] = None,
|
|
optimizers: tuple[
|
|
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
|
] = (None, None),
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
# qwen2-vl related params
|
|
max_pixels: Optional[int] = 12845056,
|
|
min_pixels: Optional[int] = 3136,
|
|
attn_implementation: str = "flash_attention_2",
|
|
):
|
|
|
|
# Args
|
|
if args is None:
|
|
model_name = model if isinstance(model, str) else model.config._name_or_path
|
|
model_name = model_name.split("/")[-1]
|
|
args = GRPOConfig(f"{model_name}-GRPO")
|
|
|
|
# Models
|
|
# Trained model
|
|
model_init_kwargs = args.model_init_kwargs or {}
|
|
model_init_kwargs["attn_implementation"] = attn_implementation
|
|
if isinstance(model, str):
|
|
model_id = model
|
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
if (
|
|
isinstance(torch_dtype, torch.dtype)
|
|
or torch_dtype == "auto"
|
|
or torch_dtype is None
|
|
):
|
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
|
torch_dtype = getattr(torch, torch_dtype)
|
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
|
else:
|
|
raise ValueError(
|
|
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
|
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
|
)
|
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
|
model_init_kwargs["use_cache"] = (
|
|
False
|
|
if args.gradient_checkpointing
|
|
else model_init_kwargs.get("use_cache")
|
|
)
|
|
if "Qwen2-VL" in model_id:
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
model, **model_init_kwargs
|
|
)
|
|
elif "Qwen2.5-VL" in model_id:
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
|
elif "Aria" in model_id:
|
|
model_init_kwargs.pop("use_cache")
|
|
model = AriaForConditionalGeneration.from_pretrained(
|
|
model, **model_init_kwargs
|
|
)
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
|
else:
|
|
model_id = model.config._name_or_path
|
|
if args.model_init_kwargs is not None:
|
|
raise ValueError(
|
|
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
|
"This argument can only be used when the `model` argument is a string."
|
|
)
|
|
|
|
if peft_config is not None:
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
# Reference model
|
|
if is_deepspeed_zero3_enabled():
|
|
if "Qwen2-VL" in model_id:
|
|
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
model_id, **model_init_kwargs
|
|
)
|
|
elif "Aria" in model_id:
|
|
self.ref_model = AriaForConditionalGeneration.from_pretrained(
|
|
model_id, **model_init_kwargs
|
|
)
|
|
else:
|
|
self.ref_model = AutoModelForCausalLM.from_pretrained(
|
|
model_id, **model_init_kwargs
|
|
)
|
|
elif peft_config is None:
|
|
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
|
self.ref_model = create_reference_model(model)
|
|
else:
|
|
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
|
# to revert to the initial model.
|
|
self.ref_model = None
|
|
|
|
# Processing class
|
|
if processing_class is None:
|
|
if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
|
|
processing_class = AutoProcessor.from_pretrained(model_id)
|
|
pad_token_id = processing_class.tokenizer.pad_token_id
|
|
processing_class.pad_token_id = pad_token_id
|
|
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
|
if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
|
|
processing_class.image_processor.max_pixels = max_pixels
|
|
processing_class.image_processor.min_pixels = min_pixels
|
|
else:
|
|
processing_class = AutoTokenizer.from_pretrained(
|
|
model.config._name_or_path, padding_side="left"
|
|
)
|
|
pad_token_id = processing_class.pad_token_id
|
|
|
|
# Reward functions
|
|
if not isinstance(reward_funcs, list):
|
|
reward_funcs = [reward_funcs]
|
|
for i, reward_func in enumerate(reward_funcs):
|
|
if isinstance(reward_func, str):
|
|
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
|
reward_func, num_labels=1, **model_init_kwargs
|
|
)
|
|
self.reward_funcs = reward_funcs
|
|
|
|
# Reward processing class
|
|
if reward_processing_classes is None:
|
|
reward_processing_classes = [None] * len(reward_funcs)
|
|
elif not isinstance(reward_processing_classes, list):
|
|
reward_processing_classes = [reward_processing_classes]
|
|
else:
|
|
if len(reward_processing_classes) != len(reward_funcs):
|
|
raise ValueError(
|
|
"The number of reward processing classes must match the number of reward functions."
|
|
)
|
|
|
|
for i, (reward_processing_class, reward_func) in enumerate(
|
|
zip(reward_processing_classes, reward_funcs)
|
|
):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
if reward_processing_class is None:
|
|
reward_processing_class = AutoTokenizer.from_pretrained(
|
|
reward_func.config._name_or_path
|
|
)
|
|
if reward_processing_class.pad_token_id is None:
|
|
reward_processing_class.pad_token = (
|
|
reward_processing_class.eos_token
|
|
)
|
|
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
|
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
|
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
|
reward_processing_classes[i] = reward_processing_class
|
|
self.reward_processing_classes = reward_processing_classes
|
|
|
|
# Data collator
|
|
def data_collator(features): # No data collation is needed in GRPO
|
|
return features
|
|
|
|
# Training arguments
|
|
self.max_prompt_length = args.max_prompt_length
|
|
self.max_completion_length = (
|
|
args.max_completion_length
|
|
) # = |o_i| in the GRPO paper
|
|
self.num_generations = args.num_generations # = G in the GRPO paper
|
|
self.generation_config = GenerationConfig(
|
|
max_new_tokens=self.max_completion_length,
|
|
do_sample=True,
|
|
temperature=1, # HACK
|
|
num_return_sequences=self.num_generations,
|
|
pad_token_id=pad_token_id,
|
|
)
|
|
self.beta = args.beta
|
|
|
|
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
|
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
|
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
|
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
|
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
|
# This acts as a flag to indicate that the warning has already been issued.
|
|
model.warnings_issued["estimate_tokens"] = True
|
|
|
|
# Initialize the metrics
|
|
self._metrics = defaultdict(list)
|
|
self.use_vllm = args.use_vllm
|
|
|
|
# # rewrite the processing AutoTokenizer -> AutoProcessor
|
|
# model_id = model if isinstance(model, str) else model.config._name_or_path
|
|
# if processing_class is None:
|
|
# if "Qwen2-VL" in model_id or "Aria" in model_id:
|
|
# processing_class = AutoProcessor.from_pretrained(model_id)
|
|
# pad_token_id = processing_class.tokenizer.pad_token_id
|
|
# processing_class.pad_token_id = pad_token_id
|
|
# processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
|
# if "Qwen2-VL" in model_id:
|
|
# processing_class.image_processor.max_pixels = max_pixels
|
|
# processing_class.image_processor.min_pixels = min_pixels
|
|
# else:
|
|
# processing_class = AutoTokenizer.from_pretrained(
|
|
# model.config._name_or_path, padding_side="left"
|
|
# )
|
|
# pad_token_id = processing_class.pad_token_id
|
|
|
|
super().__init__(
|
|
model=model,
|
|
args=args,
|
|
data_collator=data_collator,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=processing_class,
|
|
callbacks=callbacks,
|
|
optimizers=optimizers,
|
|
)
|
|
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
|
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
|
# self.model_accepts_loss_kwargs to False to enable scaling.
|
|
self.model_accepts_loss_kwargs = False
|
|
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
|
num_processes = self.accelerator.num_processes
|
|
global_batch_size = args.per_device_train_batch_size * num_processes
|
|
possible_values = [
|
|
n_gen
|
|
for n_gen in range(2, global_batch_size + 1)
|
|
if (global_batch_size) % n_gen == 0
|
|
]
|
|
|
|
if self.num_generations not in possible_values:
|
|
raise ValueError(
|
|
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
|
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
|
f"batch size, the valid values for the number of generations are: {possible_values}."
|
|
)
|
|
if self.args.eval_strategy != "no":
|
|
global_batch_size = args.per_device_eval_batch_size * num_processes
|
|
possible_values = [
|
|
n_gen
|
|
for n_gen in range(2, global_batch_size + 1)
|
|
if (global_batch_size) % n_gen == 0
|
|
]
|
|
if self.num_generations not in possible_values:
|
|
raise ValueError(
|
|
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
|
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
|
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
|
)
|
|
|
|
if self.use_vllm:
|
|
if not is_vllm_available():
|
|
raise ImportError(
|
|
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
|
"`pip install vllm` to use it."
|
|
)
|
|
|
|
if self.accelerator.is_main_process:
|
|
vllm_device = self.args.vllm_device
|
|
if vllm_device == "auto":
|
|
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
|
# Check that the requested device is available
|
|
if (
|
|
vllm_device.split(":")[0] == "cuda"
|
|
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
|
|
):
|
|
raise ValueError(
|
|
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
|
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
|
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
|
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
|
)
|
|
# Check that the requested device is not also used for training
|
|
if vllm_device in {
|
|
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
|
|
}:
|
|
warnings.warn(
|
|
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
|
"behavior. It is recommended to use a dedicated device for vLLM."
|
|
)
|
|
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
|
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
|
# setting (profiling_patch).
|
|
world_size_patch = patch(
|
|
"torch.distributed.get_world_size", return_value=1
|
|
)
|
|
profiling_patch = patch(
|
|
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
|
|
return_value=None,
|
|
)
|
|
with world_size_patch, profiling_patch:
|
|
print("vllm is running on: ", vllm_device)
|
|
self.llm = LLM(
|
|
model=model.name_or_path,
|
|
device=vllm_device,
|
|
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
|
dtype=torch.bfloat16,
|
|
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
|
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
|
# This is particularly useful here because we generate completions from the same prompts.
|
|
enable_prefix_caching=True,
|
|
enforce_eager=True,
|
|
max_model_len=args.max_completion_length,
|
|
)
|
|
self.sampling_params = SamplingParams(
|
|
temperature=args.temperature,
|
|
max_tokens=self.max_completion_length,
|
|
)
|
|
|
|
self._last_loaded_step = (
|
|
0 # tag to avoid useless loading during grad accumulation
|
|
)
|
|
|
|
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
|
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
|
# synchronize all processes after vLLM has been fully initialized.
|
|
self.accelerator.wait_for_everyone()
|
|
else:
|
|
raise ValueError(
|
|
"Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
|
|
)
|
|
|
|
if self.ref_model is not None:
|
|
if self.is_deepspeed_enabled:
|
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
|
else:
|
|
self.ref_model = self.accelerator.prepare_model(
|
|
self.ref_model, evaluation_mode=True
|
|
)
|
|
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
self.reward_funcs[i] = self.accelerator.prepare_model(
|
|
reward_func, evaluation_mode=True
|
|
)
|
|
|
|
def _set_signature_columns_if_needed(self):
|
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
|
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
|
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
|
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
|
if self._signature_columns is None:
|
|
self._signature_columns = ["prompt"]
|
|
|
|
# We need a custom sampler that samples the same prompt multiple times
|
|
def _get_train_sampler(self):
|
|
return RepeatRandomSampler(self.train_dataset, self.num_generations)
|
|
|
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
|
def _get_per_token_logps(
|
|
self,
|
|
model,
|
|
input_ids,
|
|
attention_mask,
|
|
pixel_values,
|
|
image_grid_thw,
|
|
logits_to_keep,
|
|
):
|
|
pixel_values = pixel_values.to(model.device)
|
|
image_grid_thw = image_grid_thw.to(device=model.device)
|
|
logits = model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
pixel_values=pixel_values,
|
|
image_grid_thw=image_grid_thw,
|
|
).logits # (B, L, V)
|
|
logits = logits[
|
|
:, :-1, :
|
|
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
|
input_ids = input_ids[
|
|
:, -logits_to_keep:
|
|
] # (B, L-1), exclude the first input ID since we don't have logits for it
|
|
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
|
logits = logits[:, -logits_to_keep:]
|
|
per_token_logps = []
|
|
for logits_row, input_ids_row in zip(logits, input_ids):
|
|
log_probs = logits_row.log_softmax(dim=-1)
|
|
token_log_prob = torch.gather(
|
|
log_probs, dim=1, index=input_ids_row.unsqueeze(1)
|
|
).squeeze(1)
|
|
per_token_logps.append(token_log_prob)
|
|
return torch.stack(per_token_logps)
|
|
|
|
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
|
|
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
|
|
def _prepare_inputs(
|
|
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
|
device = self.accelerator.device
|
|
prompts = [x["prompt"] for x in inputs]
|
|
images = [x["image"] for x in inputs]
|
|
prompts_text = [
|
|
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
|
for example in inputs
|
|
]
|
|
prompt_inputs = self.processing_class(
|
|
# prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
|
text=prompts_text,
|
|
images=images,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
padding_side="left",
|
|
add_special_tokens=False,
|
|
)
|
|
prompt_ids, prompt_mask = (
|
|
prompt_inputs["input_ids"].to(device),
|
|
prompt_inputs["attention_mask"].to(device),
|
|
)
|
|
if self.max_prompt_length is not None:
|
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
|
|
|
if self.args.use_vllm:
|
|
# First, have main process load weights if needed
|
|
if self.state.global_step != self._last_loaded_step:
|
|
with unwrap_model_for_generation(
|
|
self.model,
|
|
self.accelerator,
|
|
gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
|
|
) as unwrapped_model:
|
|
if is_compiled_module(unwrapped_model):
|
|
state_dict = unwrapped_model._orig_mod.state_dict()
|
|
else:
|
|
state_dict = unwrapped_model.state_dict()
|
|
if self.accelerator.is_main_process:
|
|
llm_model = (
|
|
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
|
)
|
|
llm_model.load_weights(state_dict.items())
|
|
self._last_loaded_step = self.state.global_step
|
|
|
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
|
all_prompts_text = gather_object(prompts_text)
|
|
all_images = gather_object(images)
|
|
# group into pairs
|
|
all_multimodal_inputs = [
|
|
{"prompt": p, "multi_modal_data": {"image": i}}
|
|
for p, i in zip(all_prompts_text, all_images)
|
|
]
|
|
|
|
if self.accelerator.is_main_process:
|
|
outputs = self.llm.generate(
|
|
all_multimodal_inputs,
|
|
sampling_params=self.sampling_params,
|
|
use_tqdm=False,
|
|
)
|
|
completion_ids = [
|
|
out.token_ids
|
|
for completions in outputs
|
|
for out in completions.outputs
|
|
]
|
|
else:
|
|
completion_ids = [None] * len(all_prompts_text)
|
|
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
|
process_slice = slice(
|
|
self.accelerator.process_index * len(prompts),
|
|
(self.accelerator.process_index + 1) * len(prompts),
|
|
)
|
|
completion_ids = completion_ids[process_slice]
|
|
|
|
# Pad the completions, and concatenate them with the prompts
|
|
completion_ids = [
|
|
torch.tensor(ids, device=device) for ids in completion_ids
|
|
]
|
|
completion_ids = pad(
|
|
completion_ids, padding_value=self.processing_class.pad_token_id
|
|
)
|
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
else:
|
|
raise ValueError("Only vLLM generation is supported in this version ")
|
|
|
|
# below are the same with yifan's code
|
|
# Mask everything after the first EOS token
|
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
|
device = self.accelerator.device
|
|
eos_idx = torch.full(
|
|
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
|
|
)
|
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
|
|
is_eos.size(0), -1
|
|
)
|
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
|
|
|
# Concatenate prompt_mask with completion_mask for logit computation
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
|
# pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
|
|
# self.num_generations, dim=0
|
|
# )
|
|
|
|
pixel_values = prompt_inputs["pixel_values"]
|
|
# [None].repeat_interleave(self.num_generations, dim=0)
|
|
# pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
|
|
|
|
image_grid_thw = prompt_inputs["image_grid_thw"]
|
|
# .repeat_interleave(
|
|
# self.num_generations, dim=0
|
|
# )
|
|
logits_to_keep = completion_ids.size(1)
|
|
|
|
with torch.inference_mode():
|
|
if self.ref_model is not None:
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
self.ref_model,
|
|
prompt_completion_ids,
|
|
attention_mask,
|
|
pixel_values,
|
|
image_grid_thw,
|
|
logits_to_keep,
|
|
)
|
|
else:
|
|
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
|
ref_per_token_logps = self._get_per_token_logps(
|
|
self.model,
|
|
prompt_completion_ids,
|
|
attention_mask,
|
|
pixel_values,
|
|
image_grid_thw,
|
|
logits_to_keep,
|
|
)
|
|
|
|
# Decode the generated completions
|
|
completions = self.processing_class.batch_decode(
|
|
completion_ids, skip_special_tokens=True
|
|
)
|
|
if is_conversational(inputs[0]):
|
|
completions = [
|
|
[{"role": "assistant", "content": completion}]
|
|
for completion in completions
|
|
]
|
|
|
|
# Compute the rewards
|
|
rewards_per_func = torch.zeros(
|
|
len(prompts), len(self.reward_funcs), device=device
|
|
)
|
|
for i, (reward_func, reward_processing_class) in enumerate(
|
|
zip(self.reward_funcs, self.reward_processing_classes)
|
|
):
|
|
if isinstance(reward_func, PreTrainedModel):
|
|
if is_conversational(inputs[0]):
|
|
messages = [
|
|
{"messages": p + c} for p, c in zip(prompts, completions)
|
|
]
|
|
texts = [
|
|
apply_chat_template(x, reward_processing_class)["text"]
|
|
for x in messages
|
|
]
|
|
else:
|
|
texts = [p + c for p, c in zip(prompts, completions)]
|
|
reward_inputs = reward_processing_class(
|
|
texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
padding_side="right",
|
|
add_special_tokens=False,
|
|
)
|
|
reward_inputs = super()._prepare_inputs(reward_inputs)
|
|
with torch.inference_mode():
|
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
|
:, 0
|
|
] # Shape (B*G,)
|
|
else:
|
|
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
|
reward_kwargs = {
|
|
key: []
|
|
for key in inputs[0].keys()
|
|
if key not in ["prompt", "completion"]
|
|
}
|
|
for key in reward_kwargs:
|
|
for example in inputs:
|
|
# Repeat each value in the column for `num_generations` times
|
|
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
|
output_reward_func = reward_func(
|
|
prompts=prompts, completions=completions, **reward_kwargs
|
|
)
|
|
rewards_per_func[:, i] = torch.tensor(
|
|
output_reward_func, dtype=torch.float32, device=device
|
|
)
|
|
rewards_per_func = gather(rewards_per_func)
|
|
# Sum the rewards from all reward functions
|
|
rewards = rewards_per_func.sum(dim=1)
|
|
|
|
# Compute grouped-wise rewards
|
|
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
|
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
|
|
|
# Normalize the rewards to compute the advantages
|
|
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
|
self.num_generations, dim=0
|
|
)
|
|
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
|
self.num_generations, dim=0
|
|
)
|
|
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
|
|
|
# Slice to keep only the local part of the data
|
|
process_slice = slice(
|
|
self.accelerator.process_index * len(prompts),
|
|
(self.accelerator.process_index + 1) * len(prompts),
|
|
)
|
|
advantages = advantages[process_slice]
|
|
|
|
# Log the metrics
|
|
reward_per_func = rewards_per_func.mean(0)
|
|
for i, reward_func in enumerate(self.reward_funcs):
|
|
if isinstance(
|
|
reward_func, nn.Module
|
|
): # Module instead of PretrainedModel for compat with compiled models
|
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
|
else:
|
|
reward_func_name = reward_func.__name__
|
|
self._metrics[f"rewards/{reward_func_name}"].append(
|
|
reward_per_func[i].item()
|
|
)
|
|
|
|
self._metrics["reward"].append(rewards.mean().item())
|
|
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
|
|
|
return {
|
|
"prompt_ids": prompt_ids,
|
|
"prompt_mask": prompt_mask,
|
|
"completion_ids": completion_ids,
|
|
"completion_mask": completion_mask,
|
|
"ref_per_token_logps": ref_per_token_logps,
|
|
"advantages": advantages,
|
|
"pixel_values": pixel_values,
|
|
"image_grid_thw": image_grid_thw,
|
|
}
|
|
|
|
def compute_loss(
|
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
|
):
|
|
if return_outputs:
|
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
# Compute the per-token log probabilities for the model
|
|
|
|
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
|
completion_ids, completion_mask = (
|
|
inputs["completion_ids"],
|
|
inputs["completion_mask"],
|
|
)
|
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
pixel_values = inputs["pixel_values"]
|
|
image_grid_thw = inputs["image_grid_thw"]
|
|
logits_to_keep = completion_ids.size(
|
|
1
|
|
) # we only need to compute the logits for the completion tokens
|
|
|
|
per_token_logps = self._get_per_token_logps(
|
|
model,
|
|
input_ids,
|
|
attention_mask,
|
|
pixel_values,
|
|
image_grid_thw,
|
|
logits_to_keep,
|
|
)
|
|
|
|
# Compute the KL divergence between the model and the reference model
|
|
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
per_token_kl = (
|
|
torch.exp(ref_per_token_logps - per_token_logps)
|
|
- (ref_per_token_logps - per_token_logps)
|
|
- 1
|
|
)
|
|
|
|
# x - x.detach() allows for preserving gradients from x
|
|
advantages = inputs["advantages"]
|
|
per_token_loss = torch.exp(
|
|
per_token_logps - per_token_logps.detach()
|
|
) * advantages.unsqueeze(1)
|
|
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
|
loss = (
|
|
(per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
|
).mean()
|
|
|
|
# Log the metrics
|
|
completion_length = (
|
|
self.accelerator.gather_for_metrics(completion_mask.sum(1))
|
|
.float()
|
|
.mean()
|
|
.item()
|
|
)
|
|
self._metrics["completion_length"].append(completion_length)
|
|
|
|
mean_kl = (
|
|
(per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
|
).mean()
|
|
self._metrics["kl"].append(
|
|
self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
|
|
|
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
|
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
|
if next(iter(logs.keys())).startswith("eval_"):
|
|
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
|
|
|
logs = {**logs, **metrics}
|
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
super().log(logs, start_time)
|
|
else: # transformers<=4.46
|
|
super().log(logs)
|
|
self._metrics.clear() |