# Base MMVAE class definition

from itertools import combinations

import torch
import torch.nn as nn

from utils import get_mean, kl_divergence
from vis import embed_umap, tensors_to_df


class MMVAE(nn.Module):
    def __init__(self, prior_dist, params, *vaes):
        super(MMVAE, self).__init__()
        self.pz = prior_dist
        self.pw = prior_dist
        self.vaes = nn.ModuleList([vae(params) for vae in vaes])
        self.modelName = None  # filled-in per sub-class
        self.params = params
        self._pz_params = None  # defined in subclass

    @property
    def pz_params(self):
        return self._pz_params

    @staticmethod
    def getDataLoaders(batch_size, shuffle=True, device="cuda"):
        # handle merging individual datasets appropriately in sub-class
        raise NotImplementedError

    def forward(self, x, K=1):
        qz_xs, zss = [], []
        # initialise cross-modal matrix
        px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        for m, vae in enumerate(self.vaes):
            qz_x, px_z, zs = vae(x[m], K=K)
            qz_xs.append(qz_x)
            zss.append(zs)
            px_zs[m][m] = px_z  # fill-in diagonal
        # print(zss[0].mean(), zss[1].mean())
        for e, zs in enumerate(zss):
            for d, vae in enumerate(self.vaes):
                if e != d:  # fill-in off-diagonal
                    if self.params.w_from_prior == 'joint':
                        _, u_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                        # BASICALLY WE ONLY KEEP Z AND RESAMPLE W FROM PRIOR, SO WE AIM AT ONLY HAVING THE SHARED SPACE LEARNING COMMON FEATURES
                        pz = self.pz(*self.pz_params)
                        latents = pz.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2)
                        latents_w, _ = torch.split(latents, [self.params.latent_dim_w, self.params.latent_dim_u],
                                                   dim=-1)
                        zs_combined = torch.cat((latents_w, u_e), dim=2)
                        px_zs[e][d] = vae.px_z(*vae.dec(zs_combined))
                    elif self.params.w_from_prior == 'single':
                        _, u_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                        pw = self.pw(*vae.pw_params)
                        latents_w = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2)
                        '''
                        pw = self.pw(torch.zeros(1, self.params.latent_dim_w), torch.zeros(1, self.params.latent_dim_w))
                        if self.params.cuda:
                            latents_w = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2).cuda()
                        else:
                            latents_w = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2).cpu()
                        '''
                        if not self.params.no_cuda and torch.cuda.is_available():
                            latents_w.cuda()
                        zs_combined = torch.cat((latents_w, u_e), dim=-1)
                        px_zs[e][d] = vae.px_z(*vae.dec(zs_combined))
                    elif self.params.w_from_prior == 'natural-no-grad':
                        _, u_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                        qz_x = qz_xs[d]
                        with torch.no_grad():
                            zs_target = qz_x.rsample(torch.Size([zs.size()[0]]))
                            w_d, _ = torch.split(zs_target, [self.params.latent_dim_w, self.params.latent_dim_u],
                                                 dim=-1)
                        zs_combined = torch.cat((w_d, u_e), dim=-1)
                        px_zs[e][d] = vae.px_z(*vae.dec(zs_combined))
                    else:
                        px_zs[e][d] = vae.px_z(*vae.dec(zs))
        return qz_xs, px_zs, zss

    def test_forward(self, x, K=1):
        qz_xs, zss = [], []
        # initialise cross-modal matrix
        px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        for m, vae in enumerate(self.vaes):
            qz_x, px_z, zs = vae(x[m], K=K)
            qz_xs.append(qz_x)
            zss.append(zs)
            px_zs[m][m] = px_z  # fill-in diagonal
        for e, zs in enumerate(zss):
            for d, vae in enumerate(self.vaes):
                if e != d:  # fill-in off-diagonal
                    _, u_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                    zs_target = zss[d]
                    w_d, _ = torch.split(zs_target, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                    zs_combined = torch.cat((w_d, u_e), dim=-1)
                    px_zs[e][d] = vae.px_z(*vae.dec(zs_combined))
        return qz_xs, px_zs, zss

    def helper_to_plot_different_ws(self, x, K=1):
        qz_xs, zss = [], []
        wss_natural = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        wss_prior = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        wss_natural = []
        wss_prior = []
        for m, vae in enumerate(self.vaes):
            qz_x, px_z, zs = vae(x[m], K=K)
            qz_xs.append(qz_x)
            zss.append(zs)
            ws, _ = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
            wss_natural.append(ws)  # fill-in diagonal
            pw = self.pw(*vae.pw_params)
            latents_w = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2)
            wss_prior.append(latents_w)
        return wss_natural, wss_prior

    def generate(self, N):
        with torch.no_grad():
            data = []
            pz = self.pz(*self.pz_params)
            latents = pz.rsample(torch.Size([N]))
            for d, vae in enumerate(self.vaes):
                px_z = vae.px_z(*vae.dec(latents))
                data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return data  # list of generations---one for each modality

    def generate_parametric(self, N, factor_u, factor_w_img, factor_w_sent):
        with torch.no_grad():
            data = []
            mean_z, scale_z = self.pz_params
            # scale_w, scale_u = torch.split(scale_z, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
            # scale_w = factor_w * scale_w
            scale_z_u = factor_u * scale_z
            # scale_z = torch.cat([scale_w, scale_u], dim=-1)
            pz_u = self.pz(mean_z, scale_z_u)
            # pz = self.pz(*self.pz_params)
            latents_z_u = pz_u.rsample(torch.Size([N]))
            _, latents_u = torch.split(latents_z_u, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
            # mean_z, scale_z = self.pz_params
            scale_z_w_img = factor_w_img * scale_z
            pz_w_img = self.pz(mean_z, scale_z_w_img)
            latents_z_w_img = pz_w_img.rsample(torch.Size([N]))
            latents_w_img, _ = torch.split(latents_z_w_img, [self.params.latent_dim_w, self.params.latent_dim_u],
                                           dim=-1)
            scale_z_w_sent = factor_w_sent * scale_z
            pz_w_sent = self.pz(mean_z, scale_z_w_sent)
            latents_z_w_sent = pz_w_sent.rsample(torch.Size([N]))
            latents_w_sent, _ = torch.split(latents_z_w_sent, [self.params.latent_dim_w, self.params.latent_dim_u],
                                            dim=-1)
            latents_z_img = torch.cat([latents_u, latents_w_img], dim=-1)
            latents_z_sent = torch.cat([latents_u, latents_w_sent], dim=-1)
            for d, vae in enumerate(self.vaes):
                if d == 0:
                    px_z = vae.px_z(*vae.dec(latents_z_img))
                    data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
                else:
                    px_z = vae.px_z(*vae.dec(latents_z_sent))
                    data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return data  # list of generations---one for each modality

    def generate_sampled(self, N):
        with torch.no_grad():
            data = []
            pz = self.pz(*self.pz_params)
            latents = pz.rsample(torch.Size([N]))
            for d, vae in enumerate(self.vaes):
                px_z = vae.px_z(*vae.dec(latents))
                if d == 1:
                    sample = px_z.sample()
                    data.append(sample.view(-1, *sample.size()[2:]))
                else:
                    data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return data  # list of generations---one for each modality

    def reconstruct(self, data):
        with torch.no_grad():
            _, px_zs, _ = self.forward(data)
            # cross-modal matrix of reconstructions
            recons = [[get_mean(px_z) for px_z in r] for r in px_zs]
        return recons

    def reconstruct_sampled(self, data):
        with torch.no_grad():
            _, px_zs, _ = self.forward(data)
            # cross-modal matrix of reconstructions
            recons = [[None for px_z in r] for r in px_zs]
            for r, vae in enumerate(self.vaes):
                for o, vae in enumerate(self.vaes):
                    if o == 1:
                        recons[r][o] = px_zs[r][o].sample()
                    else:
                        recons[r][o] = get_mean(px_zs[r][o])
        return recons

    def reconstruct_options_forw(self, data, option, factor=1.0, K=1):
        qz_xs, zss = [], []
        # initialise cross-modal matrix
        px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        # pw = self.pz(torch.zeros(1, self.params.latent_dim_w), torch.ones(1, self.params.latent_dim_w))
        for m, vae in enumerate(self.vaes):
            qz_x, px_z, zs = vae(data[m], K=K)
            qz_xs.append(qz_x)
            zss.append(zs)
            px_zs[m][m] = px_z  # fill-in diagonal
        for e, zs in enumerate(zss):
            if option == "stdprior":
                latents_w, latents_u = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                pw = self.pw(torch.zeros(1, self.params.latent_dim_w),
                             torch.full([1, self.params.latent_dim_w], fill_value=factor))
                if self.params.cuda:
                    latents_w_new = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2).cuda()
                else:
                    latents_w_new = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2).cpu()
                zs = torch.cat((latents_w_new, latents_u), dim=-1)
            elif option == "jointprior":
                latents_w, latents_u = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                # Tune single prior
                mean_z, scale_z = self.pz_params
                scale_z = factor * scale_z
                pz = self.pz(mean_z, scale_z)
                # pz = self.pz(*self.pz_params)
                latents_z_to_split = pz.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2)
                latents_w_new, _ = torch.split(latents_z_to_split,
                                               [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                zs = torch.cat((latents_w_new, latents_u), dim=-1)
            elif option == "tunedsingleprior":
                latents_w, latents_u = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                # Tune single prior
                mean_w_sp, scale_w_p = vae.pw_params
                scale_w_p = factor * scale_w_p
                pw = self.pw(mean_w_sp, scale_w_p)
                if self.params.cuda:
                    latents_w_new = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(
                        2).cuda()
                else:
                    latents_w_new = pw.rsample(torch.Size([zs.size()[0], zs.size()[1]])).squeeze(2).cpu()
                zs = torch.cat((latents_w_new, latents_u), dim=-1)
            elif option == 'natural':
                pass
            else:
                raise ValueError("Not a valid reconstruction option")
            for d, vae in enumerate(self.vaes):
                if e != d:  # fill-in off-diagonal
                    if option == 'natural':
                        zs_d = zss[d]
                        ws_d, _ = torch.split(zs_d, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                        _, us_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                        zs_combined = torch.cat((ws_d, us_e), dim=-1)
                        zs = zs_combined
                    else:
                        pass
                    px_zs[e][d] = vae.px_z(*vae.dec(zs))
        return qz_xs, px_zs, zss

    def reconstruct_options(self, data, option, factor=1.0):
        with torch.no_grad():
            _, px_zs, _ = self.reconstruct_options_forw(data, option, factor)
            # ------------------------------------------------
            # cross-modal matrix of reconstructions
            recons = [[get_mean(px_z) for px_z in r] for r in px_zs]
        return recons

    def reconstruct_options_sampled(self, data, option, factor=1.0):
        with torch.no_grad():
            _, px_zs, _ = self.reconstruct_options_forw(data, option, factor)
            # ------------------------------------------------
            # cross-modal matrix of reconstructions
            recons = [[None for px_z in r] for r in px_zs]
            for r, vae in enumerate(self.vaes):
                for o, vae in enumerate(self.vaes):
                    if o == 1:
                        recons[r][o] = px_zs[r][o].sample()
                    else:
                        recons[r][o] = get_mean(px_zs[r][o])
        return recons

    def shift_reconstruct_forw(self, data, shift=0, K=1):
        qz_xs, zss = [], []
        # initialise cross-modal matrix
        px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))]
        for m, vae in enumerate(self.vaes):
            qz_x, px_z, zs = vae(data[m], K=K)
            qz_xs.append(qz_x)
            zss.append(zs)
            px_zs[m][m] = px_z  # fill-in diagonal
        for e, zs in enumerate(zss):
            for d, vae in enumerate(self.vaes):
                if e != d:  # fill-in off-diagonal
                    zs_d = zss[d]
                    ws_d, _ = torch.split(zs_d, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                    ws_d = torch.roll(ws_d, shifts=shift, dims=1)
                    _, us_e = torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
                    zs_combined = torch.cat((ws_d, us_e), dim=-1)
                    zs = zs_combined
                    px_zs[e][d] = vae.px_z(*vae.dec(zs))
        return qz_xs, px_zs, zss

    def shift_reconstruct(self, data, shift=0):
        with torch.no_grad():
            _, px_zs, _ = self.shift_reconstruct_forw(data, shift)
            # ------------------------------------------------
            # cross-modal matrix of reconstructions
            recons = [[get_mean(px_z) for px_z in r] for r in px_zs]
        return recons

    def private_conditioned_self_generate(self, data):
        with torch.no_grad():
            output = []
            _, px_zs, zss = self.reconstruct_options_forw(data, "natural")
            wss = []
            for zs in zss:
                wss.append(torch.split(zs, [self.params.latent_dim_w, self.params.latent_dim_u],
                                       dim=-1)[0])
            pz = self.pz(*self.pz_params)
            latents_z_to_split = pz.rsample(torch.Size(wss[0].size()[:-1])).squeeze(2)
            us = torch.split(latents_z_to_split, [self.params.latent_dim_w, self.params.latent_dim_u],
                             dim=-1)[1]
            zss_new = []
            for ws in wss:
                zss_new.append(torch.cat((ws, us), dim=-1))

            for d, vae in enumerate(self.vaes):
                px_z = vae.px_z(*vae.dec(zss_new[d]))
                output.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return output

    def generate_multilple_ws_for_u(self, M=8):
        with torch.no_grad():
            data = []
            latents = [None for _ in range(M)]
            pz = self.pz(*self.pz_params)
            latents_z = pz.rsample(torch.Size([1]))
            latents_w, latents_u = torch.split(latents_z, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
            latents[0] = latents_z
            for i_w in range(1, M):
                latents_z_to_split = pz.rsample(torch.Size([1]))
                latents_w_new, _ = torch.split(latents_z_to_split, [self.params.latent_dim_w, self.params.latent_dim_u],
                                               dim=-1)
                latents[i_w] = torch.cat((latents_w_new, latents_u), dim=-1)
            latents = torch.cat(latents, dim=0)
            for d, vae in enumerate(self.vaes):
                px_z = vae.px_z(*vae.dec(latents))
                data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return data  # list of generations---one for each modality

    def generate_multilple_us_for_w(self, M=8):
        with torch.no_grad():
            data = []
            latents = [None for _ in range(M)]
            pz = self.pz(*self.pz_params)
            latents_z = pz.rsample(torch.Size([1]))
            latents_w, latents_u = torch.split(latents_z, [self.params.latent_dim_w, self.params.latent_dim_u], dim=-1)
            latents[0] = latents_z
            for i_w in range(1, M):
                latents_z_to_split = pz.rsample(torch.Size([1]))
                _, latents_u_new = torch.split(latents_z_to_split, [self.params.latent_dim_w, self.params.latent_dim_u],
                                               dim=-1)
                latents[i_w] = torch.cat((latents_w, latents_u_new), dim=-1)
            latents = torch.cat(latents, dim=0)
            for d, vae in enumerate(self.vaes):
                px_z = vae.px_z(*vae.dec(latents))
                data.append(px_z.mean.view(-1, *px_z.mean.size()[2:]))
        return data  # list of generations---one for each modality

    def analyse(self, data, K):
        with torch.no_grad():
            qz_xs, _, zss = self.forward(data, K=K)
            pz = self.pz(*self.pz_params)
            zss = [pz.sample(torch.Size([K, data[0].size(0)])).view(-1, pz.batch_shape[-1]),
                   *[zs.view(-1, zs.size(-1)) for zs in zss]]
            zsl = [torch.zeros(zs.size(0)).fill_(i) for i, zs in enumerate(zss)]
            kls_df = tensors_to_df(
                [*[kl_divergence(qz_x, pz).cpu().numpy() for qz_x in qz_xs],
                 *[0.5 * (kl_divergence(p, q) + kl_divergence(q, p)).cpu().numpy()
                   for p, q in combinations(qz_xs, 2)]],
                head='KL',
                keys=[*[r'KL$(q(z|x_{})\,||\,p(z))$'.format(i) for i in range(len(qz_xs))],
                      *[r'J$(q(z|x_{})\,||\,q(z|x_{}))$'.format(i, j)
                        for i, j in combinations(range(len(qz_xs)), 2)]],
                ax_names=['Dimensions', r'KL$(q\,||\,p)$']
            )
        return embed_umap(torch.cat(zss, 0).cpu().numpy()), \
               torch.cat(zsl, 0).cpu().numpy(), \
               kls_df

    def analyse_difference_ws(self, data, K, m=0):
        with torch.no_grad():
            wss_natural, wss_prior = self.helper_to_plot_different_ws(data, K=K)
            for mod, vae in enumerate(self.vaes):
                if mod == m:
                    wss = [[ws.view(-1, ws.size(-1)) for ws in wss_natural][m],
                           [ws.view(-1, ws.size(-1)) for ws in wss_prior][m]]
            wsl = [torch.zeros(ws.size(0)).fill_(i) for i, ws in enumerate(wss)]
            '''
            kls_df = tensors_to_df(
                [*[kl_divergence(qz_x, pz).cpu().numpy() for qz_x in qz_xs],
                 *[0.5 * (kl_divergence(p, q) + kl_divergence(q, p)).cpu().numpy()
                   for p, q in combinations(qz_xs, 2)]],
                head='KL',
                keys=[*[r'KL$(q(z|x_{})\,||\,p(z))$'.format(i) for i in range(len(qz_xs))],
                      *[r'J$(q(z|x_{})\,||\,q(z|x_{}))$'.format(i, j)
                        for i, j in combinations(range(len(qz_xs)), 2)]],
                ax_names=['Dimensions', r'KL$(q\,||\,p)$']
            )
            '''
        return embed_umap(torch.cat(wss, 0).cpu().numpy()), \
               torch.cat(wsl, 0).cpu().numpy()