import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"]="2"

import tensorflow as tf
gpu=1
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its obnoxious warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')


import sys

import numpy as np
import random as rn

#np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

tf.random.set_seed(1234)

from trulens.nn.models import get_model_wrapper as ModelWrapper

from trulens.nn.attribution import InternalInfluence

from scipy.stats import mode

from scipy import stats
from skimage.metrics import structural_similarity as ssim
import numpy as np
from get_data import get_data
from model_strings import GermanCredit_str, Adult_str, Seizure_str, Taiwanese_str, Warafin_str
from model_strings import HELOC_str, CTG_str, Thyroid_str, Colon_str
from datalib import FashionMnist
#top k


def top_k_intersection(arr1,
                       arr2,
                       k,
                       axis=-1,
                       return_ranks=False,
                       return_values=False):
    vals1, topk1 = tf.math.top_k(arr1, k)
    vals2, topk2 = tf.math.top_k(arr2, k)

    inter = np.intersect1d(topk1, topk2)
    if return_ranks:
        return topk1, topk2, inter
    elif return_values:
        return topk1, topk2, vals1, vals2, inter
    else:
        return inter.shape[0] / k
#MSE
def mse(A, B, ax):
    return ((A - B)**2).mean(axis=ax)

def get_metrics(expl1, expl2, k, SSIM_plz=False, c=False):

    if SSIM_plz:
      
        if c:
            expl1=expl1.mean(axis=0).flatten()
            expl2=expl2.mean(axis=0).flatten()
        else:
            expl1=expl1.flatten()
            expl2=expl2.flatten()
      

    sp_rk=np.array([stats.spearmanr(expl1, expl2)[0]])
    pr_rk=np.array([stats.pearsonr(expl1, expl2)[0] ])
    mses=np.array([mse(expl1, expl2, ax=None) ])
    topk=np.array([top_k_intersection(expl1, expl2, k)])
    l2=np.array([np.linalg.norm(np.abs(expl1-expl2))])

    if SSIM_plz:
        SSIM=np.array(
            [ssim(expl1, expl2, data_range=max(
                expl1.max(), expl2.max()) - min(expl1.min(), expl2.min()))])

    if SSIM_plz:
        return np.array([SSIM, pr_rk, mses, topk, l2])
    else:
        return np.array([sp_rk, pr_rk, mses, topk, l2])

def get_attrs_oneshot(filepath, seed_nr, data, qoi="max", mul_act=False, from_script=False):
    if from_script:
        filepath2=filepath.format(seed_nr)
    else:
        filepath2=filepath+"{}.h5".format(seed_nr)
    try:
        mod=tf.keras.models.load_model(filepath2)
    except:
        filepath2=filepath+"{}_ex.h5"
        mod=tf.keras.models.load_model(filepath2.format(seed_nr-100))

    if from_script:
        mod_tru=ModelWrapper(mod)
        preds=((tf.nn.sigmoid(mod.predict(data)).numpy().squeeze()>0.5).astype('int'))
    else:
        mod_tru=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))
        preds=((mod.predict(data).squeeze()>0.5).astype('int'))
    exp=InternalInfluence(mod_tru, cuts=(0,-1), qoi=qoi, doi='point', multiply_activation=mul_act)
    attrs_0=exp.attributions(data)
    attrs=np.where(preds[:,None], np.array(attrs_0), -np.array(attrs_0))

    return np.array(attrs)

def get_attrs_oneshot_multi(filepath, data, seed_nr, n_classes=10, mul_act=False,
                            filepath2=None, y_test=None, SSIM_plz=False, from_script=False,
                            batch=125, n_batch=10):
    attrs=[]
    mod=tf.keras.models.load_model(filepath.format(seed_nr))
    mod.summary()
    if from_script:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        mod=ModelWrapper(mod)
    else:
        preds=tf.keras.utils.to_categorical(mod.predict(data).argmax(axis=1)).T
        #if from_Script, there is no softmax layer
        mod=ModelWrapper(tf.keras.models.Model(mod.input, mod.layers[-2].output))

    for i in range(n_classes):
        exp=InternalInfluence(mod, cuts=(0,-1), qoi=i, doi='point', multiply_activation=mul_act)
        
        if SSIM_plz:
            attrs.append(np.concatenate([exp.attributions(data[batch*j:batch*(j+1)]) for j in range(n_batch)], axis=0))
        else:
            attrs.append(exp.attributions(data))
    attrs=np.array(attrs)
    if SSIM_plz:
        attrs_only_predicted_class=(
            attrs*preds[:,:,None, None, None]).sum(axis=0)
    else:
        
        attrs_only_predicted_class=(
            attrs*preds[:,:,None]).sum(axis=0)
    
    return attrs_only_predicted_class

def func(k, attrs, filepath=None, mul_act=False, num_mods=24, 
         SSIM_plz=False, multi=False, y_preds_str=None, c=False):
    #filepath=filepath+"{}.h5"
    base_to_other=[]
    for i,j in [(i, j) for i in range(num_mods) for j in range(i)]:
           base_to_other.append(get_metrics(attrs[i], attrs[j], k, SSIM_plz=SSIM_plz, c=c))
        
    return np.array(base_to_other)


def script(ensemble_nr, mdl_string, exp_type, dataset_string, from_script=False,
    multi=False, num_mods=24, SSIM_plz=False, n=500, load_attrs=False, batch=125,
     n_batch=10, c=False):
    
    if SSIM_plz:
        try: 
            (X_train, y_train), (X_test, y_test) , _ =get_data(dataset_string) 
            del X_train, y_train
        except:
            FM_data=FashionMnist()
            y_test=FM_data.y_te_1hot
            X_test=FM_data.x_te
            del FM_data

    else:
        (X_train, y_train), (X_test, y_test) , _ =get_data(dataset_string) 
        del X_train, y_train
    
    n_feats=X_test.shape[1]
    if SSIM_plz:
        n_classes=int(y_test.argmax(axis=1).max()+1)
    else:
        n_classes=int(y_test.max()+1)
  
    strr0=dataset_string
    strr_0=dataset_string
    i=ensemble_nr
    inputs=X_test

    try:
        all_attrs_scaled=np.load("ensemble/attrs/"+strr0+"_"+exp_type+"attrs_ensembles_of_{}_scaled.npy".format(ensemble_nr))
    except:
        if multi:
            all_attrs=get_attrs_oneshot_multi(mdl_string, X_test, 0, n_classes=n_classes, SSIM_plz=SSIM_plz, from_script=False,
                            batch=125, n_batch=10)
        else:
            all_attrs=get_attrs_oneshot(mdl_string, 0, X_test, from_script=from_script)



    attr_metrics_scaled=func(5, all_attrs_scaled, SSIM_plz=SSIM_plz, c=c, num_mods=num_mods)
    np.save("ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_COMP.npy".format(i), attr_metrics_scaled)
    

    if SSIM_plz:
        return {"SSIM": attr_metrics_scaled.mean(axis=0)[0].mean(),"pear_rk_sc": attr_metrics_scaled.mean(axis=0)[1].mean(), 
                "mses_sc": attr_metrics_scaled.mean(axis=0)[2].mean(),"topk_int_sc":attr_metrics_scaled.mean(axis=0)[3].mean(), 
                "l2_dist_sc": attr_metrics_scaled.mean(axis=0)[4].mean(),
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_COMP.npy".format(i), 
                "result_filename_scaled" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_COMP.npy".format(i) 
            }
    else:
        return {"spear_rk_sc": attr_metrics_scaled.mean(axis=0)[0].mean(),"pear_rk_sc": attr_metrics_scaled.mean(axis=0)[1].mean(), 
                "mses_sc": attr_metrics_scaled.mean(axis=0)[2].mean(),"topk_int_sc":attr_metrics_scaled.mean(axis=0)[3].mean(), 
                "l2_dist_sc": attr_metrics_scaled.mean(axis=0)[4].mean(),
                "result_filename" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_COMP.npy".format(i), 
                "result_filename_scaled" : "ensemble/attrs/"+strr0+"_"+exp_type+"metrics_for_ensembles_of_{}_scaled_COMP.npy".format(i) }


for clss in [GermanCredit_str(), Adult_str(), Seizure_str(), Taiwanese_str()]:
    script(1, clss.rs, "RS", clss.name, multi=False, num_mods=24, load_attrs=False)
    script(1, clss.loo, "LF", clss.name, multi=False, num_mods=24, load_attrs=False)

    for i in range(15,24,5):
        script(i, clss.rs, "RS", clss.name, multi=False, num_mods=24, load_attrs=False)
        script(i, clss.loo, "LF", clss.name, multi=False, num_mods=24, load_attrs=False)

for clss in [Warafin_str()]:
    
    script(1, clss.rs+"{}.h5", "RS", clss.name, multi=True, num_mods=24, n=500, SSIM_plz=False)
    script(1, clss.loo+"{}.h5", "LF", clss.name, multi=True, num_mods=24, n=500, SSIM_plz=False)

    for i in range(5,24,5):
            script(i, clss.rs, "RS", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)
            script(i, clss.loo, "LF", clss.name, multi=True, num_mods=24, n=500,SSIM_plz=False)


perm=np.random.permutation(100)
np.save("ensemble/permutation_for_dataset_"+"Colon"+".npy", perm)

for clss in [Colon_str()]:
    
    script(1, clss.rs, "RS", clss.name, multi=True, num_mods=5, n=200, SSIM_plz=True, c=True)
    script(1, clss.loo, "LF", clss.name, multi=True, num_mods=5, n=200, SSIM_plz=True, c=True)

    for i in range(5,24,5):
            script(i, clss.rs, "RS", clss.name, multi=True, num_mods=10, n=200,SSIM_plz=True)
            script(i, clss.loo, "LF", clss.name, multi=True, num_mods=10, n=200,SSIM_plz=True)


#FMNIST

script(1, "../data/models/fmnist_base_sgd_drop_seed{}.h5", "RS", "FMNIST", 
    multi=True, SSIM_plz=True, num_mods=10, n=200, batch=1000,
                         n_batch=10)
script(1, "../data/models/fmnist_LOO_{}_5trials_batch64.h5", "LF", "FMNIST",
    SSIM_plz=True, multi=True, num_mods=10, n=200)

for i in range(10,24,5):
    script(i, "../data/models/fmnist_base_sgd_drop_seed{}.h5", "RS", "FMNIST", 
                multi=True, SSIM_plz=True, num_mods=10, n=200, batch=1000,
                         n_batch=10)
       script(i, "../data/models/fmnist_LOO_{}_5trials_batch64.h5", "LF", "FMNIST", 
                multi=True, SSIM_plz=True, num_mods=10, n=200, batch=1000,
                         n_batch=10)

