import numpy as np
import os
import time
import random
import glob
import argparse
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm

from nets.load_net import gnn_model  # import GNNs
from data.data import LoadData  # import dataset
from train import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network


def gpu_setup(use_gpu, gpu_id):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    if torch.cuda.is_available() and use_gpu:
        print('cuda available with GPU:', torch.cuda.get_device_name(0))
        device = torch.device("cuda")
    else:
        print('cuda not available')
        device = torch.device("cpu")
    return device


def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param


def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs):
    avg_test_acc = []
    avg_train_acc = []
    avg_convergence_epochs = []

    t0 = time.time()
    per_epoch_time = []

    dataset = LoadData(DATASET_NAME)

    print("[!] Adding graph self-loops for GCN/GAT models (central node trick).")
    dataset._add_self_loops()

    trainset, valset, testset = dataset.train, dataset.val, dataset.test
    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""".format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for split_number in range(10):
            t0_split = time.time()
            log_dir = os.path.join(root_log_dir, "RUN_" + str(split_number))
            writer = SummaryWriter(log_dir=log_dir)

            # setting seeds
            random.seed(params['seed'])
            np.random.seed(params['seed'])
            torch.manual_seed(params['seed'])
            if device.type == 'cuda':
                torch.cuda.manual_seed(params['seed'])

            print("RUN NUMBER: ", split_number)
            trainset, valset, testset = dataset.train[split_number], dataset.val[split_number], dataset.test[split_number]
            print("Training Graphs: ", len(trainset))
            print("Validation Graphs: ", len(valset))
            print("Test Graphs: ", len(testset))
            print("Number of Classes: ", net_params['n_classes'])

            model = gnn_model(MODEL_NAME, net_params)
            model = model.to(device)

            if net_params['optimizer'] == 'Adam':
                optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
            elif params['optimizer'] == 'SGD':
                optimizer = optim.SGD(model.parameters(), lr=params['init_lr'], momentum=0.9, weight_decay=params['weight_decay'], nesterov=True)
            else:
                raise NotImplementedError
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                             factor=params['lr_reduce_factor'],
                                                             patience=params['lr_schedule_patience'],
                                                             verbose=True)

            epoch_train_losses, epoch_val_losses = [], []
            epoch_train_accs, epoch_val_accs = [], []

            train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, drop_last=False, collate_fn=dataset.collate)
            val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, drop_last=False, collate_fn=dataset.collate)
            test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, drop_last=False, collate_fn=dataset.collate)

            with tqdm(range(params['epochs'])) as t:
                for epoch in t:

                    t.set_description('Epoch %d' % epoch)

                    start = time.time()

                    epoch_train_loss, epoch_train_acc, optimizer, epoch_cls_loss, epoch_neg_loss, epoch_pos_loss = train_epoch(model, optimizer, device, train_loader, epoch)

                    _, _, = evaluate_network(model, device, train_loader)

                    epoch_val_loss, epoch_val_acc= evaluate_network(model, device, val_loader)
                    _, epoch_test_acc= evaluate_network(model, device, test_loader)

                    epoch_train_losses.append(epoch_train_loss)
                    epoch_val_losses.append(epoch_val_loss)
                    epoch_train_accs.append(epoch_train_acc)
                    epoch_val_accs.append(epoch_val_acc)

                    writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                    writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                    writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                    writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                    writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

                    t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                                  train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                                  train_acc=epoch_train_acc, val_acc=epoch_val_acc,
                                  test_acc=epoch_test_acc)

                    per_epoch_time.append(time.time()-start)

                    # Saving checkpoint
                    ckpt_dir = os.path.join(root_ckpt_dir, "RUN_" + str(split_number))
                    if not os.path.exists(ckpt_dir):
                        os.makedirs(ckpt_dir)
                    torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                    files = glob.glob(ckpt_dir + '/*.pkl')
                    for file in files:
                        epoch_nb = file.split('_')[-1]
                        epoch_nb = int(epoch_nb.split('.')[0])
                        if epoch_nb < epoch-1:
                            os.remove(file)

                    scheduler.step(epoch_val_loss)

                    if optimizer.param_groups[0]['lr'] < params['min_lr']:
                        print("\n!! LR EQUAL TO MIN LR SET.")
                        break

                    # Stop training after params['max_time'] hours
                    if time.time()-t0_split > params['max_time']*3600/10:       # Dividing max_time by 10, since there are 10 runs in TUs
                        print('-' * 89)
                        print("Max_time for one train-val-test split experiment elapsed {:.3f} hours, so stopping".format(params['max_time']/10))
                        break

            _, test_acc = evaluate_network(model, device, test_loader)
            _, train_acc = evaluate_network(model, device, train_loader)
            avg_test_acc.append(test_acc)
            avg_train_acc.append(train_acc)
            avg_convergence_epochs.append(epoch)

            print("Test Accuracy [LAST EPOCH]: {:.4f}".format(test_acc))
            print("Train Accuracy [LAST EPOCH]: {:.4f}".format(train_acc))
            print("Convergence Time (Epochs): {:.4f}".format(epoch))

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time()-t0)/3600))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
    print("AVG CONVERGENCE Time (Epochs): {:.4f}".format(np.mean(np.array(avg_convergence_epochs))))
    # Final test accuracy value averaged over 10-fold
    print("""\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}"""          .format(np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100))
    print("\nAll splits Test Accuracies:\n", avg_test_acc)
    print("""\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}"""          .format(np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100))
    print("\nAll splits Train Accuracies:\n", avg_train_acc)

    writer.close()

    """
        Write the results in out/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n edge_num: {}\n\n
    FINAL RESULTS\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}\n\n
    Average Convergence Time (Epochs): {:.4f} with s.d. {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\nAll Splits Test Accuracies: {}""" \
                .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], len(trainset[0][0].edata['feat']),
                        np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100,
                        np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100,
                        np.mean(avg_convergence_epochs), np.std(avg_convergence_epochs),
                        (time.time()-t0)/3600, np.mean(per_epoch_time), avg_test_acc))


def main():
    """
        USER CONTROLS
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu_id', help="Please give a value for gpu id", default=0)
    parser.add_argument('--model', help="Please give a value for model name", default="GCN")
    parser.add_argument('--dataset', help="Please give a value for dataset name", default="MUTAG")
    parser.add_argument('--out_dir', help="Please give a value for out_dir", default="out/TUs_graph_classification/")
    parser.add_argument('--seed', help="Please give a value for seed", default=41)
    parser.add_argument('--epochs', help="Please give a value for epochs", default=1000)
    parser.add_argument('--batch_size', help="Please give a value for batch_size", default=20)
    parser.add_argument('--init_lr', help="Please give a value for init_lr", default=7e-4)
    parser.add_argument('--lr_reduce_factor', help="Please give a value for lr_reduce_factor", default=0.5)
    parser.add_argument('--lr_schedule_patience', help="Please give a value for lr_schedule_patience", default=25)
    parser.add_argument('--min_lr', help="Please give a value for min_lr", default=1e-6)
    parser.add_argument('--weight_decay', help="Please give a value for weight_decay", default=0.0)
    parser.add_argument('--print_epoch_interval', help="Please give a value for print_epoch_interval", default=5)
    parser.add_argument('--L', help="Please give a value for L", default=4)
    parser.add_argument('--hidden_dim', help="Please give a value for hidden_dim", default=146)
    parser.add_argument('--out_dim', help="Please give a value for out_dim", default=146)
    parser.add_argument('--residual', help="Please give a value for residual", default=True)
    parser.add_argument('--readout', help="Please give a value for readout", default='mean')
    parser.add_argument('--in_feat_dropout', help="Please give a value for in_feat_dropout", default=0.0)
    parser.add_argument('--dropout', help="Please give a value for dropout", default=0.0)
    parser.add_argument('--batch_norm', help="Please give a value for batch_norm", default=True)
    parser.add_argument('--self_loop', help="Please give a value for self_loop", default=False)
    parser.add_argument('--max_time', help="Please give a value for max_time", default=30)
    parser.add_argument('--optimizer', help="Please choose an optimizer", default='Adam')
    parser.add_argument('--cluster', default=False, action='store_true')

    args = parser.parse_args()
    config = {}

    # device
    config['gpu'] = {'id': int(args.gpu_id), 'use': True}
    device = gpu_setup(config['gpu']['use'], config['gpu']['id'])

    # model, dataset, out_dir
    MODEL_NAME = args.model
    DATASET_NAME = args.dataset
    dataset = LoadData(DATASET_NAME)
    out_dir = args.out_dir

    # parameters
    params = {}
    params['seed'] = int(args.seed)
    params['epochs'] = int(args.epochs)
    params['batch_size'] = int(args.batch_size)
    params['init_lr'] = float(args.init_lr)
    params['lr_reduce_factor'] = float(args.lr_reduce_factor)
    params['lr_schedule_patience'] = int(args.lr_schedule_patience)
    params['min_lr'] = float(args.min_lr)
    params['weight_decay'] = float(args.weight_decay)
    params['print_epoch_interval'] = int(args.print_epoch_interval)
    params['max_time'] = float(args.max_time)

    # network parameters
    net_params = {}
    net_params['device'] = device
    net_params['gpu_id'] = config['gpu']['id']
    net_params['batch_size'] = params['batch_size']
    net_params['L'] = int(args.L)
    net_params['hidden_dim'] = int(args.hidden_dim)
    net_params['out_dim'] = int(args.out_dim)
    net_params['residual'] = args.residual
    net_params['readout'] = args.readout
    net_params['in_feat_dropout'] = float(args.in_feat_dropout)
    net_params['dropout'] = float(args.dropout)
    net_params['batch_norm'] = args.batch_norm
    net_params['self_loop'] = args.self_loop
    net_params['optimizer'] = args.optimizer
    net_params['cluster'] = args.cluster

    # TUs
    net_params['in_dim'] = dataset.all.graph_lists[0].ndata['feat'][0].shape[0]
    net_params['edge_dim'] = dataset.all.graph_lists[0].edata['feat'][0].shape[0] \
        if 'feat' in dataset.all.graph_lists[0].edata else None
    net_params['max_num_node'] = dataset.max_num_node
    num_classes = len(np.unique(dataset.all.graph_labels))
    net_params['n_classes'] = num_classes

    root_log_dir = out_dir + 'logs/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    root_ckpt_dir = out_dir + 'checkpoints/' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_file_name = out_dir + 'results/result_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    write_config_file = out_dir + 'configs/config_' + MODEL_NAME + "_" + DATASET_NAME + "_GPU" + str(config['gpu']['id']) + "_" + time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')
    dirs = root_log_dir, root_ckpt_dir, write_file_name, write_config_file

    if not os.path.exists(out_dir + 'results'):
        os.makedirs(out_dir + 'results')

    if not os.path.exists(out_dir + 'configs'):
        os.makedirs(out_dir + 'configs')

    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)
    train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs)


main()
