import torch
import torch.nn as nn

class Scale(nn.Module):
    """
        Scaling factor module.
    """
    def __init__(self, tensor):
        """
        :param torch.Tensor tensor: Scaling factor vector.
        """
        super().__init__()

        # TODO scaling_factor might become negative, no issue in terms of training but inconsistent wrt BNN+ paper
        self.scaling_factor = nn.Parameter(tensor)

    def forward(self, input):
        scaling_factor = self.scaling_factor.view(1, -1, 1, 1) if len(input.size()) == 4 \
            else self.scaling_factor.view(1, -1)
        return scaling_factor * input
