import math
import torch
from .base import WeightedLoss, get_delta


# unrobust
class CE(WeightedLoss):
    def __init__(self, **kargs):
        self.max_w = 1
        super().__init__()

    def w(self, delta):
        return 1 - torch.sigmoid(delta)


class FL(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 1.0)
        assert self.q > 0
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return (1 - p) ** self.q * \
            (1 - p - self.q * p * p.log())


# symmetric
class MAE(WeightedLoss):
    def __init__(self, **kargs):
        self.max_w = 0.25
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return p * (1 - p)


class NCES(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 0.1 / math.log(10))  # 1/klogk
        self.max_w = self.q
        super().__init__()

    def forward(self, logits, y):
        delta_y = get_delta(logits, y)
        gamma = -torch.log_softmax(logits.detach(), dim=-1).sum(dim=-1)
        wce = 1 - torch.sigmoid(delta_y.detach())
        w = gamma * wce / self.max_w
        return -delta_y * w


# asymmetric
class AGCE(WeightedLoss):
    def __init__(self, **kargs):
        self.a = kargs.get("a", 1)
        self.q = kargs.get("q", 2)
        assert self.a > 0
        assert self.q > 0
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return p * torch.pow(self.a + p, self.q - 1) * (1 - p)


class AUL(WeightedLoss):
    def __init__(self, **kargs):
        self.a = kargs.get("a", 1.5)
        self.q = kargs.get("q", 0.9)
        assert self.a > 1
        assert self.q > 0
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return p * (1 - p) * torch.pow(self.a - p, self.q - 1)


class AEL(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 3)
        assert self.q > 0
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return p * (1 - p) * torch.exp(-p / self.q) / self.q


class GCE(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 0.7)
        assert 0 < self.q <= 1
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return torch.pow(p, self.q) * (1 - p)


class TCE(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 2)
        assert self.q >= 1
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return p * sum((1 - p) ** i for i in range(1, int(self.q) + 1))


class SCE(WeightedLoss):
    def __init__(self, **kargs):
        self.q = kargs.get("q", 0.95)
        assert 0 <= self.q <= 1
        super().__init__()

    def w(self, delta):
        p = torch.sigmoid(delta)
        return (1 - self.q + self.q * p) * (1 - p)
