
import jax
from typing import Any, Callable
from modules.architectures.mlp import mlp_module

def make_critic(distance_config: dict[str, Any]) -> tuple[Callable[..., list[jax.Array]], Callable[..., jax.Array]]:
    return mlp_module(distance_config)

def make_wgan_distance(distance_config: dict[str, Any]):
    """
    fn:
      z_pred, z_true: [B, T, d]
      returns: scalar
    """
    init_critic, apply_critic = make_critic(distance_config)

    def init(rng: jax.Array, dtype: jax.numpy.dtype) -> list[jax.Array]:
        return init_critic(rng, dtype)

    def distance_apply(ω: Any, z_pred: jax.numpy.ndarray, z_true: jax.numpy.ndarray) -> jax.numpy.ndarray:
        """
            W_1(z, z_hat) = sup_φ { E_{z~ν}[φ(z)] - E_{z~ν_hat}[φ(z_hat)] }
        """
        if z_pred.ndim != 3 or z_true.ndim != 3:
            z_pred = z_pred[..., None]
            z_true = z_true[..., None]
        
        if z_pred.shape != z_true.shape:
            raise ValueError("z_pred and z_true must have identical shapes.")

        φ_compose_z = apply_critic(ω, z_true)
        φ_compose_z_hat = apply_critic(ω, z_pred)

        dists = jax.numpy.mean(φ_compose_z, axis=1) - jax.numpy.mean(φ_compose_z_hat, axis=1)
        return jax.numpy.mean(dists)

    return {
        "has_params": True,
        "init": init,
        "apply": distance_apply
    }