import os
import sys
import umap
import h5py
import scipy
import torch
import numpy as np
import forestci as fci
import statsmodels.api as sm
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.decomposition import PCA
from statsmodels.discrete.discrete_model import NegativeBinomial, Poisson

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../region_model/trainers'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../sequence_model'))

from data_loader import *
#from gp_trainer import *
import nb_model

epsilon = 1e-5
samp_num_dict = {'PCAWG_Eso-AdenoCa_SNV': 98, 'PCAWG_Stomach-AdenoCA_SNV': 37, 'PCAWG_Skin-Melanoma_SNV': 70}


def get_datasets(args, kfold_data=None, k=None):
    if kfold_data is None:
        data_loader = DataLoader(args.data_file, args.label_ids, args.track_file)
        train_ds, test_ds = data_loader.get_datasets(args.split_method, args.train_ratio)
    else:
        test_ds = kfold_data[k]
        train_idxs = np.delete(np.arange(len(kfold_data)), k)
        train_ds = {'x': np.concatenate([kfold_data[i]['x'] for i in train_idxs])}
        train_ds['y'] = {l: np.concatenate([kfold_data[i]['y'][l] for i in train_idxs]) for l in args.label_ids}
        train_ds['locs'] = np.concatenate([kfold_data[i]['locs'] for i in train_idxs])
    return train_ds, test_ds


def get_infinitesimal_jackknife_variance(inbag, forest, n_estimators, X_test):
    simple_prediction = np.array([tree.predict(X_test) for tree in forest]).T
    centered_prediction = simple_prediction - simple_prediction.mean(axis=0)
    y_err = np.sum((np.dot(inbag-1, centered_prediction.T)/n_estimators)**2, 0)
    bias_adjustment = bias_correction(y_err, inbag, centered_prediction, n_estimators)
    return y_err, bias_adjustment


def bias_correction(V_IJ, inbag, pred_centered, n_trees):
    n_train_samples = inbag.shape[0]
    boot_var = np.square(pred_centered).sum(axis=1) / n_trees

    bias_correction = n_train_samples * boot_var / n_trees

    return bias_correction


def run_gp(device, train_set, test_set, h5_grp, run_num, model_id):
    h5f = h5py.File('run_{}.h5'.format(model_id))
    for j in range(run_num):
        print('GP run {}/{}...'.format(j, run_num))
        run_successeed = False
        n_inducing = 2000
        while not run_successeed and n_inducing > 0:
            gp_trainer = GPTrainer(device, train_set, test_set, n_inducing=n_inducing)
            try:
                print('Running GP with {} inducing points...'.format(n_inducing))
                gp_test_results, gp_ho_results = gp_trainer.run()
            except RuntimeError as err:
                print('Run failed with {} inducing points. Encountered run-time error in training: {}'
                      .format(n_inducing, err))
                n_inducing -= 200
                continue
            run_successeed = True
        if run_successeed:
            gp_trainer.save_results(gp_test_results, None, h5f, str(j))

    return h5f


def run_random_forest(args, l, n_jobs, r, kfold_data=None):
    train_ds, test_ds = get_datasets(args, kfold_data, r)
    print('Computing random forest regression for {}...'.format(l))
    # number of estimators needs to be at least to nqrt(n) for reliable confidence estimation
    if args.estimators_num > 0:
        tree_num = args.estimators_num
    else:
        tree_num = int(len(train_ds['x']) / args.estimators_const)
    print('Number of estimators: {}'.format(tree_num))
    rf = RandomForestRegressor(max_depth=args.max_depth, n_estimators=tree_num, n_jobs=n_jobs)
    rf.fit(train_ds['x'], train_ds['y'][l])
    y_pred = rf.predict(test_ds['x'])
    inbag = fci.calc_inbag(train_ds['x'].shape[0], rf)
    y_err, bias_adjustment = get_infinitesimal_jackknife_variance(inbag, rf, tree_num, test_ds['x'])
    #y_err = fci.random_forest_error(rf, np.log(train_ds['x'] + epsilon), np.log(test_ds['x'] + epsilon))

    std_pred = np.nan_to_num(np.sqrt(y_err - bias_adjustment))
    sklearn_r2 = r2_score(test_ds['y'][l], y_pred)
    pearson_r2 = scipy.stats.pearsonr(test_ds['y'][l], y_pred)[0]**2
    print('Random forest {}: sklearn R^2 = {}, pearson R^2 = {}'.format(l, sklearn_r2, pearson_r2))
    print('Number of zero confidence intervals: {}'.format(len(np.where(y_err - bias_adjustment <= 0)[0])))

    return y_pred, std_pred, rf.feature_importances_, sklearn_r2, pearson_r2, train_ds['locs'], test_ds['locs']


def run_negative_binomial(args, l, r, kfold_data=None):
    train_ds, test_ds = get_datasets(args, kfold_data, r)
    print('Computing negative binomial regression for {}...'.format(l))

    pca = PCA(n_components=90)
    train_pca_x = pca.fit_transform(np.log(train_ds['x'] + epsilon))
    test_pca_x = pca.transform(np.log(test_ds['x'] + epsilon))
    train_x_nb = sm.add_constant(train_pca_x)
    test_x_nb = sm.add_constant(test_pca_x)
    print(sum(pca.explained_variance_ratio_))
    y_train = train_ds['y'][l]
    y_test = test_ds['y'][l]

    if args.trinuc_file is not None:
        exp_train, exp_test = nb_model.expected_mutations_by_context(train_ds['locs'], test_ds['locs'], args.trinuc_file, N=1, key_prefix=l)
        nb = NegativeBinomial(y_train, train_x_nb, exposure=exp_train.values + 1e-16)
        nb_fit = nb.fit(method='bfgs', maxiter=1000)
        y_pred = nb_fit.predict(test_x_nb, exposure=exp_test.values + 1e-16)
    else:
        nb = NegativeBinomial(y_train, train_x_nb)
        nb_fit = nb.fit(method='bfgs', maxiter=1000)
        y_pred = nb_fit.predict(test_x_nb)
    std_pred = np.sqrt(y_pred + y_pred**2*np.exp(nb_fit.lnalpha))
    sklearn_r2 = r2_score(y_test, y_pred)
    pearson_r2 = scipy.stats.pearsonr(y_test, y_pred)[0]**2
    print('Negative binomial {}: sklearn R^2 = {}, pearson R^2 = {}'.format(l, sklearn_r2, pearson_r2))

    return y_pred, std_pred, nb_fit.params[1:-1], sklearn_r2, pearson_r2, train_ds['locs'], test_ds['locs']


def run_poisson_regression(args, l, r, kfold_data=None):
    train_ds, test_ds = get_datasets(args, kfold_data, r)
    print('Computing poisson regression for {}...'.format(l))

    pca = PCA(n_components=20)
    train_pca_x = pca.fit_transform(np.log(train_ds['x'] + epsilon))
    test_pca_x = pca.transform(np.log(test_ds['x'] + epsilon))
    train_x_nb = sm.add_constant(train_pca_x)
    test_x_nb = sm.add_constant(test_pca_x)
    print(sum(pca.explained_variance_ratio_))
    y_train = train_ds['y'][l]
    y_test = test_ds['y'][l]

    if args.trinuc_file is not None:
        exp_train, exp_test = nb_model.expected_mutations_by_context(train_ds['locs'], test_ds['locs'], args.trinuc_file, N=1, key_prefix=l)
        psn = Poisson(y_train, train_x_nb, exposure=exp_train.values)
        psn_fit = psn.fit(method='bfgs', maxiter=1000)
        y_pred = psn_fit.predict(test_x_nb, exposure=exp_test.values)
    else:
        psn = Poisson(y_train, train_x_nb)
        psn_fit = psn.fit(method='bfgs', maxiter=1000)
        y_pred = psn_fit.predict(test_x_nb)
    std_pred = np.sqrt(y_pred)
    sklearn_r2 = r2_score(y_test, y_pred)
    pearson_r2 = scipy.stats.pearsonr(y_test, y_pred)[0]**2
    print('Poisson {}: sklearn R^2 = {}, pearson R^2 = {}'.format(l, sklearn_r2, pearson_r2))

    return y_pred, std_pred, psn_fit.params[1:], sklearn_r2, pearson_r2, train_ds['locs'], test_ds['locs']


def get_window_size(data_path):
    split_path = data_path.split('_')
    if '100' in split_path:
        return 100
    elif '500' in split_path:
        return 500
    elif '1000' in split_path:
        return 1000
    elif '2000' in split_path:
        return 2000
    elif '5000' in split_path:
        return 5000
    elif '10000' in split_path:
        return 10000
    elif '25000' in split_path:
        return 25000
    elif '50000' in split_path:
        return 50000
    elif '75000' in split_path:
        return 75000
    elif '100000' in split_path:
        return 100000
    elif '1000000' in split_path:
        return 1000000
    else:
        raise Exception('Not sure about window scale from data path {}'.format(data_path))


def run_binomial_regression(args, l, r, kfold_data=None):
    train_ds, test_ds = get_datasets(args, kfold_data, r)
    print('Computing binomial regression for {}...'.format(l))

    pca = PCA(n_components=20)
    train_pca_x = pca.fit_transform(np.log(train_ds['x'] + epsilon))
    test_pca_x = pca.transform(np.log(test_ds['x'] + epsilon))
    train_x_bn = sm.add_constant(train_pca_x)
    test_x_bn = sm.add_constant(test_pca_x)
    print(sum(pca.explained_variance_ratio_))
    ws = get_window_size(args.data_file)
    samp_num = samp_num_dict[l]

    y_train = train_ds['y'][l] / (ws * samp_num)
    y_test = test_ds['y'][l] / (ws * samp_num)
    if args.trinuc_file is not None:
        exp_train, exp_test = nb_model.expected_mutations_by_context(train_ds['locs'], test_ds['locs'],
                                                                 args.trinuc_file, N=1, key_prefix=l)
        exp_denom = sum(exp_train) + sum(exp_test)
        binom = sm.GLM(y_train, train_x_bn, family=sm.families.Binomial(), offset=exp_train/exp_denom)
        binom_fit = binom.fit(method='bfgs', maxiter=1000)
        y_pred = binom_fit.predict(test_x_bn, offset=exp_test/exp_denom)
    else:
        binom = sm.GLM(y_train, train_x_bn, family=sm.families.Binomial())
        binom_fit = binom.fit(method='bfgs', maxiter=1000)
        y_pred = binom_fit.predict(test_x_bn)

    std_pred = np.sqrt(y_pred * (1 - y_pred))
    sklearn_r2 = r2_score(y_test, y_pred)
    pearson_r2 = scipy.stats.pearsonr(y_test, y_pred)[0]**2
    print('Binomial {}: sklearn R^2 = {}, pearson R^2 = {}'.format(l, sklearn_r2, pearson_r2))

    return y_pred, std_pred, binom_fit.params[1:], sklearn_r2, pearson_r2, train_ds['locs'], test_ds['locs']


def run_linear_regression(args, l, r, kfold_data=None):
    train_ds, test_ds = get_datasets(args, kfold_data, r)
    print('Computing linear regression for {}...'.format(l))
    lr = LinearRegression()
    lr.fit(np.log(train_ds['x'] + epsilon), train_ds['y'][l])
    y_pred = lr.predict(np.log(test_ds['x'] + epsilon))
    std_pred = np.ones(len(y_pred)) * np.sqrt(np.sum(test_ds['y'][l] - y_pred))
    sklearn_r2 = r2_score(test_ds['y'][l], y_pred)
    pearson_r2 = scipy.stats.pearsonr(test_ds['y'][l], y_pred)[0]**2
    print('Linear regression {}: sklearn R^2 = {}, pearson R^2 = {}'.format(l, sklearn_r2, pearson_r2))

    return y_pred, std_pred, lr.coef_, sklearn_r2, pearson_r2, train_ds['locs'], test_ds['locs']


def run_pca_gp(args, device_str, l, r):
    data_loader = DataLoader(args.data_file, args.label_ids, args.track_file)
    train_ds, test_ds = data_loader.get_datasets(args.split_method, args.train_ratio)
    print('Computing GP regression over after PCA reduction...')
    if device_str == 'cuda':
        device = torch.device(device_str + ':' + str(r % torch.cuda.device_count()))
    else:
        device = torch.device(device_str)
    pca = PCA(n_components=args.dims)
    train_pca_x = pca.fit_transform(np.log(train_ds['x'] + epsilon))
    test_pca_x = pca.transform(np.log(test_ds['x'] + epsilon))

    train_set = (train_pca_x, train_ds['y'][l], train_ds['locs'])
    test_set = (test_pca_x, test_ds['y'][l], test_ds['locs'])

    return run_gp(device, train_set, test_set, args.gp_reruns)


def run_umap_gp(args, device_str, l, r):
    data_loader = DataLoader(args.data_file, args.label_ids, args.track_file)
    train_ds, test_ds = data_loader.get_datasets(args.split_method, args.train_ratio)
    print('Computing GP regression over after UMAP reduction...')
    if device_str == 'cuda':
        device = torch.device(device_str + ':' + str(r % torch.cuda.device_count()))
    else:
        device = torch.device(device_str)
    umapper = umap.UMAP(n_components=args.dims, n_neighbors=20, metric='euclidean')
    umapper.fit(np.log(train_ds['x'] + epsilon))
    train_umap_x = umapper.transform(np.log(train_ds['x'] + epsilon))
    test_umap_x = umapper.transform(np.log(test_ds['x'] + epsilon))

    train_set = (train_umap_x, train_ds['y'][l], train_ds['locs'])
    test_set = (test_umap_x, test_ds['y'][l], test_ds['locs'])

    return run_gp(device, train_set, test_set, args.gp_reruns)
