from mad_td.utils.jax import torch_he_uniform

torch_he_uniform(None)
