import numpy as np
from causallearn.graph.GraphNode import GraphNode
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.graph.GeneralGraph import GeneralGraph
from pandas import DataFrame
from utils.utils import HiddenPrints

def check_constant_columns_numpy(arr: np.ndarray) -> list[int]:
    """
    Checks for constant columns in a NumPy array.

    Args:
        arr (np.ndarray): The input NumPy array.

    Returns:
        list[int]: A list of column indices that are constant.
    """
    constant_cols = []
    for i in range(arr.shape[1]):
        if np.all(arr[:, i] == arr[0, i]):
            constant_cols.append(i)
    return constant_cols

def fci_algorithm(data: DataFrame, edge_constraints: list[list[str]], seed: int = 0, alpha: float = 0.05) -> GeneralGraph:
    """
    Run the FCI algorithm on the given data and initial condition. In case of a
      singular correlation matrix, add noise to the data. This is needed, 
      otherwise C.I. tests will fail.

    Args:
        data (DataFrame): The input data.
        edge_constraints (nx.DiGraph): The initial condition graph.
        seed (int): Random seed for reproducibility.
    Returns:
        nx.DiGraph: The resulting directed graph.
    """
    data_np = data.to_numpy().astype(float)

    constant_columns = check_constant_columns_numpy(data_np)
    if len(constant_columns) > 0:
        # Add noise to the whole dataset if there are any constant columns
        noise = np.random.normal(0, 0.0001, data_np.shape)
        data_np = data_np + noise

    # If the correlation matrix is singular, add noise to the date
    # It is needed, otherwise C.I. tests will fail
    corr_matrix = np.corrcoef(data_np, rowvar=False)
    determinant = np.linalg.det(corr_matrix)
    if determinant < 1e-2:
        noise = np.random.normal(0, 0.0001, data_np.shape)
        data_np = data_np + noise

    # Run FCI algorithm
    with HiddenPrints():
        background = extract_edge_constraints(edge_constraints)
        g, edges = fci(
            data_np,
            independence_test="chisq",
            seed=seed,
            background_knowledge=background,
            verbose=False,
            show_progress=False,
            alpha=alpha,
            node_names=data.columns.tolist(),
        )

    for edge in edge_constraints:
        from_node = GraphNode(edge[0])
        to_node = GraphNode(edge[1])
        g.add_directed_edge(from_node, to_node)

    for edge in g.get_graph_edges():
        if edge is None:
            continue

    return g


def extract_edge_constraints(graph: list[list[str, str]]
) -> BackgroundKnowledge:
    """
    Extract edge constraints from a directed graph. Those are all the edges that are imposed to be present.

    Args:
        graph (list[list[str, str]]): A list of directed edges representing the proposed causal graph.
            Each edge is a tuple of strings (source, target).

    Returns:
        BackgroundKnowledge: A BackgroundKnowledge object containing the edge constraints.
    """
    edge_constraints = BackgroundKnowledge()
    for edge in graph:
        from_node = GraphNode(edge[0])
        to_node = GraphNode(edge[1])
        edge_constraints.add_required_by_node(from_node, to_node)
    return edge_constraints

if __name__ == "__main__":
    # Example usage
    data = DataFrame(
        {"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}, columns=["A", "B", "C"]
    )

    proposal_graph = [("A", "B"), ("B", "C")]
    print(fci_algorithm(data, proposal_graph))