import numpy as np
import pandas as pd
from dgp.dag_data import create_data, create_data_2
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import kci, fastkci
from sklearn.metrics import precision_recall_fscore_support
import os
import time
import argparse


def simulation(n_obs, n_nodes, spar, use_gp):
    res = []
    SLURM_JOB_ID = os.getenv('SLURM_JOB_ID')
    os.mkdir(f"saved_dags/{SLURM_JOB_ID}")

    for i in range(50):
        X, dag = create_data(sample_size=n_obs, num_nodes=n_nodes, sparsity=spar)
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/true_{i}.csv", dag, fmt='%d')

        true_skeleton = dag[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")

        start_time = time.time()
        g = pc(X, indep_test=kci, alpha=0.05, approx=False, use_gp=use_gp, show_progress=False)
        kci_time = time.time() - start_time
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/kci_{i}.csv", g.G.graph, fmt='%d')

        est_skeleton = ((np.abs(g.G.graph) + np.abs(g.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1,1).astype("int")
        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("KCI", 1, 1, 1, n_obs, kci_time))       
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("KCI", pre[1], rec[1], f1[1], n_obs, kci_time))

        start_time = time.time()
        g2 = pc(X, indep_test=fastkci, alpha=0.05, approx=False, use_gp=use_gp, K=3, J=16, show_progress=False)
        kci_time = time.time() - start_time
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/fkcik3_{i}.csv", g2.G.graph, fmt='%d')

        est_skeleton = ((np.abs(g2.G.graph) + np.abs(g2.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1,1).astype("int")
        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("Fast KCI (K=3)", 1, 1, 1, n_obs, kci_time))
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("Fast KCI (K=3)", pre[1], rec[1], f1[1], n_obs, kci_time))

        start_time = time.time()
        g3 = pc(X, indep_test=fastkci, alpha=0.05, approx=False, use_gp=use_gp, K=10, J=16, show_progress=False)
        kci_time = time.time() - start_time
        est_skeleton = ((np.abs(g3.G.graph) + np.abs(g3.G.graph.T)) != 0)[np.triu_indices_from(dag)].reshape(-1, 1).astype("int")
        np.savetxt(f"saved_dags/{SLURM_JOB_ID}/fkcik10_{i}.csv", g3.G.graph, fmt='%d')

        if np.sum(true_skeleton + est_skeleton) == 0:
            res.append(("Fast KCI (K=10)", 1, 1, 1, n_obs, kci_time))
        else:
            pre, rec, f1, _ = precision_recall_fscore_support(true_skeleton, est_skeleton, zero_division=0)
            res.append(("Fast KCI (K=10)", pre[1], rec[1], f1[1], n_obs, kci_time))

        pd.DataFrame(res, columns=["type", "precision", "recall", "f1", "n", "time"]).to_csv(f"simulation_results_{SLURM_JOB_ID}.csv", index=False)

    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Script to run pc experiment.')
    parser.add_argument('--num_obs', type=int, help='Number of Observations in generation')
    parser.add_argument('--num_nodes', type=int, help='Number of Nodes in generation')
    parser.add_argument('--sparsity', type=float, help='Sparsity in generation')
    parser.add_argument('--use_gp', type=int, help='Use GP in KCI')
    args = parser.parse_args()
    print(f"n = {args.num_obs}, nodes = {args.num_nodes}, sparsity = {args.sparsity}, GP = {bool(args.use_gp)}")

    simulation(args.num_obs, args.num_nodes, args.sparsity, bool(args.use_gp))