from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
import scipy


def add_zoom(ax):
    axins = zoomed_inset_axes(ax, 6, loc="lower left")  # zoom = 6
    # axins.figure.set_size_inches(0.1, 0.1)

    # sub region of the original image
    x1, x2, y1, y2 = 3e21, 1.3e22, 0.56, 0.72
    axins.set_xlim(x1, x2)
    axins.set_ylim(y1, y2)
    # axins.set_axes_locator(InsetPosition(ax, [0.125, 0.1, 0.35, 0.35]))
    axins.set_axes_locator(InsetPosition(ax, [0.6, 0.6, 0.35, 0.35]))
    axins.set_yscale("log")
    axins.set_xscale("log")
    axins.minorticks_off()

    # axins.set_yticks([])
    # axins.set_yticks([0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74])
    axins.set_yticks([0.56, 0.60, 0.64, 0.68, 0.72])
    axins.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    axins.get_yaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())

    axins.set_xticks([1e22])
    axins.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    axins.get_xaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())
    axins.set_xticklabels(["$10^{22}$"])

    # axins.get_xaxis().set_visible(False)
    # axins.get_yaxis().set_visible(False)
    axins.grid(which="major", ls="-")
    # axins.set_yticklabels([])  # x-axis
    # axins.set_xticklabels([])  # y-axis

    # draw a bbox of the region of the inset axes in the parent axes and
    # connecting lines between the bbox and the inset axes area
    mark_inset(ax, axins, loc1=3, loc2=4, fc="none", ec="0.5")

    return axins


def add_zoom2(ax):
    axins = zoomed_inset_axes(ax, 6, loc="lower left")  # zoom = 6
    # axins.figure.set_size_inches(0.1, 0.1)

    # sub region of the original image
    x1, x2, y1, y2 = 2.3, 2.7, 0.43, 0.53
    axins.set_xlim(x2, x1)
    axins.set_ylim(y1, y2)
    # axins.set_axes_locator(InsetPosition(ax, [0.6, 0.1, 0.35, 0.55]))
    axins.set_axes_locator(InsetPosition(ax, [0.12, 0.12, 0.35, 0.55]))
    # axins.set_yscale("log")
    # axins.set_xscale("log")
    axins.minorticks_off()

    # axins.set_yticks([0.56, 0.60, 0.64, 0.68, 0.72])
    # axins.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    # axins.get_yaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())

    # axins.set_xticks([1e22])
    # axins.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    # axins.get_xaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())
    # axins.set_xticklabels(["$10^{22}$"])

    axins.grid(which="major", ls="-")
    # draw a bbox of the region of the inset axes in the parent axes and
    # connecting lines between the bbox and the inset axes area
    # mark_inset(ax, axins, loc1=2, loc2=3, fc="none", ec="0.5")
    #

    mark_inset(ax, axins, loc1=4, loc2=1, fc="none", ec="0.5")
    # axins.invert_xaxis()
    # axins.set_xlim(axins.get_xlim()[::-1])

    return axins


def figure1(
    val_suffix="c4_val",
    model_dir="exp_data/models_tok",
    cc_mults=[1.0, 32.0],
    points_1x=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8", "open_lm_1b", "open_lm_7b"],
    small_mults=[
        # 8.0,
        16.0,
    ],
    fit_models=["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"],
    irreducible_error_estimate_cc=1.0,
    dataset="rpj",
    eval_dir="exp_data/evals_tok",
    prefix="loss_",
    downstream="err_avg_subset",
):

    popt_approach2, (al, b, e), ys_irr = fit_ds(dataset, downstream, True)

    mpl.rcParams["figure.dpi"] = 300
    font = {
        "size": 11,
    }
    mpl.rc("font", **font)

    cmap = plt.get_cmap("cool")
    cNorm = colors.LogNorm(vmin=0.5, vmax=32.0)  # values range from 0.25 to 64
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)

    df = parse_model_jsons(
        model_dir,
        cc_mults=cc_mults,
        datasets=[dataset],
        eval_dir=eval_dir,
    )

    dfs_model, model_names = split_df_by_model(df)

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), constrained_layout=True)
    ax = axes[0]
    ax_down = axes[1]

    ax.set_ylabel(f"Reducible loss: {VAL_FRIENDLIES[val_suffix]}")
    ax.set_xlabel("Compute ($6ND, D=MN$) [FLOPs]")
    ax.set_yscale("log")
    ax.set_yticks([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
    ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xscale("log")
    ax.grid(which="major", ls="-")

    ax_down.set_ylabel(f"Average top-1 error: 17-task split")
    ax_down.set_xlabel(f"Loss: {VAL_FRIENDLIES[val_suffix]}")
    ax_down.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax_down.grid(which="major", ls="-")
    # ax_down.set_yticks([0.6, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.74, 0.76])
    ax_down.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())

    # ax.set_title(f"Training set: {DATASET_FRIENDLIES[dataset]}")
    fig.tight_layout()

    ax_down.invert_xaxis()

    axins = add_zoom(ax)
    axins2 = add_zoom2(ax_down)

    df_mults, names = split_df_by_mult(df, fit_models)

    df2 = parse_model_jsons(model_dir, cc_mults=small_mults, datasets=[dataset], eval_dir=eval_dir)
    df_mults2, _ = split_df_by_mult(df2, fit_models)

    x_range = np.linspace(2.0, 6.0, 100)
    ax_down.plot(x_range, e - al * np.exp(x_range) ** (-b), linestyle="dashed", color="tab:grey")
    axins2.plot(x_range, e - al * np.exp(x_range) ** (-b), linestyle="dashed", color="tab:grey")

    x_range = np.linspace(min(ys_irr), max(ys_irr), 100)
    ax_down.plot(x_range, e - al * np.exp(x_range) ** (-b), color="tab:grey")
    # ax.plot(computes, y, linestyle="dashed", color="tab:grey")
    # axins.plot(computes, y, linestyle="dashed", color="tab:grey")

    E = popt_approach2[-1]

    for ii, df_mult in enumerate(df_mults2):
        tmp = df_mult[(df_mult["model_name"] == "d=96_l=8_h=4")]
        l = tmp[f"{prefix}{val_suffix}"].tolist()[0]
        # asymmetric_error = [
        #     [tmp[f"{prefix}upper_{val_suffix}"].tolist()[0] - l],
        #     [l - tmp[f"{prefix}lower_{val_suffix}"].tolist()[0]],
        # ]

        for ax_count, a in enumerate([ax, axins]):
            a.errorbar(
                tmp["flops"].tolist()[0],
                (tmp[f"{prefix}{val_suffix}"] - E).tolist()[0],
                # yerr=asymmetric_error,
                # color=color_map[tmp["cc_mult"].tolist()[0]],
                # ecolor=color_map[tmp["cc_mult"].tolist()[0]],
                color=scalarMap.to_rgba(tmp["cc_mult"].tolist()[0]),
                ecolor=scalarMap.to_rgba(tmp["cc_mult"].tolist()[0]),
                marker=tmp["shape"].tolist()[0],
                markersize=8 if ax_count == 0 else 12,
                alpha=0.9,
                zorder=9,
                # capsize=4,
            )
        ax_down.errorbar(
            tmp[f"{prefix}{val_suffix}"].tolist()[0],
            tmp[downstream].tolist()[0],
            # yerr=asymmetric_error,
            # color=color_map[tmp["cc_mult"].tolist()[0]],
            # ecolor=color_map[tmp["cc_mult"].tolist()[0]],
            color=scalarMap.to_rgba(tmp["cc_mult"].tolist()[0]),
            ecolor=scalarMap.to_rgba(tmp["cc_mult"].tolist()[0]),
            marker=tmp["shape"].tolist()[0],
            markersize=8,
            alpha=0.9,
            zorder=9,
        )

    # print(f"E + A N^-alpha + B D^-alpha trained on: {dataset}, eval on: {val_suffix}")
    # print(f"E: {E}, A: {popt_approach2[0]}, B: {popt_approach2[1]}, alpha: {popt_approach2[2]}")

    params = [10569312, 78914048, 153677376, 411616256, 1439795200]

    for ii, df_mult in enumerate(df_mults):
        xs = df_mult["flops"].tolist()
        ys = (df_mult[f"{prefix}{val_suffix}"]).tolist()
        ns = (df_mult["N"]).tolist()
        mult = names[ii]

        # add lines for chinchilla mult trends
        if len(xs) < 2:
            continue

        # popt = curve_fit_powlaw(np.array(xs).astype(float), np.array(ys).astype(float) - E)

        # print(mult)
        if mult == irreducible_error_estimate_cc:
            computes = [3e15, 3e17, 3e18, 3e19, 3e20, 3e22]
            n = np.sqrt(np.array(computes).astype(float) / (6 * 32.0 * 20)).tolist()
            y = (
                powlaw_approach2(
                    np.array([n, [640.0, 640.0, 640.0, 640.0, 640.0, 640.0]]).astype(float),
                    *popt_approach2,
                )
                - E
            )
            ax.plot(computes, y, linestyle="dashed", color="tab:grey")
            axins.plot(computes, y, linestyle="dashed", color="tab:grey")

            n = np.sqrt(np.array(computes).astype(float) / (6 * 1.0 * 20)).tolist()
            y = (
                powlaw_approach2(
                    np.array([n, [20.0, 20.0, 20.0, 20.0, 20.0, 20.0]]).astype(float),
                    *popt_approach2,
                )
                - E
            )
            ax.plot(computes, y, linestyle="dashed", color="tab:grey")
            axins.plot(computes, y, linestyle="dashed", color="tab:grey")

            # adding the text
            # text = CurvedText(
            #     x=computes,
            #     y=y,
            #     text="extrapolation in $N$",  #'this this is a very, very long text',
            #     va="bottom",
            #     axes=ax,  ##calls ax.add_artist in __init__
            # )

            y = (
                powlaw_approach2(
                    np.array([[ns[0], ns[-1]], [20.0, 20.0]]).astype(float),
                    *popt_approach2,
                )
                - E
            )

            ax.plot([xs[0], xs[-1]], y, color="tab:grey")

            ms_11m = np.linspace(1.0, max(small_mults), 100) * 20
            ns_11m = [10569312 for _ in range(len(ms_11m))]
            ys_11m = powlaw_approach2(np.array([ns_11m, ms_11m.tolist()]), *popt_approach2).astype(float) - E
            xs_11m = [6 * m * 10569312**2 for m in ms_11m]

            ax.plot(xs_11m, ys_11m, color="tab:grey")

            ms_11m = np.linspace(16.0, 32.0, 3) * 20
            ns_11m = [10569312 for _ in range(len(ms_11m))]
            ys_11m = powlaw_approach2(np.array([ns_11m, ms_11m.tolist()]), *popt_approach2).astype(float) - E
            xs_11m = [6 * m * 10569312**2 for m in ms_11m]
            ax.plot(xs_11m, ys_11m, color="tab:grey", linestyle="dashed")
            axins.plot(xs_11m, ys_11m, color="tab:grey", linestyle="dashed")

            for N in params[1:]:
                ms_11m = [20 * m for m in [1.0, 2.0, 4.0, 8.0, 16.0, 32.0]]
                ns_11m = [N for _ in range(len(ms_11m))]
                xs_11m = ((N**2) * np.array(ms_11m) * 6).tolist()
                ys_11m = powlaw_approach2(np.array([ns_11m, ms_11m]).astype(float), *popt_approach2) - E
                ax.plot(xs_11m, ys_11m, color="tab:grey", linestyle="dashed")
                axins.plot(xs_11m, ys_11m, color="tab:grey", linestyle="dashed")

            for N in params:
                start_compute = (N**2) * 20 * 32.0 * 6
                end_compute = 3e22

                ms_11m = []
                for c in np.linspace(np.log10(start_compute), np.log10(end_compute), 10):
                    m = np.power(10, c) / ((N**2) * 6)
                    ms_11m.append(m)

                xs_11m = ((N**2) * np.array(ms_11m) * 6).tolist()
                ys_11m = (
                    powlaw_approach2(np.array([[N for _ in range(len(xs_11m))], ms_11m]).astype(float), *popt_approach2)
                    - E
                )

                ax.plot(
                    xs_11m, ys_11m, color="tab:grey", linestyle="dashed"
                )  # c=[(m - 0.5) / 64.0 for m in [1.0] + small_mults], cmap=cmap)
                axins.plot(
                    xs_11m, ys_11m, color="tab:grey", linestyle="dashed"
                )  # c=[(m - 0.5) / 64.0 for m in [1.0] + small_mults], cmap=cmap)

        for N in params + [6889410560]:
            for m in cc_mults:
                if N == 10569312 and m in small_mults:
                    continue
                if m in small_mults:
                    continue
                # if N <= 411616256 and m < 2.0:
                #     continue
                # if N < 511616256:
                #     continue
                if N == 6889410560 and m != 1.0:
                    continue

                # extrap_popt = powlaw_approach2()
                compute = (N**2) * 20 * m * 6
                for ax_count, a in enumerate([ax, axins]):
                    a.errorbar(
                        [compute],
                        [powlaw_approach2(np.array([[N], [20 * m]]), *popt_approach2) - E],
                        # color=color_map[m],
                        color=scalarMap.to_rgba(m),
                        markersize=8 if ax_count == 0 else 12,
                        alpha=0.9,
                        zorder=9,
                        marker=PARAM_SHAPES[N],
                        fillstyle="none",
                    )
                intr = [powlaw_approach2(np.array([[N], [20 * m]]), *popt_approach2)]

                if (N == 1439795200 and m == 32.0) or (N > 1439795200 and m == 1.0):
                    for ax_count, a in enumerate([ax_down, axins2]):
                        a.errorbar(
                            intr,
                            [e - al * np.exp(-b * iii) for iii in intr[0]],
                            # tmp[downstream].tolist()[0],
                            color=scalarMap.to_rgba(m),
                            markersize=8 if ax_count == 0 else 12,
                            alpha=0.9,
                            zorder=9,
                            marker=PARAM_SHAPES[N],
                            fillstyle="none",
                        )
    for ii, df_model in enumerate(dfs_model):
        if model_names[ii] not in points_1x:
            continue

        for i in range(len(df_model["flops"].tolist())):
            l = df_model[f"{prefix}{val_suffix}"].tolist()[i]
            # asymmetric_error = [
            #     [df_model[f"{prefix}upper_{val_suffix}"].tolist()[i] - l],
            #     [l - df_model[f"{prefix}lower_{val_suffix}"].tolist()[i]],
            # ]

            for ax_count, a in enumerate([ax, axins]):
                a.errorbar(
                    df_model["flops"].tolist()[i],
                    (df_model[f"{prefix}{val_suffix}"] - E).tolist()[i],
                    # yerr=asymmetric_error,
                    # color=color_map[df_model["cc_mult"].tolist()[i]],
                    # ecolor=color_map[df_model["cc_mult"].tolist()[i]],
                    color=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                    ecolor=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                    marker=df_model["shape"].tolist()[0],
                    markersize=8 if ax_count == 0 else 12,
                    alpha=0.9,
                    zorder=9,
                    # capsize=4,
                )
            if df_model["cc_mult"].tolist()[i] == 1.0 or df_model["shape"].tolist()[0] == PARAM_SHAPES[1439795200]:
                for ax_count, a in enumerate([ax_down, axins2]):
                    a.errorbar(
                        df_model[f"{prefix}{val_suffix}"].tolist()[i],
                        df_model[downstream].tolist()[i],
                        color=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                        ecolor=scalarMap.to_rgba(df_model["cc_mult"].tolist()[i]),
                        marker=df_model["shape"].tolist()[0],
                        markersize=8 if ax_count == 0 else 12,
                        alpha=0.9,
                        zorder=9,
                    )

    # ax.set_xlim(left=3e15, right=3e20)
    # ax.set_ylim(top=4.5, bottom=0.5)
    ax.margins(y=0.0, x=0.0)
    ax.autoscale()

    ax_down.margins(y=0.0, x=0.0)
    ax_down.autoscale()

    handles, labels = None, None
    if ax is not None:
        handles, labels = ax.get_legend_handles_labels()
    else:
        handles, labels = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    for ms in points_1x:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label="$N = $" + MODEL_FRINDLIES[ms],
                color="grey",
                marker=MODEL_SHAPES[ms],
                linestyle="",
            )
        )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Prediction",
            color="grey",
            marker="o",
            fillstyle="none",
            linestyle="",
        )
    )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Interpolation",
            color="grey",
            marker="",
            linestyle="-",
        )
    )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Extrapolation",
            color="grey",
            marker="",
            linestyle="--",
        )
    )

    more_handles.extend(
        [
            Line2D(
                [0],
                [0],
                label="$M = 20$",
                color=scalarMap.to_rgba(1.0),
                marker="o",
                linestyle="",
                alpha=0.9,
            ),
            Line2D(
                [0],
                [0],
                label="$M = 320$",
                color=scalarMap.to_rgba(16.0),
                marker="o",
                linestyle="",
                alpha=0.9,
            ),
            Line2D(
                [0],
                [0],
                label="$M = 640$",
                color=scalarMap.to_rgba(32.0),
                marker="o",
                linestyle="",
                alpha=0.9,
            ),
        ]
    )

    # add manual symbols to auto legend
    handles.extend(more_handles)

    # Adding a color bar
    # sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, 32.0))
    # extended = ["20", "40", "80", "160", "320", "640"]
    # cbar = plt.colorbar(sm, ax=ax)  # , aspect=50)
    # cbar.set_label("token multiplier $M$", labelpad=15, rotation=270)
    # cbar.set_ticks(np.linspace(0, 32.0, len(extended)))
    # cbar.set_ticklabels(extended)

    ax.legend(
        handles=handles,
        loc="lower left",
        # bbox_to_anchor=(1.02, 1.02),
        # bbox_to_anchor=bbox_to_anchor,
        ncol=2,
        columnspacing=0.8,
    ).set_zorder(102)

    return E
