import argparse
import torch
import os
import numpy as np
import time
from utils import utils
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from tqdm import tqdm
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib
from scipy.sparse.csgraph import connected_components
import vtkplotter as vtk


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='ace2', help='The dataset to train the network on.') #da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=0)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
# train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
#                                          pin_memory=True)
# val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
#                                        pin_memory=True)
# test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
#                                         pin_memory=True)

# idx = [0, 600, 1200]
idx = [0, 800, 1600]
# idx = [0, 110, 220, 330]
# idx = [400, 450, 500, 550]
# idx = [0, 60, 120, 180]
bond_len = [[] for i in idx]
dihedrals = [[] for i in idx]
ca_lens = [[] for i in idx]
coords = [[] for i in idx]
labels = [[] for i in idx]

for i, id in enumerate(idx):
    bond_len[i], dihedrals[i], ca_lens[i], labels[i], coords[i] = test_dataset.__getitem__(id)
    # bond_len[i], dihedrals[i], ca_lens[i], labels[i], coords[i] = train_dataset.__getitem__(id)
sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)
_, seg_color = connected_components(dist)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             ignore_int_fine=True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()


def generate_protein(dihedrals, ca_lens, pts, gen_model):
    dihedrals = np.stack(dihedrals)
    ca_lens = np.stack(ca_lens)
    pts = np.stack(pts)
    with torch.no_grad():
        dihedrals = torch.tensor(dihedrals, device=device, dtype=torch.float32)
        ca_lens = torch.tensor(ca_lens, device=device, dtype=torch.float32)
        _, mu_int_c, _, _ = gen_model.inference_net[1](ca_lens)
        _, mu_ext, _, _ = gen_model.inference_net[2](dihedrals)

        z = torch.cat([mu_int_c, mu_ext], dim=-1)
        pts_new = gen_model.generative_net(z)
        pts_new = protein_dataset.unnormalize_coordinates_torch(pts_new)
        pts = protein_dataset.unnormalize_coordinates(pts)

        pts_new = pts_new.cpu().numpy()
    return pts, pts_new


pts_true, pts_gen = generate_protein(dihedrals, ca_lens, coords, model)

pts_agg = np.stack([pts_true, pts_gen])

pts_err = np.linalg.norm(pts_gen - pts_true, axis=1)
pts_err = np.log(pts_err)

edges = np.stack([protein_dataset.bond_idx0, protein_dataset.bond_idx1], axis=-1)
edge_error = (pts_err[:, edges[:, 0]] + pts_err[:, edges[:, 1]]) / 2

cmap_color = 'jet'
cmap = matplotlib.cm.ScalarMappable(cmap=cmap_color)
cmap_err = matplotlib.cm.ScalarMappable(cmap=cmap_color)
seg_colorm = cmap.to_rgba(seg_color)[:, :-1]
ver_color = [seg for seg in seg_colorm]
edge_color = [seg for seg in seg_colorm[edges[:, 0]]]

if args.dataset is "da_10906555":
    segs = np.unique(seg_color)[-3:]
    seg_idx0 = np.isin(seg_color, segs)
    seg_idx = np.arange(len(seg_idx0))[seg_idx0]
    seg_idx2 = np.intersect1d(protein_dataset.ca_idx_list, seg_idx)
    pts_err2 = pts_err[:, seg_idx2]

    edge_idx = np.isin(edges[:, 0], np.arange(len(seg_idx0))[seg_idx0])
    edge_error2 = edge_error[:, edge_idx]
else:
    seg_idx = protein_dataset.ca_idx_list
    pts_err2 = pts_err[:, seg_idx]
    edge_error2 = edge_error

pts_err2 = np.clip(pts_err2, a_min=np.min(edge_error2), a_max=np.percentile(edge_error2, 99))
edge_error2 = np.clip(edge_error2, a_min=np.min(edge_error2), a_max=np.percentile(edge_error2, 99))
color_edge = cmap_err.to_rgba(edge_error2)[:, :, :-1]
color_err = cmap_err.to_rgba(pts_err2)[:, :, :-1]


def get_mesh(pts_plot, color_error=False, ver_color=None, edge_color=None):
    shapes = []

    if args.dataset is not "da_10906555":
        segs = np.unique(seg_color)
        seg_idx = protein_dataset.ca_idx_list
        if not color_error:
            mesh = vtk.Points(pts_plot[seg_idx], c='red', r=3)
        else:
            # ver_color = ver_color[seg_idx]
            ver_color = [c for c in ver_color]
            mesh = vtk.Points(pts_plot[seg_idx], c=ver_color, r=3).pointColors(ver_color, cmap='jet', vmin=np.min(edge_error2), vmax=np.max(edge_error2))
        shapes += [mesh]
    else:
        segs = np.unique(seg_color)[-3:]
        seg_idx = np.isin(seg_color, segs)
        seg_idx = np.arange(len(seg_idx))[seg_idx]
        seg_idx = np.intersect1d(protein_dataset.ca_idx_list, seg_idx)
        ver_color_seg = [seg_colorm[i] for i in seg_idx]
        if not color_error:
            mesh = vtk.Points(pts_plot[seg_idx], c=ver_color_seg)
        else:

            # ver_color = ver_color[seg_idx]
            ver_color = [c for c in ver_color]
            mesh = vtk.Points(pts_plot[seg_idx], c=ver_color).pointColors(ver_color, cmap='jet', vmin=np.min(edge_error2), vmax=np.max(edge_error2))
        shapes += [mesh]
    for seg in segs:
        seg_idx = seg_color == seg
        edge_idx = np.isin(edges[:, 0], np.arange(len(seg_idx))[seg_idx])
        pts_start = pts_plot[edges[edge_idx, 0]]
        pts_end = pts_plot[edges[edge_idx, 1]]
        edge_col = cmap.to_rgba(seg)[:-1]
        if not color_error:
            line = vtk.Lines(pts_start, pts_end, c=edge_col, lw=2)
            shapes += [line]
        else:
            # edge_colerr = edge_color[edge_idx]
            edge_colerr = edge_color
            for pstart, pend, ecol in zip(pts_start, pts_end, edge_colerr):
                line = vtk.Line(pstart, pend, c=ecol, lw=2)
            # line = vtk.Lines(pts_start, pts_end, c=edge_color, lw=2)
            # line = vtk.Lines(pts_start, pts_end, c=edge_col, lw=2)
                shapes += [line]
    # for idx, _ in enumerate(tqdm(edges)):
    #     line = vtk.Line(pts_plot[edges[idx, 0]], pts_plot[edges[idx, 1]], c=edge_color[idx], lw=2)
    #     shapes += [line]
    return shapes


shape_list = []
# ca_signal = []
for i in range(pts_agg.shape[0]):
    for j in range(pts_agg.shape[1]):
        # if i == 0:
        shape_list += [get_mesh(pts_agg[i, j].transpose())]
        # else:
            # ver_color = [c for c in color_err[j]]
            # ed_color = [c for c in color_edge[j]]
            # color_bar_flag = (j == pts_agg.shape[1] - 1)
            # shape_list += [get_mesh(pts_agg[i, j].transpose(), color_error=True, ver_color=color_err[j],
            #                         edge_color=color_edge[j])]

print('Error in alpha carbon bond lengths, to Source 0')

# shape_list[-1][0].addScalarBar(vmin=np.min(edge_error2), vmax=np.max(edge_error2))
# vp = vtk.Plotter(shape=(pts_agg.shape[0], pts_agg.shape[1]), axes=0, bg='white')
# vp.show(shape_list)
# vp.addScalarBar()
# vp.show(interactive=1)
vtk.show(shape_list, bg='white', shape=(pts_agg.shape[0], pts_agg.shape[1]), axes=0)

print('Done')
