"""
Explanation Sensitivity Measures
Sensitivity-Max
References: https://arxiv.org/pdf/1901.09392.pdf
also from Captum: https://captum.ai/api/metrics.html

"""
## -------------------
## --- Third-Party ---
## -------------------
import os
import sys
sys.path.append('../../')
fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
sys.path.append(parentDir)
import argparse
import numpy as np
import torch as t

## -----------
## --- Own ---
## -----------
from visual_interpretability import load_data_and_models
from utils import throw_out_wrong_classified
from visualize_mechanism.visual_func import SaliencyFunctions
from visualize_mechanism.tsr import tsr_from_saliencymap
from metrics.sensitivity_max import sensitivity_test
from trainhelper.dataset import Dataset


def sensitivity_test_eval(args,
                          model,
                          dataset,
                          vis_methods: list,
                          use_tsr: bool = False,
                          save_scores_path: str = None
                          ):
    ## setting args
    perturb_radius = args.Perturb_radius
    n_perturb_samples = args.N_perturb_samples

    root_dir = parentDir + '/../'
    # dataset_name = args.Dataset_name
    dataset_name_save = args.Dataset_name_save
    dl_selected_model = args.DLModel
    experiment = args.Experiments
    # experiment = ["experiment_10"]
    experiment = experiment[0]

    ## clean dataset
    normal_data, normal_labels = throw_out_wrong_classified(model=model,
                                                            data=dataset.data,
                                                            labels=dataset.labels,
                                                            )
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    normal_data = t.tensor(normal_data).float().to(device)
    normal_labels = t.tensor(normal_labels)
    normal_dataset = Dataset(normal_data, normal_labels)

    ## get tsr
    if use_tsr:
        rescaledSaliency = tsr_from_saliencymap(
            samples=dataset.data,
            labels=dataset.labels,
            dl_model=model,
            time_groups=5,
            feature_groups=1,
            threshold=0.5
        )
        rescaledSaliency = t.tensor(rescaledSaliency).float().to(device)
        rescaledSaliency = rescaledSaliency.unsqueeze(0)
    else:
        rescaledSaliency = None

    ## set saliency functions
    saliency_functions = SaliencyFunctions(model=model,
                                           tsr_saliency=rescaledSaliency,
                                           device=device)

    ## compute sensitivity scores
    sensitivity_scores = {}
    for method in vis_methods:
        sensitivity_scores[method] = t.zeros(len(normal_labels))
    for i in range(len(normal_labels)):
        sample = normal_dataset.data[i]
        sample = sample.reshape(1, sample.shape[0], sample.shape[1])

        if "grads" in vis_methods:
            sensitivity_scores["grads"][i] = sensitivity_test(
                explanation_function=saliency_functions.getGradientSaliency,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False
            )
        if "integrated_gradients" in vis_methods:
            sensitivity_scores["integrated_gradients"][i] = sensitivity_test(
                explanation_function=saliency_functions.getIntegratedGradients,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                ig_steps=60
            )
        if "smoothgrads" in vis_methods:
            sensitivity_scores["smoothgrads"][i] = sensitivity_test(
                explanation_function=saliency_functions.getSmoothGradients,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                nt_samples=60,
                stdevs=0.2
            )
        if "lrp_epsilon" in vis_methods and dl_selected_model not in ["LSTM"]:
            sensitivity_scores["lrp_epsilon"][i] = sensitivity_test(
                explanation_function=saliency_functions.getLRP,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                rule="epsilon"
            )
        elif "lrp_epsilon" in vis_methods and dl_selected_model in ["LSTM"]:
            sensitivity_scores["lrp_epsilon"][i] = sensitivity_test(
                explanation_function=saliency_functions.getLRP4LSTM,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False
            )
        if "lrp_gamma" in vis_methods:
            sensitivity_scores["lrp_gamma"][i] = sensitivity_test(
                explanation_function=saliency_functions.getLRP,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                rule="gamma"
            )
        if "gradCAM" in vis_methods:
            sensitivity_scores["gradCAM"][i] = sensitivity_test(
                explanation_function=saliency_functions.getGradCAM,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                use_relu=True,
                upsample_to_input=True,
                layer_to_grad="gap_softmax.conv1",
                attribute_to_layer_input=False
            )
        if "g_gradCAM" in vis_methods:
            sensitivity_scores["g_gradCAM"][i] = sensitivity_test(
                explanation_function=saliency_functions.getGuidedGradCAM,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch = n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                layer_to_grad="gap_softmax.conv1",
                attribute_to_layer_input=False
            )
        if "gbp" in vis_methods:
            sensitivity_scores["gbp"][i] = sensitivity_test(
                explanation_function=saliency_functions.getGuidedBackprop,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False
            )
        if "lime" in vis_methods:
            sensitivity_scores["lime"][i] = sensitivity_test(
                explanation_function=saliency_functions.getLIME,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                n_sample=1000,
                num_features=50,
                baseline="total_mean",
                kernel_width=5.0
            )
        if "kernelShap" in vis_methods:
            sensitivity_scores["kernelShap"][i] = sensitivity_test(
                explanation_function=saliency_functions.getKernelSHAP,
                inputs=sample,
                perturb_radius=perturb_radius,
                n_perturb_samples=n_perturb_samples,
                max_examples_per_batch=n_perturb_samples,
                label=normal_dataset.labels[i],
                absolute=False,
                n_sample=1000,
                baseline="total_mean",
                num_features=50
            )
    if save_scores_path is not None:
        path_2_save = root_dir + "results/" + dataset_name_save + "/" + dl_selected_model + "/" + experiment + "/"
        name_scores = path_2_save + experiment + "_" + save_scores_path
        np.save(name_scores, sensitivity_scores)
    mean_sensitivity_score = {}
    std_sensitivity_score = {}
    for key in sensitivity_scores.keys():
        mean_sensitivity_score[key] = np.mean(sensitivity_scores[key].cpu().detach().numpy())
        std_sensitivity_score[key] = np.std(sensitivity_scores[key].cpu().detach().numpy())

    return mean_sensitivity_score, std_sensitivity_score

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    # parser.add_argument("--Root_Dir", type=str, default='../')
    parser.add_argument("--Dataset_name", type=str, default='GunPointAgeSpan')
    parser.add_argument("--Dataset_name_save", type=str, default='GunPointAgeSpan_Cluster')
    parser.add_argument("--Experiments", nargs='+', default='experiment_0')
    parser.add_argument("--DLModel", type=str, default='LSTM')

    parser.add_argument("--Perturb_radius", type=float, default=0.05)
    parser.add_argument("--N_perturb_samples", type=int, default=10)
    parser.add_argument("--Use_tsr", action="store_true", default=True)
    parser.add_argument("--Save_scores_path", type=str, default='sensitivity_max_scores.npy')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments(sys.argv[1:])

    print("Load Data and Model")
    testsets, models, model_softmaxs, saliency_constructor_gcs, saliency_constructors = load_data_and_models(args=args)

    dl_selected_model = args.DLModel

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    if dl_selected_model not in ['LSTM', "MLP"]:
        # methods = ["lrp_epsilon"]
        # methods = ["grads",
        #            "smoothgrads",
        #            "integrated_gradients",
        #            "lrp_epsilon",
        #            "lrp_gamma",
        #            "gradCAM",
        #            "g_gradCAM",
        #            "gbp",
        #            "lime",
        #            "kernelShap"]
        methods = ["grads",
                   "smoothgrads",
                   "integrated_gradients",
                   "lrp_epsilon",
                   # "lrp_gamma",
                   "gradCAM",
                   "g_gradCAM",
                   "gbp"]
    else:
        # methods = ["grads",
        #            "smoothgrads",
        #            "integrated_gradients",
        #            "lrp_epsilon",
        #            "lime",
        #            "kernelShap"]
        methods = ["grads",
                   "smoothgrads",
                   "integrated_gradients",
                   "lrp_epsilon"]

    save_scores_path = args.Save_scores_path
    use_tsr = args.Use_tsr
    mean_scores, std_scores = sensitivity_test_eval(args=args,
                                                    model=models[0],
                                                    dataset=testsets[0],
                                                    vis_methods=methods,
                                                    use_tsr=use_tsr,
                                                    save_scores_path=save_scores_path)
