import argparse
import torch
import os
import numpy as np
import time
from utils import utils
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from tqdm import tqdm
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib
from scipy.sparse.csgraph import connected_components
import vtkplotter as vtk


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='da_10906555', help='The dataset to train the network on.') #da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=0)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:1" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
# train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
#                                          pin_memory=True)
# val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
#                                        pin_memory=True)
# test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
#                                         pin_memory=True)

idx_0 = 0
idx_1 = 300 #2000

bond_len_0, dihedrals_0, ca_lens_0, label_0, coords_0 = test_dataset.__getitem__(idx_0)
bond_len_1, dihedrals_1, ca_lens_1, label_1, coords_1 = test_dataset.__getitem__(idx_1)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)
_, seg_color = connected_components(dist)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             ignore_int_fine=True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()


def generate_protein(dihedrals_in, ca_lens_in, gen_model, num=2):
    pts_arr = []
    with torch.no_grad():
        dihedrals = torch.tensor(dihedrals_in[0], dtype=torch.float32, device=device)
        dihedrals = dihedrals.unsqueeze(0)
        ca_lens = torch.tensor(ca_lens_in[0], dtype=torch.float32, device=device)
        ca_lens = ca_lens.unsqueeze(0)
        _, mu_int_c_0, _, _ = gen_model.inference_net[1](ca_lens)
        _, mu_ext_0, _, _ = gen_model.inference_net[2](dihedrals)

        dihedrals = torch.tensor(dihedrals_in[1], dtype=torch.float32, device=device)
        dihedrals = dihedrals.unsqueeze(0)
        ca_lens = torch.tensor(ca_lens_in[1], dtype=torch.float32, device=device)
        ca_lens = ca_lens.unsqueeze(0)
        _, mu_int_c_1, _, _ = gen_model.inference_net[1](ca_lens)
        _, mu_ext_1, _, _ = gen_model.inference_net[2](dihedrals)

        z_list = []
        for t_int in np.linspace(1, 0, num):
            for t_ext in np.linspace(1, 0, num):
                mu_int = t_int * mu_int_c_0 + (1 - t_int) * mu_int_c_1
                mu_ext = t_ext * mu_ext_0 + (1 - t_ext) * mu_ext_1
                z = torch.cat([mu_int, mu_ext], dim=-1)
                # print('z_int_c / z_ext', mu_int, mu_ext)
                z_list += [z]
                # z_list += [torch.cat([mu_int_c_0, mu_ext_1], dim=-1)]
                # z_list += [torch.cat([mu_int_c_1, mu_ext_0], dim=-1)]
                # z_list += [torch.cat([mu_int_c_1, mu_ext_1], dim=-1)]

        for z in z_list:
            pts = model.generative_net(z)
            pts = protein_dataset.unnormalize_coordinates_torch(pts)
            pts = pts.cpu().numpy()
            pts_arr += [pts[0]]
    pts_arr = np.stack(pts_arr)
    pts_arr = np.reshape(pts_arr, [num, num, pts_arr.shape[1], pts_arr.shape[2]])
    return pts_arr


num = 6
pts_arr = generate_protein([dihedrals_0, dihedrals_1], [ca_lens_0, ca_lens_1], model, num=num)

edges = np.stack([protein_dataset.bond_idx0, protein_dataset.bond_idx1], axis=-1)

cmap_color = 'jet'
cmap = matplotlib.cm.ScalarMappable(cmap=cmap_color)
seg_colorm = cmap.to_rgba(seg_color)[:, :-1]
ver_color = [seg for seg in seg_colorm]
edge_color = [seg for seg in seg_colorm[edges[:, 0]]]


def get_mesh(pts_plot):
    shapes = []

    if args.dataset is not "da_10906555":
        segs = np.unique(seg_color)
        mesh = vtk.Points(pts_plot, c=ver_color)
        shapes += [mesh]
    else:
        segs = np.unique(seg_color)[-3:]
        seg_idx = np.isin(seg_color, segs)
        seg_idx = np.arange(len(seg_idx))[seg_idx]
        seg_idx = np.intersect1d(protein_dataset.ca_idx_list, seg_idx)
        ver_color_seg = [ver_color[i] for i in seg_idx]
        mesh = vtk.Points(pts_plot[seg_idx], c=ver_color_seg)
        shapes += [mesh]
    for seg in segs:
        seg_idx = seg_color == seg
        edge_idx = np.isin(edges[:, 0], np.arange(len(seg_idx))[seg_idx])
        pts_start = pts_plot[edges[edge_idx, 0]]
        pts_end = pts_plot[edges[edge_idx, 1]]
        edge_col = cmap.to_rgba(seg)[:-1]
        line = vtk.Lines(pts_start, pts_end, c=edge_col, lw=2)
        shapes += [line]
    # for idx, _ in enumerate(tqdm(edges)):
    #     line = vtk.Line(pts_plot[edges[idx, 0]], pts_plot[edges[idx, 1]], c=edge_color[idx], lw=2)
    #     shapes += [line]
    return shapes


shape_list = []
ca_signal = []
for i in range(num):
    for j in range(num):
        shape_list += [get_mesh(pts_arr[i, j].transpose())]
        # ca = protein_dataset.compute_ca_signal(torch.tensor(pts_arr[i, j]).unsqueeze(0))
        # ca = protein_dataset.unnormalize_calpha_torch(ca)
        # ca_signal += [ca]

print('Error in alpha carbon bond lengths, to Source 0')
# torch.sum(torch.abs(ca_signal[0] - ca_signal[num ** 2 - 1])).item() / np.count_nonzero(protein_dataset.ca_bonds_stats['std'])
# print('Int %d, Ext %d: %.3E' % (0, num - 1, torch.sum(torch.abs(ca_signal[0] - ca_signal[num - 1])).item()
#                                 / np.count_nonzero(protein_dataset.ca_bonds_stats['std'])))
# print('Int %d, Ext %d: %.3E' % (num - 1, 0, torch.sum(torch.abs(ca_signal[0]
#                                                                 - ca_signal[num ** 2 - 1 - (num - 1)])).item()
#                                 / np.count_nonzero(protein_dataset.ca_bonds_stats['std'])))
# print('Int %d, Ext %d: %.3E' % (num - 1, num - 1, torch.sum(torch.abs(ca_signal[0] - ca_signal[num ** 2 - 1])).item()
#                                 / np.count_nonzero(protein_dataset.ca_bonds_stats['std'])))
# print('Error in alpha carbon bond lengths, to Source 1')
# print('Int %d, Ext %d: %.3E' % (0, num - 1, torch.norm(ca_signal[num ** 2 - 1] - ca_signal[num - 1]).item()))
# print('Int %d, Ext %d: %.3E' % (num - 1, 0, torch.norm(ca_signal[num ** 2 - 1] - ca_signal[num ** 2 - 1 - (num - 1)]).item()))

# video = vtk.Video(name='test2.mp4', duration=5)
# # vp = vtk.Plotter(bg='white', interactive=False, offscreen=False)
# vp = vtk.Plotter(bg='white', interactive=False, offscreen=False)
# # vp.show(shape_list[0], resetcam=True, axes=0)
# i = 0

#
# def return_vid_list(m1, m2, m3):
#     m1 = vtk.Assembly(m1)
#     m2 = vtk.Assembly(m2)
#     m3 = vtk.Assembly(m3)
#     m1n = m1.pos(-2.5, 0, 0)
#     m2n = m2.pos(0, 0, 0)
#     m3n = m3.pos(2.5, 0, 0)
#     mlist = [m1n, m2n, m3n]
#     return mlist
#
#
# # mlist = return_vid_list(shape_list[num * i], shape_list[(num * i) + num // 2], shape_list[(num * i) + (num -1)])
# # vp.show(mlist, resetcam=True, axes=0)
# vp.show(shape_list[num * i + (num-1)], resetcam=True, axes=0)
# video.addFrame()
# for i in tqdm(range(1, num)):
#     # mlist = return_vid_list(shape_list[num * i], shape_list[(num * i) + num // 2], shape_list[(num * i) + (num - 1)])
#     # vp.show(mlist, resetcam=False, axes=0)
#     vp.show(shape_list[num * i + (num-1)], resetcam=False, axes=0)
#     video.addFrame()
# video.close()


# vtk.show(shape_list, bg='white', shape=(num, num))
shape_list2 = [shape_list[i*num + i] for i in range(num)]
vtk.show(shape_list2, bg='white', shape=(1, num), axes=0)

print('Done')
