from torch_geometric.graphgym.register import register_config


@register_config('custom_gnn')
def custom_gnn_cfg(cfg):
    """Extending config group of GraphGym's built-in GNN for purposes of our
    CustomGNN network model.
    """

    # Use residual connections between the GNN layers.
    cfg.gnn.residual = False
    cfg.gnn.global_dropout = 0.0
    cfg.gnn.pooling_layer = 'mean'
    cfg.gnn.layer_pooling = None # either 'add_mean', 'subtract_mean', 'learnable_mean'
    cfg.gnn.feedforward = False
    cfg.gnn.norm_weighting = False
