#!/usr/bin/env python
import os
import sys
import h5py
import argparse
import datetime
import torch
import umap
import numpy as np
import pandas as pd
import forestci as fci
import statsmodels.api as sm
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.decomposition import PCA
from statsmodels.discrete.discrete_model import NegativeBinomial

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pytorch/trainers'))

from data_loader import *
from gp_trainer import *

epsilon = 1e-5


def get_cmd_arguments():
    ap = argparse.ArgumentParser()

    # Required cancer type argument
    ap.add_argument('-c', '--cancer-id', required=True, nargs='*', action='store', type=str, dest='label_ids',
                    help='A list of the h5 file mutation count dataset IDs (e.g. SNV_skin_melanoma_MELAU_AU)')

    # Path arguments
    ap.add_argument('-d', "--data", required=False, nargs='?', action='store', type=str, dest='data_file',
                    default='/storage/datasets/cancer/data_vecs_PCAWG_1000000_0_0.7.h5', help='Path to h5 data file')
    ap.add_argument('-o', "--out-dir", required=False, nargs='?', action='store', type=str, dest='out_dir',
                    default='/storage/yaari/mutation-density-outputs', help='Path to output directory')
    ap.add_argument('-u', "--held-out", required=False, nargs='?', action='store', type=str, dest='heldout_file',
                    default=None, help='Path to file of held-out samples file')
    ap.add_argument('-t', "--tracks", required=False, nargs='?', action='store', type=str, dest='track_file',
                    default=None, help='Path to predictor tracks selection file')

    # Run type parameters
    ap.add_argument('-s', "--split", required=False, nargs='?', action='store', type=str, dest='split_method',
                    default='random', help='Dataset split method (random/chr)')
    ap.add_argument('-r', "--train-ratio", required=False, nargs='?', action='store', type=float, dest='train_ratio',
                    default=0.8, help='Train set split size ratio')
    ap.add_argument('-m', "--dims", required=False, nargs='?', action='store', type=int, dest='dims',
                    default=16, help='Number of dimensions for PCA and UMAP')
    ap.add_argument('-gr', "--gp_reruns", required=False, nargs='?', action='store', type=int, dest='gp_reruns',
                    default=10, help='Number of GP reinitializations and training runs')
    ap.add_argument('-pc', "--run_pca", required=False, nargs='?', action='store', type=int, dest='run_pca',
                    default=1, help='Run GP prediction over PCA reduced vectors')
    ap.add_argument('-um', "--run_umap", required=False, nargs='?', action='store', type=int, dest='run_umap',
                    default=1, help='Run GP prediction over UMAP reduced vectors')
    ap.add_argument('-re', "--reruns", required=False, nargs='?', action='store', type=int, dest='reruns',
                    default=1, help='Number of models retraining runs')

    return ap.parse_args()


def run_gp(device, train_set, test_set, h5_grp, run_num):
    for j in range(run_num):
        print('GP run {}/{}...'.format(j, run_num))
        run_successeed = False
        n_inducing = 2000
        while not run_successeed and n_inducing > 0:
            gp_trainer = GPTrainer(device, train_set, test_set, n_inducing=n_inducing)
            try:
                print('Running GP with {} inducing points...'.format(n_inducing))
                gp_test_results, gp_ho_results = gp_trainer.run()
            except RuntimeError as err:
                print('Run failed with {} inducing points. Encountered run-time error in training: {}'
                      .format(n_inducing, err))
                n_inducing -= 200
                continue
            run_successeed = True
        if run_successeed:
            gp_trainer.save_results(gp_test_results, None, h5_grp, str(j))


def main():
    args = get_cmd_arguments()
    out_dir = os.path.join(args.out_dir, 'vector_predictions', str(datetime.datetime.now()))

    data_loader = DataLoader(args.data_file, args.label_ids, args.track_file)

    print('Saving results under: \'{}\''.format(out_dir))
    os.makedirs(out_dir)
    args_dict = vars(args)
    with open(os.path.join(out_dir, 'run_params.txt'), 'w') as f:
        [f.write('{}: {}\n'.format(k, args_dict[k])) for k in args_dict.keys()]

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    acc_df = pd.DataFrame()
    out_h5f = h5py.File(os.path.join(out_dir, 'gp_models_output.h5'), 'w')
    for l in args.label_ids:
        lbl_grp = out_h5f.create_group(l)
        for r in range(args.reruns):
            print('Run {}/{} over cancer type \'{}\'...'.format(r + 1, args.reruns, l))
            run_grp = lbl_grp.create_group('run_{}'.format(r))
            train_ds, test_ds = data_loader.get_datasets(args.split_method, args.train_ratio)
            run_grp.create_dataset('train_locs', data=train_ds['locs'])
            run_grp.create_dataset('test_locs', data=test_ds['locs'])

            if args.run_pca:
                print('Computing GP regression after PCA reduction...')

                pca = PCA(n_components=args.dims)
                train_pca_x = pca.fit_transform(np.log(train_ds['x'] + epsilon))
                test_pca_x = pca.transform(np.log(test_ds['x'] + epsilon))

                train_set = (train_pca_x, train_ds['y'][l], train_ds['locs'])
                test_set = (test_pca_x, test_ds['y'][l], test_ds['locs'])
                pca_grp = run_grp.create_group('pca_gp')
                run_gp(device, train_set, test_set, pca_grp, args.gp_reruns)

            if args.run_umap:
                print('Computing GP regression after UMAP reduction...')

                umapper = umap.UMAP(n_components=args.dims, n_neighbors=20, metric='euclidean')
                umapper.fit(np.log(train_ds['x'] + epsilon))
                train_umap_x = umapper.transform(np.log(train_ds['x'] + epsilon))
                test_umap_x = umapper.transform(np.log(test_ds['x'] + epsilon))

                train_set = (train_umap_x, train_ds['y'][l], train_ds['locs'])
                test_set = (test_umap_x, test_ds['y'][l], test_ds['locs'])
                umap_grp = run_grp.create_group('umap_gp')
                run_gp(device, train_set, test_set, umap_grp, args.gp_reruns)

        print()

    acc_df.to_csv(os.path.join(out_dir, 'predictions_r2.csv'))
    out_h5f.close()
    print('Done!')


if __name__ == '__main__':
    main()
