from causallearn.search.ConstraintBased.PC import pc
import rpy2.robjects as robjects
robjects.r('.libPaths(c("/work/p1ux195/Rpackages/42", .libPaths()))')

import os
import time

from causalAssembly.drf_fitting import fit_drf
from causalAssembly.models_dag import ProductionLineGraph

from sklearn.metrics import precision_recall_fscore_support

import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    print("Starting")
    line_data = ProductionLineGraph.get_data()
    n_obs = 5000

    df_line_data = line_data.sample(n_obs, replace=False)

    assembly_line = ProductionLineGraph.get_ground_truth()
    assembly_line.drf = fit_drf(graph=assembly_line, data=df_line_data)

    SLURM_JOB_ID = os.getenv('SLURM_JOB_ID')
    os.mkdir(f"saved_dags/{SLURM_JOB_ID}")
    res = []

    for i in range(20):
        print(f"Rep: {i}")
        df = assembly_line.sample_from_drf(size=n_obs)
        df = (df - df.mean()) / df.std()
        X = df.values
        dag = assembly_line.ground_truth.values
        true_skeleton = dag[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/true_{i}.csv", dag, fmt='%d')

        true_skeleton = dag[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")

        start_time = time.time()
        g3 = pc(X, indep_test="fastkci", alpha=0.05, approx=False, use_gp=False, K=50, J=8, show_progress=False)
        kci_time = time.time() - start_time
        est_skeleton = ((np.abs(g3.G.graph) + np.abs(g3.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/fkcik10_{i}.csv", g3.G.graph, fmt='%d')

        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("Fast KCI (K=50)", 1, 1, 1, kci_time))
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("Fast KCI (K=50)", pre[1], rec[1], f1[1], kci_time))

        pd.DataFrame(res, columns=["type", "precision", "recall", "f1", "time"]).to_csv(f"simulation_results_{SLURM_JOB_ID}.csv", index=False)
