import matplotlib.pyplot as plt
import os
from load_all_result import load_all_result
from make_title import make_title
import numpy as np
from matplotlib.lines import Line2D

def create_n_time_plot(df, dist, vis_setting, figsize=(6, 5), ax=None, y_lim=None, ylabel=True):
    """
    The function to create the time plot (1 point per time)

    Args:
        df: the dataframe of the result data
        dist: the distribution name
        vis_setting: the visualization setting
        figsize: the size of the figure (only used if ax is not specified)
        ax: the axes object to draw on (if None, a new figure is created)

    Returns:
        fig, ax: the figure and axes object of matplotlib
    """
    df_d = df[df["dist"] == dist]
    if df_d.empty:
        print(f"No data for distribution: {dist}")
        return None, None

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        new_fig = True
    else:
        fig = ax.figure
        new_fig = False

    y_min = np.inf
    y_max = -np.inf

    for sort in vis_setting["sorts"]:
        df_dm = df_d[df_d["sort"] == sort]
        if df_dm.empty:
            continue

        df_dm = df_dm.sort_values(by="n")
        df_dm["time"] = df_dm["time"] / df_dm["n"]
        time_std = df_dm["time std"] / df_dm["n"]

        settings = vis_setting["sorts"][sort]
        ax.plot(
            df_dm["n"],
            df_dm["time"],
            label=settings["label"],
            marker=settings["marker"],
            markersize=settings["markersize"],
            linestyle=settings["linestyle"],
            color=settings["color"]
        )
        ax.fill_between(
            df_dm["n"],
            df_dm["time"] - time_std,
            df_dm["time"] + time_std,
            alpha=0.2,
            color=settings["color"]
        )

        y_min = min(y_min, df_dm["time"].min())
        y_max = max(y_max, df_dm["time"].max())

    ax.set_xscale("log")
    if new_fig:
        ax.set_title(make_title(dist), fontsize=32)
    ax.tick_params(axis='both', which='major', labelsize=22)
    ax.set_xlabel("n", fontsize=32)
    if ylabel:
        ax.set_ylabel("Time [s] / n", fontsize=32)
    else:
        ax.set_ylabel("")
    ax.grid(True, which="both", linestyle="--", c="gray")
    ax.set_xlim(10000 * 0.7, 1e7 / 0.7)

    if y_lim is None:
        y_lim = (0.0, 1.5e-7)
    if y_min < y_lim[0] or y_max > y_lim[1]:
        if y_min < y_lim[0]:
            print(f"[Warning] y_min {y_min} < {y_lim[0]} for {dist}")
        if y_max > y_lim[1]:
            print(f"[Warning] y_max {y_max} > {y_lim[1]} for {dist}")
        ax.set_ylim(y_lim)
    else:
        ax.set_ylim(y_lim)

    return fig, ax


def plot_n_time_per_n(df, dist, vis_setting, fig_path):
    fig, ax = create_n_time_plot(df, dist, vis_setting)
    if fig is None:
        return

    plt.tight_layout()
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, bbox_inches='tight')
    plt.close()


def save_legend_only_n_time(vis_setting, fig_path, ncol=1):
    handles = []
    labels = []
    for _, settings in vis_setting["sorts"].items():
        handle = Line2D(
            [0],
            [0],
            color=settings["color"],
            marker=settings["marker"],
            linestyle=settings["linestyle"],
            markersize=settings["markersize"],
            label=settings["label"],
        )
        handles.append(handle)
        labels.append(settings["label"])

    fig = plt.figure(figsize=(10, 2.6))
    fig.legend(
        handles=handles,
        labels=labels,
        loc="center",
        ncol=ncol,
        frameon=False,
        fontsize=18,
        markerscale=1.0,
    )
    plt.tight_layout()
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    fig.savefig(fig_path, bbox_inches="tight")
    plt.close(fig)
