import os

import tensorflow as tf
gpu=1
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its obnoxious warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')

import pandas as pd

import sys

import numpy as np
import random as rn
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt
from get_data import get_data

np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

tf.random.set_seed(1234)
np.random.seed(0)
#sample_models=np.random.choice(100, size=25, replace=False)
#np.save("ensemble/sample_models_FMNIST.npy", sample_models)

#from util import *
def load_data(strr, exp_type, num=500, strr0=None):
    base_test_seed = []
    if exp_type=="RS":
        for i in range(num):
            base_test_seed.append(np.load(strr.format(i)))

        return np.array(base_test_seed).squeeze()
    else:
        if strr0=="warafin":
            base_test_seed=np.load(strr)
            base_test_seed_2=np.load(
                "..data/war_loo_test_same_seed_0100.npy")
            bts=np.concatenate((base_test_seed,base_test_seed_2), axis=0)
            return bts
        else:
            return np.load(strr)


def make_agg_multi(outputs, n, alpha, ranges=None, num_avg=None, num_models=None, step=None):
    #assumes argmax is already taken

    if ranges is None:
        ranges=np.arange(0, outputs.shape[0], num_avg)
    if num_models==None:
        num_models= outputs.shape[0]//n
    stats_test=[]
    avg_res=[]
    for i in range(num_models):
  
        mode, counts= stats.mode(outputs[ranges[i]:ranges[i+1]], axis=0)
        avg_res.append(mode.squeeze())

        stats_test.append(
            np.array(
                [stats.binom_test(
                counts.squeeze()[j], n, p=0.5) 
                 for j in range(outputs.shape[1])] ))
    stats_test=np.array(stats_test)
    avg_res=np.array(avg_res)
    abstain_percent=[1-(stats_test[i]<alpha).sum()/np.array(stats_test[i]).shape[0] 
                     for i in range(num_models)]
    idxs_arr=np.array([(np.array(stats_test[i])<alpha) for i in range(num_models)])
    
    return avg_res, idxs_arr, abstain_percent
    
def get_avg_acc_multi(outputs, true_lab, pred_idx=None):
    #assumes you've taken the argmax of the outputs
    accs=[]
    non_abs_accs=[]
    if pred_idx is not None:
        if pred_idx.mean()==0:
            non_abs_accs.append(0.0)
            accs.append(0.0)
        for i in range(outputs.shape[0]):
            non_abs_acc=((outputs[i][pred_idx[i]])==(true_lab[pred_idx[i]].argmax(axis=1))).mean()
            non_abs_accs.append(non_abs_acc)
            accs.append(pred_idx[i].mean()*non_abs_acc)
        return np.array(accs).mean(), np.array(non_abs_accs).mean()
    for i in range(outputs.shape[0]):
        accs.append(((outputs[i])==(true_lab.argmax(axis=1))).mean())
    return np.array(accs).mean()

def get_std_and_min_max_acc_multi(base, y):
    
    accs=np.array([((base[i])==y.argmax(axis=1)).mean() for i in range(base.shape[0])])
    std=accs.std()
    max_diff=accs.max()-accs.min()
    max_dist_avg=np.max(np.abs(accs-accs.mean()))
    return accs,std, max_diff, max_dist_avg

def get_epsilon_multi(array, model_idx, predict, conf_idx=0, conf_arr=None):
    #expects array to already have been argmaxed
    dont_agree=[]
    dont_agree_alpha=[]
    for i,j in [(i,j) for i, j in zip(
        np.meshgrid(model_idx,model_idx)[0].flatten(),
        np.meshgrid(model_idx,model_idx)[1].flatten()) if i != j]:
        dont_agree.append((array[i])!=(array[j]))


    dont_agree=np.array(dont_agree)

    epsilon=np.max(dont_agree.mean(axis=0))

    avg_eps=dont_agree.mean(axis=0).mean()

    median_eps=np.median(dont_agree.mean(axis=0))

    percent_chance_flip=(dont_agree.mean(axis=0)>0).mean()

    percent_chance_flip_05=(dont_agree.mean(axis=0)>0.5).mean()
    return(epsilon, avg_eps, median_eps,percent_chance_flip, percent_chance_flip_05)

def diff_pred_multi(yp_train_loo,  yp_test_loo,
                  base_train, base_test, conf_arr_test=None, conf_arr_train=None, train=True):
   # expects base_train/test, and conf_arr, and base_test to already be argmaxed!
    flipped_train=[]
    flipped_test=[]
    if conf_arr_test is None:
        conf_arr_test=base_test
    if conf_arr_train is None:
        conf_arr_train= base_train
    for i in range(yp_test_loo.shape[0]):
        if train:
            flipped_train.append(yp_train_loo[i]!=base_train)
        flipped_test.append(yp_test_loo[i]!=base_test)
    if train:
        pts_flipped_anywhere_tr=np.array(flipped_train).max(axis=0)
    pts_flipped_anywhere_te=np.array(flipped_test).max(axis=0)
    
    
    sorted_conf_test=np.sort(conf_arr_test,axis=1)
    if train:
        sorted_conf_train=np.sort(conf_arr_train, axis=1)
    
    confidence_test=sorted_conf_test[:,-1]
    if train:
        confidence_train=sorted_conf_train[:,-1]

    if train:
        num_flipped_per_conf_tr=[
            (confidence_train[pts_flipped_anywhere_tr]>0.1*i).sum() for i in range(10)]
    num_flipped_per_conf_te=[
        (confidence_test[pts_flipped_anywhere_te]>0.1*i).sum() for i in range(10)]

    if train:
        percent_pts_per_conf=((np.array(num_flipped_per_conf_tr)+
                          np.array(num_flipped_per_conf_te))/(yp_train_loo.shape[1]+yp_test_loo.shape[1]))*100
    else:
        percent_pts_per_conf=(
                          np.array(num_flipped_per_conf_te)/yp_test_loo.shape[1])*100
    
    return percent_pts_per_conf
 

def script(exp_type, alpha, sample_models_strr, ds_name, mdl_file, labels_file, num_models=5, num=500):
    strr0=ds_name
    strr_1=mdl_file
    strr_2=labels_file
    
    sample_models=np.load(sample_models_strr)

    if ds_name=="FMNIST":
        FM_data=FashionMnist()
        y_test=FM_data.y_te_1hot
        del FM_data
    elif ds_name=="colon":
        import tensorflow_datasets as tfds

        image, label = tfds.as_numpy(tfds.load(
        'colorectal_histology',
        split='train',
        batch_size=-1,
        as_supervised=True,
        ))
        image = np.moveaxis(image, 3, 1)
        image=tf.keras.applications.resnet.preprocess_input(image)
        from sklearn.model_selection import train_test_split

        X_train, X_test, y_train, y_test=train_test_split(
            image,
            label, shuffle=True, random_state=0)
        del X_train, X_test, image, label
        y_train_cat=tf.keras.utils.to_categorical(y_train, num_classes=8)
        y_test=tf.keras.utils.to_categorical(y_test, num_classes=8)

    else:
        (X_train, y_train), (X_test, y_test), (_,_)= get_data(ds_name)
        #y_test=np.load(strr_2, allow_pickle=False)
        if len(y_test.shape)<2:
            y_test=tf.keras.utils.to_categorical(y_test, num_classes=int(y_test.max())+1)

    base_test_seed= load_data(strr_1, exp_type=exp_type, num=num, strr0=strr0).argmax(axis=2)
    base_test_seed_w_conf= load_data(strr_1, exp_type=exp_type, num=num,strr0=strr0)


    avg_base_acc=get_avg_acc_multi(base_test_seed, y_test)

    _, accs_std_gc_rs, md_gc_rs, md_avg= get_std_and_min_max_acc_multi(base_test_seed, y_test)

    avg_base_accs_min_max_std.append([avg_base_acc, md_gc_rs, accs_std_gc_rs, md_avg ])

    pd.DataFrame(data=[avg_base_acc, md_gc_rs, accs_std_gc_rs, md_avg], index=["Avg Base Acc", "Min Max", "Std", "Max Diff From Avg"],
                 columns=[strr0]).to_csv("ensemble/"+strr0+"datasets_min_max_std_"+exp_type+".csv")

    avg_tes, pred_idxs, abs_percs, pred_intersections= [], [], [], []
    accs= []
    for i in range(5,25, 5):
        avg_te, pred_idx, abs_perc=make_agg_multi(base_test_seed, i, alpha, 
            ranges=np.arange(0, num+1, i), num_models=num_models)
        avg_tes.append(avg_te)
        pred_idxs.append(pred_idx)
        abs_percs.append(abs_perc)
        np.save("ensemble/"+strr0+"_avg_predictions_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), avg_te)
        #1s are predictions, 0s abstain
        np.save("ensemble/"+strr0+"_abstain_idxs_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), pred_idx)
        np.save("ensemble/"+strr0+"_abstain_percentages_{}_ens_mdl_alpha_{}".format(i, alpha)+exp_type+".npy", abs_perc)
        avg_acc_abs, avg_acc_on_nonabs=get_avg_acc_multi(avg_te, y_test, pred_idx=pred_idx)
        accs.append([avg_acc_abs, get_avg_acc_multi(avg_te, y_test), avg_acc_on_nonabs, np.array(abs_perc).mean()])
        pred_intersections.append(np.all(pred_idx, axis=0))


    accs=np.array(accs)
    plt.plot(np.arange(5,25,5), 1-accs[:,0], label="Err Abst as 0")
    plt.plot(np.arange(5,25,5), 1-accs[:,1], label="Err No Abst")
    plt.plot(np.arange(5,25,5), 1-accs[:,2], label="Err Only on Pred")
    plt.plot(np.arange(5,25,5), accs[:,3], label="Abst Perc")
    plt.plot(np.arange(5,25,5), [1-avg_base_acc for i in range(4)], label="Oneshot Avg Err")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Error/Abstain Perc")
    plt.legend()
    plt.savefig(strr0+'Accs_alpha{}.png'.format(alpha))
    plt.close()

    get_eps_no_abst, get_eps= [], []
    lufs_no_abst, lufs = [], []
    get_eps_no_abst.append(get_epsilon_multi(base_test_seed, sample_models, np.ones_like(base_test_seed), conf_arr=base_test_seed_w_conf[42]))
    lufs_no_abst.append(diff_pred_multi([], base_test_seed[sample_models], [], base_test_seed_w_conf[42], conf_arr_test=base_test_seed_w_conf[42], train=False))

    get_eps.append(get_epsilon_multi(base_test_seed, sample_models, np.ones_like(base_test_seed), conf_arr=base_test_seed_w_conf[42]))
    lufs.append(diff_pred_multi([], base_test_seed[sample_models], [], base_test_seed[42], train=False, conf_arr_test=base_test_seed_w_conf[42]))

    for (arr, pred_idx, pred_intersect) in zip(avg_tes, pred_idxs, pred_intersections):
        if pred_intersect.mean()>0:
                get_eps.append(get_epsilon_multi(arr[:,pred_intersect], np.arange(num_models), pred_idx, 
                    conf_arr=base_test_seed[42][pred_intersect]))
                
                lufs.append(diff_pred_multi([], np.moveaxis(arr[:,pred_intersect], 0, 1), [], arr[0][pred_intersect], 
                                  conf_arr_test=base_test_seed_w_conf[42][pred_intersect], train=False))
        else:
            get_eps.append((0,0,0,0,0))
            lufs.append([0,0,0,0,0,0,0,0,0,0])
        get_eps_no_abst.append(get_epsilon_multi(arr, np.arange(num_models), pred_idx, conf_arr=base_test_seed[42]))
        lufs_no_abst.append(diff_pred_multi([], avg_te, [], arr[0], 
                              conf_arr_test=base_test_seed_w_conf[42], train=False))
    get_eps=np.array(get_eps)
    lufs=np.array(lufs)
    get_eps_no_abst=np.array(get_eps_no_abst)
    lufs_no_abst=np.array(lufs_no_abst)

    plt.plot(np.arange(0,25,5), get_eps[:,3], label="% with Avg Flip Chnge > 0")
    plt.plot(np.arange(0,25,5), get_eps[:,4], label="Avg Flip Chnc >0.5")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Avg Flip Chance")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'avg_flip_chance_'+exp_type+"alpha_{}".format(alpha)+'_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps[:,0], label="Eps")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Epsilon")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'Epsilon_'+exp_type+"alpha_{}".format(alpha)+'_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,3], label="% with Avg Flip Chnge > 0")
    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,4], label="Avg Flip Chnc >0.5")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Avg Flip Chance")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'avg_flip_chance_'+exp_type+"alpha_{}".format(alpha)+'_no_abstain.png')
    plt.close()

    plt.plot(np.arange(0,25,5), get_eps_no_abst[:,0], label="Eps")
    plt.xlabel("Num Models Agg")
    plt.ylabel("Epsilon")
    plt.legend()
    plt.savefig('ensemble/'+strr0+'Epsilon_'+exp_type+"alpha_{}".format(alpha)+'_no_abstain.png')
    plt.close()


    pd.DataFrame(data=accs, columns=["Avg Acc with Abstention as Wrong", 
                                         "Avg Acc No Abstention", 
                                         "Avg Acc on Predicted Pts", "Abstain Percent"]).to_csv(
                                     "ensemble/"+strr0+"_Accs_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=get_eps.squeeze(), columns=["epsilon", "avg_eps", "median_eps","percent_chance_flip",
                                        "percent_chance_flip_05" 
                                       ]).to_csv(
                                       "ensemble/"+strr0+"_Avg_Epsilons_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=lufs, columns=["total LUF", "over 0.1", "over 0.2", "over 0.3",
                                    "over 0.4", "over 0.5", "over 0.6","over 0.7", "over 0.8", "over 0.9"]).to_csv(
                                    "ensemble/"+strr0+"_LUFs_"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=get_eps_no_abst.squeeze(), columns=["epsilon", "avg_eps", "median_eps","percent_chance_flip",
                                        "percent_chance_flip_05" 
                                       ]).to_csv(
                                       "ensemble/"+strr0+"_Avg_Epsilons_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv")
    pd.DataFrame(data=lufs_no_abst, columns=["total LUF", "over 0.1", "over 0.2", "over 0.3",
                                    "over 0.4", "over 0.5", "over 0.6","over 0.7", "over 0.8", "over 0.9"]).to_csv(
                                    "ensemble/"+strr0+"_LUFs_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv")

    return {"model_preds": "ensemble/"+strr0+"_avg_predictions_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i),
                "abs_or_pred_true_false": "ensemble/"+strr0+"_abstain_idxs_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy".format(i), 
                "abstain_percents": "ensemble/"+strr0+"_abstain_percentages_{}_ens_mdl_alpha_{}".format(i, alpha)+exp_type+".npy",
                "accs": "ensemble/"+strr0+"_Accs_"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "get_eps_output":"ensemble/"+strr0+"_Avg_Epsilons_"+exp_type+"alpha_{}".format(alpha)+".csv",
                "lufs_output": "ensemble/"+strr0+"_LUFs_"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "get_eps_no_abst_output": "ensemble/"+strr0+"_Avg_Epsilons_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv", 
                "lufs_no_abst_output": "ensemble/"+strr0+"_LUFs_no_abst"+exp_type+"alpha_{}".format(alpha)+".csv"
                }


script("RS", 0.01, "ensemble/sample_models_FMNIST.npy", "FMNIST", 
    '../data/fmnist_base_test_sgd_batch64_drop_seed{}.npy', '', num_models=10, num=200)
script("RS", 0.05, "ensemble/sample_models_FMNIST.npy", "FMNIST", 
    '../data/fmnist_base_test_sgd_batch64_drop_seed{}.npy', '', num_models=10, num=200)

script("LF", 0.01, "ensemble/sample_models_FMNIST.npy", "FMNIST", 
    '../data/fmnist_yp_test_loo_batch64.npy', '', num_models=10, num=200)
script("LF", 0.05, "ensemble/sample_models_FMNIST.npy", "FMNIST", 
    '../data/fmnist_yp_test_loo_batch64.npy', '', num_models=10, num=200)



script("RS", 0.01, "ensemble/sample_models.npy", "warafin", 
    '..data/war_yp_test_same_seed.npy', '', num_models=24, num=200)

script("RS", 0.05, "ensemble/sample_models.npy", "warafin", 
    '..data/war_yp_test_same_seed.npy', '', num_models=24,num=200)

script("LF", 0.01, "ensemble/sample_models.npy", "warafin", 
    "..data/war_loo_test_same_seed.npy", '', num_models=24, num=200)

script("LF", 0.05, "ensemble/sample_models.npy", "warafin", 
    "..data/war_loo_test_same_seed.npy", '', num_models=24, num=200)

script("RS", 0.01, "ensemble/sample_models_FMNIST.npy", "colon", 
    '../data/colon_res/colon_yp_test_same_seed_seed{}.npy', '', num_models=10, num=200)

script("RS", 0.05,"ensemble/sample_models_FMNIST.npy", "colon", 
    '..data/colon_res/colon_yp_test_same_seed_seed{}.npy', '', num_models=10,num=200)

script("LF", 0.01, "ensemble/sample_models_FMNIST.npy", "colon", 
    "..data/colon_res/colon_loo_test_same_seed.npy", '', num_models=10, num=200)

script("LF", 0.05, "ensemble/sample_models_FMNIST.npy", "colon", 
    "..data/colon_res/colon_loo_test_same_seed.npy", '', num_models=10, num=200)

