import argparse
import os
import json
import itertools

import numpy as np
import pandas as pd
from scipy.stats import ttest_ind
import statsmodels.api as sm


def make_table(data, store_dir):
    envs = data.keys()
    all_means_plus_minus = {}
    np_idx = [9, 24, 49, 99]
    for e in envs:
        models = data[e].keys()
        means_plus_minus = {}
        for m in models:
            #clipped = y_models[m].clip(-200,200)
            clipped = data[e][m]
            mean = np.nanmean(clipped, axis=0)
            sigma = np.nanstd(clipped, axis=0)
            mean_pl_mn = [str(np.round(mean[idx],2))+'±'+str(np.round(sigma[idx],2)) 
                for idx in np_idx]
            means_plus_minus[m] = mean_pl_mn
        all_means_plus_minus[e] = means_plus_minus
    ac_batch = [10, 25, 50, 100]
    envs = list(envs)
    envs.sort()
    envs = ['hetero', 'bimodal', 'Pendulum-v0', 'Hopper-v2', 'Ant-v2', 'Humanoid-v2']
    idx = list(itertools.product(envs, ac_batch))  
    idx = pd.MultiIndex.from_tuples(idx, names=["Env", "Acquisition Batch"])
    models = list(models)
    models.sort()
    df = pd.DataFrame(np.zeros((len(ac_batch)*len(envs), len(models))), columns = models ,index=idx)
    for m in models:
        data = all_means_plus_minus[envs[0]][m]
        data += all_means_plus_minus[envs[1]][m]
        data += all_means_plus_minus[envs[2]][m]
        data += all_means_plus_minus[envs[3]][m]
        data += all_means_plus_minus[envs[4]][m]
        data += all_means_plus_minus[envs[5]][m]
        df[m] = data
    df = df.rename({'random':'Random', 'sample_bald':'Monte Carlo', 'kl_exp':'KL', 
        'bhatt_exp':'Bhatt', 'bmdal_batchbald':'BatchBALD', 'bmdal_badge':'BADGE',
        'bmdal_bait':'BAIT', 'bmdal_lcmd':'LCMD'}, axis=1)
    df = df.loc[:,['Random', 'BatchBALD', 'BADGE', 'BAIT', 'LCMD', 'Monte Carlo', 'KL', 'Bhatt']]
    print(df.to_latex())
    import pdb; pdb.set_trace()
    return df

def stats_testing(data, envs):
    results = {}
    results['order'] = ['rand_kl', 'mc_kl', 'bbald_kl', 'badge_kl', 
        'bait_kl', 'lcmd_kl', 'rand_bhatt', 'mc_bhatt', 'bbald_bhatt', 
        'badge_bhatt', 'bait_bhatt', 'lcmd_bhatt']
    ac_batch = [10, 25, 50, 100]
    idx = list(itertools.product(envs, ac_batch))
    idx = pd.MultiIndex.from_tuples(idx, names=["Env", "Acquisition Batch"])
    df_pvals = pd.DataFrame(np.zeros((len(ac_batch)*len(envs), len(results['order']))), columns = results['order'] ,index=idx)
    df_corpvals = pd.DataFrame(np.zeros((len(ac_batch)*len(envs), len(results['order']))), columns = results['order'] ,index=idx)
    df_test_stat = pd.DataFrame(np.zeros((len(ac_batch)*len(envs), len(results['order']))), columns = results['order'] ,index=idx)
    for env in envs:
        env_data = data[env]
        np_idx = [9, 24, 49, 99]
        rand_data = env_data['random']
        mc_data = env_data['sample_bald']
        kl_data = env_data['kl_exp']
        bhatt_data = env_data['bhatt_exp']
        bbald_data = env_data['bmdal_batchbald']
        badge_data = env_data['bmdal_badge']
        bait_data = env_data['bmdal_bait']
        lcmd_data = env_data['bmdal_lcmd']
        results[env] = {}
        for idx in np_idx:
            rand_data_idx = rand_data[:,idx]
            mc_data_idx = mc_data[:,idx]
            kl_data_idx = kl_data[:,idx]
            bhatt_data_idx = bhatt_data[:,idx]
            bbald_data_idx = bbald_data[:,idx]
            badge_data_idx = badge_data[:,idx]
            bait_data_idx = bait_data[:,idx]
            lcmd_data_idx = lcmd_data[:,idx]
            rand_kl = ttest_ind(rand_data_idx, kl_data_idx, equal_var=False)
            mc_kl = ttest_ind(mc_data_idx, kl_data_idx, equal_var=False)
            bbald_kl = ttest_ind(bbald_data_idx, kl_data_idx, equal_var=False)
            badge_kl = ttest_ind(badge_data_idx, kl_data_idx, equal_var=False)
            bait_kl = ttest_ind(bait_data_idx, kl_data_idx, equal_var=False)
            lcmd_kl = ttest_ind(lcmd_data_idx, kl_data_idx, equal_var=False)
            rand_bhatt = ttest_ind(rand_data_idx, bhatt_data_idx, equal_var=False)
            mc_bhatt = ttest_ind(mc_data_idx, bhatt_data_idx, equal_var=False)
            bbald_bhatt = ttest_ind(bbald_data_idx, bhatt_data_idx, equal_var=False)
            badge_bhatt = ttest_ind(badge_data_idx, bhatt_data_idx, equal_var=False)
            bait_bhatt = ttest_ind(bait_data_idx, bhatt_data_idx, equal_var=False)
            lcmd_bhatt = ttest_ind(lcmd_data_idx, bhatt_data_idx, equal_var=False)
            test_stat = [rand_kl.statistic, mc_kl.statistic, bbald_kl.statistic, 
                badge_kl.statistic, bait_kl.statistic, lcmd_kl.statistic, 
                rand_bhatt.statistic, mc_bhatt.statistic, bbald_bhatt.statistic,
                badge_bhatt.statistic, bait_bhatt.statistic, lcmd_bhatt.statistic]
            pvalues = [rand_kl.pvalue, mc_kl.pvalue, bbald_kl.pvalue, 
                badge_kl.pvalue, bait_kl.pvalue, lcmd_kl.pvalue, 
                rand_bhatt.pvalue, mc_bhatt.pvalue, bbald_bhatt.pvalue,
                badge_bhatt.pvalue, bait_bhatt.pvalue, lcmd_bhatt.pvalue]
            reject, corrected_p_values, _, _ = sm.stats.multipletests(pvalues, method='holm')    
            df_corpvals.loc[(env, idx+1)] = corrected_p_values
            df_pvals.loc[(env, idx+1)] = np.array(pvalues)
            df_test_stat.loc[(env, idx+1)] = np.round(np.array(test_stat), 2)
            results[env][idx+1]={'test_stat':np.array(test_stat), 'p-value':np.array(pvalues), 'corrected_p-value':corrected_p_values}
    kl_cols =[i for i in df_corpvals.columns if 'kl' in i]
    bhatt_cols =[i for i in df_corpvals.columns if 'bhatt' in i]
    print(df_corpvals.to_latex())
    #print(df_pvals.to_latex())
    print(df_test_stat.to_latex())
    import pdb; pdb.set_trace()

    

if __name__ == '__main__':
    #base_dir = '/home/nwaftp23/scratch/uncertainty_estimation/mujoco'
    #base_dir = '/home/nwaftp23/projects/def-dpmeger/nwaftp23/uncertainty_estimation/mujoco'
    base_dir = '/home/lucas/uncertainty_estimation/results/scp_vs_zip/mujoco'
    acq_crit = ['random', 'sample_bald', 'bmdal_batchbald', 
        'bmdal_badge', 'bmdal_bait', 'bmdal_lcmd', 'kl_exp', 'bhatt_exp'] 
    envs = [ 'hetero', 'bimodal', 'Pendulum-v0', 'Hopper-v2', 'Ant-v2', 'Humanoid-v2']
    models_to_graph = 'nflows_ensemble_fixedmasks'
    #models_to_graph = 'nn_ensemble_fixedmasks'
    all_likelihoods = {}
    all_rmses = {}
    cutoff = 100
    for env in envs:
        env_dir = os.path.join(base_dir, env+'_test_aquisition')
        run_dirs = os.listdir(env_dir)
        run_dirs = [os.path.join(env_dir, d) for d in run_dirs
            if os.path.isdir(os.path.join(env_dir, d))]
        rmses_models = {}
        likelihoods_models = {}
        for crit in acq_crit:
            model_dirs = [d for d in run_dirs
                if crit in os.path.basename(os.path.normpath(d))]
            model_dirs = [d for d in model_dirs if models_to_graph in d]
            rmses = []
            likelihoods = []
            for d in model_dirs:
                if d.split('_')[-1] == 'seed585':
                    continue
                if models_to_graph not in d:
                    continue
                args_path = os.path.join(d,'commandline_args.txt')
                saved_parser = argparse.ArgumentParser()
                saved_args = saved_parser.parse_args()
                with open(args_path, 'r') as f:
                    saved_args.__dict__ = json.load(f)
                train_len = saved_args.epochs_multiplier
                results_dir = os.path.join(d, 'results')
                rmse = np.load(os.path.join(results_dir, 'rmse_array.npy'))
                test_loss = np.load(os.path.join(results_dir, 'test_loss_array.npy'))
                likelihood = -test_loss
                rmse_nans = np.where(np.isnan(rmse))
                likelihoods_nans = np.where(np.isnan(likelihood))
                print(d)
                print(rmse.shape)
                if len(rmse) >= cutoff:
                    rmses.append(rmse[:cutoff])
                    likelihoods.append(likelihood[:cutoff])
            rmses = np.stack(rmses)
            likelihoods = np.stack(likelihoods)
            rmses_models[crit] = rmses
            likelihoods_models[crit] = likelihoods
        all_rmses[env] = rmses_models
        all_likelihoods[env] = likelihoods_models
    print('likelihood stat tests')
    stats_testing(all_likelihoods, envs)
    print('likelihood table')
    likelihood_table = make_table(all_likelihoods, base_dir)  
    print('RMSE stat tests')
    stats_testing(all_rmses, envs)
    print('RMSE table')
    rmse_table = make_table(all_rmses, base_dir)  
