import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="2"

import tensorflow as tf
gpu=1
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its obnoxious warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')


import sys

import numpy as np
import random as rn
import pandas as pd

#np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

tf.random.set_seed(1234)

from trulens.nn.models import get_model_wrapper as ModelWrapper

from trulens.nn.attribution import InternalInfluence

from scipy.stats import mode

from scipy import stats
from skimage.metrics import structural_similarity as ssim
import numpy as np
from get_data import get_data
from model_strings import GermanCredit_str, Adult_str, Seizure_str, Taiwanese_str, Warafin_str
from model_strings import HELOC_str, CTG_str, Thyroid_str, Colon_str
from datalib import FashionMnist
#top k
def top_k_intersection(arr1,
                       arr2,
                       k,
                       axis=-1,
                       return_ranks=False,
                       return_values=False):
 
    vals1, topk1 = tf.math.top_k(arr1, k)
    vals2, topk2 = tf.math.top_k(arr2, k)
    inter = np.array([np.intersect1d(topk1[a], topk2[a]).shape[0] for a in range(topk1.shape[0])])
    return inter / k

#MSE
def mse(A, B, ax):
    return ((A - B)**2).mean(axis=ax)


def get_metrics(expls1, expls2, k, SSIM_plz=False, c=False):
    if SSIM_plz:
        if c:
            expls1=np.reshape(expls1, (expls1.shape[0], expls1.shape[1], -1)).squeeze().mean(axis=1)
            expls2=np.reshape(expls2, (expls2.shape[0], expls2.shape[1], -1)).squeeze().mean(axis=1)
        else:
            expls1=np.reshape(expls1, (expls1.shape[0], expls1.shape[1], -1)).squeeze()
            expls2=np.reshape(expls2, (expls2.shape[0], expls2.shape[1], -1)).squeeze()

    

    sp_rk=np.array([stats.spearmanr(expl1, expl2)[0] for expl1, expl2 in zip(expls1, expls2)])
    pr_rk=np.array([stats.pearsonr(expl1, expl2)[0] for expl1, expl2 in zip(expls1, expls2)])
    mses=np.array([mse(expl1, expl2, ax=None) for expl1, expl2 in zip(expls1, expls2)])
    topk=np.array([top_k_intersection(expl1, expl2, k) for expl1, expl2 in zip(expls1, expls2)])
    l2=np.array([np.linalg.norm(np.abs(expl1-expl2)) for expl1, expl2 in zip(expls1, expls2)])

    if SSIM_plz:
        SSIM=np.array(
            [ssim(expl1, expl2, data_range=max(
                expl1.max(), expl2.max()) - min(expl1.min(), expl2.min())) for expl1, expl2 in zip(expls1, expls2)])

    if SSIM_plz:
        return np.array([SSIM, pr_rk, mses, topk, l2])
    else:
        return np.array([sp_rk, pr_rk, mses, topk, l2])




#sign is proper
def get_attrs2(mdl_string, inputs, sd_start=0, seed_stop=15, 
    seed_nrs=None, qoi=0, from_script=False):
    attrs=[]
    attrs_c0=[]
    attrs_c1=[]
    preds= []
    preds_prob=[]
    if not from_script:
        mdl_string2=mdl_string
        mdl_string=mdl_string+"{}.h5"
    else:
        mdl_string2=mdl_string
    if seed_nrs is None:
        seed_nrs=np.arange(sd_start, seed_stop)
    for i in seed_nrs:
        try:
            mod=tf.keras.models.load_model(mdl_string.format(i))
        except:
            mod=tf.keras.models.load_model(mdl_string2+"{}_ex.h5".format(i-100))
        
        if from_script:
            mod_tru=ModelWrapper(mod)
            preds.append((tf.nn.sigmoid(mod.predict(inputs)).numpy().squeeze()>0.5).astype('int'))
            preds_prob.append((tf.nn.sigmoid(mod.predict(inputs)).numpy().squeeze()))
        else:
            mod_tru=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))
            preds.append((mod.predict(inputs).squeeze()>0.5).astype('int'))
            preds_prob.append((mod.predict(inputs).squeeze()))
        exp_c0=InternalInfluence(mod_tru, cuts=(0,-1), qoi=0, doi='point', multiply_activation=False)
        #exp_c1=InternalInfluence(mod_tru, cuts=(0,-1), qoi=1, doi='point', multiply_activation=False)
        attrs_c0.append(exp_c0.attributions(inputs))
#         attrs_c1.append(exp_c0.attributions(inputs))
    preds_mode=mode(np.array(preds))[0]
    attrs=np.where(preds_mode[:,:,None], np.array(attrs_c0), -np.array(attrs_c0))
    return attrs.mean(axis=0), preds_prob

def get_attrs_oneshot(filepath, seed_nr, data, qoi="max", mul_act=False, from_script=False):
    if from_script:
        filepath2=filepath.format(seed_nr)
    else:
        filepath2=filepath+"{}.h5".format(seed_nr)
    try:
        mod=tf.keras.models.load_model(filepath2)
    except:
        filepath2=filepath+"{}_ex.h5"
        mod=tf.keras.models.load_model(filepath2.format(seed_nr-100))

    if from_script:
        mod_tru=ModelWrapper(mod)
        preds=((tf.nn.sigmoid(mod.predict(data)).numpy().squeeze()>0.5).astype('int'))
    else:
        mod_tru=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))
        preds=((mod.predict(data).squeeze()>0.5).astype('int'))
    exp=InternalInfluence(mod_tru, cuts=(0,-1), qoi=qoi, doi='point', multiply_activation=mul_act)
    attrs_0=exp.attributions(data)
    attrs=np.where(preds[:,None], np.array(attrs_0), -np.array(attrs_0))

    return np.array(attrs)


def get_attrs_multi(mdl_string, inputs, n_classes, sd_start=0, seed_stop=15, seed_nrs=None, 
    mult_act=False, avg=False, y_test=None, SSIM_plz=False, from_script=False, batch=125, n_batch=10):
    attrs=[]
    preds=[]
    preds_prob=[]
    count=0
    if seed_nrs is None:
        seed_nrs=np.arange(sd_start, seed_stop)
    for i in seed_nrs:
        attrs.append([])

        mod=tf.keras.models.load_model(mdl_string.format(i))
        if from_script:
            preds_prob.append((tf.nn.softmax(mod.predict(inputs)).numpy()))
        else:
            preds_prob.append((mod.predict(inputs)))
        preds.append((mod.predict(inputs).argmax(axis=1)))
        mod=ModelWrapper(mod)
        for j in range(n_classes):            
            exp=InternalInfluence(mod, 
                cuts=(0,-2), qoi=j, doi='point', multiply_activation=mult_act)
            if SSIM_plz:
                attrs[count].append(np.concatenate([exp.attributions(
                    inputs[batch*j:batch*(j+1)]) for j in range(n_batch)], axis=0))

            else:
                attrs[count].append(exp.attributions(inputs))
        count+=1
    #get mode of preds in the right shape to broadcast,
    #i.e. (num_underlying_models, num_classes, num_data_points, (image_dims)
    attrs=np.array(attrs)
    preds_mode=tf.keras.utils.to_categorical(mode(np.array(preds))[0]).T
    #then multiply with attrs to select only the attrs for the class that's the mode 
    #for each point's ensemble prediction. Sum over axis 1 because all but the desired class
    #will be zero. Then take mean over the num_underlying_models to get the agg attr
    if SSIM_plz:
        attrs_only_predicted_class=(
            attrs*preds_mode[None, :,:,None, None]).sum(axis=1).mean(axis=0)
    else:
        
        attrs_only_predicted_class=(
            attrs*preds_mode[None, :,:]).sum(axis=1).mean(axis=0)
        
    return attrs_only_predicted_class, preds_mode


def get_attrs_oneshot_multi(filepath, data, seed_nr, n_classes=10, mul_act=False,
                            filepath2=None, y_test=None, SSIM_plz=False, from_script=False,
                            batch=125, n_batch=10):
    attrs=[]
    mod=tf.keras.models.load_model(filepath.format(seed_nr))
    mod.summary()
    if from_script:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        mod=ModelWrapper(mod)
    else:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        #if from_Script, there is no softmax layer
        mod=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))

    for i in range(n_classes):
        exp=InternalInfluence(mod, cuts=(0,-1), qoi=i, doi='point', multiply_activation=mul_act)
        
        if SSIM_plz:
            attrs.append(np.concatenate([exp.attributions(data[batch*j:batch*(j+1)]) for j in range(n_batch)], axis=0))
        else:
            attrs.append(exp.attributions(data))
    attrs=np.array(attrs)
    if SSIM_plz:
        attrs_only_predicted_class=(
            attrs*preds[:,:,None, None, None]).sum(axis=0)
    else:
        
        attrs_only_predicted_class=(
            attrs*preds[:,:,None]).sum(axis=0)
    
    return attrs_only_predicted_class

def func(k, attrs, filepath=None, mul_act=False, num_mods=24, 
         SSIM_plz=False, multi=False, y_preds_str=None):
    base_to_other=[]
    spear=[]
    pear=[]
    mses=[]
    topk=[]
    l2=[]
    for i,j in [(i, j) for i in range(24) for j in range(i)]:
        df=pd.DataFrame(data=attrs[i],    # values
              index=None,    # 1st column as index
              columns=None)

        df2=pd.DataFrame(data=attrs[j],    # values
              index=None,    # 1st column as index
              columns=None)
        spear.append(df.corrwith(df2, axis=1, method='spearman'))
        pear.append(df.corrwith(df2, axis=1))
        mses.append(mse(attrs[i], attrs[j], ax=1))
        topk.append(top_k_intersection(attrs[i], attrs[j], k))
        l2.append(np.linalg.norm(np.abs(attrs[i]-attrs[j]), axis=1))
        
        
    base_to_other=[spear, pear, mses, topk, l2]
        
    return np.moveaxis(np.array(base_to_other), 0,1)

def script(ensemble_nr, mdl_string, exp_type, dataset_string, from_script=False,
    multi=False, num_mods=24, SSIM_plz=False, n=500, load_attrs=False, batch=125, n_batch=10):
    
    if SSIM_plz:
        try: 
            (X_train, y_train), (X_test, y_test) , _ =get_data(dataset_string) 
            del X_train, y_train
            c=True
        except:
            FM_data=FashionMnist()
            y_test=FM_data.y_te_1hot
            X_test=FM_data.x_te
            c=False
            del FM_data

    else:
        (X_train, y_train), (X_test, y_test) , _ =get_data(dataset_string) 
        del X_train, y_train
      
        c=False
    
    n_feats=X_test.shape[1]
    if SSIM_plz:
        n_classes=int(y_test.argmax(axis=1).max()+1)
    else:
        n_classes=int(y_test.max()+1)
  
    strr0=dataset_string
    strr_0=dataset_string
    i=ensemble_nr
    inputs=X_test
    if not load_attrs:
        try:
            perm=np.load("ensemble/permutation_for_dataset_"+strr0+".npy")
            if perm.max()>n:
                perm=np.random.permutation(n)
                np.save("ensemble/permutation_for_dataset_"+strr0+".npy", perm)
           
        except:
            perm=np.random.permutation(n)
         
            np.save("ensemble/permutation_for_dataset_"+strr0+".npy", perm)
        all_attrs, all_attrs_scaled=[], []
        all_preds_prob=[]
        for (j,sed_nrs) in enumerate(np.split(perm[0:i*num_mods], num_mods)):
            if ensemble_nr>1:
                if multi:

                    attrs, preds_prob=get_attrs_multi(mdl_string, inputs, n_classes, seed_nrs=sed_nrs,
                         SSIM_plz=SSIM_plz, from_script=from_script,batch=batch,
                         n_batch=n_batch)
                else:
                    attrs, preds_prob=get_attrs2(mdl_string, inputs, seed_nrs=sed_nrs, 
                        qoi=0, from_script=from_script)
                attrs_scaled=attrs/attrs.max()
                all_preds_prob.append(preds_prob)
            else:
                if multi:
                    attrs=get_attrs_oneshot_multi(mdl_string, inputs, sed_nrs[0], 
                        n_classes=n_classes, SSIM_plz=SSIM_plz, from_script=from_script, batch=batch,
                         n_batch=n_batch)
                else:
                    attrs=get_attrs_oneshot(mdl_string, sed_nrs[0], inputs, qoi=0, from_script=from_script)
                attrs_scaled=attrs/attrs.max()
            all_attrs.append(attrs)
            all_attrs_scaled.append(attrs_scaled)
        all_attrs=np.array(all_attrs)
        all_attrs_scaled=np.array(all_attrs_scaled)
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}_scaled.npy".format(i), all_attrs_scaled)
        #np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}.npy".format(i), attr_metrics)
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}.npy".format(i), all_attrs)

        attr_metrics_scaled=func(5, all_attrs_scaled, SSIM_plz=SSIM_plz)
        for idx in np.argwhere(np.isnan((attr_metrics_scaled))):
            attr_metrics_scaled[idx[0], idx[1], idx[2]]=0

        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled.npy".format(i), attr_metrics_scaled)
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"pred_probs_ensembles_of_{}.npy".format(i), all_preds_prob)


    else:
        all_attrs=np.load("ensemble/attrs/"+strr0+"attrs_ensembles_of_{}.npy".format(ensemble_nr))
        all_attrs_scaled=np.load("ensemble/attrs/"+strr0+"attrs_ensembles_of_{}_scaled.npy".format(ensemble_nr))

        attr_metrics_scaled=func(5, all_attrs_scaled, SSIM_plz=SSIM_plz)
        for idx in np.argwhere(np.isnan((attr_metrics_scaled))):
            attr_metrics_scaled[idx[0], idx[1], idx[2]]=0
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled.npy".format(i), attr_metrics_scaled)
        # np.save("ensemble/attrs/"+strr0+"attrs_ensembles_of_{}_scaled.npy".format(i), all_attrs_scaled)
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}.npy".format(i), attr_metrics)
        # np.save("ensemble/attrs/"+strr0+"attrs_ensembles_of_{}.npy".format(i), all_attrs)
        np.save("ensemble/attrs/"+strr0+"_"+exp_type+"pred_probs_ensembles_of_{}.npy".format(i), all_preds_prob)

    slicee=102
    

    random_sel=[ 4, 20, 11, 17, 22, 18, 17,  1, 12,  6]
    import matplotlib.pyplot as plt
    x=np.arange(n_feats)
    width = 0.25  # the width of the bars
    all_attrs=np.array(all_attrs)
    all_attrs_scaled=np.array(all_attrs_scaled)

    
    if not SSIM_plz:
        zip_arr=zip([all_attrs_scaled,all_attrs],
                        ["scaled", "unsc"])
        order_1s = np.argsort((np.array(all_attrs)[0][slicee]))

        for array, sc in zip_arr:
            fig, ax = plt.subplots(figsize=(30, 15))

            rects1 = ax.bar(x - 8*width/9, (np.array(array)[random_sel[0]][slicee])[order_1s], width, label='1')
            rects2 = ax.bar(x- 7*width/9, (np.array(array)[random_sel[1]][slicee])[order_1s], width, label='2')
            rects3 = ax.bar(x - 6*width/9, (np.array(array)[random_sel[2]][slicee])[order_1s], width, label='3')
            rects4 = ax.bar(x - 5*width/9, (np.array(array)[random_sel[3]][slicee])[order_1s], width, label='4')
            rects5 = ax.bar(x- 4*width/9, (np.array(array)[random_sel[4]][slicee])[order_1s], width, label='5')
            rects6 = ax.bar(x - 3*width/9, (np.array(array)[random_sel[5]][slicee])[order_1s], width, label='6')
            rects7 = ax.bar(x - 2*width/9, (np.array(array)[random_sel[6]][slicee])[order_1s], width, label='7')
            rects8 = ax.bar(x- width/9, (np.array(array)[random_sel[7]][slicee])[order_1s], width, label='8' )
            rects9 = ax.bar(x + width/9, (np.array(array)[random_sel[8]][slicee])[order_1s], width, label='9')

            ax.set_ylabel('Inf')
            ax.set_title('Feats')
       
            ax.legend()
            plt.savefig('ensemble/attrs/'+strr_0+"_"+exp_type+'disagreement_of_attrs_example_pt{}_num_ens_{}'.format(slicee, i)+exp_type+"_"+sc+'.png')
            plt.close()


    if SSIM_plz:
        return {"SSIM": attr_metrics_scaled.mean(axis=0)[0].mean(),"pear_rk_sc": attr_metrics_scaled.mean(axis=0)[1].mean(), 
                "mses_sc": attr_metrics_scaled.mean(axis=0)[2].mean(),"topk_int_sc":attr_metrics_scaled.mean(axis=0)[3].mean(), 
                "l2_dist_sc": attr_metrics_scaled.mean(axis=0)[4].mean(),
                "spear_rk": 0.0, #attr_metrics.mean(axis=0)[0].mean(),"pear_rk": attr_metrics.mean(axis=0)[1].mean(), 
                "mses": 0.0, #attr_metrics.mean(axis=0)[2].mean(),"topk_int":attr_metrics.mean(axis=0)[3].mean(), 
                "l2_dist": 0.0, #attr_metrics.mean(axis=0)[4].mean(),
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}.npy".format(i), 
                "attr_files": "ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}.npy".format(i) ,
                "result_filename_scaled" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled.npy".format(i), 
                "attr_files_sclaed": "ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}_scaled.npy".format(i) }
    else:
        return {"spear_rk_sc": attr_metrics_scaled.mean(axis=0)[0].mean(),"pear_rk_sc": attr_metrics_scaled.mean(axis=0)[1].mean(), 
                "mses_sc": attr_metrics_scaled.mean(axis=0)[2].mean(),"topk_int_sc":attr_metrics_scaled.mean(axis=0)[3].mean(), 
                "l2_dist_sc": attr_metrics_scaled.mean(axis=0)[4].mean(),
                "spear_rk": 0.0, #attr_metrics.mean(axis=0)[0].mean(),"pear_rk": attr_metrics.mean(axis=0)[1].mean(), 
                "mses": 0.0, #attr_metrics.mean(axis=0)[2].mean(),"topk_int":attr_metrics.mean(axis=0)[3].mean(), 
                "l2_dist": 0.0,#attr_metrics.mean(axis=0)[4].mean(),
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}.npy".format(i), 
                "attr_files": "ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}.npy".format(i) ,
                "result_filename_scaled" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled.npy".format(i), 
                "attr_files_sclaed": "ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}_scaled.npy".format(i) }


for clss in [GermanCredit_str(), Adult_str(), Seizure_str(), Taiwanese_str()]:
    script(1, clss.rs, "RS", clss.name, multi=False, num_mods=24, load_attrs=False)
    script(1, clss.loo, "LF", clss.name, multi=False, num_mods=24, load_attrs=False)
    for i in range(5,24,5):

        script(i, clss.rs, "RS", clss.name, multi=False, num_mods=24, load_attrs=False)
        script(i, clss.loo, "LF", clss.name, multi=False, num_mods=24, load_attrs=False)


for clss in [Warafin_str()]:
    
    script(1, clss.rs+"{}.h5", "RS", clss.name, multi=True, num_mods=24, n=500, SSIM_plz=False)
    script(1, clss.loo+"{}.h5", "LF", clss.name, multi=True, num_mods=24, n=500, SSIM_plz=False)

    for i in range(5,24,5):
            script(i, clss.rs+"{}.h5", "RS", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)
            script(i, clss.loo+"{}.h5", "LF", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)



script(1, "../resubmit_data/models/fmnist_base_sgd_drop_seed{}.h5", "RS", "FMNIST", 
    multi=True, SSIM_plz=True, num_mods=5, n=100, batch=1000,
                         n_batch=10)

for i in range(5,24,5):
    script(i, "../resubmit_data/models/fmnist_base_sgd_drop_seed{}.h5", "RS", "FMNIST", 
                multi=True, SSIM_plz=True, num_mods=5, n=100, batch=1000,
                         n_batch=10)

script(1, "../resubmit_data/models/fmnist_LOO_{}_batch64.h5", "LF", "FMNIST",
    SSIM_plz=True, multi=True, num_mods=5, n=100)

for i in range(5,24,5):
            script(i, "../resubmit_data/models/fmnist_LOO_{}_batch64.h5", "LF", "FMNIST",
             multi=True, SSIM_plz=True, num_mods=5, n=100, batch=1000,
                         n_batch=10)