import torch_geometric.data as pygData
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.utils import is_undirected
import torch

def batch2dense(batch: pygData.Batch, batch_size: int=None, max_num_nodes: int=None, permute: bool=False):
    max_num_nodes = torch.max(torch.diff(batch.ptr))
    aligned_size = 32
    max_num_nodes = ((max_num_nodes + aligned_size - 1)//aligned_size) * aligned_size
    x, nodemask = to_dense_batch(x=batch.x, batch=batch.batch, batch_size=batch_size, max_num_nodes=max_num_nodes)
    nodemask = torch.logical_not(nodemask) # true means not node
    max_num_nodes = x.shape[1]
    batch_size = x.shape[0]
    # assert is_undirected(batch.edge_index, batch.edge_attr)
    A = to_dense_adj(batch.edge_index, batch.batch, batch.edge_attr, max_num_nodes).contiguous()
    if permute:
        perm = torch.randperm(max_num_nodes, device=x.device)
        x = x[:, perm]
        nodemask = nodemask[:, perm]
        A = A[:, perm][:, :, perm]
    return A, x, nodemask, max_num_nodes
