from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import LogNorm, Normalize
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FuncFormatter


def prediction_error(
    train_val_pairs=[("c4_original", "c4_val"), ("rpj", "c4_val"), ("rw_original", "c4_val")],
    model_dir="exp_data/models_tok",
    de_model_dir="exp_data/models_tok_de_en",
    eval_dir="exp_data/eval_tok",
    small_mult=16.0,
    fit_models=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    Ms=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    Ns=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8", "open_lm_1b", "open_lm_7b"],
    irreducible_error_estimate_cc=1.0,
    fixed_Es=None,
):

    mpl.rcParams["figure.dpi"] = 300
    font = {
        "size": 22,
    }
    mpl.rc("font", **font)

    fig, axes = plt.subplots(
        nrows=1, ncols=len(train_val_pairs) + 1, constrained_layout=True, gridspec_kw={"width_ratios": [1, 1, 1, 0.03]}
    )
    cbar_ax = axes[-1]  # the placement of the colorbar

    for i, (train_dataset, val_dataset) in enumerate(train_val_pairs):
        fixed_E = None
        if fixed_Es is not None:
            fixed_E = fixed_Es[i]
        df = parse_model_jsons(
            model_dir if "de-en" not in val_dataset else de_model_dir,
            cc_mults=Ms,
            datasets=[train_dataset],
            eval_dir=eval_dir,
        )

        # dfs_model, model_names = split_df_by_model(df)
        df_mults, names = split_df_by_mult(df, fit_models)

        df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

        xs_irr = df_mults_dict[irreducible_error_estimate_cc]["flops"].tolist()
        ys_irr = df_mults_dict[irreducible_error_estimate_cc][f"loss_{val_dataset}"].tolist()
        ms_irr = df_mults_dict[irreducible_error_estimate_cc]["tok_mult"].tolist()
        ns_irr = df_mults_dict[irreducible_error_estimate_cc]["N"].tolist()

        tmp = df_mults_dict[small_mult]
        xs_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["flops"].tolist())
        ys_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"][f"loss_{val_dataset}"].tolist())
        ms_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["tok_mult"].tolist())
        ns_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["N"].tolist())

        popt_approach2 = curve_fit_powlaw_approach2(
            np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float), fixed_E=fixed_E
        )

        errors = []
        for N in Ns:
            curr = []
            for M in Ms:
                # compute prediction error
                gt = df[(df["model_name"] == N) & (df["cc_mult"] == M)][f"loss_{val_dataset}"].tolist()
                if not len(gt) == 1:
                    gt = float("NaN")
                else:
                    gt = gt[0]

                pred = powlaw_approach2(np.array([[NAME_PARAMS[N]], [M * 20.0]]), *popt_approach2)
                assert len(pred) == 1
                pred = pred[0]

                curr.append(abs(gt - pred) / gt)
                # curr.append(gt)

            errors.append(curr)

        fmt = lambda x, pos: "{:.1%}".format(x)
        hm = sns.heatmap(
            errors,
            ax=axes[i],
            annot=True,
            vmin=0.0,
            vmax=0.1,
            cmap="coolwarm",
            xticklabels=[int(m * 20) for m in Ms],
            yticklabels=[MODEL_FRINDLIES[n] for n in Ns] if i == 0 else [],
            fmt=".1%",
            cbar=False if i != len(train_val_pairs) - 1 else True,
            cbar_ax=cbar_ax,
            cbar_kws={"format": FuncFormatter(fmt)},
        )
        if i != 0:
            axes[i].set_ylabel("")
        else:
            axes[i].set_ylabel(f"$N$")

        hm.set_facecolor("linen")
        hm.set_yticklabels(labels=hm.get_yticklabels(), va="center")

        axes[i].set_title(f"Train: {DATASET_FRIENDLIES[train_dataset]}\nEval: {VAL_FRIENDLIES[val_dataset]}")
        axes[i].set_xlabel(f"$M$")
        axes[i].add_patch(Rectangle((1, 0), 1, 4, fill=False, edgecolor="yellow", lw=3))
        axes[i].add_patch(Rectangle((5, 0), 1, 1, fill=False, edgecolor="yellow", lw=3))

    cbar_ax.set_ylabel(f"Relative error")

    fig.set_size_inches(22.5, 7.0)
