import torch
from models.ae import AutoEncoder


class LaplaceVAE(AutoEncoder):

    def __init__(
        self,
        nu=1.1,
        input_shape=1,
        latent_dim=32,
        reconstruction="l1",
        mode="cnn",
        device="cuda" if torch.cuda.is_available() else "cpu",
        normalization="batchnorm",
        activation="relu",
        dropout_rate=0.0,
    ):
        super(LaplaceVAE, self).__init__(
            input_shape=input_shape,
            latent_dim=latent_dim,
            nu=nu,
            reconstruction=reconstruction,
            mode=mode,
            device=device,
            normalization=normalization,
            activation=activation,
            dropout_rate=dropout_rate,
        )
        self.name = "LaplaceVAE"

    def forward(self, x):
        z, (mu, logvar) = self.encode(x)
        x_hat = self.decode(z)
        return z, x_hat, (mu, logvar)

    def encode(self, x):
        x = self.encoder(x)
        mu = self.enc_to_latent1(x)
        logvar = self.enc_to_latent2(x)
        # Reparameterization trick
        z = self._reparameterize(mu, logvar)
        return z, (mu, logvar)

    def _reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=1.0)
        eps = laplace_dist.sample(sample_shape=std.shape).to(self.device)
        z = mu + eps * std
        return z

    def regularization(self, derivatives):
        mu, logvar = derivatives
        var = logvar.exp()
        mu_abs = mu.abs()
        kl_divergence = torch.sum(
            var * (-mu_abs / var).exp() + mu_abs - logvar - 1, dim=1
        )
        loss = torch.mean(kl_divergence)
        return loss

    def generate(self, n_gen=64):
        # Generation: sample random latent vectors and decode
        z_gen = torch.randn(n_gen, self.latent_dim).to(self.device)
        x_gen = self.decode(z_gen)
        x_gen = x_gen.detach().cpu()
        return x_gen
