import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def loss_kd(outputs, labels, teacher_outputs, params):
    """
    loss function for Knowledge Distillation (KD)
    """
    alpha = params.alpha
    T = params.temperature

    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (T * T)
    KD_loss =  (1. - alpha)*loss_CE + alpha*D_KL

    return KD_loss

def loss_kd_self(outputs, labels, teacher_outputs, params):
    """
    loss function for self training: Tf-KD_{self}
    """
    alpha = params.alpha
    T = params.temperature

    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (T * T) * params.multiplier  # multiple is 1.0 in most of cases, some cases are 10 or 50
    KD_loss =  (1. - alpha)*loss_CE + alpha*D_KL

    return KD_loss

#
def loss_kd_regularization(outputs, labels, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = params.alpha_L
    # alpha = 0.5
    T = params.temperature_L
    # import ipdb
    # ipdb.set_trace()
    correct_prob = 0.99    # the probability for correct class in u(k)
    loss_CE = F.cross_entropy(outputs, labels)
    K = outputs.size(1)

    teacher_soft = torch.ones_like(outputs).cuda()
    teacher_soft = teacher_soft*(1-correct_prob)/(K-1)  # p^d(k)
    for i in range(outputs.shape[0]):
        teacher_soft[i ,labels[i]] = correct_prob
    loss_soft_regu = nn.KLDivLoss()(F.log_softmax(outputs, dim=1), F.softmax(teacher_soft/T, dim=1))*params.multiplier_L

    KD_loss = (1. - alpha)*loss_CE + alpha*loss_soft_regu

    return KD_loss




# best
def divergence(student_logits, teacher_logits):
    divergence = -torch.sum(student_logits * teacher_logits, dim=-1)  # forward KL
    return torch.mean(divergence)


def loss_pseudo_kd(outputs, labels, args):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = args.alpha_L
    beta = args.beta_L
    #T = params.temperature
    K = outputs.size(1)
    teacher_tmp = torch.log(F.softmax(outputs, dim=1))
    # import ipdb
    loss_CE = F.cross_entropy(outputs, labels, reduction='none')
    # ipdb.set_trace()
    factor = torch.exp(loss_CE).unsqueeze(-1) * -1
    teacher_outputs = 1 / (beta) * teacher_tmp
    # teacher_outputs = factor * teacher_tmp
    teacher_sfm = F.softmax(teacher_outputs, dim=1)
    KL_cross = divergence(teacher_tmp, teacher_sfm.detach())
    # KL_uniform = nn.KLDivLoss()(-math.log(K) + torch.torch.zeros_like(outputs), teacher_sfm)
    KD_loss = (1. - alpha) * loss_CE.mean() + alpha * KL_cross

    return KD_loss



def loss_label_smoothing(outputs, labels, args):
    """
    loss function for label smoothing regularization
    """
    alpha = 0.1
    N = outputs.size(0)  # batch_size
    C = outputs.size(1)  # number of classes
    smoothed_labels = torch.full(size=(N, C), fill_value= alpha / (C - 1)).cuda()
    smoothed_labels.scatter_(dim=1, index=torch.unsqueeze(labels, dim=1), value=1-alpha)

    log_prob = torch.nn.functional.log_softmax(outputs, dim=1)
    loss = -torch.sum(log_prob * smoothed_labels) / N

    return loss
def loss_ce(outputs, labels, args):
    loss_CE = F.cross_entropy(outputs, labels)
    return loss_CE
