import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math


def loss_kd(outputs, labels, teacher_outputs, params):
    """
    loss function for Knowledge Distillation (KD)
    """
    alpha = params.alpha
    T = params.temperature

    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) * (T * T)
    KD_loss = (1. - alpha) * loss_CE + alpha * D_KL

    return KD_loss


def loss_kd_self(outputs, labels, teacher_outputs, params):
    """
    loss function for self training: Tf-KD_{self}
    """
    alpha = params.alpha
    T = params.temperature

    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) * (
                T * T) * params.multiplier  # multiple is 1.0 in most of cases, some cases are 10 or 50
    KD_loss = (1. - alpha) * loss_CE + alpha * D_KL

    return KD_loss


#
def loss_kd_regularization(outputs, labels, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = params.alpha
    # alpha = 0.5
    T = params.temperature
    # import ipdb
    # ipdb.set_trace()
    correct_prob = 0.99  # the probability for correct class in u(k)
    loss_CE = F.cross_entropy(outputs, labels)
    K = outputs.size(1)

    teacher_soft = torch.ones_like(outputs).cuda()
    teacher_soft = teacher_soft * (1 - correct_prob) / (K - 1)  # p^d(k)
    for i in range(outputs.shape[0]):
        teacher_soft[i, labels[i]] = correct_prob
    loss_soft_regu = nn.KLDivLoss()(F.log_softmax(outputs, dim=1),
                                    F.softmax(teacher_soft / T, dim=1)) * params.multiplier

    KD_loss = (1. - alpha) * loss_CE + alpha * loss_soft_regu

    return KD_loss


def loss_kd_regularizationo(outputs, labels, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = 0.5
    # lambda_p = 39.5
    K = outputs.size(1)
    T = 5
    logsft_outputs = torch.log(F.softmax(outputs, dim=1))
    # teacher_outputs = outputs
    teacher_tmp = torch.log(F.softmax(outputs / T, dim=1))
    teacher_outputs = torch.exp(1 / (T * T) * teacher_tmp)
    loss_CE = F.cross_entropy(outputs, labels)
    teacher_sfm = F.softmax(teacher_outputs, dim=1)
    KL_self = nn.KLDivLoss()(teacher_tmp, teacher_sfm.detach())
    # KL_uniform = nn.KLDivLoss()(-math.log(K) + torch.torch.zeros_like(outputs), teacher_sfm)
    # import ipdb
    # ipdb.set_trace()
    KD_loss = (1. - alpha) * loss_CE + alpha * T * T * KL_self  # - lambda_p * KL_uniform

    # return KD_loss, loss_CE, KL_self, KL_uniform

    return KD_loss


# best
def divergence(student_logits, teacher_logits):
    divergence = -torch.sum(student_logits * teacher_logits, dim=-1)  # forward KL
    return torch.mean(divergence)





def loss_pseudo_kd_new(outputs, labels, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = params.alpha
    # T = params.temperature
    lambda_p = params.lambda_p
    K = outputs.size(1)
    teacher_tmp = torch.log(F.softmax(outputs, dim=1))
    teacher_outputs = 1 / (lambda_p) * teacher_tmp
    loss_CE = F.cross_entropy(outputs, labels)
    teacher_sfm = F.softmax(teacher_outputs, dim=1)
    KL_cross = divergence(teacher_tmp, teacher_sfm.detach())
    KL_uniform = nn.KLDivLoss()(-math.log(K) + torch.torch.zeros_like(outputs), teacher_sfm)
    KD_loss = (1. - alpha) * loss_CE + alpha * KL_cross + alpha * lambda_p * KL_uniform

    return KD_loss


def loss_pseudo_kd(outputs, labels, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = params.alpha
    # T = params.temperature
    lambda_p = params.lambda_p
    K = outputs.size(1)
    teacher_tmp = torch.log_softmax(outputs, dim=1)
    teacher_outputs = 1 / (lambda_p) * teacher_tmp
    loss_CE = F.cross_entropy(outputs, labels)
    teacher_sfm = F.softmax(teacher_outputs, dim=1)
    KL_cross = divergence(teacher_tmp, teacher_sfm.detach())
    # KL_uniform = nn.KLDivLoss()(-math.log(K) + torch.torch.zeros_like(outputs), teacher_sfm)
    KD_loss = (1. - alpha) * loss_CE + alpha * KL_cross

    return KD_loss


def loss_pseudo_kd_self(outputs, labels, outputs_self, params):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = params.alpha
    # T = params.temperature
    lambda_p = params.lambda_p
    K = outputs.size(1)
    teacher_tmp = torch.log_softmax(outputs, dim=1)
    # teacher_outputs = 1 / (lambda_p) * teacher_tmp
    loss_CE = F.cross_entropy(outputs, labels)

    teacher_tmp_self = torch.log_softmax(outputs_self, dim=1)
    teacher_outputs = 1 / (lambda_p) * teacher_tmp_self

    teacher_sfm = F.softmax(teacher_outputs, dim=1)
    KL_cross = divergence(teacher_tmp, teacher_sfm.detach())
    # KL_uniform = nn.KLDivLoss()(-math.log(K) + torch.torch.zeros_like(outputs), teacher_sfm)
    KD_loss = (1. - alpha) * loss_CE + alpha * KL_cross

    return KD_loss


loss_fc = nn.CrossEntropyLoss()


def loss_CE(outputs, labels, params):
    return loss_fc(outputs, labels)



def loss_label_smoothing(outputs, labels, params):
    """
    loss function for label smoothing regularization
    """
    alpha = params.alpha
    N = outputs.size(0)  # batch_size
    C = outputs.size(1)  # number of classes
    smoothed_labels = torch.full(size=(N, C), fill_value=alpha / (C - 1)).cuda()
    smoothed_labels.scatter_(dim=1, index=torch.unsqueeze(labels, dim=1), value=1 - alpha)

    log_prob = torch.nn.functional.log_softmax(outputs, dim=1)
    loss = -torch.sum(log_prob * smoothed_labels) / N

    return loss


def soft_beta_loss(outputs, labels, outputs_orig, params):
    beta = 3
    alpha = 0.3
    softmaxes = F.softmax(outputs, dim=1)
    n, num_classes = softmaxes.shape
    tensor_labels = Variable(torch.zeros(n, num_classes).cuda().scatter_(1, labels.long().view(-1, 1).data, 1))

    # sort outputs and labels based on confidence/entropy        
    softmaxes_orig = F.softmax(outputs_orig, dim=1)
    maximum, _ = (softmaxes_orig * tensor_labels).max(dim=1)
    maxes, indices = maximum.sort()

    sorted_softmax, sorted_labels = softmaxes[indices], tensor_labels[indices]
    sorted_softmax_orig = softmaxes_orig[indices]

    # generate beta labels  
    random_beta = np.random.beta(beta, 1, n)
    random_beta.sort()
    random_beta = torch.from_numpy(random_beta).cuda()

    # create beta smoothing labels 
    uniform = (1 - random_beta) / (num_classes - 1)
    random_beta -= uniform
    random_beta = random_beta.view(-1, 1).repeat(1, num_classes).float()
    beta_label = sorted_labels * random_beta
    beta_label += uniform.view(-1, 1).repeat(1, num_classes).float()

    # compute NLL loss
    loss = -beta_label * torch.log(sorted_softmax + 10 ** (-8))
    loss = loss.sum() / n
    loss = (1 - alpha) * loss_fc(outputs, labels) + alpha * loss
    return loss