import torch
import torch.nn as nn
from models.pretrained_classifier import get_pretrained_classifier_bn


class ReconstructionLoss(nn.Module):
    def __init__(self, loss_mode):
        super(ReconstructionLoss, self).__init__()
        if loss_mode == 'l1':
            print("use L1 loss")
            self.loss = nn.L1Loss(reduction="mean")
        elif loss_mode == 'l2':
            self.loss = nn.MSELoss(reduction="mean")
        else:
            raise NotImplementedError('loss mode %s not implemented' % loss_mode)

    def __call__(self, reconstruction, real_data):
        return self.loss(reconstruction, real_data)


class KLDLoss(nn.Module):
    def __init__(self):
        super(KLDLoss, self).__init__()

    def __call__(self, mu, logvar):
        return -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))


class DistillationLoss(nn.Module):
    def __init__(self, device):
        super(DistillationLoss, self).__init__()
        self.classifier = get_pretrained_classifier_bn(
            num_classes=50, weights_file="../pretrained_models/mobilenet_classifier_bn_NC.pth"
        )
        self.classifier.to(device)
        self.classifier.eval()  # check this
        self.cross_entropy = nn.CrossEntropyLoss()

    def __call__(self, features, labels):
        output = self.classifier(features)
        return self.cross_entropy(output, labels)


class DistillationLossWithClassifier(nn.Module):
    def __init__(self):
        super(DistillationLossWithClassifier, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss()

    def __call__(self, classifier, features, labels):
        classifier.eval()
        output = classifier(None, features, return_lat_acts=False)
        return self.cross_entropy(output, labels)


class DistillationLossWithClassifierAndSoftLabels(nn.Module):
    def __init__(self):
        super(DistillationLossWithClassifierAndSoftLabels, self).__init__()
        self.kldiv_loss = torch.nn.KLDivLoss(reduction="batchmean")
        self.log_softmax = torch.nn.LogSoftmax()

    def __call__(self, classifier, features, soft_labels):
        classifier.eval()
        output = classifier(None, features, return_lat_acts=False)
        return self.kldiv_loss(self.log_softmax(output), soft_labels)
