import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np

import pdb

from lib.utils.utils import to_one_hot


class CrossEntropy(nn.Module):
    def __init__(self):
        super(CrossEntropy, self).__init__()

    def forward(self, output, label, reduction='mean'):
        loss = F.cross_entropy(output, label, reduction=reduction)
        return loss


class CrossEntropy_binary(nn.Module):
    def __init__(self):
        super(CrossEntropy_binary, self).__init__()

    def forward(self, output, label, reduction='mean'):
        binary_targets = to_one_hot(label.cpu(), output.size(1)).to(label.device)
        loss_cls = F.binary_cross_entropy_with_logits(input=output, target=binary_targets,
                                                      reduction='none'
                                                      ).sum(dim=1)
        if "mean" in reduction:
            loss_cls = loss_cls.mean()
        return loss_cls


class loss_fn_kd_KL(nn.Module):
    def __init__(self):
        super(loss_fn_kd_KL, self).__init__()

    def forward(self, scores, target_scores, T=2., reduction='mean'):
        log_scores = F.log_softmax(scores / T, dim=1)
        targets = F.softmax(target_scores / T, dim=1)
        criterion = torch.nn.KLDivLoss(reduction="none")
        loss_cls = criterion(log_scores, targets).sum(dim=1)
        if 'mean' in reduction:
            loss_cls = loss_cls.mean()
        return loss_cls


def loss_fn_kd(scores, target_scores, T=2., reduction="mean"):
    """Compute knowledge-distillation (KD) loss given [scores] and [target_scores].

    Both [scores] and [target_scores] should be tensors, although [target_scores] should be repackaged.
    'Hyperparameter': temperature"""

    device = scores.device

    log_scores_norm = F.log_softmax(scores / T, dim=1)
    targets_norm = F.softmax(target_scores / T, dim=1)

    # if [scores] and [target_scores] do not have equal size, append 0's to [targets_norm]
    # n = scores.size(1)
    assert len(scores) == len(target_scores) and scores.size(1) == target_scores.size(1)
    # if n > target_scores.size(1):
    #     n_batch = scores.size(0)
    #     zeros_to_add = torch.zeros(n_batch, n - target_scores.size(1))
    #     zeros_to_add = zeros_to_add.to(device)
    #     targets_norm = torch.cat([targets_norm.detach(), zeros_to_add], dim=1)

    # Calculate distillation loss (see e.g., Li and Hoiem, 2017)
    KD_loss_unnorm = -(targets_norm * log_scores_norm)
    KD_loss_unnorm = KD_loss_unnorm.sum(dim=1)  # --> sum over classes
    if reduction == "mean":
        KD_loss_unnorm = KD_loss_unnorm.mean()  # --> average over batch

    # normalize
    KD_loss = KD_loss_unnorm * T ** 2

    return KD_loss


def loss_fn_kd_binary(scores, target_scores, T=2., reduction="mean"):
    """Compute binary knowledge-distillation (KD) loss given [scores] and [target_scores].

    Both [scores] and [target_scores] should be tensors, although [target_scores] should be repackaged.
    'Hyperparameter': temperature"""

    device = scores.device

    scores_norm = torch.sigmoid(scores / T)
    targets_norm = torch.sigmoid(target_scores / T)
    assert len(scores) == len(target_scores) and scores.size(1) == target_scores.size(1)
    # if [scores] and [target_scores] do not have equal size, append 0's to [targets_norm]
    # n = scores.size(1)
    # if n > target_scores.size(1):
    #     n_batch = scores.size(0)
    #     zeros_to_add = torch.zeros(n_batch, n - target_scores.size(1))
    #     zeros_to_add = zeros_to_add.to(device)
    #     targets_norm = torch.cat([targets_norm, zeros_to_add], dim=1)

    # Calculate distillation loss
    KD_loss_unnorm = -(targets_norm * torch.log(scores_norm) + (1 - targets_norm) * torch.log(1 - scores_norm))
    KD_loss_unnorm = KD_loss_unnorm.sum(dim=1)  # --> sum over classes
    if reduction == "mean":
        KD_loss_unnorm = KD_loss_unnorm.mean()  # --> average over batch

    # normalize
    KD_loss = KD_loss_unnorm * T ** 2

    return KD_loss


def ib_loss(input_values, ib, reduction="mean"):
    """Computes the focal loss"""
    loss = input_values * ib
    if "mean" in reduction:
        return loss.mean()
    else:
        return loss


class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000., active_classes_num=100):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight
        self.active_classes_num = active_classes_num

    def forward(self, input, target, features, reduction="mean"):
        features = torch.sum(torch.abs(features), 1).reshape(-1, 1)
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, self.active_classes_num)), 1)  # N * 1
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib, reduction=reduction)


class trade_off_IB_Loss(nn.Module):
    def __init__(self, weight=None, alpha=10000., beta=0.000001, active_classes_num=100):
        super(trade_off_IB_Loss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight
        self.active_classes_num = active_classes_num
        self.beta = beta

    def forward(self, loss_vec, input, previous_model_output, target, features, reduction="mean"):
        features = torch.sum(torch.abs(features), 1).reshape(-1, 1)
        old_classes_num = previous_model_output.size(1)
        input_softmax = F.softmax(input, dim=1)
        input_softmax_for_distill = input_softmax[:, :old_classes_num]
        previous_model_output_softmax = F.softmax(previous_model_output, dim=1)
        target_onehot = F.one_hot(target, self.active_classes_num)
        kd_grads = torch.sum(torch.abs(input_softmax_for_distill - previous_model_output_softmax), 1)  # N * 1
        cls_grads = torch.sum(torch.abs(input_softmax - target_onehot), 1)  # N * 1
        grads = cls_grads + self.beta * kd_grads
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(loss_vec, ib, reduction=reduction)


class mixup_trade_off_IB_Loss(nn.Module):
    def __init__(self, weight=None, alpha=10000., beta=0.000001, active_classes_num=100):
        super(mixup_trade_off_IB_Loss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight
        self.active_classes_num = active_classes_num
        self.beta = beta

    def forward(self, loss_vec, input, previous_model_output, mixup_label_a, mixup_label_b, all_lams, features,
                reduction="mean"):
        features = torch.sum(torch.abs(features), 1).reshape(-1, 1)
        old_classes_num = previous_model_output.size(1)
        input_softmax = F.softmax(input, dim=1)
        input_softmax_for_distill = input_softmax[:, :old_classes_num]
        previous_model_output_softmax = F.softmax(previous_model_output, dim=1)
        target_a_onehot = F.one_hot(mixup_label_a, self.active_classes_num)
        target_b_onehot = F.one_hot(mixup_label_b, self.active_classes_num)
        all_lams = all_lams.reshape((input.size(0), 1))
        target_onehot = all_lams * target_a_onehot + (1 - all_lams) * target_b_onehot
        kd_grads = torch.sum(torch.abs(input_softmax_for_distill - previous_model_output_softmax), 1)  # N * 1
        cls_grads = torch.sum(torch.abs(input_softmax - target_onehot), 1)  # N * 1
        grads = cls_grads + self.beta * kd_grads
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(loss_vec, ib, reduction=reduction)
