
import os

import tensorflow as tf
gpu=2
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its obnoxious warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append("..")
from util import diff_pred 


import numpy as np
import random as rn

np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

tf.random.set_seed(1234)
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt

from get_data import get_data



def load_data(strr, exp_type, ds_name="", num=500, from_script=False):
    if ds_name=='taiwanese':
        if exp_type=="LF":
            base_test_seed=np.load(strr)
            if base_test_seed.shape[0]>500:
                base_test_seed=np.moveaxis(np.array(base_test_seed), 0, 1)
            return np.array(base_test_seed)
        else:
            base_test_seed = []
            for i in range(num):
                base_test_seed.append(np.load(strr.format(i)))

        if len(np.array(base_test_seed).shape)>2:
            return np.array(base_test_seed).squeeze()
        else:
    
            return np.array(base_test_seed)
    elif from_script:

        return np.moveaxis(np.load(strr).squeeze(), 0, 1)
    else:
        if exp_type=="LF":
            base_test_seed=None
            strr= strr.split(",")
            for st in strr:
                preds=np.load(st)
                if preds.shape[1]>500:
                    preds=np.moveaxis(preds,0,1).squeeze()
                if len(preds.shape)>2:
                    preds=preds[:,0]
                base_test_seed=preds if base_test_seed is None else np.hstack((base_test_seed, np.load(st)))
            if (base_test_seed.shape[0]>501 or base_test_seed.shape[0]<499):
                base_test_seed=np.moveaxis(np.array(base_test_seed), 0, 1)
        else:
            base_test_seed = []
            for i in range(num):
                base_test_seed.append(np.load(strr.format(i)))

        if len(np.array(base_test_seed).shape)>2:
            return np.array(base_test_seed).squeeze()
        else:
    
            return np.array(base_test_seed)


def make_agg(outputs, n, alpha, ranges=None, num_avg=None, num_models=None, step=None):
    if ranges is None:
        ranges=np.arange(0, outputs.shape[0], num_avg)
    if num_models==None:
        num_models= outputs.shape[0]//n
    avg_res=np.array([(outputs[ranges[i]:ranges[i+1]]>0.5).mean(axis=0) for i in range(num_models)])
    stats_test=[]
    for i in range(num_models):
        stats_test.append(np.array([stats.binom_test((
                outputs[ranges[i]:ranges[i+1]]>0.5).sum(axis=0)[j], n, p=0.5) 
                 for j in range(outputs.shape[1])] ))
    stats_test=np.array(stats_test)
    abstain_percent=[1-(stats_test[i]<alpha).sum()/np.array(stats_test[i]).shape[0] 
                     for i in range(num_models)]
    idxs_arr=np.array([(np.array(stats_test[i])<alpha) for i in range(num_models)])
    return avg_res, idxs_arr, abstain_percent
    
def get_avg_acc(outputs, true_lab, pred_idx=None):
    accs=[]
    non_abs_accs=[]
    if pred_idx is not None:
        for i in range(outputs.shape[0]):
            non_abs_acc=((outputs[i][pred_idx[i]]>0.5)==(true_lab[pred_idx[i]]>0.5)).mean()
            non_abs_accs.append(non_abs_acc)
            accs.append(pred_idx[i].mean()*non_abs_acc)
        return np.array(accs).mean(), np.array(non_abs_accs).mean()
    for i in range(outputs.shape[0]):
        accs.append(((outputs[i]>0.5)==(true_lab>0.5)).mean())
    return np.array(accs).mean()

def get_std_and_min_max_acc(base, y):
    
    accs=np.array([((base[i]>0.5)==y).mean() for i in range(base.shape[0])])
    std=accs.std()
    max_diff=accs.max()-accs.min()#maxDiff(accs, accs.shape[0])
    max_diff_mean=np.max(np.abs(accs-accs.mean()))
    return accs,std, max_diff, max_diff_mean

def get_epsilon(array, model_idx, predict, conf_idx=0, conf_arr=None):
    dont_agree=[]
    dont_agree_alpha=[]
    for i,j in [(i,j) for i, j in zip(
        np.meshgrid(model_idx,model_idx)[0].flatten(),
        np.meshgrid(model_idx,model_idx)[1].flatten()) if i != j]:
        dont_agree.append((array[i]>0.5)!=(array[j]>0.5))
        
    dont_agree=np.array(dont_agree)

    epsilon=np.max(dont_agree.mean(axis=0))

    avg_eps=dont_agree.mean(axis=0).mean()

    median_eps=np.median(dont_agree.mean(axis=0))

    percent_chance_flip=(dont_agree.mean(axis=0)>0).mean()

    percent_chance_flip_05=(dont_agree.mean(axis=0)>0.5).mean()

    if conf_arr is not None:
        avg_distance_bdry_flips=np.abs(conf_arr[dont_agree.mean(axis=0)>0]-0.5).mean()
    else:
        if np.allclose(percent_chance_flip,0):
            avg_distance_bdry_flips=0
        else:
            avg_distance_bdry_flips=np.abs(array[conf_idx][dont_agree.mean(axis=0)>0]-0.5).mean()
    return(epsilon, avg_eps, median_eps,percent_chance_flip, percent_chance_flip_05,avg_distance_bdry_flips)



avg_base_accs_min_max_std=[]


def script(exp_type, alpha, sample_models_strr, ds_name, mdl_file, labels_file, num=500, from_script=False):
    strr0=ds_name
    strr_1=mdl_file
    strr_2=labels_file
    base_test_seed=load_data(strr_1, exp_type=exp_type, ds_name=strr0, from_script=from_script)
    
    try:
        perm=np.load("ensemble/permutation_for_dataset_"+strr0+".npy")
    except:
        perm=np.random.permutation(base_test_seed.shape[0])
        np.save("ensemble/permutation_for_dataset_"+strr0+".npy", perm)
    base_test_seed=base_test_seed[perm]

    if from_script:
        (X_train, y_train), (X_test, y_test), (X_train_out, y_train_out) = get_data(ds_name)
    else:
        y_test=np.load(strr_2, allow_pickle=False)

    #these are the randomly selected one-off models we compare to
    sample_models=np.load(sample_models_strr)
    avg_base_acc=get_avg_acc(base_test_seed, y_test)

    _, accs_std_gc_rs, md_gc_rs, m_avg_diff= get_std_and_min_max_acc(base_test_seed, y_test)

    avg_base_accs_min_max_std.append([avg_base_acc, md_gc_rs, accs_std_gc_rs, m_avg_diff])

    pd.DataFrame(data=[avg_base_acc, md_gc_rs, accs_std_gc_rs, m_avg_diff], index=["Avg Base Acc", "Min Max", "Std", "Max Diff From Avg"],
             columns=[strr0]).to_csv("ensemble/"+strr0+"datasets_min_max_std_"+exp_type+".csv")

    avg_tes, pred_idxs, abs_percs, pred_intersections= [], [], [], []

    accs= []
    for i in range(5, 24, 5):
        avg_te, pred_idx, abs_perc=make_agg(base_test_seed, i, alpha, ranges=np.arange(0, num ,i), num_models=24)
        np.save("ensemble/"+strr0+"_avg_predictions_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), avg_te)
        #1s are predictions, 0s abstain
        np.save("ensemble/"+strr0+"_abstain_idxs_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), pred_idx)
        np.save("ensemble/"+strr0+"_abstain_percentages_{}_ens_mdl_alpha_{}".format(i, alpha)+exp_type+".npy", abs_perc)
        avg_tes.append(avg_te)
        pred_idxs.append(pred_idx)
        abs_percs.append(abs_perc)
        avg_acc_abs, avg_acc_on_nonabs=get_avg_acc(avg_te, y_test, pred_idx=pred_idx)
        accs.append([avg_acc_abs, get_avg_acc(avg_te, y_test), avg_acc_on_nonabs, np.array(abs_perc).mean()])
        pred_intersections.append(np.all(pred_idx, axis=0))



    accs=np.array(accs)
    plt.plot(np.arange(5,25,5), 1-accs[:,0], label="Err Abst as 0")
    plt.plot(np.arange(5,25,5), 1-accs[:,1], label="Err No Abst")
    plt.plot(np.arange(5,25,5), 1-accs[:,2], label="Err Only on Pred")
    plt.plot(np.arange(5,25,5), accs[:,3], label="Abst Perc")
    plt.plot(np.arange(5,25,5), [1-avg_base_acc for i in range(4)], label="Oneshot Avg Err")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Error/Abstain Perc")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'Accs_'+exp_type+'alpha_{}'.format(alpha)+'.png')
    plt.close()

    get_eps_no_abst, get_eps= [], []
    #LUF pts per conf, where conf is from baseline one shot model (can just take first, this is total LUF)
    lufs_no_abst, lufs = [], []
    
    #want to make these only on overlapping predictions(not abstaining)


    get_eps_no_abst.append(get_epsilon(base_test_seed, sample_models, np.ones_like(base_test_seed), conf_arr=base_test_seed[42]))
    lufs_no_abst.append(diff_pred([], np.moveaxis(base_test_seed[sample_models], 0, 1), [], base_test_seed[42], train=False))

    get_eps.append(get_epsilon(base_test_seed, sample_models, np.ones_like(base_test_seed), conf_arr=base_test_seed[42]))
    lufs.append(diff_pred([], np.moveaxis(base_test_seed[sample_models], 0, 1), [], base_test_seed[42], train=False))

    for (arr, pred_idx, pred_intersect) in zip(avg_tes, pred_idxs, pred_intersections):

        if pred_intersect.mean()>0:
            get_eps.append(get_epsilon(arr[:,pred_intersect], np.arange(24), pred_idx, 
                conf_arr=base_test_seed[42][pred_intersect]))
            
            lufs.append(diff_pred([], np.moveaxis(arr[:,pred_intersect], 0, 1), [], arr[0][pred_intersect]>0.5, 
                              conf_arr_test=base_test_seed[42][pred_intersect]>0.5, train=False))
        else:
            get_eps.append((0,0,0,0,0,0))
            lufs.append([0,0,0,0,0])
        get_eps_no_abst.append(get_epsilon(arr, np.arange(24), pred_idx, conf_arr=base_test_seed[42]))
        
        lufs_no_abst.append(diff_pred([], np.moveaxis(arr, 0, 1), [], arr[0]>0.5, 
                              conf_arr_test=base_test_seed[42]>0.5, train=False))
    get_eps=np.array(get_eps)
    lufs=np.array(lufs)
    get_eps_no_abst=np.array(get_eps_no_abst)
    lufs_no_abst=np.array(lufs_no_abst)

    plt.plot(np.arange(0,25,5), get_eps[:,3], label="% with Avg Flip Chnge > 0")
    plt.plot(np.arange(0,25,5), get_eps[:,4], label="Avg Flip Chnc >0.5")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Avg Flip Chance")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'avg_flip_chance_'+exp_type+"alpha_{}".format(alpha)+'_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps[:,0], label="Eps")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Epsilon")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'Epsilon_'+exp_type+"alpha_{}".format(alpha)+'_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,3], label="% with Avg Flip Chnge > 0")
    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,4], label="Avg Flip Chnc >0.5")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Avg Flip Chance")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'avg_flip_chance_'+exp_type+"alpha_{}".format(alpha)+'_no_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,0], label="Eps")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Epsilon")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'Epsilon_'+exp_type+"alpha_{}".format(alpha)+'_no_abstain.png')
    plt.close()




    pd.DataFrame(data=accs, columns=["Avg Acc with Abstention as Wrong", 
                                     "Avg Acc No Abstention", 
                                     "Avg Acc on Predicted Pts", "Abstain Percent"]).to_csv(
                                     "ensemble/"+strr0+"_Accs_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=get_eps, columns=["epsilon", "avg_eps", "median_eps","percent_chance_flip",
                                        "percent_chance_flip_05","avg_distance_bdry_flips", 
                                       ]).to_csv("ensemble/"+strr0+"_Avg_Epsilons_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=lufs, columns=["total LUF", "over 0.1", "over 0.2", "over 0.3",
                                    "over 0.4"]).to_csv("ensemble/"+strr0+"_LUFs_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=get_eps_no_abst, columns=["epsilon", "avg_eps", "median_eps","percent_chance_flip",
                                        "percent_chance_flip_05","avg_distance_bdry_flips", 
                                       ]).to_csv("ensemble/"+strr0+"_Avg_Epsilons_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=lufs_no_abst, columns=["total LUF", "over 0.1", "over 0.2", "over 0.3",
                                    "over 0.4"]).to_csv("ensemble/"+strr0+"_LUFs_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv")
    
    return {"model_preds": "ensemble/"+strr0+"_avg_predictions_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i),
                "abs_or_pred_true_false": "ensemble/"+strr0+"_abstain_idxs_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), 
                "abstain_percents": "ensemble/"+strr0+"_abstain_percentages_{}_ens_mdl_alpha_{}".format(i, alpha)+exp_type+".npy",
                "accs": "ensemble/"+strr0+"_Accs_"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "get_eps_output":"ensemble/"+strr0+"_Avg_Epsilons_"+exp_type+"alpha_{}".format(alpha)+".csv",
                "lufs_output": "ensemble/"+strr0+"_LUFs_"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "get_eps_no_abst_output": "ensemble/"+strr0+"_Avg_Epsilons_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "lufs_no_abst_output": "ensemble/"+strr0+"_LUFs_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv"
                }

######
num=500
alpha=0.05
np.random.seed(0)
try:
    sample_models=np.load("ensemble/sample_models.npy")
except:
    sample_models=np.random.choice(np.arange(100))
    np.save("ensemble/sample_models.npy", sample_models)
######

for (strr0, strr_1, strr_2) in zip(["german_credit", "adult", "seizure"],
                                   ["../resubmit_data/gc_yp_test_same_seed_seed{}.npy",
                         "../resubmit_data/adult_yp_test_same_seed_seed{}.npy",
                        "../resubmit_data/seizure_yp_test_same_seed_seed{}.npy"],
                            ['../data/german_y_test.npy',
                             '../data/adult_y_test_norm.npy',
                             '../data/data_redone/seizure_y_test.npy']):
    
    script(exp_type="RS", alpha=0.05, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)
    script(exp_type="RS", alpha=0.01, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)


for (strr0, strr_1, strr_2) in zip(["german_credit", "adult", 
                                        "seizure"],
                                   [",".join(["../resubmit_data/gc_loo_test_same_seed.npy",
                                     "../resubmit_data/gc_loo_test_same_seed_seed_ex.npy",
                                    "../resubmit_data/gc_loo_test_same_seed_seed_ex_0100.npy"]),
                         ",".join(["../resubmit_data/adult_loo_test_same_seed.npy",
                          "../resubmit_data/adult_loo_test_same_seed_ext.npy"]),
                         ",".join(["../resubmit_data/seizure_loo_test_same_seed.npy",
                          "../resubmit_data/seizure_loo_test_same_seed_ext.npy"])], 
                            ['../data/german_y_test.npy',
                             '../data/adult_y_test_norm.npy',
                             '../data/data_redone/seizure_y_test.npy']):

    script(exp_type="LF", alpha=0.05, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)
    script(exp_type="LF", alpha=0.01, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)

for (strr0, strr_1, strr_2) in zip(["taiwanese"],
                                   ["../data/tai_loo_test_same_seed.npy"],
                            ['../data/y_test_taiwanese.npy']):

    script(exp_type="LF", alpha=0.05, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)
    script(exp_type="LF", alpha=0.01, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)

for (strr0, strr_1, strr_2) in zip(["taiwanese"],
                                   ['../data/tai_yp_test_same_seed_seed{}.npy'],
                            ['../data/y_test_taiwanese.npy']):

    script(exp_type="RS", alpha=0.05, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)
    script(exp_type="RS", alpha=0.01, sample_models_strr="ensemble/sample_models.npy", ds_name=strr0,
     mdl_file=strr_1, labels_file=strr_2, num=500)
   
