import jax
from typing import Callable
import jax.numpy as jnp

def make_mse_objective(
    g: Callable[..., jax.Array],
) -> Callable[[jax.Array, jax.Array], jax.Array]:

    def mse_objective(θ: jax.Array, batch: jax.Array) -> jax.Array:
        """
        Args:
            θ: emulator parameters
            batch: ground-truth trajectories [B, T, d]

        Returns:
            MSE loss between predicted and true trajectories
        """
        u_hat = g(θ, batch)
        u = batch

        err = u_hat - u
        return jnp.mean(err ** 2)

    return mse_objective

def make_ot_objective(
    g: Callable[..., jax.Array],
    f: Callable[..., jax.Array],
    D: Callable[..., jax.Array],
    ot_horizon: int | None = None,
) -> Callable[[jax.Array, jax.Array, jax.Array, jax.Array], jax.Array]:

    def ot_objective(
        θ: jax.Array,
        ψ: jax.Array,
        ω: jax.Array,
        batch: jax.Array,
    ) -> jax.Array:
        """
        Args:
            θ: emulator parameters
            ψ: summary parameters
            ω: critic parameters
            batch: ground-truth trajectories [B, T, d]

        Returns:
            OT loss between predicted and true trajectories
        """
        T = batch.shape[1]
        H = T if ot_horizon is None else min(ot_horizon, T)
        
        u = batch[:, :H, :]
        u_hat = g(θ, u)

        z_hat = f(ψ, u_hat)
        z = f(ψ, u)

        return D(ω, z_hat, z)

    return ot_objective
