import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Tuple, Optional, Union, Dict

from model.vae_models.vae import Decoder as MLPDecoder


def reparameterize(mu: torch.Tensor, logvar: torch.Tensor, eps: Optional[torch.Tensor] = None) -> torch.Tensor:
    std = torch.exp(0.5 * logvar)
    if eps is None:
        eps = torch.randn_like(std)
    return mu + eps * std


def _clamp_logvar(logvar: torch.Tensor, min_val: float = -10.0, max_val: float = 10.0) -> torch.Tensor:
    return torch.clamp(logvar, min=min_val, max=max_val)


def gaussian_kl(q_mu: torch.Tensor, q_logvar: torch.Tensor, p_mu: torch.Tensor, p_logvar: torch.Tensor) -> torch.Tensor:
    """KL(q||p) for diagonal Gaussians, returns mean over batch."""
    # Ensure broadcastability
    q_logvar = _clamp_logvar(q_logvar)
    p_logvar = _clamp_logvar(p_logvar)
    var_ratio = torch.exp(q_logvar - p_logvar)
    diff = q_mu - p_mu
    inv_p_var = torch.exp(-p_logvar)
    kl_per_dim = 0.5 * (p_logvar - q_logvar + var_ratio + diff * diff * inv_p_var - 1.0)
    return kl_per_dim.sum(dim=-1).mean()


def combine_gaussians(mu_u: torch.Tensor, logvar_u: torch.Tensor,
                      mu_p: torch.Tensor, logvar_p: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Ladder-style precision-weighted combination of two diagonal Gaussians:
      q(z) ∝ N(z; mu_u, var_u) * N(z; mu_p, var_p)
    Returns parameters of resulting Gaussian.
    """
    # precision = 1/var = exp(-logvar)
    logvar_u = _clamp_logvar(logvar_u)
    logvar_p = _clamp_logvar(logvar_p)
    prec_u = torch.exp(-logvar_u)
    prec_p = torch.exp(-logvar_p)
    prec = prec_u + prec_p
    var = 1.0 / prec
    mu = (mu_u * prec_u + mu_p * prec_p) * var
    logvar = torch.log(var)
    return mu, logvar


class HierarchicalMLPVAE(nn.Module):
    """
    Ladder-style hierarchical VAE (MLP) with L latent layers.
    Generative: p(z_L)=N(0,I); p(z_l | z_{l+1}); p(x|z_1)
    Inference: q(z_L|x) (combined with prior) and q(z_l|x,z_{l+1}) via precision-weighted combination.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        latent_dims: List[int],
        *,
        encoder_hidden_dims: Optional[List[int]] = None,
        decoder_hidden_dims: Optional[List[int]] = None,
    ) -> None:
        super().__init__()
        self.input_dim = int(input_dim)
        self.latent_dims = [int(d) for d in latent_dims]
        self.num_layers = len(self.latent_dims)

        enc_dims = encoder_hidden_dims if encoder_hidden_dims is not None else hidden_dims
        dec_dims = decoder_hidden_dims if decoder_hidden_dims is not None else hidden_dims

        # Encoder trunk: x -> feat
        enc_layers: List[nn.Module] = []
        prev = self.input_dim
        for h in enc_dims:
            enc_layers.append(nn.Linear(prev, h))
            enc_layers.append(nn.ReLU())
            enc_layers.append(nn.BatchNorm1d(h))
            prev = h
        self.encoder_trunk = nn.Sequential(*enc_layers)
        enc_feat_dim = prev

        # Upward heads for each latent layer: feat -> (mu_u, logvar_u)
        self.u_mu_heads = nn.ModuleList([
            nn.Linear(enc_feat_dim, ld) for ld in self.latent_dims
        ])
        self.u_logvar_heads = nn.ModuleList([
            nn.Linear(enc_feat_dim, ld) for ld in self.latent_dims
        ])

        # Top-down prior conditionals p(z_l | z_{l+1}) for l=1..L-1
        self.p_mu_layers = nn.ModuleList()
        self.p_logvar_layers = nn.ModuleList()
        for l in range(self.num_layers - 1):
            parent_dim = self.latent_dims[l + 1]
            child_dim = self.latent_dims[l]
            # Simple 2-layer MLP for conditional prior parameters
            self.p_mu_layers.append(nn.Sequential(
                nn.Linear(parent_dim, max(parent_dim, child_dim)),
                nn.ReLU(),
                nn.Linear(max(parent_dim, child_dim), child_dim),
            ))
            self.p_logvar_layers.append(nn.Sequential(
                nn.Linear(parent_dim, max(parent_dim, child_dim)),
                nn.ReLU(),
                nn.Linear(max(parent_dim, child_dim), child_dim),
            ))

        # Decoder p(x|z1)
        self.decoder = MLPDecoder(self.latent_dims[0], dec_dims, self.input_dim)

    # ---------- Helper: upward parameters ----------
    def _compute_upward_params(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]:
        feat = self.encoder_trunk(x)
        mu_u = [head(feat) for head in self.u_mu_heads]
        logvar_u = [
            _clamp_logvar(head(feat)) for head in self.u_logvar_heads
        ]
        return mu_u, logvar_u, feat

    # ---------- Forward for convenience ----------
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Provide a minimal forward: decode from sampled z1; return top layer stats of z1
        q_mus, q_logvars, _ = self._infer_posterior(x)
        z1 = reparameterize(q_mus[0], q_logvars[0])
        x_recon = self.decoder(z1)
        return x_recon, q_mus[0], q_logvars[0]

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return approximate posterior parameters for bottom latent z1."""
        q_mus, q_logvars, _ = self._infer_posterior(x)
        return q_mus[0], q_logvars[0]

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)

    def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
        """Ancestral sampling: z_L ~ N(0,I), then z_{l} ~ p(z_l|z_{l+1}), finally x ~ p(x|z_1)."""
        top = self.num_layers - 1
        z_parent = torch.randn(num_samples, self.latent_dims[top], device=device)
        # Downward through conditional priors
        for l in range(top - 1, -1, -1):
            p_mu_l = self.p_mu_layers[l](z_parent)
            p_logvar_l = self.p_logvar_layers[l](z_parent)
            z_child = reparameterize(p_mu_l, p_logvar_l)
            z_parent = z_child
        z1 = z_parent
        x = self.decode(z1)
        return x

    # ---------- Posterior inference (returns lists per layer, 0 is z1) ----------
    def _infer_posterior(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        mu_u, logvar_u, _ = self._compute_upward_params(x)
        q_mus: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_logvars: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        z_samples: List[torch.Tensor] = [None] * self.num_layers  # type: ignore

        # Top layer index
        top = self.num_layers - 1
        # Prior at top: N(0, I)
        p_mu_top = torch.zeros_like(mu_u[top])
        p_logvar_top = torch.zeros_like(logvar_u[top])
        q_mu_top, q_logvar_top = combine_gaussians(mu_u[top], logvar_u[top], p_mu_top, p_logvar_top)
        q_mus[top] = q_mu_top
        q_logvars[top] = q_logvar_top
        z_samples[top] = reparameterize(q_mu_top, q_logvar_top)

        # Go down the hierarchy
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](z_samples[parent])
            p_logvar_l = self.p_logvar_layers[l](z_samples[parent])
            q_mu_l, q_logvar_l = combine_gaussians(mu_u[l], logvar_u[l], p_mu_l, p_logvar_l)
            q_mus[l] = q_mu_l
            q_logvars[l] = q_logvar_l
            z_samples[l] = reparameterize(q_mu_l, q_logvar_l)

        return q_mus, q_logvars, z_samples

    # ---------- ELBO ----------
    def compute_elbo_loss(self, x: torch.Tensor, beta: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q_mus, q_logvars, z_samples = self._infer_posterior(x)
        # Recon from z1
        x_recon = self.decoder(z_samples[0])
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='mean')

        # KL sum across layers
        total_kl = 0.0
        top = self.num_layers - 1
        # Top: KL(q(zL|x) || p(zL)) with p=N(0,I)
        p_mu_top = torch.zeros_like(q_mus[top])
        p_logvar_top = torch.zeros_like(q_logvars[top])
        total_kl = total_kl + gaussian_kl(q_mus[top], q_logvars[top], p_mu_top, p_logvar_top)
        # Others: KL(q(zl|x, z_{l+1}) || p(zl|z_{l+1}))
        for l in range(top - 1, -1, -1):
            parent = l + 1
            # Do not detach z_{l+1}: allow gradients to propagate to upper latents
            p_mu_l = self.p_mu_layers[l](z_samples[parent])
            p_logvar_l = self.p_logvar_layers[l](z_samples[parent])
            total_kl = total_kl + gaussian_kl(q_mus[l], q_logvars[l], p_mu_l, p_logvar_l)

        total_loss = recon_loss + beta * total_kl
        return total_loss, recon_loss, torch.as_tensor(total_kl, device=x.device)

    # ---------- IWAE ----------
    def _log_normal_diag(self, z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        # returns log N(z; mu, var) summed over dims
        logvar = _clamp_logvar(logvar)
        log_2pi = math.log(2.0 * math.pi)
        return -0.5 * (
            logvar + (z - mu) * (z - mu) * torch.exp(-logvar) + torch.tensor(log_2pi, device=z.device)
        ).sum(dim=-1)

    def compute_iwae_loss(
        self,
        x: torch.Tensor,
        k: Union[int, List[int]] = 5,
        *,
        mode: str = 'joint',
        max_joint_k: int = 50,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute IWAE bound for hierarchical model.
        - If k is int: sample K joint latent trajectories.
        - If k is list[int] with length L and mode='joint': nested sampling with total K=Π k_l (capped by max_joint_k; fallback to shared).
        - If k is list[int] with mode='shared': use max(k_l) as global K.
        Returns dict with loss, recon_loss (weighted), kl_loss (analytic proxy), and optionally diagnostics.
        """
        if isinstance(k, list):
            if mode == 'shared':
                K = max(k)
            else:
                K = 1
                for v in k:
                    K *= int(v)
                if K > max_joint_k:
                    K = max(k)
                    mode = 'shared'
        else:
            K = int(k)

        batch_size = x.size(0)

        # Upward params
        mu_u, logvar_u, _ = self._compute_upward_params(x)

        # Allocate containers per layer with shape [B, K, D_l]
        zs: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_mus: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_logvars: List[torch.Tensor] = [None] * self.num_layers  # type: ignore

        # Top layer
        top = self.num_layers - 1
        p_mu_top = torch.zeros_like(mu_u[top]).unsqueeze(1).expand(batch_size, K, -1)
        p_logvar_top = torch.zeros_like(logvar_u[top]).unsqueeze(1).expand(batch_size, K, -1)
        mu_u_top = mu_u[top].unsqueeze(1).expand(batch_size, K, -1)
        logvar_u_top = logvar_u[top].unsqueeze(1).expand(batch_size, K, -1)
        q_mu_top, q_logvar_top = combine_gaussians(mu_u_top, logvar_u_top, p_mu_top, p_logvar_top)
        eps_top = torch.randn_like(q_logvar_top)
        z_top = reparameterize(q_mu_top, q_logvar_top, eps_top)
        q_mus[top] = q_mu_top
        q_logvars[top] = q_logvar_top
        zs[top] = z_top

        # Downward sampling for remaining layers
        for l in range(top - 1, -1, -1):
            parent = l + 1
            # Compute prior params given parent sample
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)

            mu_u_l = mu_u[l].unsqueeze(1).expand(batch_size, K, -1)
            logvar_u_l = logvar_u[l].unsqueeze(1).expand(batch_size, K, -1)
            q_mu_l, q_logvar_l = combine_gaussians(mu_u_l, logvar_u_l, p_mu_l, p_logvar_l)
            eps = torch.randn_like(q_logvar_l)
            z_l = reparameterize(q_mu_l, q_logvar_l, eps)
            q_mus[l] = q_mu_l
            q_logvars[l] = q_logvar_l
            zs[l] = z_l

        # Decode from z1
        z1_flat = zs[0].reshape(batch_size * K, -1)
        x_recon_flat = self.decoder(z1_flat)
        x_recon_flat = torch.clamp(x_recon_flat, 1e-6, 1.0 - 1e-6)
        if x.dim() == 2:
            x_rep = x.unsqueeze(1).repeat(1, K, 1).reshape_as(x_recon_flat)
            bce = F.binary_cross_entropy(x_recon_flat, x_rep, reduction='none').view(batch_size, K, -1).sum(dim=2)
        else:
            # x is expected flattened for MLP; but support [B, C, H, W] by flattening
            x_flat = x.view(x.size(0), -1)
            x_rep = x_flat.unsqueeze(1).repeat(1, K, 1).reshape_as(x_recon_flat)
            bce = F.binary_cross_entropy(x_recon_flat, x_rep, reduction='none').view(batch_size, K, -1).sum(dim=2)
        log_px_z = -bce  # [B, K]

        # Log priors and posteriors
        log_p = torch.zeros_like(log_px_z)
        log_q = torch.zeros_like(log_px_z)

        # Top prior and posterior
        log_p = log_p + self._log_normal_diag(zs[top], p_mu_top, p_logvar_top)
        log_q = log_q + self._log_normal_diag(zs[top], q_mus[top], q_logvars[top])

        # Other layers
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            log_p = log_p + self._log_normal_diag(zs[l], p_mu_l, p_logvar_l)
            log_q = log_q + self._log_normal_diag(zs[l], q_mus[l], q_logvars[l])

        log_w = log_px_z + log_p - log_q  # [B, K]
        logsumexp = torch.logsumexp(log_w, dim=1)  # [B]
        iwae_bound = logsumexp - torch.log(torch.tensor(K, dtype=logsumexp.dtype, device=logsumexp.device))
        loss = -iwae_bound.mean()

        # Diagnostics: weighted reconstruction and analytic KL proxy
        w_tilde = torch.softmax(log_w, dim=1).detach()
        recon_per_k = bce  # already sum over dims
        recon_loss = (w_tilde * (recon_per_k / x.view(x.size(0), -1).size(1))).sum(dim=1).mean()

        # Analytic KL proxy using mean of q vs p at current z samples
        kl_proxy = 0.0
        # Top KL
        kl_proxy = kl_proxy + gaussian_kl(q_mus[top].reshape(batch_size * K, -1),
                                          q_logvars[top].reshape(batch_size * K, -1),
                                          p_mu_top.reshape(batch_size * K, -1),
                                          p_logvar_top.reshape(batch_size * K, -1))
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1))
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1))
            kl_proxy = kl_proxy + gaussian_kl(q_mus[l].reshape(batch_size * K, -1),
                                              q_logvars[l].reshape(batch_size * K, -1),
                                              p_mu_l, p_logvar_l)

        return {
            'loss': loss,
            'recon_loss': recon_loss,
            'kl_loss': torch.as_tensor(kl_proxy, device=x.device),
        }


class HierarchicalConvVAE(nn.Module):
    """
    Ladder-style hierarchical VAE (CNN encoder trunk + MLP conditional priors, Conv decoder for p(x|z1)).
    """

    def __init__(
        self,
        in_channels: int,
        encoder_channels: List[int],
        decoder_channels: List[int],
        latent_dims: List[int],
    ) -> None:
        super().__init__()
        self.in_channels = int(in_channels)
        self.latent_dims = [int(d) for d in latent_dims]
        self.num_layers = len(self.latent_dims)

        # CNN encoder trunk -> global pooled feature
        layers: List[nn.Module] = []
        prev_c = self.in_channels
        spatial = 28
        downsamples = 0
        for out_c in encoder_channels:
            if downsamples < 2:
                layers.append(nn.Conv2d(prev_c, out_c, kernel_size=4, stride=2, padding=1))
                spatial = spatial // 2
                downsamples += 1
            else:
                layers.append(nn.Conv2d(prev_c, out_c, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.BatchNorm2d(out_c))
            prev_c = out_c
        self.enc_cnn = nn.Sequential(*layers)
        self.enc_feat_channels = prev_c
        self.enc_feat_size = spatial
        flat_dim = self.enc_feat_channels * self.enc_feat_size * self.enc_feat_size
        self.enc_fc = nn.Sequential(
            nn.Linear(flat_dim, max(256, self.latent_dims[0] * 2)),
            nn.ReLU(),
        )
        feat_dim = max(256, self.latent_dims[0] * 2)

        # Upward heads per latent
        self.u_mu_heads = nn.ModuleList([nn.Linear(feat_dim, d) for d in self.latent_dims])
        self.u_logvar_heads = nn.ModuleList([nn.Linear(feat_dim, d) for d in self.latent_dims])

        # Top-down conditional priors
        self.p_mu_layers = nn.ModuleList()
        self.p_logvar_layers = nn.ModuleList()
        for l in range(self.num_layers - 1):
            parent_dim = self.latent_dims[l + 1]
            child_dim = self.latent_dims[l]
            self.p_mu_layers.append(nn.Sequential(
                nn.Linear(parent_dim, max(parent_dim, child_dim)),
                nn.ReLU(),
                nn.Linear(max(parent_dim, child_dim), child_dim),
            ))
            self.p_logvar_layers.append(nn.Sequential(
                nn.Linear(parent_dim, max(parent_dim, child_dim)),
                nn.ReLU(),
                nn.Linear(max(parent_dim, child_dim), child_dim),
            ))

        # Conv decoder to image from z1
        from model.vae_models.conv_vae import ConvDecoder, ConvEncoder  # reuse decoder shape logic
        # Create a tiny dummy encoder to reuse its feature size in decoder path
        dummy_encoder = ConvEncoder(self.in_channels, encoder_channels, self.latent_dims[0])
        self.feature_channels = dummy_encoder.feature_channels
        self.feature_size = dummy_encoder.feature_size
        self.decoder = ConvDecoder(self.latent_dims[0], decoder_channels, out_channels=self.in_channels,
                                   feature_channels=self.feature_channels, feature_size=self.feature_size)

    def _compute_upward_params(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        h = self.enc_cnn(x)
        h_flat = h.view(h.size(0), -1)
        feat = self.enc_fc(h_flat)
        mu_u = [head(feat) for head in self.u_mu_heads]
        logvar_u = [head(feat) for head in self.u_logvar_heads]
        return mu_u, logvar_u

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q_mus, q_logvars, _ = self._infer_posterior(x)
        z1 = reparameterize(q_mus[0], q_logvars[0])
        x_recon = self.decoder(z1, self.feature_channels, self.feature_size)
        return x_recon, q_mus[0], q_logvars[0]

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return approximate posterior parameters for bottom latent z1."""
        q_mus, q_logvars, _ = self._infer_posterior(x)
        return q_mus[0], q_logvars[0]

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z, self.feature_channels, self.feature_size)

    def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
        """Ancestral sampling for CNN variant."""
        top = self.num_layers - 1
        z_parent = torch.randn(num_samples, self.latent_dims[top], device=device)
        for l in range(top - 1, -1, -1):
            p_mu_l = self.p_mu_layers[l](z_parent)
            p_logvar_l = self.p_logvar_layers[l](z_parent)
            z_child = reparameterize(p_mu_l, p_logvar_l)
            z_parent = z_child
        z1 = z_parent
        x = self.decode(z1)
        return x

    def _infer_posterior(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        mu_u, logvar_u = self._compute_upward_params(x)
        q_mus: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_logvars: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        z_samples: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        top = self.num_layers - 1
        p_mu_top = torch.zeros_like(mu_u[top])
        p_logvar_top = torch.zeros_like(logvar_u[top])
        q_mu_top, q_logvar_top = combine_gaussians(mu_u[top], logvar_u[top], p_mu_top, p_logvar_top)
        q_mus[top] = q_mu_top
        q_logvars[top] = q_logvar_top
        z_samples[top] = reparameterize(q_mu_top, q_logvar_top)
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](z_samples[parent])
            p_logvar_l = self.p_logvar_layers[l](z_samples[parent])
            q_mu_l, q_logvar_l = combine_gaussians(mu_u[l], logvar_u[l], p_mu_l, p_logvar_l)
            q_mus[l] = q_mu_l
            q_logvars[l] = q_logvar_l
            z_samples[l] = reparameterize(q_mu_l, q_logvar_l)
        return q_mus, q_logvars, z_samples

    def compute_elbo_loss(self, x: torch.Tensor, beta: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q_mus, q_logvars, z_samples = self._infer_posterior(x)
        x_recon = self.decoder(z_samples[0], self.feature_channels, self.feature_size)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='mean')
        total_kl = 0.0
        top = self.num_layers - 1
        p_mu_top = torch.zeros_like(q_mus[top])
        p_logvar_top = torch.zeros_like(q_logvars[top])
        total_kl = total_kl + gaussian_kl(q_mus[top], q_logvars[top], p_mu_top, p_logvar_top)
        for l in range(top - 1, -1, -1):
            parent = l + 1
            # Do not detach: propagate gradients through top-down path
            p_mu_l = self.p_mu_layers[l](z_samples[parent])
            p_logvar_l = self.p_logvar_layers[l](z_samples[parent])
            total_kl = total_kl + gaussian_kl(q_mus[l], q_logvars[l], p_mu_l, p_logvar_l)
        total_loss = recon_loss + beta * total_kl
        return total_loss, recon_loss, torch.as_tensor(total_kl, device=x.device)

    def _log_normal_diag(self, z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        logvar = _clamp_logvar(logvar)
        log_2pi = math.log(2.0 * math.pi)
        return -0.5 * (
            logvar + (z - mu) * (z - mu) * torch.exp(-logvar) + torch.tensor(log_2pi, device=z.device)
        ).sum(dim=-1)

    def compute_iwae_loss(
        self,
        x: torch.Tensor,
        k: Union[int, List[int]] = 5,
        *,
        mode: str = 'joint',
        max_joint_k: int = 50,
    ) -> Dict[str, torch.Tensor]:
        # Shared/global K only for CNN path to control memory
        if isinstance(k, list):
            K = max(k)
        else:
            K = int(k)

        batch_size = x.size(0)
        mu_u, logvar_u = self._compute_upward_params(x)
        zs: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_mus: List[torch.Tensor] = [None] * self.num_layers  # type: ignore
        q_logvars: List[torch.Tensor] = [None] * self.num_layers  # type: ignore

        top = self.num_layers - 1
        p_mu_top = torch.zeros_like(mu_u[top]).unsqueeze(1).expand(batch_size, K, -1)
        p_logvar_top = torch.zeros_like(logvar_u[top]).unsqueeze(1).expand(batch_size, K, -1)
        mu_u_top = mu_u[top].unsqueeze(1).expand(batch_size, K, -1)
        logvar_u_top = logvar_u[top].unsqueeze(1).expand(batch_size, K, -1)
        q_mu_top, q_logvar_top = combine_gaussians(mu_u_top, logvar_u_top, p_mu_top, p_logvar_top)
        eps_top = torch.randn_like(q_logvar_top)
        z_top = reparameterize(q_mu_top, q_logvar_top, eps_top)
        q_mus[top] = q_mu_top
        q_logvars[top] = q_logvar_top
        zs[top] = z_top

        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            mu_u_l = mu_u[l].unsqueeze(1).expand(batch_size, K, -1)
            logvar_u_l = logvar_u[l].unsqueeze(1).expand(batch_size, K, -1)
            q_mu_l, q_logvar_l = combine_gaussians(mu_u_l, logvar_u_l, p_mu_l, p_logvar_l)
            eps = torch.randn_like(q_logvar_l)
            z_l = reparameterize(q_mu_l, q_logvar_l, eps)
            q_mus[l] = q_mu_l
            q_logvars[l] = q_logvar_l
            zs[l] = z_l

        # Decode
        z1_flat = zs[0].reshape(batch_size * K, -1)
        x_recon_flat = self.decoder(z1_flat, self.feature_channels, self.feature_size)
        x_recon_flat = torch.clamp(x_recon_flat, 1e-6, 1.0 - 1e-6)
        x_rep = x.unsqueeze(1).repeat(1, K, 1, 1, 1).reshape_as(x_recon_flat)
        bce = F.binary_cross_entropy(x_recon_flat, x_rep, reduction='none').view(batch_size, K, -1).sum(dim=2)
        log_px_z = -bce

        log_p = self._log_normal_diag(zs[top], p_mu_top, p_logvar_top)
        log_q = self._log_normal_diag(zs[top], q_mus[top], q_logvars[top])
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1)).view(batch_size, K, -1)
            log_p = log_p + self._log_normal_diag(zs[l], p_mu_l, p_logvar_l)
            log_q = log_q + self._log_normal_diag(zs[l], q_mus[l], q_logvars[l])

        log_w = log_px_z + log_p - log_q
        logsumexp = torch.logsumexp(log_w, dim=1)
        iwae_bound = logsumexp - torch.log(torch.tensor(K, dtype=logsumexp.dtype, device=logsumexp.device))
        loss = -iwae_bound.mean()

        w_tilde = torch.softmax(log_w, dim=1).detach()
        recon_loss = (w_tilde * (bce / (x.size(2) * x.size(3)))).sum(dim=1).mean()

        # Analytic KL proxy
        kl_proxy = 0.0
        kl_proxy = kl_proxy + gaussian_kl(q_mus[top].reshape(batch_size * K, -1),
                                          q_logvars[top].reshape(batch_size * K, -1),
                                          p_mu_top.reshape(batch_size * K, -1),
                                          p_logvar_top.reshape(batch_size * K, -1))
        for l in range(top - 1, -1, -1):
            parent = l + 1
            p_mu_l = self.p_mu_layers[l](zs[parent].reshape(batch_size * K, -1))
            p_logvar_l = self.p_logvar_layers[l](zs[parent].reshape(batch_size * K, -1))
            kl_proxy = kl_proxy + gaussian_kl(q_mus[l].reshape(batch_size * K, -1),
                                              q_logvars[l].reshape(batch_size * K, -1),
                                              p_mu_l, p_logvar_l)

        return {
            'loss': loss,
            'recon_loss': recon_loss,
            'kl_loss': torch.as_tensor(kl_proxy, device=x.device),
        }


