import torch
import torch.nn as nn
import torch.distributions as dist
from torch.nn import functional as F
import pytorch_lightning as pl



LATENT_DIM = 64
INPUT_DATA_DIM= 784
MODEL_STR = "VAE"


class VAE(pl.LightningModule):


    def __init__(self, encoder, decoder, modality,
                latent_dim = LATENT_DIM, 
                input_size = INPUT_DATA_DIM ,
                model_name = MODEL_STR , 
                 ):

        super(VAE, self).__init__()

        self.latent_dim = latent_dim
        self.input_size = input_size

        self.encoder = encoder
        self.decoder = decoder
        self.modality = modality
        self.modelName = model_name

    def training_step(self, x) :
        x = x[0][self.modality.name]
        recon , mu , logvar = self.forward(x)  
        recon_loss = self.reconstruction_loss(x,recon,x.size(0))
        kld = self.kld_loss(mu,logvar,x.size(0))
        elbo = self.elbo_objectif(recon_loss,kld, beta =1)
        self.log("loss", elbo)
        self.log("recon_loss", recon_loss)
        self.log("KLD",kld)
        return{"loss":elbo} 



    def encode(self, x):
        return self.encoder(x)


    def decode(self, z):
        return self.decoder(z)


    def reparam(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample


    def forward(self, x):
        ## Encode x into param 
        # in case of gaussian posterior -> generate mu and var
        mu, logvar = self.encode(x)
       
        z = self.reparam(mu, logvar)
        ## Decode z and reconstruct x       
        return self.decode(z), mu, logvar


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001 , betas=(0.9,0.999) )
        return optimizer
    
    def reconstruction_loss(self, x,recon,batch_size):
        return  - self.modality.log_prob(x,recon) / batch_size

    def kld_loss(self,mu,logvar, batch_size):
        return - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) /batch_size

    def elbo_objectif(self,reconstruction_loss,KLD, beta = 1.0):
        return reconstruction_loss + beta * KLD
    