from torch_geometric.graphgym.register import register_config


@register_config('custom_gnn')
def custom_gnn_cfg(cfg):
    """Extending config group of GraphGym's built-in GNN for purposes of our
    CustomGNN network model.
    """

    cfg.devices_ids = []

    cfg.gnn.layers = 1

    # Normalization: layer, batch or None
    cfg.gnn.norm_type = None

    cfg.gnn.layer_type = None

    # Only used with GAT
    cfg.gnn.n_heads = 2

    cfg.gnn.dropout = 0.0

    cfg.gnn.alpha = 0.5

    cfg.gnn.pna_degrees = None

    cfg.gnn.jk = True
    cfg.gnn.jk_mode = 'max'

