TRL documentation

GRPO With Replay Buffer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.27.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GRPO With Replay Buffer

This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that’ve been used to train a model in prior batches.

Usage

import torch
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
from datasets import load_dataset

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
    if torch.rand(1).item() < 0.25:
        return [0] * len(completions)  # simulate some None rewards
    else:
        return torch.rand(len(completions)).tolist()

training_args = GRPOWithReplayBufferConfig(
    output_dir="./tmp",
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    num_generations=4,
    max_completion_length=8,
    replay_buffer_size=8,
    report_to="none",
)

trainer = GRPOWithReplayBufferTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs=[custom_reward_func],
    args=training_args,
    train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

GRPOWithReplayBufferTrainer

class trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer

< >

( args: trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_config.GRPOWithReplayBufferConfig | None = None **kwargs )

train

< >

( resume_from_checkpoint: str | bool | None = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: list[str] | None = None )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

Main training entry point.

save_model

< >

( output_dir: str | None = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

GRPOWithReplayBufferConfig

class trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig

< >

( output_dir: str | None = None do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 gradient_accumulation_steps: int = 1 eval_accumulation_steps: int | None = None eval_delay: float = 0 torch_empty_cache_steps: int | None = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_ratio: float | None = None warmup_steps: float = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: str | None = None logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_total_limit: int | None = None enable_jit_checkpoint: bool = False save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False use_cpu: bool = False seed: int = 42 data_seed: int | None = None bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None local_rank: int = -1 ddp_backend: str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' dataloader_drop_last: bool = False eval_steps: float | None = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: int | None = None run_name: str | None = None disable_tqdm: bool | None = None remove_unused_columns: bool | None = False label_names: list[str] | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None deepspeed: dict | str | None = None label_smoothing_factor: float = 0.0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None group_by_length: bool = False length_column_name: str = 'length' report_to: None | str | list[str] = 'none' project: str = 'huggingface' trackio_space_id: str | None = 'trackio' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True push_to_hub: bool = False resume_from_checkpoint: str | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_token: str | None = None hub_private_repo: bool | None = None hub_always_push: bool = False hub_revision: str | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None include_for_metrics: list = <factory> eval_do_concat_batches: bool = True auto_find_batch_size: bool = False full_determinism: bool = False ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None include_num_input_tokens_seen: str | bool = 'no' neftune_noise_alpha: float | None = None optim_target_modules: None | str | list[str] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None eval_use_gather_object: bool = False average_tokens_across_devices: bool = True use_cache: bool = False model_init_kwargs: dict | str | None = None disable_dropout: bool = False cast_lm_head_to_fp32: bool = False num_generations: int | None = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool | None = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict | None = None chat_template_kwargs: dict | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_structured_outputs_regex: str | None = None vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_group_port: int = 51216 vllm_gpu_memory_utilization: float = 0.3 vllm_max_model_length: int | None = None vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: float | None = None epsilon_high: float | None = None sapo_temperature_neg: float = 1.05 sapo_temperature_pos: float = 1.0 importance_sampling_level: str = 'token' reward_weights: list[float] | None = None multi_objective_aggregation: str = 'sum_then_normalize' scale_rewards: str = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 max_tool_calling_iterations: int | None = None vllm_importance_sampling_correction: bool = True vllm_importance_sampling_mode: str = 'sequence_mask' vllm_importance_sampling_cap: float = 3.0 off_policy_mask_threshold: float | None = None use_bias_correction_kl: bool = False log_completions: bool = False num_completions_to_print: int | None = None log_unique_prompts: bool = False replay_buffer_size: int = 64 )

New Parameters: replay_buffer_size (int, optional, defaults to 0): A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer.

ReplayBuffer

class trl.experimental.grpo_with_replay_buffer.ReplayBuffer

< >

( max_size: int )

A simple replay buffer to store and sample previously seen rollouts.

Update on GitHub