import math
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
import pdb


def to_sigma(logvar):
    """ Compute std """
    return torch.exp(0.5*logvar)

def biased_sigmoid(x, simgoid_bias=4.0, eps=1e-5):
    """
    Adapted from https://github.com/applied-ai-lab/genesis/blob/82eb91de18b4eb50b32a3cf99c09dee0803e327c/modules/blocks.py#L28

    This parameterisation bounds sigma of a learned prior to [eps, 1+eps].
    The default sigmoid_bias of 4.0 initialises sigma to be close to 1.0.
    The default eps prevents instability as sigma -> 0.
    """
    return torch.sigmoid(x + simgoid_bias) + eps

def shifted_softplus(x):
    return F.softplus(x + 0.5) + 1e-8

class PriorSigma(nn.Module):
    def __init__(self, sigmoid_bias=4.0, eps=1e-5):
        super().__init__()
        self.sb = sigmoid_bias
        self.eps = eps

    def forward(self, x):
        return biased_sigmoid(x, simgoid_bias=self.sb, eps=self.eps)

class PriorSoftplus(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return shifted_softplus(x)

def layernorm(x):
    """
    :param x: (B, K, L) or (B, K, C, H, W)
    (function adapted from: https://github.com/MichaelKevinKelly/IODINE)
    """
    if len(x.size()) == 3:
        layer_mean = x.mean(dim=2, keepdim=True)
        layer_std = x.std(dim=2, keepdim=True)
    elif len(x.size()) == 5:
        mean = lambda x: x.mean(2, keepdim=True).mean(3, keepdim=True).mean(4, keepdim=True)
        layer_mean = mean(x)
        # this is not implemented in some version of torch
        layer_std = torch.pow(x - layer_mean, 2)
        layer_std = torch.sqrt(mean(layer_std))
    else:
        assert False, 'invalid size for layernorm'

    x = (x - layer_mean) / (layer_std + 1e-5)
    return x


def kl_exponential(post_mu, post_sigma, z_samples=None, pri_mu=None, pri_sigma=None):
    """Support Gaussian only now"""
    if pri_mu is None:
        pri_mu = torch.zeros_like(post_mu, device=post_mu.device, requires_grad=True)
    if pri_sigma is None:
        pri_sigma = torch.ones_like(post_sigma, device=post_sigma.device, requires_grad=True)
    p_post = dist.Normal(post_mu, post_sigma)
    if z_samples is None:
        z_samples = p_post.rsample()
    p_pri = dist.Normal(pri_mu, pri_sigma)
    return p_post.log_prob(z_samples) - p_pri.log_prob(z_samples)


def Gaussian_ll(x_col, _x, masks, std):
    """
    x_col: [B,K,C,H,W]
    _x:    [B,K,3,H,W]
    masks:   [B,K,1,H,W]
    """
    B, K, _, _, _ = x_col.size()
    std_t = torch.tensor([std] * K, device=x_col.device, dtype=x_col.dtype, requires_grad=False)
    std_t = std_t.expand(1, K).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    log_pxz = dist.Normal(x_col, std_t).log_prob(_x)
    ll_pix = torch.logsumexp((masks + 1e-6).log() + log_pxz, dim=1, keepdim=True)  # [B,K,3,H,W]
    assert ll_pix.min().item() > -math.inf
    return ll_pix, log_pxz

def init_weights(net, init_type='normal', init_gain=0.02):
    """
    Adapted from EfficientMORL (https://github.com/pemami4911/EfficientMORL/blob/main/lib/model.py)

    Initialize network weights.
    Modified from: https://github.com/baudm/MONet-pytorch/blob/master/models/networks.py
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
            elif init_type == 'truncated_normal':
                m.weight.data = truncated_normal_initializer(m.weight.shape, 0.0, stddev=init_gain)
            elif init_type == 'zeros':
                torch.nn.init.constant_(m.weight.data, 0.0)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, init_gain)
            torch.nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)

class GatedConv2d(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, stride,
                 padding, dilation=1,
                 h_norm=None, g_norm=None):
        super(GatedConv2d, self).__init__()
        # Main
        #self.activation = activation
        self.activation = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(input_channels, 2*output_channels, kernel_size,
                              stride, padding, dilation)
        # Normalisation
        self.h_norm, self.g_norm = None, None
        # - Hiddens
        if h_norm == 'in':
            self.h_norm = nn.InstanceNorm2d(output_channels, affine=True, eps=1e-5)
        elif h_norm == 'bn':
            self.h_norm = nn.BatchNorm2d(output_channels, eps=1e-5)
        elif h_norm is None or h_norm == 'none':
            pass
        else:
            raise ValueError("Normalisation option not recognised.")
        # - Gates
        if g_norm == 'in':
            self.g_norm = nn.InstanceNorm2d(output_channels, affine=True, eps=1e-5)
        elif g_norm == 'bn':
            self.g_norm = nn.BatchNorm2d(output_channels, eps=1e-5)
        elif g_norm is None or g_norm == 'none':
            pass
        else:
            raise ValueError("Normalisation option not recognised.")

    def forward(self, x):
        h, g = torch.chunk(self.conv(x), 2, dim=1)
        # Features
        if self.h_norm is not None:
            h = self.h_norm(h)
        if self.activation is not None:
            h = self.activation(h)
        # Gates
        if self.g_norm is not None:
            g = self.g_norm(g)
        g = self.sigmoid(g)
        # Output
        return h * g

class Encoder(nn.Module):
    def __init__(self, input_dim, output_dim, image_size):
        super(Encoder, self).__init__()
        height = image_size[0]
        width = image_size[1]
        self.convs = nn.Sequential(
            nn.Conv2d(input_dim, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU(inplace=True)
        )

        for i in range(4):
            width = (width - 1) // 2
            height = (height - 1) // 2

        self.mlp = nn.Sequential(
            nn.Linear(64 * width * height, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.shape[0], -1)
        return self.mlp(x)

class GlobalEncoder(nn.Module):
    def __init__(self, input_dim, z_dim, v_dim, image_size, nf=32, linear_dim=256, vdim_input=3):
        super().__init__()
        height = image_size[0]
        width = image_size[1]
        self.z_dim = z_dim    # dimension of global latent 
        self.v_dim = v_dim
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(input_dim, nf, 3, stride=2, bias=False),
        #     nn.BatchNorm2d(nf), nn.ReLU(inplace=True),
        #     nn.Conv2d(nf, nf, 3, stride=2, bias=False),
        #     nn.BatchNorm2d(nf), nn.ReLU(inplace=True),
        #     nn.Conv2d(nf, nf*2, 3, stride=2, bias=False),
        #     nn.BatchNorm2d(nf*2), nn.ReLU(inplace=True),
        #     nn.Conv2d(nf*2, nf*2, 3, stride=2, bias=False),
        #     nn.BatchNorm2d(nf*2), nn.ReLU(inplace=True)
        # )


        # self.encoder = nn.Sequential(
        #     nn.Conv2d(input_dim, nf, 3, stride=2),
        #     nn.SiLU(inplace=True),
        #     nn.Conv2d(nf, nf, 3, stride=2),
        #     nn.SiLU(inplace=True),
        #     nn.Conv2d(nf, nf*2, 3, stride=2),
        #     nn.SiLU(inplace=True),
        #     nn.Conv2d(nf*2, nf*2, 3, stride=2),
        #     nn.SiLU(inplace=True)
        # )
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU(inplace=True)
        )
        last_nf = nf*2
        for i in range(4):
            width = (width - 1) // 2
            height = (height - 1) // 2

        # # Gated Convolution Version
        # self.encoder = nn.Sequential(
        #     GatedConv2d(input_dim, nf, kernel_size = 5, stride=1, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),    # 128->128
        #     GatedConv2d(nf, nf*2, kernel_size = 5, stride=2, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),  # 128x128 -> 64x64
        #     GatedConv2d(nf*2, nf*2, kernel_size = 5, stride=1, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),  # 64x64 -> 64x64
        #     GatedConv2d(nf*2, nf*4, kernel_size = 5, stride=2, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),  # 64x64 -> 32x32
        #     GatedConv2d(nf*4, nf*4, kernel_size = 5, stride=1, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),  # 32x32 -> 32x32
        #     # GatedConv2d(nf*4, nf*4, kernel_size = 5, stride=2, padding=2, h_norm="bn", g_norm="bn"), nn.ELU(True),  # 32x32 -> 16x16
        # )
        # last_nf = nf*4
        # width = height = 16

        self.view_enc = nn.Sequential(
            nn.Linear(vdim_input, 128),
            nn.ReLU(True),
            nn.Linear(128, v_dim)
        )
        self.mlp = nn.Sequential(
            nn.Linear(last_nf * width * height + v_dim, linear_dim),
            nn.SiLU(inplace=True),
            nn.Linear(linear_dim, z_dim*2)
        )
        # self.view_enc = nn.Sequential(
        #     nn.Linear(3, 128),
        #     nn.ReLU(True),
        #     nn.Linear(128, v_dim)
        # )
        # self.mlp = nn.Sequential(
        #     nn.Linear(last_nf * width * height + v_dim, linear_dim),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(linear_dim, z_dim*2)
        # )

    def _poe(self, mu, logvar, prior_exp=True, eps=1e-8):
        # adapted from MVAE implementation 
        # https://github.com/mhw32/multimodal-vae-public

        # Product of Experts with a prior expert
        # input: 
        #   mu, logvar -> (B, V, z_dim)
        # output:
        #   mu, logvar -> (B, z_dim)

        if prior_exp:
            mu_0 = torch.zeros_like(mu)
            logvar_0 = torch.zeros_like(logvar)
            mu = torch.cat([mu_0, mu], dim=1)
            logvar = torch.cat([logvar_0, logvar], dim=1)

        var = torch.exp(logvar)
        T = 1 / (var + eps)  # precision of i-th Gaussian expert at point x
        pd_mu = torch.sum(mu * T, dim=1) / torch.sum(T, dim=1)
        pd_var = 1 / torch.sum(T, dim=1)
        return pd_mu, torch.log(pd_var)
        
    def forward(self, x, v):
        """
        input
            x: (B, V, C, h, w)
            v: (B, V, view_dim)
        output
            lambda: (loc, scale)  loc/scale: (B, V, z_dim)
        """
        # print(x.shape)
        B, V, C, h, w = x.shape
        h = self.encoder(x.view(B*V, C, h, w))
        # print(h.shape)
        assert v.shape[2] == 3 or v.shape[2] == 7, "wrong view dim in global enc"
        _v = self.view_enc(v)
        h = self.mlp(torch.cat([h.view(B, V, -1), _v], dim=2))    # logits: (B, V, z_dim*2+view_dim)
        mu, logvar = h.chunk(2, dim=2)    # mu, logvar: (B, V, z_dim)
        mu_g, logvar_g = self._poe(mu, logvar)
        return mu_g, logvar_g


class SpatialBroadcastDec(nn.Module):
    """Variational Autoencoder with spatial broadcast decoder"""
    def __init__(self, input_dim, output_dim, image_size, num_layers=4, decoder='sbd'):
        super(SpatialBroadcastDec, self).__init__()
        self.height = image_size[0]
        self.width = image_size[1]
        
        self.n_layers = num_layers
        convs = [nn.Conv2d(input_dim+2, 32, 3), nn.ReLU(inplace=True)]
        for i in range(self.n_layers-1):
            convs.append(nn.Conv2d(32, 32, 3))
            convs.append(nn.ReLU(inplace=True))
        convs.append(nn.Conv2d(32, output_dim, 1))
        self.convs = nn.Sequential(*convs)

        # self.convs = nn.Sequential(
        #     nn.Conv2d(input_dim+2, 32, 3),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(32, 32, 3),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(32, 32, 3),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(32, 32, 3),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(32, output_dim, 1),
        # )

        ys = torch.linspace(-1, 1, self.height + 2*self.n_layers)
        xs = torch.linspace(-1, 1, self.width + 2*self.n_layers)
        ys, xs = torch.meshgrid(ys, xs)
        coord_map = torch.stack((ys, xs)).unsqueeze(0)
        self.register_buffer('coord_map_const', coord_map)

    def forward(self, z):
        z_tiled = z.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.height + 2*self.n_layers, self.width + 2*self.n_layers)
        coord_map = self.coord_map_const.repeat(z.shape[0], 1, 1, 1)
        inp = torch.cat((z_tiled, coord_map), 1)
        result = self.convs(inp)
        return result


class RefineNetLSTM(nn.Module):
    """
    function Phi (see Sec 3.3 of the paper)
    (adapted from: https://github.com/MichaelKevinKelly/IODINE)
    """
    def __init__(self, z_dim, channels_in, image_size):
        super(RefineNetLSTM, self).__init__()
        self.convnet = Encoder(channels_in, 128, image_size)
        self.lstm = nn.LSTMCell(128 + 4 * z_dim, 128, bias=True)
        self.fc_out = nn.Linear(128, 2 * z_dim)

    def forward(self, x, h, c):
        x_img, lmbda_moment = x['img'], x['state']
        conv_codes = self.convnet(x_img)
        lstm_input = torch.cat((lmbda_moment, conv_codes), dim=1)
        h, c = self.lstm(lstm_input, (h, c))
        return self.fc_out(h), h, c

class LightRefineNet(nn.Module):
    """
    referenced EfficientMORL (https://github.com/pemami4911/EfficientMORL/blob/main/lib/model.py)

    This refine network does not take any image-size variables as input such as pixel-wise likelihood

    lambda' = lambda + refine_net( lambda, d loss/d lambda , loss)

    loss: (,)  (aggregated over all view-points)
    lambda: (N*K, 2*z_dim)    (one scene has one lambda, shared across all views)
    """
    def __init__(self, z_dim, z_g_dim, K, h_dim=128):
        super().__init__()
        self.K = K
        self.mlp = nn.Sequential(
            nn.Linear(z_dim*4 + z_g_dim*2, h_dim),
            # nn.Linear(z_dim*4, h_dim),
            nn.ELU(inplace=True),
            nn.Linear(h_dim, h_dim)
        )
        self.rnn = nn.LSTM(h_dim, z_dim*2)  # EfficientMORL uses GRU, and hidden dim is "z_dim"
        # self.rnn = nn.GRU(h_dim, z_dim*2)

        self.loc = nn.Linear(z_dim*2, z_dim)
        self.logvar = nn.Linear(z_dim*2, z_dim)

        init_weights(self.loc, 'xavier')
        init_weights(self.logvar, 'xavier')
        init_weights(self.mlp, 'xavier')

        self.loc_LN = nn.LayerNorm((z_dim,), elementwise_affine=False)
        self.logvar_LN = nn.LayerNorm((z_dim,), elementwise_affine=False)

    def forward(self, loss, lambd, global_ft, state):
        N = int(lambd.shape[0] / self.K)
        # loc, logvar = lambd.chunk(2, dim=1)    # (N*K, z_dim)
        # loss = loss.unsqueeze(-1).expand(N, self.K).view(N*self.K, 1)
        d_lambd = torch.autograd.grad(loss, lambd, retain_graph=True, only_inputs=True)[0]    # (N*K, z_dim*2)
        d_loc, d_logvar = d_lambd.chunk(2, dim=1)
        d_loc, d_logvar = self.loc_LN(d_loc.contiguous()).detach(), self.logvar_LN(d_logvar.contiguous()).detach()

        # without global_ft
        # ft = self.mlp(torch.cat([lambd, d_loc, d_logvar], dim=1)).unsqueeze(0)    # (1, N*K, z_dim*4)   
        # with global_ft
        global_ft = global_ft.unsqueeze(1).expand(-1, self.K, -1).contiguous().view(N*self.K, global_ft.shape[-1]).detach()
        ft = self.mlp(torch.cat([global_ft, lambd, d_loc, d_logvar], dim=1)).unsqueeze(0)    # (1, N*K, z_dim*4)

        self.rnn.flatten_parameters()
        h, state = self.rnn(ft, state)

        loc = self.loc(h.squeeze(0))
        logvar = self.logvar(h.squeeze(0))
        return torch.cat([loc, logvar], dim=1), state

class SequentialEncoder(nn.Module):
    """
    This is a model for 
        q(z_k | z_g)

    options:
    * MLP
    * LSTM (Autoregressive)
    * Transformer
    """

    def __init__(self, K, z_dim_g, z_dim, h_dim=256, type="mlp"):
        super().__init__()
        
        if "mlp" in type: 
            self.model = nn.Sequential(
                nn.Linear(z_dim_g, h_dim), nn.ELU(),
                nn.Linear(h_dim, z_dim*2*K), nn.ELU()
            )
            # self.loc = nn.Linear(z_dim*2, z_dim)
            # self.logvar = nn.Linear(z_dim*2, z_dim)
            self.linear = nn.Linear(z_dim*2, z_dim*2)

            # init_weights(self.model, "xavier")
            # init_weights(self.linear, "xavier")

        elif "lstm" in type:
            self.model = nn.LSTM(z_dim_g+z_dim*2, h_dim, batch_first=True)    # input at iteration k: [z_g, z_k] -> output z_{k+1}
            self.linear = nn.Linear(h_dim, z_dim*2)

        # elif "transformer" in type:
            

        elif "iid" in type:
            self.model = nn.Sequential(
                nn.Linear(z_dim_g, h_dim), nn.ELU(),
                nn.Linear(h_dim, z_dim*2), nn.ELU()
            )
            self.loc = nn.Linear(z_dim*2, z_dim)
            self.logvar = nn.Linear(z_dim*2, z_dim)

        self.tp = type
        self.K = K
        self.z_dim_g = z_dim_g
        self.z_dim = z_dim

    def forward(self, lmbda):
        """
        input 
        :lmbda (B, z_dim_g):
        
        output
        :lmbda' (B*K, z_dim*2)
        """
        B, _ = lmbda.shape
        if "mlp" in self.tp:
            # lmbda = lmbda.unsqueeze(1).expand((B, self.K, self.z_dim_g))
            h = self.model(lmbda).view(B, self.K, self.z_dim*2)
            # loc = self.loc(h)
            # logvar = self.logvar(h)
            # return torch.cat([loc, logvar], dim=2)    # lambda: (B, K, z_dim*2)
            return self.linear(h)    # lambda: (B, K, z_dim*2)

        elif "lstm" in self.tp:
            h = torch.cat([lmbda, torch.zeros(lmbda.shape[0], self.z_dim*2, device=lmbda.device)], dim=1).unsqueeze(1)    # h: (B, 1, z_dim_g+z_dim*2)
            state = None
            params = []
            for i in range(self.K):
                h, state = self.model(h, state)
                h = self.linear(h)
                params.append(h)
                h = torch.cat([lmbda.unsqueeze(1), h], dim=2)

            assert len(params) == self.K, "something wrong with slot iteration"
            return torch.cat(params, dim=1)    # (B, K, z_dim*2)

        elif "iid" in self.tp:
            lmbda = lmbda.unsqueeze(1).expand((B, self.K, self.z_dim_g))
            h = self.model(lmbda)
            loc = self.loc(h)
            logvar = self.logvar(h)
            return torch.cat([loc, logvar], dim=2)    # lambda: (B, K, z_dim*2)
            # return lmbda.unsqueeze(1).expand((B, self.K, self.z_dim*2))

class ComponentPrior(nn.Module):
    def __init__(self, K, z_dim, hdim=128, cfg=None, type="mlp", detach=True):
        super().__init__()
        self.K = K
        self.z_dim = z_dim
        self.tp = type
        self.det = detach
        if "mlp" in type: 
            # self.model = nn.Sequential(
            #     nn.Linear(z_dim*2, hdim),
            #     nn.ELU(),
            #     nn.Linear(hdim, z_dim*2)
            # )
            self.model = nn.Sequential(
                nn.Linear(K*z_dim*2, hdim),
                nn.ELU(),
                nn.Linear(hdim, K*z_dim*2)
            )
        elif "transformer" in type:
            if cfg is not None:
                nhead = cfg["nhead"]
                d_model = cfg["dmodel"]
                fw_dim = cfg["fw_dim"]
                n_layers = cfg["n_layers"]
            else:
                nhead = 4
                d_model = hdim
                fw_dim = 1024
                n_layers = 4

            self.pos_enc = nn.Parameter(torch.zeros(K, 1, d_model, device="cuda:0"))
            self.emb_layer = nn.Linear(z_dim*2, d_model)
            self.dec_layer = nn.Linear(d_model, z_dim*2)
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=fw_dim, dropout=0.0, activation="gelu")
            self.model = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # the domain of output parameters are restricted (mu: [-1, 1],    sigma: [0, 1])
        # self.loc = nn.Sequential(
        #     nn.Linear(z_dim*2, z_dim), nn.Tanh()    # mu is restricted to [-1, 1]
        # )
        # self.scale = nn.Sequential(
        #     nn.Linear(z_dim*2, z_dim), PriorSigma()    # sigma is restricted to [0, 1]
        # )

        # no restriction for output parameters
        self.loc = nn.Sequential(
            nn.Linear(z_dim*2, z_dim)
        )
        self.scale = nn.Sequential(
            nn.Linear(z_dim*2, z_dim), PriorSoftplus()
        )

    def forward(self, lmbda):
        """
        input: 
            lmbda(= seq_enc(z_g)) (B*K, z_dim*2)

        output:
            p(z_k | z_g)
        """
        B = int(lmbda.shape[0] // self.K)
        if self.det:
            lmbda = lmbda.detach()

        if "mlp" in self.tp:
            # # independent version
            # h = self.model(lmbda)
            # loc = self.loc(h)
            # scale = self.scale(h)

            # consider dependencies between slots
            h = self.model(lmbda.view(B, self.K, self.z_dim*2)).view(B*self.K, self.z_dim*2)
            loc = self.loc(h)
            scale = self.scale(h)
        elif "transformer" in self.tp:
            h = self.emb_layer(lmbda.view(B, self.K, self.z_dim*2).permute(1,0,2).contiguous()) + self.pos_enc    # (K, B, hdim)
            h = self.model(h)
            h = self.dec_layer(h).permute(1,0,2).contiguous().view(B*self.K, self.z_dim*2)
            loc = self.loc(h)
            scale = self.scale(h)
        return dist.Normal(loc, scale)