from typing import Optional, Tuple, Dict, Any

import math
import torch
import torch.distributed as dist
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F

from .utils import masked_mean, masked_sum

from ares.utils.utils import get_tensor_stats, flatten_dict, get_global_statistics
from ares.models.megatron_log_softmax import vocab_parallel_log_softmax_and_gather_targets
from ares.models.megatron_entropy import vocab_parallel_entropy_by_logits
from ares.models.auto_temperature import apply_temperature
from ares.models.megatron_mtp_loss import compute_mtp_loss
from ares.utils.ppo import compute_approx_kl
from megatron.core import parallel_state as mpu
from ares.utils.logger import logger


def average_metrics_across_data_parallel_group(metrics: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    values = [v.view(1) for k, v in metrics.items()]
    values = torch.cat(values)
    torch.distributed.all_reduce(values, group=mpu.get_data_parallel_group())
    values = values / torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
    values = values.cpu()
    return {k: values[i].item() for i, k in enumerate(metrics.keys())}


def agg_loss(
    loss_mat: torch.Tensor,
    loss_mask: torch.Tensor,
    loss_agg_mode: str,
    batch: Dict[str, Any],
    max_tokens: Optional[int] = None,
) -> torch.Tensor:
    """
    Aggregate the loss matrix into a scalar.
    Args:
        loss_mat: `(torch.Tensor)`
            shape: (bs, response_length)
        loss_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean"
            "token-mean" is the default behavior
        max_tokens: max_tokens (`int`, *optional*, only for dr_grpo):
                Maximum number of tokens to generate for each prompt.
    Returns:
        loss: `a scalar torch.Tensor`
            aggregated loss
    """
    if loss_agg_mode == "dr_grpo":
        assert isinstance(max_tokens, int)
        loss = (loss_mat * loss_mask).sum() / (loss_mat.size(0) * max_tokens)
    elif loss_agg_mode == "token-mean":
        total_tokens_per_group = batch.get("total_tokens_per_group", None)
        num_responses = batch.get("num_responses", None)
        bs = loss_mat.shape[0]
        assert total_tokens_per_group is not None and num_responses is not None
        assert (total_tokens_per_group == total_tokens_per_group[0]).all()

        k = total_tokens_per_group[0].item() / (num_responses[0].item() / bs)
        loss = masked_sum(loss_mat, loss_mask) / k
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        loss = masked_mean(loss_mat, loss_mask, dim=-1).mean()
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss


class PolicyLoss(nn.Module):
    """PolicyLoss for PPOTrainer"""

    def __init__(
        self,
        cliprange=0.2,
        clip_ratio_c=10.0,  # The original paper default is 3.0. None is for backwards compatible
        kl_loss_coef=0.0,
        kl_loss_type="k3",
        nll_loss_coef=0.0,
        entropy_loss_coef=0.0,
        ignore_useless_data=False,
        enable_temperature=False,
        spike_prob_ratio_limit=None,
        loss_agg_mode="seq-mean-token-mean",
        max_tokens=None,
        mtp_loss_coef=0.0,
        dump_token_level_log=False,
    ):
        super().__init__()
        if clip_ratio_c is not None:
            assert (
                clip_ratio_c > 1.0
            ), f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."
        self.clip_ratio_c = clip_ratio_c  # dual-clip
        self.cliprange = cliprange
        self.cliphighrange = cliprange
        if isinstance(cliprange, (list, tuple)) and len(cliprange) > 1:
            self.cliprange = cliprange[0]
            self.cliphighrange = cliprange[1]
        self.kl_loss_coef = kl_loss_coef
        self.kl_loss_type = kl_loss_type
        self.nll_loss_coef = nll_loss_coef
        self.entropy_loss_coef = entropy_loss_coef
        self.ignore_useless_data = ignore_useless_data
        self.enable_temperature = enable_temperature
        self.spike_ratio_limit = spike_prob_ratio_limit
        self.history_ratio_max = torch.tensor(1.0, dtype=torch.float32)
        self.loss_agg_mode = loss_agg_mode
        self.max_tokens = max_tokens
        self.mtp_loss_coef = mtp_loss_coef
        self.dump_token_level_log = dump_token_level_log

    def forward(
        self, batch: Dict[str, torch.Tensor], output: torch.Tensor, non_loss_data=True
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # output: (B, S, V/TP)
        output = output.float()  # already done by megatron fp16module
        temperature = 1.0
        if self.enable_temperature:
            temperature = batch["temperature"]
            # output = apply_temperature(output, temperature, inplace=True)
        input_ids = batch["input_ids"]  # (B, S)
        old_action_logprobs = batch["action_log_probs"]  # (B, A)
        action_ref_logprobs = batch.get('action_ref_log_probs', None)  # (B, A)
        advantages = batch["advantages"]  # (B, A) or # (B)
        if advantages.dim() == 1:
            advantages = advantages.view(-1, 1)
        action_mask = batch["action_mask"]  # (B, A)
        attention_mask = batch["attention_mask"]  # (B, S)
        ignore_mask = batch.get("ignore_mask", None)  # (B, A)
        try:
            data_version = batch["data_version"]
            train_version = batch["train_version"]
            staleness = (train_version - data_version).mean()
        except:
            staleness = torch.zeros(1, dtype=advantages.dtype, device=advantages.device)
        if ignore_mask is not None and self.ignore_useless_data:
            # we don't want to compute the loss for some data. e.g. the response that is truncated.
            unsqueezed_ignore_mask = ignore_mask.unsqueeze(1)  # size: (mbs, 1)
            assert unsqueezed_ignore_mask.dim() == 2 and unsqueezed_ignore_mask.dtype == torch.bool
            action_mask = action_mask * (~unsqueezed_ignore_mask)

        # when the action_mask is empty, the num_actions will be 0.
        # and the response_mask and logprob will use the whole sequence for computing loss.
        # Then this loss function would be WRONG, we need fix this.
        num_actions = action_mask.size(1)

        # entropy
        if self.entropy_loss_coef == 0.0:
            with torch.no_grad():
                full_entropys = vocab_parallel_entropy_by_logits(output, temperature, False)
        else:
            full_entropys = vocab_parallel_entropy_by_logits(output, temperature, False)
        action_entropys = full_entropys[:, -num_actions - 1 : -1]  # (B, A)

        # action logprobs
        # shift_labels = torch.roll(input_ids, -1, 1)  # The npu torch.roll interface needs to be verified
        shift_labels = torch.zeros_like(input_ids)
        shift_labels[:, :-1] = input_ids[:, 1:]
        logprobs = vocab_parallel_log_softmax_and_gather_targets(output, temperature, shift_labels, inplace=True)
        action_logprobs = logprobs[:, -num_actions - 1 : -1]  # (B, A)

        # MTP loss
        mtp_loss = compute_mtp_loss(batch, output, self.mtp_loss_coef, pad_token_id=0)

        # PPO's pessimistic surrogate
        # make sure padding is removed
        # log_p_clamp_min = math.log(1.0e-10)
        # action_logprobs.clamp(min=log_p_clamp_min)
        # old_action_logprobs.clamp(min=log_p_clamp_min)
        log_ratio = (action_logprobs - old_action_logprobs) * action_mask
        ratio = log_ratio.exp()

        # ignore the spike ratio, maybe there is a bug...
        self.history_ratio_max = self.history_ratio_max.to(ratio.device)
        if self.spike_ratio_limit is not None:
            spike_ratio_value = self.spike_ratio_limit * self.history_ratio_max
            not_spike_mask = (ratio < spike_ratio_value).all(dim=-1).unsqueeze(dim=-1)
            action_mask = action_mask * not_spike_mask.to(action_mask.dtype)

        # clip ratio
        select_ratio = torch.masked_select(ratio, action_mask.bool())
        if select_ratio.numel() > 0:
            self.history_ratio_max = torch.max(self.history_ratio_max, select_ratio.max())
        ratio_clip_upper_frac = (
            ((ratio > 1.0 + self.cliphighrange) & (advantages > 0)) * action_mask
        ).sum() / action_mask.sum()
        ratio_clip_lower_frac = (
            ((ratio < 1.0 - self.cliprange) & (advantages < 0)) * action_mask
        ).sum() / action_mask.sum()
        ratio_clip_positive_upper_frac = (
            ((ratio > 1.0 + self.cliphighrange) & (advantages > 0)) * action_mask
        ).sum() / ((advantages > 0) * action_mask).sum()
        ratio_clip_positive_upper_frac = (
            ratio_clip_positive_upper_frac
            if not ratio_clip_positive_upper_frac.isnan().sum().item()
            else torch.zeros_like(ratio_clip_upper_frac)
        )
        ratio_clip_negative_lower_frac = (((ratio < 1.0 - self.cliprange) & (advantages < 0)) * action_mask).sum() / (
            (advantages < 0) * action_mask
        ).sum()
        ratio_clip_negative_lower_frac = (
            ratio_clip_negative_lower_frac
            if not ratio_clip_negative_lower_frac.isnan().sum().item()
            else torch.zeros_like(ratio_clip_upper_frac)
        )

        # policy loss
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1.0 - self.cliprange, 1.0 + self.cliphighrange) * advantages
        clip_pg_losses1 = -torch.min(surr1, surr2)
        if self.clip_ratio_c is not None:
            pg_losses3 = -advantages * self.clip_ratio_c
            clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
            policy_loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
            pg_clipfrac_lower = masked_mean(
                torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), action_mask
            )
        else:
            policy_loss = clip_pg_losses1
            pg_clipfrac_lower = torch.tensor(0.0, dtype=torch.float32)
        policy_loss = agg_loss(policy_loss, action_mask, self.loss_agg_mode, batch, self.max_tokens)
        pg_clipfrac = masked_mean((surr2 < surr1).float(), action_mask)
        ppl = torch.exp(-masked_mean(action_logprobs, action_mask, dim=-1))
        old_ppl = torch.exp(-masked_mean(old_action_logprobs, action_mask, dim=-1))

        # entropy loss
        if self.entropy_loss_coef == 0.0:
            with torch.no_grad():
                entropy_loss = -agg_loss(action_entropys, action_mask, self.loss_agg_mode, batch, self.max_tokens)
        else:
            entropy_loss = -agg_loss(action_entropys, action_mask, self.loss_agg_mode, batch, self.max_tokens)

        with torch.no_grad():
            entropy = masked_mean(action_entropys, action_mask)

        flat_masked_action_entropys = torch.masked_select(action_entropys, action_mask).flatten()
        percentiles = [90, 80, 60, 40, 20]  # corresponding to 10%, 20%, 40%, 60%, 80%
        entropy_ratios = {}
        means = {}

        if flat_masked_action_entropys.numel() > 0:
            thresholds = torch.quantile(
                flat_masked_action_entropys,
                torch.tensor([p / 100 for p in percentiles], device=flat_masked_action_entropys.device),
            )
            entropy_ranges = [(0, 0.5), (0.5, 0.8), (0.8, 1.2), (1.2, 1.5), (1.5, 2.0), (2.0, 3.0), (3.0, float('inf'))]
            total_elements = len(flat_masked_action_entropys)
            for lower, upper in entropy_ranges:
                if upper == float('inf'):
                    mask = flat_masked_action_entropys >= lower
                else:
                    mask = (flat_masked_action_entropys >= lower) & (flat_masked_action_entropys < upper)
                count = torch.sum(mask)
                entropy_ratio = count / total_elements
                if upper == float('inf'):
                    entropy_ratios[f"entropy_{lower}+_ratio"] = entropy_ratio
                else:
                    entropy_ratios[f"entropy_{lower}-{upper}_ratio"] = entropy_ratio

            # Calculate mean for each percentile interval
            # Calculate mean for top 10%
            top_10_mask = flat_masked_action_entropys >= thresholds[0]
            means["mean_entropy_top_10%"] = torch.mean(flat_masked_action_entropys[top_10_mask])
            # Calculate mean for other intervals
            for i in range(len(percentiles) - 1):
                upper_mask = flat_masked_action_entropys >= thresholds[i + 1]
                lower_mask = flat_masked_action_entropys < thresholds[i]
                interval_mask = upper_mask & lower_mask
                means[f"mean_entropy_top_{100-percentiles[i+1]}%_to_{100-percentiles[i]}%"] = torch.mean(
                    flat_masked_action_entropys[interval_mask]
                )
        else:
            thresholds = None

        # kl loss
        if action_ref_logprobs is None:
            kl_loss = torch.zeros(1, dtype=logprobs.dtype, device=logprobs.device)
        elif self.kl_loss_coef == 0.0:
            with torch.no_grad():
                # beta * DKL(𝜋𝜃||𝜋𝑟𝑒𝑓)
                kl_loss = compute_approx_kl(
                    action_logprobs,
                    action_ref_logprobs,
                    action_mask=action_mask,
                    kl_estimator=self.kl_loss_type,
                )
                kl_loss = agg_loss(kl_loss, action_mask, self.loss_agg_mode, batch, self.max_tokens)
        else:
            # beta * DKL(𝜋𝜃||𝜋𝑟𝑒𝑓)
            kl_loss = compute_approx_kl(
                action_logprobs,
                action_ref_logprobs,
                action_mask=action_mask,
                kl_estimator=self.kl_loss_type,
            )
            kl_loss = agg_loss(kl_loss, action_mask, self.loss_agg_mode, batch, self.max_tokens)

        # nll_loss
        rewards = batch.get("rewards", None)
        if rewards is not None:
            nll_loss_mask = (rewards > 0) & action_mask
            rewards_mean = masked_mean(rewards, action_mask, dim=-1)
            nll_loss = -masked_mean(action_logprobs, nll_loss_mask, dim=-1) * rewards_mean
            nll_loss_mask = rewards_mean > 0
            nll_loss = masked_mean(nll_loss, nll_loss_mask, dim=-1)
        else:
            nll_loss = torch.zeros(1, dtype=logprobs.dtype, device=logprobs.device)

        total_loss = (
            policy_loss
            + self.kl_loss_coef * kl_loss
            + self.nll_loss_coef * nll_loss
            + self.entropy_loss_coef * entropy_loss
            + self.mtp_loss_coef * mtp_loss
        )
        stats = {
            "total_loss": total_loss,
            "policy_loss": policy_loss,
            "entropy_loss": entropy_loss,
            "kl_loss": kl_loss,
            "nll_loss": nll_loss,
            "mtp_loss": mtp_loss,
            "pg_clipfrac": pg_clipfrac,
            "pg_clipfrac_lower": pg_clipfrac_lower,
            "ratio": masked_mean(ratio, action_mask, dim=-1).mean(),
            "ratio_clip_upper_frac": ratio_clip_upper_frac,
            "ratio_clip_lower_frac": ratio_clip_lower_frac,
            "ratio_clip_positive_upper_frac": ratio_clip_positive_upper_frac,
            "ratio_clip_negative_lower_frac": ratio_clip_negative_lower_frac,
            "ppl": ppl.mean(),
            "ppl_old": old_ppl.mean(),
            "train_staleness": staleness,
        }
        logger.info(f'loss stats: {stats}')
        if flat_masked_action_entropys.numel() > 0:
            stats.update({"entropy": entropy})

        if self.dump_token_level_log:
            # token level log return to megatron engine then return to trainer
            with torch.no_grad():
                actual_clip_high = ((ratio > 1.0 + self.cliphighrange) & (advantages > 0)) * action_mask
                actual_clip_low = ((ratio < 1.0 - self.cliprange) & (advantages < 0)) * action_mask
                clip_high = ((ratio > 1.0 + self.cliprange) & (advantages > 0)) * action_mask

            log_stats = [
                input_ids.detach().cpu().numpy().tolist()[0],
                advantages.detach().cpu().numpy().tolist()[0],
                action_entropys.detach().cpu().numpy().tolist()[0],
                action_mask.detach().cpu().numpy().tolist()[0],
                action_logprobs.detach().cpu().numpy().tolist()[0],
                old_action_logprobs.detach().cpu().numpy().tolist()[0],
                actual_clip_high.detach().cpu().numpy().tolist()[0],
                clip_high.detach().cpu().numpy().tolist()[0],
                actual_clip_low.detach().cpu().numpy().tolist()[0],
            ]
            stats["log_stats"] = log_stats

        stats.update(entropy_ratios)
        stats.update(means)
        if stats["ratio"].item() > 1e2:
            torch.set_printoptions(edgeitems=10000, linewidth=100000)
            ratio_value = stats["ratio"].item()
            logger.info(f"abnormal ratio!!! current ratio value: {ratio_value}")
            logger.info(f"action mask: {action_mask}")
            logger.info(f"action_logprobs: {action_logprobs}")
            logger.info(f"old_action_logprobs: {old_action_logprobs}")
            logger.info(f"log_ratio: {log_ratio}")
            logger.info(f"ratio: {ratio}")

        # The allreduce between dp will be done in trainer
        # stats = average_metrics_across_data_parallel_group(stats)
        for k, v in stats.items():
            if k.endswith("log_stats"):
                stats[k] = v
            else:
                stats[k] = v.item()

        return total_loss, stats


class GSPOPolicyLoss(nn.Module):
    """GSPO PolicyLoss for PPOTrainer"""

    def __init__(
        self,
        cliprange=0.2,
        clip_ratio_c=10.0,  # The original paper default is 3.0. None is for backwards compatible
        kl_loss_coef=0.0,
        kl_loss_type="k3",
        nll_loss_coef=0.0,
        entropy_loss_coef=0.0,
        ignore_useless_data=False,
        enable_temperature=False,
        spike_prob_ratio_limit=None,
        loss_agg_mode="seq-mean-token-mean",
        max_tokens=None,
        mtp_loss_coef=0.0,
        dump_token_level_log=False,
        use_gspo = True
    ):
        super().__init__()
        if clip_ratio_c is not None:
            assert (
                clip_ratio_c > 1.0
            ), f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."
        self.clip_ratio_c = clip_ratio_c  # dual-clip
        self.cliprange = cliprange
        self.cliphighrange = cliprange
        if isinstance(cliprange, (list, tuple)) and len(cliprange) > 1:
            self.cliprange = cliprange[0]
            self.cliphighrange = cliprange[1]
        self.kl_loss_coef = kl_loss_coef
        self.kl_loss_type = kl_loss_type
        self.nll_loss_coef = nll_loss_coef
        self.entropy_loss_coef = entropy_loss_coef
        self.ignore_useless_data = ignore_useless_data
        self.enable_temperature = enable_temperature
        self.spike_ratio_limit = spike_prob_ratio_limit
        self.history_ratio_max = torch.tensor(1.0, dtype=torch.float32)
        self.loss_agg_mode = loss_agg_mode
        self.max_tokens = max_tokens
        self.mtp_loss_coef = mtp_loss_coef
        self.dump_token_level_log = dump_token_level_log
        self.use_gspo = use_gspo

    def forward(
        self, batch: Dict[str, torch.Tensor], output: torch.Tensor, non_loss_data=True
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # output: (B, S, V/TP)
        output = output.float()  # already done by megatron fp16module
        temperature = 1.0
        if self.enable_temperature:
            temperature = batch["temperature"]
            # output = apply_temperature(output, temperature, inplace=True)
        input_ids = batch["input_ids"]  # (B, S)
        old_action_logprobs = batch["action_log_probs"]  # (B, A)
        action_ref_logprobs = batch.get('action_ref_log_probs', None)  # (B, A)
        advantages = batch["advantages"]  # (B, A) or # (B)
        if advantages.dim() == 1:
            advantages = advantages.view(-1, 1)
        action_mask = batch["action_mask"]  # (B, A)
        attention_mask = batch["attention_mask"]  # (B, S)
        ignore_mask = batch.get("ignore_mask", None)  # (B, A)
        try:
            data_version = batch["data_version"]
            train_version = batch["train_version"]
            staleness = (train_version - data_version).mean()
        except:
            staleness = torch.zeros(1, dtype=advantages.dtype, device=advantages.device)
        if ignore_mask is not None and self.ignore_useless_data:
            # we don't want to compute the loss for some data. e.g. the response that is truncated.
            unsqueezed_ignore_mask = ignore_mask.unsqueeze(1)  # size: (mbs, 1)
            assert unsqueezed_ignore_mask.dim() == 2 and unsqueezed_ignore_mask.dtype == torch.bool
            action_mask = action_mask * (~unsqueezed_ignore_mask)

        # when the action_mask is empty, the num_actions will be 0.
        # and the response_mask and logprob will use the whole sequence for computing loss.
        # Then this loss function would be WRONG, we need fix this.
        num_actions = action_mask.size(1)

        # entropy
        if self.entropy_loss_coef == 0.0:
            with torch.no_grad():
                full_entropys = vocab_parallel_entropy_by_logits(output, temperature, False)
        else:
            full_entropys = vocab_parallel_entropy_by_logits(output, temperature, False)
        action_entropys = full_entropys[:, -num_actions - 1 : -1]  # (B, A)

        # action logprobs
        # shift_labels = torch.roll(input_ids, -1, 1)  # The npu torch.roll interface needs to be verified
        shift_labels = torch.zeros_like(input_ids)
        shift_labels[:, :-1] = input_ids[:, 1:]
        logprobs = vocab_parallel_log_softmax_and_gather_targets(output, temperature, shift_labels, inplace=True)
        action_logprobs = logprobs[:, -num_actions - 1 : -1]  # (B, A)

        # MTP loss
        mtp_loss = compute_mtp_loss(batch, output, self.mtp_loss_coef, pad_token_id=0)

        # PPO's pessimistic surrogate
        # make sure padding is removed
        # log_p_clamp_min = math.log(1.0e-10)
        # action_logprobs.clamp(min=log_p_clamp_min)
        # old_action_logprobs.clamp(min=log_p_clamp_min)
        mask_mean_log_ratio = masked_mean(action_logprobs - old_action_logprobs, action_mask, dim=-1)
        ratio = mask_mean_log_ratio.exp()

        advantages = masked_mean(advantages, action_mask,  dim=-1)

        # ignore the spike ratio, maybe there is a bug...
        self.history_ratio_max = self.history_ratio_max.to(ratio.device)
        if self.spike_ratio_limit is not None:
            spike_ratio_value = self.spike_ratio_limit * self.history_ratio_max
            not_spike_mask = (ratio < spike_ratio_value).all(dim=-1).unsqueeze(dim=-1)
            action_mask = action_mask * not_spike_mask.to(action_mask.dtype)

        # clip ratio
        seq_ratio_clip_upper_frac = ((advantages > 0) & (ratio > 1.0 + self.cliphighrange)).float()
        seq_ratio_clip_lower_frac = ((advantages < 0) & (ratio < 1.0 - self.cliprange)).float()

        # policy loss
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1.0 - self.cliprange, 1.0 + self.cliphighrange) * advantages
        clip_pg_losses1 = -torch.min(surr1, surr2)

        if self.clip_ratio_c is not None:
            pg_losses3 = -advantages * self.clip_ratio_c
            clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
            policy_loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
        else:
            policy_loss = clip_pg_losses1
        # from https://stackoverflow.com/questions/61988776/how-to-calculate-perplexity-for-a-language-model-using-pytorch
        ppl = torch.exp(-masked_mean(action_logprobs, action_mask, dim=-1))
        old_ppl = torch.exp(-masked_mean(old_action_logprobs, action_mask, dim=-1))

        # entropy loss
        if self.entropy_loss_coef == 0.0:
            with torch.no_grad():
                entropy_loss = -agg_loss(action_entropys, action_mask, self.loss_agg_mode, batch, self.max_tokens)
        else:
            entropy_loss = -agg_loss(action_entropys, action_mask, self.loss_agg_mode, batch, self.max_tokens)

        with torch.no_grad():
            entropy = masked_mean(action_entropys, action_mask)

        flat_masked_action_entropys = torch.masked_select(action_entropys, action_mask).flatten()
        percentiles = [90, 80, 60, 40, 20]  # corresponding to 10%, 20%, 40%, 60%, 80%
        entropy_ratios = {}
        means = {}

        if flat_masked_action_entropys.numel() > 0:
            thresholds = torch.quantile(
                flat_masked_action_entropys,
                torch.tensor([p / 100 for p in percentiles], device=flat_masked_action_entropys.device),
            )
            entropy_ranges = [(0, 0.5), (0.5, 0.8), (0.8, 1.2), (1.2, 1.5), (1.5, 2.0), (2.0, 3.0), (3.0, float('inf'))]
            total_elements = len(flat_masked_action_entropys)
            for lower, upper in entropy_ranges:
                if upper == float('inf'):
                    mask = flat_masked_action_entropys >= lower
                else:
                    mask = (flat_masked_action_entropys >= lower) & (flat_masked_action_entropys < upper)
                count = torch.sum(mask)
                entropy_ratio = count / total_elements
                if upper == float('inf'):
                    entropy_ratios[f"entropy_{lower}+_ratio"] = entropy_ratio
                else:
                    entropy_ratios[f"entropy_{lower}-{upper}_ratio"] = entropy_ratio

            # Calculate mean for each percentile interval
            # Calculate mean for top 10%
            top_10_mask = flat_masked_action_entropys >= thresholds[0]
            means["mean_entropy_top_10%"] = torch.mean(flat_masked_action_entropys[top_10_mask])
            # Calculate mean for other intervals
            for i in range(len(percentiles) - 1):
                upper_mask = flat_masked_action_entropys >= thresholds[i + 1]
                lower_mask = flat_masked_action_entropys < thresholds[i]
                interval_mask = upper_mask & lower_mask
                means[f"mean_entropy_top_{100-percentiles[i+1]}%_to_{100-percentiles[i]}%"] = torch.mean(
                    flat_masked_action_entropys[interval_mask]
                )
        else:
            thresholds = None

        # kl loss
        if action_ref_logprobs is None:
            kl_loss = torch.zeros(1, dtype=logprobs.dtype, device=logprobs.device)
        elif self.kl_loss_coef == 0.0:
            with torch.no_grad():
                # beta * DKL(𝜋𝜃||𝜋𝑟𝑒𝑓)
                kl_loss = compute_approx_kl(
                    action_logprobs,
                    action_ref_logprobs,
                    action_mask=action_mask,
                    kl_estimator=self.kl_loss_type,
                )
                kl_loss = agg_loss(kl_loss, action_mask, self.loss_agg_mode, batch, self.max_tokens)
        else:
            # beta * DKL(𝜋𝜃||𝜋𝑟𝑒𝑓)
            kl_loss = compute_approx_kl(
                action_logprobs,
                action_ref_logprobs,
                action_mask=action_mask,
                kl_estimator=self.kl_loss_type,
            )
            kl_loss = agg_loss(kl_loss, action_mask, self.loss_agg_mode, batch, self.max_tokens)

        # nll_loss
        rewards = batch.get("rewards", None)
        if rewards is not None:
            nll_loss_mask = (rewards > 0) & action_mask
            rewards_mean = masked_mean(rewards, action_mask, dim=-1)
            nll_loss = -masked_mean(action_logprobs, nll_loss_mask, dim=-1) * rewards_mean
            nll_loss_mask = rewards_mean > 0
            nll_loss = masked_mean(nll_loss, nll_loss_mask, dim=-1)
        else:
            nll_loss = torch.zeros(1, dtype=logprobs.dtype, device=logprobs.device)

        total_loss = (
            policy_loss
            + self.kl_loss_coef * kl_loss
            + self.nll_loss_coef * nll_loss
            + self.entropy_loss_coef * entropy_loss
            + self.mtp_loss_coef * mtp_loss
        )
        stats = {
            "total_loss": total_loss,
            "policy_loss": policy_loss,
            "entropy_loss": entropy_loss,
            "kl_loss": kl_loss,
            "nll_loss": nll_loss,
            "mtp_loss": mtp_loss,
            "ratio": ratio,
            "seq_ratio_clip_upper_frac": seq_ratio_clip_upper_frac,
            "seq_ratio_clip_lower_frac": seq_ratio_clip_lower_frac,
            "ppl": ppl.mean(),
            "ppl_old": old_ppl.mean(),
            "train_staleness": staleness,
        }
        logger.info(f'loss stats: {stats}')
        if flat_masked_action_entropys.numel() > 0:
            stats.update({"entropy": entropy})

        if self.dump_token_level_log:
            # token level log return to megatron engine then return to trainer
            with torch.no_grad():
                actual_clip_high = ((ratio > 1.0 + self.cliphighrange) & (advantages > 0)) * action_mask
                actual_clip_low = ((ratio < 1.0 - self.cliprange) & (advantages < 0)) * action_mask
                clip_high = ((ratio > 1.0 + self.cliprange) & (advantages > 0)) * action_mask

            log_stats = [
                input_ids.detach().cpu().numpy().tolist()[0],
                advantages.detach().cpu().numpy().tolist()[0],
                action_entropys.detach().cpu().numpy().tolist()[0],
                action_mask.detach().cpu().numpy().tolist()[0],
                action_logprobs.detach().cpu().numpy().tolist()[0],
                old_action_logprobs.detach().cpu().numpy().tolist()[0],
                actual_clip_high.detach().cpu().numpy().tolist()[0],
                clip_high.detach().cpu().numpy().tolist()[0],
                actual_clip_low.detach().cpu().numpy().tolist()[0],
            ]
            stats["log_stats"] = log_stats

        stats.update(entropy_ratios)
        stats.update(means)
        if stats["ratio"].item() > 1e2:
            torch.set_printoptions(edgeitems=10000, linewidth=100000)
            ratio_value = stats["ratio"].item()
            logger.info(f"abnormal ratio!!! current ratio value: {ratio_value}")
            logger.info(f"action mask: {action_mask}")
            logger.info(f"action_logprobs: {action_logprobs}")
            logger.info(f"old_action_logprobs: {old_action_logprobs}")
            logger.info(f"log_ratio: {mask_mean_log_ratio}")
            logger.info(f"ratio: {ratio}")

        # The allreduce between dp will be done in trainer
        # stats = average_metrics_across_data_parallel_group(stats)
        for k, v in stats.items():
            if k.endswith("log_stats"):
                stats[k] = v
            else:
                stats[k] = v.item()

        return total_loss, stats


def get_last_token_rm_score(batch: Dict[str, Any], output: torch.Tensor, non_loss_data: bool = True):
    return output[:, -1].float()


def gather_values_dist(batch: Dict[str, Any], output: torch.Tensor, non_loss_data: bool = True):
    return output[:, :-1].float()


def reward_to_binary(reward):
    return (reward + 1) / 2