import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from .utils import conjugate_gradient
from .constants import CUDA, GPU_MAYBE, Q_SAMPLES_VAE, DIM_HIDDEN, DIM_LATENT

class CustomLinear(torch.nn.Module):
    def __init__(self, dim_in, dim_out, marker=None):
        super(CustomLinear, self).__init__()

        # Learnable parameters
        self.W = torch.nn.Parameter(
            torch.Tensor(size=(dim_in, dim_out)).uniform_(-0.1, 0.1)
        )

        self.marker = marker

    def forward(self, x):
        w = self.W
        wx = F.linear(x, w.T)

        return wx


def matrix_vector_product_reparameterise(W, L, p):
    v = torch.einsum("bo,lo->bl", p, W)
    v *= L # bl, bl -> bl
    v = torch.einsum("bl,ol->bo", v, W.transpose(0, 1))

    return v


def dot_product_reparameterise(a, b):
    return (a * b).sum(dim=1, keepdim=True)


class Reparameteriser(torch.nn.Module):
    def __init__(self, dim_latent, dim_out, reparam="r2g2", use_tanh=True):
        super(Reparameteriser, self).__init__()

        # Decoder parameters (first layer)
        self.W = torch.nn.Parameter(
            torch.Tensor(size=(dim_latent, dim_out)).uniform_(-0.1, 0.1)
        )

        self.reparam = reparam
        self.gpu_maybe = GPU_MAYBE
        self.W.to(self.gpu_maybe)
        self.use_tanh = use_tanh
    
    def forward(self, eps, v_mean, v_var):
        W = self.W
        v_std = v_var.sqrt()

        if self.training and self.reparam == "r2g2":
            with torch.no_grad():
                # compute pre-activations
                W_detach = W.detach()
                v_var_detach = v_var.detach()
                v_std_detach = v_var_detach.sqrt()
                fwd_v_std_eps = v_std_detach * eps
                fwd_wv_std_eps = F.linear(fwd_v_std_eps, W_detach.T)

                # compute conditional eps for each (batch_size, dim_out)
                dim_latent, dim_out = W.shape
                iters = min(dim_latent, dim_out)
                r2g2_beta = conjugate_gradient(
                    W=W_detach,
                    V=v_var_detach,
                    b=fwd_wv_std_eps,
                    matrix_vector_product_function=matrix_vector_product_reparameterise,
                    dot_product_function=dot_product_reparameterise,
                    iters=iters,
                ) # bo
                r2g2_eps = torch.einsum("bo,lo->bl", r2g2_beta, W_detach)
                r2g2_eps *= v_std_detach

            r2g2_v_std_eps = r2g2_eps * v_std
            r2g2_wv_std_eps = F.linear(r2g2_v_std_eps, W.T)
            wv_std_eps = (fwd_wv_std_eps - r2g2_wv_std_eps).detach() + r2g2_wv_std_eps # stop_gradients

            wv_mean = F.linear(v_mean, W.T)
            wv = wv_mean + wv_std_eps
            
        else:
            # reparameterisation step
            v = v_mean + v_std * eps

            # decoder step
            wv = F.linear(v, W.T)
        
        if self.use_tanh:
            return F.tanh(wv)

        else:
            return wv


class x_to_embedding(torch.nn.Module):
    def __init__(self):
        super().__init__()

        layers = [
            CustomLinear(dim_in=28*28, dim_out=DIM_HIDDEN),
            torch.nn.Tanh(),
            CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_HIDDEN),
            torch.nn.Tanh(),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class embedding_to_x(torch.nn.Module):
    def __init__(self):
        super().__init__()

        layers = [
            CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_HIDDEN),
            torch.nn.Tanh(),
            CustomLinear(dim_in=DIM_HIDDEN, dim_out=28*28),
            torch.nn.Sigmoid(),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class x_flattener(torch.nn.Module):
    def __init__(self):
        super().__init__()

        layers = [torch.nn.Flatten()]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class embedding_to_embedding(torch.nn.Module):
    def __init__(self):
        super().__init__()

        layers = [
            CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_HIDDEN),
            torch.nn.Tanh(),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class VAE(torch.nn.Module):
    def __init__(self, reparam="r2g2"):
        super().__init__()

        self.gpu_maybe = GPU_MAYBE

        # x flattener
        self.x_flattener = x_flattener().to(self.gpu_maybe)

        # encoder + reparameterisation mappings
        self.embed_x_q_v = x_to_embedding().to(self.gpu_maybe)
        self.mean_q_v = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT, marker="top").to(self.gpu_maybe)
        self.reparam_q_v_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam=reparam, use_tanh=True).to(self.gpu_maybe)

        # decoder mappings
        self.embedding_to_x = embedding_to_x().to(self.gpu_maybe)

    
    def encode(self, x):
        embedding_x_q_v = self.embed_x_q_v(x)
        q_v_means = self.mean_q_v(embedding_x_q_v)
        q_v_inv_softplus_var_params = self.var_q_v(embedding_x_q_v)
        q_v_vars = F.softplus(q_v_inv_softplus_var_params)

        return q_v_means, q_v_vars

    def decode(self, q_v_means, q_v_vars):
        eps_q_v = torch.empty(q_v_vars.size(), device=q_v_vars.device).normal_(0.0, 1.0)
        embedding_v = self.reparam_q_v_embed(eps_q_v, q_v_means, q_v_vars)
        x_hat = self.embedding_to_x(embedding_v)

        return x_hat

    def forward(self, x):
        if self.training:
            x = self.x_flattener(x)
            q_v_means, q_v_vars = self.encode(x)
            x_hat = self.decode(q_v_means, q_v_vars)

        else:
            with torch.no_grad():
                x = self.x_flattener(x)
                q_v_means, q_v_vars = self.encode(x)
                x_hat = self.decode(q_v_means, q_v_vars)

        return  x_hat, q_v_means, q_v_vars

    def train_one_step(self, X_train, nll_loss, opt):
        # training step
        self.train()

        X_train = (X_train > torch.rand(X_train.shape)).double() # dynamic binarisation
        if CUDA:
            X_train = X_train.cuda()
        X_train = X_train.type(torch.double)
        batch_size = X_train.shape[0]
        opt.zero_grad()
        X_hat, q_v_means, q_v_vars = self.forward(X_train)
        recon_loss = nll_loss(X_hat.view(batch_size, -1), X_train.view(batch_size, -1)) / batch_size

        # calculate kl components
        q_v_stddevs = q_v_vars.sqrt()
        prior_means = torch.zeros_like(q_v_means)
        prior_stddevs = torch.ones_like(q_v_stddevs)
        if CUDA:
            q_v_means, q_v_stddevs = q_v_means.cuda(), q_v_stddevs.cuda()
            prior_means, prior_stddevs = prior_means.cuda(), prior_stddevs.cuda()
        q = torch.distributions.Normal(q_v_means, q_v_stddevs)
        p = torch.distributions.Normal(prior_means, prior_stddevs)
        kl = torch.distributions.kl.kl_divergence(q, p).sum(dim=1).mean()

        if CUDA:
            recon_loss = recon_loss.cuda()
            kl = kl.cuda()

        vi_loss = recon_loss + kl

        vi_loss.backward()
        opt.step()

    def print_elbo(self, step, elbo_loader, nll_loss, elbo_out=None):
        self.eval()

        # compute stable ELBO on test set based on Q_SAMPLES from posterior q
        vi_loss_total = torch.sum(torch.zeros(1, dtype=torch.float))
        vi_loss_total += (len(elbo_loader.dataset) * np.log(Q_SAMPLES_VAE))
        if CUDA:
            vi_loss_total = vi_loss_total.cuda()

        for X_test, _ in tqdm(elbo_loader, desc=f"step {step} elbo"):
            X_test = (X_test > torch.rand(X_test.shape)).double() # dynamic binarisation
            if CUDA:
                X_test = X_test.cuda()
            batch_size, channels, height, width = X_test.shape
            X_test = X_test.unsqueeze(1).expand(-1, Q_SAMPLES_VAE, -1, -1, -1) # batch_size, samples, channels, height, width
            X_test = X_test.reshape(batch_size * Q_SAMPLES_VAE, channels, height, width)

            with torch.no_grad():
                X_hat, q_v_means, q_v_vars = self.forward(X_test)
                recon_loss = nll_loss(
                    X_hat.view(batch_size * Q_SAMPLES_VAE, -1), X_test.view(batch_size * Q_SAMPLES_VAE, -1)
                ) # batch_size x samples, pixels
                recon_loss = recon_loss.sum(dim=1) # batch_size x samples

                # calculate kl components
                q_v_stddevs = q_v_vars.sqrt() # batch_size x samples, dim_latent
                p_v_means = torch.zeros_like(q_v_means)
                p_v_stddevs = torch.ones_like(q_v_stddevs)
                if CUDA:
                    q_v_means, q_v_stddevs = q_v_means.cuda(), q_v_stddevs.cuda()
                    p_v_means, p_v_stddevs = p_v_means.cuda(), p_v_stddevs.cuda()
                q_v = torch.distributions.Normal(q_v_means, q_v_stddevs)
                p_v = torch.distributions.Normal(p_v_means, p_v_stddevs)
                kl = torch.distributions.kl.kl_divergence(q_v, p_v).sum(dim=1) # batch_size x samples

                if CUDA:
                    recon_loss = recon_loss.cuda()
                    kl = kl.cuda()

                # neg ELBO: bound on neg log-likelihood
                vi_loss = recon_loss + kl # batch_size x samples
                vi_loss = vi_loss.view(batch_size, Q_SAMPLES_VAE)

                vi_loss_total -= torch.logsumexp(-vi_loss, dim=1).sum()

        if elbo_out is not None:
            elbo_out.write('{} {} \n'.format(step, vi_loss_total.item()))
            elbo_out.flush()


class TwoLayerVAE(torch.nn.Module):
    def __init__(self, reparam="r2g2"):
        super().__init__()

        self.gpu_maybe = GPU_MAYBE

        # mappings
        # x flattener
        self.x_flattener = x_flattener().to(self.gpu_maybe)

        # top-down inference process
        # embed x for q_v1 and q_v2
        self.embed_x = CustomLinear(dim_in=28*28, dim_out=2 * DIM_HIDDEN).to(self.gpu_maybe)

        # x->v1
        self.embed_embedding_q_v1 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_q_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.reparam_q_v1_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam="rt", use_tanh=False).to(self.gpu_maybe)

        # x,v1->v2
        self.embed_v1_q_v2 = CustomLinear(dim_in=DIM_LATENT, dim_out=DIM_HIDDEN).to(self.gpu_maybe)
        self.embed_embedding_q_v2 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_q_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.reparam_q_v2_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=2 * DIM_HIDDEN, reparam="rt", use_tanh=False).to(self.gpu_maybe)

        # bottom-up generative process
        # v2->v1
        self.embed_embedding_p_v1 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_p_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_p_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT, marker="top").to(self.gpu_maybe)
        self.reparam_p_v1_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam=reparam, use_tanh=False).to(self.gpu_maybe)
        
        # v2,v1->x
        self.embedding_to_x = embedding_to_x().to(self.gpu_maybe)

    def encode(self, x):
        # embed x for q_v1 and q_v2
        embedding_x = self.embed_x(x)
        embedding_x_q_v1 = embedding_x[:, :DIM_HIDDEN]
        embedding_x_q_v2 = embedding_x[:, DIM_HIDDEN:]

        # x->v1 params
        embeddings_q_v1 = self.embed_embedding_q_v1(F.tanh(embedding_x_q_v1))
        q_v1_means = self.mean_q_v1(embeddings_q_v1)
        q_v1_inv_softplus_vars = self.var_q_v1(embeddings_q_v1)
        q_v1_vars = F.softplus(q_v1_inv_softplus_vars)

        # sample v1 and embed for q_v2
        eps_q_v1 = torch.empty(q_v1_vars.size(), device=q_v1_vars.device).normal_(0.0, 1.0)
        embedding_v1_q_v2 = self.reparam_q_v1_embed(eps_q_v1, q_v1_means, q_v1_vars)

        # x,v1->v2 params
        embedding_xv1_q_v2 = embedding_x_q_v2 + embedding_v1_q_v2
        embeddings_q_v2 = self.embed_embedding_q_v2(F.tanh(embedding_xv1_q_v2))
        q_v2_means = self.mean_q_v2(embeddings_q_v2)
        q_v2_inv_softplus_vars = self.var_q_v2(embeddings_q_v2)
        q_v2_vars = F.softplus(q_v2_inv_softplus_vars)

        return q_v1_means, q_v1_vars, q_v2_means, q_v2_vars

    def decode(self, q_v2_means, q_v2_vars):
        # sample v2 and embed for p_v1 and p_x
        eps_q_v2 = torch.empty(q_v2_vars.size(), device=q_v2_vars.device).normal_(0.0, 1.0)
        embedding_v2 = self.reparam_q_v2_embed(eps_q_v2, q_v2_means, q_v2_vars)
        embedding_v2_p_v1 = embedding_v2[:, :DIM_HIDDEN]
        embedding_v2_p_x = embedding_v2[:, DIM_HIDDEN:]

        # v2->v1 params
        embeddings_p_v1 = self.embed_embedding_p_v1(F.tanh(embedding_v2_p_v1))
        p_v1_means = self.mean_p_v1(embeddings_p_v1)
        p_v1_inv_softplus_vars = self.var_p_v1(embeddings_p_v1)
        p_v1_vars = F.softplus(p_v1_inv_softplus_vars)

        # sample v1 and embed for p_x
        eps_p_v1 = torch.empty(p_v1_vars.size(), device=p_v1_vars.device).normal_(0.0, 1.0)
        embedding_v1_p_x = self.reparam_p_v1_embed(eps_p_v1, p_v1_means, p_v1_vars)

        # v2,v1->x params
        embedding_v2v1_p_x = embedding_v2_p_x + embedding_v1_p_x
        x_hat = self.embedding_to_x(F.tanh(embedding_v2v1_p_x))

        return x_hat, p_v1_means, p_v1_vars

    def forward(self, x):
        if self.training:
            x = self.x_flattener(x)
            q_v1_means, q_v1_vars, q_v2_means, q_v2_vars = self.encode(x)
            x_hat, p_v1_means, p_v1_vars = self.decode(q_v2_means, q_v2_vars)

        else:
            with torch.no_grad():
                x = self.x_flattener(x)
                q_v1_means, q_v1_vars, q_v2_means, q_v2_vars = self.encode(x)
                x_hat, p_v1_means, p_v1_vars = self.decode(q_v2_means, q_v2_vars)

        return  x_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, p_v1_means, p_v1_vars

    def train_one_step(self, X_train, nll_loss, opt):
        # training step
        self.train()

        X_train = (X_train > torch.rand(X_train.shape)).double() # dynamic binarisation
        if CUDA:
            X_train = X_train.cuda()
        X_train = X_train.type(torch.double)
        batch_size = X_train.shape[0]
        opt.zero_grad()
        X_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, p_v1_means, p_v1_vars = self.forward(X_train)
        recon_loss = nll_loss(X_hat.view(batch_size, -1), X_train.view(batch_size, -1)) / batch_size

        # calculate kl components
        q_v1_stddevs = q_v1_vars.sqrt()
        q_v2_stddevs = q_v2_vars.sqrt()
        p_v2_means = torch.zeros_like(q_v2_means)
        p_v2_stddevs = torch.ones_like(q_v2_stddevs)
        p_v1_stddevs = p_v1_vars.sqrt()
        if CUDA:
            q_v1_means, q_v1_stddevs = q_v1_means.cuda(), q_v1_stddevs.cuda()
            q_v2_means, q_v2_stddevs = q_v2_means.cuda(), q_v2_stddevs.cuda()
            p_v2_means, p_v2_stddevs = p_v2_means.cuda(), p_v2_stddevs.cuda()
            p_v1_means, p_v1_stddevs = p_v1_means.cuda(), p_v1_stddevs.cuda()

        q_v1 = torch.distributions.Normal(q_v1_means, q_v1_stddevs)
        p_v1 = torch.distributions.Normal(p_v1_means, p_v1_stddevs)
        kl_v1 = torch.distributions.kl.kl_divergence(q_v1, p_v1).sum(dim=1).mean()
        
        q_v2 = torch.distributions.Normal(q_v2_means, q_v2_stddevs)
        p_v2 = torch.distributions.Normal(p_v2_means, p_v2_stddevs)
        kl_v2 = torch.distributions.kl.kl_divergence(q_v2, p_v2).sum(dim=1).mean()

        if CUDA:
            recon_loss = recon_loss.cuda()
            kl_v1 = kl_v1.cuda()
            kl_v2 = kl_v2.cuda()

        vi_loss = recon_loss + kl_v1 + kl_v2

        vi_loss.backward()
        opt.step()

    def print_elbo(self, step, elbo_loader, nll_loss, elbo_out=None):
        self.eval()

        # compute stable ELBO on test set based on Q_SAMPLES from posterior q
        vi_loss_total = torch.sum(torch.zeros(1, dtype=torch.float))
        vi_loss_total += (len(elbo_loader.dataset) * np.log(Q_SAMPLES_VAE))
        if CUDA:
            vi_loss_total = vi_loss_total.cuda()

        for X_test, _ in tqdm(elbo_loader, desc=f"step {step} elbo"):
            X_test = (X_test > torch.rand(X_test.shape)).double() # dynamic binarisation
            if CUDA:
                X_test = X_test.cuda()
            batch_size, channels, height, width = X_test.shape
            X_test = X_test.unsqueeze(1).expand(-1, Q_SAMPLES_VAE, -1, -1, -1) # batch_size, samples, channels, height, width
            X_test = X_test.reshape(batch_size * Q_SAMPLES_VAE, channels, height, width)

            with torch.no_grad():
                X_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, p_v1_means, p_v1_vars = self.forward(X_test)
                recon_loss = nll_loss(
                    X_hat.view(batch_size * Q_SAMPLES_VAE, -1), X_test.view(batch_size * Q_SAMPLES_VAE, -1)
                ) # batch_size x samples, pixels
                recon_loss = recon_loss.sum(dim=1) # batch_size x samples

                # calculate kl components
                q_v1_stddevs = q_v1_vars.sqrt()
                q_v2_stddevs = q_v2_vars.sqrt()
                p_v2_means = torch.zeros_like(q_v2_means)
                p_v2_stddevs = torch.ones_like(q_v2_stddevs)
                p_v1_stddevs = p_v1_vars.sqrt()
                if CUDA:
                    q_v1_means, q_v1_stddevs = q_v1_means.cuda(), q_v1_stddevs.cuda()
                    q_v2_means, q_v2_stddevs = q_v2_means.cuda(), q_v2_stddevs.cuda()
                    p_v2_means, p_v2_stddevs = p_v2_means.cuda(), p_v2_stddevs.cuda()
                    p_v1_means, p_v1_stddevs = p_v1_means.cuda(), p_v1_stddevs.cuda()

                q_v1 = torch.distributions.Normal(q_v1_means, q_v1_stddevs)
                p_v1 = torch.distributions.Normal(p_v1_means, p_v1_stddevs)
                kl_v1 = torch.distributions.kl.kl_divergence(q_v1, p_v1).sum(dim=1)
                
                q_v2 = torch.distributions.Normal(q_v2_means, q_v2_stddevs)
                p_v2 = torch.distributions.Normal(p_v2_means, p_v2_stddevs)
                kl_v2 = torch.distributions.kl.kl_divergence(q_v2, p_v2).sum(dim=1)

                if CUDA:
                    recon_loss = recon_loss.cuda()
                    kl_v1 = kl_v1.cuda()
                    kl_v2 = kl_v2.cuda()

                # neg ELBO: bound on neg log-likelihood
                vi_loss = recon_loss + kl_v1 + kl_v2
                vi_loss = vi_loss.view(batch_size, Q_SAMPLES_VAE)

                vi_loss_total -= torch.logsumexp(-vi_loss, dim=1).sum()

        if elbo_out is not None:
            elbo_out.write('{} {} \n'.format(step, vi_loss_total.item()))
            elbo_out.flush()


class ThreeLayerVAE(torch.nn.Module):
    def __init__(self, reparam="r2g2"):
        super().__init__()

        self.gpu_maybe = GPU_MAYBE

        # mappings
        # x flattener
        self.x_flattener = x_flattener().to(self.gpu_maybe)

        # top-down inference process
        # embed x for q_v1 and q_v2
        self.embed_x = CustomLinear(dim_in=28*28, dim_out=3 * DIM_HIDDEN).to(self.gpu_maybe)

        # x->v1
        self.embed_embedding_q_v1 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_q_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.reparam_q_v1_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam="rt", use_tanh=False).to(self.gpu_maybe)

        # x,v1->v2
        self.embed_embedding_q_v2 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_q_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.reparam_q_v2_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam="rt", use_tanh=False).to(self.gpu_maybe)
    
        # x,v1,v2->v3
        self.embed_embedding_q_v3 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_q_v3 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_q_v3 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.reparam_q_v3_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=2 * DIM_HIDDEN, reparam="rt", use_tanh=False).to(self.gpu_maybe)

        # top-down generative process
        # v3->v2
        self.embed_embedding_p_v2 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_p_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_p_v2 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT, marker="top").to(self.gpu_maybe)
        self.reparam_p_v2_embed_p_v1 = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam=reparam, use_tanh=False).to(self.gpu_maybe)
        self.reparam_p_v2_embed_p_x = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam=reparam, use_tanh=False).to(self.gpu_maybe)

        # v2->v1
        self.embed_embedding_p_v1 = embedding_to_embedding().to(self.gpu_maybe)
        self.mean_p_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT).to(self.gpu_maybe)
        self.var_p_v1 = CustomLinear(dim_in=DIM_HIDDEN, dim_out=DIM_LATENT, marker="top").to(self.gpu_maybe)
        self.reparam_p_v1_embed = Reparameteriser(dim_latent=DIM_LATENT, dim_out=DIM_HIDDEN, reparam=reparam, use_tanh=False).to(self.gpu_maybe)
        
        # v3,v2,v1->x
        self.embedding_to_x = embedding_to_x().to(self.gpu_maybe)

    def encode(self, x):
        # embed x for q_v1 and q_v2
        embedding_x = self.embed_x(x)
        embedding_x_q_v1 = embedding_x[:, :DIM_HIDDEN]
        embedding_x_q_v2 = embedding_x[:, DIM_HIDDEN:2 * DIM_HIDDEN]
        embedding_x_q_v3 = embedding_x[:, 2 * DIM_HIDDEN:]

        # x->v1 params
        embeddings_q_v1 = self.embed_embedding_q_v1(F.tanh(embedding_x_q_v1))
        q_v1_means = self.mean_q_v1(embeddings_q_v1)
        q_v1_inv_softplus_vars = self.var_q_v1(embeddings_q_v1)
        q_v1_vars = F.softplus(q_v1_inv_softplus_vars)

        # sample v1 and embed for q_v2
        eps_q_v1 = torch.empty(q_v1_vars.size(), device=q_v1_vars.device).normal_(0.0, 1.0)
        embedding_v1_q_v2 = self.reparam_q_v1_embed(eps_q_v1, q_v1_means, q_v1_vars)

        # x,v1->v2 params
        embedding_xv1_q_v2 = embedding_x_q_v2 + embedding_v1_q_v2
        embeddings_q_v2 = self.embed_embedding_q_v2(F.tanh(embedding_xv1_q_v2))
        q_v2_means = self.mean_q_v2(embeddings_q_v2)
        q_v2_inv_softplus_vars = self.var_q_v2(embeddings_q_v2)
        q_v2_vars = F.softplus(q_v2_inv_softplus_vars)

        # sample v2 and embed for q_v3
        eps_q_v2 = torch.empty(q_v2_vars.size(), device=q_v2_vars.device).normal_(0.0, 1.0)
        embedding_v2_q_v3 = self.reparam_q_v2_embed(eps_q_v2, q_v2_means, q_v2_vars)

        # x,v2->v3 params
        embedding_xv2_q_v3 = embedding_x_q_v3 + embedding_v2_q_v3
        embeddings_q_v3 = self.embed_embedding_q_v3(F.tanh(embedding_xv2_q_v3))
        q_v3_means = self.mean_q_v3(embeddings_q_v3)
        q_v3_inv_softplus_vars = self.var_q_v3(embeddings_q_v3)
        q_v3_vars = F.softplus(q_v3_inv_softplus_vars)

        return q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars
    
    def decode(self, q_v3_means, q_v3_vars):
        # sample v3 and embed for p_v1 and p_x
        eps_q_v3 = torch.empty(q_v3_vars.size(), device=q_v3_vars.device).normal_(0.0, 1.0)
        embedding_v3 = self.reparam_q_v3_embed(eps_q_v3, q_v3_means, q_v3_vars)
        embedding_v3_p_v2 = embedding_v3[:, :DIM_HIDDEN]
        embedding_v3_p_x = embedding_v3[:, DIM_HIDDEN:]

        # v3->v2 params
        embeddings_p_v2 = self.embed_embedding_p_v2(F.tanh(embedding_v3_p_v2))
        p_v2_means = self.mean_p_v2(embeddings_p_v2)
        p_v2_inv_softplus_vars = self.var_p_v2(embeddings_p_v2)
        p_v2_vars = F.softplus(p_v2_inv_softplus_vars)

        # sample v2 and embed for p_v1
        eps_p_v2 = torch.empty(p_v2_vars.size(), device=p_v2_vars.device).normal_(0.0, 1.0)
        embedding_v2_p_v1 = self.reparam_p_v2_embed_p_v1(eps_p_v2, p_v2_means, p_v2_vars)
        embedding_v2_p_x = self.reparam_p_v2_embed_p_x(eps_p_v2, p_v2_means, p_v2_vars)

        # v2->v1 params
        embeddings_p_v1 = self.embed_embedding_p_v1(F.tanh(embedding_v2_p_v1))
        p_v1_means = self.mean_p_v1(embeddings_p_v1)
        p_v1_inv_softplus_vars = self.var_p_v1(embeddings_p_v1)
        p_v1_vars = F.softplus(p_v1_inv_softplus_vars)     

        # sample v1 and embed for p_x
        eps_p_v1 = torch.empty(p_v1_vars.size(), device=p_v1_vars.device).normal_(0.0, 1.0)
        embedding_v1_p_x = self.reparam_p_v1_embed(eps_p_v1, p_v1_means, p_v1_vars)

        # v2,v1->x params
        embedding_v3v2v1_p_x = embedding_v3_p_x + embedding_v2_p_x + embedding_v1_p_x
        x_hat = self.embedding_to_x(F.tanh(embedding_v3v2v1_p_x))

        return x_hat, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars


    def forward(self, x):
        if self.training:
            x = self.x_flattener(x)
            q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars = self.encode(x)
            x_hat, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars = self.decode(q_v3_means, q_v3_vars)

        else:
            with torch.no_grad():
                x = self.x_flattener(x)
                q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars = self.encode(x)
                x_hat, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars = self.decode(q_v3_means, q_v3_vars)

        return  x_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars

    def train_one_step(self, X_train, nll_loss, opt):
        # training step
        self.train()

        X_train = (X_train > torch.rand(X_train.shape)).double() # dynamic binarisation
        if CUDA:
            X_train = X_train.cuda()
        X_train = X_train.type(torch.double)
        batch_size = X_train.shape[0]
        opt.zero_grad()
        X_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars = self.forward(X_train)
        recon_loss = nll_loss(X_hat.view(batch_size, -1), X_train.view(batch_size, -1)) / batch_size

        # calculate kl components
        q_v1_stddevs = q_v1_vars.sqrt()
        q_v2_stddevs = q_v2_vars.sqrt()
        q_v3_stddevs = q_v3_vars.sqrt()
        p_v3_means = torch.zeros_like(q_v3_means)
        p_v3_stddevs = torch.ones_like(q_v3_stddevs)
        p_v2_stddevs = p_v2_vars.sqrt()
        p_v1_stddevs = p_v1_vars.sqrt()
        if CUDA:
            q_v1_means, q_v1_stddevs = q_v1_means.cuda(), q_v1_stddevs.cuda()
            q_v2_means, q_v2_stddevs = q_v2_means.cuda(), q_v2_stddevs.cuda()
            q_v3_means, q_v3_stddevs = q_v3_means.cuda(), q_v3_stddevs.cuda()
            p_v3_means, p_v3_stddevs = p_v3_means.cuda(), p_v3_stddevs.cuda()
            p_v2_means, p_v2_stddevs = p_v2_means.cuda(), p_v2_stddevs.cuda()
            p_v1_means, p_v1_stddevs = p_v1_means.cuda(), p_v1_stddevs.cuda()

        q_v1 = torch.distributions.Normal(q_v1_means, q_v1_stddevs)
        p_v1 = torch.distributions.Normal(p_v1_means, p_v1_stddevs)
        kl_v1 = torch.distributions.kl.kl_divergence(q_v1, p_v1).sum(dim=1).mean()
        
        q_v2 = torch.distributions.Normal(q_v2_means, q_v2_stddevs)
        p_v2 = torch.distributions.Normal(p_v2_means, p_v2_stddevs)
        kl_v2 = torch.distributions.kl.kl_divergence(q_v2, p_v2).sum(dim=1).mean()

        q_v3 = torch.distributions.Normal(q_v3_means, q_v3_stddevs)
        p_v3 = torch.distributions.Normal(p_v3_means, p_v3_stddevs)
        kl_v3 = torch.distributions.kl.kl_divergence(q_v3, p_v3).sum(dim=1).mean()

        if CUDA:
            recon_loss = recon_loss.cuda()
            kl_v1 = kl_v1.cuda()
            kl_v2 = kl_v2.cuda()
            kl_v3 = kl_v3.cuda()

        vi_loss = recon_loss + kl_v1 + kl_v2 + kl_v3

        vi_loss.backward()
        opt.step()

    def print_elbo(self, step, elbo_loader, nll_loss, elbo_out=None):
        self.eval()

        # compute stable ELBO on test set based on Q_SAMPLES from posterior q
        vi_loss_total = torch.sum(torch.zeros(1, dtype=torch.float))
        vi_loss_total += (len(elbo_loader.dataset) * np.log(Q_SAMPLES_VAE))
        if CUDA:
            vi_loss_total = vi_loss_total.cuda()

        for X_test, _ in tqdm(elbo_loader, desc=f"step {step} elbo"):
            X_test = (X_test > torch.rand(X_test.shape)).double() # dynamic binarisation
            if CUDA:
                X_test = X_test.cuda()
            batch_size, channels, height, width = X_test.shape
            X_test = X_test.unsqueeze(1).expand(-1, Q_SAMPLES_VAE, -1, -1, -1) # batch_size, samples, channels, height, width
            X_test = X_test.reshape(batch_size * Q_SAMPLES_VAE, channels, height, width)

            with torch.no_grad():
                X_hat, q_v1_means, q_v1_vars, q_v2_means, q_v2_vars, q_v3_means, q_v3_vars, p_v1_means, p_v1_vars, p_v2_means, p_v2_vars = self.forward(X_test)
                recon_loss = nll_loss(
                    X_hat.view(batch_size * Q_SAMPLES_VAE, -1), X_test.view(batch_size * Q_SAMPLES_VAE, -1)
                ) # batch_size x samples, pixels
                recon_loss = recon_loss.sum(dim=1) # batch_size x samples

                # calculate kl components
                q_v1_stddevs = q_v1_vars.sqrt()
                q_v2_stddevs = q_v2_vars.sqrt()
                q_v3_stddevs = q_v3_vars.sqrt()
                p_v3_means = torch.zeros_like(q_v3_means)
                p_v3_stddevs = torch.ones_like(q_v3_stddevs)
                p_v2_stddevs = p_v2_vars.sqrt()
                p_v1_stddevs = p_v1_vars.sqrt()
                if CUDA:
                    q_v1_means, q_v1_stddevs = q_v1_means.cuda(), q_v1_stddevs.cuda()
                    q_v2_means, q_v2_stddevs = q_v2_means.cuda(), q_v2_stddevs.cuda()
                    q_v3_means, q_v3_stddevs = q_v3_means.cuda(), q_v3_stddevs.cuda()
                    p_v3_means, p_v3_stddevs = p_v3_means.cuda(), p_v3_stddevs.cuda()
                    p_v2_means, p_v2_stddevs = p_v2_means.cuda(), p_v2_stddevs.cuda()
                    p_v1_means, p_v1_stddevs = p_v1_means.cuda(), p_v1_stddevs.cuda()

                q_v1 = torch.distributions.Normal(q_v1_means, q_v1_stddevs)
                p_v1 = torch.distributions.Normal(p_v1_means, p_v1_stddevs)
                kl_v1 = torch.distributions.kl.kl_divergence(q_v1, p_v1).sum(dim=1).mean()
                
                q_v2 = torch.distributions.Normal(q_v2_means, q_v2_stddevs)
                p_v2 = torch.distributions.Normal(p_v2_means, p_v2_stddevs)
                kl_v2 = torch.distributions.kl.kl_divergence(q_v2, p_v2).sum(dim=1).mean()

                q_v3 = torch.distributions.Normal(q_v3_means, q_v3_stddevs)
                p_v3 = torch.distributions.Normal(p_v3_means, p_v3_stddevs)
                kl_v3 = torch.distributions.kl.kl_divergence(q_v3, p_v3).sum(dim=1).mean()

                if CUDA:
                    recon_loss = recon_loss.cuda()
                    kl_v1 = kl_v1.cuda()
                    kl_v2 = kl_v2.cuda()
                    kl_v3 = kl_v3.cuda()

                # neg ELBO: bound on neg log-likelihood
                vi_loss = recon_loss + kl_v1 + kl_v2 + kl_v3
                vi_loss = vi_loss.view(batch_size, Q_SAMPLES_VAE)

                vi_loss_total -= torch.logsumexp(-vi_loss, dim=1).sum()

        if elbo_out is not None:
            elbo_out.write('{} {} \n'.format(step, vi_loss_total.item()))
            elbo_out.flush()
