from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config("slt")
def set_cfg_slt(cfg):
    r"""
    This function sets the default config value for SLT options.
    """

    # SLT options
    cfg.slt = CN()
    cfg.slt.slt = False
    cfg.slt.debug = False
    cfg.slt.mpnn = False
    cfg.slt.msa = False
    cfg.slt.ffn = False
    cfg.slt.encoder = False
    cfg.slt.pred = False
    cfg.slt.sm = False
    cfg.slt.mm = False
    cfg.slt.linear_sparsity = [1.0]
    cfg.slt.init_mode_weight = "signed_constant_SF"
    cfg.slt.init_mode_score = "kaiming_uniform"
    cfg.slt.enable_unshared = False
    cfg.slt.init_scale_weight = 1.0
    cfg.slt.init_scale_score = 1.0
    cfg.slt.sparse_scheduling = "linear_ascending"
    cfg.slt.score_split = "linear"
    cfg.slt.folding = False
    cfg.slt.pruning = "global"
    cfg.slt.enable_abs_pruning = True
    cfg.slt.mm_step = False
    cfg.slt.mm_delayed = False
    cfg.slt.mpnn_batch_norm = True
    cfg.slt.mpnn_layer_norm = False
    cfg.slt.msa_batch_norm = True
    cfg.slt.msa_layer_norm = False
    cfg.slt.no_mask = False
    cfg.slt.sign_mask = False
    cfg.slt.adj_rand_pruning = False
    cfg.slt.adj_pruning_rate = 1.0
    cfg.slt.batch_affine = True
    cfg.slt.embedding = False
    cfg.slt.srste = False
    cfg.slt.srste_decay = 0.0002
    cfg.slt.learnable_sum = False
    cfg.slt.minmax_scale = False
    cfg.slt.learnable_weight_scaling = False
    cfg.slt.slt_weight_scaling = False
    cfg.slt.attention_scaling = 1.0
    cfg.slt.bitlinear = False
    cfg.slt.tome = False
    cfg.slt.tome_r = 4
    cfg.slt.save_fig = False
    cfg.slt.batchnorm_mpnn = False
    cfg.slt.layernorm_mpnn = False
    cfg.slt.pairnorm_mpnn = False
    cfg.slt.rmsnorm_mpnn = False
    cfg.slt.batchnorm_msa = False
    cfg.slt.layernorm_msa = False
    cfg.slt.pairnorm_msa = False
    cfg.slt.rmsnorm_msa = False
    cfg.slt.batchnorm_ffn = False
    cfg.slt.layernorm_ffn = False
    cfg.slt.pairnorm_ffn = False
    cfg.slt.rmsnorm_ffn = False
    cfg.slt.node_perturbation = 0.0
    cfg.slt.edge_perturbation = 0.0
    cfg.slt.train_data_delete = 0.0
    cfg.slt.remove_ood_classes = 0
