import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def focal_loss(input_values, gamma):
    """Computes the focal loss"""
    p = torch.exp(-input_values)
    loss = (1 - p) ** gamma * input_values
    return loss.mean()

def label_smooth(labels, class_count, epsilon):
    assert 0 <= epsilon <= 1.0
    labels = F.one_hot(labels.to(torch.int64), class_count).float()
    
    confidence = 1.0 - epsilon
    smooth_labels = labels * confidence + epsilon / class_count
    return smooth_labels

class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0):
        super().__init__()
        assert gamma >= 0
        self.gamma = gamma
        self.weight = weight

    def forward(self, logit, target):
        return focal_loss(F.cross_entropy(logit, target, reduction='none', weight=self.weight), self.gamma)


class LDAMLoss(nn.Module):
    def __init__(self, cls_num_list, max_m=0.5, s=30):
        super().__init__()
        m_list = 1.0 / torch.sqrt(torch.sqrt(cls_num_list))
        m_list = m_list * (max_m / torch.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        self.s = s

    def forward(self, logit, target):
        index = torch.zeros_like(logit, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)

        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        logit_m = logit - batch_m * self.s  # scale only the margin, as the logit is already scaled.

        output = torch.where(index, logit_m, logit)
        return F.cross_entropy(output, target)


class ClassBalancedLoss(nn.Module):
    def __init__(self, cls_num_list, beta=0.9):
        super().__init__()
        per_cls_weights = (1.0 - beta) / (1.0 - (beta ** cls_num_list))
        per_cls_weights = per_cls_weights / torch.mean(per_cls_weights)
        self.per_cls_weights = per_cls_weights
    
    def forward(self, logit, target):
        logit = logit.to(self.per_cls_weights.dtype)
        return F.cross_entropy(logit, target, weight=self.per_cls_weights)


class GeneralizedReweightLoss(nn.Module):
    def __init__(self, cls_num_list, exp_scale=1.0):
        super().__init__()
        cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
        per_cls_weights = 1.0 / (cls_num_ratio ** exp_scale)
        per_cls_weights = per_cls_weights / torch.mean(per_cls_weights)
        self.per_cls_weights = per_cls_weights
    
    def forward(self, logit, target):
        logit = logit.to(self.per_cls_weights.dtype)
        return F.cross_entropy(logit, target, weight=self.per_cls_weights)


class BalancedSoftmaxLoss(nn.Module):
    def __init__(self, cls_num_list):
        super().__init__()
        cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
        log_cls_num = torch.log(cls_num_ratio)
        self.log_cls_num = log_cls_num

    def forward(self, logit, target):
        logit_adjusted = logit + self.log_cls_num.unsqueeze(0)
        return F.cross_entropy(logit_adjusted, target)


class LogitAdjustedLoss(nn.Module):
    def __init__(self, cls_num_list, tau=2.0):
        super().__init__()
        cls_num_ratio = cls_num_list / torch.sum(cls_num_list)    
        log_cls_num = torch.log(cls_num_ratio)
        self.log_cls_num = log_cls_num
        self.tau = tau

    def forward(self, logit, target):
        logit_adjusted = logit + self.tau * self.log_cls_num.unsqueeze(0)
        
        return F.cross_entropy(logit_adjusted, target)

class LabelSmoothLoss(nn.Module):
    def __init__(self, criterion, class_count, epsilon):
        super().__init__()
        assert 0 <= epsilon <= 1.0
        self.criterion = criterion
        self.class_count = class_count
        self.epsilon = epsilon
    
    def forward(self, logit, target):
        onehot_target = F.one_hot(target, self.class_count).float()
    
        confidence = 1.0 - self.epsilon
        smooth_target =  onehot_target * confidence + self.epsilon / self.class_count
        
        return self.criterion(logit, smooth_target)
        
        
class LADELoss(nn.Module):
    def __init__(self, cls_num_list, remine_lambda=0.01, estim_loss_weight=0.1):
        super().__init__()
        self.num_classes = len(cls_num_list)
        self.prior = cls_num_list / torch.sum(cls_num_list)

        self.balanced_prior = torch.tensor(1. / self.num_classes).float().to(self.prior.device)
        self.remine_lambda = remine_lambda

        self.cls_weight = cls_num_list / torch.sum(cls_num_list)
        self.estim_loss_weight = estim_loss_weight

    def mine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        N = x_p.size(-1)
        first_term = torch.sum(x_p, -1) / (num_samples_per_cls + 1e-8)
        second_term = torch.logsumexp(x_q, -1) - np.log(N)
 
        return first_term - second_term, first_term, second_term

    def remine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        loss, first_term, second_term = self.mine_lower_bound(x_p, x_q, num_samples_per_cls)
        reg = (second_term ** 2) * self.remine_lambda
        return loss - reg, first_term, second_term

    def forward(self, logit, target):
        logit_adjusted = logit + torch.log(self.prior).unsqueeze(0)
        ce_loss =  F.cross_entropy(logit_adjusted, target)

        per_cls_pred_spread = logit.T * (target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target))  # C x N
        pred_spread = (logit - torch.log(self.prior + 1e-9) + torch.log(self.balanced_prior + 1e-9)).T  # C x N

        num_samples_per_cls = torch.sum(target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target), -1).float()  # C
        estim_loss, first_term, second_term = self.remine_lower_bound(per_cls_pred_spread, pred_spread, num_samples_per_cls)
        estim_loss = -torch.sum(estim_loss * self.cls_weight)

        return ce_loss + self.estim_loss_weight * estim_loss

    

class DistillationLoss(torch.nn.Module):
    """
    This module wraps a standard criterion and adds an extra knowledge distillation loss by
    taking a teacher model prediction and using it as additional supervision.
    """
    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
                 distillation_type: str, alpha: float, temperature: float, cls_num_list):
        super().__init__()
        self.base_criterion = base_criterion
        self.teacher_model = teacher_model
        assert distillation_type in ['none', 'soft', 'hard']
        self.distillation_type = distillation_type
        self.alpha = alpha
        self.temperature = temperature
        
        cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
        log_cls_num = torch.log(cls_num_ratio)
        self.log_cls_num = log_cls_num

    def forward(self, inputs, outputs, labels):     # inputs = images
        """
        Args:
            inputs: The original inputs that are feed to the teacher model
            outputs: the outputs of the model to be trained. It is expected to be
                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
                in the first position and the distillation predictions as the second output
            labels: the labels for the base criterion
        """
        outputs_kd = None
        if not isinstance(outputs, torch.Tensor):
            # assume that the model outputs a tuple of [outputs, outputs_kd]
            outputs, outputs_kd = outputs
        
        # outputs_kd  None or dict(1 or n)
        
        #outputs label smooth
        #class_count = outputs.shape[1]
        #labels = label_smooth(labels, class_count, 0.05)
        
        base_loss = self.base_criterion(outputs, labels)    #(batchsize, 100)  (100)
        
        if self.distillation_type == 'none':
            return base_loss

        if outputs_kd is None:
            raise ValueError("When knowledge distillation is enabled, the model is "
                             "expected to return a Tuple[Tensor, Tensor] with the output of the "
                             "class_token and the dist_token")
        # don't backprop throught the teacher
        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs)

        def compute_distillation_loss(outputs_kd, teacher_outputs):
            if self.distillation_type == 'soft':
                T = self.temperature
                # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
                # with slight modifications
                distillation_loss = F.kl_div(
                    F.log_softmax(outputs_kd / T, dim=1),
                    #We provide the teacher's targets in log probability because we use log_target=True 
                    #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
                    #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
                    F.log_softmax(teacher_outputs / T, dim=1),
                    reduction='sum',
                    log_target=True
                ) * (T * T) / outputs_kd.numel()
                #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 
                #But we also experiments output_kd.size(0) 
                #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
            elif self.distillation_type == 'hard':
                distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
            
            return distillation_loss
        
        if True:
            # outputs_kd = sum(outputs_kd.values()) / len(outputs_kd)     
            # distillation_loss = compute_distillation_loss(outputs_kd, teacher_outputs)
            # loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
            # return loss
            distillation_loss = 0
            for v in list(outputs_kd.values()):
                dis_loss = compute_distillation_loss(v, teacher_outputs)
                distillation_loss += dis_loss 
            distillation_loss = distillation_loss / len(outputs_kd)
            
            loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
            return loss

        
        # if len(outputs_kd) == 1:
        #     outputs_kd = outputs_kd["dist"]
        #     #outputs_kd +=  self.log_cls_num.unsqueeze(0)
        #     distillation_loss = compute_distillation_loss(outputs_kd, teacher_outputs)
        #     loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
        #     return loss

        # elif len(outputs_kd) == 2:
        #     outputs_kd_head, outputs_kd_tail = outputs_kd["dist_head"], outputs_kd["dist_tail"]
            
        #     #outputs_kd_head += self.log_cls_num.unsqueeze(0)
        #     distillation_loss_head = compute_distillation_loss(outputs_kd_head, teacher_outputs)
            
        #     outputs_kd_tail +=  self.log_cls_num.unsqueeze(0)
        #     distillation_loss_tail = compute_distillation_loss(outputs_kd_tail, teacher_outputs)
        #     distillation_loss = (distillation_loss_head + distillation_loss_tail) / 2
            
        #     loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
        #     return loss