from torch_scatter import scatter

from torch_geometric.graphgym.register import register_pooling


@register_pooling('example')
def global_example_pool(x, batch, size=None):
    size = batch.max().item() + 1 if size is None else size
    return scatter(x, batch, dim=0, dim_size=size, reduce='add')
