import math
import torch
import torch.nn.functional as F

from torch.nn.modules.loss import _Loss
from typing import Dict
from . import register_criterion


def cross_entropy_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    ignore_index: int = -100,
    reduction: str = "mean",
):
    loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
    return loss


def accuracy_prediction(
    logits: torch.Tensor,
    labels: torch.Tensor,
):
    with torch.no_grad():
        _, pred = logits.topk(1, -1)
        correct = pred.view(-1).eq(labels.view(-1))
    return correct.sum()


@register_criterion("colm_cross_entropy")
class CoLMCrossEntropyCriterion(_Loss):

    def __init__(self):
        super().__init__()

    def forward(self, model: torch.nn.Module, batch: Dict[str, torch.Tensor]):
        outputs = model(tokens=batch["input_ids"])
        labels = batch["labels"].view(-1)
        tokens_num = labels.numel()

        results = {
            "loss": 0.0,
            "tokens_num": tokens_num,
        }

        for i, logits in enumerate(outputs):
            logits = logits.view(-1, logits.size(-1))
            loss_i = cross_entropy_loss(logits, labels)
            results[f"loss_{i}"] = loss_i
            results["loss"] += loss_i
        results["loss"] *= 1.0 / len(outputs)
        return results
    
    def logging_outputs(self, outputs: Dict[str, any]):
        logging_output = {
            "loss": "{:2f}".format(outputs["loss"]),
            "ppl": "{:.2f}".format(math.exp(outputs["loss"])),
        }
        for i in outputs.keys():
            if i.startswith("loss_"):
                logging_output[f"{i}"] = "{:2f}".format(outputs[i].item())

        return logging_output