Fortrain/qw/open_r1/trainer/grpo_config.py

286 lines
15 KiB
Python
Raw Permalink Normal View History

2025-03-31 15:56:36 +08:00
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class GRPOConfig(TrainingArguments):
r"""
Configuration class for the [`GRPOTrainer`].
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
[`~transformers.TrainingArguments`] documentation.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model and reference model
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`GRPOTrainer`] is provided as a string.
> Parameters that control the data preprocessing
remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
with vLLM generation.
> Parameters that control generation acceleration powered by vLLM
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
training has not already occupied all available GPUs. If only one device is available, the device will be
shared between both training and vLLM.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
support this feature.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
> Parameters that control the logging
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
"""
# Parameters that control the model and reference model
model_init_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
"argument of the `GRPOTrainer` is provided as a string."
},
)
# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
default=8,
metadata={
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
max_completion_length: Optional[int] = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
"is not compatible with vLLM generation."
},
)
# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
"unused for training, as vLLM will require one for generation. vLLM must be installed "
"(`pip install vllm`)."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
"will automatically select the next available GPU after the last one used for training. This assumes "
"that training has not already occupied all available GPUs."
},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
vllm_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
vllm_enable_prefix_caching: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
"the hardware support this feature."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
# Parameters that control the training
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
beta: float = field(
default=0.04,
metadata={
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
"training speed."
},
)
num_iterations: int = field(
default=1,
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
"rewards are weighted equally with weight `1.0`."
},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.6,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
},
)
ref_model_sync_steps: int = field(
default=512,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
# Parameters that control the logging
log_completions: bool = field(
default=False,
metadata={
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
},
)