#!/usr/bin/env python
import os
import sys
import h5py
import argparse
import datetime
import torch
import pandas as pd
import numpy as np

from functools import partial
from multiprocessing import Pool, cpu_count
from joblib import Parallel, delayed
from datetime import datetime
startTime = datetime.now()
    
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pytorch/trainers'))

from data_loader import *
#from gp_trainer import *
from vec_models import *

epsilon = 1e-5


def get_cmd_arguments():
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')
    ap.add_argument('-m', '--models', required=True, nargs='*', action='store', type=str, dest='models',
                    help='A list of models to run (Random Forest = rf, Negative Binomial = nb, Binomial = bn, \
                    Linear regression = lr, PCA+GP = pc, UMAP+GP = um)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/data_vecs_PCAWG_1000000_0_0.7.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-u', "--held-out", required=False, nargs='?', action='store', type=str, dest='heldout_file',
                    default=None, help='Path to file of held-out samples file')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')
    ap.add_argument('-tn', "--trinuc", required=False, nargs='?', action='store', type=str, dest='trinuc_file',
                    default=None, help='Path to trinucleotide count file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-r', "--train-ratio", required=False, nargs='?', action='store', type=float, dest='train_ratio',
                    default=0.8, help='Train set split size ratio')
    ap.add_argument('-ec', "--estimators_const", required=False, nargs='?', action='store', type=int, dest='estimators_const',
                    default=20, help='Number of RF estimators constant (number of trees will be this divided by number of samples)')
    ap.add_argument('-en', "--estimators_number", required=False, nargs='?', action='store', type=int, dest='estimators_num',
                    default=-1, help='Number of RF estimators (if set estimator-const will be ignored)')
    ap.add_argument('-dp', "--depth", required=False, nargs='?', action='store', type=int, dest='max_depth',
                    default=50, help='RF maximum depth')
    ap.add_argument('-di', "--dims", required=False, nargs='?', action='store', type=int, dest='dims',
                    default=16, help='Number of dimensions for PCA and UMAP')
    ap.add_argument('-gr', "--gp_reruns", required=False, nargs='?', action='store', type=int, dest='gp_reruns',
                    default=10, help='Number of GP reinitializations and training runs')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='reruns',
                    default=1, help='Number of models retraining runs')

    return ap.parse_args()


def save_to_h5_group(parent_grp, model_id, pred_out):
    model_grp = parent_grp.create_group(model_id)
    for i in range(len(pred_out)):
        run_grp = model_grp.create_group(str(i))
        run_grp.create_dataset('mean', data=pred_out[i][0])
        run_grp.create_dataset('std', data=pred_out[i][1])
        run_grp.create_dataset('params', data=pred_out[i][2])
        run_grp.attrs['sklearn_r2'] = pred_out[i][3]
        run_grp.attrs['pearson_r2'] = pred_out[i][4]
        run_grp.create_dataset('train_locs', data=pred_out[i][5])
        run_grp.create_dataset('test_locs', data=pred_out[i][6])


def save_to_df(acc_df, pred_out, model_id, label):
    for r in range(len(pred_out)):
        acc_df.loc[r, '{}_{}_sklearn_r2'.format(label, model_id)] = pred_out[r][3]
        acc_df.loc[r, '{}_{}_pearson_r2'.format(label, model_id)] = pred_out[r][4]


def main():
    args = get_cmd_arguments()
    imp_models = ['rf', 'nb', 'bn', 'lr', 'po']#, 'pc', 'um']
    assert all(m in imp_models for m in args.models), 'Not all models in {} are implemented'.format(args.models)
    out_dir = os.path.join(args.out_dir, 'vector_predictions', str(datetime.now()))

    print('Saving results under: \'{}\''.format(out_dir))
    os.makedirs(out_dir)
    args_dict = vars(args)
    with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
        [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_cores = cpu_count()

    acc_df = pd.DataFrame()
    out_h5f = h5py.File(os.path.join(out_dir, 'vector_models_output.h5'), 'w')
    for l in args.label_ids:
        lbl_grp = out_h5f.create_group(l)
        pool = Pool(args.reruns)

        if 'rf' in args.models:
            func = partial(run_random_forest, args, l, num_cores // args.reruns)
            rf_out = pool.map(func, range(args.reruns))
            #rf_out = np.array(Parallel(n_jobs=args.reruns)(delayed(run_random_forest)
            #                                               (args, l, num_cores//args.reruns, r) for r in range(args.reruns)))
            save_to_h5_group(lbl_grp, 'random_forest', rf_out)
            save_to_df(acc_df, rf_out, 'random_forest', l)

        if 'nb' in args.models:
            func = partial(run_negative_binomial, args, l)
            nb_out = pool.map(func, range(args.reruns))
            #nb_out = np.array(Parallel(n_jobs=args.reruns)(delayed(run_negative_binomial)
            #                                               (args, l, r) for r in range(args.reruns)))
            save_to_h5_group(lbl_grp, 'negative_binomial', nb_out)
            save_to_df(acc_df, nb_out, 'negative_binomial', l)

        if 'bn' in args.models:
            func = partial(run_binomial_regression, args, l)
            bn_out = pool.map(func, range(args.reruns))
            #bn_out = np.array(Parallel(n_jobs=args.reruns)(delayed(run_binomial_regression)
            #                                               (args, l, r) for r in range(args.reruns)))
            save_to_h5_group(lbl_grp, 'binomial', bn_out)
            save_to_df(acc_df, bn_out, 'binomial', l)

        if 'lr' in args.models:
            func = partial(run_linear_regression, args, l)
            lr_out = pool.map(func, range(args.reruns))
            #lr_out = np.array(Parallel(n_jobs=args.reruns)(delayed(run_linear_regression)
            #                                               (args, l, r) for r in range(args.reruns)))
            save_to_h5_group(lbl_grp, 'linear_regression', lr_out)
            save_to_df(acc_df, lr_out, 'linear_regression', l)

        if 'po' in args.models:
            func = partial(run_poisson_regression, args, l)
            bn_out = pool.map(func, range(args.reruns))
            #bn_out = np.array(Parallel(n_jobs=args.reruns)(delayed(run_poisson_regression)
            #                                               (args, l, r) for r in range(args.reruns)))
            save_to_h5_group(lbl_grp, 'poisson', bn_out)
            save_to_df(acc_df, bn_out, 'poisson', l)

        '''
        if 'pc' in args.models:
            func = partial(run_pca_gp, args, l)
            pc_out = pool.map(func, range(args.reruns))
            #save_to_h5_group(lbl_grp, 'pca_gp', pc_out)
            #save_to_df(acc_df, pc_out, 'pca_gp', l)

        if 'um' in args.models:
            func = partial(run_umap_gp, args, l)
            um_out = pool.map(func, range(args.reruns))
            #save_to_h5_group(lbl_grp, 'umap_gp', um_out)
            #save_to_df(acc_df, um_out, 'umap_gp', l)
        '''
        
        #pool.close()
        #pool.join()

    acc_df.to_csv(os.path.join(out_dir, 'predictions_r2.csv'))
    out_h5f.close()
    print('Done!')
    print('Time elapsed: {}'.format(datetime.now() - startTime))

    return acc_df


if __name__ == '__main__':
    main()
