import torch
import torch.nn as nn



def register(key, module, module_dict):
    if key in module_dict:
        raise KeyError('Key {} is already pre-defined.'.format(key))
    else:
        module_dict[key] = module

act_dict = {}
def register_act(key, module):
    register(key, module, act_dict)

node_encoder_dict = {}
def register_node_encoder(key, module):
    register(key, module, node_encoder_dict)

edge_encoder_dict = {}
def register_edge_encoder(key, module):
    register(key, module, edge_encoder_dict)

stage_dict = {}
def register_stage(key, module):
    register(key, module, stage_dict)

head_dict = {}
def register_head(key, module):
    register(key, module, head_dict)

layer_dict = {}
def register_layer(key, module):
    register(key, module, layer_dict)

pooling_dict = {}
def register_pooling(key, module):
    register(key, module, pooling_dict)

network_dict = {}
def register_network(key, module):
    register(key, module, network_dict)

config_dict = {}
def register_config(key, module):
    register(key, module, config_dict)

loader_dict = {}
def register_loader(key, module):
    register(key, module, loader_dict)

optimizer_dict = {}
def register_optimizer(key, module):
    register(key, module, optimizer_dict)

scheduler_dict = {}
def register_scheduler(key, module):
    register(key, module, scheduler_dict)

loss_dict = {}
def register_loss(key, module):
    register(key, module, loss_dict)

feature_augment_dict = {}
def register_feature_augment(key, module):
    register(key, module, feature_augment_dict)

train_dict = {}
def register_train(key, module):
    register(key, module, train_dict)