import torch.nn as nn
from antgine.callback import Callback


class SetScalesCallback(Callback):
    """
        Set scaling factor callback for BNN+.
    """
    def __init__(self, model: nn.Module):
        """
        :param torch.nn.Model model: Model.
        """
        super().__init__()
        self._model = model

    def on_train_begin(self):
        """
            Set scaling factor's values before training begins.
        """
        self._model.set_scales()
