# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single Process Actor
"""

import itertools
import logging
import os
from typing import Tuple
import numpy as np
import torch
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, compute_policy_loss_focal, compute_policy_loss_implicit, compute_sft_loss, kl_penalty, compute_policy_loss_kl_cov, compute_policy_loss_clip_cov
from verl.utils.debug import GPUMemoryLogger
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
from verl.workers.actor import BasePPOActor

__all__ = ["DataParallelPPOActor"]

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

def split_dict_by_batch(original_dict, mini_bsz):
    """
    将值为列表的字典分割成多个字典，每个字典的值列表按mini_bsz进行分割
    
    参数:
    original_dict (dict): 原始字典，每个键对应一个列表
    mini_bsz (int): 每个分割后的列表的最小批量大小
    
    返回:
    list: 分割后的字典列表
    """
    # 确定最大分割数
    max_splits = 0
    for value in original_dict.values():
        splits_needed = (len(value) + mini_bsz - 1) // mini_bsz
        if splits_needed > max_splits:
            max_splits = splits_needed
    
    # 创建分割后的字典列表
    result = []
    for i in range(max_splits):
        split_dict = {}
        for key, value in original_dict.items():
            start_idx = i * mini_bsz
            end_idx = start_idx + mini_bsz
            split_dict[key] = value[start_idx:end_idx]
        result.append(split_dict)
    
    return result
import numpy as np
import torch

def generate_perplexity_mask(segment_ids, response_log_probs, response_mask, threshold=0.3):
    """
    计算每个段的困惑度并生成高困惑度段的token掩码
    
    参数:
    segment_ids (torch.Tensor): 段ID张量，形状为 (batch_size, max_seq_len)
    response_log_probs (torch.Tensor): 响应对数概率，形状为 (batch_size, max_seq_len)
    response_mask (torch.Tensor): 响应掩码，形状为 (batch_size, max_seq_len)
    threshold (float): 阈值百分比，用于标记高困惑度的段
    
    返回:
    torch.Tensor: 高困惑度token的掩码，形状为 (batch_size, max_seq_len)
    """
    batch_size, max_seq_len = segment_ids.shape
    high_perplexity_mask = torch.zeros_like(response_mask, dtype=torch.bool)
    
    for i in range(batch_size):
        # 获取当前样本的实际序列
        seq_len = response_mask[i].sum().item()
        actual_segment_ids = segment_ids[i, :seq_len]
        actual_log_probs = response_log_probs[i, :seq_len]
        
        # 确定段的数量
        max_seg_id = actual_segment_ids.max().item() # if actual_segment_ids.numel() > 0 else -1
        
        
        # 初始化每个段的对数概率列表
        num_segments = max_seg_id + 1
        segment_log_probs_list = []
        
        for ii in range(num_segments):
            ids_of_this_segment = segment_ids==ii
            segment_logp = response_log_probs[ids_of_this_segment]
            segment_log_probs_list.append(segment_logp)

        segment_ppl = [-x.mean().item() for x in segment_log_probs_list]
        if len(segment_ppl)==0: continue
        upper_value = np.percentile(segment_ppl, (1 - 0.1) * 100)
        lower_value = np.percentile(segment_ppl, 0.5 * 100)
        # 找出高困惑度的段
        high_perplexity_segments = [seg_id for seg_id, pp in enumerate(segment_ppl) 
                                    if pp >= lower_value]
        # 生成高困惑度token的掩码
        for seg_id in high_perplexity_segments:
            high_perplexity_mask[i, :seq_len] |= (actual_segment_ids == seg_id)
    
    return high_perplexity_mask.to(response_mask.dtype)



def generate_lowppl_mask(segment_ids, response_log_probs, response_mask, threshold=0.3):
    """
    计算每个段的困惑度并生成高困惑度段的token掩码
    
    参数:
    segment_ids (torch.Tensor): 段ID张量，形状为 (batch_size, max_seq_len)
    response_log_probs (torch.Tensor): 响应对数概率，形状为 (batch_size, max_seq_len)
    response_mask (torch.Tensor): 响应掩码，形状为 (batch_size, max_seq_len)
    threshold (float): 阈值百分比，用于标记高困惑度的段
    
    返回:
    torch.Tensor: 高困惑度token的掩码，形状为 (batch_size, max_seq_len)
    """
    batch_size, max_seq_len = segment_ids.shape
    high_perplexity_mask = torch.zeros_like(response_mask, dtype=torch.bool)
    
    for i in range(batch_size):
        # 获取当前样本的实际序列
        seq_len = response_mask[i].sum().item()
        actual_segment_ids = segment_ids[i, :seq_len]
        actual_log_probs = response_log_probs[i, :seq_len]
        
        # 确定段的数量
        max_seg_id = actual_segment_ids.max().item() # if actual_segment_ids.numel() > 0 else -1
        
        
        # 初始化每个段的对数概率列表
        num_segments = max_seg_id + 1
        segment_log_probs_list = []
        
        for ii in range(num_segments):
            ids_of_this_segment = segment_ids==ii
            segment_logp = response_log_probs[ids_of_this_segment]
            segment_log_probs_list.append(segment_logp)

        segment_ppl = [-x.mean().item() for x in segment_log_probs_list]
        if len(segment_ppl)==0: continue
        upper_value = np.percentile(segment_ppl, (1 - 0.1) * 100)
        lower_value = np.percentile(segment_ppl, 0.5 * 100)
        # 找出高困惑度的段
        high_perplexity_segments = [seg_id for seg_id, pp in enumerate(segment_ppl) 
                                    if pp <= lower_value]
        # 生成高困惑度token的掩码
        for seg_id in high_perplexity_segments:
            high_perplexity_mask[i, :seq_len] |= (actual_segment_ids == seg_id)
    
    return high_perplexity_mask.to(response_mask.dtype)
  

def generate_highentropy_mask(entropy, response_mask, how='instance', threshold=0.3):
    """
    Identifies high-entropy tokens within response segments of a batch.

    This function creates a binary mask to highlight tokens that have the highest
    entropy values, based on a specified percentage threshold. The calculation
    can be performed across all valid tokens in the batch or independently
    for each response (instance) in the batch.

    Args:
        entropy (torch.Tensor): A tensor of shape (bsz, ntoken) containing the
            entropy value for each token.
        response_mask (torch.Tensor): A tensor of shape (bsz, ntoken) with 1s
            indicating tokens that are part of the response and 0s otherwise.
            It should be a float or long tensor.
        how (str): The method for calculation. Must be one of:
            - 'all': The top `threshold` percentage is calculated from all
                     tokens marked `1` in `response_mask` across the
                     entire batch.
            - 'instance': The top `threshold` percentage is calculated for each
                          response (row) independently.
        threshold (float, optional): The fraction of high-entropy tokens to
            select. Defaults to 0.3.

    Returns:
        torch.Tensor: A tensor of shape (bsz, ntoken) with 1s marking the
            selected high-entropy tokens and 0s elsewhere.
    
    Raises:
        TypeError: If 'entropy' or 'response_mask' are not torch.Tensors.
        ValueError: If input tensors do not have matching shapes or if 'how'
                    is not 'all' or 'instance'.
    """
    # --- Input Validation ---
    if not isinstance(entropy, torch.Tensor) or not isinstance(response_mask, torch.Tensor):
        raise TypeError("Inputs 'entropy' and 'response_mask' must be torch.Tensors.")
    if entropy.shape != response_mask.shape:
        raise ValueError("Inputs 'entropy' and 'response_mask' must have the same shape.")
    if how not in ['all', 'instance']:
        raise ValueError("Argument 'how' must be either 'all' or 'instance'.")

    # --- Initialization ---
    bsz, ntoken = entropy.shape
    device = entropy.device
    
    # Create a clone of the entropy tensor. We will modify this tensor for masking
    # without affecting the original input tensor.
    # We set the entropy of non-response tokens to negative infinity to ensure they
    # are never selected as having high entropy.
    masked_entropy = entropy.clone()
    masked_entropy[response_mask == 0] = -torch.inf

    # --- Mask Generation ---
    if how == 'all':
        # Flatten the tensor to treat the entire batch as a single sequence
        flat_masked_entropy = masked_entropy.flatten()

        # Calculate the total number of tokens to select from the whole batch
        num_valid_tokens = response_mask.sum()
        
        # If there are no valid tokens, or the threshold results in k=0, return a zero mask
        if num_valid_tokens == 0:
            return torch.zeros_like(entropy, dtype=torch.long, device=device)
        
        k = int(num_valid_tokens * threshold)
        if k == 0:
            return torch.zeros_like(entropy, dtype=torch.long, device=device)

        # Find the indices of the top 'k' entropy values in the flattened tensor
        _, top_indices = torch.topk(flat_masked_entropy, k, sorted=False)

        # Create the final mask by placing 1s at the identified indices
        highentropy_mask = torch.zeros_like(flat_masked_entropy, dtype=torch.long, device=device)
        highentropy_mask[top_indices] = 1
        
        # Reshape the mask back to the original (bsz, ntoken) shape
        return highentropy_mask.view(bsz, ntoken)

    elif how == 'instance':
        # Calculate the number of tokens to select for each instance (row)
        num_valid_per_instance = response_mask.sum(dim=1)
        k_per_instance = (num_valid_per_instance * threshold).int()

        # Sort entropies in descending order for each instance and get their original indices
        _, sorted_indices = torch.sort(masked_entropy, dim=1, descending=True)
        
        # Create a boolean mask for the top-k elements in the sorted view.
        # It will be True for the first 'k' elements of each row, where 'k' can
        # differ for each row based on k_per_instance.
        arange_tokens = torch.arange(ntoken, device=device).expand(bsz, -1)
        topk_mask = arange_tokens < k_per_instance.unsqueeze(1)

        # Scatter the 1s (from the boolean topk_mask) back to their original positions
        # using the 'sorted_indices' map.
        highentropy_mask = torch.zeros_like(entropy, dtype=torch.long, device=device)
        highentropy_mask.scatter_(1, sorted_indices, topk_mask.long())
        
        return highentropy_mask


def get_mask(reference, selector):
    entropy_mask = torch.zeros_like(reference)
    tmp = torch.FloatTensor(selector).reshape((-1,1))
    entropy_mask[:] = tmp
    return entropy_mask

class DataParallelPPOActor(BasePPOActor):
    def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None):
        """When optimizer is None, it is Reference Policy"""
        super().__init__(config)
        self.actor_module = actor_module
        self.actor_optimizer = actor_optimizer

        self.use_remove_padding = self.config.get("use_remove_padding", False)
        print(f"Actor use_remove_padding={self.use_remove_padding}")
        self.use_fused_kernels = self.config.get("use_fused_kernels", False)
        print(f"Actor use_fused_kernels={self.use_fused_kernels}")

        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

        self.compute_entropy_from_logits = (
            torch.compile(verl_F.entropy_from_logits, dynamic=True)
            if self.config.get("use_torch_compile", True)  #  use torch compile by default
            else verl_F.entropy_from_logits
        )

        self.compute_self_certainty_from_logits = (
            torch.compile(verl_F.self_certainty_from_logits, dynamic=True)
            if self.config.get("use_torch_compile", True)  #  use torch compile by default
            else verl_F.self_certainty_from_logits
        )

        if self.use_fused_kernels:
            from verl.utils.experimental.torch_functional import FusedLinearForPPO

            self.fused_linear_for_ppo = FusedLinearForPPO()

            # FusedLinearForPPO has an error when compiled, disable for now
            # if self.config.get("use_torch_compile", True):
            #     self.fused_linear_for_ppo.compile(dynamic=True)

    def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False, calculate_self_certainty=False) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
            entropy: # (bs, response_len)
            log_probs: # (bs, response_len)
        """
        response_length = micro_batch["responses"].size(-1)
        multi_modal_inputs = {}
        if "multi_modal_inputs" in micro_batch:
            for key in micro_batch["multi_modal_inputs"][0].keys():
                multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            input_ids = micro_batch["input_ids"]
            batch_size, seqlen = input_ids.shape
            attention_mask = micro_batch["attention_mask"]
            position_ids = micro_batch["position_ids"]
            entropy = None
            self_certainty = None
            if position_ids.dim() == 3:  # qwen2vl mrope
                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)

            if self.use_remove_padding:
                input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)  # input_ids_rmpad (total_nnz, ...)
                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

                # unpad the position_ids to align the rotary
                if position_ids.dim() == 3:
                    position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1)  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
                else:
                    position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)

                # for compute the log_prob
                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)

                # pad and slice the inputs if sp > 1
                if self.use_ulysses_sp:
                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                        input_ids_rmpad,
                        position_ids_rmpad=position_ids_rmpad,
                        sp_size=self.ulysses_sequence_parallel_size,
                    )
                    input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
                        input_ids_rmpad_rolled,
                        position_ids_rmpad=None,
                        sp_size=self.ulysses_sequence_parallel_size,
                    )

                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)

                # only pass input_ids and position_ids to enable flash_attn_varlen
                output = self.actor_module(
                    input_ids=input_ids_rmpad,
                    attention_mask=None,
                    position_ids=position_ids_rmpad,
                    **multi_modal_inputs,
                    use_cache=False,
                )  # prevent model thinks we are generating

                if self.use_fused_kernels:
                    hidden_states = output.last_hidden_state
                    vocab_weights = self.actor_module.lm_head.weight

                    log_probs, entropy_rmpad = self.fused_linear_for_ppo(
                        hidden_states=hidden_states.squeeze(0),
                        vocab_weights=vocab_weights,
                        input_ids=input_ids_rmpad_rolled,
                        temperature=temperature,
                    )
                    self_certainty_rmpad = None   

                else:
                    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)

                    # logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)
                    logits_rmpad.div_(temperature)

                    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
                    inplace_backward = True
                    if calculate_entropy:
                        inplace_backward = False
                    log_probs = logprobs_from_logits(
                        logits=logits_rmpad,
                        labels=input_ids_rmpad_rolled,
                        inplace_backward=inplace_backward,
                    )

                    # compute entropy
                    if calculate_entropy:
                        entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)
                    if calculate_self_certainty:
                        self_certainty_rmpad = self.compute_self_certainty_from_logits(logits_rmpad)  # ((total_nnz / sp) + pad)

                # gather log_prob if sp > 1
                if self.use_ulysses_sp:
                    # gather and unpad for the ulysses sp
                    log_probs = gather_outpus_and_unpad(
                        log_probs,
                        gather_dim=0,
                        unpad_dim=0,
                        padding_size=pad_size,
                    )
                    if calculate_entropy:
                        entropy_rmpad = gather_outpus_and_unpad(
                            entropy_rmpad,
                            gather_dim=0,
                            unpad_dim=0,
                            padding_size=pad_size,
                        )
                    if calculate_self_certainty:
                        self_certainty_rmpad = gather_outpus_and_unpad(
                            self_certainty_rmpad,
                            gather_dim=0,
                            unpad_dim=0,
                            padding_size=pad_size,
                        )
                # pad back to (bsz, seqlen)
                if calculate_entropy:
                    full_entropy = pad_input(
                        hidden_states=entropy_rmpad.unsqueeze(-1),
                        indices=indices,
                        batch=batch_size,
                        seqlen=seqlen,
                    )
                if calculate_self_certainty:
                    full_self_certainty = pad_input(
                        hidden_states=self_certainty_rmpad.unsqueeze(-1),
                        indices=indices,
                        batch=batch_size,
                        seqlen=seqlen,
                    )
                full_log_probs = pad_input(
                    hidden_states=log_probs.unsqueeze(-1),
                    indices=indices,
                    batch=batch_size,
                    seqlen=seqlen,
                )

                # only return response part:
                if calculate_entropy:
                    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)
                if calculate_self_certainty:
                    self_certainty = full_self_certainty.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)
                log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)

            else:  # not using rmpad and no ulysses sp
                output = self.actor_module(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    **multi_modal_inputs,
                    use_cache=False,
                )  # prevent model thinks we are generating

                if self.use_fused_kernels:
                    hidden_states = output.last_hidden_state
                    vocab_weights = self.actor_module.lm_head.weight

                    log_probs, entropy = self.fused_linear_for_ppo(
                        hidden_states=hidden_states[:, -response_length - 1 : -1, :],
                        vocab_weights=vocab_weights,
                        input_ids=micro_batch["responses"],
                        temperature=temperature,
                    )
                    self_certainty = None   
                else:
                    logits = output.logits

                    logits.div_(temperature)
                    logits = logits[:, -response_length - 1 : -1, :]  # (bsz, response_length, vocab_size)
                    log_probs = logprobs_from_logits(logits, micro_batch["responses"])
                    if calculate_entropy:
                        entropy = verl_F.entropy_from_logits(logits)  # (bsz, response_length)
                    if calculate_self_certainty:
                        self_certainty = verl_F.self_certainty_from_logits(logits)  # (bsz, response_length)
            return entropy, log_probs, self_certainty

    def _optimizer_step(self):
        assert self.config.grad_clip is not None

        if isinstance(self.actor_module, FSDP):
            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)
        elif isinstance(self.actor_module, FSDPModule):
            grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)

        # if grad_norm is not finite, skip the update
        if not torch.isfinite(grad_norm):
            print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}")
            self.actor_optimizer.zero_grad()
        else:
            self.actor_optimizer.step()
        return grad_norm

    @GPUMemoryLogger(role="dp actor", logger=logger)
    def compute_log_prob(self, data: DataProto, calculate_entropy=False, calculate_self_certainty=False) -> torch.Tensor:
        """Compute the log probability of the responses given input_ids, attention_mask and position_ids

        Args:
            data (DataProto): a DataProto containing keys

                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.

                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.

                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.

        Returns:
            torch.Tensor: the log_prob tensor
        """
        # set to eval
        self.actor_module.eval()

        micro_batch_size = data.meta_info["micro_batch_size"]
        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
        use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]

        def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]:
            select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
            batch = data.select(batch_keys=select_keys).batch
            has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch

            if has_multi_modal_inputs:
                all_multi_modal_inputs_list = data.non_tensor_batch["multi_modal_inputs"]
                if use_dynamic_bsz:
                    max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
                    rearranged_text_micro_batches, textual_indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)

                    final_micro_batches_list = []
                    for i, text_mb_td in enumerate(rearranged_text_micro_batches):
                        current_original_indices = textual_indices[i]
                        current_mm_inputs_list = [all_multi_modal_inputs_list[idx] for idx in current_original_indices]

                        mb_dict = {k: v for k, v in text_mb_td.items()}
                        mb_dict["multi_modal_inputs"] = current_mm_inputs_list
                        final_micro_batches_list.append(mb_dict)
                    return final_micro_batches_list, textual_indices
                else:
                    num_micro_batches = batch.batch_size[0] // micro_batch_size
                    micro_batches_dp = data.chunk(num_micro_batches)
                    return micro_batches_dp, None
            elif use_dynamic_bsz:
                max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
                micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
                return micro_batches, indices
            else:
                micro_batches = batch.split(micro_batch_size)
                return micro_batches, None

        micro_batches, indices = _get_micro_batches(data)

        log_probs_lst = []
        entropy_lst = []
        self_certainty_lst = []
        for micro_batch in micro_batches:
            if isinstance(micro_batch, DataProto):
                micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
            with torch.no_grad():
                entropy, log_probs, self_certainty = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy, calculate_self_certainty=calculate_self_certainty)
            log_probs_lst.append(log_probs)
            if calculate_entropy:
                entropy_lst.append(entropy)
            if calculate_self_certainty:
                self_certainty_lst.append(self_certainty)

        log_probs = torch.concat(log_probs_lst, dim=0)
        entropys = None
        self_certaintys = None
        if calculate_entropy:
            entropys = torch.concat(entropy_lst, dim=0)
        if calculate_self_certainty:
            self_certaintys = torch.concat(self_certainty_lst, dim=0)
        if use_dynamic_bsz:
            indices = list(itertools.chain.from_iterable(indices))
            assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
            log_probs = log_probs[revert_indices]
            if calculate_entropy:
                entropys = entropys[revert_indices]

        return log_probs, entropys, self_certaintys
    
    @GPUMemoryLogger(role="dp actor", logger=logger)
    def update_policy(self, data: DataProto, extra_data=None):
        # make sure we are in training mode
        self.actor_module.train()

        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
        multi_turn = data.meta_info.get("multi_turn", False)
        non_tensor_batch = data.non_tensor_batch
        tmp = ['offpolicy' in x for x in non_tensor_batch["index"]]
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages", "segments"]
        if multi_turn:
            select_keys.append("loss_mask")
        if self.config.use_kl_loss or getattr(self.config, "use_ref", False):
            select_keys.append("ref_log_prob")
        if self.config.explore_loss.startswith("segment_entropy"):
            select_keys.append('segments')
        if self.config.explore_loss.startswith("lower_ppl"):
            select_keys.append('sequence_level_ppl')
        batch = data.select(batch_keys=select_keys).batch
        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()

        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        if has_multi_modal_inputs:
            num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
            non_tensor_select_keys = ["multi_modal_inputs"]
            dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
        else:
            dataloader = batch.split(self.config.ppo_mini_batch_size)
        

        metrics = {}
        split_non_tensor_batch = split_dict_by_batch(original_dict=non_tensor_batch, mini_bsz=self.config.ppo_mini_batch_size)

        for epoch in range(self.config.ppo_epochs):
            for batch_idx, data in enumerate(dataloader):
                # split batch into micro_batches
                mini_batch = data
                mini_batch_non_tensor_batch = split_non_tensor_batch[batch_idx]
                selector = ['explore' in x for x in mini_batch_non_tensor_batch["index"]]
                op_selector = ['offpolicy' in x for x in mini_batch_non_tensor_batch["index"]]
                if "responses" in mini_batch:
                    mini_batch["entropy_mask"] = get_mask(mini_batch['responses'], selector)
                    mini_batch["op_mask"] = get_mask(mini_batch['responses'], op_selector)
                    
                if has_multi_modal_inputs:
                    self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                    num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
                    micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
                    # micro_non_tensor_batch = split_dict_by_batch(original_dict=mini_batch_non_tensor_batch, mini_bsz=num_micro_batches)
                elif self.config.use_dynamic_bsz:
                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                    micro_batches, grouped_idxes = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
                
                else:
                    self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                    # split batch into micro_batches
                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
                    # micro_non_tensor_batch = split_dict_by_batch(original_dict=mini_batch_non_tensor_batch, mini_bsz=self.config.ppo_micro_batch_size_per_gpu)

                self.actor_optimizer.zero_grad()
                cnt = 0
                for data in micro_batches:
                    # Support all hardwares
                    cnt += 1
                    if isinstance(data, DataProto):
                        data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
                    else:
                        data = data.to(torch.cuda.current_device())  # actor device is cpu when using offload
                        
                    responses = data["responses"]
                    response_length = responses.size(1)
                    attention_mask = data["attention_mask"]
                    if multi_turn:
                        response_mask = data["loss_mask"][:, -response_length:]
                    else:
                        response_mask = attention_mask[:, -response_length:]

                    old_log_prob = data["old_log_probs"]
                    advantages = data["advantages"]
                    ref_log_probs = None 

                    if "ref_log_prob" in data:
                        ref_log_probs = data['ref_log_prob']

                    clip_ratio = self.config.clip_ratio
                    clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
                    clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
                    clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
                    entropy_coeff = self.config.entropy_coeff
                    loss_agg_mode = self.config.loss_agg_mode
                    loss_mode = self.config.loss_mode

                    # all return: (bsz, response_length)
                    calculate_entropy = False
                    # if entropy_coeff != 0:
                    calculate_entropy = True

                    entropy, log_prob, certainty = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
                    if loss_mode == "vanilla":
                        pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                        )
                    elif loss_mode == "focal":
                        pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss_focal(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                        )
                    elif loss_mode.startswith("implicit"):
                        pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss_implicit(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                            entropy=entropy,
                            ref_log_prob=ref_log_probs
                        )
                    elif loss_mode == "20percent":
                        mask_higher = generate_highentropy_mask(entropy, response_mask)
                        pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=mask_higher,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                        )
                    elif loss_mode == "clip_cov":
                        pg_loss, pg_clipfrac, ppo_kl= compute_policy_loss_clip_cov(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            cliprange=clip_ratio,
                            cliprange_low=clip_ratio_low,
                            cliprange_high=clip_ratio_high,
                            loss_agg_mode=loss_agg_mode,
                            clip_ratio=self.config.clip_cov_ratio,
                            clip_cov_lb=self.config.clip_cov_lb,
                            clip_cov_ub=self.config.clip_cov_ub,
                        )

                    elif loss_mode == "kl_cov":
                        pg_loss, pg_clipfrac, ppo_kl= compute_policy_loss_kl_cov(
                            old_log_prob=old_log_prob,
                            log_prob=log_prob,
                            advantages=advantages,
                            response_mask=response_mask,
                            loss_agg_mode=loss_agg_mode,
                            k_percent=self.config.k_percent,
                            ppo_kl_coef=self.config.ppo_kl_coef,
                        )

                    else:
                        raise ValueError(f"Unsupported loss mode: {self.config.loss_mode}")
                    ######### check segment statistics
                    segments = data['segments'] 
                    low_ppl_mask = generate_perplexity_mask(segments, log_prob, response_mask)
                    flag = (response_mask*low_ppl_mask).sum()>0
                    if flag:
                        high_perplexity_seg_token_entropy = entropy[(response_mask*low_ppl_mask).bool()]
                        low_perplexity_seg_token_entropy = entropy[torch.logical_not((response_mask*(1-low_ppl_mask)).bool())]
                        metrics["segments/high_perplexity_entropy_min"] = high_perplexity_seg_token_entropy.min().item()
                        metrics["segments/high_perplexity_entropy_max"] = high_perplexity_seg_token_entropy.max().item()
                        metrics["segments/high_perplexity_entropy_mean"] = high_perplexity_seg_token_entropy.mean().item()
                        metrics["segments/low_perplexity_entropy_min"] = low_perplexity_seg_token_entropy.min().item()
                        metrics["segments/low_perplexity_entropy_max"] = low_perplexity_seg_token_entropy.max().item()
                        metrics["segments/low_perplexity_entropy_mean"] = low_perplexity_seg_token_entropy.mean().item()
                    if entropy_coeff != 0: #  and flag:
                        print(f"===> {self.config.explore_loss} entropy_loss: ")
                        emin = entropy[entropy>0].min().detach()
                        emean = entropy[entropy>0].mean().detach()
                        sigma = (emean - emin)/3. 
                        upper = sigma*2.0 + emean
                        print(f"===> entropy_loss: emin: {emin.item()}, emean: {emean.item()}, emax:{entropy.max().detach().item()}, sigma: {sigma.item()}, clipped: {upper.item()}")
                        if self.config.explore_loss=="lower_entropy":
                            mask0 = response_mask * data['entropy_mask']
                            mask = mask0 * (entropy<1.0).to(response_mask.dtype)
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask * (1-data['entropy_mask']), loss_agg_mode=loss_agg_mode).detach().max().item()
                            ratio = mask.sum() / mask0.sum()
                            metrics["actor/smaller_entropy_ratio_explorer"] = ratio.detach().item()
                            target_entropy = 0.15
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            print(f"===> {self.config.explore_loss}  entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                        elif self.config.explore_loss=="lower_entropy_noexp":
                            mask0 = response_mask 
                            mask = mask0 * (entropy<1.0).to(response_mask.dtype)
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode).detach().max().item()
                            ratio = mask.sum() / mask0.sum()
                            metrics["actor/smaller_entropy_ratio_explorer"] = ratio.detach().item()
                            target_entropy = 0.15
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                        elif self.config.explore_loss=="lower_ppl_entropy_noexp":
                            ppls = data['sequence_level_ppl']
                            lower_ppl_ratio = (ppls<0.6).float().sum()/len(ppls)
                            mask0 = response_mask.clone() 
                            mask0[ppls>=0.6] = 0 
                            mask0[advantages>0] = 0
                            entloss_ratio = mask0.any(-1).sum()/len(ppls)

                            mask = mask0 
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode).detach().max().item()
                            metrics["actor/smaller_ppl_ratio"] = lower_ppl_ratio.detach().item()
                            target_entropy = 0.15
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            print(f'===> lower ppl ratio: {metrics["actor/smaller_ppl_ratio"]}, has entropy loss ratio: {entloss_ratio.item()}')
                            print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                        
                        elif self.config.explore_loss=="segment_entropy":
                            segments = data['segments'] # why is segments not available?
                            ppl_mask = generate_perplexity_mask(segments, log_prob, response_mask)
                            mask0 = response_mask * data['entropy_mask']
                            mask1 = ppl_mask * mask0 
                            high_perplexity_ratio = mask1.sum()/mask0.sum()
                            mask0 = mask1
                            metrics["actor/high_perplexity_ratio_explorer"] = high_perplexity_ratio.detach().item()
                            mask = mask0 * (entropy<0.6).to(response_mask.dtype)
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask * (1-data['entropy_mask']) * (1-data['op_mask']), loss_agg_mode=loss_agg_mode).detach().max().item()
                            metrics["actor/explorer_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask * data['entropy_mask'], loss_agg_mode=loss_agg_mode).detach().max().item()
                            
                            ratio = mask.sum() / mask0.sum()
                            metrics["actor/smaller_entropy_ratio_explorer"] = ratio.detach().item()
                            target_entropy = 0.15
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            if mean_entropy<0.075: 
                                policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            else: policy_loss = pg_loss
                            print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} *{entropy_loss.item()} (target={target_entropy}, mean={mean_entropy})")
                            print(f"===> high ppl ratio: {high_perplexity_ratio.detach().item()}, smaller entropy ratio: {ratio.item()}")
                                  
                        elif self.config.explore_loss.startswith("segment_entropy_v2"):
                            # segments = data['segments'] # why is segments not available?
                            # low_ppl_mask = generate_perplexity_mask(segments, log_prob, response_mask)
                            if "noexp" in self.config.explore_loss:
                                mask0 = response_mask
                            else: mask0 = response_mask  * data['entropy_mask']
                            mask1 = low_ppl_mask * mask0 
                            high_perplexity_ratio = mask1.sum()/mask0.sum()
                            mask0 = mask1
                            metrics["actor/high_perplexity_ratio_explorer"] = high_perplexity_ratio.detach().item()
                            mask = mask0 * (entropy<1.0).to(response_mask.dtype)
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            # breakpoint()
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask , loss_agg_mode=loss_agg_mode).detach().max().item()
                            ratio = mask.sum() / mask0.sum()
                            metrics["actor/smaller_entropy_ratio_explorer"] = ratio.detach().item()
                            target_entropy = 0.3
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                            
                        elif self.config.explore_loss=="segment_entropy_noexp":
                            # segments = data['segments'] # why is segments not available?
                            # low_ppl_mask = generate_perplexity_mask(segments, log_prob, response_mask)
                            mask0 = response_mask  
                            mask1 = low_ppl_mask * mask0 
                            high_perplexity_ratio = mask1.sum()/mask0.sum()
                            mask0 = mask1
                            metrics["actor/high_perplexity_ratio_explorer"] = high_perplexity_ratio.detach().item()
                            mask = mask0 * (entropy<1.0).to(response_mask.dtype)
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask , loss_agg_mode=loss_agg_mode).detach().max().item()
                            ratio = mask.sum() / mask0.sum()
                            metrics["actor/smaller_entropy_ratio_explorer"] = ratio.detach().item()
                            target_entropy = 0.15
                            mean_entropy = entropy_loss.detach().item()
                            scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                            policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                            print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                            
                        else:
                            mask0 = response_mask 
                            metrics["actor/entropy_max"] = entropy.detach().max().item()
                            metrics["actor/policy_entropy"] = entropy.detach().mean().item()
                            higher_entropy = (entropy>entropy.mean()).to(mask0.dtype) 
                            mask = mask0 * higher_entropy 
                            ratio = mask.sum()/mask0.sum()
                            metrics["actor/higher_entropy_ratio"] = ratio.detach().item()
                            ###### uncomment below will use only higher entropy
                            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                            metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                            entropy_part = - 0.1*entropy_coeff * entropy_loss
                            policy_loss = pg_loss + entropy_part
                            if self.config.explore_loss!="higher_entropy":
                                mask = mask0
                                print(f"====> standard entropy loss {entropy_part.item()}")
                            else:
                                print(f"===> higher entropy mode filters with {ratio} higer entropy tokens")
                            

                        # if emean>4.0 or upper>8.0:
                        #     scaler = 0.0
                        #     policy_loss = pg_loss
                        #     print(f"===> entropy loss skipped")
                        # else:
                        #     entropy = torch.clamp(entropy, max=upper)
                        #     mask = response_mask * data['entropy_mask'] if self.config.explore else response_mask
                        #     entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                        #     metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                        #     metrics["actor/entropy_max"] = entropy.detach().max().item()
                        #     # compute policy loss
                        #     upper = 0.6 
                        #     target_entropy = 0.5
                        #     mean_entropy = entropy_loss.detach().item()
                        #     scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                        #     policy_loss = pg_loss - scaler * entropy_coeff * entropy_loss
                        #     print(f"===> explore={self.config.explore}, entropy loss weight {entropy_coeff}*{scaler} (mean-entropy-loss={mean_entropy}). loss_agg_mode={loss_agg_mode}")
                    else:
                        policy_loss = pg_loss
                        metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode).item()

                    if self.config.use_kl_loss:
                        ref_log_prob = data["ref_log_prob"]
                        # compute kl loss
                        kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
                        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)

                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
                        metrics["actor/kl_loss"] = kl_loss.detach().item()
                        metrics["actor/kl_coef"] = self.config.kl_loss_coef
                    
                    if self.config.use_dynamic_bsz:
                        # relative to the dynamic bsz
                        loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
                    else:
                        loss = policy_loss / self.gradient_accumulation
                    loss.backward()

                    data = {
                        "actor/pg_loss": pg_loss.detach().item(),
                        "actor/pg_clipfrac": pg_clipfrac.detach().item(),
                        "actor/ppo_kl": ppo_kl.detach().item(),
                        # "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
                    }
                    append_to_dict(metrics, data)

                grad_norm = self._optimizer_step()
                data = {"actor/grad_norm": grad_norm.detach().item()}
            append_to_dict(metrics, data)
        self.actor_optimizer.zero_grad()
        return metrics
    
    @GPUMemoryLogger(role="dp actor", logger=logger)
    def update_policy_sft(self, data: DataProto):
        # make sure we are in training mode
        self.actor_module.train()

        temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
        multi_turn = data.meta_info.get("multi_turn", False)
        non_tensor_batch = data.non_tensor_batch
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "segments"]
        if multi_turn:
            select_keys.append("loss_mask")
        if self.config.use_kl_loss or getattr(self.config, "use_ref", False):
            select_keys.append("ref_log_prob")
        if self.config.explore_loss.startswith("lower_ppl"):
            select_keys.append('sequence_level_ppl')
        batch = data.select(batch_keys=select_keys).batch
        has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()

        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        if has_multi_modal_inputs:
            num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
            non_tensor_select_keys = ["multi_modal_inputs"]
            dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
        else:
            dataloader = batch.split(self.config.ppo_mini_batch_size)
        

        metrics = {}
        split_non_tensor_batch = split_dict_by_batch(original_dict=non_tensor_batch, mini_bsz=self.config.ppo_mini_batch_size)
        
        for batch_idx, data in enumerate(dataloader):
            # split batch into micro_batches
            mini_batch = data
            
            mini_batch_non_tensor_batch = split_non_tensor_batch[batch_idx]
            
            if has_multi_modal_inputs:
                self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
                micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
                print('gradacc', self.gradient_accumulation)
                # micro_non_tensor_batch = split_dict_by_batch(original_dict=mini_batch_non_tensor_batch, mini_bsz=num_micro_batches)
            elif self.config.use_dynamic_bsz:
                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, grouped_idxes = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
                print('gradacc is unclear because dynamic bsz is used')
            
            else:
                self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                # split batch into micro_batches
                micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
                print('gradacc', self.gradient_accumulation)
                # micro_non_tensor_batch = split_dict_by_batch(original_dict=mini_batch_non_tensor_batch, mini_bsz=self.config.ppo_micro_batch_size_per_gpu)
            print(f"==> gradient step with {len(mini_batch)} entries, per-gpu mini_bsz={self.config.ppo_mini_batch_size}, micro_bsz={self.config.ppo_micro_batch_size_per_gpu}")
            self.actor_optimizer.zero_grad()
            cnt = 0
            for data in micro_batches:
                # Support all hardwares
                cnt += 1
                if isinstance(data, DataProto):
                    data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
                else:
                    data = data.to(torch.cuda.current_device())  # actor device is cpu when using offload
                    
                responses = data["responses"]
                response_length = responses.size(1)
                attention_mask = data["attention_mask"]
                if multi_turn:
                    response_mask = data["loss_mask"][:, -response_length:]
                else:
                    response_mask = attention_mask[:, -response_length:]


                clip_ratio = self.config.clip_ratio
                clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
                clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
                clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
                entropy_coeff = self.config.entropy_coeff
                loss_agg_mode = self.config.loss_agg_mode
                loss_mode = self.config.loss_mode

                # all return: (bsz, response_length)
                calculate_entropy = False
                # if entropy_coeff != 0:
                calculate_entropy = True

                entropy, log_prob, certainty = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
                pg_loss = compute_sft_loss(
                        # old_log_prob=old_log_prob,
                        log_prob=log_prob,
                        # advantages=advantages,
                        response_mask=response_mask,
                        cliprange=clip_ratio,
                        cliprange_low=clip_ratio_low,
                        cliprange_high=clip_ratio_high,
                        loss_agg_mode=loss_agg_mode,
                    )
                ######### check segment statistics
                segments = data['segments'] # why is segments not available?
                low_ppl_mask = generate_perplexity_mask(segments, log_prob, response_mask)
                high_perplexity_seg_token_entropy = entropy[(response_mask*low_ppl_mask).bool()]
                low_perplexity_seg_token_entropy = entropy[torch.logical_not((response_mask*(1-low_ppl_mask)).bool())]
                metrics["segments/high_perplexity_entropy_min"] = high_perplexity_seg_token_entropy.min().item()
                metrics["segments/high_perplexity_entropy_max"] = high_perplexity_seg_token_entropy.max().item()
                metrics["segments/high_perplexity_entropy_mean"] = high_perplexity_seg_token_entropy.mean().item()
                metrics["segments/low_perplexity_entropy_min"] = low_perplexity_seg_token_entropy.min().item()
                metrics["segments/low_perplexity_entropy_max"] = low_perplexity_seg_token_entropy.max().item()
                metrics["segments/low_perplexity_entropy_mean"] = low_perplexity_seg_token_entropy.mean().item()
                metrics["actor/policy_entropy"] = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode).max().item()
                metrics["actor/entropy_max"] = entropy.max().item()
                metrics["actor/entropy_min"] = entropy[entropy>0].min().item()
                emin = entropy[entropy>0].min().detach()
                emean = entropy[entropy>0].mean().detach()
                sigma = (emean - emin)/3. 
                upper = sigma*2.0 + emean
                print(f"===> pgloss: {pg_loss.item()}, entropy: emin: {emin.item()}, emean: {emean.item()}, emax:{entropy.max().detach().item()}, sigma: {sigma.item()}, clipped: {upper.item()}")
                
                if entropy_coeff != 0:
                    pass
                    # print(f"===> {self.config.explore_loss} entropy_loss: ")
                    # if self.config.explore_loss=="segment_entropy_v2":
                    #     # segments = data['segments'] # why is segments not available?
                    #     low_ppl_mask = generate_lowppl_mask(segments, log_prob, response_mask)
                    #     mask0 = response_mask  
                    #     mask1 = low_ppl_mask * mask0 
                    #     high_perplexity_ratio = mask1.sum()/mask0.sum()
                    #     mask0 = mask1
                    #     metrics["actor/low_perplexity_ratio_explorer"] = high_perplexity_ratio.detach().item()
                    #     mask = mask0 * (entropy<1.0).to(response_mask.dtype)
                    #     entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                    #     metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                    #     ratio = mask.sum() / mask0.sum()
                    #     metrics["actor/smaller_entropy_ratio"] = ratio.detach().item()
                    #     target_entropy = 0.4
                    #     mean_entropy = entropy_loss.detach().item()
                    #     scaler =  min(target_entropy / (mean_entropy +1e-8), 5.0)
                    #     policy_loss = pg_loss - entropy_coeff * scaler * entropy_loss
                    #     print(f"===> {self.config.explore_loss} entropy_loss: {scaler} * {entropy_coeff} (target={target_entropy}, mean={mean_entropy})")
                    # else:
                    #     mask = response_mask 
                    #     entropy_loss = agg_loss(loss_mat=entropy, loss_mask=mask, loss_agg_mode=loss_agg_mode)
                    #     metrics["actor/entropy_loss"] = entropy_loss.detach().item()
                    #     policy_loss = pg_loss - entropy_coeff * entropy_loss
                    #     print(f"===> default entropy_loss: ")
                else:
                    policy_loss = pg_loss
                    
                if self.config.use_kl_loss:
                    ref_log_prob = data["ref_log_prob"]
                    # compute kl loss
                    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
                    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)

                    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
                    metrics["actor/kl_loss"] = kl_loss.detach().item()
                    metrics["actor/kl_coef"] = self.config.kl_loss_coef
                
                if self.config.use_dynamic_bsz:
                    # relative to the dynamic bsz
                    loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
                else:
                    loss = policy_loss / self.gradient_accumulation
                loss.backward()

                data = {
                    "actor/pg_loss": pg_loss.detach().item(),
                    # "actor/pg_clipfrac": pg_clipfrac.detach().item(),
                    # "actor/ppo_kl": ppo_kl.detach().item(),
                    # "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
                }
                append_to_dict(metrics, data)
            print(f"===> gradient step with {cnt}x{len(responses)} on each device")
            grad_norm = self._optimizer_step()
            data = {"actor/grad_norm": grad_norm.detach().item()}
        append_to_dict(metrics, data)
        self.actor_optimizer.zero_grad()
        return metrics
    
    def compute_entropy_for_every_source(self, data: DataProto) -> torch.Tensor:
            
            if data.meta_info['train_mode'] ==True:
                self.actor_module.train()
                print("train mode")
            else:
                self.actor_module.eval()
                print("eval mode")
    
            micro_batch_size = data.meta_info['micro_batch_size']
            temperature = data.meta_info['temperature']  # temperature must be in the data.meta_info to avoid slient error
            use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
            temperature = 1
            
            select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
            batch = data.select(batch_keys=select_keys).batch
    
            # Split to make minibatch iterator for updating the actor
            # See PPO paper for details. https://arxiv.org/abs/1707.06347
            if use_dynamic_bsz:
                # split using dynamic bsz
                max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
                micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
            else:
                micro_batches = batch.split(micro_batch_size)
            
            
            # all_entropy_losses = {}
            entropy_lst = []
            entropy_loss_lst = []
            for micro_batch in micro_batches:
                responses = micro_batch['responses']
                response_length = responses.size(1)
                attention_mask = micro_batch['attention_mask']
                response_mask = attention_mask[:, -response_length:]
                with torch.no_grad():
                    entropy, _, _ = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=True)
                    
                    entropy_loss = verl_F.masked_mean(entropy, response_mask, axis=-1)  # (bsz, response_length)
                entropy_loss_lst.append(entropy_loss)
                entropy_lst.append(entropy)
            entropy_loss_lst = torch.concat(entropy_loss_lst, dim=0)
            entropy_lst = torch.concat(entropy_lst, dim=0)
            
            # print("entropy_lst len,", len(entropy_lst))
            
            if use_dynamic_bsz:
                indices = list(itertools.chain.from_iterable(indices))
                assert len(indices) == entropy_lst.size(0), f"{len(indices)} vs. {entropy_lst.size()}"
                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
                entropy_lst = entropy_lst[revert_indices]
                entropy_loss_lst = entropy_loss_lst[revert_indices]
            
            return entropy_lst, entropy_loss_lst
