from overtraining.plotting.shared import *


def emperical(
    model_dir="exp_data/models_tok",
    de_model_dir="exp_data/models_tok_de_en",
    eval_dir="exp_data/evals_tok",
    cc_mults=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    models=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    plot_points=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    dataset_val_pairs=[
        ("c4_original", "c4_val"),
        ("rpj", "c4_val"),
        ("rw_original", "c4_val"),
    ],
    irreducible_error_estimate_cc=1.0,
    compute_range=[3e15, 3e17, 3e18, 3e19, 3e21],
    fixed_Es=None,
):
    mpl.rcParams["figure.dpi"] = 300
    font = {
        "size": 13,
    }
    mpl.rc("font", **font)

    val_sufs = [e[-1] for e in dataset_val_pairs]

    fig, axs = plt.subplots(nrows=1, ncols=len(dataset_val_pairs), constrained_layout=True, figsize=(10, 4))
    if len(dataset_val_pairs) != 2:
        fig.set_size_inches(len(dataset_val_pairs) * 30 / 7, 4.0)

    d_m_slope_pairs = {}

    for ax_it, (dataset, val_suffix) in enumerate(dataset_val_pairs):

        d_m_slope_pairs[dataset] = {"x": [], "y": []}

        fixed_E = None
        if fixed_Es is not None:
            fixed_E = fixed_Es[ax_it]

        ax = axs[ax_it]
        ax.set_ylim(top=6)

        df = parse_model_jsons(
            model_dir if "de-en" not in val_suffix else de_model_dir,
            cc_mults=cc_mults,
            datasets=[dataset],
            eval_dir=eval_dir,
        )

        dfs_model, model_names = split_df_by_model(df)

        if all([ee == val_sufs[0] for ee in val_sufs]):
            if ax_it == 0:
                ax.set_ylabel(f"Reducible loss: {VAL_FRIENDLIES[val_suffix]}")
        else:
            ax.set_ylabel(f"Reducible loss: {VAL_FRIENDLIES[val_suffix]}")

        ax.set_xlabel("Compute ($6ND, D=MN$) [FLOPs]")
        ax.set_yscale("log")
        ax.set_yticks([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
        ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
        ax.set_xscale("log")
        # ax.set_xticks([1e17, 1e19, 1e21])
        # ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
        ax.grid(which="major", ls="-")
        # ax.margins(y=0.0, x=0.0)

        ax.set_title(f"Training set: {DATASET_FRIENDLIES[dataset]}")

        df_mults, names = split_df_by_mult(df, models)

        df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

        xs_irr = df_mults_dict[irreducible_error_estimate_cc]["flops"].tolist()
        ys_irr = df_mults_dict[irreducible_error_estimate_cc][f"loss_{val_suffix}"].tolist()
        popt = curve_fit_powlaw_irreducible(
            np.array(xs_irr).astype(float), np.array(ys_irr).astype(float), fixed_E=fixed_E
        )

        _, _, E = popt
        print(f"emperical, irreducible error train on: {dataset} val on: {val_suffix} E: {E}")

        # Setting the color map and normalizing the c values for the color map
        cmap = plt.get_cmap("cool")
        cNorm = colors.LogNorm(vmin=min(cc_mults), vmax=cc_mults[-1])  # values range from 0.25 to 64
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)

        for ii, df_model in enumerate(dfs_model):
            if model_names[ii] not in plot_points:
                continue

            for i in range(len(df_model["flops"].tolist())):
                l = df_model[f"loss_{val_suffix}"].tolist()[i]
                asymmetric_error = [
                    [df_model[f"loss_upper_{val_suffix}"].tolist()[i] - l],
                    [l - df_model[f"loss_lower_{val_suffix}"].tolist()[i]],
                ]

                ax.errorbar(
                    df_model["flops"].tolist()[i],
                    (df_model[f"loss_{val_suffix}"] - E).tolist()[i],
                    # yerr=np.array(asymmetric_error).astype(np.float64),
                    color=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                    ecolor=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                    marker=df_model["shape"].tolist()[0],
                    alpha=0.8,
                    zorder=9,
                    capsize=4,
                    markersize=8,
                )

        for ii, df_mult in enumerate(df_mults):
            xs = df_mult["flops"].tolist()
            ys = df_mult[f"loss_{val_suffix}"].tolist()

            # add lines for chinchilla mult trends
            if len(xs) < 2:
                continue

            popt = curve_fit_powlaw(np.array(xs).astype(float), np.array(ys).astype(float) - E)

            d_m_slope_pairs[dataset]["x"].append(df_mult["cc_mult"].tolist()[0] * 20)
            d_m_slope_pairs[dataset]["y"].append(popt[-1])

            x = np.array(compute_range).astype(float)
            y = powlaw(x, *popt)

            x2 = np.array([xs[0], xs[-1]]).astype(float)
            y2 = powlaw(x2, *popt)
            ax.plot(x, y, linestyle="dashed", color=scalarMap.to_rgba(df_mult["cc_mult"].tolist()[0]))
            ax.plot(x2, y2, color=scalarMap.to_rgba(df_mult["cc_mult"].tolist()[0]))
            ax.margins(y=0.0, x=0.0)
            ax.autoscale()

    # Adding a color bar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, max(cc_mults)))

    cbar = plt.colorbar(sm, ax=axs, aspect=12, pad=0.01)
    cbar.set_label("token multiplier $M$", labelpad=15)  # rotation=270)
    cbar.set_ticks(np.linspace(0, max(cc_mults), len(cc_mults)))
    cbar.set_ticklabels([f"{int(m*20)}" for m in cc_mults])

    handles = None
    if ax is not None:
        handles, _ = ax.get_legend_handles_labels()
    else:
        handles, _ = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    for ms in models:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label="$N = $" + MODEL_FRINDLIES[ms],
                color="grey",
                marker=MODEL_SHAPES[ms],
                linestyle="",
            )
        )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Interpolation",
            color="grey",
            marker="",
            linestyle="-",
        )
    )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Extrapolation",
            color="grey",
            marker="",
            linestyle="--",
        )
    )
    # add manual symbols to auto legend
    handles.extend(more_handles)

    plt.legend(
        handles=handles,
        loc="upper right",
        bbox_to_anchor=(1.35, 0.97) if len(dataset_val_pairs) == 3 else (1.3, 1.03),
        # bbox_to_anchor=bbox_to_anchor,
        ncol=1,
    )
