import math
import torch
import torch.nn.functional as F

from torch.nn.modules.loss import _Loss
from typing import Dict
from . import register_criterion


def cross_entropy_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    ignore_index: int = -100,
    reduction: str = "mean",
):
    loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
    return loss


def accuracy_prediction(
    logits: torch.Tensor,
    labels: torch.Tensor,
):
    with torch.no_grad():
        _, pred = logits.topk(1, -1)
        correct = pred.view(-1).eq(labels.view(-1))
    return correct.sum()


@register_criterion("cross_entropy")
class CrossEntropyCriterion(_Loss):

    def __init__(self):
        super().__init__()

    def forward(self, model: torch.nn.Module, batch: Dict[str, torch.Tensor]):
        outputs = model(tokens=batch["input_ids"])
        logits = outputs

        logits = logits.view(-1, logits.size(-1))
        labels = batch["labels"].view(-1)
        tokens_num = labels.numel()

        loss = cross_entropy_loss(logits, labels)
    
        return {
            "loss": loss,
            "tokens_num": tokens_num,
        }

    def logging_outputs(self, outputs: Dict[str, any]):
        return {
            "loss": "{:2f}".format(outputs["loss"]),
            "ppl": "{:.2f}".format(math.exp(outputs["loss"])),
        }