import os


import tensorflow as tf
gpu=1
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its obnoxious warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')


import sys

import numpy as np
import random as rn
import pandas as pd

np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

tf.random.set_seed(1234)

from trulens.nn.models import get_model_wrapper as ModelWrapper

from trulens.nn.attribution import InternalInfluence

from scipy.stats import mode

from scipy import stats
from skimage.metrics import structural_similarity as ssim
import numpy as np
from get_data import get_data
from model_strings import GermanCredit_str, Adult_str, Seizure_str, Taiwanese_str, Warafin_str
from model_strings import HELOC_str, CTG_str, Thyroid_str, Colon_str
from datalib import FashionMnist
#top k
def top_k_intersection(arr1,
                       arr2,
                       k,
                       axis=-1,
                       return_ranks=False,
                       return_values=False):

    vals1, topk1 = tf.math.top_k(arr1, k)
    vals2, topk2 = tf.math.top_k(arr2, k)
    inter = np.array([np.intersect1d(topk1[a], topk2[a]).shape[0] for a in range(topk1.shape[0])])
    return inter / k

#MSE
def mse(A, B, ax):
    return ((A - B)**2).mean(axis=ax)


def get_metrics(expls1, expls2, k, SSIM_plz=False, c=False):

    if SSIM_plz:
      
        if c:
            expls1=np.reshape(expls1, (expls1.shape[0], expls1.shape[1], -1)).squeeze().mean(axis=1)
            expls2=np.reshape(expls2, (expls2.shape[0], expls2.shape[1], -1)).squeeze().mean(axis=1)
        else:
            expls1=np.reshape(expls1, (expls1.shape[0], expls1.shape[1], -1)).squeeze()
            expls2=np.reshape(expls2, (expls2.shape[0], expls2.shape[1], -1)).squeeze()

      

    sp_rk=np.array([stats.spearmanr(expl1, expl2)[0] for expl1, expl2 in zip(expls1, expls2)])
    pr_rk=np.array([stats.pearsonr(expl1, expl2)[0] for expl1, expl2 in zip(expls1, expls2)])
    mses=np.array([mse(expl1, expl2, ax=None) for expl1, expl2 in zip(expls1, expls2)])
    topk=np.array([top_k_intersection(expl1, expl2, k) for expl1, expl2 in zip(expls1, expls2)])
    l2=np.array([np.linalg.norm(np.abs(expl1-expl2)) for expl1, expl2 in zip(expls1, expls2)])

    if SSIM_plz:
        SSIM=np.array(
            [ssim(expl1, expl2, data_range=max(
                expl1.max(), expl2.max()) - min(expl1.min(), expl2.min())) for expl1, expl2 in zip(expls1, expls2)])

    if SSIM_plz:
        return np.array([SSIM, pr_rk, mses, topk, l2])
    else:
        return np.array([sp_rk, pr_rk, mses, topk, l2])





def get_attrs2(mdl_string, inputs, sd_start=0, seed_stop=15, 
    seed_nrs=None, qoi=0, from_script=False):
    attrs=[]
    attrs_c0=[]
    attrs_c1=[]
    preds= []
    preds_prob=[]
    if not from_script:
        mdl_string2=mdl_string
        mdl_string=mdl_string+"{}.h5"
    else:
        mdl_string2=mdl_string
    if seed_nrs is None:
        seed_nrs=np.arange(sd_start, seed_stop)
    for i in seed_nrs:
        try:
            mod=tf.keras.models.load_model(mdl_string.format(i))
        except:
            mod=tf.keras.models.load_model(mdl_string2+"{}_ex.h5".format(i-100))
        
        if from_script:
            mod_tru=ModelWrapper(mod)
            preds.append((tf.nn.sigmoid(mod.predict(inputs)).numpy().squeeze()>0.5).astype('int'))
            preds_prob.append((tf.nn.sigmoid(mod.predict(inputs)).numpy().squeeze()))
        else:
            mod_tru=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))
            preds.append((mod.predict(inputs).squeeze()>0.5).astype('int'))
            preds_prob.append((mod.predict(inputs).squeeze()))
        exp_c0=InternalInfluence(mod_tru, cuts=(0,-1), qoi=0, doi='point', multiply_activation=False)
        attrs_c0.append(exp_c0.attributions(inputs))
    preds_mode=mode(np.array(preds))[0]
    attrs=np.where(preds_mode[:,:,None], np.array(attrs_c0), -np.array(attrs_c0))
    return attrs.mean(axis=0), preds_prob

def get_attrs_oneshot(filepath, seed_nr, data, qoi="max", mul_act=False, from_script=False):
    if from_script:
        filepath2=filepath.format(seed_nr)
    else:
        filepath2=filepath+"{}.h5".format(seed_nr)
    try:
        mod=tf.keras.models.load_model(filepath2)
    except:
        filepath2=filepath+"{}_ex.h5"
        mod=tf.keras.models.load_model(filepath2.format(seed_nr-100))

    if from_script:
        mod_tru=ModelWrapper(mod)
        preds=((tf.nn.sigmoid(mod.predict(data)).numpy().squeeze()>0.5).astype('int'))
    else:
        mod_tru=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))
        preds=((mod.predict(data).squeeze()>0.5).astype('int'))
    exp=InternalInfluence(mod_tru, cuts=(0,-1), qoi=qoi, doi='point', multiply_activation=mul_act)
    attrs_0=exp.attributions(data)
    attrs=np.where(preds[:,None], np.array(attrs_0), -np.array(attrs_0))

    return np.array(attrs)


def get_attrs_multi(mdl_string, inputs, n_classes, sd_start=0, seed_stop=15, seed_nrs=None, 
    mult_act=False, avg=False, y_test=None, SSIM_plz=False, from_script=False, batch=125, n_batch=10):
    attrs=[]
    preds=[]
    preds_prob=[]
    count=0
    if seed_nrs is None:
        seed_nrs=np.arange(sd_start, seed_stop)
    for i in seed_nrs:
        attrs.append([])

        mod=tf.keras.models.load_model(mdl_string.format(i))
        if from_script:
            preds_prob.append((tf.nn.softmax(mod.predict(inputs)).numpy()))
        else:
            preds_prob.append((mod.predict(inputs)))
        preds.append((mod.predict(inputs).argmax(axis=1)))
        mod=ModelWrapper(mod)
        for j in range(n_classes):            
            exp=InternalInfluence(mod, 
                cuts=(0,-2), qoi=j, doi='point', multiply_activation=mult_act)
            if SSIM_plz:
                attrs[count].append(np.concatenate([exp.attributions(
                    inputs[batch*j:batch*(j+1)]) for j in range(n_batch)], axis=0))

            else:
                attrs[count].append(exp.attributions(inputs))
        count+=1
    #get mode of preds in the right shape to broadcast,
    #i.e. (num_underlying_models, num_classes, num_data_points, (image_dims)
    attrs=np.array(attrs)
    preds_mode=tf.keras.utils.to_categorical(mode(np.array(preds))[0]).T
    #then multiply with attrs to select only the attrs for the class that's the mode 
    #for each point's ensemble prediction. Sum over axis 1 because all but the desired class
    #will be zero. Then take mean over the num_underlying_models to get the agg attr
    if SSIM_plz:
        attrs_only_predicted_class=(
            attrs*preds_mode[None, :,:,None, None]).sum(axis=1).mean(axis=0)
    else:
        attrs_only_predicted_class=(
            attrs*preds_mode[None, :,:]).sum(axis=1).mean(axis=0)
        #need to figure out how to only get the class that was predicted.
    return attrs_only_predicted_class, preds_mode


def get_attrs_oneshot_multi(filepath, data, seed_nr, n_classes=10, mul_act=False,
                            filepath2=None, y_test=None, SSIM_plz=False, from_script=False,
                            batch=125, n_batch=10):
    attrs=[]
    mod=tf.keras.models.load_model(filepath.format(seed_nr))
    mod.summary()
    if from_script:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        mod=ModelWrapper(mod)
    else:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        #if from_Script, there is no softmax layer
        mod=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))

    for i in range(n_classes):
        exp=InternalInfluence(mod, cuts=(0,-1), qoi=i, doi='point', multiply_activation=mul_act)
        
        if SSIM_plz:
            attrs.append(np.concatenate([exp.attributions(data[batch*j:batch*(j+1)]) for j in range(n_batch)], axis=0))
        else:
            attrs.append(exp.attributions(data))
    attrs=np.array(attrs)
    if SSIM_plz:
        attrs_only_predicted_class=(
            attrs*preds[:,:,None, None, None]).sum(axis=0)
    else:
        attrs_only_predicted_class=(
            attrs*preds[:,:,None]).sum(axis=0)
    return attrs_only_predicted_class

import pandas as pd
def func(k, metrics, abst, num_mods=24):
    count=0
    new_metrics=[]
    if metrics.shape[0]>(num_mods*(num_mods-1))/2:
        for i,j in [(i,j) for i, j in zip(
        np.meshgrid(np.arange(num_mods),np.arange(num_mods))[0].flatten(),
        np.meshgrid(np.arange(num_mods),np.arange(num_mods))[1].flatten()) if i != j]:
            non_abst=abst[i]&abst[j]

            new_metrics.append(metrics[count][:,non_abst].mean(axis=1))
            count+=1
    else:
        for i,j in [(i, j) for i in range(num_mods) for j in range(i)]:
            non_abst=abst[i]&abst[j]

            new_metrics.append(metrics[count][:,non_abst].mean(axis=1))
            count+=1
    return np.array(new_metrics)


from dbify import dbify
@dbify("expl_instability", "selective_ensemble_saliency")
def script(ensemble_nr, mdl_string, exp_type, dataset_string, from_script=False,
    multi=False, num_mods=24, SSIM_plz=False, n=500, load_attrs=False,
     batch=125, n_batch=10, alpha=0.01):
    i=ensemble_nr
    strr0=dataset_string
    if dataset_string=="Colon":
        strr_0="colon"
    else:
        strr_0=dataset_string
    abst=np.load("ensemble/"+strr_0+"_abstain_idxs_{}_ens_mdls_alpha_{}".format(i, alpha)+exp_type+".npy")
 

    i=ensemble_nr
   
    
    attr_metrics=np.load("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled.npy".format(i))


    select_attr_metrics=func(5, attr_metrics, abst, num_mods=num_mods)

    np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_SELECT_alpha_{}.npy".format(i, alpha), select_attr_metrics)
  

    if SSIM_plz:
        return {"SSIM": select_attr_metrics.mean(axis=0)[0],"pear_rk_sc": select_attr_metrics.mean(axis=0)[1], 
                "mses_sc": select_attr_metrics.mean(axis=0)[2],"topk_int_sc":select_attr_metrics.mean(axis=0)[3], 
                "l2_dist_sc": select_attr_metrics.mean(axis=0)[4],
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_SELECT_alpha_{}.npy".format(i, alpha)}
    else:
        return {"spear_rk_sc": select_attr_metrics.mean(axis=0)[0],"pear_rk_sc": select_attr_metrics.mean(axis=0)[1], 
                "mses_sc": select_attr_metrics.mean(axis=0)[2],"topk_int_sc":select_attr_metrics.mean(axis=0)[3], 
                "l2_dist_sc": select_attr_metrics.mean(axis=0)[4],
                
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_SELECT_alpha_{}.npy".format(i, alpha)
                }


for clss in [GermanCredit_str(),Adult_str(), Taiwanese_str()]:
    
    for i in range(10,24,5):
        script(i, clss.rs, "RS", clss.name, multi=False, num_mods=24, load_attrs=False)
        script(i, clss.loo, "LF", clss.name, multi=False, num_mods=24, load_attrs=False)


#FMNIST

for clss in [Warafin_str()]:


    for i in range(10,24,5):
            script(i, clss.rs, "RS", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)
            script(i, clss.loo, "LF", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)


perm=np.random.permutation(200)
np.save("ensemble/permutation_for_dataset_"+"Colon"+".npy", perm)

for clss in [Colon_str()]:

    for i in range(10,24,5):
            script(i, clss.rs, "RS", clss.name, multi=True, num_mods=10, n=200,SSIM_plz=True)
            script(i, clss.loo, "LF", clss.name, multi=True, num_mods=10, n=200,SSIM_plz=True)




for i in range(10,24,5):
            script(i, "../resubmit_data/models/fmnist_base_sgd_drop_seed{}.h5", "RS", "FMNIST", 
                multi=True, SSIM_plz=True, num_mods=10, n=200)
            script(i, "../resubmit_data/models/fmnist_LOO_{}_5trials_batch64.h5", "LF", "FMNIST",
             multi=True, SSIM_plz=True, num_mods=10, n=200)


