# Asynchronous GRPO

> [!IMPORTANT]
> This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not).
>
> Currently, `vllm` and `transformers` have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers:
>
> ```bash
> pip install 'vllm>=0.17.1'
> pip install 'transformers>=5.2.0' --no-deps
> ```

## Overview

`AsyncGRPOTrainer` implements the same [GRPO](grpo_trainer) algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [GRPOTrainer](/docs/trl/main/en/gspo_token#trl.GRPOTrainer) — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the [GRPO Trainer](grpo_trainer) documentation. Not all features from [GRPOTrainer](/docs/trl/main/en/gspo_token#trl.GRPOTrainer) are available; refer to `AsyncGRPOConfig` for the supported parameters.

This trainer was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Amine Dirhoussi](https://huggingface.co/aminediroHF).

## How it differs from [GRPOTrainer](/docs/trl/main/en/gspo_token#trl.GRPOTrainer)

In the standard [GRPOTrainer](/docs/trl/main/en/gspo_token#trl.GRPOTrainer), generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in [vLLM colocate mode](grpo_trainer#speed-up-training-with-vllm), where generation runs on the same GPUs, one phase must finish before the other begins.

`AsyncGRPOTrainer` separates these two concerns:

- **Rollout worker** (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue.
- **Training loop** (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights.

After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy.

Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded.

The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.

## Quick start

```python
# train_async_grpo.py
from datasets import load_dataset
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen3-4B",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
```

The vLLM server and the trainer must run on **separate GPUs**. Use `CUDA_VISIBLE_DEVICES` to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows:

```bash
# Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required)
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
    --max-model-len 4096 \
    --logprobs-mode processed_logprobs \
    --weight-transfer-config '{"backend":"nccl"}'
```

> [!TIP]
> Set `--max-model-len` to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus `max_completion_length` from your config.

```bash
# Terminal 2: training on GPU 1
CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py
```

## Design philosophy

This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand.

## AsyncGRPOConfig[[trl.experimental.async_grpo.AsyncGRPOConfig]]

#### trl.experimental.async_grpo.AsyncGRPOConfig[[trl.experimental.async_grpo.AsyncGRPOConfig]]

[Source](https://github.com/huggingface/trl/blob/main/trl/experimental/async_grpo/async_grpo_config.py#L21)

Configuration class for the `AsyncGRPOTrainer`.

This class includes only the parameters that are specific to asynchronous GRPO training. For a full list of
training arguments, please refer to the [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments) documentation. Note that default values
in this class may differ from those in [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments).

> [!NOTE]
> These parameters have default values different from [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments):
> - `logging_steps`: Defaults to `10` instead of `500`.
> - `gradient_checkpointing`: Defaults to `True` instead of `False`.
> - `bf16`: Defaults to `True` if `fp16` is not set, instead of `False`.
> - `learning_rate`: Defaults to `1e-6` instead of `5e-5`.

## AsyncGRPOTrainer[[trl.experimental.async_grpo.AsyncGRPOTrainer]]

#### trl.experimental.async_grpo.AsyncGRPOTrainer[[trl.experimental.async_grpo.AsyncGRPOTrainer]]

[Source](https://github.com/huggingface/trl/blob/main/trl/experimental/async_grpo/async_grpo_trainer.py#L169)

Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language
Models](https://huggingface.co/papers/2402.03300). This trainer is the asynchronous version of GRPO, where
generation is offloaded to an external vLLM server that runs asynchronously alongside training, decoupling rollout
from the gradient update loop.

Example:

```python
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward
from datasets import load_dataset

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
```

**Parameters:**

model (`str`) : Model to be trained. Must be a string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [save_pretrained](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.save_pretrained), e.g., `'./my_model_directory/'`. The model is loaded using [from_pretrained](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained). The model name is also used to identify the model on the vLLM server used for generation.

reward_funcs (`RewardFunc | list[RewardFunc]`) : Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:  - A single reward function: The function is provided with the prompts and the generated completions, plus any additional columns in the dataset. It should return a list of rewards. Reward functions can be either synchronous or asynchronous and can also return `None` when the reward is not applicable to those samples. This is useful for multi-task training where different reward functions apply to different types of samples. When a reward function returns `None` for a sample, that reward function is excluded from the reward calculation for that sample. For more details, see [Using a custom reward function](#using-a-custom-reward-function). - A list of reward functions, where each item is a reward function as described above. Rewards from all functions are summed.

args (`AsyncGRPOConfig`, *optional*) : Configuration for this trainer. If `None`, a default configuration is used.

train_dataset ([Dataset](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset) or [IterableDataset](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset)) : Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset are ignored. The format of the samples can be either:  - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content).

processing_class ([PreTrainedTokenizerBase](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase), *optional*) : Processing class used to process the data. The padding side must be set to `"left"`. If `None`, the processing class is loaded from the model's name with [from_pretrained](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained). A padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.

callbacks (list of [TrainerCallback](https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback), *optional*) : List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).  If you want to remove one of the default callbacks used, use the [remove_callback](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.remove_callback) method.

optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`) : A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your model and a scheduler given by [get_linear_schedule_with_warmup](https://huggingface.co/docs/transformers/main/en/main_classes/optimizer_schedules#transformers.get_linear_schedule_with_warmup) controlled by `args`.

tools (list of `Callable`, *optional*) : A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool should be a standard Python function with properly type-hinted arguments and return values, and a Google-style docstring describing its purpose, arguments, and return value. For more details, see: https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool use and that it has been fine-tuned for tool calling.

environment_factory (`EnvironmentFactory`, *optional*) : A callable that creates and returns an environment instance. The environment class should define methods that can be invoked as tools during generation. Each method should comply with the same requirements as the `tools` described above. If `environment_factory` is provided, an instance of the environment is created for each generation in the batch, allowing for parallel and independent interactions. The environment must also implement a callable `reset` method that can be used to reset state between generations. The `reset` method should return either `None` or a string: when it returns a string, that string is appended to the last user message before generation. This feature is experimental and may change or be removed at any time without prior notice.

