from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import pandas as pd
import scipy


def non_trivial():

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), constrained_layout=True)
    axes[0].grid(which="major", ls="-")
    axes[1].grid(which="major", ls="-")
    axes[0].set_ylabel("Relative prediction error")
    axes[0].set_xlabel(
        "Inclusion threshold $t$ (i.e., include evals where any model gets\n$t$ percentage points above random chance at 0.154B scales)"
    )
    axes[1].set_xlabel("Number of excluded datasets (out of 46-total)")

    axes[0].set_yscale("log")
    axes[1].set_yscale("log")

    datasets = ["c4_original", "rpj", "rw_original"]

    model_dir = "exp_data/models_tok"
    eval_dir = "exp_data/evals_tok"
    df = parse_model_jsons(
        model_dir,
        datasets,
        cc_mults=[
            0.25,
            0.5,
            1.0,
            2.0,
            4.0,
            8.0,
            16.0,
            32.0,
        ],
        eval_dir=eval_dir,
    )

    df = df[df["model_name"] == "d=576_l=24_h=8"]
    df = df[[col for col in df.columns if "err_" in col or "tok_mult" in col]].set_index("tok_mult")
    df = df.T

    splits = []
    threshes = np.linspace(-0.05, 0.50, 12).tolist()
    for thresh in threshes:
        low_var = set()
        # print(thresh)

        for index, row in df.iterrows():
            if index.replace("err_", "") in RANDOM_BASELINE:
                tmp = (1.0 - row) - RANDOM_BASELINE[index.replace("err_", "")]

                if any([ele > thresh for ele in tmp.tolist()]):
                    low_var.add(index.replace("err_", ""))

        splits.append(list(low_var))

    # make threshes percentages
    threshes = [100 * t for t in threshes]

    for dataset in datasets:

        # make a loss scaling law
        df = parse_model_jsons(
            model_dir,
            cc_mults=[1.0, 32.0],
            datasets=[dataset],
            eval_dir=eval_dir,
        )

        dfs_model, model_names = split_df_by_model(df)

        sevenb_row = dfs_model[-1]

        assert len(sevenb_row["model_name"].tolist()) == 1
        assert sevenb_row["model_name"].tolist()[0] == "open_lm_7b"
        assert sevenb_row["cc_mult"].tolist()[0] == 1.0

        # fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), constrained_layout=True)
        fit_models = ["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"]
        df_mults, names = split_df_by_mult(df, fit_models)

        df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

        ys2_irrs = []
        xs_irr = df_mults_dict[1.0]["flops"].tolist()
        ys_irr = df_mults_dict[1.0][f"loss_c4_val"].tolist()
        ms_irr = df_mults_dict[1.0][f"tok_mult"].tolist()
        ns_irr = df_mults_dict[1.0][f"N"].tolist()

        targets = []

        for split in splits:
            hi = []
            hi_target = []
            for s in split:
                ys2_irr = df_mults_dict[1.0][f"err_{s}"].tolist()
                hi.append(ys2_irr)
                hi_target.append(sevenb_row[f"err_{s}"].tolist())
            hi = np.sum(np.array(hi), axis=0) / len(hi)
            hi_target = np.sum(np.array(hi_target), axis=0) / len(hi_target)
            ys2_irrs.append(hi.tolist())
            targets.append(hi_target.tolist())

        df2 = parse_model_jsons(model_dir, cc_mults=[16.0], datasets=[dataset], eval_dir=eval_dir)
        df_mults2, _ = split_df_by_mult(df2, fit_models)
        for ii, df_mult in enumerate(df_mults2):
            tmp = df_mult[(df_mult["model_name"] == "d=96_l=8_h=4")]
            assert len(tmp["flops"].tolist()) == 1

            xs_irr.extend(tmp["flops"].tolist())
            ys_irr.extend(tmp[f"loss_c4_val"].tolist())
            ms_irr.extend(tmp["tok_mult"].tolist())
            ns_irr.extend(tmp["N"].tolist())

            for i, split in enumerate(splits):
                hi = []
                for s in split:
                    ys2_irr = tmp[f"err_{s}"].tolist()
                    hi.append(ys2_irr)
                # print(hi)
                hi = np.sum(np.array(hi), axis=0) / len(hi)
                # print(hi)
                ys2_irrs[i].extend(hi)
        # print(ys2_irrs)
        # exit(0)
        # scaling law to predict val loss
        popt_approach2 = curve_fit_powlaw_approach2(np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float))

        # print(popt_approach2)
        # exit(0)

        # # add 1b for this
        df_double = parse_model_jsons(model_dir, cc_mults=[1.0], datasets=[dataset], eval_dir=eval_dir)
        df_mults_double, _ = split_df_by_mult(df_double, ["open_lm_1b"])
        for ii, df_mult in enumerate(df_mults_double):
            tmp2 = df_mult[(df_mult["model_name"] == "open_lm_1b")]
            assert len(tmp2["flops"].tolist()) == 1

            xs_irr.extend(tmp2["flops"].tolist())
            ys_irr.extend(tmp2[f"loss_c4_val"].tolist())
            ms_irr.extend(tmp2["tok_mult"].tolist())
            ns_irr.extend(tmp2["N"].tolist())
            # ys2_irr.extend(tmp2[downstream].tolist())

            for i, split in enumerate(splits):
                hi = []
                for s in split:
                    ys2_irr = tmp2[f"err_{s}"].tolist()
                    hi.append(ys2_irr)
                hi = np.sum(np.array(hi), axis=0) / len(hi)
                ys2_irrs[i].extend(hi.tolist())

        # print(ys_irr)
        # print(ys2_irrs)
        # exit(0)

        # (al, b, e), _ = scipy.optimize.curve_fit(
        #     lambda t, a, b, e: e - a * np.exp(t) ** (-b),
        #     ys_irr,
        #     ys2_irr,
        #     maxfev=10000,
        # )
        relative_errors = []
        for i, s in enumerate(splits):
            ys_irr_approx = powlaw_approach2(np.array([ns_irr, ms_irr]), *popt_approach2)
            (al, b, e), _ = scipy.optimize.curve_fit(
                lambda t, a, b, e: e - a * np.exp(t) ** (-b),
                ys_irr_approx,
                ys2_irrs[i],
                maxfev=10000,
            )

            loss_probe = powlaw_approach2(np.array([NAME_PARAMS["open_lm_7b"], 20.0]), *popt_approach2)
            error_probe = e - al * np.exp(loss_probe) ** (-b)
            error_gt = targets[i]

            relative_errors.append(abs(error_probe - error_gt[0]) / error_gt[0])

        excluded = []
        for s in splits:
            excluded.append(46 - len(s))

        axes[0].plot(threshes, relative_errors, marker="o", zorder=9, color=DATASET_COLORS[dataset])
        axes[1].plot(
            excluded,
            relative_errors,
            marker="o",
            zorder=9,
            color=DATASET_COLORS[dataset],
            label=DATASET_FRIENDLIES[dataset],
        )

    axes[1].legend()
