import networkx as nx
import numpy as np
from causallearn.graph.GraphNode import GraphNode
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from pandas import DataFrame

from utils.causalgraph_utils import edgelist2nodelist
from state import ConquerState


def ci_test(state: ConquerState, alpha: float = 0.05) -> ConquerState:
    """Perform conditional independence tests to establish which are the edges that do not find enough evidence in the data.
    This function is used to refine the causal graph by identifying edges that are not supported by the data.
    It uses the Fast Causal Inference (FCI) algorithm to estimate the causal graph from the data and compares it with the proposed graph.
    If the proposed graph is empty, it returns the state without any changes."""

    # Get graph and data from state
    data = state["dataset"].copy()
    critic_graph = state["causal_graph"]
    if critic_graph == []:
        return state

    # Subset data to only include variables in the proposed graph
    data = data[edgelist2nodelist(critic_graph)]

    # Perform analysis using FCI and add unsupported edges to the state
    unsupported_edges = edges_without_evidence(critic_graph, data, significance_level=alpha)

    state["unsupported_edges"] = unsupported_edges

    state["supported_edges"] = list(set(critic_graph) - set(unsupported_edges))

    return state

def edges_without_evidence(proposal_edges: list[tuple[str, str]], data: DataFrame, significance_level: float = 0.05
) -> list[tuple[str, str]]:
    """
    Identifies edges in the proposal graph that are not supported by the evidence 
    in the given data.
    This function compares a proposed causal graph with an estimated graph 
    derived from the data using the Fast Causal Inference (FCI) algorithm and
      using the proposal graph as a starting skeleton. It returns the edges in
        the proposal graph that differ from the edges in the estimated graph.

    Args:
        proposal_edges (list[tuple[str, str]]): A list of directed edges representing 
            the proposed causal graph. Each edge is a tuple of strings (source, target).
        data (DataFrame): A pandas DataFrame containing the dataset used to estimate 
            the causal graph.
        significance_level (float): The significance level for the conditional independence tests.
    Returns:
        list[tuple[str, str]]: A list of edges from the proposal graph that are not 
        supported by the evidence in the data.
    """
    
    # Convert proposal to networkx graph
    nx_proposal_graph = nx.DiGraph()
    nx_proposal_graph.add_nodes_from(data.columns)
    nx_proposal_graph.add_edges_from(proposal_edges)

    # Estimate graph from data
    edges_with_evidence = run_fci(data, nx_proposal_graph, seed=0, alpha=significance_level)

    # Return edges that are in the proposal graph but not in the graph estimated from data
    return list(set(proposal_edges) - set(edges_with_evidence))

def check_constant_columns_numpy(arr: np.ndarray) -> list[int]:
    """
    Checks for constant columns in a NumPy array.

    Args:
        arr (np.ndarray): The input NumPy array.

    Returns:
        list[int]: A list of column indices that are constant.
    """
    constant_cols = []
    for i in range(arr.shape[1]):
        if np.all(arr[:, i] == arr[0, i]):
            constant_cols.append(i)
    return constant_cols

def run_fci(data, initial_condition: nx.DiGraph, seed: int = 0, alpha: float = 0.05) -> nx.DiGraph:
    """
    Run the FCI algorithm on the given data and initial condition. In case of a
      singular correlation matrix, add noise to the data. This is needed, 
      otherwise C.I. tests will fail.

    Args:
        data (DataFrame): The input data.
        initial_condition (nx.DiGraph): The initial condition graph.
        seed (int): Random seed for reproducibility.
    Returns:
        nx.DiGraph: The resulting directed graph.
    """
    data_np = data.to_numpy().astype(float)

    constant_columns = check_constant_columns_numpy(data_np)
    if len(constant_columns) > 0:
        # Add noise to constant columns
        noise = np.random.normal(0, 0.0001, data_np.shape)
        data_np = data_np + noise

    # If the correlation matrix is singular, add noise to the date
    # It is needed, otherwise C.I. tests will fail
    corr_matrix = np.corrcoef(data_np, rowvar=False)
    determinant = np.linalg.det(corr_matrix)
    if determinant < 1e-2:
        noise = np.random.normal(0, 0.0001, data_np.shape)
        data_np = data_np + noise

    # Run FCI algorithm
    background = extract_edge_constraints(list(data.columns), initial_condition)
    g, edges = fci(
        data_np,
        independence_test="chisq",
        seed=seed,
        background_knowledge=background,
        verbose=False,
        show_progress=False,
        alpha=alpha,
        node_names=data.columns.tolist(),
    )

    # Convert the resulting graph to a NetworkX graph
    g = nx.DiGraph()
    g.add_nodes_from(data.columns)
    # Create mapping from X1, X2.. labels to the real labels.
    mapping = {}
    for k, label in enumerate(data.columns):
        mapping[f"X{k + 1}"] = label

    for edge in edges:
        node1 = mapping[edge.node1.name]
        node2 = mapping[edge.node2.name]
        g.add_edge(node1, node2)

    return list(g.edges())


def extract_edge_constraints(
    nodes: list[str], graph: list[list[str, str]]
) -> BackgroundKnowledge:
    """
    Extract edge constraints from a directed graph. Those are all the edges that are not present.

    Args:
        graph (nx.DiGraph): The directed graph.

    Returns:
        BackgroundKnowledge: List of edges that are not present in the graph.
    """
    # Create dense graph with same nodes as input graph
    dense_graph = nx.DiGraph()
    dense_graph.add_nodes_from(nodes)
    edges = [[x1, x2] for x1 in nodes for x2 in nodes if x1 != x2]
    dense_graph.add_edges_from(edges)

    for edge in graph:
        if edge in list(dense_graph.edges()):
            dense_graph.remove_edge(edge[0], edge[1])

    prohibit_edges = list(dense_graph.edges())
    edge_constraints = BackgroundKnowledge()
    for forbid_edge in prohibit_edges:
        from_node = GraphNode(forbid_edge[0])
        to_node = GraphNode(forbid_edge[1])
        edge_constraints.add_forbidden_by_node(from_node, to_node)
    return edge_constraints

if __name__ == "__main__":
    # Example usage
    data = DataFrame(
        {"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}, columns=["A", "B", "C"]
    )

    proposal_graph = [("A", "B"), ("B", "C")]
    print(edges_without_evidence(proposal_graph, data))