import torch.nn as nn
from models.networks import Conv1d, ConvDW, ConvTransposeDW, Flatten, Unflatten
from models.base_vae_model import BaseModel
from models.classifier import ClassifierBN


class AE(BaseModel):
    def __init__(self, latent_size=1024, num_classes=50, pretrained_encoder=False):
        super().__init__(latent_size)

        if pretrained_encoder:
            self.encoder = nn.Sequential(
                self._create_pretrained_encoder(),
                Flatten()
            )
        else:
            self.encoder = nn.Sequential(
                Conv1d(512, 512),  # conv5_4/sep
                ConvDW(512, 512, 1),  # conv5_5
                ConvDW(512, 1024, 2),  # conv5_6
                ConvDW(1024, 1024, 1),  # conv6
                nn.MaxPool2d(kernel_size=4),
                Flatten()
            )

        self.decoder = nn.Sequential(
            Unflatten(1024, 1, 1),
            nn.UpsamplingBilinear2d(scale_factor=4),
            nn.ReLU(),
            ConvTransposeDW(1024, 1024, 1),  # deconv6
            ConvTransposeDW(1024, 512, 2),  # deconv5_6
            ConvTransposeDW(512, 512, 1),  # deconv5_5
            Conv1d(512, 512)  # conv5_4/sep
        )

    def forward(self, x, y_one_hot=None):
        """
        :param x: the batch of features to be reconstructed
        :param y_one_hot: the one hot representation of the class (not used in the AE)
        :return: the reconstructed x features.
        """
        latent_vector, _ = self.encode(x)
        return self.decode(latent_vector), 0, 0

    def encode(self, x):
        return self.encoder(x), 0

    def decode(self, z):
        return self.decoder(z)

    def reparameterize(self, mu, logvar):
        return mu

    @staticmethod
    def _create_pretrained_encoder():
        """
        Create a pretrained MobileNetV1 classifier on the ImageNet dataset.
        :return: a pretrained classifier on the Imagenet dataset.
        """
        print("pretrained ae create pretrained encoder")
        classifier = ClassifierBN(n_classes=50, weights_file=None)
        return classifier.network.features
