import torch
import torch.nn as nn
from torch_geometric.nn.aggr import AttentionalAggregation, SetTransformerAggregation, GraphMultisetTransformer


class Aggregator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, pooling_method):
        super().__init__()
        if pooling_method == 'attentional_aggregation':
            self.pool = AttentionalAggregation(
                gate_nn=nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.SiLU(),
                    nn.Linear(hidden_dim, 1),
                ),
                nn=nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.SiLU(),
                    nn.Linear(hidden_dim, output_dim),
                )
            )
        elif pooling_method in ('set_transformer', 'last_set_transformer'):
            self.pool = SetTransformerAggregation(
                input_dim, heads=8, num_encoder_blocks=4, num_decoder_blocks=4)
        elif pooling_method in ('graph_multiset_transformer', 'last_graph_multiset_transformer'):
            self.pool = GraphMultisetTransformer(input_dim, k=8, heads=8)
        else:
            raise ValueError(f'Unknown pooling method {pooling_method}')

    def forward(self, x):
        batch_idx = torch.arange(x.shape[0], device=x.device).repeat_interleave(x.shape[1])
        return self.pool(x.flatten(0, 1), index=batch_idx)
