import argparse
import torch
import os
import numpy as np
import time
from utils import utils
from models import MVAE
import data
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from tqdm import tqdm
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation
import matplotlib
from scipy.sparse.csgraph import connected_components


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=32)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='da_10906555', help='The dataset to train the network on.')
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=1)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

# protein_dataset.compute_ca_signal(torch.tensor(protein_dataset.coords[:16]))
drug_class = 0
# lens, norms, _, coords = protein_dataset.__getitem__(protein_dataset.labels == drug_class)
# lens = torch.tensor(lens)
# norms = torch.tensor(norms)
coords = test_dataset.dataset.coords[test_dataset.indices]
# coords = protein_dataset.unnormalize_coordinates(coords)
coords = np.swapaxes(coords, 1, 2)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)
_, seg_color = connected_components(dist)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device, ignore_int_fine=True,
             **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

criterion = utils.loss_func_cfan

with torch.no_grad():
    latent_data_int = None
    latent_sigma_int = None
    latent_data_ext = None
    latent_sigma_ext = None
    pts_arr = []
    for signal_int_fine, signal_ext, signal_int_course, _, _ in tqdm(test_data):
        signal_int_fine = signal_int_fine.to(device)
        signal_int_course = signal_int_course.to(device)
        signal_ext = signal_ext.to(device)
        _, mu_int_f, _, sigma_int_f = model.inference_net[0](signal_int_fine)
        _, mu_int_c, _, sigma_int_c = model.inference_net[1](signal_int_course)
        _, mu_ext, _, sigma_ext = model.inference_net[2](signal_ext)
        z = torch.cat([mu_int_c, mu_ext], dim=-1)
        pts = model.generative_net(z)
        pts = protein_dataset.unnormalize_coordinates_torch(pts)
        pts = pts.detach().cpu().numpy()
        pts_arr += [pts]
pts_arr = np.concatenate(pts_arr, axis=0)
pts_arr = np.swapaxes(pts_arr, 1, 2)

edges = np.stack([protein_dataset.bond_idx0, protein_dataset.bond_idx1], axis=-1)

cmap_color = 'jet'
# fig, ax = plt.subplots(1, 2, figsize=[12.8, 4.8], projection='3d')
fig = plt.figure(figsize=[12.8, 4.8])
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax2 = fig.add_subplot(1, 2, 2, projection='3d')

s = 10

cmap = matplotlib.cm.ScalarMappable(cmap=cmap_color)
# plt.subplots_adjust(bottom=0.2)
# plt.subplot(1, 2, 1)
# ax1 = fig.add_subplot(121)
ax1.set_title('Trajectory Reconstruction')
lc = Line3DCollection(pts_arr[0, edges], colors=cmap.to_rgba(seg_color[edges[:, 0]]))
ax1.add_collection3d(lc)
ax1.set_xlim(pts_arr[:, :, 0].min(), pts_arr[:, :, 0].max())
ax1.set_ylim(pts_arr[:, :, 1].min(), pts_arr[:, :, 1].max())
ax1.set_zlim(pts_arr[:, :, 2].min(), pts_arr[:, :, 2].max())
# plt1, = ax1.plot(pts_arr[0, :, 0], pts_arr[0, :, 1], pts_arr[0, :, 2], 'r.')

ax2.set_title('Trajectory Ground Truth')
lc_2 = Line3DCollection(coords[0, edges], colors=cmap.to_rgba(seg_color[edges[:, 0]]), cmap=cmap_color)
ax2.add_collection3d(lc_2)
ax2.set_xlim(coords[:, :, 0].min(), coords[:, :, 0].max())
ax2.set_ylim(coords[:, :, 1].min(), coords[:, :, 1].max())
ax2.set_zlim(coords[:, :, 2].min(), coords[:, :, 2].max())
# plt2, = ax2.plot(coords[0, :, 0], coords[0, :, 1], coords[0, :, 2], 'r.')

# plt.show()

# tbox = plt.text(0.25, 0.05, 0.05, "Drug Class 0", fontsize=11, transform=ax1.transAxes)

num_int_frames = 1
num_frames = num_int_frames * pts_arr.shape[0]
# samp_frame = np.arange(0, num_frames, num_int_frames)
# upsamp_frame = interp1d(samp_frame, principalComponents0_ext_traj[:, :2], axis=0)


def update(val):
    lc = Line3DCollection(pts_arr[val, edges], colors=cmap.to_rgba(seg_color[edges[:, 0]]), cmap=cmap_color)
    ax1.collections = []
    ax1.add_collection3d(lc)
    lc_2 = Line3DCollection(coords[val, edges], colors=cmap.to_rgba(seg_color[edges[:, 0]]), cmap=cmap_color)
    ax2.collections = []
    ax2.add_collection3d(lc_2)
    # plt1.set_data(pts_arr[val, :, 0], pts_arr[val, :, 1])
    # plt1.set_3d_properties(pts_arr[val, :, 2])
    # plt2.set_data(coords[val, :, 0], coords[val, :, 1])
    # plt2.set_3d_properties(coords[val, :, 2])
    return lc, lc_2, ax1, ax2  # plt1, plt2,


print('Generating animation')
ani = FuncAnimation(fig, update, frames=num_frames - (num_int_frames - 1), blit=False)
Writer = animation.writers['ffmpeg']
writer = Writer(fps=30)
ani.save('protein_animation_testf.mp4', writer=writer)
plt.show()

print('Done')
