import torch_geometric.nn as pyg
import torch.nn as nn
from global_pooling_layers import TemporalGlobalPoolingLayer

def get_graph_pooling(pooling_type: str, emb_dim: int = None):
    """
    Returns the appropriate graph pooling function based on the specified type.
    
    Args:
        pooling_type (str): Type of pooling ('sum', 'mean', 'max', or 'attention')
        emb_dim (int, optional): Embedding dimension, required for attention pooling
        
    Returns:
        callable: The pooling function
    """
    if pooling_type == 'sum':
        return pyg.global_add_pool
    elif pooling_type == 'mean':
        return pyg.global_mean_pool
    elif pooling_type == 'max':
        return pyg.global_max_pool
    elif pooling_type == 'attention':
        if emb_dim is None:
            raise ValueError("emb_dim must be provided for attention pooling")
        return pyg.GlobalAttention(gate_nn=nn.Linear(emb_dim, 1))
    elif pooling_type == 'temporal_global_pooling':
        return TemporalGlobalPoolingLayer(emb_dim, emb_dim, 4, 0.1)
    else:
        raise ValueError("Invalid graph pooling type.") 