from typing import List, Optional
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor, SparseTensor

from networkx.algorithms import connected_components as nx_connected_components
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

def connected_components(
    edge_index: Adj,
    num_nodes: int,
    edge_weight: OptTensor=None,
    largest: bool = False,
    single : bool = False,
    dtype: Optional[torch.dtype] = None,
) -> List[Tensor]:

    data = Data(x=None,edge_index=edge_index,edge_attr=edge_weight,num_nodes=num_nodes)
    G = to_networkx(data,to_undirected=True)
    components = list(nx_connected_components(G))

    if largest:
        indicator = [0 for _ in range(num_nodes)]
        largest_component = max(components, key=len)
        for idx in largest_component:
            indicator[idx] = 1
    
    else:
        indicator = [None for _ in range(num_nodes)]
        for i,component in enumerate(components):
            for idx in component:
                indicator[idx] = i

    return indicator