import torch

def reparam(model, prev_layer=None):
    """
    Reparametrization of a network with batch normalization so that it calculates the same function as the
    original network but without batch normalization. Instead of removing batch norm completely, we set the bias and mean
    to zero, and scaling and variance to one.

    Args:
        model: input model to be reparameterized
        prev_layer: recursion helper, previous child of the model
    """
    for child in model.children():
        module_name = child._get_name()
        prev_layer = reparam(child, prev_layer)
        if module_name in ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']:
            prev_layer = child
        elif module_name in ['BatchNorm2d', 'BatchNorm1d']:
            with torch.no_grad():
                scale = child.weight / ((child.running_var + child.eps).sqrt())
                prev_layer.bias.copy_( child.bias + (scale * (prev_layer.bias - child.running_mean)))
                perm = list(reversed(range(prev_layer.weight.dim())))
                prev_layer.weight.copy_((prev_layer.weight.permute(perm) * scale).permute(perm))
                child.bias.fill_(0)
                child.weight.fill_(1)
                child.running_mean.fill_(0)
                child.running_var.fill_(1)
    return prev_layer