# cub multi-modal model specification
import matplotlib.pyplot as plt
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from numpy import sqrt, prod
from torch.utils.data import DataLoader
from torchnet.dataset import TensorDataset, ResampleDataset
from torchvision.utils import save_image, make_grid

from utils import Constants
from vis import plot_embeddings, plot_kls_df
from .mmvae_cub import MMVAE
from .vae_cub_image import CUB_Image
from .vae_cub_sent_convolution import CUB_Sentence

from torchvision import transforms

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
import numpy as np
import textwrap

import os
import shutil

# Constants
maxSentLen = 32
minOccur = 3


# This is required because there are 10 captions per image.
# Allows easier reuse of the same image for the corresponding set of captions.
def resampler(dataset, idx):
    return idx // 10


class CUB_Image_Sentence(MMVAE):

    def __init__(self, params):
        super(CUB_Image_Sentence, self).__init__(dist.Normal, params, CUB_Image,
                                                 CUB_Sentence)  # TODO check if prior dist is compatible with the one used for text
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim_w + params.latent_dim_u), **grad)  # logvar
        ])
        self.vaes[0].llik_scaling = self.vaes[1].maxSentLen / prod(self.vaes[0].dataSize) \
            if params.llik_scaling == 0 else params.llik_scaling
        self.vaes[1].llik_scaling = params.llik_scaling_sent
        # for vae in self.vaes:
        #      vae._pz_params = self._pz_params      # TODO check what is this and why one should use it
        self.modelName = 'cubIS_conv'

    @property
    def pz_params(self):
        return self._pz_params[0], F.softplus(self._pz_params[1]) + Constants.eta

    def getDataLoaders(self, batch_size, shuffle=True, device='cuda'):
        # load base datasets
        t1, s1 = self.vaes[0].getDataLoaders(batch_size, shuffle, device)
        t2, s2 = self.vaes[1].getDataLoaders(batch_size, shuffle, device)

        kwargs = {'num_workers': 2, 'pin_memory': True} if device == 'cuda' else {}
        train_loader = DataLoader(TensorDataset([
            ResampleDataset(t1.dataset, resampler, size=len(t1.dataset) * 10),
            t2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
        test_loader = DataLoader(TensorDataset([
            ResampleDataset(s1.dataset, resampler, size=len(s1.dataset) * 10),
            s2.dataset]), batch_size=batch_size, shuffle=shuffle, **kwargs)
        return train_loader, test_loader

    def generate(self, runPath, epoch):
        N = 8
        samples = super(CUB_Image_Sentence, self).generate(N)
        images, captions = [sample.data.cpu() for sample in samples]
        captions = self._sent_process(captions.argmax(-1))
        fig = plt.figure(figsize=(15, 12))
        for i, (image, caption) in enumerate(zip(images, captions)):
            fig = self._imshow(image, caption, i, fig, N)
        plt.savefig('{}/gen_samples_{:03d}.png'.format(runPath, epoch))
        plt.close()

    def generate_parametric(self, runPath, epoch, factor_u, factor_w_img, factor_w_sent):
        N = 8
        samples = super(CUB_Image_Sentence, self).generate_parametric(N, factor_u, factor_w_img, factor_w_sent)
        images, captions = [sample.data.cpu() for sample in samples]
        captions = self._sent_process(captions.argmax(-1))
        fig = plt.figure(figsize=(15, 12))
        for i, (image, caption) in enumerate(zip(images, captions)):
            fig = self._imshow(image, caption, i, fig, N)
        plt.savefig('{}/gen_samples_parametric{:03d}.png'.format(runPath, epoch))
        plt.close()

    def generate_text(self, runPath, epoch, font):
        N = 100
        samples = super(CUB_Image_Sentence, self).generate(N)
        images, captions = [sample.data.cpu() for sample in samples]
        imgsize = images.size()[-3:]
        process_sentences = lambda sent: text_to_pil_cub(sent, imgsize, font)
        comp = []
        recon = captions
        recon = self._sent_process(recon.argmax(-1).squeeze())
        recon = [' '.join(self.vaes[1].i2w[str(word)] for word in sent) for sent in recon]
        sentences_tensor = torch.stack([process_sentences(sent) for sent in recon])
        comp.append(sentences_tensor)
        # comp_rev = []
        # comp_rev.append(comp.pop(0))
        # comp.reverse()
        # for elem in comp:
        # comp_rev.append(elem)
        # plt.imshow(make_grid(torch.cat(comp_rev, dim=0), nrow=N).permute(1,2,0))
        save_image(make_grid(torch.cat(comp, dim=0), nrow=10),
                   fp='{}/generation_pretty_text_{}.png'.format(runPath, epoch))
        # plt.show()

    def generate_for_fid(self, runPath, num_samples, tranche):
        N = num_samples
        samples = super(CUB_Image_Sentence, self).generate(N)
        images, captions = [sample.data.cpu() for sample in samples]
        if os.path.exists(os.path.join(runPath, 'fid_samples_{}'.format(0))):
            pass
        else:
            os.makedirs(os.path.join(runPath, 'fid_samples_{}'.format(0)))
        counter = 0
        for image in images:
            save_image(image,
                       '{}/fid_samples_{}/{}_{}_{}.png'.format(runPath, 0, tranche, counter, 0))
            counter += 1

    def reconstruct_for_fid(self, data, runPath, i, n_sam):
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d[:n_sam] for d in data], option='jointprior')
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                if r == 1 and o == 0:
                    _data = data[r][:n_sam].cpu()
                    recon = recon.squeeze(0).cpu()
                    for image in range(recon.size(0)):
                        save_image(recon[image, :, :, :],
                                   '{}/fid_samples_recon_{}x{}/{}_{}.png'.format(runPath, r, o, image, i))

    def reconstruct_for_fid_coherence(self, data, i, save_path, factor):
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d for d in data], option='jointprior',
                                                                         factor=factor)
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                if r == 1 and o == 0:
                    # _data = data[r][:n_sam].cpu()
                    recon = recon.squeeze(0).cpu()
                    for image in range(recon.size(0)):
                        save_image(recon[image, :, :, :], '{}/{}_{}.png'.format(save_path, i, image))

    def generate_sampled(self, runPath, epoch):
        N = 8
        samples = super(CUB_Image_Sentence, self).generate_sampled(N)
        images, captions = [sample.data.cpu() for sample in samples]
        captions = self._sent_process(captions.argmax(-1))
        fig = plt.figure(figsize=(15, 12))
        for i, (image, caption) in enumerate(zip(images, captions)):
            fig = self._imshow(image, caption, i, fig, N)
        plt.savefig('{}/gen_samples_sampled_{:03d}.png'.format(runPath, epoch))
        plt.close()

    def generate_nips(self, runPath, epoch):
        N = 100
        samples = super(CUB_Image_Sentence, self).generate(N)
        images, captions = [sample.data.cpu() for sample in samples]
        # wrangle things so they come out tiled
        images = images.view(N, *images.size()[1:])
        save_image(images,
                   '{}/gen_nips_samples_{:03d}.png'.format(runPath, epoch),
                   nrow=int(sqrt(N)))
        plt.imshow(make_grid(images, nrow=10).permute(1, 2, 0))
        plt.show()

    def reconstruct(self, raw_data, runPath, epoch):
        N = 8
        recons_mat = super(CUB_Image_Sentence, self).reconstruct([d for d in raw_data])
        # Preprocessing functions
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                # Set data
                data = raw_data[r][:N]
                # Options for plotting
                if r == 0 and o == 0:
                    data = self._img_process(data.squeeze())
                #    recon = self._img_process(recon.squeeze())
                #    comp = torch.cat([data, recon])
                #    save_image(comp, '{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
                elif r == 0 and o == 1:
                    data = self._img_process(data.squeeze())
                #    recon = self._sent_process(recon.argmax(-1).squeeze())
                #    fig = plt.figure(figsize=(15, 12))
                #    for i, (_data, _recon) in enumerate(zip(data, recon)):
                #        image, caption = (_data, _recon)
                #        fig = self._imshow(image, caption, i, fig, N)
                #    plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
                #    plt.close()
                elif r == 1 and o == 0:
                    data = self._sent_process(data.argmax(-1))
                #    recon = self._img_process(recon.squeeze())
                #    fig = plt.figure(figsize=(15, 12))
                #    for i, (_data, _recon) in enumerate(zip(data, recon)):
                #        image, caption = (_recon, _data)
                #        fig = self._imshow(image, caption, i, fig, N)
                #    plt.savefig('{}/recon_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
                #    plt.close()
                else:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
                        for r_sent, d_sent in zip(recon, data):
                            txt_file.write(
                                '[DATA]  ==> {}\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in d_sent)))
                            txt_file.write(
                                '[RECON] ==> {}\n\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in r_sent)))
                    #    print(
                    #        '[DATA]  ==> {}\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in d_sent)))
                    #    print(
                    #        '[RECON] ==> {}\n\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in r_sent)))

    def reconstruct_sampled(self, raw_data, runPath, epoch):
        N = 8
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_sampled([d[:N] for d in raw_data])
        # Preprocessing functions
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                # Set data
                data = raw_data[r][:N]
                # Options for plotting
                if r == 0 and o == 0:
                    data = self._img_process(data.squeeze())
                    recon = self._img_process(recon.squeeze())
                    comp = torch.cat([data, recon])
                    save_image(comp, '{}/recon_sampled_{}x{}_{:03d}.png'.format(runPath, r, o, epoch))
                elif r == 0 and o == 1:
                    data = self._img_process(data.squeeze())
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_data, _recon)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_{}x{}_sampled_{:03d}.png'.format(runPath, r, o, epoch))
                    plt.close()
                elif r == 1 and o == 0:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._img_process(recon.squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_recon, _data)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_{}x{}_sampled_{:03d}.png'.format(runPath, r, o, epoch))
                    plt.close()
                else:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    with open('{}/recon_{}x{}_sampled_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
                        for r_sent, d_sent in zip(recon, data):
                            txt_file.write(
                                '[DATA]  ==> {}\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in d_sent)))
                            txt_file.write(
                                '[RECON] ==> {}\n\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in r_sent)))

    def reconstruct_options(self, raw_data, runPath, epoch, option, factor=1.0):
        N = 8
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d[:N] for d in raw_data], option=option,
                                                                         factor=factor)
        # Preprocessing functions
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                # Set data
                data = raw_data[r][:N]
                # Options for plotting
                if r == 0 and o == 0:
                    data = self._img_process(data.squeeze())
                    recon = self._img_process(recon.squeeze())
                    comp = torch.cat([data, recon])
                    save_image(comp, '{}/recon_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                elif r == 0 and o == 1:
                    data = self._img_process(data.squeeze())
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_data, _recon)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                    plt.close()
                elif r == 1 and o == 0:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._img_process(recon.squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_recon, _data)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                    plt.close()
                else:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    with open('{}/recon_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
                        for r_sent, d_sent in zip(recon, data):
                            txt_file.write(
                                '[DATA]  ==> {}\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in d_sent)))
                            txt_file.write(
                                '[RECON] ==> {}\n\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in r_sent)))

    def reconstruct_options_sampled(self, raw_data, runPath, epoch, option, factor=1.0):
        N = 8
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_options_sampled([d[:N] for d in raw_data],
                                                                                 option=option, factor=factor)
        # Preprocessing functions
        for r, recons_list in enumerate(recons_mat):
            for o, recon in enumerate(recons_list):
                # Set data
                data = raw_data[r][:N]
                # Options for plotting
                if r == 0 and o == 0:
                    data = self._img_process(data.squeeze())
                    recon = self._img_process(recon.squeeze())
                    comp = torch.cat([data, recon])
                    save_image(comp,
                               '{}/recon_sampled_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                elif r == 0 and o == 1:
                    data = self._img_process(data.squeeze())
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_data, _recon)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_sampled_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                    plt.close()
                elif r == 1 and o == 0:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._img_process(recon.squeeze())
                    fig = plt.figure(figsize=(15, 12))
                    for i, (_data, _recon) in enumerate(zip(data, recon)):
                        image, caption = (_recon, _data)
                        fig = self._imshow(image, caption, i, fig, N)
                    plt.savefig('{}/recon_sampled_{}x{}_{:03d}_{}_{}.png'.format(runPath, r, o, epoch, option, factor))
                    plt.close()
                else:
                    data = self._sent_process(data.argmax(-1))
                    recon = self._sent_process(recon.argmax(-1).squeeze())
                    with open('{}/recon_sampled_{}x{}_{:03d}.txt'.format(runPath, r, o, epoch), "w+") as txt_file:
                        for r_sent, d_sent in zip(recon, data):
                            txt_file.write(
                                '[DATA]  ==> {}\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in d_sent)))
                            txt_file.write(
                                '[RECON] ==> {}\n\n'.format(' '.join(self.vaes[1].i2w[str(i)] for i in r_sent)))

    def cross_generate(self, raw_data, runPath, epoch, r_in, o_in, factor):
        fig = plt.figure(figsize=(15, 12))
        for j in range(8):
            N = 8
            recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d for d in raw_data], 'jointprior',
                                                                             factor=factor)
            # fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)]
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                    if r == r_in == 0 and o == o_in == 1:
                        data = raw_data[r][:N]
                        data = self._img_process(data.squeeze())
                        recon = self._sent_process(recon.argmax(-1).squeeze())
                        if j == 0:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_data, _recon)
                                fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(caption, 'text', i, j + 1, fig, N)
                        else:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_data, _recon)
                                # fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(caption, 'text', i, j + 1, fig, N)

                    if r == r_in == 1 and o == o_in == 0:
                        data = raw_data[r][:N]
                        data = self._sent_process(data.argmax(-1))
                        recon = self._img_process(recon.squeeze())
                        if j == 0:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_recon, _data)
                                fig = self._imshow_data(caption, 'text', i, fig, N)
                                fig = self._imshow_cg(image, 'image', i, j + 1, fig, N)
                        else:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_recon, _data)
                                # fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(image, 'image', i, j + 1, fig, N)

        #plt.savefig('{}/cross_generation_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch))
        plt.show()
        plt.close()

    def cross_generate_images_for_coherence(self, raw_data, factor):
        recons_mat = super(CUB_Image_Sentence, self).reconstruct_options(raw_data, 'jointprior', factor)
        recon = recons_mat[1][0]
        recon = self._img_process(recon.squeeze())
        return recon

    def cross_generate_pretty(self, raw_data, runPath, epoch, r_in, o_in, font, pi, factor):
        imgsize = raw_data[0].size()[-3:]
        process_sentences = lambda sent: text_to_pil_cub(sent, imgsize, font)
        fig = plt.figure(figsize=(15, 12))
        N = 8
        M = 10
        if r_in == 0 and o_in == 1:
            comp = []
            data = raw_data[r_in][:N]
            data = self._img_process(data.squeeze())
            comp.append(data)
            for j in range(M):
                recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d[:N] for d in raw_data],
                                                                                 'jointprior',
                                                                                 factor)
                recon = recons_mat[r_in][o_in]
                recon = self._sent_process(recon.argmax(-1).squeeze())
                recon = [' '.join(self.vaes[1].i2w[str(word)] for word in sent) for sent in recon]
                sentences_tensor = torch.stack([process_sentences(sent) for sent in recon])
                comp.append(sentences_tensor)
            comp_rev = []
            comp_rev.append(comp.pop(0))
            comp.reverse()
            for elem in comp:
                comp_rev.append(elem)
            # plt.imshow(make_grid(torch.cat(comp_rev, dim=0), nrow=N).permute(1,2,0))
            save_image(make_grid(torch.cat(comp_rev, dim=0), nrow=N),
                       fp='{}/cross_generation_pretty_{}x{}_{:03d}_{}.png'.format(runPath, r_in, o_in, epoch, pi))
            # plt.show()
        elif r_in == 1 and o_in == 0:
            comp = []
            data = raw_data[r_in][:N]
            data = self._sent_process(data.argmax(-1))
            data = [' '.join(self.vaes[1].i2w[str(word)] for word in sent) for sent in data]
            sentences_tensor = torch.stack([process_sentences(sent) for sent in data])
            comp.append(sentences_tensor)
            for j in range(M):
                recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d[:N] for d in raw_data],
                                                                                 'jointprior',
                                                                                 factor)
                recon = recons_mat[r_in][o_in]
                recon = self._img_process(recon.squeeze())
                comp.append(recon)
            comp_rev = []
            comp_rev.append(comp.pop(0))
            comp.reverse()
            for elem in comp:
                comp_rev.append(elem)
            # plt.imshow(make_grid(torch.cat(comp_rev, dim=0), nrow=N).permute(1,2,0))
            save_image(make_grid(torch.cat(comp_rev, dim=0), nrow=N),
                       fp='{}/cross_generation_pretty_jointprior_{}_{}x{}_{:03d}_{}.png'.format(runPath, factor, r_in,
                                                                                                o_in, epoch, pi))
            # plt.show()
        else:
            raise ValueError("Incorrect values for r_in and o_in, must be either r_in = 0 and o_in = 1 or vice versa.")

    def cross_generate_pretty_select(self, data, runPath, epoch, samples_to_select, start_mod, factor):
        N = 8
        M = 8
        recon_tries = [[p for p in range(M)] for vae in range(len(self.vaes)-1)]
        _data = data[start_mod][:N].cpu()
        for i in range(M):
            recons_mat = super(CUB_Image_Sentence, self).reconstruct_options([d[:N] for d in data], option="jointprior", factor=factor)
            recons_list = recons_mat[start_mod]
            recons_list.pop(start_mod)
            for o, recon in enumerate(recons_list):
                  recon = recon.squeeze(0).cpu()
                  # resize mnist to 32 and colour. 0 => mnist, 1 => svhn
                  #_data_r = _data_r if r == 1 else resize_img(_data_r, self.vaes[1].dataSize)
                  #_data_o = _data_o if o == 1 else resize_img(_data_o, self.vaes[1].dataSize)
                  # recon = recon if o == 1 else resize_img(recon, self.vaes[1].dataSize)
                  recon_tries[o][i] = recon
        selected_samples = []
        for (o,j) in samples_to_select:
            selected_samples.append(recon_tries[o][j])
        comp = torch.cat([_data]+selected_samples)
        #for m, recon_list in enumerate(recon_tries):
            #comp = torch.cat([_data] + recon_list)
        save_image(comp, '{}/cross_generation_pretty_jointprior_{}_{}_{:03d}.png'.format(runPath, factor, start_mod, epoch), nrow=N)

    def cross_generate_sampled(self, raw_data, runPath, epoch, r_in, o_in):
        fig = plt.figure(figsize=(15, 12))
        for j in range(8):
            N = 8
            recons_mat = super(CUB_Image_Sentence, self).reconstruct_options_sampled([d[:N] for d in raw_data],
                                                                                     'jointprior', 1.0)
            # fns = [lambda images: images.data.cpu(), lambda sentences: self._sent_preprocess(sentences)]
            for r, recons_list in enumerate(recons_mat):
                for o, recon in enumerate(recons_list):
                    if r == r_in == 0 and o == o_in == 1:
                        data = raw_data[r][:N]
                        data = self._img_process(data.squeeze())
                        recon = self._sent_process(recon.argmax(-1).squeeze())
                        if j == 0:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_data, _recon)
                                fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(caption, 'text', i, j + 1, fig, N)
                        else:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_data, _recon)
                                # fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(caption, 'text', i, j + 1, fig, N)

                    if r == r_in == 1 and o == o_in == 0:
                        data = raw_data[r][:N]
                        data = self._sent_process(data.argmax(-1))
                        recon = self._img_process(recon.squeeze())
                        if j == 0:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_recon, _data)
                                fig = self._imshow_data(caption, 'text', i, fig, N)
                                fig = self._imshow_cg(image, 'image', i, j + 1, fig, N)
                        else:
                            for i, (_data, _recon) in enumerate(zip(data, recon)):
                                image, caption = (_recon, _data)
                                # fig = self._imshow_data(image, 'image', i, fig, N)
                                fig = self._imshow_cg(image, 'image', i, j + 1, fig, N)

        plt.savefig('{}/cross_generation_sampled_{}x{}_{:03d}.png'.format(runPath, r_in, o_in, epoch))
        plt.show()
        plt.close()

    def analyse(self, data, runPath, epoch):
        zemb, zsl, kls_df = super(CUB_Image_Sentence, self).analyse(data, K=10)
        labels = ['Prior', *[vae.modelName.lower() for vae in self.vaes]]
        plot_embeddings(zemb, zsl, labels, '{}/emb_umap_{:03d}.png'.format(runPath, epoch))
        plot_kls_df(kls_df, '{}/kl_distance_{:03d}.png'.format(runPath, epoch))
        print(kls_df)

    def _sent_process(self, sentences):
        return [self.vaes[1].fn_trun(self.vaes[1].fn_2i(s)) for s in sentences]

    def _img_process(self, images):
        return images.data.cpu()

    def _imshow(self, image, caption, i, fig, N):
        """Imshow for Tensor."""
        ax = fig.add_subplot(N // 2, 4, i * 2 + 1)
        ax.axis('off')
        image = image.numpy().transpose((1, 2, 0))  #
        plt.imshow(image)
        ax = fig.add_subplot(N // 2, 4, i * 2 + 2)
        pos = ax.get_position()
        ax.axis('off')
        plt.text(
            x=0.5 * (pos.x0 + pos.x1),
            y=0.5 * (pos.y0 + pos.y1),
            ha='left',
            s='{}'.format(
                ' '.join(self.vaes[1].i2w[str(i)] + '\n' if (n + 1) % 5 == 0
                         else self.vaes[1].i2w[str(i)] for n, i in enumerate(caption))),
            fontsize=6,
            verticalalignment='center',
            horizontalalignment='center'
        )
        return fig

    def _imshow_data(self, data, mod, i, fig, N):
        """Imshow for Tensor."""
        if mod == 'text':
            ax = fig.add_subplot(9, 8, i + 1)
            pos = ax.get_position()
            ax.axis('off')
            plt.text(
                x=0.5 * (pos.x0 + pos.x1),
                y=0.5 * (pos.y0 + pos.y1),
                ha='left',
                s='{}'.format(
                    ' '.join(self.vaes[1].i2w[str(i)] + '\n' if (n + 1) % 5 == 0
                             else self.vaes[1].i2w[str(i)] for n, i in enumerate(data))),
                fontsize=5,
                verticalalignment='center',
                horizontalalignment='center'
            )
        else:
            ax = fig.add_subplot(9, 8, i + 1)
            ax.axis('off')
            image = data.numpy().transpose((1, 2, 0))  #
            plt.imshow(image)
        return fig

    def _imshow_cg(self, data, mod, i, j, fig, N):
        """Imshow for Tensor."""
        if mod == 'text':
            ax = fig.add_subplot(9, 8, j * 8 + i + 1)
            pos = ax.get_position()
            ax.axis('off')
            plt.text(
                x=0.5 * (pos.x0 + pos.x1),
                y=0.5 * (pos.y0 + pos.y1),
                ha='left',
                s='{}'.format(
                    ' '.join(self.vaes[1].i2w[str(i)] + '\n' if (n + 1) % 5 == 0
                             else self.vaes[1].i2w[str(i)] for n, i in enumerate(data))),
                fontsize=5,
                verticalalignment='center',
                horizontalalignment='center'
            )
        else:
            ax = fig.add_subplot(9, 8, j * 8 + i + 1)
            ax.axis('off')
            image = data.numpy().transpose((1, 2, 0))  #
            plt.imshow(image)
        return fig


def text_to_pil_cub(t, imgsize, font, w=192, h=300, linewidth: int = 15, max_nbr_lines: int = 10, text_cleanup=True):
    blank_img = torch.ones([imgsize[0], w, h])
    pil_img = transforms.ToPILImage()(blank_img.cpu()).convert("RGB")
    draw = ImageDraw.Draw(pil_img)
    sep = ' '
    text_sample = t
    if text_cleanup:
        text_sample = [word for word in text_sample if (word != '{pad}' and
                                                        word != '{eos}')]
    text_sample = sep.join(text_sample).translate({ord('*'): None}).replace('.',
                                                                            '.')
    lines = textwrap.wrap(text_sample, width=33)
    lines = lines[:max_nbr_lines]
    lines = '\n'.join(lines)
    draw.multiline_text((10, 10), lines, font=font, fill=(0, 0, 0))
    if imgsize[0] == 3:
        return transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]),
                                                    Image.ANTIALIAS))
    else:
        return transforms.ToTensor()(pil_img.resize((imgsize[1], imgsize[2]),
                                                    Image.ANTIALIAS).convert('L'))