from abc import abstractmethod
import torch.nn as nn
from models.classifier import ClassifierBN


class BaseModel(nn.Module):
    def __init__(self, latent_size):
        super(BaseModel, self).__init__()
        self.latent_size = latent_size

    @abstractmethod
    def encode(self, x):
        """
        Encode the given tensore into the latent space representation
        :param x: the tensor to be encoded
        :return: a tuple composed of two elements, mu and logvar.
        Mu is the mean of the encoding, and logvar is the log of the variance.
        If no probabilistic latent space is used (as an example in basic autoencoders) just use mu as the
        result of the encoding and set the logvar to zero.
        """
        pass

    @abstractmethod
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick for variational autoencoders
        see Appendix B from VAE paper:
        Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        https://arxiv.org/abs/1312.6114
        If the reparametrization is not required just return mu
        :param mu: the mean of the latent space representation
        :param logvar: the logarithm of the variance of the latent representation
        :return: the latent vector calculated from mu and logvar
        """
        pass

    @abstractmethod
    def decode(self, z):
        """
        Decode the latent vector z into a feature map of the same dimension of the input
        :param z: the latent vector to be decoded
        :return: a tensor of the same dimension of the input
        """
        pass

    def forward(self, x, y_one_hot=None):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


    @staticmethod
    def _create_pretrained_encoder():
        """
        Create a pretrained MobileNetV1 classifier on the CORe50 dataset.
        :return: a pretrained classifier on the CORe50 dataset.
        """
        classifier = ClassifierBN(n_classes=50, weights_file="../pretrained_models/mobilenet_classifier_bn_NC.pth")
        return classifier.network.features
