from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
import scipy
from scipy import stats
from matplotlib.ticker import FuncFormatter
from scipy.stats import spearmanr


def prediction_error_downstream(
    train_val_pairs=[
        ("c4_original", "c4_val"),
        ("rpj", "c4_val"),
        ("rw_original", "c4_val"),
    ],
    model_dir="exp_data/models_tok",
    de_model_dir="exp_data/models_tok_de_en",
    eval_dir="exp_data/evals_tok",
    small_mult=16.0,
    fit_models=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    Ms=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
    Ns=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8", "open_lm_1b", "open_lm_7b"],
    irreducible_error_estimate_cc=1.0,
    push="avg",
    prefix="err_",
):

    mpl.rcParams["figure.dpi"] = 300
    font = {
        "size": 22,
    }
    mpl.rc("font", **font)

    fig, axes = plt.subplots(
        nrows=1, ncols=len(train_val_pairs) + 1, constrained_layout=True, gridspec_kw={"width_ratios": [1, 1, 1, 0.03]}
    )
    cbar_ax = axes[-1]  # the placement of the colorbar

    spearman_x = []
    spearman_y = []

    for i, (train_dataset, val_dataset) in enumerate(train_val_pairs):
        df = parse_model_jsons(
            model_dir if "de-en" not in val_dataset else de_model_dir,
            cc_mults=Ms,
            datasets=[train_dataset],
            eval_dir=eval_dir,
        )

        df_mults, names = split_df_by_mult(df, fit_models)

        df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

        xs_irr = df_mults_dict[irreducible_error_estimate_cc]["flops"].tolist()
        ys_irr = df_mults_dict[irreducible_error_estimate_cc][f"loss_{val_dataset}"].tolist()
        ys2_irr = df_mults_dict[irreducible_error_estimate_cc][f"{prefix}{push}"].tolist()
        ms_irr = df_mults_dict[irreducible_error_estimate_cc]["tok_mult"].tolist()
        ns_irr = df_mults_dict[irreducible_error_estimate_cc]["N"].tolist()

        tmp = df_mults_dict[small_mult]
        xs_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["flops"].tolist())
        ys_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"][f"loss_{val_dataset}"].tolist())
        ys2_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"][f"{prefix}{push}"].tolist())
        ms_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["tok_mult"].tolist())
        ns_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["N"].tolist())

        popt_approach2 = curve_fit_powlaw_approach2(np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float))

        # # add 1b for this
        df_double = parse_model_jsons(model_dir, cc_mults=[1.0], datasets=[train_dataset], eval_dir=eval_dir)
        df_mults_double, _ = split_df_by_mult(df_double, ["open_lm_1b"])
        for ii, df_mult in enumerate(df_mults_double):
            tmp2 = df_mult[(df_mult["model_name"] == "open_lm_1b")]
            assert len(tmp2["flops"].tolist()) == 1

            xs_irr.extend(tmp2["flops"].tolist())
            ys_irr.extend(tmp2[f"loss_{val_dataset}"].tolist())
            ms_irr.extend(tmp2["tok_mult"].tolist())
            ns_irr.extend(tmp2["N"].tolist())
            ys2_irr.extend(tmp2[f"{prefix}{push}"].tolist())

        (a, b, E), _ = scipy.optimize.curve_fit(
            lambda t, a, b, E: E - a * np.exp(-b * t),
            ys_irr,
            ys2_irr,
            maxfev=10000,
        )

        errors = []
        for N in Ns:
            curr = []
            for M in Ms:
                # compute prediction error
                gt = df[(df["model_name"] == N) & (df["cc_mult"] == M)][f"{prefix}{push}"].tolist()

                if not len(gt) == 1:
                    gt = float("NaN")
                else:
                    gt = gt[0]

                pred = powlaw_approach2(np.array([[NAME_PARAMS[N]], [M * 20.0]]), *popt_approach2)
                assert len(pred) == 1
                pred = pred[0]
                pred2 = E - a * np.exp(-b * pred)

                curr.append(abs(gt - pred2) / gt)

                if not math.isnan(gt):
                    spearman_x.append(pred2)
                    spearman_y.append(gt)

            errors.append(curr)

        fmt = lambda x, pos: "{:.0%}".format(x)
        hm = sns.heatmap(
            errors,
            ax=axes[i],
            annot=True,
            vmin=0.0,
            vmax=0.1,
            cmap="coolwarm",
            xticklabels=[int(m * 20) for m in Ms],
            yticklabels=[MODEL_FRINDLIES[n] for n in Ns] if i == 0 else [],
            fmt=".1%",
            cbar=False if i != len(train_val_pairs) - 1 else True,
            cbar_ax=cbar_ax,
            cbar_kws={"format": FuncFormatter(fmt)},
        )
        if i != 0:
            axes[i].set_ylabel("")
        else:
            axes[i].set_ylabel(f"$N$")

        hm.set_facecolor("linen")
        hm.set_yticklabels(labels=hm.get_yticklabels(), va="center")

        axes[i].set_title(f"Train: {DATASET_FRIENDLIES[train_dataset]},\nDownstream: {VAL_FRIENDLIES[push]}")
        axes[i].set_xlabel(f"$M$")
        axes[i].add_patch(Rectangle((1, 0), 1, 5, fill=False, edgecolor="yellow", lw=3))
        axes[i].add_patch(Rectangle((len(fit_models), 0), 1, 1, fill=False, edgecolor="yellow", lw=3))

    cbar_ax.set_ylabel(f"Relative error")

    fig.set_size_inches(22.5, 7.0)
    print(len(spearman_x))
    res = spearmanr(spearman_x, spearman_y)
    print(spearman_x, spearman_y)
    print(push + " rank corr: " + str(res.correlation))
