from __future__ import annotations
import jax
import jax.numpy as jnp
from typing import Any
from ott.geometry import pointcloud, costs
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

def make_sinkhorn_distance(distance_config: dict):
    """
    fn:
      z_pred, z_true: [B, T, d]
      returns: scalar
    """

    def init(rng: jax.Array, dtype: jnp.dtype) -> None:
        return None

    epsilon = float(distance_config["epsilon"])
    max_iters = int(distance_config["max_iters"])
    threshold = float(distance_config["threshold"])
    normalize_embeddings = bool(distance_config.get("normalize_embeddings", False))
    normalize_eps = float(distance_config.get("normalize_eps", 1e-6))

    solver = sinkhorn.Sinkhorn(
        lse_mode=True,
        threshold=threshold,
        max_iterations=max_iters,
    )

    def _uniform_weights(n: int, dtype):
        return jnp.full((n,), 1.0 / n, dtype=dtype)

    def _ot_cost(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        """
        x, y: [T, d]
        """
        t = x.shape[0]
        a = _uniform_weights(t, x.dtype)
        b = _uniform_weights(t, x.dtype)

        geom = pointcloud.PointCloud(
            x,
            y,
            epsilon=epsilon,
            cost_fn=costs.SqEuclidean(),
        )

        prob = linear_problem.LinearProblem(geom, a=a, b=b)
        out = solver(prob)

        return out.reg_ot_cost

    ot_batch = jax.vmap(_ot_cost, in_axes=(0, 0), out_axes=0)

    def _normalize(z: jnp.ndarray) -> jnp.ndarray:
        mean = jnp.mean(z, axis=(1, 2), keepdims=True)
        std = jnp.std(z, axis=(1, 2), keepdims=True)
        return (z - mean) / (std + normalize_eps)

    def distance_apply(ω: Any, z_pred: jnp.ndarray, z_true: jnp.ndarray) -> jnp.ndarray:
        """
          S_ε(z,z_hat) - 1/2 S_ε(z,z) - 1/2 S_ε(z_hat,z_hat)
        """
        if z_pred.ndim != 3 or z_true.ndim != 3:
            z_pred = z_pred[..., None]
            z_true = z_true[..., None]

        if z_pred.shape != z_true.shape:
            raise ValueError("z_pred and z_true must have identical shapes.")

        if normalize_embeddings:
            z_pred = _normalize(z_pred)
            z_true = _normalize(z_true)

        xy = ot_batch(z_pred, z_true)
        xx = jax.lax.stop_gradient(ot_batch(z_pred, z_pred))
        yy = jax.lax.stop_gradient(ot_batch(z_true, z_true))

        return jnp.mean(xy - 0.5 * xx - 0.5 * yy)

    return {
        "nas_params": False,
        "init": init,
        "apply": distance_apply
    }
