import torch
import numpy as np
import verl.utils.torch_functional as verl_F
from verl import DataProto
from sklearn.metrics import roc_curve
from typing import Dict


def get_distritbution_level_mask(data: DataProto, config):
    
    batch_size, response_length, topk = data.batch["old_log_prob_topk_values"].shape
    response_mask = data.batch["attention_mask"][:, - response_length:].bool()
    response_lens = response_mask.sum(-1)

    for i in range(batch_size):
        response_mask[i, response_lens[i] - 1] = False

    mask = response_mask.unsqueeze(-1).expand(-1, -1, topk)

    # min_p
    min_p = config.get("distribution_level_minp", 0)
    if min_p > 1e-5:
        old_prob = torch.exp(data.batch["old_log_prob_topk_values"])
        mask = mask & (old_prob > min_p)

    # top-p
    top_p = config.get("distribution_level_topp", 0)
    if 1 - top_p > 1e-5:
        old_prob = torch.exp(data.batch["old_log_prob_topk_values"])  # [B, L, K]
        sorted_prob, sorted_idx = torch.sort(old_prob, dim=-1, descending=True)
        cum_prob = torch.cumsum(sorted_prob, dim=-1)

        # 构造 top-p mask
        top_p_mask = cum_prob <= top_p
        # 至少保留第一个位置（因为第一个 token 总概率最高）
        top_p_mask[..., 0] = True

        # 将 top-p mask 映射回原顺序
        reverse_idx = torch.argsort(sorted_idx, dim=-1)
        top_p_mask_unsorted = torch.gather(top_p_mask, dim=-1, index=reverse_idx)

        mask = mask & top_p_mask_unsorted
    
    # choose position with high entropy (paper from 8020)
    min_entropy_proportion = config.get("min_entropy_proportion", 0)
    if min_entropy_proportion > 1e-5:
        entropy = -torch.sum(torch.exp(data.batch["old_log_prob_topk_values"]) * data.batch["old_log_prob_topk_values"], dim=-1)

        # 1) padding 位置设为 -inf，保证排在最后
        entropy_pad = torch.where(response_mask, entropy, torch.full_like(entropy, -float("inf")))

        # 2) 计算每行要选多少个 token：k_i = ceil(p * L_i)
        L = response_mask.sum(dim=1)                             # [B]
        k = torch.clamp((min_entropy_proportion * L.float()), min=1)        # [B]

        # 3) 只取每行前 k_mean 个最大熵位置
        k_mean = k.mean().floor().long().item()
        _, topk_idx = torch.topk(entropy_pad, k_mean, dim=1)
            
        # 4) 根据每行自己的 k_i 做裁剪，生成 keep mask
        batch_idx = torch.arange(topk_idx.size(0), device=topk_idx.device).unsqueeze(1)
        keep = torch.zeros_like(entropy, dtype=torch.bool)
        mask_len = torch.arange(k_mean, device=topk_idx.device).unsqueeze(0)  # [1, k_mean]
        keep_idx_mask = mask_len < k.unsqueeze(1)                             # [B, k_mean]
        keep[batch_idx.expand_as(topk_idx)[keep_idx_mask], topk_idx[keep_idx_mask]] = True
        keep = keep & response_mask
        mask = mask & keep.unsqueeze(-1)

    mask[:, :, 0] = True
    output = DataProto.from_dict(tensors={'distribution_level_mask': mask})
    return output


def calculate_logp_adv_cov(logp, adv, distribution_level_mask):
    # mask
    logp = logp[distribution_level_mask]
    adv = adv[distribution_level_mask]
    
    # mean
    u = logp.mean()
    v = adv.mean()
    
    # covariance
    cov = ((logp - u) * (adv - v)).mean()
    
    # standard deviations
    std_logp = logp.std()
    std_adv = adv.std()
    
    # avoid division by zero
    if std_logp == 0 or std_adv == 0:
        return 0.0
    
    # correlation
    corr = cov / (std_logp * std_adv)
    return corr



import torch

def suffix_mean_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [batch_size, seq_len]
    # eos_mask: [batch_size, seq_len], 1 for valid tokens, 0 for padding
    assert score.shape == eos_mask.shape
    
    # Flip 维度，在最后一个 token 上进行累加
    score_flipped = score.flip(dims=[1])
    mask_flipped = eos_mask.flip(dims=[1])

    # 只保留有效部分的累加和
    masked_score = score_flipped * mask_flipped
    cumsum = torch.cumsum(masked_score, dim=1).flip(dims=[1])  # [B, L]

    # 有效token的个数
    count = torch.cumsum(mask_flipped, dim=1).flip(dims=[1])  # [B, L]

    # 避免除以0（用1占位，不影响最终结果，因为分子也为0）
    count = torch.clamp(count, min=1)

    suffix_avg = cumsum / count

    # 为 padding 的位置重新置为 0（可选，根据需要也可以置为 NaN）
    suffix_avg = suffix_avg * eos_mask

    return suffix_avg


import torch

def suffix_geomean_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [batch_size, seq_len], should be > 0 for geometric mean
    # eos_mask: [batch_size, seq_len], 1 for valid tokens, 0 for padding
    assert score.shape == eos_mask.shape
    assert (score > 0).all(), "Score must be positive for geometric mean"

    # Flip for suffix computation
    score_flipped = score.flip(dims=[1])
    mask_flipped = eos_mask.flip(dims=[1])

    # log domain for numeric stability
    log_score = torch.log(score_flipped) * mask_flipped  # invalid tokens → 0
    log_cumsum = torch.cumsum(log_score, dim=1).flip(dims=[1])  # [B, L]
    
    count = torch.cumsum(mask_flipped, dim=1).flip(dims=[1])  # valid token count
    count = torch.clamp(count, min=1)  # avoid division by zero

    # geometric mean in log space: exp(sum log x / n)
    suffix_geomean = torch.exp(log_cumsum / count)

    # mask out padding tokens
    suffix_geomean = suffix_geomean * eos_mask

    return suffix_geomean


def find_best_threshold(pos_scores: list[float], neg_scores: list[float]) -> float:
    # 构建标签和分数
    y_true = [1] * len(pos_scores) + [0] * len(neg_scores)
    y_scores = pos_scores + neg_scores

    # 计算 ROC 曲线
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)

    # Youden's J statistic: TPR - FPR
    j_scores = tpr - fpr
    best_idx = np.argmax(j_scores)
    best_threshold = thresholds[best_idx]

    return float(best_threshold)


def masked_corrcoef(x, y, mask):
    x_masked = x[mask]
    y_masked = y[mask]
    if x_masked.numel() < 2:
        return torch.tensor(float('nan'))  # 避免 torch.corrcoef 报错
    return torch.corrcoef(torch.stack((x_masked, y_masked)))[0, 1].detach()


def calculate_tok_metric(token_level_reward, response_mask, accs, metric):
    mask = response_mask
    float_mask = response_mask.float()

    cnt = torch.cumsum(float_mask, dim=-1)  # 每个位置之前的 token 数量（包括自己）
    cnt[~response_mask] = 1
    tok_cumsum = torch.cumsum(token_level_reward * float_mask, dim=-1)

    # tok_cumsummean 只在有效位置做除法
    tok_cumsummean = torch.zeros_like(tok_cumsum)
    tok_cumsummean[mask] = (tok_cumsum / cnt)[mask]
    tok_sigmoidcumsummean = torch.sigmoid(tok_cumsummean)

    # 计算相关性
    metric["tok_logpratio_cor"] = masked_corrcoef(token_level_reward, accs, mask)
    metric["tok_cumsum_cor"] = masked_corrcoef(tok_cumsum, accs, mask)
    metric["tok_cumsummean_cor"] = masked_corrcoef(tok_cumsummean, accs, mask)
    metric["tok_sigmoidcumsummean_cor"] = masked_corrcoef(tok_sigmoidcumsummean, accs, mask)

    return cnt, tok_cumsum, tok_cumsummean, tok_sigmoidcumsummean, metric


def calculate_seq_metric(token_level_reward, response_mask, acc, metric):
    seq_reward_sum = verl_F.masked_sum(token_level_reward, response_mask, axis=-1)
    seq_reward_mean = verl_F.masked_mean(token_level_reward, response_mask, axis=-1)
    seq_mask = response_mask.sum(-1) > 0 
    metric["seq_sum_cor"] = masked_corrcoef(seq_reward_sum, acc, seq_mask)
    metric["seq_mean_cor"] = masked_corrcoef(seq_reward_mean, acc, seq_mask)

    return seq_reward_sum, seq_reward_mean, metric

def normal_dist_adv_qs_diff(original_distribution_level_adv_values, distribution_level_mask, token_level_reward, response_mask, beta, accs, metric):
    token_level_reward = original_distribution_level_adv_values[:, :, 0].clone()
    cnt = torch.cumsum(response_mask.float(), dim=-1)
    cnt[~response_mask] = 1
    logp_ratio_prefix_sum = torch.cumsum(token_level_reward.float(), dim=-1)
    logp_ratio_prefix_sum[~response_mask] = 0

    # used to normalize
    mean, std = logp_ratio_prefix_sum[response_mask].mean(), logp_ratio_prefix_sum[response_mask].std()

    # calculate the baseline
    logp_ratio_prefix_sum_rolled = torch.roll(logp_ratio_prefix_sum, 1, dims=-1)
    logp_ratio_prefix_sum_rolled[~response_mask] = 0
    logp_ratio_prefix_sum_rolled[:, 0] = 0
    baseline = logp_ratio_prefix_sum_rolled

    # calculate the distribution-level adv.
    distribution_level_q = (logp_ratio_prefix_sum_rolled.unsqueeze(-1) + original_distribution_level_adv_values)
    distribution_level_adv = ((distribution_level_q - mean) / std) - ((baseline - mean) / std).unsqueeze(-1)
    # distribution_level_adv_values_normed = torch.tanh(distribution_level_adv)
    
    return distribution_level_adv


def normal_dist_adv_qs_diff_minus_baseline(
    original_distribution_level_adv_values,  # [B, T, K] = A^{TD}_phi(y'_t | y_{<t})
    distribution_level_mask,                 # [B, T, K] bool
    token_level_reward,                      # [B, T]    = r_phi(y_t) for the *actually taken* token
    response_mask,                           # [B, T]    bool
    old_log_prob_topk_values,                # [B, T, K] = log pi_old(y'_t | y_{<t})
):
    """
    Returns z-scored A minus token-wise baseline:
    (A / std_q) - E_{y'~pi_old}[A / std_q]
    where std_q is computed from the prefix-sum Q along the *taken* path.
    """
    eps = 1e-12
    B, T, K = original_distribution_level_adv_values.shape
    mask_bt = response_mask
    mask_btk = distribution_level_mask.float()

    # 1) 用传入的 token_level_reward 构造 Q 的前缀和（taken path）
    #    Q_t = sum_{i<=t} r_phi(y_i) ; 我们只需要它的 std 作为尺度
    q_prefix = torch.cumsum(token_level_reward.float(), dim=-1)     # [B, T]
    q_prefix = torch.where(mask_bt, q_prefix, torch.zeros_like(q_prefix))

    # 2) 计算全局或逐样本 std。这里给“逐 batch 全局标量”版，稳一点可改成逐样本：
    #    std_q = q_prefix[mask].std()
    #    若想逐样本：std_q = q_prefix.masked_fill(~mask_bt, float('nan')).nanstd(dim=-1, keepdim=True)
    valid_vals = q_prefix[mask_bt]
    std_q = valid_vals.std() if valid_vals.numel() > 1 else q_prefix.new_tensor(1.0)
    std_q = std_q.clamp_min(eps)

    # 3) 得到 z-scored 的 advantage：A_norm = A / std_q
    A_norm = original_distribution_level_adv_values / std_q         # [B, T, K]

    # 4) 用概率做 token-wise baseline（在子词表上再归一化）
    w = old_log_prob_topk_values.exp() * mask_btk                   # [B, T, K]
    Z = w.sum(dim=-1, keepdim=True).clamp_min(eps)                  # [B, T, 1]
    baseline = (A_norm * w).sum(dim=-1, keepdim=True) / Z           # [B, T, 1]

    # 5) 扣除基线并 re-mask 无效位置
    A_tilde = (A_norm - baseline) * mask_btk

    return A_tilde


def normal_dist_adv_orignial(original_distribution_level_adv_values, distribution_level_mask, token_level_reward, response_mask, beta, accs, metric):
    distribution_level_adv_values_normed = verl_F.masked_whiten(original_distribution_level_adv_values, distribution_level_mask)
    upper_bound, lower_bound = torch.quantile(distribution_level_adv_values_normed[distribution_level_mask], 0.99), torch.quantile(distribution_level_adv_values_normed[distribution_level_mask], 0.01)
    distribution_level_adv_values_normed[distribution_level_adv_values_normed > upper_bound] = upper_bound
    distribution_level_adv_values_normed[distribution_level_adv_values_normed < lower_bound] = lower_bound
    
    return distribution_level_adv_values_normed

def masked_grpo(reward_tensor_original, mask_tensor, n_samples):
    """
    GRPO: Group Relative Policy Optimization baseline

    Args:
        reward_tensor_original: [batch_size, ...] 奖励张量
        mask_tensor: [batch_size, ...] bool 掩码，哪些位置有效
        n_samples: 每组的样本数（对应同一个 prompt 的 rollout 数）

    Returns:
        经过 baseline 归一化的奖励张量
    """
    reward_tensor = reward_tensor_original.clone()
    reward_tensor[~mask_tensor] = 0
    for start_pos in range(0, reward_tensor.shape[0], n_samples):
        # 取出当前组
        group_rewards = reward_tensor[start_pos : start_pos + n_samples]
        group_mask = mask_tensor[start_pos : start_pos + n_samples]
        # 只取有效位置的奖励
        valid_rewards = [group_rewards[i][group_mask[i]] for i in range(n_samples)]
        group_means = torch.cat([r.mean(dim=0, keepdim=True) for r in valid_rewards], dim=0)
        # baseline = 组内平均
        group_baseline = group_means.mean(dim=0, keepdim=True)

        # Advantage = r - baseline
        for i in range(n_samples):
            reward_tensor[start_pos + i][group_mask[i]] = (
                valid_rewards[i] - group_baseline
            )
    return reward_tensor

def masked_rloo(reward_tensor_original, mask_tensor, n_samples):
    reward_tensor = reward_tensor_original.clone()
    reward_tensor[~mask_tensor] = 0
    for start_pos in range(0, reward_tensor.shape[0], n_samples):
        cur_rewards_mean = torch.cat([reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) for pos in range(start_pos, start_pos + n_samples)], dim=0)
        cur_rewards_sum = cur_rewards_mean.sum()
        cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
        reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (
            reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]
            * (n_samples / (n_samples - 1))
            - cur_reward_baseline)
    return reward_tensor

def compute_advantage_return(data: DataProto, eos_mask: torch.Tensor, n_samples, config):
    # msa
    if config.reward_model.seq_agg =="msa":
        ref_log_labels = data.batch["ref_log_prob"]
        ref_log_prob_topk_values = data.batch["ref_log_prob_topk_values"]
        rm_log_labels = data.batch["old_log_prob"]
        rm_log_prob_topk_values = data.batch["old_log_prob_topk_values"]
        rm_scores = (rm_log_labels - ref_log_labels) * config.reward_model.model.beta_train
        distribution_level_adv_values = (rm_log_prob_topk_values - ref_log_prob_topk_values) * config.reward_model.model.beta_train
        distribution_level_mask = eos_mask.unsqueeze(-1) & (data.batch["old_log_prob_topk_values"] > config.distribution_level_minp)

        # set unvalid position as 0
        max_positions = data.batch["response_mask"].sum(-1)
        for i in range(data.batch["input_ids"].shape[0]):
            rm_scores[i, max_positions[i] :] = 0
            distribution_level_adv_values[i, max_positions[i] :] = 0

    else:
        rm_scores = data.batch["rm_scores"]
        distribution_level_adv_values = data.batch["distribution_level_adv_values"]
        distribution_level_mask = data.batch["distribution_level_mask"]

    reward_tensors = []
    metric = {}
    valid_response_length = data.batch['response_mask'].sum(-1)
    batch_size, response_length = data.batch['response_mask'].shape

    # take the position
    distribution_level_adv_values_normed = distribution_level_adv_values
    accs = data.batch['acc'].unsqueeze(1).repeat(1, data.batch["response_mask"].shape[1])

    with torch.no_grad():
        # logging key correlations
        seq_reward_sum, seq_reward_mean, metric = calculate_seq_metric(
            token_level_reward=rm_scores, 
            response_mask=data.batch['response_mask'].bool(), 
            acc=data.batch['acc'], metric=metric)

        # Masked those incorrect seq-level reward
        valid_response_length = eos_mask.sum(-1)
        # if config.reward_model.seq_agg == "sum":
        #     seq_reward_sum = verl_F.masked_sum(rm_scores, data.batch['response_mask'], axis=-1)
        #     pos_scores = torch.concat([(s).unsqueeze(0) for (s, a) in zip(seq_reward_sum, data.batch['acc']) if a.item() == 1]).numpy().tolist()
        #     neg_scores = torch.concat([(s).unsqueeze(0) for (s, a) in zip(seq_reward_sum, data.batch['acc']) if a.item() == 0]).numpy().tolist()
        #     seq_threshold = find_best_threshold(pos_scores, neg_scores)
        #     incorrect_mask = data.batch['response_mask'] & ((seq_reward_sum > seq_threshold) == data.batch['acc']).unsqueeze(-1)
        # elif config.reward_model.seq_agg == "mean":
        #     seq_reward_mean = verl_F.masked_mean(rm_scores, data.batch['response_mask'], axis=-1)
        #     pos_scores = torch.concat([(s).unsqueeze(0) for (s, a) in zip(seq_reward_mean, data.batch['acc']) if a.item() == 1]).numpy().tolist()
        #     neg_scores = torch.concat([(s).unsqueeze(0) for (s, a) in zip(seq_reward_mean, data.batch['acc']) if a.item() == 0]).numpy().tolist()
        #     seq_threshold = find_best_threshold(pos_scores, neg_scores)
        #     incorrect_mask = data.batch['response_mask'] & ((seq_reward_mean > seq_threshold) == data.batch['acc']).unsqueeze(-1)
        
        # Distribution-Level Reward
        if config.actor_rollout_ref.actor.distribution_level_coef > 0:
            # logging key correlations
            cnt, tok_cumsum, tok_cumsummean, tok_sigmoidcumsummean, metric = calculate_tok_metric(
                token_level_reward=distribution_level_adv_values[:, :, 0], 
                response_mask=data.batch['response_mask'].bool(), 
                accs=accs, metric=metric)

            if config.distribution_level_std == "qs_diff":
                distribution_level_adv_values_normed = normal_dist_adv_qs_diff(
                    distribution_level_adv_values, 
                    distribution_level_mask.bool(), 
                    rm_scores, 
                    data.batch['response_mask'].bool(), 
                    config.reward_model.model.beta_train, 
                    accs, metric)
            elif config.distribution_level_std == "qs_diff_minus_baseline":
                distribution_level_adv_values_normed = normal_dist_adv_qs_diff_minus_baseline(
                    distribution_level_adv_values, 
                    distribution_level_mask.bool(), 
                    rm_scores, 
                    data.batch['response_mask'].bool(), 
                    data.batch['old_log_prob_topk_values'])
            elif config.distribution_level_std == "orignial":
                distribution_level_adv_values_normed = normal_dist_adv_orignial(
                    distribution_level_adv_values, 
                    distribution_level_mask.bool(), 
                    rm_scores, 
                    data.batch['response_mask'].bool(), 
                    config.reward_model.model.beta_train, 
                    accs, metric)
            else:
                raise NotImplementedError
            
            # distribution_level_mask = distribution_level_mask & incorrect_mask.unsqueeze(-1)
            
        # TOKEN LEVEL REWARD
        if config.algorithm.reward_dpo_coef > 0:
        
            # 1. get token-level reward mask (mask the token-level reward for <eos> token)
            reward_mask = eos_mask.bool()
            reward_mask[
                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
                valid_response_length - 1,
            ] = False

            # 2. numerical normalization (rloo norm)
            if config.reward_model.prime_norm == "batch_norm":
                mini_batch = batch_size // config.trainer.n_gpus_per_node
                num_mini_batch = batch_size // mini_batch  # = 8

                # 拆分 reshape
                token_level_scores = rm_scores
                token_level_scores_split = token_level_scores.view(num_mini_batch, mini_batch, response_length)

                if config.reward_model.seq_agg == "sum":
                    reverse_cumsum = suffix_sum_with_mask(token_level_scores, reward_mask)
                    reverse_cumsum = reverse_cumsum.view(num_mini_batch, mini_batch, response_length)
                    reverse_cumsum = reverse_cumsum.view(num_mini_batch, mini_batch * response_length)
                    reverse_cumsum = reverse_cumsum.abs().max(-1).values.unsqueeze(-1).unsqueeze(-1)
                    token_level_scores_split = token_level_scores_split / (reverse_cumsum + 1e-6)
                    token_level_scores = token_level_scores_split.view(batch_size, response_length)
                elif config.reward_model.seq_agg == "mean":
                    reverse_cumsum = suffix_mean_with_mask(token_level_scores, reward_mask)
                    reverse_cumsum = reverse_cumsum.view(num_mini_batch, mini_batch, response_length)
                    reverse_cumsum = reverse_cumsum.view(num_mini_batch, mini_batch * response_length)
                    reverse_cumsum = reverse_cumsum.abs().max(-1).values.unsqueeze(-1).unsqueeze(-1)
                    token_level_scores_split = token_level_scores_split / (reverse_cumsum + 1e-6)
                    token_level_scores = token_level_scores_split.view(batch_size, response_length)
                elif config.reward_model.seq_agg == "msa":
                    prefix_cumsum = torch.cumsum(token_level_scores, dim=1)
                    prefix_cumsum = prefix_cumsum.view(num_mini_batch, mini_batch, response_length)
                    prefix_cumsum = prefix_cumsum.view(num_mini_batch, mini_batch * response_length)
                    prefix_cumsum = prefix_cumsum.abs().max(-1).values.unsqueeze(-1).unsqueeze(-1)
                    token_level_scores_split = token_level_scores_split / (prefix_cumsum + 1e-6)
                    token_level_scores = token_level_scores_split.view(batch_size, response_length)
                else:
                    raise NotImplementedError                                  
            if config.token_level_std == "rloo" and n_samples > 1:
                reward_tensor = masked_rloo(token_level_scores, reward_mask, n_samples) * config.algorithm.reward_dpo_coef
            else:
                reward_tensor = torch.where(reward_mask, token_level_scores, torch.tensor(0, dtype=token_level_scores.dtype))

            # 3. cumsum
            if config.reward_model.seq_agg =="mean":
                returns = suffix_mean_with_mask(reward_tensor, eos_mask=reward_mask)
            elif config.reward_model.seq_agg =="sum":
                returns = suffix_sum_with_mask(reward_tensor, eos_mask=reward_mask)
            elif config.reward_model.seq_agg =="gt_minus_value":
                baselines = exclude_current_prefix_mean(reward_tensor)
                baselines = torch.sigmoid(baselines)
                returns = data.batch["acc"].unsqueeze(-1) - baselines
            elif config.reward_model.seq_agg =="msa":
                returns = torch.zeros_like(reward_tensor, device=reward_tensor.device)
                for start_pos in range(0, reward_tensor.shape[0], n_samples):
                    score_prompt = reward_tensor[start_pos : start_pos + n_samples]
                    mask_prompt = eos_mask[start_pos : start_pos + n_samples]
                    cpr_prompt = torch.cumsum(score_prompt, dim=1)
                    cpr_prompt[~mask_prompt.bool()] = 0
                    baseline_prompt = verl_F.masked_mean(cpr_prompt, mask_prompt)
                    returns[start_pos : start_pos + n_samples] = cpr_prompt - baseline_prompt
                    returns[start_pos : start_pos + n_samples][~mask_prompt.bool()] = 0
            else:
                raise NotImplementedError

            reward_tensors.append(returns * config.algorithm.reward_dpo_coef)
            accs = data.batch['acc'].unsqueeze(1).repeat(1, data.batch["response_mask"].shape[1])[reward_mask.bool()]
            metric["reward_cor"] = torch.corrcoef(torch.concat((reward_tensor[reward_mask.bool()].unsqueeze(0), accs.unsqueeze(0)), dim=0))[0, 1]
            metric["return_cor"] = torch.corrcoef(torch.concat((returns[reward_mask.bool()].unsqueeze(0), accs.unsqueeze(0)), dim=0))[0, 1]

        if config.algorithm.reward_gt_coef > 0.0:
            reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32)
            reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool)
            reward_mask[
                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
                valid_response_length - 1] = True
            reward_tensor[
                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
                valid_response_length - 1] = data.batch['acc']
            reward_tensor = masked_rloo(reward_tensor, reward_mask, n_samples=n_samples) * config.algorithm.reward_gt_coef
            reward_tensors.append((reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]))
    
        returns = sum(reward_tensors)
        advantages = returns.clone()
        advantages = verl_F.masked_whiten(advantages, eos_mask)

        return advantages, returns, distribution_level_adv_values_normed, distribution_level_mask, metric, rm_scores, distribution_level_adv_values

def compute_adpa_policy_loss(
        old_log_prob, 
        log_prob, 
        advantages, 
        eos_mask, 
        distribution_level_mask,
        cliprange,
        clip_ratio_low,
        clip_ratio_high,
        policy_log_prob_topk_values,
        distribution_level_adv_values,
        old_log_prob_topk_values,
        clip_ratio_c=3.0):
    
    assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
    pg_losses = -advantages * ratio

    if clip_ratio_low is None:
        clip_ratio_low = cliprange
    if clip_ratio_high is None:
        clip_ratio_high = cliprange
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - clip_ratio_low, 1.0 + clip_ratio_high)
    clip_pg_losses1 = torch.max(pg_losses, pg_losses2)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), eos_mask)
    # We only apply the dual-clip when the advantage is negative.
    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

    pg_loss = verl_F.masked_mean(pg_losses, eos_mask)

    # step1: distribution-level KLD
    topk = policy_log_prob_topk_values.shape[2]
    mask_adv = distribution_level_mask
    topk_ppo_kl = verl_F.masked_mean(torch.exp(old_log_prob_topk_values) * (old_log_prob_topk_values - policy_log_prob_topk_values),mask_adv)
    
    # step2: distribution-level policy loss
    topk_prob_ratio = torch.exp(policy_log_prob_topk_values - old_log_prob_topk_values)
    topk_pg_losses = - distribution_level_adv_values * topk_prob_ratio
    topk_pg_losses2 = - distribution_level_adv_values * torch.clamp(topk_prob_ratio, 1.0 - cliprange, 1.0 + cliprange)
    topk_clip_pg_losses1 = torch.max(topk_pg_losses, topk_pg_losses2)
    topk_pg_clipfrac = verl_F.masked_mean(torch.gt(topk_pg_losses2, topk_pg_losses).float(), mask_adv)
    topk_pg_losses3 = -distribution_level_adv_values * clip_ratio_c
    topk_clip_pg_losses2 = torch.min(topk_pg_losses3, topk_clip_pg_losses1)
    topk_pg_clipfrac_lower = verl_F.masked_mean(torch.gt(topk_clip_pg_losses2, topk_pg_losses3) * (distribution_level_adv_values < 0).float(), mask_adv) # TODO
    topk_pg_losses = torch.where(distribution_level_adv_values < 0, topk_clip_pg_losses2, topk_clip_pg_losses1)

    # 这里要对整个词表算期望
    old_prob = torch.exp(old_log_prob_topk_values[mask_adv])
    topk_pg_loss = (topk_pg_losses[mask_adv] * old_prob).sum() / old_prob.sum()
    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower, topk_pg_loss, topk_pg_clipfrac, topk_ppo_kl, topk_pg_clipfrac_lower


def suffix_sum_with_mask(score: torch.Tensor, eos_mask: torch.Tensor):
    # score: [B, L]
    # eos_mask: [B, L]，1=有效，0=padding
    assert score.shape == eos_mask.shape

    score = score * eos_mask  # 先把 padding 部分变成 0
    flipped_score = score.flip(dims=[1])
    suffix_sum = torch.cumsum(flipped_score, dim=1).flip(dims=[1])  # [B, L]

    return suffix_sum

def exclude_current_prefix_mean(rewards: torch.Tensor) -> torch.Tensor:
    """
    计算 baseline[i][j] = 平均值( rewards[i][0 : j] )
    不包含当前位置 j 的 reward。

    Args:
        rewards (torch.Tensor): [batch, seqlen] 的 token-level reward 张量

    Returns:
        torch.Tensor: [batch, seqlen] baseline 张量
    """
    if rewards.dim() != 2:
        raise ValueError("rewards 必须是二维 [batch, seqlen] 张量")
    
    # 累加和
    cumsum = torch.cumsum(rewards, dim=1)

    # 前 j-1 项的和 -> 将累加和向右移 1 位
    sum_excl_current = torch.roll(cumsum, shifts=1, dims=1)
    sum_excl_current[:, 0] = 0  # 第一列没有前项

    # 分母：j-1（注意第一列设为 1 避免除以 0）
    seqlen = rewards.size(1)
    denom = torch.arange(1, seqlen+1, device=rewards.device, dtype=rewards.dtype) - 1
    denom[0] = 1  # 避免第一列 NaN

    # 平均值
    baseline = sum_excl_current / denom
    baseline[:, 0] = 0  # 第一列 baseline 明确设为 0

    return baseline