
import torch
from torch import nn
import torch.nn.functional as F


class CriterionWrapper(nn.Module):

    def __init__(self, criterion, output_key='logits', target_key='y', weight_key=None):
        super(CriterionWrapper, self).__init__()
        self.criterion = criterion
        self.output_key = output_key
        self.target_key = target_key
        self.weight_key = weight_key

    def forward(self, model_output, target):
        logits = model_output[self.output_key]

        if self.target_key is None:
            loss = self.criterion(logits)
        else:
            loss = self.criterion(logits, target[self.target_key])

        if self.weight_key is not None:
            loss *= target[self.weight_key]
            loss = loss.mean()

        return loss


class SoftmaxCELoss(nn.Module):

    def __init__(self):
        super(SoftmaxCELoss, self).__init__()

    def forward(self, input, target):
        return -torch.sum(target * torch.log_softmax(input, dim=-1), dim=-1).mean()


class GanDiscriminatorCriterion(nn.Module):

    def __init__(self):
        super(GanDiscriminatorCriterion, self).__init__()
        self.criterion = nn.BCELoss()

    def forward(self, model_output, target):

        real_loss = self.criterion(model_output["logits_val"], target["y_val"])
        fake_loss = self.criterion(model_output["logits_hat_det"], target["y_hat"])
        d_loss = (real_loss + fake_loss) / 2
        return d_loss


class MixupCriterionWrapper(nn.Module):

    def __init__(self, criterion):
        super(MixupCriterionWrapper, self).__init__()
        self.criterion = criterion

    def forward(self, input, target):
        logits = input["logits"]

        if 'y' in target and 'y2' in target:
            y1, y2 = target['y'], target['y2']
            l = target['l']
            loss = l * self.criterion(logits, y1) * (1.0 - l) * self.criterion(logits, y2)

        else:
            loss = self.criterion(logits, target['y'])

        return loss.mean()


class MixupCCE(nn.Module):

    def __init__(self):
        super(MixupCCE, self).__init__()

    def forward(self, input, target):
        logits = input["logits"]
        log_softmax = logits.log_softmax(1)

        if 'y' in target and 'y2' in target:
            y1, y2 = target['y'], target['y2']
            l = target['l']

            batch_size = len(y1)
            target_matrix = torch.zeros_like(logits)
            for i in range(batch_size):
                target_matrix[i, y1[i]] = l[i]
                target_matrix[i, y2[i]] = (1.0 - l[i])

        else:
            y = target['y']

            batch_size = len(y)
            target_matrix = torch.zeros_like(logits)
            for i in range(batch_size):
                target_matrix[i, y[i]] = 1.0

        loss = -torch.sum(target_matrix * log_softmax, 1)

        return loss.mean()


class FocalLoss(nn.Module):

    def __init__(self, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, input, target):
        p = input

        # implementation 1
        # eps = 1e-7
        # p = p.clamp(eps, 1. - eps)
        # p_t = target * p + (1.0 - target) * (1.0 - p)
        # loss = -((1.0 - p_t) ** self.gamma) * torch.log(p_t)

        # implementation 2
        p_t = target * p + (1.0 - target) * (1.0 - p)
        focal_weights = (1.0 - p_t) ** self.gamma

        # this is a nasty hack to be able to use F.binary_cross_entropy
        focal_weights = focal_weights.detach().cpu().numpy()
        focal_weights = torch.from_numpy(focal_weights).cuda()

        loss = F.binary_cross_entropy(p, target, reduction='none', weight=focal_weights)

        return loss.mean()


class ClassWeightedBCELoss(nn.Module):

    def __init__(self, weight=None):
        super(ClassWeightedBCELoss, self).__init__()
        self.weight = weight

    def forward(self, input, target):
        p = input

        loss = F.binary_cross_entropy(p, target, reduction='none', weight=None)
        loss *= self.weight

        return loss.mean()


class BCECCELoss(nn.Module):

    def __init__(self, weight_bce=1.0, weigth_cce=1.0):
        super(BCECCELoss, self).__init__()
        self.weight_bce = weight_bce
        self.weigth_cce = weigth_cce

    def forward(self, input, target):
        p = input

        loss_bce = F.binary_cross_entropy_with_logits(p, target, reduction='elementwise_mean')

        net_out = torch.log(p / (1.0 - p))
        soft_target = target / torch.sum(target, 1, keepdim=True)
        loss_cce = torch.mean(torch.sum(-soft_target * nn.LogSoftmax(dim=1)(net_out), 1))

        loss = self.weight_bce * loss_bce + self.weigth_cce * loss_cce

        return loss


if __name__ == "__main__":
    """ main """
    import torch
    import numpy as np

    pred = np.asarray([[0.1, 0, 0, 0, 0.7, 0.66],
                       [0.2, 0, 0, 0, 0.8, 0.9]], dtype=np.float32)
    targ = np.asarray([[0, 0, 0.8, 0, 0, 0.2],
                       [0.7, 0, 0, 0, 0.3, 0]], dtype=np.float32)

    pred_dict = {'logits': torch.from_numpy(pred)}
    targ_dict = {'y': torch.from_numpy(np.asarray([2, 0], dtype=np.int)),
                 'y2': torch.from_numpy(np.asarray([5, 4], dtype=np.int)),
                 'l': torch.from_numpy(np.asarray([0.8, 0.7], dtype=np.float32))}

    pytorch_creterion = MixupCriterionWrapper(nn.CrossEntropyLoss(reduction='none'))
    loss = pytorch_creterion(pred_dict, targ_dict)
    print(loss)

    lasagne_creterion = MixupCCE()
    loss = lasagne_creterion(pred_dict, targ_dict)
    print(loss)

    # gamma = 2.0
    #
    # pred = np.asarray([[0.1, 0, 0, 0, 0.7, 0.66],
    #                    [0.2, 0, 0, 0, 0.8, 0.9]], dtype=np.float32)
    # targ = np.asarray([[1, 0, 1, 0, 0, 1],
    #                    [1, 0, 1, 0, 1, 0]], dtype=np.float32)
    #
    # eps = 1e-8
    # pred = pred.clip(eps, 1. - eps)
    #
    # pred = torch.from_numpy(pred)
    # targ = torch.from_numpy(targ)
    #
    # # -------------------------------------------------
    # # class weighted loss
    #
    # weight = torch.from_numpy(0.5 * np.ones((1, targ.shape[1]), dtype=np.float32))
    # criterion = ClassWeightedBCELoss(weight=weight)
    # loss = criterion.forward(pred, targ)
    # print(loss.mean())

    # -------------------------------------------------
    # focal loss

    # criterion = torch.nn.BCELoss(reduction='none')
    # loss = criterion.forward(pred, targ)
    # print(loss.mean())
    #
    # focal_loss = FocalLoss(gamma=gamma)
    # loss = focal_loss.forward(pred, targ)
    # print(loss.mean())
    #
    # p_t = targ * pred + (1.0 - targ) * (1.0 - pred)
    # focal_weights = (1 - p_t) ** gamma
    #
    # criterion = torch.nn.BCELoss(reduction='none', weight=focal_weights)
    # loss = criterion.forward(pred, targ)
    # print(loss.mean())
