import torch

class FakeBN(torch.nn.Module):
    def __init__(self, shift_mean=0, scaling_mean=1, shift_std=2.5e-2, scaling_std=1e-2):
        super(FakeBN, self).__init__()

        self.shift_mean = shift_mean
        self.shift_std = shift_std
        self.scaling_mean = scaling_mean
        self.scaling_std = scaling_std

    def __call__(self, x):
        device = x.device
        shift = torch.normal(mean=self.shift_mean, std=self.shift_std, size=[1], device=device)
        scaling = torch.normal(mean=self.scaling_mean, std=self.scaling_std, size=[1], device=device)
        return (x + shift) / scaling

