import torch
import torch.nn as nn

class LaMD(nn.Module):
    def __init__(self, encoder, decoder, propagator, no_vae=False):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.propagator = propagator
        self.no_vae = no_vae

    def forward(self, data):
        # For now just passing encoder output to decoder and not the predicted states of the propagator
        # TODO: change this to pass the predicted states of the propagator to the decoder at some point?

        # encode all data into latent space
        mu, logvar = self.encoder(data)
        if self.no_vae:
            z = mu
        else:
            z = self.reparameterize(mu, logvar)

        # decode latent space into target space
        z_decoded = self.decoder(z)
        # reshape decoder output into (batch_size, seq_len +1 , latent_dim), such that loss is well defined
        z_decoded = tuple([pred.view(-1,self.propagator.seq_len+1,self.decoder.out_features[i]) for i,pred in enumerate(z_decoded)])

        # propagate latent space 
        z_propagated = self.propagator(z)

        # return latent target, which is the last element of the sequence (first seq_len are inputs, last is target)
        latent_taget = z.view(-1,self.propagator.seq_len+1,self.propagator.latent_dim)[:,-1,:]

        # reshape mu and logvar into (batch_size, seq_len +1 , latent_dim), such that loss is well defined
        mu = mu.view(-1,self.propagator.seq_len+1,self.propagator.latent_dim)
        logvar = logvar.view(-1,self.propagator.seq_len+1,self.propagator.latent_dim)

        if self.no_vae:
            return latent_taget, z_decoded, z_propagated , None, None
        else:
            return latent_taget, z_decoded, z_propagated, mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std