import torch
import torch.nn as nn
from torch.nn import functional as F
from .baseline import baseline

class LVAE(baseline) : 
    def __init__(self, n_dim=1, m_dim=1, nu=3, recon_sigma=1, reg_weight=1, num_hidden=64,  device='cpu'):
        super(LVAE, self).__init__(n_dim=n_dim, m_dim=m_dim)
        self.model_name = "LVAE"
        self.nu = nu
        self.n_dim = n_dim
        self.m_dim = m_dim
        self.recon_sigma = recon_sigma
        self.reg_weight = reg_weight
        self.num_layers = num_hidden
        self.device = device
        # define encoder
        self.encoder = nn.Sequential(
            nn.Linear(n_dim, num_hidden),
            nn.LeakyReLU(),  
            nn.BatchNorm1d(num_hidden),
            # nn.Linear(num_hidden, num_hidden),
            # nn.LeakyReLU(),  
            # nn.BatchNorm1d(num_hidden),
        )
        self.latent_mu = nn.Linear(num_hidden, m_dim)
        self.latent_logvar = nn.Linear(num_hidden, m_dim)

        # define decoder

        self.decoder = nn.Sequential(
            nn.Linear(m_dim, num_hidden), 
            nn.LeakyReLU(), 
            nn.BatchNorm1d(num_hidden),
            nn.Linear(num_hidden, n_dim),
        )
        self.laplace_dist= torch.distributions.Laplace(loc=0.0, scale=1.0)


    def encoder_reparameterize(self, mu, logvar) : 
        std = torch.exp(0.5 * logvar)
        eps = self.laplace_dist.sample(sample_shape=std.shape).to(self.device)
        return mu + std * eps
    
    def encode(self, x) : 
        x = self.encoder(x)
        mu = self.latent_mu(x)
        logvar = self.latent_logvar(x)
        z = self.encoder_reparameterize(mu, logvar)
        return z, mu, logvar
    
    def decode(self, z) : 
        return self.decoder(z)
    
    def recon_loss(self, x, recon_x) : 
        return F.mse_loss(recon_x, x, reduction = 'none').sum(dim = 1).mean(dim = 0) / self.recon_sigma

    def reg_loss(self, mu, logvar) : 
        # return KL regularizer including constant term
        var = logvar.exp()
        mu_abs = mu.abs()
        return torch.mean(torch.sum(var*(-mu_abs / var).exp() + mu_abs - logvar - 1, dim=1))
    
    def total_loss(self, x, recon_x, mu, logvar) : 
        recon = self.recon_loss(x, recon_x)
        reg = self.reg_loss(mu, logvar)

        return recon, reg, recon + self.reg_weight * reg


    def generate(self, N = 1000) : 
        eps = self.laplace_dist.sample(sample_shape=(N,self.m_dim)).to(self.device)

        return self.decoder(eps)
    
    def recon_data(self, x) : 
        enc_z, _, _ = self.encode(x)
        recon_x = self.decode(enc_z)
        return recon_x
    

    def forward(self, x, verbose = False) : 
        enc_z, mu, logvar = self.encode(x)
        recon_x = self.decode(enc_z)
        total_loss = self.total_loss(x, recon_x, mu, logvar)
        if verbose is False : 
            return self.total_loss(x, recon_x, mu, logvar)
        else : 
            return [
                total_loss, 
                mu.detach().flatten().cpu().numpy(), 
                logvar.detach().flatten().exp().cpu().numpy(), 
                total_loss[2]
            ]

        

        


