"""
Plot the parameter combinations in terms of weight distributions and 
graph densities that give rise to different levels of var-sortability.
"""
import CDExperimentSuite_DEV as CDES
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
from scipy.optimize import fsolve
import os

plt.rcParams["text.usetex"] = True  # use system latex
plt.rcParams["text.latex.preamble"] = r"\usepackage{amssymb}\usepackage{bm}"


def run_check(a, b, k):
    """
    Check if log of weight distribution was approximated correctly
    """
    n = int(1e7)
    sample = np.random.uniform(a, b, n)

    log_mean = np.mean(np.log(sample))
    F = lambda x: np.exp(x) * (x - 1) / (b - a)
    ana_mean = F(np.log(b)) - F(np.log(a))

    print(f"{k=}, {log_mean=}, {ana_mean=}")


def approx_b(a, k):
    """
    For W ~ U(a, b), approximate which b to choose for a given E[ln|P_\omega|] = k
    """
    fun = lambda b: 1 / (b - a) * (b * (np.log(b) - 1) - a * (np.log(a) - 1)) - k
    res = fsolve(fun, np.exp(k))

    # we want a unique solution for b and b > a
    assert len(res) == 1
    assert res[0] > a

    return res


def plot_heatmap(df):
    """Plot heatmap from data generated by vsb_investigation"""
    df.drop(columns=["d", "iter"], inplace=True)
    df["Eln|w|"] = np.round(df["Eln|w|"], 1)
    df_gb = df.groupby(by=["x", "Eln|w|"]).aggregate("mean")
    df_gb.reset_index(inplace=True)
    table = df_gb.pivot(index="Eln|w|", columns="x", values="R2")
    img = np.flip(table.to_numpy(), axis=0)

    ## plot
    # pal = sns.diverging_palette(250, 30, l=65, center="dark", as_cmap=True)
    _ = plt.figure(figsize=(9, 6))  # (7.5, 5)
    ax = plt.gca()
    p = ax.imshow(X=img)

    # *2 for x -> \lambda
    ax.set_xticks(
        ticks=list(range(img.shape[1])), labels=np.round(np.array(table.columns), 2)
    )

    ax.set_yticks(
        ticks=list(range(img.shape[0])), labels=np.round(list(reversed(table.index)), 2)
    )
    ax.set_xlabel(r"Average node in-degree $\gamma$")  # Graph density
    ax.set_ylabel(r"$\mathrm{\mathbb{E}}[\ln|V|]$")  # Geometric weight mean

    ## colorbar
    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.2)
    cbar = plt.colorbar(p, cax=cax)
    cbar.ax.get_yaxis().labelpad = 30
    cbar.ax.set_ylabel(r"$\mathbf{v}_{R^2}$", rotation=0)
    p.set_clim([0, 1])
    plt.tight_layout()


def vsb_investigation(opt, vsb_functions, plot_fun, verbose=False, plot_and_save=True):
    """
    See how the var-sortability evolves for increasing graph size.
    args:
        opt: experiment options
        vsb_functions (dict): dictionary of {name: function} of different vsb functions.
        plot_fun (function): a function taking a dataframe as parameter
    """

    # Check if we already have some data:
    data_path = opt.base_dir + "/" + opt.exp_name + ".csv"

    if not os.path.exists(data_path) or opt.overwrite:
        # create all run combinations
        DG = CDES.DataGenerator()
        datasets = DG.generate(opt)

        # exp_params = opt.exp_params
        vsb_names = list(vsb_functions.keys())
        colnames = ["x", "d", "Eln|w|", "iter", "w_low", "w_high"] + vsb_names
        res = np.zeros(shape=(len(datasets), len(colnames)))

        print(f"Starting {len(datasets)} runs")
        prev_d = 0
        for idx, dat in enumerate(datasets):
            x, d, w = (
                dat.parameters.x,
                dat.parameters.n_nodes,
                dat.parameters.edge_weight_range,
            )
            elnw = CDES.utils.elnw(dat.parameters)
            if verbose and dat.parameters.n_nodes != prev_d:
                prev_d = d
                print(f"\nestimating for {x=}, {d=}, {w=}...")
                print(elnw)
            w_low, w_high = tuple(round(i, 2) for i in dat.parameters.edge_weight_range)
            res[idx, :6] = x, d, elnw, dat.parameters.random_seed, w_low, w_high
            vsb_idx = 6
            for k, v in vsb_functions.items():
                if k == "Bound":
                    res[idx, vsb_idx] = v(dat.parameters)
                else:
                    res[idx, vsb_idx] = v(dat.data, dat.W_true)
                vsb_idx += 1
        df = pd.DataFrame(res, columns=colnames)
        # create output folder, snapshot opt, and save data
        CDES.utils.create_folder(opt.base_dir, overwrite=False)
        CDES.utils.snapshot(opt)
        df.to_csv(data_path, index=False)
    else:
        df = pd.read_csv(data_path)

    df["Eln|w|"] = np.round(df["Eln|w|"], 1)
    df["Eln|w|"] = df["Eln|w|"].replace({-0.0: 0.0})
    CDES.utils.create_folder(opt.base_dir, overwrite=False)
    if plot_and_save:
        plot_fun(df)
        plt.tight_layout()
        plt.savefig(opt.base_dir + "/" + opt.exp_name + ".pdf")
        plt.close("all")
    else:
        return df
