              
                                                      
                                                                 

from contextlib import nullcontext, contextmanager
from copy import deepcopy
from enum import Enum
from typing import Literal, Optional, Tuple, Union, Dict, List, Callable, Iterator
import logging
import itertools

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from einops import rearrange

from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core import mpu

from gpatch.core.aligner_helper import masked_mean
from gpatch.core.device_type import is_wxacc1
from gpatch.core.tensor_parallel.mappings import all_gather_to_context_parallel_region


def calculate_advantages_and_returns(
    values,
    rewards,
    discount_factor,
    gae_lambda,
    mask=None,
    per_token_rewards=None,
    per_token_rewards_factor=1.0,
):
    """calculate the per token advantages and returns for the entire sequence

    Args:
        values, rewards (torch.Tensor): shape of B x (S-1)
    """
    assert mask is not None
    if mask is not None:
                                                                                              
        values = values * mask
        rewards = rewards * mask

    last_gae_lam = 0
    advantages = torch.zeros_like(rewards)
    max_seq_len = values.size(-1)

    for i in reversed(range(max_seq_len)):
        if i == max_seq_len - 1:
            next_values = 0.0                                    
        else:
            next_values = values[:, i + 1]                                 
        delta = rewards[:, i] + discount_factor * next_values - values[:, i]
        last_gae_lam = delta + discount_factor * gae_lambda * last_gae_lam
        advantages[:, i] = last_gae_lam

    if per_token_rewards is not None:
        advantages += (per_token_rewards * per_token_rewards_factor)

    returns = advantages + values
    return advantages, returns


def calculate_grpo_advantages(
    rewards: List[torch.Tensor],
    mask: List[torch.Tensor],
    grpo_sampling_times=1,
    grpo_advantage_epsilon=1e-6,
):
    for reward in rewards:
        assert reward.numel() == 1
    scores = [reward.sum() for reward in rewards]
    scores = torch.stack(scores).view(-1)
                                  
    mean_grouped_rewards = scores.view(-1, grpo_sampling_times).mean(dim=1)
    std_grouped_rewards = scores.view(-1, grpo_sampling_times).std(dim=1)
    mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(grpo_sampling_times, dim=0)
    std_grouped_rewards = std_grouped_rewards.repeat_interleave(grpo_sampling_times, dim=0)
    advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + grpo_advantage_epsilon)
    advantages_mask = []
    for advantage, m in zip(advantages.chunk(len(rewards)), mask):
        assert m.ndim == 1
        advantages_mask.append(advantage.tile([m.shape[-1]]) * m)

    return advantages_mask, advantages_mask


def calculate_entropy(log_probs, mask=None):
    """calculate the entropy, with an optional mask

    Args:
        log_probs (torch.Tensor): Tensor of log probs with shape [B x S x V]
        mask (torch.Tensor): Tensor of masks on the sequence length with shape B x S
    """
    entropy_unmasked = -torch.sum(log_probs.exp() * log_probs, dim=-1)
    return entropy_unmasked.mean() if mask is None else masked_mean(entropy_unmasked, mask)


def calculate_ppo_rewards(values,
                          rewards,
                          per_token_rewards,
                          sequence_lengths,
                          init_policy_kl,
                          penalty_factor=0.0):
    """the reward should be defined on the last valid action"""
    rewards_sequence = torch.zeros_like(values)

    idx = (sequence_lengths - 2).clamp(min=0, max=None)
    rewards_sequence[torch.arange(rewards_sequence.size(0)), idx] = rewards.flatten()

    if per_token_rewards is not None:
        rewards_sequence += per_token_rewards

    return rewards_sequence - penalty_factor * init_policy_kl


def calculate_kl_penalty(log_probs_a: List[torch.Tensor],
                         log_probs_b: List[torch.Tensor],
                         use_absolute_kl=True):
    """Calculates a per-token estimate of the KL Divergence between two log_probs.
    """
    init_policy_kl = []
    for log_prob_a, log_prob_b in zip(log_probs_a, log_probs_b):
        tmp = log_prob_a - log_prob_b
        if use_absolute_kl:
            tmp = tmp.abs()
        init_policy_kl.append(tmp)

    return init_policy_kl


def calculate_kl_loss(cur_log_probs,
                      ref_log_probs,
                      use_absolute_kl=True,
                      use_low_var_kl=False,
                      clamp_kl_loss=False,
                      clamp_kl_val=None):
    kl = ref_log_probs - cur_log_probs

    if use_low_var_kl:
                                 
        if clamp_kl_val is not None:
            kl = torch.clamp(kl, min=-clamp_kl_val, max=clamp_kl_val)
        ratio = torch.exp(kl)
        kl_loss = (ratio - kl - 1).contiguous()
        if clamp_kl_loss:
                                                   
            kl_loss = torch.clamp(kl_loss, min=-10, max=10)

    if use_absolute_kl:
        kl_loss = kl_loss.abs()

    return kl_loss


def create_mask(values: List[torch.Tensor],
                prompt_lengths: List[torch.Tensor],
                sequence_lengths: List[torch.Tensor],
                dtype=None):
    """Creates a mask to only keep the values in the sequence that are between prompt_lengths and sentence_lengths.
    This results in removing the prompt tokens, and removing the padding at the end of the sequence.
    """
    mask = []
    for value, prompt_length, response_length in zip(values, prompt_lengths, sequence_lengths):
                                          
        tmp = torch.zeros_like(value, dtype=dtype)
        tmp[prompt_length - 1:response_length - 1] = 1.0
        mask.append(tmp)
    return mask


class _VocabParallelEntropy(torch.autograd.Function):

    @staticmethod
    def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:

        @torch.compile(dynamic=True)
        def mul_reduce(a, b):
            return (a * b).sum(dim=-1, keepdim=True)

        logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
        dist.all_reduce(logits_max,
                        op=dist.ReduceOp.MAX,
                        group=mpu.get_tensor_model_parallel_group())
        normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
        normalized_exp_logits = normalized_vocab_parallel_logits.exp_()
        normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
        dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())
        softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)
        sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)
        dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())
        entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
        ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
        return entropy.squeeze(dim=-1)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
                                      
        vocab_parallel_logits.sub_(sum_softmax_times_logits)
        softmax_logits.mul_(vocab_parallel_logits)
        softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
                                       
        vocab_parallel_logits.add_(sum_softmax_times_logits)
        softmax_logits.mul_(-1)
        return softmax_logits


def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor, mask=None) -> torch.Tensor:
                                                                
    if mpu.get_context_parallel_world_size() > 1:
                                 
        entropy_unmasked = _VocabParallelEntropy.apply(vocab_parallel_logits)
        output_entropy_unmasked = all_gather_to_context_parallel_region(entropy_unmasked)[:, :-1]
    else:
        output_entropy_unmasked = _VocabParallelEntropy.apply(vocab_parallel_logits)[:, :-1]

    return output_entropy_unmasked.mean() if mask is None else masked_mean(output_entropy_unmasked, mask)
