"""
Plot the Visualization Method, combined with original data

"""
## -------------------
## --- Third-Party ---
## -------------------
import sys
sys.path.append('..')
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.collections import LineCollection
from typing import Dict

## -----------
## --- Own ---
## -----------
from utils import create_directory

def plot_vis_plt(feature_map, x, data, cla, name: str, save_path: str, method_name: str):
    """
    x: the length of data (ex: arange())
    """
    plt.figure(figsize=(13, 7))
    # plt.plot(data[0, :, :].transpose(0, 1))
    plt.scatter(x, data[0, :], cmap='jet', marker='.', c=feature_map.squeeze(),
                s=5, vmin=0, vmax=100, linewidths=3.0)
    # plt.plot(data[0, :, :])
    plt.scatter(x, data[1, :], cmap='jet', marker='.', c=feature_map.squeeze(),
                s=5, vmin=0, vmax=100, linewidths=3.0)
    # plt.plot(data[0, :, :])
    plt.scatter(x, data[2, :], cmap='jet', marker='.', c=feature_map.squeeze(),
                s=5, vmin=0, vmax=100, linewidths=3.0)
    plt.title(f'{name} X Y Z axis')
    plt.colorbar()
    save_path = save_path + "/" + f"{method_name}_{name}.png"
    plt.savefig(save_path)
    # plt.show()


def plot_vis_plt_ucr(feature_map, x, data, label: int, predict_label: int,
                     save_path: str = None, vis_name: str = None):
    """

    how to plot colorful line:
    https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html

    feature_map (np.array) : The feature map from vis. Methods
    x: the length of data (ex: arange())
    data (np.array) : The raw data
    label (int) : The label of the data
    predict_label (int) : The label that model predicted
    save_path: the path to save the figure
    """
    fig, axs = plt.subplots(2, 1, sharex=True, gridspec_kw=dict(height_ratios=[3, 1]))
    for i in range(data.shape[0]):
        ## first subplot
        axs[0].plot(x, data[i, :], linewidth=2)
        #       cmap = 'jet' ?
        color = axs[0].scatter(x, data[i, :], cmap='hot_r', marker='.', c=feature_map[i].squeeze(),
                               s=100, vmin=np.min(feature_map[i]), vmax=np.max(feature_map[i]),
                               linewidths=3.0)
        fig.colorbar(color, ax=axs[0])

        ## second subplot
        points = np.array([x, feature_map[i, :]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments=segments, cmap='hot_r')
        # Set the values used for colormapping
        lc.set_array(feature_map[i, :])
        lc.set_linewidth(2)
        line = axs[1].add_collection(lc)
        fig.colorbar(line, ax=axs[1])
        #         plt.colorbar(feature_map[i, :])
        plt.tight_layout()
    if save_path is not None:
        plt.title(f'{vis_name}: Target : {label}, Prediction : {predict_label}')
        save_path = save_path + "/" + f"{vis_name}_prediction_{predict_label}"
        new_save_path = save_path + ".png"
        i = 1
        while os.path.exists(new_save_path):
            new_save_path = save_path + "_{}.png".format(i)
            i += 1
        plt.savefig(new_save_path, bbox_inches='tight', dpi=200)
    plt.show()

def plot_histogram(histogram, classes, label_dicts,
                   save_path: str, vis_method: str, threshold: float):
    old_save_path = save_path
    ## for coloring labels
    num_dims = histogram.shape[1]
    color_mapping = plt.get_cmap().colors
    cmap = color_mapping[0::256 // num_dims]

    for i, c in enumerate(classes):
        plt.figure(figsize=(13, 7))
        for j in range(histogram.shape[1]):
            plt.plot(np.arange(0, np.shape(histogram)[-1]), histogram[i, j, :], label=f"dim_{j}",
                     c=cmap[j])
            # plt.hist(histogram[i, :], bins=1)
        plt.legend()
        plt.title(f"Histogram with Class_{label_dicts[c]}_threshold_{threshold}_{vis_method}")
        save_path = old_save_path + "/temporal_histograms/"
        path_done = create_directory(save_path)
        save_path = save_path + f"class_{label_dicts[c]}_threshold_{threshold}_{vis_method}.png"
        plt.savefig(save_path, bbox_inches="tight", dpi=720)
    plt.show()

def plot_hist(histogram: np.ndarray):
    """
    Parameters
    ----------
    histogram (np.ndarray) : the histogram, created from Temporal Instability

    """
    num_dims = histogram.shape[1]
    num_classes = histogram.shape[0]
    color_mapping = plt.get_cmap().colors
    cmap = color_mapping[0::256 // num_dims]

    plots_idx = 1
    fig = plt.figure()
    for i in range(num_classes):
        for j in range(num_dims):
            plt.subplot(num_classes, num_dims, plots_idx)
            x = np.arange(int(histogram.shape[-1]))
            plt.bar(x, histogram[i, j, :])
            plt.xlabel("time axis")
            plt.ylabel("number of hits")
            plt.title("Histogram Class: {}, Features: {}".format(i, j))
            fig.tight_layout()
            plots_idx += 1

    plt.show()

def plot_confusion_matrix(criterions: Dict, save_path: str, metric_name: str):
    if "threshold_count_histogram" in criterions.keys():
        threshold_count = criterions["threshold_count_histogram"]
    if "threshold_highlightpoint" in criterions.keys():
        threshold_highlight = criterions["threshold_highlightpoint"]
    cm = criterions[f"{metric_name}_confusion_matrix"]
    new_save_path = save_path + f"/confusion_matrix_thres_highlightpoint_{threshold_highlight}_thres_count_hist_{threshold_count}_{metric_name}_perturbation.png"

    plt.clf()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Wistia)
    classNames = [*criterions["label_summary"].values()]

    plt.title(f'Confusion Matrix of threshold_highlightpoint {threshold_highlight} thres_count_hist_{threshold_count} {metric_name}_Perturbation of'
              f' {criterions["Dataset"]} with {criterions["Classifier"]}')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    tick_marks = np.arange(len(classNames))
    plt.xticks(tick_marks, classNames, rotation=45)
    plt.yticks(tick_marks, classNames)

    for i in range(len(classNames)):
        for j in range(len(classNames)):
            plt.text(j, i, str(cm[i][j]))
    plt.savefig(new_save_path)
    plt.show()