import torch
import torch.nn.functional as F

from typing import Dict
from transformers import AutoModelForCausalLM
from . import register_criterion


def cross_entropy_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    ignore_index: int = -100,
    reduction: str = "mean",
):
    loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)
    return loss


def accuracy_prediction(
    logits: torch.Tensor,
    labels: torch.Tensor,
):
    with torch.no_grad():
        _, pred = logits.topk(1, -1)
        correct = pred.view(-1).eq(labels.view(-1))
    return correct.sum()


@register_criterion("cross_entropy")
def cross_entropy_criterion(model: AutoModelForCausalLM, batch: Dict[str, torch.Tensor]):
    if hasattr(model, "compute"):
        logits, logits_std = model.compute(batch)
    else:
        logits, logits_std = model.module.compute(batch)

    logits = logits.view(-1, logits.size(-1))
    labels = batch["labels"].view(-1)

    tokens_num = labels.numel()

    loss_sparse = cross_entropy_loss(logits, labels)
    acc_sparse = accuracy_prediction(logits, labels)

    logits_std = logits_std.view(-1, logits_std.size(-1))
    labels_std = batch["standard"]["labels"].view(-1)

    tokens_num_std = labels_std.numel()

    loss_std = cross_entropy_loss(logits_std, labels_std)
    acc_std = accuracy_prediction(logits_std, labels_std)
    
    return {
        "loss": loss_sparse + loss_std,
        "loss_sparse": loss_sparse,
        "acc_sparse": acc_sparse,
        "tokens_num": tokens_num,
        "loss_std": loss_std,
        "acc_std": acc_std,
        "tokens_num_std": tokens_num_std,
    }


