import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import LogLocator, FixedLocator
from matplotlib.lines import Line2D
import os
from load_all_result import load_all_result
import math
import numpy as np
from utils import make_title, get_position_dist

def create_epsilon_op_plot(df, dist, cond_and_vis_settings, figsize=(6, 5), ax=None, ylabel=True):
    """
    The function to create epsilon_op plot
    
    Args:
        df: the dataframe of the result
        dist: the distribution name
        cond_and_vis_settings: the list of conditions and visualization settings
        figsize: the size of the figure (only used if ax is not specified)
        ax: the axes object to draw on (if None, a new figure is created)
    
    Returns:
        fig, ax: the figure and axes object of matplotlib
    """
    df_d = df[df["dist"] == dist]
    if df_d.empty:
        print(f"No data for distribution: {dist}")
        return None, None

    # if ax is not specified, a new figure is created
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        new_fig = True
    else:
        fig = ax.figure
        new_fig = False

    max_y = 0

    for cond_and_vis_setting in cond_and_vis_settings:
        cond = cond_and_vis_setting["cond"]
        sort = cond["sort"]
        n = cond["n"]
        df_dm = df_d[df_d["sort"] == sort]
        df_dm = df_dm[df_dm["n"] == n]
        if df_dm.empty:
            continue
        df_dm = df_dm.sort_values(by="epsilon")

        if cond_and_vis_setting["vis_setting"].get("hline", False):
            hline = df_dm["op"].values[0] / n
            max_y = max(max_y, hline)
            vis_setting = cond_and_vis_setting["vis_setting"]
            ax.axhline(
                y=hline,
                color=vis_setting["color"],
                linestyle=vis_setting["linestyle"],
                linewidth=3,
                label=vis_setting["label"]
            )
            continue

        epsilons = []
        op_ns = []
        op_stds = []
        for epsilon in df_dm["epsilon"].unique():
            df_dm_epsilon = df_dm[df_dm["epsilon"] == epsilon]
            epsilons.append(epsilon)
            op_ns.append(df_dm_epsilon["op"] / df_dm_epsilon["n"])
            op_stds.append(df_dm_epsilon["op std"] / df_dm_epsilon["n"])
        max_y = max(max_y, max([max(op_n) for op_n in op_ns]))
        vis_setting = cond_and_vis_setting["vis_setting"]

        epsilons = np.array(epsilons)
        op_ns = np.array(op_ns).flatten()
        op_stds = np.array(op_stds).flatten()

        ax.plot(
            epsilons,
            op_ns,
            label=vis_setting["label"],
            marker=vis_setting["marker"],
            markersize=vis_setting["markersize"],
            linestyle=vis_setting["linestyle"],
            color=vis_setting["color"],
            alpha=vis_setting.get("alpha", 1.0)
        )
        ax.fill_between(
            epsilons,
            (op_ns - op_stds),
            (op_ns + op_stds),
            alpha=0.2,
            color=vis_setting["color"]
        )

    ax.set_xscale("log")
    ax.xaxis.set_major_locator(FixedLocator([1, 10, 100, 1000]))
    ax.xaxis.set_minor_locator(LogLocator(base=10.0, subs=range(2, 10), numticks=100))
    ax.yaxis.set_major_locator(FixedLocator([0, 5, 10, 15, 20, 25]))
    ax.grid(True, which="both", ls="--", c='gray')
    if new_fig:  # if a new figure is created, the title is set
        ax.set_title(make_title(dist), fontsize=32)
    ax.tick_params(axis='both', which='major', labelsize=22)
    ax.set_xlabel(r"$\varepsilon$", fontsize=32)
    if ylabel:
        ax.set_ylabel("#Comparisons / n", fontsize=32)
    else:
        ax.set_ylabel("")
    ax.set_ylim(0, max_y * 1.1)

    return fig, ax

def plot_epsilon_op_per_dist(df, dist, cond_and_vis_settings, fig_path):
    """
    The function to create and save the epsilon_op plot for a single distribution
    """
    fig, ax = create_epsilon_op_plot(df, dist, cond_and_vis_settings)
    if fig is None:
        return
    
    plt.tight_layout()
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, bbox_inches='tight', transparent=True)
    print("Saved to ", fig_path)
    plt.close()

def save_legend_only_epsilon_op(cond_and_vis_settings, fig_path, ncol=1):
    """
    The function to save the legend only
    """
    handles = []
    labels = []
    for cond_and_vis_setting in cond_and_vis_settings:
        vis_setting = cond_and_vis_setting["vis_setting"]
        if vis_setting.get("hline", False):
            handle = Line2D(
                [0],
                [0],
                color=vis_setting["color"],
                linestyle=vis_setting["linestyle"],
                linewidth=3,
                label=vis_setting["label"],
            )
        else:
            handle = Line2D(
                [0],
                [0],
                color=vis_setting["color"],
                marker=vis_setting["marker"],
                linestyle=vis_setting["linestyle"],
                markersize=vis_setting["markersize"],
                label=vis_setting["label"],
            )
        handles.append(handle)
        labels.append(vis_setting["label"])

    fig = plt.figure(figsize=(10, 2.6))
    fig.legend(
        handles=handles,
        labels=labels,
        loc="center",
        ncol=ncol,
        frameon=False,
        fontsize=18,
        markerscale=1.0,
    )
    plt.tight_layout()
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    fig.savefig(fig_path, bbox_inches="tight")
    plt.close(fig)

def plot_epsilon_op(df, cond_and_vis_settings, dists, fig_path, nrows, ncols):
    """
    The function to create the epsilon_op plot for multiple distributions
    """
    dists = sorted(dists, key=get_position_dist)
    fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows), sharey=True)
    axes = axes.flatten() if isinstance(axes, (list, np.ndarray)) else [axes]
    dist_to_ax = {dist: ax for dist, ax in zip(dists, axes)}
    for ax in axes[len(dists):]:
        ax.set_visible(False)

    max_y = 0

    for dist_i, dist in enumerate(dists):
        ax = dist_to_ax[dist]
        
        # create the epsilon_op plot for each distribution
        _, temp_ax = create_epsilon_op_plot(df, dist, cond_and_vis_settings, ax=ax)
        if temp_ax is not None:
            # get the y-axis range and reflect it to the maximum value of all y-axes
            y_lim = temp_ax.get_ylim()
            max_y = max(max_y, y_lim[1])
            
            # the label is only displayed for the first distribution
            if dist_i % ncols == 0:
                ax.set_ylabel(r"\# comparisons / $n$", fontsize=44)
            else:
                ax.set_ylabel("")

    # unify the y-axis range of all axes
    for ax in axes:
        if ax.get_visible():
            ax.set_ylim(0, max_y)

    plt.tight_layout()
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, bbox_inches='tight', transparent=True)
    print("Saved to ", fig_path)
    plt.close()
