import scipy

from overtraining.plotting.shared import *
from overtraining.plotting.constants import *

downstream = [
    "err_arc_easy",
    "err_lambada_openai",
    "err_openbook_qa",
    "err_hellaswag_zeroshot",
    "err_avg_subset",
]
# downstream += [f"err_{e}" for e in SUBSET]
# downstream = [
#     "err_hellaswag_zeroshot",
# ]
datasets = ["c4_original", "rpj", "rw_original"]
with_1b = [True, False]


model_dir = "exp_data/models_tok"
eval_dir = "exp_data/evals_tok"

fig, axes = plt.subplots(nrows=1, ncols=len(datasets), constrained_layout=True, figsize=(10, 4))

rows = []
row_names = []
for ii, dataset in enumerate(datasets):
    axes[ii].invert_xaxis()
    axes[ii].margins(y=0.0, x=0.0)
    axes[ii].autoscale()
    df = parse_model_jsons(
        model_dir,
        cc_mults=[1.0],
        datasets=[dataset],
        eval_dir=eval_dir,
    )

    plot_err_pairs = []

    # errs = []
    # pred = []

    for w in with_1b:
        row = []
        for ds in downstream:
            popt_approach2, (al, b, e), _ = fit_ds(dataset, ds, w)

            loss_probe = powlaw_approach2(np.array([NAME_PARAMS["open_lm_7b"], 20.0]), *popt_approach2)
            error_probe = e - al * np.exp(loss_probe) ** (-b)
            error_gt = df[(df["model_name"] == "open_lm_7b")][ds].tolist()
            assert len(error_gt) == 1
            error_gt = error_gt[0]

            # errs.append(error_gt)
            # pred.append(error_probe)

            plot_err_pairs.append(((al, b, e), abs(error_probe - error_gt) / error_gt))

            row.append(f"{abs(error_probe - error_gt) * 100/ error_gt:.2f}")
        row_names.append(f"with1b={w}_{dataset}")
        rows.append(row)

        print(row_names[-1])
        print("\% &".join(rows[-1]) + "\% \\\\")
        # print(downstream)
        # print(errs)
        # print(pred)
        # print()
        # print(plot_err_pairs)
        # for f, badness in plot_err_pairs[1:]:
        #     a, b, E = f
        #     xseq = np.linspace(2.5, 5.0, num=40)
        #     axes[ii].plot(
        #         xseq,
        #         E - a * np.exp(xseq) ** (-b),
        #         color="tab:blue",
        #         linestyle="dashed",
        #         zorder=8,
        #     )
        # f, badness = plot_err_pairs[0]
        # a, b, E = f
        # xseq = np.linspace(2.5, 5.0, num=40)
        # axes[ii].plot(
        #     xseq,
        #     E - a * np.exp(xseq) ** (-b),
        #     color="tab:red",
        #     # linestyle="dashed",
        #     zorder=8,
        # )

        print("----")

# plt.savefig("ind_and_avg.png")
# plt.savefig("ind_and_avg.pdf")
# print(row_names)
# print(rows)
