import h5py
import pickle as pkl
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scipy.stats
import itertools as it
import multiprocessing as mp

from sklearn import linear_model
from sklearn.metrics import r2_score
from sklearn import preprocessing
from sklearn import decomposition

import simulate_detection

def load_data_cnn(f, cancer):
    ## Load data
    data_pred = h5py.File(f, 'r')

    y_hld_true = data_pred[cancer]['held-out']['y_true'][:]
    y_train_true = data_pred[cancer]['train']['y_true'][:]

    hld_feat = data_pred[cancer]['held-out']['nn_features'][:]
    train_feat = data_pred[cancer]['train']['nn_features'][:]

    hld_idx = data_pred[cancer]['held-out']['chr_locs'][:]
    train_idx = data_pred[cancer]['train']['chr_locs'][:]

    gp_mean_lst = [data_pred[cancer]['held-out'][str(i)]['mean'][:] for i in range(10)]
    gp_std_lst = [data_pred[cancer]['held-out'][str(i)]['std'][:] for i in range(10)]

    gp_mean_nd = np.vstack(gp_mean_lst)
    gp_mean = np.median(gp_mean_nd, axis=0)

    gp_std_nd = np.vstack(gp_std_lst)
    gp_std = np.median(gp_std_nd, axis=0)
    
    data_pred.close()

    return y_train_true, train_feat, train_idx, \
           y_hld_true, hld_feat, hld_idx, \
           gp_mean, gp_std

def load_data_vector(f, cancer, model, run):
    data_pred = h5py.File(f, 'r')

    if model in ('pca_gp', 'umap_gp', 'ae_gp'):
        y_test = data_pred[cancer][run][model]['test']['y_true'][:]
        y_train = data_pred[cancer][run][model]['train']['y_true'][:]

        test_feat = data_pred[cancer][run][model]['test']['nn_features'][:]
        train_feat = data_pred[cancer][run][model]['train']['nn_features'][:]

        test_idx = data_pred[cancer][run][model]['test']['chr_locs'][:]
        train_idx = data_pred[cancer][run][model]['train']['chr_locs'][:]

        model_mean_lst = [data_pred[cancer][run][model]['test'][str(i)]['mean'][:] for i in range(10)]
        model_std_lst = [data_pred[cancer][run][model]['test'][str(i)]['std'][:] for i in range(10)]

        model_mean_nd = np.vstack(model_mean_lst)
        model_mean = np.mean(model_mean_nd, axis=0)

        model_std_nd = np.vstack(model_std_lst)
        model_std = np.mean(model_std_nd, axis=0)

    else:
        y_test = None
        y_train = None
        
        test_feat = None
        train_feat = None

        test_idx = data_pred[cancer][run]['test_locs'][:]
        train_idx = data_pred[cancer][run]['train_locs'][:]

        model_mean = data_pred[cancer][run][model]['mean'][:]
        model_std = data_pred[cancer][run][model]['std'][:]
    
    data_pred.close()

    return y_train, train_feat, train_idx, \
           y_test, test_feat, test_idx, \
           model_mean, model_std

def estimate_empirical_var_cnn(fbase, run, cancer):
    f = fbase.format(run)
    y_train_true, train_feat, _, y_hld_true, hld_feat, _, gp_mean, gp_std = load_data_cnn(f, cancer)
    estd_test, emean_test = _estimate_emp_var(hld_feat, train_feat, y_train_true)

    print(scipy.stats.pearsonr(y_hld_true, gp_mean)[0]**2)
    print(scipy.stats.pearsonr(emean_test, gp_mean)[0]**2)
    print(scipy.stats.pearsonr(estd_test**2, gp_std**2)[0]**2)

    return gp_mean, gp_std, y_hld_true, emean_test, estd_test

def estimate_empirical_var_vector(f, cancer, model, run):
    y_train, train_feat, y_test, test_feat, model_mean, model_std = load_data_vector(f, cancer, model, run)
    estd_test, emean_test = _estimate_emp_var(test_feat, train_feat, y_train)

    print(scipy.stats.pearsonr(y_test, model_mean)[0]**2)
    print(scipy.stats.pearsonr(emean_test, model_mean)[0]**2)
    print(scipy.stats.pearsonr(estd_test**2, model_std**2)[0]**2)

    return model_mean, model_std, y_test, emean_test, estd_test

def _estimate_emp_var(test_feat, train_feat, y_train):
    dist = scipy.spatial.distance_matrix(test_feat, train_feat)
    dist_ix = np.argsort(dist, axis=1)
    estd_gp = y_train[dist_ix[:, 0:500]].std(axis=1)
    emean_gp = y_train[dist_ix[:, 0:500]].mean(axis=1)

    return estd_gp, emean_gp

def sim_nb_kernel(emp_mean, emp_std, model_mean, model_std):
    sim = simulate_detection.SimDetection(emp_mean, emp_std**2)
    pr_miscal, tp_rate, r2_mean, r2_var = sim.simulate_batch_nb_from_model(model_mean, model_std**2, 
                                                                           pdriver=0.1, nsamp=100, 
                                                                           nsim=1000000, seed=37)

    return tp_rate, pr_miscal

def sim_nb_kernel2(emp_mean, emp_std, model_mean, model_std, pdriver=0.1, nsamp=100, pz = 0.01, p_tilde=1, nsim=1000000):
    sim = simulate_detection.SimDetection(emp_mean, emp_std**2)
    tp_true, fn_true, fp_true, f1_true, mcc_true,  tp_pred, fn_pred, fp_pred, f1_pred, mcc_pred = sim.simulate_batch_nb_from_model2(model_mean, model_std**2, 
                                                                           pdriver=pdriver, nsamp=nsamp, 
                                                                           nsim=nsim, seed=37, p_tilde=p_tilde, pz=pz,)

    return tp_true, fn_true, fp_true, f1_true, mcc_true, tp_pred, fn_pred, fp_pred, f1_pred, mcc_pred

    # pr_miscal, tp_rate, r2_mean, r2_var, fp_true, fp_pred = sim.simulate_batch_nb_from_model(model_mean, model_std**2, 
    #                                                                        pdriver=0.1, nsamp=100, 
    #                                                                        nsim=1000000, seed=37, debug=True)

    # return tp_rate, pr_miscal, fp_true, fp_pred

def sim_normal_kernel(emp_mean, emp_std, model_mean, model_std):
    sim = simulate_detection.SimDetection(emp_mean, emp_std**2)
    pr_miscal, tp_rate, r2_mean, r2_var = sim.simulate_batch_normal_from_model(model_mean, model_std**2, 
                                                                           pdriver=0.1, nsamp=100, 
                                                                           nsim=1000000, seed=37)

    return tp_rate, pr_miscal
