from overtraining.plotting.shared import *
from overtraining.plotting.constants import DATASET_SHAPES


def slopes(
    model_dir="exp_data/models_tok",
    de_model_dir="exp_data/models_tok_de_en",
    eval_dir="exp_data/evals_tok",
    cc_mults=[1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    models=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    dataset_val_pairs=[
        ("c4_original", "c4_val"),
        ("rpj", "c4_val"),
        ("rw_original", "c4_val"),
    ],
    irreducible_error_estimate_cc=1.0,
):
    mpl.rcParams["figure.dpi"] = 300
    font = {
        "size": 13,
    }
    mpl.rc("font", **font)

    _, ax3 = plt.subplots(nrows=1, ncols=1, figsize=(11, 2.5), constrained_layout=True)

    d_m_slope_pairs = {}

    for _, (dataset, val_suffix) in enumerate(dataset_val_pairs):

        d_m_slope_pairs[dataset] = {"x": [], "y": []}

        df = parse_model_jsons(
            model_dir if "de-en" not in val_suffix else de_model_dir,
            cc_mults=cc_mults,
            datasets=[dataset],
            eval_dir=eval_dir,
        )

        df_mults, names = split_df_by_mult(df, models)
        df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

        xs_irr = df_mults_dict[irreducible_error_estimate_cc]["flops"].tolist()
        ys_irr = df_mults_dict[irreducible_error_estimate_cc][f"loss_{val_suffix}"].tolist()
        popt = curve_fit_powlaw_irreducible(
            np.array(xs_irr).astype(float), np.array(ys_irr).astype(float), fixed_E=None
        )

        _, _, E = popt
        print(f"emperical, irreducible error train on: {dataset} val on: {val_suffix} E: {E}")

        for _, df_mult in enumerate(df_mults):
            xs = df_mult["flops"].tolist()
            ys = df_mult[f"loss_{val_suffix}"].tolist()

            # # add lines for chinchilla mult trends
            # if len(xs) < 2:
            #     continue

            popt = curve_fit_powlaw(np.array(xs).astype(float), np.array(ys).astype(float) - E)

            d_m_slope_pairs[dataset]["x"].append(df_mult["cc_mult"].tolist()[0] * 20)
            d_m_slope_pairs[dataset]["y"].append(popt[-1])

    handles = []

    # create manual symbols for legend
    more_handles = []
    for e in d_m_slope_pairs:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label=DATASET_FRIENDLIES[e],
                color=DATASET_COLORS[e],
                marker=DATASET_SHAPES[e],
                linestyle="",
            )
        )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Trend",
            color="grey",
            marker="",
            linestyle="-",
        )
    )
    # add manual symbols to auto legend
    handles.extend(more_handles)

    plt.legend(
        handles=handles,
        loc="upper right",
        bbox_to_anchor=(1.22, 0.85),
        ncol=1,
    )

    ax3.set_ylim(bottom=0.11, top=0.15)
    ax3.set_xscale("log")
    ax3.set_xticks([20, 40, 80, 160, 320, 640])
    ax3.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax3.set_ylabel(r"Scaling exponent $\eta$")
    ax3.set_xlabel(r"Token multiplier $M$")
    ax3.grid(which="major", ls="-", axis="y")

    for e in d_m_slope_pairs:
        popts = []
        for _ in tqdm(range(10_000)):
            idx = np.random.choice(len(d_m_slope_pairs[e]["x"]), size=len(d_m_slope_pairs[e]["x"]))
            z = np.polyfit(np.log(d_m_slope_pairs[e]["x"])[idx], np.array(d_m_slope_pairs[e]["y"])[idx], 1)
            p = np.poly1d(z)
            popts.append(p)

        y_samps = []
        for p in tqdm(popts):
            y_samps.append(-p(np.log([20.0, 640.0])))

        stuff = np.vstack(y_samps)

        lowers = np.percentile(stuff, 2.5, axis=0)
        uppers = np.percentile(stuff, 97.5, axis=0)

        ax3.plot([20, 640], lowers, linestyle="dashed", color=DATASET_COLORS[e], alpha=0.4)
        ax3.plot([20, 640], uppers, linestyle="dashed", color=DATASET_COLORS[e], alpha=0.4)
        ax3.fill_between([20, 640], lowers, uppers, color=DATASET_COLORS[e], alpha=0.1)

    for e in d_m_slope_pairs:
        ax3.scatter(
            d_m_slope_pairs[e]["x"],
            [-f for f in d_m_slope_pairs[e]["y"]],
            color=DATASET_COLORS[e],
            marker=DATASET_SHAPES[e],
        )
        z = np.polyfit(np.log(d_m_slope_pairs[e]["x"]), d_m_slope_pairs[e]["y"], 1)
        p = np.poly1d(z)

        # add trendline to plot
        ax3.plot(d_m_slope_pairs[e]["x"], -p(np.log(d_m_slope_pairs[e]["x"])), color=DATASET_COLORS[e])
