from magni.src.modules.compute_graph_magnitude import *
import numpy as np
import random
from magni.src.modules.utils import to_numpy
from magni.src.modules.compute_graph_magnitude import compute_magnitude_subgraphs, get_magdiff

def get_scores_edge(this_graph, edges, dist_fn, ts, original_magni, method="cholesky"):
    scores = []
    for edge in edges:
        step_graph = this_graph.copy()
        node_a, node_b = edge

        # Merge node_a and node_b into node_a
        neighbors_a = set(step_graph.neighbors(node_a)) - {node_b}
        neighbors_b = set(step_graph.neighbors(node_b)) - {node_a}
        merged_neighbors = neighbors_b.difference(neighbors_a)

        # Connect node_a to the neighbors of both nodes
        for neighbor in merged_neighbors:
            step_graph.add_edge(node_a, neighbor)

        # Remove node_b and its edges
        step_graph.remove_node(node_b)

        # Compute the magnitude difference
        step_magni, _ = compute_magnitude_subgraphs(step_graph, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)
        if len(ts) > 1:
            mag_diff_this = get_magdiff(step_magni, original_magni, ts)
        else:
            mag_diff_this = abs(original_magni[0] - step_magni[0])
        scores.append(mag_diff_this)
    return scores

def edge_pooling_magnitude_repeated(g, ts, dist_fn, original_magni=None, n_steps=None, method="cholesky"):
    n_nodes = g.number_of_nodes()

    if n_steps is None:
        n_steps = n_nodes - 1

    if original_magni is None:
        original_magni, _ = compute_magnitude_subgraphs(g, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)
    
    S = np.eye(n_nodes)
    edges_merged = []
    this_graph = g.copy()

    while (this_graph.number_of_nodes() > (g.number_of_nodes()+1-n_steps)) and (this_graph.number_of_edges() > 0):
        to_do = n_steps - (n_nodes- this_graph.number_of_nodes())
        this_graph, _, nodes_removed, S_new, edges_merged_new = edge_pooling_magnitude(this_graph, ts, dist_fn, n_steps=to_do, method=method)

        S = np.dot(S_new,S)
        edges_merged = edges_merged + edges_merged_new
    return this_graph, None, None, S, edges_merged


def edge_pooling_magnitude(g, ts, dist_fn, original_magni=None, n_steps=None, method="cholesky"):
    ### This function is used to drop nodes from a graph in order to minimise the magnitude difference after removal.

    original_graph = g.copy()
    if original_magni is None:
        original_magni, _ = compute_magnitude_subgraphs(g, dist_fn=dist_fn, ts=ts, get_weights=False, method=method)
    this_graph = original_graph.copy()
    n = g.number_of_nodes()

    if n_steps is None:
        n_steps = n-1

    nodes_removed = []
    S = np.eye(n)
    edges_merged = []
    edges = list(this_graph.edges())  # Get all edges in the graph
    this_nodes = [m for m in this_graph.nodes()]

    scores = get_scores_edge(this_graph, edges, dist_fn, ts, original_magni, method=method)

    n_steps = min(n_steps, len(scores))

    for l in range(0, n_steps):
        this_nodes = [m for m in this_graph.nodes()]
        order_this = np.argsort(scores)
        min_score_indices = np.where(scores == scores[order_this[0]])[0]  # Find all indices with the lowest score
        chosen_index = random.choice(min_score_indices)  # Randomly select one of the indices
        chosen_edge = edges[chosen_index]
        node_a, node_b = chosen_edge

        indx_a = this_nodes.index(node_a)
        indx_b = this_nodes.index(node_b)
        
        # Merge node_a and node_b into node_a
        if 0 <= indx_b < S.shape[0]:  # Ensure indx_b is within bounds
            S[indx_a, :] += S[indx_b, :]
            S = np.delete(S, (indx_b), axis=0)
        else:
            raise IndexError(f"Index indx_b={indx_b} is out of bounds for array S with shape {S.shape}")

        scores.pop(chosen_index)
        edges.pop(chosen_index)

        # Merge the chosen edge in the main graph
        neighbors_a = set(this_graph.neighbors(node_a)) - {node_b}
        neighbors_b = set(this_graph.neighbors(node_b)) - {node_a}
        merged_neighbors = neighbors_b.difference(neighbors_a)

        # Connect node_a to the neighbors of both nodes
        for neighbor in merged_neighbors:
            this_graph.add_edge(node_a, neighbor)

        # Remove node_b and its edges
        this_graph.remove_node(node_b)
        nodes_removed = nodes_removed + [node_b]

        if this_graph.number_of_edges() == 0:
            break

        # Rename node_b to node_a for all edges
        new_edges = []
        new_scores = []
        for i, edge in enumerate(edges):
            if (node_b in edge) | (node_a in edge):
                continue
            else:
                new_edge = edge
                new_edges.append(new_edge)
                new_scores.append(scores[i])

        edges = new_edges
        scores = new_scores

        if (len(edges) == 0) & (l < n_steps - 1):
            print("No edges left to merge")
            break

    # Normalize each row in S by its sum
    row_sums = S.sum(axis=1, keepdims=True)
    S = S / row_sums

    return this_graph, None, nodes_removed, S, edges_merged
