import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks import Conv1d, ConvTransposeDW, Unflatten
from models.base_vae_model import BaseModel


class PretrainedCVAESharedEncoder(BaseModel):
    def __init__(self, encoder: nn.Module, latent_size=1024, num_classes=50):
        super().__init__(latent_size)

        self.encoder = encoder

        self.ext_enc = nn.Linear(1024, 256)
        self.mu = nn.Linear(256, self.latent_size)
        self.logvar = nn.Linear(256, self.latent_size)
        self.ext_dec = nn.Linear(self.latent_size + num_classes, 1024)

        self.decoder = nn.Sequential(
            Unflatten(1024, 1, 1),
            nn.UpsamplingBilinear2d(scale_factor=4),
            nn.ReLU(),
            ConvTransposeDW(1024, 1024, 1),  # deconv6
            ConvTransposeDW(1024, 512, 2),  # deconv5_6
            ConvTransposeDW(512, 512, 1),  # deconv5_5
            Conv1d(512, 512)  # conv5_4/sep
        )

    def forward(self, x, y_one_hot=None):
        """
        :param x: the batch of features to be reconstructed
        :param y_one_hot: the one-hot encoding vectors of labels
        :return: the reconstructed x features.
        """
        assert y_one_hot is not None, "labels should be provided to the forward method"
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        z = torch.cat((z, y_one_hot), dim=1)
        return self.decode(z), mu, logvar

    def encode(self, x):
        h = self.encoder(x).flatten(start_dim=1)
        h = F.tanh(self.ext_enc(h))
        mu, logvar = self.mu(h), self.logvar(h)
        return mu, logvar

    def decode(self, z):
        z = self.ext_dec(z)
        z = self.decoder(z)
        return z

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            res = eps.mul(std).add_(mu)
            return res
        else:
            return mu

    def load_decoder(self, weight_file):
        self.decoder.load_state_dict(torch.load(weight_file))
