from torch_scatter import scatter

from MegaGNN.graphgym.register import register_pooling


def global_example_pool(x, batch, size=None):
    size = batch.max().item() + 1 if size is None else size
    return scatter(x, batch, dim=0, dim_size=size, reduce='add')


register_pooling('example', global_example_pool)
