                                                      
                                                                 
from calendar import firstweekday
from collections import defaultdict
from contextlib import nullcontext
from datetime import datetime
import asyncio
import logging
import os
import sys
import time
import random
import copy
import psutil
import traceback
from queue import SimpleQueue
from typing_extensions import override
from typing import Callable, List, Any, Dict, Tuple, Set
from packaging.version import Version

                                     
from megatron.training.training import _TRAIN_START_TIME

from tqdm import tqdm
import numpy as np
import torch
import torch.distributed
from transformers import AutoConfig

from megatron.core import package_info
from megatron.core import mpu
from megatron.core import parallel_state
from megatron.core.distributed import finalize_model_grads
from megatron.core.utils import get_model_config
from megatron.core.utils import divide
from megatron.training import get_tokenizer
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.training import (
    print_datetime,
    save_checkpoint_and_time,
    get_model,
    setup_model_and_optimizer,
)
from megatron.training.utils import (
    calc_params_l2_norm,
    print_rank_0,
    unwrap_model,
)
from megatron.training.global_vars import (
    get_args,
    get_timers,
    get_tensorboard_writer,
    get_wandb_writer,
)
try:
    from megatron.training.global_vars import get_energy_monitor
except ImportError:
    get_energy_monitor = None

from megatron.core.num_microbatches_calculator import (
    get_num_microbatches, )
from megatron.training.training import build_train_valid_test_data_iterators

from gpatch.training.utils import print_with_rank_and_datetime, print_meminfo_str
from gpatch.training.arguments import validate_rl_args
from gpatch.core.ppo_helper import (
    calculate_advantages_and_returns,
    calculate_ppo_rewards,
    calculate_kl_penalty,
    create_mask,
    calculate_grpo_advantages,
)
from gpatch.core.utils import get_nvml_memory_info, print_memory_tracking
from gpatch.core.aligner_helper import (
    cpu_weight_swap_v2,
    masked_mean_list,
    masked_global_statistics_list,
    normalize_tensor,
    clear_memory,
    get_iterator_k_split_list,
    cpu_dict,
    retrieve_model_state_dict_in_cpu,
)
from gpatch.training.global_vars import set_global_variables
from gpatch.core.wecube import (init_wecube_reporter, report_ppo_metrics)
from gpatch.core.parallel_state import (
    init_pg,
    is_mp_and_cp_head,
    cpu_barrier,
)
from gpatch.core.swap import (
    offload_megatron_model,
    onload_megatron_model,
    offload_megatron_optimizer,
    onload_megatron_optimizer,
)
from gpatch.core.aligner_helper import (
    broadcast_object_within_mp_and_cp,
    expand_rollout_batches,
)
from gpatch.core.models.gpt.weight_conversion.cpu_memory_model import (
    CpuMemoryModel,
)
from gpatch.rpc.monitor import (
    start_monitor_client_in_background,
    start_monitor_server_in_background,
    set_exit_flag,
    mark_made_progress,
)
from gpatch.core.utils import check_rollout_batches, list_for_tensor_tolist
from gpatch.core.smart_pad_helper import smart_pad_train_get_reorder_rollout_batches
from gpatch.training.utils import training_log
from megatron_datasets.args import parse_dataset_config

ADVANTAGE_METHOD = os.environ.get("ADVANTAGE_METHOD", None)
assert ADVANTAGE_METHOD is not None
print(f"Using ADVANTAGE_METHOD {ADVANTAGE_METHOD} for ppo actor")


def parse_task_name(task_name: str, known_names: list):
    """
    task_name: "0.5RLVR-1.2reasoning_reward-0.8XXX_reward"
    known_names: ["RLVR", "reasoning_reward", "XXX_reward"]
    返回: [("RLVR", 0.5), ("reasoning_reward", 1.2), ("XXX_reward", 0.8)]
    """
    task_name = task_name.replace("GT", "")
    parts = task_name.split('-')
    result = []
    for part in parts:
        for name in known_names:
            if part.endswith(name):
                weight_str = part[:-len(name)]
                weight = float(weight_str)
                result.append((name, weight))
                break
        else:
            raise ValueError(f"未知的任务名部分: {part}")
    return result


def iter_to_ppo_epoch_step(iteration):
    args = get_args()
    ppo_step = divide(iteration, args.train_iters_each_rollout)
    epoch = ppo_step // args.ppo_step_per_epoch
    return epoch, ppo_step


                                               
                                                            
                                   
def setup_model_and_optimizer_and_ref(model_provider_func,
                                      actor_provider_func,
                                      model_type,
                                      no_wd_decay_cond=None,
                                      scale_lr_cond=None,
                                      lr_mult=1.0):
    """Setup model and optimizer."""
    args = get_args()

    prev_finetune = args.finetune
    args.finetune = True
    ref_model = get_model(model_provider_func, model_type, wrap_with_ddp=False)

    torch.cuda.synchronize()
    before_load_ref_time = time.time()
    args.load_ref_x = args.load_ref[0]
    load_checkpoint(ref_model, None, None, load_arg='load_ref_x')
    del args.load_ref_x
    torch.cuda.synchronize()
    after_load_ref_time = time.time()

    init_policy_state_dict = None
    if args.ppo_initial_policy_kl_penalty > 0:
        init_policy_state_dict = retrieve_model_state_dict_in_cpu(
            unwrap_model(ref_model)[0])
    del ref_model
    args.finetune = prev_finetune

    torch.cuda.synchronize()
    before_load_time = time.time()
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
        model_provider_func,
        model_type,
        no_wd_decay_cond=no_wd_decay_cond,
        scale_lr_cond=scale_lr_cond,
        lr_mult=lr_mult)
    clear_memory()
    torch.cuda.synchronize()
    after_load_time = time.time()
    print_rank_0(
        f"Load ckpt using time: load ref model {after_load_ref_time - before_load_ref_time} seconds"
        f" load actor {after_load_time - before_load_time} seconds")

    actor_model = actor_provider_func(
        model=model, ref_model_state=init_policy_state_dict)

    return actor_model, optimizer, opt_param_scheduler


def save_actor_and_critic_ckpt(iteration, actor_model, rm_critic_client, sampler_client, optimizer,
                               opt_param_scheduler, num_floating_point_operations_so_far,
                               checkpointing_context):
    args = get_args()

                                   
    assert args.ppo_standalone_sampler
    g_rank = torch.distributed.get_rank()
    g_size = torch.distributed.get_world_size()
    prgss = [None for _ in range(g_size)]
    new_prgs = {}
    torch.distributed.all_gather_object(
        prgss,
        args.train_data_consuming_progresses,
    )
    for prgs in prgss:
        for r, prg in prgs.items():
            if r not in new_prgs:
                new_prgs[r] = copy.deepcopy(prg)
            else:
                if new_prgs[r] is not None:
                    new_prgs[r].merge(prg)

    args.train_data_consuming_progresses = new_prgs
    print_rank_0(
        f"save_actor_and_critic_ckpt get {args.train_data_consuming_progresses}")

    args.gather_train_data_consuming_progresses = False
    save_checkpoint_and_time(iteration, actor_model.model, optimizer, opt_param_scheduler,
                             num_floating_point_operations_so_far, checkpointing_context)

    if args.ppo_update_ref_w_actor_interval > 0:
        assert len(args.save_ref) == 1, 'unexpected save_ref {args.save_ref}'
        print_rank_0('save ckpt for ref')

        with cpu_weight_swap_v2(actor_model.model, unwrap_model, actor_model.init_policy_state_dict):
            save_checkpoint_and_time(iteration,
                                     actor_model.model,
                                     optimizer,
                                     opt_param_scheduler,
                                     num_floating_point_operations_so_far,
                                     checkpointing_context,
                                     save_arg='save_ref',
                                     save_index=0)


class PPOActorTrainerV3(object):

    def __init__(self, extra_metric_info=None):
        '''
        Parameters
        ----------
        extra_metric_info : list[dict[str, Any]]
            extra metrics to export to wandb.

            .. highlight:: python
            .. code-block:: python

                [
                  {'key_name': 'acc_rewards', 'dtype': torch.float32},
                  {'key_name': 'fmt_rewards', 'dtype': torch.float32},
                ]

        '''
        self._extra_metric_info = extra_metric_info if extra_metric_info else []

        self.disp_rng = None
        self.replay_queue = SimpleQueue()
                                                                               
        self.replay_samples_dict = dict()
        self.sample_idx = 0
        self.eval_sample_idx = 0
        self.test_flag = False
        self.cpu_memory_model = None

                                                      
           
                                                                  
                                                                  
           
                                                
        self.rm_prompt_index_mapping: list[list] = [
            [0]]                                        

    def build_cpu_memory_model(self, model):
        self.cpu_memory_model = CpuMemoryModel(model)
        torch.cuda.synchronize()

    def update_cpu_memory_model(self, model):
        torch.cuda.synchronize()
        bt = time.time()
        assert self.cpu_memory_model is not None and \
            isinstance(self.cpu_memory_model, CpuMemoryModel)
        self.cpu_memory_model.update_cpu_weights(model)
        torch.cuda.synchronize()
        et = time.time()
        print_with_rank_and_datetime(
            f"update cpu memory model using time {et - bt}s")

    def get_cpu_memory_model(self):
        return self.cpu_memory_model

    @property
    def extra_metric_info(self):
        if isinstance(self._extra_metric_info, Callable):
            return self._extra_metric_info()
        else:
            return self._extra_metric_info

                                                                        
                                                                  
    def remove_rollout_attr_before_sampling(self, rollout_batch: Dict[str, Any]) -> Dict[str, Any]:
        return rollout_batch

                                                                        
    def add_back_rollout_attr_before_sampling(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
    ):
        return rollout_batches

    def replay_rollout_batch(self, rollout_batch):
        pass

    def clear_data_cache(self):
        pass

    def post_process_rm_rollout_batch(self, rbs: List[Dict[str, List[Any]]]):
        pass

    def display_rollout_generation(self, rollout_batches: List[Dict[str, List[Any]]]):
        args = get_args()
        tokenizer = get_tokenizer()
        if not args.ppo_display_rollout_generation:
            return

                                  
        if self.disp_rng is None:
            self.disp_rng = random.Random(args.seed)
                               
                                                         
                                                      
            _, ppo_step = iter_to_ppo_epoch_step(args.iteration)
            for _ in range(ppo_step):
                for __ in range(3):
                    self.disp_rng.randint(0, 1)

        pick_rank = self.disp_rng.randint(
            0, torch.distributed.get_world_size() - 1)
        pick_batch = self.disp_rng.randint(0, len(rollout_batches) - 1)

        rollout_batch = rollout_batches[pick_batch]
        tokens = list_for_tensor_tolist(rollout_batch['tokens'], False)
        prompt_lengths = list_for_tensor_tolist(
            rollout_batch["prompt_lengths"], True)
        lengths = list_for_tensor_tolist(
            rollout_batch["sequence_lengths"], True)
        if 'rewards' in rollout_batch:
            rewards = list_for_tensor_tolist(rollout_batch["rewards"], True)
        else:
            rewards = None
        if 'per_token_rewards' in rollout_batch and rollout_batch['per_token_rewards'][0] is not None:
            per_token_rewards = list_for_tensor_tolist(
                rollout_batch["per_token_rewards"], False)
        else:
            per_token_rewards = None

                             
        extra_metrics_cpu = {}
        for em_info in self.extra_metric_info:
            em_k = em_info['key_name']
            em_v = list_for_tensor_tolist(rollout_batch[em_k], True, True)
            extra_metrics_cpu[em_k] = em_v

        pick_offset = self.disp_rng.randint(0, len(tokens) - 1)
        num_to_pick = 4
        tokens = tokens[pick_offset:pick_offset + num_to_pick]
        prompt_lengths = prompt_lengths[pick_offset:pick_offset + num_to_pick]
        lengths = lengths[pick_offset:pick_offset + num_to_pick]
        if isinstance(prompt_lengths[0], list):
            prompt_lengths = prompt_lengths[0]
        if isinstance(lengths[0], list):
            lengths = lengths[0]
        if rewards is not None:
            rewards = rewards[pick_offset:pick_offset + num_to_pick]
        if per_token_rewards is not None:
            per_token_rewards = per_token_rewards[pick_offset:pick_offset + num_to_pick]
        for i in range(len(tokens)):
            tokens[i] = tokens[i][:lengths[i]]

              
        for em_k, em_v in extra_metrics_cpu.items():
            extra_metrics_cpu[em_k] = em_v[pick_offset:pick_offset + num_to_pick]

        texts = tokenizer._tokenizer.batch_decode(
            tokens, skip_special_tokens=False)
        log_string = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
        log_string += f'DISPLAY_ROLLOUT_GENERATION rollout_batches[{pick_batch}]\n'
        DISPLAY_ROLLOUT_GENERATION_SEP = '-' * 40 + '\n'
        for ti, text in enumerate(texts):
            log_string += DISPLAY_ROLLOUT_GENERATION_SEP
            log_string += f'micro_batch 0 idx {ti}\n'
            log_string += f'prompt_lengths {prompt_lengths[ti]}\n'
            log_string += f'length {lengths[ti]}\n'
            if rewards is not None:
                log_string += f'reward {rewards[ti]}\n'

            if per_token_rewards is not None:
                                      
                tmp = per_token_rewards[ti][prompt_lengths[ti] -
                                            1:lengths[ti] - 1]
                tmp = np.mean(tmp)
                log_string += f'per_token_rewards {tmp}\n'

                  
            for em_k, em_v in extra_metrics_cpu.items():
                log_string += f'{em_k} {em_v[ti]}\n'

            log_string += DISPLAY_ROLLOUT_GENERATION_SEP
            log_string += f'^{text}$\n'
            log_string += DISPLAY_ROLLOUT_GENERATION_SEP

        if torch.distributed.get_rank() == pick_rank:
            print(log_string)

    def run_inference(
        self,
        rollout_get_batch_func,
        filter_samplings_func,
        actor_model,
        rm_critic_client,
        gen_rm_client,
        sampler_client,
        dataloader_iter,
        ppo_step,
        num_microbatches,
    ) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        """this function is run per DP so the metrics need to be computed globally
        """
        args = get_args()
        timers = get_timers()

        print_with_rank_and_datetime(
            f'rollout run_inference {num_microbatches=}', rank=0)
        torch.distributed.barrier()
        rollout_batches = self.get_rollout_batches(sampler_client, rm_critic_client, gen_rm_client,
                                                   ppo_step, num_microbatches,
                                                   rollout_get_batch_func, dataloader_iter)
        torch.distributed.barrier()

        onload_megatron_model(actor_model.model)
        torch.distributed.barrier()

        timers('broadcast_rollout_batches', log_level=0).start(barrier=True)
                               
        if is_mp_and_cp_head():
            assert check_rollout_batches(
                rollout_batches), f"rollout_batches fmt error: {rollout_batches=}"
            for rollout_batch in rollout_batches:
                assert len(rollout_batch) >= 3 \
                    and 'tokens' in rollout_batch \
                    and 'prompt_lengths' in rollout_batch \
                    and 'sequence_lengths' in rollout_batch \
                    and 'rewards' in rollout_batch \

        print_memory_tracking(
            f"Memory tracking: before bcast data", rank=args.world_size-1)
        rollout_batches = sampler_client.broadcast_rollout_batch(
            rollout_batches)
        clear_memory()
        print_memory_tracking(
            f"Memory tracking: after bcast data", rank=args.world_size-1)

        rollout_batches = self.add_back_rollout_attr_before_sampling(
            rollout_batches)
        assert check_rollout_batches(
            rollout_batches), f"rollout_batches fmt error: {rollout_batches=}"
        timers('broadcast_rollout_batches').stop()
        print_meminfo_str("Track CPU Memory after get_rollout_batches: ")

                   
        clear_memory()
        print_memory_tracking(
            f"Memory tracking: before get logps", rank=args.world_size-1)

                  
        timers('compute_logps', log_level=0).start(barrier=True)
        if args.ppo_smart_pad_infer:
            policy_logprobs = actor_model.batch_get_policy_logprobs(
                rollout_batches)
        else:
            policy_logprobs = actor_model.get_policy_logprobs(rollout_batches)
        assert policy_logprobs is not None
        for logprobs, rb in zip(policy_logprobs, rollout_batches, strict=True):
            rb["logprobs"] = logprobs
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

        print_memory_tracking(
            f"Memory tracking: before get ref logps", rank=args.world_size-1)
        if args.ppo_initial_policy_kl_penalty > 0:
            if args.ppo_smart_pad_infer:
                init_policy_logprobs = actor_model.batch_get_ref_policy_logprobs(
                    rollout_batches)
            else:
                init_policy_logprobs = actor_model.get_ref_policy_logprobs(
                    rollout_batches)
            assert init_policy_logprobs is not None
            for ref_logprobs, rb in zip(init_policy_logprobs, rollout_batches, strict=True):
                rb["ref_logprobs"] = ref_logprobs

        clear_memory()
        print_memory_tracking(
            f"Memory tracking: after get ref logps", rank=args.world_size-1)
        timers('compute_logps').stop()
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)
        print_meminfo_str("Track CPU Memory after get logprobs: ")

        timers('hook_before_computing_metrics',
               log_level=0).start(barrier=True)
        rollout_batches = self.hook_before_computing_metrics(
            rollout_batches, ppo_step, is_eval=False)
        timers('hook_before_computing_metrics').stop()

        timers('filter_sampling', log_level=0).start(barrier=True)
        assert check_rollout_batches(
            rollout_batches), f"rbs fmt error: {rollout_batches=}"
        if args.ppo_sampling_keep != args.ppo_sampling_repeat:
            filter_samplings_func(args, rollout_batches)
            for rb in rollout_batches:
                expected_mbs = args.ppo_rollout_micro_batch_size * args.ppo_sampling_keep
                assert len(rb['tokens']) == expected_mbs
                assert check_rollout_batches(
                    rollout_batches), f"rbs fmt error: {rollout_batches=}"
        timers('filter_sampling').stop()
              
        timers('compute_global_rollout_metrics',
               log_level=0).start(barrier=True)
        rollout_metrics = cpu_dict(
            self.compute_global_rollout_metrics(rollout_batches))
        timers('compute_global_rollout_metrics').stop()

        self.display_rollout_generation(rollout_batches)
        return rollout_batches, rollout_metrics

    def compute_global_rollout_metrics(self, rollout_batches: List[Dict[str, List[Any]]]):
        metrics = defaultdict(lambda: 0)

        num_samples = 0
        for _, rb in enumerate(rollout_batches):
            prompt_lengths = torch.stack(rb["prompt_lengths"]).view(-1)
            sequence_lengths = torch.stack(rb["sequence_lengths"]).view(-1)
            rewards = torch.stack(rb["rewards"]).view(-1)
            metrics["sequence_lengths"] += (sequence_lengths -
                                            prompt_lengths).sum()
            metrics["prompt_lengths"] += prompt_lengths.sum()
            metrics["rewards"] += rewards.sum()
            for em_info in self.extra_metric_info:
                em_k = em_info['key_name']

                em_v = torch.stack(rb[em_k]).view(-1)
                metrics[em_k] += em_v.sum()
            num_samples += prompt_lengths.size(0)

        args = get_args()
        if args.ppo_wecube_report and is_mp_and_cp_head():
            biz_report_data = {
                "prompt_lengths": metrics["prompt_lengths"],
                "sequence_lengths": metrics["sequence_lengths"],
                "num_samples": num_samples,
            }
            for k in biz_report_data.keys():
                if torch.is_tensor(biz_report_data[k]):
                    biz_report_data[k] = biz_report_data[k].item()
            report_ppo_metrics(biz_report_data)

        tensor_to_accumulate = [
            metrics["sequence_lengths"],
            metrics["prompt_lengths"],
            metrics["rewards"],
            num_samples,
        ]
        for em_info in self.extra_metric_info:
            em_k = em_info['key_name']
            tensor_to_accumulate.append(metrics[em_k])
        tensor_to_accumulate = torch.tensor(
            tensor_to_accumulate,
            dtype=torch.float32,
            device=torch.cuda.current_device(),
        )
        torch.distributed.all_reduce(tensor_to_accumulate,
                                     group=parallel_state.get_data_parallel_group())

        tensor_to_accumulate_as_list = tensor_to_accumulate.tolist()
        (
            global_response_lengths,
            global_prompt_lengths,
            global_rewards,
            global_num_samples,
        ) = tensor_to_accumulate_as_list[:4]
        metrics = {
            "rollout-metrics/global_response_lengths_mean":
            global_response_lengths / global_num_samples,
            "rollout-metrics/global_prompt_lengths": global_prompt_lengths / global_num_samples,
            "rollout-metrics/global_rewards": global_rewards / global_num_samples,
        }
        for em_i, em_info in enumerate(self.extra_metric_info):
            em_k = em_info['key_name']
            metrics[f"rollout-metrics/global_{em_k}"] = tensor_to_accumulate_as_list[
                4 + em_i] / global_num_samples
        return metrics

    def generate_ppo_data(self,
                          rollout_batches) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        """generate ppo specific data for training
        """
        args = get_args()
        tokenizer = get_tokenizer()
        ppo_rollout_metrics = defaultdict(lambda: 0)
        num_samples = 0
        print_with_rank_and_datetime('rollout generate_ppo_data', rank=0)

                                          
                                                            
                            
        for rollout_batch in rollout_batches:
                                  
                                                                                                         
            prompt_lengths = rollout_batch["prompt_lengths"]
            sequence_lengths = rollout_batch["sequence_lengths"]
            tokens = rollout_batch["tokens"]
            values = rollout_batch.get("values", [None])
            if values[0] is None:
                values = None
            rewards = rollout_batch["rewards"]
            confidence_rewards = rollout_batch["confidence_reward"]
            relevance_rewards = rollout_batch["relevance_reward"]
            coherent_rewards = rollout_batch["coherent_reward"]
            critic_rewards = rollout_batch["critic_reward"]

            per_token_rewards = rollout_batch["per_token_rewards"]
            logprobs = rollout_batch["logprobs"]
                               
            assert len(logprobs) == len(confidence_rewards)

            for each_sample_idx in range(len(logprobs)):
                each_token_logp = logprobs[each_sample_idx]
                reasoning_start = prompt_lengths[each_sample_idx]
                answer_start = rollout_batch['answer_start'][each_sample_idx].item(
                )
                answer_end = rollout_batch['answer_end'][each_sample_idx].item(
                )
                if answer_start < reasoning_start or answer_end < reasoning_start:
                    confidence_rewards[each_sample_idx] = -10.0
                    continue
                reasoning_N_logps = each_token_logp[reasoning_start:answer_start]
                answer_logps = each_token_logp[answer_start:answer_end]
                confidence_rewards[each_sample_idx] = sum(
                    reasoning_N_logps)/len(reasoning_N_logps) + sum(answer_logps)
            confidence_rewards = list(torch.tensor(
                confidence_rewards, dtype=torch.float32))
            rollout_batch["confidence_reward"] = confidence_rewards

            num_samples += len(prompt_lengths)

            mask = create_mask(values=logprobs,
                               prompt_lengths=prompt_lengths,
                               sequence_lengths=sequence_lengths,
                               dtype=rewards[0].dtype)
            rollout_batch["mask"] = mask

                       
            assert not args.use_grpo or (
                args.use_grpo and args.ppo_initial_policy_kl_penalty)

            if args.ppo_initial_policy_kl_penalty > 0:
                ref_logprobs = rollout_batch["ref_logprobs"]
                init_policy_kl = calculate_kl_penalty(
                    log_probs_a=logprobs,
                    log_probs_b=ref_logprobs,
                    use_absolute_kl=args.ppo_use_absolute_kl,
                )
            else:
                init_policy_kl = torch.tensor(
                    0, dtype=logprobs.dtype, device=logprobs.device)

            if not args.use_grpo:
                assert values.shape == logprobs.shape, f"values {values.shape} and logprobs {logprobs.shape} should have the same shape"

            if args.use_grpo:
                if ADVANTAGE_METHOD == "RLVR":
                    advantages, returns = calculate_grpo_advantages(
                        rewards=rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    rollout_batch["returns"] = returns
                elif ADVANTAGE_METHOD in ["reasoning_reward", "GTreasoning_reward"]:
                    advantages_confidence, returns_confidence = calculate_grpo_advantages(
                        rewards=confidence_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    advantages_coherent, returns_coherent = calculate_grpo_advantages(
                        rewards=coherent_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    advantages_relevance, returns_relevance = calculate_grpo_advantages(
                        rewards=relevance_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    returns = [0.1 * rc + 0.7 * rh + 0.2 * rr
                               for rc, rh, rr in zip(returns_confidence, returns_coherent, returns_relevance)]
                    advantages = [0.1 * ac + 0.7 * ah + 0.2 * ar
                                  for ac, ah, ar in zip(advantages_confidence, advantages_coherent, advantages_relevance)]
                    rollout_batch["returns"] = returns
                elif "RLVR" in ADVANTAGE_METHOD and "reasoning_reward" in ADVANTAGE_METHOD:
                    weights = parse_task_name(ADVANTAGE_METHOD, [
                        "RLVR", "reasoning_reward"])
                    weights_RLVR = weights[0][1]
                    weights_reasoning = weights[1][1]

                    advantages_confidence, returns_confidence = calculate_grpo_advantages(
                        rewards=confidence_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    advantages_coherent, returns_coherent = calculate_grpo_advantages(
                        rewards=coherent_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    advantages_relevance, returns_relevance = calculate_grpo_advantages(
                        rewards=relevance_rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    advantages_RLVR, returns_RLVR = calculate_grpo_advantages(
                        rewards=rewards,
                        mask=mask,
                        grpo_sampling_times=args.ppo_sampling_keep,
                        grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                    )
                    returns_reasoning = [0.1 * rc + 0.7 * rh + 0.2 * rr
                                         for rc, rh, rr in zip(returns_confidence, returns_coherent, returns_relevance)]
                    advantages_reasoning = [0.1 * ac + 0.7 * ah + 0.2 * ar
                                            for ac, ah, ar in zip(advantages_confidence, advantages_coherent, advantages_relevance)]
                    returns = [weights_reasoning * r1 + weights_RLVR * r2 for r1,
                               r2 in zip(returns_reasoning, returns_RLVR)]
                    advantages = [weights_reasoning * a1 + weights_RLVR * a2 for a1,
                                  a2 in zip(advantages_reasoning, advantages_RLVR)]
                    rollout_batch["returns"] = returns
                else:
                    raise NotImplementedError(
                        f"Not implemented ADVANTAGE_METHOD {ADVANTAGE_METHOD} for grpo")
            else:
                                                    
                rewards_with_kl = calculate_ppo_rewards(logprobs, rewards, None, sequence_lengths,
                                                        init_policy_kl,
                                                        args.ppo_initial_policy_kl_penalty)
                advantages, returns = calculate_advantages_and_returns(
                    values=values,
                    rewards=rewards_with_kl,
                    discount_factor=args.ppo_discount_factor,
                    gae_lambda=args.ppo_gae_lambda,
                    mask=mask,
                    per_token_rewards=per_token_rewards,
                    per_token_rewards_factor=args.ppo_calc_adv_per_token_rewards_factor,
                )
                assert returns[0].dtype == torch.float32
                rollout_batch["returns"] = returns

            assert advantages[0].dtype == torch.float32
            rollout_batch["advantages"] = advantages

                             
                                                                                              
            if args.ppo_initial_policy_kl_penalty > 0:
                ppo_rollout_metrics["ppo-metrics/init_policy_kl"] += masked_mean_list(
                    init_policy_kl, mask, dim=-1).sum().item()
        assert check_rollout_batches(
            rollout_batches), f"check rbs fmt error {rollout_batches=}"
                                                               
        ppo_rollout_metrics = {k: v / num_samples for k,
                               v in ppo_rollout_metrics.items()}

        mask_list = []
        for rollout_batch in rollout_batches:
            mask_list.extend(rollout_batch["mask"])
        for key in ["advantages", "returns", "values", 'per_token_rewards']:
            if key not in rollout_batches[0] or rollout_batches[0][key][0] is None:
                continue
            tensor_list = []
            for rollout_batch in rollout_batches:
                tensor_list.extend(rollout_batch[key])
            global_mean, global_var, min_var, max_var = masked_global_statistics_list(
                tensor_list,
                mask_list,
                group=parallel_state.get_data_parallel_group(),
            )
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_mean"] = global_mean.item(
            )
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_std"] = global_var.sqrt(
            ).item()
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_min"] = min_var.item()
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_max"] = max_var.item()
        if args.ppo_normalize_advantages and not args.use_grpo:
                                                 
                                                                    
                                            
                                                                
                                                 
                                           
                                                                 
               
            raise NotImplemented

        return rollout_batches, cpu_dict(ppo_rollout_metrics)

    @torch.no_grad()
    def generate_rollouts(self, rollout_get_batch_func, filter_samplings_func, actor_model,
                          rm_critic_client, gen_rm_client, sampler_client, dataloader_iter,
                          ppo_step,
                          num_microbatches) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        timers = get_timers()
        args = get_args()
        actor_model.prepare_for_inference()

        rollout_batches, rollout_metrics = self.run_inference(rollout_get_batch_func,
                                                              filter_samplings_func, actor_model,
                                                              rm_critic_client, gen_rm_client,
                                                              sampler_client, dataloader_iter,
                                                              ppo_step, num_microbatches)

        timers('generate_ppo_data', log_level=0).start(barrier=True)
        rollout_batches, ppo_rollout_metrics = self.generate_ppo_data(
            rollout_batches)
        timers('generate_ppo_data').stop()

        if args.ppo_save_first_rollout_data:
            if torch.distributed.get_rank() == 0:
                debug_dir = "./debug-tmp"
                os.makedirs(debug_dir, exist_ok=True)
                torch.save(rollout_batches,
                           f"{debug_dir}/first_rollout_data_rank0.pt")
                torch.save(rollout_metrics,
                           f"{debug_dir}/first_rollout_metrics_rank0.pt")

        actor_model.finish_inference()
        return rollout_batches, rollout_metrics | ppo_rollout_metrics

    def train_step(
        self,
        config,
        actor_model,
        optimizer,
        opt_param_scheduler,
        dataloader_iter,
        ppo_step,
        train_iters_each_rollout,
        iteration,
        total_loss_dict,
        rollout_metrics,
    ):
        args = get_args()
        timers = get_timers()
        actor_model.prepare_for_training()

        for batch in tqdm(
                dataloader_iter,
                initial=ppo_step * train_iters_each_rollout,
                total=args.ppo_max_epochs * args.ppo_step_per_epoch * train_iters_each_rollout,
                leave=True,
                position=2,
                desc='PPO actor train',
                disable=torch.distributed.get_rank() != 0 or args.ppo_disable_tqdm,
        ):
            actor_model.prepare_for_training_step()
            for model_chunk in actor_model.model:
                model_chunk.zero_grad_buffer()
            optimizer.zero_grad()

            timers('loss-and-metrics',
                   log_level=1).start(barrier=args.barrier_with_L1_time)
            loss_mean, metrics = actor_model.get_loss_and_metrics(batch,
                                                                  get_num_microbatches(),
                                                                  forward_only=False)
            timers('loss-and-metrics').stop()

            assert not config.finalize_model_grads_func
            finalize_model_grads(actor_model.model)
            actor_model.finish_training_step()

            timers('optimizer', log_level=1).start(
                barrier=args.barrier_with_L1_time)
            update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
            timers('optimizer').stop()

            if update_successful:
                increment = get_num_microbatches() * args.micro_batch_size * \
                    args.data_parallel_size
                opt_param_scheduler.step(increment=increment)
                skipped_iter = 0
            else:
                skipped_iter = 1
            iteration += 1

                                           
            loss_scale = optimizer.get_loss_scale().item()
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(actor_model.model)

            learning_rates = []
            decoupled_learning_rate = None
            for param_group in optimizer.param_groups:
                if param_group['is_decoupled_lr']:
                    decoupled_learning_rate = param_group['lr']
                else:
                    learning_rate = param_group['lr']
                    if learning_rate not in learning_rates:
                        learning_rates.append(learning_rate)

            metrics.update(rollout_metrics)
            training_log(metrics, total_loss_dict, learning_rates, decoupled_learning_rate,
                         iteration, loss_scale, False, skipped_iter, grad_norm, params_norm,
                         num_zeros_in_grad)
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

        actor_model.finish_training()
        optimizer.zero_grad()
        return iteration

    def get_rollout_batches(self, sampler_client, rm_critic_client, gen_rm_client, ppo_step,
                            num_microbatches, rollout_get_batch_func, dataloader_iter, is_eval=False):
                                                         
        args = get_args()
        timers = get_timers()

        eval_ = 'eval_' if is_eval else ''
        sample_idx = self.eval_sample_idx if is_eval else self.sample_idx

                                           
        sampling_repeat = args.ppo_eval_sampling_repeat if is_eval else args.ppo_sampling_repeat

        timers(f'{eval_}generate', log_level=0).start(barrier=True)

        print_memory_tracking(
            f"Memory tracking before rollout:", verbose=True, rank=0)
        print_with_rank_and_datetime(
            f'get_rollout_batches mark_sampler_ppo_step_begin', rank=0)
        sampler_client.mark_sampler_ppo_step_begin(ppo_step)
        cpu_barrier()

                  
        async def gen_co_1():
            batches = [None for _ in range(num_microbatches)]
            cos = []
            rollout_mbs = args.ppo_rollout_micro_batch_size
            for rbi in range(num_microbatches):
                if not is_eval and args.replay_sample and self.replay_queue.qsize() >= rollout_mbs:
                    rb = None
                    for i in range(rollout_mbs):
                        sample = self.replay_queue.get()
                        if rb is None:
                            rb = {}
                            for k in sample.keys():
                                rb[k] = [sample[k]]
                        else:
                            for k in rb.keys():
                                rb[k].append(sample[k])
                    batches[rbi] = self.rollout_get_from_replay_queue(rb)
                else:
                    batches[rbi] = rollout_get_batch_func(dataloader_iter)
                batches[rbi] = self.remove_rollout_attr_before_sampling(
                    batches[rbi])
                co = sampler_client.generate(
                    ppo_step, sample_idx + rbi, batches[rbi], sampling_repeat)
                cos.append(co)

            return await asyncio.gather(*cos)

        print_with_rank_and_datetime(
            f'{eval_}get_rollout_batches generate', rank=0)
        num_rollout_samples = num_microbatches

        if is_mp_and_cp_head():
            rbs: List[Dict[str, List[Any]]] = asyncio.run(gen_co_1())
            assert check_rollout_batches(
                rbs), f"rbs format error, may need pop('ready'): {rbs=}"
        else:
            rbs = [None for _ in range(num_microbatches)]
        cpu_barrier()
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

                     
        sampler_client.infer_engine_flush_cache()
        print_memory_tracking(
            f"Memory tracking: actor after sampler flush cache", verbose=True, rank=0)
        cpu_barrier()

        print_with_rank_and_datetime(
            f'{eval_}get_rollout_batches mark_sampler_ppo_step_end', rank=0)
        sampler_client.mark_sampler_ppo_step_end(ppo_step)
        cpu_barrier()
        timers(f'{eval_}generate').stop()               

        torch.cuda.synchronize()
        print_memory_tracking(
            f"Memory tracking: actor after sampler sleep", verbose=True, rank=0)

        timers(f'{eval_}hook_after_sampling', log_level=0).start(barrier=True)
        if is_mp_and_cp_head():
            rbs = self.hook_after_sampling(rbs)
        cpu_barrier()
        timers(f'{eval_}hook_after_sampling').stop()

                        
        async def gen_rm_co_1(req_rbs, rbs_idx, prompt_idx):
            cos = []

            for rb, rbi in zip(req_rbs, rbs_idx, strict=True):
                co = gen_rm_client.generate_rewards(ppo_step, sample_idx + rbi,
                                                    prompt_idx, rb)
                cos.append(co)

            return await asyncio.gather(*cos)

        if args.use_gen_rm:
            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches mark_gen_rm_ppo_step_begin', rank=0)
            timers(f'{eval_}gen_rm_scoring', log_level=0).start(barrier=True)
            gen_rm_client.mark_gen_rm_ppo_step_begin(ppo_step)
            cpu_barrier()

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches with {args.ppo_num_rm} gen_rm(s)'
                                         f' with rm_prompt mapping {self.rm_prompt_index_mapping}',
                                         rank=0)
                                                             
            prompt_start_idx = 0
                                                                 
            is_single_prompt = len(self.rm_prompt_index_mapping) == 1 and \
                len(self.rm_prompt_index_mapping[0]) == 1
            for rm_idx, prompt_list in enumerate(self.rm_prompt_index_mapping):
                                                                                       
                if rm_idx != 0:
                    print_with_rank_and_datetime(f"ppo_actor {rm_idx=}")
                    gen_rm_client.update_gen_rm_weight_by_model_idx(rm_idx)

                for pl_idx in range(len(prompt_list)):
                    prompt_idx = prompt_start_idx + pl_idx
                    if is_mp_and_cp_head():
                        req_rbs_list = []
                        req_idx_list = []
                        for rbi, rb in enumerate(rbs):
                            if f'rm_{prompt_idx}_tokens' in rb or is_single_prompt:
                                req_idx_list.append(rbi)
                                req_rbs_list.append(rb)
                                                                                          
                        rbs_result = asyncio.run(gen_rm_co_1(req_rbs_list,
                                                             req_idx_list, prompt_idx))
                        for rb_idx, rb_result in zip(req_idx_list,
                                                     rbs_result, strict=True):
                            rbs[rb_idx].update(rb_result)
                    else:
                        rbs = [None for _ in range(num_rollout_samples)]
                prompt_start_idx += len(prompt_list)
                cpu_barrier()
            cpu_barrier()

            if is_mp_and_cp_head():
                self.post_process_rm_rollout_batch(rbs)

            gen_rm_client.infer_engine_flush_cache()

            print_memory_tracking(
                f"Memory tracking: actor after gen_rm flush cache", verbose=True, rank=0)
            cpu_barrier()

            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches mark_gen_rm_ppo_step_end', rank=0)
            gen_rm_client.mark_gen_rm_ppo_step_end(ppo_step)
            cpu_barrier()
            timers(f'{eval_}gen_rm_scoring').stop()

            print_memory_tracking(
                f'Memory tracking: actor after gen_rm sleep', verbose=True, rank=0)
            if args.do_monitor:
                mark_made_progress(args.monitor_server_ip, args.monitor_port)

                    
        if args.use_rm_and_critic:
            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches mark_rm_ppo_step_begin', rank=0)
            timers(f'{eval_}scoring', log_level=0).start(barrier=True)
            rm_critic_client.mark_rm_ppo_step_begin(ppo_step)
            cpu_barrier()

            async def rm_critic_co_1():
                cos = [
                    rm_critic_client.issue_infer_rm_critic(ppo_step, sample_idx + rbi,
                                                           rbs[rbi])
                    for rbi in range(num_rollout_samples)
                ]
                return await asyncio.gather(*cos)

            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches issue_infer_rm_critic', rank=0)
            if is_mp_and_cp_head():
                asyncio.run(rm_critic_co_1())
            cpu_barrier()

            async def rm_critic_co_2():
                cos = [
                    rm_critic_client.get_infer_rm_critic_result(
                        ppo_step, sample_idx + rbi, sampling_repeat)
                    for rbi in range(num_rollout_samples)
                ]
                return await asyncio.gather(*cos)

            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches get_infer_rm_critic_result', rank=0)
            if is_mp_and_cp_head():
                crbs = asyncio.run(rm_critic_co_2())
                for rb, crb in zip(rbs, crbs, strict=True):
                    rb.update(crb)
                assert check_rollout_batches(
                    rbs), f"rbs format error, may need pop('ready'): {crbs=}"
            else:
                rbs = [None for _ in range(num_rollout_samples)]
            cpu_barrier()

            print_with_rank_and_datetime(
                f'{eval_}get_rollout_batches mark_rm_ppo_step_end', rank=0)
            rm_critic_client.mark_rm_ppo_step_end(ppo_step)
            cpu_barrier()
            timers(f'{eval_}scoring').stop()
            if args.do_monitor:
                mark_made_progress(args.monitor_server_ip, args.monitor_port)

        if not is_eval and args.replay_sample and is_mp_and_cp_head():
            for rb in rbs:
                if 'sample_useful' in rb:
                    replay_samples = self.get_replay_samples(rb)
                    for sample in replay_samples:
                        self.replay_queue.put(sample)

        if is_eval:
            self.eval_sample_idx += num_rollout_samples
        else:
            self.sample_idx += num_rollout_samples

                                     
        self.clear_data_cache()
        return rbs

    def train_loop(self, actor_model, sampler_client, rm_critic_client, gen_rm_client,
                   rollout_get_batch_func, filter_samplings_func, optimizer, opt_param_scheduler,
                   train_data_iterator, valid_data_iterator, process_non_loss_data_func, config,
                   checkpointing_context):
        """Train the model function."""
        args = get_args()
        timers = get_timers()

                                   
        write_args_to_tensorboard()

                        
        total_loss_dict = {}

                     
        iteration = args.iteration
        num_floating_point_operations_so_far = args.num_floating_point_operations_so_far

        timers('interval-time', log_level=0).start(barrier=True)
        print_datetime('before the start of training step')
        exit = False

        epoch, ppo_step = iter_to_ppo_epoch_step(iteration)
        dp_size = parallel_state.get_data_parallel_world_size()
        sampler_dp_size = args.ppo_sampler_data_parallel_size
        rollout_gbs = args.ppo_rollout_global_batch_size
        rollout_mbs = args.ppo_rollout_micro_batch_size
        assert rollout_gbs % dp_size == 0, f"{rollout_gbs=} {dp_size=}"
        assert rollout_gbs % sampler_dp_size == 0, f"{rollout_gbs=} {sampler_dp_size=}"
        num_rollout_micro_batches = divide(rollout_gbs, rollout_mbs * dp_size)
        ppo_step_since_recover = 0

        print_rank_0((f'actor iter_to_ppo_epoch_step'
                      f' num_rollout_micro_batches {num_rollout_micro_batches}'
                      f' train_iters_each_rollout {args.train_iters_each_rollout}'
                      f' ppo_step {ppo_step}'
                      f' epoch {epoch}'))
        assert args.train_iters == args.ppo_max_epochs * args.ppo_step_per_epoch * args.train_iters_each_rollout, (
            f'invalid args train_iters {args.train_iters} != ppo_max_epochs {args.ppo_max_epochs}'
            f' * ppo_step_per_epoch {args.ppo_step_per_epoch} * train_iters_each_rollout {args.train_iters_each_rollout}'
        )

        epoch_iter = range(epoch, args.ppo_max_epochs)
        if len(epoch_iter) <= 0:
            return

        wandb_writer = get_wandb_writer()
        for _ in epoch_iter:
            num_steps_in_epoch = args.ppo_step_per_epoch - ppo_step % args.ppo_step_per_epoch
            loop_iter = range(num_steps_in_epoch)
            if not loop_iter:
                return

            perf_ctx = nullcontext()
            with perf_ctx:
                global_pbar = tqdm(
                    loop_iter,
                    initial=ppo_step,
                    total=args.ppo_max_epochs * args.ppo_step_per_epoch,
                    leave=True,
                    position=0,
                    desc="PPO actor global step",
                    disable=torch.distributed.get_rank() != 0 or args.ppo_disable_tqdm,
                )
                for _ in global_pbar:
                    timers('ppo_step_time_cost',
                           log_level=0).start(barrier=True)
                                  
                    clear_memory()
                    rollout_batches, rollout_metrics = self.generate_rollouts(
                        rollout_get_batch_func,
                        filter_samplings_func,
                        actor_model,
                        rm_critic_client,
                        gen_rm_client,
                        sampler_client,
                        train_data_iterator,
                        ppo_step,
                        num_rollout_micro_batches,
                    )

                                       
                    if not args.use_grpo:
                        timers('rm_critic_client_train',
                               log_level=0).start(barrier=True)
                        clear_memory()
                                                                
                        rm_critic_client.train(rollout_batches)
                        timers('rm_critic_client_train').stop()

                                  
                    onload_megatron_optimizer(optimizer)
                    torch.distributed.barrier()

                    clear_memory()
                    torch.distributed.barrier()
                    print_with_rank_and_datetime('train', rank=0)
                    ex_rollout_batches = expand_rollout_batches(
                        rollout_batches)
                    total_samples = args.ppo_rollout_global_batch_size * args.ppo_sampling_keep
                    assert len(ex_rollout_batches) * mpu.get_data_parallel_world_size(
                    ) == total_samples, f"{len(ex_rollout_batches)} != {total_samples}"

                    if args.ppo_smart_pad_train:
                        num_global_batch = args.ppo_rollout_global_batch_size * \
                            args.ppo_sampling_keep // args.global_batch_size
                        train_batch_size_per_dp = total_samples // num_global_batch // args.ppo_actor_data_parallel_size
                        print_with_rank_and_datetime(
                            f"[SMART_PAD_TRAIN] {args.ppo_rollout_global_batch_size=} {args.ppo_sampling_keep=} {args.global_batch_size=} {num_global_batch=} {train_batch_size_per_dp=}", rank=0)
                        reorder_ex_rollout_batches = smart_pad_train_get_reorder_rollout_batches(
                            ex_rollout_batches, num_global_batch, train_batch_size_per_dp,
                            args.ppo_rollout_pad_to_multiple_of, ppo_step
                        )
                        ex_rollout_batches = reorder_ex_rollout_batches

                    for _ in range(args.ppo_max_epochs_2):
                                             
                        timers('make_rollout_dataloader_iter',
                               log_level=0).start(barrier=True)
                                                  
                                                                          
                        rollout_dataloader_iter = get_iterator_k_split_list(
                            ex_rollout_batches, args.ppo_rollout_global_batch_size *
                            args.ppo_sampling_keep // args.global_batch_size)
                        timers('make_rollout_dataloader_iter').stop()

                                        
                        timers('train_step', log_level=0).start(barrier=True)
                        iteration = self.train_step(
                            config,
                            actor_model,
                            optimizer,
                            opt_param_scheduler,
                            rollout_dataloader_iter,
                            ppo_step,
                            args.train_iters_each_rollout,
                            iteration,
                            total_loss_dict,
                            rollout_metrics,
                        )
                        timers('train_step').stop()

                    ppo_step += 1
                    ppo_step_since_recover += 1

                                    
                    if args.ppo_update_ref_w_actor_interval > 0 and ppo_step % args.ppo_update_ref_w_actor_interval == 0:
                        print_with_rank_and_datetime(
                            'update_ref_with_actor constantly', rank=0)
                        actor_model.update_ref_with_actor(
                            args.ppo_update_ref_w_actor_coef)

                               
                    assert args.save_interval % args.train_iters_each_rollout == 0
                    if args.save and args.save_interval and ppo_step % args.ppo_step_save_interval == 0:
                        save_actor_and_critic_ckpt(iteration, actor_model, rm_critic_client,
                                                   sampler_client, optimizer, opt_param_scheduler,
                                                   num_floating_point_operations_so_far,
                                                   checkpointing_context)

                                             
                    offload_megatron_optimizer(optimizer)
                    if args.ppo_early_swap_model:
                        self.update_cpu_memory_model(
                            unwrap_model(actor_model.model)[0])
                        offload_megatron_model(actor_model.model)
                    torch.distributed.barrier()

                                        
                    timers('update_sampler_weights',
                           log_level=0).start(barrier=True)
                    u_intv = args.ppo_step_update_sampler_interval
                    if args.ppo_standalone_sampler and u_intv > 0 and ppo_step_since_recover % u_intv == 0:
                        print_with_rank_and_datetime(
                            'update_sampler_weights', rank=0)
                        sampler_client.update_weights(actor_model.model,
                                                      early_swap_model=args.ppo_early_swap_model,
                                                      cpu_memory_model=self.cpu_memory_model)
                    timers('update_sampler_weights').stop()

                                         
                    if not args.ppo_early_swap_model:
                        offload_megatron_model(actor_model.model)
                        torch.distributed.barrier()

                    if args.ppo_wecube_report and is_mp_and_cp_head():
                        report_data = {'timer_report_times': 1}
                        for report_key in [
                                'generate', 'compute_logps', 'scoring', 'rm_critic_client_train',
                                'train_step', 'update_sampler_weights', 'ppo_step_time_cost',
                        ]:
                            report_data['timer_' + report_key] = timers(
                                report_key, log_level=0).elapsed(reset=False)
                        report_ppo_metrics(report_data)

                          
                    timers('eval', log_level=0).start(barrier=True)
                    if args.ppo_step_eval_interval > 0 and ppo_step % args.ppo_step_eval_interval == 0:
                        print_with_rank_and_datetime(
                            f'eval at ppo_step={ppo_step - 1}', rank=0)
                        eval_metrics = self.eval_loop(
                            actor_model=actor_model,
                            sampler_client=sampler_client,
                            rm_critic_client=rm_critic_client,
                            gen_rm_client=gen_rm_client,
                            rollout_get_batch_func=rollout_get_batch_func,
                            eval_dataloader_iter=valid_data_iterator,
                            iteration=iteration,
                        )
                        print_with_rank_and_datetime(
                            f'eval_metrics: {dict(eval_metrics)}', rank=0)
                    timers('eval').stop()

                    timers('ppo_step_time_cost').stop()

                    log_metrics_keys = [
                        'generate',
                        'hook_after_sampling',
                        'scoring',
                        'gen_rm_scoring',
                        'compute_logps',
                        'result_future',
                        'filter_sampling',
                        'rm_critic_future',
                        'hook_before_computing_metrics',
                        'compute_global_rollout_metrics',
                        'generate_ppo_data',
                        'rm_critic_client_train',
                        'make_rollout_dataloader_iter',
                        'train_step',
                        'update_sampler_weights',
                        'eval',
                        'gen_rm_scoring',
                        'ppo_step_time_cost',
                    ]
                    if wandb_writer:
                        for metrics_key in log_metrics_keys:
                            wandb_writer.log({f"train_s/{metrics_key}": timers(
                                metrics_key, log_level=0).elapsed(reset=False)}, iteration)
                    timers.log(log_metrics_keys, barrier=True)
                    print_meminfo_str("Track CPU Memory finish ppo step: ")

                                              
        writer = get_tensorboard_writer()
        if writer:
            writer.flush()

                                                                                         
                                                          
        if args.use_distributed_optimizer and args.overlap_param_gather:
            optimizer.disable_pre_hook()

                                                                                                
        if exit:
            if wandb_writer:
                wandb_writer.finish()
            sys.exit()

        return iteration, num_floating_point_operations_so_far

    def eval_loop(
        self,
        actor_model,
        sampler_client,
        rm_critic_client,
        gen_rm_client,
        rollout_get_batch_func,
        eval_dataloader_iter,
        iteration,
    ):
        '''
        每 args.ppo_step_eval_interval 个 ppo_step 进入一次 eval_loop, 对于每次 eval_loop:
            1. 消费 args.ppo_eval_steps 个 mini-batch (size = eval_rollout_gbs) 的 eval 数据
            2. 计算 eval metrics (default: 使用所有 eval rollout batches 计算 sample-wise mean.
               实际实现为每个 rb 内先各自算 sample-wise mean, 再对所有 rb metrics 求 mean, 避免保存所有 rb.
               对目前的 sequence_lengths, rewards 等指标适用，对方差、业务自定义的指标可能需要 override eval_loop)
            3. logging
        '''
        args = get_args()
        timers = get_timers()

        if args.ppo_eval_steps <= 0:
            raise ValueError(
                f'Trying to eval but got {args.ppo_eval_steps=}, should be > 0.')

                                                                                    
        actor_model.prepare_for_inference()

        dp_size = parallel_state.get_data_parallel_world_size()
        sampler_dp_size = args.ppo_sampler_data_parallel_size
        eval_rollout_gbs = args.ppo_eval_rollout_global_batch_size
        eval_rollout_mbs = args.ppo_eval_rollout_micro_batch_size
        assert eval_rollout_gbs % dp_size == 0, f'{eval_rollout_gbs=} % {dp_size=} != 0'
        assert eval_rollout_gbs % sampler_dp_size == 0, f'{eval_rollout_gbs=} % {sampler_dp_size=} != 0'
        eval_num_rollout_microbatches = divide(
            eval_rollout_gbs, eval_rollout_mbs * dp_size)

        perf_ctx = nullcontext()
        with perf_ctx:
            global_eval_rollout_metrics = defaultdict(lambda: 0)
            global_pbar = tqdm(
                range(args.ppo_eval_steps),
                leave=True,
                position=0,
                desc='PPO eval step',
                disable=torch.distributed.get_rank() != 0 or args.ppo_disable_tqdm,
            )
            for ppo_eval_step in global_pbar:
                clear_memory()

                                                      
                eval_rollout_batches, eval_rollout_metrics = self.eval_run_inference(
                    actor_model,
                    sampler_client,
                    rm_critic_client,
                    gen_rm_client,
                    ppo_eval_step,
                    eval_num_rollout_microbatches,
                    rollout_get_batch_func,
                    eval_dataloader_iter,
                )
                for key, value in eval_rollout_metrics.items():
                    global_eval_rollout_metrics[key] += value

                                                   
            for key in global_eval_rollout_metrics:
                global_eval_rollout_metrics[key] /= args.ppo_eval_steps

                     
            self.eval_logging(global_eval_rollout_metrics, iteration)

            log_metrics_keys = [
                "eval_generate",
                "eval_hook_after_sampling",
                "eval_scoring",
                "eval_gen_rm_scoring",
                "eval_update_sampler_weights",
                "broadcast_eval_rollout_batches",
                "eval_hook_before_computing_metrics",
                "compute_global_eval_rollout_metrics",
            ]

            wandb_writer = get_wandb_writer()
            if wandb_writer:
                for metrics_key in log_metrics_keys:
                    elapsed = timers(
                        metrics_key, log_level=0).elapsed(reset=False)
                    if elapsed > 0:
                        wandb_writer.log(
                            {f"train_s/{metrics_key}": elapsed}, iteration)
            timers.log(log_metrics_keys, barrier=True)

        actor_model.finish_inference()

        return global_eval_rollout_metrics

    def eval_run_inference(
        self,
        actor_model,
        sampler_client,
        rm_critic_client,
        gen_rm_client,
        ppo_eval_step,
        eval_num_rollout_microbatches,
        rollout_get_batch_func,
        eval_dataloader_iter,
    ) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        args = get_args()
        timers = get_timers()

        print_with_rank_and_datetime(
            f"eval_run_inference {eval_num_rollout_microbatches=}", rank=0)
        torch.distributed.barrier()
        eval_rollout_batches = self.get_rollout_batches(
            sampler_client,
            rm_critic_client,
            gen_rm_client,
            ppo_eval_step,
            eval_num_rollout_microbatches,
            rollout_get_batch_func,
            eval_dataloader_iter,
            is_eval=True,
        )
        torch.distributed.barrier()

                                                  
                                                                         
        if args.infer_engine_impl == 'sglang':
            if not args.ppo_early_swap_model:
                onload_megatron_model(actor_model.model)
                torch.distributed.barrier()

            timers('eval_update_sampler_weights',
                   log_level=0).start(barrier=True)
            print_with_rank_and_datetime('eval_update_sampler_weights', rank=0)
            sampler_client.update_weights(
                actor_model.model,
                early_swap_model=args.ppo_early_swap_model,
                cpu_memory_model=self.cpu_memory_model,
            )
            timers('eval_update_sampler_weights').stop()

            if not args.ppo_early_swap_model:
                offload_megatron_model(actor_model.model)
                torch.distributed.barrier()

        timers('broadcast_eval_rollout_batches',
               log_level=0).start(barrier=True)
                                    
        if is_mp_and_cp_head():
            assert check_rollout_batches(
                eval_rollout_batches), f'eval_rollout_batches fmt error: {eval_rollout_batches=}'
            for eval_rollout_batch in eval_rollout_batches:
                assert (
                    len(eval_rollout_batch) >= 3
                    and "tokens" in eval_rollout_batch
                    and "prompt_lengths" in eval_rollout_batch
                    and "sequence_lengths" in eval_rollout_batch
                    and "rewards" in eval_rollout_batch
                )

        eval_rollout_batches = sampler_client.broadcast_rollout_batch(
            eval_rollout_batches)
                   
        clear_memory()

        eval_rollout_batches = self.add_back_rollout_attr_before_sampling(
            eval_rollout_batches)
        assert check_rollout_batches(
            eval_rollout_batches), f'eval_rollout_batches fmt error: {eval_rollout_batches=}'
        timers('broadcast_eval_rollout_batches').stop()
        clear_memory()

        timers('eval_hook_before_computing_metrics',
               log_level=0).start(barrier=True)
        eval_rollout_batches = self.hook_before_computing_metrics(
            eval_rollout_batches, ppo_eval_step, is_eval=True)
        timers('eval_hook_before_computing_metrics').stop()

        timers('compute_global_eval_rollout_metrics',
               log_level=0).start(barrier=True)
        _metrics = cpu_dict(
            self.compute_global_rollout_metrics(eval_rollout_batches))
        eval_rollout_metrics = {}
        for key, value in _metrics.items():
            eval_key = 'eval-' + key
            eval_rollout_metrics[eval_key] = value
        timers('compute_global_eval_rollout_metrics').stop()

        self.display_rollout_generation(eval_rollout_batches)
        return eval_rollout_batches, eval_rollout_metrics

    def eval_logging(self, metrics, iteration):
        args = get_args()
        writer = get_tensorboard_writer()
        wandb_writer = get_wandb_writer()

        for key, value in metrics.items():
            if writer:
                writer.add_scalar(key, value, iteration)
                writer.add_scalar(key + ' vs samples', value,
                                  args.consumed_train_samples)
            if wandb_writer:
                wandb_writer.log({key: value}, iteration)

                                        
    def rollout_get_from_replay_queue(self, rb):
        "rb is the original rejected samples"
        assert False, "Inherit and override it"

                                                                    
                                                                          
                                                
    def get_replay_samples(self, rb):
        assert 'sample_useful' in rb
        assert 'trunc_index' in rb
        samples = rb['response_tokens'].shape[0]
        sample_not_useful_index = [index for index, useful in enumerate(
            rb['sample_useful']) if not useful]

        samples = []
        for i in sample_not_useful_index:
            sample_not_useful = {}
            for k in rb:
                if k == 'ready':
                    continue
                sample_not_useful[k] = copy.deepcopy(rb[k][i])

            sample_not_useful['response_tokens'] = sample_not_useful['response_tokens'][:rb['trunc_index'][i]]
            samples.append(sample_not_useful)
        return samples

                                                     
                              
    def replay_samples(
        self,
        rejected_sampling_num,
        max_ppo_step,
        sampler_client,
        rollout_get_batch_func,
        train_data_iterator,
    ):
        """
        replay samples

        issue_sampling_resps: list
        returns: a list and a int
        """
                             
        assert rejected_sampling_num > 0
        replay_queue_size = len(self.replay_queue)

        replay_sampling_resps = []

        is_replay_now = self.is_time_to_replay_samples(
            self.replay_queue, self.replay_samples_dict)
        assert len(is_replay_now) == replay_queue_size
        not_poped_queue = []
        for relay_i, (pop_flag, cur_sample_idx) in enumerate(zip(is_replay_now, self.replay_queue)):
            if pop_flag:
                sample_dict = self.replay_samples_dict[cur_sample_idx]
                self.replay_rollout_batch(sample_dict.prompt_data)
                if relay_i % 5 == 0:
                    print_with_rank_and_datetime(f"replaying sample {cur_sample_idx=} retry_i {sample_dict.replay_times} "
                                                 f"ppo_step {max_ppo_step=} {rejected_sampling_num=}")
                resp = sampler_client.issue_sampling(max_ppo_step,
                                                     cur_sample_idx,
                                                     prompt_data=sample_dict.prompt_data,
                                                     retry_i=sample_dict.replay_times)
                replay_sampling_resps.append(resp)
                rejected_sampling_num -= 1
                if rejected_sampling_num == 0:
                    not_poped_queue.extend(self.replay_queue[relay_i + 1:])
                    break
            else:
                not_poped_queue.append(cur_sample_idx)

        self.replay_queue = not_poped_queue
        print_with_rank_and_datetime(f"not_poped_data {self.replay_queue}")

        num_need_to_new_call = rejected_sampling_num
        for new_i in range(num_need_to_new_call):
            prompt_data = rollout_get_batch_func(train_data_iterator)
            prompt_data = self.remove_rollout_attr_before_sampling(prompt_data)
            if new_i % 5 == 0:
                print_with_rank_and_datetime(f"sampling new call {self.sample_idx=} retry_i 0 "
                                             f"ppo_step {max_ppo_step=} {rejected_sampling_num=}")
            resp = sampler_client.issue_sampling(max_ppo_step,
                                                 self.sample_idx,
                                                 prompt_data=prompt_data,
                                                 retry_i=0)
            replay_sampling_resps.append(resp)
            self.sample_idx += 1
            rejected_sampling_num -= 1

        return replay_sampling_resps

                              
    def is_rollout_batch_accepted(self, rb):
        """
        check if all batches are accepted

        returns: bool
        """
        return True

                              
    def update_replay_samples_dict(self, rb, sample_idx):
        """
        update replay samples dict

        returns: no return
        """
        pass

                              
    def is_time_to_replay_samples(self, replay_queue, replay_samples_dict):
        """
        check if need to replay samples
        replay_queue is a list that contains sample_idx
        replay_samples_dict is a dict that {sample_idx: sample_data}

        returns: a list of bool
        """
                                                          
        assert False, f"你继承并重写这个函数"
        return [True] * len(replay_queue)

                                 
    def hook_after_sampling(
        self, rollout_batches: List[Dict[str, List[Any]]]
    ) -> List[Dict[str, List[Any]]]:
        """
        hook after sampling

        Parameters
        ----------
        rollout_batches : List[Dict[str, List[Any]]]
            rollout batches containing sampling results.

        Returns
        -------
        List[Dict[str, List[Any]]]
            processed rollout batches.
        """
        return rollout_batches

                                           
    def hook_before_computing_metrics(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
        ppo_step: int,
        is_eval: bool = False,
    ) -> List[Dict[str, List[Any]]]:
        """
        hook before computing metrics

        Parameters
        ----------
        rollout_batches : List[Dict[str, List[Any]]]
            rollout batches before computing metrics, maybe containing sampling results, logprobs, something add in hook_after_sampling, etc.
        ppo_step : int
            ppo step of train or eval.
        is_eval : bool
            whether is eval.

        Returns
        -------
        List[Dict[str, List[Any]]]
            processed rollout batches.
        """
        return rollout_batches


class MultiModalPpoActorTrainer(PPOActorTrainerV3):

    def __init__(self, extra_metric_info=None):
        '''
        Parameters
        ----------
        extra_metric_info : list[dict[str, Any]]
            extra metrics to export to wandb.

            .. highlight:: python
            .. code-block:: python

                [
                  {'key_name': 'acc_rewards', 'dtype': torch.float32},
                  {'key_name': 'fmt_rewards', 'dtype': torch.float32},
                ]

        '''
        super().__init__(extra_metric_info=extra_metric_info)
        self.mm_data_cache: Dict[str, Dict[str, Any]] = {}
        self.mm_data_used: Set[str] = set()

    @override
    def remove_rollout_attr_before_sampling(self, rollout_batch: Dict[str, Any]) -> Dict[str, Any]:
        """
        remove attrs from before send a prompts to samplers (to avoid OOM).

        清理字段，防止 oom。

        Parameters
        ----------
        rollout_batch : Dict[str, List[Any]]
            rollout batch containing prompts

        Returns
        -------
        Dict[str, List[Any]]
            processed rollout batch
        """
        assert "unique_id" in rollout_batch
        assert "cache_keys" in rollout_batch

        unique_id = rollout_batch["unique_id"][0]
        cache_keys = rollout_batch.pop("cache_keys")
        assert unique_id not in self.mm_data_cache
        self.mm_data_cache[unique_id] = {}
        for key in cache_keys:
            assert key in rollout_batch
            self.mm_data_cache[unique_id][key] = rollout_batch[key]
            del rollout_batch[key]

        return rollout_batch

    @override
    def add_back_rollout_attr_before_sampling(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
    ) -> List[Dict[str, List[Any]]]:
        """
        add attrs back (those removed in ``remove_rollout_attr_before_sampling``).

        加回特殊被 ``remove_rollout_attr_before_sampling`` 删除的字段。

        Parameters
        ----------
        rollout_batch : Dict[str, List[Any]]
            rollout batch containing sampling results

        Returns
        -------
        Dict[str, List[Any]]
            processed rollout batch
        """
                                   
        self.mm_data_cache = broadcast_object_within_mp_and_cp(
            self.mm_data_cache)
        for rollout_batch in rollout_batches:
            assert "unique_id" in rollout_batch
            uniq_ids = rollout_batch.pop("unique_id")
            for unique_id in uniq_ids:
                assert unique_id in self.mm_data_cache, f"error: {unique_id=} {self.mm_data_cache.keys()=}"
                self.mm_data_used.add(unique_id)
                for k, v in self.mm_data_cache[unique_id].items():
                    if k not in rollout_batch:
                        rollout_batch[k] = []
                    rollout_batch[k].append(v)

        return rollout_batches

    @override
    def replay_rollout_batch(self, rollout_batch):
                                  
        assert "unique_id" in rollout_batch

        unique_id = rollout_batch["unique_id"][0]
        if unique_id in self.mm_data_used:
            self.mm_data_used.remove(unique_id)

    @override
    def clear_data_cache(self):
        data_used_tmp = list(self.mm_data_used)
        for k in data_used_tmp:
            del self.mm_data_cache[k]
            self.mm_data_used.remove(k)


def train_ppo_actor_v3(
    actor_trainer,
    model_provider,
    actor_provider,
    sampler_client_provider,
    rm_critic_client_provider,
    gen_rm_client_provider,
    train_valid_test_datasets_provider,
    rollout_get_batch_func,
    filter_samplings_func,
    model_type,
    process_non_loss_data_func=None,
    extra_args_provider=None,
    args_defaults={},
    store=None,
):
    '''
    训练 GRPO / PPO 的 actor。 model_provider 和 xx_func 设计成 hook 是为了历史兼容，后续 api 都用 oo 形式。

    Train GRPO / PPO actor. The design of model_provider and xx_func as hooks is for historical compatibility; future APIs will all use the oo (object-oriented) approach.

    Parameters
    ----------

    actor_trainer : gpatch.training.v3.ppo_actor.PPOActorTrainerV3
        Actor trainer

    model_provider : Callable
        A callable to provide megatron.GPTModel

    actor_provider : Callable
        A callable to provide GptPpoActorModel wrapper for actor model

    sampler_client_provider : Callable
        A callable to provide GptPpoSamplerClientV3

    rm_critic_client_provider : Callable
        A callable to provide GptPpoRmCriticClientV3

    gen_rm_client_provider : Callable
        A callable to provide GptPpoGenRmClientV3

    train_valid_test_datasets_provider : Callable
        A callable to provide torch.utils.data.dataset or torch.utils.data.DataLoader

    rollout_get_batch_func : Callable
        A callable to call vllm / sglang for generation (given prompts from dataset).

    model_type : megatron.core.enums.ModelType
        megatron model type, use ModelType.encoder_or_decoder for causal model.

    Returns
    -------
    '''

    print_meminfo_str("Track CPU Memory at starting:")
                                                                  
    mcore_version = Version(package_info.__version__)
    extra_args = {}
    if mcore_version >= Version("0.13.0"):
        extra_args = {"store": store}
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults,
                        **extra_args)
    args = get_args()
    validate_rl_args(args)
    args.rl_role = 'actor'
    parse_dataset_config(args)
    set_global_variables(args)
    init_pg(distributed_timeout_minutes=args.distributed_timeout_minutes)

                                                                    
    set_jit_fusion_options()
    if mcore_version >= Version("0.13.0"):
        energy_monitor = get_energy_monitor()
        energy_monitor.setup()

                                                         
    if args.auto_set_finetune_arg:
        txt = os.path.join(args.save, 'latest_checkpointed_iteration.txt')
        if (not os.path.exists(txt) or open(txt, 'r').read(7).strip() == 'release'):
            args.finetune = True
            args.no_load_optim = True
            args.no_load_rng = True
            print_with_rank_and_datetime(
                f'auto_set_finetune_arg found release {args.finetune=} {args.no_load_optim=} {args.no_load_rng=} {args.load=}',
                rank=0
            )
        else:
            args.finetune = False
            args.no_load_optim = False
            args.no_load_rng = False
            args.load = args.save
            print_with_rank_and_datetime(
                f'auto_set_finetune_arg found resuming {args.finetune=} {args.no_load_optim=} {args.no_load_rng=} {args.load=}',
                rank=0
            )

    cpu_barrier()
    if args.do_monitor:
        if torch.distributed.get_rank() == 0:
            start_monitor_server_in_background(
                args.monitor_server_ip, args.monitor_port)
        if torch.distributed.get_rank() % args.num_gpus_per_node == 0:
            start_monitor_client_in_background(
                args.rl_role, args.monitor_server_ip, args.monitor_port, args.monitor_max_time_wo_progress
            )

                                                               
                                                                
                         
    global _TRAIN_START_TIME
    start_time_tensor = torch.tensor(
        [_TRAIN_START_TIME], dtype=torch.double, device='cuda')
    torch.distributed.all_reduce(
        start_time_tensor, op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(time.time() -
                                                                        _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

    args = get_args()
    if args.ppo_wecube_report:
        init_wecube_reporter()

    httpx_logger = logging.getLogger("httpx")
    httpx_logger.setLevel(logging.WARNING)

    if args.num_experts is not None and args.num_experts > 1:
                          
        assert mpu.get_expert_tensor_parallel_world_size() == 1 or \
            mpu.get_expert_tensor_parallel_world_size(
        ) == mpu.get_tensor_model_parallel_world_size()

    assert args.ppo_rollout_micro_batch_size == 1, f"set --ppo-rollout-micro-batch-size 1"
    assert args.ppo_eval_rollout_micro_batch_size == 1, f"set --ppo-eval-rollout-micro-batch-size 1"
    hf_model_config = AutoConfig.from_pretrained(
        os.path.dirname(args.hf_config_json_path))
    if args.model_arch in ["gemma3"]:
        hf_model_config = hf_model_config.text_config
    args.hf_vocab_size = hf_model_config.vocab_size
    print_rank_0(
        f"config path {args.hf_config_json_path} {args.hf_vocab_size=}")
    timers = get_timers()

    print_rank_0('waiting for sampler to be ready ...')
    sampler_client = sampler_client_provider()
    sampler_client.wait_until_sampler_server_is_ready()
    cpu_barrier()

                                                                         
    if args.use_gen_rm:
        print_rank_0('waiting for gen rm to be ready ...')
        gen_rm_client = gen_rm_client_provider()
        gen_rm_client.wait_until_gen_rm_server_is_ready()
        cpu_barrier()
    else:
        gen_rm_client = None

                      
    if args.use_rm_and_critic:
        print_rank_0('waiting for critic to be ready ...')
        rm_critic_client = rm_critic_client_provider()
        rm_critic_client.wait_until_critic_server_is_ready()
        cpu_barrier()
    else:
        rm_critic_client = None

    print_rank_0('waiting for sampler setup...')
    sampler_client.setup()
    cpu_barrier()

                                  
    if args.ppo_debug_sglang_sleep_wakeup_generate:
        print_rank_0(
            'Debug parkeychen: testing sglang generating before/after sleep and wakeup')
        sampler_client.test_sglang_sleep_wakeup_generate()
        sys.exit(0)

                         
    if sampler_client.infer_engine_impl == "vllm":
        print_with_rank_and_datetime(
            f"call vllm sampler to build stateless group")
        sampler_client.init_weight_update_group(setup_head=True)
    elif sampler_client.infer_engine_impl == "sglang":
        sampler_client.init_weight_update_group(setup_head=False)
    else:
        raise NotImplementedError(
            f"Infer engine {sampler_client.infer_engine_impl} not implemented.")
    cpu_barrier()

                                   
    if args.ppo_debug_update_weight:
        sampler_client.test_generate()
        torch.distributed.barrier()

    print_rank_0('waiting for sampler sleep...')
    sampler_client.sleep()
    cpu_barrier()

    if args.use_gen_rm:
        print_rank_0('waiting for gen rm setup...')
        gen_rm_client.setup()
        cpu_barrier()

        print_rank_0('waiting for gen rm sleep...')
        gen_rm_client.sleep()
        cpu_barrier()

    if args.use_rm_and_critic:
        print_rank_0('waiting for critic setup...')
        rm_critic_client.setup()
        cpu_barrier()

        print_rank_0('waiting for critic sleep...')
        rm_critic_client.sleep()
        cpu_barrier()

    print_memory_tracking(
        f"Memory tracking: actor before setup model", verbose=True, rank=0)
    print_meminfo_str("Track CPU Memory before actor setup:")

                                          
    timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
    actor_model, optimizer, opt_param_scheduler = setup_model_and_optimizer_and_ref(
        model_provider, actor_provider, model_type)
    timers('model-and-optimizer-setup').stop()
    print_datetime(
        'after model, optimizer, and learning rate scheduler are built')
    config = get_model_config(actor_model)
    config.hf_vocab_size = args.hf_vocab_size

    print_memory_tracking(
        f"Memory tracking: actor after setup model", verbose=True, rank=0)

                                       
    if optimizer:
        config.grad_scale_func = optimizer.scale_loss
    config.timers = timers
    assert not args.overlap_grad_reduce and not args.overlap_param_gather
    assert config.no_sync_func is None
    assert config.grad_sync_func is None
    assert config.param_sync_func is None
    assert config.finalize_model_grads_func is None

    print_with_rank_and_datetime(
        f"moe model rank tp {mpu.get_tensor_model_parallel_rank()} pp {mpu.get_pipeline_model_parallel_rank()} cp {mpu.get_context_parallel_rank()} dp {mpu.get_data_parallel_rank()}"
        f" dcp {mpu.get_data_parallel_rank(with_context_parallel=True)} etp {mpu.get_expert_tensor_parallel_rank()} ep {mpu.get_expert_model_parallel_rank()} edp {mpu.get_expert_data_parallel_rank()}"
    )

    offload_megatron_optimizer(optimizer)
    if args.ppo_early_swap_model:
        actor_trainer.build_cpu_memory_model(
            unwrap_model(actor_model.model)[0])
        offload_megatron_model(actor_model.model)

    torch.distributed.barrier()
    print_meminfo_str("Track CPU Memory after first time offload: ")

                                                                         
                                                                
                                                     
                                  
                                    
    print_with_rank_and_datetime(f"begin update actor weight to vllm sampler")
    sampler_client.update_weights(actor_model.model,
                                  early_swap_model=args.ppo_early_swap_model,
                                  cpu_memory_model=actor_trainer.get_cpu_memory_model())
    print_with_rank_and_datetime(
        f"finished update actor weight to vllm sampler")

    if args.ppo_debug_update_weight:
        print_with_rank_and_datetime(
            f"begin update zero weight to vllm sampler")
        sampler_client.update_weights(actor_model.model, replace_zeros=True,
                                      early_swap_model=args.ppo_early_swap_model,
                                      cpu_memory_model=actor_trainer.get_cpu_memory_model())
        print_with_rank_and_datetime(
            f"finished update zero weight to vllm sampler")
        if sampler_client.infer_engine_impl == "vllm":
            sampler_client.check_sampler_zeros_weight()
            print_with_rank_and_datetime(f"after check zeros")
        else:
            pass
                                                                           
                                                                   
                                                                                     
        if not args.ppo_early_swap_model:
            offload_megatron_model(actor_model.model)
            torch.distributed.barrier()
        sampler_client.mark_sampler_ppo_step_begin(0)
        sampler_client.test_generate()
        sampler_client.sleep()

        if not args.ppo_early_swap_model:
            onload_megatron_model(actor_model.model)
            torch.distributed.barrier()
        print_with_rank_and_datetime(
            f"begin update actor weight to vllm sampler")
        sampler_client.update_weights(actor_model.model,
                                      early_swap_model=args.ppo_early_swap_model,
                                      cpu_memory_model=actor_trainer.get_cpu_memory_model())
        print_with_rank_and_datetime(
            f"finished update actor weight to vllm sampler")
        if not args.ppo_early_swap_model:
            offload_megatron_model(actor_model.model)
            torch.distributed.barrier()
        if sampler_client.infer_engine_impl == "vllm":
            sampler_client.check_sampler_zeros_weight()
        else:
            pass
                                         
                                                                                       
        sampler_client.mark_sampler_ppo_step_begin(0)
        sampler_client.test_generate()
        sampler_client.sleep()
        torch.distributed.barrier()

        sys.exit()

    if not args.ppo_early_swap_model:
        offload_megatron_model(actor_model.model)
        torch.distributed.barrier()

                 
    timers('train/valid/test-data-iterators-setup',
           log_level=0).start(barrier=True)
    assert args.virtual_pipeline_model_parallel_size is None
    if args.virtual_pipeline_model_parallel_size is None:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_datasets_provider)
    timers('train/valid/test-data-iterators-setup').stop()
    print_datetime('after dataloaders are built')

                                                                      
    checkpointing_context = {}

                         
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup',
               'train/valid/test-data-iterators-setup'], barrier=True)

    print_rank_0('training ...')
    iteration = 0
    print_meminfo_str("Track CPU Memory before training: ")
    if args.do_train and args.train_iters > 0:
        try:
            iteration, num_floating_point_operations_so_far = actor_trainer.train_loop(
                actor_model, sampler_client, rm_critic_client, gen_rm_client, rollout_get_batch_func,
                filter_samplings_func, optimizer, opt_param_scheduler, train_data_iterator,
                valid_data_iterator, process_non_loss_data_func, config, checkpointing_context)
        except:
                                                                                    
            traceback.print_exc()
            if args.do_monitor:
                set_exit_flag(args.monitor_server_ip, args.monitor_port)
            return

    print_datetime('after training is done')
    _, ppo_step = iter_to_ppo_epoch_step(iteration)
    if args.save and iteration != 0 and ppo_step % args.ppo_step_save_interval != 0:
                                                       
        sampler_client.sleep()
        cpu_barrier()
        if args.use_gen_rm:
            gen_rm_client.sleep()
        if args.use_rm_and_critic:
            rm_critic_client.sleep()
        cpu_barrier()

                      
        onload_megatron_model(actor_model.model)
        onload_megatron_optimizer(optimizer)
        torch.distributed.barrier()
        save_actor_and_critic_ckpt(iteration, actor_model, rm_critic_client, sampler_client,
                                   optimizer, opt_param_scheduler,
                                   num_floating_point_operations_so_far, checkpointing_context)

    wandb_writer = get_wandb_writer()
    if wandb_writer:
        wandb_writer.finish()

    if args.do_monitor:
        set_exit_flag(args.monitor_server_ip, args.monitor_port)
