## This file is used to evaluate the saliency maps, which includes the gap scores in deletion(swap or mean
## or zero) of importance continuous sequences
## Reference: https://arxiv.org/abs/1909.07082

## import libraries
## -------------------
## --- Third-Party ---
## -------------------
import os
import sys
sys.path.append('../')
sys.path.append('../../')
fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
sys.path.append(parentDir)
from typing import List
import argparse
import numpy as np
import pandas as pd
import torch as t
import matplotlib.pyplot as plt

### -----------
### --- Own ---
### -----------
from utils import load_saliencies, clean_saliency_list
from gun_point.evaluation.temporal_sequence_importance import temporal_sequence_evaluate
from temporal_saliency_importance import load_data_and_models


def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--Dataset_name", type=str, default='tool_tracking_Cluster')
    parser.add_argument("--Dataset_name_save", type=str, default='tool_tracking_Cluster')
    parser.add_argument("--Data_path", type=str, default='data/tool_tracking_data')
    parser.add_argument("--Detection", action="store_true", default=False)
    parser.add_argument("--Znorm", action="store_true", default=True)
    parser.add_argument("--One_matrix", action="store_true", default=True)
    parser.add_argument("--Sparse_labels", action="store_true", default=True)
    parser.add_argument("--Window_length", type=float, default=0.2)
    parser.add_argument("--Overlap", type=float, default=0.5)

    parser.add_argument("--Experiments", nargs='+', default='experiment_0')
    parser.add_argument("--DLModel", type=str, default='TCN_laststep')
    parser.add_argument("--Evaluation_mode", type=str, default='swap')
    parser.add_argument("--Evaluation_length", type=float, default=0.1)
    parser.add_argument("--Batch_size", type=int, default=1)
    parser.add_argument("--Verbose", type=int, default=1)
    parser.add_argument("--TypeofSaliency", type=str, default='No_abs_norm')
    parser.add_argument("--Save_to", type=str, default=None)

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments(sys.argv[1:])
    print("Load Data and Model")
    testsets, models, model_softmaxs, saliency_constructor_gcs, saliency_constructors, testset = load_data_and_models(
        args=args
    )
    ## setting
    root_dir = parentDir + '/../'
    dataset_name = args.Dataset_name
    dl_selected_model = args.DLModel
    path_2_saliency = root_dir + "results/" + dataset_name + "/" + dl_selected_model + "/"
    experiments = args.Experiments
    experiments = ["experiment_15"]
    saliency_abs_list, saliency_no_abs_list = load_saliencies(path_2_saliency, experiments)

    # Temporal Sequence Object
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    if dl_selected_model not in ['LSTM', 'MLP']:
        # methods = ["lrp_epsilon"]
        methods = ["grads",
                   "smoothgrads",
                   "igs",
                   "lrp_epsilon",
                   # "lrp_gamma",
                   "gradCAM",
                   "guided_gradcam",
                   "guided_backprop",
                   "lime",
                   "kernel_shap"]
        # methods = ["grads",
        #            "smoothgrads",
        #            "igs",
        #            "lrp_epsilon",
        #            # "lrp_gamma",
        #            "gradCAM",
        #            "guided_gradcam",
        #            "guided_backprop"]
    else:
        methods = ["grads",
                   "smoothgrads",
                   "igs",
                   "lrp_epsilon",
                   "lime",
                   "kernel_shap"]
        # methods = ["grads",
        #            "smoothgrads",
        #            "igs",
        #            "lrp_epsilon"]

    typeofsali = args.TypeofSaliency
    if typeofsali in ["No_abs_norm"]:
        print("No Abs is used")
        saliency_list = saliency_no_abs_list
    elif typeofsali in ["Abs_norm"]:
        print("Abs norm is used")
        saliency_list = saliency_abs_list
    else:
        raise ValueError("Type of saliency not found")

    saliency_list = clean_saliency_list(model_softmaxs, testset, saliency_list, testsets)
    temporal_sequence_evaluate(args=args,
                               models=model_softmaxs,
                               datasets=testsets,
                               saliency_list=saliency_list,
                               typeofsaliency=typeofsali,
                               methods=methods,
                               eval_mode=args.Evaluation_mode,
                               length=args.Evaluation_length,
                               batch_size=args.Batch_size,
                               verbose=args.Verbose,
                               device=device,
                               save_to=args.Save_to)

