#!/usr/bin/env python
import os
import sys
import h5py
import copy
import argparse
import datetime
import numpy as np
import pandas as pd
from torch import nn, optim
from tensorboardX import SummaryWriter

file_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(file_path, 'nets'))
sys.path.append(os.path.join(file_path, 'trainers'))
sys.path.append(os.path.join(file_path, 'data_aux'))
sys.path.append(os.path.join(file_path, '../sequence_model'))

from cnn_predictors import *

from nn_trainer import *
from gp_trainer import *
from dataset_generator import *

from gp_tools import *

def get_cmd_arguments():
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/unzipped_data_matrices_PCAWG_10000_0_0.0.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-u', "--held-out", required=False, nargs='?', action='store', type=str, dest='heldout_file',
                    default=None, help='Path to file of held-out samples file')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-m', "--mappability", required=False, nargs='?', action='store', type=float, dest='mappability',
                    default=0.7, help='Mappability lower bound')
    ap.add_argument('-a', "--attention", required=False, nargs='?', action='store', type=bool, dest='get_attention',
                    default=False, help='True: train with attention map training')
    ap.add_argument('-gp', "--gaussian", required=False, nargs='?', action='store', type=int, dest='run_gaussian',
                    default=0, help='Number of GP reinitializations and training runs')
    ap.add_argument('-n', "--network", required=False, nargs='?', action='store', type=str, dest='net',
                    default='cnn', help='The type of neural network model to use (\'fc\' or \'cnn\')')

    # Train parameters
    ap.add_argument('-r', "--train-ratio", required=False, nargs='?', action='store', type=float, dest='train_ratio',
                    default=0.8, help='Train set split size ratio')
    ap.add_argument('-ho', "--heldout-ratio", required=False, nargs='?', action='store', type=float, dest='heldout_ratio',
                    default=0.2, help='Held-out set split size ratio (will be extracted prior to train validation split)')
    ap.add_argument('-e', "--epochs", required=False, nargs='?', action='store', type=int, dest='epochs',
                    default=20, help='Number of epochs')
    ap.add_argument('-b', "--batch", required=False, nargs='?', action='store', type=int, dest='bs',
                    default=128, help='Batch size')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='nn_reruns',
                    default=1, help='Number of NN reinitializations and training runs')

    # Run management parameters
    ap.add_argument('-sm', "--save-model", required=False, nargs='?', action='store', type=bool, dest='save_model',
                    default=False, help='True: save best model across all reruns')
    ap.add_argument('-st', "--save-training", required=False, nargs='?', action='store', type=float, dest='save_training',
                    default=False, help='True: save training process and results to Tensorboard file')
    ap.add_argument('-g', "--gpus", required=False, nargs='?', action='store', type=str, dest='gpus',
                    default='all', help='GPUs devices (all/comma separted list)')

    return ap.parse_args()

def predict(model, data_ds, label_ids):
    data_loader = DataLoader(data_ds, batch_size=128, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)
    corr_coef_sums = np.zeros(len(label_ids))
    all_preds = [[] for _ in range(len(label_ids))]
    all_features = [[] for _ in range(len(label_ids))]
    all_true = [[] for _ in range(len(label_ids))]
    for j, (X, t_lst) in enumerate(data_loader):
        y_lst, features_lst, _ = model(X.cuda())
        with torch.no_grad():
            for i, t in enumerate(t_lst):
                y = y_lst[i]
                feature_vecs = features_lst[i]
                all_features[i].append(feature_vecs.cpu().detach().numpy())
                all_preds[i].extend(y.data.cpu().numpy().tolist())
                all_true[i].extend(t.data.cpu().numpy().tolist())
    all_features = [np.concatenate(all_features[j], axis=0) for j in range(len(all_features))]
    return all_preds, all_true, all_features, [r2_score(all_true[i], all_preds[i]) for i in range(len(label_ids))]


def run_gp(device, train_set, test_set, ho_set=None):
    run_successeed = False
    n_inducing = 2000
    while not run_successeed and n_inducing > 0:
        gp_trainer = GPTrainer(device, train_set, test_set, heldout_tup=ho_set, n_inducing=n_inducing)
        try:
            print('Running GP with {} inducing points...'.format(n_inducing))
            gp_test_results, gp_ho_results = gp_trainer.run()
            pvals = calc_pvals(np.array(test_set[1]),
                               gp_test_results['gp_mean'],
                               gp_test_results['gp_std'],
                               onesided=False)
            print('Test set calibration score: {}'.format(calibration_score_by_pvals(pvals)))
        except RuntimeError as err:
            print('Run failed with {} inducing points. Encountered run-time error in training: {}'
                  .format(n_inducing, err))
            n_inducing -= 200
            continue
        run_successeed = True
    if run_successeed: return gp_trainer, gp_test_results, gp_ho_results
    return None, None, None


def main():
    args = get_cmd_arguments()
    labels_str = '-'.join(args.label_ids)
    out_dir = os.path.join(args.out_dir, labels_str, str(datetime.datetime.now()))
    print('Generating prediction for cancer types: {}'.format(args.label_ids))
    # Configure GPUs
    if args.gpus is None:
        print('Using CPU device.')
        device = torch.device('cpu')
    else:
        print('Using GPU device: \'{}\''.format(args.gpus))
        device = torch.device('cuda')
        if args.gpus != 'all':
            os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    # Create output directory
    if args.save_model or args.save_training or args.get_attention or args.run_gaussian:
        print('Saving results under: \'{}\''.format(out_dir))
        os.makedirs(out_dir)
        args_dict = vars(args)
        with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
            [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    # Train model for multiple reruns and choose best as final
    best_overall_acc = 0
    accs_df = pd.DataFrame()
    for r in range(args.nn_reruns):
        # Intialize new dataset split
        data_generator = DatasetGenerator(args.data_file,
                                          args.label_ids,
                                          args.mappability,
                                          args.heldout_ratio,
                                          heldout_file=args.heldout_file,
                                          track_file=args.track_file)
        train_ds, test_ds = data_generator.get_datasets(args.split_method, args.train_ratio)
        ho_ds = data_generator.get_heldout_dataset()
        print('Using {} predictors for prediction.'.format(train_ds.get_data_shape()[2]))

        # Initialize a new model
        print('Setting model and optimizers for run {}/{}...'.format(r + 1, args.nn_reruns))
        if args.net == 'fc':
            model = FCNet(train_ds.get_data_shape(), len(args.label_ids))
        else:
            model = SimpleMultiTaskResNet(train_ds.get_data_shape(),
                                          len(args.label_ids),
                                          get_attention_maps=args.get_attention)

        optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=False)
        loss_fn = nn.MSELoss()
        if args.gpus is not None: model = nn.DataParallel(model)

        if args.save_training:
            writer = SummaryWriter(logdir=out_dir, comment=labels_str)
            writer.add_text('configurations', str(args), 0)
            writer.add_text('model', str(model), 0)
        else:
            writer = None
        nn_trainer = NNTrainer(model,
                               optimizer,
                               loss_fn,
                               args.bs,
                               args.label_ids,
                               train_ds,
                               test_ds,
                               device,
                               writer,
                               get_attention_maps=args.get_attention)

        # Run datasplit training and evaluation
        best_run_acc = 0
        best_epoch = 0
        for epoch in range(1, args.epochs + 1):
            print('Running epoch {}/{}'.format(epoch, args.epochs))
            train_losses, train_accs, train_features_lst, train_pred_lst, train_true_lst = nn_trainer.train(epoch, r)
            test_losses, test_accs, test_features_lst, test_pred_lst, test_true_lst, test_attention = nn_trainer.test(epoch, r)

            # Test GP over the new feature vectors, ignore run if GP failes
            for l in range(len(args.label_ids)):
                train_set = (np.array(train_features_lst[l]), np.array(train_true_lst[l]), train_ds.get_chromosome_locations())
                test_set = (np.array(test_features_lst[l]), np.array(test_true_lst[l]), test_ds.get_chromosome_locations())
                gp_trainer, gp_test_results, gp_ho_results = run_gp(device, train_set, test_set)
                if gp_test_results is None:
                    print('GP Run failed, skipping to next epoch.')
                    continue
                else:
                    #print(ho_ds.get_stds())
                    #print(gp_ho_results['gp_std'])
                    #print(ho_ds.get_stds().shape)
                    #print(gp_ho_results['gp_std'].shape)
                    print('#######', r2_score(test_ds.get_stds(), gp_test_results['gp_std']), '#######')
            # Keep only the best model according to test performance 
            if test_accs[0] > best_run_acc:
                print('Changing run model since best R2 was {} compared to previous {}'.format(test_accs[0], best_run_acc))
                best_run_acc, best_epoch = test_accs[0], epoch
                best_train_accs, best_test_accs = train_accs, test_accs
                best_run_model, best_run_att = copy.deepcopy(model), test_attention

        # Evaluate model performance over held-out set
        print('Best validation accuracy for run {}/{} was: {}.'.format(r + 1, args.nn_reruns, best_run_acc))
        print('Running best model over {} held-out set samples...'.format(ho_ds.get_data_shape()[0]))
        ho_preds, ho_labels, ho_features, ho_accs = predict(best_run_model, ho_ds, args.label_ids)
        print('Model held-out accuracy: {}'.format(ho_accs))
        

        # Save run performance
        for j, l in enumerate(args.label_ids):
            accs_df.loc[r, 'Train_{}'.format(l)] = best_train_accs[j]
            accs_df.loc[r, 'Test_{}'.format(l)] = best_test_accs[j]
            accs_df.loc[r, 'Held-out_{}'.format(l)] = ho_accs[j]

        # Save attention maps from best overall model
        if args.get_attention:
            with h5py.File(os.path.join(out_dir, 'attention_maps_{}.h5'.format(r)), 'w') as h5f:
                h5f.create_dataset('attention_maps', data=best_run_att)
                h5f.create_dataset('chr_locs', data=test_ds.get_chromosome_locations())
                h5f.create_dataset('idxs', data=test_ds.get_set_indices())
                h5f.create_dataset('pred_lbls', data=test_pred_lst)
                h5f.create_dataset('true_lbls', data=test_true_lst)

        # Save best run model
        if args.save_model:
            print('Saving model and test indices from run {} to {}...'.format(r, out_dir))
            np.save(os.path.join(out_dir, 'test_indices_{}'.format(r)), ho_ds.get_set_indices())
            torch.save(best_run_model.state_dict(), os.path.join(out_dir, 'best_model_{}.pt'.format(r)))

        # Run GP on best overall model
        if args.run_gaussian > 0:
            print('Computing {} train set features...'.format(train_ds.get_data_shape()[0]))
            train_preds, train_labels, train_features, train_acc = predict(best_run_model, train_ds, args.label_ids)
            print('Model train accuracy: {}'.format(train_acc))
            print('Computing {} validation set features...'.format(test_ds.get_data_shape()[0]))
            test_preds, test_labels, test_features, test_acc = predict(best_run_model, test_ds, args.label_ids)
            print('Model validation accuracy: {}'.format(test_acc))
            gp_h5 = h5py.File(os.path.join(out_dir, 'gp_results_{}.h5'.format(r)), 'w')
            for l in range(len(args.label_ids)):
                print('Running gaussian process model for {}...'.format(args.label_ids[l]))
                lbl_grp = gp_h5.create_group(args.label_ids[l])
                train_set = (np.array(train_features[l]), np.array(train_labels[l]), train_ds.get_chromosome_locations())
                test_set = (np.array(test_features[l]), np.array(test_labels[l]), test_ds.get_chromosome_locations())
                ho_set = (np.array(ho_features[l]), np.array(ho_labels[l]), ho_ds.get_chromosome_locations())

                # Run and store multiple GPs
                for j in range(args.run_gaussian):
                    print('GP run {}/{}...'.format(j, args.run_gaussian))
                    gp_trainer, gp_test_results, gp_ho_results = run_gp(device, train_set, test_set, ho_set)
                    if gp_test_results is not None:
                        gp_trainer.save_results(gp_test_results, gp_ho_results, lbl_grp, str(j))

    if args.save_training: accs_df.to_csv(os.path.join(out_dir, 'run_accuracies.csv'))

    print('Results summary for {} runs:\n {}'.format(args.nn_reruns, accs_df.describe()))

    print('Done!')


if __name__ == '__main__':
    main()
