'''Contains visualization tools'''
from sklearn.decomposition import PCA
import torchvision
import math
import pdb
import torch
import utils
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import kde
import seaborn as sns
from moviepy.editor import *
from argparse import Namespace
import os
from matplotlib.ticker import MaxNLocator
from matplotlib.gridspec import GridSpec
import matplotlib.cm as cm
import seaborn as sns
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''Plot losses'''


def plot_MMD_loglik(ls_all_test, args, end_of_block_iter=False):
    '''
        end_of_block_iter: if True, we plot test metrics over all iterations for all blocks so far. Default is False because we plot test metric at EACH iteration.
    '''
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    errs = np.array(ls_all_test)
    if end_of_block_iter:
        MMD_ls, loglik_ls = errs[:, :, 0], errs[:, :, 1]
    else:
        MMD_ls, loglik_ls = errs[:, 0], errs[:, 1]
    ax[0].plot(MMD_ls.flatten())
    ax[0].set_title(r'Median Trick MMD of $X$ and $\hat{X}$')
    ax[1].plot(loglik_ls.flatten())
    ax[1].set_title('Test log-likelihood')
    fig.suptitle(
        f'Test metrics over {args.niters} training epochs per block over {args.num_blocks} blocks_phase{args.p}', y=1)
    fig.tight_layout()
    return fig


def plot_losses(ls_all, args):
    fig, ax = plt.subplots(2, 2, figsize=(8, 8))
    errs = np.array(ls_all)
    ax[0, 0].plot(errs[:, :, -1].flatten())
    ax[0, 0].set_title('Sum of three')
    ax[0, 1].plot(errs[:, :, 0].flatten())
    ax[0, 1].set_title(r'$-\int_0^1 \nabla \cdot f_b(X_b(s),s)ds$')
    ax[1, 0].plot(errs[:, :, 1].flatten())
    ax[1, 0].set_title(r'$V(X_b+\int_0^1 f_b(X_b(s),s)ds)/2$')
    ax[1, 1].plot(errs[:, :, 2].flatten())
    ax[1, 1].set_title(r'$W_2^2(f_b)/(2T_b)$')
    fig.suptitle(
        f'Training metrics over {args.niters} training epochs per block over {args.num_blocks} blocks_phase{args.p}', y=1)
    fig.tight_layout()
    return fig


def plot_losses_over_phases(args, full_container):
    print('T_b at each phase')
    print(args.T_dict)
    # fig, ax = plt.subplots(3, 1, figsize=(10, 12), sharex=True)
    fig = plt.figure(figsize=(10, 9))
    gs = GridSpec(4, 1, figure=fig)
    ax2 = fig.add_subplot(gs[2:])
    ax0 = fig.add_subplot(gs[0], sharex=ax2)
    ax1 = fig.add_subplot(gs[1], sharex=ax2)
    for p in range(args.num_phase):
        B = len(args.T_dict[p])
        ax0.plot(range(1, B+1), [loss[-1][0]
                                 for loss in full_container.ls_all_dict[p]], label=f'Phase {p+1}')
        ax0.set_title(
            r'Divergence: $-\int_{t_k}^{t_{k+1}} \nabla \cdot f_k(X_k(s),s)ds$', fontsize=20, pad=10)
        ax1.plot(range(1, B+1), [loss[-1][1]
                                 for loss in full_container.ls_all_dict[p]], label=f'Phase {p+1}')
        ax1.set_title(
            r'$V$ loss: $||X_k+\int_{t_k}^{t_{k+1}} f_k(X_k(s),s)ds)||_2^2/2$', fontsize=20, pad=10)
        lossW2old = np.array([loss[-1][2]
                              for loss in full_container.ls_all_dict[p]])
        T_old_ls = np.array(args.T_dict[p])
        W2_allblocks = np.sqrt(2*T_old_ls*lossW2old)
        ax2.plot(range(1, B+1), W2_allblocks, label=f'Phase {p+1}')
        ax2.set_title(
            r'$W_2(f_k):=|| \int_{t_k}^{t_{k+1}} f_k(X_k(s),s)ds ||_2$', fontsize=20)
        ax2.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5, -0.05),
                   fancybox=True, shadow=True, ncol=3)
        ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax2.set_xlabel('Block', fontsize=20)
        for a in [ax0, ax1, ax2]:
            a.tick_params(axis='both', which='major', labelsize=15)
            a.tick_params(axis='both', which='minor', labelsize=15)
        fig.tight_layout()
    return fig


'''Plot forward and backward'''


def get_PCA_plot(self, X_test, X_test_hat):
    if self.X_test_PCA is None:
        pca = PCA(n_components=2)
        X_test_tmp = X_test.view(X_test.shape[0], -1).cpu().detach().numpy()
        pca.fit(X_test_tmp)
        explained_var = pca.explained_variance_ratio_.sum()
        print(f'Top 2 components explain {explained_var*100:.2f}% variance')
        self.V_two_dim = pca.components_.T
        self.X_test_PCA = X_test_tmp.dot(self.V_two_dim)
    # Project to 2D by PCA
    X_test_hat_PCA = X_test_hat.view(
        X_test.shape[0], -1).cpu().detach().numpy().dot(self.V_two_dim)
    visualize_after_PCA(self.X_test_PCA, X_test_hat_PCA,
                        self.args.word, self.use_kde)


def visualize_after_PCA(X_test_PCA, X_test_hat_PCA, dataname, use_kde=False):
    fig, ax = plt.subplots(1, 2, figsize=(10, 3), sharex=True, sharey=True)
    if dataname == 'gas':
        # Remove strange outlier
        X_test_PCA = X_test_PCA[X_test_PCA[:, 0] < 10]
    KDE(X_test_PCA, ax[0], use_kde)
    KDE(X_test_hat_PCA, ax[1], use_kde)
    fontsize = 22
    ax[0].set_title(r'Projected $X$', fontsize=fontsize)
    ax[1].set_title(r'Projected $F^{-1}(Z)$', fontsize=fontsize)
    # fig.tight_layout()
    fig.savefig(f'{dataname}_PCA_2D_compare.png', dpi=150,
                bbox_inches='tight', pad_inches=0)
    plt.show()
    plt.close()


def KDE(data, ax, use_kde=False, cmap='viridis'):
    # NOTE, run kde takes a while so we not always run it
    # # ax.set_facecolor('lightblue')

    x, y = data[:, 0], data[:, 1]
    if use_kde:
        xy = np.vstack([x, y])
        k = kde.gaussian_kde([x, y])(xy)
        ax.scatter(x, y, c=k, s=2, cmap=cmap)
    else:
        ax.scatter(x, y, s=2)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)


def slice_data_plt(X_test, X_test_hat, dim1, dim2):
    X_test = X_test.cpu().detach().numpy()
    X_test_hat = X_test_hat.cpu().detach().numpy()
    fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True)
    KDE(X_test[:, [dim1, dim2]], ax[0], use_kde=True, cmap='inferno')
    KDE(X_test_hat[:, [dim1, dim2]], ax[1], use_kde=True, cmap='inferno')
    for a in ax:
        a.set_facecolor('black')
    plt.show()
    plt.close()


def vision_data_visualize(self, args, type='interpolate', num_rand_fig=16, rand_seeds=None, rows=None, cols=None):
    '''
        Type controls how we visualize performance, which supports:
        1. Interpolate
        2. Generate from random N(0,I_d) and then flow back
    '''
    FlowNet = self.FlowNet
    test_batch = iter(self.test_loader).next()
    d = args.Xdim
    dataname = args.word  # 'MNIST' or 'CIFAR10'
    if dataname == 'CIFAR10':
        type = 'random_sample'
    if type == 'interpolate':
        num_digits = 4
        # Plot 2 original images AT THE END of each row and interpolate three in the middle
        fig, ax = plt.subplots(num_digits, 5, figsize=(5*2, num_digits*2))
        ALL_digits = np.random.choice(range(10), 10, replace=False)
        imgs, labels = test_batch
        if 'FC' in args.netname:
            imgs = imgs.view(imgs.shape[0], 28, 28)  # For AutoEncoder
        alphas = [0.8, 0.5, 0.2]
        for i, d in enumerate(ALL_digits[:num_digits]):
            # Img -> Enc(Img) -> Z_tilde = Interpolate F(Enc(Img))
            two_img = imgs[labels == d][:2]
            if 'FC' in args.netname:
                two_img_enc = self.through_AE(two_img.view(
                    2, -1).to(device), args, 0, encode=True)
            else:
                two_img_enc = two_img.clone().to(device)
            z_est_ls = utils.map_for_or_back(input=two_img_enc,
                                             num_blocks=len(FlowNet), FlowNet=FlowNet,
                                             args=args, reverse=False)
            zhat_enc = z_est_ls[-1]
            interpolated_img_enc = torch.stack(
                [alpha*zhat_enc[0]+(1-alpha)*zhat_enc[1] for alpha in alphas])
            # F^{-1}(Z_tilde) -> Dec(F^{-1}(Z_tilde))
            x_est_ls = utils.map_for_or_back(input=interpolated_img_enc,
                                             num_blocks=len(FlowNet), FlowNet=FlowNet,
                                             args=args, reverse=True)
            xhat_enc = x_est_ls[-1]
            img_orig_gen = self.through_AE(xhat_enc, args, 1, encode=False)
            map_type = 'gray'
            ax[i, 0].imshow(two_img[0].reshape(28, 28),
                            cmap=map_type)  # True image 1
            ax[i, -1].imshow(two_img[1].reshape(28, 28),
                             cmap=map_type)  # True image 2
            fsize = 16
            if i == 0:
                ax[i, 0].set_title(r'$X_1$', fontsize=fsize)
                ax[i, -1].set_title(r'$X_2$', fontsize=fsize)
            for j, alpha in enumerate(alphas):
                ax[i, j+1].imshow(img_orig_gen[j].cpu().reshape(28,
                                  28), cmap=map_type)
                if i == 0:
                    ax[i, j+1].set_title(r'$\alpha=$'
                                         + f'{alpha}', fontsize=fsize)
            for a in ax.ravel():
                a.get_yaxis().set_visible(False)
                a.get_xaxis().set_visible(False)
                a.set_aspect('equal')
            fig.suptitle(
                r'$D(F^{-1}[\alpha F(E(X_1))+(1-\alpha)F(E(X_2))]), X_1,X_2 \sim D_{X|Y}$', y=0.96, fontsize=18)
    elif type == 'random_sample':
        # Randomly sample from N(0,I_d), d=low dim
        if rand_seeds is None:
            rand_seeds = np.random.choice(100000, num_rand_fig, replace=False)
            rows = int(np.sqrt(num_rand_fig))
            cols = int(num_rand_fig/rows)
        Z_ls = []
        for seed in rand_seeds:
            torch.manual_seed(seed)
            if 'FC' in args.netname:
                # If thhrough AutoEncoder
                Z_ls.append(torch.randn(d))
            else:
                Z_ls.append(torch.randn(d, 28, 28))
        Z = torch.stack(Z_ls).to(device)
        Z_back_ls = utils.map_for_or_back(input=Z,
                                          num_blocks=len(FlowNet), FlowNet=FlowNet,
                                          args=args, reverse=True)
        Z_back = Z_back_ls[-1]
        if 'FC' in args.netname:
            X_gen = self.through_AE(Z_back, args, 0, encode=False)
        else:
            X_gen = Z_back
        if dataname == 'MNIST':
            fig = display_mult_mnist_images(
                X_gen.cpu().reshape(num_rand_fig, 28, 28), rows, cols)
            # yloc = 1.05
        else:
            cols = 8
            rows = int(num_rand_fig/cols)
            fig = display_mult_cifar10_images(X_gen, rows, cols)
            # yloc = 0.98
        # fig.suptitle(
        #     r'$D(F^{-1}(Z)), Z\sim N(0,I_{d_{\rm{small}}})$', y=yloc, fontsize=24)
    else:
        raise ValueError('Other visualization not supported yet')
    plt.show()
    plt.close()
    return fig, rand_seeds


def for_and_back(self, args, base_dist_ls=[]):
    FlowNet = self.FlowNet
    if args.cond_gen:
        if args.word == 'two_moon':
            xinput, yraw = utils.inf_train_gen_cond_gen(args, train=False)
        else:
            xinput, yraw = self.xraw, self.yraw
        xinput, yraw = xinput.to(device), yraw.to(device)
        yraw = yraw.flatten(start_dim=0, end_dim=1)
        unique_Y = torch.unique(yraw)
        colors = cm.CMRmap(np.linspace(0, 0.5, len(unique_Y)))
    else:
        xinput = utils.inf_train_gen(args, train=False)
        if args.color_X:
            xinput, yraw = xinput
            xinput = xinput.to(device)
            unique_Y = torch.unique(yraw.to(device))
            colors = cm.viridis(np.linspace(0.5, 1, len(unique_Y)))
        else:
            xinput = xinput.to(device)
            yraw = torch.zeros(xinput.shape[0]).to(device)
            colors = cm.binary(np.ones(1))
    # Get what is needed for plotting
    plot_dict = {}
    if args.cond_gen:
        plot_dict[0] = xinput.flatten(
            start_dim=0, end_dim=1).cpu().detach().numpy()
    else:
        plot_dict[0] = xinput.cpu().detach().numpy()
    offset = args.offset  # Because last several blocks are not good
    num_blocks = args.num_blocks - offset
    if args.cond_gen:
        N, V, C = xinput.shape
        NV = int(N*V)
        z = torch.zeros(NV, C).to(device)
        for i in unique_Y:
            # NOTE, this replaces "get_reindex" in iGNN
            base_disti = base_dist_ls[int(i.cpu().detach().numpy())]
            idx_i = yraw == i
            z[idx_i] = base_disti.rsample(sample_shape=(
                idx_i.sum().cpu().detach().numpy(),))
        z = z.reshape(N, V, C)
    else:
        N, V, C = xinput.shape[0], args.Xdim, 1
        z = torch.randn(N, V).to(device)
    x_est_ls = utils.map_for_or_back(input=z,
                                     num_blocks=num_blocks, FlowNet=FlowNet,
                                     args=args, reverse=True, refinement=self.refinement)
    xhat = x_est_ls[-1]
    z_est_ls = utils.map_for_or_back(input=xinput,
                                     num_blocks=num_blocks, FlowNet=FlowNet,
                                     args=args, reverse=False, refinement=self.refinement)
    zhat = z_est_ls[-1]
    if args.cond_gen and args.word == 'solar':
        # Here, we conditionally visualize the results based on membership of Y
        # Taken from IGNN "viz_generation"
        num_viz = 3
        Y = self.yraw
        Unique_Y, counts_Y = torch.unique(Y, return_counts=True, dim=0)
        counts_Y, idx = torch.sort(counts_Y, descending=True)
        Unique_Y = Unique_Y[idx]
        for pp, Y_row in enumerate(Unique_Y[:num_viz]):
            which_rows = (Y == Y_row).all(dim=1)
            X = xinput[which_rows]
            X_pred = xhat[which_rows]
            Y_plt = Y[which_rows]
            H_full = z[which_rows]
            H_pred = zhat[which_rows]
            plt_args = Namespace(
                V=X.shape[1], C=args.Xdim, final_viz=True, plot_sub=True)
            plt_generation_fig(plt_args, X, X_pred, Y_plt, H_full, H_pred)
            plt_args.fig_gen.savefig(
                f'Top_{pp+1}_{args.word}_generation.png', dpi=150, bbox_inches='tight', pad_inches=0)
            # Get MMD
            utils.get_MMD_dict(self, X, X_pred)
            self.alpha_MMD = None  # o/w the median by class is incorrect
            MMD_metric_dict = self.args.MMD_test
            print('###############Test Metrics###############')
            for alpha, MMD_metric in MMD_metric_dict.items():
                print(
                    f'--Test MMD loss at alpha={alpha}: {MMD_metric.item():.2e}')
            # Check PCA
            self.args.word = f'solar_top{pp}'
            get_PCA_plot(self, X, X_pred)
            self.X_test_PCA = None  # o/w the PCA by class is incorrect
            self.args.word = 'solar'
    else:
        # All 2D visualization except solar conditional generation
        if args.cond_gen:
            plot_dict[1] = xhat.flatten(
                start_dim=0, end_dim=1).cpu().detach().numpy()
            plot_dict[2] = z.flatten(
                start_dim=0, end_dim=1).cpu().detach().numpy()
            plot_dict[3] = zhat.flatten(
                start_dim=0, end_dim=1).cpu().detach().numpy()
        else:
            plot_dict[1] = xhat.cpu().detach().numpy()
            plot_dict[2] = z.cpu().detach().numpy()
            plot_dict[3] = zhat.cpu().detach().numpy()
        if args.cond_gen:
            title_dict = {0: r'$X|Y$',
                          1: r'$\hat{X}|Y=F^{-1}(H|Y)$', 2: r'$H|Y$', 3: r'$\hat{H}|Y=F(X|Y)$'}
        else:
            title_dict = {
                0: r'$X$', 1: r'$\hat{X}=F^{-1}(Z)$', 2: r'$Z$', 3: r'$\hat{Z}=F(X)$'}
        if C > 2:
            V = int(C/2)
            C = 2
            for key in plot_dict.keys():
                plot_dict[key] = plot_dict[key].reshape(N, V, C)
        # Start plotting
        fig, ax = plt.subplots(1, 4, figsize=(14, 4))
        for i in range(4):
            # Plot structure: X, \hat{X}, Z, \hat{Z}
            if i in [1, 2] and args.cond_gen == False:
                color_plt = 'black'
            else:
                color_plt = np.vstack([colors[int(i.cpu().detach().numpy())]
                                       for i in yraw])
            plt_data = plot_dict[i]
            plt_title = title_dict[i]
            if args.Xdim == 1:
                ax[i].hist(plt_data, bins=100)
            else:
                ax[i].scatter(plt_data[:, 0], plt_data[:, 1],
                              s=2, color=color_plt)
                if V > 1 and C > 1:
                    # Connect dots
                    lwidth = 0.075
                    if 'solar' in args.savename or 'traffic' in args.savename:
                        lwidth = 0.025
                    ax[i].plot(plt_data[:, 0], plt_data[:, 1],
                               linestyle='dashed', linewidth=lwidth)
            ax[i].set_title(plt_title, fontsize=24)
            for a in ax:
                a.axes.get_yaxis().set_visible(False)
                a.axes.get_xaxis().set_visible(False)
            ax[3].get_shared_x_axes().join(ax[3], ax[2])
            ax[3].get_shared_y_axes().join(ax[3], ax[2])
            ax[1].get_shared_x_axes().join(ax[1], ax[0])
            ax[1].get_shared_y_axes().join(ax[1], ax[0])
            for a in ax:
                a.axes.get_yaxis().set_visible(False)
                a.axes.get_xaxis().set_visible(False)
            fig.tight_layout()
        fig.savefig(f'{args.word}_true_vs_gen.png', dpi=150,
                    bbox_inches='tight', pad_inches=0)
        plt.show()
        plt.close()
    return x_est_ls, z_est_ls


def quick_scatter(self, X_test, X_test_hat):
    d1, d2 = 0, 1
    s = 2 if self.args.word == 'img_rose.png' else 1
    titlesize = 24
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    if X_test.shape[1] == 1:
        # One-dimensional example
        # Two-D examplees
        sns.kdeplot(X_test.detach().cpu().numpy().flatten(), ax=axs[0])
        sns.kdeplot(X_test_hat.detach().cpu().numpy().flatten(), ax=axs[1])
    else:
        # Two-D examplees
        axs[0].scatter(X_test.detach().cpu().numpy()[:, d1],
                       X_test.detach().cpu().numpy()[:, d2], s=s, color='black')
        axs[1].scatter(X_test_hat.detach().cpu().numpy()[:, d1],
                       X_test_hat.detach().cpu().numpy()[:, d2], s=s, color='black')
    axs[0].set_title(r'$X$', fontsize=titlesize)
    axs[1].set_title(r'$\hat{X}=F^{-1}(Z)$', fontsize=titlesize)
    axs[1].get_shared_x_axes().join(axs[1], axs[0])
    axs[1].get_shared_y_axes().join(axs[1], axs[0])
    for a in axs:
        a.axes.get_yaxis().set_visible(False)
        a.axes.get_xaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(f'{self.args.word}_true_vs_gen.png', dpi=150,
                bbox_inches='tight', pad_inches=0)
    plt.show()
    plt.close()


def plt_generation_fig(plt_args, X, X_pred, Y, H_full, H_pred):
    # Modified from IGNN
    plt_dict = {0: X, 1: X_pred, 2: H_full, 3: H_pred}
    V_tmp = X.shape[1]
    N = X.shape[0]
    if plt_args.C > 2:
        # NOTE: this is because FC treated graph example in R^V-x-C as a vector in \R^V-by-C, so that we need reshaping for visualization
        V_tmp = int(plt_args.C/2)
        C_tmp = 2
        for key in plt_dict.keys():
            plt_dict[key] = plt_dict[key].reshape(N, V_tmp, C_tmp)
    if plt_args.final_viz and plt_args.plot_sub:
        title_dict = {
            0: r'$X|Y$', 1: r'$\hat{X}|Y=F^{-1}(H|Y)$'}
        fig, axs = plt.subplots(1, 2, figsize=(2 * 4, 4))
    else:
        title_dict = {
            0: r'$X|Y$', 1: r'$\hat{X}|Y=F^{-1}(H|Y)$', 2: r'$H|Y$', 3: r'$\hat{H}|Y=F(X|Y)$'}
        fig, axs = plt.subplots(1, 4, figsize=(4 * 4, 4))
    # Plot X and X_pred=F^-1(H)
    which_cmap = cm.coolwarm
    markersize = 20
    lwidth = 0.025
    X = plt_dict[0]
    vars = torch.var(X, dim=[0, 2]).cpu().detach()
    vars, idx = torch.sort(vars, descending=True)
    # All between 0 and 1
    vars = ((vars-vars.min())/(vars.max()-vars.min()))
    cutoff = 0.7
    vars[vars > cutoff] = vars[vars > cutoff]**2  # Make them lighter
    vars = torch.flip(vars, dims=(0,)).numpy()
    print(f'{len(X)} graphs, 1st Var to Last Var, lightest to darkest: {vars}')
    colors = np.tile(which_cmap(vars), (X.shape[0], 1))
    for j in range(len(title_dict)):
        ax, ax1 = axs[j], axs[0]
        if j > 1:
            ax2 = axs[2]
        XorH = plt_dict[j]
        XorXpred_tmp = XorH.flatten(
            start_dim=0, end_dim=1).cpu().detach().numpy()
        if plt_args.C == 1:
            XorXpred_tmp = np.c_[XorXpred_tmp, np.zeros(XorXpred_tmp.shape)]
        if plt_args.V > 1 or (plt_args.V == 1 and plt_args.C > 2):
            ax.plot(XorXpred_tmp[:, 0], XorXpred_tmp[:, 1],
                    linestyle='dashed', linewidth=lwidth)
        ax.scatter(XorXpred_tmp[:, 0],
                   XorXpred_tmp[:, 1], color=colors, s=markersize)
        ax.set_title(title_dict[j], fontsize=24)
        # Uncomment if we want the subplots to have fixed axes according to True X
        if j < 2:
            X_tmp = plt_dict[0].flatten(
                start_dim=0, end_dim=1).cpu().detach().numpy()
            ax.set_xlim(left=X_tmp[:, 0].min(), right=X_tmp[:, 0].max())
            ax.set_ylim(bottom=X_tmp[:, 1].min(), top=X_tmp[:, 1].max())
        if j == 1:
            ax.get_shared_x_axes().join(ax1, ax)
            ax.get_shared_y_axes().join(ax1, ax)
        if j == 3:
            ax.get_shared_x_axes().join(ax2, ax)
            ax.get_shared_y_axes().join(ax2, ax)
    fig.tight_layout()
    if plt_args.final_viz and plt_args.plot_sub:
        axs[0].set_visible(False)
        axs[1].tick_params(axis='both', which='major', labelsize=14)
        axs[1].tick_params(axis='both', which='minor', labelsize=14)
    plt_args.fig_gen = fig
    plt.show()
    plt.close()


"""## Plot transformations as .mp4 (forward and backward, nearly the same as IGNN visualization)"""


def plt_and_save(data_true, data_est_ls, args, est_X=True):
    '''
        data_true: true data whose distribution is to be matched by estimates
        data_est_ls: 3D tensor with shape (num_int_pts*num_blocks X num_samples X dimension)
        est_X: just affects title
    '''
    plt.rcParams['axes.titlesize'] = 18
    plt.rcParams['legend.fontsize'] = 13
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14
    plt.rcParams['figure.titlesize'] = 20
    num_fig, N, dim = data_est_ls.shape
    # Diff over each integration
    data_est_diff = torch.diff(data_est_ls, dim=0)
    xmin, xmax = data_true[:, 0].min().cpu().detach().numpy(
    ), data_true[:, 0].max().cpu().detach().numpy()
    data_true, data_est_ls = data_true.cpu().detach(
    ).numpy(), data_est_ls.cpu().detach().numpy()
    data_est_diff = data_est_diff.cpu().detach().numpy()
    data_est_diff_norm = (np.linalg.norm(
        data_est_diff, axis=2)**2).mean(axis=1)
    if dim > 1:
        ymin, ymax = data_true[:, 1].min(), data_true[:, 1].max()
    for i in range(num_fig-1):
        if args.word != '' and 'img' not in args.word:
            fig = plt.figure(figsize=(8, 8))
            spec = fig.add_gridspec(8, 2)
            w2_len, gen_len = 2, 3
        else:
            fig = plt.figure(figsize=(8, 11))
            spec = fig.add_gridspec(5, 2)
            w2_len, gen_len = 1, 2
        # Plot W2
        ax = fig.add_subplot(spec[0:w2_len, :])
        ax.plot(range(1, i+2), data_est_diff_norm[:i+1], '-o')
        suff = r'$Z \rightarrow X$' if est_X else r'$X \rightarrow Z$'
        a, b = i % (args.num_int_pts-1), i // (args.num_int_pts-1)
        ax.set_title(r'$W_2$ cost of '+suff+' at \n'
                     + f'{a+1}th step in block {b+1}')
        ax.set_facecolor('lightblue')
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        # Plot true data
        ax = fig.add_subplot(spec[w2_len:w2_len+gen_len, 0])
        ax.set_facecolor('lightblue')
        ax.set_xlim(xmin, xmax)
        if dim > 1:
            ax.set_ylim(ymin, ymax)
            ax.scatter(data_true[:, 0], data_true[:, 1], s=2)
        else:
            ax.hist(data_true.flatten(), bins=200)
        suff = r'$X$' if est_X else r'$Z$'
        ax.set_title('Target '+suff)
        # Plot density of estimate
        ax = fig.add_subplot(spec[w2_len:w2_len+gen_len, 1])
        ax.set_facecolor('lightblue')
        ax.set_xlim(xmin, xmax)
        data_est = data_est_ls[i+1]
        if dim > 1:
            ax.set_ylim(ymin, ymax)
            x, y = data_est[:, 0], data_est[:, 1]
            xy = np.vstack([x, y])
            k = kde.gaussian_kde([x, y])(xy)
            ax.scatter(x, y, c=k, s=2)
        else:
            sns.distplot(data_est.flatten(), ax=ax, hist=False)
        suff = r'$\hat{X}$' if est_X else r'$\hat{Z}$'
        ax.set_title('Density of '+suff)
        # Plot estimates
        ax = fig.add_subplot(spec[w2_len+gen_len:, 0])
        ax.set_facecolor('lightblue')
        ax.set_xlim(xmin, xmax)
        data_est = data_est_ls[i+1]
        if dim > 1:
            ax.set_ylim(ymin, ymax)
            ax.scatter(x, y, s=2)
        else:
            sns.histplot(data_est.flatten(), bins=200, ax=ax)
        suff = r'$\hat{X}$' if est_X else r'$\hat{Z}$'
        ax.set_title('Estimates '+suff)
        # Plot vector field.
        ax = fig.add_subplot(spec[w2_len+gen_len:, 1])
        ax.set_facecolor('lightblue')
        ax.set_xlim(xmin, xmax)
        est_diff = data_est_diff[i]
        directions = est_diff
        if dim == 1:
            x = data_est.flatten()
            y = kde.gaussian_kde([data_est_ls[i].flatten()])(
                data_est_ls[i].flatten())
            est_diff_density = kde.gaussian_kde(
                [data_est_ls[i+1].flatten()])(data_est_ls[i+1].flatten())-y
            directions = np.c_[est_diff, est_diff_density]
        if dim > 1:
            ax.set_ylim(ymin, ymax)
        logmag = 2 * \
            np.log(np.hypot(directions[:, 0], directions[:, 1]))
        ax.quiver(
            x, y, directions[:, 0], directions[:, 1],
            np.exp(logmag), cmap="coolwarm", scale=3.5, width=0.015, pivot="mid")
        ax.set_title('Vector Field')
        fig.tight_layout()
        # Savefig
        suff = '_estX' if est_X else '_estZ'
        dir = f'{args.savename}{suff}'
        isExist = os.path.exists(dir)
        if not isExist:
            # Create a new directory because it does not exist
            os.makedirs(dir)
            print("The new directory is created!")
        # fig.savefig(f'{dir}/'+'{:04d}.jpg'.format(i), dpi=150, bbox_inches='tight', pad_inches=0)
        fig.savefig(f'{dir}/'+'{:04d}.jpg'.format(i))
        plt.show()
        plt.close()


def trajectory_to_mp4andgif(args, est_X=True):
    # NOTE: cannot include suptitle in figure, as o/w cannot save properly
    suff = '_estX' if est_X else '_estZ'
    dir = f'{args.savename}{suff}'
    out_path = os.path.join(dir, f'output_phase{args.p}')
    import subprocess
    if os.path.exists(out_path+'.mp4'):
        os.remove(out_path+'.mp4')
        print('An earlier .mp4 deleted')
    # Smaller framerate reduces picture speed (desirable if num blocks small)
    # 10 for 40 blocks in discrete ResNet was pretty fast
    bashCommand = 'ffmpeg -framerate 10 -i {} {}'.format(os.path.join(
        dir, '%04d.jpg'), out_path+'.mp4')
    process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
    output, error = process.communicate()
    # # If want a .gif, then run these
    # clip = (VideoFileClip(out_path+'.mp4'))
    # clip.write_gif(out_path + '.gif')


"""Other quick/minor plots"""


def display_mult_cifar10_images(images, rows, cols):
    fig, ax = plt.subplots(rows, 1, figsize=(2*cols, 2*rows))
    num_per_row = cols
    for i in range(rows):
        start = i*num_per_row
        end = (i+1)*num_per_row
        grid_img_gen = torchvision.utils.make_grid(
            images[start:end], nrow=num_per_row)
        grid_img_gen = grid_img_gen.permute(1, 2, 0).detach().cpu().numpy()
        ax[i].imshow(grid_img_gen)
    for a in ax.ravel():
        a.get_yaxis().set_visible(False)
        a.get_xaxis().set_visible(False)
        a.set_aspect('equal')
    figure.tight_layout(h_pad=0.0, w_pad=0.0)
    return fig


def display_mult_mnist_images(images, rows, cols, titles=[]):
    '''
        Display MNIST images (tensor inputs without gradients are needed)
    '''
    # https://jamesmccaffrey.wordpress.com/2022/03/14/displaying-multiple-mnist-images-in-a-single-figure/
    figure, ax = plt.subplots(rows, cols, figsize=(
        2*cols, 2*rows))  # array of axes
    for idx, img in enumerate(images):  # images is a list
        ax.ravel()[idx].imshow(img, cmap='gray')
        if len(titles) > 1:
            ax.ravel()[idx].set_title(titles[idx])
    for a in ax.ravel():
        a.get_yaxis().set_visible(False)
        a.get_xaxis().set_visible(False)
        a.set_aspect('equal')
    figure.tight_layout(h_pad=0.0, w_pad=0.0)
    return figure


def plotAutoEncCIFAR10(x, xRecreate, sPath):
    # From OT-Flow
    # assume square image
    # visualize CIFAR10
    fig, ax = plt.subplots(4, 1, figsize=(10, 8))
    fig.suptitle("Rows 1 and 2 originals. Rows 3 and 4 are generations.")
    num_per_row = 5
    for i in range(2):
        start = i*num_per_row
        end = (i+1)*num_per_row
        grid_img = torchvision.utils.make_grid(x[start:end], nrow=num_per_row)
        grid_img_gen = torchvision.utils.make_grid(
            xRecreate[start:end], nrow=num_per_row)
        grid_img = grid_img.permute(1, 2, 0).detach().cpu().numpy()
        grid_img_gen = grid_img_gen.permute(
            1, 2, 0).detach().cpu().numpy()
        ax[i].imshow(grid_img)
        ax[i+2].imshow(grid_img_gen)
    for a in ax.ravel():
        a.get_yaxis().set_visible(False)
        a.get_xaxis().set_visible(False)
        a.set_aspect('equal')
    plt.subplots_adjust(wspace=0.0, hspace=0.0)
    if not os.path.exists(os.path.dirname(sPath)):
        os.makedirs(os.path.dirname(sPath))
    plt.savefig(sPath, dpi=300)
    plt.show()
    plt.close()


def plotAutoEnc(x, xRecreate, sPath):
    # From OT-Flow
    # assume square image
    # visualize CIFAR10
    s = int(math.sqrt(x.shape[1]))

    nex = 8

    fig, axs = plt.subplots(4, nex//2)
    fig.set_size_inches(9, 9)
    fig.suptitle("first 2 rows originals. Rows 3 and 4 are generations.")

    for i in range(nex//2):
        axs[0, i].imshow(x[i, :].reshape(
            s, s).detach().cpu().numpy(), cmap='gray')
        axs[1, i].imshow(x[nex//2 + i, :].reshape(s,
                         s).detach().cpu().numpy(), cmap='gray')
        axs[2, i].imshow(xRecreate[i, :].reshape(
            s, s).detach().cpu().numpy(), cmap='gray')
        axs[3, i].imshow(
            xRecreate[nex//2 + i, :].reshape(s, s).detach().cpu().numpy(), cmap='gray')

    for i in range(axs.shape[0]):
        for j in range(axs.shape[1]):
            axs[i, j].get_yaxis().set_visible(False)
            axs[i, j].get_xaxis().set_visible(False)
            axs[i, j].set_aspect('equal')

    plt.subplots_adjust(wspace=0.0, hspace=0.0)

    if not os.path.exists(os.path.dirname(sPath)):
        os.makedirs(os.path.dirname(sPath))
    plt.savefig(sPath, dpi=300)
    plt.show()
    plt.close()


########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
########################
