from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config("cfg_sa")
def set_cfg_sa(cfg):
    """Configuration for Graph Transformer-style models, e.g.:
    - Spectral Attention Network (SAN) Graph Transformer.
    - "vanilla" Transformer / Performer.
    - General Powerful Scalable (GPS) Model.
    """

    # Positional encodings argument group
    cfg.sa = CN()
    cfg.sa.node_decoder = CN()

    cfg.sa.depth = 3

    cfg.sa.n_heads = 1

    cfg.sa.node_decoder.depth = 2
    cfg.sa.node_decoder.n_heads = 2
