import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
# plotting defaults
# plt.rcParams["font.family"] = "Times"
plt.rcParams["font.family"] = 'sans-serif'
# Enable LaTeX and set the preamble
# plt.rcParams['text.usetex'] = True
# plt.rcParams['text.latex.preamble'] = r'\usepackage{mathptmx} \usepackage{amsmath}'


def plot_single_metric(metric_array:np.ndarray, 
                    se_array:np.ndarray=None,
                    indi_data_array:np.ndarray=None, 
                    fig_title:str="Accuracy",
                    y_label:str="Accuracy",
                    chance_level:float=0.5,
                    y_lim:tuple=(None, None), 
                    block_lines:tuple=(8.5, 16.5)):
    """ plot the metric derived from the training data, se and individual data
    Args:   metric_array: np.ndarray, the metric of each block
            indi_data_array: np.ndarray, individual accuracy of each block
            se_array: np.ndarray, standard error of the mean
            fig_title: str, title of the figure
            fig_folder: str, folder to save the figure
            fig_name: str, name of the figure
            chance_level: float, chance level of the task
    """
    # Define figure
    cmap = matplotlib.cm.get_cmap('Set1')
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
    fig, ax = plt.subplots(1, 1, figsize=(4.2, 4), dpi=170, facecolor='w')

    # Block lines
    if block_lines:
        for i in block_lines:
            ax.axvline(x=i, color='k', linestyle=(5, (10, 3)), linewidth=1.2, alpha=0.5)

    n_blocks = metric_array.shape[0]
    # Average per block
    for i in range(n_blocks):
        ax.errorbar(
            x = i+1,
            y = metric_array[i],
            yerr=se_array[i] if se_array is not None else None,
            color = cmap(norm(.2)),
            markersize = 5,
            marker = 'D',
            alpha = 1
        )
    # Individual data check if it is not empty
    if indi_data_array is not None:
        n_sujs = indi_data_array.shape[0]
        for i in range(n_blocks):
            ax.plot(
                np.ones(n_sujs) * i+1 + np.random.normal(0, 0.05, (n_sujs)), 
                indi_data_array[:, i],
                color = cmap(norm(.2)),
                markersize = 3,
                marker = 'o',
                alpha = 0.5, 
                lw = 0
            )

    # Chance level
    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5 )

    # Aesthetics
    ax.set_title(fig_title, fontweight='bold')
    ax.set_xlabel('Block \#', fontsize=16, labelpad=10)
    ax.set_ylabel(y_label, fontsize=16, labelpad=10)
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)

    # make x tick every 4 blocks and stop at final block
    ax.set_xticks(np.concatenate([[1],np.arange(0, n_blocks+1, 4)]))
    # remove the ticks on the left only
    ax.tick_params(axis='y', which='both', length=0, labelsize=13)
    ax.tick_params(axis='x', which='both', labelsize=13)
    ax.margins(x=0.02)
    plt.tight_layout()
    return fig, ax


def plot_three_metrics(metric_array:np.ndarray,
                    se_array:np.ndarray=None,
                    fig_title:str=None,
                    y_label:str="Accuracy",
                    chance_level:float=0.5,
                    y_lim:tuple=(None, None),
                    color_list:tuple =None,
                    marker_list:tuple = ('o', 'o', 'o'),
                    label_list:tuple = ("Top level", "Mid level", "Bot level"),
                    leg_loc:tuple = (.75, .2),
                    block_lines:tuple=(8.5, 16.5),
                    theory_baseline:tuple=None,
                    function_fits:np.ndarray=None,
                    title_margin:float=1.0,
                    line_width:float=1.5,
                    labelpad:int=3,
                    title_fontsize:int=16):
    """_summary_ plot the metric derived from the training data,
    for all three separate metrics, this will usually be the three levels of the hierarchy
    Args:
        metric_array (np.ndarray): 2d array of metric (n_blocks, n_metrics)
        se_array (np.ndarray): 2d array of standard error of the mean (n_blocks, n_metrics)
        fig_title (str, optional): Defaults to "Accuracy".
        y_label (str, optional):  Defaults to "Accuracy".
        chance_level (float, optional): Defaults to 0.5.
        y_lim (tuple, optional): the limits of the y_axis. Defaults to (None, None).
        color_list (list, optional): colors used for the three lines. Defaults to ['#fd7f6f', 
                                                                                  '#bd7ebe', 
                                                                                  '#8bd3c7'
                                                                                  ].
        marker_list (list, optional): Defaults to ['o', 'D', 's'].
        label_list (list, optional): Labels for legend. Defaults to ["First Property", 
                                                                 "Second Property", 
                                                                 "Third Property"
                                                                 ].
        leg_loc (tuple, optional): Location of the legend. Defaults to (.75, .2).  
        block_lines (tuple, optional): Defaults to (8.5, 16.5).
        theory_baseline (tuple, optional): Defaults to None.
        function_fits (np.ndarray, optional): Defaults to None. plot function fits if given.
        title_margin (float, optional): Defaults to 1.0.
        line_width (float, optional): Defaults to 1.5.
    """
    # Define figure
    fig, ax = plt.subplots(1, 1, figsize=(4.2, 4), dpi=170, facecolor='w')

    # Block lines
    if block_lines:
        for i in block_lines:
            ax.axvline(x=i, color='k', linestyle=(5, (10, 3)), linewidth=1.2, alpha=0.5)
    # if no color list is provided, use the default
    if color_list is None:
        color_list = matplotlib.colormaps['cividis_r'](np.linspace(.09, .85, metric_array.shape[0]))

    n_levels = metric_array.shape[0]
    n_blocks = metric_array.shape[1]
    # add theory baseline to metric array
    if theory_baseline:
        metric_array = np.hstack((np.array(theory_baseline)[:, np.newaxis], metric_array))
        se_array = np.hstack((np.zeros((n_levels, 1)), se_array)) if se_array is not None else None

    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5)

    # Average per block
    for j in range(n_levels):
        ax.plot(np.arange(n_blocks+1), metric_array[j], color=color_list[j], label=label_list[j] if label_list else None, linewidth=line_width)
        if se_array is not None:
            ax.fill_between(np.arange(n_blocks+1), metric_array[j]-se_array[j], metric_array[j]+se_array[j],
                            color=color_list[j], alpha=0.5)
        # ax.errorbar(
        #     x = np.arange(n_blocks) + 1,
        #     y = metric_array[j],
        #     yerr = se_array[j] if se_array is not None else None,
        #     color = color_list[j],
        #     markersize = 5,
        #     marker =  marker_list[j],
        #     alpha = 1,
        #     label = label_list[j],
        #     # only plot the line if no function fit is given
        #     linestyle = '-' if function_fits is None else 'None'
        # )
        # # If there's a theoretical baseline, connect it to the first data point of each block
        # if theory_baseline and function_fits is None:
        #     start_x = 0
        #     ax.plot(
        #         [start_x, start_x + 1],  # x-values
        #         [theory_baseline[j], metric_array[j, 0]],  # y-values
        #         color=color_list[j],  # Use the same color as the metric's line
        #         alpha = 1
        #     )
        if function_fits is not None and function_fits.shape[0] == n_levels:
            # plot the function fit
            ax.plot(
                np.linspace(0, n_blocks, 100),
                function_fits[j],
                color = color_list[j],
                linestyle = '-',
                alpha = 0.8
            )
    # if we have only one function fit the array has only one dimension
    if function_fits is not None and function_fits.ndim == 1:
        ax.plot(
            np.linspace(0, n_blocks, 100),
            function_fits,
            linestyle = '-',
            alpha = 0.8
        )

    # plot the chance markers if there is a theoretical baseline
    if theory_baseline:
        ax.errorbar(
            x = [0,0,0],
            y = theory_baseline,
            color = 'Black',
            marker =  'P',
            markersize=8,
            alpha = 1,
            label = 'Chance',
            linestyle = 'None'
        )

    # Aesthetics
    if fig_title is not None:
        ax.set_title(fig_title, y=title_margin, fontsize=title_fontsize)
    ax.set_xlabel('Block #', fontsize=16, labelpad=labelpad)
    ax.set_ylabel(y_label, fontsize=16, labelpad=labelpad)
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)
    # make x tick every 4 blocks and stop at final block
    if not theory_baseline:
        ax.set_xticks(np.concatenate([[1],np.arange(0, n_blocks+1, 4)]))
    else:
        ax.set_xticks(np.arange(0, n_blocks+1, 4))
    # remove the ticks
    ax.tick_params(axis='both', which='both', length=0, labelsize=15)
    ax.margins(x=0.02)

    # change the spine width
    for spine in ax.spines.values():
        spine.set_linewidth(1.3)

    # legend
    if leg_loc is not None:
        ax.legend(loc=leg_loc,borderpad=0.5, fontsize=9, frameon=False)

    plt.tight_layout()
    return fig, ax


def _save_figure(fig_folder, fig_name, fig):
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                        '..',
                                        '..',
                                        'results', 
                                        'plots'))
    # make the folder if it does not exist
    if not os.path.exists(f'{results_folder}{fig_folder}'):
        os.makedirs(f'{results_folder}{fig_folder}')
    fig_name_png = f'{results_folder}{fig_folder}{fig_name}.png'
    fig_name_pdf = f'{results_folder}{fig_folder}{fig_name}.pdf'
    fig_name_svg = f'{results_folder}{fig_folder}{fig_name}.svg'
    fig.savefig(fig_name_png)
    fig.savefig(fig_name_pdf)
    fig.savefig(fig_name_svg)
        