"""
Present the results from the evaluation of Vis.Methods, including the area under curves in
        Deletion and Insertion way
Metric: Insertion/Deletion
Refer to: https://arxiv.org/abs/1806.07421
"""
## -------------------
## --- Third-Party ---
## -------------------
import sys
sys.path.append('..')
import numpy as np
import pandas as pd
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.collections import LineCollection
from scipy.ndimage.filters import gaussian_filter1d

### -----------
### --- Own ---
### -----------
from metrics.insertion_deletion import MetricInsertDelete, gaussian_kernel, auc
from metrics.insertion_deletion import gaussian_kernel_, quantile_values_like, zero_values_like
from trainhelper.dataset import Dataset
from utils import read_dataset_ts, load_model, throw_out_wrong_classified
from visualize_mechanism.visual_utils import SaliencyConstructor, min_max_normalize
from models.models import FCN, TCN


def plot_vis_plt_ucr(feature_map, x, data, label):
    """
    x: the length of data (ex: arange())
    how to plot colorful line:
    https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html
    """
    # plt.plot(data[0, :, :].transpose(0, 1))
    for i in range(data.shape[0]):
        #         gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
        fig, axs = plt.subplots(2, 1, sharex=True, gridspec_kw=dict(height_ratios=[3, 1]))
        ## first subplot
        axs[0].plot(x, data[i, :], linewidth=2)
        #       cmap = 'jet' ?
        color = axs[0].scatter(x, data[i, :], cmap='hot_r', marker='.', c=feature_map[i].squeeze(),
                               s=100, vmin=np.min(feature_map[i]), vmax=np.max(feature_map[i]), linewidths=3.0)
        fig.colorbar(color, ax=axs[0])

        ## second subplot
        points = np.array([x, feature_map[i, :]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments=segments, cmap='hot_r')
        # Set the values used for colormapping
        lc.set_array(feature_map[i, :])
        lc.set_linewidth(2)
        line = axs[1].add_collection(lc)
        fig.colorbar(line, ax=axs[1])
        #         plt.colorbar(feature_map[i, :])
    axs[0].set_ylabel(ylabel=str(label),
                      fontdict=dict(fontsize=20, color='b', rotation='horizontal'))
    plt.tight_layout()
    plt.show()


#     cbar = plt.colorbar()

## Data and Model Loading
# Deep Learning Model Selecion
dl_selected_model = "FCN_withoutFC"
use_fc = False
# Load the dataset
root_dir = "../"
dataset_name = "GunPoint"
dataset = read_dataset_ts(root_dir, dataset_name)
train_x, test_x, train_y, test_y, labels_dict = dataset[dataset_name]
## transfer test set into Torch Dataset
testset = Dataset(test_x, test_y)
## Parameters
path_2_parameters = root_dir + "results/" + dataset_name + "/" + dl_selected_model + "/"
report = pd.read_csv(path_2_parameters + "reports.csv")

## model setting and loading from checkpoint
ckp_path = path_2_parameters + "checkpoints/checkpoint_{}.ckp".format(report["best_epoch"][0])
kernel_size = [int(k) for k in report["kernel_size"][0][1:-1].split(',')]
ch_out = [int(k) for k in report["Filter_numbers"][0][1:-1].split(',')]
model = FCN(ch_in=int(test_x.shape[1]), ch_out=ch_out,
                dropout_rate=report["dropout_rate"][0],
                num_classes=report["num_classes"][0],
                kernel_size=kernel_size,
                use_fc=use_fc)

## Load the model weights
## Add Softmax as the last layer to produce the probability
model = load_model(model=model, ckp_path=ckp_path)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

## Clean out the wrong labels
cleandata, cleanlabels = throw_out_wrong_classified(model=model,
                                                    data=testset.data,
                                                    labels=testset.labels,
                                                    device=device)
cleantestset = Dataset(cleandata, cleanlabels)

## Create Saliency Object
saliency_constructor = SaliencyConstructor(model, data=cleantestset,
                                           use_prediction=True,
                                           device=device)

model = nn.Sequential(model, nn.Softmax(dim=1))
model = model.eval()
model = model.to(device)

## Substrate Functions
## Insertion: noise signal from gaussian filter
## Deletion: use 20 quantile values as end matrix
klen = 15
nsig = 7
gkernel = gaussian_kernel_(klen, nsig, testset.data[0])
blur = lambda x: nn.functional.conv1d(x, gkernel.to(device), padding= klen//2)

## an example for sample smooth
sample = testset.data[0]
#for i in range(sample.shape[0]):
gaussian_sample = gaussian_filter1d(sample, nsig)
noise_sample = blur(t.tensor(sample).float().to(device).unsqueeze(dim=0)).squeeze(dim=0)
noise_sample = noise_sample.cpu().numpy()
## Plot figure
# plt.figure(figsize=(12, 4))
# plt.subplot(131)
# plt.title("Compare raw with gaussian filter")
# plt.plot(range(sample.shape[-1]), sample[0], label= "Raw Signal")
# plt.plot(range(sample.shape[-1]), gaussian_sample[0], label="gaussian_filter1d")
# plt.plot(range(sample.shape[-1]), noise_sample[0], label="blur")
# plt.legend()
#
# plt.subplot(132)
# plt.title("Gaussian kernel")
# plt.axis('off')
# plt.imshow(gkernel[0])

quantile_passon = lambda x: quantile_values_like(quantile=0.2, dataset=x.cpu().numpy())
zero_passon = lambda x: zero_values_like(dataset=x.cpu().numpy())
## Insertion and Deletion Objects
insertion = MetricInsertDelete(model=model, mode="ins", step=1,
                               substrate_fn=blur, device=device)
deletion = MetricInsertDelete(model=model, mode="del", step=1,
                               substrate_fn=zero_passon, device=device)

## single sample run
idx = 50
gradCAM_map = saliency_constructor.gradient_saliency(idx=idx)[0]
# plt.subplot(133)
# plt.title("GradCAM")
# plot_vis_plt_ucr(gradCAM_map, np.arange(0, testset.data[idx].shape[-1]),
#                 testset.data[idx], testset.labels[idx])
# plt.show()
#
ins_scores = insertion.single_run(testset.data[idx], gradCAM_map, verbose=2)
del_scores = deletion.single_run(testset.data[idx], gradCAM_map, verbose=2)

## run the saliency map for whole dataset
# gradcam_maps = np.zeros(testset.data.shape)
# for i in range(len(testset.labels)):
#     gradcam_maps[i] = saliency_constructor.grad_cam(idx=i)[0]
# scores_ins, auc_ins = insertion.evaluation(testset.data, testset.labels,
#                                     gradcam_maps,
#                                     batch_size=30)
# scores_del, auc_del = deletion.evaluation(testset.data, testset.labels,
#                                     gradcam_maps,
#                                     batch_size=30)


