import argparse
import torch
import torch.optim as optim
import os
import numpy as np
import numpy.linalg as la
import time
from utils import utils
from models import MVAE
import data
from tqdm import tqdm
from scipy.sparse.csgraph import connected_components


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='./data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='ace2', help='The dataset to train the network on.') # da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=1)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)
dir_dataset = os.path.join(args.dir_data, args.dataset)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [(x * 4) // args.n_heads for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [(x * 4) // args.n_heads for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [(x * 4) // args.n_heads for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [(x * 4) // args.n_heads for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:2" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(dir_dataset)
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

drug_class = 0
# lens, norms, _, coords = protein_dataset.__getitem__(protein_dataset.labels == drug_class)
# lens = torch.tensor(lens)
# norms = torch.tensor(norms)
coords = test_dataset.dataset.coords[test_dataset.indices]
# coords = protein_dataset.unnormalize_coordinates(coords)
coords = np.swapaxes(coords, 1, 2)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                             len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)
_, seg_color = connected_components(dist)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             ignore_int_fine=True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

criterion = utils.loss_func_cfan

# Create the linear classifier
# 0  Molecular Weight                   1  XLogP3
# 2  Hydrogen Bond Donor Count          3  Hydrogen Bond Acceptor Count
# 4  Rotatable Bond Count               5  Topological Polar Surface Area
# 6  Heavy Atom Count                   7  Formal Charge  x
# 8  Complexity                         9  Isotope Atom Count  x
# 10 Defined Atom Stereocenter Count    11 Undefined Atom Stereocenter Count
# 12 Defined Bond Stereocenter Count    13 Undefined Bond Stereocenter Count x
# 14 Covalently-Bonded Unit Count

properties = np.loadtxt(dir_dataset + '/chem_properties.csv', delimiter=',')
# log_idx = [0, 5, 8] #0 5 8
# properties[:, log_idx] = np.log(properties[:, log_idx])
if args.dataset is not "ace2":
    idx = [0, 2, 3, 4, 5, 6, 8]
else:
    idx = [0, 1, 2]
properties = properties[:, idx]
properties_mu = np.mean(properties, axis=0, keepdims=True)
properties_std = np.std(properties, axis=0, keepdims=True)
properties = (properties - properties_mu) \
             / (properties_std + 1E-12)
# idx = [0, 1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 14]
n_properties = properties.shape[1]
ext_flag = True
if ext_flag:
    # reg_model = torch.nn.Sequential(torch.nn.Linear(args.n_latent[2], 16), torch.nn.BatchNorm1d(16),
    #                                 torch.nn.ReLU(), torch.nn.Linear(16, n_properties))
    reg_model = torch.nn.Linear(args.n_latent[2], n_properties)
else:
    # reg_model = torch.nn.Sequential(torch.nn.Linear(args.n_latent[1], 16), torch.nn.BatchNorm1d(16),
    #                                 torch.nn.ReLU(), torch.nn.Linear(16, n_properties))
    reg_model = torch.nn.Linear(args.n_latent[1], n_properties)


def train_reg_model(protein_mod, reg_mod, props, dat, epochs):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.train()

    optimizer = optim.Adam(filter(lambda param: param.requires_grad, reg_mod.parameters()), 1E-3, weight_decay=5E-5)

    for ep in range(epochs):
        error = np.zeros([n_properties])
        loss_sum = 0.0
        for lens, dihedrals, ca_lens, labels, _ in dat:
            if ext_flag:
                dihedrals = dihedrals.to(device)
                _, mu, _, _ = protein_mod.inference_net[2](dihedrals)
            else:
                ca_lens = ca_lens.to(device)
                _, mu, _, _ = protein_mod.inference_net[1](ca_lens)
            prop = torch.tensor(props[labels], device=device, dtype=torch.float32)
            preds = reg_mod(mu)
            optimizer.zero_grad()
            loss = torch.nn.functional.smooth_l1_loss(2 * preds, 2 * prop)
            loss_sum += loss.item() * lens.shape[0]
            loss.backward()
            optimizer.step()
            err = torch.sum(torch.abs(preds - prop), dim=0)
            err = err.detach().cpu().numpy()
            error += err
        error = error / len(dat.dataset)
        loss_sum = loss_sum / len(dat.dataset)
        print('Epoch %03d: Loss %.2E' % (ep, loss_sum))
        print('Diff', error)


def train_reg_model_pca(protein_mod, reg_mod, props, dat, epochs):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.train()

    optimizer = optim.Adam(filter(lambda param: param.requires_grad, reg_mod.parameters()), 1E-3, weight_decay=5E-5)

    for ep in range(epochs):
        error = np.zeros([n_properties])
        loss_sum = 0.0
        for sig, labels in dat:
            sig = sig.to(device)
            preds = reg_mod(sig)
            prop = torch.tensor(props[labels], device=device, dtype=torch.float32)
            optimizer.zero_grad()
            loss = torch.nn.functional.smooth_l1_loss(2 * preds, 2 * prop)
            loss_sum += loss.item() * sig.shape[0]
            loss.backward()
            optimizer.step()
            err = torch.sum(torch.abs(preds - prop), dim=0)
            err = err.detach().cpu().numpy()
            error += err
        error = error / len(dat.dataset)
        loss_sum = loss_sum / len(dat.dataset)
        print('Epoch %03d: Loss %.2E' % (ep, loss_sum))
        print('Diff', error)
    return error


def eval_reg_model(protein_mod, reg_mod, props, dat):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.eval()

    error = np.zeros([n_properties])
    loss_sum = 0.0
    with torch.no_grad():
        for lens, dihedrals, ca_lens, labels, _ in tqdm(dat):
            if ext_flag:
                dihedrals = dihedrals.to(device)
                _, mu, _, _ = protein_mod.inference_net[2](dihedrals)
            else:
                ca_lens = ca_lens.to(device)
                _, mu, _, _ = protein_mod.inference_net[1](ca_lens)
            prop = torch.tensor(props[labels], device=device, dtype=torch.float32)
            preds = reg_mod(mu)
            loss = torch.nn.functional.smooth_l1_loss(2 * preds, 2 * prop)
            loss_sum += loss.item() * lens.shape[0]
            err = torch.sum(torch.abs(preds - prop), dim=0)
            err = err.detach().cpu().numpy()
            error += err
    error = error / len(dat.dataset)
    loss_sum = loss_sum / len(dat.dataset)
    print('Loss %.2E' % loss_sum)
    return error


def eval_reg_model_pca(protein_mod, reg_mod, props, dat):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.eval()

    error = np.zeros([n_properties])
    loss_sum = 0.0
    for sig, labels in tqdm(dat):
        sig = sig.to(device)
        preds = reg_mod(sig)
        prop = torch.tensor(props[labels], device=device, dtype=torch.float32)
        loss = torch.nn.functional.smooth_l1_loss(2 * preds, 2 * prop)
        loss_sum += loss.item() * sig.shape[0]
        err = torch.sum(torch.abs(preds - prop), dim=0)
        err = err.detach().cpu().numpy()
        error += err
    error = error / len(dat.dataset)
    loss_sum = loss_sum / len(dat.dataset)
    print('Loss %.2E' % loss_sum)
    return error

# PCA Embedding
print('Computing PCA')
pca_idx = test_data.dataset.indices
if ext_flag:
    pca_samp = test_data.dataset.dataset.edge_normal[pca_idx]
else:
    pca_samp = test_data.dataset.dataset.ca_bonds[pca_idx]
pca_samp = np.reshape(pca_samp, [pca_samp.shape[0], -1])
pca_samp = (pca_samp - np.mean(pca_samp, axis=0, keepdims=True)) / np.std(pca_samp, axis=0, keepdims=True)
pca_labels_test = test_data.dataset.dataset.labels[pca_idx]
# u_pca, _, _ = la.svd(pca_samp, full_matrices=False)
with torch.no_grad():
    u, _, s = torch.svd(torch.tensor(pca_samp, device=device))
u_pca = u.detach().cpu().numpy()


if ext_flag:
    # reg_model_pca = torch.nn.Sequential(torch.nn.Linear(args.n_latent[2], 16), torch.nn.BatchNorm1d(16),
    #                                     torch.nn.ReLU(), torch.nn.Linear(16, n_properties))
    reg_model_pca = torch.nn.Linear(args.n_latent[2], n_properties)
else:
    # reg_model_pca = torch.nn.Sequential(torch.nn.Linear(args.n_latent[1], 16), torch.nn.BatchNorm1d(16),
    #                                     torch.nn.ReLU(), torch.nn.Linear(16, n_properties))
    reg_model_pca = torch.nn.Linear(args.n_latent[1], n_properties)

if ext_flag:
    pca_samp_test = torch.tensor(u_pca[:, :args.n_latent[2]], dtype=torch.float32)
else:
    pca_samp_test = torch.tensor(u_pca[:, :args.n_latent[1]], dtype=torch.float32)
pca_labels_test = torch.tensor(pca_labels_test, dtype=torch.int)
pca_dataset_test = torch.utils.data.TensorDataset(pca_samp_test, pca_labels_test)
test_data_pca = torch.utils.data.DataLoader(pca_dataset_test, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                            pin_memory=True)

print('Training Latent space regression model')
train_reg_model(model, reg_model, properties, train_data, epochs=10)
error_reg = eval_reg_model(model, reg_model, properties, test_data)
print('Normalized Error', error_reg)
error_reg_unnorm = error_reg * properties_std + properties_mu
print('Error', error_reg_unnorm)

print('Training PCA space regression model')
error_pca = train_reg_model_pca(model, reg_model_pca, properties, test_data_pca, epochs=10)
# error_reg = eval_reg_model_pca(model, reg_model_pca, properties, test_data_pca)
print('Normalized Error', error_pca)
error_pca_unnorm = error_pca * properties_std + properties_mu
print('Error', error_pca_unnorm)

np.savez('%s/drug_regression_results_%s_%03d.npz' % (args.model_dir, args.model_name, args.epochs),
         normalized_error_latent=error_reg, error_latent=error_reg_unnorm, normalized_error_pca=error_pca,
         error_pca=error_pca_unnorm)

print('Done')
