286 lines
15 KiB
Python
286 lines
15 KiB
Python
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|||
|
#
|
|||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
# you may not use this file except in compliance with the License.
|
|||
|
# You may obtain a copy of the License at
|
|||
|
#
|
|||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
#
|
|||
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
# See the License for the specific language governing permissions and
|
|||
|
# limitations under the License.
|
|||
|
|
|||
|
from dataclasses import dataclass, field
|
|||
|
from typing import Optional
|
|||
|
|
|||
|
from transformers import TrainingArguments
|
|||
|
|
|||
|
|
|||
|
@dataclass
|
|||
|
class GRPOConfig(TrainingArguments):
|
|||
|
r"""
|
|||
|
Configuration class for the [`GRPOTrainer`].
|
|||
|
|
|||
|
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
|||
|
[`~transformers.TrainingArguments`] documentation.
|
|||
|
|
|||
|
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
|||
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
|||
|
command line.
|
|||
|
|
|||
|
Parameters:
|
|||
|
> Parameters that control the model and reference model
|
|||
|
|
|||
|
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
|||
|
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
|||
|
argument of the [`GRPOTrainer`] is provided as a string.
|
|||
|
|
|||
|
> Parameters that control the data preprocessing
|
|||
|
|
|||
|
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
|||
|
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
|||
|
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
|||
|
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
|||
|
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
|||
|
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
|||
|
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
|||
|
must be divisible by this value.
|
|||
|
temperature (`float`, *optional*, defaults to `0.9`):
|
|||
|
Temperature for sampling. The higher the temperature, the more random the completions.
|
|||
|
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
|||
|
Maximum length of the generated completion.
|
|||
|
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
|||
|
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
|||
|
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
|||
|
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
|||
|
with vLLM generation.
|
|||
|
|
|||
|
> Parameters that control generation acceleration powered by vLLM
|
|||
|
|
|||
|
use_vllm (`bool`, *optional*, defaults to `False`):
|
|||
|
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
|||
|
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
|||
|
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
|||
|
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
|||
|
automatically select the next available GPU after the last one used for training. This assumes that
|
|||
|
training has not already occupied all available GPUs. If only one device is available, the device will be
|
|||
|
shared between both training and vLLM.
|
|||
|
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
|||
|
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
|||
|
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
|||
|
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
|||
|
during initialization.
|
|||
|
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
|||
|
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
|||
|
based on the model configuration. Find the supported values in the vLLM documentation.
|
|||
|
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
|||
|
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
|||
|
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
|||
|
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
|||
|
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
|
|||
|
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
|
|||
|
support this feature.
|
|||
|
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
|||
|
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
|||
|
|
|||
|
> Parameters that control the training
|
|||
|
|
|||
|
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
|||
|
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
|||
|
[`~transformers.TrainingArguments`].
|
|||
|
beta (`float`, *optional*, defaults to `0.04`):
|
|||
|
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
|||
|
speed.
|
|||
|
num_iterations (`int`, *optional*, defaults to `1`):
|
|||
|
Number of iterations per batch (denoted as μ in the algorithm).
|
|||
|
epsilon (`float`, *optional*, defaults to `0.2`):
|
|||
|
Epsilon value for clipping.
|
|||
|
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
|||
|
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
|||
|
weighted equally with weight `1.0`.
|
|||
|
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
|||
|
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
|||
|
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
|||
|
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
|||
|
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
|||
|
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
|||
|
between the current policy and the previous reference policy during updates. The reference policy is
|
|||
|
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
|||
|
must set `sync_ref_model=True`.
|
|||
|
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
|
|||
|
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
|||
|
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
|||
|
set `sync_ref_model=True`.
|
|||
|
|
|||
|
> Parameters that control the logging
|
|||
|
log_completions (`bool`, *optional*, defaults to `False`):
|
|||
|
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
|
|||
|
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
|
|||
|
"""
|
|||
|
|
|||
|
# Parameters that control the model and reference model
|
|||
|
model_init_kwargs: Optional[dict] = field(
|
|||
|
default=None,
|
|||
|
metadata={
|
|||
|
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
|||
|
"argument of the `GRPOTrainer` is provided as a string."
|
|||
|
},
|
|||
|
)
|
|||
|
|
|||
|
# Parameters that control the data preprocessing
|
|||
|
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
|||
|
# additional columns to compute the reward
|
|||
|
remove_unused_columns: Optional[bool] = field(
|
|||
|
default=False,
|
|||
|
metadata={
|
|||
|
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
|||
|
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
|||
|
},
|
|||
|
)
|
|||
|
max_prompt_length: Optional[int] = field(
|
|||
|
default=512,
|
|||
|
metadata={
|
|||
|
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
|||
|
},
|
|||
|
)
|
|||
|
num_generations: Optional[int] = field(
|
|||
|
default=8,
|
|||
|
metadata={
|
|||
|
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
|||
|
"must be divisible by this value."
|
|||
|
},
|
|||
|
)
|
|||
|
temperature: Optional[float] = field(
|
|||
|
default=0.9,
|
|||
|
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
|||
|
)
|
|||
|
max_completion_length: Optional[int] = field(
|
|||
|
default=256,
|
|||
|
metadata={"help": "Maximum length of the generated completion."},
|
|||
|
)
|
|||
|
ds3_gather_for_generation: bool = field(
|
|||
|
default=True,
|
|||
|
metadata={
|
|||
|
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
|||
|
"generation, improving generation speed. However, disabling this option allows training models that "
|
|||
|
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
|||
|
"is not compatible with vLLM generation."
|
|||
|
},
|
|||
|
)
|
|||
|
|
|||
|
# Parameters that control generation acceleration powered by vLLM
|
|||
|
use_vllm: Optional[bool] = field(
|
|||
|
default=False,
|
|||
|
metadata={
|
|||
|
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
|||
|
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
|||
|
"(`pip install vllm`)."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_device: Optional[str] = field(
|
|||
|
default="auto",
|
|||
|
metadata={
|
|||
|
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
|||
|
"will automatically select the next available GPU after the last one used for training. This assumes "
|
|||
|
"that training has not already occupied all available GPUs."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_gpu_memory_utilization: float = field(
|
|||
|
default=0.9,
|
|||
|
metadata={
|
|||
|
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
|||
|
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
|||
|
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
|||
|
"out-of-memory (OOM) errors during initialization."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_dtype: Optional[str] = field(
|
|||
|
default="auto",
|
|||
|
metadata={
|
|||
|
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
|||
|
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_max_model_len: Optional[int] = field(
|
|||
|
default=None,
|
|||
|
metadata={
|
|||
|
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
|||
|
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
|||
|
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_enable_prefix_caching: Optional[bool] = field(
|
|||
|
default=True,
|
|||
|
metadata={
|
|||
|
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
|
|||
|
"the hardware support this feature."
|
|||
|
},
|
|||
|
)
|
|||
|
vllm_guided_decoding_regex: Optional[str] = field(
|
|||
|
default=None,
|
|||
|
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
|||
|
)
|
|||
|
|
|||
|
# Parameters that control the training
|
|||
|
learning_rate: float = field(
|
|||
|
default=1e-6,
|
|||
|
metadata={
|
|||
|
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
|||
|
"`transformers.TrainingArguments`."
|
|||
|
},
|
|||
|
)
|
|||
|
beta: float = field(
|
|||
|
default=0.04,
|
|||
|
metadata={
|
|||
|
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
|||
|
"training speed."
|
|||
|
},
|
|||
|
)
|
|||
|
num_iterations: int = field(
|
|||
|
default=1,
|
|||
|
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
|||
|
)
|
|||
|
epsilon: float = field(
|
|||
|
default=0.2,
|
|||
|
metadata={"help": "Epsilon value for clipping."},
|
|||
|
)
|
|||
|
reward_weights: Optional[list[float]] = field(
|
|||
|
default=None,
|
|||
|
metadata={
|
|||
|
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
|||
|
"rewards are weighted equally with weight `1.0`."
|
|||
|
},
|
|||
|
)
|
|||
|
sync_ref_model: bool = field(
|
|||
|
default=False,
|
|||
|
metadata={
|
|||
|
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
|||
|
"steps, using the `ref_model_mixup_alpha` parameter."
|
|||
|
},
|
|||
|
)
|
|||
|
ref_model_mixup_alpha: float = field(
|
|||
|
default=0.6,
|
|||
|
metadata={
|
|||
|
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
|||
|
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
|||
|
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
|||
|
},
|
|||
|
)
|
|||
|
ref_model_sync_steps: int = field(
|
|||
|
default=512,
|
|||
|
metadata={
|
|||
|
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
|||
|
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
|||
|
},
|
|||
|
)
|
|||
|
|
|||
|
# Parameters that control the logging
|
|||
|
log_completions: bool = field(
|
|||
|
default=False,
|
|||
|
metadata={
|
|||
|
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
|||
|
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
|||
|
},
|
|||
|
)
|