              
                                                      
                                                                 

import sys
from functools import partial
from typing import List, Dict, Any, Tuple
from packaging.version import Version

import torch

from megatron.core import package_info
from megatron.core import mpu, parallel_state, InferenceParams
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core import tensor_parallel
from megatron.core.utils import divide, get_model_config
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from gpatch.core.device_type import is_wxacc1
from gpatch.core.aligner_interface import AlignableGenerativeInterface, ResetArgsMixin
from gpatch.core.ppo_helper import vocab_parallel_entropy, calculate_kl_loss
from gpatch.core.tensor_parallel.mappings import all_gather_to_context_parallel_region
from gpatch.core.aligner_helper import (
    cpu_weight_swap_v2,
    get_ltor_masks_and_position_ids,
    get_iterator_k_split,
    get_iterator_k_split_list,
    from_parallel_logits_to_logprobs,
    broadcast_2d_tensor_within_pp,
    broadcast_object_within_pp,
    expand_rollout_batches,
    masked_mean,
    average_losses_across_data_parallel_group,
    get_last_rank,
    get_gbs_batches_seqlen,
    retrieve_model_state_dict_in_cpu,
    get_tensor_on_this_cp_rank,
    get_max_seqlen_within_ep,
    get_max_seqlen_within_dp,
    clear_memory,
)

from gpatch.core.utils import print_with_rank_and_datetime, print_memory_tracking
from gpatch.core.smart_pad_helper import (
    GroupSmartPadInferHelper,
    get_split_batchs,
    preprocess_packed_seqs,
    postprocess_packed_seqs,
)
from gpatch.core.wecube import (
    report_ppo_metrics
)


                      
class GptPpoActorModel(ResetArgsMixin, AlignableGenerativeInterface):

    def __init__(
        self,
        model,
        ref_model_state,
        unwrap_model_func,
        forward_micro_batch_size=1,
        gen_left_pad=False,
        ppo_rollout_temperature=1.0,
        pad_to_multi_of=None,
        pad_token_id=None,
        dynamic_mbs_target_seqlen: int | None = None,
        dynamic_mbs_limit: int | None = None,
        ppo_pack_seq: bool = False,
    ):
                                                    
        self.model = model
        assert len(self.model) == 1
        self.config = get_model_config(self.model[0])
        self.init_policy_state_dict = ref_model_state
        self.unwrap_model_func = unwrap_model_func

        self.forward_micro_batch_size = forward_micro_batch_size
        self.gen_left_pad = gen_left_pad
        self.distributed_adam_offload_manager = None
        self._optimizer = None
        self.entropy_bonus = self.config.ppo_entropy_bonus
        self.ratio_eps = self.config.ppo_ratio_eps
        self.clip_ratio_low = self.config.ppo_clip_ratio_low
        self.clip_ratio_high = self.config.ppo_clip_ratio_high
        self.clamp_kl_val = self.config.ppo_clamp_kl_val
        self.logps_ratio_clamp = self.config.ppo_logps_ratio_clamp
        self.dual_clip_ratio_c = self.config.ppo_dual_clip_ratio_c
        self.use_grpo = self.config.use_grpo
        self.use_gspo_loss = self.config.use_gspo_loss
        self.ppo_rollout_temperature = ppo_rollout_temperature
        if self.dual_clip_ratio_c is not None:
            assert self.dual_clip_ratio_c > 1.0, "dual-clip PPO should be greater than 1.0"
        self.pad_to_multi_of = pad_to_multi_of
        self.pad_token_id = pad_token_id

        self.dynamic_mbs_target_seqlen = dynamic_mbs_target_seqlen
        self.dynamic_mbs_limit = dynamic_mbs_limit
        self.ppo_pack_seq = ppo_pack_seq

        self.batch_iters = 0
        self.total_iters = 0
        self.batch_log_str = ""

    def prepare_for_inference(self):
        for model_module in self.model:
            model_module.eval()

    def finish_inference(self):
        for model_module in self.model:
            model_module.train()

    def infer(self, *args, **kwargs):
        raise NotImplementedError('not implemented yet')

    def prepare_for_training_step(self):
        for model_module in self.model:
            model_module.train()

    def finish_training_step(self):
        pass

    def get_loss_and_metrics(
        self,
        batch: List[Dict[str, Any]],
        num_microbatches: int,
        forward_only: bool,
    ):
        seq_length = get_gbs_batches_seqlen(batch, self.pad_to_multi_of)
        if self.dynamic_mbs_target_seqlen is not None:
            seq_length = get_max_seqlen_within_dp(seq_length)
        else:
            seq_length = get_max_seqlen_within_ep(seq_length)
        dynamic_num_microbatches = 0
        dynamic_mbs = 0
        batch_size = len(batch)

        if self.dynamic_mbs_target_seqlen is not None:
            dynamic_mbs = self.dynamic_mbs_target_seqlen // seq_length * \
                self.config.micro_batch_size
            if dynamic_mbs == 0:
                dynamic_mbs = 1
            dynamic_mbs = min(self.dynamic_mbs_limit, dynamic_mbs)
            while dynamic_mbs >= 1:
                if batch_size % dynamic_mbs == 0:
                    dynamic_num_microbatches = batch_size // dynamic_mbs
                    break
                else:
                    dynamic_mbs -= 1
            if dynamic_num_microbatches > 0:
                assert dynamic_num_microbatches * \
                    dynamic_mbs == batch_size, f"{dynamic_mbs=} {dynamic_num_microbatches=} {batch_size=} mismatch!"

        enable_dynamic_mbs = True if dynamic_num_microbatches > 0 else False
        print_with_rank_and_datetime(
            f"[TRAIN] {seq_length=} {batch_size=} {num_microbatches=} {enable_dynamic_mbs=} {dynamic_num_microbatches=} {dynamic_mbs=}", rank=0)
        data_iter = get_iterator_k_split_list(
            batch, dynamic_num_microbatches if enable_dynamic_mbs else num_microbatches)

        fwd_bwd_function = get_forward_backward_func()
        if self.use_grpo:
            forward_step_func = self.get_actor_grpo_forward_output_and_loss_func(
                seq_length)
        else:
                                              
            forward_step_func = self.get_actor_forward_output_and_loss_func()
        losses_reduced_per_micro_batch = fwd_bwd_function(
            forward_step_func=forward_step_func,
            data_iterator=data_iter,
            model=self.model,
            num_microbatches=dynamic_num_microbatches if enable_dynamic_mbs else num_microbatches,
            forward_only=forward_only,
            seq_length=seq_length,
            decoder_seq_length=seq_length,
            micro_batch_size=dynamic_mbs if enable_dynamic_mbs else self.config.micro_batch_size,
        )
                                                            
                                                        
        clear_memory()

        metrics = {"seq_length": seq_length}
        metrics_key = ["loss", "ppo_ratio",
                       "ppo_ratio_clamped", "scaled_entropy"]
        if self.use_grpo:
            metrics_key += ["grpo_kl_loss"]
        for key in metrics_key:
            if losses_reduced_per_micro_batch:
                metric_mean = torch.stack(
                    [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch]).mean()
            else:
                metric_mean = torch.tensor(
                    0.0, device=torch.cuda.current_device())
            torch.distributed.broadcast(metric_mean, get_last_rank())
            metrics[key] = metric_mean.cpu().item()
        return metrics["loss"], metrics

    def get_actor_forward_output_and_loss_func(self):

        def fwd_output_and_loss_func(data_iterator, model):
            batch = next(data_iterator)
            non_blocking = False if is_wxacc1() else True
            response_tokens = batch["response_tokens"].cuda(
                non_blocking=non_blocking)
            advantages = batch["advantages"]
            mask = batch["mask"]
            prev_logprobs = batch["prev_logprobs"]

            attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
                data=response_tokens,
                eod_token=0,          
                reset_position_ids=False,
                reset_attention_mask=False,
                eod_mask_loss=False,
                compute_attention_mask=False,
            )

            batch = {
                "tokens": response_tokens,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "advantages": advantages,
                "prev_log_probs": prev_logprobs,
                "mask": mask,
            }
            required_keys = set()
            if parallel_state.get_pipeline_model_parallel_world_size() == 1:
                required_keys.update(batch.keys())
            else:
                required_keys.add("attention_mask")
                if parallel_state.is_pipeline_first_stage():
                    required_keys.update(("tokens", "position_ids"))
                if parallel_state.is_pipeline_last_stage():
                    required_keys.update(
                        ("tokens", "advantages", "mask", "prev_log_probs"))

            batch = {
                key: val.cuda(
                    non_blocking=non_blocking) if key in required_keys else None
                for key, val in batch.items()
            }

            parallel_logits = model(
                batch["tokens"],
                batch["position_ids"],
                batch["attention_mask"],
                labels=None,
            )

            def loss_func(parallel_logits):
                parallel_logits = parallel_logits.float()
                mask = batch["mask"]
                advantages = batch["advantages"]
                prev_log_probs = batch["prev_log_probs"]
                tokens = batch["tokens"]
                assert advantages.dtype == torch.float32
                assert prev_log_probs.dtype == torch.float32
                parallel_logits_clone = parallel_logits.clone()

                assert not self.config.cross_entropy_loss_fusion
                if self.config.cross_entropy_loss_fusion:
                                  
                    curr_log_probs = -1 * fused_vocab_parallel_cross_entropy(
                        parallel_logits.transpose(0, 1).contiguous(),
                        tokens.transpose(0, 1).contiguous(),
                    ).transpose(0, 1)[:, -1]
                else:
                    curr_log_probs = from_parallel_logits_to_logprobs(
                        vocab_parallel_logits=parallel_logits, target=tokens)

                scaled_entropy = vocab_parallel_entropy(
                    parallel_logits_clone, mask)

                                                                
                ratios = (curr_log_probs - prev_log_probs).exp()
                ratios_clamped = ratios.clamp(
                    1.0 - self.ratio_eps, 1.0 + self.ratio_eps)

                loss1 = -advantages * ratios
                loss2 = -advantages * ratios_clamped
                actor_loss = masked_mean(torch.max(loss1, loss2), mask)
                loss = actor_loss - scaled_entropy * self.entropy_bonus

                with torch.no_grad():
                    ppo_ratio = masked_mean(ratios.detach(), mask)
                    ppo_ratio_clamped = masked_mean(
                        ratios_clamped.detach(), mask)
                    scaled_entropy = scaled_entropy.detach()

                reduced_actor_loss = average_losses_across_data_parallel_group([
                                                                               loss])
                return (
                    loss,
                    {
                        "loss": reduced_actor_loss,
                        "ppo_ratio": ppo_ratio,
                        "ppo_ratio_clamped": ppo_ratio_clamped,
                        "scaled_entropy": scaled_entropy,
                    },
                )

            return parallel_logits, loss_func

        return fwd_output_and_loss_func

    def prepare_data_for_grpo_loss(
        self,
        batches: List[Dict[str, Any]],
        seqlen: int,
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        non_blocking = False if is_wxacc1() else True
        tokens_l = []
        advantages_l = []
        mask_l = []
        logprobs_l = []
        ref_logprobs_l = []
        sequence_lengths_l = []
        for batch in batches:
            tokens_l.append(
                torch.nn.functional.pad(
                    batch['tokens'],
                    (0, seqlen - batch['tokens'].shape[-1]),
                    value=self.pad_token_id,
                ))
            advantages_l.append(
                torch.nn.functional.pad(
                    batch['advantages'],
                    (0, seqlen - 1 - batch['advantages'].shape[-1]),
                    value=0,
                ))
            mask_l.append(
                torch.nn.functional.pad(
                    batch['mask'],
                    (0, seqlen - 1 - batch['mask'].shape[-1]),
                    value=0,
                ))
            logprobs_l.append(
                torch.nn.functional.pad(
                    batch['logprobs'],
                    (0, seqlen - 1 - batch['logprobs'].shape[-1]),
                    value=0,
                ))
            ref_logprobs_l.append(
                torch.nn.functional.pad(
                    batch['ref_logprobs'],
                    (0, seqlen - 1 - batch['ref_logprobs'].shape[-1]),
                    value=0,
                ))
            sequence_lengths_l.append(batch['sequence_lengths'])

        tokens = torch.stack(tokens_l).cuda(non_blocking=non_blocking)
        advantages = torch.stack(advantages_l)
        mask = torch.stack(mask_l)
        logprobs = torch.stack(logprobs_l)
        ref_logprobs = torch.stack(ref_logprobs_l)
        sequence_lengths = torch.stack(sequence_lengths_l)

        attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
            data=tokens,
            eod_token=0,          
            reset_position_ids=False,
            reset_attention_mask=False,
            eod_mask_loss=False,
            compute_attention_mask=False,
        )
        target = tokens.detach().clone()

                                    
        if self.config.context_parallel_size > 1 and not self.ppo_pack_seq:
            tokens = get_tensor_on_this_cp_rank(tokens, 1, key_name="tokens")
            attention_mask = get_tensor_on_this_cp_rank(
                attention_mask, 2, key_name="attention_mask")
            position_ids = get_tensor_on_this_cp_rank(
                position_ids, 1, key_name="position_ids")

        batch = {
            "tokens": tokens,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "advantages": advantages,
            "prev_log_probs": logprobs,
            "mask": mask,
            "ref_log_probs": ref_logprobs,
            'target': target,
            'sequence_lengths': sequence_lengths,
        }
        required_keys = set()
        if parallel_state.get_pipeline_model_parallel_world_size() == 1:
            required_keys.update(batch.keys())
        else:
            required_keys.add("attention_mask")
            required_keys.add("sequence_lengths")
            if parallel_state.is_pipeline_first_stage():
                required_keys.update(("tokens", "position_ids"))
            if parallel_state.is_pipeline_last_stage():
                required_keys.update(
                    ("tokens", "advantages", "mask", "prev_log_probs", "ref_log_probs", 'target'))

        batch = {
            key: (
                val.cuda(non_blocking=non_blocking)
                if key in required_keys and val is not None
                else None
            )
            for key, val in batch.items()
        }

        fwd_kwargs = dict(
            input_ids=batch.pop("tokens"),
            position_ids=batch.pop("position_ids"),
            attention_mask=batch.pop("attention_mask"),
            labels=None,
        )
        return batch, fwd_kwargs

    def get_actor_grpo_forward_output_and_loss_func(self, seqlen: int):

        def fwd_output_and_loss_func(seqlen, data_iterator, model):
            batches: List[Dict[str, Any]] = next(data_iterator)
            unwrapped_model = self.unwrap_model_func(model)

            batch, fwd_kwargs = self.prepare_data_for_grpo_loss(
                batches, seqlen)
            for key in ["mask", "advantages", "prev_log_probs", "ref_log_probs", "target"]:
                assert key in batch

            if not self.ppo_pack_seq:
                                                         
                                                           
                                            
                                               
                                                        
                                                
                                                         
                parallel_logits = model(**fwd_kwargs)
            else:
                                                                             
                                                             
                assert not (
                    unwrapped_model.position_embedding_type == 'mrope' and
                    not unwrapped_model.config.multi_latent_attention
                )
                                                             
                cur_mbs, cur_max_seqlen = fwd_kwargs['input_ids'].shape[:2]
                cur_actual_seqlen = batch['sequence_lengths'].unsqueeze(
                    1).expand(-1, cur_max_seqlen)
                tmpa = torch.arange(cur_max_seqlen, device='cuda', dtype=torch.int32).unsqueeze(
                    0).expand(cur_mbs, -1)
                pad_mask = torch.ones(
                    cur_mbs, cur_max_seqlen, device='cuda', dtype=torch.int32)
                pad_mask[tmpa >= cur_actual_seqlen] = 0
                input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(
                    fwd_kwargs['input_ids'],
                    pad_mask,
                    pre_process=unwrapped_model.pre_process,
                )
                input_ids_rmpad = input_ids_rmpad.contiguous()
                output_rmpad = model(
                    input_ids=input_ids_rmpad,
                    position_ids=fwd_kwargs['position_ids'],
                    attention_mask=None,
                    labels=None,
                    packed_seq_params=packed_seq_params,
                )
                parallel_logits = postprocess_packed_seqs(
                    output_rmpad,
                    packed_seq_params,
                    pad_mask,
                    cur_mbs,
                    cur_max_seqlen,
                    post_process=unwrapped_model.post_process,
                )

            if isinstance(parallel_logits, tuple):
                parallel_logits = parallel_logits[0]
            assert isinstance(parallel_logits, torch.Tensor)

            def loss_func(parallel_logits):
                parallel_logits = parallel_logits.float()
                mask = batch["mask"]
                advantages = batch["advantages"]
                prev_log_probs = batch["prev_log_probs"]
                ref_log_probs = batch["ref_log_probs"]
                assert advantages.dtype == torch.float32
                assert prev_log_probs.dtype == torch.float32
                target = batch["target"]

                parallel_logits_clone = parallel_logits.clone()

                assert not self.config.cross_entropy_loss_fusion
                curr_log_probs = from_parallel_logits_to_logprobs(
                    vocab_parallel_logits=parallel_logits, target=target)

                scaled_entropy = vocab_parallel_entropy(
                    parallel_logits_clone, mask)

                if self.use_gspo_loss:
                    negative_approx_kl = curr_log_probs - prev_log_probs
                                                              
                                                                 
                                                                                           
                    seq_lengths = torch.sum(mask, dim=-1).clamp(min=1)
                    negative_approx_kl_seq = torch.sum(
                        negative_approx_kl * mask, dim=-1) / seq_lengths

                                                    
                                                                                             
                                                                                             
                    log_seq_importance_ratio = curr_log_probs - \
                        curr_log_probs.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
                    log_seq_importance_ratio = torch.clamp(
                        log_seq_importance_ratio, max=10.0)                                 

                                                
                    ratios = torch.exp(log_seq_importance_ratio)
                else:
                                                                    
                                                   
                    if self.logps_ratio_clamp is not None:
                        ratios = torch.clamp(curr_log_probs - prev_log_probs,
                                             min=-self.logps_ratio_clamp,
                                             max=self.logps_ratio_clamp).exp()
                    else:
                        ratios = (curr_log_probs - prev_log_probs).exp()

                                                                           
                clip_ratio_low = self.clip_ratio_low if self.clip_ratio_low is not None else self.ratio_eps
                clip_ratio_high = self.clip_ratio_high if self.clip_ratio_high is not None else self.ratio_eps
                ratios_clamped = ratios.clamp(
                    1.0 - clip_ratio_low, 1.0 + clip_ratio_high)

                loss1 = -advantages * ratios
                loss2 = -advantages * ratios_clamped
                clip_max_loss = torch.maximum(loss1, loss2)
                                                            
                if self.dual_clip_ratio_c is not None:
                    loss3 = -advantages * self.dual_clip_ratio_c
                    clip_min_loss = torch.min(loss3, clip_max_loss)
                    actor_loss = torch.where(
                        advantages < 0, clip_min_loss, clip_max_loss)
                else:
                    actor_loss = clip_max_loss

                if self.use_gspo_loss:
                                              
                    seq_losses = torch.sum(
                                    
                        actor_loss * mask, dim=-1) / torch.sum(mask, dim=-1)
                    actor_loss = torch.mean(seq_losses)            
                else:
                    actor_loss = masked_mean(actor_loss, mask)

                loss = actor_loss - scaled_entropy * self.entropy_bonus

                with torch.no_grad():
                    if self.use_gspo_loss:
                                                                                              
                        ppo_ratio = ratios.detach().mean()
                        ppo_ratio_clamped = ratios_clamped.detach().mean()
                    else:
                        ppo_ratio = masked_mean(ratios.detach(), mask)
                        ppo_ratio_clamped = masked_mean(
                            ratios_clamped.detach(), mask)
                    scaled_entropy = scaled_entropy.detach()

                                              
                kl_loss = masked_mean(
                    calculate_kl_loss(
                        cur_log_probs=curr_log_probs,
                        ref_log_probs=ref_log_probs,
                        use_absolute_kl=False,                                           
                        use_low_var_kl=True,
                        clamp_kl_loss=self.dual_clip_ratio_c is not None,
                        clamp_kl_val=self.clamp_kl_val,
                    ),
                    mask)
                loss = loss + kl_loss * self.config.grpo_kl_loss_beta

                reduced_actor_loss = average_losses_across_data_parallel_group([
                                                                               loss, kl_loss])

                if Version(package_info.__version__) < Version("0.12.1"):
                    bwd_loss = loss * self.config.context_parallel_size
                else:
                    bwd_loss = loss.clone()

                return (
                    bwd_loss,
                    {
                        "loss": reduced_actor_loss[0],
                        "ppo_ratio": ppo_ratio,
                        "ppo_ratio_clamped": ppo_ratio_clamped,
                        "scaled_entropy": scaled_entropy,
                        "grpo_kl_loss": reduced_actor_loss[1],
                    },
                )

            return parallel_logits, loss_func

        return partial(fwd_output_and_loss_func, seqlen)

    def prepare_for_training(self):
                               
        pass

    def finish_training(self):
        pass

    def get_policy_logprobs(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
    ) -> List[List[torch.Tensor]]:
        sampling_repeat = len(list(rollout_batches[0].values())[0])
        batches_list = expand_rollout_batches(rollout_batches)
        assert len(batches_list) == sampling_repeat * len(rollout_batches)
                                 
        self.batch_iters = 0
        self.total_iters = len(batches_list) // self.forward_micro_batch_size
        self.batch_log_str = "get_policy_logprobs microbatch "

        log_probs = self.get_inference_log_probs(
            batches_list, forward_micro_batch_size=self.forward_micro_batch_size)
        res = []
        for i in range(len(rollout_batches)):
            res.append(
                log_probs[i * sampling_repeat:(i + 1) * sampling_repeat])
        return res

    def get_ref_policy_logprobs(
        self,
        rollout_batches: List[Dict[str, List[Any]]],
    ) -> List[List[torch.Tensor]]:
        sampling_repeat = len(list(rollout_batches[0].values())[0])
        batches_list = expand_rollout_batches(rollout_batches)
        assert len(batches_list) == sampling_repeat * len(rollout_batches)
                                 
        self.batch_iters = 0
        self.total_iters = len(batches_list) // self.forward_micro_batch_size
        self.batch_log_str = "get_ref_policy_logprobs microbatch "

        assert self.init_policy_state_dict
        with cpu_weight_swap_v2(self.model, self.unwrap_model_func, self.init_policy_state_dict):
            init_log_probs = self.get_inference_log_probs(
                batches_list, forward_micro_batch_size=self.forward_micro_batch_size)
        ref_res = []
        for i in range(len(rollout_batches)):
            ref_res.append(
                init_log_probs[i * sampling_repeat:(i + 1) * sampling_repeat])
        return ref_res

    def batch_get_policy_logprobs(self, rollout_batches):
        self.batch_log_str = "get_policy_logprobs microbatch "
        log_probs = self.batch_get_inference_log_probs(
            rollout_batches,
            forward_micro_batch_size=self.forward_micro_batch_size,
        )
        return log_probs

    def batch_get_ref_policy_logprobs(self, rollout_batches):
        self.batch_log_str = "get_ref_policy_logprobs microbatch "
        assert self.init_policy_state_dict
        with cpu_weight_swap_v2(self.model, self.unwrap_model_func, self.init_policy_state_dict):
            init_log_probs = self.batch_get_inference_log_probs(
                rollout_batches,
                forward_micro_batch_size=self.forward_micro_batch_size,
            )
        return init_log_probs

    @torch.no_grad()
    def batch_get_inference_log_probs(self, rollout_batches, forward_micro_batch_size) -> List[List[torch.Tensor]]:
        gbs_per_dp_rank = len(rollout_batches)
        mbs = len(list(rollout_batches[0].values())[0])
        assert mbs > 0, f"[SMART_PAD] can't predict mbs from rollout_batches, {rollout_batches[0]=}"
        assert (gbs_per_dp_rank *
                mbs) % self.forward_micro_batch_size == 0, f"{mbs=} should be multiple of {self.forward_micro_batch_size=}"

        num_batches = divide(gbs_per_dp_rank * mbs, forward_micro_batch_size)
                                                                                                             

                                 
        self.batch_iters = 0
        self.total_iters = num_batches

        smart_pad_helper = GroupSmartPadInferHelper(
            rollout_batches, self.forward_micro_batch_size)

        def get_seqlen_func(input): return input["tokens"].shape[-1]
        smart_pad_helper.forward_pipeline(pad_to_multi_of=self.pad_to_multi_of,
                                          get_seqlen_func=get_seqlen_func, forward_step_wrapped_func=self.forward_step)
        logprobs_list = smart_pad_helper.get_rowed_based_forward_results(
            is_row_based_rets=True)

        flatten_logprobs_list = []
        if parallel_state.is_pipeline_last_stage():
            for logprobs_list_per_forward_step in logprobs_list:
                for logprobs_per_sample in logprobs_list_per_forward_step:
                    flatten_logprobs_list.append(logprobs_per_sample.cpu())

                                                             
        flatten_logprobs_list = broadcast_object_within_pp(
            flatten_logprobs_list)
        logprobs = get_split_batchs(flatten_logprobs_list, mbs)
        assert len(
            logprobs) == gbs_per_dp_rank, f"len(logprobs) expect {gbs_per_dp_rank}, but get {len(logprobs)}"

        return logprobs

    @torch.no_grad()
    def get_inference_log_probs(
        self,
        batches_list: List[Dict[str, Any]],
        forward_micro_batch_size,
    ) -> List[torch.Tensor]:
        gbs = len(batches_list)
        seq_length = get_gbs_batches_seqlen(batches_list, self.pad_to_multi_of)
        seq_length = get_max_seqlen_within_ep(seq_length)
        num_microbatches = divide(gbs, forward_micro_batch_size)
        batch_iter = get_iterator_k_split_list(batches_list, num_microbatches)

        fwd_bwd_function = get_forward_backward_func()
        logprobs_list = fwd_bwd_function(
            forward_step_func=self.get_logprob_output_only_func(
                seq_length, inference_only=True),
            data_iterator=batch_iter,
            model=self.model,
            num_microbatches=num_microbatches,
            forward_only=True,
            seq_length=seq_length,
            micro_batch_size=forward_micro_batch_size,
            collect_non_loss_data=True,
            decoder_seq_length=seq_length,
        )

        logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None

                                                             
        logprobs = broadcast_2d_tensor_within_pp(logprobs)
        if not is_wxacc1():
            assert logprobs.dtype == torch.float32, f'{logprobs.dtype=}'

        assert logprobs.shape[0] == gbs
        logprobs = [logprob.squeeze(0)
                    for logprob in logprobs.cpu().chunk(gbs)]
        clear_memory()
        return logprobs

    def forward_step(self, batch_iter, num_microbatches, micro_batch_size, seq_length):
        fwd_bwd_function = get_forward_backward_func()
                                                                                                          
        output_tensor = fwd_bwd_function(
            forward_step_func=self.get_logprob_output_only_func(
                seqlen=seq_length, inference_only=True),
            data_iterator=batch_iter,
            model=self.model,
            num_microbatches=num_microbatches,
            forward_only=True,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            collect_non_loss_data=True,
            decoder_seq_length=seq_length,
        )
                                                            
        clear_memory()
        return output_tensor

                     
    def get_logprob_output_only_func(self, seqlen, inference_only=True):

        def log_prob_output_only_func(seqlen, dataloader_iter, model):
            batches: List[Dict[str, Any]] = next(dataloader_iter)
                                          
            if self.total_iters > 0 and torch.distributed.get_rank() == 0:
                self.batch_iters += 1
                print_with_rank_and_datetime(
                    f"{self.batch_log_str} {self.batch_iters:8d}/{self.total_iters:8d}")

            model_fwd_args = self.prepare_data_for_model_forward_only(
                batches, seqlen)
            target = model_fwd_args.pop("target")
            output_tensor = model(**model_fwd_args)
            if isinstance(output_tensor, tuple):
                output_tensor = output_tensor[0]
            assert isinstance(output_tensor, torch.Tensor)

            def id_func(output_tensor, non_loss_data=True):
                assert not self.config.cross_entropy_loss_fusion
                                                                                     
                logprobs = from_parallel_logits_to_logprobs(vocab_parallel_logits=output_tensor,
                                                            target=target,
                                                            inference_only=inference_only)

                return logprobs

            return output_tensor, id_func

        return partial(log_prob_output_only_func, seqlen)

    def prepare_data_for_model_forward_only(self, batches: List[Dict[str, Any]],
                                            seqlen: int) -> Dict[str, Any]:
        model_fwd_args = {}

        tokens_l = []
        for batch in batches:
            token_len = batch["tokens"].shape[-1]
            token = torch.nn.functional.pad(
                batch["tokens"],
                (0, seqlen - token_len),
                value=self.pad_token_id,
            )
            tokens_l.append(token)

        tokens = torch.stack(tokens_l).view(
            len(tokens_l), -1).cuda(non_blocking=True)
        model_fwd_args["target"] = tokens.detach().clone()

        attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
            tokens, 0, False, False, False, False)
        if attention_mask is not None:
            attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1)

        if self.config.context_parallel_size > 1:
            tokens = get_tensor_on_this_cp_rank(tokens, 1, key_name="tokens")
            attention_mask = get_tensor_on_this_cp_rank(attention_mask,
                                                        2,
                                                        key_name="attention_mask")
            position_ids = get_tensor_on_this_cp_rank(
                position_ids, 1, key_name="position_ids")

        model_fwd_args["input_ids"] = tokens
        model_fwd_args["position_ids"] = position_ids
        model_fwd_args["attention_mask"] = attention_mask
        return model_fwd_args

    @torch.no_grad()
    def update_ref_with_actor(self, coef):
        assert self.init_policy_state_dict
        unwrapped_model = self.unwrap_model_func(self.model)[0]
        sd = retrieve_model_state_dict_in_cpu(unwrapped_model)
        for k, v in sd.items():
            assert k in self.init_policy_state_dict
            if torch.is_tensor(v):
                self.init_policy_state_dict[k] *= (1 - coef)
                self.init_policy_state_dict[k] += coef * v
