import os
from builtins import FileNotFoundError
import numpy as np
import pandas as pd
import pickle
import pathlib
import mlflow
import warnings
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
from matplotlib.legend_handler import HandlerTuple

from hyperparameter_tuning.utils.gpytorch.models.variational_gpr import NUM_INDUCING_INPUTS
from utils.visualization.save_relative_error import get_log_relative_errors_save
from utils.visualization.visualization_constants import acgp_color, cglb_color, exact_color

# this import is necessary to set the tracking uri correct
from utils.result_management.result_management import get_steps_and_values_from_run
from acgp.hooks.stop_hook import StopHook
from utils.result_management.constants import (
    SN2,
    EXACT_SOLUTIONS,
    DIAGONAL,
    EQSOL,
    TEMP_VALUES,
    TEMP,
    OFFDIAGONAL,
    ALGORITHM,
    KERNEL,
    DATASET,
    U_DET,
    L_DET,
    U_QUAD,
    L_QUAD,
    STEP_TIME,
    SEED,
    SETUP_TIME,
)


# Saving plots:
save = True
fig_path = pathlib.Path("output/figures/bounds")
fig_path.mkdir(parents=True, exist_ok=True)

plot_relative_error = False

methods = ["acgp", "cglb", "exact"]
terms = ["llh", "quad", "log_det"]

bounds_path = "./output/results/mlruns"
mlflow.set_tracking_uri(bounds_path)
mlfc = mlflow.tracking.MlflowClient()
experiment_list = [
   exp for exp in mlfc.list_experiments() if exp.experiment_id != "0"
]  # 0 is the default mlflow experiment--it's empty
#experiment_list = [mlflow.get_experiment(e_id) for e_id in ["6", "8", "13"]]   # ids for main paper
#experiment_list = [mlflow.get_experiment(e_id) for e_id in ["13"]]   # ids for main paper
# experiment_list = [mlflow.get_experiment(e_id) for e_id in ["2", "3", "4", "5"]]
#experiment_list = [mlflow.get_experiment(e_id) for e_id in ["8"]]   # ids for main paper
#experiment_list = [mlflow.get_experiment(e_id) for e_id in ["31"]]

seeds = [0, 1, 2, 3, 4]
#seeds = [0]

plt.rc("text", usetex=True)
plt.rc(
    "text.latex",
    preamble=[
        r"\usepackage{amssymb} \usepackage{amsmath} \usepackage{marvosym} "
        r"\usepackage{bm}"  # " \usepackage[cm]"
    ],
)
plt.rc("font", family="serif")
plt.rcParams.update({"font.size": 18})


cglb_bound_label = r"CGLB upper+lower bound"
our_bound_label = r"ACGP upper+lower bound"

show_acgp_for_full_dataset = False

global N


def process_cglb_runs(runs):
    result_dict = {}
    for r in runs:
        if r.data.tags[ALGORITHM].startswith("cglb"):
            (
                idx,
                inducing_points,
                times,
                log_det_upper_bounds,
                log_det_lower_bounds,
                quad_upper_bounds,
                quad_lower_bounds,
            ) = process_cglb_run(r)

            result_dict[(idx, times)] = (
                inducing_points,
                log_det_upper_bounds,
                log_det_lower_bounds,
                quad_upper_bounds,
                quad_lower_bounds,
            )
    bounds = result_dict.values()
    inducing_points = [x for x, _, _, _, _ in bounds]
    log_det_upper_bounds = [x for _, x, _, _, _ in bounds]
    log_det_lower_bounds = [x for _, _, x, _, _ in bounds]
    quad_upper_bounds = [x for _, _, _, x, _ in bounds]
    quad_lower_bounds = [x for _, _, _, _, x in bounds]
    idx = [x for x, _ in result_dict.keys()]
    times = [x for _, x in result_dict.keys()]
    return (
        idx,
        inducing_points,
        times,
        log_det_upper_bounds,
        log_det_lower_bounds,
        quad_upper_bounds,
        quad_lower_bounds,
    )


# load CGLB data
def process_cglb_run(run):
    # run_id = run.info.run_id
    log_det_upper_bounds = run.data.metrics[U_DET]
    log_det_lower_bounds = run.data.metrics[L_DET]
    quad_upper_bounds = run.data.metrics[U_QUAD]
    quad_lower_bounds = run.data.metrics[L_QUAD]
    times = run.data.metrics[STEP_TIME] + run.data.metrics[SETUP_TIME]
    inducing_points = int(run.data.tags[ALGORITHM + "." + NUM_INDUCING_INPUTS])
    print(inducing_points)
    print(log_det_upper_bounds)
    global N
    idx = np.power(np.square(inducing_points) * N, 1.0 / 3.0)
    return (
        idx,
        inducing_points,
        times,
        log_det_upper_bounds,
        log_det_lower_bounds,
        quad_upper_bounds,
        quad_lower_bounds,
    )


def process_stopped_chol_runs(runs, preconditioner=0):
    for r in runs:
        if r.data.tags[ALGORITHM].startswith("MetaCholesky"):
            if int(r.data.tags[ALGORITHM + ".block_size"]) > 20000:
                # this is an exact run
                continue
            if int(r.data.tags["preconditioner_steps"]) != preconditioner:
                continue
            # raise NotImplementedError("I have to exclude the exact runs here!")
            return process_stopped_chol_run(r)
    return None


def process_exact_runs(runs):
    for r in runs:
        if r.data.tags[ALGORITHM].startswith("MetaCholesky"):
            if int(r.data.tags[ALGORITHM + ".block_size"]) < 20000:
                # this is an approximate run
                continue
            # below call is insufficient as it essentially ignores the time to build the kernel matrix
            #return r.data.metrics[STEP_TIME]
            _, times = get_steps_and_values_from_run(r.info.run_id, STEP_TIME)
            # there should be only the time to setup the kernel matrix and to run the meta Cholesky
            assert(len(times) == 2)
            return np.sum(times)
    return None


def process_stopped_chol_run(run):
    # fill the hook
    # TODO: HACK!
    file_name = run.data.tags[EXACT_SOLUTIONS]
    file_name = file_name.split("experiments/")[1]
    d = pickle.load(open(file_name, "rb"))
    diagL = np.squeeze(d[DIAGONAL])
    assert len(diagL.shape) == 1
    alpha = np.squeeze(d[EQSOL])
    assert diagL.shape == alpha.shape
    log_det = 2 * np.sum(np.log(diagL))
    quad = np.sum(np.square(alpha))

    try:
        # TODO: HACK!
        file_name = run.data.tags[TEMP_VALUES]
        file_name = file_name.split("experiments/")[1]
        d = pickle.load(open(file_name, "rb"))
    except FileNotFoundError:
        alternative_file_name = run.data.tags[TEMP_VALUES].replace("results", "results_bounds")
        d = pickle.load(open(alternative_file_name, "rb"))
        warnings.warn("Found a somewhat older result referencing to data in the wrong place.")

    temp_diagK = np.squeeze(d[TEMP + DIAGONAL])
    temp_offdiagK = np.squeeze(d[TEMP + OFFDIAGONAL])
    temp_alpha = np.squeeze(d[TEMP + EQSOL])
    global N
    N = diagL.shape[0]
    block_size = int(run.data.tags[ALGORITHM + "." + "block_size"])
    hook = StopHook(N=N, min_noise=float(exp.tags[SN2]))
    # hook.prepare()  # not necessary
    other_auxilary_variables = {
        "average_model_calibration": [],
        "expected_worst_case_increase_rate": [],
        "average_error": [],
        "average_error_overestimate": [],
    }

    K_ = np.zeros([block_size, block_size])
    processed_data = []
    for idi in range(0, N, block_size):
        advance = min(block_size, N - idi)

        processed_data.append(advance + idi)

        if advance < block_size:
            K_ = np.zeros([advance, advance])
        K_ += np.diag(temp_diagK[idi : idi + advance])  # fill diagonal
        K_ += np.diag(temp_offdiagK[idi : idi + advance - 1], -1)  # fill off-diagonal
        y_ = temp_alpha[idi : idi + advance].reshape(-1, 1)  # create solution vector
        hook.pre_chol(idi, K_, y_)
        other_auxilary_variables["average_error_overestimate"].append(
            np.sum(np.square(y_)) / advance
        )
        other_auxilary_variables["average_error"].append(
            np.sum(np.square(diagL[idi : idi + advance] * alpha[idi : idi + advance]))
            / advance
        )

        # now pretend to do the down-date
        K_ *= 0.0
        K_ += np.diag(diagL[idi : idi + advance])  # fill diagonal
        y_ = alpha[idi : idi + advance].reshape(-1, 1)  # create solution vector
        hook.post_chol(idi, K_, y_)
        K_ *= 0.0  # clear matrix

        # other_auxilary_variables["average_model_calibration"].append(np.sum(np.square(alpha[idi:idi+advance])) / advance)
        other_auxilary_variables["average_model_calibration"].append(
            np.sum(
                np.square(diagL[idi : idi + advance] * alpha[idi : idi + advance])
                / temp_diagK[idi : idi + advance]
            )
            / advance
        )
        other_auxilary_variables["expected_worst_case_increase_rate"].append(
            np.sum(
                np.square(
                    diagL[idi : idi + advance - 1]
                    * alpha[idi : idi + advance - 1]
                    * temp_offdiagK[idi : idi + advance - 1]
                    / sn2
                )
            )
            / (advance - 1)
        )

    hook.ldet = log_det
    hook.quad = quad
    hook.finalize()
    nllh, bounds = hook.get_bounds()

    # the first bound estimates are computed before even computing the first Cholesky --
    # it's so bad it screws up the plot
    start_at = 1
    bounds = bounds[start_at:]
    log_det_upper_bounds = [x for x, _, _, _ in bounds]
    log_det_lower_bounds = [x for _, x, _, _ in bounds]
    quad_upper_bounds = [x for _, _, x, _ in bounds]
    quad_lower_bounds = [x for _, _, _, x in bounds]
    # no setup time---that's just allocating N^2 memory which is at most a second
    _, times = get_steps_and_values_from_run(run.info.run_id, STEP_TIME)
    times = np.cumsum(times)[start_at:]
    m = block_size
    # print(-nllh)
    # print(-log_det / 2 - quad / 2 - N * np.log(2 * np.pi) / 2)
    steps = len(log_det_upper_bounds)
    idx = np.arange(start_at, steps + 1) * m
    idx[-1] = N

    return (
        idx,
        processed_data,
        times,
        log_det_upper_bounds,
        log_det_lower_bounds,
        quad_upper_bounds,
        quad_lower_bounds,
    )


def load_data(exp, method, cache_path):

    if method == "acgp":
        process_func = process_stopped_chol_runs
    else:
        process_func = process_cglb_runs

    path = cache_path / method
    path.mkdir(parents=True, exist_ok=True)

    USE_CACHE = False

    #try:
    if USE_CACHE:
        exact_time = np.load(path / "exact_time.npy")
        idx = np.load(path / "idx.npy")
        num_points = np.load(path / "num_points.npy")
        times = np.load(path / "times.npy")
        log_det_lower_bounds = np.load(path / "log_det_lower_bounds.npy")
        log_det_upper_bounds = np.load(path / "log_det_upper_bounds.npy")
        quad_lower_bounds = np.load(path / "quad_lower_bounds.npy")
        quad_upper_bounds = np.load(path / "quad_upper_bounds.npy")
    else:
    #except FileNotFoundError:
        print("Recomputing!")
        s = path.parts[-2]
        runs = mlfc.search_runs(
            [exp.experiment_id], filter_string=f"tags.{SEED} = '{s}'"
        )

        (
            idx,
            num_points,
            times,
            log_det_upper_bounds,
            log_det_lower_bounds,
            quad_upper_bounds,
            quad_lower_bounds,
        ) = process_func(runs)

        # log_det_lower_bounds is a mix of arrays and floats. Converting all to float:
        log_det_lower_bounds = [float(v) for v in log_det_lower_bounds]

        exact_time = process_exact_runs(runs)
        # embed()
        # stop
    if USE_CACHE:
        np.save(path / "exact_time.npy", exact_time)
        np.save(path / "idx.npy", idx)
        np.save(path / "num_points.npy", num_points)
        np.save(path / "times.npy", times)
        np.save(path / "log_det_lower_bounds.npy", log_det_lower_bounds)
        np.save(path / "log_det_upper_bounds.npy", log_det_upper_bounds)
        np.save(path / "quad_lower_bounds.npy", quad_lower_bounds)
        np.save(path / "quad_upper_bounds.npy", quad_upper_bounds)

    return (
        idx,
        num_points,
        times,
        exact_time,
        log_det_upper_bounds,
        log_det_lower_bounds,
        quad_upper_bounds,
        quad_lower_bounds,
    )


# Only create the figure once:
fig, ax = plt.subplots(figsize=(9, 5), constrained_layout=True)
for exp in experiment_list:
    try:
        dataset_name = exp.tags[DATASET]
        kernel_name = exp.tags[KERNEL]
        log_ls2 = exp.tags[KERNEL + ".log_ls2"]
    except:
        # seems to be a crashed experiment
        warnings.warn(f"Crashed experiment? : {exp.name}")
        continue
    sn2 = float(exp.tags[SN2])

    rec_times = {"acgp": [], "cglb": []}
    log_det_uppers = {"acgp": [], "cglb": []}
    log_det_lowers = {"acgp": [], "cglb": []}
    quad_uppers = {"acgp": [], "cglb": []}
    quad_lowers = {"acgp": [], "cglb": []}

    dfs = []

    first_seed = 0
    for s in seeds:

        cache_path = pathlib.Path(os.getcwd()) #pathlib.Path(bounds_path)
        cache_path = cache_path / "cache" / dataset_name / kernel_name / log_ls2 / str(s)
        cache_path.mkdir(parents=True, exist_ok=True)

        # s = int(s)

        # runs = mlfc.search_runs([exp.experiment_id], filter_string=f"tags.{SEED} = '{s}'")

        # exact_time = process_exact_runs(runs)

        # idx, times, log_det_upper_bounds, log_det_lower_bounds, quad_upper_bounds, quad_lower_bounds = process_stopped_chol_runs(runs)
        exact_log_det = 0
        exact_quad = 0
        for method in ["acgp", "cglb"]:
            (
                idx,
                processed_data,
                times,
                exact_time,
                log_det_upper_bounds,
                log_det_lower_bounds,
                quad_upper_bounds,
                quad_lower_bounds,
            ) = load_data(exp, method, cache_path)

            if method == "acgp":
                exact_log_det = log_det_upper_bounds[-1]
                exact_quad = quad_upper_bounds[-1]
                # ACGP delivers the exact solution in the last step
                results = {
                    "seed": s,
                    "points": processed_data[-1],
                    "times": exact_time,
                    "log_det_upper": exact_log_det if not plot_relative_error else 0,
                    "quad_upper": exact_quad if not plot_relative_error else 0,
                    "llh_upper": -exact_quad/2 -exact_log_det/2 - N*np.log(2*np.pi)/2 if not plot_relative_error else 0,
                    "method": "exact",
                }
                full_dataset_size = processed_data[-1]
                dfs.append(pd.DataFrame(results, index=[0]))

            if plot_relative_error:
                log_det_upper_bounds = get_log_relative_errors_save(
                    log_det_upper_bounds, exact_log_det
                )
                log_det_lower_bounds = get_log_relative_errors_save(
                    log_det_lower_bounds, exact_log_det
                )
                quad_upper_bounds = get_log_relative_errors_save(
                    quad_upper_bounds, exact_quad
                )
                quad_lower_bounds = get_log_relative_errors_save(
                    quad_lower_bounds, exact_quad
                )

            results = {
                "seed": [s] * len(times),
                "times": times,
                "points": processed_data,
                "log_det_upper": log_det_upper_bounds,
                "log_det_lower": log_det_lower_bounds,
                "quad_upper": quad_upper_bounds,
                "quad_lower": quad_lower_bounds,
                "llh_lower": (-np.array(log_det_upper_bounds) / 2 -np.array(quad_upper_bounds) / 2 - N * np.log(2 * np.pi) / 2).flatten().tolist(),
                "llh_upper": (-np.array(log_det_lower_bounds) / 2 -np.array(quad_lower_bounds) / 2 - N * np.log(2 * np.pi) / 2).flatten().tolist(),
                "method": [method] * len(times),
            }

            dfs.append(pd.DataFrame(results))

    results = pd.concat(dfs, ignore_index=True)
    results = results.astype({"points": "Int64"})

    for term in terms:
        scatter_handles = []

        for method in methods:
            subset = results[results["method"] == method]

            color_levels = np.linspace(0.4, 0.8, len(subset.points.unique()))
            scatter_size = 80

            if method == "acgp":
                # cmap = plt.cm.Reds
                colors = acgp_color(len(subset.points.unique()))
                acgp_legend_handle = ax.scatter(
                    # [], [], s=60, c=[cmap(color_levels[-2])], label="ACGP bounds"
                    [],
                    [],
                    s=scatter_size,
                    c=[colors[-2]],
                    label="ACGP bounds",
                )
            elif method == "cglb":
                # cmap = plt.cm.Blues
                colors = cglb_color(len(subset.points.unique()))
                if len(subset.points.unique()) > 0:
                    cglb_legend_handle = ax.scatter(
                        [],
                        [],
                        s=scatter_size,
                        c=[colors[-2]],
                        label="CGLB bounds",
                    )
                    #cglb_legend_handle = [ax.scatter(
                    #   [], [], s=scatter_size, c=[colors[-2]], label="CGLB bounds",
                    #   marker=marker
                    #)
                    #for marker in ["v", "^"]]
                else:
                    cglb_legend_handle = None

            elif method == "exact":
                # cmap = plt.cm.Greys
                # color_levels = [0.8]
                colors = exact_color(steps=2)[1:]
                # ax.axhline(subset[term + "_upper"].iloc[0], c=cmap(color_levels[0]))
                ax.axhline(subset[term + "_upper"].iloc[0], c=colors[0])
                exact_legend_handle = mlines.Line2D(
                    [],
                    [],
                    # color=cmap(color_levels[0]),
                    color=colors[0],
                    marker=".",
                    markersize=18,
                    label="Exact GPR",
                )

            # for seed in seeds:
            #    subset = results[results["seed"] == seed]
            #    acgp = subset[subset["method"] == "acgp"]
            #    cglb = subset[subset["method"] == "cglb"]
            #    exact = subset[subset["method"] == "exact"]

            # embed()
            point_groups = subset.groupby("points")

            for p, (points, idx) in enumerate(point_groups.groups.items()):

                if method == "acgp" and points == full_dataset_size:
                    if not show_acgp_for_full_dataset:
                        continue
                # s = point_groups.get_group(points)
                s = subset.loc[idx]
                # embed()

                for bound in ["upper", "lower"]:
                    #marker =
                    marker="^" if bound == "lower" else "v"
                    if method == "exact":
                        marker = "o"
                    scatter_handle = ax.scatter(
                        s["times"],
                        s[term + "_" + bound],
                        # s=30 + 40 * p,
                        marker=marker,
                        # marker="_",
                        #marker="^" if bound == "lower" else "v",
                        s=scatter_size,
                        # c=[cmap(color_levels[p])],
                        color=colors[p],
                        edgecolors="white",
                        # label=f"{method}"
                    )

                    if p == 2 or method=="exact":
                        scatter_handles.append(scatter_handle)

        # Annotate CGLB bounds:
        groups = results.groupby(["method", "points"])
        means = groups.mean()
        stds = groups.std()
        mins = groups.min()
        maxs = groups.max()

        # embed()
        annotations = []
        font_size = 14
        annotation_text = lambda text: f"\\textsf{{\\bfseries {text}}}"

        if "cglb" in methods:
            cglb_means = means.loc["cglb"]
            for p, points in enumerate(cglb_means.index):
                m = cglb_means.loc[points]
                y = mins.loc["cglb", points][term + "_upper"]

                text = ax.annotate(
                    text=annotation_text(int(points)),
                    # text,
                    # xy=(m.times, m[term + "_upper"]),
                    xy=(m.times, y),
                    xytext=(0, -18),
                    textcoords="offset points",
                    color=cglb_color(steps=len(cglb_means))[p],
                    fontsize=font_size,
                    horizontalalignment="center",
                    verticalalignment="baseline",
                )
                annotations.append(text)

            acgp_means = means.loc["acgp"]
            if not show_acgp_for_full_dataset:
                acgp_means = acgp_means.drop(index=full_dataset_size)
            for p, points in enumerate(acgp_means.index):
                m = acgp_means.loc[points]
                # y = mins.loc["acgp", points][term + "_lower"]
                y = maxs.loc["acgp", points][term + "_upper"]
                text = ax.annotate(
                    text=annotation_text(int(points)),
                    # xy=(m.times, m[term + "_lower"]),
                    xy=(m.times, y),
                    # xytext=(0, -18),
                    xytext=(0, 8),
                    textcoords="offset points",
                    color=acgp_color(steps=len(acgp_means))[p],
                    fontsize=font_size,
                    horizontalalignment="center",
                    verticalalignment="baseline",
                )
                annotations.append(text)

        exact_means = means.loc["exact"]
        for p, points in enumerate(exact_means.index):
            m = exact_means.loc[points]
            y = mins.loc["exact", points][term + "_upper"]
            text = ax.annotate(
                text=annotation_text(int(points)),
                xy=(m.times, y),
                xytext=(0, -15),
                textcoords="offset points",
                color=exact_color(steps=len(exact_means) + 1)[-1],
                fontfamily="sans-serif",
                fontsize=font_size,
                horizontalalignment="center",
                verticalalignment="baseline",
            )
            annotations.append(text)

        for text in annotations:
            # Set white outline:
            text.set_path_effects(
                [
                    path_effects.Stroke(linewidth=3, foreground="white", alpha=1),
                    path_effects.Normal(),
                ]
            )

        if "cglb" in methods:
            ax.legend([scatter_handles[4],
                        (scatter_handles[0], scatter_handles[1]),
                        (scatter_handles[2], scatter_handles[3])],
                        ["Exact", "ACGP", "CGLB"],
                        numpoints=1,
                        handler_map={tuple: HandlerTuple(ndivide=None, pad=-.9)}
                       )

        latex_dataset_name = dataset_name.replace("wilson_", "")
        if term == "log_det":
            if plot_relative_error:
                ylabel = r"$\mathrm{sign}(r)\mathrm{log}_2 (|r|+1)$"
            else:
                ylabel = r"$\mathrm{log}|\mathbf{K}|$"
            title = rf"Bounds on the log-determinant term -- \texttt{{{latex_dataset_name}}}"
        elif term == "quad":
            if plot_relative_error:
                ylabel = r"$\mathrm{sign}(r)\mathrm{log}_2 (|r|+1)$"
            else:
                ylabel = r"$\mathbf{y}^\top \mathbf{K}^{-1}\mathbf{y}$"
                ax.set_yscale('log')
            title = rf"Bounds on the quadratic term -- \texttt{{{latex_dataset_name}}}"
            # When plotting just the values, a log scale would mess with the distances of upper and lower bound to the exact quantity
            ax.set_yscale("log")
        elif term == "llh":
            if plot_relative_error:
                raise NotImplementedError()
            else:
                ylabel = r"$\log p(\mathbf{y})$"
            title = rf"Bounds on the marginal log likelihood -- \texttt{{{latex_dataset_name}}}"
            # When plotting just the values, a log scale would mess with the distances of upper and lower bound to the exact quantity
        else:
            raise NotImplementedError("Unknown term")
        ax.set_xlabel("time in seconds")
        ax.set_ylabel(ylabel)
        ax.set_title(title)

        # adjust_text(annotations, only_move={'points':'y', 'texts':'y'})
        # ax.legend()

        if save:
            file_postfix = dataset_name + kernel_name + log_ls2
            fig.savefig(fname=fig_path / f"experiment_4_{term}_{file_postfix}.pdf")
        else:
            plt.show()

        #fig.clear()
        plt.cla()
