import torch
import numpy as np
import sys
sys.path.append('../')
from utils.MCS import MCS

@torch.no_grad()
def _default_eps(D: torch.Tensor):
    pos = D[D > 0]
    if pos.numel() == 0:
        return torch.as_tensor(1.0, device=D.device, dtype=D.dtype)
    med = pos.median()
    return (0.05 * med).clamp(min=torch.finfo(D.dtype).eps)

def sinkhorn_distance(
    D: torch.Tensor,
    eps: float | torch.Tensor | None = None,
    a: torch.Tensor | None = None,
    b: torch.Tensor | None = None,
    max_iter: int = 10_000,
    tol: float = 1e-9,
    return_plan: bool = False,
    stabilized: bool = False,
):
    """
    Entropic OT ('Sinkhorn distance') between two discrete distributions with cost matrix D (n x m).
    If a or b is None, they are set to uniform. Returns <P, D>.
    - stabilized=False: classic scaling with K=exp(-D/eps) (fast, may underflow for tiny eps).
    - stabilized=True: log-domain updates for numerical stability (slower, more robust).
    """
    assert D.ndim == 2 and D.is_floating_point(), "D must be a 2D floating tensor"
    n, m = D.shape
    device, dtype = D.device, D.dtype

    if a is None:
        a = torch.full((n,), 1.0 / n, device=device, dtype=dtype)
    else:
        a = a.to(device=device, dtype=dtype)
        a = a / a.sum()

    if b is None:
        b = torch.full((m,), 1.0 / m, device=device, dtype=dtype)
    else:
        b = b.to(device=device, dtype=dtype)
        b = b / b.sum()

    eps = _default_eps(D) if eps is None else torch.as_tensor(eps, device=device, dtype=dtype)

    if not stabilized:
        # Classic Sinkhorn–Knopp
        K = torch.exp(-D / eps).clamp_min(torch.finfo(dtype).tiny)
        u = torch.ones(n, device=device, dtype=dtype)
        v = torch.ones(m, device=device, dtype=dtype)

        for _ in range(max_iter):
            u_prev, v_prev = u, v
            Kv = (K @ v).clamp_min(torch.finfo(dtype).tiny)
            u = a / Kv
            KTu = (K.t() @ u).clamp_min(torch.finfo(dtype).tiny)
            v = b / KTu
            if max((u - u_prev).abs().max().item(), (v - v_prev).abs().max().item()) < tol:
                break

        P = (u[:, None] * K) * v[None, :]
    else:
        # Log-domain stabilized Sinkhorn
        # Dual potentials f, g (same units as D), initialize at zeros
        f = torch.zeros(n, device=device, dtype=dtype)
        g = torch.zeros(m, device=device, dtype=dtype)
        log_a = torch.log(a.clamp_min(torch.finfo(dtype).tiny))
        log_b = torch.log(b.clamp_min(torch.finfo(dtype).tiny))

        inv_eps = 1.0 / eps
        for _ in range(max_iter):
            f_prev, g_prev = f, g
            # Update f: enforce row sums
            # logsumexp_j ( (g_j - D_ij)/eps )  then f_i += eps*(log a_i - that)
            f = f + eps * (log_a - torch.logsumexp((g[None, :] - D) * inv_eps, dim=1))
            # Update g: enforce col sums
            g = g + eps * (log_b - torch.logsumexp((f[:, None] - D) * inv_eps, dim=0))
            if max((f - f_prev).abs().max().item(), (g - g_prev).abs().max().item()) < tol:
                break

        P = torch.exp(((f[:, None] + g[None, :]) - D) * inv_eps)

    cost = torch.sum(P * D)
    return (cost, P) if return_plan else cost

def mcs_sinkhorn(X, Y, mcs :MCS):
    D = mcs.distance(X, Y)
    return sinkhorn_distance(D)


def rand_unif(N, d, device):
    samples = torch.randn(N, d, device=device)
    samples_ = samples / samples.norm(dim=1, keepdim=True)
    return samples_

def fibonacci_sphere(samples=1000):
    points = []
    phi = np.pi * (3. - np.sqrt(5.))  # golden angle in radians
    y = 1 - (2 * np.arange(samples) / (samples - 1))
    radius = np.sqrt(1 - y*y)
    theta = phi * np.arange(samples)
    x = np.cos(theta) * radius
    z = np.sin(theta) * radius
    points = np.stack([x, y, z], axis=-1)
    return points
