from abc import ABC, abstractmethod

from torch.nn import Module


class ModelConfig(ABC):
    """
    For each model class, a subclass of this class should exist that can hold the arguments for the model.
    Instances of that subclass can then directly be used to instantiate the model.

    Subclasses of this class should be annotated with `@yaml_object(YAML())` and `@dataclass()`, e.g.
    ```python
    from ruamel.yaml import YAML, yaml_object

    @yaml_object(YAML())
    @dataclass()
    class MyModelConfig(ModelConfig):
        ...
    ```
    """

    @classmethod
    @property
    @abstractmethod
    def MODEL_NAME() -> str:
        """
        The name of the model, e.g. used in file names.

        Subclasses should override this as a class-level constant without a type hint, like this:
        ```python
        class MyModelConfig(ModelConfig):
            MODEL_NAME = "my-model"
        ```
        The absence of a type hint means that no argument will be added to the constructor generated by the
        `@dataclass` decorator, which would allow the constant to accidentally be overridden.
        """
        pass

    @classmethod
    @property
    @abstractmethod
    def MODEL_CLASS() -> type[Module]:
        """
        The model class that this config class is intended to configure.
        The constructor of that model class should take an instance of this config class as an argument, and no other
        arguments.

        Subclasses should override this as a class-level constant without a type hint, like this:
        ```python
        class MyModelConfig(ModelConfig):
            MODEL_CLASS = MyModel
        ```
        The absence of a type hint means that no argument will be added to the constructor generated by the
        `@dataclass` decorator, which would allow the constant to accidentally be overridden.
        """
        pass

    def instantiate_model(self) -> Module:
        """
        Instantiates the model based on this configuration.
        """
        return self.MODEL_CLASS(self)
