import json
from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
from pathlib import Path
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition

np.random.seed(0)


def pareto_frontier(points_dict):
    sorted_points = sorted(points_dict.items(), key=lambda kv: (kv[1][0], kv[1][1]))
    result_dict = {}

    # Track the highest value of y and the current pareto point
    pareto_point = sorted_points[0]
    for point in sorted_points:
        # Is this point a pareto point?
        if point[1][1] < pareto_point[1][1]:
            result_dict[point[0]] = point[1]
            pareto_point = point

    return result_dict


def grid_full_plot():
    data = {}
    with open("exp_data/grid.json", "r") as f:
        data = json.load(f)

    xs = []
    ys = []

    xs_mini = {}

    by_param = {}

    open_lm_7b_tup = (6 * 6889410560 * 6889410560 * 20, 1.805)
    open_lm_1b_tup = (6 * 1439795200 * 1439795200 * 20, 2.09)

    for e in data:
        D = int(e.split("-")[-4])
        N = D / 20
        xs.append(6 * N * D)
        ys.append(data[e])

        if xs[-1] not in xs_mini:
            xs_mini[xs[-1]] = ys[-1]
            by_param[N] = (ys[-1], e, D)
        else:
            if ys[-1] < xs_mini[xs[-1]]:
                xs_mini[xs[-1]] = ys[-1]
                by_param[N] = (ys[-1], e, D)

    points_dict = {}
    for n in by_param:
        l, name, d = by_param[n]
        points_dict[name] = (6 * n * d, l)
    points_dict = pareto_frontier(points_dict)
    print(list(points_dict.keys()))

    _, axes = plt.subplots(nrows=1, ncols=3, constrained_layout=True, figsize=(90 / 7, 4.0), sharex=True, sharey=True)

    for i, ax in enumerate(axes):
        ax.set_xlabel("Compute ($6ND$) [FLOPs]")
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
        ax.set_yticks([2.0, 3.0, 4.0, 5.0, 6.0])
        ax.set_xlim(left=8e15, right=2e22)
        ax.grid(which="major", ls="-")

        if i == 0:
            ax.set_title("Search")
            ax.scatter(xs, ys, zorder=9, alpha=0.3, color="tab:orange", label="Grid search models", s=25)
            ax.set_ylabel("Loss: OpenLM eval")
        elif i == 1:
            ax.set_title("Filter")
            xs_m = []
            ys_m = []

            for k in points_dict:
                c, l = points_dict[k]
                xs_m.append(c)
                ys_m.append(l)

            ax.scatter(xs_m, ys_m, zorder=9, alpha=0.7, color="tab:orange", label="Grid search models", s=64)

        else:
            ax.set_title("Fit")
            axins = zoomed_inset_axes(ax, 6)

            # sub region of the original image
            x1, x2, y1, y2 = 2e21, 1.5e22, 1.65, 1.95
            axins.set_xlim(x1, x2)
            axins.set_ylim(y1, y2)
            axins.set_axes_locator(InsetPosition(ax, [0.12, 0.12, 0.35, 0.25]))
            axins.set_yscale("log")
            axins.set_xscale("log")
            axins.minorticks_off()
            axins.set_yticks([1.9, 1.8, 1.7])
            axins.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            axins.get_yaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())

            axins.grid(which="major", ls="-")

            # draw a bbox of the region of the inset axes in the parent axes and
            # connecting lines between the bbox and the inset axes area
            mark_inset(ax, axins, loc1=1, loc2=4, fc="none", ec="0.5")

            xs_m = []
            ys_m = []

            for k in points_dict:
                c, l = points_dict[k]
                if c < 5.361317869436928e16 or c > 5.18223548264448e17:
                    xs_m.append(c)
                    ys_m.append(l)

            ax.scatter(xs_m, ys_m, zorder=9, alpha=0.7, color="tab:orange", label="Grid search models", s=64)

            popts = []
            x_range_compute = []
            for ii in [1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, 1e22, 1e23]:
                for jj in [1, 3, 5]:
                    x_range_compute.append(ii * jj)

            for _ in tqdm(range(10_000)):
                idx = np.random.choice(len(xs_m), size=len(xs_m))
                popt = curve_fit_powlaw_irreducible(
                    np.array(xs_m).astype(float)[idx].T, np.array(ys_m).astype(float)[idx]
                )

                popts.append(popt)

            y_samps = []
            for popt in tqdm(popts):
                y_samps.append(powlaw_irreducible(x_range_compute, *popt))

            stuff = np.vstack(y_samps)

            lowers = np.percentile(stuff, 2.5, axis=0)
            uppers = np.percentile(stuff, 97.5, axis=0)

            popt = curve_fit_powlaw_irreducible(np.array(xs_m).astype(float).T, np.array(ys_m).astype(float))
            for _, a in enumerate([ax, axins]):
                a.plot(
                    x_range_compute, powlaw_irreducible(x_range_compute, *popt), color="tab:orange", linestyle="dashed"
                )
                a.plot(
                    np.linspace(xs_m[0], xs_m[-1], 2000),
                    powlaw_irreducible(np.linspace(xs_m[0], xs_m[-1], 2000).T, *popt),
                    color="tab:orange",
                )
                a.plot(x_range_compute, lowers, linestyle="dashed", color="tab:orange", alpha=0.4)
                a.plot(x_range_compute, uppers, linestyle="dashed", color="tab:orange", alpha=0.4)
                a.fill_between(x_range_compute, lowers, uppers, color="tab:orange", alpha=0.1)

            ours = ["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"]

            xs_m = []
            ys_m = []

            for k in points_dict:
                c, l = points_dict[k]
                if k.split("-")[0] in ours:
                    xs_m.append(c)
                    ys_m.append(l)

            ax.scatter(xs_m, ys_m, zorder=9, alpha=0.7, color="tab:blue", marker="P", label="Selected models", s=64)
            axes[0].scatter(
                xs_m, ys_m, zorder=9, alpha=0.7, color="tab:blue", marker="P", label="Selected models", s=64
            )
            axes[1].scatter(
                xs_m, ys_m, zorder=9, alpha=0.7, color="tab:blue", marker="P", label="Selected models", s=64
            )

            popt = curve_fit_powlaw_irreducible(np.array(xs_m).astype(float).T, np.array(ys_m).astype(float))
            for a in [ax, axins]:
                a.plot(
                    x_range_compute, powlaw_irreducible(x_range_compute, *popt), color="tab:blue", linestyle="dashed"
                )
                a.plot(
                    np.linspace(xs_m[0], xs_m[-1], 2000),
                    powlaw_irreducible(np.linspace(xs_m[0], xs_m[-1], 2000).T, *popt),
                    color="tab:blue",
                )
            axins.scatter(
                [open_lm_7b_tup[0]],
                [open_lm_7b_tup[1]],
                marker="*",
                zorder=9,
                alpha=0.7,
                color="tab:purple",
                label="Target 6.9B model",
                s=128,
            )

        ax.scatter(
            [open_lm_1b_tup[0]],
            [open_lm_1b_tup[1]],
            marker="p",
            zorder=9,
            alpha=0.7,
            color="tab:purple",
            label="Target 1.4B model",
            s=64,
        )
        ax.scatter(
            [open_lm_7b_tup[0]],
            [open_lm_7b_tup[1]],
            marker="*",
            zorder=9,
            alpha=0.7,
            color="tab:purple",
            label="Target 6.9B model",
            s=64,
        )

    handles = None
    if ax is not None:
        handles, _ = ax.get_legend_handles_labels()
    else:
        handles, _ = plt.gca().get_legend_handles_labels()

    more_handles = []

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Interpolation",
            color="grey",
            marker="",
            linestyle="-",
        )
    )

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Extrapolation",
            color="grey",
            marker="",
            linestyle="--",
        )
    )

    # add manual symbols to auto legend
    handles.extend(more_handles)
    ax.legend(handles=handles, loc="upper right", bbox_to_anchor=(1.25, 1.0))

    plt.tight_layout()
