"""
Visualization(Explain)
Use for UCR datasets
for example: GunPoint
"""

## ------------------
## --- Third-Party ---
## ------------------
import math
from copy import deepcopy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch as t
import torchviz

from captum.attr import Lime
from torchray.attribution.excitation_backprop import excitation_backprop
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from captum.attr._utils.visualization import visualize_image_attr


## -----------
## --- own ---
## -----------
from utils import read_dataset_ts, summarize_label, create_directory, load_model, get_model_weights
from utils import throw_out_wrong_classified
from visualize_mechanism.visual_utils import min_max_normalize
from models.models import TCN, FCN
from models.lstm import LSTM
from trainhelper.dataset import Dataset
from metrics.insertion_deletion import MetricInsertDelete, gaussian_kernel, gaussian_kernel_
from metrics.insertion_deletion import quantile_values_like, auc
from visualize_mechanism.visual_utils import SaliencyConstructor
from visualize_mechanism.lrp import LRP_individual
from visualize_mechanism.vanillabackprop import VanillaBackprop
from visualize_mechanism.integrated_gradients import IntegratedGradients
from visualize_mechanism.smoothgrad import SmoothGrad
from visualize_mechanism.cam import interpolate_smooth_ucr, grad_cam_ucr, cam_ucr
from metrics.robustness import IntermodelCheck
from metrics.temporal_sequence_eval import TemporalSequenceEval

from visualize_mechanism.plot_vis_plt import plot_vis_plt_ucr
from visualize_mechanism.tsr import temporalsaliencyrescaling

# Visualization Method Selection and Deep Learning Model Selection
dl_selected_model = "FCN_withoutFC"  ## ["FCN","TCN", "ResNet", "LSTM"]
use_fc = False
use_pooling = True
vis_method = "LIME"
## ["CAM", "GradCAM","Excitation_backprop", "Rise", "LRP", "Gradient", "IntegratedGradients",
#   "SmoothGrad", "LIME"]

# Load the dataset
root_dir = "../../"
dataset_name = "GunPointAgeSpan"
model_name = "GunPointAgeSpan_Cluster"
dataset = read_dataset_ts(root_dir, dataset_name)
train_x, test_x, train_y, test_y, labels_dict = dataset[dataset_name]
## transfer test set into Torch Dataset
testset = Dataset(test_x, test_y)

## For Visualization
## We don't need the whole dataset, just some samples from testset
label_summary = np.unique(test_y)
num_cls = len(label_summary)
## Label Mapping
# _, test_y_merged = summarize_label(None, test_label=test_y)


## Parameters
path_2_parameters = root_dir + "results/" + model_name + "/" + dl_selected_model + "/" + \
                    "experiment_7/"
report = pd.read_csv(path_2_parameters + "reports.csv")

## model setting and loading from checkpoint
ckp_path = path_2_parameters + "checkpoints/checkpoint_{}.ckp".format(report["best_epoch"][0])
if dl_selected_model in ['FCN_withoutFC', 'FCN_laststep']:
    kernel_size = [int(k) for k in report["kernel_size"][0][1:-1].split(',')]
    ch_out = [int(k) for k in report["Filter_numbers"][0][1:-1].split(',')]
    model = FCN(ch_in=int(test_x.shape[1]),
                ch_out=ch_out,
                dropout_rate=report["dropout_rate"][0],
                num_classes=report["num_classes"][0],
                kernel_size=kernel_size,
                use_fc=use_fc,
                use_pooling=use_pooling)
    # t.backends.cudnn.enabled = False


if dl_selected_model in ["TCN_withoutFC", "TCN_laststep"]:
    dilation = [int(k) for k in report["dilation"][0][1:-1].split(',')] ## should be always same size as ch_out
    kernel_size = [int(k) for k in report["kernel_size"][0][1:-1].split(',')]  ## the size also should be the same as ch_out
    ch_out = [int(k) for k in report["Filter_numbers"][0][1:-1].split(',')]
    model = TCN(ch_in=int(train_x.shape[1]), ch_out=ch_out,
                kernel_size=kernel_size,
                dropout_rate=report["dropout_rate"][0],
                use_fc=use_fc,
                use_pooling=use_pooling,
                num_classes=report["num_classes"][0])

if dl_selected_model is "LSTM":
    hidden_size = int(report["Hidden_size"][0])
    num_layers = int(report["num_layers"][0])
    dropout = float(report["dropout_rate"][0])
    bidirectional = bool(report["bidirectional"][0])
    model = LSTM(ch_in=int(train_x.shape[1]),
                 hidden_size=hidden_size,
                 num_layers=num_layers,
                 dropout=dropout,
                 bidirectional=bidirectional,
                 num_classes=num_cls)
    t.backends.cudnn.enabled = False

## this is for sanity checks
# layer_names, model_weights = get_model_weights(model)
# num_layers = len(layer_names)
idx = 0
model = load_model(model=model, ckp_path=ckp_path)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
## add softmax to create probability
model_softmax = t.nn.Sequential(model, t.nn.Softmax(dim=1))
model_softmax = model_softmax.eval()
model_softmax = model_softmax.cuda()

cleandata, cleanlabels = throw_out_wrong_classified(model_softmax,
                                                    test_x,
                                                    test_y,
                                                    device)
cleantestset = Dataset(cleandata, cleanlabels)

y = model(t.tensor(cleantestset.data[1].reshape(1, 1, -1)).to(device))
# z = torchviz.make_dot(y, params=dict(model.named_parameters()))
# print(z)
# tsr_gradCAM = temporalsaliencyrescaling(saliency_method="ig",
#                                      input=testset.data,
#                                      labels=testset.labels,
#                                      dl_model=model)
# while idx < num_layers:
#     random_gaussian_model_, idx = load_model(model=model,
#                                         ckp_path=ckp_path,
#                                         randomized=True,
#                                         independent=True,
#                                         idx_layer=idx)
#
#     ## For TCN
#     if idx > 4 and idx < num_layers:
#         layer_name = layer_names[idx].split('.')
#         if layer_name[-2] in ["conv1", "downsample"]:
#             print(layer_names[idx - 3])
#         else:
#             print(layer_names[idx - 2])
#     elif idx >= num_layers:
#         print(layer_names[idx - 3])
#     else:
#         print(layer_names[idx - 2])
#     # idx += 1

# samples_idx = [i for i in np.arange(15, 25)]
# samples_idx = [100, 151, 140, 150, 0, 1, 2, 3, 6, 7, 10, 222, 111]

saliency_constructor_gc = SaliencyConstructor(model=model, data=cleantestset,
                                              labels= ["Gun", "Point"],
                                              use_prediction=True,
                                              device=t.device('cuda' if t.cuda.is_available() else 'cpu'))
saliency_constructor = SaliencyConstructor(model=model_softmax, data=cleantestset,
                                              labels= ["Gun", "Point"],
                                              use_prediction=True,
                                              device=t.device('cuda' if t.cuda.is_available() else 'cpu'))

grads = np.zeros(cleantestset.data.shape)
igs = deepcopy(grads)
smoothgrads = deepcopy(grads)
lrp_maps_selbst = deepcopy(grads)
lrp_maps = deepcopy(grads)
gradCam_maps = deepcopy(grads)
g_gradcam_maps = deepcopy(grads)
gbp_maps = deepcopy(grads)
lime_maps = deepcopy(grads)
shap_maps = deepcopy(grads)

sample_labels = []
for idx, sample in enumerate(range(cleantestset.data.shape[0])):
    grads[idx] = saliency_constructor.gradient_saliency(idx=sample)[0]
    # igs[idx] = saliency_constructor.integrated_gradients(idx=sample, ig_steps=60)[0]
    # smoothgrads[idx] = saliency_constructor.smooth_gradients(idx=sample,
    #                                                          nt_samples=60,
    #                                                          stdevs=0.2)[0]
    # lrp_maps[idx] = saliency_constructor_gc.lrp4lstm_(idx=sample,
    #                                                   absolute=False)
    # lrp_maps_selbst[idx] = saliency_constructor_gc.lrp_selbst(idx=sample,
    #                                                    rule='epsilon',
    #                                                    absolute=False)
    lrp_maps[idx] = saliency_constructor_gc.lrp_(idx=sample,
                                                 rule="epsilon",
                                                 absolute=False)[0]
    # gradCam_maps[idx] = saliency_constructor_gc.grad_cam(idx=sample,
    #                                                      layer_to_grad="gap_softmax.conv1",
    #                                                      use_relu=False,
    #                                                      attribute_to_layer_input=True)[0]
    # g_gradcam_maps[idx] = saliency_constructor_gc.guided_gradCAM_(idx=sample,
    #                                                               layer_to_grad="gap_softmax.conv1",
    #                                                               use_relu=False,
    #                                                               attribute_to_layer_input=True)[0]
    # gbp_maps[idx] = saliency_constructor.guided_backprop(idx=sample)[0]


    # shap_maps[idx] = saliency_constructor.kernelshap_(idx=sample,
    #                                                   num_features=75,
    #                                                   n_sample=500,
    #                                                   baseline="noise")[0]
    # lime_maps[idx] = saliency_constructor.lime_(idx=sample,
    #                                             num_features=75,
    #                                             n_sample=500,
    #                                             baseline="mean",
    #                                             kernel_width=10.0)[0]

    # lime_maps[idx] = saliency_constructor.lime_ts(idx=sample,
    #                                               num_slices=75,
    #                                               num_features=75,
    #                                               n_sample=50,
    #                                               replacement_method="total_mean")



    sample_labels.append(cleantestset.labels[sample])

## Min Max Normalizationgrads = min_max_normalize(grads)
# grads = min_max_normalize(grads)
# igs = min_max_normalize(igs)
# smoothgrads = min_max_normalize(smoothgrads)
# lrp_maps_selbst = min_max_normalize(lrp_maps_selbst)
# lrp_maps = min_max_normalize(lrp_maps)
# gradCam_maps = min_max_normalize(gradCam_maps)
# g_gradcam_maps = min_max_normalize(g_gradcam_maps)
# gbp_maps = min_max_normalize(gbp_maps)
# lime_maps = min_max_normalize(lime_maps)
# shap_maps = min_max_normalize(shap_maps)
normal_saliency = {"grads":grads, "smoothgrads":smoothgrads,
                   "igs":igs, "gradCAM": gradCam_maps, "lrp":lrp_maps,
                   "guided_gradcam": g_gradcam_maps, "guided_backprop": gbp_maps,
                   "lime": lime_maps,
                  "kernel_shap": shap_maps}
## try to evaluate the stability of saliency methods
## create objects for model check
intermodelcheck_gc = IntermodelCheck(model=model,
                                    device=device)
intermodelcheck = IntermodelCheck(model=model_softmax,
                                  device=device)

## evaluate the temporal sequence importance of saliency methods
## create objects for temporal sequence
temporal_sequence_check = TemporalSequenceEval(model=model_softmax,
                                               eval_mode="mean",
                                               length=0.1,
                                               device=device)

stability_distances = {}
for key in normal_saliency.keys():
    if key in ["gradCAM", "guided_gradcam"]:
        stability_distances[key] = intermodelcheck_gc.stability_check(saliency_maps=normal_saliency[key],
                                                                   labels=sample_labels,
                                                                   similar_metric="l2")
    else:
        gap_scores = temporal_sequence_check.evaluation(batch_samples=cleantestset.data,
                                                        batch_labels=cleantestset.labels,
                                                        batch_size=1,
                                                        batch_saliency_maps=normal_saliency[key],
                                                        verbose=1)
        stability_distances[key] = intermodelcheck.stability_check(saliency_maps=normal_saliency[key],
                                                                   labels=sample_labels,
                                                                   similar_metric="l2")



for sample in samples_idx:
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    testdata = t.tensor(testset.data[sample, :, :]).float().to(device)
    testdata = testdata.reshape(1, testdata.shape[0], testdata.shape[1])
    target = testset.labels[sample]

    if vis_method is "GradCAM":
        gradCAM_map = saliency_constructor.grad_cam(idx=idx,
                                                    use_relu=False,
                                                    attribute_to_layer_input=True)[0]

    if vis_method is "CAM":
        class_map, prediction, target = cam_ucr(testdata=testset,
                                                model=model,
                                                checkpoint=ckp_path,
                                                samples_idx=sample,
                                                class_idx=False,
                                                target_layer="gap_softmax.conv1",
                                                fc_layer="gap_softmax.fc",
                                                used_relu=True,
                                                model_loaded=False
                                                )
    if vis_method is "LRP":
        class_map, prediction = LRP_individual(model=model, checkpoint=ckp_path, X=testdata, used_normalized=True)

    if vis_method is "Gradient":
        vanillaBackprop = VanillaBackprop(model=model, checkpoint=ckp_path)
        class_map, prediction = vanillaBackprop.generate_gradients(X=testdata)

    if vis_method is "IntegratedGradients":
        integratedGradients = IntegratedGradients(model=model, checkpoint=ckp_path)
        class_map, prediction = integratedGradients.generate_integrated_gradients(X=testdata, steps=100)
    if vis_method is "SmoothGrad":
        noise_sigma = 2
        num_samples = 80
        smoothgrad = SmoothGrad(model=model, checkpoint=ckp_path, noise_sigma=noise_sigma,
                                num_samples=num_samples)
        class_map, prediction = smoothgrad.generate_gradients(X=testdata)

    if vis_method is "LIME": ## ??? work, but seems wrong
        model = load_model(model, ckp_path=ckp_path, use_cuda=True)
        lime = Lime(model)
        class_map = lime.attribute(inputs=testdata, target=0, n_perturb_samples=200)
        class_map = class_map.cpu().detach().numpy()[0]
        print("Normalization")
        class_map -= np.min(class_map, axis=1)
        class_map /= (np.max(class_map, axis=1) - np.min(class_map, axis=1))
        class_map *= 100
        prediction = 0
    if vis_method is "SHAP":
        pass
    if vis_method is "ScoreCAM":
        pass
    if vis_method is "GradCAMpp":
        pass

    if vis_method is "Excitation_backprop":  ##fail
        class_map = excitation_backprop(module, fmap, testset.labels[sample],
                                        saliency_layer="conv1")

    if vis_method is "Perturbation":  ## fail
        masks_1, _ = extremal_perturbation(
            model=model, input=sample, target=int(target),
            reward_func=contrastive_reward,
            debug=True
        )
    if vis_method is "Rise":  ## Fail
        pass

    gkernel = gaussian_kernel_(15, 7, testdata[0])
    blur = lambda x: t.nn.functional.conv1d(x, gkernel.to(device), padding=15 // 2)
    quantile_passon = lambda x: quantile_values_like(quantile=0.2, dataset=x.cpu().numpy())
    insertion = MetricInsertDelete(model=model, mode="ins", step=1, substrate_fn=blur, device=device)
    deletion = MetricInsertDelete(model=model, mode="del", step=1,
                                  substrate_fn = quantile_passon, device = device)
    scores = insertion.single_run(testdata[0].cpu().numpy(),
                                  saliency_map=gradCAM_map)
    data = testdata[0]
    fig, axes = plt.subplots(data.shape[0], 1, sharex=True)
    fig.subplots_adjust(hspace=0)
    i = 0
    if data.shape[0] != 1:
        for ax in axes:
            ax.imshow(data[i], cmap=gradCAM_map[i].squeeze(), aspect='auto')
            i += 1
    else:
        axes.imshow(data.cpu().numpy(), aspect='auto')
        heat_map = axes.imshow(gradCAM_map, cmap='Blues', vmin=np.min(gradCAM_map[0]),
                               vmax=np.max(gradCAM_map[0]), alpha=0.5)
        # axis_separator = make_axes_locatable(axes)
        # colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
        # fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis)
        i += 1
    plt.show()


# def make_axes_locatable(axes):
#     divider = AxesDivider(axes)
#     locator = divider.new_locator(nx=0, ny=0)
#     axes.set_axes_locator(locator)
#
#     return divider
    # cbar = plt.colorbar()
    ## interpolation
    # win_len, inter_cam, inter_sample = interpolate_smooth_ucr(dataset=testset, cam_out=class_map, num_idx=sample,
    #                                                           num_sampling=10)
    #
    # path = root_dir + "results/" + dataset_name + "/" + dl_selected_model + "/visualizations/" + vis_method
    # path_done = create_directory(path)
    # ## plot and save plot
    # plot_vis_plt_ucr(inter_cam, win_len, inter_sample, f'prediction_{prediction}', f'target_{target}',
    #                  sample_idx=sample,
    #                  vis_method=vis_method, save_path=path)


# plt.show()

