import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import shutil
import sys
sys.path.append('./')
sys.path.append('../')

from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.datasets.load_dataset import load_dataset
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from deel.lip.losses import MulticlassHinge
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler

import tensorflow as tf
from tensorflow.keras.optimizers import Adam,SGD
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from deel.utils.lip_utils import explain_for_dummies, explain_for_binary_dummies,load_compiled_model,add_softmax,get_data, mesure_metric, spearman_dist,mesure_stability,evaluate_kolmo,evaluate_dist
from xplique.metrics import MuFidelity,AverageStability
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap)
from matplotlib import pyplot as plt
from xplique.plots import plot_attributions,plot_attribution
import matplotlib
import time


matplotlib.use('Agg')
tf.get_logger().setLevel('ERROR')


def mesure_mu_fidelity(model,X,Y,expe_name,nb = 1000,nb_samples=[], grid_size=[]):
    model.layers[-1].activation = tf.keras.activations.linear
    f = open(folder +expe_name+"/curves/fidelity.txt", "w")
    explainer = Saliency(model)
    explanations = explainer(X[:nb], Y[:nb]).numpy()
    
    for (size, sample) in zip(grid_size,nb_samples) :
        print("fidelity ",size,sample,nb)
        start = time.time()
        metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = size,nb_samples=sample)
        fidelity_score = metric(explanations)
        end = time.time()
        print("computation time :",(end - start))
        f.write(f'''{f"fidelity grid {size} :":20s} {fidelity_score:.4f}\n''')
    f.close()
    
    

def save_saliency(model,X,Y,expe_name, nb = 10,nb_images = 3,folder ="./results/"):
    rep =folder +expe_name+"/curves/saliency/"
    if not os.path.exists(rep):
        os.makedirs(rep)
    explainer = Saliency(model)
    for i in range(nb):
        deb =i*nb_images
        end =deb + nb_images
        x = X[deb:end]
        y = Y[deb:end]
        #y[y==0] = -1
        explanations = explainer.explain(x.copy(), y)
        #print(explanations.shape)
        if(nb_images ==1):
            plot_attribution(explanations[0,:,:], x[0,:,:,:], cmap='jet', alpha=0.4,  absolute_value=True,clip_percentile=0)
        else:
            plot_attributions(explanations, x, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True,clip_percentile=0)

        plt.savefig(rep+"saliency_"+str(i)+".jpg",bbox_inches='tight',pad_inches=0)
        plt.close()

def save_smooth(model,X,Y,expe_name, nb = 10,nb_images = 3,folder ="./results/"):
    rep =folder +expe_name+"/curves/smooth/"
    if not os.path.exists(rep):
        os.makedirs(rep)
    x_max, x_min = X.max(), X.min()
    explainer = SmoothGrad(model, nb_samples=50, batch_size=batch_size, noise = 0.1*(x_max-x_min))
    
    for i in range(nb):
        deb =i*nb_images
        end =deb + nb_images
        x = X[deb:end]
        y = Y[deb:end]
        #y[y==0] = -1
        explanations = explainer.explain(x.copy(), y)
        if len(explanations.shape)>3 :
            explanations = np.mean(explanations, axis = 3)
        plot_attributions(explanations, x, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True)
        plt.savefig(rep+"smooth_"+str(i)+".jpg",bbox_inches='tight')
        plt.close()

def save_gradient(model,X,Y,expe_name, nb = 10,folder = "./results/"):
    rep = folder+expe_name+"/curves/grads/"
    if not os.path.exists(rep):
        os.makedirs(rep)
    ind = 0
    for i in range(nb):
        explain_for_binary_dummies(model,X[i],Y[i],counter_coeff = .1,filename = rep  +"grad_"+str(i))
       
def evaluate_all_metrics(model,data,expe_name):
    X,Y = get_data(data,nb_images = 2000)
    Y[Y==0] = -1
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50)
    explainers = [Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size),
                #Rise(model, nb_samples=4000, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size),
                GradientInput(model),
                GradCAM(model)]
    mesure_metric(model,X,Y, folder +expe_name+"/curves/mufidelity.txt",metric,explainers,nb = nb)

def evaluate_mu_uniform(model,data,expe_name):
    X,Y = get_data(data,nb_images = 6000)
    x_max, x_min = X.max(), X.min()
    Y[Y==0] = -1
    baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min 
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50,baseline_mode = baseline_uniform)
    explainers = [
                #Saliency(model), 
                #SmoothGrad(model, nb_samples=50, batch_size=batch_size),
                #Rise(model, nb_samples=4000, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size),
                #GradientInput(model),
                #GradCAM(model)
                ]
    mesure_metric(model,X,Y, folder +expe_name+"/curves/mufidelityunif.txt",metric,explainers,nb = nb)


def evaluate_stability(model,data,expe_name,logit = False):
    X,Y = get_data(data,nb_images = 2000)
    x_max, x_min = X.max(), X.min()
    if logit: 
        model.layers[-1].activation = tf.keras.activations.sigmoid
    Y[Y==0] = -1
    #Y[Y<=0] = 0
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = AverageStability(model, X[:nb], Y[:nb],  batch_size,  radius = 0.15*x_max, distance = spearman_dist, nb_samples = 10)
    explainers = [
                Saliency(model), 
                #SmoothGrad(model, nb_samples=50, batch_size=batch_size),
                #IntegratedGradients(model, steps=50, batch_size=batch_size)
                ]
    mesure_stability(model,X,Y, folder +expe_name+"/curves/stability_spear.txt",metric,explainers,nb = nb)


filename="./results/cifar/cifar_cat_vs_dog"
if len(sys.argv)>=2:
        filename =sys.argv[1]
if filename[-1] == "/":
    filename =filename[:-1] 
pos = filename.rindex("/")
folder = filename[:pos+1]
filename = filename[pos+1:]

model,full_config = load_compiled_model(filename, vanilla = True, folder = folder)
dtset = load_dataset(getParams(full_config,'dataset'))
batch_size = dtset['batch_size']
if 'test_XY' in dtset.keys():
    X,Y = dtset['test_XY']
else :
    test_iter = dtset['test'].__iter__()
    X,Y = next(test_iter)
    X = X.numpy()
    Y = Y.numpy()
Y[Y==0] = -1
#model_soft = add_softmax(model)
#evaluate_stability(model,dtset['test'],filename)
#save_saliency(model,X,Y,filename,nb = 100,nb_images = 1,folder = folder)
#save_smooth(model,X,Y,filename,nb = 30,nb_images = 1,folder = folder)
#save_gradient(model,X,Y,filename, nb = 100,folder = folder)
#evaluate_kolmo(model,dtset['test'],filename,nb_images=100)
#evaluate_dist(model,dtset['test'],filename,nb_images=100)

#mesure_mu_fidelity(model,X,Y,filename,nb = 1000,nb_samples=[50], grid_size=[9])
