import argparse
import torch
import os
import numpy as np
import numpy.linalg as la
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from sklearn.metrics import davies_bouldin_score
import time
from utils import utils, cca
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from matplotlib.widgets import Slider, Button, RadioButtons
from tqdm import tqdm
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib.tri as tri
from sklearn.manifold import SpectralEmbedding, TSNE
from scipy.sparse.csgraph import connected_components
from mpl_toolkits.axes_grid1 import make_axes_locatable


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=32)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='da_10906555', help='The dataset to train the network on.')  # da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=0)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
# parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='trial_rad')
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')  #fixed_4head
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
#train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split()
# train_data_init = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]
# rad = [val for pair in zip(*[rad] * blocks) for val in pair]
dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param)

criterion = utils.loss_func_mse

print('Computing latent embeddings')
with torch.no_grad():
    latent_data_int_f = None
    latent_sigma_int_f = None
    latent_data_int_c = None
    latent_sigma_int_c = None
    latent_data_ext = None
    latent_sigma_ext = None
    latent_labels_id = []
    for signal_int_fine, signal_ext, signal_int_course, labels, pts in tqdm(test_data):
        signal_int_fine = signal_int_fine.to(device)
        signal_int_course = signal_int_course.to(device)
        signal_ext = signal_ext.to(device)
        # pts = pts.to(device)
        _, mu_int_f, _, sigma_int_f = model.inference_net[0](signal_int_fine)
        _, mu_int_c, _, sigma_int_c = model.inference_net[1](signal_int_course)
        _, mu_ext, _, sigma_ext = model.inference_net[2](signal_ext)
        # z = torch.cat([mu_int_c, mu_ext])
        # pts_norm = model.generative_net(z)
        # pts_gen = protein_dataset.unnormalize_coordinates_torch(pts_norm)
        # pts = protein_dataset.unnormalize_coordinates_torch(pts)
        # bonded_distance

        latent_labels_id += labels
        if latent_data_int_f is None:
            latent_data_int_f = mu_int_f.cpu().numpy()
            latent_sigma_int_f = sigma_int_f.cpu().numpy()
            latent_data_int_c = mu_int_c.cpu().numpy()
            latent_sigma_int_c = sigma_int_c.cpu().numpy()
            latent_data_ext = mu_ext.cpu().numpy()
            latent_sigma_ext = sigma_ext.cpu().numpy()
        else:
            latent_data_int_f = np.concatenate([latent_data_int_f, mu_int_f.cpu().numpy()], axis=0)
            latent_sigma_int_f = np.concatenate([latent_sigma_int_f, sigma_int_f.cpu().numpy()], axis=0)
            latent_data_int_c = np.concatenate([latent_data_int_c, mu_int_c.cpu().numpy()], axis=0)
            latent_sigma_int_c = np.concatenate([latent_sigma_int_c, sigma_int_c.cpu().numpy()], axis=0)
            latent_data_ext = np.concatenate([latent_data_ext, mu_ext.cpu().numpy()], axis=0)
            latent_sigma_ext = np.concatenate([latent_sigma_ext, sigma_ext.cpu().numpy()], axis=0)
# z_score_0 = (latent_data - np.mean(latent_data, axis=0)) # / np.std(latent_data, axis=0)
z_score_0_int_f = latent_data_int_f - np.mean(latent_data_int_f, axis=0, keepdims=True)
# z_score_0_int_f /= np.std(z_score_0_int_f, axis=0, keepdims=True)
print('Mean (Intrinsic)', np.mean(latent_data_int_f, axis=0))
print('Sigma (Intrinsic)', np.mean(latent_sigma_int_f, axis=0))

latent_int_c_mu = np.mean(latent_data_int_c, axis=0, keepdims=True)
latent_int_c_std = np.std(latent_data_int_c, axis=0, keepdims=True)
z_score_0_int_c = latent_data_int_c - latent_int_c_mu
# z_score_0_int_c /= latent_int_c_std
print('Mean (Intrinsic)', np.mean(latent_data_int_c, axis=0))
print('Sigma (Intrinsic)', np.mean(latent_sigma_int_c, axis=0))

latent_ext_mu = np.mean(latent_data_ext, axis=0, keepdims=True)
latent_ext_std = np.std(latent_data_ext, axis=0, keepdims=True)
z_score_0_ext = latent_data_ext - np.mean(latent_data_ext, axis=0, keepdims=True)
# z_score_0_ext /= np.std(z_score_0_ext, axis=0, keepdims=True)
print('Mean (Extrinsic)', np.mean(latent_data_ext, axis=0))
print('Sigma (Extrinsic)', np.mean(latent_sigma_ext, axis=0))

n_comp = latent_data_int_f.shape[1]
# u_int_f, s_int_f, vh_int_f = np.linalg.svd(z_score_0_int_f, full_matrices=False)
# pca = PCA(n_components=n_comp)
# principalComponents0_int_f = pca.fit_transform(z_score_0_int_f)
# print('XYZ (Intrinsic Fine)', np.cumsum(pca.explained_variance_ratio_))

# principalComponents0_int_c = pca.fit_transform(z_score_0_int_c)
# print('XYZ (Intrinsic Course)', np.cumsum(pca.explained_variance_ratio_))

#CCA
# u_x, s_x, vh_x = la.svd(z_score_0_int_c, full_matrices=False)
# u_y, s_y, vh_y = la.svd(z_score_0_ext, full_matrices=False)
# covar = np.matmul(np.transpose(u_x), u_y)
# u_xy, s_xy, vh_xy = la.svd(covar, full_matrices=False)
#
# a_cca = np.matmul(np.transpose(vh_x), np.diag(1 / s_x))
# a_cca = np.matmul(a_cca, u_xy)
# b_cca = np.matmul(np.transpose(vh_y), np.diag(1 / s_y))
# b_cca = np.matmul(b_cca, np.transpose(vh_xy))
#
# int_embed = np.matmul(u_x, u_xy)
# ext_embed = np.matmul(u_y, np.transpose(vh_xy))

#PCA
u_int_c, s_int_c, vh_int_c = np.linalg.svd(z_score_0_int_c, full_matrices=False)

# principalComponents0_ext = pca.fit_transform(z_score_0_ext)
# print('XYZ (Extrinsic)', np.cumsum(pca.explained_variance_ratio_))
u_ext, s_ext, vh_ext = np.linalg.svd(z_score_0_ext, full_matrices=False)

x = np.linspace(u_int_c[:, :2].min() * 0.95, u_int_c[:, :2].max() * 0.95, 50)
y = np.linspace(u_ext[:, :2].min() * 0.95, u_ext[:, :2].max() * 0.95, 50)
# x = np.linspace(int_embed.min() * 1.1, int_embed.max() * 1.1, 50)
# y = np.linspace(ext_embed.min() * 1.1, ext_embed.max() * 1.1, 50)

xv, yv = np.meshgrid(x, y)
pca_scores = np.stack([xv.flatten(), yv.flatten()]).transpose()


def get_bonded_dist(pca_int, pca_ext):
    bonded_dist = np.zeros(pca_int.shape[0])
    nonbonded_dist = np.zeros(pca_int.shape[0])
    eq = protein_dataset.equilibrium_coords[0]
    eq = torch.tensor(eq, device=device)
    eq_coords_diff = eq[:, protein_dataset.bond_idx1] - eq[:, protein_dataset.bond_idx0]
    eq_square = torch.matmul(eq.t(), eq)
    eq_diag = torch.diag(eq_square)
    eq_edge_lens = torch.norm(eq_coords_diff, dim=0, keepdim=True)

    mu_ca = torch.tensor(protein_dataset.ca_bonds_stats['mu'], device=device)
    std_ca = torch.tensor(protein_dataset.ca_bonds_stats['std'], device=device)
    eq_ca_lens = protein_dataset.compute_ca_signal(eq.unsqueeze(0))
    eq_ca_lens = eq_ca_lens * std_ca + mu_ca
    model.eval()
    with torch.no_grad():
        for idx, (l_int, l_ext) in enumerate(zip(tqdm(pca_int), pca_ext)):
            l_int = torch.tensor(l_int, device=device, dtype=torch.float32)
            l_ext = torch.tensor(l_ext, device=device, dtype=torch.float32)
            l_int = l_int.unsqueeze(0)
            l_ext = l_ext.unsqueeze(0)
            z = torch.cat([l_int, l_ext], dim=-1)
            pts_norm = model.generative_net(z)
            pts_gen = protein_dataset.unnormalize_coordinates_torch(pts_norm)
            gen_coords_diff = pts_gen[:, :, protein_dataset.bond_idx1] - pts_gen[:, :, protein_dataset.bond_idx0]
            gen_edge_lens = torch.norm(gen_coords_diff, dim=1, keepdim=True)
            gen_edge_lens = gen_edge_lens.squeeze(0)
            z = torch.sum(torch.abs(eq_edge_lens - gen_edge_lens)) / eq_edge_lens.shape[-1]
            z_np = z.cpu().numpy()
            # bonded_dist[idx] = z_np

            gen_ca_lens = protein_dataset.compute_ca_signal(pts_gen)
            gen_ca_lens = gen_ca_lens * std_ca + mu_ca
            ca_z = torch.sum(torch.abs(gen_ca_lens - eq_ca_lens)) / torch.sum(gen_ca_lens != 0)
            ca_z = ca_z.cpu().numpy()
            bonded_dist[idx] = ca_z

            pts_square = torch.matmul(pts_gen.transpose(-1, -2), pts_gen)
            pts_diag = torch.diagonal(pts_square, dim1=-2, dim2=-1)
            # eq_pts = torch.matmul(eq.t(), pts_gen)
            eq_dist_mat = eq_diag.unsqueeze(1) + eq_diag.unsqueeze(0) - 2 * eq_square
            eq_dist_mat = eq_dist_mat ** 0.5
            pts_dist_mat = pts_diag.unsqueeze(2) + pts_diag.unsqueeze(1) - 2 * pts_square
            pts_dist_mat = torch.clamp(pts_dist_mat, min=0)
            pts_dist_mat = pts_dist_mat ** 0.5
            nb_dist = torch.sum(torch.abs(eq_dist_mat - pts_dist_mat))
            nb_dist = nb_dist - 2 * eq_edge_lens.shape[-1] * z
            nb_dist /= (eq.shape[-1] ** 2 - eq.shape[-1] - 2 * eq_edge_lens.shape[-1])
            nb_dist = nb_dist.cpu().numpy()
            if np.isnan(nb_dist):
                print('Error; nonbonded distance is nan')
            nonbonded_dist[idx] = nb_dist
    bonded_dist = np.reshape(bonded_dist, (x.shape[0], y.shape[0]))
    nonbonded_dist = np.reshape(nonbonded_dist, (x.shape[0], y.shape[0]))
    return bonded_dist, nonbonded_dist
# eq_square.unsqueeze(1) + pts_square.unsqueeze(1) - 2 * eq_pts


pca_int_grid = np.matmul(np.matmul(pca_scores, np.diag(s_int_c[:2])), vh_int_c[:2])
pca_ext_grid = np.matmul(np.matmul(pca_scores, np.diag(s_ext[:2])), vh_ext[:2])
# pca_int_grid = np.matmul(pca_scores, la.pinv(a_cca)[:2])
# pca_ext_grid = np.matmul(pca_scores, la.pinv(b_cca)[:2])

# pca_int_grid0 = pca_int_grid * latent_int_c_std + latent_int_c_mu
# pca_ext_grid0 = np.zeros_like(pca_ext_grid) * latent_ext_std + latent_ext_mu
pca_int_grid0 = pca_int_grid + latent_int_c_mu
pca_ext_grid0 = np.zeros_like(pca_ext_grid) + latent_ext_mu
bonded_dist_int, nonbonded_dist_int = get_bonded_dist(pca_int_grid0, pca_ext_grid0)

# pca_int_grid1 = np.zeros_like(pca_int_grid) * latent_int_c_std + latent_int_c_mu
# pca_ext_grid1 = pca_ext_grid * latent_ext_std + latent_ext_mu
pca_int_grid1 = np.zeros_like(pca_int_grid) + latent_int_c_mu
pca_ext_grid1 = pca_ext_grid + latent_ext_mu
bonded_dist_ext, nonbonded_dist_ext = get_bonded_dist(pca_int_grid1, pca_ext_grid1)

pca_int_grid = np.matmul(np.matmul(pca_scores[:, :1], np.diag(s_int_c[:1])),
                         vh_int_c[:1])
pca_ext_grid = np.matmul(np.matmul(pca_scores[:, 1:], np.diag(s_ext[:1])),
                         vh_ext[:1])
# pca_int_grid = np.matmul(pca_scores[:, :1], la.pinv(a_cca)[:1])
# pca_ext_grid = np.matmul(pca_scores[:, :1], la.pinv(b_cca)[:1])

# pca_int_grid2 = pca_int_grid * latent_int_c_std + latent_int_c_mu
# pca_ext_grid2 = pca_ext_grid * latent_ext_std + latent_ext_mu
pca_int_grid2 = pca_int_grid + latent_int_c_mu
pca_ext_grid2 = pca_ext_grid + latent_ext_mu
bonded_dist_intext, nonbonded_dist_intext = get_bonded_dist(pca_int_grid2, pca_ext_grid2)

vmin_bd = np.min([bonded_dist_int, bonded_dist_ext, bonded_dist_intext])
vmax_bd = np.max([bonded_dist_int, bonded_dist_ext, bonded_dist_intext])
vmin_nd = np.min([nonbonded_dist_int, nonbonded_dist_ext, nonbonded_dist_intext])
vmax_nd = np.max([nonbonded_dist_int, nonbonded_dist_ext, nonbonded_dist_intext])

fig, ax = plt.subplots(1, 3, figsize=[6.4 * 3, 4.8])
# contour_int = ax[0].contourf(xv, yv, bonded_dist_int, cmap='jet')
interpolation = 'bicubic'
contour_int = ax[0].imshow(bonded_dist_int, cmap='jet', interpolation=interpolation, origin='lower',
                           extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_bd, vmax=vmax_bd)
# fig.colorbar(contour_int, ax=ax[0])
ax[0].set_xlabel('First Principal Component', fontsize='large')
ax[0].set_ylabel('Second Principal Component', fontsize='large')
ax[0].set_title('Bonded Distance; Intrinsic', fontsize='x-large')

# contour_ext = ax[1].contourf(xv, yv, bonded_dist_ext, cmap='jet')
contour_ext = ax[1].imshow(bonded_dist_ext, cmap='jet', interpolation=interpolation, origin='lower',
                           extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_bd, vmax=vmax_bd)
# fig.colorbar(contour_ext, ax=ax[1])
ax[1].set_xlabel('First Principal Component', fontsize='large')
ax[1].set_ylabel('Second Principal Component', fontsize='large')
ax[1].set_title('Bonded Distance; Extrinsic', fontsize='x-large')

# contour_intext = ax[2].contourf(xv, yv, bonded_dist_intext, cmap='jet')
contour_intext = ax[2].imshow(bonded_dist_intext, cmap='jet', interpolation=interpolation, origin='lower',
                              extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_bd, vmax=vmax_bd)

divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(contour_intext, cax=cax)
# fig.colorbar(contour_intext, ax=ax[2])
ax[2].set_xlabel('First Intrinsic Principal Component', fontsize='large')
ax[2].set_ylabel('First Extrinsic Principal Component', fontsize='large')
ax[2].set_title('Bonded Distance; Intrinsic/Extrinsic', fontsize='x-large')

plt.tight_layout()

# Nonbonded Distance
fig2, ax2 = plt.subplots(1, 3, figsize=[6.4 * 3, 4.8])
contour_int_nb = ax2[0].imshow(nonbonded_dist_int, cmap='jet', interpolation=interpolation, origin='lower',
                               extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_nd, vmax=vmax_nd)
# fig2.colorbar(contour_int_nb, ax=ax2[0])
ax2[0].set_xlabel('First Principal Component', fontsize='large')
ax2[0].set_ylabel('Second Principal Component', fontsize='large')
ax2[0].set_title('Nonbonded Distance; Intrinsic', fontsize='x-large')

# contour_ext = ax[1].contourf(xv, yv, bonded_dist_ext, cmap='jet')
contour_ext_nb = ax2[1].imshow(nonbonded_dist_ext, cmap='jet', interpolation=interpolation, origin='lower',
                               extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_nd, vmax=vmax_nd)
# fig2.colorbar(contour_ext_nb, ax=ax2[1])
ax2[1].set_xlabel('First Principal Component', fontsize='large')
ax2[1].set_ylabel('Second Principal Component', fontsize='large')
ax2[1].set_title('Nonbonded Distance; Extrinsic', fontsize='x-large')

# contour_intext = ax[2].contourf(xv, yv, bonded_dist_intext, cmap='jet')
contour_intext_nb = ax2[2].imshow(nonbonded_dist_intext, cmap='jet', interpolation=interpolation, origin='lower',
                                  extent=(xv.min(), xv.max(), yv.min(), yv.max()), vmin=vmin_nd, vmax=vmax_nd)

ax2[2].set_xlabel('First Intrinsic Principal Component', fontsize='large')
ax2[2].set_ylabel('First Extrinsic Principal Component', fontsize='large')
ax2[2].set_title('Nonbonded Distance; Intrinsic/Extrinsic', fontsize='x-large')
divider = make_axes_locatable(ax2[2])
cax2 = divider.append_axes("right", size="5%", pad=0.05)
fig2.colorbar(contour_intext_nb, cax=cax2)

# fig2.colorbar(contour_intext_nb, ax=ax2[2])

plt.tight_layout()
plt.show()

print('Done')
