
import igraph
import networkx as nx
from tqdm import tqdm
# from graphgym.utils import plot_igraph
import numpy as np
from .utils import remove_edges_with_attribute_value, torch_geometric_to_igraph

def is_graph_valid(graph, subg=False):
    """
    Checks if a graph is valid, meaning it does not contain isolated node, has exactly one start and end node.

    Args:
        graph (igraph.Graph): The input graph.

    Returns:
        bool: True if the graph is valid, False otherwise.
    """
    # Get the degree of each node (number of connections)
    degrees = graph.degree()
    
    # Check if any node has a degree of 0 (isolated)
    has_isolated_nodes = any(degree == 0 for degree in degrees)

    if has_isolated_nodes: return False
    
    if subg:
        START_TYPE=0
        END_TYPE=1
    else:
        START_TYPE=8 
        END_TYPE=9

    #return res
    n_start, n_end = 0, 0
    for v in graph.vs:
        if v['type'] == START_TYPE:
            n_start += 1
        elif v['type'] == END_TYPE:
            n_end += 1
        if v.outdegree() == 0 and v['type'] != END_TYPE:
            return False
    return n_start == 1 and n_end == 1


def is_valid_circuit_graph(g,subg = False):
    """
        Check if the given igraph g is a valid circuit graph. 
        Same metrics as CktGNN but substituing the valid Dag with the valid graph.
        And also looking fro the main path in all possible path with the same length as the diameter.
        Args:
            g (igraph.Graph): The input graph.
            subg (bool): If the graph is a subgraph.    
        Returns:
            bool: True if the graph is valid, False otherwise.
    """
    if subg:
        if not is_graph_valid(g, subg=subg):
            return False 
        cond2 = True
        for v in g.vs:
            pos = v['pos']
            subg_feats = [v['r'], v['c'], v['gm']]
            if pos in [2,3,4]: # i.e. in the main path
                if v['type'] in [8,9]:
                    cond2 = False
        return cond2
    else:
        if not is_graph_valid(g):
            return False    
        diameter_path = g.get_diameter(directed=False) #find the main path the diameter path must start/end at the sudo input/end node
        all_paths = calculate_all_shortest_paths(g, len(diameter_path))
        for (source, target), paths in all_paths.items():
            for path in paths:
                cond = True
                for i, v_ in enumerate(path):
                    v = g.vs[v_]
                    if i == 0:
                        if v['type'] != 8:
                            cond = False
                    elif i == len(path) - 1:
                        if v['type'] != 9:
                            cond = False
                    else:
                        if v['type'] in [4, 5]:
                            cond = False
                            predecessors_ = g.predecessors(v_)
                            successors_ = g.successors(v_)
                            for v_p in predecessors_:
                                v_p_succ = g.successors(v_p)
                                for v_cand in v_p_succ:
                                    inster_set = set(g.successors(v_cand)) & set(successors_)
                                    if g.vs[v_cand]['type'] in [0, 1] and len(inster_set) > 0:
                                        cond = True
                if cond:
                    return True

def is_valid_DAG(g, subg=False):
    """
        Check if the given igraph g is a valid DAG computation graph.
        Same metrics as CktGNN.
    """
    # Check if the given igraph g is a valid DAG computation graph
    # first need to have no directed cycles
    # second need to have no zero-indegree nodes except input
    # third need to have no zero-outdegree nodes except output
    # i.e., ensure nodes are connected
    # fourth need to have exactly one input node
    # finally need to have exactly one output node
    if subg:
        START_TYPE=0
        END_TYPE=1
    else:
        START_TYPE=8 
        END_TYPE=9
    res = g.is_dag()
    #return res
    n_start, n_end = 0, 0
    for v in g.vs:
        if v['type'] == START_TYPE:
            n_start += 1
        elif v['type'] == END_TYPE:
            n_end += 1
        if v.outdegree() == 0 and v['type'] != END_TYPE:
            return False
    return res and n_start == 1 and n_end == 1

def calculate_all_shortest_paths(graph, diameter_length):
    """
    Calculate all shortest paths between all pairs of nodes in a graph.

    Args:
        graph (igraph.Graph): The input graph.

    Returns:
        dict: A dictionary where the keys are tuples (source, target),
              and the values are lists of shortest paths between the nodes.
    """
    all_shortest_paths = {}

    for source in range(graph.vcount()):  # Loop over all vertices as source
        for target in range(graph.vcount()):  # Loop over all vertices as target
            if source != target:  # Skip paths from a node to itself
                paths = graph.get_all_shortest_paths(source, to=target)
                if len(paths) != 0 and len(paths[0]) == diameter_length:
                    all_shortest_paths[(source, target)] = paths

    return all_shortest_paths

def is_valid_Circuit(g, subg=False):
    """
        Check if the given igraph g is a valid circuit graph.
        Same metrics as CktGNN, but check over all the paths with the same length as the diameter.
    """
    # Check if the given igraph g is a amp circuits
    # first checks whether the circuit topology is a DAG
    # second checks the node type in the main path
    if subg:
        if not is_valid_DAG(g, subg=subg):
            return False 
        cond2 = True
        for v in g.vs:
            pos = v['pos']
            subg_feats = [v['r'], v['c'], v['gm']]
            if pos in [2,3,4]: # i.e. in the main path
                if v['type'] in [8,9]:
                    cond2 = False
        return cond2
    else:
        if not is_valid_DAG(g, subg=subg):
            return False    
        diameter_path = g.get_diameter(directed=True) #find the main path the diameter path must start/end at the sudo input/end node
        all_paths = calculate_all_shortest_paths(g, len(diameter_path))
        for (source, target), paths in all_paths.items():
            for path in paths:
                cond = True
                for i, v_ in enumerate(path):
                    v = g.vs[v_]
                    if i == 0:
                        if v['type'] != 8:
                            pass
                            # cond = False
                    elif i == len(path) - 1:
                        if v['type'] != 9:
                            pass
                            # cond = False
                    else:
                        if v['type'] in [4, 5]:
                            cond = False
                            predecessors_ = g.predecessors(v_)
                            successors_ = g.successors(v_)
                            for v_p in predecessors_:
                                v_p_succ = g.successors(v_p)
                                for v_cand in v_p_succ:
                                    inster_set = set(g.successors(v_cand)) & set(successors_)
                                    if g.vs[v_cand]['type'] in [0, 1] and len(inster_set) > 0:
                                        cond = True
                if cond:
                    return True
        return False

def is_valid_circuit_cktgnn(g, subg=True):
    """
        Same metrics as CktGNN.
        Check if the given igraph g is a valid circuit graph.
    """
    # Check if the given igraph g is a amp circuits
    # first checks whether the circuit topology is a DAG
    # second checks the node type in the main path
    if subg:
        cond1 = is_valid_DAG(g, subg=True)
        cond2 = True
        for v in g.vs:
            pos = v['pos']
            subg_feats = [v['r'], v['c'], v['gm']]
            if pos in [2,3,4]: # i.e. in the main path
                if v['type'] in [8,9]:
                    cond2 = False
        return cond1 and cond2
    else:
        cond1 = is_valid_DAG(g, subg=False)
        cond2 = True
        diameter_path = g.get_diameter(directed=True) #find the main path the diameter path must start/end at the sudo input/end node
        if len(diameter_path) < 3:
            cond2 = False
        for i, v_ in enumerate(diameter_path):
            v = g.vs[v_]
            if i == 0:
                if v['type'] != 8:
                    cond2 = False
            elif i == len(diameter_path) - 1:
                if v['type'] != 9:
                    cond2 = False
            else:
                #if v['type'] not in [1,2,3]: # main path nodes must come from subg_type = 6 or 7 or 10 or 11
                if v['type'] in [4, 5]:
                    cond2 = False
                    predecessors_ = g.predecessors(i)
                    successors_ = g.successors(i)
                    for v_p in predecessors_:
                        v_p_succ = g.successors(v_p)
                        for v_cand in v_p_succ:
                            inster_set = set(g.successors(v_cand)) & set(successors_)
                            if g.vs[v_cand]['type'] in [0,1] and len(inster_set) > 0:
                                cond2 = True
        return cond1 and cond2
    

def our_is_valid_circuit(g: igraph.Graph):
    """
        Check if the given igraph g is a valid circuit graph.
        Check if the graph is valid.
        Calculate all the possible path from the start to the end node, and check if exists a path that does not contain any node of type 4 or 5. (Resistor and Capacitor)
        That means that is the main path, and so the operational aplifier is designed correctly.
    Input:
        g: igraph.Graph
    Output:
        bool: True if the graph is valid, False otherwise.
    """
    if not is_graph_valid(g):
        return False
    
    start_vertex = None
    end_vertex = None

    for v in g.vs:
        if v['type'] == 8:
            start_vertex = v
        elif v['type'] == 9:
            end_vertex = v

    if start_vertex is None or end_vertex is None:
        return False

    # Transform the directed graph into an undirected one
    undirected_g : igraph.Graph = g.as_undirected()

    start_vertex_idx, end_vertex_idx = start_vertex.index, end_vertex.index

    # # # Additional logic can be added here using start_vertex and end_vertex
    # paths = undirected_g.get_all_simple_paths(start_vertex, to=end_vertex)
    # if not paths:
    #     return False

    # for path in paths:
    #     types = [undirected_g.vs[v_]['type'] for v_ in path]
    #     if 4 not in types and 5 not in types:
    #         return True
    
    # return False
    return check_valid_dfs(undirected_g, start_vertex_idx, end_vertex_idx)


def ratio_same_DAG(G0, G1):
    """
        Same metrics as CktGNN.
        Compute the ratio of the number of graphs in G1 that are the same DAG as some graph in G0.
        Args:
            G0 (list): List of igraph.Graph objects.
            G1 (list): List of igraph.Graph objects.
        Returns:
            float: The ratio of the number of graphs in G1 that are the same DAG as some graph in G0.
    """
    # from cktgnn
    # how many G1 are in G0
    res = 0
    for g1 in tqdm(G1):
        for g0 in G0:
            if is_same_DAG(g1, g0):
                res += 1
                break
    return res / len(G1)

def ratio_same_graphs(G0, G1, save_novelty=False):
    """
        Compute the ratio of the number of graphs in G1 that are the same as some graph in G0.
    """
    # how many G1 are in G0
    res = 0
    # init all the generated graphs with novelty
    if save_novelty:
        for g in G0: 
            g.novel_graph = True

    for g1 in G1:
        for g0 in G0:
            if are_igraphs_equal(g1, g0):
                res += 1
                if save_novelty: g0.novel_graph = False
                break
                
    return res / len(G1)


def unique_ratio(graphs):
    """
        Compute the ratio of the number of unique graphs in the list.
    """
    # how many unique graphs in the list
    unique_graphs = []
    for g in graphs:
        if not any(are_igraphs_equal(g, ug) for ug in unique_graphs):
            unique_graphs.append(g)
            g.unique_graph = True
        else:
            g.unique_graph = False

    return len(unique_graphs) / len(graphs)

def get_training_graphs(train_loader):
    """
        
    """
    train_graphs = []
    for batch in train_loader:
        for i in range(batch.num_graphs):
            graph = batch.get_example(i)
            graph = remove_edges_with_attribute_value(graph, 0)
            i_graph = torch_geometric_to_igraph(graph)
            train_graphs.append(i_graph)
    return train_graphs


def novelty_ratio(generated_graphs, train_loader):
    """
        Compute the novelty ratio of the generated graphs with respect to the training set.
        The novelty ratio is defined as the ratio of the number of unique generated graphs that are not in the training set.
    """
    train_graphs = get_training_graphs(train_loader)
    return 1 - ratio_same_graphs(generated_graphs, train_graphs, save_novelty=True)

def compute_VUN(generated_graphs):
    count_VUN = 0
    for g in generated_graphs:
        if g.valid_graph and g.valid_circuit and g.unique_graph and g.novel_graph:
            g.VUN = True
            count_VUN += 1

    return count_VUN / len(generated_graphs) if len(generated_graphs) > 0 else 0
def is_same_DAG(g0, g1):
    """
        Same metrics as CktGNN.
        Check if two igraph.Graph objects are the same DAG.
    """

    # note that it does not check isomorphism
    if g0.vcount() != g1.vcount():
        return False
    for vi in range(g0.vcount()):
        if g0.vs[vi]['type'] != g1.vs[vi]['type']:
            return False
        if set(g0.neighbors(vi, 'in')) != set(g1.neighbors(vi, 'in')):
            return False
    return True

def are_igraphs_equal(g0, g1, property_name= "type"):

    """
        Check if two igraphs are equal.
        They must have same number of nodes and edges.
        MultiSet of node types must be the same.
        For each node type in g0, there must be a node in g1 with the same type and the same neighbors. (Of course, each node can be used only once for this comparison.)
    """

    if g0.vcount() != g1.vcount():
        return False
    
    if g0.ecount() != g1.ecount():
        return False
    
    g0_types = np.array(g0.vs[property_name])
    g1_types = np.array(g1.vs[property_name])

    g0_types.sort()
    g1_types.sort()

    if not np.array_equal(g0_types, g1_types):
        return False
    
    g1_indexes_used = []
    
    for vertex0 in g0.vs:

        vertex0_neighbors_types = get_neighbors_property_names(g0, vertex0.index, property_name)

        found = False
        for vertex1 in g1.vs:
            if vertex1.index in g1_indexes_used:
                continue
            if vertex1[property_name] == vertex0[property_name]:
                vertex1_neighbors_types = get_neighbors_property_names(g1, vertex1.index, property_name)
                if vertex0_neighbors_types == vertex1_neighbors_types:
                    g1_indexes_used.append(vertex1.index)
                    found = True
                    break
        if not found:
            return False
    
    return True

def get_neighbors_property_names(graph, vertex, property_name):
    """
        Get the property values of the neighbors of a vertex.
    """
    neighbors = graph.neighbors(vertex)
    property_names = []
    for neighbor in neighbors:
        property_names.append(graph.vs[neighbor][property_name])
    property_names.sort()
    return property_names


def check_valid_dfs(graph, start_vertex_idx, end_vertex_idx, forbidden_types=[0, 1], max_depth=4):
    
    edge_list = graph.get_edgelist()
    node_types = graph.vs['type']
    path, valid = dfs_step(start_vertex_idx, list(), edge_list, node_types, forbidden_types, 0, max_depth, end_vertex_idx)

    return valid


def dfs_step(node_idx, visited, edge_list, node_types, forbidden_types, depth_idx, max_depth, final_node_idx):

    if depth_idx == max_depth:
        return visited[:-1], False

    child_nodes = [d for (s, d) in edge_list if s == node_idx] + [s for (s, d) in edge_list if d == node_idx]
    child_nodes = np.unique([c for c in child_nodes if (c not in visited) and (node_types[c] not in forbidden_types)])

    if (depth_idx == 0) and (final_node_idx in child_nodes):
        return visited[:-1], False

    if (depth_idx > 0) and (final_node_idx in child_nodes):
        return visited + [node_idx, final_node_idx], True

    depth_inc = 0
    if len(child_nodes) >= 1:
        depth_inc = 1

    child_nodes = child_nodes.tolist()
    while len(child_nodes) > 0:
        child_idx = child_nodes.pop(0)
        if (child_idx in visited):
            continue
        visited, bool_search = dfs_step(child_idx, visited + [node_idx], edge_list, node_types, \
                                        forbidden_types, depth_idx + depth_inc, max_depth, final_node_idx)
        if bool_search:
            return visited, True

    return visited[:-1], False # dead-end, remove last index
