from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import scipy


def downstream_corr(
    train_sets=["c4_original", "rpj", "rw_original"],
    model_dir="exp_data/models_tok",
    Ms=[
        1.0,
        2.0,
        4.0,
        8.0,
        16.0,
        32.0,
    ],
    eval_dir="exp_data/evals_tok",
    x_axis=["loss_c4_val", "loss_c4_val"],
    y_axis=["err_avg", "err_avg_subset"],
    x_labels=["Loss: C4 eval", "Loss: C4 eval"],
    y_labels=["Average top-1 error: 46-task split", "Average top-1 error: 17-task split"],
):
    font = {
        "size": 12,
    }
    mpl.rc("font", **font)

    fig, axes = plt.subplots(nrows=1, ncols=len(x_axis), constrained_layout=True, figsize=(10, 4))
    if len(x_axis) != 2:
        fig.set_size_inches(len(x_axis) * 30 / 7, 4.0)

    for i, ax in enumerate(axes):
        ax.invert_xaxis()
        for _, train_dataset in enumerate(train_sets):

            df = parse_model_jsons(
                model_dir,
                cc_mults=Ms,
                datasets=[train_dataset],
                eval_dir=eval_dir,
            )
            df = df[df["model_name"] != "open_lm_1b"]
            df = df[df["model_name"] != "open_lm_7b"]
            df = df[df["tok_mult"] > 19.0]

            x = df[x_axis[i]].to_list()
            y = df[y_axis[i]].to_list()

            ax.scatter(
                x,
                y,
                alpha=0.7,
                zorder=9,
                color=DATASET_COLORS[train_dataset],
                marker=DATASET_SHAPES[train_dataset],
            )

            xseq = np.linspace(min(x), max(x), num=40)

            (a, b, E), _ = scipy.optimize.curve_fit(
                lambda t, a, b, E: E - a * np.exp(t) ** (-b),
                x,
                y,
                maxfev=10000,
            )

            # # Plot regression line
            x_margins = [None, None]

            if x_axis[i] == "loss_c4_val":
                x_margins = [2.5, 6.0]
            elif x_axis[i] == "loss_paloma_c4_en":
                x_margins = [2.5, 6.0]
            elif x_axis[i] == "loss_openlm":
                x_margins = [2, 8]
            elif x_axis[i] == "loss_paloma_redpajama":
                x_margins = [1.8, 8.2]
            elif x_axis[i] == "loss_paloma_falcon-refinedweb":
                x_margins = [2.8, 6.8]

            xseq = np.linspace(x_margins[0], x_margins[1], num=40)
            ax.plot(
                xseq,
                E - a * np.exp(xseq) ** (-b),
                color=DATASET_COLORS[train_dataset],
                linestyle="dashed",
                zorder=8,
            )

            xseq = np.linspace(min(x), max(x), num=40)
            ax.plot(
                xseq,
                E - a * np.exp(-b * xseq),
                color=DATASET_COLORS[train_dataset],
                zorder=8,
            )

            ax.margins(y=0.01, x=0.0)
            ax.autoscale()

        ax.set_ylim(bottom=0.66 if y_axis[i] == "err_avg" else 0.45)
        ax.set_xlabel(x_labels[i])
        if all(jj == y_labels[0] for jj in y_labels):
            if i == 0:
                ax.set_ylabel(y_labels[i])
        else:
            ax.set_ylabel(y_labels[i])
        ax.grid()

    handles = None
    if ax is not None:
        handles, _ = ax.get_legend_handles_labels()
    else:
        handles, _ = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    for d in train_sets:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label=DATASET_FRIENDLIES[d],
                color=DATASET_COLORS[d],
                marker=DATASET_SHAPES[d],
                linestyle="",
            )
        )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Interpolation",
            color="grey",
            marker="",
        )
    )
    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Extrapolation",
            color="grey",
            linestyle="--",
        )
    )

    # add manual symbols to auto legend
    handles.extend(more_handles)

    ax.legend(
        handles=handles,
        loc="lower left",
        ncol=1,
    )
