import networkx as nx
import numpy as np

from collections import defaultdict

def plan_to_dag(plan):
    '''
    Convert LLM plan to networkx DAG
    '''
    dag = nx.MultiDiGraph()
    dag.add_nodes_from([(n["id"], n) for n in plan["nodes"]])
    dag.add_edges_from(plan["edges"])
    return dag

def get_unique_vertices(edges):
    vertices = set()
    for u, v in edges:
        vertices.add(u)
        vertices.add(v)
    return vertices

def add_edges(edges):
    graph = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)
    return graph

def dfs(graph, v, visited, rec_stack):
    visited[v] = True
    rec_stack[v] = True
    for neighbor in graph[v]:
        if not visited[neighbor]:
            if dfs(graph, neighbor, visited, rec_stack):
                return True
        elif rec_stack[neighbor]:
            return True
    rec_stack[v] = False
    return False

def validate_dag(edges):
    '''
    Returns True if graph is a valid DAG, False otherwise. 
    '''
    graph = add_edges(edges)
    V = get_unique_vertices(edges)
    visited = {node: False for node in V}
    rec_stack = {node: False for node in V}
    for node in V:
        if not visited[node]:
            if dfs(graph, node, visited, rec_stack):
                return False
    return True

def get_num_nodes(graph):
    return graph.number_of_nodes()

def get_num_edges(graph):
    return graph.number_of_edges()

def is_linear(graph):
    in_degrees = dict(graph.in_degree())
    out_degrees = dict(graph.out_degree())

    start_nodes = [node for node in graph.nodes if in_degrees[node] == 0]
    end_nodes = [node for node in graph.nodes if out_degrees[node] == 0]
    
    if len(start_nodes) != 1 or len(end_nodes) != 1:
        return False
    
    for node in graph.nodes:
        if node not in start_nodes and node not in end_nodes:
            if in_degrees[node] != 1 or out_degrees[node] != 1:
                return False
    return True

def avg_degree(graph):
    return sum(dict(graph.degree()).values()) / graph.number_of_nodes()

def avg_in_degree(graph):
    return sum(dict(graph.in_degree()).values()) / graph.number_of_nodes()

def avg_out_degree(graph):
    return sum(dict(graph.out_degree()).values()) / graph.number_of_nodes()

def centrality(graph):
    return nx.degree_centrality(graph)

def paths_from_src_to_dest(graph, source, destination):
    return len(list(nx.all_simple_paths(graph, source=source, target=destination)))
