# CMVAE class definition
import torch
import torch.nn as nn
import torch.distributions as dist
from utils import get_mean


class CMVAE(nn.Module):
    """
    CMVAE class definition. Multimodal VAE with clustering in the latent space.
    """
    def __init__(self, prior_dist, params, *vaes):
        super(CMVAE, self).__init__()
        self.pz = prior_dist  # Prior distribution (shared latent)
        self.pw = prior_dist  # Prior distribution (modality-specific latent)
        self.vaes = nn.ModuleList([vae(params) for vae in vaes])  # Unimodal VAEs
        self.modelName = None  # Filled-in in subclass
        self.params = params  # Model parameters (i.e. args passed to main script)

    @staticmethod
    def getDataSets(batch_size, shuffle=True, device="cuda"):
        # Handle getting individual datasets appropriately in sub-class
        raise NotImplementedError

    def pw_params(self, z=None):
        """Handled in multimodal VAE subclass, optionally conditioned on z."""
        return self._pw_params

    def forward(self, x, K=1):
        """
        Forward function.
        Input:
            - x: list of data samples for each modality
            - K: number of samples for reparameterization in latent space

        Returns:
            - qu_xs: List of encoding distributions (one per encoder)
            - px_us: Matrix of self- and cross- reconstructions. px_us[m][n] contains
                    m --> n reconstruction.
            - uss: List of latent codes, one for each modality. uss[m] contains latents inferred
                   from modality m. Note there latents are the concatenation of private and shared latents.
        """
        use_disen = getattr(self.params, "use_disen", False)
        if not use_disen:
            qu_xs, uss = [], []
            px_us = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
            for m, vae in enumerate(self.vaes):
                qu_x, px_u, us = vae(x[m], K=K)
                qu_xs.append(qu_x)
                uss.append(us)
                px_us[m][m] = px_u
        else:
            qu_xs, uss = [], []
            px_us = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
            for m, vae in enumerate(self.vaes):
                qu_x_params = vae.enc(x[m])
                vae._qu_x_params = qu_x_params
                qu_x_r_mean, qu_x_r_lv = qu_x_params
                _, qz_x_mean = torch.split(
                    qu_x_r_mean, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                )
                _, qz_x_lv = torch.split(
                    qu_x_r_lv, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                )
                qz_x = vae.qu_x(qz_x_mean, qz_x_lv)
                z_m = qz_x.rsample(torch.Size([K]))  # (K,B,Dz)
                mu_w, lv_w = vae.enc.forward_disen(x[m], z_m)  # (K,B,Dw) params
                qw_x = vae.qu_x(mu_w, lv_w)
                ws = qw_x.rsample()  # (K,B,Dw)

                us = torch.cat((ws, z_m), dim=-1)
                uss.append(us)
                qu_xs.append({"z": qz_x, "w": qw_x})
                px_us[m][m] = vae.px_u(*vae.dec(us))

        for e, us in enumerate(uss):
            for d, vae in enumerate(self.vaes):
                if e != d:
                    _, z_e = torch.split(
                        us, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                    )
                    pw = vae.pw(*vae.pw_params_aux)
                    latents_w = pw.rsample(torch.Size([us.size()[0], us.size()[1]])).squeeze(2)
                    if not self.params.no_cuda and torch.cuda.is_available():
                        latents_w.cuda()
                    us_combined = torch.cat((latents_w, z_e), dim=-1)
                    px_us[e][d] = vae.px_u(*vae.dec(us_combined))
        return qu_xs, px_us, uss

    def generate_random_unconditional(self, N):
        """
        Unconditional random generation.
        Args:
            N: Number of samples to generate.
        Returns:
            Generations
        """
        with torch.no_grad():
            data = []
            idxs = dist.Categorical(probs=self.pc_params).sample([N])
            latents_z_l = []
            for idx in idxs:
                pz = self.pz(*self.pz_params(idx))
                latents_z = pz.rsample(torch.Size([1]))
                latents_z_l.append(latents_z)
            latents_z_all = torch.cat(latents_z_l, dim=0)
            for d, vae in enumerate(self.vaes):
                pw = self.pw(*self.pw_params())
                latents_w = pw.rsample([latents_z_all.size()[0]])
                latents = torch.cat((latents_w, latents_z_all), dim=-1)
                px_u = vae.px_u(*vae.dec(latents))
                data.append(px_u.mean.view(-1, *px_u.mean.size()[2:]))
        return data

    def generate_random_unconditional_with_pruning(self, N, idxs_to_prune):
        """
        Unconditional random generation with pruned clusters.
        Args:
            N: Number of samples to generate
            idxs_to_anneal: Indexes of annealed latent clusters

        Returns:
            Generations
        """
        with torch.no_grad():
            data = []
            if idxs_to_prune is None:
                idxs = dist.Categorical(probs=self.pc_params).sample([N])
            else:
                idxs = dist.Categorical(probs=self.pc_params_pruning(idxs_to_prune)).sample([N])
            latents_z_l = []
            for idx in idxs:
                pz = self.pz(*self.pz_params(idx))
                latents_z = pz.rsample(torch.Size([1]))
                latents_z_l.append(latents_z)
            latents_z_all = torch.cat(latents_z_l, dim=0)
            for d, vae in enumerate(self.vaes):
                pw = self.pw(*self.pw_params())
                latents_w = pw.rsample([latents_z_all.size()[0]])
                latents = torch.cat((latents_w, latents_z_all), dim=-1)
                px_u = vae.px_u(*vae.dec(latents))
                data.append(px_u.mean.view(-1, *px_u.mean.size()[2:]))
        return data

    def generate_unconditional(self, N):
        """
        Unconditional generation from each latent cluster.
        Args:
            N: Number of samples to generate

        Returns:
            Generations
        """
        with torch.no_grad():
            data = []
            latents_z_l = []
            for idx in range(self.params.latent_dim_c):
                pz = self.pz(*self.pz_params(idx))
                latents_z = pz.rsample(torch.Size([N]))
                latents_z_l.append(latents_z)
            latents_z_all = torch.cat(latents_z_l, dim=0)
            for d, vae in enumerate(self.vaes):
                pw = self.pw(*self.pw_params())
                latents_w = pw.rsample([latents_z_all.size()[0]])
                latents = torch.cat((latents_w, latents_z_all), dim=-1)
                px_u = vae.px_u(*vae.dec(latents))
                data.append(px_u.mean.view(-1, *px_u.mean.size()[2:]))
        return data

    def generate_unconditional_with_input_latent_clusters(self, N, indexes):
        """
        Unconditional generation from selected latent clusters.
        Args:
            indexes: list of latent clusters indexes to sample from

        Returns:
            Generations
        """
        with torch.no_grad():
            data = []
            latents_z_l = []
            for idx in indexes:
                pz = self.pz(*self.pz_params(idx))
                latents_z = pz.rsample(torch.Size([N]))
                latents_z_l.append(latents_z)
            latents_z_all = torch.cat(latents_z_l, dim=0)
            for d, vae in enumerate(self.vaes):
                pw = self.pw(*self.pw_params())
                latents_w = pw.rsample([latents_z_all.size()[0]])
                latents = torch.cat((latents_w, latents_z_all), dim=-1)
                px_u = vae.px_u(*vae.dec(latents))
                data.append(px_u.mean.view(-1, *px_u.mean.size()[2:]))
        return data

    def self_and_cross_modal_generation_forward(self, x, K=1):
        """
        Test-time self- and cross-modal generation forward function.
        Input:
            - x: list of data samples for each modality
            - K: number of samples for reparameterization in latent space

        Returns:
            - qu_xs: List of encoding distributions (one per encoder)
            - px_us: Matrix of test-time self- and cross- reconstructions. px_us[m][n] contains
                    m --> n reconstruction.
            - uss: List of latent codes, one for each modality. uss[m] contains latents inferred
                   from modality m. Note there latents are the concatenation of private and shared latents.
        """
        use_disen = getattr(self.params, "use_disen", False)
        if not use_disen:
            qu_xs, uss = [], []
            px_us = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
            for m, vae in enumerate(self.vaes):
                qu_x, px_u, us = vae(x[m], K=K)
                qu_xs.append(qu_x)
                uss.append(us)
                px_us[m][m] = px_u
        else:
            qu_xs, uss = [], []
            px_us = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
            for m, vae in enumerate(self.vaes):
                qu_x_params = vae.enc(x[m])
                vae._qu_x_params = qu_x_params
                qu_x_r_mean, qu_x_r_lv = qu_x_params
                _, qz_x_mean = torch.split(
                    qu_x_r_mean, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                )
                _, qz_x_lv = torch.split(
                    qu_x_r_lv, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                )
                qz_x = vae.qu_x(qz_x_mean, qz_x_lv)
                z_m = qz_x.rsample(torch.Size([K]))  # (K,B,Dz)

                mu_w, lv_w = vae.enc.forward_disen(x[m], z_m)  # (K,B,Dw) params
                qw_x = vae.qu_x(mu_w, lv_w)
                ws = qw_x.rsample()  # (K,B,Dw)

                us = torch.cat((ws, z_m), dim=-1)
                uss.append(us)
                qu_xs.append({"z": qz_x, "w": qw_x})
                px_us[m][m] = vae.px_u(*vae.dec(us))

        for e, us in enumerate(uss):
            for d, vae in enumerate(self.vaes):
                if e != d:
                    _, z_e = torch.split(
                        us, [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1
                    )
                    pw = vae.pw(*vae.pw_params_std)
                    latents_w = pw.rsample(torch.Size([us.size()[0], us.size()[1]])).squeeze(2)
                    if not self.params.no_cuda and torch.cuda.is_available():
                        latents_w.cuda()
                    us_combined = torch.cat((latents_w, z_e), dim=-1)
                    px_us[e][d] = vae.px_u(*vae.dec(us_combined))
        return qu_xs, px_us, uss

    def self_and_cross_modal_generation(self, data):
        """
        Test-time self- and cross-reconstruction.
        Args:
            data: Input

        Returns:
            Matrix of self- and cross-modal reconstructions

        """
        with torch.no_grad():
            _, px_us, _ = self.self_and_cross_modal_generation_forward(data)
            recons = [[get_mean(px_u) for px_u in r] for r in px_us]
        return recons
