""" 
Compare causal structure learning performance on Sachs et al. (2005) data.
"""
import CDExperimentSuite_DEV as CDES
from CDExperimentSuite_DEV import utils
from copy import deepcopy
from auxiliary.causal_discovery import *
import numpy as np
import pandas as pd
import pickle as pk
from sklearn.utils import resample


def write_sachs():
    """
    Write Sachs data in ExperimentSuite format.
    """
    p = "src/data/Sachs/sachs.csv"
    X_orig = pd.read_csv(p).to_numpy()
    fname = "src/results/CausalDiscovery/Sachs"
    utils.create_folder(fname)
    utils.create_folder(fname + "/Sachs_raw")
    utils.create_folder(fname + "/Sachs_raw/_data")
    utils.create_folder(fname + "/Sachs_raw/_data/Sachs")

    for i in range(30):
        X = resample(X_orig)

        B = pd.read_csv("src/data/Sachs/consensus_adj_mat_raw.csv").to_numpy()
        print("var-sortability:", utils.var_sortability(X, B))
        print("R2-sortability:", utils.r2_sortability(X, B))

        # generate dataset from sachs in our fashion
        params = utils.DatasetParameters(
            graph_type="sachs",
            edge_type="fixed",
            x=None,
            noise_dist=None,
            noise_sigma_dist=None,
            noise_sigma_lims=None,
            edge_weight_range=None,
            n_nodes=len(B),
            n_obs=len(X),
            random_seed=None,
        )
        dataset = utils.Dataset(
            parameters=params,
            W_true=B,
            B_true=B,
            data=X,
            hash=123,
            scaler=None,
            scaling_factors=np.ones(len(B)),
            sigma=None,
            vars=np.var(X, axis=0),
            R2=utils.r2coef(X.T),
            scaled_vars=np.var(X, axis=0),
            varsortability=utils.var_sortability(X, B),
            R2sortability=utils.r2_sortability(X, B),
            CEVsortability=utils.cev_sortability(X, B),
        )
        dname = fname + f"/Sachs_raw/_data/Sachs/dataset_{i}.pk"
        with open(dname, "wb") as f:
            pk.dump(dataset, f)


def run():
    """Complete experiment"""

    opt = {
        "overwrite": True,
        "base_dir": f"src/results/CausalDiscovery/Sachs",
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": utils.var_sortability,
        "R2sb_function": utils.r2_sortability,
        "CEVsb_function": utils.cev_sortability,
        # ---
        "n_repetitions": 30,
        "graphs": ["ER"],
        "edges": [2],
        "edge_types": ["fixed"],
        "noise_distributions": [
            utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0))
        ],
        "edge_weights": [(0.5, 2.0)],
        "n_nodes": [11],
        "n_obs": [1000],
    }

    # raw dat
    opt_raw = deepcopy(opt)
    opt_raw["exp_name"] = "_raw"
    opt_raw["scaler"] = CDES.Scalers.Identity()
    opt_raw = utils.Options(**opt_raw)

    # runs
    expR = CDES.ExperimentRunner(opt_raw)
    expR.sortnregressIC_R2()
    expR.sortnregressIC()
    expR.randomregressIC()
    # evaluate
    evaluate(opt_raw)


def show_results():
    res = pd.read_csv(
        "src/results/CausalDiscovery/Sachs/Sachs_raw/_eval/standard_0.csv"
    )
    dff = res.loc[:, ["algorithm", "sid", "shd"]]
    dff.rename(
        columns={"algorithm": "Algorithm", "sid": "SID", "shd": "SHD"}, inplace=True
    )
    dff.sort_values(by=["Algorithm"], inplace=True)
    df = dff.groupby(["Algorithm"]).aggregate(np.mean).reset_index()
    df["SHD_min"] = dff.groupby(["Algorithm"])["SHD"].agg(np.min).to_list()
    df["SHD_max"] = dff.groupby(["Algorithm"])["SHD"].agg(np.max).to_list()
    df["SHD"] = df.apply(
        lambda x: f"{x.SHD:.2f}\; [{x.SHD_min:.0f}, {x.SHD_max:.0f}]", axis=1
    )
    df.drop(columns=["SHD_min", "SHD_max"], inplace=True)
    df["Algorithm"] = df["Algorithm"].replace(
        {
            "sortnregressIC_R2": "R2-SortnRegress",
            "randomregressIC": "RandomRegress",
            "sortnregressIC": "Var-SortnRegress",
        }
    )
    print(df)
    with open("src/results/CausalDiscovery/Sachs/Sachs_raw/sachs_table.tex", "w") as f:
        f.write(df.to_latex(index=False))
    print("empty graph:", res.loc[1, ["shd"]].to_dict())
    print("vsb:", res["varsortability"].mean(), res["varsortability"].std())
    print("rsb:", res["R2sortability"].mean(), res["R2sortability"].std())


if __name__ == "__main__":
    utils.set_random_seed(42)
    write_sachs()
    run()
    show_results()
