""" 
Compare causal structure learning performance of exploiting R² to established algorithms.
"""
import CDExperimentSuite_DEV as CDES
from copy import deepcopy
from auxiliary.causal_discovery import *
import matplotlib.pyplot as plt

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
plt.rcParams["axes.labelsize"] = 22
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["lines.markersize"] = 12


def main(opt):
    """MAIN vanilla run"""
    opt = deepcopy(opt)
    opt["base_dir"] += "MAIN"
    # raw dat
    opt_raw = deepcopy(opt)
    opt_raw["exp_name"] = "_raw"
    opt_raw["scaler"] = CDES.Scalers.Identity()
    opt_raw = CDES.utils.Options(**opt_raw)
    # standardized run
    opt_std = deepcopy(opt)
    opt_std["exp_name"] = "_std"
    opt_std["scaler"] = CDES.Scalers.Normalizer()
    opt_std = CDES.utils.Options(**opt_std)
    ## run
    run(opt_raw)
    run(opt_std)
    ## visualize
    viz_compare(opt, opt_raw, opt_std)


def run(opt):
    """Complete experiment"""
    data(opt)
    experiment(opt)
    evaluate(opt)
    pass


if __name__ == "__main__":
    opt = {
        "overwrite": False,
        "base_dir": f"src/results/CausalDiscovery/",
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": CDES.utils.var_sortability,
        "R2sb_function": CDES.utils.r2_sortability,
        "CEVsb_function": CDES.utils.cev_sortability,
        # ---
        "n_repetitions": 10,
        "graphs": ["ER"],
        "edges": [2],
        "edge_types": ["fixed"],
        "noise_distributions": [
            CDES.utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0))
        ],
        "edge_weights": [(0.5, 2.0)],
        "n_nodes": [10],
        "n_obs": [1000],
    }

    ## Runs
    main(opt)
