"""
Combine TSR method with backpropagation based methods
"""
## -------------------
## --- Third-Party ---
## -------------------
import sys
import os
sys.path.append('../../')
fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
sys.path.append(parentDir)
from typing import List
import argparse
import numpy as np

### -----------
### --- Own ---
### -----------
from gun_point.evaluation.visual_interpretability import load_saliencymaps

def load_tsr_saliency(path2folder, experiments: list):
    saliencymaps = []
    for i in range(len(experiments)):
        name = path2folder + "rescaled_grads_" + experiments[i] + ".npy"
        maps = np.load(name, allow_pickle=True)
        saliencymaps.append(maps)
    return saliencymaps

def save_mod_saliencymaps(args, saliencymaps):
    ### Save the Saliency Maps
    experiment_names = args.Experiments
    experiment_names = ["experiment_11"]
    root_dir = parentDir
    dataset_name_save = args.Dataset_name_save
    dl_selected_model = args.DLModel
    path_2_save = root_dir + "/results/" + dataset_name_save + "/" + dl_selected_model + "/"
    for i in range(len(experiment_names)):
        name = path_2_save + "modsaliencymaps" + experiment_names[i] + ".npy"
        np.save(name, saliencymaps[i])

def combine_tsr_with_visual(args,
                            saliency_maps,
                            tsr_maps):
    mod_saliency_list = []
    for i in range(len(tsr_maps)):
        mod_saliency_maps = {}
        tsr = tsr_maps[i]
        ## Gradient Based
        if "grads" in saliency_maps[i].keys():
            grads = saliency_maps[i]["grads"]
            mod_saliency_maps["grads"] = grads * tsr
        if "smoothgrads" in saliency_maps[i].keys():
            smoothgrads = saliency_maps[i]["smoothgrads"]
            mod_saliency_maps["smoothgrads"] = smoothgrads * tsr
        if "igs" in saliency_maps[i].keys():
            igs = saliency_maps[i]["igs"]
            mod_saliency_maps["igs"] = igs * tsr
        if "lrp_epsilon" in saliency_maps[i].keys():
            lrp = saliency_maps[i]["lrp_epsilon"]
            mod_saliency_maps["lrp_epsilon"] = lrp * tsr
        if "gradCAM" in saliency_maps[i].keys():
            gradcam = saliency_maps[i]["gradCAM"]
            mod_saliency_maps["gradCAM"] = gradcam * tsr
        if "guided_gradcam" in saliency_maps[i].keys():
            guided_gradcam = saliency_maps[i]["guided_gradcam"]
            mod_saliency_maps["guided_gradcam"] = guided_gradcam * tsr
        if "guided_backprop" in saliency_maps[i].keys():
            guided_backprop = saliency_maps[i]["guided_backprop"]
            mod_saliency_maps["guided_backprop"] = guided_backprop * tsr
        mod_saliency_list.append(mod_saliency_maps)
    return mod_saliency_list


def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    # parser.add_argument("--Root_Dir", type=str, default='../')
    parser.add_argument("--Dataset_name", type=str, default='GunPointAgeSpan')
    parser.add_argument("--Dataset_name_save", type=str, default='GunPointAgeSpan_Cluster')
    parser.add_argument("--Experiments", nargs='+', default='experiment_13')
    parser.add_argument("--DLModel", type=str, default='TCN_laststep')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments(sys.argv[1:])

    root_dir = parentDir
    dataset_name = args.Dataset_name
    dataset_name_save = args.Dataset_name_save
    dl_selected_model = args.DLModel
    path2folder = root_dir + "/results/" + dataset_name_save + "/" + dl_selected_model + "/"
    experiments = args.Experiments
    experiments = ['experiment_11']
    saliencymaps = load_saliencymaps(path2folder, experiments)

    tsr_maps = load_tsr_saliency(path2folder, experiments)

    mod_saliencys = combine_tsr_with_visual(args,
                                            saliency_maps=saliencymaps,
                                            tsr_maps=tsr_maps)
    save_mod_saliencymaps(args,
                          mod_saliencys)