

import numpy as np
import torch
from collections import defaultdict

import verl.utils.torch_functional as verl_F


class AdaptiveKLController:


    def __init__(self, init_kl_coef, target_kl, horizon):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl, n_steps):
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:


    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current_kl, n_steps):
        pass


def get_kl_controller(config):
    if config.critic.kl_ctrl.type == 'fixed':
        kl_ctrl = FixedKLController(kl_coef=config.critic.kl_ctrl.kl_coef)
    elif config.critic.kl_ctrl.type == 'adaptive':
        assert config.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}'
        kl_ctrl = AdaptiveKLController(init_kl_coef=config.critic.kl_ctrl.kl_coef,
                                       target_kl=config.critic.kl_ctrl.target_kl,
                                       horizon=config.critic.kl_ctrl.horizon)
    else:
        raise ValueError('Unknown kl_ctrl type')

    return kl_ctrl


def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
                                 gamma: torch.Tensor, lam: torch.Tensor):

    with torch.no_grad():
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        advantages = verl_F.masked_whiten(advantages, eos_mask)
    return advantages, returns



def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
                                   eos_mask: torch.Tensor,
                                   index: torch.Tensor,
                                   epsilon: float = 1e-6):

    response_length = token_level_rewards.shape[-1]
    non_zero_mask = (token_level_rewards != 0)
    scores = (token_level_rewards * non_zero_mask).sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
        scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask

    return scores, scores


def _compute_confidence_ci(old_log_prob: torch.Tensor,
                           response_mask: torch.Tensor,
                           epsilon: float = 1e-8) -> torch.Tensor:


    masked_log_prob = old_log_prob * response_mask


    sequence_lengths = response_mask.sum(dim=-1)


    sum_log_prob = masked_log_prob.sum(dim=-1)


    mean_log_prob = sum_log_prob / (sequence_lengths + epsilon)


    ci = torch.exp(mean_log_prob)

    return ci



def compute_advantage_CCPO_BCE(token_level_rewards: torch.Tensor,
                                 old_log_prob: torch.Tensor,
                                 response_mask: torch.Tensor,
                                 index: np.ndarray,
                                 epsilon: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:


    scores = token_level_rewards.sum(dim=-1)
    device = scores.device


    ci = _compute_confidence_ci(old_log_prob, response_mask)


    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i].item())

        for idx, score_list in id2score.items():
            id2mean[idx] = torch.tensor(np.mean(score_list), device=device)


        advantages_scalar = torch.zeros_like(scores)
        for i in range(bsz):
            numerator = scores[i] - id2mean[index[i]]

            denominator = 1.0 - ci[i]
            advantages_scalar[i] = numerator / (denominator + epsilon)


    advantages = advantages_scalar.unsqueeze(-1) * response_mask


    return advantages, advantages




def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor,
                                                  eos_mask: torch.Tensor,
                                                  gamma: torch.Tensor,
                                                  epsilon: float = 1e-6):



    with torch.no_grad():
        returns = torch.zeros_like(token_level_rewards)
        running_return = 0

        for t in reversed(range(token_level_rewards.shape[1])):
            running_return = token_level_rewards[:, t] + gamma * running_return
            returns[:, t] = running_return

            running_return = running_return * eos_mask[:, t]

        advantages = verl_F.masked_whiten(returns, eos_mask)
        advantages = advantages * eos_mask

    return advantages, returns


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
    kl = old_log_prob - ref_log_prob
    return token_level_scores - kl * kl_ratio


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):

    if loss_agg_mode == "token-mean":
        loss = verl_F.masked_mean(loss_mat, loss_mask)
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
        loss = torch.mean(seq_losses)
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)
        loss = torch.mean(seq_losses)
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss



def compute_policy_loss(old_log_prob,
                        log_prob,
                        advantages,
                        response_mask,
                        cliprange=None,
                        cliprange_low=None,
                        cliprange_high=None,
                        clip_ratio_c=3.0,
                        loss_agg_mode="token-mean",
                        ):

    assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."

    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

    pg_losses1 = -advantages * ratio
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low,
                                           1 + cliprange_high)
    clip_pg_losses1 = torch.maximum(pg_losses1,
                                    pg_losses2)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_clipfrac_lower = verl_F.masked_mean(
        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


def compute_policy_loss_new(
        algorithm_name: str,
        old_log_prob: torch.Tensor,
        log_prob: torch.Tensor,
        advantages: torch.Tensor,
        response_mask: torch.Tensor,

        token_level_rewards: torch.Tensor = None,
        beta: float = 0.01,
        positive_scale: float = 1.0,
        negative_scale: float = 1.0,
        loss_agg_mode: str = "token-mean",
        scale_method: str = "",
        epsilon: float = 1e-8,
        cliprange=None,
        cliprange_low=None,
        cliprange_high=None,
        clip_ratio_c=3.0,
        old_denominators=None
):



    approx_kl = verl_F.masked_mean(old_log_prob - log_prob, response_mask)

    device = log_prob.device
    pg_clipfrac = torch.tensor(0.0, device=device)
    pg_clipfrac_lower = torch.tensor(0.0, device=device)
    regularization_loss = torch.tensor(0.0, device=device)

    positive_bce_part = torch.tensor(0.0, device=device)
    negative_bce_part = torch.tensor(0.0, device=device)
    n_conflict  = 0.0

    if algorithm_name == "grpo":

        assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."

        negative_approx_kl = log_prob - old_log_prob
        ratio = torch.exp(negative_approx_kl)
        ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

        pg_losses1 = -advantages * ratio
        if cliprange_low is None:
            cliprange_low = cliprange
        if cliprange_high is None:
            cliprange_high = cliprange
        pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low,
                                               1 + cliprange_high)
        clip_pg_losses1 = torch.maximum(pg_losses1,
                                        pg_losses2)
        pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

        pg_losses3 = -advantages * clip_ratio_c
        clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
        pg_clipfrac_lower = verl_F.masked_mean(
            torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)

        pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
        pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)




    elif algorithm_name == "ccpo_bce":

        if token_level_rewards is None:
            raise ValueError("`token_level_rewards` must be provided for 'ours_bce'.")

        ci_new = _compute_confidence_ci(log_prob, response_mask, epsilon)
        rewards_scalar = token_level_rewards.sum(dim=-1)
        advantages_scalar = advantages[:, 0]
        if scale_method == 'beta_dynamic':

            n_total = rewards_scalar.numel()

            with torch.no_grad():
                adv_pos = advantages_scalar > 0

                adv_neg = advantages_scalar < 0
                conf_err_pos = (rewards_scalar - ci_new) > 0

                conf_err_neg = (rewards_scalar - ci_new) < 0


                conflict_case_1 = adv_pos & conf_err_neg
                conflict_case_2 = adv_neg & conf_err_pos
                conflict_mask = conflict_case_1 | conflict_case_2


                effective_beta = torch.full_like(advantages_scalar, beta)
                effective_beta[conflict_mask] = 0.0


                n_conflict = conflict_mask.sum().item() / n_total


        ci_clamped = torch.clamp(ci_new, min=epsilon, max=1.0 - epsilon)

        policy_term = torch.log(ci_clamped) * advantages_scalar

        positive_bce_term_unscaled = rewards_scalar * torch.log(ci_clamped)


        negative_bce_term_unscaled = (1 - rewards_scalar) * torch.log(1 - ci_clamped)


        if scale_method == 'beta_dynamic':
            bce_term = effective_beta * (
                    positive_scale * positive_bce_term_unscaled +
                    negative_scale * negative_bce_term_unscaled
            )
        else:
            bce_term = beta * (positive_scale*positive_bce_term_unscaled + negative_scale*negative_bce_term_unscaled)


        positive_bce_part = positive_bce_term_unscaled.mean()
        negative_bce_part = negative_bce_term_unscaled.mean()
        regularization_loss = - bce_term.mean()


        objective = policy_term + bce_term
        pg_loss = -objective.mean()

    else:
        raise ValueError(f"Unknown algorithm_name: '{algorithm_name}'. "
                         "Choices are: 'grpo', 'ccpo_bce'.")

    return pg_loss, pg_clipfrac, approx_kl, pg_clipfrac_lower, regularization_loss, positive_bce_part, negative_bce_part, n_conflict


def compute_entropy_loss(logits, eos_mask):

    entropy = verl_F.entropy_from_logits(logits)
    entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
    return entropy_loss


def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):

    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
    return vf_loss, vf_clipfrac


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:

    if kl_penalty == "kl":
        return logprob - ref_logprob

    if kl_penalty == "abs":
        return (logprob - ref_logprob).abs()

    if kl_penalty == "mse":
        return 0.5 * (logprob - ref_logprob).square()


    if kl_penalty == 'low_var_kl':
        kl = ref_logprob - logprob
        kl = torch.clamp(kl, min=-5, max=5)
        ratio = torch.exp(kl)
        kld = (ratio - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)

    if kl_penalty == "full":

        raise NotImplementedError

    raise NotImplementedError
