              
                                                      
                                                                 

from functools import partial
from typing import List, Dict, Optional, Union, Any
import argparse
import math
import copy
import inspect

from torch import Tensor
import torch
import torch.nn.functional as F

from megatron.core import mpu, parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import divide
from megatron.core.utils import get_model_config

from gpatch.core.aligner_helper import broadcast_2d_tensor_within_pp, broadcast_2d_tensor_within_mp, broadcast_object_within_mp, pad_batches_to_multiple_of
from gpatch.core.aligner_interface import CriticModelInterface, ResetArgsMixin, Inferrable
from gpatch.core.device_type import is_wxacc1
from gpatch.core.models.gpt.gpt_reward_model import StateDictState
from gpatch.core.aligner_helper import (
    swap_dict,
    get_ltor_masks_and_position_ids,
    get_iterator_k_split,
    get_iterator_k_split_list,
    masked_mean,
    get_last_rank,
    average_losses_across_data_parallel_group,
    pad_batches_to_multiple_of_within_ep,
    expand_rollout_batches,
    get_gbs_batches_seqlen,
    get_max_seqlen_within_ep,
)
from gpatch.core.parallel_state import is_mp_head
from gpatch.core.wecube import report_ppo_metrics
from gpatch.core.smart_pad_helper import (
    CatedSmartPadInferHelper,
    get_column_based_batches,
)


class GptPpoCriticModel(ResetArgsMixin, CriticModelInterface):

    def __init__(self,
                 config,
                 reward_model,
                 unwrap_model_func,
                 forward_micro_batch_size: int = 1,
                 reward_running=None,
                 ppo_reward_len_penalty_coef=0.,
                 ppo_reward_len_penalty_mean=None,
                 ppo_reward_len_penalty_std=None,
                 rm_inputs_modifier=None,
                 rm_outputs_modifier=None,
                 rm_output_sequence=None,
                 rm_output_scalar=None,
                 enable_smart_pad=False,
                 pad_to_multi_of=None,
                 pad_token_id=None,
                 **kwargs):
        self.model = reward_model
        self.config = config
        if not (self.config.use_grpo and self.config.ppo_grpo_reward_type == "rule_only"):
            assert len(self.model) == 1

        self.unwrap_model_func = unwrap_model_func

        self.cpu_state_dict = None
        self.rm_state_dicts = []
        self.ppo_reward_scalings = []
        self.rm_factors = []
        self.loaded_rm_idx = -100
        self.loaded_state_dict = StateDictState.CRITIC
        self.clip_val = self.config.ppo_loss_clip_val
        assert isinstance(reward_running,
                          argparse.Namespace), f'type of reward_running {type(reward_running)}'

        self.forward_micro_batch_size = forward_micro_batch_size
        self.reward_running = reward_running
        self.ppo_reward_len_penalty_coef = ppo_reward_len_penalty_coef
        self.ppo_reward_len_penalty_mean = ppo_reward_len_penalty_mean
        self.ppo_reward_len_penalty_std = ppo_reward_len_penalty_std
        self.use_grpo = self.config.use_grpo
        self.rm_inputs_modifier = rm_inputs_modifier
        self.rm_outputs_modifier = rm_outputs_modifier
        self.rm_output_sequence = rm_output_sequence
        self.rm_output_scalar = rm_output_scalar

                                                        
                   
        self.enable_smart_pad = enable_smart_pad
        self.pad_to_multi_of = pad_to_multi_of
        self.pad_token_id = pad_token_id

        self.validate_samples_params = None

    def prepare_for_inference(self):
        if self.model is not None:
            for model_module in self.model:
                model_module.eval()

    def finish_inference(self):
        if self.model is not None:
            for model_module in self.model:
                model_module.train()

    def prepare_for_training_step(self):
        if self.model is not None:
            for model_module in self.model:
                model_module.train()

    def finish_training_step(self):
        pass

    def prepare_data_for_infer(
        self,
        batches: List[Dict[str, Any]],
        seq_length: int,
        rm_index: Optional[int],
    ) -> Dict[str, Optional[torch.Tensor]]:
        if rm_index is None or not f'rm_{rm_index}_tokens' in batches[0]:
            tokens_key = "tokens"
            sequence_lengths_key = "sequence_lengths"
            prompt_lengths_key = "prompt_lengths"
            rm_output_mask_key = "rm_output_mask"
        else:
                                      
            tokens_key = f'rm_{rm_index}_tokens'
            sequence_lengths_key = f'rm_{rm_index}sequence_lengths'
            prompt_lengths_key = f'rm_{rm_index}prompt_lengths'
            rm_output_mask_key = f'rm_{rm_index}_output_mask'

        tokens_l = []
        sequence_lengths = []
        prompt_lengths = []
        rm_output_mask = []
        for batch in batches:
            tokens = batch[tokens_key]
            token_len = tokens.shape[-1]
            tokens = torch.nn.functional.pad(
                tokens,
                (0, seq_length - token_len),
                value=self.pad_token_id,
            )
            tokens_l.append(tokens)
            sequence_lengths.append(batch[sequence_lengths_key])
            prompt_lengths.append(batch[prompt_lengths_key])
            if rm_output_mask_key in batch:
                assert False, "TODO(xiaotaoliu): no do now"

        tokens = torch.stack(tokens_l).view(len(tokens_l), -1).cuda(non_blocking=True)
        attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
            tokens,
            0,          
            False,
            False,
            False,
            False,
        )
        infer_kwargs = dict(
            input_ids=tokens,
            lengths=torch.stack(sequence_lengths).view(-1).cuda(non_blocking=True),
            prompt_lens=torch.stack(prompt_lengths).view(-1).cuda(non_blocking=True),
            position_ids=position_ids,
            attention_mask=attention_mask,
        )

        if not self.enable_smart_pad:
            pad_batches_to_multiple_of_within_ep(infer_kwargs, self.pad_to_multi_of, 'inputs',
                                                 self.pad_token_id)
                         
                                      

        return infer_kwargs

    def infer_rm_critic(self, *args, **kwargs):
        call_order = (self._infer_rm, self._infer_critic)

        original_state = self.loaded_state_dict
        if original_state == StateDictState.CRITIC:
                                                               
                                            
            call_order = reversed(call_order)

        outputs = []
        per_token_rewards_list = None
        custom_rewards = None
        for fn in call_order:
            infer_out = fn(*args, **kwargs)
            if len(infer_out) == 3:
                        
                output, tlrl, exceeded = infer_out
            elif len(infer_out) == 4:
                                            
                output, tlrl, exceeded, custom_rewards = infer_out
            else:
                raise ValueError(f"invalid number of returns from inference function")
            outputs.append(output)
            if tlrl is not None:
                assert per_token_rewards_list is None
                per_token_rewards_list = tlrl

        if original_state == StateDictState.CRITIC:
                                     
            outputs = reversed(outputs)

                                                
        return (*outputs, per_token_rewards_list, exceeded, custom_rewards)

    def infer_rm_only(
        self,
        batches: List[Dict[str, Union[int, List[Any]]]] = None,
        sampling_repeat: int = None,
    ):
        output, tlrl, exceeded, custom_rewards = self._infer_rm(batches, sampling_repeat)
        return (output, None, tlrl, exceeded, custom_rewards)

    def set_output_sequence_flag(self, unwrapped_model, value_to_set):
        if hasattr(unwrapped_model, 'rm_head'):
            unwrapped_model.rm_head.output_sequence = value_to_set

    def set_output_scalar_flag(self, unwrapped_model, value_to_set):
        if hasattr(unwrapped_model, 'rm_head'):
            unwrapped_model.rm_head.output_scalar = value_to_set

    def _load_critic(self):
        if self.loaded_state_dict == StateDictState.REWARD:
            assert self.cpu_state_dict is not None
            unwrapped_model = self.unwrap_model_func(self.model)[0]
            swap_dict(unwrapped_model, self.cpu_state_dict, offload_onto_cpu=False)
            self.set_output_sequence_flag(unwrapped_model, True)
            self.set_output_scalar_flag(unwrapped_model, False)
            self.loaded_state_dict = StateDictState.CRITIC
            self.loaded_rm_idx = -100

    def _load_rm(self, index):
        assert index < len(self.rm_state_dicts)
        unwrapped_model = self.unwrap_model_func(self.model)[0]

        if self.use_grpo:
            self.loaded_rm_idx = 0
            if self.rm_output_sequence is None:
                self.set_output_sequence_flag(unwrapped_model, False)
            else:
                assert len(self.rm_output_sequence) == 1
                self.set_output_sequence_flag(unwrapped_model, self.rm_output_sequence[index] == 1)
            if self.rm_output_scalar is None:
                self.set_output_scalar_flag(unwrapped_model, True)
            else:
                assert len(self.rm_output_scalar) == 1
                self.set_output_scalar_flag(unwrapped_model, self.rm_output_scalar[index] == 1)

            self.loaded_state_dict = StateDictState.REWARD
            return

        self.loaded_rm_idx = index
        if self.loaded_state_dict == StateDictState.CRITIC:
            self.cpu_state_dict = swap_dict(
                unwrapped_model,
                self.rm_state_dicts[self.loaded_rm_idx],
                offload_onto_cpu=True,
                offloaded_weights=self.cpu_state_dict,
            )
        else:
            swap_dict(unwrapped_model,
                      self.rm_state_dicts[self.loaded_rm_idx],
                      offload_onto_cpu=False)

                                                         
        if self.rm_output_sequence is None:
            self.set_output_sequence_flag(unwrapped_model, False)
        else:
            self.set_output_sequence_flag(unwrapped_model, self.rm_output_sequence[index] == 1)
        if self.rm_output_scalar is None:
            self.set_output_scalar_flag(unwrapped_model, True)
        else:
            self.set_output_scalar_flag(unwrapped_model, self.rm_output_scalar[index] == 1)
        self.loaded_state_dict = StateDictState.REWARD

    def _infer_critic(self, batches: List[Dict[str, Union[int, List[Any]]]]):
        self._load_critic()
        ts, _, exceeded = self.infer(batches)
        assert _ is None
        return ts, None, exceeded

    def infer_rule_based_rm(
        self,
        rewards,
        per_token_rewards=None,
        sequence_lengths: torch.Tensor = None,
        prompt_lengths: torch.Tensor = None,
        batches: List[Dict[str, Union[int, List[Any]]]] = None,
    ):
        '''
        继承这个 GptPpoCriticModel，实现 infer_rule_based_rm，自己提供自己的 rule。
        - rewards 是 rewards 分数，shape 为 [b, 1]，
        - per_token_rewards 是 process rewarding 的分数，shape [b, s-1]，一般为 none。
        - batches[i]["tokens"] 就是 prompt + generated tokens
        - sequence_lengths = prompt_lengths + answer_lengths
        - batches[i]["rm_output_mask"] 一般是 none。

        ```python
        class DqaGptPpoCriticModel:
            def infer_rule_based_rm(self, rewards, per_token_rewards=None, batches):
                args = get_args()
                tokenizer = get_tokenizer()

                inputs_list = []
                for batch in batches:
                    inputs_list.extend(batch["tokens"])
                tokens_cpu = list_for_tensor_tolist(inputs_list, False)

                # 干点你自己的逻辑，比如惩罚下敏感关键词，惩罚下格式。
                strs = tokenizer._tokenizer.batch_decode(tokens_cpu, skip_special_tokens=False)
                for i, s in enumerate(strs):
                    if not 'box' in s:
                        # 这么些 gpu 上打断队列不太优，不过这里计算量很小我估计还好。
                        rewards[i] = rewards[i] * 0.5

                return rewards, per_token_rewards, {}
        ```
        '''
        custom_rewards = None
        if not is_mp_head():
            assert rewards is None
            assert per_token_rewards is None
        else:
                      
                                                                         
                                                                         
            custom_rewards = {}
        return rewards, per_token_rewards, custom_rewards

    def _infer_rm(self, batches: List[Dict[str, Union[int, List[Any]]]], sampling_repeat):
        rewards_ret: Tensor = None
        per_token_rewards_ret: Tensor = None
        exceeded = None

        if not (self.config.use_grpo and self.config.ppo_grpo_reward_type == "rule_only"):
            assert self.model is not None
            for rm_idx in range(len(self.rm_state_dicts)):
                self._load_rm(rm_idx)
                ts, per_token_rewards, exceeded = self.infer(batches, rm_idx)

                if is_mp_head():
                                            
                    if rm_idx == 0:
                        ts *= self.rm_factors[rm_idx]
                        rewards_ret = ts
                    else:
                        rewards_ret += self.rm_factors[rm_idx] * ts

                                      
                    if per_token_rewards is not None:
                        if per_token_rewards_ret is None:
                            per_token_rewards *= self.rm_factors[rm_idx]
                            per_token_rewards_ret = per_token_rewards
                        else:
                            per_token_rewards_ret += self.rm_factors[rm_idx] * per_token_rewards
        else:
            self.loaded_state_dict = StateDictState.REWARD

                        
        sequence_lengths = None
        prompt_lengths = None
        if is_mp_head():
            sequence_lengths_list = []
            prompt_lengths_list = []
            for batch in batches:
                sequence_lengths_list.extend(batch["sequence_lengths"])
                prompt_lengths_list.extend(batch["prompt_lengths"])
            sequence_lengths = torch.stack(sequence_lengths_list).view(-1)
            prompt_lengths = torch.stack(prompt_lengths_list).view(-1)
                    
        rewards_ret, per_token_rewards_ret, custom_rewards = self.infer_rule_based_rm(
            rewards_ret,
            per_token_rewards=per_token_rewards_ret,
            sequence_lengths=sequence_lengths,
            prompt_lengths=prompt_lengths,
            batches=batches,
        )

        if self.validate_samples_params is None:
            sig = inspect.signature(self.validate_samples)
            self.validate_samples_params = sig.parameters

        if 'sampling_repeat' in self.validate_samples_params:
            check_sampler_result = self.validate_samples(rewards_ret, sampling_repeat=sampling_repeat)
        else:
            check_sampler_result = self.validate_samples(rewards_ret)
        if check_sampler_result is not None and len(check_sampler_result) > 0:
            custom_rewards.update(check_sampler_result)

                     
        if is_mp_head():
            assert sequence_lengths is not None and prompt_lengths is not None
            assert rewards_ret is not None
        else:
            assert sequence_lengths is None and prompt_lengths is None
        if self.ppo_reward_len_penalty_coef > 0. and is_mp_head():
            rewards_ret -= self.ppo_reward_len_penalty_coef * (
                torch.abs(sequence_lengths.cuda() - prompt_lengths.cuda() -
                          self.ppo_reward_len_penalty_mean) / self.ppo_reward_len_penalty_std)
                                       
        if self.config.dapo_overlong_penalty and is_mp_head():
            expected_len = self.config.ppo_resp_seq_len - self.config.dapo_overlong_buffer_len
            valid_response_length = sequence_lengths.cuda() - prompt_lengths.cuda()
            exceed_len = valid_response_length - expected_len
            overlong_reward = torch.clamp(-exceed_len / self.config.dapo_overlong_buffer_len * self.config.dapo_overlong_penalty_factor, max=0)
            rewards_ret += overlong_reward.view(-1, 1)
                                
        assert self.loaded_state_dict == StateDictState.REWARD
        if self.config.ppo_enable_standardization and is_mp_head():
                              
            rew_mean, rew_std = self.update_reward_running(rewards_ret, False, None, None)
            eps = torch.finfo(rewards_ret.dtype).eps
            rewards_ret = (rewards_ret - rew_mean) / (rew_std + eps)
            rewards_ret = rewards_ret.clamp(-self.config.ppo_reward_clip_val,
                                            self.config.ppo_reward_clip_val)

            if per_token_rewards_ret is not None:
                pt_rew_mean, pt_rew_std = self.update_reward_running(per_token_rewards_ret, True,
                                                                     prompt_lengths,
                                                                     sequence_lengths)
                eps = torch.finfo(per_token_rewards_ret.dtype).eps
                per_token_rewards_ret = (per_token_rewards_ret - pt_rew_mean) / (pt_rew_std + eps)
                per_token_rewards_ret = per_token_rewards_ret.clamp(
                    -self.config.ppo_reward_clip_val, self.config.ppo_reward_clip_val)

                           
        if self.config.ppo_enable_standardization and torch.distributed.get_world_size() > 1:
            running = self.reward_running
            tmp = torch.tensor([
                running.ppo_reward_mean,
                running.ppo_reward_var,
                running.ppo_per_token_reward_mean,
                running.ppo_per_token_reward_var,
            ],
                               dtype=torch.float32,
                               device=torch.cuda.current_device())
            if not is_wxacc1():
                torch.distributed.all_reduce(tmp, op=torch.distributed.ReduceOp.AVG)
                (
                    running.ppo_reward_mean,
                    running.ppo_reward_var,
                    running.ppo_per_token_reward_mean,
                    running.ppo_per_token_reward_var,
                ) = tmp.tolist()
                tmp = torch.tensor([
                    running.ppo_reward_count,
                    running.ppo_per_token_reward_count,
                ],
                                   dtype=torch.int64,
                                   device=torch.cuda.current_device())
                torch.distributed.all_reduce(tmp, op=torch.distributed.ReduceOp.AVG)
            else:
                world_size = torch.distributed.get_world_size()
                tmp /= world_size
                torch.distributed.all_reduce(tmp, op=torch.distributed.ReduceOp.SUM)
                (
                    running.ppo_reward_mean,
                    running.ppo_reward_var,
                    running.ppo_per_token_reward_mean,
                    running.ppo_per_token_reward_var,
                ) = tmp.tolist()
                tmp = torch.tensor([
                    running.ppo_reward_count,
                    running.ppo_per_token_reward_count,
                ],
                                   dtype=torch.float,
                                   device=torch.cuda.current_device())
                tmp /= world_size
                torch.distributed.all_reduce(tmp, op=torch.distributed.ReduceOp.SUM)
            running.ppo_reward_count, running.ppo_per_token_reward_count = tmp.to(
                dtype=torch.int64).tolist()

        return rewards_ret, per_token_rewards_ret, exceeded, custom_rewards

    def get_global_statistics(self, xs, per_token, prompt_lengths, sequence_lengths):
        xs_tensor = xs                  
        if per_token:
            assert sequence_lengths is not None and prompt_lengths is not None
            mask = torch.arange(xs_tensor.shape[1], device='cuda').view(1, -1)
            mask = mask.expand(xs_tensor.shape[0], -1)
            mask = torch.logical_and(mask >= prompt_lengths.cuda(), mask < sequence_lengths.cuda())
            assert mask.dtype == torch.bool
            if is_wxacc1():
                                                                          
                xs_cnt = torch.sum(sequence_lengths.cuda() - prompt_lengths.cuda())
                xs_tensor[~mask] = 0
                xs_mean = torch.sum(xs_tensor) / xs_cnt
                xs_tensor -= xs_mean
                xs_tensor[~mask] = 0
                xs_var = torch.var(xs_tensor, correction=0) * (xs_tensor.numel() / xs_cnt)
            else:
                xs_tensor = xs_tensor[mask]
                xs_var, xs_mean = torch.var_mean(xs_tensor, correction=0)
                xs_cnt = xs_tensor.numel()
        else:
            xs_var, xs_mean = torch.var_mean(xs_tensor, correction=0)
            xs_cnt = xs_tensor.numel()
        return xs_mean, xs_var, xs_cnt

    @torch.no_grad()
    def update_reward_running(self, xs, per_token, prompt_lengths, sequence_lengths):
        assert isinstance(xs, torch.Tensor)
        xs_mean, xs_var, xs_count = self.get_global_statistics(xs, per_token, prompt_lengths,
                                                               sequence_lengths)

        running = self.reward_running
        if per_token:
            old_cnt = running.ppo_per_token_reward_count
            old_mean = running.ppo_per_token_reward_mean
            old_var = running.ppo_per_token_reward_var
        else:
            old_cnt = running.ppo_reward_count
            old_mean = running.ppo_reward_mean
            old_var = running.ppo_reward_var
        assert old_cnt >= 0 and old_var >= 0, f'old_cnt {old_cnt} old_mean {old_mean} old_var {old_var}'
        delta = xs_mean - old_mean
        tot_count = old_cnt + xs_count

                                                               
        new_sum = xs_var * xs_count
        old_sum = old_var * old_cnt + delta**2 * old_cnt * xs_count / tot_count
        tot_sum = old_sum + new_sum

        if per_token:
            running.ppo_per_token_reward_mean += delta * xs_count / tot_count
            if old_cnt == 0:
                running.ppo_per_token_reward_var = xs_var
            else:
                running.ppo_per_token_reward_var = tot_sum / tot_count
            running.ppo_per_token_reward_count = tot_count
            return running.ppo_per_token_reward_mean, math.sqrt(running.ppo_per_token_reward_var)
        else:
            running.ppo_reward_mean += delta * xs_count / tot_count
            if old_cnt == 0:
                running.ppo_reward_var = xs_var
            else:
                running.ppo_reward_var = tot_sum / tot_count
            running.ppo_reward_count = tot_count
            return running.ppo_reward_mean, math.sqrt(running.ppo_reward_var)

    def get_loss_and_metrics(self, batch, num_microbatches, forward_only):
        seq_length = batch["tokens"].size(-1)
        data_iter = get_iterator_k_split(batch, num_microbatches)

        fwd_bwd_function = get_forward_backward_func()
        losses_reduced_per_micro_batch = fwd_bwd_function(
            forward_step_func=self.get_forward_output_and_loss_func(),
            data_iterator=data_iter,
            model=self.model,
            num_microbatches=num_microbatches,
            forward_only=forward_only,
            seq_length=seq_length,
            micro_batch_size=self.config.micro_batch_size,
        )

                                                            
        if losses_reduced_per_micro_batch:
                                               
            loss_tensors_list = [
                loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch
            ]
            loss_tensor = torch.concat(loss_tensors_list)
            loss_mean = loss_tensor.mean()
        else:
            loss_mean = torch.tensor(0.0, device=torch.cuda.current_device())

                                                                                       
        torch.distributed.broadcast(loss_mean, get_last_rank())
        metrics = {
            "loss": loss_mean.item(),
        }
        return loss_mean.item(), metrics

    def get_forward_output_and_loss_func(self):
                                     
        def fwd_output_and_loss_func(data_iterator, model):
            batch = next(data_iterator)
            tokens = batch["tokens"].cuda()
            returns = batch["returns"]
            prev_values = batch["prev_values"]
            mask = batch["mask"]

                                                                                                
                                              
            attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
                tokens,
                0,
                False,
                False,
                False,
                False,
            )

                                                                             
            if parallel_state.get_pipeline_model_parallel_world_size() > 1:
                if parallel_state.is_pipeline_first_stage():
                    returns, prev_values, mask = None, None, None
                    tokens, position_ids = map(lambda x: x.cuda(non_blocking=True),
                                               (tokens, position_ids))
                elif parallel_state.is_pipeline_last_stage():
                    tokens, position_ids = None, None
                    prev_values, mask, returns = map(lambda x: x.cuda(non_blocking=True),
                                                     (prev_values, mask, returns))
                else:
                                                       
                    tokens, position_ids, returns, prev_values, mask = [None] * 5

            assert self.loaded_state_dict == StateDictState.CRITIC
            output = model(
                input_ids=tokens,
                lengths=None,
                position_ids=position_ids,
                attention_mask=attention_mask,
                prompt_lens='POISON',
                labels=None,
            )

                                          
                       
                                                             
                                                               
                 

            def loss_func(curr_values):
                if self.config.ppo_value_truncate_head:
                    curr_values = curr_values[:, 1:].contiguous()
                else:
                    curr_values = curr_values[:, :-1].contiguous()
                assert curr_values.dtype == torch.float32
                assert returns.dtype == torch.float32
                assert prev_values.dtype == torch.float32

                                
                if self.loaded_state_dict == StateDictState.REWARD:
                    curr_values = curr_values * self.ppo_reward_scalings[self.loaded_rm_idx]

                                                                
                clip_val = self.clip_val

                if clip_val > 0.0:
                    values_clipped = prev_values + (curr_values - prev_values).clamp(
                        -clip_val, clip_val)
                    v_loss1 = (values_clipped - returns)**2
                else:
                    v_loss1 = torch.tensor(0.0).cuda()
                v_loss2 = (curr_values - returns)**2

                             
                loss = 0.5 * masked_mean(torch.max(v_loss1, v_loss2), mask)

                reduced_loss = average_losses_across_data_parallel_group([loss])
                return loss, {"avg": reduced_loss}

            return output, loss_func

        return fwd_output_and_loss_func

    def prepare_for_training(self):
                               
        self._load_critic()

    def broadcaset_data(self, ex_batches: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
                                                             
        ex_batches = broadcast_object_within_mp(ex_batches)
        return ex_batches

    def infer(self, batches: List[Dict[str, Union[int, List[Any]]]], rm_index: int = None):
        ex_batches = None
        if is_mp_head():
            ex_batches = expand_rollout_batches(batches, allow_nolist=True)
                                            

        ex_batches = self.broadcaset_data(ex_batches)
        seq_length = get_gbs_batches_seqlen(ex_batches, self.pad_to_multi_of)
        seq_length = get_max_seqlen_within_ep(seq_length)

        batch_size = len(ex_batches)
        num_microbatches = divide(batch_size, self.forward_micro_batch_size)
        split_batches = get_iterator_k_split_list(ex_batches, num_microbatches)
        exceeded = [False for _ in range(batch_size)]

                                                                     
        smart_pad_helper = None
        if self.enable_smart_pad:
            get_seqlen_func = lambda input: input["tokens"].shape[-1]
            def forward_step_wrapped_func(batches, num_microbatches: int, micro_batch_size: int, seq_length: int):
                return self.forward_step(batches, num_microbatches, micro_batch_size, seq_length, rm_index)

            smart_pad_helper = CatedSmartPadInferHelper(ex_batches, self.forward_micro_batch_size)
            smart_pad_helper.forward_pipeline(self.pad_to_multi_of, get_seqlen_func=get_seqlen_func, forward_step_wrapped_func=forward_step_wrapped_func)
            rowed_based_fwd_step_rets = smart_pad_helper.get_rowed_based_forward_results()

            if parallel_state.is_pipeline_last_stage():
                fwd_step_rets = []
                                                                                                                                                                                        
                for fwd_id, rowed_based_fwd_step_ret in enumerate(rowed_based_fwd_step_rets):
                    for result_key in ["per_token_reward"]:
                        if result_key in rowed_based_fwd_step_ret[0]:
                            pad_batches_to_multiple_of(rowed_based_fwd_step_ret, self.pad_to_multi_of, result_key, 0)
                    column_based_fwd_step_ret = get_column_based_batches(rowed_based_fwd_step_ret)
                    for result_key in ["per_token_reward", "value_or_reward"]:
                        if result_key in column_based_fwd_step_ret:
                            column_based_fwd_step_ret[result_key] = torch.stack(column_based_fwd_step_ret[result_key])
                    fwd_step_rets.append(column_based_fwd_step_ret)
                                                                                                                                                    
        else:
            fwd_step_rets = self.forward_step(
                split_batches,
                num_microbatches,
                self.forward_micro_batch_size,
                seq_length,
                rm_index,
            )
        rewards = None
        per_token_rewards = None

        if parallel_state.is_pipeline_last_stage():
            need_tlr = 'per_token_reward' in fwd_step_rets[0]
            rewards = torch.cat([each['value_or_reward']
                                 for each in fwd_step_rets]) if len(fwd_step_rets) > 0 else None
            if need_tlr:
                if self.enable_smart_pad:
                    pad_batches_to_multiple_of(fwd_step_rets, self.pad_to_multi_of, "per_token_reward", 0)
                per_token_rewards = torch.cat([each['per_token_reward'] for each in fwd_step_rets])

                            
            if self.loaded_state_dict == StateDictState.REWARD:
                rewards.mul_(self.ppo_reward_scalings[self.loaded_rm_idx])
                if need_tlr:
                    per_token_rewards.mul_(self.ppo_reward_scalings[self.loaded_rm_idx])
        rewards = broadcast_2d_tensor_within_pp(rewards)
        per_token_rewards = broadcast_2d_tensor_within_pp(per_token_rewards)

        if self.loaded_state_dict == StateDictState.CRITIC:
            assert per_token_rewards is None
        return rewards, per_token_rewards, exceeded

    def forward_step(
        self,
        batches,
        num_microbatches: int,
        micro_batch_size: int,
        seq_length: int,
        rm_index: Optional[int],
    ):
        fwd_bwd_function = get_forward_backward_func()
        output_tensor = fwd_bwd_function(
            forward_step_func=self.get_forward_output_only_func(seq_length, rm_index),
            data_iterator=batches,
            model=self.model,
            num_microbatches=num_microbatches,
            forward_only=True,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
        )
        return output_tensor

    def get_forward_output_only_func(self, seq_length: int, rm_index: Optional[int]):

        def fwd_output_only_func(seq_length: int, rm_index: Optional[int], batches, model):
            infer_kwargs = self.prepare_data_for_infer(next(batches), seq_length, rm_index)
                            
            output_tensor = model(**infer_kwargs)

                                          
                       
                                                             
                                                                             
                 

            if parallel_state.is_pipeline_last_stage(
            ) and parallel_state.get_tensor_model_parallel_rank() == 0:
                tokens = infer_kwargs["input_ids"]
                report_data = {
                    "critic_infer_times": 1,
                    "critic_infer_samples": len(tokens),
                    "critic_infer_seq_lengths": tokens.shape[0] * tokens.shape[1],
                }
                report_ppo_metrics(report_data)

                          
            if parallel_state.is_pipeline_last_stage() and self.rm_outputs_modifier is not None:
                output_tensor = self.rm_outputs_modifier(self.loaded_rm_idx, output_tensor,
                                                         infer_kwargs["rm_output_mask"])

            def id_func(output_tensor):
                                        
                if self.loaded_state_dict == StateDictState.CRITIC:
                             
                                      
                    assert torch.is_tensor(output_tensor)
                    return output_tensor, {
                        "value_or_reward": output_tensor.detach().clone(),
                    }
                else:
                         
                    assert self.loaded_state_dict == StateDictState.REWARD
                    if torch.is_tensor(output_tensor):
                        if output_tensor.shape[1] == 1:
                                               
                            return output_tensor, {
                                "value_or_reward": output_tensor.detach().clone(),
                            }
                        else:
                                                         
                            zw = torch.zeros((output_tensor.shape[0], 1),
                                             dtype=output_tensor.dtype,
                                             device=output_tensor.device)
                            return output_tensor, {
                                "value_or_reward": zw,
                                "per_token_reward": output_tensor.detach().clone(),
                            }
                    else:
                                                                         
                        assert isinstance(output_tensor, tuple) and len(output_tensor) == 2
                        output_seq, output_scalar = output_tensor
                        return output_scalar, {
                            "value_or_reward": output_scalar.detach().clone(),
                            "per_token_reward": output_seq.detach().clone(),
                        }

            return output_tensor, id_func

        return partial(fwd_output_only_func, seq_length, rm_index)

    def validate_samples(self, rewards, sampling_repeat=None):
        """"
        Checks the validity of the given samples and returns a dictionary with the results.

        Parameters:
        rewards: a tensor of rewards, shape [b, 1].

        Returns:
        dict: a dict regarding the usefulness of samples. example:
                {'sample_useful': tensor of usefulness ([b])}
        """
        check_result = None
        if not is_mp_head():
            assert rewards is None
        else:
                                       
                                                                  
            check_result = {}
        return check_result