import torch
import torch.nn as nn


class ReconstructionLoss(nn.Module):
    def __init__(self, loss_mode):
        super(ReconstructionLoss, self).__init__()
        if loss_mode == 'l1':
            print("use L1 loss")
            self.loss = nn.L1Loss(reduction="mean")
        elif loss_mode == 'l2':
            self.loss = nn.MSELoss(reduction="mean")
        else:
            raise NotImplementedError('loss mode %s not implemented' % loss_mode)

    def __call__(self, reconstruction, real_data):
        return self.loss(reconstruction, real_data)


class KLDLoss(nn.Module):
    def __init__(self):
        super(KLDLoss, self).__init__()

    def __call__(self, mu, logvar):
        return -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))


class DistillationLoss(nn.Module):
    def __init__(self):
        super(DistillationLoss, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss()

    def __call__(self, classifier, features, labels, latent_layer):
        classifier.eval()
        output = classifier(x=None, y=features, latent_layer=latent_layer, return_lat_acts=False)
        return self.cross_entropy(output, labels)


class DistillationLossWithClassifierAndSoftLabels(nn.Module):
    def __init__(self):
        super(DistillationLossWithClassifierAndSoftLabels, self).__init__()
        self.kldiv_loss = torch.nn.KLDivLoss(reduction="batchmean")
        self.log_softmax = torch.nn.LogSoftmax()

    def __call__(self, classifier, features, soft_labels):
        classifier.eval()
        output = classifier(None, features, return_lat_acts=False)
        return self.kldiv_loss(self.log_softmax(output), soft_labels)
