""" 4. Visualize evaluated experiment results. """
import os
import numpy as np
import pandas as pd
import matplotlib

matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import seaborn as sns
from pathlib import Path
from CDExperimentSuite_DEV.src import utils
import sys
from operator import itemgetter
import re

np.set_printoptions(suppress=False)  # scientific notation is ok
np.set_printoptions(threshold=sys.maxsize)  # don't truncate big arrays


class Visualizer:
    def __init__(self, opt, accept=lambda x: True):
        """
        Load a result df from path
        Args:
            exp_name(String): Name of the experiment
            base_dir(String): parent dir of data and viz
            overwrite_prev(bool): Overwrite _viz folder
            accept(func): allows to subfilter experiments, e.g.
                lambda x: ('_5_' in x or '_10_' in x) and ('gauss' in x or 'uniform' in x)
                plots only gauss uniform d=5 d=10
            thres_type: Which thresholding type to use
        """
        self.opt = opt
        if not isinstance(opt.exp_name, list):
            self.exp_names = [opt.exp_name]
        else:
            self.exp_names = opt.exp_name

        # read from all experiments provided
        self.res_dfs = []
        self.res_df_names = []

        def beautify_exp_name(x):
            return x.replace("normalized", "standardized").replace(
                "norm", "standardized"
            )

        for i in self.exp_names:
            input_dir = os.path.join(
                opt.base_dir, os.path.basename(opt.base_dir) + i, "_eval"
            )
            results_files = list(
                Path(input_dir).rglob(f"{opt.thres_type}_{opt.thres}.csv")
            )
            for f in results_files:
                if accept(str(f)):
                    print("Reading from", f, "...")
                    df = utils.load_results(f)
                    df["experiment"] = i
                    df["experiment"] = df["experiment"].apply(beautify_exp_name)
                    name = f.stem

                    # merge datasets with same name from different experiments to allow for comparison
                    if name in self.res_df_names:
                        idx = self.res_df_names.index(name)
                        self.res_dfs[idx] = pd.concat((self.res_dfs[idx], df), axis=0)
                        print(f"...merging across experiments into {name}")
                    else:
                        self.res_dfs.append(df)
                        self.res_df_names.append(name)
            print()

        # output
        if len(self.exp_names) > 1:  # comparison of multiple experiments
            self.exp_name = os.path.basename(opt.base_dir) + "-vs-".join(opt.exp_name)
        else:
            self.exp_name = os.path.join(
                os.path.basename(opt.base_dir) + self.exp_names[0], "_viz"
            )
        head_folder = os.path.join(opt.base_dir, self.exp_name)
        utils.create_folder(head_folder)
        assert len(opt.n_nodes) == 1
        self.output_folder = os.path.join(
            head_folder, f"thr_{opt.thres}_{max(opt.n_nodes)}nodes"
        )
        utils.create_folder(self.output_folder, opt.overwrite)

        # rename for pretty axis names
        self.pretty_algo_names = {
            "sortnregressIC": "Var-SortnRegress",
            "sortnregressIC_R2": r"$R^2$-SortnRegress",
            "randomregressIC": "RandomRegress",
        }
        self.pretty_axes_names = {
            "sid": "Structural Intervention Distance",
            "shd": "Structural Hamming Distance",
            "algorithm": "Algorithm",
        }
        self.ugly_axes_names = dict(
            zip(self.pretty_axes_names.values(), self.pretty_axes_names.keys())
        )

    def _decorator(func, filters={}, custom_name="", acc_measure="SID"):
        """
        Common decorator for all plotting functions
        Args:
            acc_measure(String): Evaluation metric to plot
            filters(dict): Filter arguments, e.g. "graph: 'ER-2'"
        """

        def function_wrapper(
            self, custom_name="", filters=filters, acc_measure=acc_measure, **kwargs
        ):
            plt.close("all")
            function_name = func.__name__

            # do preprocessing
            for i, res_df in enumerate(self.res_dfs):

                def _select_data(filters, df):
                    res = res_df
                    for k, v in filters.items():
                        v_filter = [v] if not isinstance(v, list) else v
                        mask = res[k].apply(lambda x: x in v_filter)
                        res = res.loc[mask, :]
                    return res

                subset = _select_data(filters, res_df).copy()

                # check if empty
                if len(subset) == 0:
                    continue

                # ordering
                if "algorithm" in filters.keys():
                    subset["algorithm"] = subset["algorithm"].apply(
                        lambda x: str(filters["algorithm"].index(x)) + x
                    )
                # '_raw' always first, keep algo order from filters list
                subset.scaler = subset.scaler.replace({"Identity": "AAA"})
                subset.sort_values(by=["algorithm", "scaler"], inplace=True)
                subset.scaler = subset.scaler.replace({"AAA": "Identity"})
                if "algorithm" in filters.keys():
                    subset["algorithm"] = subset["algorithm"].apply(
                        lambda x: re.sub("^[0-9]", "", x)
                    )

                # beautify names
                subset["algorithm"].replace(self.pretty_algo_names, inplace=True)
                subset.rename(columns=self.pretty_axes_names, inplace=True)
                long_acc_measure = self.pretty_axes_names[acc_measure]

                g = func(
                    self,
                    subset,
                    res_df,
                    custom_name=custom_name,
                    acc_measure=long_acc_measure,
                    **kwargs,
                )
                if g is None:
                    print("no figure produced")
                    continue

                # save plot
                name = (
                    "_".join([acc_measure.upper()] + list(filters.keys()))
                    + f"_{''.join(self.opt.exp_name)}{custom_name}.pdf"
                ).replace(" ", "")
                g.savefig(
                    f"{self.output_folder}/{function_name}_{self.res_df_names[i]}_{name}",
                    dpi=200,
                )
                plt.close("all")

        return function_wrapper

    @_decorator
    def boxplot(self, dataset, raw_res_df, acc_measure, custom_name, static_xlim=False):
        upper_xlim = np.max(raw_res_df[self.ugly_axes_names[acc_measure]])

        fig, ax = plt.subplots(figsize=(9, 4.5))  # (10, 5)

        # background grid
        ax.set_facecolor("#FFFFFF")
        plt.grid(True, linewidth=0.5, color="#999999", linestyle="-")
        ax.set_axisbelow(True)

        # legend pathes
        if not static_xlim:
            upper_xlim = np.max(dataset[acc_measure])
        if "selection" in custom_name:
            rect1 = patches.Rectangle(
                (0, -0.5),
                upper_xlim,
                4,
                linewidth=1,
                edgecolor="blue",
                facecolor="blue",
                alpha=0.1,
                zorder=0,
            )
            rect2 = patches.Rectangle(
                (0, 3.5),
                upper_xlim,
                3,
                linewidth=1,
                edgecolor="red",
                facecolor="red",
                alpha=0.1,
                zorder=0,
            )
            rect3 = patches.Rectangle(
                (0, 6.5),
                upper_xlim,
                1,
                linewidth=1,
                edgecolor="grey",
                facecolor="grey",
                alpha=0.1,
                zorder=0,
            )
            ax.add_patch(rect1)
            ax.add_patch(rect2)
            # ax.add_patch(rect3)

        # axvlines
        snr_name = self.pretty_algo_names["sortnregressIC"]
        if snr_name in dataset.Algorithm.unique():
            random_results = dataset.loc[
                (dataset["Algorithm"] == snr_name)
                & (dataset["scaler"].apply(lambda x: x == "Normalizer")),
                [acc_measure],
            ]
            if len(random_results) != 0:
                plt.axvline(x=np.median(random_results), color="grey", linestyle="--")
        else:
            random_results = []
        if "empty" in dataset["Algorithm"].unique():
            draw_empty = True
            empty_results = dataset.loc[
                (dataset["Algorithm"] == "empty")
                & (dataset["scaler"].apply(lambda x: "norm" in x.lower())),
                [acc_measure],
            ]
            plt.axvline(x=np.median(empty_results), color="red", linestyle="-")
            dataset = dataset.loc[dataset.Algorithm != "empty", :]
        else:
            draw_empty = False

        p = itemgetter(1, 0)(sns.color_palette("tab10"))  # (1, 9)
        sns.violinplot(
            x=dataset[acc_measure],
            y=dataset["Algorithm"],
            hue=(
                dataset["experiment"]
                .apply(lambda x: x.split("_")[-1])
                .astype("category")
            ),
            split=True,
            inner=None,
            ax=ax,
            palette=p,
            zorder=1,
        )

        sns.despine(left=True, bottom=True)
        ax.set_xlabel(ax.get_xlabel(), labelpad=10)

        # legend handles
        handles, labels = ax.get_legend_handles_labels()
        if len(random_results) > 0:
            handles.append(Line2D([0], [0], color="grey", lw=2, linestyle="--"))
            labels.append("RandomRegress")
        if custom_name == "selection":
            handles.append(patches.Patch(facecolor="blue", alpha=0.3, zorder=2))
            labels.append("combinatorial")
            handles.append(patches.Patch(facecolor="red", alpha=0.3, zorder=2))
            labels.append("continuous")
            if draw_empty:
                handles.append(Line2D([0], [0], color="red", lw=1, linestyle="-"))
                labels.append("empty graph")
        ax.legend(
            handles=handles,
            labels=labels,
            bbox_to_anchor=(0.5, 1.15),
            loc="upper center",
            ncol=3,
            frameon=True,
            facecolor="white",
        )

        ax.set_xlim(0, upper_xlim)

        plt.tight_layout()
        return fig

    @_decorator
    def sortability(self, dataset, raw_res_df, acc_measure, custom_name, stb_name):
        """show performance for different levels of varsortability"""

        r2name = r"$\boldsymbol{R^2}$\textbf{-SortnRegress}"
        order_dict = {
            r2name: "0" + r2name,
            "PC": "1PC",
            "FGES": "2FGES",
            "RandomRegress": "3RandomRegress",
            "Var-SortnRegress\n(on raw data)": "4Var-SortnRegress\n(on raw data)",
        }
        dataset.Algorithm = dataset.Algorithm.replace(order_dict)
        dataset.sort_values(by=["Algorithm", "edge_weight_range"], inplace=True)
        dataset.Algorithm = dataset.Algorithm.replace(
            {v: k for k, v in order_dict.items()}
        )
        dataset.Algorithm = dataset.Algorithm.replace({"$R^2$-SortnRegress": r2name})
        dataset["Edge weight range"] = dataset["edge_weight_range"]

        # start plot
        fig, _ = plt.subplots(figsize=(9, 6))

        colors = [
            "#1f77b4",
            "#ff7f0e",
            "#2ca02c",
            "#808080",
            "#808080",
            "#d62728",
            "#9467bd",
        ]
        # smooth by rolling window of fixed size
        window_size = 0.1
        nwindows = 21
        ticks = []
        for idx, a in enumerate(dataset.Algorithm.unique()):
            to_plot = {stb_name: [], "Algorithm": [], acc_measure: []}
            df = (
                dataset.copy()
                .loc[(dataset.Algorithm == a), :]
                .sort_values(by=stb_name)
                .reset_index()
            )
            for i in np.linspace(0, 1, nwindows):
                start = i - 0.5 * window_size
                end = i + 0.5 * window_size
                if start >= 0 and end <= 1:
                    window_y = df.loc[
                        (df[stb_name] >= start) & (df[stb_name] < end), acc_measure
                    ]
                    if len(window_y) > 1:
                        to_plot[stb_name] += len(window_y) * [i]
                        to_plot["Algorithm"] += len(window_y) * [a]
                        to_plot[acc_measure] += list(window_y)
                        if i not in ticks:
                            ticks.append(i)
            sns.lineplot(
                data=pd.DataFrame(to_plot),
                x=stb_name,
                y=acc_measure,
                hue="Algorithm",
                style="Algorithm",
                palette=sns.color_palette([colors[idx]]),
                linewidth=3,
                errorbar=("ci", 95),
            )

        if stb_name == "R2sortability":
            plt.xlabel(r"$R^2$-sortability")
        elif stb_name == "CEVsortability":
            plt.xlabel("CEV-sortability")
        plt.xlim((None, 1 - 0.5 * window_size))
        plt.legend(loc="upper right")
        plt.tight_layout()
        return fig
