from torch import nn


class BaseModel(nn.Module):

    def __init__(self, config, global_config):
        super().__init__()
        self.config = config
        self._global_config = global_config

    def _build(self):
        """Function to be implemented by the child class, in case they need to
        build their model separately than ``__init__``. All model related
        downloads should also happen here.
        """
        raise NotImplementedError(
            "Build method not implemented in the child model class."
        )

    def build(self):
        self._build()
        self.inference(False)

    def inference(self, mode=True):
        if mode:
            super().train(False)
        self.inferencing = mode
        for module in self.modules():
            if hasattr(module, "inferencing"):
                module.inferencing = mode
            else:
                setattr(module, "inferencing", mode)

    def train(self, mode=True):
        if mode:
            self.inferencing = False
            for module in self.modules():
                if hasattr(module, "inferencing"):
                    module.inferencing = False
                else:
                    setattr(module, "inferencing", False)
        super().train(mode)
