import json
from overtraining.plotting.shared import *
from overtraining.plotting.constants import *
import seaborn as sns
from matplotlib.patches import Rectangle
from pathlib import Path
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition

np.random.seed(0)


def pareto_frontier(points_dict):
    sorted_points = sorted(points_dict.items(), key=lambda kv: (kv[1][0], kv[1][1]))
    result_dict = {}

    # Track the highest value of y and the current pareto point
    pareto_point = sorted_points[0]
    for point in sorted_points:
        # Is this point a pareto point?
        if point[1][1] < pareto_point[1][1]:
            result_dict[point[0]] = point[1]
            pareto_point = point

    return result_dict


def grid_full_plot_color():
    data = {}
    with open("exp_data/grid.json", "r") as f:
        data = json.load(f)

    xs = []
    ys = []
    zs = []

    xs_mini = {}

    by_param = {}

    open_lm_7b_tup = (6 * 6889410560 * 6889410560 * 20, 1.805)
    open_lm_1b_tup = (6 * 1439795200 * 1439795200 * 20, 2.09)

    for e in data:
        D = int(e.split("-")[-4])
        N = D / 20

        bs = int(e.split("-")[1]) * 8
        steps = D // (bs * 2048)

        xs.append(6 * N * D)
        ys.append(data[e])
        zs.append(steps)

        if xs[-1] not in xs_mini:
            xs_mini[xs[-1]] = ys[-1]
            by_param[N] = (ys[-1], e, D, zs[-1])
        else:
            if ys[-1] < xs_mini[xs[-1]]:
                xs_mini[xs[-1]] = ys[-1]
                by_param[N] = (ys[-1], e, D, zs[-1])

    cmap = plt.get_cmap("plasma")
    cNorm = colors.LogNorm(vmin=min(zs), vmax=max(zs))  # values range from 0.25 to 64
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(min(zs), max(zs)))

    points_dict = {}
    for n in by_param:
        l, name, d, c = by_param[n]
        points_dict[name] = (6 * n * d, l, c)
    points_dict = pareto_frontier(points_dict)
    print(list(points_dict.keys()))

    # fig, axes = plt.subplots(nrows=1, ncols=3, constrained_layout=True, figsize=(90 / 7, 4), sharex=True)
    fig, axes = plt.subplots(nrows=1, ncols=2, constrained_layout=True, figsize=(10, 4), sharex=True)

    # ax.set_xlim(left=0.07, right=0.12)
    # ax.set_ylim(bottom=0.14, top=0.165)
    for i, ax in enumerate(axes):
        ax.set_xlabel("Compute ($6ND$) [FLOPs]")
        ax.set_yscale("log")
        ax.set_xscale("log")

        if i != 2:
            ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
            ax.set_yticks([2.0, 3.0, 4.0, 5.0, 6.0])
        # ax.set_xlim(left=8e15, right=2e22)
        # ax.set_ylim(top=6.0)

        ax.grid(which="major", ls="-")

        if i == 0:
            ax.scatter(xs, ys, zorder=9, alpha=0.7, color=scalarMap.to_rgba(zs), label="Grid search models", s=64)
            ax.set_ylabel("Loss: OpenLM eval")
        elif i == 1:
            xs_m = []
            ys_m = []
            zs_m = []

            for k in points_dict:
                c, l, color = points_dict[k]
                xs_m.append(c)
                ys_m.append(l)
                zs_m.append(color)

            ax.scatter(xs_m, ys_m, zorder=9, alpha=0.7, color=scalarMap.to_rgba(zs_m), label="Grid search models", s=64)

        # else:
        #     xs_m = []
        #     ys_m = []
        #     zs_m = []

        #     for k in points_dict:
        #         c, l, color = points_dict[k]
        #         xs_m.append(c)
        #         ys_m.append(l)
        #         zs_m.append(color)

        #     ax.scatter(xs_m, zs_m, zorder=9, alpha=0.7, color="grey", label="Grid search models", s=64)
        #     # ax.set_ylabel("Number of optimization steps")
        #     # axins.set_yticks([1.9, 1.8, 1.7])
        #     ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
        #     ax.get_yaxis().set_minor_formatter(mpl.ticker.ScalarFormatter())
        #     # ax.grid(which="major", ls="-")

    handles = []
    # if ax is not None:
    #     handles, labels = ax.get_legend_handles_labels()
    # else:
    #     handles, labels = plt.gca().get_legend_handles_labels()

    more_handles = []

    cbar = plt.colorbar(sm, ax=axes[:2], aspect=40)
    cbar.set_label("Number of optimization steps", labelpad=15)  # rotation=270)
    # cbar.set_ticks(np.linspace(0, max(cc_mults), len(cc_mults)))
    # cbar.set_ticklabels([f"{int(m*20)}" for m in cc_mults])

    more_handles.append(
        Line2D(
            [0],
            [0],
            label="Grid search models",
            color="grey",
            marker="o",
            linestyle="",
        )
    )

    # more_handles.append(
    #     Line2D(
    #         [0],
    #         [0],
    #         label="Interpolation",
    #         color="grey",
    #         marker="",
    #         linestyle="-",
    #     )
    # )

    # more_handles.append(
    #     Line2D(
    #         [0],
    #         [0],
    #         label="Extrapolation",
    #         color="grey",
    #         marker="",
    #         linestyle="--",
    #     )
    # )

    # add manual symbols to auto legend
    handles.extend(more_handles)
    ax.legend(handles=handles)  # , loc="lower left")

    # plt.tight_layout()
