#!/usr/bin/env python
import os
import sys
import h5py
import copy
import argparse
import datetime
import numpy as np
import pandas as pd
from torch import nn, optim
from tensorboardX import SummaryWriter

file_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(file_path, 'nets'))
sys.path.append(os.path.join(file_path, 'trainers'))
sys.path.append(os.path.join(file_path, 'data_aux'))
sys.path.append(os.path.join(file_path, '../sequence_model'))

from rnn_predictors import *
from cnn_predictors import *

from nn_trainer import *
from gp_trainer import *
from dataset_generator import *

from gp_tools import *

def get_cmd_arguments(text = None):
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/unzipped_data_matrices_pcawg_10k.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-m', "--mappability", required=False, nargs='?', action='store', type=float, dest='mappability',
                    default=0.7, help='Mappability lower bound')
    ap.add_argument('-a', "--attention", required=False, action='store_true', dest='get_attention',
                    help='True: train with attention map training and save attention maps')
    ap.add_argument('-gp', "--gaussian", required=False, nargs='?', action='store', type=int, dest='run_gaussian',
                    default=10, help='True: train gaussian process regression on the best performing model')

    # Train parameters
    ap.add_argument('-k', required=False, nargs='?', action='store', type=int, dest='k',
                    default=5, help='Number of folds')
    ap.add_argument('-e', "--epochs", required=False, nargs='?', action='store', type=int, dest='epochs',
                    default=20, help='Number of epochs')
    ap.add_argument('-b', "--batch", required=False, nargs='?', action='store', type=int, dest='bs',
                    default=128, help='Batch size')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='nn_reruns',
                    default=1, help='Number of model reinitializations and training runs')

    # Run management parameters
    ap.add_argument('-sm', "--save-model", required=False, action='store_true', dest='save_model',
                    help='True: save best model across all reruns')
    ap.add_argument('-st', "--save-training", required=False, action='store_true', dest='save_training',
                    help='True: save training process and results to Tensorboard file')
    ap.add_argument('-g', "--gpus", required=False, nargs='?', action='store', type=str, dest='gpus',
                    default='all', help='GPUs devices (all/comma separted list)')
    ap.add_argument('-u', "--sub_mapp", required=False,  action='store_true',  dest='sub_mapp',
                    help='True: run model on regions below mappability threshold')
    
    if text:
        args = ap.parse_args(text.split())
    else:
        args = ap.parse_args()

    return args


def predict(model, data_ds, label_ids):
    data_loader = DataLoader(data_ds, batch_size=128, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)
    corr_coef_sums = np.zeros(len(label_ids))
    all_preds = [[] for _ in range(len(label_ids))]
    all_features = [[] for _ in range(len(label_ids))]
    all_true = [[] for _ in range(len(label_ids))]
    all_att = []
    for j, (X, t_lst) in enumerate(data_loader):
        y_lst, features_lst, attention = model(X.cuda())
        all_att.append(attention.cpu().detach().numpy())
        with torch.no_grad():
            for i, t in enumerate(t_lst):
                y = y_lst[i]
                feature_vecs = features_lst[i]
                all_features[i].append(feature_vecs.cpu().detach().numpy())
                all_preds[i].extend(y.data.cpu().numpy().tolist())
                all_true[i].extend(t.data.cpu().numpy().tolist())
    all_features = [np.concatenate(all_features[j], axis=0) for j in range(len(all_features))]
    return all_preds, all_true, all_features, \
        [r2_score(all_true[i], all_preds[i]) for i in range(len(label_ids))], \
        np.concatenate(all_att, axis=0)


def run_gp(device, train_set, test_set, ho_set=None):
    run_successeed = False
    n_inducing = 2000
    while not run_successeed and n_inducing > 0:
        gp_trainer = GPTrainer(device, train_set, test_set, heldout_tup=ho_set, n_inducing=n_inducing)
        try:
            print('Running GP with {} inducing points...'.format(n_inducing))
            gp_test_results, gp_ho_results = gp_trainer.run()
            pvals = calc_pvals(np.array(test_set[1]),
                               gp_test_results['gp_mean'],
                               gp_test_results['gp_std'],
                               onesided=False)
            print('Test set calibration score: {}'.format(calibration_score_by_pvals(pvals)))
        except RuntimeError as err:
            print('Run failed with {} inducing points. Encountered run-time error in training: {}'
                  .format(n_inducing, err))
            n_inducing -= 200
            continue
        run_successeed = True
    if run_successeed: return gp_trainer, gp_test_results, gp_ho_results
    return None, None, None


def main(input_args = None):
    if input_args is None:
        args = get_cmd_arguments()
    else:
        args = input_args
    
    labels_str = '-'.join(args.label_ids)
    out_dir = os.path.join(args.out_dir, 'kfold', labels_str, str(datetime.datetime.now()))
    print('Generating prediction for cancer types: {}'.format(args.label_ids))

    if args.gpus is None:
        print('Using CPU device.')
        device = torch.device('cpu')
    else:
        print('Using GPU device: \'{}\''.format(args.gpus))
        device = torch.device('cuda')
        if args.gpus != 'all':
            os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    os.makedirs(out_dir)
    args_dict = vars(args)
    with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
        [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    best_model_file = os.path.join(out_dir, 'best_model_fold_{}.pt')
    test_set_file = os.path.join(out_dir, 'test_indices_fold_{}')
    preds_csv = os.path.join(out_dir, 'kfold_predictions.csv')

    if args.save_model or args.save_training:
        print('Saving results under: \'{}\''.format(out_dir))

    data_generator = KFoldDatasetGenerator(args.data_file,
                                           args.label_ids,
                                           args.k,
                                           args.mappability,
                                           args.split_method,
                                           track_file=args.track_file)

    print('Running {}-fold prediction...'.format(args.k))
    for k in range(args.k):
        train_ds, test_ds = data_generator.get_datasets(k)
        best_overall_acc = 0
        for r in range(args.nn_reruns):
            print('Setting model and optimizers for run {}/{} and fold {}/{}...'.format(r + 1, args.nn_reruns, k + 1, args.k))
            model = SimpleMultiTaskResNet(train_ds.get_data_shape(), len(args.label_ids), get_attention_maps=args.get_attention)
            optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=False)
            loss_fn = nn.MSELoss()
            if args.gpus is not None: model = nn.DataParallel(model)

            if args.save_training:
                writer = SummaryWriter(logdir=out_dir, comment=labels_str)
                writer.add_text('configurations', str(args), 0)
                writer.add_text('model', str(model), 0)
            else:
                writer = None
            trainer = NNTrainer(model,
                                optimizer,
                                loss_fn,
                                args.bs,
                                args.label_ids,
                                train_ds,
                                test_ds,
                                device,
                                writer,
                                get_attention_maps=args.get_attention)

            best_run_acc = 0
            for epoch in range(1, args.epochs + 1):
                print('Running epoch {}/{}'.format(epoch, args.epochs))
                train_losses, train_accs, train_features_lst, train_pred_lst, train_true_lst = trainer.train(epoch, r)
                test_losses, test_accs, test_features_lst, test_pred_lst, test_true_lst, test_attention = trainer.test(epoch, r)

                # Test GP over the new feature vectors, ignore run if GP failes
                for l in range(len(args.label_ids)):
                    train_set = (np.array(train_features_lst[l]), np.array(train_true_lst[l]), train_ds.get_chromosome_locations())
                    test_set = (np.array(test_features_lst[l]), np.array(test_true_lst[l]), test_ds.get_chromosome_locations())
                    gp_trainer, gp_test_results, gp_ho_results = run_gp(device, train_set, test_set)
                    if gp_test_results is None:
                        print('GP Run failed, skipping to next epoch.')
                        continue

                if test_accs[0] > best_run_acc:
                    print('Best model validation accuracy is now: {}'.format(best_run_acc))
                    best_run_acc = test_accs[0]
                    best_run_model, best_run_att = copy.deepcopy(model), test_attention
                    #best_run_model, best_run_preds = copy.deepcopy(model), test_pred_lst
                    #best_run_train_feat_lst, best_run_test_feat_lst = train_features_lst, test_features_lst

            if best_run_acc > best_overall_acc:
                best_overall_acc = best_run_acc,
                best_overall_model = best_run_model
                #best_overall_model, best_overall_pred = best_run_model, best_run_preds
                #best_overall_train_feat_lst, best_overall_test_feat_lst = best_run_train_feat_lst, best_run_test_feat_lst

            print('Best test accuracy for run {}/{} was: {}.'.format(r + 1, args.nn_reruns, best_run_acc))
        print('Best overall accuract over {} reruns was: {}.'.format(args.nn_reruns, best_overall_acc))

        # Save attention maps from best overall model
        if args.get_attention:
            with h5py.File(os.path.join(out_dir, 'attention_maps_{}.h5'.format(k)), 'w') as h5f:
                h5f.create_dataset('attention_maps', data=best_run_att)
                h5f.create_dataset('chr_locs', data=test_ds.get_chromosome_locations())
                h5f.create_dataset('idxs', data=test_ds.get_set_indices())
                h5f.create_dataset('pred_lbls', data=test_pred_lst)
                h5f.create_dataset('true_lbls', data=test_true_lst)

         # Save best run model
        if args.save_model:
            print('Saving model and test indices for future evaluations to {}...'.format(test_set_file))
            np.save(test_set_file.format(k), test_ds.get_set_indices())
            torch.save(best_overall_model.state_dict(), best_model_file.format(k))

        # Run GP on best overall model
        if args.run_gaussian > 0:
            print('Computing {} train set features...'.format(train_ds.get_data_shape()[0]))
            train_preds, train_labels, train_features, train_acc, _ = predict(best_overall_model, train_ds, args.label_ids)
            print('Model train accuracy: {}'.format(train_acc))
            print('Computing {} validation set features...'.format(test_ds.get_data_shape()[0]))
            test_preds, test_labels, test_features, test_acc, _ = predict(best_overall_model, test_ds, args.label_ids)
            print('Model validation accuracy: {}'.format(test_acc))
            gp_h5 = h5py.File(os.path.join(out_dir, 'gp_results_fold_{}.h5'.format(k)), 'w')
            for l in range(len(args.label_ids)):
                print('Running gaussian process model for {}...'.format(args.label_ids[l]))
                lbl_grp = gp_h5.create_group(args.label_ids[l])
                train_set = (np.array(train_features[l]), np.array(train_labels[l]), train_ds.get_chromosome_locations())
                test_set = (np.array(test_features[l]), np.array(test_labels[l]), test_ds.get_chromosome_locations())

                # Run and store multiple GPs
                for j in range(args.run_gaussian):
                    print('GP run {}/{}...'.format(j, args.run_gaussian))
                    gp_trainer, gp_test_results, gp_ho_results = run_gp(device, train_set, test_set)
                    if gp_test_results is not None:
                        gp_trainer.save_results(gp_test_results, gp_ho_results, lbl_grp, str(j))

        if args.sub_mapp:
            print('Running model on sub-mappabbility-theshold regions')
            sub_ds = data_generator.get_below_mapp()
            print('Computing {} train set features...'.format(train_ds.get_data_shape()[0]))
            train_preds, train_labels, train_features, train_acc, _ = predict(best_overall_model, train_ds, args.label_ids)
            print('Model train accuracy: {}'.format(train_acc))
            print('Computing {} sub-theshold features...'.format(test_ds.get_data_shape()[0]))
            sub_preds, sub_labels, sub_features, sub_acc, sub_attention = predict(best_overall_model, sub_ds, args.label_ids)
            # Save attention maps from unmappable regions
            sub_att_path = os.path.join(out_dir, 'attention_maps_submapp.h5')
            if args.get_attention and not os.path.exists(sub_att_path):
                with h5py.File(sub_att_path, 'w') as h5f:
                    h5f.create_dataset('attention_maps', data=sub_attention)
                    h5f.create_dataset('chr_locs', data=sub_ds.get_chromosome_locations())
                    h5f.create_dataset('idxs', data=sub_ds.get_set_indices())
                    h5f.create_dataset('pred_lbls', data=sub_preds)
                    h5f.create_dataset('true_lbls', data=sub_labels)

            print('Model accuracy on sub-theshold regions: {}'.format(sub_acc))
            gp_h5 = h5py.File(os.path.join(out_dir, 'sub_mapp_results_fold_{}.h5'.format(k)), 'w')
            for l in range(len(args.label_ids)):
                print('Running gaussian process model for {} on below threshold regions...'.format(args.label_ids[l]))
                lbl_grp = gp_h5.create_group(args.label_ids[l])
                train_set = (np.array(train_features[l]), np.array(train_labels[l]), train_ds.get_chromosome_locations())
                sub_set = (np.array(sub_features[l]), np.array(sub_labels[l]), sub_ds.get_chromosome_locations())

                # Run and store multiple GPs
                for j in range(args.run_gaussian):
                    print('GP run {}/{}...'.format(j, args.run_gaussian))
                    gp_trainer, gp_test_results, gp_ho_results = run_gp(device, train_set, sub_set)
                    if gp_test_results is not None:
                        gp_trainer.save_results(gp_test_results, gp_ho_results, lbl_grp, str(j))
    print('Done!')


if __name__ == '__main__':
    main()
