import argparse
import torch
import os
import numpy as np
import numpy.linalg as la
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from sklearn.metrics import davies_bouldin_score
import time
from utils import utils, cca
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from matplotlib.widgets import Slider, Button, RadioButtons
from tqdm import tqdm
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib.tri as tri
from sklearn.manifold import SpectralEmbedding, TSNE
from scipy.sparse.csgraph import connected_components


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=32)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='ace2', help='The dataset to train the network on.')  # da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=1)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
# parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='trial_rad')
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_fixed_4head')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
#train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split()
# train_data_init = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]
# rad = [val for pair in zip(*[rad] * blocks) for val in pair]
dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param)

criterion = utils.loss_func_mse

with torch.no_grad():
    latent_data_int_f = None
    latent_sigma_int_f = None
    latent_data_int_c = None
    latent_sigma_int_c = None
    latent_data_ext = None
    latent_sigma_ext = None
    latent_labels_id = []
    for signal_int_fine, signal_ext, signal_int_course, labels, _ in tqdm(test_data):
        signal_int_fine = signal_int_fine.to(device)
        signal_int_course = signal_int_course.to(device)
        signal_ext = signal_ext.to(device)
        _, mu_int_f, _, sigma_int_f = model.inference_net[0](signal_int_fine)
        _, mu_int_c, _, sigma_int_c = model.inference_net[1](signal_int_course)
        _, mu_ext, _, sigma_ext = model.inference_net[2](signal_ext)
        latent_labels_id += labels
        if latent_data_int_f is None:
            latent_data_int_f = mu_int_f.cpu().numpy()
            latent_sigma_int_f = sigma_int_f.cpu().numpy()
            latent_data_int_c = mu_int_c.cpu().numpy()
            latent_sigma_int_c = sigma_int_c.cpu().numpy()
            latent_data_ext = mu_ext.cpu().numpy()
            latent_sigma_ext = sigma_ext.cpu().numpy()
        else:
            latent_data_int_f = np.concatenate([latent_data_int_f, mu_int_f.cpu().numpy()], axis=0)
            latent_sigma_int_f = np.concatenate([latent_sigma_int_f, sigma_int_f.cpu().numpy()], axis=0)
            latent_data_int_c = np.concatenate([latent_data_int_c, mu_int_c.cpu().numpy()], axis=0)
            latent_sigma_int_c = np.concatenate([latent_sigma_int_c, sigma_int_c.cpu().numpy()], axis=0)
            latent_data_ext = np.concatenate([latent_data_ext, mu_ext.cpu().numpy()], axis=0)
            latent_sigma_ext = np.concatenate([latent_sigma_ext, sigma_ext.cpu().numpy()], axis=0)
# z_score_0 = (latent_data - np.mean(latent_data, axis=0)) # / np.std(latent_data, axis=0)
z_score_0_int_f = latent_data_int_f - np.mean(latent_data_int_f, axis=0, keepdims=True)
z_score_0_int_f /= np.std(z_score_0_int_f, axis=0, keepdims=True)
print('Mean (Intrinsic)', np.mean(latent_data_int_f, axis=0))
print('Sigma (Intrinsic)', np.mean(latent_sigma_int_f, axis=0))

z_score_0_int_c = latent_data_int_c - np.mean(latent_data_int_c, axis=0, keepdims=True)
z_score_0_int_c /= np.std(z_score_0_int_c, axis=0, keepdims=True)
print('Mean (Intrinsic)', np.mean(latent_data_int_c, axis=0))
print('Sigma (Intrinsic)', np.mean(latent_sigma_int_c, axis=0))

z_score_0_ext = latent_data_ext - np.mean(latent_data_ext, axis=0, keepdims=True)
z_score_0_ext /= np.std(z_score_0_ext, axis=0, keepdims=True)
print('Mean (Extrinsic)', np.mean(latent_data_ext, axis=0))
print('Sigma (Extrinsic)', np.mean(latent_sigma_ext, axis=0))

n_comp = latent_data_int_f.shape[1]
pca = PCA(n_components=n_comp)
principalComponents0_int_f = pca.fit_transform(z_score_0_int_f)
print('XYZ (Intrinsic Fine)', np.cumsum(pca.explained_variance_ratio_))

principalComponents0_int_c = pca.fit_transform(z_score_0_int_c)
print('XYZ (Intrinsic Course)', np.cumsum(pca.explained_variance_ratio_))

principalComponents0_ext = pca.fit_transform(z_score_0_ext)
print('XYZ (Extrinsic)', np.cumsum(pca.explained_variance_ratio_))

colors_id = cm.rainbow(np.linspace(0, 1, len(np.unique(latent_labels_id))))

# d_id_ = dict([(y, x) for x, y in enumerate(sorted(set(latent_labels_id)))])
# d_id = [d_id_[x] for x in latent_labels_id]

d_id = np.array(latent_labels_id)
idx = np.argsort(d_id)
d_id = d_id[idx]
principalComponents0_int_f = principalComponents0_int_f[idx]
principalComponents0_int_c = principalComponents0_int_c[idx]
principalComponents0_ext = principalComponents0_ext[idx]

# Try a spectral embedding
# xyunique, idxunique = np.unique(np.half(latent_data_int[idx]), axis=0, return_index=True)
# xyunique = (xyunique - np.mean(xyunique, axis=0, keepdims=True)) / np.std(xyunique, axis=0, keepdims=True)
# embedding = SpectralEmbedding(n_components=2, n_neighbors=4)
# embedding = TSNE(n_components=2, perplexity=5)
# principalComponents0_int = embedding.fit_transform(xyunique)

# cmap = 'jet'
cmap = 'nipy_spectral'
# cmap = 'tab20'
fig, ax = plt.subplots(1, 3, figsize=[6.4 * 3, 4.8])
# ax1 = ax[0]
# ax2 = ax[1]
# plt.subplots_adjust(bottom=0.2)
# plt.subplot(1, 2, 1)
# ax1 = fig.add_subplot(121)
ax[0].set_xlabel('Principal Component 1')
ax[0].set_ylabel('Principal Component 2')
ax[0].set_title('Latent PCA, Fine Intrinsic')

axcolor = 'lightgoldenrodyellow'
# axfreq = plt.axes([0.25, 0.05, 0.65, 0.03], facecolor=axcolor)
# sfreq = Slider(axfreq, 'Drug class', np.min(d_id), np.max(d_id), valinit=0, valstep=1)
tbox = plt.text(0.25, 0.05, "Drug Class", fontsize=11, transform=ax[0].transAxes)

spread = np.loadtxt('/home/laigpu/JoeyProjects/ugd_protein/data/da_10906555/chem_properties.csv', delimiter=',')
spread_var = 10 # 1, 4, 5 #2, 3, 8, ignore 7, 9, 11, 12, 13, 14
s = 10
cmap = 'jet'
idx = d_id == 0
# scat_int = ax1.scatter(principalComponents0_int[:, 0], principalComponents0_int[:, 1], color='blue', s=s)
# scat_int_f = ax[0].scatter(principalComponents0_int_f[:, 0], principalComponents0_int_f[:, 1], c=spread[d_id, spread_var], cmap=cmap, s=4*s)
scat_int_f = ax[0].scatter(principalComponents0_int_f[:, 0], principalComponents0_int_f[:, 1], c=d_id, cmap=cmap, s=4*s)

# xyunique, idxunique = np.unique(np.half(principalComponents0_int[:, :2]), axis=0, return_index=True)
# d_idunique = d_id[idxunique]
# triang = tri.Triangulation(xyunique[:, 0], xyunique[:, 1])
# interpolator = tri.LinearTriInterpolator(triang, spread[d_idunique, spread_var])
# Xi, Yi = np.meshgrid(xyunique[:, 0], xyunique[:, 1])
# zi = interpolator(Xi, Yi)
# scat_int = ax1.contourf(xyunique[:, 0], xyunique[:, 1], zi)

# ax1.scatter(principalComponents0_int[idx, 0], principalComponents0_int[idx, 1], c=d_id[idx], cmap=cmap, s=s)
# plt.tricontour(principalComponents0_int[:, 0], principalComponents0_int[:, 1], principalComponents0_int[:, 2], 15, linewidths=0.5, cmap=cmap)
# ax1 = fig.add_subplot(121, projection='3d')
# ax1.scatter(principalComponents0_int[:, 0], principalComponents0_int[:, 1], principalComponents0_int[:, 2],  c=d_id,
#             cmap=cmap, s=s)
# plt.colorbar()

ax[1].set_xlabel('Principal Component 1')
ax[1].set_ylabel('Principal Component 2')
ax[1].set_title('Latent PCA, Course Intrinsic')
# scat_int_c = ax[1].scatter(principalComponents0_int_c[:, 0], principalComponents0_int_c[:, 1],
#                            c=spread[d_id, spread_var], cmap=cmap, s=s)
scat_int_c = ax[1].scatter(principalComponents0_int_c[:, 0], principalComponents0_int_c[:, 1],
                           c=d_id, cmap=cmap, s=s)

# scat_ext = ax2.scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], color='blue', s=s)
# scat_ext = ax[2].scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], c=spread[d_id, spread_var], cmap=cmap, s=s)
scat_ext = ax[2].scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], c=d_id, cmap=cmap, s=s)
# plt.subplot(1, 2, 2)
# ax2 = fig.add_subplot(122)
ax[2].set_xlabel('Principal Component 1')
ax[2].set_ylabel('Principal Component 2')
ax[2].set_title('Latent PCA, Extrinsic')

plt.colorbar(scat_int_f)
plt.show()

def update(val):
    # val1 = sfreq.val
    # val2 = sfreq2.val
    idx_update = d_id == val
    scat_int_f.set_offsets(np.stack([principalComponents0_int_f[idx_update, 0],
                                     principalComponents0_int_f[idx_update, 1]], axis=-1))
    scat_int_c.set_offsets(np.stack([principalComponents0_int_c[idx_update, 0],
                                     principalComponents0_int_c[idx_update, 1]], axis=-1))
    scat_ext.set_offsets(np.stack([principalComponents0_ext[idx_update, 0],
                                   principalComponents0_ext[idx_update, 1]], axis=-1))
    # path1 = ax1.scatter(principalComponents0_int[:, 0], principalComponents0_int[:, 1], color='white', s=s, linewidths=0)
    # path1 = ax1.scatter(principalComponents0_int[idx_update, 0], principalComponents0_int[idx_update, 1],
    #                     c=d_id[idx_update], cmap=cmap, s=s, linewidths=0)
    # path2 = ax2.scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], color='white', s=s, linewidths=0)
    # path2 = ax2.scatter(principalComponents0_ext[idx_update, 0], principalComponents0_ext[idx_update, 1],
    #                     c=d_id[idx_update], cmap=cmap, s=s, linewidths=0)
    # fig.canvas.draw_idle()
    plt.draw()

def update2(val):
    # val1 = sfreq.val
    # val2 = sfreq2.val
    idx_update = d_id == val
    scat_int_f.set_offsets(np.stack([principalComponents0_int_f[idx_update, 0],
                                     principalComponents0_int_f[idx_update, 1]], axis=-1))
    scat_int_c.set_offsets(np.stack([principalComponents0_int_c[idx_update, 0],
                                     principalComponents0_int_c[idx_update, 1]], axis=-1))
    scat_ext.set_offsets(np.stack([principalComponents0_ext[idx_update, 0],
                                   principalComponents0_ext[idx_update, 1]], axis=-1))
    # path1 = ax1.scatter(principalComponents0_int[:, 0], principalComponents0_int[:, 1], color='white', s=s, linewidths=0)
    # path1 = ax1.scatter(principalComponents0_int[idx_update, 0], principalComponents0_int[idx_update, 1],
    #                     c=d_id[idx_update], cmap=cmap, s=s, linewidths=0)
    # path2 = ax2.scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], color='white', s=s, linewidths=0)
    # path2 = ax2.scatter(principalComponents0_ext[idx_update, 0], principalComponents0_ext[idx_update, 1],
    #                     c=d_id[idx_update], cmap=cmap, s=s, linewidths=0)
    # fig.canvas.draw_idle()
    # sfreq = Slider(axfreq, 'Drug class', np.min(d_id), np.max(d_id), valinit=val, valstep=1)
    tbox.set_text("Drug Class % 02d" % val)
    return scat_int_f, scat_int_c, scat_ext, tbox

# ax2.scatter(principalComponents0_ext[idx, 0], principalComponents0_ext[idx, 1], c=d_id[idx], cmap=cmap, s=s)
# plt.tricontour(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], principalComponents0_ext[:, 2], 15, linewidths=0.5, cmap=cmap)
# ax2 = fig.add_subplot(122, projection='3d')
# ax2.scatter(principalComponents0_ext[:, 0], principalComponents0_ext[:, 1], principalComponents0_ext[:, 2], c=d_id,
#             cmap=cmap, s=s)
# plt.colorbar()

# sfreq.on_changed(update)
# sfreq2.on_changed(update)


ani = FuncAnimation(fig, update2, frames=50, blit=True)
Writer = animation.writers['ffmpeg']
writer = Writer(fps=2)
# ani.save('pca_drug_class.mp4', writer=writer)
# plt.tight_layout()
plt.show()

print('Done')
