import os
import shutil
import sys
import time
sys.path.append('./')
sys.path.append('../')
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
#os.environ['AUTOGRAPH_VERBOSITY'] = '1'
import tensorflow as tf
import tensorflow_datasets as tfds
tf.get_logger().setLevel('ERROR')

import numpy as np
import pickle
import cv2
from deel.datasets.imagenet_dataset import imagenet_dataset
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,SobolAttributionMethod,
                                  GradCAMPP, Lime, KernelShap)

from matplotlib import pyplot as plt
from xplique.plots import plot_attributions,plot_attribution
from deel.utils.lip_utils import explain_for_dummies, load_compiled_model,add_softmax,get_data, mesure_metric, spearman_dist, mesure_stability,evaluate_kolmo,evaluate_dist
from deel.datasets.imagenet_dataset import normalize_vgg
from harmonization.common import load_clickme_val
from harmonization.evaluation import evaluate_clickme
from xplique.attributions.base import WhiteBoxExplainer, sanitize_input_output
from xplique.commons import batch_gradient
from xplique.types import Optional, Union
from xplique.metrics import MuFidelity, AverageStability

class SaliencyL2(WhiteBoxExplainer):
    """
    Used to compute the absolute gradient of the output relative to the input.
    Ref. Simonyan & al., Deep Inside Convolutional Networks: Visualising Image Classification
    Models and Saliency Maps (2013).
    https://arxiv.org/abs/1312.6034
    Notes
    -----
    As specified in the original paper, the Saliency map method should return the magnitude of the
    gradient (absolute value), and the maximum magnitude over the channels in case of RGB images.
    However it is not uncommon to find definitions that don't apply the L1 norm, in this case one
    can simply calculate the gradient relative to the input using the BaseExplanation method.
    Parameters
    ----------
    model
        The model from which we want to obtain explanations
    output_layer
        Layer to target for the outputs (e.g logits or after softmax).
        If an `int` is provided it will be interpreted as a layer index.
        If a `string` is provided it will look for the layer name.
        Default to the last layer.
        It is recommended to use the layer before Softmax.
    batch_size
        Number of inputs to explain at once, if None compute all at once.
    """

    @sanitize_input_output
    def explain(self,
                inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
                targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> tf.Tensor:
        """
        Compute saliency maps for a batch of samples.
        Parameters
        ----------
        inputs
            Dataset, Tensor or Array. Input samples to be explained.
            If Dataset, targets should not be provided (included in Dataset).
            Expected shape among (N, W), (N, T, W), (N, W, H, C).
            More information in the documentation.
        targets
            Tensor or Array. One-hot encoding of the model's output from which an explanation
            is desired. One encoding per input and only one output at a time. Therefore,
            the expected shape is (N, output_size).
            More information in the documentation.
        Returns
        -------
        explanations
            Saliency maps.
        """
        gradients = batch_gradient(self.model, inputs, targets, self.batch_size)
        #gradients = tf.abs(gradients)

        # if the image is a RGB, take the maximum magnitude across the channels (see Ref.)
        if len(gradients.shape) == 4:
            gradients = tf.norm(gradients, ord = 1,axis=-1)

        return gradients

class SaliencyCombine(WhiteBoxExplainer):
    """
    Used to compute the absolute gradient of the output relative to the input.
    Ref. Simonyan & al., Deep Inside Convolutional Networks: Visualising Image Classification
    Models and Saliency Maps (2013).
    https://arxiv.org/abs/1312.6034
    Notes
    -----
    As specified in the original paper, the Saliency map method should return the magnitude of the
    gradient (absolute value), and the maximum magnitude over the channels in case of RGB images.
    However it is not uncommon to find definitions that don't apply the L1 norm, in this case one
    can simply calculate the gradient relative to the input using the BaseExplanation method.
    Parameters
    ----------
    model
        The model from which we want to obtain explanations
    output_layer
        Layer to target for the outputs (e.g logits or after softmax).
        If an `int` is provided it will be interpreted as a layer index.
        If a `string` is provided it will look for the layer name.
        Default to the last layer.
        It is recommended to use the layer before Softmax.
    batch_size
        Number of inputs to explain at once, if None compute all at once.
    """

    @sanitize_input_output
    def explain(self,
                inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
                targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> tf.Tensor:
        x = inputs
        y_pred=self.model(x,training = False).numpy().squeeze()
        y_true = targets
        y_pred_second = tf.where(y_true==1,-tf.float32.max,y_pred)
        maxOthers = tf.reduce_max(y_pred_second, axis=1)
        y_true_second = tf.where(y_pred==tf.expand_dims(maxOthers,1),1., 0.)
        with tf.GradientTape() as tape:
            tape.watch(x)
            score =  tf.reduce_sum(tf.multiply(self.model(x,training = False), y_true), axis=1)
        grads_class = tape.gradient(score, x).numpy()
        with tf.GradientTape() as tape:
            tape.watch(x)
            score =  tf.reduce_sum(tf.multiply(self.model(x,training = False), y_true_second), axis=1)
        grads_other = tape.gradient(score, x).numpy()
        grads_class_norm = np.max(np.abs(grads_class), axis = 3)
        grads_other_norm = np.max(np.abs(grads_other), axis = 3)
        #grads_class_norm= grads_class_norm/(np.max(grads_class_norm,axis =(1,2),keepdims=True))
        #grads_other_norm= grads_other_norm/(np.max(grads_other_norm,axis =(1,2),keepdims=True))
        gradients  =np.max(np.abs(0.5*grads_class+0.5*grads_other), axis = 3)
        #gradients  =np.abs( 0.5*grads_class_norm  -0.5*grads_other_norm)
        #gradients  =np.max(np.abs(grads_other), axis = 3)
        return gradients


class SaliencyCombineInv(WhiteBoxExplainer):
    """
    Used to compute the absolute gradient of the output relative to the input.
    Ref. Simonyan & al., Deep Inside Convolutional Networks: Visualising Image Classification
    Models and Saliency Maps (2013).
    https://arxiv.org/abs/1312.6034
    Notes
    -----
    As specified in the original paper, the Saliency map method should return the magnitude of the
    gradient (absolute value), and the maximum magnitude over the channels in case of RGB images.
    However it is not uncommon to find definitions that don't apply the L1 norm, in this case one
    can simply calculate the gradient relative to the input using the BaseExplanation method.
    Parameters
    ----------
    model
        The model from which we want to obtain explanations
    output_layer
        Layer to target for the outputs (e.g logits or after softmax).
        If an `int` is provided it will be interpreted as a layer index.
        If a `string` is provided it will look for the layer name.
        Default to the last layer.
        It is recommended to use the layer before Softmax.
    batch_size
        Number of inputs to explain at once, if None compute all at once.
    """

    @sanitize_input_output
    def explain(self,
                inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray],
                targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> tf.Tensor:
        inputs = np.array(inputs)
        #inputs = inputs[..., [2,1,0]]
        inputs = normalize_vgg(inputs)
        x = inputs
        y_pred=self.model(x,training = False).numpy().squeeze()
        y_true = targets
        y_pred_second = tf.where(y_true==1,-tf.float32.max,y_pred)
        maxOthers = tf.reduce_max(y_pred_second, axis=1)
        y_true_second = tf.where(y_pred==tf.expand_dims(maxOthers,1),1., 0.)
        with tf.GradientTape() as tape:
            tape.watch(x)
            score =  tf.reduce_sum(tf.multiply(self.model(x,training = False), y_true), axis=1)
        grads_class = tape.gradient(score, x).numpy()
        with tf.GradientTape() as tape:
            tape.watch(x)
            score =  tf.reduce_sum(tf.multiply(self.model(x,training = False), y_true_second), axis=1)
        grads_other = tape.gradient(score, x).numpy()
        grads_class_norm = np.max(np.abs(grads_class), axis = 3)
        grads_other_norm = np.max(np.abs(grads_other), axis = 3)
        #grads_class_norm= grads_class_norm/(np.max(grads_class_norm,axis =(1,2),keepdims=True))
        #grads_other_norm= grads_other_norm/(np.max(grads_other_norm,axis =(1,2),keepdims=True))
        #gradients  =np.max(np.abs(0.5*grads_class+0.5*grads_other), axis = 3)
        #gradients  =np.abs( 0.5*grads_class_norm  -0.5*grads_other_norm)
        gradients  =np.max(np.abs(grads_other), axis = 3)
        return gradients

class SaliencyInv(Saliency):
    def explain(self,inputs,targets):
      #print(inputs)
      inputs = np.array(inputs)
      #inputs = inputs[..., [2,1,0]]
      inputs = normalize_vgg(inputs)
      return super().explain( inputs,targets)

def preprocess_images(x):
  #x = tf.reverse(x,axis = -1)
  #print(x.shape)
  return normalize_vgg(x)

def inverse_process(x) :
    mean = [123.68, 116.779,103.939 ]
    #x = x[..., [2,1,0]]
    x[..., 0] += mean[0]
    x[..., 1] += mean[1]
    x[..., 2] += mean[2]
    
    x = x[..., [2,1,0]]
    x = np.clip(x, a_min=0, a_max=255)
    
    return x

def inverse_process_grad(x) :
    x = x.numpy()
    mean = [123.68, 116.779,103.939 ]
    #x = x[..., [2,1,0]]
    x[..., 0] += mean[0]
    x[..., 1] += mean[1]
    x[..., 2] += mean[2]
    
    #x = x[..., [2,1,0]]
    x = np.clip(x, a_min=0, a_max=255)

    return x/255.
def data_range(data):
  dataset_it = data.__iter__()
  print("ok1")
  x,y = next(dataset_it)
  print("ok2")
  x = x.numpy()
  y = y.numpy()
  print(x.shape)
  print(x.min(), x.mean(), x.max())
def data_stats(data,prefix = "val"):
    dataset_it = data.__iter__()
    print("ok1")
    x,y = next(dataset_it)
    print("ok2")
    x = x.numpy()
    y = y.numpy()
    print(x.shape)
    print(x.mean(axis=tuple(range(x.ndim-1))),x.min(),x.max(),x.std(axis=tuple(range(x.ndim-1))) )
    #print("ymin", y.min(),"ymax", y.max())
    x = inverse_process(x)
    for i in range(5):
        img = x[i]
        #print(img.min(), img.max())
        cv2.imwrite("images/"+prefix+"_"+str(i)+".jpeg",img)



def evaluate_results(models, data):
    dataset_it = data.__iter__()
    for i in range(500):
        x,y = next(dataset_it)
        logits = model(x, training=True)
        y = y.numpy()
        logits = logits.numpy()
        res = logits[np.argmax(y,axis = 1) == np.argmax(logits,axis = 1)]
        print(res.shape[0]/logits.shape[0],logits.max(axis=1).mean())
def save_gradcam(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/smoothgrad/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  x_max, x_min = X.numpy().max(), X.numpy().min()
  nb = int(X.shape[0]//nb_images)
  explainer = SmoothGrad(model,nb_samples=100, batch_size=100, noise = 0.1*(x_max-x_min))
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    if len(explanations.shape)>3 :
      explanations = np.mean(explanations, axis = 3)
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True,clip_percentile=0.)
    plt.savefig(rep+"smoothg_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()
def save_saliency(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/saliency/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = Saliency(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    if(nb_images ==1):
        plot_attribution(explanations[0,:,:], x_rescale[0,:,:,:], cmap='jet', alpha=0.4,  absolute_value=True,clip_percentile=0)
    else:
        plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True,clip_percentile=0)
    plt.savefig(rep+"saliency_"+str(i)+".jpg",bbox_inches='tight',pad_inches=0)
    plt.close()
def save_saliency_combine(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/grads_comb/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = SaliencyCombine(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True)
    plt.savefig(rep+"saliency_comb_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()

def save_saliency_l2(model,data,expe_name,nb_images=10):
  dataset_it = data.__iter__()
  rep = "./results/"+expe_name+"/curves/grads_l2/"
  if not os.path.exists(rep):
    os.makedirs(rep)
  X,Y = next(dataset_it)
  nb = int(X.shape[0]//nb_images)
  explainer = SaliencyL2(model)
  for i in range(nb):
    print("Saliency n°",i)
    deb =i*nb_images
    end =deb + nb_images
    x = X[deb:end]
    y = Y[deb:end]
    explanations = explainer.explain(x.numpy(), y.numpy())
    x_rescale = inverse_process(x.numpy())
    plot_attributions(explanations, x_rescale, cmap='jet', alpha=0.4,cols=nb_images,  absolute_value=True)
    plt.savefig(rep+"saliency_"+str(i)+".jpg",bbox_inches='tight')
    plt.close()
def save_gradient(model,data,expe_name, nb_images = 10):
    rep = "./results/"+expe_name+"/curves/grads/"
    if not os.path.exists(rep):
        os.makedirs(rep)
    ind = 0
    dataset_it = data.__iter__()

    X,Y = next(dataset_it)
    X_rescale = inverse_process(X.numpy())
    for i in range(min(nb_images,X.shape[0])):
        cl,count = explain_for_dummies(model,X[i],Y[i], counter_coeff = 0.5, filename = rep  +"grad_"+str(i),post_process =inverse_process_grad)
    


def evaluate_mu_zero(model,data,expe_name):
    X,Y = get_data(data,nb_images = 2000)
    x_max, x_min = X.max(), X.min()
    print(X.shape)
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50)
    explainers = [Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size, noise = 0.1*(x_max-x_min)),
                #Rise(model, nb_samples=4000, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size),
                GradientInput(model),
                GradCAM(model)
                # 
                ]
    mesure_metric(model,X,Y, "./results/"+expe_name+"/curves/mufidelity_zero.txt",metric,explainers,nb = nb)

def evaluate_mu_unif(model,data,expe_name):
    X,Y = get_data(data,nb_images = 2000)
    x_max, x_min = X.max(), X.min()
    print(X.shape)
    nb = 100
    #model.layers[-1].activation = tf.keras.activations.linear
    baseline_uniform = lambda x: np.random.uniform(size=x.shape) * (x_max-x_min) - x_min 
    metric = MuFidelity(model, X[:nb], Y[:nb],  batch_size, grid_size = 9,nb_samples=50,baseline_mode = baseline_uniform)
    explainers = [
                Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size, noise = 0.1*(x_max-x_min)),
                #Rise(model, nb_samples=400, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size),
                GradientInput(model),
                GradCAM(model)
                ]
    mesure_metric(model,X,Y, "./results/"+expe_name+"/curves/mufidelity_unif.txt",metric,explainers,nb = nb)

def evaluate_stability(model,data,expe_name):
    X,Y = get_data(data,nb_images = 1000)
    x_max, x_min = X.max(), X.min()
    print(f"data range : [{x_min},{x_max}]")
    nb = 1000
    #model.layers[-1].activation = tf.keras.activations.linear
    metric = AverageStability(model, X[:nb], Y[:nb],  batch_size,  radius = .15*x_max, distance = spearman_dist, nb_samples = 10)
    explainers = [Saliency(model), 
                SmoothGrad(model, nb_samples=50, batch_size=batch_size),
                IntegratedGradients(model, steps=50, batch_size=batch_size)]
    mesure_stability(model,X,Y, "./results/"+expe_name+"/curves/stability_spear.txt",metric,explainers,nb = nb) 

clickme_dataset = load_clickme_val(batch_size = 128)
batch_size = 256
#print("load 1")
#train, val, info = imagenet_dataset(batch_size = batch_size,
#                                              preprocess = "VGG",
                                              #compute_train_val=True,
#                                              shuffle = 0)
print("analyse 1")

#data_range(train)
#data_range(val)
#print("load 2")
train, val,train_val, info = imagenet_dataset(batch_size = batch_size,
                                              preprocess =  "VGG",
                                              compute_train_val=True,
                                              shuffle_files = False,
                                              shuffle = 0)
#data_stats(val,prefix = "valinv")
#print("analyse 2")
#data_stats(train,prefix = "train-2")
#data_stats(val,prefix = "val")
#data_stats(train_val,prefix = "train_val")
expe_name = "resnet50_auto_nograd_d"
suffix = ""
if len(sys.argv)>=2:
    expe_name =sys.argv[1]
if len(sys.argv)>=3:
    suffix =sys.argv[2]
print(expe_name,suffix)
model,_ = load_compiled_model(expe_name,vanilla=True,softmax = False,suffix =suffix)
model.summary()
explainer = SaliencyInv(model)
rep = "./results/"+expe_name+"/models/"
#model.save(rep+expe_name+"vanilla.h5")
#evaluate_mu_unif(model,val,expe_name)
#evaluate_mu_zero(model,val,expe_name)
#evaluate_stability(model,val,expe_name)
#evaluate_kolmo(model,val,expe_name, nb_images=100)
#evaluate_dist(model,val,expe_name, nb_images=100)
#
#vanilla = model.vanilla_export()
#model = tf.keras.models.load_model("results/resnet_50_save/models/resnet_50_save_vanilla/")
#model.compile(loss='categorical_crossentropy',
#                optimizer='adam',
#                metrics=['categorical_accuracy','top_k_categorical_accuracy'])
#results = model.evaluate(val, steps = 100000//batch_size)



#model_soft = add_softmax(model)
#save_saliency(model,val,expe_name,nb_images=1)
#save_gradcam(model,val,expe_name,nb_images=1)
#save_gradient(model,val,expe_name,nb_images=100)
#model_soft.summary()

#save_saliency_combine(model,val,expe_name,nb_images=10)
# scores = evaluate_clickme(model, 
#                           explainer = explainer,
#                           preprocess_inputs=None)
# print('alignement :',scores['alignment_score'])
# f = open("./results/"+expe_name+"/curves/aligement.txt", "w")
# f.write(f"alignement :{scores['alignment_score']:0.3f}")
# f.write('\n')
# f.flush()
# f.close()
#save_gradcam(model_soft,val,expe_name,nb_images=10)

#save_gradient(model,val,expe_name,nb_images=256)
#print("train")
#evaluate_results(model, train)
#print("val")
#evaluate_results(model, val)
#print("trainval")
#evaluate_results(model, val)
#print(results)
#nb = 1000,nb_samples=[128], grid_size=[9])