import torch


class Base(torch.nn.Module):
    """
    The base class for modules. That contains a disable round mode
    """

    def __init__(self):
        super().__init__()

    def _set_child_attribute(self, attr, value):
        r"""Sets the module in rounding mode.

        This has any effect only on certain modules if variable type is
        discrete.

        Returns:
            Module: self
        """
        if hasattr(self, attr):
            setattr(self, attr, value)

        for module in self.modules():
            if hasattr(module, attr):
                setattr(module, attr, value)
        return self

    def set_temperature(self, value):
        self._set_child_attribute("temperature", value)

    def enable_hard_round(self, mode=True):
        self._set_child_attribute("hard_round", mode)

    def disable_hard_round(self, mode=True):
        self.enable_hard_round(not mode)
