

from utils.MAG_tools_copy import Checker




def get_equiv_class(
        nodenum,                # int
        dag_edges,              # list of tuples (i,j) with 0 <= i,j <= nodenum-1
        interv_targets,         # list of subsets of {0,1,...,nodenum-1}
        selection_parents,      # list of subsets of {0,1,...,nodenum-1}
):
    '''
    :return: a dictionary with keys as follows, and value (i,j)s with 0 <= i,j <= nodenum-1
        'NO': set of (i,j) that must be nonadjacent.                                        [(i,j) is in iff (j,i) is in]
        '->': set of (i,j) that must be directed i->j edge in original DAG
        '<->': set of (i,j) that must be nonadjacent, and being selected in original DAG.   [(i,j) is in iff (j,i) is in]
        '⚬->': set of (i,j) that is either i->j or selected in original DAG.
        '⚬-⚬': set of (i,j) that each endpoint may vary in the equivalence class.           [(i,j) is in iff (j,i) is in]
    '''
    ckr_dag_edges = [(i+1, j+1) for i, j in dag_edges]
    interv_targets = [tuple(sorted([i+1 for i in interv_target])) for interv_target in interv_targets]
    selection_parents = [tuple(sorted([i+1 for i in selection_parent])) for selection_parent in selection_parents]
    ckr = Checker(nodenum, ckr_dag_edges, interv_targets, selection_parents)
    resdict = ckr.run_algo()
    return {k: {(i-1, j-1) for i, j in v} for k, v in resdict.items()}



if __name__ == '__main__':
    nodenum = 6
    dag_edges = [(0, 2), (0, 3), (1, 2), (3, 2), (4, 3), (5, 1)]
    interv_targets = [(0, 1), (4,)]
    selection_parents = [(4, 5)]
    print(get_equiv_class(nodenum, dag_edges, interv_targets, selection_parents))