# Copyright 2021 Juan L Gamella

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:

# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.

# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.

# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""The main module, containing the implementation of GES, including
the logic for the insert, delete and turn operators. The
implementation is directly based on two papers:

  1. The 2002 GES paper by Chickering, "Optimal Structure
  Identification With Greedy Search" -
  https://www.jmlr.org/papers/volume3/chickering02b/chickering02b.pdf

  2. For the turn operator, the 2012 GIES paper by Hauser & Bühlmann,
  "Characterization and Greedy Learning of Interventional Markov
  Equivalence Classes of Directed Acyclic Graphs" -
  https://www.jmlr.org/papers/volume13/hauser12a/hauser12a.pdf

Further credit is given where due.

Additional modules / packages:

  - ges.utils contains auxiliary functions and the logic to transform
    a PDAG into a CPDAG, used after each application of an operator.
  - ges.scores contains the modules with the score classes:
      - ges.scores.decomposable_score contains the base class for
        decomposable score classes (see that module for more details).
      - ges.scores.gauss_obs_l0_pen contains a cached implementation
        of the gaussian likelihood BIC score used in the original GES
        paper.
   - ges.test contains the modules with the unit tests and tests
     comparing against the algorithm's implementation in the R package
     'pcalg'.

"""

import numpy as np
import os, time
from scipy.stats import norm
from csl.search.ScoreBased import utils
from csl.graph.GeneralGraph import GeneralGraph
from csl.graph.GraphNode import GraphNode
from csl.graph.Edge import Edge
from csl.graph.Endpoint import Endpoint
MAXTIMEPERPHASEPERNODE = 30 # np.inf; 30 may still be too slow: i.e., for a 30 nodes graph, it takes max 30min (15min per pahse)
BIC_L0_PENALTY = 2

def fisherz_stat(rho, N): # here we omit the N-3-d's d term
    return np.sqrt(N - 3) * abs(0.5 * np.log((1 + rho) / (1 - rho)))
def fisherz_rho(stat, N):
    return 1. - 2. / (1. + np.exp(2 * stat / np.sqrt(N - 3)))  # here we omit the N-3-d's d term
def bic_score_change(rho, N):
    return -0.5 * N * np.log(1 - rho ** 2) - np.log(N) / BIC_L0_PENALTY


def ges(data, indep_test, pval_alpha, A0=None, phases=['forward', 'backward', 'turning'], iterate=False, debug=0, bic_l0_penal=2):
    """Run GES on the given data, using the Gaussian BIC score
    (l0-penalized Gaussian Likelihood). The data is not assumed to be
    centered, i.e. an intercept is fitted.

    To use a custom score, see ges.fit.

    Parameters
    ----------
    data : numpy.array
        The n x p array containing the observations, where columns
        correspond to variables and rows to observations.
    indep_test : CIT in {'fisherz', 'zerodel_fisherz', 'kci', 'zerodel_kci'}
        For 'fisherz', it is equivalent to the BIC score with L0 penalty used (in the original version of this package).
    A0 : numpy.ndarray, optional
        The initial CPDAG on which GES will run, where where `A0[i,j]
        != 0` implies the edge `i -> j` and `A[i,j] != 0 & A[j,i] !=
        0` implies the edge `i - j`. Defaults to the empty graph
        (i.e. matrix of zeros).
    phases : [{'forward', 'backward', 'turning'}*], optional
        Which phases of the GES procedure are run, and in which
        order. Defaults to `['forward', 'backward', 'turning']`.
    iterate : bool, default=False
        Indicates whether the given phases should be iterated more
        than once.
    debug : int, optional
        If larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity.

    Returns
    -------
    estimate : numpy.ndarray
        The adjacency matrix of the estimated CPDAG.
    total_score : float
        The score of the estimate.

    Raises
    ------
    TypeError:
        If the type of some of the parameters was not expected,
        e.g. if data is not a numpy array.
    ValueError:
        If the value of some of the parameters is not appropriate,
        e.g. a wrong phase is specified.
    """
    # Initialize Gaussian BIC score (precomputes scatter matrices, sets up cache)
    # cache = GaussObsL0Pen(data, cov=cov)
    # Unless indicated otherwise, initialize to the empty graph
    global BIC_L0_PENALTY; BIC_L0_PENALTY = bic_l0_penal
    A0 = np.zeros((data.shape[1], data.shape[1])) if A0 is None else A0
    return fit(indep_test, pval_alpha, None, A0, phases, iterate, debug)


def fit(indep_test, pval_alpha, completion_algorithm=None, A0=None, phases=['forward', 'backward', 'turning'], iterate=False, debug=0):
    """Run GES using a user defined score.

    Parameters
    ----------
    score_class : ges.DecomposableScore
        an instance of a class which inherits from
        ges.decomposable_score.DecomposableScore (or defines a
        local_score function and a p attribute, see
        ges.decomposable_score for more info).
    completion_algorithm : function, optional
        the "completion algorithm" used to go from PDAG to CPDAG after
        the application of each operator. If `None`, the algorithm
        used in the original GES paper is used.
    A0 : np.array, optional
        the initial CPDAG on which GES will run, where where A0[i,j]
        != 0 implies i -> j and A[i,j] != 0 & A[j,i] != 0 implies i -
        j. Defaults to the empty graph.
    phases : [{'forward', 'backward', 'turning'}*], optional
        which phases of the GES procedure are run, and in which
        order. Defaults to ['forward', 'backward', 'turning'].
    iterate : bool, default=False
        Indicates whether the given phases should be iterated more
        than once.
    debug : int, optional
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity.

    Returns
    -------
    estimate : np.array
        the adjacency matrix of the estimated CPDAG
    total_score : float
        the score of the estimate

    """
    # Select the desired phases
    if len(phases) == 0:
        raise ValueError("Must specify at least one phase")
    # Set the completion algorithm
    completion_algorithm = utils.pdag_to_cpdag if completion_algorithm is None else completion_algorithm
    # Unless indicated otherwise, initialize to the empty graph
    A0 = np.zeros((indep_test.data.shape[1], indep_test.data.shape[1])) if A0 is None else A0
    # GES procedure
    total_score = 0 # just a placeholder here; we only need local consistency; score_class.full_score(A0)
    A, score_change = A0, np.Inf
    visited_As = {''.join(map(str, A.astype(int).flatten()))}
    # Run each phase
    while True:
        start = time.time()
        last_total_score = total_score
        for phase in phases:
            if phase == 'forward':
                fun = forward_step
            elif phase == 'backward':
                fun = backward_step
            elif phase == 'turning':
                fun = turning_step
            else:
                raise ValueError('Invalid phase "%s" specified' % phase)
            print("\nGES %s phase start" % phase) if debug else None
            print("-------------------------") if debug else None
            tic = time.time()
            while True:
                if time.time() - tic > MAXTIMEPERPHASEPERNODE * A0.shape[0]: break
                valid_operators = fun(pval_alpha, A, indep_test, max(0, debug - 1))
                valid_operators = [(score_change, new_A) for score_change, new_A, x, y, H in valid_operators if score_change > 1e-6]
                valid_operators.sort(key=lambda op: op[0], reverse=True)
                # find the operator with the biggest positive score change with the unseen A
                updated_succesfully = False
                for score_change, new_A in valid_operators:
                    new_A = completion_algorithm(new_A)
                    new_A_str = ''.join(map(str, new_A.astype(int).flatten()))
                    if new_A_str not in visited_As:
                        visited_As.add(new_A_str)
                        A = new_A
                        total_score += score_change
                        updated_succesfully = True
                        break
                if not updated_succesfully:
                    break

            print("-----------------------") if debug else None
            print("GES %s phase end" % phase) if debug else None
            print("Total score: %0.4f" % total_score) if debug else None
            [print(row) for row in A] if debug else None
        if total_score <= last_total_score or not iterate:
            break

        end = time.time()
        print(f"time cost for one step is: {end-start}")

    # return A, total_score
    directed_edges = {(i, j) for i in range(A.shape[0]) for j in range(A.shape[0]) if A[i, j] != 0 and A[j, i] == 0}
    undirected_edges = {(i, j) for i in range(A.shape[0]) for j in range(i + 1, A.shape[0]) if A[i, j] != 0 and A[j, i] != 0}

    nodes = [GraphNode(f'X{1+i}') for i in range(A.shape[0])]
    general_graph = GeneralGraph(nodes=nodes)
    for i, j in directed_edges: general_graph.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.ARROW))
    for i, j in undirected_edges: general_graph.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL))

    return A, general_graph, total_score


def forward_step(pval_alpha, A, indep_test, debug=0):
    """
    Scores all valid insert operators that can be applied to the current
    CPDAG A, and applies the highest scoring one.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of a CPDAG, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.
    cache : DecomposableScore
        an instance of the score class, which computes the change in
        score and acts as cache.
    debug : int, optional
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    score : float
        the change in score resulting from applying the highest
        scoring operator (note, can be smaller than 0).
    new_A : np.array
        the adjacency matrix of the PDAG resulting from applying the
        operator (not yet a CPDAG).

    """
    # Construct edge candidates (i.e. edges between non-adjacent nodes)
    fro, to = np.where((A + A.T + np.eye(len(A))) == 0)
    edge_candidates = list(zip(fro, to))
    # For each edge, enumerate and score all valid operators
    valid_operators = []
    print("  %d candidate edges" % len(edge_candidates)) if debug > 1 else None
    for (x, y) in edge_candidates:
        valid_operators += score_valid_insert_operators(pval_alpha, x, y, A, indep_test, debug=max(0, debug - 1))
    # Pick the edge/operator with the highest score
    return valid_operators
    # if len(valid_operators) == 0:
    #     print("  No valid insert operators remain") if debug else None
    #     return 0, A
    # else:
    #     scores = [op[0] for op in valid_operators]
    #     score, new_A, x, y, T = valid_operators[np.argmax(scores)]
    #     print("  Best operator: insert(%d, %d, %s) -> (%0.4f)" %
    #           (x, y, T, score)) if debug else None
    #     # print(new_A.astype(int))
    #     return score, new_A


def backward_step(pval_alpha, A, indep_test, debug=0):
    """
    Scores all valid delete operators that can be applied to the current
    CPDAG A, and applies the highest scoring one.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of a CPDAG, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.
    cache : DecomposableScore
        an instance of the score class, which computes the change in
        score and acts as cache.
    debug : int
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    score : float
        the change in score resulting from applying the highest
        scoring operator (note, can be smaller than 0).
    new_A : np.array
        the adjacency matrix of the PDAG resulting from applying the
        operator (not yet a CPDAG).

    """
    # Construct edge candidates:
    #   - directed edges
    #   - undirected edges, counted only once
    fro, to = np.where(utils.only_directed(A))
    directed_edges = zip(fro, to)
    fro, to = np.where(utils.only_undirected(A))
    undirected_edges = filter(lambda e: e[0] > e[1], zip(fro, to))  # zip(fro,to)
    edge_candidates = list(directed_edges) + list(undirected_edges)
    assert len(edge_candidates) == utils.skeleton(A).sum() / 2
    # For each edge, enumerate and score all valid operators
    valid_operators = []
    print("  %d candidate edges" % len(edge_candidates)) if debug > 1 else None
    for (x, y) in edge_candidates:
        valid_operators += score_valid_delete_operators(pval_alpha, x, y, A, indep_test, debug=max(0, debug - 1))
    # Pick the edge/operator with the highest score
    return valid_operators
    # if len(valid_operators) == 0:
    #     print("  No valid delete operators remain") if debug else None
    #     return 0, A
    # else:
    #     scores = [op[0] for op in valid_operators]
    #     score, new_A, x, y, H = valid_operators[np.argmax(scores)]
    #     print("  Best operator: delete(%d, %d, %s) -> (%0.4f)" %
    #           (x, y, H, score)) if debug else None
    #     # print(new_A.astype(int))
    #     return score, new_A


def turning_step(pval_alpha, A, indep_test, debug=0):
    """
    Scores all valid turn operators that can be applied to the current
    CPDAG A, and applies the highest scoring one.

    Parameters
    ----------
    A : np.array
        the adjacency matrix of a CPDAG, where A[i,j] != 0 => i -> j
        and A[i,j] != 0 & A[j,i] != 0 => i - j.
    cache : DecomposableScore
        an instance of the score class, which computes the change in
        score and acts as cache.
    debug : int
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    score : float
        the change in score resulting from applying the highest
        scoring operator (note, can be smaller than 0).
    new_A : np.array
        the adjacency matrix of the PDAG resulting from applying the
        operator (not yet a CPDAG).

    """
    # Construct edge candidates:
    #   - directed edges, reversed
    #   - undirected edges
    fro, to = np.where(A != 0)
    edge_candidates = list(zip(to, fro))
    # For each edge, enumerate and score all valid operators
    valid_operators = []
    print("  %d candidate edges" % len(edge_candidates)) if debug > 1 else None
    for (x, y) in edge_candidates:
        valid_operators += score_valid_turn_operators(x, y, A, indep_test, debug=max(0, debug - 1))
    # Pick the edge/operator with the highest score
    return valid_operators
    # if len(valid_operators) == 0:
    #     print("  No valid turn operators remain") if debug else None
    #     return 0, A
    # else:
    #     scores = [op[0] for op in valid_operators]
    #     score, new_A, x, y, C = valid_operators[np.argmax(scores)]
    #     print("  Best operator: turn(%d, %d, %s) -> (%0.4f)" % (x, y, C, score)) if debug else None
    #     # print(new_A.astype(int))
    #     return score, new_A

# --------------------------------------------------------------------
# Insert operator
#    1. definition in function insert
#    2. enumeration logic (to enumerate and score only valid
#    operators) function in score_valid_insert_operators


def insert(x, y, T, A):
    """
    Applies the insert operator:
      1) adds the edge x -> y
      2) for all t in T, orients the previously undirected edge t -> y

    Parameters
    ----------
    x : int
        the origin node (i.e. x -> y)
    y : int
        the target node
    T : iterable of ints
        a subset of the neighbors of y which are not adjacent to x
    A : np.array
        the current adjacency matrix

    Returns
    -------
    new_A : np.array
        the adjacency matrix resulting from applying the operator

    """
    # Check inputs
    T = sorted(T)
    if A[x, y] != 0 or A[y, x] != 0:
        raise ValueError("x=%d and y=%d are already connected" % (x, y))
    if len(T) == 0:
        pass
    elif not (A[T, y].all() and A[y, T].all()):
        raise ValueError("Not all nodes in T=%s are neighbors of y=%d" % (T, y))
    elif A[T, x].any() or A[x, T].any():
        raise ValueError("Some nodes in T=%s are adjacent to x=%d" % (T, x))
    # Apply operator
    new_A = A.copy()
    # Add edge x -> y
    new_A[x, y] = 1
    # Orient edges t - y to t -> y, for t in T
    new_A[T, y] = 1
    new_A[y, T] = 0
    return new_A


def score_valid_insert_operators(pval_alpha, x, y, A, indep_test, debug=0):
    """Generate and score all valid insert(x,y,T) operators involving the edge
    x-> y, and all possible subsets T of neighbors of y which
    are NOT adjacent to x.

    Parameters
    ----------
    x : int
        the origin node (i.e. x -> y)
    y : int
        the target node
    A : np.array
        the current adjacency matrix
    cache : instance of ges.scores.DecomposableScore
        the score cache to compute the score of the
        operators that are valid
    debug : int
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    valid_operators : list of tuples
        a list of tubles, each containing a valid operator, its score
        and the resulting connectivity matrix

    """
    p = len(A)
    if A[x, y] != 0 or A[y, x] != 0:
        raise ValueError("x=%d and y=%d are already connected" % (x, y))
    # One-hot encode all subsets of T0, plus one extra column to mark
    # if they pass validity condition 2 (see below)
    T0 = sorted(utils.neighbors(y, A) - utils.adj(x, A))
    if len(T0) == 0:
        subsets = np.zeros((1, p + 1), dtype=bool)
    else:
        subsets = np.zeros((2**len(T0), p + 1), dtype=bool)
        subsets[:, T0] = utils.cartesian([np.array([False, True])] * len(T0), dtype=bool)
    valid_operators = []
    print("    insert(%d,%d) T0=" % (x, y), set(T0)) if debug > 1 else None
    while len(subsets) > 0:
        print("      len(subsets)=%d, len(valid_operators)=%d" %
              (len(subsets), len(valid_operators))) if debug > 1 else None
        # Access the next subset
        T = np.where(subsets[0, :-1])[0]
        passed_cond_2 = subsets[0, -1]
        subsets = subsets[1:]
        # Check that the validity conditions hold for T
        na_yxT = utils.na(y, x, A) | set(T)
        # Condition 1: Test that NA_yx U T is a clique
        cond_1 = utils.is_clique(na_yxT, A)
        if not cond_1:
            # Remove from consideration all other sets T' which
            # contain T, as the clique condition will also not hold
            supersets = subsets[:, T].all(axis=1)
            subsets = utils.delete(subsets, supersets, axis=0)
        # Condition 2: Test that all semi-directed paths from y to x contain a
        # member from NA_yx U T
        if passed_cond_2:
            # If a subset of T satisfied condition 2, so does T
            cond_2 = True
        else:
            # Check condition 2
            cond_2 = True
            for path in utils.semi_directed_paths(y, x, A):
                if len(na_yxT & set(path)) == 0:
                    cond_2 = False
                    break
            if cond_2:
                # If condition 2 holds for NA_yx U T, then it holds for all supersets of T
                supersets = subsets[:, T].all(axis=1)
                subsets[supersets, -1] = True
        print("      insert(%d,%d,%s)" % (x, y, T), "na_yx U T = ",
              na_yxT, "validity:", cond_1, cond_2) if debug > 1 else None
        # If both conditions hold, apply operator and compute its score
        if cond_1 and cond_2:
            # Apply operator
            new_A = insert(x, y, T, A)
            # Compute the change in score
            aux = na_yxT | utils.pa(y, A)

            ########## OPTION 1: to rescale rho and use BIC score ##########
            # which partial correlation would produce the same test statistics in the full observations case?
            # and then calculate the BIC score change by this rescaled rho and the full samplesize.
            #  1. for comparison between score changes: it is still comparing the test statistics given by fisherz.
            #     the higher test statistic corresponds to the smaller pvalue, i.e., more dependent.
            #  2. for thresholding between fisherz's p-value and alpha, we do not explicitly set an alpha.
            #     instead, this is handled by the BIC score change thresholding with zero (increases or not).
            #     empirically, fisherz with alpha=0.01 and BIC-change with L0-penalty=2 gives almost the same
            #     (samplesize, minimum_rho_for_rejection) curve for samplesize between 100 and 10000.
            #     see https://gist.github.com/MarkDana/c24ad7872515cbbc0b09282b8e2483fa for details.
            # print(indep_test.__class__.__name__)
            if indep_test.__class__.__name__ == 'KCI' or 'ZeroDel_KCI':
                pvalue = indep_test(x, y, aux - {x, y})
                score_change = pval_alpha - pvalue
            else:
                pval, rho, spsz = indep_test(x, y, aux - {x, y}, return_rho=True, return_sample_size=True)
                if spsz != indep_test.sample_size:
                    rho = fisherz_rho(fisherz_stat(rho, spsz), indep_test.sample_size)  # rescale samplesize to full observations
                score_change = bic_score_change(rho, indep_test.sample_size)

            ########## OPTION 2: to directly use FisherZ test statistics and compare with alpha. no concept of BIC. ##########
            # PVAL_ALPHA = 0.01
            # STAT_ALPHA = norm.ppf(1 - PVAL_ALPHA / 2)
            # score_change = stat - STAT_ALPHA
            # # the highest score_change corresponds to the smallest pvalue, i.e., the largest partial correlation
            # #   (given current y's parents' info, x increases the most info for y) in the same samplesize case;
            # # when samplesizes are not the same, fisherz's pval already takes care of it; just like
            # # in the original BIC score: score_change = -N * np.log(1 - rho**2) - np.log(N)/2, where rho is sample partial correlation
            # # here between score_change and rho, we use fisherz to handle the samplesize difference.
            # # note that we use the tau-consistency; i.e., see GES completely as a constraint-based method.
            # # here score_change does not have any global BIC meaning; it is just for local comparison: <pval_alpha then insert edge.

            ########## OPTION 3: to directly calculate BIC score change from rho and N. ##########
            # kinda like OPTION 1, in the sense that for fisherz(rho1, N1)==fisherz(rho2, N2), we also have BIC_change(rho1, N1)~=BIC_change(rho2, N2).

            print("        new: s(%d, %s) old: s(%d, %s) new-old= %0.6f" %
                  (y, aux | {x}, y, aux, score_change)) if debug > 1 else None
            # Add to the list of valid operators
            valid_operators.append((score_change, new_A, x, y, T))

            print("    insert(%d,%d,%s) -> %0.16f" % (x, y, T, score_change)) if debug else None
    # Return all the valid operators
    return valid_operators

# --------------------------------------------------------------------
# Delete operator
#    1. definition in function delete
#    2. enumeration logic (to enumerate and score only valid
#    operators) function in valid_delete_operators


def delete(x, y, H, A):
    """
    Applies the delete operator:
      1) deletes the edge x -> y or x - y
      2) for every node h in H
           * orients the edge y -> h
           * if the edge with x is undirected, orients it as x -> h

    Note that H must be a subset of the neighbors of y which are
    adjacent to x. A ValueError exception is thrown otherwise.

    Parameters
    ----------
    x : int
        the "origin" node (i.e. x -> y or x - y)
    y : int
        the "target" node
    H : iterable of ints
        a subset of the neighbors of y which are adjacent to x
    A : np.array
        the current adjacency matrix

    Returns
    -------
    new_A : np.array
        the adjacency matrix resulting from applying the operator

    """
    H = set(H)
    # Check inputs
    if A[x, y] == 0:
        raise ValueError("There is no (un)directed edge from x=%d to y=%d" % (x, y))
    # neighbors of y which are adjacent to x
    na_yx = utils.na(y, x, A)
    if not H <= na_yx:
        raise ValueError(
            "The given set H is not valid, H=%s is not a subset of NA_yx=%s" % (H, na_yx))
    # Apply operator
    new_A = A.copy()
    # delete the edge between x and y
    new_A[x, y], new_A[y, x] = 0, 0
    # orient the undirected edges between y and H towards H
    new_A[list(H), y] = 0
    # orient any undirected edges between x and H towards H
    n_x = utils.neighbors(x, A)
    new_A[list(H & n_x), x] = 0
    return new_A


def score_valid_delete_operators(pval_alpha, x, y, A, indep_test, debug=0):
    """Generate and score all valid delete(x,y,H) operators involving the edge
    x -> y or x - y, and all possible subsets H of neighbors of y which
    are adjacent to x.

    Parameters
    ----------
    x : int
        the "origin" node (i.e. x -> y or x - y)
    y : int
        the "target" node
    A : np.array
        the current adjacency matrix
    cache : instance of ges.scores.DecomposableScore
        the score cache to compute the score of the
        operators that are valid
    debug : int
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    valid_operators : list of tuples
        a list of tubles, each containing a valid operator, its score
        and the resulting connectivity matrix

    """
    # Check inputs
    if A[x, y] == 0:
        raise ValueError("There is no (un)directed edge from x=%d to y=%d" % (x, y))
    # One-hot encode all subsets of H0, plus one column to mark if
    # they have already passed the validity condition
    na_yx = utils.na(y, x, A)
    H0 = sorted(na_yx)
    p = len(A)
    if len(H0) == 0:
        subsets = np.zeros((1, (p + 1)), dtype=bool)
    else:
        subsets = np.zeros((2**len(H0), (p + 1)), dtype=bool)
        subsets[:, H0] = utils.cartesian([np.array([False, True])] * len(H0), dtype=bool)
    valid_operators = []
    print("    delete(%d,%d) H0=" % (x, y), set(H0)) if debug > 1 else None
    while len(subsets) > 0:
        print("      len(subsets)=%d, len(valid_operators)=%d" %
              (len(subsets), len(valid_operators))) if debug > 1 else None
        # Access the next subset
        H = np.where(subsets[0, :-1])[0]
        cond_1 = subsets[0, -1]
        subsets = subsets[1:]
        # Check if the validity condition holds for H, i.e. that
        # NA_yx \ H is a clique.
        # If it has not been tested previously for a subset of H,
        # check it now
        if not cond_1 and utils.is_clique(na_yx - set(H), A):
            cond_1 = True
            # For all supersets H' of H, the validity condition will also hold
            supersets = subsets[:, H].all(axis=1)
            subsets[supersets, -1] = True
        # If the validity condition holds, apply operator and compute its score
        print("      delete(%d,%d,%s)" % (x, y, H), "na_yx - H = ",
              na_yx - set(H), "validity:", cond_1) if debug > 1 else None
        if cond_1:
            # Apply operator
            new_A = delete(x, y, H, A)
            # Compute the change in score
            aux = (na_yx - set(H)) | utils.pa(y, A) | {x}
            # print(x,y,H,"na_yx:",na_yx,"old:",aux,"new:", aux - {x})
            if indep_test.__class__.__name__ == 'KCI' or 'ZeroDel_KCI':
                pvalue = indep_test(x, y, aux - {x, y})
                score_change = -1. * (pval_alpha - pvalue)
            else:
                _, rho, spsz = indep_test(x, y, aux - {x, y}, return_rho=True, return_sample_size=True)
                if spsz != indep_test.sample_size:
                    rho = fisherz_rho(fisherz_stat(rho, spsz), indep_test.sample_size)  # rescale samplesize to full observations
                score_change = -1. * bic_score_change(rho, indep_test.sample_size)
            print("        new: s(%d, %s) old: s(%d, %s) new-old= %0.6f" %
                  (y, aux - {x}, y, aux, score_change)) if debug > 1 else None
            valid_operators.append((score_change, new_A, x, y, H))

            print("    delete(%d,%d,%s) -> %0.16f" % (x, y, H, score_change)) if debug else None
    # Return all the valid operators
    return valid_operators

# --------------------------------------------------------------------
# Turn operator
#    1. definition in function turn
#    2. enumeration logic (to enumerate and score only valid
#    operators) function in score_valid_insert_operators


def turn(x, y, C, A):
    """
    Applies the turning operator: For an edge x - y or x <- y,
      1) orients the edge as x -> y
      2) for all c in C, orients the previously undirected edge c -> y

    Parameters
    ----------
    x : int
        the origin node (i.e. x -> y)
    y : int
        the target node
    C : iterable of ints
        a subset of the neighbors of y
    A : np.array
        the current adjacency matrix

    Returns
    -------
    new_A : np.array
        the adjacency matrix resulting from applying the operator

    """
    # Check inputs
    if A[x, y] != 0 and A[y, x] == 0:
        raise ValueError("The edge %d -> %d is already exists" % (x, y))
    if A[x, y] == 0 and A[y, x] == 0:
        raise ValueError("x=%d and y=%d are not connected" % (x, y))
    if not C <= utils.neighbors(y, A):
        raise ValueError("Not all nodes in C=%s are neighbors of y=%d" % (C, y))
    if len({x, y} & C) > 0:
        raise ValueError("C should not contain x or y")
    # Apply operator
    new_A = A.copy()
    # Turn edge x -> y
    new_A[y, x] = 0
    new_A[x, y] = 1
    # Orient edges c -> y for c in C
    new_A[y, list(C)] = 0
    return new_A


def score_valid_turn_operators(x, y, A, indep_test, debug=0):
    """Generate and score all valid turn(x,y,C) operators that can be
    applied to the edge x <- y or x - y, iterating through the valid
    subsets C of neighbors of y.

    Parameters
    ----------
    x : int
        the origin node (i.e. orient x -> y)
    y : int
        the target node
    A : np.array
        the current adjacency matrix
    cache : instance of ges.scores.DecomposableScore
        the score cache to compute the score of the
        operators that are valid
    debug : int
        if larger than 0, debug are traces printed. Higher values
        correspond to increased verbosity

    Returns
    -------
    valid_operators : list of tuples
        a list of tubles, each containing a valid operator, its score
        and the resulting connectivity matrix

    """
    if A[x, y] != 0 and A[y, x] == 0:
        raise ValueError("The edge %d -> %d already exists" % (x, y))
    if A[x, y] == 0 and A[y, x] == 0:
        raise ValueError("x=%d and y=%d are not connected" % (x, y))
    # Different validation/scoring logic when the edge to be turned is
    # essential (x <- x) or not (x - y)
    if A[x, y] != 0 and A[y, x] != 0:
        return score_valid_turn_operators_undir(x, y, A, indep_test, debug=debug)
    else:
        return score_valid_turn_operators_dir(x, y, A, indep_test, debug=debug)


def score_valid_turn_operators_dir(x, y, A, indep_test, debug=0):
    """Logic for finding and scoring the valid turn operators that can be
    applied to the edge x <- y.

    Parameters
    ----------
    x : int
        the origin node (i.e. x -> y)
    y : int
        the target node
    A : np.array
        the current adjacency matrix
    cache : instance of ges.scores.DecomposableScore
        the score cache to compute the score of the
        operators that are valid
    debug : bool or string
        if debug traces should be printed (True/False). If a non-empty
        string is passed, traces are printed with the given string as
        prefix (useful for indenting the prints from a calling
        function)

    Returns
    -------
    valid_operators : list of tuples
        a list of tubles, each containing a valid operator, its score
        and the resulting connectivity matrix

    """
    # One-hot encode all subsets of T0, plus one extra column to mark
    # if they pass validity condition 2 (see below). The set C passed
    # to the turn operator will be C = NAyx U T.
    p = len(A)
    T0 = sorted(utils.neighbors(y, A) - utils.adj(x, A))
    if len(T0) == 0:
        subsets = np.zeros((1, p + 1), dtype=bool)
    else:
        subsets = np.zeros((2**len(T0), p + 1), dtype=bool)
        subsets[:, T0] = utils.cartesian([np.array([False, True])] * len(T0), dtype=bool)
    valid_operators = []
    print("    turn(%d,%d) T0=" % (x, y), set(T0)) if debug > 1 else None
    while len(subsets) > 0:
        print("      len(subsets)=%d, len(valid_operators)=%d" %
              (len(subsets), len(valid_operators))) if debug > 1 else None
        # Access the next subset
        T = np.where(subsets[0, :-1])[0]
        passed_cond_2 = subsets[0, -1]
        subsets = subsets[1:]  # update the list of remaining subsets
        # Check that the validity conditions hold for T
        C = utils.na(y, x, A) | set(T)
        # Condition 1: Test that C = NA_yx U T is a clique
        cond_1 = utils.is_clique(C, A)
        if not cond_1:
            # Remove from consideration all other sets T' which
            # contain T, as the clique condition will also not hold
            supersets = subsets[:, T].all(axis=1)
            subsets = utils.delete(subsets, supersets, axis=0)
        # Condition 2: Test that all semi-directed paths from y to x contain a
        # member from C U neighbors(x)
        if passed_cond_2:
            # If a subset of T satisfied condition 2, so does T
            cond_2 = True
        else:
            # otherwise, check condition 2
            cond_2 = True
            for path in utils.semi_directed_paths(y, x, A):
                if path == [y, x]:
                    pass
                elif len((C | utils.neighbors(x, A)) & set(path)) == 0:
                    cond_2 = False
                    break
            if cond_2:
                # If condition 2 holds for C U neighbors(x), that is,
                # for C = NAyx U T U neighbors(x), then it holds for
                # all supersets of T
                supersets = subsets[:, T].all(axis=1)
                subsets[supersets, -1] = True
        # If both conditions hold, apply operator and compute its score
        print("      turn(%d,%d,%s)" % (x, y, C), "na_yx =", utils.na(y, x, A),
              "T =", T, "validity:", cond_1, cond_2) if debug > 1 else None
        if cond_1 and cond_2:
            # Apply operator
            new_A = turn(x, y, C, A)
            # Compute the change in score
            # here we don't care about rejection/acceptance; we just compare between test statistics and see which is more dependent.
            _, rho1, spsz1 = indep_test(x, y, (utils.pa(y, A) | C) - {x, y}, return_rho=True, return_sample_size=True)
            _, rho2, spsz2 = indep_test(x, y, (utils.pa(x, A)) - {x, y}, return_rho=True, return_sample_size=True)
            score_change = fisherz_stat(rho1, spsz1) - fisherz_stat(rho2, spsz2)

            print("        score new-old = %0.6f, y=%d, C=%s" % (score_change, y, C)) if debug > 1 else None
            # Add to the list of valid operators
            valid_operators.append((score_change, new_A, x, y, C))

            print("    turn(%d,%d,%s) -> %0.16f" % (x, y, C, score_change)) if debug else None
    # Return all the valid operators
    return valid_operators


def score_valid_turn_operators_undir(x, y, A, indep_test, debug=0):
    """Logic for finding and scoring the valid turn operators that can be
    applied to the edge x - y.

    Parameters
    ----------
    x : int
        the origin node (i.e. x -> y)
    y : int
        the target node
    A : np.array
        the current adjacency matrix
    cache : instance of ges.scores.DecomposableScore
        the score cache to compute the score of the
        operators that are valid
    debug : bool or string
        if debug traces should be printed (True/False). If a non-empty
        string is passed, traces are printed with the given string as
        prefix (useful for indenting the prints from a calling
        function)

    Returns
    -------
    valid_operators : list of tuples
        a list of tubles, each containing a valid operator, its score
        and the resulting connectivity matrix

    """
    # Proposition 31, condition (ii) in GIES paper (Hauser & Bühlmann
    # 2012) is violated if:
    #   1. all neighbors of y are adjacent to x, or
    #   2. y has no neighbors (besides u)
    # then there are no valid operators.
    non_adjacents = list(utils.neighbors(y, A) - utils.adj(x, A) - {x})
    if len(non_adjacents) == 0:
        print("    turn(%d,%d) : ne(y) \\ adj(x) = Ø => stopping" % (x, y)) if debug > 1 else None
        return []
    # Otherwise, construct all the possible subsets which will satisfy
    # condition (ii), i.e. all subsets of neighbors of y with at least
    # one which is not adjacent to x
    p = len(A)
    C0 = sorted(utils.neighbors(y, A) - {x})
    subsets = np.zeros((2**len(C0), p + 1), dtype=bool)
    subsets[:, C0] = utils.cartesian([np.array([False, True])] * len(C0), dtype=bool)
    # Remove all subsets which do not contain at least one non-adjacent node to x
    to_remove = (subsets[:, non_adjacents] == False).all(axis=1)
    subsets = utils.delete(subsets, to_remove, axis=0)
    # With condition (ii) guaranteed, we now check conditions (i,iii)
    # for each subset
    valid_operators = []
    print("    turn(%d,%d) C0=" % (x, y), set(C0)) if debug > 1 else None
    while len(subsets) > 0:
        print("      len(subsets)=%d, len(valid_operators)=%d" %
              (len(subsets), len(valid_operators))) if debug > 1 else None
        # Access the next subset
        C = set(np.where(subsets[0, :])[0])
        subsets = subsets[1:]
        # Condition (i): C is a clique in the subgraph induced by the
        # chain component of y. Because C is composed of neighbors of
        # y, this is equivalent to C being a clique in A. NOTE: This
        # is also how it is described in Alg. 5 of the paper
        cond_1 = utils.is_clique(C, A)
        if not cond_1:
            # Remove from consideration all other sets C' which
            # contain C, as the clique condition will also not hold
            supersets = subsets[:, list(C)].all(axis=1)
            subsets = utils.delete(subsets, supersets, axis=0)
            continue
        # Condition (iii): Note that condition (iii) from proposition
        # 31 appears to be wrong in the GIES paper; instead we use the
        # definition of condition (iii) from Alg. 5 of the paper:
        # Let na_yx (N in the GIES paper) be the neighbors of Y which
        # are adjacent to X. Then, {x,y} must separate C and na_yx \ C
        # in the subgraph induced by the chain component of y,
        # i.e. all the simple paths from one set to the other contain
        # a node in {x,y}.
        subgraph = utils.induced_subgraph(utils.chain_component(y, A), A)
        na_yx = utils.na(y, x, A)
        if not utils.separates({x, y}, C, na_yx - C, subgraph):
            continue
        # At this point C passes both conditions
        #   Apply operator
        new_A = turn(x, y, C, A)
        #   Compute the change in score
        _, rho1, spsz1 = indep_test(x, y, (utils.pa(y, A) | C) - {x, y}, return_rho=True, return_sample_size=True)
        _, rho2, spsz2 = indep_test(x, y, (utils.pa(x, A) | (C & na_yx)) - {x, y}, return_rho=True, return_sample_size=True)
        score_change = fisherz_stat(rho1, spsz1) - fisherz_stat(rho2, spsz2)

        print("        score new-old = %0.6f, y=%d, C=%s" % (score_change, y, C)) if debug > 1 else None
        #   Add to the list of valid operators
        valid_operators.append((score_change, new_A, x, y, C))
        print("    turn(%d,%d,%s) -> %0.16f" % (x, y, C, score_change)) if debug else None

    # Return all valid operators
    return valid_operators
