
from __future__ import annotations
import math
from typing import Callable, Optional, Tuple, Dict, Union
import ot
import torch
import time
# ---------------------------------------------------------------------------
# Utils FGW : init_matrix, tensor_product, gw_loss, gw_grad (convention POT)
# ---------------------------------------------------------------------------

def _transform_matrix(C1: torch.Tensor, C2: torch.Tensor, loss_fun: str = 'square_loss') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    if loss_fun != 'square_loss':
        raise ValueError(f"Unsupported loss_fun={loss_fun} (only 'square_loss')")
    fC1 = C1 ** 2
    fC2 = C2 ** 2
    hC1 = C1
    hC2 = 2.0 * C2
    return fC1, fC2, hC1, hC2


def init_matrix(C1: torch.Tensor, C2: torch.Tensor, p: torch.Tensor, q: torch.Tensor, loss_fun: str = 'square_loss') -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun)
    constC1 = (fC1 @ p).unsqueeze(1)                   # (ns,1)
    constC2 = (fC2 @ q).unsqueeze(0)                   # (1,nt)
    constC = constC1 + constC2                         # (ns,nt)
    return constC, hC1, hC2


def tensor_product(constC: torch.Tensor, hC1: torch.Tensor, hC2: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    # tens = constC - hC1 @ T @ hC2^T = constC - C1 T (2C2)^T = constC - 2 * C1 T C2^T
    return constC - (hC1 @ T) @ hC2.T


def gw_loss(constC: torch.Tensor, hC1: torch.Tensor, hC2: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    tens = tensor_product(constC, hC1, hC2, T)
    return (tens * T).sum()


def gw_grad(constC: torch.Tensor, hC1: torch.Tensor, hC2: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    # grad GW = 2 * tens
    return 2.0 * tensor_product(constC, hC1, hC2, T)


# ---------------------------------------------------------------------------
# Sinkhorn stabilised 
# ---------------------------------------------------------------------------

def sinkhorn_log_stabilized(
    a: torch.Tensor,
    b: torch.Tensor,
    C: torch.Tensor,
    eps: float = 5e-3,
    max_iter: int = 1_000,
    tol: float = 1e-9,
    stop_interval: int = 10,
    log: bool = False,
    f0: Optional[torch.Tensor] = None,
    g0: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
    """Solve entropic OT with *stabilized* Sinkhorn in log-domain.

    Minimizes:  <C, T> + eps * KL(T || a ⊗ b)  s.t. T1=a, T^T1=b, T>=0.

    Returns the transport plan T and (optionally) dual potentials f, g.
    """
    if a.dim() != 1 or b.dim() != 1 or C.dim() != 2:
        raise ValueError("a,b must be 1D and C 2D")
    ns, nt = C.shape
    if ns != a.shape[0] or nt != b.shape[0]:
        raise ValueError("Shape mismatch: C(ns,nt), a(ns), b(nt)")

    # Normalize (avoid NaNs if sums differ a bit)
    a = a / (a.sum() + 1e-300)
    b = b / (b.sum() + 1e-300)

    # log-domain potentials f,g; work in float64 for stability
    dev = C.device
    f = torch.zeros(C.shape[0], dtype=torch.float64, device=C.device) if f0 is None else f0.to(torch.float64)
    g = torch.zeros(C.shape[1], dtype=torch.float64, device=C.device) if g0 is None else g0.to(torch.float64)

    C64 = C.to(torch.float64)
    a64 = a.to(torch.float64)
    b64 = b.to(torch.float64)

    # Precompute logs (handle zeros)
    tiny = torch.finfo(torch.float64).tiny
    loga = torch.where(a64 > 0, a64.log(), torch.full_like(a64, -float('inf')))
    logb = torch.where(b64 > 0, b64.log(), torch.full_like(b64, -float('inf')))

    def row_logsumexp(X: torch.Tensor) -> torch.Tensor:
        # X: (ns, nt) -> lse over columns dim=1
        m = torch.amax(X, dim=1, keepdim=True)
        Z = torch.exp(X - m)
        return (m.squeeze(1) + torch.log(Z.sum(dim=1) + tiny))

    def col_logsumexp(X: torch.Tensor) -> torch.Tensor:
        m = torch.amax(X, dim=0, keepdim=True)
        Z = torch.exp(X - m)
        return (m.squeeze(0) + torch.log(Z.sum(dim=0) + tiny))

    last_err = None
    for it in range(1, max_iter + 1):
        # f-update
        # f = eps * ( log a - logsumexp((g - C_ij)/eps over j) )
        f = eps * (loga - row_logsumexp((g.unsqueeze(0) - C64) / eps))
        # g-update
        g = eps * (logb - col_logsumexp((f.unsqueeze(1) - C64) / eps))

        if it % stop_interval == 0 or it == max_iter:
            # build T only occasionally to check marginals
            T = torch.exp((f.unsqueeze(1) + g.unsqueeze(0) - C64) / eps)
            err_r = (T.sum(dim=1) - a64).abs().max()
            err_c = (T.sum(dim=0) - b64).abs().max()
            err = torch.max(err_r, err_c).item()
            if last_err is None or err < last_err:
                last_err = err
            if err < tol:
                break

    # Final T
    T = torch.exp((f.unsqueeze(1) + g.unsqueeze(0) - C64) / eps)

    if log:
        return T.to(C.dtype), {
            'f': f.to(C.dtype),
            'g': g.to(C.dtype),
            'iters': torch.tensor(it),
            'err': torch.tensor(last_err if last_err is not None else float('nan')),
        }
    return T.to(C.dtype), None


# ---------------------------------------------------------------------------
# Line search 
# ---------------------------------------------------------------------------

def solve_gromov_linesearch(
    G: torch.Tensor,
    deltaG: torch.Tensor,
    cost_G: torch.Tensor,
    C1: torch.Tensor,
    C2: torch.Tensor,
    M_feat: torch.Tensor,   # (1-alpha) * M
    reg_gw: float,          # alpha
    symmetric: bool = False,
) -> Tuple[torch.Tensor, int, torch.Tensor]:
    # dot terms
    dot_d = C1 @ deltaG @ C2.T
    a = -reg_gw * (dot_d * deltaG).sum()

    if symmetric:
        b = (M_feat * deltaG).sum() - 2.0 * reg_gw * (dot_d * G).sum()
    else:
        b = (M_feat * deltaG).sum() - reg_gw * ( (dot_d * G).sum() + ((C1 @ G @ C2.T) * deltaG).sum() )

    a_val = a.item()
    b_val = b.item()
    if a_val > 0:
        t = -b_val / (2.0 * a_val)
        alpha = torch.tensor(min(1.0, max(0.0, t)), device=G.device, dtype=G.dtype)
    else:
        alpha = torch.tensor(1.0 if (a_val + b_val) < 0 else 0.0, device=G.device, dtype=G.dtype)

    cost_new = cost_G + a * (alpha ** 2) + b * alpha
    return alpha, 1, cost_new


# ---------------------------------------------------------------------------
# Conditional Gradient 
# ---------------------------------------------------------------------------

def generic_conditional_gradient_sinkhorn(
    a: torch.Tensor,
    b: torch.Tensor,
    M_feat: torch.Tensor,             # (1-alpha) * M
    f: Callable[[torch.Tensor], torch.Tensor],
    df: Callable[[torch.Tensor], torch.Tensor],
    reg_gw: float,                    # alpha
    G0: Optional[torch.Tensor],
    eps: float,
    C1: torch.Tensor,
    C2: torch.Tensor,
    numItermax: int = 200,
    stopThr: float = 1e-9,
    stopThr2: float = 1e-9,
    verbose: bool = False,
    log: bool = False,
    sinkhorn_max_iter: int = 1_000,
    sinkhorn_tol: float = 1e-9,
    polished: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:

    device = M_feat.device
    dtype = M_feat.dtype

    if log:
        log_dict: Dict[str, Union[float, torch.Tensor]] = {"loss": []}

    # Init: plan barycentrique
    if G0 is None:
        G = torch.outer(a, b)
    else:
        G = G0.clone()

    def cost(G_: torch.Tensor) -> torch.Tensor:
        return (M_feat * G_).sum() + reg_gw * f(G_)

    cost_G = cost(G)
    if log:
        log_dict["loss"].append(cost_G.item())

    it = 0
    if verbose:
        print(f"{'It.':<5s}|{'Loss':<12s}|{'Rel.d':<12s}|{'Abs.d':<12s}")
        print("-" * 46)
        print(f"{it:<5d}|{cost_G.item():<12.4e}|{'0.0':<12s}|{'0.0':<12s}")
    f_cache = g_cache = None

    while True:
        it += 1
        old_cost = cost_G

        grad = df(G)
        Mi = M_feat + reg_gw * grad
        s = torch.median(Mi.abs()).item() + 1e-12
        eps_use = eps * s

        #  (Sinkhorn stabilisé)
        Gc, logK = sinkhorn_log_stabilized(a, b, Mi, eps=eps_use, max_iter=sinkhorn_max_iter, tol=sinkhorn_tol, log=True, f0=f_cache, g0=g_cache)
        f_cache = logK['f'] if log else None
        g_cache = logK['g'] if log else None
        deltaG = Gc - G
        alpha_step, _, cost_G = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M_feat, reg_gw, symmetric=False)

        G = G + alpha_step * deltaG

        # Convergence
        if it >= numItermax:
            break
        abs_d = (cost_G - old_cost).abs()
        rel_d = abs_d / (cost_G.abs() + 1e-300)
        if log:
            log_dict["loss"].append(cost_G.item())
        if verbose:
            if it % 20 == 0:
                print(f"{'It.':<5s}|{'Loss':<12s}|{'Rel.d':<12s}|{'Abs.d':<12s}")
                print("-" * 46)
            print(f"{it:<5d}|{cost_G.item():<12.4e}|{rel_d.item():<12.4e}|{abs_d.item():<12.4e}")
        if rel_d.item() < stopThr or abs_d.item() < stopThr2:
            break
    if polished:
        time_loss = time.time()
        # Add a last step, with an exact emd solver 
        M_final = M_feat + reg_gw * df(G)
        T_exact, log_dict = ot.emd(a,b, M_final, numItermax=1000, log=True)
        deltaG = T_exact - G
        alpha_step, _, _ = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M_feat, reg_gw, symmetric=False)
        G = G + alpha_step * deltaG
        print(f"Polished in {time.time() - time_loss:.2f}s")
    if log:
        return G, log_dict
    return G


# ---------------------------------------------------------------------------
# : FGW + Sinkhorn stabilisé
# ---------------------------------------------------------------------------

def fused_gromov_wasserstein_sinkhorn(
    M: torch.Tensor,
    C1: torch.Tensor,
    C2: torch.Tensor,
    p: torch.Tensor,
    q: torch.Tensor,
    *,
    loss_fun: str = 'square_loss',
    alpha: float = 0.5,
    eps: float = 5e-3,                   # eps de Sinkhorn
    armijo: bool = False,               
    G0: Optional[torch.Tensor] = None,
    log: bool = False,
    numItermax: int = 200,
    stopThr: float = 1e-9,
    verbose: bool = False,
    sinkhorn_max_iter: int = 1_000,
    sinkhorn_tol: float = 1e-9,
    polished: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
    p = (p / (p.sum() + 1e-300)).to(M.dtype)
    q = (q / (q.sum() + 1e-300)).to(M.dtype)

    constC, hC1, hC2 = init_matrix(C1, C2, p.to(C1.dtype), q.to(C2.dtype), loss_fun)

    def f(G: torch.Tensor) -> torch.Tensor:
        return gw_loss(constC, hC1, hC2, G)

    def df(G: torch.Tensor) -> torch.Tensor:
        return gw_grad(constC, hC1, hC2, G)

    # Pondérations POT : (1-alpha) * W + alpha * GW
    M_feat = (1.0 - alpha) * M
    reg_gw = alpha

    def solve_gromov_linesearch_local(G, dG, cost_G):
        return solve_gromov_linesearch(G, dG, cost_G, C1, C2, M_feat, reg_gw, symmetric=False)

    # Wrap 
    return generic_conditional_gradient_sinkhorn(
        p, q, M_feat, f, df, reg_gw, G0, eps, C1, C2,
        numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr,
        verbose=verbose, log=log,
        sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol,
        polished=polished
    )


def fused_gromov_wasserstein2_sinkhorn(
    M: torch.Tensor,
    C1: torch.Tensor,
    C2: torch.Tensor,
    p: torch.Tensor,
    q: torch.Tensor,
    *,
    loss_fun: str = 'square_loss',
    alpha: float = 0.5,
    eps: float = 5e-3,
    armijo: bool = False,
    G0: Optional[torch.Tensor] = None,
    log: bool = False,
    numItermax: int = 200,
    stopThr: float = 1e-9,
    verbose: bool = False,
    sinkhorn_max_iter: int = 1_000,
    sinkhorn_tol: float = 1e-9,
    polished: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
    """Distance FGW ."""
    if log:
        T, logd = fused_gromov_wasserstein_sinkhorn(
            M, C1, C2, p, q, loss_fun=loss_fun, alpha=alpha, eps=eps,
            armijo=armijo, G0=G0, log=True, numItermax=numItermax, stopThr=stopThr,
            verbose=verbose, sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol,
            polished=polished
        )
    else:
        T = fused_gromov_wasserstein_sinkhorn(
            M, C1, C2, p, q, loss_fun=loss_fun, alpha=alpha, eps=eps,
            armijo=armijo, G0=G0, log=False, numItermax=numItermax, stopThr=stopThr,
            verbose=verbose, sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol,
            polished=polished
        )

    
    p_n = (p / (p.sum() + 1e-300)).to(M.dtype)
    q_n = (q / (q.sum() + 1e-300)).to(M.dtype)
    constC, hC1, hC2 = init_matrix(C1, C2, p_n.to(C1.dtype), q_n.to(C2.dtype), loss_fun)
    W = (M * T).sum()
    GW = gw_loss(constC, hC1, hC2, T)
    fgw = (1.0 - alpha) * W + alpha * GW

    if log:
        logd['fgw_dist'] = fgw.detach().cpu().item()
        return fgw, logd
    return fgw


# ---------------------------------------------------------------------------
# Mini self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    torch.set_printoptions(precision=4, sci_mode=True)
    ns, nt = 5, 4
    rng = torch.Generator().manual_seed(0)
    C1 = torch.rand(ns, ns, generator=rng, dtype=torch.float64); C1 = 0.5*(C1+C1.T)
    C2 = torch.rand(nt, nt, generator=rng, dtype=torch.float64); C2 = 0.5*(C2+C2.T)
    X  = torch.rand(ns, 3, generator=rng, dtype=torch.float64)
    Y  = torch.rand(nt, 3, generator=rng, dtype=torch.float64)
    M  = torch.cdist(X, Y, p=2)
    p  = torch.full((ns,), 1.0/ns, dtype=torch.float64)
    q  = torch.full((nt,), 1.0/nt, dtype=torch.float64)
    T  = fused_gromov_wasserstein_sinkhorn(M, C1, C2, p, q, alpha=0.3, eps=2e-2, numItermax=50, sinkhorn_max_iter=500)
    d  = fused_gromov_wasserstein2_sinkhorn(M, C1, C2, p, q, alpha=0.3, eps=2e-2, numItermax=50, sinkhorn_max_iter=500)
    print("T sum:", T.sum().item(), "FGW:", float(d))
