"""
Implementation of Focal Loss with adaptive gamma.
Reference:
[1]  T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar, Focal loss for dense object detection.
     arXiv preprint arXiv:1708.02002, 2017.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from scipy.special import lambertw
import numpy as np


def get_gamma(p=0.2):
    """
    Get the gamma for a given pt where the function g(p, gamma) = 1
    """
    y = ((1 - p) ** (1 - (1 - p) / (p * np.log(p))) / (p * np.log(p))) * np.log(1 - p)
    gamma_complex = (1 - p) / (p * np.log(p)) + lambertw(-y + 1e-12, k=-1) / np.log(
        1 - p
    )
    gamma = np.real(gamma_complex)  # gamma for which p_t > p results in g(p_t,gamma)<1
    return gamma


ps = [0.2, 0.5]
gammas = [5.0, 3.0]
i = 0
gamma_dic = {}
for p in ps:
    gamma_dic[p] = gammas[i]
    i += 1


class FocalLossAdaptive(nn.Module):
    def __init__(self, device, gamma=0, size_average=False):
        super(FocalLossAdaptive, self).__init__()
        self.size_average = size_average
        self.gamma = gamma
        self.device = device

    def get_gamma_list(self, pt):
        gamma_list = []
        batch_size = pt.shape[0]
        for i in range(batch_size):
            pt_sample = pt[i].item()
            if pt_sample >= 0.5:
                gamma_list.append(self.gamma)
                continue
            # Choosing the gamma for the sample
            for key in sorted(gamma_dic.keys()):
                if pt_sample < key:
                    gamma_list.append(gamma_dic[key])
                    break
        return torch.Tensor(gamma_list).to(self.device)

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)
        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()
        gamma = self.get_gamma_list(pt)
        loss = -1 * (1 - pt) ** gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()
