import torch
import math


class CrossEntropyWithZLoss(torch.nn.Module):
    """Cross Entropy plus logit regularization via z_loss."""

    __constants__ = ["ignore_index", "z_loss_factor"]
    ignore_index: int
    z_loss_factor: float

    def __init__(self, ignore_index=-100, z_loss_factor=1e-4):
        super().__init__()
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.z_loss_factor = z_loss_factor
        self.ignore_index = ignore_index

    def forward(self, inputs, labels):
        """Is this is the optimal implementation? Is this even what is meant?
        I wish there were more answers or code for PaLM

        This implementation assumes that log(Z) is log(sum(exp(logits))).
        The usage of log2 here is also a bit wild...
        """
        z_reg = inputs.exp().sum(dim=-1).log2().sum() * self.z_loss_factor
        return self.loss_fn(inputs, labels) + z_reg


class MSELoss(torch.nn.Module):
    """MSE Loss as a drop-in replacement for Cross Entropy Loss.

    This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes."""

    def __init__(self, ignore_index=-100):
        """Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)"""
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, inputs, labels):
        """Is this is the optimal implementation? Could also do an index_select variation..."""
        num_classes = inputs.shape[-1]
        valid_mask = labels != self.ignore_index
        M = math.sqrt(num_classes)
        onehot_labels = self._label_to_onehot(labels[valid_mask], M, num_classes=num_classes)
        return 1 / (2 * M * num_classes) * (inputs[valid_mask] - onehot_labels).pow(2).sum()

    @staticmethod
    @torch.jit.script
    def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
        onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)
        onehot_target.scatter_(1, target.view(-1, 1), M)
        return onehot_target


class MSELossFast(torch.nn.Module):
    """MSE Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels

    This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes."""

    def __init__(self, ignore_index=-100):
        """Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)"""
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, inputs, labels):
        """Is this is the optimal implementatio? This at least circumvents literal 1-hot labels"""
        num_examples, num_classes = inputs.shape
        valid_mask = labels != self.ignore_index
        M = math.sqrt(num_classes)

        inputs = inputs[valid_mask]
        labels = labels[valid_mask]

        x_i = inputs.pow(2).sum()
        x_j = inputs[torch.arange(labels.shape[-1]), labels].sum()
        return 1 / (2 * M * num_classes) * (x_i - 2 * M * x_j + labels.shape[-1] * M**2)


class L1Loss(torch.nn.Module):
    """L1 Loss as a drop-in replacement for Cross Entropy Loss. Only for 2dim inputs and 1dim labels

    This implementation includes a mean reduction in batch dimension and a 1/num_classes/M reduction in classes."""

    def __init__(self, ignore_index=-100):
        """Parameters as in Hui&Belkin, 2021, but k=1, and M=sqrt(C) (so maybe not really Hui&Belkin?)"""
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, inputs, labels):
        """Is this is the optimal implementation? Could also do an index_select variation..."""
        num_classes = inputs.shape[-1]
        valid_mask = labels != self.ignore_index
        M = math.sqrt(num_classes)
        onehot_labels = self._label_to_onehot(labels[valid_mask], M, num_classes=num_classes)
        return 1 / (M * num_classes) * (inputs[valid_mask] - onehot_labels).abs().sum()

    @staticmethod
    @torch.jit.script
    def _label_to_onehot(target, M: float = 1.0, num_classes: int = 100):
        onehot_target = torch.zeros(target.shape[0], num_classes, device=target.device)
        onehot_target.scatter_(1, target.view(-1, 1), M)
        return onehot_target


class SzegedyLoss(torch.nn.Module):
    """Regression directly back to input embedding. Remove the decoding layer if using this loss.

    As mentioned at https://twitter.com/ChrSzegedy/status/1533322132368728064?t=xz00T1YT3-WiE0id-h3MEA&s=19
    """

    def __init__(self, embedding_layer, ignore_index=-100, overrelaxation=2.0):
        """Overrelax parameter is quite a bit speculative..."""
        super().__init__()
        self.embedding = embedding_layer
        self.ignore_index = ignore_index
        self.overrelaxation = overrelaxation

    def forward(self, inputs, labels):
        """This really just does L2(DNN(embed(x[:,:-1]), 2.0 * stop_gradient(embed(x[:,1:]))) as quoted above"""
        num_examples, num_classes = inputs.shape
        valid_mask = labels != self.ignore_index
        M = math.sqrt(num_classes)

        inputs = inputs[valid_mask]
        with torch.no_grad():
            embedded_labels = self.overrelaxation * self.embedding(labels)[valid_mask]

        return (inputs - embedded_labels).pow(2).sum() / labels.shape[-1] / num_classes
