from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import LogNorm, Normalize
from tqdm import tqdm
import random


def error_vs_count(
    train_val_pair=("rpj", "c4_val"),
    model_dir="exp_data/models_tok",
    Ms=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    Ns=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    target_N="open_lm_1b",
):
    font = {
        "size": 13,
    }
    mpl.rc("font", **font)

    random.seed(0)
    np.random.seed(0)

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), sharey=True)

    train_dataset, val_dataset = train_val_pair
    df = parse_model_jsons(model_dir, cc_mults=Ms, datasets=[train_dataset])  # + [target_N],
    df = df[["flops", "tok_mult", "N", f"loss_{val_dataset}", "model_name", "cc_mult"]].reset_index(drop=True)
    sampling_pool = df[(df["tok_mult"] != 640.0) | (df["N"] != 1439795200)]
    sampling_pool = sampling_pool[(sampling_pool["N"] != 6889410560)]
    target = df[(df["tok_mult"] == 640.0) & (df["N"] == 1439795200)]

    df2 = parse_model_jsons(
        model_dir,
        cc_mults=[16.0],
        datasets=[train_dataset],
    )
    df_mults2, _ = split_df_by_mult(df2, Ns)

    df_mults, names = split_df_by_mult(df, Ns)

    df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

    xs_irr = df_mults_dict[1.0]["flops"].tolist()
    ys_irr = df_mults_dict[1.0][f"loss_{val_dataset}"].tolist()
    ms_irr = df_mults_dict[1.0][f"tok_mult"].tolist()
    ns_irr = df_mults_dict[1.0][f"N"].tolist()

    for ii, df_mult in enumerate(df_mults2):
        tmp = df_mult[(df_mult["model_name"] == "d=96_l=8_h=4")]
        assert len(tmp["flops"].tolist()) == 1

        xs_irr.extend(tmp["flops"].tolist())
        ys_irr.extend(tmp[f"loss_{val_dataset}"].tolist())
        ms_irr.extend(tmp["tok_mult"].tolist())
        ns_irr.extend(tmp["N"].tolist())

    popt_ours = curve_fit_powlaw_approach2(np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float))

    flopss = []
    errs = []
    num_for_estimate = []

    for num_samples in tqdm(np.arange(5, 29)):
        inner = 0
        while inner < 5:
            for inner2 in [5, 10, 20, 25, 30]:
                if inner2 >= num_samples:
                    try:
                        idx = np.random.choice(np.arange(0, inner2), size=num_samples, replace=False)
                        tmp = sampling_pool[sampling_pool.index.isin(idx.tolist())]
                        popt_approach2 = curve_fit_powlaw_approach2(
                            np.array([tmp["N"], tmp["tok_mult"]]), tmp["loss_c4_val"]
                        )
                        pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [640.0]]), *popt_approach2)
                        flops = np.sum(tmp["flops"].tolist())
                        flopss.append(float(flops))
                        errs.append(
                            abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0]
                        )
                        num_for_estimate.append(num_samples)
                        inner += 1
                    except Exception as e:
                        print(e)
                        pass

    idx = np.arange(0, 29)
    tmp = sampling_pool[sampling_pool.index.isin(idx.tolist())]
    popt_approach2 = curve_fit_powlaw_approach2(np.array([tmp["N"], tmp["tok_mult"]]), tmp["loss_c4_val"])
    pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [640.0]]), *popt_approach2)
    flops = np.sum(tmp["flops"].tolist())
    flopss.append(float(flops))
    errs.append(abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0])
    num_for_estimate.append(len(tmp))

    for i, ax in enumerate(axes):
        ax.set_yscale("log")

        if i == 0:
            ax.set_xscale("log")

        ax.scatter(
            flopss if i == 0 else num_for_estimate,
            errs,
            color="tab:red",
            marker="o",
            s=70,
            alpha=0.3,
            zorder=9,
            label="Individual estimates" if i == 1 else None,
        )

        if i == 0:
            computes = [5e17, 2.5e21]
            popt = curve_fit_powlaw(flopss, errs)
            y = powlaw(
                np.array(computes).astype(float),
                *popt,
            )
            ax.plot(computes, y, color="tab:red")
            ax.set_xlabel("Compute [FLOPs] used for the scaling fit")
            ax.set_ylabel("Relative error: C4 eval")
            ax.grid(which="major", ls="-")

        else:
            popt = curve_fit_powlaw(num_for_estimate, errs)
            y = powlaw(
                np.array([5, 30]).astype(float),
                *popt,
            )
            ax.plot([4, 31], y, color="tab:red", label="Trend")
            ax.set_xlabel("Number of samples used for the scaling fit")
            ax.grid(which="major", ls="-")

        ax.set_ylim(top=1.0)
        ax.margins(y=0.1, x=0.0)
        ax.autoscale()

    # ours
    pred = powlaw_approach2(np.array([[NAME_PARAMS[target_N]], [640.0]]), *popt_ours)
    axes[0].scatter(
        [sum(xs_irr)],
        [abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0]],
        color="tab:blue",
        marker="*",
        s=100,
        alpha=0.8,
        zorder=9,
    )
    axes[1].scatter(
        [
            5,
        ],
        [abs(target["loss_c4_val"].tolist()[0] - pred.item()) / target["loss_c4_val"].tolist()[0]],
        color="tab:blue",
        marker="*",
        s=100,
        alpha=0.8,
        zorder=9,
        label="Default setting from Table 2",
    )

    ax.legend()
    fig.tight_layout()
