import json
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os
import json
import numpy as np
import matplotlib as mpl
import matplotlib.colors as colors
from tqdm import tqdm
from scipy.optimize import curve_fit
import pandas as pd
import matplotlib.cm as cmx
from matplotlib import patches
from matplotlib import text as mtext
import math
from pathlib import Path
import scipy

from overtraining.plotting.constants import (
    VAL_FRIENDLIES,
    MODEL_SHAPES,
    MODEL_FRINDLIES,
    DATASET_FRIENDLIES,
    DATASET_COLORS,
    SUBSET,
    DOWNSTREAM_TO_LOSS,
    RANDOM_BASELINE,
)

mpl.rcParams["figure.dpi"] = 300
font = {
    "size": 11,
}
mpl.rc("font", **font)


"""different functional forms"""


def powlaw_irreducible(x, a, b, c):
    return a * np.power(x, b) + c


def powlaw_approach2(x, alpha, beta, b, E):
    # x is tuple of N and m
    N = x[0]
    M = x[1]
    D = x[0] * x[1]

    # new
    return (alpha * np.power(N, b)) + (beta * np.power(6, b / 2) * np.power(M * N, b)) + E


def powlaw_approach2_no_E(x, alpha, beta, b):
    # x is tuple of N and m
    N = x[0]
    M = x[1]
    D = x[0] * x[1]

    return (alpha * np.power(N, b)) + (beta * np.power(6, b / 2) * np.power(M * N, b))


def powlaw(x, a, b):
    return a * np.power(x, b)


def linlaw(x, a, b):
    return a + x * b


"""curve fitting functions"""


def curve_fit_powlaw_irreducible(xdata, ydata, fixed_E=None):
    # try many different fits and retain the best one as done in chinchilla
    a_p0 = [3e1, 3e2, 3e3]
    b_p0 = [-1e-1]
    c_p0 = [0.0]

    min_residual = float("inf")
    ret = None

    for a0 in a_p0:
        for b0 in b_p0:
            for c0 in c_p0:
                popt, pcov = None, None
                ydatafit = None

                if fixed_E is None:
                    popt, pcov = curve_fit(
                        powlaw_irreducible,
                        xdata,
                        ydata,
                        p0=[a0, b0, c0],
                        maxfev=10000,
                    )

                    ydatafit = powlaw_irreducible(xdata, *popt)
                else:
                    popt, pcov = curve_fit(
                        powlaw,
                        xdata,
                        ydata - fixed_E,
                        p0=[a0, b0],
                        maxfev=10000,
                    )

                    ydatafit = powlaw(xdata, *popt)
                    popt = np.array([popt[0], popt[1], fixed_E])

                residuals = ydata - ydatafit
                curr_residual = (np.sum(residuals**2) / (residuals.size - 2)) ** 0.5

                if curr_residual < min_residual:
                    min_residual = curr_residual
                    # ret = (popt, pcov, ydatafit)
                    ret = popt
    return ret


def curve_fit_powlaw(xdata, ydata):
    # https://stackoverflow.com/questions/41109122/fitting-a-curve-to-a-power-law-distribution-with-curve-fit-does-not-work

    """Fit data to a power law with weights according to a log scale"""
    # Weights according to a log scale
    # Apply fscalex
    xdata_log = np.log10(xdata)
    # Apply fscaley
    ydata_log = np.log10(ydata)
    # Fit linear
    popt_log, pcov_log = curve_fit(linlaw, xdata_log, ydata_log)
    # Apply fscaley^-1 to fitted data
    ydatafit_log = np.power(10, linlaw(xdata_log, *popt_log))
    # There is no need to apply fscalex^-1 as original data is already available

    return np.array([np.power(10, popt_log[0]), popt_log[1]])
    # return (popt_log, pcov_log, ydatafit_log)


def curve_fit_powlaw_approach2(xdata, ydata, fixed_E=None):
    # try many different fits and retain the best one as done in chinchilla

    alpha_p0 = [1e2, 3e2, 1e3, 3e3]
    beta_p0 = [1e2, 3e2, 1e3, 3e3]
    b_p0 = [
        -1e-1,
        -3e-1,
    ]
    e_p0 = [0.0, 1.0, 2.0, 10.0]

    min_residual = float("inf")
    ret = None

    for a0 in alpha_p0:
        for b0 in beta_p0:
            for c0 in b_p0:
                for e0 in e_p0:

                    popt, pcov = None, None
                    ydatafit = None

                    if fixed_E is None:
                        popt, pcov = curve_fit(
                            powlaw_approach2,
                            xdata,
                            ydata,
                            p0=[a0, b0, c0, e0],
                            maxfev=10000,
                        )

                        ydatafit = powlaw_approach2(xdata, *popt)
                    else:
                        popt, pcov = curve_fit(
                            powlaw_approach2_no_E,
                            xdata,
                            ydata - fixed_E,
                            p0=[a0, b0, c0],
                            maxfev=10000,
                        )

                        ydatafit = powlaw_approach2_no_E(xdata, *popt)
                        popt = np.array([popt[0], popt[1], popt[2], fixed_E])

                    ydatafit = powlaw_approach2(xdata, *popt)
                    residuals = ydata - ydatafit
                    curr_residual = (np.sum(residuals**2) / (residuals.size - 2)) ** 0.5

                    if curr_residual < min_residual:
                        min_residual = curr_residual
                        ret = popt

    return ret


"""data parsing"""


def parse_model_json(model_json, cc_mults, datasets, eval_dir=None):
    payload = {}
    data = None
    with open(model_json) as f:
        data = json.load(f)

    payload["cc_mult"] = data["hyperparameters"]["chinchilla_multiplier"]
    payload["dataset_name"] = data["dataset_name"]

    if payload["cc_mult"] not in cc_mults:
        return None
    if payload["dataset_name"] not in datasets:
        return None

    payload["name"] = data["name"]
    payload["model_name"] = data["hyperparameters"]["model"].split("/")[-1].split(".")[0]
    payload["N"] = data["hyperparameters"]["params"]
    payload["D"] = data["hyperparameters"]["tokens"]
    payload["flops"] = 6 * payload["N"] * payload["D"]
    payload["color"] = DATASET_COLORS[payload["dataset_name"]]
    payload["shape"] = MODEL_SHAPES[payload["model_name"]] if payload["model_name"] in MODEL_SHAPES else "o"
    payload["tok_mult"] = payload["cc_mult"] * 20

    for result in data["results"]:
        suffix = result["val_data"][0].split("/")[-2]
        if "de-en" in suffix:
            suffix = result["val_data"][0].split("/")[-1].split(".")[0]
        payload[f"loss_{suffix}"] = result["loss"]
        payload[f"loss_upper_{suffix}"] = result["loss_sequences_upper_95"]
        payload[f"loss_lower_{suffix}"] = result["loss_sequences_lower_95"]

    if eval_dir is not None:

        root_name = f"evaluation_{Path(model_json).stem}_heavy.json"
        eval_json = f"{eval_dir}/{root_name}"
        assert os.path.exists(eval_json)

        eval_data = None
        with open(eval_json) as f:
            eval_data = json.load(f)

        err_acc = 0.0
        err_subset_acc = 0.0
        err_centered_acc = 0.0
        err_centered_subset_acc = 0.0

        err_acc_count = 0
        err_subset_acc_count = 0

        for k in eval_data["eval_metrics"]["icl"]:
            err = 1.0 - eval_data["eval_metrics"]["icl"][k]
            if k in DOWNSTREAM_TO_LOSS:
                if k in SUBSET:
                    err_subset_acc += err
                    err_subset_acc_count += 1
                    err_centered_subset_acc += err - (1 - RANDOM_BASELINE[k])

                err_acc += err
                err_acc_count += 1
                err_centered_acc += err - (1 - RANDOM_BASELINE[k])

                payload[f"err_{k}"] = err

        assert err_acc_count == len(DOWNSTREAM_TO_LOSS)
        assert err_subset_acc_count == len(SUBSET)

        payload["err_avg"] = err_acc / err_acc_count
        payload["err_avg_subset"] = err_subset_acc / err_subset_acc_count
        payload["err_centered_avg"] = err_centered_acc / err_acc_count
        payload["err_centered_avg_subset"] = err_centered_subset_acc / err_subset_acc_count

    return payload


def parse_model_jsons(
    model_dir,
    datasets,
    cc_mults=[
        1.0,
    ],
    eval_dir=None,
):
    payloads = []
    for fp in os.listdir(model_dir):
        parsed = parse_model_json(f"{model_dir}/{fp}", cc_mults, datasets, eval_dir)
        if parsed is not None:
            payloads.append(parsed)

    df = pd.DataFrame.from_dict(payloads).sort_values(by=["flops"])

    return df


def split_df_by_dataset(df):
    names = df["dataset_name"].unique().tolist()  # find unique values
    dfs = []
    for n in names:
        dfs.append(df[df["dataset_name"] == n].sort_values(by=["flops"]).reset_index(drop=True))

    return dfs


def split_df_by_mult(df, included_models):

    dff = df[df["model_name"].isin(included_models)]
    names = dff["cc_mult"].unique().tolist()  # find unique values
    dfs = []
    for n in names:
        dfs.append(dff[dff["cc_mult"] == n].sort_values(by=["flops"]).reset_index(drop=True))

    return dfs, names


def split_df_by_model(df, min_only_field=""):
    names = df["model_name"].unique().tolist()  # find unique values
    dfs = []
    for n in names:
        df_tmp = df[df["model_name"] == n].sort_values(by=["dataset_name", "cc_mult"])

        if len(min_only_field):
            min_inds = df_tmp[min_only_field].idxmin()
            df_tmp = df_tmp.loc[[min_inds]]

        dfs.append(df_tmp.reset_index(drop=True))

    return dfs, names


def setup_ax_scaling(suffix):
    fig, ax = plt.subplots(nrows=1, ncols=1)

    ax.set_ylabel(f"Reducible loss: {VAL_FRIENDLIES[suffix]} ($L(C) - E$)")
    ax.set_xlabel("Compute ($6ND, D=MN$) [FLOPs]")
    ax.set_yscale("log")
    # ax.set_yticks([0.80, 0.825, 0.85, 0.875, 0.9, 0.925, 0.95, 0.975, 1.0])
    ax.set_yticks([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
    ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xscale("log")
    ax.grid(which="major", ls="-")
    # ax.set_xlim(left=10e15, right=10e20)
    # ax.set_ylim(bottom=2, top=6.5)

    return fig, ax


def setup_legend(datasets, bbox_to_anchor=(1.11, 1.05), ax=None, loc=None):
    handles, labels = None, None
    if ax is not None:
        handles, labels = ax.get_legend_handles_labels()
    else:
        handles, labels = plt.gca().get_legend_handles_labels()

    # create manual symbols for legend
    more_handles = []
    for ms in MODEL_SHAPES:
        more_handles.append(
            Line2D(
                [0],
                [0],
                label=MODEL_FRINDLIES[ms] if ms in MODEL_FRINDLIES else "grid",
                color="grey",
                marker=MODEL_SHAPES[ms] if ms in MODEL_SHAPES else "o",
                linestyle="",
            )
        )
    for ds in DATASET_COLORS:
        if ds in datasets:
            more_handles.append(
                Line2D(
                    [0],
                    [0],
                    label=DATASET_FRIENDLIES[ds],
                    markersize=8,
                    color=DATASET_COLORS[ds],
                    marker="o",
                    linestyle="",
                )
            )

    # add manual symbols to auto legend
    handles.extend(more_handles)

    plt.legend(
        handles=handles,
        loc=loc if loc is not None else "upper right",
        # bbox_to_anchor=(1.18, 0),
        bbox_to_anchor=bbox_to_anchor,
        ncol=1,
    )


def make_latex_table(cols: str, rows: str, title: str, caption: str, label: str):

    header = f"""
\\begin{{table}}
\\centering
\\begin{{tabular}}{{l?{'c'*(len(cols)-1)}}}
\\toprule
"""

    body = " & ".join(cols) + "\\\\\\midrule\n"
    for row in rows:
        body += " & ".join(row) + "\\\\\n"

    footer = f"""
\\bottomrule
\\end{{tabular}}
\\caption{{\\textbf{{{title}}} {caption}}}
\\label{{{label}}}
\\end{{table}}
"""

    return header + body + footer


def fit_ds(
    train_dataset,
    downstream,
    add_1b,
    model_dir="exp_data/models_tok",
    eval_dir="exp_data/evals_tok",
    de_model_dir="exp_data/models_tok_de_en",
    val_dataset="c4_val",
    Ms=[0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0],
):
    df = parse_model_jsons(
        model_dir if "de-en" not in val_dataset else de_model_dir,
        cc_mults=Ms,
        datasets=[train_dataset],
        eval_dir=eval_dir,
    )
    fit_models = ["d=96_l=8_h=4", "d=512_l=8_h=4", "d=576_l=24_h=8", "d=1024_l=24_h=8"]
    df_mults, names = split_df_by_mult(df, fit_models)

    df_mults_dict = {names[i]: df_mults[i] for i in range(len(names))}

    xs_irr = df_mults_dict[1.0]["flops"].tolist()
    ys_irr = df_mults_dict[1.0][f"loss_{val_dataset}"].tolist()
    ys2_irr = df_mults_dict[1.0][downstream].tolist()
    ms_irr = df_mults_dict[1.0]["tok_mult"].tolist()
    ns_irr = df_mults_dict[1.0]["N"].tolist()

    tmp = df_mults_dict[16.0]
    xs_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["flops"].tolist())
    ys_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"][f"loss_{val_dataset}"].tolist())
    ys2_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"][downstream].tolist())
    ms_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["tok_mult"].tolist())
    ns_irr.extend(tmp[tmp["model_name"] == "d=96_l=8_h=4"]["N"].tolist())

    assert len(xs_irr) == 5

    popt_approach2 = curve_fit_powlaw_approach2(np.array([ns_irr, ms_irr]), np.array(ys_irr).astype(float))

    if add_1b:
        # add 1b for this
        df_double = parse_model_jsons(model_dir, cc_mults=[1.0], datasets=[train_dataset], eval_dir=eval_dir)
        df_mults_double, _ = split_df_by_mult(df_double, ["open_lm_1b"])
        for ii, df_mult in enumerate(df_mults_double):
            tmp2 = df_mult[(df_mult["model_name"] == "open_lm_1b")]
            assert len(tmp2["flops"].tolist()) == 1

            xs_irr.extend(tmp2["flops"].tolist())
            ys_irr.extend(tmp2[f"loss_{val_dataset}"].tolist())
            ms_irr.extend(tmp2["tok_mult"].tolist())
            ns_irr.extend(tmp2["N"].tolist())
            ys2_irr.extend(tmp2[downstream].tolist())

        assert len(xs_irr) == 6

    popt_ds, _ = scipy.optimize.curve_fit(
        lambda t, a, b, E: E - a * np.exp(-b * t), ys_irr, ys2_irr, maxfev=10000, p0=[1.0, 1.0, 0.75]
    )

    return popt_approach2, popt_ds, ys_irr
