from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import scipy


def downstream_corr_all(
    train_sets=["c4_original", "rpj", "rw_original"],
    model_dir="exp_data/models_tok",
    Ms=[
        1.0,
        2.0,
        4.0,
        8.0,
        16.0,
        32.0,
    ],
    eval_dir="exp_data/evals_tok",
):
    font = {
        "size": 12,
    }
    mpl.rc("font", **font)

    fig, axes = plt.subplots(nrows=8, ncols=6, constrained_layout=True, figsize=(27, 35))
    axes = axes.flatten()
    # if len(x_axis) != 2:
    #     fig.set_size_inches(len(x_axis) * 30 / 7, 4.0)

    axes[-1].axis("off")
    axes[-2].axis("off")
    # axes[-3].axis("off")

    for i, ds in enumerate(sorted(FRIENDLY_CITATIONS.keys())):
        ax = axes[i]
        ax.invert_xaxis()
        for _, train_dataset in enumerate(train_sets):

            df = parse_model_jsons(
                model_dir,
                cc_mults=Ms,
                datasets=[train_dataset],
                eval_dir=eval_dir,
            )
            # df = df[df["model_name"] != "open_lm_1b"]
            # df = df[df["model_name"] != "open_lm_7b"]
            # df = df[df["tok_mult"] > 19.0]

            x = df["loss_c4_val"].to_list()
            y = df[f"err_{ds}"].to_list()

            ax.scatter(
                x,
                y,
                alpha=0.7,
                zorder=9,
                color=DATASET_COLORS[train_dataset],
                marker=DATASET_SHAPES[train_dataset],
            )
            ax.plot(
                [min(x), max(x)],
                [1.0 - RANDOM_BASELINE[ds]] * 2,
                color="black",
                # linestyle="dotted",
            )
            ax.set_ylabel(f"Top-1 error: {FRIENDLY_CITATIONS[ds].split('~')[0]}")
            ax.set_xlabel("Loss: C4 eval")

            # xseq = np.linspace(min(x), max(x), num=40)

            # (a, b, E), _ = scipy.optimize.curve_fit(
            #     lambda t, a, b, E: E - a * np.exp(t) ** (-b),
            #     x,
            #     y,
            #     maxfev=10000,
            # )

            # # Plot regression line
            # x_margins = [None, None]

            # if x_axis[i] == "loss_c4_val":
            #     x_margins = [2.5, 6.0]
            # elif x_axis[i] == "loss_paloma_c4_en":
            #     x_margins = [2.5, 6.0]
            # elif x_axis[i] == "loss_openlm":
            #     x_margins = [2, 8]
            # elif x_axis[i] == "loss_paloma_redpajama":
            #     x_margins = [1.8, 8.2]
            # elif x_axis[i] == "loss_paloma_falcon-refinedweb":
            #     x_margins = [2.8, 6.8]

            # xseq = np.linspace(x_margins[0], x_margins[1], num=40)
            # ax.plot(
            #     xseq,
            #     E - a * np.exp(xseq) ** (-b),
            #     color=DATASET_COLORS[train_dataset],
            #     linestyle="dashed",
            #     zorder=8,
            # )

            # xseq = np.linspace(min(x), max(x), num=40)
            # ax.plot(
            #     xseq,
            #     E - a * np.exp(-b * xseq),
            #     color=DATASET_COLORS[train_dataset],
            #     zorder=8,
            # )

            # ax.margins(y=0.01, x=0.0)
            ax.autoscale()

        # ax.set_ylim(bottom=0.66 if y_axis[i] == "err_avg" else 0.45)
        # ax.set_xlabel(x_labels[i])
        # if all(jj == y_labels[0] for jj in y_labels):
        #     if i == 0:
        #         ax.set_ylabel(y_labels[i])
        # else:
        #     ax.set_ylabel(y_labels[i])
        ax.grid()

    handles = None
    if ax is not None:
        handles, _ = ax.get_legend_handles_labels()
    else:
        handles, _ = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    for d in train_sets:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label=DATASET_FRIENDLIES[d],
                color=DATASET_COLORS[d],
                marker=DATASET_SHAPES[d],
                linestyle="",
            )
        )
    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Random chance",
            color="black",
            marker=None,
            # linestyle="dotted",
        )
    )

    # add manual symbols to auto legend
    handles.extend(more_handles)

    # fig.legend(lines, labels, loc = (0.5, 0), ncol=5)
    plt.legend(
        handles=handles,
        loc="lower left",
        ncol=1,
        # bbox_to_anchor=(1, 0.5),
    )
    # fig.tight_layout()

    # fig.tight_layout()
