import numpy as np
import torch
import ot
from geomloss import SamplesLoss

def _validate_histogram(w: torch.Tensor, name: str, atol: float = 1e-12) -> torch.Tensor:
    """
    Validates and normalizes weights to sum to 1.
    Keeps everything on the same device and dtype.
    """
    w = w.reshape(-1)

    if torch.any(w < -atol):
        raise ValueError(f"{name} has negative entries (below -{atol}).")

    s = w.sum()
    if not torch.isfinite(s) or s <= 0:
        raise ValueError(f"{name} sum must be positive and finite; got {s.item()}.")

    w = torch.clamp(w, min=0.0)
    w = w / w.sum()
    return w

@torch.no_grad()  
def sinkhorn_divergence_same_atoms(
    a: torch.Tensor,
    b: torch.Tensor,
    X: torch.Tensor,
    eps: float = 0.05,
    scaling: float = 0.9,
    atol: float = 1e-12,
    debias: bool = False,
) -> dict:
    """
    Compute Sinkhorn divergence + entropic dual potentials on GPU using GeomLoss.

    Parameters
    ----------
    a, b : (n,) torch.Tensor
        Nonnegative weights (not necessarily normalized).
    X : (n, d) torch.Tensor
        Common support atoms (same for both measures).
    reg : float
        Entropic regularization ε (POT's `reg`). In GeomLoss, blur = sqrt(ε).
    scaling : float
        GeomLoss scaling parameter for the Sinkhorn annealing schedule.
    atol : float
        Tolerance for negativity check.

    Returns
    -------
    dict with keys:
      - "S": Sinkhorn divergence (scalar tensor)
      - "OT_ab", "OT_aa", "OT_bb": entropic OT costs (scalar tensors)
      - "potentials": {"ab": (f_ab, g_ab), "aa": (f_aa, g_aa), "bb": (f_bb, g_bb)}
        where f and g are tensors of shape (n,)
    """

    if eps <= 0:
        raise ValueError("eps must be > 0.")

    # Device/dtype alignment
    device = X.device
    dtype = X.dtype

    a = a.to(device=device, dtype=dtype)
    b = b.to(device=device, dtype=dtype)

    # Normalize weights
    a = _validate_histogram(a, "a", atol=atol)
    b = _validate_histogram(b, "b", atol=atol)

    n = X.shape[0]
    if a.numel() != n or b.numel() != n:
        raise ValueError(f"a,b must have length {n}; got {a.numel()}, {b.numel()}.")

    # Map POT reg (ε) -> GeomLoss blur (σ)
    blur = float(eps) ** 0.5

    # Entropic OT loss (NOT debiased), and request dual potentials
    # Note: potentials are defined up to additive constants.
    
    loss_val = SamplesLoss(loss="sinkhorn", p=2, blur=blur, scaling=scaling, debias=debias, potentials=False)
    loss_pot = SamplesLoss(loss="sinkhorn", p=2, blur=blur, scaling=scaling, debias=debias, potentials=True)

    def ot_and_pot(alpha, beta):
        # scalar entropic OT (or divergence if debias=True)
        val = loss_val(alpha, X, beta, X)
        # potentials only (tuple of two tensors)
        f, g = loss_pot(alpha, X, beta, X)
        return val, f, g
    
    OT_ab, f_ab, g_ab = ot_and_pot(a, b)
    OT_aa, f_aa, g_aa = ot_and_pot(a, a)
    OT_bb, f_bb, g_bb = ot_and_pot(b, b)

    S = OT_ab  - 0.5 * OT_aa - 0.5 * OT_bb

    return {
        "S": S.item(),
        "OT_ab": OT_ab.item(),
        "OT_aa": OT_aa.item(),
        "OT_bb": OT_bb.item(),
        "potentials": {
            "ab": (f_ab, g_ab),
            "aa": (f_aa, g_aa),
            "bb": (f_bb, g_bb),
        },
    }

@torch.no_grad()  
def cost_matrix_on_atoms(
    X: torch.Tensor,
    metric: str = "sqeuclidean",
    normalize: bool = False,
    eps: float = 1e-12,
) -> torch.Tensor:
    """
    Compute the (n,n) ground cost matrix on atoms X (GPU supported).

    Parameters
    ----------
    X : (n, d) torch.Tensor
        Atom locations.
    metric : {"sqeuclidean", "euclidean"}
        Ground metric.
    normalize : bool
        If True, rescale C by the median of its positive entries.
        (Helps stability; matches your POT normalization idea.)
    eps : float
        Numerical floor for selecting positive entries.

    Returns
    -------
    C : (n, n) torch.Tensor
        Cost matrix on the same device/dtype as X.
    """
    X = X.contiguous()

    if metric == "sqeuclidean":
        # Compute ||xi - xj||^2 = ||xi||^2 + ||xj||^2 - 2 xi·xj
        x2 = (X * X).sum(dim=1, keepdim=True)          # (n, 1)
        C = x2 + x2.t() - 2.0 * (X @ X.t())            # (n, n)
        C = C.clamp_min_(0.0)                          # numerical safety
        C = C/2

    elif metric == "euclidean":
        # torch.cdist uses sqrt; can be slower but direct
        C = torch.cdist(X, X, p=2)

    else:
        raise ValueError(f"Unsupported metric='{metric}'. Use 'sqeuclidean' or 'euclidean'.")

    if normalize:
        # median of positive entries (exclude diagonal zeros)
        pos = C[C > eps]
        if pos.numel() > 0:
            med = pos.median()
            # avoid divide-by-zero
            if med > 0:
                C = C / med

    return C


def _sym(A: torch.Tensor) -> torch.Tensor:
    return 0.5 * (A + A.transpose(-1, -2))

def _sqrtm_psd(A: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # Symmetric PSD matrix square root via eigen-decomposition
    A = _sym(A)
    w, V = torch.linalg.eigh(A)
    w = torch.clamp(w, min=eps)
    return (V * w.sqrt().unsqueeze(-2)) @ V.transpose(-1, -2)

def _logdet_psd(A: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # log det for symmetric PD/PSD via eigenvalues
    A = _sym(A)
    w = torch.linalg.eigvalsh(A)
    w = torch.clamp(w, min=eps)
    return torch.log(w).sum()

def entropic_ot_gaussian_closed_form(
    m0: torch.Tensor, C0: torch.Tensor,
    m1: torch.Tensor, C1: torch.Tensor,
    eps: float,
    eps_eig: float = 1e-12
) -> torch.Tensor:
    
    """
    Closed-form OT_σ between Gaussians as in Janati et al. (NeurIPS 2020), Thm 1. :contentReference[oaicite:3]{index=3}
    Uses cost ||x-y||^2 and their σ-parameterization.

    m0, m1: (d,)
    C0, C1: (d,d) symmetric PSD/PD
    returns: scalar tensor
    """
    device = C0.device
    dtype = C0.dtype

    d = C0.shape[-1]
    I = torch.eye(d, device=device, dtype=dtype)
    eps_t = eps if torch.is_tensor(eps) else torch.tensor(eps, device=device, dtype=dtype)

    C0 = _sym(C0)
    C1 = _sym(C1)

    # D_sigma = ( 4 C0^{1/2} C1 C0^{1/2} + (sigma^4/4) I )^{1/2}
    C0_sqrt = _sqrtm_psd(C0, eps=eps_eig)
    inner = 4.0 * (C0_sqrt @ C1 @ C0_sqrt) + (eps_t**2) * I
    D = _sqrtm_psd(inner, eps=eps_eig)

    mean_term = torch.sum((m0 - m1) ** 2)

    cov_term = (
        torch.trace(C0)
        + torch.trace(C1)
        - torch.trace(D)
        + d * eps_t * (1.0 - torch.log(2*eps_t))
        + eps_t * _logdet_psd(D + eps_t * I, eps=eps_eig)
    )

    return (mean_term + cov_term)/2.0

def sinkhorn_divergence_gaussians(
    m1: torch.Tensor, C1: torch.Tensor,
    m2: torch.Tensor, C2: torch.Tensor,
    eps: torch.Tensor | float,
    eps_eig: float = 1e-12
) -> torch.Tensor:
    
    """
    S_σ = OT(0,1) - 1/2 OT(0,0) - 1/2 OT(1,1)
    """
    device = m1.device
    dtype = m1.dtype

    eps_t = eps if torch.is_tensor(eps) else torch.tensor(eps, device=device, dtype=dtype)
    if eps_t.ndim != 0:
        eps_t = eps_t.squeeze()
        if eps_t.ndim != 0:
            raise ValueError("eps must be a scalar (float or 0-dim tensor).")
        
    OT12 = entropic_ot_gaussian_closed_form(m1, C1, m2, C2, eps_t, eps_eig)
    OT11 = entropic_ot_gaussian_closed_form(m1, C1, m1, C1, eps_t, eps_eig)
    OT22 = entropic_ot_gaussian_closed_form(m2, C2, m2, C2, eps_t, eps_eig)
    G_true_t = OT12 - 0.5 * OT11 - 0.5 * OT22
    return float(G_true_t.item()) if hasattr(G_true_t, "item") else float(G_true_t)
     
def median_pairwise(C: torch.Tensor) -> float:
    n = C.shape[0]
    iu, ju = torch.triu_indices(row=n, col=n, offset=1, device=C.device)
    pairwise = torch.sqrt(2 * C[iu, ju])
    return float(pairwise.median().item())


@torch.no_grad()
def stable_hadamard_M_from_density_Q(
    Q: torch.Tensor,           # (n,n) density wrt mu⊗mu: Pi_ij = w_i w_j Q_ij
    w: torch.Tensor,           # (n,) weights, positive
    tau: float = 1e-8,
    ridge: float = 0.0,
    symmetrize_Q: bool = True,
) -> torch.Tensor:
    """
    Your convention:
      Pi_ij = w_i w_j Q_ij, where Q_ij = exp((f_i + f_j - c_ij)/eps).

    Then:
      H = D^{-1} Pi D^{-1} = Q
      K_mu = D^{-1} Pi      = Q D   (row-stochastic if Pi has marginals w)

    We compute M = (I - K_mu^2)^† (P H), using similarity transform:
      S = D^{1/2} K_mu D^{-1/2} = D^{1/2} Q D^{1/2}  (symmetric in self case).
    """
    if Q.ndim != 2 or Q.shape[0] != Q.shape[1]:
        raise ValueError("Q must be square (n,n).")
    n = Q.shape[0]
    if w.ndim != 1 or w.shape[0] != n:
        raise ValueError("w must have shape (n,).")
    if torch.any(w < 0):
        raise ValueError("weights must be positive.")

    device = Q.device
    w = w.to(device=device, dtype=Q.dtype)

    w = w / w.sum()

    if symmetrize_Q:
        Q = 0.5 * (Q + Q.T)

    # H = Q
    Hc = Q - torch.mean(Q, dim=0, keepdim=True)

    # Build S = D^{1/2} Q D^{1/2}
    sqrtw = torch.sqrt(w)
    inv_sqrtw = 1.0 / sqrtw
    S = (sqrtw[:, None] * Q) * sqrtw[None, :]  # symmetric

    lam, U = torch.linalg.eigh(S)

    denom = 1.0 - lam**2 + ridge
    r = torch.zeros_like(denom)
    mask = denom >= tau
    r[mask] = 1.0 / denom[mask]

    # Apply M = D^{-1/2} U diag(r) U^T D^{1/2} Hc
    Y = sqrtw[:, None] * Hc
    Yh = U.T @ Y
    Zh = r[:, None] * Yh
    Z = U @ Zh
    M = inv_sqrtw[:, None] * Z
    return M