from magni.src.modules.compute_graph_magnitude import *
import numpy as np
import random
from magni.src.modules.utils import to_numpy

def reconnect_neighbours(g, node):
    ### This function is used to reconnect all neighbours of a node that has been removed from the graph.
    
    neighbours = g.neighbors(node)
    for n in neighbours:
        for m in neighbours:
            if n != m:
                #if not g.has_edge(n, m):
                g.add_edge(n, m)
    return g

def edge_dropping_random(g, n_steps=None):
    ### This function is used to drop nodes from a graph in order to minimise the magnitude difference after removal.

    this_result = []
    original_graph = g.copy()
    
    this_graph = original_graph.copy()
    n = g.number_of_nodes()

    if n_steps is None:
        n_steps = n-1

    nodes_removed = []
    S = np.eye(n)

    n_steps = min(n_steps, len(list(this_graph.edges())))

    for l in range(0, n_steps):
        #scores = []
        edges = list(this_graph.edges())  # Get all edges in the graph
        this_nodes = [m for m in this_graph.nodes()]
 
        chosen_edge = random.choice(edges)  
        node_a, node_b = chosen_edge

        indx_a = this_nodes.index(node_a)
        indx_b = this_nodes.index(node_b)
        # Merge node_a and node_b into node_a
        S[indx_a, :] += S[indx_b, :]
        S = np.delete(S, (indx_b), axis=0)

        # Merge the chosen edge in the main graph
        neighbors_a = set(this_graph.neighbors(node_a)) - {node_b}
        neighbors_b = set(this_graph.neighbors(node_b)) - {node_a}
        #merged_neighbors = neighbors_a | neighbors_b
        merged_neighbors = neighbors_b.difference(neighbors_a)

        # Connect node_a to the neighbors of both nodes
        for neighbor in merged_neighbors:
            this_graph.add_edge(node_a, neighbor)

        # Remove node_b and its edges
        this_graph.remove_node(node_b)
        nodes_removed = nodes_removed + [node_b]

        if this_graph.number_of_edges() == 0:
            break
    
    row_sums = S.sum(axis=1, keepdims=True)
    S = S / row_sums

    return this_graph, None, nodes_removed, S

def to_nx_graph(X, A):
    g_pool = nx.from_numpy_array(to_numpy(A))
    return g_pool