from MegaGNN.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config('cfg_gt')
def set_cfg_gt(cfg):
    # Positional encodings argument group
    cfg.gt = CN()

    # Prediction head. Use cfg.dataset.task by default
    cfg.gt.head = 'default'

    # Type of Graph Transformer layer to use
    cfg.gt.layer_type = 'SANLayer'

    # Number of layers before graph transformer
    cfg.gt.layers_pre_gt = 0

    # Number of Transformer layers in the model
    cfg.gt.layers = 3

    # Number of layers after graph transformer
    cfg.gt.layers_post_gt = 0

    # Number of attention heads in the Graph Transformer
    cfg.gt.attn_heads = 1

    # Size of the hidden node and edge representation
    cfg.gt.dim_hidden = 64

    # Turn on edge-type based weight
    cfg.gt.edge_weight = True

    # Full attention SAN transformer including all possible pairwise edges
    cfg.gt.full_graph = True

    # SAN real vs fake edge attention weighting coefficient
    cfg.gt.gamma = 1e-5

    # Histogram of in-degrees of nodes in the training set used by PNAConv.
    # Used when `gt.layer_type: PNAConv+...`. If empty it is precomputed during
    # the dataset loading process.
    cfg.gt.pna_degrees = []

    # Dropout for input.
    cfg.gt.input_dropout = 0.0

    # Dropout in feed-forward module.
    cfg.gt.dropout = 0.0

    # Dropout in self-attention.
    cfg.gt.attn_dropout = 0.0

    cfg.gt.layer_norm = False

    cfg.gt.batch_norm = False

    cfg.gt.l2_norm = False

    cfg.gt.residual = 'Fixed'

    cfg.gt.jumping_knowledge = False

    cfg.gt.act = 'relu'

    cfg.gt.ffn = 'none'

    # Attention masking, options: "none", "Edge", "kHop"
    cfg.gt.attn_mask = 'Edge'

    # kHop attention parameter
    cfg.gt.hops = 2

    cfg.gt.virtual_nodes = 0
