#!/usr/bin/env python
import os
import sys
import h5py
import argparse
import datetime
import torch
import pandas as pd
import numpy as np

from functools import partial
from multiprocessing import Pool, cpu_count
from joblib import Parallel, delayed

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pytorch/trainers'))

from data_loader import *
from gp_trainer import *
from vec_models import *

epsilon = 1e-5


def get_cmd_arguments():
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')
    ap.add_argument('-m', '--models', required=True, nargs='*', action='store', type=str, dest='models',
                    help='A list of models to run (Random Forest = rf, Negative Binomial = nb, Binomial = bn, \
                    Linear regression = lr, PCA+GP = pc, UMAP+GP = um)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/data_vecs_PCAWG_1000000_0_0.7.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-u', "--held-out", required=False, nargs='?', action='store', type=str, dest='heldout_file',
                    default=None, help='Path to file of held-out samples file')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')
    ap.add_argument('-tn', "--trinuc", required=False, nargs='?', action='store', type=str, dest='trinuc_file',
                    default=None, help='Path to trinucleotide count file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-k', "--k-fold", required=False, nargs='?', action='store', type=int, dest='k',
                    default=5, help='Number of folds')
    ap.add_argument('-ec', "--estimators", required=False, nargs='?', action='store', type=int, dest='estimators_const',
                    default=20, help='Number of RF estimators constant (number of trees will be this divided by number of samples)')
    ap.add_argument('-dp', "--depth", required=False, nargs='?', action='store', type=int, dest='max_depth',
                    default=50, help='RF maximum depth')
    ap.add_argument('-di', "--dims", required=False, nargs='?', action='store', type=int, dest='dims',
                    default=16, help='Number of dimensions for PCA and UMAP')
    ap.add_argument('-gr', "--gp_reruns", required=False, nargs='?', action='store', type=int, dest='gp_reruns',
                    default=10, help='Number of GP reinitializations and training runs')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='reruns',
                    default=1, help='Number of models retraining runs')

    return ap.parse_args()


def save_to_h5_group(parent_grp, model_id, pred_out):
    model_grp = parent_grp.create_group(model_id)
    for i in range(len(pred_out)):
        run_grp = model_grp.create_group(str(i))
        run_grp.create_dataset('mean', data=pred_out[i][0])
        run_grp.create_dataset('std', data=pred_out[i][1])
        run_grp.create_dataset('params', data=pred_out[i][2])
        run_grp.attrs['sklearn_r2'] = pred_out[i][3]
        run_grp.attrs['pearson_r2'] = pred_out[i][4]
        run_grp.create_dataset('train_locs', data=pred_out[i][5])
        run_grp.create_dataset('test_locs', data=pred_out[i][6])


def save_to_df(acc_df, pred_out, model_id, label):
    for r in range(len(pred_out)):
        acc_df.loc[r, '{}_{}_sklearn_r2'.format(label, model_id)] = pred_out[r][3]
        acc_df.loc[r, '{}_{}_pearson_r2'.format(label, model_id)] = pred_out[r][4]


def main():
    args = get_cmd_arguments()
    imp_models = ['rf', 'nb', 'bn', 'lr', 'po']#, 'pc', 'um']
    assert all(m in imp_models for m in args.models), 'Not all models in {} are implemented'.format(args.models)
    out_dir = os.path.join(args.out_dir, 'vector_predictions', str(datetime.datetime.now()))

    print('Saving results under: \'{}\''.format(out_dir))
    os.makedirs(out_dir)
    args_dict = vars(args)
    with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
        [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    num_cores = cpu_count()

    data_loader = DataLoader(args.data_file, args.label_ids, args.track_file)
    kfold_ds = data_loader.get_kfold_datasets(args.split_method, args.k)

    acc_df = pd.DataFrame()
    out_h5f = h5py.File(os.path.join(out_dir, 'vector_models_output.h5'), 'w')
    for l in args.label_ids:
        lbl_grp = out_h5f.create_group(l)
        out_dict = {m: [] for m in args.models}
        for i in range(args.k):
            if 'rf' in args.models:
                rf_out = run_random_forest(args, l, num_cores // args.reruns, i, kfold_ds)
                out_dict['rf'].append(rf_out)

            if 'nb' in args.models:
                nb_out = run_negative_binomial(args, l, i, kfold_ds)
                out_dict['nb'].append(nb_out)

            if 'bn' in args.models:
                bn_out = run_binomial_regression(args, l, i, kfold_ds)
                out_dict['bn'].append(bn_out)

            if 'lr' in args.models:
                lr_out = run_linear_regression(args, l, i, kfold_ds)
                out_dict['lr'].append(rf_out)

            if 'po' in args.models:
                po_out = run_poisson_regression(args, l, i, kfold_ds)
                out_dict['po'].append(po_out)

        [save_to_h5_group(lbl_grp, m_id, out_dict[m_id]) for m_id in out_dict.keys()]
        [save_to_df(acc_df, out_dict[m_id], m_id, l) for m_id in out_dict.keys()]

    acc_df.to_csv(os.path.join(out_dir, 'predictions_r2.csv'))
    out_h5f.close()

    print(acc_df.describe())
    print('Done!')


if __name__ == '__main__':
    main()
