from torch_geometric.utils import degree
from .pnorm import pnormdiffusion, cut
import numpy as np
import torch


# TODO: debugging
def separate_pnorm_neighbors(num_nodes, edge_index, use_swp_cut=False):
    deg = degree(col, num_nodes).long()
    adjmat = np.zeros((num_nodes, deg.max()), dtype=np.int32)
    for i in range(num_nodes):
        neighbors = edge_index[1, edge_index[0] == i]
        adjmat[i, 0:deg[i]] = neighbors.cpu().numpy().astype(dtype=np.int32)

    x_indices, x_vals, m_indices, m_vals = \
        pnormdiffusion(adjmat, 
                       deg.cpu().detach().numpy().astype(np.int64),
                       num_nodes, 
                       nodes, 
                       delta=np.array(delta, dtype=np.float32),
                       p=4, 
                       limit_type=limit_type)

    mask = np.argsort(x_vals[h][n])[-k:]
    mask = np.union1d(mask, [nodes[n]])

    if use_swp_cut:
        cut_ind = cut(num_nodes, adjmat, deg.cpu().numpy(),
                                     List([List(i) for i in x_indices]),
                                     List([List(v) for v in x_vals]))
        for h in range(num_hier):
            for n in range(len(nodes)):
                mask = cut_ind[h][n]
                mask = np.union1d(mask, [nodes[n]])
                hierarchical_diffusions['inx'][h][n] = mask