# PolyMNIST-PolyMNIST multi-modal model specification
import os
import shutil

import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
from numpy import sqrt, prod
from torch.utils.data import DataLoader
from torchnet.dataset import TensorDataset, ResampleDataset
from torchvision.utils import save_image, make_grid
from torchvision import transforms
from datasets_PolyMNIST import PolyMNISTDataset
from utils import Constants

from vis import plot_embeddings, plot_kls_df
from .mmvae import MMVAE
#from .vae_mnist import MNIST
#from .vae_svhn import SVHN
from .vae_polymnist_resnet import PolyMNIST


class PolyMNIST_5modalities(MMVAE):
    def __init__(self, params):
        super(PolyMNIST_5modalities, self).__init__(dist.Laplace, params, PolyMNIST, PolyMNIST, PolyMNIST, PolyMNIST, PolyMNIST)
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), **grad)  # logvar
        ])
        # REMOVE LLIK SCALING
        # self.vaes[0].llik_scaling = prod(self.vaes[1].dataSize) / prod(self.vaes[0].dataSize) \
            # if params.llik_scaling == 0 else params.llik_scaling
        self.modelName = 'polymnist-5modalities'

        # Fix model names for indiviudal models to be saved
        for idx, vae in enumerate(self.vaes):
            vae.modelName = 'polymnist_resnet_m'+str(idx)

        self.tmpdir = None

    @property
    def pz_params(self):
        return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)

    #@property
    #def pz_params(self):
    #    return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta

    def setTmpDir(self, tmpdir):
        self.tmpdir = tmpdir

    def getDataLoaders(self, batch_size, shuffle=True, device='cuda'):
        tx = transforms.ToTensor()
        unim_train_datapaths = [self.tmpdir+"/PolyMNIST/train/" + "m" + str(i) for i in [0, 1, 2, 3, 4]]
        unim_test_datapaths = [self.tmpdir+"/PolyMNIST/test/" + "m" + str(i) for i in [0, 1, 2, 3, 4]]
        dataset_PolyMNIST_train = PolyMNISTDataset(unim_train_datapaths, transform=tx)
        dataset_PolyMNIST_test = PolyMNISTDataset(unim_test_datapaths, transform=tx)
        kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
        train = DataLoader(dataset_PolyMNIST_train, batch_size=batch_size, shuffle=shuffle, **kwargs)
        test = DataLoader(dataset_PolyMNIST_test, batch_size=batch_size, shuffle=shuffle, **kwargs)
        return train, test


    def generate(self, runPath, epoch):
        N = 100
        samples_list = super(PolyMNIST_5modalities, self).generate(N)
        for i, samples in enumerate(samples_list):
            samples = samples.data.cpu()
            # wrangle things so they come out tiled
            samples = samples.view(N, *samples.size()[1:])
            save_image(samples,
                       '{}/gen_samples_{}_{:03d}.png'.format(runPath, i, epoch),
                       nrow=int(sqrt(N)))

    def generate_tb(self):
        N = 100
        outputs = []
        samples_list = super(PolyMNIST_5modalities, self).generate(N)
        for i, samples in enumerate(samples_list):
            samples = samples.data.cpu()
            # wrangle things so they come out tiled
            samples = samples.view(N, *samples.size()[1:])
            #save_image(samples,
                       #'{}/gen_samples_{}_{:03d}.png'.format(runPath, i, epoch),
                       #nrow=int(sqrt(N)))
            outputs.append(make_grid(samples, nrow=int(sqrt(N))))
        return outputs

    def generate_parametric(self, runPath, epoch, factor_u, factor_w):
        N = 100
        samples_list = super(PolyMNIST_5modalities, self).generate_parametric(N, factor_u=factor_u, factor_w=factor_w)
        for i, samples in enumerate(samples_list):
            samples = samples.data.cpu()
            # wrangle things so they come out tiled
            samples = samples.view(N, *samples.size()[1:])
            save_image(samples,
                       '{}/gen_samples_parametric_{}_{:03d}.png'.format(runPath, i, epoch),
                       nrow=int(sqrt(N)))

    def generate_for_coherence(self, N):
        samples_list = super(PolyMNIST_5modalities, self).generate_parametric(N, 1.0, 0.65)

        return [samples.data.cpu() for samples in samples_list]

    def generate_for_fid(self, runPath, num_samples, tranche, r_in):
        N = num_samples
        samples_list = super(PolyMNIST_5modalities, self).generate(N)
        for i, samples in enumerate(samples_list):
            if i == r_in:
                if os.path.exists(os.path.join(runPath, 'fid_samples_{}'.format(i))):
                    shutil.rmtree(os.path.join(runPath, 'fid_samples_{}'.format(i)))
                    os.makedirs(os.path.join(runPath, 'fid_samples_{}'.format(i)))
                else:
                    os.makedirs(os.path.join(runPath, 'fid_samples_{}'.format(i)))
                samples = samples.data.cpu()
                # wrangle things so they come out tiled
                # samples = samples.view(N, *samples.size()[1:])
                for image in range(samples.size(0)):
                    save_image(samples[image, :, :, :], '{}/fid_samples_{}/{}_{}_{}.png'.format(runPath, i, tranche, image,i))

    def generate_for_fid_tb(self, runPath, num_samples, tranche):
        N = num_samples
        samples_list = super(PolyMNIST_5modalities, self).generate(N)
        for i, samples in enumerate(samples_list):
            samples = samples.data.cpu()
            # wrangle things so they come out tiled
            # samples = samples.view(N, *samples.size()[1:])
            for image in range(samples.size(0)):
                save_image(samples[image, :, :, :], '{}/random/m{}/{}_{}.png'.format(runPath, i, tranche, image))

    def shift_reconstruct(self, data, runPath, epoch, r_in, o_in):
        recon_tries = []
        for i in range(10):
            recons_mat = super(PolyMNIST_5modalities, self).shift_reconstruct([d[:10] for d in data], shift=i)
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                      if r == r_in and o == o_in:
                            _data_r = data[r][:10].cpu()
                            _data_o = data[o][:10].cpu()
                            recon = recon.squeeze(0).cpu()
                            # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                            #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                            #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                            #recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                            recon_tries.append(recon)
        comp = torch.cat([_data_r]+recon_tries[1:]+ [recon_tries[0]]+[_data_o])
        save_image(comp, '{}/shift_reconstruction_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch), nrow=10)

    def cross_reconstruct(self, data, runPath, epoch, r_in, o_in):
        recon_tries = []
        for i in range(10):
            recons_mat = super(PolyMNIST_5modalities, self).cross_reconstruct([d[:10] for d in data], shift=i)
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                      if r == r_in and o == o_in:
                            _data_r = data[r][:10].cpu()
                            _data_o = data[o][:10].cpu()
                            recon = recon.squeeze(0).cpu()
                            # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                            #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                            #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                            #recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                            recon_tries.append(recon)
        comp = torch.cat([_data_r]+ [recon_tries[0]] + recon_tries[1:] +[_data_o])
        comp1 = comp.view(12, 10, 3, 28, 28).transpose(0, 1).reshape(120, 3, 28, 28)
        save_image(comp1, '{}/cross_reconstruction_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch), nrow=12)

    def private_conditioned_self_generate(self, data, runPath, epoch, r_in,o_in):
        recon_tries = []
        for i in range(8):
            recons_list = super(PolyMNIST_5modalities, self).private_conditioned_self_generate([d[:10] for d in data])
            for r, _ in enumerate(recons_list):
                for o, recon in enumerate(recons_list):
                    if r == r_in and o == o_in:
                        _data = data[o][:10].cpu()
                        recon = recon.squeeze(0).cpu()
                        # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                        #_data = _data if r == 0 else resize_img(_data, self.vaes[1].dataSize)
                        #recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                        recon_tries.append(recon)
        comp = torch.cat([_data] + recon_tries)
        save_image(comp, '{}/private_conditioned_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch), nrow=10)

    def reconstruct(self, data, runPath, epoch):
        recons_mat = super(PolyMNIST_5modalities, self).reconstruct([d[:8] for d in data])
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                _data = data[r][:8].cpu()
                recon = recon.squeeze(0).cpu()
                # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                #_data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize)
                #recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                comp = torch.cat([_data, recon])
                save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))

    def reconstruct_options(self, data, runPath, epoch, option, factor=1.0):
        recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d[:8] for d in data], option, factor)
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                _data = data[r][:8].cpu()
                recon = recon.squeeze(0).cpu()
                # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                # _data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize)
                # recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                comp = torch.cat([_data, recon])
                save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))

    def reconstruct_for_fid(self, data, runPath, i, n_sam, r_in, factor):
        recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d[:n_sam] for d in data],
                                                                            option='jointprior', factor=factor)
        for r, recons_list in enumerate(recons_mat):
            if r == r_in:
                for o, recon in enumerate(recons_list):
                    if o != r_in:
                        _data = data[r][:n_sam].cpu()
                        recon = recon.squeeze(0).cpu()
                        # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                        # _data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize)
                        # recon = recon if o == 1 else resize_img(recon, self.vkaes[1].dataSize)

                        for image in range(recon.size(0)):
                            save_image(recon[image, :, :, :],
                                        '{}/fid_samples_recon_{}x{}/{}_{}.png'.format(runPath, r,o, image, i))

    def reconstruct_for_fid_tb(self, data, runPath, i):
        recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d for d in data],
                                                                            option='jointprior', factor=1.0)
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                # _data = data[r][:n_sam].cpu()
                recon = recon.squeeze(0).cpu()
                # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                # _data = _data if r == 1 else resize_img(_data, self.vaes[1].dataSize)
                # recon = recon if o == 1 else resize_img(recon, self.vkaes[1].dataSize)

                for image in range(recon.size(0)):
                    save_image(recon[image, :, :, :],
                                '{}/m{}/m{}/{}_{}.png'.format(runPath, r,o, image, i))

    def cross_generate(self, data, runPath, epoch, r_in, o_in):
        N = 10
        recon_tries = []
        for i in range(10):
            recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d[:N] for d in data], option="jointprior")
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                      if r == r_in and o == o_in:
                          _data_r = data[r][:N].cpu()
                          _data_o = data[o][:N].cpu()
                          recon = recon.squeeze(0).cpu()
                          # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                          #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                          #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                          # recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                          recon_tries.append(recon)
        comp = torch.cat([_data_r]+recon_tries)
        save_image(comp, '{}/cross_generation_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch), nrow=N)

    def cross_generate_tb(self, data):
        N = 10
        recon_triess = [[[] for i in range(N)] for j in range(N)]
        outputss = [[[] for i in range(N)] for j in range(N)]
        for i in range(10):
            recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d[:N] for d in data], option="jointprior")
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                      #_data_r = data[r][:N].cpu()
                      #_data_o = data[o][:N].cpu()
                      recon = recon.squeeze(0).cpu()
                      # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                      #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                      #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                      # recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                      recon_triess[r][o].append(recon)
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                outputss[r][o] =  make_grid(torch.cat([data[r][:N].cpu()]+recon_triess[r][o]), nrow=N)
        #save_image(comp, '{}/cross_generation_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch), nrow=N)
        return outputss


    def cross_generate_pretty(self, data, runPath, epoch, samples_to_select, start_mod, factor):
        N = 10
        M = 10
        recon_tries = [[p for p in range(M)] for vae in range(len(self.vaes)-1)]
        _data = data[start_mod][:N].cpu()
        for i in range(M):
            recons_mat = super(PolyMNIST_5modalities, self).reconstruct_options([d[:N] for d in data], option="jointprior", factor=factor)
            recons_list = recons_mat[start_mod]
            recons_list.pop(start_mod)
            for o, recon in enumerate(recons_list):
                  recon = recon.squeeze(0).cpu()
                  # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                  #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                  #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                  # recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                  recon_tries[o][i] = recon
        selected_samples = []
        for (o,j) in samples_to_select:
            selected_samples.append(recon_tries[o][j])
        comp = torch.cat([_data]+selected_samples)
        #for m, recon_list in enumerate(recon_tries):
            #comp = torch.cat([_data] + recon_list)
        save_image(comp, '{}/cross_generation_pretty_jointprior_{}_{}_{:03d}.png'.format(runPath, factor, start_mod, epoch), nrow=N)


    def generate_pretty_parametric(self, runPath, epoch, factor_u, factor_w):
        N = 5
        samples_list = super(PolyMNIST_5modalities, self).generate_parametric(N, factor_u, factor_w)
        samples = torch.cat(samples_list, dim=0).data.cpu()
        samples = samples.view(N*5, *samples.size()[1:])
        save_image(samples,
                       '{}/gen_samples_pretty_parametric{:03d}_{}_{}.png'.format(runPath, epoch, factor_u, factor_w),
                       nrow=5)

    def analyse(self, data, runPath, epoch):
        zemb, zsl, kls_df = super(PolyMNIST_5modalities, self).analyse(data, K=10)
        labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]]
        plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
        plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))


def resize_img(img, refsize):
    return F.pad(img, (2, 2, 2, 2)).expand(img.size(0), *refsize)
