from causallearn.utils.FastKCI.FastKCI import FastKCI_CInd
from causallearn.utils.KCI.KCI import KCI_CInd
from dgp.dgp import kci_paper_type_2, type2_x_cause_y
import time
import pandas as pd
import os
import argparse
from joblib import Parallel, delayed


def simulation_fast(k, d, n, n_reps):
    results = []

    for r in range(n_reps):
        data_x, data_y, data_z = kci_paper_type_2(n=n, d=d, is_type2=False, k=k)

        param_sets = [(3, 16), (10, 16), (3, 128), (10, 128)]
        for u, j in param_sets:
            start_time = time.time()
            pvalue, tstat = FastKCI_CInd(K=u, J=j, use_gp=True).compute_pvalue(data_x, data_y, data_z)
            fastkci_time = time.time() - start_time
            results.append((k, d, f'FastKCI(K={u}, J={j})', pvalue, tstat, fastkci_time, n, True))
        df = pd.DataFrame(results, columns=["k", "d", "type", "p value", "tstat", "time", "n", "type 2"])
        df.to_csv(f"simulation_results_{os.getenv('SLURM_JOB_ID')}_fast.csv", index=False)
    return


def simulation_old(k, d, n):
    data_x, data_y, data_z = kci_paper_type_2(n=n, d=d, is_type2=False, k=k)

    start_time = time.time()
    pvalue, tstat = KCI_CInd(approx=False, use_gp=True).compute_pvalue(data_x, data_y, data_z)
    kci_time = time.time() - start_time

    return (k, d, 'KCI', pvalue, tstat, kci_time, n, True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Script to run comparisons between fastkci and kci.')
    parser.add_argument('--num_k', type=int, help='Number of Gaussians in generation')
    parser.add_argument('--num_d', type=int, help='Size of Z in generation')
    parser.add_argument('--num_obs', type=int, help='Number of Obs in generation')
    parser.add_argument('--n_reps', type=int, help='Number of Repetitions')
    args = parser.parse_args()

    simulation_fast(args.num_k, args.num_d, args.num_obs, args.n_reps)
    results = Parallel(n_jobs=16)(delayed(simulation_old)(args.num_k, args.num_d, args.num_obs) for i in range(args.n_reps))
    df = pd.DataFrame(results, columns=['k', 'd', 'type', 'p value', 'tstat', 'time', "n", "type 2"])
    df.to_csv(f"simulation_results_{os.getenv('SLURM_JOB_ID')}.csv", index=False)
