"""
Tool Tracking
Evaluation:
    Sanity Check
    https://github.com/adebayoj/sanity_checks_saliency/tree/3e24048c570f08ca655fcd332b6128fa069810a0
    Model Randomization Test
"""
## ------------------
## --- Third-Party ---
## ------------------
import os
import sys
sys.path.append('../../')
fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
sys.path.append(parentDir)
import argparse
import torch as t

### -----------
### --- Own ---
### -----------
from utils import load_saliencies
from visual_interpretability import load_data_and_models
from gun_point.evaluation.sanitycheck import saliency_sanitycheck, plot_statistic

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    # parser.add_argument("--Root_Dir", type=str, default='../')
    parser.add_argument("--Dataset_name", type=str, default='tool_tracking_Cluster')
    parser.add_argument("--Dataset_name_save", type=str, default='tool_tracking_Cluster')
    parser.add_argument("--Experiments", nargs='+', default='experiment_0')
    parser.add_argument("--DLModel", type=str, default='TCN_withoutFC')
    parser.add_argument("--Evaluation_mode", type=str, default='Cascade')
    parser.add_argument("--Title", type=str, default='Cascade_randomize_correlation')
    parser.add_argument("--Use_tsr", action="store_true", default=True)

    parser.add_argument("--Data_path", type=str, default='data/tool_tracking_data')
    parser.add_argument("--Detection", action="store_true", default=False)
    parser.add_argument("--Znorm", action="store_true", default=True)
    parser.add_argument("--One_matrix", action="store_true", default=True)
    parser.add_argument("--Sparse_labels", action="store_true", default=True)
    parser.add_argument("--Window_length", type=float, default=0.2)

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments(sys.argv[1:])

    print("Load Data and Model")
    testsets, models, model_softmaxs, saliency_constructor_gcs, saliency_constructors = load_data_and_models(
        args=args
    )
    ## setting
    root_dir = parentDir + '/../'
    dataset_name = args.Dataset_name
    dataset_name_save = args.Dataset_name_save
    dl_selected_model = args.DLModel
    path_2_saliency = root_dir + "results/" + dataset_name_save + "/" + dl_selected_model + "/"
    experiments = args.Experiments
    # experiments = ["experiment_13"]
    saliency_abs_list, saliency_no_abs_list = load_saliencies(path_2_saliency, experiments)

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    if dl_selected_model not in ['LSTM', 'MLP']:
        # methods = ["lrp_epsilon"]
        # methods = ["grads",
        #            "smoothgrads",
        #            "integrated_gradients",
        #            "lrp_epsilon",
        #            "lrp_gamma",
        #            "gradCAM",
        #            "g_gradcam",
        #            "gbp",
        #            "lime",
        #            "kernelShap"]
        methods = ["grads",
                   "smoothgrads",
                   "integrated_gradients",
                   "lrp_epsilon",
                   # "lrp_gamma",
                   "gradCAM",
                   "g_gradCAM",
                   "gbp"]
    else:
        # methods = ["grads",
        #            "smoothgrads",
        #            "integrated_gradients",
        #            "lrp_epsilon",
        #            "lime",
        #            "kernelShap"]
        methods = ["grads",
                   "smoothgrads",
                   "integrated_gradients",
                   "lrp_epsilon"]

    random_correlation_stat, random_ssim_stat, rand_names, rand_acc_dict = saliency_sanitycheck(
        args=args,
        models=models,
        datasets=testsets,
        normal_saliency=saliency_no_abs_list[0],
        normal_saliency_abs=saliency_abs_list[0],
        methods=methods,
        save_randsaliency=True
    )
    plot_statistic(args=args,
                   rand_correlation_stat=random_correlation_stat,
                   rand_ssim_stat=random_ssim_stat,
                   rand_names=rand_names,
                   rand_acc_dict=rand_acc_dict,
                   vis_methods=methods,
                   title=args.Title
                   )