""" 
Show algorithm performance for graphs with different R²-sortability.
"""
from CDExperimentSuite_DEV import *
from copy import deepcopy
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
plt.rcParams["axes.labelsize"] = 25
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["legend.title_fontsize"] = 16
plt.rcParams["lines.linewidth"] = 3
plt.rcParams["lines.markersize"] = 12


def data(opt):
    """Data Generation"""
    DataGenerator().generate_and_save(opt)


def experiment(opt):
    """Experiment"""
    expR = ExperimentRunner(opt)
    expR.sortnregressIC()
    expR.sortnregressIC_R2()
    expR.randomregressIC()


def evaluate(opt):
    """Evaluation"""
    Evaluator(opt).evaluate(thresholding=opt.thres_type)


def visualize(opt):
    """Visualization"""
    viz = Visualizer(opt)
    for acc_measure in ["shd"]:
        viz.sortability(
            acc_measure=acc_measure, custom_name="_R2", stb_name="R2sortability"
        )


def ER(opt):
    opt_raw = deepcopy(opt)
    opt_raw["base_dir"] += "ER"
    opt_raw["graphs"] = ["ER"]
    opt_raw["edge_weights"] = [(0.5, 1)]
    opt_raw["exp_name"] = "_std"
    opt_raw["scaler"] = Scalers.Normalizer()
    opt_raw = utils.Options(**opt_raw)
    run(opt_raw)


def run(opt):
    """Complete experiment"""
    data(opt)
    experiment(opt)
    evaluate(opt)
    visualize(opt)


if __name__ == "__main__":
    opt = {
        "overwrite": False,
        "base_dir": f"src/results/PerformanceRsb/",
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": utils.var_sortability,
        "R2sb_function": utils.r2_sortability,
        "CEVsb_function": utils.cev_sortability,
        # ---
        "n_repetitions": 100,
        "graphs": None,
        "edges": [2],
        "edge_types": ["fixed"],
        "noise_distributions": [
            utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0)),
        ],
        "edge_weights": None,
        "n_nodes": [20],
        "n_obs": [1000],
    }

    ER(opt)

    plt.rcParams["axes.labelsize"] = 15
    plt.rcParams["xtick.labelsize"] = 12
    plt.rcParams["ytick.labelsize"] = 12
    plt.rcParams["lines.linewidth"] = 2

    ## show distributions
    paths = ["src/results/PerformanceRsb/ER/ER_std/_eval/standard_0.csv"]
    names = ["ER"]
    for idx, p in enumerate(paths):
        plt.close()
        df = pd.read_csv(p)
        plt.figure(figsize=(5, 2))
        p = sns.histplot(df.R2sortability, kde=True)
        p.set_xlabel(r"$R^2$-sortability")
        # p.set_ylabel('')
        plt.tight_layout()
        plt.savefig(
            f"src/results/PerformanceRsb/R2distributions_{names[idx]}.pdf",
            bbox_inches="tight",
        )
