# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import verl
import verl.utils.torch_functional as verl_F
from verl.trainer.ppo.core_algos import compute_value_model_metrics

def compute_prime_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config, dpo_acc=0.5, ):
    # Use PRIME output as value model, which requires estimating a partition term. Here we simply use average difference.
    # Then apply GAE. This ensures equivalence with PRIME.
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam

    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores']
        q_tensor[eos_mask==0]=0
        V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:]=Q_tensor[:,:-1]
        Q_tensor[:,0]=0
        # for start_pos in range(0,q_tensor.shape[0], n_samples):
            # highlight: partition is temporarily modified, partition is always directly equal to acc-value_last, adding this allows prime to pull acc back up after training collapse
            # highlight 2: BT model and reward model cannot be directly unified, and BT unbound seems to introduce some problems. For stability, normalize iprm output to 0-1 range first. This is somewhat a last resort
            # highlight 3: directly use temperature scaling, which is derived from BT characteristics. Average the last token score to find the reference score for calculation
            # highlight 4: use overall acc to find baseline, note this is a biased estimate
            # partition = (data.batch['acc'][start_pos:start_pos+n_samples] - V_last[start_pos:start_pos+n_samples])
            # Q_tensor[start_pos:start_pos+n_samples] += partition.unsqueeze(-1)
            # Q_tensor[start_pos:start_pos+n_samples] = (Q_tensor[start_pos:start_pos+n_samples]-Q_tensor[start_pos:start_pos+n_samples].min())/(Q_tensor[start_pos:start_pos+n_samples].max()-Q_tensor[start_pos:start_pos+n_samples].min())

            # avg_score = V_last[start_pos:start_pos+n_samples].mean()
            # avg_reward = data.batch['acc'][start_pos:start_pos+n_samples].mean() # Note that avg_reward might occasionally be 0
            # avg_reward=torch.clamp(avg_reward,1/n_samples,1-1/n_samples)
            # baseline_score = avg_score + torch.log((1-avg_reward)/avg_reward)
            # Q_tensor[start_pos:start_pos+n_samples]=torch.sigmoid(Q_tensor[start_pos: start_pos+n_samples]-baseline_score)
        if config.reward_model.model.loss_type == 'dpo':
            avg_score = V_last.mean()
            avg_reward=data.batch['acc'].mean()
            baseline_score=avg_score+torch.log((1-avg_reward)/avg_reward)
            Q_tensor = torch.sigmoid(Q_tensor-baseline_score)
        else:
            # First estimate Q_0 based on accuracy. If Q_0 and the final correctness signal differ greatly (likely PRIME reward direction is wrong), lock V_last to soft bound and derive Q_0 backwards
            # Note that log ratio has the characteristic of being easy to be negative and hard to be positive, even CE loss can hardly overcome this. It's conceivable that nogt and rloo will have many negative gradients. Experiments found that when acc approaches 0, reestimation also approaches 0. Need targeted adjustments within theoretical limits
            soft_bound = 0.99 # Need to assume reward is not 0-1, otherwise sigmoid operation cannot be performed. If warping is needed, make sure to modify CE loss itself
            M = - torch.log(1/torch.tensor(soft_bound, device=Q_tensor.device)-1)
            prepend_q0 = torch.zeros_like(data.batch['acc'])
            beta = config.reward_model.model.beta_test

            if config.algorithm.q0_estimator == 'soft':
                for i in range(0, Q_tensor.shape[0]):
                    # if (torch.sigmoid(V_last[i]+prepend_q0[i])-0.5)*(data.batch['acc'][i]-0.5)<=0: # This Q0 estimation is biased
                    if True:
                        V_target = M * (2 * data.batch['acc'][i] - 1)
                        prepend_q0[i] = V_target - V_last[i]
                        data.batch['acc'][i] = torch.sigmoid(
                            V_target)  # Theoretically this should be the case, no need to keep a small negative advantage, policy model can already sample positive with high probability so no need for reinforcement
            elif config.algorithm.q0_estimator == 'mcts':
                # Guess a q0 based on rollout results, roughly representing a probability estimate. Due to the existence of Bellman equation. Since normally as long as there are positive values in the answer, the value will be very convex, so this is actually very similar to setting a large integer as baseline...
                if beta>0:
                    prepend_q0[:]=M
                else:
                    prepend_q0[:]=-M

            elif config.algorithm.q0_estimator == 'none':
                pass

            Q_tensor+=prepend_q0.unsqueeze(1)
            Q_tensor=torch.sigmoid(Q_tensor)

            # print('Q0 re-estimation rate: ', Q0_reestimation_count/Q_tensor.shape[0])

        # highlight: 7B model training basically shows acceleration in early stages and deceleration in later stages. To maximize value model advantages, mainly assignment advantages, bias estimation needs an extra layer of insurance. The simplest approach is to forcibly pull start and end points to 0.5-1, if the difference cannot achieve this, add a small bias to each token. This operation is actually reward shaping, dispersing reward based on value function to increase stability.

        # Q_tensor_increment = eos_mask.cumsum(dim=-1)
        # Q_tensor_increment[eos_mask==0]=0
        # Q_tensor_increment_max = Q_tensor_increment.max(dim=-1)[0]
        # Q_tensor_bias = data.batch['acc'] - torch.sigmoid(V_last)
        # Q_tensor += Q_tensor_bias.unsqueeze(-1) / Q_tensor_increment_max.unsqueeze(-1) * Q_tensor_increment


        Q_tensor[eos_mask==0]=0
        # V(t) = Q(t-1), V_0 should always be partition, equivalent to V_value needing to shift Q back by one position
        # Note that Q_tensor means V, don't confuse them

        # value clipping: Since Q optimization here is relatively unbound, need to forcibly clip to 0-1 range to avoid various instabilities
        # Q_tensor = torch.clamp(Q_tensor, min=0, max=1)

        # value normalizing: Unreasonable values are not allowed, so each answer's value needs to be normalized between 0 and 1
        # Q_max = Q_tensor.max(dim=-1)[0].unsqueeze(-1)
        # Q_max[Q_max<1]=1
        # Q_tensor /= Q_max
        # Q_tensor[eos_mask == 0] = 0
        #
        # Q_min = Q_tensor.min(dim=-1)[0].unsqueeze(-1)
        # Q_min[Q_min>0]=0
        # Q_tensor = 1-(1-Q_tensor)/(1-Q_min)
        # Q_tensor[eos_mask == 0] = 0



        # reward tensor needs to be preserved here
        token_level_rewards=torch.zeros_like(q_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = q_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = Q_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - Q_tensor[:, t]
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + Q_tensor
        advantages = verl_F.masked_whiten(advantages, eos_mask)
        # advantages = returns

        metrics=compute_value_model_metrics(Q_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_prime_value_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam


    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores'].clone()
        q_tensor[eos_mask==0]=0

        # Apply rloo to q_tensor
        def masked_rloo(reward_tensor_original, mask_tensor):
            reward_tensor = reward_tensor_original.clone()
            reward_tensor[~mask_tensor] = 0
            for start_pos in range(0, reward_tensor.shape[0], n_samples):
                cur_rewards_mean = torch.cat([
                    reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True)
                    for pos in range(start_pos, start_pos + n_samples)
                ],
                    dim=0)
                cur_rewards_sum = cur_rewards_mean.sum()
                cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
                reward_tensor[start_pos:start_pos + n_samples][
                    mask_tensor[start_pos:start_pos + n_samples]] = \
                    reward_tensor[start_pos:start_pos + n_samples][
                        mask_tensor[start_pos:start_pos + n_samples]] * (
                            n_samples / (n_samples - 1)) - cur_reward_baseline

            return reward_tensor
        q_tensor = masked_rloo(q_tensor, eos_mask.bool())

        V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:]=Q_tensor[:,:-1]
        Q_tensor[:,0]=0

        # normalize like prime, to be honest this norm is weird, not normalizing by value but by dpo reward-value
        # Q_tensor/= (V_last.abs().max()+ 1e-6)
        # reverse_cumsum = torch.cumsum(q_tensor.flip(dims=[1]), dim=-1)
        # Q_tensor/= (reverse_cumsum.abs().max() + 1e-6)

        # set Q0 like prime
        Q_tensor+=(data.batch['acc']-V_last).unsqueeze(-1)
        Q_tensor[ eos_mask == 0 ] = 0

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(q_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = q_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = Q_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - Q_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + Q_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(Q_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_reasonable_prime_value_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
    # Assume loss is DPO loss (CE loss cannot reverse the inherent problems of DPO loss),
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam


    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores']
        q_tensor[eos_mask==0]=0
        # Directly squeeze prm output to 0-1 range, the effect is simply to remove DPO's negative effect,
        for i in range(0, q_tensor.shape[0], n_samples):
            q_tensor_mean = q_tensor[i:i+n_samples, :][eos_mask[i:i+n_samples].bool()].mean()
            q_tensor[i:i+n_samples] -= q_tensor_mean
        q_tensor[eos_mask==0]=0
        V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:] = Q_tensor[:,:-1]
        Q_tensor[:,0]=0

        for i in range(0, Q_tensor.shape[0],n_samples):
            # estimate winrate for each prompt
            # win_rate_last = torch.sigmoid(V_last.unsqueeze(-1)-V_last.unsqueeze(0)).mean(dim=-1)
            win_rate_all = torch.sigmoid(Q_tensor[i:i+n_samples].unsqueeze(-1) - V_last[i:i+n_samples].unsqueeze(0).unsqueeze(0)).mean(dim=-1)
            global_acc = data.batch['acc'][i:i+n_samples].mean()
            # estimate_acc = 2*win_rate_all + global_acc - 1
            # estimate_acc may exceed 0-1, but can try not clipping first to see what happens
            # Another feasible definition:
            estimate_acc = (global_acc*win_rate_all)/(global_acc*win_rate_all+(1-global_acc)*(1-win_rate_all))
            Q_tensor[i:i+n_samples] = estimate_acc

        Q_tensor[ eos_mask==0 ] = 0

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(q_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = q_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = Q_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - Q_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + Q_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(Q_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_middle_prime_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config, linear=False, trust_acc=True):
    # The optimal for positive examples is 0. For example, calculate acc based on acc
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam
    beta = config.reward_model.model.get('beta_test', 0.05)

    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores']
        q_tensor[eos_mask==0]=0

        V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:] = Q_tensor[:,:-1]
        Q_tensor[:,0]=0

        for i in range(0, Q_tensor.shape[0],n_samples):
            # estimate margin
            group_acc = data.batch['acc'][i:i + n_samples].mean()
            warped_group_acc = group_acc / 2
            margin= beta * torch.log(warped_group_acc / (1 - warped_group_acc))
            if warped_group_acc == 0:
                margin = torch.zeros_like(margin)

            Q_tensor[i:i+n_samples] = margin

        Q_tensor_reward = Q_tensor.clone()
        Q_tensor_reward[eos_mask==0]=0
        Q_tensor_reward=Q_tensor_reward.clamp(max=0.)

        # At this step, Q_tensor_reward is bound to (-inf, 0) logits values
        # linear means convert to probability
        # non linear means directly sigmoid then multiply by 2
        if linear:
            Q_tensor = torch.exp(Q_tensor_reward/beta)
            Q_tensor[eos_mask==0]=0
        else:
            Q_tensor = torch.sigmoid(Q_tensor_reward)*2
            Q_tensor[eos_mask==0]=0

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(q_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = q_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = Q_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - Q_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + Q_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(Q_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_simple_upv(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config, linear=False, trust_acc=True):
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam
    beta = config.reward_model.model.get('beta_test', 0.05)

    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores']
        q_tensor[eos_mask==0]=0

        # Estimate confidence based on other samples

        # for start_pos in range(0, q_tensor.shape[0], n_samples):
        #     q_tensor[start_pos:start_pos+n_samples] = verl_F.masked_whiten(q_tensor[start_pos:start_pos+n_samples], eos_mask[start_pos:start_pos+n_samples])
        # q_tensor = verl_F.masked_whiten(q_tensor, eos_mask)

        V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:] = Q_tensor[:,:-1]
        Q_tensor[:,0]=0

        ratio = torch.exp(Q_tensor)
        ratio[torch.isnan(ratio)]=100 # 概率累计有上溢
        ratio[torch.isinf(ratio)]=100
        ratio[eos_mask==0]=0

        V_tensor = torch.zeros_like(Q_tensor)

        for i in range(0, Q_tensor.shape[0],n_samples):
            # estimate margin. For fair comparison with other value models, value estimation here cannot cheat, cannot know its own reward in advance
            group_acc = data.batch['acc'][i:i + n_samples]
            group_acc_rloo = (group_acc.sum(dim=-1,keepdims=True) - group_acc)/(n_samples-1)

            V_tensor[i:i+n_samples] = torch.clamp(group_acc_rloo.unsqueeze(-1)*ratio[i:i+n_samples],0,1)

        V_tensor[eos_mask==0]=0

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(V_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = V_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = V_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - V_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + V_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(V_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_adaptive_upv(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config, linear=False, trust_acc=True):
    # Meaning: According to derivation, adaptively adjust value model multiplier
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam
    beta = config.reward_model.model.get('beta_test', 0.05)

    with torch.no_grad():
        assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        q_tensor = data.batch['rm_scores']
        q_tensor[eos_mask==0]=0

        # V_last = q_tensor.sum(dim=-1)
        Q_tensor = q_tensor.cumsum(dim=-1)
        Q_tensor[:,1:] = Q_tensor[:,:-1]
        Q_tensor[:,0]=0

        V_tensor = torch.zeros_like(Q_tensor)
        ratio = torch.exp(Q_tensor)
        ratio[torch.isnan(ratio)]=100 # 概率累计有上溢
        ratio[torch.isinf(ratio)]=100
        ratio[eos_mask==0]=0

        for i in range(0, Q_tensor.shape[0],n_samples):
            # estimate margin. For fair comparison with other value models, value estimation here cannot cheat, cannot know its own reward in advance
            group_acc = data.batch['acc'][i:i + n_samples]
            # group_acc_rloo = (group_acc.sum(dim=-1,keepdims=True) - group_acc)/(n_samples-1)
            # V_tensor[i:i+n_samples] = torch.clamp(group_acc_rloo.unsqueeze(-1)*ratio[i:i+n_samples],0,1)

            # cheat mode
            V_tensor[i:i+n_samples] = torch.clamp(group_acc.mean()*ratio[i:i+n_samples],0,1)

        # Uniformly calculate xi-1, make value model output upper bound not exceed 1, lower bound not below 0, thus spreading value numbers apart
        r = (ratio-1)[eos_mask==1]
        V = V_tensor[eos_mask==1]
        print(f'max_ratio: {r.max()+1} min_V: {r.min()+1}')
        print(f'max_V: {V.max()} min_V: {V.min()}')
        th_pos = torch.where(r > 0, (1 - V) / r, torch.full_like(r, float('inf')))
        th_neg = torch.where(r < 0, -V / r, torch.full_like(r, float('inf')))

        # The real constraint for each position is the minimum of the two
        th_each = torch.min(th_pos, th_neg)  # shape (M,)

        # Allow 10% of samples to exceed limits
        xi_ratio = torch.quantile(th_each,0.1)

        if torch.isinf(xi_ratio):
            xi_ratio=1e-6

        print(f'xi: {1/xi_ratio+1}')

        V_tensor += (ratio-1)*xi_ratio

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(V_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = V_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = V_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - V_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + V_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(V_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_single_upv(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config, linear=False, trust_acc=True):
    # baseline = pi_theta/pi_ref \cdot acc,
    prompt_ids = data.batch['prompts']
    prompt_length = prompt_ids.shape[-1]
    valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
    gamma=1
    lam=config.algorithm.lam
    beta = config.reward_model.model.get('beta_test', 0.05)

    with torch.no_grad():
        # assert 'rm_scores' in data.batch.keys() and 'acc' in data.batch.keys()
        # q_tensor = data.batch['rm_scores']
        # q_tensor[eos_mask==0]=0

        # V_last = q_tensor.sum(dim=-1)
        # Q_tensor = q_tensor.cumsum(dim=-1)
        # Q_tensor[:,1:] = Q_tensor[:,:-1]
        # Q_tensor[:,0]=0

        log_ratio = data.batch['old_log_probs'] - data.batch['ref_log_prob']
        ratio = torch.exp(log_ratio)
        ratio[eos_mask==0]=0

        V_tensor = torch.zeros_like(ratio)

        for i in range(0, ratio.shape[0],n_samples):
            # estimate margin. For fair comparison with other value models, value estimation here cannot cheat, cannot know its own reward in advance
            group_acc = data.batch['acc'][i:i + n_samples]
            group_acc_rloo = (group_acc.sum(dim=-1,keepdims=True) - group_acc)/(n_samples-1)

            V_tensor[i:i+n_samples] = torch.clamp(group_acc_rloo.unsqueeze(-1)*ratio[i:i+n_samples],0,1)

        V_tensor[eos_mask==0]=0

        # calculate advantage of prime with lambda
        token_level_rewards=torch.zeros_like(V_tensor)
        token_level_rewards[
            torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
            valid_response_length - 1] = data.batch['acc']

        lastgaelam = 0
        advantages_reversed = []
        gen_len = V_tensor.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = V_tensor[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - V_tensor[:, t] #
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + V_tensor

        # calculate advantage from acc reward
        advantages_acc = eos_mask*data.batch['acc'].unsqueeze(-1)

        # combine advantages
        advantages = verl_F.masked_whiten(advantages*config.algorithm.reward_dpo_coef + advantages_acc*config.algorithm.reward_gt_coef, eos_mask)

        metrics=compute_value_model_metrics(V_tensor, eos_mask, data.batch['acc'], returns)

    return advantages, returns, metrics

def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
    # calculate rloo reward on different reward sources, and sum again
    def masked_rloo(reward_tensor_original, mask_tensor):
        reward_tensor = reward_tensor_original.clone()
        reward_tensor[~mask_tensor] = 0
        for start_pos in range(0, reward_tensor.shape[0], n_samples):
            cur_rewards_mean = torch.cat([
                reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True)
                for pos in range(start_pos, start_pos + n_samples)
            ],
                                         dim=0)
            cur_rewards_sum = cur_rewards_mean.sum()
            cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
            reward_tensor[start_pos:start_pos + n_samples][
                mask_tensor[start_pos:start_pos + n_samples]] = \
                reward_tensor[start_pos:start_pos + n_samples][
                    mask_tensor[start_pos:start_pos + n_samples]] * (
                        n_samples / (n_samples - 1)) - cur_reward_baseline

        return reward_tensor

    reward_tensors = []

    with torch.no_grad():

        if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.:
            reward_tensor = data.batch['rm_scores']
            reward_mask = eos_mask.bool()

            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)

        if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.:
            reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32)
            reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool)

            prompt_ids = data.batch['prompts']
            prompt_length = prompt_ids.shape[-1]
            valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)

            outcome_reward= data.batch['acc'].clone()
            coef = config.algorithm.reward_gt_coef
            if config.algorithm.reward_gt_coef<0: # this means that gt reward is only used to fix numerical errors. it will keep the final reward unchanged.
                reward_tensor_prime = data.batch['rm_scores']
                outcome_reward -= reward_tensor_prime.sum(dim=-1)
                coef = config.algorithm.reward_dpo_coef

            reward_mask[
                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
                valid_response_length - 1] = True
            reward_tensor[
                torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
                valid_response_length - 1] = outcome_reward



            reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * coef)

        final_reward_tensor = sum(reward_tensors)

        returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

        # returns = reward_tensors[1].flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + reward_tensors[0]

        advantages = returns.clone()
        advantages = verl_F.masked_whiten(advantages, eos_mask)

        return advantages, returns


def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta):
    cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid()
    cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
    return cur_dpo_loss

def compute_middle_ce_loss_rm(token_level_scores, acc, eos_mask, beta, margin=None):
    if margin==None:
        margin = torch.zeros_like(acc)
    cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta + margin).sigmoid()
    cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc/2)
    return cur_dpo_loss

def compute_margin_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta, gamma=1.0):
    cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta - gamma*(acc*2-1)).sigmoid()
    cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
    return cur_dpo_loss

def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, eos_mask, beta, bon_mode='none', use_ce=False):
    # we always assume that the BoN size equals n_samples
    # mode1: use acc as rm
    # mode2: use Q as rm
    cur_Q = (token_level_scores * eos_mask).sum(dim=1) * beta
    other_Q = torch.zeros_like(cur_Q)
    batch_size = token_level_scores.shape[0]
    for i in range(token_level_scores.shape[0]):
        if acc[i] > 0:
            Q_chosen = Q_bc[i][acc_bc[i] < acc[i]]
        else:
            Q_chosen = Q_bc[i][acc_bc[i] > acc[i]]
        if use_ce or len(Q_chosen) == 0:
            other_Q[i] = 0
        else:
            other_Q[i] = Q_chosen.mean() * beta

    dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
    if bon_mode == 'none':
        dpo_loss = dpo_loss.mean()
    else:
        weight = torch.zeros_like(dpo_loss)
        n_samples = acc_bc.shape[1]
        n_samples_valid = 2
        if bon_mode == 'bon_rm':
            for i in range(token_level_scores.shape[0]):
                cur_response_weights = n_samples_valid * torch.pow(
                    (Q_bc[i].unsqueeze(0) <= Q_bc[i].unsqueeze(1)).float().mean(dim=-1), n_samples_valid - 1)
                weight[i] = n_samples_valid * torch.pow(
                    (Q_bc[i] * beta <= cur_Q[i]).float().mean(),
                    n_samples_valid - 1) / cur_response_weights.sum() * n_samples / batch_size
        elif bon_mode == 'bon_acc':
            for i in range(token_level_scores.shape[0]):
                cur_response_weights = n_samples_valid * torch.pow(
                    (acc_bc[i].unsqueeze(0) <= acc_bc[i].unsqueeze(1)).float().mean(dim=-1), n_samples_valid - 1)
                weight[i] = n_samples_valid * torch.pow(
                    (acc_bc[i] <= acc[i]).float().mean(),
                    n_samples_valid - 1) / cur_response_weights.sum() * n_samples / batch_size

        else:
            raise NotImplementedError
        dpo_loss = (dpo_loss * weight).sum()

    return dpo_loss


def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
    dpo_acc = []
    for start_id in range(0, token_level_scores.shape[0], n_samples):
        cur_scores = (token_level_scores[start_id:start_id + n_samples] *
                      eos_mask[start_id:start_id + n_samples]).sum(dim=1)

        def get_upper_triangle(tensor_x):
            diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
            upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
            return diff_matrix[upper_tri_indices]

        cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples])  # in range [-1,1]
        cur_score_diff = get_upper_triangle(cur_scores)  # in R
        cur_score_prediction = (cur_score_diff > 0).float()  # in [0,1]
        if cur_acc_diff.abs().sum() == 0:
            cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
        else:
            cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() *
                       cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum()

        dpo_acc.append(cur_acc.unsqueeze(0))

    return torch.cat(dpo_acc, dim=0).mean()

def compute_dpo_continual_accuracy(token_level_scores, acc, eos_mask, n_samples): # Determine the level of continuous correctness, later used for platt calibration, basic assumption is that existing model has consistent continual acc for each sample
    dpo_acc = []
    for start_id in range(0, token_level_scores.shape[0], n_samples):
        cur_scores = (token_level_scores[start_id:start_id + n_samples] *
                      eos_mask[start_id:start_id + n_samples]).sum(dim=1)

        def get_upper_triangle(tensor_x):
            diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
            upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
            return diff_matrix[upper_tri_indices]

        cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples])  # in range [-1,1]
        cur_score_diff = get_upper_triangle(cur_scores)  # in R
        cur_score_diff_signed = cur_score_diff*cur_acc_diff # For same correct/same wrong, score becomes 0, accuracy is 0.5
        cur_acc = torch.sigmoid(cur_score_diff_signed)

        dpo_acc.append(cur_acc.unsqueeze(0))

    return torch.cat(dpo_acc, dim=0).mean()
def compute_dpo_abs_accuracy(token_level_scores, acc, eos_mask, n_samples):
    return (torch.sign((token_level_scores * eos_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()


def compute_return_abs_accuracy(returns, acc):
    return (torch.sign(returns[:, 0]) == torch.sign(acc * 2 - 1)).float().mean()


def compute_return_smoothness(returns):
    return ((returns[:, :-1] - returns[:, 1:])**2).sum(dim=-1).mean()
