from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config("extended_model")
def extended_model_cfg(cfg):
    # # Positional encodings argument group
    # cfg.model = CN()

    cfg.model.type = "PerceiverGraph_SingleDataset"

    cfg.model.node_pos_encoder_name = "LapPE"

    cfg.model.node_feat_encoder_name = "LinearNode"

    cfg.model.node_pos_encoder_name = "LinearNode"

    cfg.model.ffn_dropout = 0.0

    cfg.model.loss_fun = "cross_entropy"

    cfg.model.num_latents = 128

    cfg.model.attn_dropout = 0.0

    cfg.model.lin_dropout = 0.0

    cfg.model.dim_data_emb = 4

    cfg.model.data_emb_init_scale = 0.02

    cfg.model.tok_emb_dim = 4

    cfg.model.hop_cutoff = 50

    cfg.model.use_memory_efficient_attn = True

    cfg.model.kmeans_clusters = 0

    cfg.model.latent_dim = 128

    cfg.model.pretrained_epoch = -1

    cfg.model.pretrained_model_run_id = "None"
