import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

eps = 1e-7


def abs_e_delta_y(k):
    sigma = 1
    mu = 0
    v = (math.exp(sigma ** 2) - 1) * math.exp(2 * mu + sigma ** 2)
    e = math.exp(mu + 0.5 * sigma ** 2)
    return math.fabs(mu - math.log((k - 1) * e) + (k - 1) * v / (2 * ((k - 1) * e) ** 2))


def get_delta_y(pred, y):
    y = y.unsqueeze(-1)
    s = torch.gather(pred, -1, y).squeeze()
    lse = torch.logsumexp(pred.scatter(-1, y, float("-inf")), -1)
    return s - lse


def gather_target(value, y):
    return torch.gather(value, -1, y.unsqueeze(-1)).squeeze()


def get_target_p(pred, y):
    return gather_target(torch.softmax(pred, dim=-1), y)


class WeightedLoss(nn.Module):
    def __init__(self, a=-1, q=-1, fix="none", tau=1):
        super(WeightedLoss, self).__init__()
        self.maximum = 1
        self.fix = fix
        self.tau = tau
        self.e = 0
        self.w = lambda p: p

    def update(self, num_class):
        p = torch.linspace(1e-8, 1, 1000)
        self.maximum = self.w(p).max()
        self.e = abs_e_delta_y(num_class)

    def get_loss(self, pred, y):
        raise NotImplementedError

    def get_w(self, delta_y, pred=None, y=None):
        p = torch.sigmoid(delta_y)
        return self.w(p)

    def forward(self, pred, y):
        delta_y = get_delta_y(pred, y)
        if self.fix == "shift":
            w = self.get_w(delta_y.detach() + (self.e - self.tau))
        elif self.fix == "scale":
            w = self.get_w(delta_y.detach() * (self.tau / self.e))
        else:
            w = self.get_w(delta_y.detach())
        w = w / self.maximum
        return -delta_y * w, w, delta_y


class CE(WeightedLoss):
    def __init__(self, a=-1, q=-1, fix="none", tau=1):
        super(CE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q < 0
        self.w = lambda p: 1 - p

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        return -torch.log(p)


class FL(WeightedLoss):
    def __init__(self, a=-1, q=-1, fix="none", tau=1):
        super(FL, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q > 0
        self.q = q
        self.w = lambda p: (1 - p) ** self.q * (1 - p - self.q * p * p.log())

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        return -(1 - p) ** self.q * torch.log(p)


class MAE(WeightedLoss):
    def __init__(self, a=-1, q=-1, fix="none", tau=1):
        super(MAE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q < 0
        self.w = lambda p: p * (1 - p)

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        return 1 - p


class AGCE(WeightedLoss):
    def __init__(self, a=1, q=2, fix="none", tau=1):
        super(AGCE, self).__init__(a, q, fix, tau)
        assert a > 0
        assert q > 0
        self.a = a
        self.q = q
        self.w = lambda p: p * torch.pow(self.a + p, self.q - 1) * (1 - p)

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        loss = (self.a + 1 - torch.pow(self.a + p, self.q)) / self.q
        return loss


class AUL(WeightedLoss):
    def __init__(self, a=1.5, q=0.9, fix="none", tau=1):
        super(AUL, self).__init__(a, q, fix, tau)
        assert a > 1
        assert q > 0
        self.a = a
        self.q = q
        self.w = lambda p: p * (1 - p) * torch.pow(self.a - p, self.q - 1)

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        loss = (torch.pow(self.a - p, self.q) - (self.a - 1) ** self.q) / self.q
        return loss


class AEL(WeightedLoss):
    def __init__(self, a=-1, q=3, fix="none", tau=1):
        super(AEL, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q > 0
        self.q = q
        self.w = lambda p: p * (1 - p) * torch.exp(-p / self.q) / self.q

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        loss = torch.exp(-p / self.q)
        return loss


class GCE(WeightedLoss):
    def __init__(self, a=-1, q=0.7, fix="none", tau=1):
        super(GCE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert 0 < q <= 1
        self.q = q
        self.w = lambda p: torch.pow(p, self.q) * (1 - p)

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        loss = (1 - torch.pow(p, self.q)) / self.q
        return loss


class TCE(WeightedLoss):
    def __init__(self, a=-1, q=2, fix="none", tau=1):
        super(TCE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q >= 1
        self.q = q
        self.w = lambda p: p * sum((1 - p) ** i for i in range(1, int(self.q) + 1))

    def get_loss(self, pred, y):
        p = get_target_p(pred, y)
        loss = sum((1 - p) ** i / i for i in range(1, int(self.q) + 1))
        return loss


class SCE(WeightedLoss):
    def __init__(self, a=-1, q=1, fix="none", tau=1):
        super(SCE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert 0 < q <= 1
        self.alpha = 1 - q
        self.beta = q
        self.ce = CE()
        self.mae = MAE()
        self.w = lambda p: self.alpha * self.ce.w(p) + self.beta * self.mae.w(p)

    def get_loss(self, pred, y):
        return self.alpha * self.ce.get_loss(pred, y) + self.beta * self.mae.get_loss(pred, y)


class NCE(WeightedLoss):
    def __init__(self, a=-1, q=-1, fix="none", tau=1):
        super(NCE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert q < 0
        self.num_class = -1
        assert self.fix == "none", "fix for underfitting not implemented for NCE"

    def get_loss(self, pred, y):
        lprobs = torch.log_softmax(pred, dim=-1)
        loss = gather_target(lprobs, y) / lprobs.sum(dim=-1)
        return loss

    def get_w(self, delta_y, pred=None, y=None):
        assert pred is not None
        assert y is not None
        logprobs = torch.log_softmax(pred.detach(), dim=-1)
        logp = gather_target(logprobs, y)
        gamma = -1 / logprobs.sum(dim=-1)
        epsilon = self.num_class * logp / logprobs.sum(dim=-1)
        w = 2 * gamma * (1 - logp.exp()) * epsilon
        if torch.isinf(w).any():
            raise ValueError
        return w

    def update(self, num_class):
        p = torch.linspace(0, 1, 1000)
        self.maximum = 1 / (num_class * math.log(num_class))
        self.e = abs_e_delta_y(num_class)
        self.num_class = num_class

    def forward(self, pred, y):
        loss = self.get_loss(pred, y)
        delta_y = get_delta_y(pred.detach(), y)
        w = self.get_w(delta_y, pred, y)
        return loss / self.maximum, w / self.maximum, delta_y


class NCEMAE(WeightedLoss):
    def __init__(self, a=-1, q=1, fix="none", tau=1):
        super(NCEMAE, self).__init__(a, q, fix, tau)
        assert a < 0
        assert 0 < q <= 1
        self.alpha = 1 - q
        self.beta = q
        self.nce = NCE()
        self.mae = MAE()
        assert self.fix == "none", "fix for underfitting not implemented for NCEMAE"
        print("Sample weights for NCEMAE is undefined. Use weights of NCE as surrogate to avoid exception.")

    def update(self, num_class):
        self.nce.update(num_class)
        self.mae.update(num_class)
        self.e = abs_e_delta_y(num_class)
        self.maximum = max(self.alpha, self.beta)

    def get_w(self, delta_y, pred=None, y=None):
        # Sample weights for NCEMAE is undefined. Use weights of NCE as surrogate to avoid exception
        return self.nce.get_w(delta_y, pred, y) / self.nce.maximum

    def get_loss(self, pred, y):
        return self.alpha * self.nce.get_loss(pred, y) / self.nce.maximum + \
               self.beta * self.mae.get_loss(pred, y) / self.mae.maximum

    def forward(self, pred, y):
        loss = self.get_loss(pred, y)
        delta_y = get_delta_y(pred.detach(), y)
        w = self.get_w(delta_y, pred, y)
        return loss / self.maximum, w / self.maximum, delta_y
