# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""

from collections import defaultdict
import copy
import numpy as np
import torch
import random


import verl.utils.torch_functional as verl_F


class AdaptiveKLController:
    """
    Adaptive KL controller described in the paper:
    https://arxiv.org/pdf/1909.08593.pdf
    """

    def __init__(self, init_kl_coef, target_kl, horizon):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl, n_steps):
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:
    """Fixed KL controller."""

    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current_kl, n_steps):
        pass


def get_kl_controller(kl_ctrl):
    if kl_ctrl.type == "fixed":
        return FixedKLController(kl_coef=kl_ctrl.kl_coef)
    elif kl_ctrl.type == "adaptive":
        assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
    else:
        raise NotImplementedError


def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
    whiten: bool = True,
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        values: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
        gamma: `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)

    """
    with torch.no_grad():
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        if whiten:
            advantages = verl_F.masked_whiten(advantages, response_mask)
    return advantages, returns


# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: str = True,
):
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        norm_adv_by_std_in_grpo: (bool)
            whether to scale the GRPO advantage.
            If True, the advantage is scaled by the std, as in the original GRPO.
            If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                scores[i] = scores[i] - id2mean[index[i]]
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6):
    """
    Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = scores[i] - id2mean[index[i]]

        scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
        scores = verl_F.masked_whiten(scores, response_mask) * response_mask

    return scores, scores


def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6):
    """
    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            response_num = len(id2score[index[i]])
            if response_num > 1:
                scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1)
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor):
    """
    Compute advantage for REINFORCE++.
    This implementation is based on the paper: https://arxiv.org/abs/2501.03262
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """

    with torch.no_grad():
        returns = torch.zeros_like(token_level_rewards)
        running_return = 0

        for t in reversed(range(token_level_rewards.shape[1])):
            running_return = token_level_rewards[:, t] + gamma * running_return
            returns[:, t] = running_return
            # Reset after EOS
            running_return = running_return * response_mask[:, t]

        advantages = verl_F.masked_whiten(returns, response_mask)
        advantages = advantages * response_mask

    return advantages, returns


def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor):
    """
    Compute advantage for ReMax, operating only on Outcome reward
    This implementation is based on the paper: https://arxiv.org/abs/2310.10505

    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        reward_baselines: `(torch.Tensor)`
            shape: (bs,)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """

    with torch.no_grad():
        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask

    return advantages, returns


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
    kl = old_log_prob - ref_log_prob
    return token_level_scores - kl * kl_ratio


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
    """
    Aggregate the loss matrix into a scalar.
    Args:
        loss_mat: `(torch.Tensor)`
            shape: (bs, response_length)
        loss_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior
    Returns:
        loss: `a scalar torch.Tensor`
            aggregated loss
    """
    if loss_agg_mode == "token-mean":
        loss = verl_F.masked_mean(loss_mat, loss_mask)
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (torch.sum(loss_mask, dim=-1) + 1e-8)  # token-mean
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-sum-norm":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor
        # (loss_mask.shape[-1]) should ideally be constant
        # throughout training to well-replicate the DrGRPO paper.
        # TODO: Perhaps add user-defined normalizer argument to
        # agg_loss to ensure divisor stays constant throughout.
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss


def compute_policy_loss(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
    Args:
        old_log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        cliprange: (float)
            The clip range used in PPO. See https://arxiv.org/abs/1707.06347
        cliprange_low: (float)
            The lower clip range used in PPO.
        cliprange_high: (float)
            The higher clip range used in PPO.
        clip_ratio_c: (float) default: 3.0
            The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior

    Returns:
        pg_loss: `a scalar torch.Tensor`
            policy gradient loss computed via PPO
        pg_clipfrac: (float)
            the fraction of policy gradient loss being clipped
        ppo_kl: (float)
            the estimated KL divergence between the latest updating policy and the old sampling policy
        pg_clipfrac_lower: (float)
            the fraction of policy gradient loss being clipped when the advantage is negative
    """
    # assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."

    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)

    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange


    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A
    pg_losses = torch.maximum(pg_losses1, pg_losses2)  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    # pg_losses3 = -advantages * clip_ratio_c
    # clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    # pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

    # pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl



def compute_metric_reweight(values_in, response_mask):
    values = copy.deepcopy(values_in.detach())
    batch_size, seq_len = values.shape
    device = values.device
    
    flipped_mask = response_mask.flip(dims=[1])
    last_valid_positions = (flipped_mask == 1).long().argmax(dim=1, keepdim=True)
    last_valid_positions = seq_len - 1 - last_valid_positions
    
    last_valid_values = torch.gather(values, 1, last_valid_positions)
    filled_values = values.clone()
    filled_values = torch.where(response_mask.unsqueeze(-1) == 1, 
                               filled_values.unsqueeze(-1), 
                               last_valid_values.unsqueeze(-1)).squeeze(-1)
    
    A = filled_values.unfold(dimension=1, size=11, step=1).mean(dim=-1)
    padding_A = A[:, -1].unsqueeze(1).expand(-1, seq_len - A.shape[1])
    A = torch.cat([A, padding_A], dim=1)
    
    B = filled_values.unfold(dimension=1, size=10, step=1).mean(dim=-1)
    padding_B = B[:, 0].unsqueeze(1).expand(-1, seq_len - B.shape[1])
    B = torch.cat([padding_B, B], dim=1)
    
    result = A - B
    result = result * response_mask
    
    return torch.exp(result)


def compute_policy_loss_advantage_reweight(
    old_log_prob,
    log_prob,
    entropy,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
    prob_alpha=1.0,
    prob_eps=0.01,
    adv_reweight_upper_bound=1,
    adv_reweight_lower_bound=1e-4,
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
    Args:
        old_log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        cliprange: (float)
            The clip range used in PPO. See https://arxiv.org/abs/1707.06347
        cliprange_low: (float)
            The lower clip range used in PPO.
        cliprange_high: (float)
            The higher clip range used in PPO.
        clip_ratio_c: (float) default: 3.0
            The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior

    Returns:
        pg_loss: `a scalar torch.Tensor`
            policy gradient loss computed via PPO
        pg_clipfrac: (float)
            the fraction of policy gradient loss being clipped
        ppo_kl: (float)
            the estimated KL divergence between the latest updating policy and the old sampling policy
        pg_clipfrac_lower: (float)
            the fraction of policy gradient loss being clipped when the advantage is negative
    """
    # assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."

    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)

    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)


    # advantage reweight
    # Metric = 1/3 * (3.0 - torch.minimum(2 * torch.ones_like(entropy), entropy)) / (torch.exp(log_prob) + prob_eps) ** prob_alpha
    # Metric = compute_metric_reweight(entropy, response_mask)

    # Metric = torch.exp(-compute_gae_advantage_return(torch.zeros_like(log_prob.detach()), entropy.detach(), response_mask, 1.0, 0.95, whiten=False)[0])
    # Metric = torch.clamp(Metric, min=0, max=10)


    # entropy_quantiles = torch.quantile(entropy, q=0.99, dim=-1) # (batch_size,)
    # bound_max = torch.clamp(entropy_quantiles ** 2, min=adv_reweight_lower_bound, max=adv_reweight_upper_bound) # (batch_size,)
    # Metric = torch.clip(entropy ** 2, min=bound_max.unsqueeze(-1)/10000, max=bound_max.unsqueeze(-1))
    Metric = torch.clip(entropy ** 2, min=adv_reweight_lower_bound, max=adv_reweight_upper_bound)

    total_advantages = (advantages * response_mask).sum(dim=1, keepdim=True)  # (batch_size, 1)
    total_metrics = (Metric * response_mask).sum(dim=1, keepdim=True)  # (batch_size, 1)
    total_metrics = torch.clamp(total_metrics, min=1e-6)
    weight_ratio = Metric / total_metrics  # (batch_size, sequence_length)
    redistributed_advantages = total_advantages * weight_ratio
    advantages = torch.where(response_mask == 1, redistributed_advantages, advantages)

    pg_losses1 = -advantages * ratio
    
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange

    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A
    pg_losses = torch.maximum(pg_losses1, pg_losses2)  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    # pg_losses3 = -advantages * clip_ratio_c
    # clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    # pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

    # pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl

def compute_policy_loss_gspo(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    loss_agg_mode="seq-mean-token-mean",
    clip_ratio=0.0003,
    clip_ratio_low=None,
    clip_ratio_high=None,
):
    """
    Compute the clipped policy objective and related metrics for GSPO.

    See https://arxiv.org/pdf/2507.18071 for more details.

    Args:
        old_log_prob (torch.Tensor):
            Log-probabilities of actions under the old policy, shape (batch_size, response_length).
        log_prob (torch.Tensor):
            Log-probabilities of actions under the current policy, shape (batch_size, response_length).
        advantages (torch.Tensor):
            Advantage estimates for each action, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean".
    """

    clip_ratio_low = clip_ratio_low if clip_ratio_low is not None else clip_ratio
    clip_ratio_high = clip_ratio_high if clip_ratio_high is not None else clip_ratio

    negative_approx_kl = log_prob - old_log_prob

    # compute sequence-level importance ratio:
    # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =
    # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]
    seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)
    negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths

    # Combined ratio at token level:
    # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]
    # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]
    log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)
    log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)  # clamp for numerical stability

    # finaly exp() to remove log
    seq_importance_ratio = torch.exp(log_seq_importance_ratio)

    pg_losses1 = -advantages * seq_importance_ratio
    pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
    pg_losses = torch.maximum(pg_losses1, pg_losses2)

    # for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")

    # For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
    pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)

    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


def compute_policy_loss_8020_split(
    old_log_prob,
    log_prob,
    entropy,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
    Args:
        old_log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        log_prob: `(torch.Tensor)`
            shape: (bs, response_length)
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        cliprange: (float)
            The clip range used in PPO. See https://arxiv.org/abs/1707.06347
        cliprange_low: (float)
            The lower clip range used in PPO.
        cliprange_high: (float)
            The higher clip range used in PPO.
        clip_ratio_c: (float) default: 3.0
            The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior

    Returns:
        pg_loss: `a scalar torch.Tensor`
            policy gradient loss computed via PPO
        pg_clipfrac: (float)
            the fraction of policy gradient loss being clipped
        ppo_kl: (float)
            the estimated KL divergence between the latest updating policy and the old sampling policy
        pg_clipfrac_lower: (float)
            the fraction of policy gradient loss being clipped when the advantage is negative
    """
    # assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."

    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)

    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange


    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A
    pg_losses = torch.maximum(pg_losses1, pg_losses2)  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
    
    # update response_mask, find top 20% entropy idx
    all_valid = (response_mask > 0)
    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0].detach() 
    all_valid_entropy = entropy[all_valid].detach().reshape(-1).cpu()
    
    if len(all_valid_entropy) > 0:
        top_20_percent_num = max(1, int(len(all_valid_entropy) * 0.2))
        top_20_entropy_idx = torch.topk(all_valid_entropy, k=top_20_percent_num, largest=True).indices
            
        mask_all_tokens = torch.ones(len(all_valid_entropy), dtype=torch.bool)

        mask_all_tokens[top_20_entropy_idx] = False
        
        tokens_to_mask_idx = torch.nonzero(mask_all_tokens, as_tuple=True)[0]
        
        tokens_to_mask_original_idxs = all_valid_idx[tokens_to_mask_idx]
        
        if len(tokens_to_mask_original_idxs) > 0:
            response_mask[tokens_to_mask_original_idxs // advantages.shape[1], tokens_to_mask_original_idxs % advantages.shape[1]] = 0
    
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    # pg_losses3 = -advantages * clip_ratio_c
    # clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    # pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

    # pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl

def compute_policy_loss_kl_cov(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
    k_percent=0.2,
    ppo_kl_coef=1,
):
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange

    negative_approx_kl = log_prob - old_log_prob

    abs_kl = negative_approx_kl.abs()

    ratio = torch.exp(negative_approx_kl)

    ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)

    pg_losses1 = -advantages * ratio

    pg_losses_kl = - advantages * ratio + ppo_kl_coef * abs_kl

    pg_losses = pg_losses1

    all_valid = (response_mask > 0)
    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] 
    all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
    all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
    all_valid_prob = torch.exp(all_valid_logp)

    k = min(k_percent, len(all_valid_adv))

    if k != 0:
        cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
        k_percent_nums = max(1, int(len(cov_lst_all) * k / 100))
        large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
        
        if len(large_cov_idxs) != 0:
            
            large_cov_idxs = all_valid_idx[large_cov_idxs]
            pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]]

    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    return pg_loss, torch.tensor(0.), ppo_kl_abs



def compute_policy_loss_kl_minp(
    old_log_prob,
    log_prob,
    pos_tgt_log_prob,
    neg_tgt_log_prob,
    entropy,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
    logp_pos_k_percent=0.2,
    logp_neg_k_percent=0.2,
    ent_pos_k_percent=0.2,
    ent_neg_k_percent=0.2,
    dynamic_pos_k_percent=False,
    dynamic_neg_k_percent=False,
    dynamic_coef=1.0,
    use_clip=True,
    ppo_kl_coef=1,
    onpolicy=False, 
    use_adv_reweight=False,
    adv_reweight_upper_bound=1e-6,
    adv_reweight_lower_bound=1e-10,
    overlong_mask=False,
    use_coef_clip=False,
    kl_type="mse", # "abs", "mse" or "low_var_kl"
    kl_threshold=0,
    use_tgt_log_prob_reshape=False,
    kl_minp_ablation=False,
    use_kl_minp_fork=False,
    neglect_isr=False,
    select_old_log_prob=False,
    use_reverse_kl=False,
):
    print(f"my debug: kl_type: {kl_type}")
    if overlong_mask:
        is_overlong_sequence = torch.sum(response_mask, dim=1) == response_mask.shape[1]
        response_mask[is_overlong_sequence] = 0
        del is_overlong_sequence  

    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    
    if neglect_isr:
        print(f"my debug: neglect_isr=True, treating off-policy as on-policy (ISR=1)")
        old_log_prob = log_prob.detach()

    negative_approx_kl = log_prob - old_log_prob
    
    # Delay ratio calculation until needed
    # ratio = torch.exp(negative_approx_kl)  # moved later

    # Use no_grad to avoid gradient computation
    with torch.no_grad():
        ppo_kl_mean = verl_F.masked_mean(kl_penalty(log_prob, old_log_prob, kl_type), response_mask)

    # if use_tgt_log_prob_reshape:
    #     pos_tgt_log_prob = torch.where(log_prob > pos_tgt_log_prob, log_prob.detach() - 1, pos_tgt_log_prob)
    #     neg_tgt_log_prob = torch.where(log_prob > neg_tgt_log_prob, log_prob.detach() - 1, neg_tgt_log_prob)
    # if kl_minp_ablation:
    #     print(f"my debug: kl_minp_ablation = {kl_minp_ablation}, use old_log_prob as tgt_log_prob")
    #     pos_tgt_log_prob = torch.zeros_like(old_log_prob).detach()
    #     neg_tgt_log_prob = torch.zeros_like(old_log_prob).detach()

    # update advantages
    if use_adv_reweight:
        # Avoid redundant computation while preserving original data
        Metric = entropy.square().clamp(min=adv_reweight_lower_bound, max=adv_reweight_upper_bound)
        
        # Use more efficient computation method
        response_mask_sum = response_mask.sum(dim=1, keepdim=True)
        valid_sequences = response_mask_sum > 0
        
        total_advantages = (advantages * response_mask).sum(dim=1, keepdim=True)
        total_metrics = (Metric * response_mask).sum(dim=1, keepdim=True).clamp_(min=1e-6)
        
        # Calculate weight ratio
        weight_ratio = Metric / total_metrics
        redistributed_advantages = total_advantages * weight_ratio
        
        # Update advantages (original advantages need to be modified here)
        mask = (response_mask == 1) & valid_sequences
        advantages = torch.where(mask, redistributed_advantages, advantages)
        
        del Metric, weight_ratio, redistributed_advantages, mask  # Explicitly release

    # Now calculate ratio
    ratio = torch.exp(negative_approx_kl)
    
    # Note: Cannot use in-place operations here as advantages will be used later
    if not use_coef_clip:
        pg_losses1 = -advantages * ratio
    else:
        pg_losses1 = -advantages * torch.clamp(torch.exp(log_prob.detach() - old_log_prob.detach()), 1 - cliprange_low, 1 + cliprange_high) * log_prob
    
    if use_clip:
        # Reuse ratio clamp operation to avoid creating new large tensors
        ratio_clipped = ratio.clamp(1 - cliprange_low, 1 + cliprange_high)
        pg_losses2 = -advantages * ratio_clipped
        pg_losses = torch.maximum(pg_losses1, pg_losses2)
        
        # Use no_grad when calculating clipfrac
        with torch.no_grad():
            pg_clipfrac = verl_F.masked_mean((pg_losses2 > pg_losses1).float(), response_mask)
        del ratio_clipped  # Explicitly release
    else:
        pg_losses = pg_losses1
        with torch.no_grad():
            if not use_coef_clip:
                pg_losses2 = advantages.mul(-ratio.clamp(1 - cliprange_low, 1 + cliprange_high))
                pg_clipfrac = verl_F.masked_mean((pg_losses2 > pg_losses1).float(), response_mask)
                del pg_losses2  # Release immediately
            else:
                ratio_tmp = torch.exp(log_prob.detach() - old_log_prob.detach())
                clipped_condition = (ratio_tmp < 1 - cliprange_low) | (ratio_tmp > 1 + cliprange_high)
                pg_clipfrac = verl_F.masked_mean(clipped_condition.float(), response_mask)        
    
    del pg_losses1  # Explicitly release

    # Use more efficient way to find valid indices
    all_valid_mask = response_mask > 0
    all_valid_flat_idx = all_valid_mask.reshape(-1).nonzero(as_tuple=True)[0]
    
    if len(all_valid_flat_idx) == 0:
        pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
        return pg_loss, pg_clipfrac, ppo_kl_mean

    # Use reshape instead of view (safer, handles non-contiguous tensors)
    batch_size, seq_len = advantages.shape
    advantages_flat = advantages.reshape(-1)
    if select_old_log_prob:
        print(f"my debug: select_old_log_prob=True, using old_log_prob")
        log_prob_flat = old_log_prob.reshape(-1)
    else:
        log_prob_flat = log_prob.reshape(-1)
    entropy_flat = entropy.reshape(-1)
    
    # Use indices instead of masks to reduce memory usage
    pos_indices = all_valid_flat_idx[advantages_flat[all_valid_flat_idx] > 0]
    neg_indices = all_valid_flat_idx[advantages_flat[all_valid_flat_idx] < 0]
    
    # Handle dynamic clipping
    pos_kl_mean, neg_kl_mean, total_reg_token_num = torch.tensor(0.), torch.tensor(0.), torch.tensor(0.)
    if dynamic_pos_k_percent or dynamic_neg_k_percent:
        print(f"my debug kl_minp-1: dynamic_pos_k_percent: {dynamic_pos_k_percent}, dynamic_neg_k_percent: {dynamic_neg_k_percent}")
        with torch.no_grad():  # These calculations don't need gradients
            # Recalculate clip condition since pg_losses2 may have been deleted
            pg_losses1_temp = -advantages * ratio
            ratio_clipped = ratio.clamp(1 - cliprange_low, 1 + cliprange_high)
            pg_losses2_temp = -advantages * ratio_clipped
            clip_condition = (pg_losses2_temp > pg_losses1_temp) & all_valid_mask
            del pg_losses2_temp, ratio_clipped, pg_losses1_temp
            
        if dynamic_pos_k_percent and len(pos_indices) > 0:
            clip_mask_pos = clip_condition & (advantages > 0)
            clip_idxs_pos = clip_mask_pos.nonzero()
            
            if len(clip_idxs_pos) > 0:
                logp_pos_k_percent = min(dynamic_coef * len(clip_idxs_pos) / len(pos_indices), 0.01)
                print(f"my debug: clip_idxs_pos_count: {len(clip_idxs_pos)}, pos_adv_count: {len(pos_indices)}, logp_pos_k_percent: {logp_pos_k_percent}")
                
                # Convert (N, 2) format indices to flat indices
                batch_idx, seq_idx = clip_idxs_pos[:, 0], clip_idxs_pos[:, 1]
                pos_flat_idx = batch_idx * seq_len + seq_idx
                # Use unified function to calculate KL penalty
                pos_kl_mean, reg_token_num = apply_kl_penalty_selective(pos_flat_idx, pos_tgt_log_prob, "pos_dynamic_clip")
                total_reg_token_num += reg_token_num
                del clip_idxs_pos, batch_idx, seq_idx, pos_flat_idx
        
        if dynamic_neg_k_percent and len(neg_indices) > 0:
            clip_mask_neg = clip_condition & (advantages < 0)
            clip_idxs_neg = clip_mask_neg.nonzero()
            
            if len(clip_idxs_neg) > 0:
                logp_neg_k_percent = min(dynamic_coef * len(clip_idxs_neg) / len(neg_indices), 0.01)
                print(f"my debug: clip_idxs_neg_count: {len(clip_idxs_neg)}, neg_adv_count: {len(neg_indices)}, logp_neg_k_percent: {logp_neg_k_percent}")
                
                # Convert (N, 2) format indices to flat indices
                batch_idx, seq_idx = clip_idxs_neg[:, 0], clip_idxs_neg[:, 1]
                neg_flat_idx = batch_idx * seq_len + seq_idx
                # Use unified function to calculate KL penalty
                neg_kl_mean, reg_token_num = apply_kl_penalty_selective(neg_flat_idx, neg_tgt_log_prob, "neg_dynamic_clip")
                total_reg_token_num += reg_token_num
                del clip_idxs_neg, batch_idx, seq_idx, neg_flat_idx

    # Unified KL penalty application function to reduce code duplication
    def apply_kl_penalty_selective(target_flat_idx, tgt_log_prob, debug_name):
        if len(target_flat_idx) == 0:
            return torch.tensor(0.), torch.tensor(0)
            
        # Batch convert indices
        batch_idx = target_flat_idx // seq_len
        seq_idx = target_flat_idx % seq_len

        # Calculate number of tokens where log_prob[batch_idx, seq_idx].detach() < tgt_log_prob[batch_idx, seq_idx]
        reg_token_num = (log_prob[batch_idx, seq_idx].detach() < tgt_log_prob[batch_idx, seq_idx]).sum().cpu()

        # Update batch_idx, seq_idx, keep only tokens satisfying log_prob[batch_idx, seq_idx].detach() < tgt_log_prob[batch_idx, seq_idx]
        mask = log_prob[batch_idx, seq_idx].detach() < tgt_log_prob[batch_idx, seq_idx]
        batch_idx = batch_idx[mask]
        seq_idx = seq_idx[mask]
        
        if use_kl_minp_fork:
            pg_losses[batch_idx, seq_idx] = 0
            del batch_idx, seq_idx  # Release immediately
            return torch.tensor(0.), reg_token_num
        else:
            # Batch calculate KL penalty
            if use_tgt_log_prob_reshape:
                target_kl = torch.clamp((log_prob[batch_idx, seq_idx].detach() - tgt_log_prob[batch_idx, seq_idx]) / torch.exp(log_prob[batch_idx, seq_idx].detach()), min=-2, max=2) * torch.exp(log_prob[batch_idx, seq_idx])
            else:
                if use_reverse_kl:
                    print(f"my debug: use_reverse_kl=True, using reverse KL")
                    target_kl = kl_penalty(tgt_log_prob[batch_idx, seq_idx], log_prob[batch_idx, seq_idx], kl_type)
                else:
                    target_kl = kl_penalty(log_prob[batch_idx, seq_idx], tgt_log_prob[batch_idx, seq_idx], kl_type)
    
            
            # Batch update
            if not use_coef_clip:
                pg_losses[batch_idx, seq_idx] = -advantages[batch_idx, seq_idx] * ratio[batch_idx, seq_idx] + ppo_kl_coef * target_kl
            else:
                pg_losses[batch_idx, seq_idx] = -advantages[batch_idx, seq_idx] * torch.clamp(torch.exp(log_prob[batch_idx, seq_idx].detach() - old_log_prob[batch_idx, seq_idx].detach()), 1 - cliprange_low, 1 + cliprange_high) * log_prob[batch_idx, seq_idx] + ppo_kl_coef * target_kl
            tgt_kl_mean = torch.mean(target_kl)

            if kl_minp_ablation:
                # Find tokens in [batch_idx, seq_idx] where log_prob > tgt_log_prob, then set pg_losses to 0 for these tokens
                minp_mask = log_prob[batch_idx, seq_idx] > tgt_log_prob[batch_idx, seq_idx]
                if minp_mask.any():
                    pg_losses[batch_idx[minp_mask], seq_idx[minp_mask]] = 0

            
            print(f"my debug: {debug_name}_len: {len(target_flat_idx)}")
            del target_kl, batch_idx, seq_idx  # Release immediately
            return tgt_kl_mean, reg_token_num

    # Batch process topk and KL penalty to reduce memory peak
    # Process positive samples
    if len(pos_indices) > 0:
        # Small logp
        pos_logp_k = max(1, int(len(pos_indices) * logp_pos_k_percent))
        if pos_logp_k > 1:
            with torch.no_grad():
                pos_logp_values = log_prob_flat[pos_indices]
                _, pos_logp_indices = pos_logp_values.topk(k=pos_logp_k, largest=False)
                pos_target_idx = pos_indices[pos_logp_indices]
                del pos_logp_values, pos_logp_indices
            pos_kl_mean, reg_token_num = apply_kl_penalty_selective(pos_target_idx, pos_tgt_log_prob, "pos_small_logp")
            total_reg_token_num += reg_token_num
            del pos_target_idx
            
        # Large entropy
        pos_ent_k = max(1, int(len(pos_indices) * ent_pos_k_percent))
        if pos_ent_k > 1:
            with torch.no_grad():
                pos_entropy_values = entropy_flat[pos_indices]
                _, pos_ent_indices = pos_entropy_values.topk(k=pos_ent_k, largest=True)
                pos_ent_target_idx = pos_indices[pos_ent_indices]
                del pos_entropy_values, pos_ent_indices
            pos_kl_mean, reg_token_num = apply_kl_penalty_selective(pos_ent_target_idx, pos_tgt_log_prob, "pos_large_entropy")
            total_reg_token_num += reg_token_num
            del pos_ent_target_idx
    
    # Process negative samples
    if len(neg_indices) > 0:
        # Small logp
        neg_logp_k = max(1, int(len(neg_indices) * logp_neg_k_percent))
        if neg_logp_k > 1:
            with torch.no_grad():
                neg_logp_values = log_prob_flat[neg_indices]
                if logp_neg_k_percent == 1:
                    neg_logp_indices = torch.arange(len(neg_indices))
                else:
                    _, neg_logp_indices = neg_logp_values.topk(k=neg_logp_k, largest=False)
                neg_target_idx = neg_indices[neg_logp_indices]
                del neg_logp_values, neg_logp_indices
            neg_kl_mean, reg_token_num = apply_kl_penalty_selective(neg_target_idx, neg_tgt_log_prob, "neg_small_logp")
            total_reg_token_num += reg_token_num
            del neg_target_idx
            
        # Large entropy
        neg_ent_k = max(1, int(len(neg_indices) * ent_neg_k_percent))
        if neg_ent_k > 1:
            with torch.no_grad():
                neg_entropy_values = entropy_flat[neg_indices]
                _, neg_ent_indices = neg_entropy_values.topk(k=neg_ent_k, largest=True)
                neg_ent_target_idx = neg_indices[neg_ent_indices]
                del neg_entropy_values, neg_ent_indices
            neg_kl_mean, reg_token_num = apply_kl_penalty_selective(neg_ent_target_idx, neg_tgt_log_prob, "neg_large_entropy")
            total_reg_token_num += reg_token_num
            del neg_ent_target_idx
    
    # Clean up variables no longer needed
    del pos_indices, neg_indices, all_valid_flat_idx
    del advantages_flat, log_prob_flat, entropy_flat
    del ratio, negative_approx_kl
    
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
    # ppo_kl_mean = torch.rand(()).item() * 0.01 * torch.ones_like(ppo_kl_mean)  # sample from (0, 0.01) TODO: just for debug
    if kl_threshold > 0 and ppo_kl_mean > kl_threshold:
        skip_optimizer_step = True # do not update policy
    else:
        skip_optimizer_step = False
    print(f"my debug: ppo_kl_mean: {ppo_kl_mean}, kl_threshold: {kl_threshold}, skip_optimizer_step: {skip_optimizer_step}")

    total_reg_frac = total_reg_token_num / (response_mask.sum().cpu() + 1e-6)
    return pg_loss, pg_clipfrac, ppo_kl_mean, pos_kl_mean, neg_kl_mean, skip_optimizer_step, total_reg_token_num, total_reg_frac


def compute_policy_loss_clip_cov(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    loss_agg_mode="token-mean",
    clip_ratio=0.0002,
    clip_cov_lb=1.0,
    clip_cov_ub=5.0,
):
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    
    corr = torch.ones_like(advantages)
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
    clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
    
    cov_all = (advantages- verl_F.masked_mean(advantages, response_mask)) * (log_prob- verl_F.masked_mean(log_prob.detach(), response_mask))
    cov_all[response_mask == 0] = -torch.inf
    cov_all[clip_by_origin] = -torch.inf
    
    clip_num = max(int(clip_ratio * response_mask.sum().item()), 1)
    top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
    top_k_idx = torch.nonzero(top_k_idx)
    
    if len(top_k_idx) > 0:
        perm = torch.randperm(len(top_k_idx))
        top_k_idx = top_k_idx[perm[:min(clip_num, len(top_k_idx))]]
    else:
        top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
    
    corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
    
    if len(top_k_idx) > 0:
        selected_log_probs = log_prob[top_k_idx[:, 0], top_k_idx[:, 1]]
        selected_probs = torch.exp(selected_log_probs)

        prob_max = selected_probs.max().item()
        prob_min = selected_probs.min().item()
        prob_10th = torch.quantile(selected_probs, 0.1).item()
        prob_90th = torch.quantile(selected_probs, 0.9).item()
        
        print(f"my debug: top_k_idx prob stats - max: {prob_max:.6f}, min: {prob_min:.6f}, 10th percentile: {prob_10th:.6f}, 90th percentile: {prob_90th:.6f}")
        print(f"my debug: top_k_idx count: {len(top_k_idx)}")
    else:
        print("my debug: top_k_idx is empty, no prob stats to report")
    
    pg_clipfrac = verl_F.masked_mean((corr==0).float(), response_mask)

    pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl

def compute_entropy_loss(logits, response_mask):
    """Compute Categorical entropy loss

    Args:
        logits: `(torch.Tensor)`
            shape: (bs, response_length, vocab_size)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)

    Returns:
        entropy: a scalar torch.Tensor

    """
    # compute entropy
    entropy = verl_F.entropy_from_logits(logits)  # (bs, response_len)
    entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
    return entropy_loss


def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
    """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151

    Args:
        vpreds (`torch.FloatTensor`):
            Predicted values of the value head, shape (`batch_size`, `response_length`)
        values (`torch.FloatTensor`):
            Old values of value head, shape (`batch_size`, `response_length`)
        returns: (`torch.FloatTensor`):
            Ground truth returns, shape (`batch_size`, `response_length`)

    Returns:
        vf_loss: a scalar (`torch.FloatTensor`):
            value function loss
        vf_clipfrac: a float
            The ratio of vf being clipped

    """
    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
    return vf_loss, vf_clipfrac


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
    """Compute KL divergence given logprob and ref_logprob.
    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104

    Args:
        logprob:
        ref_logprob:

    Returns:

    """
    if kl_penalty == "kl":
        kld = logprob - ref_logprob
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "abs":
        kld = (logprob - ref_logprob).abs()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "mse":
        kld = 0.5 * (logprob - ref_logprob).square()
        return torch.clamp(kld, min=-10, max=10)

    # J. Schulman. Approximating kl divergence, 2020.
    # # URL http://joschu.net/blog/kl-approx.html.
    if kl_penalty == "low_var_kl":
        kl = ref_logprob - logprob
        ratio = torch.exp(kl)
        kld = (ratio - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":
        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
        raise NotImplementedError

    raise NotImplementedError
