#!/usr/bin/env python
import os
import sys
import h5py
import copy
import argparse
import datetime
import numpy as np
import pandas as pd
from torch import nn, optim
from tensorboardX import SummaryWriter

file_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(file_path, 'nets'))
sys.path.append(os.path.join(file_path, 'trainers'))
sys.path.append(os.path.join(file_path, 'data_aux'))
sys.path.append(os.path.join(file_path, '../sequence_model'))

from cnn_predictors import *

from nn_trainer import *
from dataset_generator import *

def get_cmd_arguments():
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/unzipped_data_matrices_PCAWG_10000_0_0.0.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-u', "--held-out", required=False, nargs='?', action='store', type=str, dest='heldout_file',
                    default=None, help='Path to file of held-out samples file')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-m', "--mappability", required=False, nargs='?', action='store', type=float, dest='mappability',
                    default=0.7, help='Mappability lower bound')
    ap.add_argument('-a', "--attention", required=False, nargs='?', action='store', type=bool, dest='get_attention',
                    default=False, help='True: train with attention map training')
    ap.add_argument('-n', "--network", required=False, nargs='?', action='store', type=str, dest='net',
                    default='cnn', help='The type of neural network model to use (\'fc\' or \'cnn\')')

    # Train parameters
    ap.add_argument('-r', "--train-ratio", required=False, nargs='?', action='store', type=float, dest='train_ratio',
                    default=0.8, help='Train set split size ratio')
    ap.add_argument('-ho', "--heldout-ratio", required=False, nargs='?', action='store', type=float, dest='heldout_ratio',
                    default=0.2, help='Held-out set split size ratio (will be extracted prior to train validation split)')
    ap.add_argument('-e', "--epochs", required=False, nargs='?', action='store', type=int, dest='epochs',
                    default=20, help='Number of epochs')
    ap.add_argument('-b', "--batch", required=False, nargs='?', action='store', type=int, dest='bs',
                    default=128, help='Batch size')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='nn_reruns',
                    default=1, help='Number of NN reinitializations and training runs')

    # Run management parameters
    ap.add_argument('-sm', "--save-model", required=False, nargs='?', action='store', type=bool, dest='save_model',
                    default=False, help='True: save best model across all reruns')
    ap.add_argument('-st', "--save-training", required=False, nargs='?', action='store', type=float, dest='save_training',
                    default=False, help='True: save training process and results to Tensorboard file')
    ap.add_argument('-g', "--gpus", required=False, nargs='?', action='store', type=str, dest='gpus',
                    default='all', help='GPUs devices (all/comma separted list)')

    return ap.parse_args()

def predict(model, data_ds, label_ids):
    data_loader = DataLoader(data_ds, batch_size=128, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)
    corr_coef_sums = np.zeros(len(label_ids))
    all_preds = [[] for _ in range(len(label_ids))]
    all_features = [[] for _ in range(len(label_ids))]
    all_true = [[] for _ in range(len(label_ids))]
    for j, (X, t_lst) in enumerate(data_loader):
        y_lst, features_lst, _ = model(X.cuda())
        with torch.no_grad():
            for i, t in enumerate(t_lst):
                y = y_lst[i]
                feature_vecs = features_lst[i]
                all_features[i].append(feature_vecs.cpu().detach().numpy())
                all_preds[i].extend(y.data.cpu().numpy().tolist())
                all_true[i].extend(t.data.cpu().numpy().tolist())
    all_features = [np.concatenate(all_features[j], axis=0) for j in range(len(all_features))]
    return all_preds, all_true, all_features, [r2_score(all_true[i], all_preds[i]) for i in range(len(label_ids))]



def main():
    args = get_cmd_arguments()
    labels_str = '-'.join(args.label_ids)
    out_dir = os.path.join(args.out_dir, labels_str, str(datetime.datetime.now()))
    print('Generating prediction for cancer types: {}'.format(args.label_ids))
    # Configure GPUs
    if args.gpus is None:
        print('Using CPU device.')
        device = torch.device('cpu')
    else:
        print('Using GPU device: \'{}\''.format(args.gpus))
        device = torch.device('cuda')
        if args.gpus != 'all':
            os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    # Create output directory
    if args.save_model or args.save_training or args.get_attention:
        print('Saving results under: \'{}\''.format(out_dir))
        os.makedirs(out_dir)
        args_dict = vars(args)
        with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
            [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    # Train model for multiple reruns and choose best as final
    best_overall_acc = 0
    accs_df = pd.DataFrame()
    for r in range(args.nn_reruns):
        # Intialize new dataset split
        data_generator = DatasetGenerator(args.data_file,
                                          args.label_ids,
                                          args.mappability,
                                          args.heldout_ratio,
                                          heldout_file=args.heldout_file,
                                          track_file=args.track_file)
        train_ds, test_ds = data_generator.get_datasets(args.split_method, args.train_ratio)
        ho_ds = data_generator.get_heldout_dataset()
        print('Using {} predictors for prediction.'.format(train_ds.get_data_shape()[2]))

        # Initialize a new model
        print('Setting model and optimizers for run {}/{}...'.format(r + 1, args.nn_reruns))
        if args.net == 'fc':
            model = FCNet(train_ds.get_data_shape(), len(args.label_ids))
        else:
            model = SimpleMultiTaskResNet(train_ds.get_data_shape(),
                                          len(args.label_ids),
                                          get_attention_maps=args.get_attention)

        optimizer = optim.Adam(model.parameters(), lr=1e-3, amsgrad=False)
        loss_fn = nn.MSELoss()
        if args.gpus is not None: model = nn.DataParallel(model)

        if args.save_training:
            writer = SummaryWriter(logdir=out_dir, comment=labels_str)
            writer.add_text('configurations', str(args), 0)
            writer.add_text('model', str(model), 0)
        else:
            writer = None
        nn_trainer = NNTrainer(model,
                               optimizer,
                               loss_fn,
                               args.bs,
                               args.label_ids,
                               train_ds,
                               test_ds,
                               device,
                               writer,
                               get_attention_maps=args.get_attention)

        # Run datasplit training and evaluation
        best_run_acc = 0
        best_epoch = 0
        for epoch in range(1, args.epochs + 1):
            print('Running epoch {}/{}'.format(epoch, args.epochs))
            train_losses, train_obs_accs, train_var_accs, train_features_lst, train_pred_lst, train_true_lst = nn_trainer.train(epoch, r)
            test_losses, test_obs_accs, test_var_accs, test_features_lst, test_pred_lst, test_true_lst, test_attention = nn_trainer.test(epoch, r)

            '''
            # Keep only the best model according to test performance 
            if test_accs[0] > best_run_acc:
                print('Changing run model since best R2 was {} compared to previous {}'.format(test_obs_accs[0], best_run_acc))
                best_run_acc, best_epoch = test_obs_accs[0], epoch
                best_train_accs, best_test_accs = train_accs, test_accs
                best_run_model, best_run_att = copy.deepcopy(model), test_attention
            
        # Evaluate model performance over held-out set
        print('Best validation accuracy for run {}/{} was: {}.'.format(r + 1, args.nn_reruns, best_run_acc))
        print('Running best model over {} held-out set samples...'.format(ho_ds.get_data_shape()[0]))
        ho_preds, ho_labels, ho_features, ho_accs = predict(best_run_model, ho_ds, args.label_ids)
        print('Model held-out accuracy: {}'.format(ho_accs))

        # Save run performance
        for j, l in enumerate(args.label_ids):
            accs_df.loc[r, 'Train_{}'.format(l)] = best_train_accs[j]
            accs_df.loc[r, 'Test_{}'.format(l)] = best_test_accs[j]
            accs_df.loc[r, 'Held-out_{}'.format(l)] = ho_accs[j]

        # Save attention maps from best overall model
        if args.get_attention:
            with h5py.File(os.path.join(out_dir, 'attention_maps_{}.h5'.format(r)), 'w') as h5f:
                h5f.create_dataset('attention_maps', data=best_run_att)
                h5f.create_dataset('chr_locs', data=test_ds.get_chromosome_locations())
                h5f.create_dataset('idxs', data=test_ds.get_set_indices())
                h5f.create_dataset('pred_lbls', data=test_pred_lst)
                h5f.create_dataset('true_lbls', data=test_true_lst)

        # Save best run model
        if args.save_model:
            print('Saving model and test indices from run {} to {}...'.format(r, out_dir))
            np.save(os.path.join(out_dir, 'test_indices_{}'.format(r)), ho_ds.get_set_indices())
            torch.save(best_run_model.state_dict(), os.path.join(out_dir, 'best_model_{}.pt'.format(r)))


    if args.save_training: accs_df.to_csv(os.path.join(out_dir, 'run_accuracies.csv'))

    print('Results summary for {} runs:\n {}'.format(args.nn_reruns, accs_df.describe()))
            '''
    print('Done!')


if __name__ == '__main__':
    main()
