import argparse
import torch
import os
import numpy as np
import numpy.linalg as la
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from sklearn.metrics import davies_bouldin_score
import time
from utils import utils, cca
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from matplotlib.widgets import Slider, Button, RadioButtons
from tqdm import tqdm
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib.tri as tri
from sklearn.manifold import SpectralEmbedding, TSNE
from scipy.sparse.csgraph import connected_components


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=200)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='da_10906555', help='The dataset to train the network on.')
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[8, 64], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=2, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=3.0, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=1)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
# parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='trial_rad')
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='cont_traj_kl')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)
dir_dataset = os.path.join(args.dir_data, args.dataset)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * 2 for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * 2 for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * 2 for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * 2 for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

# trans_dict_name = 'latent_transform_orthoproc_edge_ls16_monet.npz'
trans_dir = 'transforms/%s' % args.dataset.lower()
trans_dict_name = '%s/latent_transform_orthoproc_edge_ls%02d_%s_s%02d.npz' % (trans_dir, sum(args.n_latent),
                                                                              args.conv_type, args.seed)

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end()
#train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split()
# train_data_init = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]
# rad = [val for pair in zip(*[rad] * blocks) for val in pair]
dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param)

criterion = utils.loss_func_mse

with torch.no_grad():
    latent_data_int_f = None
    latent_sigma_int_f = None
    latent_data_int_c = None
    latent_sigma_int_c = None
    latent_data_ext = None
    latent_sigma_ext = None
    latent_labels_id = []
    for signal_int_f, signal_ext, signal_int_c, labels, _ in tqdm(test_data):
        signal_int_f = signal_int_f.to(device)
        signal_int_c = signal_int_c.to(device)
        signal_ext = signal_ext.to(device)
        _, mu_int_f, _, sigma_int_f = model.inference_net[0](signal_int_f)
        _, mu_int_c, _, sigma_int_c = model.inference_net[1](signal_int_c)
        _, mu_ext, _, sigma_ext = model.inference_net[2](signal_ext)
        latent_labels_id += labels
        if latent_data_int_f is None:
            latent_data_int_f = mu_int_f.cpu().numpy()
            latent_sigma_int_f = sigma_int_f.cpu().numpy()
            latent_data_int_c = mu_int_c.cpu().numpy()
            latent_sigma_int_c = sigma_int_c.cpu().numpy()
            latent_data_ext = mu_ext.cpu().numpy()
            latent_sigma_ext = sigma_ext.cpu().numpy()
        else:
            latent_data_int_f = np.concatenate([latent_data_int_f, mu_int_f.cpu().numpy()], axis=0)
            latent_sigma_int_f = np.concatenate([latent_sigma_int_f, sigma_int_f.cpu().numpy()], axis=0)
            latent_data_int_c = np.concatenate([latent_data_int_c, mu_int_c.cpu().numpy()], axis=0)
            latent_sigma_int_c = np.concatenate([latent_sigma_int_c, sigma_int_c.cpu().numpy()], axis=0)
            latent_data_ext = np.concatenate([latent_data_ext, mu_ext.cpu().numpy()], axis=0)
            latent_sigma_ext = np.concatenate([latent_sigma_ext, sigma_ext.cpu().numpy()], axis=0)

# PCA Sample
print('Computing PCA')
# pca_idx = np.random.choice(train_data.dataset.indices, int(train_data.dataset.indices.shape[0] ** 0.5
#                                                            * np.log2(train_data.dataset.indices.shape[0])))
pca_idx = np.random.choice(train_data.dataset.indices, int(train_data.dataset.indices.shape[0] * 0.2))
# pca_samp = train_data.dataset.dataset.coords[pca_idx, :, 1::4]
pca_samp = train_data.dataset.dataset.edge_length[pca_idx]
pca_samp = np.reshape(pca_samp, [pca_samp.shape[0], -1])
pca_samp = (pca_samp - np.mean(pca_samp, axis=0, keepdims=True)) / np.std(pca_samp, axis=0, keepdims=True)

u_pca, _, _ = la.svd(pca_samp, full_matrices=False)
print('PCA computed.')

# z_score_0 = (latent_data - np.mean(latent_data, axis=0)) # / np.std(latent_data, axis=0)

# Properties
# 0  Molecular Weight                   1  XLogP3
# 2  Hydrogen Bond Donor Count          3  Hydrogen Bond Acceptor Count
# 4  Rotatable Bond Count               5  Topological Polar Surface Area
# 6  Heavy Atom Count                   7  Formal Charge  x
# 8  Complexity                         9  Isotope Atom Count  x
# 10 Defined Atom Stereocenter Count    11 Undefined Atom Stereocenter Count
# 12 Defined Bond Stereocenter Count    13 Undefined Bond Stereocenter Count x
# 14 Covalently-Bonded Unit Count
properties = np.loadtxt(dir_dataset + '/chem_properties.csv', delimiter=',')
# log_idx = [5] #0 5 8
# properties[:, log_idx] = np.log(properties[:, log_idx])

properties_nrm = (properties - np.mean(properties, axis=0, keepdims=True)) \
             / (np.std(properties, axis=0, keepdims=True) + 1E-12)
# idx = [0, 1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 14]
# idx = [0, 2, 3, 5, 6, 8]
# idx = [0, 2, 3, 5, 6, 8]
# idx = [0, 2, 3, 4, 5, 6, 8, 10, 11]
idx = [0, 2, 3, 4, 5, 6, 8]
properties_nrm = properties[:, idx]
properties_nrm = properties_nrm[latent_labels_id]
n_properties = properties.shape[1]


# latent_unique = (latent_unique - np.mean(latent_unique, axis=0, keepdims=True)) \
#              / (np.std(latent_unique, axis=0, keepdims=True) + 1E-12)
latent_data_int_c = latent_data_int_c - np.mean(latent_data_int_c, axis=0, keepdims=True)


def compute_cca(x, y):
    u_x, s_x, vh_x = la.svd(x, full_matrices=False)
    u_y, s_y, vh_y = la.svd(y, full_matrices=False)
    print('u_x, u_y shapes', u_x.shape, u_y.shape)
    covar = np.matmul(np.transpose(u_x), u_y)
    u_xy, s_xy, vh_xy = la.svd(covar, full_matrices=False)

    a_cca = np.matmul(np.transpose(vh_x), np.diag(1 / s_x))
    print('a_cca shape', a_cca.shape)
    a_cca = np.matmul(a_cca, u_xy)
    b_cca = np.matmul(np.transpose(vh_y), np.diag(1 / s_y))
    b_cca = np.matmul(b_cca, np.transpose(vh_xy))

    print('Canonical correlations:', s_xy * 100)
    print('Canonical variances:', s_xy ** 2 * 100)
    print('X canonical variable 1', a_cca[:, 0])
    print('Chemical properties canonical variable 1', b_cca[:, 0])

    latent_embed = np.matmul(u_x, u_xy)
    properties_embed = np.matmul(u_y, np.transpose(vh_xy))
    return latent_embed, properties_embed


print('Latent and Chemical Properties CCA')
latent_embed, properties_embed = compute_cca(latent_data_int_c, properties_nrm)
print('PCA and Chemical Properties CCA')
prop_pca_idx = properties_nrm[train_data.dataset.dataset.labels[pca_idx]]
pca_embed, properties_embed_pca = compute_cca(u_pca[:, :latent_data_int_c.shape[1]], prop_pca_idx)

# print('Ah Xh X A', np.matmul(np.transpose(latent_embed), latent_embed))
# print('Bh Yh Y B', np.matmul(np.transpose(properties_embed), properties_embed))
# print('Ah Xh Y B', np.matmul(np.transpose(latent_embed), properties_embed))

prop_viz = 0

s = 10
cmap = 'nipy_spectral'
# cmap = 'jet'
latent_labels_id = np.array(latent_labels_id)
colors_prop = properties[latent_labels_id][:, prop_viz]
colors = latent_labels_id
colors_a = latent_embed[:, 2]
colors_b = properties_embed[:, 2]
fig, ax = plt.subplots(1, 2, figsize=[12.8, 4.8])
scatter_lat = ax[0].scatter(latent_embed[:, 0], latent_embed[:, 1], c=colors_prop, cmap=cmap, s=s)
ax[0].set_xlabel('Canonical Variable 1')
ax[0].set_ylabel('Canonical Variable 2')
ax[0].set_title('Latent Embedding')

scatter_prop = ax[1].scatter(properties_embed[:, 0], properties_embed[:, 1], c=colors_prop, cmap=cmap,
                             s=s)
ax[1].set_xlabel('Canonical Variable 1')
ax[1].set_ylabel('Canonical Variable 2')
ax[1].set_title('Chemical Properties Embedding')

plt.colorbar(scatter_lat)
# plt.colorbar(scatter_prop)

plt.show()

print('Done')
