import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib
# from src.plotting.human_plotting_functions import _save_figure
# plotting defaults
# plt.rcParams["font.family"] = "Times"
plt.rcParams["font.family"] = 'sans-serif'
# Enable LaTeX and set the preamble
# plt.rcParams['text.usetex'] = True
# plt.rcParams['text.latex.preamble'] = r'\usepackage{mathptmx} \usepackage{amsmath}'


def plot_single_metric(metric_array:np.ndarray, 
                    se_array:np.ndarray=None,
                    fig_title:str=None,
                    y_label:str="Accuracy",
                    x_label:str="Steps, $t$",
                    fig_folder:str=None,
                    fig_name:str="accuracy_training", 
                    chance_level:float=0.5,
                    log_interval:int = 250,
                    y_lim:tuple=(None, None), 
                    legend_str:str=None,
                    line_width:float=1.5,
                    labelpad:int=1,
                    figsize:tuple=(4.2, 4),
                    title_pad = 6.0,
                    color = 'C0',
                    title_fontsize:int=16):
    # Define figure
    cmap = matplotlib.cm.get_cmap('Set1')
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=170, facecolor='w')

    # Chance level
    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5)

    n_epochs = metric_array.shape[0]
    ax.plot(np.arange(n_epochs) * log_interval, metric_array, label=legend_str or None, linewidth=line_width, color=color)
    if se_array is not None:
        ax.fill_between(np.arange(n_epochs)* log_interval, metric_array-se_array, metric_array+se_array, alpha=0.5, color=color)

    # Aesthetics
    if fig_title:
        ax.set_title(fig_title, fontsize=title_fontsize, pad=title_pad)
    ax.set_xlabel(x_label, fontsize=16, labelpad=labelpad)
    ax.set_ylabel(y_label, fontsize=16, labelpad=labelpad)
    ax.set_ylim(y_lim[0], y_lim[1])
    plt.tight_layout()
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)
    ax.margins(x=0.01)
    # change the spine width
    for spine in ax.spines.values():
        spine.set_linewidth(1.3)
    
    # ax.spines['top'].set_visible(False)
    # ax.spines['right'].set_visible(False)

    ax.ticklabel_format(style='sci',scilimits=(-3,3),axis='both')
    # remove the ticks
    ax.tick_params(axis='both', which='both', length=0, labelsize=15)

    # make the offset text larger
    ax.xaxis.get_offset_text().set_fontsize(12)
    if legend_str:
        ax.legend()
    return fig, ax

def plot_single_metric_axis(ax, metric_array, se_array=None, 
                            fig_title=None, y_label="Accuracy", 
                            x_label="Steps, $t$", chance_level=0.5, 
                            log_interval=250, y_lim=(None, None), 
                            legend_str=None, line_width=1.5, 
                            labelpad=1, title_pad=6.0,
                            title_fontsize=13, 
                            x_tick_fontsize=13,
                            y_tick_fontsize=13,
                            tick_style='sci'):
    # Use the provided axes object
    # Chance level
    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5)

    n_epochs = metric_array.shape[0]
    ax.plot(np.arange(n_epochs) * log_interval, metric_array, label=legend_str or None, linewidth=line_width)
    if se_array is not None:
        ax.fill_between(np.arange(n_epochs)* log_interval, metric_array-se_array, metric_array+se_array, alpha=0.5)

    # Aesthetics
    if fig_title:
        ax.set_title(fig_title, fontsize=title_fontsize, pad=title_pad)
    ax.set_xlabel(x_label, fontsize=x_tick_fontsize, labelpad=labelpad)
    ax.set_ylabel(y_label, fontsize=y_tick_fontsize, labelpad=labelpad)
    ax.set_ylim(y_lim[0], y_lim[1])
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)
    ax.margins(x=0.01)
    # change the spine width
    for spine in ax.spines.values():
        spine.set_linewidth(1.3)
        
    if tick_style == 'sci':
        ax.ticklabel_format(style='sci', scilimits=(-3,3), axis='both')
    ax.tick_params(axis='both', which='both', length=0, labelsize=15)

    # make the offset text larger
    ax.xaxis.get_offset_text().set_fontsize(12)
    if legend_str:
        ax.legend()

    return ax

def plot_several_metrics(metric_array:np.ndarray,
                    se_array:np.ndarray=None,
                    fig_title:str=None,
                    y_label:str="Accuracy",
                    fig_folder:str=None,
                    fig_name:str=None,
                    chance_level:float=0.5,
                    y_lim:tuple=(None, None),
                    color_list:tuple =None,
                    unifiorm_color:bool=False,
                    label_list:tuple = ("Top level", "Mid level", "Bot level"),
                    leg_loc:tuple = None,
                    theory_baseline:tuple=None,
                    title_margin:float=1.0,
                    log_interval:int = 250,
                    line_width:float=1.5,
                    linestyle:str='solid', 
                    labelpad:int=3, 
                    title_fontsize:int=16,
                    default_grid:bool=False):
    """
    Plots multiple metrics with optional standard error bars.

    Args:
        metric_array (np.ndarray): The array of metric values to plot.
        se_array (np.ndarray, optional): The array of standard errors for the metric values. Defaults to None.
        fig_title (str, optional): The title of the figure. Defaults to None.
        y_label (str, optional): The label for the y-axis. Defaults to "Accuracy".
        fig_folder (str, optional): The folder to save the figure. Defaults to None.
        fig_name (str, optional): The name of the figure file. Defaults to None.
        chance_level (float, optional): The chance level for reference line. Defaults to 0.5.
        y_lim (tuple, optional): The y-axis limits. Defaults to (None, None).
        color_list (tuple, optional): The list of colors for each metric. Defaults to None.
        label_list (tuple, optional): The list of labels for each metric. Defaults to ("Top level", "Mid level", "Bottom level").
        leg_loc (tuple, optional): The location of the legend. Defaults to (.75, .2).
        theory_baseline (tuple, optional): The theoretical baseline values. Defaults to None.
        title_margin (float, optional): The margin for the title. Defaults to 1.0.
        log_interval (int, optional): The interval for x-axis ticks. Defaults to 100.
        line_width (float, optional): The width of the line. Defaults to 1.
    Returns:
        matplotlib.axes.Axes: The axes object of the plot.
    """
    
    # Define figure
    fig, ax = plt.subplots(1, 1, figsize=(4.2, 4), dpi=170, facecolor='w')


    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5)

    # if no color list is provided, use the default
    if color_list is None:
        color_list = matplotlib.colormaps['cividis_r'](np.linspace(.09, .85, metric_array.shape[0]))
    # make all lines default blue if uniform color is set
    elif unifiorm_color:
        color_list = ['C0'] * metric_array.shape[0]

    n_levels = metric_array.shape[0]
    n_epochs = metric_array.shape[1]

    for i in range(n_levels):
        ax.plot(np.arange(n_epochs) * log_interval, metric_array[i], color=color_list[i], label=label_list[i] if label_list else None, linewidth=line_width, linestyle=linestyle, zorder=1)
        if se_array is not None:
            ax.fill_between(np.arange(n_epochs)* log_interval, metric_array[i]-se_array[i], metric_array[i]+se_array[i],
                            color=color_list[i], alpha=0.5, zorder=1)
    # Aesthetics
    ax.set_title(fig_title, y=title_margin, fontsize=title_fontsize)
    ax.set_xlabel('Steps, $t$', fontsize=16, labelpad=labelpad)
    ax.set_ylabel(y_label, fontsize=16, labelpad=labelpad)
    # increase the margin of the y-axis label and x axis label

    ax.set_ylim(y_lim[0], y_lim[1])
    # plot grid and move to back
    # set x grid every 2500
    if default_grid: 
        ax.set_xticks(np.arange(0, 15001, 2500))
        # set y ticks every 0.1
        ax.set_yticks(np.arange(0.4, 1.01, 0.1))
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)
    
    ax.margins(x=0.01)
    # change the spine width
    for spine in ax.spines.values():
        spine.set_linewidth(1.3)
        
    # change tick label format
    ax.ticklabel_format(style='sci', scilimits=(-3, 4), axis='both')
    # increase the size of the tick numbers

    # remove the ticks
    ax.tick_params(axis='both', which='both', length=0, labelsize=15)

    # make the offset text larger
    ax.xaxis.get_offset_text().set_fontsize(12)

    # legend
    if leg_loc:
        ax.legend(loc=leg_loc,borderpad=0.5, fontsize=9, frameon=False)

    plt.tight_layout()

    return fig, ax


def plot_several_metrics_axis(ax, 
                            metric_array:np.ndarray,
                            se_array:np.ndarray=None,
                            fig_title:str=None,
                            y_label:str="Accuracy",
                            fig_folder:str=None,
                            fig_name:str=None,
                            chance_level:float=0.5,
                            y_lim:tuple=(None, None),
                            color_list:tuple =None,
                            uniform_color:bool=False,
                            label_list:tuple = ("Top level", "Mid level", "Bot level"),
                            leg_loc:tuple = None,
                            theory_baseline:tuple=None,
                            title_margin:float=1.0,
                            log_interval:int = 250,
                            line_width:float=1.5,
                            linestyle:str='solid', 
                            labelpad:int=3, 
                            title_fontsize:int=13,
                            default_grid:bool=False,
                            x_fontsize=13,
                            y_fontsize=13):
    
    if chance_level:
        ax.axhline(y=chance_level, color='r', linestyle='--', alpha=0.5)

    # Handle color settings
    if color_list is None:
        color_list = matplotlib.colormaps['cividis_r'](np.linspace(0.09, 0.85, metric_array.shape[0]))
    elif uniform_color:
        color_list = ['C0'] * metric_array.shape[0]

    n_levels = metric_array.shape[0]
    n_epochs = metric_array.shape[1]

    # Plot each metric
    for i in range(n_levels):
        ax.plot(np.arange(n_epochs) * log_interval, metric_array[i], color=color_list[i], label=label_list[i], linewidth=line_width, linestyle=linestyle)
        if se_array is not None:
            ax.fill_between(np.arange(n_epochs) * log_interval, metric_array[i] - se_array[i], metric_array[i] + se_array[i], color=color_list[i], alpha=0.5)

    # Set titles and labels
    if fig_title:
        ax.set_title(fig_title, y=title_margin, fontsize=title_fontsize)
    ax.set_xlabel('Steps, $t$', fontsize=x_fontsize, labelpad=labelpad)
    ax.set_ylabel(y_label, fontsize=y_fontsize, labelpad=labelpad)
    ax.set_ylim(y_lim)

    # Grid and aesthetics
    if default_grid:
        ax.set_xticks(np.arange(0, 15001, 2500))
        ax.set_yticks(np.arange(0.4, 1.01, 0.1))
    ax.grid(alpha=0.6)
    ax.set_axisbelow(True)
    ax.margins(x=0.01)

    # Customize spine
    for spine in ax.spines.values():
        spine.set_linewidth(1.3)

    ax.ticklabel_format(style='sci', scilimits=(-3, 4), axis='both')
    ax.tick_params(axis='both', which='both', length=0, labelsize=15)
    ax.xaxis.get_offset_text().set_fontsize(12)

    # Legend
    if leg_loc:
        ax.legend(loc=leg_loc, borderpad=0.5, fontsize=9, frameon=True)

    return ax