

import numpy as np
import torch
from collections import defaultdict

import verl.utils.torch_functional as verl_F
import torch.nn.functional as F

class AdaptiveKLController:


    def __init__(self, init_kl_coef, target_kl, horizon):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl, n_steps):
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:


    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current_kl, n_steps):
        pass


def get_kl_controller(kl_ctrl):
    if kl_ctrl.type == 'fixed':
        return FixedKLController(kl_coef=kl_ctrl.kl_coef)
    elif kl_ctrl.type == 'adaptive':
        assert kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {kl_ctrl.horizon}'
        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
    else:
        raise NotImplementedError


def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor,
                                 gamma: torch.Tensor, lam: torch.Tensor):

    with torch.no_grad():
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        advantages = verl_F.masked_whiten(advantages, response_mask)
    return advantages, returns



def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
                                   response_mask: torch.Tensor,
                                   index: np.ndarray,
                                   epsilon: float = 1e-6):

    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores



def _compute_confidence_ci(old_log_prob: torch.Tensor,
                           response_mask: torch.Tensor,
                           epsilon: float = 1e-8) -> torch.Tensor:



    masked_log_prob = old_log_prob * response_mask


    sequence_lengths = response_mask.sum(dim=-1)


    sum_log_prob = masked_log_prob.sum(dim=-1)


    mean_log_prob = sum_log_prob / (sequence_lengths + epsilon)


    ci = torch.exp(mean_log_prob)

    return ci




def compute_advantage_CCPO_BCE(token_level_rewards: torch.Tensor,
                                 old_log_prob: torch.Tensor,
                                 response_mask: torch.Tensor,
                                 index: np.ndarray,
                                 epsilon: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:


    scores = token_level_rewards.sum(dim=-1)
    device = scores.device


    with torch.no_grad():
        masked_log_prob = old_log_prob * response_mask

        sequence_lengths = response_mask.sum(dim=-1)

        sum_log_prob = masked_log_prob.sum(dim=-1)

        mean_log_prob = sum_log_prob / (sequence_lengths + epsilon)

        ci = torch.exp(mean_log_prob)


    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i].item())

        for idx, score_list in id2score.items():
            id2mean[idx] = torch.tensor(np.mean(score_list), device=device)


        advantages_scalar = torch.zeros_like(scores)
        for i in range(bsz):
            numerator = scores[i] - id2mean[index[i]]

            denominator = 1.0 - ci[i]
            advantages_scalar[i] = numerator / (denominator + epsilon)


    advantages = advantages_scalar.unsqueeze(-1) * response_mask


    return advantages, advantages


def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor,
                                                           response_mask: torch.Tensor,
                                                           index: torch.Tensor,
                                                           epsilon: float = 1e-6):

    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = scores[i] - id2mean[index[i]]

        scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
        scores = verl_F.masked_whiten(scores, response_mask)

    return scores, scores


def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
                                   response_mask: torch.Tensor,
                                   index: np.ndarray,
                                   epsilon: float = 1e-6):

    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            response_num = len(id2score[index[i]])
            if response_num > 1:
                scores[i] = scores[i] * response_num / (response_num -
                                                        1) - id2mean[index[i]] * response_num / (response_num - 1)
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor,
                                                  gamma: torch.Tensor):


    with torch.no_grad():
        returns = torch.zeros_like(token_level_rewards)
        running_return = 0

        for t in reversed(range(token_level_rewards.shape[1])):
            running_return = token_level_rewards[:, t] + gamma * running_return
            returns[:, t] = running_return

            running_return = running_return * response_mask[:, t]

        advantages = verl_F.masked_whiten(returns, response_mask)
        advantages = advantages * response_mask

    return advantages, returns


def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
                                    response_mask: torch.Tensor):


    with torch.no_grad():
        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask

    return advantages, returns


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
    kl = old_log_prob - ref_log_prob
    return token_level_scores - kl * kl_ratio


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):

    if loss_agg_mode == "token-mean":
        loss = verl_F.masked_mean(loss_mat, loss_mask)
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
        loss = torch.mean(seq_losses)
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)
        loss = torch.mean(seq_losses)
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss




def compute_policy_loss(old_log_prob,
                        log_prob,
                        advantages,
                        response_mask,
                        token_level_rewards=None,
                        cliprange=None,
                        cliprange_low=None,
                        cliprange_high=None,
                        clip_ratio_c=3.0,
                        loss_agg_mode="token-mean",
                        enable_confidence_loss=False,
                        confidence_target_source="reward",
                        confidence_loss_type="bce",
                        lambda_confidence=1.0,
                        confidence_reward_scale_factor=1.0
                        ):

    assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."

    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low,
                                           1 + cliprange_high)
    clip_pg_losses1 = torch.maximum(pg_losses1,
                                    pg_losses2)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_clipfrac_lower = verl_F.masked_mean(
        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)


    confidence_loss = torch.tensor(0.0, device=pg_loss.device, dtype=pg_loss.dtype)
    if enable_confidence_loss:
        with torch.no_grad():
            if confidence_target_source=='reward':
                if token_level_rewards is None:
                    raise ValueError("use 'reward' as target, must provide `token_level_rewards`")
                rewards = token_level_rewards.sum(dim=-1)
                rewards_scaled = rewards * confidence_reward_scale_factor
                target_prob = torch.sigmoid(rewards_scaled)
            elif confidence_target_source == "advantage":
                sample_advantages = verl_F.masked_mean(advantages, response_mask)
                advantages_scaled = sample_advantages * confidence_reward_scale_factor
                target_prob = torch.sigmoid(advantages_scaled)
        new_probs = torch.exp(log_prob)
        confidence = verl_F.masked_mean(new_probs, response_mask, axis=-1)
        if confidence_loss_type == "bce":
            epsilon = 1e-8
            confidence_clamped = torch.clamp(confidence, min=epsilon, max=1.0 - epsilon)
            confidence_loss = F.binary_cross_entropy(
                input=confidence_clamped,
                target=target_prob
            )
        elif confidence_loss_type == "mse":
            confidence_loss = F.mse_loss(confidence, target_prob)
        else:
            raise ValueError(f"unknown confidence_loss_type: {confidence_loss_type}")




    total_policy_loss = pg_loss + lambda_confidence * confidence_loss

    return total_policy_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower, confidence_loss





def compute_policy_loss_new(
        algorithm_name: str,
        old_log_prob: torch.Tensor,
        log_prob: torch.Tensor,
        advantages: torch.Tensor,
        response_mask: torch.Tensor,

        token_level_rewards: torch.Tensor = None,
        beta: float = 0.01,
        loss_agg_mode: str = "token-mean",
        epsilon: float = 1e-8,
        cliprange=None,
        cliprange_low=None,
        cliprange_high=None,
        clip_ratio_c=3.0,
):

    approx_kl = verl_F.masked_mean(old_log_prob - log_prob, response_mask)

    device = log_prob.device
    pg_clipfrac = torch.tensor(0.0, device=device)
    pg_clipfrac_lower = torch.tensor(0.0, device=device)
    regularization_loss = torch.tensor(0.0, device=device)

    positive_bce_part = torch.tensor(0.0, device=device)
    negative_bce_part = torch.tensor(0.0, device=device)

    if algorithm_name == "grpo":

        assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."

        negative_approx_kl = log_prob - old_log_prob
        ratio = torch.exp(negative_approx_kl)
        ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

        pg_losses1 = -advantages * ratio
        if cliprange_low is None:
            cliprange_low = cliprange
        if cliprange_high is None:
            cliprange_high = cliprange
        pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low,
                                               1 + cliprange_high)
        clip_pg_losses1 = torch.maximum(pg_losses1,
                                        pg_losses2)
        pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

        pg_losses3 = -advantages * clip_ratio_c
        clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
        pg_clipfrac_lower = verl_F.masked_mean(
            torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

        pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
        pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    elif algorithm_name == "ccpo_bce":

        if token_level_rewards is None:
            raise ValueError("`token_level_rewards` must be provided for 'ours_bce'.")

        ci_new = _compute_confidence_ci(log_prob, response_mask, epsilon)
        rewards_scalar = token_level_rewards.sum(dim=-1)
        advantages_scalar = advantages[:, 0]


        ci_clamped = torch.clamp(ci_new, min=epsilon, max=1.0 - epsilon)

        policy_term = torch.log(ci_clamped) * advantages_scalar

        positive_bce_term_unscaled = rewards_scalar * torch.log(ci_clamped)
        negative_bce_term_unscaled = (1 - rewards_scalar) * torch.log(1 - ci_clamped)
        bce_term = beta * ( positive_bce_term_unscaled + negative_bce_term_unscaled)

        positive_bce_part = positive_bce_term_unscaled.mean()
        negative_bce_part = negative_bce_term_unscaled.mean()
        regularization_loss = - bce_term.mean()

        objective = policy_term + bce_term
        pg_loss = -objective.mean()

    else:
        raise ValueError(f"Unknown algorithm_name: '{algorithm_name}'. "
                         "Choices are: 'grpo', 'ccpo_bce'.")

    return pg_loss, pg_clipfrac, approx_kl, pg_clipfrac_lower, regularization_loss, positive_bce_part, negative_bce_part





def compute_entropy_loss(logits, response_mask):


    entropy = verl_F.entropy_from_logits(logits)
    entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
    return entropy_loss


def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):

    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = (vpreds - returns)**2
    vf_losses2 = (vpredclipped - returns)**2
    vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
    return vf_loss, vf_clipfrac


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:

    if kl_penalty == "kl":
        return logprob - ref_logprob

    if kl_penalty == "abs":
        return (logprob - ref_logprob).abs()

    if kl_penalty == "mse":
        return 0.5 * (logprob - ref_logprob).square()


    if kl_penalty == 'low_var_kl':
        kl = ref_logprob - logprob
        ratio = torch.exp(kl)
        kld = (ratio - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":

        raise NotImplementedError

    raise NotImplementedError
