import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import anndata
import random
# from conditional_independence import hsic_test 
from copy import deepcopy
from itertools import combinations
from causallearn.utils.cit import CIT
from causallearn.utils.PCUtils import SkeletonDiscovery
import jpype.imports
import argparse
import pydotplus
import os
import igraph
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix

def count_precision_recall_f1(tp, fp, fn):
    # Precision
    if tp + fp == 0:
        precision = None
    else:
        precision = float(tp) / (tp + fp)

    # Recall
    if tp + fn == 0:
        recall = None
    else:
        recall = float(tp) / (tp + fn)

    # F1 score
    if precision is None or recall is None:
        f1 = None
    elif precision == 0 or recall == 0:
        f1 = 0.0
    else:
        f1 = float(2 * precision * recall) / (precision + recall)
    return precision, recall, f1


def pag_metrics(y_pred, y_true):
    """
    Calculate accuracy, F1 score, and recall for multi-class classification.
    
    Parameters:
    y_pred (numpy.ndarray): Prediction matrix with values 0, 1, 2
    y_true (numpy.ndarray): Ground truth matrix with values 0, 1, 2
    
    Returns:
    dict: Dictionary containing accuracy, F1 scores, and recall scores
    """
    # Flatten matrices if they're 2D
    y_pred_flat = y_pred.flatten()
    y_true_flat = y_true.flatten()
    
    # Calculate accuracy
    accuracy = accuracy_score(y_true_flat, y_pred_flat)
    
    # Calculate F1 score (macro averages across classes)
    f1_macro = f1_score(y_true_flat, y_pred_flat, average='macro')
    
   
    
    # Calculate recall (macro averages across classes)
    recall_macro = recall_score(y_true_flat, y_pred_flat, average='macro')
    
    
    return {
        'precision': accuracy,
        'f1': f1_macro,
        'recall': recall_macro
    }


def count_dag_accuracy(B_bin_true, B_bin_est):
    d = B_bin_true.shape[0]
    # linear index of nonzeros
    pred = np.flatnonzero(B_bin_est)
    cond = np.flatnonzero(B_bin_true)
    cond_reversed = np.flatnonzero(B_bin_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    if pred_size == 0:
        fdr = None
    else:
        fdr = float(len(reverse) + len(false_pos)) / pred_size
    if len(cond) == 0:
        tpr = None
    else:
        tpr = float(len(true_pos)) / len(cond)
    if cond_neg_size == 0:
        fpr = None
    else:
        fpr = float(len(reverse) + len(false_pos)) / cond_neg_size
    # structural hamming distance
    pred_lower = np.flatnonzero(np.tril(B_bin_est + B_bin_est.T))
    cond_lower = np.flatnonzero(np.tril(B_bin_true + B_bin_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    # false neg
    false_neg = np.setdiff1d(cond, true_pos, assume_unique=True)
    precision, recall, f1 = count_precision_recall_f1(tp=len(true_pos),
                                                      fp=len(reverse) + len(false_pos),
                                                      fn=len(false_neg))
    # return {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
    #         'precision': precision, 'recall': recall, 'f1': f1}
    return {'f1': f1,  'precision': precision, 'recall': recall, 'shd': shd}


def get_adjSet(i, G, n_node):
    adj = []
    for j in range(n_node):
        if G[i][j] == 1 or G[j][i] == 1:
            adj.append(j)
    return adj
def get_adj_ij(i, j, G, n_node):
    adj = []
    for k in range(n_node):
        if G[i][k] ==1 & G[k][j] == 1:
            adj.append(k)
    return adj
def fisher_z_test(i, j, K, sample, result):
    indep = True
    fisher_z_obj = CIT(sample, "kci")
    Pvalue = fisher_z_obj(i,j,K)
    result.append([f'{i}_{j}_{K}___{Pvalue}'])
    # print(f'{i}_{j}_{K}___{Pvalue}')
    alpha = 0.05
    if Pvalue >= alpha:
        indep = True
    else:
        indep = False
    return indep

def skeleton(n_node, sample):

    C = np.ones((n_node,n_node))

    S = []
    for i in range(n_node):
        S.append([])
        for j in range(n_node):
            S[i].append([])

    pairs = []
    for i in range(n_node):
        for j in range(n_node - i):
            if(i != (n_node - j - 1)):  
                pairs.append((i, (n_node - j - 1)))
            else:
                C[i, i] = 0
    CI_result = []
    l = -1    
    while 1:
        l = l + 1
        flag = True   
        for (i, j) in pairs:

            adj_set = get_adjSet(i, C, n_node)    
            if(C[i][j] == 1) & (len(adj_set) >= l):    
                flag =False   
                adj_set.remove(j)    

                combin_set = combinations(adj_set, l)    
                for K in combin_set:
                    if fisher_z_test(i, j, list(K), sample, CI_result):   
                        C[i][j] = 0
                        C[j][i] = 0

                        S[i][j] = list(K)
                        S[j][i] = list(K)    
                    else:
                        continue
            else:
                continue

        if flag:
            break

    return C, S, CI_result

def dfs_iterative(matrix, start, end):
    stack = [(start, [start])]  # Stack holds tuples of (current_node, path_so_far)
    visited = [False] * len(matrix)
    
    while stack:
        node, path = stack.pop()
        
        if node == end:
            return path  # Return the path to the destination
        
        visited[node] = True
        
        for neighbor in range(len(matrix[node])):
            if matrix[node][neighbor] != 0 and not visited[neighbor]:
                stack.append((neighbor, path + [neighbor]))
    
    return None  # No path found

def find_nodes_on_paths(matrix, i, j):
    return dfs_iterative(matrix, i, j)

def has_path_length_2(adj_matrix, node_a, node_b):
    """
    Check if there exists a path of length 2 between node A and node B.

    Parameters:
    - adj_matrix: numpy array (Adjacency matrix of the graph)
    - node_a: int (Index of node A)
    - node_b: int (Index of node B)

    Returns:
    - True if a path of length 2 exists between A and B, else False
    """
    n = adj_matrix.shape[0]
    middle_node = []
    for c in range(n):  # Check all possible middle nodes C
        if adj_matrix[node_a, c] == 1 and adj_matrix[c, node_b] == 1:
            middle_node.append(c)
            return True, middle_node  # Found a path of length 2
    
    return False, middle_node

def get_adjSet(i, G, n_node):
    adj = []
    for j in range(n_node):
        if G[i][j] == 1 or G[j][i] == 1:
            adj.append(j)
    return adj

def find_all_paths_and_colliders(matrix, start_node, end_node):
    """
    Find all paths of length >= 2 between two given nodes, including node sequences, 
    weight sequences, and colliders on each path. Also, collect all unique nodes and colliders.

    Parameters:
    - matrix: Adjacency matrix of the graph
    - start_node: Starting node
    - end_node: Ending node

    Returns:
    - A list of tuples, where each tuple contains:
        - Node sequence of the path
        - Weight sequence of the path
    - A set of all unique nodes across all paths
    - A set of all colliders across all paths
    """
    n = len(matrix)
    paths = []  # To store all valid paths (node sequence and weight sequence)
    all_nodes = set()  # To store all unique nodes in the paths
    all_colliders = set()  # To store all colliders across all paths

    def dfs(current, target, visited, path_nodes, path_weights):
        if current == target and len(path_nodes) >= 3:
            # Detect colliders in the path
            for i in range(1, len(path_nodes) - 1):  # Skip the start and end nodes
                prev_node = path_nodes[i - 1]
                curr_node = path_nodes[i]
                next_node = path_nodes[i + 1]

                # Check if the current node is a collider
                if matrix[prev_node][curr_node] > 0 and matrix[next_node][curr_node] > 0:
                    all_colliders.add(curr_node)

            # Add nodes to the global node set
            all_nodes.update(path_nodes)

            # Store the current path and weights
            paths.append((path_nodes[:], path_weights[:]))
            return

        # Explore neighbors
        for neighbor in range(n):
            if matrix[current][neighbor] > 0 and neighbor not in visited:
                visited.add(neighbor)
                path_nodes.append(neighbor)
                path_weights.append(matrix[current][neighbor])

                # Recurse deeper
                dfs(neighbor, target, visited, path_nodes, path_weights)

                # Backtrack
                path_nodes.pop()
                path_weights.pop()
                visited.remove(neighbor)

    # Start DFS from the start_node
    visited = set([start_node])
    dfs(start_node, end_node, visited, [start_node], [])

    return paths, all_nodes, all_colliders

    visited = set()
    visited.add(start_node)
    found, path_nodes, path_weights = dfs(start_node, end_node, visited, 0, [start_node], [])

    if found:
        return True, path_nodes, path_weights
    else:
        return False, [], []



def causal_validate_path( path, start_set, end_set, rules):
    """
    Validate a path based on the given rules and return specific indices of weight pairs.

    Parameters:
    - path: List of weights representing the path (e.g., [1, 2, 5, 6, 2])
    - start_set: Set of valid start weights (e.g., {1, 3})
    - end_set: Set of valid end weights (e.g., {2, 4})
    - rules: Dict of valid transitions for weights apart from the first and last one
             (e.g., {-2: {-2, 4, 6}, 2: {2, 6}, 4: {2, 6}, 6: {-2, 2, 4, 6}})

    Returns:
    - is_valid: True if the path satisfies all constraints, False otherwise
    - specific_indices: List of tuples (i, i+1) where the first weight is in (6, -2, 3)
                        and the next weight is in (4, -2)
    """
    # Initialize to store indices for the specific condition
    specific_indices = []

    # Check if the first weight is valid
    if path[0] not in start_set:
        return False, specific_indices

    # Check if the last weight is valid
    if path[-1] not in end_set:
        return False, specific_indices
    
    if len(path == 2):
        current_weight = path[0]
        next_weight = path[1]
        if current_weight ==3 and next_weight == 4:
            specific_indices.append((0,1))
            return True, specific_indices
        else:
            return False, specific_indices
    else:
        # Validate the second weight based on the first weight
        if path[0] == 1 and path[1] not in {2, 6}:
            return False, specific_indices
        if path[0] == 3 and path[1] not in {-2, 2, 4, 6}:
            return False, specific_indices

        # Validate the second-to-last weight based on the last weight
        if path[-1] == 2 and path[-2] not in {2, 4, 6}:
            return False, specific_indices
        if path[-1] == 4 and path[-2] not in {-2, 6}:
            return False, specific_indices

        # Validate all other weights using the rules
        for i in range(len(path) - 1):
            current_weight = path[i]
            next_weight = path[i + 1]

            # Check if the weights follow the general rules
            if current_weight in rules and next_weight not in rules[current_weight]:
                return False, specific_indices

            # Check for the specific condition (current in {6, -2, 3}, next in {4, -2})
            if current_weight in {6, -2, 3} and next_weight in {4, -2}:
                specific_indices.append((i, i + 1))

        # If all checks passed, return True and the specific indices
        return True, specific_indices

def selection_validate_path( path, start_set, end_set, rules):
    """
    Validate a path based on the given rules and return specific indices of weight pairs.

    Parameters:
    - path: List of weights representing the path (e.g., [1, 2, 5, 6, 2])
    - start_set: Set of valid start weights (e.g., {1, 3})
    - end_set: Set of valid end weights (e.g., {2, 4})
    - rules: Dict of valid transitions for weights apart from the first and last one
             (e.g., {-2: {-2, 4, 6}, 2: {2, 6}, 4: {2, 6}, 6: {-2, 2, 4, 6}})

    Returns:
    - is_valid: True if the path satisfies all constraints, False otherwise
    - specific_indices: List of tuples (i, i+1) where the first weight is in (6, -2, 3)
                        and the next weight is in (4, -2)
    """
    # Initialize to store indices for the specific condition
    effect_specific_indices = []
    cause_specific_indices = []

    # Check if the first weight is valid
    if path[0] not in start_set:
        return False, effect_specific_indices, cause_specific_indices

    # Check if the last weight is valid
    if path[-1] not in end_set:
        return False, effect_specific_indices, cause_specific_indices
    
    if len(path == 2):
        current_weight = path[0]
        next_weight = path[1]
        if current_weight ==1 and next_weight == -3:
            cause_specific_indices.append((0,1))
            return True, effect_specific_indices, cause_specific_indices
        elif current_weight ==3 and next_weight ==-1:
            effect_specific_indices.append((0,1))
            return True, effect_specific_indices, cause_specific_indices
    else:
        # Validate the second weight based on the first weight
        if path[0] == 1 and path[1] not in {2, 6}:
            return False, effect_specific_indices, cause_specific_indices
        if path[0] == 3 and path[1] not in {-2, 2, 4, 6}:
            return False, effect_specific_indices, cause_specific_indices

        # Validate the second-to-last weight based on the last weigh
        if path[-1] == -1 and path[-2] not in {-2, 6}:
            return False, effect_specific_indices, cause_specific_indices
        if path[-1] == -3 and path[-2] not in {-2,2,4, 6}:
            return False, effect_specific_indices, cause_specific_indices

        # Validate all other weights using the rules
        for i in range(len(path) - 1):
            current_weight = path[i]
            next_weight = path[i + 1]

            # Check if the weights follow the general rules
            if current_weight in rules and next_weight not in rules[current_weight]:
                return False, effect_specific_indices, cause_specific_indices

            # Check for the specific condition (current in {6, -2, 3}, next in {4, -2})
            if current_weight in {6, -2, 3} and next_weight in {4, -2}:
                effect_specific_indices.append((i, i + 1))
            if current_weight in {1, 2, 4} and next_weight in {2,-3,6}:
                causal_validate_path.append((i,i+1))

        # If all checks passed, return True and the specific indices
        return True, effect_specific_indices, cause_specific_indices

def correct_inducing_path(matrix):
    causal_s = {1,3}
    causal_e = {2,4}
    selection_s = {1,3}
    selection_e = {-1,-3}
    rules = {-2: {-2,2, 4, 6}, 2: {2, 6}, 4: {2, 6}, 6: {-2, 2, 4, 6}}
    n_nodes = matrix.shape[0]
    causal = []
    selection = []
    for i in range(n_nodes):
        for j in range(i+1, n_nodes):
            if matrix[i][j] == 1 and matrix[j][i]==0:
                causal.append([i,j])
            elif matrix[j][i] ==1 and matrix[i][j] ==0:
                causal.append([j,i])
            elif matrix[i][j] == 5 and matrix[j][i]==5:
                selection.append([i,j])
    for k in range(len(causal)):
        paths, all_nodes, all_colliders = find_all_paths_and_colliders(matrix, causal[k][0], causal[k][1])
        if len(paths) != 0:
            for m, (node_seq, mark_seq) in enumerate(paths):
                inducing_path, indices = causal_validate_path( mark_seq,causal_s, causal_e, rules)
                if inducing_path:
                    if len(indices) != 0:
                        node_index = node_seq[indices[0][1]]
                        break
                    cause = causal[k][0]
                    i = node_index
                    j = causal[k][1]
                    c_set = all_nodes
                    if c_set is not None:
                        for n in node_seq:
                            if n in c_set:
                                c_set.remove(n)
                        for m in all_colliders:
                            if m in c_set:
                                c_set.remove(m)
                        assert cause not in c_set
                        assert j not in c_set
                    c_set = list(c_set)
                    data_i = data_final[f'per_{i}'][:,[i,j,-1]]
                    data_j = data_final[f'per_{j}'][:,[i,j,-1]]
                    data_p_i = np.concatenate((data_i, data_final[f'per_{i}'][:,c_set]), axis=1)
                    # CIT_obj = CIT(data_p_j, "kci")
                    g_adj = [i for i in range(3,data_p_i.shape[1])]
                    CIT_obi = CIT(data_p_i,"kci")
                    Upi_value = CIT_obi(1,2, g_adj)
                    if Upi_value > 0:
                        matrix[causal[k][0]][causal[k][1]] = 0
    for k in range(len(selection)):
        paths, all_nodes, all_colliders = find_all_paths_and_colliders(matrix, selection[k][0], selection[k][1])
        if len(paths) != 0:
            for m, (node_seq, mark_seq) in enumerate(paths):
                inducing_path, effect_indices, cause_indices = selection_validate_path(mark_seq, selection_s, selection_e, rules)
                if inducing_path:
                    if len(effect_indices) != 0:
                        node_index = effect_indices[0][1]
                        selection_1 = selection[k][0]
                        i = node_index
                        j = selection[k][1]
                    elif len(cause_indices) != 0:
                        selection_1 = selection[k][1]
                        node_index = cause_indices[0][1]
                        i = node_index
                        j = selection[k][0]
                    else:
                        break
                    c_set = all_nodes
                    if c_set is not None:
                        for n in node_seq:
                            if n in c_set:
                                c_set.remove(n)
                        for m in all_colliders:
                            if m in c_set:
                                c_set.remove(m)
                        assert selection_1 not in c_set
                        assert j not in c_set
                    c_set = list(c_set)
                    data_i = data_final[f'per_{i}'][:,[i,j,-1]]
                    data_j = data_final[f'per_{j}'][:,[i,j,-1]]
                    data_p_i = np.concatenate((data_i, data_final[f'per_{i}'][:,c_set]), axis=1)
                    # CIT_obj = CIT(data_p_j, "kci")
                    g_adj = [i for i in range(3,data_p_i.shape[1])]
                    CIT_obi = CIT(data_p_i,"kci")
                    Upi_value = CIT_obi(1,2, g_adj)
                    if Upi_value > 0:
                        matrix[selection[k][0]][selection[k][1]] = 0
                        # if [selection[k][0], selection[k][1]] in result['selection']:
                        #     result['selection'].remove([selection[k][0], selection[k][1]])
                        # if [selection[k][1], selection[k][0]] in result['selection']:
                        #     result['selection'].remove(selection[k][1], selection[k][0])
                break
    return matrix

def dag_transformation(dag):
    for i in range(dag.shape[0]):
        for j in range(dag.shape[0]):
            if dag[i][j] < 0:
                dag[i][j] = 0
            if dag[i][j] == 2:
                dag[i][j] = 1
            if dag[i][j] == 3:
                dag[i][j] = 1
            if dag[i][j] == 4:
                dag[i][j] = 0
            if dag[i][j] == 5:
                dag[i][j] = 0
            if dag[i][j] == 6:
                dag[i][j] = 1
    return dag



def transform_dag_matrix(D, latent_pairs, selection_pairs):
    """
    Transform a matrix representing a DAG based on latent and selection pairs
    
    Parameters:
    - D: 2D matrix representing a DAG 
    - latent_pairs: List of lists, each inner list contains nodes in a latent pair
    - selection_pairs: List of lists, each inner list contains nodes in a selection pair
    
    Returns:
    - A new matrix with updated values based on the specified rules
    """
    import numpy as np
    
    # Make a copy of the matrix to avoid modifying the original
    n = len(D)
    D_new = np.full((n, n), -1)  # Initialize with -1 (no edge)
    
    # Create sets for easier lookup
    latent_set = set()
    for group in latent_pairs:
        # Generate all pairs within this group
        for i in range(len(group)):
            for j in range(i+1, len(group)):
                latent_set.add((group[i], group[j]))
                latent_set.add((group[j], group[i]))  # Add both directions
    
    selection_set = set()
    for pair in selection_pairs:
        i, j = pair
        selection_set.add((i, j))
        selection_set.add((j, i))  # Add both directions
    
    # Copy edge information from original matrix
    for i in range(n):
        for j in range(n):
            if i == j:  # Skip diagonal elements
                continue
            
            if D[i][j] == 1:  # Has edge in original
                D_new[i][j] = 0  # Default for edge
                D_new[j][i] = 1  # Default for reverse direction
    
    # Apply rules for latent and selection pairs
    for i in range(n):
        for j in range(i,n):
            if i == j:  # Skip diagonal elements
                continue
            
            # Check if in both latent and selection
            if (i, j) in latent_set and (i, j) in selection_set:
                D_new[i][j] = 2
                D_new[j][i] = 2
            
            # Latent pairs rules
            elif (i, j) in latent_set:
                if D_new[i][j] == 0:  # Has edge originally
                    D_new[i][j] = 2
                    D_new[j][i] = 1
                elif D_new[i][j] == 1: 
                    D_new[i][j] = 1
                    D_new[j][i] = 2
                else:  # No edge originally
                    D_new[i][j] = 1
                    D_new[j][i] = 1
            
            # Selection pairs rules
            elif (i, j) in selection_set:
                if D_new[i][j] == 0:  # Has edge originally
                    D_new[i][j] = 0
                    D_new[j][i] = 2
                elif D_new[i][j] == 2:  # Has edge originally
                    D_new[i][j] = 2
                    D_new[j][i] = 0
                else:  # No edge originally
                    D_new[i][j] = 0
                    D_new[j][i] = 0
    
    return D_new


def GISL_transform_dag_matrix(D, latent_pairs, selection_pairs):
    """
    Transform a matrix representing a DAG based on latent and selection pairs
    
    Parameters:
    - D: 2D matrix representing a DAG 
    - latent_pairs: List of lists, each inner list contains nodes in a latent pair
    - selection_pairs: List of lists, each inner list contains nodes in a selection pair
    
    Returns:
    - A new matrix with updated values based on the specified rules
    """
    import numpy as np
    
    # Make a copy of the matrix to avoid modifying the original
    n = len(D)
    D_new = deepcopy(D)  # Initialize with -1 (no edge)
    
    # Create sets for easier lookup
    latent_set = set()
    for group in latent_pairs:
        # Generate all pairs within this group
        for i in range(len(group)):
            for j in range(i+1, len(group)):
                latent_set.add((group[i], group[j]))
                latent_set.add((group[j], group[i]))  # Add both directions
    
    selection_set = set()
    for pair in selection_pairs:
        i, j = pair
        selection_set.add((i, j))
        selection_set.add((j, i))  # Add both directions
    
    # Copy edge information from original matrix
    for i in range(n):
        for j in range(n):
            if i == j:  # Skip diagonal elements
                D_new[i][j]=-1
            
            if D[i][j] == D[j][i] == 0:  # Has edge in original
                D_new[i][j] = -1  # Default for edge
                D_new[j][i] = -1 
            if D[i][j] == 1 and D[j][i] == 0:  # Has edge in original
                D_new[i][j] = 0  # Default for edge
                D_new[j][i] = 1 # Default for reverse direction
            if D[i][j] == D[j][i] == 1 :
                D_new[i][j] = 2  # Default for edge
                D_new[j][i] = 2 
    
    # Apply rules for latent and selection pairs
    for i in range(n):
        for j in range(i+1,n):
            if i == j:  # Skip diagonal elements
                continue
            
            # Check if in both latent and selection
            if (i, j) in latent_set and (i, j) in selection_set:
                D_new[i][j] = 2
                D_new[j][i] = 2
            
            # Latent pairs rules
            elif (i, j) in latent_set:
                if D_new[i][j] == 0:  # Has edge originally
                    D_new[i][j] = 2
                    D_new[j][i] = 1
                elif D_new[i][j] == 1 and D_new[j][i]==0: 
                    D_new[i][j] = 1
                    D_new[j][i] = 2
                
            
            # Selection pairs rules
            elif (i, j) in selection_set:
                if D_new[i][j] != 2:  # Has edge originally
                    D_new[i][j] = 2
                    D_new[j][i] = 2
                 
    
    return D_new



def find_all_paths(adjacency_matrix, start_node, end_node, max_length=None):
    """
    Find all paths between two nodes in a graph represented by an adjacency matrix.
    
    Parameters:
    - adjacency_matrix: 2D list representing the graph's adjacency matrix (0s and 1s)
    - start_node: Starting node index
    - end_node: Ending node index
    - max_length: Maximum path length to consider (defaults to number of nodes)
    
    Returns:
    - Dictionary with path information:
        - 'total_paths': Total number of paths found
        - 'paths': List of paths, each containing the nodes in the path
        - 'path_lengths': List of path lengths (number of edges in each path)
    """
    n = len(adjacency_matrix)
    
    # Set default max_length if not specified
    if max_length is None:
        max_length = n
    
    # Initialize results
    result = {
        'total_paths': 0,
        'paths': [],
        'path_lengths': []
    }
    
    def dfs(current, target, path, visited=None):
        if visited is None:
            visited = set()
        
        # Add current node to path
        path.append(current)
        
        # If we reached the target, we found a path
        if current == target and len(path) > 1:
            # Make a copy of the path
            result['paths'].append(path.copy())
            # Path length is number of edges, which is nodes-1
            result['path_lengths'].append(len(path) - 1)
            result['total_paths'] += 1
        
        # If we haven't reached the maximum path length, explore neighbors
        if len(path) <= max_length:
            for neighbor in range(n):
                if adjacency_matrix[current][neighbor] == 1 and (
                    # Allow visiting end_node multiple times
                    neighbor == end_node or neighbor not in path
                ):
                    dfs(neighbor, target, path, visited)
        
        # Backtrack
        path.pop()
    
    # Start DFS from start_node
    dfs(start_node, end_node, [])
    
    return result

def analyze_graph_selection(adj_matrix, selection_pairs, i, j):
    """
    Returns:
    - Boolean indicating if any node x on a path between i and j forms a pair [i,x] in selection_pairs
    - Boolean indicating if ONLY i is in selection pairs (other nodes on the path are not)
    """
    # Find all nodes on paths between i and j (excluding i and j themselves)
    path_nodes = find_nodes_on_paths(adj_matrix, i, j)
    middle_nodes = [node for node in path_nodes if node != i and node != j]
    
    # Check if any node on the path forms a selection pair with i
    i_x_in_selection = False
    for x in path_nodes:
        if x != i and ([i, x] in selection_pairs or [x, i] in selection_pairs):
            i_x_in_selection = True
            break
    
    # Check if only i is in selection pairs and other middle nodes are not
    i_in_pairs = any(i in pair for pair in selection_pairs)
    middle_nodes_in_pairs = any(node in pair for node in middle_nodes for pair in selection_pairs)
    
    only_i_in_selection = i_in_pairs and not middle_nodes_in_pairs
    
    return i_x_in_selection, only_i_in_selection

def find_nodes_on_paths(adj_matrix, start, end):
    """Finds all nodes that lie on any path from start to end."""
    n = len(adj_matrix)
    all_path_nodes = set()
    
    # Use BFS to find all simple paths
    queue = [(start, [start])]
    
    while queue:
        node, path = queue.pop(0)
        
        if node == end:
            # Found a path, add all nodes to the set
            for n in path:
                all_path_nodes.add(n)
        else:
            # Explore neighbors
            for neighbor in range(n):
                if adj_matrix[node][neighbor] == 1 and neighbor not in path:
                    queue.append((neighbor, path + [neighbor]))
    
    return all_path_nodes


raw_data = sc.read_h5ad('./perturb_processed.h5ad')
ctr = raw_data[raw_data.obs['condition'] == 'ctrl']
gene_name = raw_data.var['gene_name'].tolist()
perturb = list(set(raw_data.obs['condition'].tolist())-{'ctrl'})
# print(set(raw_data.obs['condition'].tolist()))

character = 'ctrl'
all_perturb = [item for item in perturb if character in item]
# print(all_perturb)
perturbation = []
for i in range(len(all_perturb)):
    split = all_perturb[i].split('+')
    if split[0] == 'ctrl':
        name = split[1]
    elif split[1] == 'ctrl':
        name = split[0]
    perturbation.append(name)
perturbation = list(set(perturbation))

origin_node_index = {node: i for i, node in enumerate(gene_name)}
mapping = {i: gene_name[i] for i in range(len(gene_name))}
per_index = [origin_node_index[name] for name in perturbation] 
data_final = {}


for i in per_index:
    per = raw_data[(raw_data.obs['condition'] == 'ctrl+' + gene_name[i]) | (raw_data.obs['condition'] == gene_name[i] + '+ctrl')]
    if per.shape[0] == 0:
        continue
    per = per.X.toarray()
    # zero_rate = np.sum(per == 0, axis=1) / per.shape[1]
    # per = per[zero_rate < 0.7]
    per = np.concatenate((per, np.ones((per.shape[0],1))), axis=1)
    sample_size = per.shape[0]

    data_ctr_all = ctr.X.toarray()
    if 'obs' not in data_final:
        # obs_zero = np.sum(data_ctr_all ==0, axis=1)/data_ctr_all.shape[1]
        # data_ctr_all = data_ctr_all[obs_zero < 0.6]
        data_final['obs'] = data_ctr_all
    
    # print(sample_size)
    zero_rates = np.sum(data_ctr_all == 0, axis=1) / data_ctr_all.shape[1]
    sorted_indices = np.argsort(zero_rates)  # Negate to sort descending
    row_index = sorted_indices[:sample_size]
    data_ctr = data_ctr_all[row_index]
    data_ctr = np.concatenate((data_ctr, np.zeros((sample_size,1))), axis=1)
    data_final[f'per_{i}'] = np.concatenate((data_ctr, per), axis= 0)
    # print(data_final[f'per_{i}'].shape)


data = pd.read_csv('./dixit_result_p7.csv')
ske = data.values
rows, cols = ske.shape
for i in range(rows):
    for j in range(cols):
        if ske[i, j] != 0:
            ske[i, j] = 1
            ske[j, i] = 1

dag = deepcopy(ske)
thresholdset = [0.05]
s_indicator = np.zeros([rows, cols])
s_without_cause = []
correct_set = []
result = {}
special_latent = []
CI_result_u = {}
CI_result_c = {}
condition_set = {}
result['latent'] = []
result['selection'] = []
count =0


for thres in thresholdset:
    threshold = thres
    for a in range(len(per_index)):
        for b in range(a+1, len(per_index)):
            i = per_index[a]
            j = per_index[b]
            # if dag[i][j] == 1 and dag[j][i] == 1:
            print(f'{i}-{j}')
            correct = False
            data_i = data_final[f'per_{i}'][:,[i,j,-1]]
            data_j = data_final[f'per_{j}'][:,[i,j,-1]]
            ########### KCI ###########################
            CIT_obj = CIT(data_j, "kci")
            Upj_value = CIT_obj(0,2,set([]))
            Cpj_value = CIT_obj(0,2,set([1]))
            CIT_obi = CIT(data_i, "kci")
            Upi_value = CIT_obi(1,2, set([]))
            Cpi_value = CIT_obi(1,2, set([0]))
            CI_result_u[f'{i}-{j}'] = [Upj_value, Cpj_value,Upi_value, Cpi_value]
            value = [Upj_value,Cpj_value,Upi_value,Cpi_value]
            #################################################
            
            ##################HSIC###########################
            
            # print(f'{i}-{j} {Upj_value}, {Cpj_value},{Upi_value}, {Cpi_value}')
            countin = sum(1 for v in value if v > threshold)
            if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                dag[j][i] = 0
            elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                dag[i][j] = 0
            elif (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                dag[j][i] = 0
                correct = True
                correct_set.append([i,j])
                # result['selection'].append([i,j])
                condition_set[f'{i}-{j}'] = 'S_C'
            elif (Upi_value < threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                dag[i][j] = 0
                correct = True
                correct_set.append([j,i])
                # result['selection'].append([j,i])
                condition_set[f'{j}-{i}'] = 'S_C'
            elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                dag[j][i] = 0
                correct = True
                correct_set.append([i,j])
                condition_set[f'{i}-{j}'] = 'L_C'
                result['latent'].append([i,j])
            elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value < threshold):
                dag[i][j] = 0
                correct = True
                correct_set.append([j,i])
                condition_set[f'{j}-{i}'] = 'L_C'
                result['latent'].append([j,i])
            elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                dag[j][i] = 0
                dag[i][j] = 0
                result['latent'].append([i,j])
            elif (Upj_value > threshold) & (Cpj_value > threshold) & (Upi_value > threshold) & (Cpi_value > threshold):
                dag[j][i] = 0
                dag[i][j] = 0
            elif (Upj_value < threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                correct_set.append([i,j])
                condition_set[f'{i}-{j}'] ='F_D'
                result['selection'].append([i,j])
            elif countin > 2: 
                dag[j][i] = 0
                dag[i][j] = 0
            else:
                correct = True
                correct_set.append([i,j])
                condition_set[f'{i}-{j}'] ='F_D'
                # result['selection'].append([i,j])
            count +=1
                
    print(count)
 
    for ind, pair in enumerate(correct_set):
        i,j = pair[0], pair[1]
        # c_set = given_set(i,j,dag)
        # c_set = set(get_adjSet(i, dag, rows) + get_adjSet(j, dag, rows))
        c_set = find_nodes_on_paths(dag,i,j)

        if c_set is not None:
            if i in c_set:
                c_set.remove(i)
            if j in c_set:
                c_set.remove(j)
            assert i not in c_set
            assert j not in c_set
        c_set = list(c_set)
        path_2, middle_node = has_path_length_2(dag,i,j)
        if path_2:
            # if len(middle_node)>0:
            
            others = list(set(c_set) - set(middle_node))
            if len(middle_node) > 3:
                lenc = 3
            else:
                lenc = len(middle_node)
            for k in range(1, lenc):
                found = False
                paris = list(combinations(middle_node,k))
                for element in paris:
                    given_set = others+list(element)
                    data_i = data_final[f'per_{i}'][:,[i,j,-1]]
                    data_j = data_final[f'per_{j}'][:,[i,j,-1]]
                    data_p_i = np.concatenate((data_i, data_final[f'per_{i}'][:,given_set]), axis=1)
                    data_p_j = np.concatenate((data_j, data_final[f'per_{j}'][:,given_set]), axis=1)
                    g_adj = [k for k in range(3,data_p_i.shape[1])]

                    #############kci ####################
                    CIT_obj = CIT(data_p_j, "kci")
                    Upj_value = CIT_obj(0,2, g_adj)
                    Cpj_value = CIT_obj(0,2,g_adj+[1])
                    CIT_obi = CIT(data_p_i,"kci")
                    Upi_value = CIT_obi(1,2, g_adj)
                    Cpi_value = CIT_obi(1,2,g_adj+[0])
                    ##############hsic#########################
                    # Upj_value = hsic_test(data_p_j,0,2,g_adj)['p_value']
                    # Cpj_value = hsic_test(data_p_j,0,2,g_adj+[1])['p_value']
                    # Upi_value = hsic_test(data_p_i,1,2, g_adj)['p_value']
                    # Cpi_value = hsic_test(data_p_i,1,2, g_adj+[0])['p_value']
                    value = [Upj_value,Cpj_value,Upi_value,Cpi_value]
                    CI_result_c[f'{i}-{j}'] = [Upj_value,Cpj_value,Upi_value,Cpi_value]
                    if condition_set[f'{i}-{j}'] != 'F_D':
                        countv = sum(1 for v in value if v < threshold)

                        if countv >2:
                            if condition_set[f'{i}-{j}'] == 'L_C' and countv==3:
                                special_latent.append([i,j])
                                found = True
                                break
                            if condition_set[f'{i}-{j}'] == 'S_C' and countv==3:
                                result['selection'].append([i,j])
                                found = True
                                break
                            continue
                            
                        else:
                            if condition_set[f'{i}-{j}'] == 'S_C':
                                if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                                    dag[j][i] = 0
                                    # result['selection'].remove([i,j])
                                    found = True
                                    break
                                    # result['direct_cause'][f'{i}-{j}'] = Upj_value
                                elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                                    dag[i][j] = 0
                                    # result['selection'].remove([i,j])
                                    found = True
                                    break
                                
                        
                            elif condition_set[f'{i}-{j}'] == 'L_C':
                                if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                                    result['latent'].remove([i,j])
                                    found = True
                                    break
                            
                                elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                                    dag[j][i] = 0
                                    dag[i][j] = 0
                                    found = True
                                    break
                                # elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                                #     special_latent.append([i,j])
                            # elif condition_set[f'{i}-{j}'] == 'L':
                            #     if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                            #         found = True
                            #         break
                            #     else:
                            #         result['latent'].remove([i,j])
                            #         found = True
                            #         break

                    else:
                        countin = sum(1 for v in value if v > threshold)
                        countv = sum(1 for v in value if v < threshold)
                        if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                            dag[j][i] = 0
                            dag[i][j] = 1
                            found = True
                            try:
                                result['selection'].remove([i,j])
                            except ValueError:
                                pass
                            break
                        elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                            dag[i][j] = 0
                            dag[j][i] = 1
                            found = True
                            try:
                                result['selection'].remove([i,j])
                            except ValueError:
                                pass
                            break
                        elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                            dag[j][i] = 0
                            dag[i][j] = 1
                            result['latent'].append([i,j])
                            special_latent.append([i,j])
                            try:
                                result['selection'].remove([i,j])
                            except ValueError:
                                pass
                            found = True
                            break
                        elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value < threshold):
                            dag[i][j] = 0
                            dag[j][i] = 1
                            result['latent'].append([i,j])
                            special_latent.append([i,j])
                            try:
                                result['selection'].remove([i,j])
                            except ValueError:
                                pass
                            found = True
                            break
                        elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                            dag[j][i] = 0
                            dag[i][j] = 0
                            result['latent'].append([i,j])
                            try:
                                result['selection'].remove([i,j])
                            except ValueError:
                                pass
                            found = True
                            break
                        elif countin > 2:
                            dag[j][i] = 0
                            dag[i][j] = 0
                        
                if found:
                    break
        else:
            data_i = data_final[f'per_{i}'][:,[i,j,-1]]
            data_j = data_final[f'per_{j}'][:,[i,j,-1]]
            data_p_i = np.concatenate((data_i, data_final[f'per_{i}'][:,c_set]), axis=1)
            data_p_j = np.concatenate((data_j, data_final[f'per_{j}'][:,c_set]), axis=1)
            # CIT_obj = CIT(data_p_j, "kci")
            g_adj = [i for i in range(3,data_p_i.shape[1])]

            #################kci ######################
            CIT_obj = CIT(data_p_j, "kci")
            Upj_value = CIT_obj(0,2, g_adj)
            Cpj_value = CIT_obj(0,2,g_adj+[1])
            CIT_obi = CIT(data_p_i,"kci")
            Upi_value = CIT_obi(1,2, g_adj)
            Cpi_value = CIT_obi(1,2,g_adj+[0])
            ###################hsic##############################
            # Upj_value = hsic_test(data_p_j,0,2,g_adj)['p_value']
            # Cpj_value = hsic_test(data_p_j,0,2,g_adj+[1])['p_value']
            # Upi_value = hsic_test(data_p_i,1,2, g_adj)['p_value']
            # Cpi_value = hsic_test(data_p_i,1,2, g_adj+[0])['p_value']
            # result[f'{i}-{j}-{c_set}'] = [Upj_value, Cpj_value,Upi_value, Cpi_value]
            CI_result_c[f'{i}-{j}'] = [Upj_value,Cpj_value,Upi_value,Cpi_value]
            value = [Upj_value,Cpj_value,Upi_value,Cpi_value]
            if condition_set[f'{i}-{j}'] != 'F_D':
                countv = sum(1 for v in value if v < threshold)
                if countv >2:
                    if condition_set[f'{i}-{j}'] == 'L_C' and countv==3:
                        special_latent.append([i,j])
                    if condition_set[f'{i}-{j}'] == 'S_C' and countv==3:
                        result['selection'].append([i,j])
                    continue
                else:
                    if condition_set[f'{i}-{j}'] == 'S_C':
                        if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                            dag[j][i] = 0
                        
                            # result['direct_cause'][f'{i}-{j}'] = Upj_value
                        elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                            dag[i][j] = 0
                        
                    
                    elif condition_set[f'{i}-{j}'] == 'L_C':
                        if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                            result['latent'].remove([i,j])
                    
                        elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                            dag[j][i] = 0
                            dag[i][j] = 0
                            # result['latent'].append([i,j])
                        # elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                        #     special_latent.append([i,j])
            else:
                if (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value > threshold):
                    dag[j][i] = 0
                    dag[i][j] = 1
                
                    try:
                        result['selection'].remove([i,j])
                    except ValueError:
                        pass
                elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value > threshold):
                    dag[i][j] = 0
                    dag[j][i] = 1
                
                    try:
                        result['selection'].remove([i,j])
                    except ValueError:
                        pass
                
                elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value < threshold) & (Cpi_value < threshold):
                    special_latent.append([i,j])
                    try:
                        result['selection'].remove([i,j])
                    except ValueError:
                        pass
                
                elif (Upi_value > threshold) & (Cpi_value < threshold) & (Upj_value < threshold) & (Cpj_value < threshold):
                    special_latent.append([i,j])
                    try:
                        result['selection'].remove([i,j])
                    except ValueError:
                        pass
            
                elif (Upj_value > threshold) & (Cpj_value < threshold) & (Upi_value > threshold) & (Cpi_value < threshold):
                    dag[j][i] = 0
                    dag[i][j] = 0
                    result['latent'].append([i,j])
                    try:
                        result['selection'].remove([i,j])
                    except ValueError:
                        pass
                elif countin > 2:
                    dag[j][i] = 0
                    dag[i][j] = 0

    for pair in special_latent:
            i = pair[0]
            j = pair[1]
            i_x_ins, i_ins = analyze_graph_selection(dag,result['selection'], pair[0], pair[1])
            if i_x_ins:
                dag[i][j] = 0
                dag[j][i] = 0
                result['latent'].remove([i,j])
                continue
            if i_ins:
                dag[i][j] = 0
                dag[j][i] = 0
            else:
                dag[j][i] = 0

    print(result['latent'])
    print(result['selection'])
    per_dag = dag[np.ix_(per_index, per_index)]


    np.save(f'./GISL_dixit_new/latent_p3_{threshold}.npy', np.array(result['latent']))
    np.save(f'./GISL_dixit_new/selection_p3_{threshold}.npy', np.array(result['selection']))
    np.save(f'./GISL_dixit_new/gene_name_kci_p3_{threshold}.npy', gene_name)
    np.save(f'./GISL_dixit_new/per_index_kci_p3_{threshold}.npy', per_index)
    np.save(f'./GISL_dixit_new/output_dag_kci_p3_{threshold}.npy', dag)
    np.save(f'./GISL_dixit_new/per_dag_kci_p3_{threshold}.npy', per_dag)
    np.save(f'./GISL_dixit_new/per_name_kci_p3_{threshold}.npy',np.array(perturbation))