import torch
from torch import nn
from torch_scatter import scatter_sum


class CST(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.alpha = kargs.get("alpha", 0.1)

    def forward(self, logits):
        return torch.tensor(0.0)


class NLS(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.alpha = kargs.get("alpha", 0.1)

    def forward(self, logits):
        lprobs = torch.log_softmax(logits, dim=-1)
        return self.alpha * torch.mean(lprobs)


class LS(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.alpha = kargs.get("alpha", 0.1)

    def forward(self, logits):
        lprobs = torch.log_softmax(logits, dim=-1)
        return -self.alpha * torch.mean(lprobs)


class CR(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.alpha = kargs.get("alpha", 0.1)
        dataloader = kargs.get("dataloader", None)
        assert dataloader is not None
        prior = scatter_sum(torch.ones_like(dataloader.dataset.targets).float(),
                            dataloader.dataset.targets.long())
        prior /= prior.sum()
        self.register_buffer("prior", prior.unsqueeze(0))

    def forward(self, logits):
        lprobs = torch.log_softmax(logits, dim=-1)
        return self.alpha * torch.sum(self.prior * lprobs, dim=-1).mean()


class RMSE(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.alpha = kargs.get("alpha", 0.1)

    def forward(self, logits):
        probs = torch.softmax(logits, dim=-1)
        return self.alpha * torch.mean(probs.square())
