import torch
import torch.nn.functional as F


kd_criterion_dict = {
    'mse': torch.nn.MSELoss,
}

def calc_class_probs(outputs, target_seq, num_classes):
    log_probs = -F.cross_entropy(
        outputs.logits.view(-1, outputs.logits.size(-1)),
        target_seq.view(-1), ignore_index=-100, reduction='none'
    ) 
    log_probs = log_probs.view(-1, target_seq.size(-1)).sum(dim=-1)
    seq_lengths = (target_seq != -100).sum(dim=-1) * 1.0
    log_probs /= seq_lengths
    log_probs = log_probs.view(-1, num_classes)
    return log_probs

def calc_task_loss(logits, targets, reduction='mean', class_weights=None):
    assert len(logits) == len(targets)
    return F.cross_entropy(logits, targets, weight=class_weights, reduction=reduction)

class LabelSmoothingLoss(torch.nn.Module):
    def __init__(self, smoothing: float = 0.1, 
                 reduction="mean", weight=None):
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing
        self.reduction = reduction
        self.weight    = weight

    def reduce_loss(self, loss):
        return loss.mean() if self.reduction == 'mean' else loss.sum() \
         if self.reduction == 'sum' else loss

    def linear_combination(self, x, y):
        return self.smoothing * x + (1 - self.smoothing) * y

    def forward(self, preds, target):
        assert 0 <= self.smoothing < 1

        if self.weight is not None:
            self.weight = self.weight.to(preds.device)

        n = preds.size(-1)
        log_preds = F.log_softmax(preds, dim=-1)
        loss = self.reduce_loss(-log_preds.sum(dim=-1))
        nll = F.nll_loss(
            log_preds, target, reduction=self.reduction, weight=self.weight
        )
        return self.linear_combination(loss / n, nll)