import horovod.torch as hvd
import torch

hvd.init()

batchnorm = torch.nn.BatchNorm3d(100).cuda()
sync_batchnorm = hvd.SyncBatchNorm(100).cuda()
input = torch.randn(20, 100, 35, 45, 10)
input = input.cuda()
output = batchnorm(input)
sync_output = sync_batchnorm(input)
diff = (output - sync_output)
print(diff.sum())