import torch_geometric.data as pygData
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.utils import is_undirected
import torch

def batch2dense(batch: pygData.Batch, batch_size: int=None, max_num_nodes: int=None, permute: bool=False):
    x, nodemask = to_dense_batch(batch.x, batch.batch, batch_size=batch_size, max_num_nodes=max_num_nodes)
    nodemask = torch.logical_not(nodemask) # true means not node
    max_num_nodes = x.shape[1]
    batch_size = x.shape[0]
    pos, _ = to_dense_batch(batch.pos, batch.batch, batch_size=batch_size, max_num_nodes=max_num_nodes)
    #assert is_undirected(batch.edge_index)
    A = to_dense_adj(batch.edge_index, batch.batch, batch.edge_attr, max_num_nodes).contiguous()
    if permute:
        perm = torch.randperm(max_num_nodes, device=x.device)
        x = x[:, perm]
        nodemask = nodemask[:, perm]
        A = A[:, perm][:, :, perm]
    return A, x, nodemask, pos, max_num_nodes
