from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config('derivative_encoder_cfg')
def derivative_encoder_cfg(cfg):
    cfg.derivative_encoder = CN()
    cfg.derivative_encoder.enable = False
    cfg.derivative_encoder.num_layers = 10
    cfg.derivative_encoder.in_dim = 21
    cfg.derivative_encoder.emb_dim = 4
    cfg.derivative_encoder.hidden_dim = 64
    cfg.derivative_encoder.derivative_hidden_dim = 32
    cfg.derivative_encoder.add_residual = False
    cfg.derivative_encoder.track_running_stats = True
    cfg.derivative_encoder.activation = "relu"
    cfg.derivative_encoder.use_features = True
    cfg.derivative_encoder.centrality_init = True
    cfg.derivative_encoder.max_degree = 1
    cfg.derivative_encoder.first_embedding_dim = 28
    cfg.derivative_encoder.x_0_embedding_dim = 28
    cfg.derivative_encoder.derivate_embedding_dim = 64
    cfg.derivative_encoder.dropout = 0.0
    cfg.derivative_encoder.derivative_batchnorm = False
    #parameters not being used for now
    cfg.derivative_encoder.sparse = False
    cfg.derivative_encoder.num_tasks = 1
    cfg.derivative_encoder.max_variables_per_node = None
    cfg.derivative_encoder.use_GINE_activation = False
    cfg.derivative_encoder.num_edge_emb = None
    


@register_config('efficient_derivative_encoder_cfg')
def efficient_derivative_encoder_cfg(cfg):
    cfg.efficient_derivative_encoder = CN()
    cfg.efficient_derivative_encoder.enable = False
    cfg.efficient_derivative_encoder.num_layers = 10
    cfg.efficient_derivative_encoder.in_dim = 21
    cfg.efficient_derivative_encoder.emb_dim = 4
    cfg.efficient_derivative_encoder.hidden_dim = 64
    cfg.efficient_derivative_encoder.derivative_hidden_dim = 32
    cfg.efficient_derivative_encoder.add_residual = False
    cfg.efficient_derivative_encoder.track_running_stats = True
    cfg.efficient_derivative_encoder.activation = "relu"
    cfg.efficient_derivative_encoder.use_features = True
    cfg.efficient_derivative_encoder.centrality_init = True
    cfg.efficient_derivative_encoder.max_degree = 1
    cfg.efficient_derivative_encoder.first_embedding_dim = 28
    cfg.efficient_derivative_encoder.x_0_embedding_dim = 28
    cfg.efficient_derivative_encoder.derivate_embedding_dim = 64
    cfg.efficient_derivative_encoder.encoder_dropout = 0.0
    cfg.efficient_derivative_encoder.derivative_batchnorm = False
    #parameters not being used for now
    cfg.efficient_derivative_encoder.sparse = False
    cfg.efficient_derivative_encoder.num_tasks = 1
    cfg.efficient_derivative_encoder.max_variables_per_node = None
    cfg.efficient_derivative_encoder.use_GINE_activation = False
    cfg.efficient_derivative_encoder.num_edge_emb = None
    
