
import networkx as nx
import numpy as np

def add_path_to_matrix(matrix, path, tau=1):
    # Iterate through pairs of nodes in the path list to update the matrix
    for i in range(len(path) - 1):
        start = path[i]
        end = path[i + 1]
        # Set the matrix entry to 1 (or another value indicating a connection)
        matrix[start, end] = tau

    return matrix

def transitive_reduction(G):
    """ Returns the transitive reduction of a directed graph """
    TR = nx.DiGraph()
    TR.add_nodes_from(G.nodes())
    
    def reachable_nodes(u):
        """ Uses DFS to find all reachable nodes from a given node u """
        seen = set()
        stack = [u]
        while stack:
            node = stack.pop()
            for neighbor in G[node]:
                if neighbor not in seen:
                    seen.add(neighbor)
                    stack.append(neighbor)
        return seen
    
    for u in G.nodes():
        reachable = reachable_nodes(u)
        for v in G[u]:
            # If v is reachable from u via other paths that do not include (u, v) directly
            if v not in reachable or all(w == v or v not in reachable_nodes(w) for w in G[u] if w != v):
                TR.add_edge(u, v)
    
    return TR


def find_all_paths(graph, start_vertex, path=None):
    """Recursively find all paths from start_vertex to any sink in the graph."""
    if path is None:
        path = [start_vertex]
    else:
        path.append(start_vertex)
    
    # Check if the current node is a sink (no outgoing edges)
    if len(list(graph.successors(start_vertex))) == 0:
        yield list(path)
    else:
        for next_vertex in graph.successors(start_vertex):
            yield from find_all_paths(graph, next_vertex, list(path))

def extract_paths(dag):
    """Extracts all paths such that their union of transitive closures covers the transitive closure of the DAG."""
    all_paths = []
    # Identify all sources (nodes with no incoming edges)
    sources = [v for v in dag.nodes() if dag.in_degree(v) == 0]
    
    # From each source, find all paths to sinks
    for source in sources:
        all_paths.extend(find_all_paths(dag, source))
    
    return all_paths

def get_tr(A):
    G = nx.DiGraph()
    edges = np.argwhere(A)
    G.add_edges_from(edges)
    TR = transitive_reduction(G)
    return TR

def get_paths(A):
    TR = get_tr(A)
    paths = extract_paths(TR)
    return paths


if __name__ == "__main__":

    # Example of use
    A = np.array([
        [0, 1, 1, 0],
        [0, 0, 1, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 0]
    ])

    # G = nx.DiGraph()
    # # # edges = [(0, 1), (0, 2), (1, 2),(1,3)]

    # # #return edges in a list  of vertex pairs from A
    # edges = np.argwhere(A)

    # G.add_edges_from(edges)
    # Compute the transitive reduction
    paths = get_paths(A)

    # Print the edges of the transitive reduction graph
    print(paths)
