# (c) 2023 anonymous authors, not to be distributed or used for commercial purposes.

import numpy as np
import torch

ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)


def _perplexity(encoding, logits, T=1.0):
    total_tokens_available = (
        logits.shape[-2] - 1
    )  # Minus 1 for the label shifting for autoregressive loss
    # tokens_observed = min(tokens_observed, total_tokens_available)

    shifted_logits = logits[..., :-1, :].contiguous() / T  # [:, :tokens_observed, :]
    shifted_labels = encoding.input_ids[..., 1:].contiguous()  # [:, :tokens_observed]
    shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()
    # [
    #     :, :tokens_observed
    # ]

    perplexity = (
        ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels)
        * shifted_attention_mask
    ).sum(1) / shifted_attention_mask.sum(1)

    return perplexity.to("cpu").float().numpy()


def perplexity(encoding, logits, T=1.0):
    # ppl_128 = _perplexity(encoding, logits, tokens_observed=128, T=T)
    # ppl_256 = _perplexity(encoding, logits, tokens_observed=256, T=T)
    ppl = _perplexity(encoding, logits, T=T)

    # return ppl_128, ppl_256, ppl
    return None, None, ppl


def entropy(
    p_scores,
    q_scores,
    encoding,
    pad_token_id,
    padding_side,
    # tokens_observed=512,
    sample_p=False,
    T=1.0,
):
    vocab_size = p_scores.shape[-1]
    total_tokens_available = q_scores.shape[-2]
    # tokens_observed = min(tokens_observed, total_tokens_available)

    p_scores, q_scores = p_scores / T, q_scores / T

    p = softmax_fn(p_scores).view(-1, vocab_size)

    if sample_p:
        p = torch.multinomial(
            p.view(-1, vocab_size), replacement=True, num_samples=1
        ).view(-1)

    q_scores = q_scores.view(-1, vocab_size)


    if padding_side == "left":
        # If padded tokens are in left, observe tokens from right
        ce = ce_loss_fn(input=q_scores, target=p).view(-1, total_tokens_available)
        # [
        #     ..., (total_tokens_available - tokens_observed) :
        # ]
        padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)
        # [
        #     ..., (total_tokens_available - tokens_observed) :
        # ]
    else:
        # If padded tokens are in right, observe tokens from left
        ce = ce_loss_fn(input=q_scores, target=p).view(-1, total_tokens_available)
        # [
        #     ..., :tokens_observed
        # ]
        padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)
        # [
        #     ..., :tokens_observed
        # ]

    # print(ce.shape)

    mean_ce = (
        ((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy()
    )

    return mean_ce
