## ------------------
## --- Third-Party ---
## ------------------
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch as t

## -----------
## --- own ---
## -----------
from utils import read_dataset_ts, summarize_label, create_directory, generate_results_csv, plot_dataset_fromDataset
from models.models import TCN, FCN
from trainhelper.dataset import Dataset
from visualize_mechanism.cam import CAM_UCR, Grad_Cam_UCR
from visualize_mechanism.plot_vis_plt import plot_histogram, plot_confusion_matrix
from metrics.temporal_instability import Temporal_instability
from metrics.feature_importance import Feature_importance

# Visualization Method Selection and Deep Learning Model Selection
# Metric choice
dl_selected_model = "FCN"    ## ["FCN","TCN"]
vis_method = "CAM"      ## ["CAM", "GradCAM","Excitation_backprop", "Rise"]
metric_choice = "Mean"  ## ["Mean", "Zero", "Random", "Time_shift"]

# Load the dataset
root_dir = "../"
dataset_name = "GunPoint"
dataset = read_dataset_ts(root_dir, dataset_name)
train_x, test_x, train_y, test_y, labels_dict = dataset[dataset_name]

## For Visualization
## We don't need the whole dataset, just some samples from testset
label_summary = np.unique(test_y)
num_cls = len(label_summary)
## Label Mapping
# _, test_y_merged = summarize_label(None, test_label=test_y)
## transfer test set into Torch Dataset
testset = Dataset(test_x, test_y)

## Parameters
path_2_parameters = root_dir + "results/" + dataset_name + "/" + dl_selected_model + "/"
report = pd.read_csv(path_2_parameters + "reports.csv")

print("[Report] Best Epoch: {}".format(report["best_epoch"][0]))
test_acc = report["test_accuracy"][0].split(",")
print("[Report] Accuracy from best epoch: {}".format(test_acc[report["best_epoch"][0] - 1]))


## model setting and loading from checkpoint
ckp_path = path_2_parameters + "checkpoints/checkpoint_{}.ckp".format(report["best_epoch"][0])
if dl_selected_model is 'FCN':
    kernel_size = [int(k) for k in report["kernel_size"][0][1:-1].split(',')]
    model = FCN(ch_in=int(test_x.shape[1]), dropout_rate=report["dropout_rate"][0],
                num_classes=report["num_classes"][0],
                kernel_size=kernel_size,
                use_fc=True)

if vis_method is "CAM":
    visualized = CAM_UCR(testset, model, checkpoint=ckp_path, target_layer="gap_softmax.conv1",
                  fc_layer="gap_softmax.fc", model_loaded=False)
if vis_method is "GradCAM":
    visualized = Grad_Cam_UCR(testset, model, checkpoint=ckp_path, target_layer="gap_softmax.conv1",
                            before_block="convblock4.relu", model_loaded=False)

thresholds = np.linspace(0, 1, 11)
perturbation_accuracys = []
for threshold in thresholds:
    temporal_instability = Temporal_instability(testset, vis_method=[vis_method, visualized])
    threshold_highlight = threshold
    threshold_count = threshold
    temporal_instability.get_visualize()
    histogram = temporal_instability.get_histogram(threshold=threshold_highlight) ## histogram [C, dim, len]
    feature_importance = Feature_importance(testset, histogram, trained_model=model,
                                            checkpoint=ckp_path, threshold=threshold_count)
    if metric_choice is "Quantile":
        testset_perturbated = feature_importance.quantile_perturbate(importance=True)
    if metric_choice is "Mean":
        testset_perturbated = feature_importance.mean_perturbate(importance=True)
    if metric_choice is "Zero":
        testset_perturbated = feature_importance.zero_perturbate(importance=True)
    if metric_choice is "Random":
        testset_perturbated = feature_importance.random_perturbate(importance=True)
    if metric_choice is "Time_shift":
        testset_perturbated = feature_importance.time_shift_perturbate()
    perturbation_method = metric_choice
    criterions = feature_importance.test_accuracy(testset_perturbated, perturbation_method=perturbation_method)
    perturbation_accuracys.append(criterions[f"{perturbation_method}_acc"])

    ## Merge the criterions
    criterions["Dataset"] = dataset_name
    criterions["Classifier"] = dl_selected_model
    criterions["vis_method"] = vis_method
    criterions["threshold_highlightpoint"] = threshold_highlight
    criterions["histogram"] = histogram
    criterions["label_summary"] = labels_dict
    ## plot
    path = root_dir + "results/" + dataset_name + "/" + dl_selected_model + "/visualizations/metrics/" \
           + vis_method + "/perturbation_" + perturbation_method
    path_done = create_directory(path)
    generate_results_csv(criterions, store_path=path)
    plot_confusion_matrix(criterions, save_path=path, metric_name=perturbation_method)

    ## Plot the perturbation dataset
    plot_dataset_fromDataset(testset_perturbated, save_path=path, criterions=criterions,
                             method_name=perturbation_method)
    plot_histogram(histogram, label_summary, labels_dict, save_path=path, vis_method=vis_method, threshold=threshold)

plt.figure()
plt.plot(thresholds, perturbation_accuracys)
plt.xlabel("Threshold")
plt.ylabel("Accuracy")
plt.title(f"elbow points with threshold for {metric_choice}")
plt.savefig(path + "/elbow_points.png")