from causallearn.search.ConstraintBased.PC import pc
import rpy2.robjects as robjects
robjects.r('.libPaths(c("/work/p1ux195/Rpackages/42", .libPaths()))')

import os
import time

from causalAssembly.drf_fitting import fit_drf
from causalAssembly.models_dag import ProductionLineGraph

from sklearn.metrics import precision_recall_fscore_support

import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings("ignore")

import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Script to run causal ssembly on stations.')
    parser.add_argument('--station', type=int, help='Station Number')
    args = parser.parse_args()

    print("Starting")
    line_data = ProductionLineGraph.get_data()
    n_obs = 2000

    df_line_data = line_data.sample(n_obs, replace=False)

    assembly_line = ProductionLineGraph.get_ground_truth()
    assembly_line.cells[f"Station{args.station}"].drf = fit_drf(graph=assembly_line.cells[f"Station{args.station}"], data=df_line_data)

    SLURM_JOB_ID = os.getenv('SLURM_JOB_ID')
    os.mkdir(f"saved_dags/{SLURM_JOB_ID}")
    res = []

    for i in range(50):
        print(f"Rep: {i}")
        df = assembly_line.cells[f"Station{args.station}"].sample_from_drf(size=n_obs)
        df = (df - df.mean()) / df.std()
        X = df.values
        dag = assembly_line.cells[f"Station{args.station}"].ground_truth.values
        true_skeleton = dag[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/true_{i}.csv", dag, fmt='%d')

        true_skeleton = dag[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")

        start_time = time.time()
        g = pc(X, indep_test="kci", alpha=0.05, approx=False, use_gp=False, show_progress=False)
        kci_time = time.time() - start_time
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/kci_{i}.csv", g.G.graph, fmt='%d')

        est_skeleton = ((np.abs(g.G.graph) + np.abs(g.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("KCI", 1, 1, 1, kci_time))
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("KCI", pre[1], rec[1], f1[1], kci_time))

        start_time = time.time()
        g3 = pc(X, indep_test="fastkci", alpha=0.05, approx=False, use_gp=False, K=10, J=8, show_progress=False)
        kci_time = time.time() - start_time
        est_skeleton = ((np.abs(g3.G.graph) + np.abs(g3.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/fkcik10_{i}.csv", g3.G.graph, fmt='%d')

        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("Fast KCI (K=10)", 1, 1, 1, kci_time))
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("Fast KCI (K=10)", pre[1], rec[1], f1[1], kci_time))

        pd.DataFrame(res, columns=["type", "precision", "recall",
                                   "f1", "time"]).to_csv(f"simulation_results_{SLURM_JOB_ID}.csv", index=False)
