"""
Sinkhorn divergence wrapper using GeomLoss for debiased Sinkhorn.

Feature standardization is applied before distance computation.
"""
from __future__ import annotations

from typing import Tuple

import numpy as np

import torch
from geomloss import SamplesLoss


def _standardize_pair(X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    Z = np.vstack([X, Y])
    mean = Z.mean(axis=0, keepdims=True)
    std = Z.std(axis=0, keepdims=True)
    std[std == 0.0] = 1.0
    Xs = (X - mean) / std
    Ys = (Y - mean) / std
    return Xs, Ys, mean, std


def sinkhorn_divergence(
    X: np.ndarray,
    Y: np.ndarray,
    epsilon: float = 0.05,
    n_iters: int = 300,
    p: int = 2,
    debiased: bool = True,
    device: str = "cpu",
    seed: int | None = None,
) -> float:
    """Compute debiased Sinkhorn divergence between two point clouds.

    Parameters
    ----------
    X, Y : np.ndarray with shapes (n, d), (m, d)
    epsilon : float
        Entropic regularization strength. Mapped to `blur` in GeomLoss.
    n_iters : int
        Number of Sinkhorn iterations.
    p : int
        Ground cost exponent (default 2 for squared Euclidean).
    debiased : bool
        Use debiased Sinkhorn divergence.
    device : str
        'cpu' (recommended for reproducibility) or 'cuda'.
    seed : Optional[int]
        For determinism in GeomLoss internals.
    """
    assert X.ndim == 2 and Y.ndim == 2
    assert X.shape[1] == Y.shape[1]

    # Important: do NOT share normalization across OT and MMD implicitly.
    # Standardize here locally only for OT stability. MMD has its own normalization rules.
    Xs, Ys, _, _ = _standardize_pair(X, Y)

    if seed is not None:
        torch.manual_seed(seed)

    # Use double precision for stability
    dtype = torch.float64
    dev = torch.device(device)

    x = torch.tensor(Xs, dtype=dtype, device=dev)
    y = torch.tensor(Ys, dtype=dtype, device=dev)

    # Debiased Sinkhorn divergence is an OT proxy (Cuturi; Feydy et al.). It is distinct from MMD.
    loss = SamplesLoss(
        loss="sinkhorn",
        p=p,
        blur=epsilon,  # GeomLoss convention; acts like sqrt(eps) in some texts
        scaling=0.9,
        debias=debiased,
        backend="tensorized",
        reach=None,
    )
    # GeomLoss does not expose iterations directly; blur+scaling controls convergence.
    # We still thread through n_iters for API compatibility, though unused here.
    val = loss(x, y).item()
    return float(max(val, 0.0))




