286 lines
15 KiB
Python
286 lines
15 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional
|
||
|
||
from transformers import TrainingArguments
|
||
|
||
|
||
@dataclass
|
||
class GRPOConfig(TrainingArguments):
|
||
r"""
|
||
Configuration class for the [`GRPOTrainer`].
|
||
|
||
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
||
[`~transformers.TrainingArguments`] documentation.
|
||
|
||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||
command line.
|
||
|
||
Parameters:
|
||
> Parameters that control the model and reference model
|
||
|
||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||
argument of the [`GRPOTrainer`] is provided as a string.
|
||
|
||
> Parameters that control the data preprocessing
|
||
|
||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
||
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
||
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
||
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
||
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
||
must be divisible by this value.
|
||
temperature (`float`, *optional*, defaults to `0.9`):
|
||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
||
Maximum length of the generated completion.
|
||
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
||
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
||
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
||
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
||
with vLLM generation.
|
||
|
||
> Parameters that control generation acceleration powered by vLLM
|
||
|
||
use_vllm (`bool`, *optional*, defaults to `False`):
|
||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
||
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
||
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
||
automatically select the next available GPU after the last one used for training. This assumes that
|
||
training has not already occupied all available GPUs. If only one device is available, the device will be
|
||
shared between both training and vLLM.
|
||
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
||
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
||
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
||
during initialization.
|
||
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
||
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
||
based on the model configuration. Find the supported values in the vLLM documentation.
|
||
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
||
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
||
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
||
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
|
||
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
|
||
support this feature.
|
||
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
||
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
||
|
||
> Parameters that control the training
|
||
|
||
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
||
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
||
[`~transformers.TrainingArguments`].
|
||
beta (`float`, *optional*, defaults to `0.04`):
|
||
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
||
speed.
|
||
num_iterations (`int`, *optional*, defaults to `1`):
|
||
Number of iterations per batch (denoted as μ in the algorithm).
|
||
epsilon (`float`, *optional*, defaults to `0.2`):
|
||
Epsilon value for clipping.
|
||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||
weighted equally with weight `1.0`.
|
||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
||
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
||
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
||
between the current policy and the previous reference policy during updates. The reference policy is
|
||
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
||
must set `sync_ref_model=True`.
|
||
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
|
||
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
||
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
||
set `sync_ref_model=True`.
|
||
|
||
> Parameters that control the logging
|
||
log_completions (`bool`, *optional*, defaults to `False`):
|
||
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
|
||
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
|
||
"""
|
||
|
||
# Parameters that control the model and reference model
|
||
model_init_kwargs: Optional[dict] = field(
|
||
default=None,
|
||
metadata={
|
||
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
||
"argument of the `GRPOTrainer` is provided as a string."
|
||
},
|
||
)
|
||
|
||
# Parameters that control the data preprocessing
|
||
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
||
# additional columns to compute the reward
|
||
remove_unused_columns: Optional[bool] = field(
|
||
default=False,
|
||
metadata={
|
||
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
||
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
||
},
|
||
)
|
||
max_prompt_length: Optional[int] = field(
|
||
default=512,
|
||
metadata={
|
||
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
||
},
|
||
)
|
||
num_generations: Optional[int] = field(
|
||
default=8,
|
||
metadata={
|
||
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
||
"must be divisible by this value."
|
||
},
|
||
)
|
||
temperature: Optional[float] = field(
|
||
default=0.9,
|
||
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
||
)
|
||
max_completion_length: Optional[int] = field(
|
||
default=256,
|
||
metadata={"help": "Maximum length of the generated completion."},
|
||
)
|
||
ds3_gather_for_generation: bool = field(
|
||
default=True,
|
||
metadata={
|
||
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
||
"generation, improving generation speed. However, disabling this option allows training models that "
|
||
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
||
"is not compatible with vLLM generation."
|
||
},
|
||
)
|
||
|
||
# Parameters that control generation acceleration powered by vLLM
|
||
use_vllm: Optional[bool] = field(
|
||
default=False,
|
||
metadata={
|
||
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
||
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
||
"(`pip install vllm`)."
|
||
},
|
||
)
|
||
vllm_device: Optional[str] = field(
|
||
default="auto",
|
||
metadata={
|
||
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
||
"will automatically select the next available GPU after the last one used for training. This assumes "
|
||
"that training has not already occupied all available GPUs."
|
||
},
|
||
)
|
||
vllm_gpu_memory_utilization: float = field(
|
||
default=0.9,
|
||
metadata={
|
||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||
"out-of-memory (OOM) errors during initialization."
|
||
},
|
||
)
|
||
vllm_dtype: Optional[str] = field(
|
||
default="auto",
|
||
metadata={
|
||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||
},
|
||
)
|
||
vllm_max_model_len: Optional[int] = field(
|
||
default=None,
|
||
metadata={
|
||
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||
},
|
||
)
|
||
vllm_enable_prefix_caching: Optional[bool] = field(
|
||
default=True,
|
||
metadata={
|
||
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
|
||
"the hardware support this feature."
|
||
},
|
||
)
|
||
vllm_guided_decoding_regex: Optional[str] = field(
|
||
default=None,
|
||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||
)
|
||
|
||
# Parameters that control the training
|
||
learning_rate: float = field(
|
||
default=1e-6,
|
||
metadata={
|
||
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
||
"`transformers.TrainingArguments`."
|
||
},
|
||
)
|
||
beta: float = field(
|
||
default=0.04,
|
||
metadata={
|
||
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
||
"training speed."
|
||
},
|
||
)
|
||
num_iterations: int = field(
|
||
default=1,
|
||
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
||
)
|
||
epsilon: float = field(
|
||
default=0.2,
|
||
metadata={"help": "Epsilon value for clipping."},
|
||
)
|
||
reward_weights: Optional[list[float]] = field(
|
||
default=None,
|
||
metadata={
|
||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
||
"rewards are weighted equally with weight `1.0`."
|
||
},
|
||
)
|
||
sync_ref_model: bool = field(
|
||
default=False,
|
||
metadata={
|
||
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
||
"steps, using the `ref_model_mixup_alpha` parameter."
|
||
},
|
||
)
|
||
ref_model_mixup_alpha: float = field(
|
||
default=0.6,
|
||
metadata={
|
||
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
||
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
||
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
||
},
|
||
)
|
||
ref_model_sync_steps: int = field(
|
||
default=512,
|
||
metadata={
|
||
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
||
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
||
},
|
||
)
|
||
|
||
# Parameters that control the logging
|
||
log_completions: bool = field(
|
||
default=False,
|
||
metadata={
|
||
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
||
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
||
},
|
||
) |