from causaldag import igsp as igsp_func
from causaldag import unknown_target_igsp
import causaldag as cd
import random
from causaldag import partial_correlation_suffstat, partial_correlation_test, MemoizedCI_Tester
from causaldag import gauss_invariance_suffstat, gauss_invariance_test, MemoizedInvarianceTester


def igsp(all_data, interv_targets, alpha):
    nodenum = all_data[0].shape[1]
    nodes = set(range(nodenum))
    obs_samples = all_data[0]
    # print(obs)
    iv_samples_list = all_data[1:]
    # Form sufficient statistics
    obs_suffstat = partial_correlation_suffstat(obs_samples)
    invariance_suffstat = gauss_invariance_suffstat(obs_samples, iv_samples_list)
    # Create conditional independence tester and invariance tester
    ci_tester = MemoizedCI_Tester(partial_correlation_test, obs_suffstat, alpha=alpha)
    invariance_tester = MemoizedInvarianceTester(gauss_invariance_test, invariance_suffstat, alpha=alpha)
    # Run IGSP
    setting_list = [dict(interventions=list(interv_target)) for interv_target in interv_targets]
    dag_est = igsp_func(setting_list, nodes, ci_tester, invariance_tester)
    dag_est = dag_est.to_amat()[0]
    params_est = {'dag': dag_est}
    return params_est


def ut_igsp(all_data, alpha):
    nodenum = all_data[0].shape[1]
    nodes = set(range(nodenum))
    obs_samples = all_data[0]
    iv_samples_list = all_data[1:]
    # Form sufficient statistics
    obs_suffstat = partial_correlation_suffstat(obs_samples)
    invariance_suffstat = gauss_invariance_suffstat(obs_samples, iv_samples_list)
    # Create conditional independence tester and invariance tester
    ci_tester = MemoizedCI_Tester(partial_correlation_test, obs_suffstat, alpha=alpha)
    invariance_tester = MemoizedInvarianceTester(gauss_invariance_test, invariance_suffstat, alpha=alpha)
    # Run UT-IGSP
    setting_list = [dict(known_interventions=[]) for _ in range(len(iv_samples_list))]
    dag_est, targets_list_est = unknown_target_igsp(setting_list, nodes, ci_tester, invariance_tester)
    dag_est = dag_est.to_amat()[0]
    params_est = {'dag': dag_est, 'targets_list': targets_list_est}
    return params_est
