from spaghettini import quick_register

import torch
from torch import nn
from torch.nn import CrossEntropyLoss


@quick_register
class LabelSmoothedCrossEntropy(nn.Module):
    def __init__(self, alpha, downweight_ce=True):
        super().__init__()
        self.alpha = alpha
        self.downweight_ce = downweight_ce  # Whether the original cross entropy loss is multiplied by (1 - alpha).
        self.ce = CrossEntropyLoss()

    def forward(self, input, target):
        assert len(input.shape) == 2
        num_classes = input.shape[1]

        # ____ Compute the standard cross entropy loss. ____
        hard_ce = self.ce(input=input, target=target)

        # ____ Compute the smoothing term. ____
        # Take the log-softmax of the logits.
        log_ps = torch.log_softmax(input, dim=-1)

        # Compute the cross entropy between the likelihoods and the uniform distribution.
        smoothing = ((-1 / num_classes) * torch.sum(log_ps, dim=1)).mean()

        # ____ Combine and return. ____
        if self.downweight_ce:
            smoothed_loss = (1.0 - self.alpha) * hard_ce + self.alpha * smoothing
        else:
            smoothed_loss = hard_ce + self.alpha * smoothing

        return smoothed_loss


if __name__ == "__main__":
    """
    Run from root.
    python -m src.dl.losses.label_smoothed_cross_entropy
    """
    test_num = 0

    if test_num == 0:
        # Create the one hot label.
        dim = 5
        target = torch.tensor([2, 3])

        # Create unnormalized probabilities. This can be anything.
        logits = torch.arange(0, 2*dim).view(2, dim).float()
        logits[0, 0] = -1000
        logits[1] = 1.

        # Label smoothing coefficient.
        alpha = 0.1

        # Compute the label smoothed cross entropy.
        loss_module = LabelSmoothedCrossEntropy(alpha=0.1)
        loss = loss_module(input=logits, target=target)
        print(f"Loss is: {loss}")
