# Copyright 2022 The HuggingFace Team

from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import torch
import torch.nn.functional as F

from ..utils import torch_functional as VF


if TYPE_CHECKING:
    from .config import AlgorithmConfig


class KLController(ABC):
    kl_coef: float

    @abstractmethod
    def update(self, current_kl: float, n_steps: int):
        ...


class AdaptiveKLController(KLController):

    def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
        self.kl_coef = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl: float, n_steps: int):
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.kl_coef *= mult


class FixedKLController(KLController):

    def __init__(self, init_kl_coef: float):
        self.kl_coef = init_kl_coef

    def update(self, current_kl: float, n_steps: int):
        pass


class AdvantageEstimator(str, Enum):

    GAE = "gae"
    GRPO = "grpo"
    GRPO_ANCHOR = "grpo_anchor"
    GRPO_ANCHOR_VAR_TEMP = "grpo_anchor_var_temp"
    GRPO_ANCHOR_VAR_TEMP_ANCHOR = "grpo_anchor_var_temp_anchor"
    GRPO_ANCHOR_VAR_TEMP_VARIANCE = "grpo_anchor_var_temp_variance"
    GRPO_PASSK = "grpo_passk"
    GRPO_RANK_NORM = "grpo_rank_norm"
    GRPO_FAITHFUL = "grpo_faithful"
    GRPO_ZSCORE_FAITHFUL = "grpo_zscore_faithful"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REMAX = "remax"
    RLOO = "rloo"


ADV_ESTIMATOR_MAP: dict[str, Any] = {}


def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
    if algorithm_config.kl_type == "fixed":
        kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
    elif algorithm_config.kl_type == "adaptive":
        assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
        kl_ctrl = AdaptiveKLController(
            init_kl_coef=algorithm_config.kl_coef,
            target_kl=algorithm_config.kl_target,
            horizon=algorithm_config.kl_horizon,
        )
    else:
        raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")

    return kl_ctrl


def register_adv_estimator(name: AdvantageEstimator):

    def decorator(fn):
        wrapped_fn = torch.no_grad()(fn)
        ADV_ESTIMATOR_MAP[getattr(name, "value", name)] = wrapped_fn
        return wrapped_fn

    return decorator


def compute_advantage_return(name: AdvantageEstimator, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
    return ADV_ESTIMATOR_MAP[getattr(name, "value", name)](**kwargs)


@register_adv_estimator(AdvantageEstimator.GAE)
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    nextvalues = 0
    lastgaelam = 0
    advantages_reversed = []
    gen_len = token_level_rewards.shape[-1]
    for t in reversed(range(gen_len)):
        delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
        gaelam = delta + gamma * lam * lastgaelam

        if response_mask[:, t]:
            nextvalues = values[:, t]
            lastgaelam = gaelam

        advantages_reversed.append(lastgaelam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    returns = advantages + values
    advantages = VF.masked_whiten(advantages, response_mask)
    return advantages, returns


@register_adv_estimator(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
        id2std[idx] = torch.std(torch.tensor(id2score[idx]))

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_ANCHOR)
def compute_grpo_anchor_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        score_group = torch.tensor(id2score[idx], device=scores.device, dtype=scores.dtype)
        r_ext = torch.cat([score_group, torch.tensor([0.0, 1.0], device=scores.device, dtype=scores.dtype)])
        mu = r_ext.mean()
        std = r_ext.std(unbiased=False)
        id2mean[idx] = mu
        id2std[idx] = std

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_ANCHOR_VAR_TEMP)
def compute_grpo_anchor_var_temp_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, 
        var_temp_power_low_var: float = 1.5, var_temp_power_high_var: float = 0.8, var_temp_tau: float = 5.0, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2temp = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        score_group = torch.tensor(id2score[idx], device=scores.device, dtype=scores.dtype)
        r_ext = torch.cat([score_group, torch.tensor([0.0, 1.0], device=scores.device, dtype=scores.dtype)])
        
        std_ext = r_ext.std(unbiased=False)
        mu_ext = r_ext.mean()
        
        uniform_std = (1.0 / 12.0) ** 0.5
        
        distance = (std_ext - uniform_std) / (uniform_std + eps)
        
        gate = torch.sigmoid(torch.tensor(var_temp_tau * distance, device=std_ext.device, dtype=std_ext.dtype))
        
        power = var_temp_power_low_var + gate * (var_temp_power_high_var - var_temp_power_low_var)
        
        temp = std_ext ** power
        
        id2mean[idx] = mu_ext
        id2temp[idx] = temp

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2temp[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_ANCHOR_VAR_TEMP_ANCHOR)
def compute_grpo_anchor_var_temp_anchor_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2std = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        score_group = torch.tensor(id2score[idx], device=scores.device, dtype=scores.dtype)
        r_ext = torch.cat([score_group, torch.tensor([0.0, 1.0], device=scores.device, dtype=scores.dtype)])
        mu_ext = r_ext.mean()
        std_ext = r_ext.std(unbiased=False)
        id2mean[idx] = mu_ext
        id2std[idx] = std_ext

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_ANCHOR_VAR_TEMP_VARIANCE)
def compute_grpo_anchor_var_temp_variance_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, 
        var_temp_power_low_var: float = 1.5, var_temp_power_high_var: float = 0.8, var_temp_tau: float = 5.0, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2mean, id2temp = {}, {}

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        score_group = torch.tensor(id2score[idx], device=scores.device, dtype=scores.dtype)
        
        std = score_group.std(unbiased=False)
        mu = score_group.mean()
        
        uniform_std = (1.0 / 12.0) ** 0.5
        
        distance = (std - uniform_std) / (uniform_std + eps)
        
        gate = torch.sigmoid(torch.tensor(var_temp_tau * distance, device=std.device, dtype=std.dtype))
        
        power = var_temp_power_low_var + gate * (var_temp_power_high_var - var_temp_power_low_var)
        
        temp = std ** power
        
        id2mean[idx] = mu
        id2temp[idx] = temp

    for i in range(bsz):
        scores[i] = (scores[i] - id2mean[index[i]]) / (id2temp[index[i]] + eps)

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_PASSK)
def compute_grpo_passk_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    advantages = torch.zeros_like(scores)
    id2score = defaultdict(list)
    id2indices = defaultdict(list)

    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])
        id2indices[index[i]].append(i)

    for idx in id2score:
        assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
        rewards = torch.tensor(id2score[idx])
        topk, topk_idx = torch.topk(rewards, k=2)
        r_max, r_second_max = topk[0], topk[1]
        i_max = id2indices[idx][topk_idx[0]]
        advantages[i_max] = (r_max - r_second_max) / (torch.std(torch.tensor(id2score[idx])) + eps)

    returns = advantages.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.RLOO)
def compute_rloo_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2sum = {}
    bsz = scores.shape[0]
    for i in range(bsz):
        id2score[index[i]].append(scores[i])

    for idx in id2score:
        id2sum[idx] = torch.sum(torch.tensor(id2score[idx]))

    for i in range(bsz):
        sample_num = len(id2score[index[i]])
        assert sample_num > 1, "RLOO needs rollout.n > 1."
        baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1)
        scores[i] = scores[i] - baseline

    returns = scores.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.REINFORCE_PLUS_PLUS)
def compute_reinforce_plus_plus_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    returns = torch.zeros_like(token_level_rewards)
    running_return = 0
    for t in reversed(range(token_level_rewards.shape[1])):
        running_return = token_level_rewards[:, t] + gamma * running_return
        returns[:, t] = running_return
        running_return = running_return * response_mask[:, t]

    advantages = VF.masked_whiten(returns, response_mask)
    return advantages, returns


@register_adv_estimator(AdvantageEstimator.REMAX)
def compute_remax_outcome_advantage(
    token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    advantages = (token_level_rewards.sum(dim=-1) - reward_baselines) * response_mask
    returns = (token_level_rewards * response_mask).flip(dims=(-1,)).cumsum(dim=-1).flip(dims=(-1,))
    return advantages, returns


def compute_rewards(
    token_level_scores: torch.Tensor,
    log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    kl_ratio: float,
) -> torch.Tensor:
    kl = log_probs - ref_log_probs
    return token_level_scores - kl * kl_ratio


def average_loss(
    values: torch.Tensor, mask: torch.Tensor, mode: Literal["token", "seq"], eps: float = 1e-8
) -> torch.Tensor:
    if mode == "token":
        return VF.masked_mean(values, mask, eps=eps)
    elif mode == "seq":
        return ((values * mask).sum(-1) / (mask.sum(-1) + eps)).mean()
    else:
        raise NotImplementedError(f"Unknown mode: {mode}.")


def compute_policy_loss(
    old_log_probs: torch.Tensor,
    log_probs: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    clip_ratio_low: float,
    clip_ratio_high: float,
    clip_ratio_dual: float,
    tau_positive: float,
    tau_negative: float,
    loss_type: Literal["default", "gspo", "gspo_token", "cispo", "sapo"],
    loss_avg_mode: Literal["token", "seq"],
    **kwargs,
) -> tuple[torch.Tensor, dict[str, float]]:
    negative_approx_kl = log_probs - old_log_probs
    if loss_type in ["gspo", "gspo_token"]:
        negative_approx_kl_in_seq = VF.masked_mean(negative_approx_kl, response_mask, dim=-1)
        if loss_type == "gspo_token":
            log_importance_ratio = negative_approx_kl_in_seq.detach().unsqueeze(-1) + log_probs - log_probs.detach()
        else:
            log_importance_ratio = negative_approx_kl_in_seq.unsqueeze(-1) * response_mask
    else:
        log_importance_ratio = negative_approx_kl

    ratio = torch.exp(torch.clamp(log_importance_ratio, -20.0, 20.0))
    clipped_ratio = torch.exp(
        torch.clamp(log_importance_ratio, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
    )

    metrics = {"ppo_kl": -negative_approx_kl}
    metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode)

    if loss_type == "cispo":
        final_pg_loss = -advantages * log_probs * clipped_ratio.detach()
    elif loss_type == "sapo":
        positive_token_mask =  (advantages >= 0).float()
        negative_token_mask =  (advantages < 0).float()
        gate_negative = 4.0 / tau_negative * torch.sigmoid(tau_negative * (ratio - 1.0))
        gate_positive = 4.0 / tau_positive * torch.sigmoid(tau_positive * (ratio - 1.0))
        final_pg_loss = -advantages * (positive_token_mask * gate_positive + negative_token_mask * gate_negative)
    else:
        pg_loss = -advantages * ratio
        pg_loss2 = -advantages * clipped_ratio
        pg_loss3 = -advantages * clip_ratio_dual

        clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2)
        metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float()
        clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3)
        final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
        metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()

    final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode)
    metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()}
    return final_pg_loss, metrics


def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    cliprange_value: float,
    loss_avg_mode: Literal["token", "seq"],
) -> tuple[torch.Tensor, dict[str, float]]:
    vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
    vf_loss1 = torch.square(vpreds - returns)
    vf_loss2 = torch.square(vpredclipped - returns)
    clipped_vf_losses = torch.max(vf_loss1, vf_loss2)
    vf_loss = 0.5 * average_loss(clipped_vf_losses, response_mask, mode=loss_avg_mode)
    metrics = {
        "vf_clipfrac": VF.masked_mean((vf_loss1 < vf_loss2).float(), response_mask).detach().item(),
        "vpred_mean": VF.masked_mean(vpreds, response_mask).detach().item(),
    }
    return vf_loss, metrics


def compute_kl(
    log_probs: torch.FloatTensor,
    ref_log_probs: torch.FloatTensor,
    kl_penalty: Literal["kl", "abs", "mse", "low_var_kl", "full"],
) -> torch.Tensor:
    log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
    if kl_penalty == "kl":
        return log_probs - ref_log_probs

    if kl_penalty == "abs":
        return (log_probs - ref_log_probs).abs()

    if kl_penalty == "mse":
        return 0.5 * (log_probs - ref_log_probs).square()

    if kl_penalty == "low_var_kl":
        kl = (ref_log_probs - log_probs).clamp(-20.0, 20.0)
        kld = (kl.exp() - kl - 1).contiguous()
        return torch.clamp(kld, min=-10.0, max=10.0)

    if kl_penalty == "full":
        return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)

    raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")


def _rank_normalize(rewards: torch.Tensor) -> torch.Tensor:
    K = rewards.shape[0]
    if K == 1:
        return torch.tensor([0.5], device=rewards.device, dtype=torch.float32)
    
    sorted_indices = torch.argsort(rewards, descending=False)
    
    ranks = torch.zeros_like(rewards, dtype=torch.float32)
    
    i = 0
    while i < K:
        current_value = rewards[sorted_indices[i]]
        tie_start = i
        while i < K:
            if torch.isclose(rewards[sorted_indices[i]], current_value).item():
                i += 1
            else:
                break
        tie_end = i
        
        avg_rank = (tie_start + tie_end + 1) / 2.0
        for j in range(tie_start, tie_end):
            ranks[sorted_indices[j]] = avg_rank
    
    ranks = (ranks - 1.0) / (K - 1.0) if K > 1 else ranks * 0.0 + 0.5
    return ranks


@register_adv_estimator(AdvantageEstimator.GRPO_RANK_NORM)
def compute_grpo_rank_norm_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1)
    id2score = defaultdict(list)
    id2indices = defaultdict(list)
    
    bsz = scores.shape[0]
    for i in range(bsz):
        idx = index[i]
        id2score[idx].append(scores[i])
        id2indices[idx].append(i)
    
    advantages = torch.zeros_like(scores)
    
    for idx in id2score:
        rewards = torch.stack(id2score[idx])
        indices = id2indices[idx]
        
        K = len(rewards)
        
        virtual_anchor = torch.tensor([1.0], device=rewards.device, dtype=rewards.dtype)
        rewards_ext = torch.cat([rewards, virtual_anchor])
        
        ranks_ext = _rank_normalize(rewards_ext)
        
        adv_ext = (ranks_ext - 0.5) * 2.0
        
        adv = adv_ext[:-1]
        
        for j, orig_idx in enumerate(indices):
            advantages[orig_idx] = adv[j]
    
    returns = advantages.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_FAITHFUL)
def compute_grpo_faithful_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: torch.Tensor,
    beta0: float = 5.0,
    delta: float = 0.2,
    eps_pos: float = 0.02,
    eps_neg: float = 0.05,
    alpha_max: float = 0.5,
    rel_gate: float = 0.25,
    rel_temp: float = 0.08,
    anchor_strength: float = 0.2,
    gamma_rng: float = 0.15,
    adv_clip: float = 2.5,
    eps: float = 1e-6,
    tol: float = 1e-12,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    device = token_level_rewards.device
    scores = token_level_rewards.sum(dim=-1).float()

    id2idx = defaultdict(list)
    bsz = scores.shape[0]
    for i in range(bsz):
        idx = index[i].item() if hasattr(index[i], 'item') else index[i]
        id2idx[idx].append(i)

    adv_all = torch.zeros_like(scores)

    for gid, idxs in id2idx.items():
        K = len(idxs)
        if K <= 1:
            adv_all[idxs[0]] = 0.0
            continue

        r = scores[idxs]
        mean = r.mean()
        std = r.std(unbiased=False)

        adv_grpo = (r - mean) / (std + eps)

        rmax = r.max()
        rmin = r.min()
        rng = (rmax - rmin).clamp_min(0.0)

        beta_eff = beta0 * (delta / (rng + delta))

        t, temp = 0.5, 0.1
        g = torch.tanh((rmax - t) / temp)
        shift = (-eps_pos) * (1 + g) / 2 + (eps_neg) * (1 - g) / 2
        tau = rmax + shift

        x = beta_eff * (r - rmax)
        xa = beta_eff * (tau - rmax)
        Z = torch.exp(x).sum() + torch.exp(xa)
        p = torch.exp(x) / Z

        shape = p - p.mean()

        sat = p.mean() - 1.0 / (K + 1)

        lam = torch.exp(-rng / gamma_rng)
        adv_anchor = shape + lam * sat

        anchor_std = adv_anchor.std(unbiased=False)
        if anchor_std > 1e-6:
            adv_anchor = adv_anchor / anchor_std
        adv_anchor = anchor_strength * adv_anchor

        rel_std = std / (mean.abs() + 0.1)
        alpha = alpha_max * torch.sigmoid((rel_gate - rel_std) / (rel_temp + tol))

        adv = (1.0 - alpha) * adv_grpo + alpha * adv_anchor

        adv_std = adv.std(unbiased=False)
        if adv_std > 0.05:
            adv = adv / adv_std

        if adv_clip is not None:
            adv = adv.clamp(-adv_clip, adv_clip)

        adv_all[idxs] = adv

    returns = adv_all.unsqueeze(-1) * response_mask
    return returns, returns


@register_adv_estimator(AdvantageEstimator.GRPO_ZSCORE_FAITHFUL)
def compute_grpo_zscore_faithful_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: torch.Tensor,
    eps_pos: float = 0.02,
    eps_neg: float = 0.05,
    adv_clip: float = 2.5,
    eps: float = 1e-6,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    scores = token_level_rewards.sum(dim=-1).float()
    
    id2idx = defaultdict(list)
    bsz = scores.shape[0]
    for i in range(bsz):
        idx = index[i].item() if hasattr(index[i], 'item') else index[i]
        id2idx[idx].append(i)
    
    adv_all = torch.zeros_like(scores)
    
    for gid, idxs in id2idx.items():
        K = len(idxs)
        if K <= 1:
            adv_all[idxs[0]] = 0.0
            continue
        
        r = scores[idxs]
        rmax = r.max()
        
        t, temp = 0.5, 0.1
        g = torch.tanh((rmax - t) / temp)
        shift = (-eps_pos) * (1 + g) / 2 + (eps_neg) * (1 - g) / 2
        tau = rmax + shift
        
        mu = (r.sum() + tau) / (K + 1)
        var = ((r - mu).pow(2).sum() + (tau - mu).pow(2)) / (K + 1)
        std = torch.sqrt(var + eps)
        
        adv = (r - mu) / std
        
        if adv_clip is not None:
            adv = adv.clamp(-adv_clip, adv_clip)
        
        adv_all[idxs] = adv
    
    returns = adv_all.unsqueeze(-1) * response_mask
    return returns, returns
