""" 
Explore the relationship between v_{R2}, v_{Var}, and v_{CEV}.
"""

from auxiliary.vsb_emergence import vsb_investigation, approx_b
from CDExperimentSuite_DEV import Scalers, utils
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from copy import deepcopy
import numpy as np

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}\usepackage{amssymb}\usepackage{amsmath}"
plt.rcParams["axes.labelsize"] = 28
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18


def vsb_r2_heatmap(df):
    """Plot heatmap from data generated by vsb_investigation"""
    df.drop(columns=["d", "iter"], inplace=True)
    df["Eln|w|"] = np.round(df["Eln|w|"], 1)

    if ("R2" in df.columns) and ("var" in df.columns):
        df["ratio"] = df["R2"] / df["var"]
        df.drop(columns=["R2", "var"], inplace=True)
        ratio_name = r"$\dfrac{\mathbf{v}_{R^2}}{\mathbf{v}_{Var}}$"
    elif ("var" in df.columns) and ("CEV" in df.columns):
        df["ratio"] = df["var"] / df["CEV"]
        df.drop(columns=["var", "CEV"], inplace=True)
        ratio_name = r"$\dfrac{\mathbf{v}_{Var}}{\mathbf{v}_{CEV}}$"
    elif ("R2" in df.columns) and ("CEV" in df.columns):
        df["ratio"] = df["R2"] / df["CEV"]
        df.drop(columns=["R2", "CEV"], inplace=True)
        ratio_name = r"$\dfrac{\mathbf{v}_{R^2}}{\mathbf{v}_{CEV}}$"
    else:
        raise NotImplementedError()
    df_gb = df.groupby(by=["x", "Eln|w|"]).aggregate("mean")
    df_gb.reset_index(inplace=True)
    table = df_gb.pivot(index="Eln|w|", columns="x", values="ratio")
    img = np.flip(table.to_numpy(), axis=0)

    ## plot
    _ = plt.figure(figsize=(9, 6))
    ax = plt.gca()
    p = ax.imshow(X=img, cmap="RdGy_r")

    # *2 for x -> \lambda
    ax.set_xticks(
        ticks=list(range(img.shape[1])), labels=np.round(np.array(table.columns), 2)
    )

    ax.set_yticks(
        ticks=list(range(img.shape[0])), labels=np.round(list(reversed(table.index)), 2)
    )
    ax.set_xlabel(r"Average node in-degree $\gamma$")  # Graph density
    ax.set_ylabel(r"$\mathrm{\mathbb{E}}[\ln|V|]$")  # Geometric weight mean

    ## colorbar
    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.2)
    cbar = plt.colorbar(p, cax=cax)
    cbar.ax.get_yaxis().labelpad = 45
    cbar.ax.set_ylabel(ratio_name, rotation=0)
    p.set_clim([0, None])
    plt.tight_layout()


if __name__ == "__main__":
    w_ranges = []
    for i in np.arange(-1.0, 3.5, 0.5):
        a = 0.2
        res = approx_b(a=a, k=i)
        w_ranges.append((a, res[0]))
    opt = {
        "overwrite": False,
        "base_dir": f"src/results/SortabilityRelationships/",
        "scaler": Scalers.Identity(),
        # ---
        "MEC": False,
        "thres": 0.0,
        "thres_type": "standard",
        "vsb_function": utils.var_sortability,
        "R2sb_function": utils.r2_sortability,
        "CEVsb_function": utils.cev_sortability,
        # ---
        "n_repetitions": 2,
        "edge_types": ["fixed"],
        "edge_weights": w_ranges,
        "n_nodes": [50],
        "n_obs": [1000],
    }

    for graph in ["ER", "SF"]:
        # SF graphs with less than one node per edge would be empty
        if graph == "ER":
            opt["edges"] = [0.4] + list(np.arange(1.5, 9, 1))
        else:
            opt["edges"] = list(np.arange(1, 9, 1))

        opt_r2_vsb = deepcopy(opt)
        opt_r2_vsb["base_dir"] += graph
        opt_r2_vsb["exp_name"] = "r2_var"
        opt_r2_vsb["graphs"] = [graph]
        opt_r2_vsb["edge_types"] = ["fixed"]
        opt_r2_vsb["noise_distributions"] = [
            utils.NoiseDistribution("gauss", "uniform", (0.5, 2))
        ]
        vsb_functions = {r"R2": utils.r2_sortability, r"var": utils.var_sortability}
        opt_r2_vsb = utils.Options(**opt_r2_vsb)
        vsb_investigation(
            opt_r2_vsb, vsb_functions=vsb_functions, plot_fun=vsb_r2_heatmap
        )
