import os
import sys
import numpy as np
import torch
import torch.optim as optim
import argparse
import time
from tqdm import tqdm
from tabulate import tabulate
import data
from models import MVAE
from utils import utils
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--save_freq', help='The number of epochs between saving model checkpoints.', type=int, default=50)
parser.add_argument('--val_epoch', help='The number of epochs between validation runs. (default: 1)', type=int,
                    default=1)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='ace2', help='The dataset to train the network on.')  #da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, nargs='+', default=[2, 16, 32], help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--lr', help='The initial learning rate for the optimizer. (default: [5E-4, 1E-3])', type=float,
                    default=1E-3)
parser.add_argument('--wd', type=float, default=5E-5, metavar='WD',
                    help='weight decay (default: 5e-5)')  # 0 for VGG since it has dropout
parser.add_argument('--lr_decay', type=float, default=0.995, help='Learning Rate Decay for SGD. (default: 0.995)')
parser.add_argument('--lr_drop', type=int, default=1000, help='Number of epochs required to decay learning rate by 0.5')

parser.add_argument('--seed', help='The random seed for execution.', type=int, default=0)
parser.add_argument('--kl_lambda', help='The value of the kl divergence penalty term in the VAE objective function. '
                                        '(default: 1E-3)', type=float, default=0E-4)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--log_dir', help='The directory for standard output log.', type=str, default='./logs')

parser.add_argument('--ckpt_freq', help='Frequency at which to save the current model.', type=int, default=300)
parser.add_argument('--resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')
parser.add_argument('--loss_func', type=str, default='loss_func_cfan', help='Loss function to train with')

parser.add_argument('--jacobian_penalty', type=float, default=5E-1,
                    help='The penalty associated with the Jacobian norm penalty for disentanglement.')
args = parser.parse_args()

args.model_dir += '/' + args.dataset
os.makedirs(args.model_dir, exist_ok=True)

if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [12, 24, 48, 96, 96]
    # args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [(x * 4) // args.n_heads for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [(x * 4) // args.n_heads for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [(x * 4) // args.n_heads for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [(x * 4) // args.n_heads for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

table = []
for arg in vars(args):
    table.append([arg, getattr(args, arg)])
print(tabulate(table, headers=['Arguments', 'Value']))

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)
    # torch.cuda.manual_seed_all(args.seed)

protein_dataset = data.Dataset(os.path.join(args.dir_data, args.dataset))
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
# print('dataset lengths', protein_dataset.__len__(), train_dataset.__len__())
#train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split()
# train_data_init = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                       (len(num_channels[1]) - 1) // blocks + 1)
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]
rad = [val for pair in zip(*[rad] * blocks) for val in pair]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)

protein_dataset.print_stats()

#TODO: Fix the hardcoding of the decimated point size
gat_kwargs = {'n_heads': args.n_heads, 'blocks': blocks}
kwargs = locals()['%s_kwargs' % args.conv_type]
ignore_int_fine = True
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             ignore_int_fine=ignore_int_fine, **kwargs)

state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

#model = utils.MyDataParallel(model, device_ids=[0, 1])
for i in range(len(model.inference_net)):
#    model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[0, 1])
    model.inference_net[i].to(device)
#model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[0, 1])
model.generative_net.to(device)

model.to(device)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_params = sum([np.prod(p.size()) for p in model_parameters])
print('Number of Trainable Parameters: %d' % model_params)

#for name, param in model.named_parameters():
#    if param.requires_grad:
#        print(name, param.shape)

criterion = getattr(utils, args.loss_func)
# criterion = utils.loss_func_cfan
#criterion = utils.loss_func_l1

jacob_pen = np.ones(args.epochs) * args.jacobian_penalty
kl_pen = np.ones(args.epochs) * args.kl_lambda

loss_kwargs = {'model_jacob': model, 'k_jacob': args.jacobian_penalty}

nan_dict = {'loss': np.nan, 'loss_val': np.nan, 'loss_mse': np.nan, 'loss_kld': np.nan, 'loss_jacobian': np.nan, 'loss_cf': np.nan}
# nan_dict = {'loss': np.nan, 'loss_val': np.nan, 'loss_mse': np.nan, 'loss_kld': np.nan}
loss_kwargs['k_jacob'] = args.jacobian_penalty

test_results = utils.test(train_data, model, criterion, args.kl_lambda, device, protein_dataset,
                          loss_kwargs=loss_kwargs, ignore_int_fine=ignore_int_fine)
print('Train Total/MSE/KLD/JNP/CF Loss: %.3E / %.3E / %.3E / %.3E, / %.3E, '
      'Generalization: %.2f A / %.2f A' %
      (test_results['loss'], test_results['loss_mse'], test_results['loss_kld'], test_results['loss_jacobian'], test_results['loss_cf'],
       test_results['loss_val'], test_results['loss_val_med']))

# np.savez('%s/results_%s_%03d.npz' % (args.model_dir, args.model_name, args.epochs),
#          train_loss=train_loss, train_loss_mse=train_loss_mse, train_loss_kld=train_loss_kld,
#          val_loss=val_loss, val_loss_mse=val_loss_mse, val_loss_kld=val_loss_kld,
#          test_loss=test_results['loss'], test_loss_mse=test_results['loss_mse'], test_loss_kld=test_results['loss_kld'],
#          train_loss_jacobian=train_loss_jacobian, val_loss_jacobian=val_loss_jacobian,
#          test_loss_jacobian=test_results['loss_jacobian'], train_loss_cf=train_loss_cf, val_loss_cf=val_loss_cf,
#          test_loss_cf=test_results['loss_cf'])

print('Done')
