from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import LogNorm, Normalize
from tqdm import tqdm
import random
import scipy


def error_vs(
    train_val_pair=("rpj", "c4_val"),
    model_dir="exp_data/models_tok",
    eval_dir="exp_data/evals_tok",
    Ms=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    Ns=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    target_N="open_lm_1b",
    target_M=32.0,
    downstream="err_avg_subset",
):
    font = {
        "size": 13,
    }
    mpl.rc("font", **font)

    random.seed(0)
    np.random.seed(0)
    dataset = train_val_pair[0]

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), sharey=True, sharex=True)

    train_dataset, val_dataset = train_val_pair
    df = parse_model_jsons(model_dir, cc_mults=Ms, datasets=[train_dataset], eval_dir=eval_dir)
    df = df[["flops", "tok_mult", "N", f"loss_{val_dataset}", "model_name", "cc_mult", downstream]].reset_index(
        drop=True
    )
    sampling_pool = df[(df["tok_mult"] != 640.0) | (df["N"] != 1439795200)]
    sampling_pool = sampling_pool[(sampling_pool["N"] != 6889410560)]

    target = df[(df["model_name"] == target_N) & (df["tok_mult"] == 20 * target_M)]

    df2 = parse_model_jsons(model_dir, cc_mults=[16.0], datasets=[train_dataset], eval_dir=eval_dir)
    df_mults2, _ = split_df_by_mult(df2, Ns)

    df_mults, names = split_df_by_mult(df, Ns)

    df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

    xs_irr = df_mults_dict[1.0]["flops"].tolist()
    ys_irr = df_mults_dict[1.0][f"loss_{val_dataset}"].tolist()
    ms_irr = df_mults_dict[1.0][f"tok_mult"].tolist()
    ns_irr = df_mults_dict[1.0][f"N"].tolist()
    ys2_irr = df_mults_dict[1.0][downstream].tolist()

    for ii, df_mult in enumerate(df_mults2):
        tmp = df_mult[(df_mult["model_name"] == "d=96_l=8_h=4")]
        assert len(tmp["flops"].tolist()) == 1

        xs_irr.extend(tmp["flops"].tolist())
        ys_irr.extend(tmp[f"loss_{val_dataset}"].tolist())
        ms_irr.extend(tmp["tok_mult"].tolist())
        ns_irr.extend(tmp["N"].tolist())
        ys2_irr.extend(tmp[downstream].tolist())

    popt_ours = curve_fit_powlaw_approach2(np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float))

    pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [target_M * 20]]), *popt_ours)
    axes[0].scatter(
        [sum(xs_irr)],
        [abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0]],
        color="tab:blue",
        marker="*",
        s=100,
        alpha=0.8,
        zorder=9,
    )

    df_double = parse_model_jsons(model_dir, cc_mults=[1.0], datasets=[dataset], eval_dir=eval_dir)
    df_mults_double, _ = split_df_by_mult(df_double, ["open_lm_1b"])
    for ii, df_mult in enumerate(df_mults_double):
        tmp2 = df_mult[(df_mult["model_name"] == "open_lm_1b")]
        assert len(tmp2["flops"].tolist()) == 1

        xs_irr.extend(tmp2["flops"].tolist())
        ys_irr.extend(tmp2[f"loss_c4_val"].tolist())
        ms_irr.extend(tmp2["tok_mult"].tolist())
        ns_irr.extend(tmp2["N"].tolist())
        ys2_irr.extend(tmp2[downstream].tolist())

    ys_irr_approx = powlaw_approach2(np.array([ns_irr, ms_irr]), *popt_ours)
    (al, b, e), _ = scipy.optimize.curve_fit(
        lambda t, a, b, e: e - a * np.exp(t) ** (-b),
        ys_irr_approx,
        ys2_irr,
        maxfev=10000,
    )

    loss_probe = powlaw_approach2(np.array([NAME_PARAMS[target_N], target_M * 20.0]), *popt_ours)
    pred_ds = e - al * np.exp(loss_probe) ** (-b)

    axes[1].scatter(
        [sum(xs_irr)],
        [abs(target[downstream].tolist()[0] - pred_ds.item()) / target[downstream].tolist()[0]],
        color="tab:blue",
        marker="*",
        s=100,
        alpha=0.8,
        zorder=9,
        label="Default setting from Table 2",
    )

    flopss = []
    errs = []
    errs_ds = []
    num_for_estimate = []

    for num_samples in tqdm(np.arange(5, 29)):
        inner = 0
        while inner < 5:
            for inner2 in [5, 10, 20, 25, 30]:
                if inner2 >= num_samples:
                    try:
                        idx = np.random.choice(np.arange(0, inner2), size=num_samples, replace=False)
                        tmp = sampling_pool[sampling_pool.index.isin(idx.tolist())]
                        popt_approach2 = curve_fit_powlaw_approach2(
                            np.array([tmp["N"], tmp["tok_mult"]]), tmp["loss_c4_val"]
                        )
                        pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [target_M * 20]]), *popt_approach2)
                        flops = np.sum(tmp["flops"].tolist())
                        flopss.append(float(flops))
                        errs.append(
                            abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0]
                        )
                        num_for_estimate.append(num_samples)

                        (al, b, e), _ = scipy.optimize.curve_fit(
                            lambda t, a, b, e: e - a * np.exp(t) ** (-b),
                            tmp["loss_c4_val"],
                            tmp[downstream],
                            maxfev=10000,
                        )
                        pred_ds = e - al * np.exp(pred) ** (-b)
                        errs_ds.append(
                            abs(target[downstream].tolist()[0] - pred_ds.item()) / target[downstream].tolist()[0]
                        )

                        inner += 1
                    except Exception as e:
                        print(e)
                        pass

    idx = np.arange(0, 29)
    tmp = sampling_pool[sampling_pool.index.isin(idx.tolist())]
    popt_approach2 = curve_fit_powlaw_approach2(np.array([tmp["N"], tmp["tok_mult"]]), tmp["loss_c4_val"])
    pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [target_M * 20]]), *popt_approach2)
    flops = np.sum(tmp["flops"].tolist())
    flopss.append(float(flops))
    errs.append(abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0])
    num_for_estimate.append(len(tmp))

    (al, b, e), _ = scipy.optimize.curve_fit(
        lambda t, a, b, e: e - a * np.exp(t) ** (-b),
        tmp["loss_c4_val"],
        tmp[downstream],
        maxfev=10000,
    )
    pred_ds = e - al * np.exp(pred) ** (-b)
    errs_ds.append(abs(target[downstream].tolist()[0] - pred_ds.item()) / target[downstream].tolist()[0])

    for i, ax in enumerate(axes):
        ax.set_yscale("log")

        # if i == 0:
        ax.set_xscale("log")

        ax.scatter(
            flopss,
            errs if i == 0 else errs_ds,
            color="tab:red",
            marker="o",
            s=70,
            alpha=0.3,
            zorder=9,
            label="Individual estimates" if i == 1 else None,
        )

        if i == 0:
            computes = [5e17, 2.5e21]
            popt = curve_fit_powlaw(flopss, errs)
            y = powlaw(
                np.array(computes).astype(float),
                *popt,
            )
            ax.plot(computes, y, color="tab:red")
            ax.set_xlabel("Compute [FLOPs] used for the scaling fit")
            ax.set_ylabel("Relative error: C4 eval")
            ax.grid(which="major", ls="-")

        else:
            computes = [5e17, 2.5e21]
            popt = curve_fit_powlaw(flopss, errs_ds)
            y = powlaw(
                np.array(computes).astype(float),
                *popt,
            )
            ax.plot(computes, y, color="tab:red", label="Trend")
            ax.set_xlabel("Compute [FLOPs] used for the scaling fit")
            ax.set_ylabel("Relative error: 17-task split")
            ax.grid(which="major", ls="-")

        ax.set_ylim(top=1.0)
        ax.margins(y=0.1, x=0.0)
        ax.autoscale()

    ax.legend().set_zorder(102)
    fig.tight_layout()
