import numpy as np
from modules.distance.wgan_distance import make_wgan_distance
from modules.distance.sinkhorn_distance import make_sinkhorn_distance

DISTANCE_REGISTRY = {
    'no_ot': lambda x: {
        "init": lambda rng, dtype: None, 
        "apply": lambda ω, z_pred, z_true: np.array(0.0)
    },
    'sinkhorn': make_sinkhorn_distance,
    'wgan': make_wgan_distance,
}