from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import scipy


def downstream_emperical(
    train_sets=["c4_original", "rpj", "rw_original"],
    model_dir="exp_data/models_tok",
    Ms=[
        1.0,
        2.0,
        4.0,
        8.0,
        16.0,
        32.0,
    ],
    eval_dir="exp_data/evals_tok",
    x_axis=["loss_c4_val", "loss_c4_val", "loss_c4_val"],
    y_axis="err_avg_subset",
    x_labels=["Loss: C4 eval", "Loss: C4 eval", "Loss: C4 eval"],
    y_labels="Average top-1 error: 17-task split",
):
    font = {
        "size": 12,
    }
    mpl.rc("font", **font)

    fig, axes = plt.subplots(nrows=1, ncols=len(x_axis), constrained_layout=True, figsize=(10, 4))
    if len(x_axis) != 2:
        fig.set_size_inches(len(x_axis) * 30 / 7, 4.0)

    for i, ax in enumerate(axes):
        train_dataset = train_sets[i]
        ax.set_title(f"Training set: {DATASET_FRIENDLIES[train_dataset]}")
        ax.invert_xaxis()
        # for _, train_dataset in enumerate(train_sets):

        df = parse_model_jsons(
            model_dir,
            cc_mults=Ms,
            datasets=[train_dataset],
            eval_dir=eval_dir,
        )
        df = df[df["model_name"] != "open_lm_1b"]
        df = df[df["model_name"] != "open_lm_7b"]
        df = df[df["tok_mult"] > 19.0]

        x = df[x_axis[i]].to_list()
        y = df[y_axis].to_list()

        ax.scatter(x, y, alpha=0.7, zorder=9, color="tab:blue", marker="o")

        xseq = np.linspace(min(x), max(x), num=40)

        (a, b, E), _ = scipy.optimize.curve_fit(
            lambda t, a, b, E: E - a * np.exp(t) ** (-b),
            x,
            y,
            maxfev=10000,
        )

        # # Plot regression line
        x_margins = [None, None]

        if x_axis[i] == "loss_c4_val":
            x_margins = [2.5, 6.0]
        elif x_axis[i] == "loss_paloma_c4_en":
            x_margins = [2.5, 6.0]
        elif x_axis[i] == "loss_openlm":
            x_margins = [2, 8]
        elif x_axis[i] == "loss_paloma_redpajama":
            x_margins = [1.8, 8.2]
        elif x_axis[i] == "loss_paloma_falcon-refinedweb":
            x_margins = [2.8, 6.8]

        xseq = np.linspace(x_margins[0], x_margins[1], num=40)
        ax.plot(
            xseq,
            E - a * np.exp(xseq) ** (-b),
            color="tab:blue",
            linestyle="dashed",
            zorder=8,
        )

        xseq = np.linspace(min(x), max(x), num=40)
        ax.plot(
            xseq,
            E - a * np.exp(-b * xseq),
            color="tab:blue",
            zorder=8,
        )

        ax.margins(y=0.01, x=0.0)
        ax.autoscale()

        ax.set_ylim(bottom=0.66 if y_axis == "err_avg" else 0.5)
        ax.set_xlabel(x_labels[i])
        if i == 0:
            ax.set_ylabel(y_labels)

        ax.grid()

    handles = None
    if ax is not None:
        handles, _ = ax.get_legend_handles_labels()
    else:
        handles, _ = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    # for d in train_sets:
    #     more_handles.append(
    #         Line2D(
    #             [0],
    #             [0],
    #             label=DATASET_FRIENDLIES[d],
    #             color=DATASET_COLORS[d],
    #             marker=DATASET_SHAPES[d],
    #             linestyle="",
    #         )
    #     )

    # for ms in df["model_name"].unique():
    #     more_handles.append(
    #         Line2D(
    #             [0],
    #             [0],
    #             label="$N = $" + MODEL_FRINDLIES[ms],
    #             color="tab:blue",
    #             marker=MODEL_SHAPES[ms],
    #             linestyle="",
    #         )
    #     )
    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Model",
            color="tab:blue",
            marker="o",
            linestyle="",
        )
    )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Interpolation",
            color="tab:blue",
            marker="",
        )
    )
    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Extrapolation",
            color="tab:blue",
            linestyle="--",
        )
    )

    # add manual symbols to auto legend
    handles.extend(more_handles)

    ax.legend(
        handles=handles,
        loc="lower left",
        ncol=1,
    )
