import torch


def change_bn_momentum(m, momentum=0.001):
    for target_name, target_attr in m.named_children():
        if isinstance(target_attr, torch.nn.BatchNorm2d):
            target_attr.momentum = momentum
        else:
            change_bn_momentum(target_attr, momentum)
