import torch
from models.ae import AutoEncoder


class VAE(AutoEncoder):

    def __init__(
        self,
        nu=1.1,
        input_shape=1,
        latent_dim=32,
        reconstruction="l1",
        mode="cnn",
        device="cuda" if torch.cuda.is_available() else "cpu",
        normalization="batchnorm",
        activation="relu",
        dropout_rate=0.0,
    ):
        super(VAE, self).__init__(
            input_shape=input_shape,
            latent_dim=latent_dim,
            nu=nu,
            reconstruction=reconstruction,
            mode=mode,
            device=device,
            normalization=normalization,
            activation=activation,
            dropout_rate=dropout_rate,
        )
        self.name = "VAE"

    def forward(self, x):
        z, (mu, logvar) = self.encode(x)
        x_hat = self.decode(z)
        return z, x_hat, (mu, logvar)

    def encode(self, x):
        x = self.encoder(x)
        mu = self.enc_to_latent1(x)
        logvar = self.enc_to_latent2(x)
        # Reparameterization trick
        z = self._reparameterize(mu, logvar)
        return z, (mu, logvar)

    def _reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def regularization(self, derivatives):
        mu, logvar = derivatives
        kl_per_sample = -0.5 * torch.sum(
            1 + logvar - mu.pow(2) - logvar.exp(), dim=1
        )
        loss = torch.mean(kl_per_sample)
        return loss

    def generate(self, n_gen=64):
        # Generation: sample random latent vectors and decode
        z_gen = torch.randn(n_gen, self.latent_dim).to(self.device)
        x_gen = self.decode(z_gen)
        x_gen = x_gen.detach().cpu()
        return x_gen
