import copy
import networkx as nx
import numpy as np
import torch
from torch_geometric.utils import to_networkx



def get_transformation_mask(pyg_data):
    G = to_networkx(pyg_data, to_undirected=False)
    to_rotate = []
    edges = pyg_data.edge_index.T.numpy()
    # Print all edges with attributes
    for i in range(0, edges.shape[0], 2):
        assert edges[i, 0] == edges[i+1, 1]
        G2 = G.to_undirected()
        G2.remove_edge(*edges[i])
        if not nx.is_connected(G2) :
            l = list(sorted(nx.connected_components(G2), key=len)[0])
            if len(l) > 1:
                if edges[i, 0] in l:
                    to_rotate.append([])
                    to_rotate.append(l)
                else:
                    to_rotate.append(l)
                    to_rotate.append([])
                continue
        to_rotate.append([])
        to_rotate.append([])

    mask_edges = np.asarray([0 if len(l) == 0 else 1 for l in to_rotate], dtype=bool)
    mask_rotate = np.zeros((np.sum(mask_edges), len(G.nodes())), dtype=bool)
    idx = 0
    for i in range(len(G.edges())):
        if mask_edges[i]:
            mask_rotate[idx][np.asarray(to_rotate[i], dtype=int)] = True
            idx += 1

    return mask_edges, mask_rotate

def get_component_batch(G):
    """
    Given a NetworkX graph G, return a list where batch[i] = component ID of node i.
    Assumes nodes are numbered from 0 to N-1 (standard PyG format).
    """
    node_to_batch = {}
    for batch_id, component in enumerate(nx.connected_components(G)):
        for node in component:
            node_to_batch[node] = batch_id

    # Ensure output is ordered by node index
    num_nodes = G.number_of_nodes()
    batch = [node_to_batch[i] for i in range(num_nodes)]
    return batch

def get_subgraphs(data):
    """
    Deepcopy the graph from PyG data, and remove all rotatable edges (defined by mask_edges).
    """
    assert hasattr(data, 'mask_edges'), "data must have 'mask_edges' field"

    G = to_networkx(data, to_undirected=True)
    G_cut = copy.deepcopy(G)
    rotatable_edges = data.edge_index.t()[data.mask_edges] 

    for u, v in rotatable_edges.tolist():
        if G_cut.has_edge(u, v):
            G_cut.remove_edge(u, v)
        if G_cut.has_edge(v, u):
            G_cut.remove_edge(v, u)

    subgraph_batch = get_component_batch(G_cut)
    return torch.tensor(subgraph_batch)

