from probjax.distributions.divergences.kl import kl_divergence
from probjax.distributions.divergences.wasserstein import wasserstein_distance
