import torch
from torch import nn


class MAE_(nn.Module):
    def __init__(self, **kargs):
        super().__init__()

    def forward(self, logits, y):
        probs = torch.softmax(logits, dim=-1)
        target_probs = torch.gather(probs, -1, y.unsqueeze(-1)).squeeze()
        losses = 1 - target_probs
        return losses


class AGCE_(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.a = kargs.get("a", 1)
        self.q = kargs.get("q", 2)
        assert self.a > 0
        assert self.q > 0


    def forward(self, logits, y):
        probs = torch.softmax(logits, dim=-1)
        target_probs = torch.gather(probs, -1, y.unsqueeze(-1)).squeeze()
        losses = ((self.a + 1) ** self.q - (self.a + target_probs) ** self.q) / self.q
        return losses


class NCE(nn.Module):
    def __init__(self, **kargs):
        super().__init__()

    def forward(self, logits, y):
        lprobs = torch.log_softmax(logits, dim=-1)
        target_lprobs = torch.gather(lprobs, -1, y.unsqueeze(-1)).squeeze()
        losses = target_lprobs / lprobs.sum(dim=-1)
        return losses


class NCEAGCE(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.agce = AGCE_(**kargs)
        self.nce = NCE(**kargs)


    def forward(self, logits, y):
        return self.nce(logits, y) + self.agce(logits, y)


class NCEMAE(nn.Module):
    def __init__(self, **kargs):
        super().__init__()
        self.mae = MAE_(**kargs)
        self.nce = NCE(**kargs)


    def forward(self, logits, y):
        return self.nce(logits, y) + self.mae(logits, y)
