import pandas as pd
import numpy as np
import scipy.stats
import h5py
import seaborn as sns
import matplotlib.pyplot as plt

def load_ensemble(f, cancer=None, split='test'):
    ## Load data
    data_pred = h5py.File(f, 'r')
    if cancer:
        dset = data_pred[cancer]
    else:
        dset = data_pred

    try:
        reruns = len([key for key in dset[split].keys() if len(key)==1]) ## NOTE: bad way to find integers used as keys
        train_idx = dset['train']['chr_locs'][:]
        y_true = dset[split]['y_true'][:].reshape(-1, 1)
        idx = dset[split]['chr_locs'][:]
        gp_mean_lst = [dset[split][str(i)]['mean'][:] for i in range(reruns)]
        gp_std_lst = [dset[split][str(i)]['std'][:] for i in range(reruns)]

    except:
        reruns = len([key for key in dset[split].keys() if key.startswith('gp_mean')])
        train_idx = dset['train']['idxs'][:]
        y_true = dset[split]['true'][0, :].reshape(-1, 1)
        idx = dset[split]['idxs'][:]
        gp_mean_lst = [dset[split]['gp_mean_{:02d}'.format(run)][:] for run in range(1, reruns-1)]
        gp_std_lst = [dset[split]['gp_std_{:02d}'.format(run)][:] for run in range(1, reruns-1)]

    gp_mean_nd = np.vstack(gp_mean_lst)
    gp_mean = np.median(gp_mean_nd, axis=0).reshape(-1, 1)

    gp_std_nd = np.vstack(gp_std_lst)
    gp_std = np.median(gp_std_nd, axis=0).reshape(-1, 1)
    
    data_pred.close()

    return train_idx, y_true, idx, gp_mean, gp_std

def load_run(f, run, cancer=None, split='test'):
    hf = h5py.File(f, 'r')
    if cancer:
        dset = hf[cancer]
    else:
        dset = hf

    try:
        train_idx = dset['train']['chr_locs'][:]
        test_Y = dset[split]['y_true'][:].reshape(-1, 1)
        test_idx = dset[split]['chr_locs'][:]
        test_Yhat = dset[split]['{}'.format(run)]['mean'][:].reshape(-1, 1)
        test_std = dset[split]['{}'.format(run)]['std'][:].reshape(-1, 1)
    except:
        train_idx = dset['train']['idxs'][:]
        test_Y = dset[split]['true'][0, :].reshape(-1, 1)
        test_idx = dset[split]['idxs'][:]
        test_Yhat = dset[split]['gp_mean_{:02d}'.format(run)][:].reshape(-1, 1)
        test_std = dset[split]['gp_std_{:02d}'.format(run)][:].reshape(-1, 1)

    hf.close()
    return train_idx, test_Y, test_idx, test_Yhat, test_std

def load_fold(f, cancer=None, run=None, split='test', reruns=10):
    if run == None:
        run = pick_gp_by_calibration(f, cancer=cancer, dataset=split)

    if run=='ensemble':
        train_idx, test_Y, test_idx, test_Yhat, test_std = load_ensemble(f, cancer=cancer, split=split)

    else:
        train_idx, test_Y, test_idx, test_Yhat, test_std = load_run(f, run, cancer=cancer, split=split)

    vals = np.hstack([test_idx, test_Y, test_Yhat, test_std])
    df = pd.DataFrame(vals, columns=['CHROM', 'START', 'END', 'Y_TRUE', 'Y_PRED', 'STD'])

    return df

def plot_qq_log(pvals, label='', ax=None, rasterized=False):
    if not ax:
        f, ax = plt.subplots(1, 1)
    exp = -np.log10(np.arange(1, len(pvals) + 1) / len(pvals))
    pvals_log10_sort = -np.log10(np.sort(pvals))
    
    ax.plot(exp, pvals_log10_sort, '.', label=label, rasterized=rasterized)
    ax.plot(exp, exp, 'r-')

    if label:
        ax.legend()

def plot_qq(pvals, label='', ax=None, rasterized=False):
    if not ax:
        f, ax = plt.subplots(1, 1)
    exp  = (np.arange(1, len(pvals) + 1) / len(pvals))
    pvals_sort = np.sort(pvals)
    
    ax.plot(exp, pvals_sort, '.', label=label, rasterized=rasterized)
    ax.plot(exp, exp, 'r-')

    if label:
        ax.legend()

def calc_pvals(test_Y, gp_mean, gp_std, cutoff=1e9, onesided=True):
    if onesided:
        pvals = scipy.stats.norm.sf(abs((test_Y-gp_mean)/gp_std)) 
        # pvals = np.array([scipy.stats.norm.sf(abs((test_y-test_hat)/std)) 
        #                   for test_y, test_hat, std in zip(test_Y[test_Y<cutoff], gp_mean[test_Y<cutoff], gp_std[test_Y<cutoff])])
    else:
        pvals = scipy.stats.norm.sf(abs((test_Y-gp_mean)/gp_std)) + scipy.stats.norm.cdf(-abs((test_Y-gp_mean)/gp_std)) 
        # pvals = np.array([scipy.stats.norm.sf(abs((test_y-test_hat)/std)) + scipy.stats.norm.cdf(-abs((test_y-test_hat)/std)) 
                          # for test_y, test_hat, std in zip(test_Y[test_Y<cutoff], gp_mean[test_Y<cutoff], gp_std[test_Y<cutoff])])
    
    return pvals

def calibration_score_by_pvals(pvals):
    alpha = [0.05, 0.01, 0.001, 0.0001]
    alpha_emp = [len(pvals[pvals < a]) / len(pvals) for a in alpha]

    return sum([(a-ap)**2 for a, ap in zip(alpha, alpha_emp)]) 

def pick_gp_by_calibration(f_h5, cancer=None, dataset='test'):
    f = h5py.File(f_h5, 'r')
    if cancer:
        dset = f[cancer]
    else:
        dset = f

    try:
        #old version
        #N_gps = len([key for key in dset[dataset].keys() if len(key)==1]) ## NOTE: bad way to find integers used as keys
        #to account for missing runs
        runs = [int(key) for key in dset[dataset].keys() if key.isdigit()]
        test_Y = dset[dataset]['y_true'][:]
        gp_mean_lst = [dset[dataset]['{}'.format(i)]['mean'][:] for i in runs]
        gp_std_lst = [dset[dataset]['{}'.format(i)]['std'][:] for i in runs]
        gp_R2_lst = [dset[dataset]['{}'.format(i)].attrs['R2'] for i in runs]
        gp_loss_lst = [dset[dataset]['{}'.format(i)].attrs['loss'] for i in runs]

        sse_lst = []
        for gp_mean, gp_std in zip(gp_mean_lst, gp_std_lst):
            pvals = calc_pvals(test_Y, gp_mean, gp_std, onesided=False)
            sse_lst.append(calibration_score_by_pvals(pvals))

        scores = np.array(sse_lst) * gp_loss_lst
        args = np.argsort(scores)
        return args[0]

    except:
        N_gps = len([key for key in f[dataset].keys() if key.startswith('gp_mean')])

        test_Y = f[dataset]['true'][:]
        gp_mean_lst = [f[dataset]['gp_mean_{:02d}'.format(i)][:] for i in range(1, N_gps+1)]
        gp_std_lst = [f[dataset]['gp_std_{:02d}'.format(i)][:] for i in range(1, N_gps+1)]
        gp_R2_lst = np.array([f[dataset].attrs['gp_R2_{:02d}'.format(i)] for i in range(1, N_gps+1)])
        gp_loss_lst = np.array([f[dataset].attrs['gp_loss_{:02d}'.format(i)] for i in range(1, N_gps+1)])

        sse_lst = []
        for gp_mean, gp_std in zip(gp_mean_lst, gp_std_lst):
            pvals = calc_pvals(test_Y, gp_mean, gp_std, onesided=False)
            sse_lst.append(calibration_score_by_pvals(pvals))

        scores = np.array(sse_lst) * gp_loss_lst
        args = np.argsort(scores)
        return args[0] + 1
    # print(gp_R2_lst)
    # print(gp_loss_lst)

# def merge_windows(df, start, end, new_size):
def merge_windows(df, idx_new):
    # bins = np.concatenate([np.arange(start, end, new_size), [end]])

    Y_merge = np.array([df[(df.CHROM==row[0]) & (df.START >= row[1]) & (df.START < row[2])].Y_TRUE.sum() \
                           for row in idx_new])
    Yhat_merge = np.array([df[(df.CHROM==row[0]) & (df.START >= row[1]) & (df.START < row[2])].Y_PRED.sum() \
                           for row in idx_new])
    std_merge = np.array([np.sqrt((df[(df.CHROM==row[0]) & (df.START >= row[1]) & (df.START < row[2])].STD**2).sum()) \
                           for row in idx_new])

    # Y_merge = np.array([df[(df.START >= v1) & (df.START < v2)].Y_TRUE.sum() \
    #                          for v1, v2 in zip(bins[:-1], bins[1:])])
    # Yhat_merge = np.array([df[(df.START >= v1) & (df.START < v2)].Y_PRED.sum() \
    #                             for v1, v2 in zip(bins[:-1], bins[1:])])
    # std_merge = np.array([np.sqrt((df[(df.START >= v1) & (df.START < v2)].STD**2).sum()) \
    #                            for v1, v2 in zip(bins[:-1], bins[1:])])

    a_merge = np.hstack([idx_new,
                         Y_merge.reshape(-1, 1),
                         Yhat_merge.reshape(-1, 1),
                         std_merge.reshape(-1, 1)
                         ]
                        )
    # a_merge = np.hstack([bins[:-1].reshape(-1, 1),
    #                      bins[1:].reshape(-1, 1),
    #                      Y_merge.reshape(-1, 1),
    #                      Yhat_merge.reshape(-1, 1),
    #                      std_merge.reshape(-1, 1)
    #                      ]
    #                     )

    df_merge = pd.DataFrame(a_merge, columns=['CHROM', 'START', 'END', 'Y_TRUE', 'Y_PRED', 'STD'])
    # df_merge = pd.DataFrame(a_merge, columns=['START', 'END', 'Y_TRUE', 'Y_PRED', 'STD'])
    # df_merge.insert(0, 'CHROM', df.CHROM.iloc[0])

    return df_merge
