from graph_learning.config import Config, config_dispatch, init_context_namespace

import graph_learning.utils as u
import torch

@Config.register('module', is_fold=True,
                 help="""
                 Build neural network model.
                 """)
class ModuleConfig(Config):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        """General arg for all module component.

        as: module name
        """
        super().define_parser(parser)
        parser.add_argument('--as', dest='name',
                            help='the name of the module, performs as the key to access the module.')

    @property
    def builder(self):
        """ Module class

        Returns
        -------
        type(torch.nn.Module)
            Module implement class.
        """
        raise NotImplementedError

    @classmethod
    def build(cls, args, context):
        """ General module component's building process.
        Build module and save in context by its name, ready be used as submodule of later modules.
        """
        config = cls(args, context)
        layer = config_dispatch(config.builder, config)
        register_module(context, config.name, layer)

def register_module(context, name, content):
    init_context_namespace(context, 'module', u.Namespace())
    setattr(context.module, name, content)

def get_module(context, name):
    if name is None:
        return None
    else:
        return getattr(context.module, name, None)
