Extend to other RL(HF) algorithms
=================================

Last updated: 02/25/2025.

We already implemented the complete training pipeline of the PPO
algorithms. To extend to other algorithms, we analyze the high-level
principle to use verl and provide a tutorial to implement the DPO
algorithm. Users can follow the similar paradigm to extend to other RL algorithms.

.. note:: **Key ideas**: Single process drives multi-process computation and data communication.

Overall Approach
----------------

Step 1: Consider what multi-machine multi-GPU computations are needed
for each model, such as ``generate_sequence`` , ``compute_log_prob`` and
``update_policy`` in the actor_rollout model. Implement distributed
single-process-multiple-data (SPMD) computation and encapsulate them
into APIs

Step 2: Based on different distributed scenarios, including FSDP and 3D
parallelism in Megatron-LM, implement single-process control of data
interaction among multi-process computations.

Step 3: Utilize the encapsulated APIs to implement the control flow

Example: Online DPO
-------------------

We use verl to implement a simple online DPO algorithm. The algorithm
flow of Online DPO is as follows:

1. There is a prompt (rollout) generator which has the same weight as
   the actor model. After a batch of prompts are fed into the generator,
   it generates N responses for each prompt.
2. Send all the prompts + responses to a verifier for scoring, which can
   be reward model or a rule-based function. Then sort them in pairs to
   form a training batch.
3. Use this training batch to train the actor model using DPO. During
   the process, a reference policy is needed.

Step 1: What are the multi-machine multi-GPU computations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

**Sample Generator**

Implementation details:

.. code:: python

   from verl.single_controller.base import Worker
   from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
   import ray

   @ray.remote
   class SampleGenerator(Worker):
       def __init__(self, config):
           super().__init__()
           self.config = config
           
       def generate_sequences(self, data):
           pass

Here, ``SampleGenerator`` can be viewed as a multi-process pulled up by
``torchrun``, with each process running the same code (SPMD).
``SampleGenerator`` needs to implement a ``generate_sequences`` API for
the control flow to call. The implementation details inside can use any
inference engine including vllm, sglang and huggingface. Users can
largely reuse the code in
verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't
go into details here.

**ReferencePolicy inference**

API: compute reference log probability

.. code:: python

   from verl.single_controller.base import Worker
   import ray

   @ray.remote
   class ReferencePolicy(Worker):
       def __init__(self):
           super().__init__()
           self.model = Model()
           
       def infer(self, data):
           return self.model(data)

**Actor update**

API: Update actor model parameters

.. code:: python

   from verl.single_controller.base import Worker
   import ray

   @ray.remote
   class DPOActor(Worker):
       def __init__(self):
           super().__init__()
           self.model = Model()
           self.model = FSDP(self.model)  # or other distributed strategy
           self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
           self.loss_fn = xxx
           
       def update(self, data):
           self.optimizer.zero_grad()
           logits = self.model(data)
           loss = self.loss_fn(logits)
           loss.backward()
           self.optimizer.step()

**Notes: How to distinguish between control processes and distributed computation processes**
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Control processes are generally functions directly decorated with
  ``@ray.remote``
- Computation processes are all wrapped into a ``RayWorkerGroup``.

Users can reuse most of the distribtued computation logics implemented
in PPO algorithm, including FSDP and Megatron-LM backend in
verl/verl/trainer/ppo.

Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

**The core problem to solve here is how a single process sends data to
multiple processes, drives multi-process computation, and how the
control process obtains the results of multi-process computation.**
First, we initialize the multi-process ``WorkerGroup`` in the control
process.

.. code:: python

   @ray.remote(num_cpus=1)
   def main_task(config):
       # construct SampleGenerator
       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
       # put SampleGenerator onto resource pool
       worker_group = RayWorkerGroup(resource_pool, ray_cls)
       
       # construct reference policy

As we can see, in the control process, multiple processes are wrapped
into a ``RayWorkerGroup``. Inside this ``WorkerGroup``, there is a
``self._workers`` member, where each worker is a RayActor
(XXXX) of SampleGenerator.
ray_trainer.md also provide an implementation of
``MegatronRayWorkerGroup``.

Assuming the model is distributed using FSDP, and there is a batch of
data on the control process, for data parallelism, the underlying
calling process is:

.. code:: python

   data = xxx
   data_list = data.chunk(dp_size)

   output = []
   for d in data_list:
       # worker_group._workers[i] is a SampleGenerator
       output.append(worker_group._workers[i].generate_sequences.remote(d))

   output = ray.get(output)
   output = torch.cat(output)

Single process calling multiple processes involves the following 3
steps:

1. Split the data into DP parts on the control process.
2. Send the data to remote, call the remote computation through RPC, and
   utilize multi-process computation.
3. Obtain the computation results of each worker on the control process
   and merge them.

Frequently calling these 3 steps on the controller process greatly hurts
code readability. **In verl, we have abstracted and encapsulated these 3
steps, so that the worker's method + dispatch + collect can be
registered into the worker_group**

.. code:: python

   from verl.single_controller.base.decorator import register

   def dispatch_data(worker_group, data):
       return data.chunk(worker_group.world_size)
       
   def collect_data(worker_group, data):
       return torch.cat(data)

   dispatch_mode = {
       'dispatch_fn': dispatch_data,
       'collect_fn': collect_data
   }

   @register(dispatch_mode=dispatch_mode)
   def generate_sequences(self, data):
       pass

In this way, we can directly call the method inside the worker through
the ``worker_group`` on the control (driver) process (which is a single
process):

.. code:: python

   output = worker_group.generate_sequences(data)

This single line includes data splitting, data distribution and
computation, and data collection.

Furthermore, the model parallelism size of each model is usually fixed,
including dp, tp, pp. So for these common distributed scenarios, we have
pre-implemented specific dispatch and collect methods,in `decorator.py XXXX`_, which can be directly used to wrap the computations.

.. code:: python

   from verl.single_controller.base.decorator import register, Dispatch

   @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
   def generate_sequences(self, data: DataProto) -> DataProto:
       pass

Here it requires the data interface to be ``DataProto``. Definition of
``DataProto`` is in `protocol.py XXXX`_.

Step 3: Main training loop
~~~~~~~~~~~~~~~~~~~~~~~~~~

With the above training flows, we can implement the algorithm's control
flow. It is recommended that ``main_task`` is also a ray remote process.

.. code:: python

   @ray.remote(num_cpus=1)
   def main_task(config):
       # construct SampleGenerator
       resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
       ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) 
       # put SampleGenerator onto resource pool
       sample_gen = RayWorkerGroup(resource_pool, ray_cls)
       
       # construct reference policy
       ray_cls = RayClassWithInitArgs(ReferencePolicy)
       ref_policy = RayWorkerGroup(resource_pool, ray_cls)
       
       # construct actor
       ray_cls = RayClassWithInitArgs(DPOActor)  
       dpo_policy = RayWorkerGroup(resource_pool, ray_cls)
       
       dataloader = DataLoader()
       
       for data in dataloader:
           # generate data
           data = sample_gen.generate_sequences(data)
           # generate scores for each data 
           data = generate_scores(data)
           # generate pairwise data using scores
           data = generate_pairwise_data(data)
           # generate ref_log_prob
           data.batch['ref_log_prob'] = ref_policy.infer(data)
           # update using dpo
           dpo_policy.update(data)
           # logging

Here, different ``WorkerGroups`` can be placed in the same resource pool or
in different resource pools using ``create_colocated_worker_cls``
similar as in `ray_trainer.py XXXX`_.
