import einops
import torch


class CrossEntropyLossExtraBatch(torch.nn.Module):

    def __init__(self, label_smoothing: float):
        super().__init__()

        self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)


    def forward(self, input, target):
        """
        Input has shape (batch_size, num_samples, num_classes)
        Target has shape (batch_size, num_samples)

        Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
        """

        input = einops.rearrange(input, 'b s c -> (b s) c')
        target = einops.rearrange(target, 'b s -> (b s)')

        return self.loss(input, target)



