import numpy as np
import pandas as pd
import os
import re
import matplotlib.pyplot as plt

MARKERS = ["o", "v", "x", "s", "+", "d", "1", "*", "^", "p", "h"]
TAB10 = list(plt.get_cmap("tab10").colors)
TAB20 = list(plt.get_cmap("tab20").colors)
TAB20B = list(plt.get_cmap("tab20b").colors)
TAB20C = list(plt.get_cmap("tab20c").colors)
RAINBOW = [plt.cm.rainbow(i) for i in range(256)]
CATEGORICAL = ["#4f8c9d", "#fa756b", "#20b465", "#ce2bbc", "#51f310", "#660081",
               "#c9dd87", "#3f3369", "#f6bb86", "#0c4152", "#edb4ec", "#0b5313",
               "#b0579a", "#f4d403", "#7d0af6", "#698e4e", "#fb2076", "#65e6f9",
               "#74171f", "#b7c8e2", "#473a0a", "#7363e7", "#9f6c3b", "#1f84ec"]


def get_colors(n, color_map=None):
    """Get colors for plotting cell clusters.
    The colors can either be a categorical colormap or a continuous colormap.

    Args:
        n (int):
            Number of cell clusters
        color_map (str, optional):
            User-defined colormap. 
            If not set, the colors will be chosen as
            the colors for tabular data in matplotlib.
            Defaults to None.

    Returns:
        list[tuple]: list of color parameters
    """
    if color_map is None:  # default color
        if n <= 10:
            return TAB10[:n]
        elif n <= 24:
            return CATEGORICAL[:n]
        elif n <= 40:
            TAB40 = TAB20B+TAB20C
            return TAB40[:n]
        else:
            print("Warning: Number of colors exceeds the maximum (40)! Use a continuous colormap (256) instead.")
            return RAINBOW[:n]
    else:
        color_map_obj = list(plt.get_cmap(color_map).colors)
        k = len(color_map_obj)//n
        colors = ([color_map_obj(i) for i in range(0, len(color_map_obj), k)]
                  if k > 0 else
                  [color_map_obj(i) for i in range(len(color_map_obj))])
    return colors


class PerfLogger:
    """Class for saving the performance metrics
    """
    def __init__(self, save_path='perf', checkpoints=None):
        """Constructor

        Args:
            save_path (str, optional):
                Path for saving the data frames to .csv files. Defaults to 'perf'.
            checkpoints (list[str], optional):
                Existing results to load (.csv). Defaults to None.
        """
        self.save_path = save_path
        os.makedirs(save_path, exist_ok=True)
        self.n_dataset = 0
        self.n_model = 0
        self.metrics = ["Precision",
                        "Recall",
                        "F1",
                        "SHD"]
        self.metrics_type = ["Precision",
                              "Recall",
                              "F1",
                              "SHD"]
        if checkpoints is None:
            self._create_empty_df()
        else:
            self.df = pd.read_csv(checkpoints[0], header=[0], index_col=[0, 1])
            self.df_multi = pd.read_csv(checkpoints[1], header=[0], index_col=[0, 1, 2, 3])
            self.n_dataset = len(self.df.columns.get_level_values(0).unique())
            self.n_model = len(self.df.index.unique(1))

    def _create_empty_df(self):
        row_mindex = pd.MultiIndex.from_arrays([[], []], names=["Metrics", "Model"])
        row_mindex_2 = pd.MultiIndex.from_arrays([[], [], [], []], names=["Metrics", "Model", "Param Name", "Param Value"])
        col_index = pd.Index([], name='Dataset')
        self.df = pd.DataFrame(index=row_mindex, columns=col_index)
        self.df_multi = pd.DataFrame(index=row_mindex_2, columns=col_index)

    def insert(self, data_name, method, perf, param=None):
        """Insert the performance evaluation results

        Args:
            data_name (str):
                Name of the dataset
            method (str):
                Name of the method
            perf (dict):
                Dictionary of performance metrics
            param (tuple[str, int]):
                Tuple of the parameter name and its value.
                Defaults to None. If set to None, self.df will be updated.
                Otherwise, self.df_multi will be updated and perf will be considered
                a result from hyperparameter sweeping.
        """
        if data_name not in self.df_multi.columns and data_name not in self.df.columns:
            self.n_dataset += 1
        
        if method not in self.df.index.unique(1) and method not in self.df_multi.index.unique(1):
            self.n_model += 1

        if param is None:
            for metric in perf:
                self.df.loc[(metric, method), data_name] = perf[metric]
                self.df.sort_index(inplace=True)
                self.df_multi.sort_index(inplace=True)
        else:
            for metric in perf:
                self.df_multi.loc[(metric, method, param[0], param[1]), data_name] = perf[metric]
                self.df.sort_index(inplace=True)
                self.df_multi.sort_index(inplace=True)
        return

    def plot_summary(self, metrics=None, methods=None, figure_path=None, dpi=100):
        """Generate boxplots showing the overall performance metrics.
        Each plot shows one metric over all datasets, with methods as x-axis labels and 

        Args:
            metrics (list[str], optional):
                Performance metric to plot.
                If set to None, all metrics will be plotted.
            methods (list[str], optional):
                Methods to compare.
                If set to None, all existing methods will be included.
            figure_path (str, optional):
                Path to the folder for saving figures.
                If set to None, figures will not be saved.
                Defaults to None.
            bbox_to_anchor (tuple, optional):
                Location of the legend. Defaults to (1.25, 1.0).
        """
        n_model = self.n_model if methods is None else len(methods)
        colors = get_colors(n_model)
        if methods is None:
            methods = np.array(self.df.index.unique(1)).astype(str)
        if metrics is None:
            metrics = self.metrics
        for metric in metrics:
            if metric in self.df.index:
                df_plot = self.df.loc[metric]
            else:
                continue
            if methods is not None:
                df_plot = df_plot.loc[methods]
            vals = df_plot.values.T
            fig, ax = plt.subplots(figsize=(1.6*n_model+3, 4))
            # rectangular box plot
            bplot = ax.boxplot(vals,
                               vert=True,  # vertical box alignment
                               patch_artist=True,  # fill with color
                               labels=df_plot.index.to_numpy())  # will be used to label x-ticks
            for patch, color in zip(bplot['boxes'], colors):
                patch.set_facecolor(color)
            for line in bplot['medians']:
                line.set_color('black')
            for line in bplot['means']:
                line.set_color('black')
            ax.set_xlabel("")          
            ax.set_title(metric)
            ax.set_xticks(range(1, n_model+1), methods, rotation=0)
            ax.tick_params(axis='both', which='major', labelsize=15)
            # ax.grid()
            fig = ax.get_figure()
            fig.tight_layout()
            if figure_path is not None:
                fig_name = re.sub(r'\W+', ' ', metric.lower())
                fig_name = '_'.join(fig_name.rstrip().split())
                fig.savefig(f'{figure_path}/{metric}_summary.png', dpi=dpi, bbox_inches='tight')

        return

    def _plot_metrics_ax(self,
                         metric,
                         datasets,
                         methods,
                         show_legend=True,
                         ax=None,
                         **kwargs):
        colors = get_colors(len(methods))
        df_plot = self.df.loc[metric].loc[methods].T
        ax = df_plot.plot.bar(color=colors,
                              figsize=kwargs['figsize'],
                              legend=show_legend,
                              ax=ax)
        ax.set_xlabel("")
        ax.set_xticklabels(datasets, rotation=0)
        ax.set_title(metric, fontsize=kwargs['title_fontsize'])
        # ax.grid()
        ax.tick_params(axis='both', which='major', labelsize=kwargs['tick_fontsize'])
        if show_legend:
            ncols = 1 if 'ncols' not in kwargs else kwargs['ncols']
            if 'bbox_to_anchor' in kwargs:
                loc = 'center' if len(kwargs['bbox_to_anchor']) == 4 else 1
                ax.legend(fontsize=kwargs['legend_fontsize'], loc=loc, ncol=ncols, bbox_to_anchor=kwargs['bbox_to_anchor'])
            else:
                ax.legend(fontsize=kwargs['legend_fontsize'], ncol=ncols)
        return ax

    def plot_metrics(self,
                     metrics=[],
                     datasets=[],
                     methods=[],
                     figure_path=None,
                     figsize=(12, 6),
                     title_fontsize=20,
                     legend_fontsize=16,
                     tick_fontsize=15,
                     bbox_to_anchor=(1.25, 1.0),
                     dpi=100):
        """Generate bar plots showing all scalar performance metrics.
        Each plot has different datasets as x-axis labels and different bars represent methods.

        Args:
            metrics (list[str], optional):
                Performance metrics to plot. Defaults to [].
            datasets (list[str], optional):
                Datasets to plot. Defaults to [].
            methods (list[str], optional):
                Methods to compare. Defaults to [].
            figure_path (str, optional):
                Path to the folder for saving figures.
                If set to None, figures will not be saved.
                Defaults to None.
            figsize (tuple, optional):
                Size of the figure. Defaults to (12, 6).
            title_fontsize (int, optional):
                Font size of the title. Defaults to 20.
            legend_fontsize (int, optional):
                Font size of the legend. Defaults to 16.
            tick_fontsize (int, optional):
                Font size of the ticks. Defaults to 15.
            bbox_to_anchor (tuple, optional):
                Location of the legend. Defaults to (1.25, 1.0).
            dpi (int, optional):
                Resolution of the saved figures. Defaults to 100.
        """
        if len(datasets) == 0:
            datasets = self.df.columns.unique(0)
        if len(methods) == 0:
            methods = np.array(self.df.index.unique(1)).astype(str)
        if len(metrics) == 0:
            metrics = self.metrics

        for metric in metrics:
            if metric not in self.metrics:
                continue
            fig_name = re.sub(r'\W+', ' ', metric.lower())
            fig_name = '_'.join(fig_name.rstrip().split())
            if np.all(np.isnan(self.df.loc[metric, :].values)):
                continue
            ax = self._plot_metrics_ax(metric,
                                       datasets,
                                       methods,
                                       figsize=figsize,
                                       title_fontsize=title_fontsize,
                                       legend_fontsize=legend_fontsize,
                                       tick_fontsize=tick_fontsize,
                                       bbox_to_anchor=bbox_to_anchor)
            fig = ax.get_figure()
            fig.tight_layout()
            if figure_path is not None:
                fig.savefig(f'{figure_path}/perf_{fig_name}.png', bbox_inches='tight', dpi=dpi)
        return

    def _plot_metrics_sweep_ax(self,
                               param_name,
                               metric,
                               method,
                               datasets,
                               show_legend=True,
                               ax=None,
                               **kwargs):
        colors = get_colors(len(datasets))
        param_vals = self.df_multi.loc[(self.metrics[0], method, param_name)].index.to_numpy().astype(float)
        for i, dataset in enumerate(datasets):
            ax = self.df_multi.loc[(metric, method, param_name), dataset].plot(marker=MARKERS[i],
                                                                               markersize=kwargs['markersize'],
                                                                               color=colors[i],
                                                                               label=dataset,
                                                                               ax=ax,
                                                                               linewidth=kwargs['linewidth'],
                                                                               figsize=kwargs['figsize'])

        ax.set_xticks(param_vals, param_vals, rotation=0)
        # Automacially check log scale
        if np.all(param_vals > 0):
            logvals = np.log10(param_vals)
            if np.diff(logvals).mean() > 0.3:
                ax.set_xscale('log')
        if show_legend:
            ncols = 1 if 'ncols' not in kwargs else kwargs['ncols']
            if 'bbox_to_anchor' in kwargs:
                ax.legend(fontsize=kwargs['legend_fontsize'],
                          ncol=ncols,
                          loc='center',
                          bbox_to_anchor=kwargs['bbox_to_anchor'])
            else:
                ax.legend(fontsize=kwargs['legend_fontsize'], ncol=ncols)
        ax.set_title(method, fontsize=kwargs['title_fontsize'])
        # ax.grid()
        ax.set_xlabel(param_name, fontsize=kwargs['ylabel_fontsize'])
        ax.set_ylabel(metric, fontsize=kwargs['ylabel_fontsize'], labelpad=kwargs['labelpad'])
        ax.tick_params(axis='both', which='major', labelsize=kwargs['tick_fontsize'])
        return ax

    def save(self, file_name=None):
        """Save data frames to .csv files.

        Args:
            file_name (str, optional):
                Name of the csv file for saving. Does not need the path
                as the path is specified when an object is created.
                If set to None, will pick 'perf' as the default name.
                Defaults to None.
        """
        if file_name is None:
            file_name = "perf"
        self.df.to_csv(f"{self.save_path}/{file_name}.csv")
        self.df_multi.to_csv(f"{self.save_path}/{file_name}_multi.csv")
