import argparse
import torch
import torch.optim as optim
import os
import numpy as np
import numpy.linalg as la
import time
from utils import utils
from models import MVAE
import data
from tqdm import tqdm
from scipy.sparse.csgraph import connected_components


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=100)
parser.add_argument('--batch_size', help='The batch size to be used during training. (default: 32)', type=int,
                    default=64)

parser.add_argument('--dir_data', type=str, default='./data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, default='ace2', help='The dataset to train the network on.')  #da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, default=0)
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')

args = parser.parse_args()
args.model_dir += '/' + args.dataset
if args.model_name is None:
    args.model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, args.seed)
dir_dataset = os.path.join(args.dir_data, args.dataset)

blocks = 1
num_channels = [None] * 2
transform_flag = None
# Default architectures
if args.bond_enc_krn is None:
    # args.bond_enc_krn = [24, 48, 48, 96, 96]
    # args.dihedrals_enc_krn = [24, 48, 48, 96, 96]
    # args.xyz_dec_krn = [128, 128, 64, 64, 32]
    args.bond_enc_krn = [12, 24, 48, 96, 96]
    args.bond_enc_krn = [x * (4 // args.n_heads) for x in args.bond_enc_krn]
    args.bond_enc_krn = [val for pair in zip(*[args.bond_enc_krn] * blocks) for val in pair]
    args.dihedrals_enc_krn = [12, 24, 48, 96, 96]
    args.dihedrals_enc_krn = [x * (4 // args.n_heads) for x in args.dihedrals_enc_krn]
    args.dihedrals_enc_krn = [val for pair in zip(*[args.dihedrals_enc_krn] * blocks) for val in pair]
    args.ca_enc_krn = [12, 24, 48, 96, 96]
    args.ca_enc_krn = [x * (4 // args.n_heads) for x in args.ca_enc_krn]
    args.ca_enc_krn = [val for pair in zip(*[args.ca_enc_krn] * blocks) for val in pair]
    args.xyz_dec_krn = [128, 128, 64, 32, 16]
    args.xyz_dec_krn = [x * (4 // args.n_heads) for x in args.xyz_dec_krn]
    args.xyz_dec_krn = [val for pair in zip(*[args.xyz_dec_krn] * blocks) for val in pair]
args.bond_enc_krn.insert(0, 1)
args.dihedrals_enc_krn.insert(0, 3) #3)
args.ca_enc_krn.insert(0, 1)
num_channels[0] = [args.bond_enc_krn, args.ca_enc_krn, args.dihedrals_enc_krn]

args.xyz_dec_krn.append(3)
num_channels[1] = args.xyz_dec_krn

use_cuda = torch.cuda.is_available() and args.use_cuda
device = torch.device("cuda:0" if use_cuda else "cpu")
print('Device: %s' % device)

np.random.seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

protein_dataset = data.Dataset(dir_dataset)
# train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_random()
train_dataset, val_dataset, test_dataset = protein_dataset.dataset_split_traj_end(args.dataset)
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                         pin_memory=True)
val_data = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                       pin_memory=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4,
                                        pin_memory=True)

drug_class = 0
# lens, norms, _, coords = protein_dataset.__getitem__(protein_dataset.labels == drug_class)
# lens = torch.tensor(lens)
# norms = torch.tensor(norms)
coords = test_dataset.dataset.coords[test_dataset.indices]
# coords = protein_dataset.unnormalize_coordinates(coords)
coords = np.swapaxes(coords, 1, 2)

sampled_pts = utils.create_downsamples(protein_dataset.weighted_adj, args.dir_data, args.dataset, len(num_channels[1]))
# rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) ** 0.5 * args.kernel_rad for s in sampled_pts]
rad = [(protein_dataset.weighted_adj.shape[0] / len(s)) * args.kernel_rad for s in sampled_pts]

dist = utils.compute_distance_mat(protein_dataset.weighted_adj, args.dir_data, args.dataset,
                                  max_distance=rad[-1]*0.9)
_, seg_color = connected_components(dist)

gat_kwargs = {'n_heads': args.n_heads}
kwargs = locals()['%s_kwargs' % args.conv_type]
model = MVAE(sampled_pts, protein_dataset.weighted_adj, num_channels, rad, dist, args.n_latent, args.conv_type, device,
             ignore_int_fine=True, **kwargs)

t = time.time()
state_dict = torch.load('%s/ckpt_%s-%d.pt' % (args.model_dir, args.model_name, args.epochs), map_location=device)
model.load_state_dict(state_dict['model_state'])

print('Model loaded: %0.2fs' % (time.time() - t))

for i in range(len(model.inference_net)):
    # model.inference_net[i] = utils.MyDataParallel(model.inference_net[i], device_ids=[1, 2])
    model.inference_net[i] = model.inference_net[i].to(device)
# model.generative_net = utils.MyDataParallel(model.generative_net, device_ids=[1, 2])
model.generative_net = model.generative_net.to(device)

model = model.to(device)
model.eval()

criterion = utils.loss_func_cfan

# Create the linear classifier
# 0  Molecular Weight                   1  XLogP3
# 2  Hydrogen Bond Donor Count          3  Hydrogen Bond Acceptor Count
# 4  Rotatable Bond Count               5  Topological Polar Surface Area
# 6  Heavy Atom Count                   7  Formal Charge  x
# 8  Complexity                         9  Isotope Atom Count  x
# 10 Defined Atom Stereocenter Count    11 Undefined Atom Stereocenter Count
# 12 Defined Bond Stereocenter Count    13 Undefined Bond Stereocenter Count x
# 14 Covalently-Bonded Unit Count

if args.dataset == 'ace2':
    n_drugs = 75
else:
    n_drugs = 50
reg_model = torch.nn.Linear(args.n_latent[1], n_drugs)
# reg_model = torch.nn.Sequential(torch.nn.BatchNorm1d(args.n_latent[1]), torch.nn.Linear(args.n_latent[1], n_drugs))


def train_reg_model(protein_mod, reg_mod, dat, epochs):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.train()

    optimizer = optim.Adam(filter(lambda param: param.requires_grad, reg_mod.parameters()), 1E-3, weight_decay=5E-5)
    loss_func = torch.nn.functional.cross_entropy

    for ep in range(epochs):
        loss_sum = 0.0
        correct = 0
        for lens, dihedrals, ca_lens, labels, _ in dat:
            # lens = lens.to(device)
            ca_lens = ca_lens.to(device)
            dihedrals = dihedrals.to(device)
            labels = labels.to(device)
            # norms = norms.to(device)
            _, mu_int_c, _, _ = protein_mod.inference_net[1](ca_lens)
            # _, mu_ext, _, _ = protein_mod.inference_net[2](dihedrals)
            preds = reg_mod(mu_int_c)
            # preds = reg_mod(mu_ext)
            optimizer.zero_grad()
            loss = loss_func(preds, labels)
            loss_sum += loss.item() * lens.shape[0]
            loss.backward()
            optimizer.step()
            _, predicted = torch.max(preds.data, 1)
            correct += (predicted == labels).sum().item()
        acc = correct / len(dat.dataset) * 100
        loss_sum = loss_sum / len(dat.dataset)
        print('Epoch %03d: Loss %.2E, Accuracy %.2f %%' % (ep, loss_sum, acc))


def eval_reg_model(protein_mod, reg_mod, dat):
    protein_mod.to(device)
    reg_mod.to(device)
    protein_mod.eval()
    reg_mod.eval()

    loss_func = torch.nn.functional.cross_entropy
    correct = 0
    loss_sum = 0.0
    for lens, dihedrals, ca_lens, labels, _ in tqdm(dat):
        lens = lens.to(device)
        ca_lens = ca_lens.to(device)
        dihedrals = dihedrals.to(device)
        labels = labels.to(device)
        # norms = norms.to(device)
        _, mu_int_c, _, _ = protein_mod.inference_net[1](ca_lens)
        # _, mu_ext, _, _ = protein_mod.inference_net[2](dihedrals)
        preds = reg_mod(mu_int_c)
        # preds = reg_mod(mu_ext)
        loss = loss_func(preds, labels)
        loss_sum += loss.item() * lens.shape[0]
        _, predicted = torch.max(preds.data, 1)
        correct += (predicted == labels).sum().item()
    acc = correct / len(dat.dataset) * 100
    loss_sum = loss_sum / len(dat.dataset)
    print('Validation: Loss %.2E, Accuracy %.2E %%' % (loss_sum, acc))
    return loss_sum, acc


print('Training Latent space regression model')
train_reg_model(model, reg_model, train_data, epochs=5)
loss, acc = eval_reg_model(model, reg_model, test_data)

np.savez('%s/drug_classifier_results_%s_%03d.npz' % (args.model_dir, args.model_name, args.epochs), loss=loss, acc=acc)

print('Done')
