from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config('derivative_preprocessing_cfg')
def derivative_preprocessing_cfg(cfg):
    cfg.derivative_preprocessing = CN()
    cfg.derivative_preprocessing.enable = False
    cfg.derivative_preprocessing.learned = False
    cfg.derivative_preprocessing.original_model_path = ''
    cfg.derivative_preprocessing.num_layers = 10
    cfg.derivative_preprocessing.in_dim = 21
    cfg.derivative_preprocessing.emb_dim = 4
    cfg.derivative_preprocessing.hidden_dim = 64
    cfg.derivative_preprocessing.add_residual = False
    cfg.derivative_preprocessing.track_running_stats = True
    cfg.derivative_preprocessing.activation = "relu"
    cfg.derivative_preprocessing.use_features = True
    cfg.derivative_preprocessing.max_degree = 1
    cfg.derivative_preprocessing.first_embedding_dim = 28
    cfg.derivative_preprocessing.x_0_embedding_dim = 22
    cfg.derivative_preprocessing.derivate_embedding_dim = 70
    cfg.derivative_preprocessing.max_variables_per_node=None
    cfg.derivative_preprocessing.centrality_init =  False
    cfg.derivative_preprocessing.raw_norm_type = None
    cfg.derivative_preprocessing.batchnorm = False
    cfg.derivative_preprocessing.derivative_hidden_dim = 32

    
    