
from collections import defaultdict
from typing_extensions import override
from types import SimpleNamespace
from typing import Callable, List, Any, Dict, Tuple, Set
import torch
from calendar import firstweekday
from contextlib import nullcontext
import asyncio
import sys
import os

from megatron.training.training import _TRAIN_START_TIME  # WARN: don't change its import order

from tqdm import tqdm
import torch.distributed

from megatron.core import parallel_state
from megatron.training.global_vars import (
    get_args,
    get_timers,
)
from gpatch.core.wecube import (
    report_ppo_metrics,
)
from gpatch.training.utils import (
    print_with_rank_and_datetime,
    print_rank_0,
    print_meminfo_str,
)
from gpatch.training.v3.ppo_actor import PPOActorTrainerV3
from gpatch.core.ppo_helper import (
    calculate_kl_penalty,
    create_mask,
    calculate_grpo_advantages,
)
from gpatch.core.aligner_helper import (
    cpu_dict,
    clear_memory,
    masked_mean_list,
    masked_global_statistics_list,
)
from gpatch.core.parallel_state import (
    is_mp_and_cp_head,
)
from gpatch.core.swap import (
    offload_megatron_model,
    onload_megatron_model,
)
from gpatch.core.utils import (
    print_memory_tracking,
    check_rollout_batches, 
    list_for_tensor_tolist,
)
from gpatch.core.wecube import (
    report_ppo_metrics,
)
from gpatch.rpc.monitor import (
    mark_made_progress,
)

from megatron.core import mpu
from megatron.core import parallel_state
from megatron.core.distributed import finalize_model_grads
from megatron.core.utils import divide
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.training import (
    print_datetime,
)
from megatron.training.utils import (
    calc_params_l2_norm,
    print_rank_0,
    unwrap_model,
)
from megatron.training.global_vars import (
    get_args,
    get_timers,
    get_tensorboard_writer,
    get_wandb_writer,
)
from megatron.core.num_microbatches_calculator import (
    get_num_microbatches, )

from gpatch.training.utils import print_with_rank_and_datetime, print_meminfo_str
from gpatch.core.utils import get_nvml_memory_info, print_memory_tracking
from gpatch.core.aligner_helper import (
    clear_memory,
    get_iterator_k_split_list,
    cpu_dict,
)
from gpatch.training.v3.ppo_actor import iter_to_ppo_epoch_step, save_actor_and_critic_ckpt
from gpatch.core.wecube import report_ppo_metrics
from gpatch.core.parallel_state import (
    is_mp_and_cp_head,
    cpu_barrier,
)
from gpatch.core.swap import (
    offload_megatron_model,
    onload_megatron_model,
    offload_megatron_optimizer,
    onload_megatron_optimizer,
)
from gpatch.core.aligner_helper import (
    expand_rollout_batches,
)
from gpatch.rpc.monitor import (
    mark_made_progress,
)
from gpatch.core.utils import check_rollout_batches
from gpatch.core.smart_pad_helper import smart_pad_train_get_reorder_rollout_batches
from gpatch.training.utils import training_log
class MathRLActorTrainer(PPOActorTrainerV3):

    @override
    def train_step(
        self,
        config,
        actor_model,
        optimizer,
        opt_param_scheduler,
        dataloader_iter,
        ppo_step,
        train_iters_each_rollout,
        iteration,
        total_loss_dict,
        rollout_metrics,
    ):
        args = get_args()
        timers = get_timers()
        actor_model.prepare_for_training()

        for batch in tqdm(
                dataloader_iter,
                initial=ppo_step * train_iters_each_rollout,
                total=args.ppo_max_epochs * args.ppo_step_per_epoch * train_iters_each_rollout,
                leave=True,
                position=2,
                desc='PPO actor train',
                disable=torch.distributed.get_rank() != 0 or args.ppo_disable_tqdm,
        ):
            actor_model.prepare_for_training_step()
            for model_chunk in actor_model.model:
                model_chunk.zero_grad_buffer()
            optimizer.zero_grad()

            timers('loss-and-metrics', log_level=1).start(barrier=args.barrier_with_L1_time)
            loss_mean, metrics = actor_model.get_loss_and_metrics(batch,
                                                                  get_num_microbatches(),
                                                                  forward_only=False,
                                                                  iteration=iteration,
                                                                  total_iters=args.train_iters,
                                                                  )
            timers('loss-and-metrics').stop()

            assert not config.finalize_model_grads_func
            finalize_model_grads(actor_model.model)
            actor_model.finish_training_step()

            timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
            update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
            timers('optimizer').stop()

            if update_successful:
                increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
                opt_param_scheduler.step(increment=increment)
                skipped_iter = 0
            else:
                skipped_iter = 1
            iteration += 1

            # logging and metrics reporting
            loss_scale = optimizer.get_loss_scale().item()
            params_norm = None
            if args.log_params_norm:
                params_norm = calc_params_l2_norm(actor_model.model)

            learning_rates = []
            decoupled_learning_rate = None
            for param_group in optimizer.param_groups:
                if param_group['is_decoupled_lr']:
                    decoupled_learning_rate = param_group['lr']
                else:
                    learning_rate = param_group['lr']
                    if learning_rate not in learning_rates:
                        learning_rates.append(learning_rate)

            metrics.update(rollout_metrics)
            training_log(metrics, total_loss_dict, learning_rates, decoupled_learning_rate,
                         iteration, loss_scale, False, skipped_iter, grad_norm, params_norm,
                         num_zeros_in_grad)
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

        actor_model.finish_training()
        optimizer.zero_grad()
        return iteration

    @override
    def get_rollout_batches(self, sampler_client, rm_critic_client, gen_rm_client, ppo_step,
                            num_microbatches, rollout_get_batch_func, dataloader_iter, is_eval=False):
        # use cpu barrier to prevent overrun cuda kernels
        args = get_args()
        timers = get_timers()

        eval_ = 'eval_' if is_eval else ''
        sample_idx = self.eval_sample_idx if is_eval else self.sample_idx

        # 独立 train 和 eval 的 sampling_repeat
        sampling_repeat = args.ppo_eval_sampling_repeat if is_eval else args.ppo_sampling_repeat
        # eval 没有 sampling_keep，和 sampling_repeat 一致
        sampling_keep = args.ppo_eval_sampling_repeat if is_eval else args.ppo_sampling_keep

        timers(f'{eval_}generate', log_level=0).start(barrier=True)

        print_memory_tracking(f"Memory tracking before rollout:")
        print_with_rank_and_datetime(f'get_rollout_batches mark_sampler_ppo_step_begin', rank=0)
        sampler_client.mark_sampler_ppo_step_begin(ppo_step)
        cpu_barrier()

        # sampling
        async def gen_co_1():
            batches = [None for _ in range(num_microbatches)]
            cos = []
            rollout_mbs = args.ppo_rollout_micro_batch_size
            for rbi in range(num_microbatches):
                if not is_eval and args.replay_sample and self.replay_queue.qsize() >= rollout_mbs:
                    rb = None
                    for i in range(rollout_mbs):
                        sample = self.replay_queue.get()
                        if rb is None:
                            rb = {}
                            for k in sample.keys():
                                rb[k] = [sample[k]]
                        else:
                            for k in rb.keys():
                                rb[k].append(sample[k])
                    batches[rbi] = self.rollout_get_from_replay_queue(rb)
                else:
                    batches[rbi] = rollout_get_batch_func(dataloader_iter)
                batches[rbi] = self.remove_rollout_attr_before_sampling(batches[rbi])
                co = sampler_client.generate(ppo_step, sample_idx + rbi, batches[rbi], sampling_repeat)
                cos.append(co)

            return await asyncio.gather(*cos)

        print_with_rank_and_datetime(f'{eval_}get_rollout_batches generate', rank=0)
        num_rollout_samples = num_microbatches

        if is_mp_and_cp_head():
            rbs: List[Dict[str, List[Any]]] = asyncio.run(gen_co_1())
            assert check_rollout_batches(rbs), f"rbs format error, may need pop('ready'): {rbs=}"
        else:
            rbs = [None for _ in range(num_microbatches)]
        cpu_barrier()
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

        # flush cache
        sampler_client.infer_engine_flush_cache()
        print_memory_tracking(f"Memory tracking: actor after sampler flush cache")
        cpu_barrier()

        print_with_rank_and_datetime(f'{eval_}get_rollout_batches mark_sampler_ppo_step_end', rank=0)
        sampler_client.mark_sampler_ppo_step_end(ppo_step)
        cpu_barrier()
        timers(f'{eval_}generate').stop()  # 这样看不出来长尾...

        torch.cuda.synchronize()
        print_memory_tracking(f"Memory tracking: actor after sampler sleep")

        # gen rm scoring
        async def gen_rm_co_1(req_rbs, rbs_idx, prompt_idx):
            cos = []

            for rb, rbi in zip(req_rbs, rbs_idx, strict=True):
                co = gen_rm_client.generate_rewards(ppo_step, sample_idx + rbi,
                                                    prompt_idx, rb)
                cos.append(co)

            return await asyncio.gather(*cos)

        if args.use_gen_rm:
            print_with_rank_and_datetime(f'{eval_}get_rollout_batches mark_gen_rm_ppo_step_begin', rank=0)
            timers(f'{eval_}gen_rm_scoring', log_level=0).start(barrier=True)
            gen_rm_client.mark_gen_rm_ppo_step_begin(ppo_step)
            cpu_barrier()

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches with {args.ppo_num_rm} gen_rm(s)'
                                         f' with rm_prompt mapping {self.rm_prompt_index_mapping}',
                                         rank=0)
            # reindex prompt index in rm_prompt_index_mapping
            prompt_start_idx = 0
            # single prompt needs to append all prompt to req_rbs
            is_single_prompt = len(self.rm_prompt_index_mapping) == 1 and \
                               len(self.rm_prompt_index_mapping[0]) == 1
            for rm_idx, prompt_list in enumerate(self.rm_prompt_index_mapping):
                # wake up rm indexed with rm_idx, default 0 gen_rm doesn't need wake up
                if rm_idx != 0:
                    print_with_rank_and_datetime(f"ppo_actor {rm_idx=}")
                    gen_rm_client.update_gen_rm_weight_by_model_idx(rm_idx)

                for pl_idx in range(len(prompt_list)):
                    prompt_idx = prompt_start_idx + pl_idx
                    if is_mp_and_cp_head():
                        req_rbs_list = []
                        req_idx_list = []
                        for rbi, rb in enumerate(rbs):
                            if f'rm_{prompt_idx}_tokens' in rb or is_single_prompt:
                                req_idx_list.append(rbi)
                                req_rbs_list.append(rb)
                        
                        rbs_result = asyncio.run(gen_rm_co_1(req_rbs_list,
                                                             req_idx_list, prompt_idx))
                        for rb_idx, rb_result in zip(req_idx_list,
                                                     rbs_result, strict=True):
                            rbs[rb_idx].update(rb_result)
                    else:
                        rbs = [None for _ in range(num_rollout_samples)]
                prompt_start_idx += len(prompt_list)
                cpu_barrier()
            cpu_barrier()

            if is_mp_and_cp_head():
                self.post_process_rm_rollout_batch(rbs)

            gen_rm_client.infer_engine_flush_cache()

            print_memory_tracking(f"Memory tracking: actor after gen_rm flush cache")
            cpu_barrier()

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches mark_gen_rm_ppo_step_end', rank=0)
            gen_rm_client.mark_gen_rm_ppo_step_end(ppo_step)
            cpu_barrier()
            timers(f'{eval_}gen_rm_scoring').stop()

            print_memory_tracking(f'Memory tracking: actor after gen_rm sleep')
            if args.do_monitor:
                mark_made_progress(args.monitor_server_ip, args.monitor_port)

        # bt scoring
        if args.use_rm_and_critic:
            print_with_rank_and_datetime(f'{eval_}get_rollout_batches mark_rm_ppo_step_begin', rank=0)
            timers(f'{eval_}scoring', log_level=0).start(barrier=True)
            rm_critic_client.mark_rm_ppo_step_begin(ppo_step)
            cpu_barrier()

            async def rm_critic_co_1():
                cos = [
                    rm_critic_client.issue_infer_rm_critic(ppo_step, sample_idx + rbi,
                                                           rbs[rbi])
                    for rbi in range(num_rollout_samples)
                ]
                return await asyncio.gather(*cos)

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches issue_infer_rm_critic', rank=0)
            if is_mp_and_cp_head():
                asyncio.run(rm_critic_co_1())
            cpu_barrier()

            async def rm_critic_co_2():
                cos = [
                    rm_critic_client.get_infer_rm_critic_result(ppo_step, sample_idx + rbi, sampling_repeat)
                    for rbi in range(num_rollout_samples)
                ]
                return await asyncio.gather(*cos)

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches get_infer_rm_critic_result', rank=0)
            if is_mp_and_cp_head():
                crbs = asyncio.run(rm_critic_co_2())
                for rb, crb in zip(rbs, crbs, strict=True):
                    rb.update(crb)
                assert check_rollout_batches(rbs), f"rbs format error, may need pop('ready'): {crbs=}"
            else:
                rbs = [None for _ in range(num_rollout_samples)]
            cpu_barrier()

            print_with_rank_and_datetime(f'{eval_}get_rollout_batches mark_rm_ppo_step_end', rank=0)
            rm_critic_client.mark_rm_ppo_step_end(ppo_step)
            cpu_barrier()
            timers(f'{eval_}scoring').stop()
            if args.do_monitor:
                mark_made_progress(args.monitor_server_ip, args.monitor_port)

        if not is_eval and args.replay_sample and is_mp_and_cp_head():
            for rb in rbs:
                if 'sample_useful' in rb:
                    replay_samples = self.get_replay_samples(rb)
                    for sample in replay_samples:
                        self.replay_queue.put(sample)

        if is_eval:
            self.eval_sample_idx += num_rollout_samples
        else:
            self.sample_idx += num_rollout_samples

        # 使用 replay_samples 时，可能要改动位置
        self.clear_data_cache()
        return rbs

    @override
    def train_loop(self, actor_model, sampler_client, rm_critic_client, gen_rm_client,
                   rollout_get_batch_func, filter_samplings_func, optimizer, opt_param_scheduler,
                   train_data_iterator, valid_data_iterator, process_non_loss_data_func, config,
                   checkpointing_context):
        """Train the model function."""
        args = get_args()
        timers = get_timers()

        # Write args to tensorboard
        write_args_to_tensorboard()

        # Tracking loss.
        total_loss_dict = {}

        # Iterations.
        iteration = args.iteration
        num_floating_point_operations_so_far = args.num_floating_point_operations_so_far

        timers('interval-time', log_level=0).start(barrier=True)
        print_datetime('before the start of training step')
        exit = False

        epoch, ppo_step = iter_to_ppo_epoch_step(iteration)
        dp_size = parallel_state.get_data_parallel_world_size()
        sampler_dp_size = args.ppo_sampler_data_parallel_size
        rollout_gbs = args.ppo_rollout_global_batch_size
        rollout_mbs = args.ppo_rollout_micro_batch_size
        assert rollout_gbs % dp_size == 0, f"{rollout_gbs=} {dp_size=}"
        assert rollout_gbs % sampler_dp_size == 0, f"{rollout_gbs=} {sampler_dp_size=}"
        num_rollout_micro_batches = divide(rollout_gbs, rollout_mbs * dp_size)
        ppo_step_since_recover = 0

        print_rank_0((f'actor iter_to_ppo_epoch_step'
                      f' num_rollout_micro_batches {num_rollout_micro_batches}'
                      f' train_iters_each_rollout {args.train_iters_each_rollout}'
                      f' ppo_step {ppo_step}'
                      f' epoch {epoch}'))
        assert args.train_iters == args.ppo_max_epochs * args.ppo_step_per_epoch * args.train_iters_each_rollout, (
            f'invalid args train_iters {args.train_iters} != ppo_max_epochs {args.ppo_max_epochs}'
            f' * ppo_step_per_epoch {args.ppo_step_per_epoch} * train_iters_each_rollout {args.train_iters_each_rollout}'
        )

        epoch_iter = range(epoch, args.ppo_max_epochs)
        if len(epoch_iter) <= 0:
            return

        wandb_writer = get_wandb_writer()
        for epoch_idx in epoch_iter:
            num_steps_in_epoch = args.ppo_step_per_epoch - ppo_step % args.ppo_step_per_epoch
            loop_iter = range(num_steps_in_epoch)
            if not loop_iter:
                return
            
            if args.ppo_sampling_dynamic != "static":
                if args.ppo_sampling_keep_final > 0:
                    ppo_sampling_keep_original = args.ppo_sampling_keep
                    ppo_sampling_keep_new = args.ppo_sampling_keep_final
                    assert ppo_sampling_keep_original <= args.ppo_sampling_repeat, f"{ppo_sampling_keep_original=} {args.ppo_sampling_repeat=}"
                    assert ppo_sampling_keep_new <= args.ppo_sampling_repeat, f"{ppo_sampling_keep_new=} {args.ppo_sampling_repeat=}"
                    print_rank_0(f"[PODS][DYNAMIC] ppo_sampling_keep_original {ppo_sampling_keep_original} -> ppo_sampling_keep_new {ppo_sampling_keep_new}")
                else:
                    # default: 0.5x -> 2x
                    ppo_sampling_keep_original = args.ppo_sampling_keep / 2
                    ppo_sampling_keep_new = min(args.ppo_sampling_keep * 2, args.ppo_sampling_repeat)
                    print_rank_0(f"[PODS][DYNAMIC] ppo_sampling_keep_original {ppo_sampling_keep_original} -> ppo_sampling_keep_new {ppo_sampling_keep_new}")
                
                ppo_sampling_keep_step = (ppo_sampling_keep_new - ppo_sampling_keep_original) // len(epoch_iter)
                # use list of original and new
                ppo_sampling_keep_original_list = torch.arange(start=ppo_sampling_keep_original, end=ppo_sampling_keep_new, step=ppo_sampling_keep_step)
                ppo_sampling_keep_new_list = ppo_sampling_keep_original_list + ppo_sampling_keep_step
            else:
                print_rank_0(f"[PODS][STATIC] ppo_sampling_keep {args.ppo_sampling_keep}")


            total_steps = args.ppo_max_epochs * args.ppo_step_per_epoch


            perf_ctx = nullcontext()
            with perf_ctx:
                global_pbar = tqdm(
                    loop_iter,
                    initial=ppo_step,
                    total=total_steps,
                    leave=True,
                    position=0,
                    desc="PPO actor global step",
                    disable=torch.distributed.get_rank() != 0 or args.ppo_disable_tqdm,
                )
                for step_idx in global_pbar:
                    timers('ppo_step_time_cost', log_level=0).start(barrier=True)

                    if args.ppo_sampling_dynamic == "linear":
                        ratio = float(step_idx / (args.ppo_step_per_epoch - 1))  # 变化比例，0到1之间
                        args.ppo_sampling_keep = int(
                            ppo_sampling_keep_original_list[epoch_idx] * (1 - ratio) 
                            + ppo_sampling_keep_new_list[epoch_idx] * ratio
                        )
                    else:
                        assert args.ppo_sampling_dynamic == "static"

                    # assert args.ppo_rollout_global_batch_size == args.global_batch_size
                    print_rank_0(f"[PODS][DEBUG] {step_idx=} {args.ppo_sampling_keep=} {args.ppo_rollout_global_batch_size=} {args.global_batch_size=}")

                    # gen rollouts
                    clear_memory()
                    rollout_batches, rollout_metrics = self.generate_rollouts(
                        rollout_get_batch_func,
                        filter_samplings_func,
                        actor_model,
                        rm_critic_client,
                        gen_rm_client,
                        sampler_client,
                        train_data_iterator,
                        ppo_step,
                        num_rollout_micro_batches,
                    )

                    # send critic train
                    if not args.use_grpo:
                        timers('rm_critic_client_train', log_level=0).start(barrier=True)
                        clear_memory()
                        
                        rm_critic_client.train(rollout_batches)
                        timers('rm_critic_client_train').stop()

                    # onload actor
                    onload_megatron_optimizer(optimizer)
                    torch.distributed.barrier()

                    clear_memory()
                    torch.distributed.barrier()
                    print_with_rank_and_datetime('train', rank=0)
                    ex_rollout_batches = expand_rollout_batches(rollout_batches)
                    total_samples = args.ppo_rollout_global_batch_size * args.ppo_sampling_keep
                    assert len(ex_rollout_batches) * mpu.get_data_parallel_world_size(
                    ) == total_samples, f"{len(ex_rollout_batches)} * {mpu.get_data_parallel_world_size()} != {total_samples}"

                    if args.ppo_smart_pad_train:
                        num_global_batch = args.ppo_rollout_global_batch_size * args.ppo_sampling_keep // args.global_batch_size
                        train_batch_size_per_dp = total_samples // num_global_batch // args.ppo_actor_data_parallel_size
                        print_with_rank_and_datetime(f"[SMART_PAD_TRAIN] {args.ppo_rollout_global_batch_size=} {args.ppo_sampling_keep=} {args.global_batch_size=} {num_global_batch=} {train_batch_size_per_dp=}", rank=0)
                        reorder_ex_rollout_batches = smart_pad_train_get_reorder_rollout_batches(
                            ex_rollout_batches, num_global_batch, train_batch_size_per_dp,
                            args.ppo_rollout_pad_to_multiple_of, ppo_step
                        )
                        ex_rollout_batches = reorder_ex_rollout_batches

                    for _ in range(args.ppo_max_epochs_2):
                        # make train iterator
                        timers('make_rollout_dataloader_iter', log_level=0).start(barrier=True)
                        # padding 这些放到最里面这里可能有性能问题
                        # smart padding 可能得在 expand_rollout_batches 后直接就做了
                        rollout_dataloader_iter = get_iterator_k_split_list(
                            ex_rollout_batches, args.ppo_rollout_global_batch_size *
                            args.ppo_sampling_keep // args.global_batch_size)
                        timers('make_rollout_dataloader_iter').stop()

                        # start training
                        timers('train_step', log_level=0).start(barrier=True)
                        iteration = self.train_step(
                            config,
                            actor_model,
                            optimizer,
                            opt_param_scheduler,
                            rollout_dataloader_iter,
                            ppo_step,
                            args.train_iters_each_rollout,
                            iteration,
                            total_loss_dict,
                            rollout_metrics,
                        )
                        timers('train_step').stop()

                    ppo_step += 1
                    ppo_step_since_recover += 1

                    # 用 actor 更新 ref
                    if args.ppo_update_ref_w_actor_interval > 0 and ppo_step % args.ppo_update_ref_w_actor_interval == 0:
                        print_with_rank_and_datetime('update_ref_with_actor constantly', rank=0)
                        actor_model.update_ref_with_actor(args.ppo_update_ref_w_actor_coef)

                    # save ckpt
                    assert args.save_interval % args.train_iters_each_rollout == 0
                    if args.save and args.save_interval and ppo_step % args.ppo_step_save_interval == 0:
                        save_actor_and_critic_ckpt(iteration, actor_model, rm_critic_client,
                                                   sampler_client, optimizer, opt_param_scheduler,
                                                   num_floating_point_operations_so_far,
                                                   checkpointing_context)

                    # offload actor optimizer
                    offload_megatron_optimizer(optimizer)
                    if args.ppo_early_swap_model:
                        self.update_cpu_memory_model(unwrap_model(actor_model.model)[0])
                        offload_megatron_model(actor_model.model)
                    torch.distributed.barrier()

                    # 用 actor 更新 sampler
                    timers('update_sampler_weights', log_level=0).start(barrier=True)
                    u_intv = args.ppo_step_update_sampler_interval
                    if args.ppo_standalone_sampler and u_intv > 0 and ppo_step_since_recover % u_intv == 0:
                        print_with_rank_and_datetime('update_sampler_weights', rank=0)
                        sampler_client.update_weights(actor_model.model,
                                                      early_swap_model=args.ppo_early_swap_model,
                                                      cpu_memory_model=self.cpu_memory_model)
                    timers('update_sampler_weights').stop()

                    # offload actor model
                    if not args.ppo_early_swap_model:
                        offload_megatron_model(actor_model.model)
                        torch.distributed.barrier()

                    if args.ppo_wecube_report and is_mp_and_cp_head():
                        report_data = {'timer_report_times': 1}
                        for report_key in [
                                'generate', 'compute_logps', 'scoring', 'rm_critic_client_train',
                                'train_step', 'update_sampler_weights', 'ppo_step_time_cost',
                        ]:
                            report_data['timer_' + report_key] = timers(
                                report_key, log_level=0).elapsed(reset=False)
                        report_ppo_metrics(report_data)

                    # eval
                    timers('eval', log_level=0).start(barrier=True)
                    if args.ppo_step_eval_interval > 0 and ppo_step % args.ppo_step_eval_interval == 0:
                        print_with_rank_and_datetime(f'eval at ppo_step={ppo_step - 1}', rank=0)
                        eval_metrics = self.eval_loop(
                            actor_model=actor_model,
                            sampler_client=sampler_client,
                            rm_critic_client=rm_critic_client,
                            gen_rm_client=gen_rm_client,
                            rollout_get_batch_func=rollout_get_batch_func,
                            eval_dataloader_iter=valid_data_iterator,
                            iteration=iteration,
                        )
                        print_with_rank_and_datetime(f'eval_metrics: {dict(eval_metrics)}', rank=0)
                    timers('eval').stop()

                    timers('ppo_step_time_cost').stop()

                    log_metrics_keys = [
                            'generate',
                            'scoring',
                            'gen_rm_scoring',
                            'compute_logps',
                            'result_future',
                            'filter_sampling',
                            'rm_critic_future',
                            'compute_global_rollout_metrics',
                            'generate_ppo_data',
                            'rm_critic_client_train',
                            'make_rollout_dataloader_iter',
                            'train_step',
                            'update_sampler_weights',
                            'eval',
                            'gen_rm_scoring',
                            'ppo_step_time_cost',
                        ]
                    if wandb_writer:
                        for metrics_key in log_metrics_keys:
                            wandb_writer.log({f"train_s/{metrics_key}": timers(
                                        metrics_key, log_level=0).elapsed(reset=False)}, iteration)
                    timers.log(log_metrics_keys, barrier=True)
                    print_meminfo_str("Track CPU Memory finish ppo step: ")

        # Flush TensorBoard and WandB writers.
        writer = get_tensorboard_writer()
        if writer:
            writer.flush()

        # Close out pre-hooks if using distributed optimizer and overlapped param gather.
        # FIXME replace it with `disable_forward_pre_hook`
        if args.use_distributed_optimizer and args.overlap_param_gather:
            optimizer.disable_pre_hook()

        # If any exit conditions (signal handler, duration, iterations) have been reached, exit.
        if exit:
            if wandb_writer:
                wandb_writer.finish()
            sys.exit()

        return iteration, num_floating_point_operations_so_far

    @override
    def generate_ppo_data(self,
                          rollout_batches) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        """generate ppo specific data for training
        """
        args = get_args()
        ppo_rollout_metrics = defaultdict(lambda: 0)
        num_samples = 0
        print_with_rank_and_datetime('rollout generate_ppo_data', rank=0)

        
        # 可以考虑让 logprobs 与 ref_logprobs 在计算完成后 padding 到最大长度
        # NOTE: ppo 相关的还需要测试
        for rollout_batch in rollout_batches:
            # NOTE: all items in rollout batch or out of this computation must have a leading B dimension
            prompt_lengths = rollout_batch["prompt_lengths"]
            sequence_lengths = rollout_batch["sequence_lengths"]
            tokens = rollout_batch["tokens"]
            values = rollout_batch.get("values", [None])
            if values[0] is None:
                values = None
            rewards = rollout_batch["rewards"]
            per_token_rewards = rollout_batch["per_token_rewards"]
            logprobs = rollout_batch["logprobs"]
            num_samples += len(prompt_lengths)

            try:
                mask = rollout_batch["mask"]
            except KeyError:
                mask = create_mask(values=logprobs,
                                prompt_lengths=prompt_lengths,
                                sequence_lengths=sequence_lengths,
                                dtype=rewards[0].dtype)
                rollout_batch["mask"] = mask
            

            # 暂时先关闭这个情况
            assert not args.use_grpo or (args.use_grpo and args.ppo_initial_policy_kl_penalty)

            if args.ppo_initial_policy_kl_penalty > 0:
                ref_logprobs = rollout_batch["ref_logprobs"]
                init_policy_kl = calculate_kl_penalty(
                    log_probs_a=logprobs,
                    log_probs_b=ref_logprobs,
                    use_absolute_kl=args.ppo_use_absolute_kl,
                )
            else:
                init_policy_kl = torch.tensor(0, dtype=logprobs.dtype, device=logprobs.device)

            assert args.use_grpo, f"use_grpo {args.use_grpo} should be True"
            try:
                advantages, returns = rollout_batch["advantages"], rollout_batch["returns"]
                assert advantages[0].dtype == torch.float32
            except KeyError:
                advantages, returns = calculate_grpo_advantages(
                    rewards=rewards,
                    mask=mask,
                    grpo_sampling_times=args.ppo_sampling_keep,
                    grpo_advantage_epsilon=args.grpo_advantage_epsilon,
                )
                rollout_batch["returns"] = returns
                assert advantages[0].dtype == torch.float32
                rollout_batch["advantages"] = advantages

            # compute metrics
            # NOTE: this metric is not accumulated globally so it will differ between DP ranks
            if args.ppo_initial_policy_kl_penalty > 0:
                ppo_rollout_metrics["ppo-metrics/init_policy_kl"] += masked_mean_list(
                    init_policy_kl, mask, dim=-1).sum().item()
        assert check_rollout_batches(rollout_batches), f"check rbs fmt error {rollout_batches=}"
        # average across the samples for the non global metrics
        ppo_rollout_metrics = {k: v / num_samples for k, v in ppo_rollout_metrics.items()}

        mask_list = []
        for rollout_batch in rollout_batches:
            mask_list.extend(rollout_batch["mask"])
        for key in ["advantages", "returns", "values", 'per_token_rewards']:
            if key not in rollout_batches[0] or rollout_batches[0][key][0] is None:
                continue
            tensor_list = []
            for rollout_batch in rollout_batches:
                tensor_list.extend(rollout_batch[key])
            global_mean, global_var, min_var, max_var = masked_global_statistics_list(
                tensor_list,
                mask_list,
                group=parallel_state.get_data_parallel_group(),
            )
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_mean"] = global_mean.item()
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_std"] = global_var.sqrt().item()
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_min"] = min_var.item()
            ppo_rollout_metrics[f"ppo-metrics/global_{key}_max"] = max_var.item()
        if args.ppo_normalize_advantages and not args.use_grpo:
            
            # ppo_rollout_data["advantages"] 直接写到 rollout_batches 就好
            # 查考 masked_global_mean_var_list
            # ppo_rollout_data["advantages"] = normalize_tensor(
            #     ppo_rollout_data["advantages"],
            #     ppo_rollout_data["mask"],
            #     group=parallel_state.get_data_parallel_group(),
            # )
            raise NotImplemented

        return rollout_batches, cpu_dict(ppo_rollout_metrics)

    @override
    @torch.no_grad()
    def generate_rollouts(self, rollout_get_batch_func, filter_samplings_func, actor_model,
                          rm_critic_client, gen_rm_client, sampler_client, dataloader_iter,
                          ppo_step,
                          num_microbatches) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        timers = get_timers()
        args = get_args()
        actor_model.prepare_for_inference()

        rollout_batches, rollout_metrics = self.run_inference(rollout_get_batch_func,
                                                              filter_samplings_func, actor_model,
                                                              rm_critic_client, gen_rm_client,
                                                              sampler_client, dataloader_iter,
                                                              ppo_step, num_microbatches)

        timers('generate_ppo_data', log_level=0).start(barrier=True)
        rollout_batches, ppo_rollout_metrics = self.generate_ppo_data(rollout_batches)
        timers('generate_ppo_data').stop()

        if args.ppo_save_first_rollout_data:
            if torch.distributed.get_rank() == 0:
                debug_dir = "./debug-tmp"
                os.makedirs(debug_dir, exist_ok=True)
                torch.save(rollout_batches, f"{debug_dir}/first_rollout_data_rank0.pt")
                torch.save(rollout_metrics, f"{debug_dir}/first_rollout_metrics_rank0.pt")

        actor_model.finish_inference()
        return rollout_batches, rollout_metrics | ppo_rollout_metrics


    @override
    def is_rollout_batch_accepted(self, rb):
        """
        check if all batches are accepted
        
        rb: a list of rollout batches

        returns: no return
        """
        return rb['sample_useful'][0].item()

    @override
    def update_replay_samples_dict(self, rb, sample_idx):
        """
        update replay samples dict
        
        rb: a list of rollout batches, assert 
        sample_idx: index of the sample

        returns: no return
        """
        args = get_args()
        max_replay_times = args.ppo_dynamic_sampling_max_replay

        if sample_idx not in self.replay_samples_dict:
            prompt_lengths = rb['prompt_lengths']
            assert prompt_lengths.ndim == 1
            lpad_lens_list = prompt_lengths.tolist()
            response_tokens = rb["response_tokens"]
            prompt_tokens = response_tokens[0][:lpad_lens_list[0]]
            
            gt_label = rb['gt_label']
            train_data_consuming_progress = rb.get('train_data_consuming_progress', None)
            prompt_data = {
                "prompt_token_ids": [dict(prompt_token_ids=prompt_tokens.tolist())],
                "lpad_lens": prompt_lengths[0].view(1),
                "gt_label": gt_label[0].view(1),
                "train_data_consuming_progress": train_data_consuming_progress,
            }
            self.replay_samples_dict.update(
                {sample_idx: SimpleNamespace(prompt_data=prompt_data, replay_times=1)})
            # 首次加入队列
            self.replay_queue.append(sample_idx)
        else:
            if self.replay_samples_dict[sample_idx].replay_times > max_replay_times:
                # 超过最大重试限制，弹出数据
                removed_value = self.replay_samples_dict.pop(sample_idx)
                print_with_rank_and_datetime(
                    f"failure. give up replay sample {sample_idx=} replay_times {removed_value.replay_times}"
                )
            else:
                self.replay_samples_dict[sample_idx].replay_times += 1
                # 重新加入队列
                self.replay_queue.append(sample_idx)

    @override
    def is_time_to_replay_samples(self, replay_queue, replay_samples_dict):
        """
        check if need to replay samples
        replay_queue is a list that contains sample_idx
        replay_samples_dict is a dict that {sample_idx: sample_data}

        returns: a list of bool
        """
        # # for test：模拟一种情况不完全消费完的情况
        # if self.test_flag:
        #     self.test_flag = False
        #     res = [True] * len(replay_queue)
        #     res[-1] = False
        #     return res

        # 暂时规则，有就直接重放，你需要自定义规则
        return [True] * len(replay_queue)
    
    @override
    def run_inference(
        self,
        rollout_get_batch_func,
        filter_samplings_func,
        actor_model,
        rm_critic_client,
        gen_rm_client,
        sampler_client,
        dataloader_iter,
        ppo_step,
        num_microbatches,
    ) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        """this function is run per DP so the metrics need to be computed globally
        """
        args = get_args()
        timers = get_timers()

        print_with_rank_and_datetime(f'rollout run_inference {num_microbatches=}', rank=0)
        torch.distributed.barrier()
        rollout_batches = self.get_rollout_batches(sampler_client, rm_critic_client, gen_rm_client,
                                                   ppo_step, num_microbatches,
                                                   rollout_get_batch_func, dataloader_iter)
        torch.distributed.barrier()

        onload_megatron_model(actor_model.model)
        torch.distributed.barrier()

        timers('broadcast_rollout_batches', log_level=0).start(barrier=True)
        # check rollout_batches
        if is_mp_and_cp_head():
            assert check_rollout_batches(rollout_batches), f"rollout_batches fmt error: {rollout_batches=}"
            for rollout_batch in rollout_batches:
                assert len(rollout_batch) >= 3 \
                    and 'tokens' in rollout_batch \
                    and 'prompt_lengths' in rollout_batch \
                    and 'sequence_lengths' in rollout_batch \
                    and 'rewards' in rollout_batch \

        print_memory_tracking(f"Memory tracking: before bcast data", rank=args.world_size-1)

        rollout_batches = sampler_client.broadcast_rollout_batch(rollout_batches)
        clear_memory()
        print_memory_tracking(f"Memory tracking: after bcast data", rank=args.world_size-1)

        rollout_batches = self.add_back_rollout_attr_before_sampling(rollout_batches)
        assert check_rollout_batches(rollout_batches), f"rollout_batches fmt error: {rollout_batches=}"
        timers('broadcast_rollout_batches').stop()
        print_meminfo_str("Track CPU Memory after get_rollout_batches: ")

        # important
        clear_memory()
        print_memory_tracking(f"Memory tracking: before get logps", rank=args.world_size-1)

        timers('compute_logps', log_level=0).start(barrier=True)
        if args.ppo_smart_pad_infer:
            policy_logprobs = actor_model.batch_get_policy_logprobs(rollout_batches)
        else:
            policy_logprobs = actor_model.get_policy_logprobs(rollout_batches)
        assert policy_logprobs is not None
        for logprobs, rb in zip(policy_logprobs, rollout_batches, strict=True):
            rb["logprobs"] = logprobs
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)

        print_memory_tracking(f"Memory tracking: before get ref logps", rank=args.world_size-1)
        if args.ppo_initial_policy_kl_penalty > 0:
            if args.ppo_smart_pad_infer:
                init_policy_logprobs = actor_model.batch_get_ref_policy_logprobs(rollout_batches)
            else:
                init_policy_logprobs = actor_model.get_ref_policy_logprobs(rollout_batches)
            assert init_policy_logprobs is not None
            for ref_logprobs, rb in zip(init_policy_logprobs, rollout_batches, strict=True):
                rb["ref_logprobs"] = ref_logprobs

        clear_memory()
        print_memory_tracking(f"Memory tracking: after get ref logps", rank=args.world_size-1)
        timers('compute_logps').stop()
        if args.do_monitor:
            mark_made_progress(args.monitor_server_ip, args.monitor_port)
        print_meminfo_str("Track CPU Memory after get logprobs: ")

        timers('filter_sampling', log_level=0).start(barrier=True)
        assert check_rollout_batches(rollout_batches), f"rbs fmt error: {rollout_batches=}"
        rollout_metrics_original = None
        if args.ppo_sampling_keep != args.ppo_sampling_repeat:
            rollout_metrics_original = cpu_dict(self.compute_global_rollout_metrics(rollout_batches))
            filter_samplings_func(args, rollout_batches)
            for rb in rollout_batches:
                expected_mbs = args.ppo_rollout_micro_batch_size * args.ppo_sampling_keep
                assert len(rb['tokens']) == expected_mbs
                assert check_rollout_batches(rollout_batches), f"rbs fmt error: {rollout_batches=}"
        timers('filter_sampling').stop()

        timers('compute_global_rollout_metrics', log_level=0).start(barrier=True)
        rollout_metrics = cpu_dict(self.compute_global_rollout_metrics(rollout_batches))
        if rollout_metrics_original is not None:
            print_rank_0(f"[MATH][DEBUG] Updating rollout metrics from original rollout_batches")
            rollout_metrics.update({f"{k}-original": v for k, v in rollout_metrics_original.items()})
        timers('compute_global_rollout_metrics').stop()

        self.display_rollout_generation(rollout_batches)
        return rollout_batches, rollout_metrics


    def compute_pass_at_k_torch_vectorized(self, rewards_tensor):
        """
        向量化计算pass@k
        rewards_tensor: shape [batch_size, rrep], 每个元素是0或1
        
        返回: shape [batch_size, rrep] 的tensor
            result[i, k-1] 表示第i个样本的pass@k值 (k从1到rrep)
        """
        batch_size, rrep = rewards_tensor.shape
        
        # 转换为float64提高数值稳定性
        rewards = rewards_tensor.to(torch.float64)
        
        # 计算每个样本的正确数量: [batch_size]
        c = rewards.sum(dim=1)  # shape: [batch_size]
        n = rrep  # 总样本数
        
        # 初始化结果张量: [batch_size, rrep]
        pass_at_k_tensor = torch.zeros_like(rewards, dtype=torch.float64)
        
        # 为所有k值计算pass@k
        for k in range(1, rrep + 1):
            # 处理边界情况: 如果错误样本数 < k，则pass@k = 1
            mask_all_pass = (n - c) < k  # shape: [batch_size]
            
            # 初始化当前k的结果
            pass_at_k = torch.ones(batch_size, dtype=torch.float64, device=rewards.device)
            
            # 只对需要计算的样本进行计算
            valid_mask = ~mask_all_pass
            
            if valid_mask.any():
                c_valid = c[valid_mask]  # [valid_batch_size]
                
                # 创建 [valid_batch_size, n] 的矩阵用于统一计算
                c_expanded = c_valid.unsqueeze(1)  # [valid_batch_size, 1]
                j_indices = torch.arange(n, dtype=torch.float64, device=rewards.device).unsqueeze(0)  # [1, n]
                
                # denominator = n - c + 1 + j，当 j < c 时有效
                denominator_matrix = n - c_expanded + 1 + j_indices  # [valid_batch_size, n]
                
                # 计算 1 - k / denominator
                factors = 1.0 - k / denominator_matrix
                
                # 只有当 j < c[i] 时才参与计算 将无效位置设为1（不影响连乘）
                factors = torch.where(j_indices < c_expanded, factors, 1.0)
                
                # 计算连乘
                prod_results = torch.prod(factors, dim=1)  # [valid_batch_size]
                
                pass_at_k[valid_mask] = 1.0 - prod_results
            
            pass_at_k_tensor[:, k-1] = pass_at_k
        
        return pass_at_k_tensor

    def compute_global_rollout_metrics_eval(self, rollout_batches: List[Dict[str, List[Any]]]):
        metrics = defaultdict(lambda: 0)

        args = get_args()
        rbs = len(rollout_batches)
        rmbs = args.ppo_rollout_micro_batch_size
        rrep = args.ppo_sampling_repeat

        rollout_batches_rewards = []

        num_samples = 0
        for _, rb in enumerate(rollout_batches):
            prompt_lengths = torch.stack(rb["prompt_lengths"]).view(-1)
            sequence_lengths = torch.stack(rb["sequence_lengths"]).view(-1)
            rewards = torch.stack(rb["rewards"]).view(-1)
            metrics["sequence_lengths"] += (sequence_lengths - prompt_lengths).sum()
            metrics["prompt_lengths"] += prompt_lengths.sum()
            metrics["rewards"] += rewards.sum()
            
            rollout_batches_rewards.append(rewards.view(rmbs, rrep, -1))

            for em_info in self.extra_metric_info:
                em_k = em_info['key_name']

                em_v = torch.stack(rb[em_k]).view(-1)
                metrics[em_k] += em_v.sum()
            num_samples += prompt_lengths.size(0)
        
        # shape: [rbs, rmbs, rrep, -1]
        rollout_batches_rewards = torch.stack(rollout_batches_rewards, dim=0)
        print_rank_0(f"[MATH][DEBUG] {rollout_batches_rewards.shape=}")
        # shape: [rbs*rmbs, rrep, -1]
        rollout_batches_pass_at_k = self.compute_pass_at_k_torch_vectorized(rollout_batches_rewards.flatten(start_dim=0, end_dim=1).mean(dim=-1))
        
        if args.ppo_wecube_report and is_mp_and_cp_head():
            biz_report_data = {
                "prompt_lengths": metrics["prompt_lengths"],
                "sequence_lengths": metrics["sequence_lengths"],
                "num_samples": num_samples,
            }
            for k in biz_report_data.keys():
                if torch.is_tensor(biz_report_data[k]):
                    biz_report_data[k] = biz_report_data[k].item()
            report_ppo_metrics(biz_report_data)

        tensor_to_accumulate = [
            metrics["sequence_lengths"],
            metrics["prompt_lengths"],
            metrics["rewards"],
            num_samples,
        ]
        for em_info in self.extra_metric_info:
            em_k = em_info['key_name']
            tensor_to_accumulate.append(metrics[em_k])
        tensor_to_accumulate = torch.tensor(
            tensor_to_accumulate,
            dtype=torch.float32,
            device=torch.cuda.current_device(),
        )
        torch.distributed.all_reduce(tensor_to_accumulate,
                                     group=parallel_state.get_data_parallel_group())

        tensor_to_accumulate_as_list = tensor_to_accumulate.tolist()
        (
            global_response_lengths,
            global_prompt_lengths,
            global_rewards,
            global_num_samples,
        ) = tensor_to_accumulate_as_list[:4]
        metrics = {
            "rollout-metrics/global_response_lengths_mean":
            global_response_lengths / global_num_samples,
            "rollout-metrics/global_prompt_lengths": global_prompt_lengths / global_num_samples,
            "rollout-metrics/global_rewards": global_rewards / global_num_samples,
        }
        for em_i, em_info in enumerate(self.extra_metric_info):
            em_k = em_info['key_name']
            metrics[f"rollout-metrics/global_{em_k}"] = tensor_to_accumulate_as_list[
                4 + em_i] / global_num_samples
            
        for i in range(rrep):
            metrics[f"passk/pass@{i+1}"] = float(rollout_batches_pass_at_k[:,i].mean())   
        # for k,v in metrics.items():
        #     print_rank_0(f"[MATH][DEBUG] metrics {k}: {type(v)}")
        
        return metrics


    @override
    def eval_run_inference(
        self,
        actor_model,
        sampler_client,
        rm_critic_client,
        gen_rm_client,
        ppo_eval_step,
        eval_num_rollout_microbatches,
        rollout_get_batch_func,
        eval_dataloader_iter,
    ) -> Tuple[List[Dict[str, List[Any]]], Dict[str, float]]:
        args = get_args()
        timers = get_timers()

        print_with_rank_and_datetime(f"eval_run_inference {eval_num_rollout_microbatches=}", rank=0)
        torch.distributed.barrier()
        eval_rollout_batches = self.get_rollout_batches(
            sampler_client,
            rm_critic_client,
            gen_rm_client,
            ppo_eval_step,
            eval_num_rollout_microbatches,
            rollout_get_batch_func,
            eval_dataloader_iter,
            is_eval=True,
        )
        torch.distributed.barrier()

        
        # sglang (0.4.6.post5) 目前不支持 sleep 后保留 weights, 需要 update_weights
        if args.infer_engine_impl == 'sglang':
            if not args.ppo_early_swap_model:
                onload_megatron_model(actor_model.model)
                torch.distributed.barrier()

            timers('eval_update_sampler_weights', log_level=0).start(barrier=True)
            print_with_rank_and_datetime('eval_update_sampler_weights', rank=0)
            sampler_client.update_weights(
                actor_model.model,
                early_swap_model=args.ppo_early_swap_model,
                cpu_memory_model=self.cpu_memory_model,
            )
            timers('eval_update_sampler_weights').stop()

            if not args.ppo_early_swap_model:
                offload_megatron_model(actor_model.model)
                torch.distributed.barrier()

        timers('broadcast_eval_rollout_batches', log_level=0).start(barrier=True)
        # check eval_rollout_batches
        if is_mp_and_cp_head():
            assert check_rollout_batches(eval_rollout_batches), f'eval_rollout_batches fmt error: {eval_rollout_batches=}'
            for eval_rollout_batch in eval_rollout_batches:
                assert (
                    len(eval_rollout_batch) >= 3
                    and "tokens" in eval_rollout_batch
                    and "prompt_lengths" in eval_rollout_batch
                    and "sequence_lengths" in eval_rollout_batch
                    and "rewards" in eval_rollout_batch
                )

        eval_rollout_batches = sampler_client.broadcast_rollout_batch(eval_rollout_batches)
        eval_rollout_batches = self.add_back_rollout_attr_before_sampling(eval_rollout_batches)
        assert check_rollout_batches(eval_rollout_batches), f'eval_rollout_batches fmt error: {eval_rollout_batches=}'
        timers('broadcast_eval_rollout_batches').stop()

        timers('compute_global_eval_rollout_metrics', log_level=0).start(barrier=True)
        _metrics = cpu_dict(self.compute_global_rollout_metrics_eval(eval_rollout_batches))
        eval_rollout_metrics = {}
        for key, value in _metrics.items():
            eval_key = 'eval-' + key
            eval_rollout_metrics[eval_key] = value
        timers('compute_global_eval_rollout_metrics').stop()

        self.display_rollout_generation(eval_rollout_batches)
        return eval_rollout_batches, eval_rollout_metrics

    
