import math
import torch
import torch.nn as nn


def expected_delta(k):
    sigma = 1
    mu = 0
    v = (math.exp(sigma ** 2) - 1) * math.exp(2 * mu + sigma ** 2)
    e = math.exp(mu + 0.5 * sigma ** 2)
    return mu - math.log((k - 1) * e) + (k - 1) * v / (2 * ((k - 1) * e) ** 2)


def get_delta(logits, y):
    y = y.unsqueeze(-1)
    s = torch.gather(logits, -1, y).squeeze()
    lse = torch.logsumexp(logits.scatter(-1, y, float("-inf")), -1)
    return s - lse


def p_to_delta(p):
    return torch.log(-p / (p - 1) + 1e-7)


class WeightedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        if getattr(self, "max_w", None) is None:
            self.max_w = self.w(p_to_delta(torch.linspace(1e-6, 1-1e-6, 2000))).max().item()

    def w(self, delta_y):
        raise NotImplementedError

    def forward(self, logits, y):
        delta_y = get_delta(logits, y)
        w = self.w(delta_y.detach()) / self.max_w
        return -delta_y * w


class ShiftedWeightedLoss(nn.Module):
    def __init__(self, loss, num_class, a):
        super().__init__()
        self.shift = expected_delta(num_class) + a
        self.loss = loss

    def w(self, delta_y):
        return self.loss.w(delta_y - self.shift) / self.loss.max_w

    def forward(self, logits, y):
        delta_y = get_delta(logits, y)
        w = self.w(delta_y.detach())
        return -delta_y * w


class ScaledWeightedLoss(nn.Module):
    def __init__(self, loss, num_class, a):
        super().__init__()
        self.scale = a / math.fabs(expected_delta(num_class))
        self.loss = loss

    def w(self, delta_y):
        return self.loss.w(delta_y * self.scale) / self.loss.max_w

    def forward(self, logits, y):
        delta_y = get_delta(logits, y)
        w = self.w(delta_y.detach())
        return -delta_y * w
