
#!/usr/bin/env python3

"""

ICML experiment runner for ITSPACE.



Implements:

- Exact Gaussian BW evaluation (Eq. 2) with robust PSD handling.

- Prefix-time logging (cumulative) with separate update/eval/projection/total times.

- Early stopping + time-to-threshold summaries.

- Baselines: ITSPACE, BW geodesic, Euclidean, Log-Euclid, AIRM, CORAL, Sinkhorn, Sinkhorn-Gaussian, BW-GD.



Self-contained (does not depend on runners/run_experiments.py).

"""

from __future__ import annotations



import argparse

import csv

import json

import math

import os

import time

from dataclasses import dataclass

from typing import Dict, List, Optional, Tuple



import numpy as np

import torch





# -----------------------------

# Small utilities

# -----------------------------


def _pick_itspace_Q(env):
    """Pick the ITSPACE polar factor variable from a locals() dict (no commas in call site)."""
    for n in ('Q','Qk','Q_polar','Qhat','Q_hat','Qns','Q_ns','Qtilde','Q_tilde'):
        if n in env:
            return env[n]
    return None

def _itspace_orth_residual(Q):
    """Orth residual ||Q^T Q - I||_F (torch or numpy)."""
    import math
    if Q is None:
        return float('nan')
    try:
        import torch
        if isinstance(Q, torch.Tensor):
            QtQ = Q.transpose(-2, -1) @ Q
            I = torch.eye(QtQ.shape[-1], device=Q.device, dtype=Q.dtype)
            return float(torch.linalg.norm(QtQ - I, ord='fro').item())
    except Exception:
        pass
    import numpy as np
    QtQ = Q.T @ Q
    I = np.eye(QtQ.shape[-1], dtype=QtQ.dtype)
    return float(np.linalg.norm(QtQ - I, ord='fro'))
def _ts() -> str:

    return time.strftime("%Y%m%d_%H%M%S", time.localtime())





def _ensure_dir(p: str) -> None:

    os.makedirs(p, exist_ok=True)





def parse_csv_list(s: str, cast=int) -> List:

    if s is None:

        return []

    s = str(s).strip()

    if not s:

        return []

    out = []

    for part in s.split(","):

        part = part.strip()

        if part:

            out.append(cast(part))

    return out





def pick_device(device_str: str) -> torch.device:

    if device_str == "auto":

        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return torch.device(device_str)





def cuda_sync(device: torch.device) -> None:

    if device.type == "cuda":

        torch.cuda.synchronize(device)





def set_all_seeds(seed: int) -> None:

    np.random.seed(seed)

    torch.manual_seed(seed)

    torch.cuda.manual_seed_all(seed)





def sym(A: torch.Tensor) -> torch.Tensor:

    return 0.5 * (A + A.transpose(-1, -2))





def trace(A: torch.Tensor) -> torch.Tensor:

    return torch.diagonal(A).sum()





def eig_clamp_psd(A: torch.Tensor, clamp: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:

    """Eigenvalues ascending and eigenvectors of sym(A). Clamps eigenvalues to >= clamp."""

    evals, evecs = torch.linalg.eigh(sym(A))

    if clamp is not None:

        evals = torch.clamp(evals, min=clamp)

    return evals, evecs





def sqrtm_psd(A: torch.Tensor, clamp: float = 0.0) -> torch.Tensor:

    evals, evecs = eig_clamp_psd(A, clamp=clamp)

    s = torch.sqrt(evals)

    return (evecs * s.unsqueeze(0)) @ evecs.T





def invsqrtm_psd(A: torch.Tensor, clamp: float = 0.0, eps: float = 0.0) -> torch.Tensor:

    """

    Pseudo-inverse square root for PSD matrices.

    Eigenvalues <= eps are treated as 0 in the inverse.

    """

    evals, evecs = eig_clamp_psd(A, clamp=clamp)

    denom = torch.sqrt(torch.clamp(evals, min=0.0))

    inv = torch.zeros_like(denom)

    mask = denom > eps

    inv[mask] = 1.0 / denom[mask]

    return (evecs * inv.unsqueeze(0)) @ evecs.T





def logm_spd(A: torch.Tensor, clamp: float = 1e-12) -> torch.Tensor:

    evals, evecs = eig_clamp_psd(A, clamp=clamp)

    L = torch.log(evals)

    return (evecs * L.unsqueeze(0)) @ evecs.T





def expm_sym(A: torch.Tensor) -> torch.Tensor:

    evals, evecs = torch.linalg.eigh(sym(A))

    E = torch.exp(evals)

    return (evecs * E.unsqueeze(0)) @ evecs.T





def powm_spd(A: torch.Tensor, p: float, clamp: float = 1e-12) -> torch.Tensor:

    evals, evecs = eig_clamp_psd(A, clamp=clamp)

    P = torch.pow(evals, p)

    return (evecs * P.unsqueeze(0)) @ evecs.T





def ensure_spd(A: torch.Tensor, jitter: float = 1e-8) -> Tuple[torch.Tensor, float]:

    """

    Ensure SPD by shifting by delta*I if needed; returns (A_spd, delta).

    Uses eigvalsh (O(d^3)) – only used for baselines that truly need SPD.

    """

    A = sym(A)

    evals = torch.linalg.eigvalsh(A)

    min_e = float(evals.min().item())

    delta = 0.0

    if min_e <= 0.0:

        delta = (-min_e) + jitter

    if delta > 0.0:

        I = torch.eye(A.shape[0], device=A.device, dtype=A.dtype)

        A = A + delta * I

    return A, delta





def project_psd_rank(

    X: torch.Tensor,

    rank: Optional[int],

    clamp: float = 0.0,

) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    """

    Project to PSD and optionally rank-r (top-r eigenpairs).

    Returns (Xp, Y, evals_kept) where Xp = Y Y^T.

    """

    evals, evecs = eig_clamp_psd(X, clamp=clamp)

    d = evals.numel()

    if rank is None or rank <= 0 or rank >= d:

        keep = slice(None)

    else:

        keep = slice(d - rank, d)

    evals_k = evals[keep]

    evecs_k = evecs[:, keep]

    Xp = (evecs_k * evals_k.unsqueeze(0)) @ evecs_k.T

    Y = evecs_k * torch.sqrt(evals_k).unsqueeze(0)

    return sym(Xp), Y, evals_k





# -----------------------------

# BW objective (exact eval)

# -----------------------------

def bw2_full(

    X: torch.Tensor,

    Sigma: torch.Tensor,

    Sigma_sqrt: torch.Tensor,

    trSigma: torch.Tensor,

) -> torch.Tensor:

    """

    Exact Gaussian BW:

      tr(X) + tr(Sigma) - 2 tr( (Sigma^{1/2} X Sigma^{1/2})^{1/2} )

    Robust to tiny negatives via eigen-clamp on the middle PSD term.

    """

    Xs = sym(X)

    M = sym(Sigma_sqrt @ Xs @ Sigma_sqrt)

    evals, evecs = eig_clamp_psd(M, clamp=0.0)

    sqrtM = (evecs * torch.sqrt(evals).unsqueeze(0)) @ evecs.T

    return trace(Xs) + trSigma - 2.0 * trace(sqrtM)





def bw2_from_factor(

    Y: torch.Tensor,

    Sigma: torch.Tensor,

    trSigma: torch.Tensor,

) -> torch.Tensor:

    """

    Exact BW for PSD X=Y Y^T using:

      tr( (Sigma^{1/2} X Sigma^{1/2})^{1/2} ) = tr( (Y^T Sigma Y)^{1/2} )

    """

    trX = (Y * Y).sum()

    G = sym(Y.T @ Sigma @ Y)

    evals, _ = eig_clamp_psd(G, clamp=0.0)

    term = torch.sqrt(torch.clamp(evals, min=0.0)).sum()

    return trX + trSigma - 2.0 * term





def transport_map(

    Sigma_sqrt: torch.Tensor,

    X: torch.Tensor,

) -> torch.Tensor:

    """

    T(X) = Sigma^{1/2} (Sigma^{1/2} X Sigma^{1/2})^{-1/2} Sigma^{1/2}

    """

    M = sym(Sigma_sqrt @ sym(X) @ Sigma_sqrt)

    invsqrtM = invsqrtm_psd(M, clamp=0.0, eps=0.0)

    return sym(Sigma_sqrt @ invsqrtM @ Sigma_sqrt)





# -----------------------------

# ITSPACE polar (with diagnostics)

# -----------------------------

def polar_factor(

    A: torch.Tensor,

    mode: str = "gram",   # gram | svd | ns

    ns_iters: int = 6,

) -> Tuple[torch.Tensor, float, float]:

    """

    Polar factor for tall A (d x r): returns Q with orthonormal columns.

    Also returns:

      orth_resid = || Q^T Q - I ||_F

      cert_gap   = ||A||_* - tr(Q^T A)  (clamped >=0 for interpretability)

    """

    _, r = A.shape

    AtA = sym(A.T @ A)



    # nuclear norm via sqrt(AtA) (exact for tall A)

    evals, _ = eig_clamp_psd(AtA, clamp=0.0)

    nuc = torch.sqrt(torch.clamp(evals, min=0.0)).sum()



    if mode == "svd":

        U, _, Vh = torch.linalg.svd(A, full_matrices=False)

        Q = U @ Vh

    else:

        if mode == "ns":

            # Newton–Schulz for invsqrt on r x r

            I = torch.eye(r, device=A.device, dtype=A.dtype)

            norm = torch.linalg.norm(AtA)

            if float(norm.item()) == 0.0:

                invsqrt = I

            else:

                Y = AtA / norm

                Z = I

                for _ in range(ns_iters):

                    T = 0.5 * (3.0 * I - Z @ Y)

                    Y = Y @ T

                    Z = T @ Z

                invsqrt = Z / torch.sqrt(norm)

        else:

            invsqrt = invsqrtm_psd(AtA, clamp=0.0, eps=0.0)

        Q = A @ invsqrt



    QtQ = sym(Q.T @ Q)

    I = torch.eye(QtQ.shape[0], device=A.device, dtype=A.dtype)

    orth = float(torch.linalg.norm(QtQ - I).item())



    cert = float((nuc - trace(Q.T @ A)).item())

    cert = max(0.0, cert)

    return Q, orth, cert





# -----------------------------

# Sinkhorn (sample OT)

# -----------------------------

@torch.no_grad()

def sinkhorn_log_domain(

    cost: torch.Tensor,         # (n, m) >= 0

    eps: float,

    iters: int,

    log_u: Optional[torch.Tensor] = None,

    log_v: Optional[torch.Tensor] = None,

) -> Tuple[torch.Tensor, torch.Tensor]:

    """

    Balanced Sinkhorn in log-domain with uniform marginals.

    Returns (log_u, log_v).

    """

    n, m = cost.shape

    device = cost.device

    dtype = cost.dtype



    log_a = -math.log(n) * torch.ones(n, device=device, dtype=dtype)

    log_b = -math.log(m) * torch.ones(m, device=device, dtype=dtype)



    logK = -cost / eps  # (n, m)



    if log_u is None:

        log_u = torch.zeros(n, device=device, dtype=dtype)

    if log_v is None:

        log_v = torch.zeros(m, device=device, dtype=dtype)



    for _ in range(iters):

        log_u = log_a - torch.logsumexp(logK + log_v.unsqueeze(0), dim=1)

        log_v = log_b - torch.logsumexp(logK.transpose(0, 1) + log_u.unsqueeze(0), dim=1)



    return log_u, log_v





@torch.no_grad()

def sinkhorn_barycentric_cov(

    Xs: torch.Tensor,  # (n, d)

    Xt: torch.Tensor,  # (m, d)

    eps: float,

    iters: int,

    log_u: Optional[torch.Tensor] = None,

    log_v: Optional[torch.Tensor] = None,

) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    """

    Run Sinkhorn (iters) and return:

      cov: covariance of barycentric-projected source points

      log_u, log_v: warm-start states

    """

    x2 = (Xs * Xs).sum(dim=1, keepdim=True)  # (n,1)

    y2 = (Xt * Xt).sum(dim=1, keepdim=True).transpose(0, 1)  # (1,m)

    cost = x2 + y2 - 2.0 * (Xs @ Xt.transpose(0, 1))

    cost = torch.clamp(cost, min=0.0)



    log_u, log_v = sinkhorn_log_domain(cost=cost, eps=eps, iters=iters, log_u=log_u, log_v=log_v)

    logP = log_u.unsqueeze(1) + (-cost / eps) + log_v.unsqueeze(0)

    P = torch.exp(logP)  # (n,m)



    row_sum = P.sum(dim=1, keepdim=True)  # (n,1)

    mapped = (P @ Xt) / torch.clamp(row_sum, min=1e-32)



    mapped = mapped.to(torch.float64)

    mu = mapped.mean(dim=0, keepdim=True)

    Z = mapped - mu

    cov = (Z.transpose(0, 1) @ Z) / max(1, (mapped.shape[0] - 1))

    return sym(cov), log_u, log_v





# -----------------------------

# Timing

# -----------------------------

@dataclass

class TimeState:

    update: float = 0.0

    proj: float = 0.0

    eval: float = 0.0

    overhead: float = 0.0



    @property

    def total(self) -> float:

        return self.update + self.proj + self.eval + self.overhead





def _time_block(device: torch.device) -> float:

    cuda_sync(device)

    return time.perf_counter()





def _time_block_end(device: torch.device, t0: float) -> float:

    cuda_sync(device)

    return time.perf_counter() - t0





# -----------------------------

# Method runners

# -----------------------------

def run_itspace(

    X0: torch.Tensor,

    Y0: torch.Tensor,

    Ct: torch.Tensor,

    Ct_sqrt: torch.Tensor,

    trCt: torch.Tensor,

    K: int,

    lam: float,

    polar_mode: str,

    polar_ns_iters: int,

    device: torch.device,

    rank: Optional[int],

    early_stop_tau: Optional[float],

    eval_fullrank: bool,

) -> Tuple[List[Dict], Dict]:

    d = X0.shape[0]

    Y = Y0.clone()



    alpha = (2.0 * lam) / (1.0 + 2.0 * lam)

    time_state = TimeState()

    rows: List[Dict] = []



    bw0 = bw2_full(X0, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Y0, Ct, trCt)

    bw0f = float(bw0.item()) if float(bw0.item()) != 0.0 else 1.0



    def _eval(Ycur: torch.Tensor, Xcur: torch.Tensor) -> float:

        t0 = _time_block(device)

        bw = bw2_full(Xcur, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Ycur, Ct, trCt)

        time_state.eval += _time_block_end(device, t0)

        return float(bw.item())



    X = sym(Y @ Y.T)

    bw = _eval(Y, X)

    rows.append({

        "step": 0, "t": 0.0, "bw2": bw, "rel_bw2": bw / bw0f,

        "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

        "itspace_polar_orth": (lambda _Q: float(torch.linalg.norm(_Q.transpose(-2,-1)@_Q - torch.eye(_Q.shape[-1], device=_Q.device, dtype=_Q.dtype), ord='fro').item()) if _Q is not None else float('nan'))(locals().get('Q', None)), "itspace_polar_qcols": (lambda _Q: int(_Q.shape[-1]) if _Q is not None else -1)(locals().get('Q', None)), "itspace_polar_qtq_fro": (lambda _Q: float(torch.linalg.norm(_Q.transpose(-2,-1)@_Q, ord='fro').item()) if _Q is not None else float('nan'))(locals().get('Q', None)), "itspace_polar_cert_gap": None,

    })

    if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

        return rows, {"bw0": bw0f}



    for k in range(1, K + 1):

        t_iter0 = _time_block(device)

        A = Ct_sqrt @ Y

        Q, orth, cert = polar_factor(A, mode=polar_mode, ns_iters=polar_ns_iters)

        B = Ct_sqrt @ Q

        Y = alpha * B + (1.0 - alpha) * Y

        X = sym(Y @ Y.T)

        time_state.update += _time_block_end(device, t_iter0)



        bw = _eval(Y, X)

        rows.append({

            "step": k, "t": k / float(K), "bw2": bw, "rel_bw2": bw / bw0f,

            "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

            "itspace_polar_orth": (lambda _Q: float(torch.linalg.norm(_Q.transpose(-2,-1)@_Q - torch.eye(_Q.shape[-1], device=_Q.device, dtype=_Q.dtype), ord='fro').item()) if _Q is not None else float('nan'))(locals().get('Q', None)), "itspace_polar_qcols": (lambda _Q: int(_Q.shape[-1]) if _Q is not None else -1)(locals().get('Q', None)), "itspace_polar_qtq_fro": (lambda _Q: float(torch.linalg.norm(_Q.transpose(-2,-1)@_Q, ord='fro').item()) if _Q is not None else float('nan'))(locals().get('Q', None)), "itspace_polar_cert_gap": cert,

        })

        if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

            break



    return rows, {"bw0": bw0f}





def run_path_method(

    name: str,

    X0: torch.Tensor,

    Ct: torch.Tensor,

    Ct_sqrt: torch.Tensor,

    trCt: torch.Tensor,

    K: int,

    device: torch.device,

    rank: Optional[int],

    early_stop_tau: Optional[float],

    path_cache: Dict,

) -> Tuple[List[Dict], Dict]:

    d = X0.shape[0]

    time_state = TimeState()

    rows: List[Dict] = []



    eval_fullrank = (rank is None) or (rank <= 0) or (rank >= d)



    if eval_fullrank:

        bw0 = bw2_full(X0, Ct, Ct_sqrt, trCt)

        Y0 = None

    else:

        _, Y0, _ = project_psd_rank(X0, rank=rank, clamp=0.0)

        bw0 = bw2_from_factor(Y0, Ct, trCt)

    bw0f = float(bw0.item()) if float(bw0.item()) != 0.0 else 1.0



    # Precompute caches (counted as update time)

    t0 = _time_block(device)

    cache = {"jitter_x0": 0.0, "jitter_ct": 0.0}



    if name == "bw_geodesic":

        X0_spd, delta = ensure_spd(X0, jitter=1e-8)

        cache["jitter_x0"] = delta

        T = transport_map(Ct_sqrt, X0_spd)

        cache["T"] = T

        cache["I"] = torch.eye(d, device=device, dtype=X0.dtype)



    elif name == "logeuclid":

        X0_spd, delta0 = ensure_spd(X0, jitter=1e-8)

        Ct_spd, delta1 = ensure_spd(Ct, jitter=1e-8)

        cache["jitter_x0"] = delta0

        cache["jitter_ct"] = delta1

        cache["L0"] = logm_spd(X0_spd, clamp=1e-12)

        cache["L1"] = logm_spd(Ct_spd, clamp=1e-12)



    elif name == "airm":

        X0_spd, delta0 = ensure_spd(X0, jitter=1e-8)

        Ct_spd, delta1 = ensure_spd(Ct, jitter=1e-8)

        cache["jitter_x0"] = delta0

        cache["jitter_ct"] = delta1

        X0_sqrt = sqrtm_psd(X0_spd, clamp=0.0)

        X0_invsqrt = invsqrtm_psd(X0_spd, clamp=1e-12, eps=0.0)

        cache["X0_sqrt"] = X0_sqrt

        cache["M"] = sym(X0_invsqrt @ Ct_spd @ X0_invsqrt)



    elif name == "coral":

        X0_spd, delta0 = ensure_spd(X0, jitter=1e-8)

        cache["jitter_x0"] = delta0

        X0_invsqrt = invsqrtm_psd(X0_spd, clamp=1e-12, eps=0.0)

        A = Ct_sqrt @ X0_invsqrt

        cache["A"] = A

        cache["I"] = torch.eye(d, device=device, dtype=X0.dtype)



    elif name == "euclidean":

        pass

    else:

        raise ValueError(f"Unknown path method: {name}")



    time_state.update += _time_block_end(device, t0)

    path_cache.update(cache)



    def _project_if_needed(X: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        if eval_fullrank:

            return X, None

        tproj0 = _time_block(device)

        Xp, Yp, _ = project_psd_rank(X, rank=rank, clamp=0.0)

        time_state.proj += _time_block_end(device, tproj0)

        return Xp, Yp



    def _eval(X: torch.Tensor, Y: Optional[torch.Tensor]) -> float:

        te0 = _time_block(device)

        bw = bw2_full(X, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Y, Ct, trCt)

        time_state.eval += _time_block_end(device, te0)

        return float(bw.item())



    Xp0, Yp0 = _project_if_needed(X0)

    bw = _eval(Xp0, Yp0)

    rows.append({

        "step": 0, "t": 0.0, "bw2": bw, "rel_bw2": bw / bw0f,

        "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

        "jitter_x0": path_cache["jitter_x0"], "jitter_ct": path_cache["jitter_ct"],

    })

    if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

        return rows, {"bw0": bw0f, **path_cache}



    for k in range(1, K + 1):

        t = k / float(K)

        t_iter0 = _time_block(device)



        if name == "bw_geodesic":

            T = path_cache["T"]

            I = path_cache["I"]

            A_t = (1.0 - t) * I + t * T

            X = sym(A_t @ X0 @ A_t.T)

        elif name == "euclidean":

            X = sym((1.0 - t) * X0 + t * Ct)

        elif name == "logeuclid":

            L = (1.0 - t) * path_cache["L0"] + t * path_cache["L1"]

            X = sym(expm_sym(L))

        elif name == "airm":

            X0_sqrt = path_cache["X0_sqrt"]

            M = path_cache["M"]

            Mt = powm_spd(M, p=t, clamp=1e-12)

            X = sym(X0_sqrt @ Mt @ X0_sqrt)

        elif name == "coral":

            I = path_cache["I"]

            A = path_cache["A"]

            B_t = (1.0 - t) * I + t * A

            X = sym(B_t @ X0 @ B_t.T)

        else:

            raise RuntimeError("unreachable")



        time_state.update += _time_block_end(device, t_iter0)



        Xp, Yp = _project_if_needed(X)

        bw = _eval(Xp, Yp)



        rows.append({

            "step": k, "t": t, "bw2": bw, "rel_bw2": bw / bw0f,

            "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

            "jitter_x0": path_cache["jitter_x0"], "jitter_ct": path_cache["jitter_ct"],

        })

        if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

            break



    return rows, {"bw0": bw0f, **path_cache}





def run_bw_gd(

    X0: torch.Tensor,

    Ct: torch.Tensor,

    Ct_sqrt: torch.Tensor,

    trCt: torch.Tensor,

    K: int,

    eta0: float,

    max_backtracks: int,

    bt_shrink: float,

    device: torch.device,

    rank: Optional[int],

    early_stop_tau: Optional[float],

) -> Tuple[List[Dict], Dict]:

    """

    Literature BW/Bures GD baseline (reviewer-safe):



    Uses the Bures–Wasserstein Riemannian GD / BW-SGD geodesic step:

        T = T_{X -> Ct} (optimal Gaussian transport map, symmetric PSD)

        G = (1-eta) I + eta T

        X_next = G X G



    For rank-constrained runs (rank=r): apply Option-B projection (top-r eigens) after the step.

    Backtracking enforces descent AFTER projection.



    Timing:

      - map + matrix multiplies + internal BW checks => t_update

      - rank projection => t_proj

      - per-iterate shared logging eval => t_eval

    """

    d = X0.shape[0]

    eval_fullrank = (rank is None) or (rank <= 0) or (rank >= d)

    I = torch.eye(d, device=device, dtype=X0.dtype)



    # initial projection if needed (Option B)

    tproj0 = _time_block(device)

    X, Y, _ = project_psd_rank(X0, rank=rank if not eval_fullrank else None, clamp=0.0)

    proj_dt = _time_block_end(device, tproj0)



    time_state = TimeState(update=0.0, proj=proj_dt, eval=0.0, overhead=0.0)

    rows: List[Dict] = []



    bw0 = bw2_full(X, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Y, Ct, trCt)

    bw0f = float(bw0.item()) if float(bw0.item()) != 0.0 else 1.0



    def _bw_value(Xcur: torch.Tensor, Ycur: Optional[torch.Tensor]) -> torch.Tensor:

        return bw2_full(Xcur, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Ycur, Ct, trCt)



    # shared external evaluation bucket (once per iterate)

    def _eval_external(Xcur: torch.Tensor, Ycur: Optional[torch.Tensor]) -> float:

        te0 = _time_block(device)

        bw = _bw_value(Xcur, Ycur if not eval_fullrank else None)

        time_state.eval += _time_block_end(device, te0)

        return float(bw.item())



    # internal method checks bucket (line search / backtracking) counts to UPDATE

    def _eval_internal(Xcur: torch.Tensor, Ycur: Optional[torch.Tensor]) -> float:

        tu0 = _time_block(device)

        bw = _bw_value(Xcur, Ycur if not eval_fullrank else None)

        time_state.update += _time_block_end(device, tu0)

        return float(bw.item())



    bw = _eval_external(X, Y if not eval_fullrank else None)

    rows.append({

        "step": 0, "t": 0.0, "bw2": bw, "rel_bw2": bw / bw0f,

        "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

        "bw_gd_backtracks": 0, "bw_gd_eta": eta0,

    })

    if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

        return rows, {"bw0": bw0f}



    for k in range(1, K + 1):

        bw_prev = bw

        eta = eta0

        accepted = False

        bt_used = 0



        # compute transport map once at current X (update time)

        tu_map = _time_block(device)

        X_spd, _delta = ensure_spd(X, jitter=1e-8)

        T = transport_map(Ct_sqrt, X_spd)  # symmetric map, intended T X T ≈ Ct

        time_state.update += _time_block_end(device, tu_map)



        for bt in range(max_backtracks + 1):

            bt_used = bt



            # geodesic step (update time)

            tu_step = _time_block(device)

            G = (1.0 - eta) * I + eta * T

            X_trial = sym(G @ X @ G)

            time_state.update += _time_block_end(device, tu_step)



            # rank projection (projection time)

            tp = _time_block(device)

            Xp, Yp, _ = project_psd_rank(X_trial, rank=rank if not eval_fullrank else None, clamp=0.0)

            time_state.proj += _time_block_end(device, tp)



            # accept/reject based on post-projection BW (internal check => update bucket)

            bw_trial = _eval_internal(Xp, Yp if not eval_fullrank else None)



            if bw_trial <= bw_prev + 1e-12:

                X, Y, bw = Xp, Yp, bw_trial

                accepted = True

                break



            eta *= bt_shrink



        if not accepted:

            # keep previous iterate; carry forward last tried eta (more conservative next step)

            eta0 = eta

            bw = bw_prev



        # shared external eval for logging (counts to t_eval exactly once per iterate)

        bw_log = _eval_external(X, Y if not eval_fullrank else None)

        bw = bw_log



        rows.append({

            "step": k, "t": k / float(K), "bw2": bw, "rel_bw2": bw / bw0f,

            "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

            "bw_gd_backtracks": bt_used, "bw_gd_eta": eta,

        })



        if early_stop_tau is not None and (bw / bw0f) <= early_stop_tau:

            break



    return rows, {"bw0": bw0f}

def run_sinkhorn_method(

    name: str,

    X0: torch.Tensor,

    Ct: torch.Tensor,

    Ct_sqrt: torch.Tensor,

    trCt: torch.Tensor,

    Xs: Optional[np.ndarray],

    Xt: Optional[np.ndarray],

    K: int,

    eps: float,

    n_samples: int,

    iters_per_step: int,

    device: torch.device,

    rank: Optional[int],

    seed: int,

    early_stop_tau: Optional[float],

    Cs: torch.Tensor,

) -> Tuple[List[Dict], Dict]:

    d = Ct.shape[0]

    eval_fullrank = (rank is None) or (rank <= 0) or (rank >= d)



    time_state = TimeState()

    rows: List[Dict] = []



    rng = np.random.RandomState(seed)



    if name == "sinkhorn":

        if Xs is None or Xt is None:

            raise ValueError("sinkhorn requires Xs.npy and Xt.npy in the dataset pack.")

        ns = min(n_samples, Xs.shape[0])

        nt = min(n_samples, Xt.shape[0])

        idx_s = rng.choice(Xs.shape[0], size=ns, replace=False)

        idx_t = rng.choice(Xt.shape[0], size=nt, replace=False)

        Xs_t = torch.as_tensor(Xs[idx_s], device=device, dtype=torch.float32)

        Xt_t = torch.as_tensor(Xt[idx_t], device=device, dtype=torch.float32)



    elif name == "sinkhorn_gaus":

        ns = n_samples

        nt = n_samples

        t_upd = _time_block(device)

        Cs_cpu = Cs.detach().cpu().to(torch.float64)

        Ct_cpu = Ct.detach().cpu().to(torch.float64)

        Cs_sqrt = sqrtm_psd(Cs_cpu, clamp=0.0)

        Ct_sqrt_cpu = sqrtm_psd(Ct_cpu, clamp=0.0)

        Zs = torch.randn(ns, d, dtype=torch.float64)

        Zt = torch.randn(nt, d, dtype=torch.float64)

        Xs_t = (Zs @ Cs_sqrt.T).to(device=device, dtype=torch.float32)

        Xt_t = (Zt @ Ct_sqrt_cpu.T).to(device=device, dtype=torch.float32)

        time_state.update += _time_block_end(device, t_upd)



    else:

        raise ValueError(name)



    if eval_fullrank:

        bw0 = bw2_full(X0, Ct, Ct_sqrt, trCt)

        Y0 = None

    else:

        _, Y0, _ = project_psd_rank(X0, rank=rank, clamp=0.0)

        bw0 = bw2_from_factor(Y0, Ct, trCt)

    bw0f = float(bw0.item()) if float(bw0.item()) != 0.0 else 1.0



    te0 = _time_block(device)

    bw_init = float((bw2_full(X0, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Y0, Ct, trCt)).item())

    time_state.eval += _time_block_end(device, te0)

    rows.append({

        "step": 0, "t": 0.0, "bw2": bw_init, "rel_bw2": bw_init / bw0f,

        "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

        "sinkhorn_eps": eps, "sinkhorn_n": n_samples, "sinkhorn_iters_per_step": iters_per_step,

    })

    if early_stop_tau is not None and (bw_init / bw0f) <= early_stop_tau:

        return rows, {"bw0": bw0f}



    log_u = None

    log_v = None



    for k in range(1, K + 1):

        t_upd0 = _time_block(device)

        cov, log_u, log_v = sinkhorn_barycentric_cov(

            Xs=Xs_t, Xt=Xt_t, eps=eps, iters=iters_per_step, log_u=log_u, log_v=log_v

        )

        time_state.update += _time_block_end(device, t_upd0)



        if eval_fullrank:

            Xk, Yk = cov, None

        else:

            tproj = _time_block(device)

            Xk, Yk, _ = project_psd_rank(cov, rank=rank, clamp=0.0)

            time_state.proj += _time_block_end(device, tproj)



        te = _time_block(device)

        bw = bw2_full(Xk, Ct, Ct_sqrt, trCt) if eval_fullrank else bw2_from_factor(Yk, Ct, trCt)

        time_state.eval += _time_block_end(device, te)

        bwf = float(bw.item())



        rows.append({

            "step": k, "t": k / float(K), "bw2": bwf, "rel_bw2": bwf / bw0f,

            "t_update": time_state.update, "t_proj": time_state.proj, "t_eval": time_state.eval, "t_total": time_state.total,

            "sinkhorn_eps": eps, "sinkhorn_n": n_samples, "sinkhorn_iters_per_step": iters_per_step,

        })

        if early_stop_tau is not None and (bwf / bw0f) <= early_stop_tau:

            break



    return rows, {"bw0": bw0f}





# -----------------------------

# Summaries

# -----------------------------

def time_to_threshold(rows: List[Dict], tau: float, time_key: str = "t_total") -> Optional[float]:

    for r in rows:

        if r.get("rel_bw2", 1.0) <= tau:

            return float(r.get(time_key, None))

    return None





def min_bw(rows: List[Dict]) -> float:

    return min(float(r["bw2"]) for r in rows)





# -----------------------------

# Main

# -----------------------------

def main():

    ap = argparse.ArgumentParser()

    ap.add_argument("--dataset", type=str, required=True, help="Pack dir with Cs.npy/Ct.npy and optional Xs.npy/Xt.npy")

    ap.add_argument("--seeds", type=str, default="0,1,2")

    ap.add_argument("--out", type=str, default="results")

    ap.add_argument("--K", type=int, default=50)

    ap.add_argument("--rank", type=int, default=0, help="0 or >=d => full rank; else project to rank-r")

    ap.add_argument("--methods", type=str, default="itspace,bw_geodesic,euclidean,logeuclid,airm,coral,sinkhorn,sinkhorn_gaus,bw_gd")

    ap.add_argument("--device", type=str, default="auto", help="auto|cpu|cuda")

    ap.add_argument("--dtype", type=str, default="float64", choices=["float64", "float32"])



    ap.add_argument("--early-stop-tau", type=float, default=1e-6, help="Stop once relBW <= this. Set 0 to disable.")

    ap.add_argument("--tau-report", type=str, default="1e-6,1e-9", help="Thresholds to report in summary.")



    ap.add_argument("--itspace-lambda", type=float, default=6.0)

    ap.add_argument("--itspace-polar", type=str, default="gram", choices=["gram", "svd", "ns"])

    ap.add_argument("--itspace-polar-ns-iters", type=int, default=6)



    ap.add_argument("--bw-gd-eta", type=float, default=0.3)

    ap.add_argument("--bw-gd-max-backtracks", type=int, default=10)

    ap.add_argument("--bw-gd-bt-shrink", type=float, default=0.5)



    ap.add_argument("--sinkhorn-eps", type=float, default=0.05)

    ap.add_argument("--sinkhorn-n", type=int, default=1024)

    ap.add_argument("--sinkhorn-iters-per-step", type=int, default=20)
    ap.add_argument("--run-legacy-ntu-posthoc", action="store_true", help="Run legacy NTU posthoc diagnostics (W2, not squared). Outputs are quarantined to legacy_ntu_posthoc/.")




    args = ap.parse_args()



    device = pick_device(args.device)

    dtype = torch.float64 if args.dtype == "float64" else torch.float32



    seeds = parse_csv_list(args.seeds, int)

    methods = [m.strip() for m in args.methods.split(",") if m.strip()]

    tau_report = parse_csv_list(args.tau_report, float)

    early_stop_tau = args.early_stop_tau if args.early_stop_tau and args.early_stop_tau > 0 else None



    pack_dir = args.dataset

    Cs_np = np.load(os.path.join(pack_dir, "Cs.npy"))

    Ct_np = np.load(os.path.join(pack_dir, "Ct.npy"))



    Xs_np = None

    Xt_np = None

    xs_path = os.path.join(pack_dir, "Xs.npy")

    xt_path = os.path.join(pack_dir, "Xt.npy")

    if os.path.exists(xs_path) and os.path.exists(xt_path):

        Xs_np = np.load(xs_path, mmap_mode="r")

        Xt_np = np.load(xt_path, mmap_mode="r")



    Cs = sym(torch.as_tensor(Cs_np, device=device, dtype=dtype))

    Ct = sym(torch.as_tensor(Ct_np, device=device, dtype=dtype))

    d = Ct.shape[0]



    rank = args.rank if args.rank and args.rank > 0 and args.rank < d else None

    eval_fullrank = rank is None



    # Precompute target sqrt + trace (counted in config)

    cuda_sync(device)

    t_pre0 = time.perf_counter()

    Ct_sqrt = sqrtm_psd(Ct, clamp=0.0)

    trCt = trace(Ct)

    cuda_sync(device)

    pre_dt = time.perf_counter() - t_pre0



    ds_id = os.path.basename(pack_dir.rstrip("/"))
    tag = f"{ds_id}_results_icml_d{d}" + (f"_r{rank}" if rank else "_full") + f"_K{args.K}_{_ts()}"

    out_dir = os.path.join(args.out, tag)

    _ensure_dir(out_dir)



    cfg = vars(args).copy()

    cfg.update({

        "resolved_device": str(device),

        "resolved_dtype": str(dtype),

        "d": int(d),

        "rank_resolved": (rank if rank else 0),

        "precompute_time_s": pre_dt,

        "methods_resolved": methods,

        "tau_report_resolved": tau_report,

        "early_stop_tau_resolved": early_stop_tau,

        "pack_dir": os.path.abspath(pack_dir),

    })

    with open(os.path.join(out_dir, "config.json"), "w") as f:

        json.dump(cfg, f, indent=2)



    # Common initial X0 (+ factor Y0 for ITSPACE / low-rank eval)

    if rank is None:

        X0 = sym(Cs)

        Y0 = sqrtm_psd(X0, clamp=0.0)

    else:

        X0, Y0, _ = project_psd_rank(Cs, rank=rank, clamp=0.0)



    all_rows: List[Dict] = []

    summary_rows: List[Dict] = []



    for seed in seeds:

        set_all_seeds(seed)

        for method in methods:

            status = "ok"

            err = ""

            try:

                if method == "itspace":

                    rows, meta = run_itspace(

                        X0=X0, Y0=Y0, Ct=Ct, Ct_sqrt=Ct_sqrt, trCt=trCt, K=args.K,

                        lam=args.itspace_lambda, polar_mode=args.itspace_polar, polar_ns_iters=args.itspace_polar_ns_iters,

                        device=device, rank=rank, early_stop_tau=early_stop_tau, eval_fullrank=eval_fullrank,

                    )

                elif method in ("bw_geodesic", "euclidean", "logeuclid", "airm", "coral"):

                    rows, meta = run_path_method(

                        name=method, X0=X0, Ct=Ct, Ct_sqrt=Ct_sqrt, trCt=trCt, K=args.K,

                        device=device, rank=rank, early_stop_tau=early_stop_tau, path_cache={},

                    )

                elif method == "bw_gd":

                    rows, meta = run_bw_gd(

                        X0=X0, Ct=Ct, Ct_sqrt=Ct_sqrt, trCt=trCt, K=args.K,

                        eta0=args.bw_gd_eta, max_backtracks=args.bw_gd_max_backtracks, bt_shrink=args.bw_gd_bt_shrink,

                        device=device, rank=rank, early_stop_tau=early_stop_tau,

                    )

                elif method in ("sinkhorn", "sinkhorn_gaus"):

                    rows, meta = run_sinkhorn_method(

                        name=method, X0=X0, Ct=Ct, Ct_sqrt=Ct_sqrt, trCt=trCt, Xs=Xs_np, Xt=Xt_np,

                        K=args.K, eps=args.sinkhorn_eps, n_samples=args.sinkhorn_n, iters_per_step=args.sinkhorn_iters_per_step,

                        device=device, rank=rank, seed=seed, early_stop_tau=early_stop_tau, Cs=Cs,

                    )

                else:

                    raise ValueError(f"Unknown method: {method}")

            except Exception as e:

                status = "fail"

                err = repr(e)

                rows = [{

                    "step": 0, "t": 0.0, "bw2": float("nan"), "rel_bw2": float("nan"),

                    "t_update": 0.0, "t_proj": 0.0, "t_eval": 0.0, "t_total": 0.0,

                }]

                meta = {"bw0": float("nan")}



            for r in rows:

                r.update({

                    "dataset": os.path.basename(pack_dir.rstrip("/")),

                    "pack_dir": os.path.abspath(pack_dir),

                    "d": int(d),

                    "rank": int(rank) if rank else 0,

                    "K": int(args.K),

                    "method": method,

                    "seed": int(seed),

                    "status": status,

                    "error": err,

                })

            all_rows.extend(rows)



            method_min = min_bw(rows) if status == "ok" else float("nan")

            last = rows[-1]

            sum_row = {

                "dataset": os.path.basename(pack_dir.rstrip("/")),

                "pack_dir": os.path.abspath(pack_dir),

                "d": int(d),

                "rank": int(rank) if rank else 0,

                "K_budget": int(args.K),

                "method": method,

                "seed": int(seed),

                "status": status,

                "error": err,

                "bw0": float(meta.get("bw0", float("nan"))),

                "bw_min": float(method_min),

                "bw_last": float(last.get("bw2", float("nan"))),

                "rel_bw_last": float(last.get("rel_bw2", float("nan"))),

                "t_update_last": float(last.get("t_update", 0.0)),

                "t_proj_last": float(last.get("t_proj", 0.0)),

                "t_eval_last": float(last.get("t_eval", 0.0)),

                "t_total_last": float(last.get("t_total", 0.0)),

            }

            for tau in tau_report:

                sum_row[f"time_to_tau_{tau:g}"] = time_to_threshold(rows, tau=tau, time_key="t_total")

                sum_row[f"update_time_to_tau_{tau:g}"] = time_to_threshold(rows, tau=tau, time_key="t_update")

            # --- legacy fields for NTU posthoc/table compatibility ---
            sum_row["iters"] = int(last.get("step", 0))
            sum_row["W2^2_first"] = float(rows[0].get("bw2", float("nan")))
            sum_row["W2^2_last"]  = float(last.get("bw2", float("nan")))
            sum_row["wall_s_last"] = float(last.get("t_total", 0.0))
            sum_row["kernel_s_total"] = float(last.get("t_update", 0.0))
            sum_row["eval_s_total"] = float(last.get("t_eval", 0.0))
            sum_row["proj_s_total"] = float(last.get("t_proj", 0.0))
            summary_rows.append(sum_row)



    # floor = 2x global min BW across all methods

    global_min = float("inf")

    for r in all_rows:

        if r["status"] == "ok":

            bw = float(r.get("bw2", float("inf")))

            if bw < global_min:

                global_min = bw

    floor = 2.0 * global_min if global_min < float("inf") else float("nan")



    by_ms: Dict[Tuple[str, int], List[Dict]] = {}

    for r in all_rows:

        key = (r["method"], int(r["seed"]))

        by_ms.setdefault(key, []).append(r)

    for key in by_ms:

        by_ms[key].sort(key=lambda z: int(z["step"]))



    for sr in summary_rows:

        if sr["status"] != "ok":

            sr["floor_bw"] = floor

            sr["time_to_floor"] = None

            continue

        bw0 = sr["bw0"]

        if not (bw0 > 0.0) or not (floor == floor):

            sr["floor_bw"] = floor

            sr["time_to_floor"] = None

            continue

        tau_floor = floor / bw0

        sr["floor_bw"] = floor

        sr["time_to_floor"] = time_to_threshold(by_ms[(sr["method"], int(sr["seed"]))], tau=tau_floor, time_key="t_total")



    # Write curves.csv and summary.csv

    curves_path = os.path.join(out_dir, "curves.csv")

    fields_curves = sorted({k for r in all_rows for k in r.keys()})

    with open(curves_path, "w", newline="") as f:

        w = csv.DictWriter(f, fieldnames=fields_curves)

        w.writeheader()

        for r in all_rows:

            w.writerow(r)



    summary_path = os.path.join(out_dir, "summary.csv")

    fields_sum = sorted({k for r in summary_rows for k in r.keys()})

    with open(summary_path, "w", newline="") as f:

        w = csv.DictWriter(f, fieldnames=fields_sum)

        w.writeheader()

        for r in summary_rows:

            w.writerow(r)



    note = {

        "out_dir": os.path.abspath(out_dir),

        "curves_csv": os.path.abspath(curves_path),

        "summary_csv": os.path.abspath(summary_path),

        "global_min_bw": global_min,

        "floor_bw": floor,

        "precompute_time_s": pre_dt,

    }

    with open(os.path.join(out_dir, "run_note.json"), "w") as f:

        json.dump(note, f, indent=2)




    # --- In-run finalization (Protocol v2) ---

    # Canonical artifacts: icml_threshold_table.*, flow_contraction_update/total.*, icml_v2_table.*

    # Legacy NTU posthoc (W2, not squared) is optional and quarantined into legacy_ntu_posthoc/.

    try:

        import subprocess, sys, shutil



        # Threshold table (Protocol v2)

        thr = os.path.join(os.path.dirname(__file__), 'icml_finalize_threshold_table.py')

        cmd2 = [sys.executable, '-u', thr, '--run-dir', out_dir]

        print('[ICML runner] ICML threshold finalize:', ' '.join(cmd2))

        subprocess.run(cmd2, check=False)



        # v2 plots + v2 table

        v2 = os.path.join(os.path.dirname(__file__), 'icml_finalize_v2.py')

        cmdv2 = [sys.executable, '-u', v2,

                 '--run-dir', out_dir, '--seed-plot', '0',

                 '--taus', '1e-6,1e-9', '--focus-tau', '1e-6',

                 '--xmult-update', '5.0', '--xmult-total', '8.0',

                 '--max-decades', '10']

        print('[ICML runner] ICML v2 finalize:', ' '.join(cmdv2))

        subprocess.run(cmdv2, check=False)



        # Optional legacy diagnostics (kept but quarantined)

        if getattr(args, "run_legacy_ntu_posthoc", False):

            posthoc = os.path.join(os.path.dirname(__file__), 'run_ntu_icml_posthoc.py')

            cmd1 = [sys.executable, '-u', posthoc, '--dataset', args.dataset, '--out', args.out,

                    '--run-dir', out_dir, '--seed', '0']

            print('[ICML runner] LEGACY NTU posthoc (W2):', ' '.join(cmd1))

            subprocess.run(cmd1, check=False)



            legacy_dir = os.path.join(out_dir, "legacy_ntu_posthoc")

            os.makedirs(legacy_dir, exist_ok=True)

            for f in [

                "flow_contraction.png",

                "comparison_table.png", "comparison_table.csv",

                "path_energy_bars.png",

                "intermediate_mmd.png",

            ]:

                src = os.path.join(out_dir, f)

                if os.path.exists(src):

                    try:

                        shutil.move(src, os.path.join(legacy_dir, f))

                    except Exception:

                        pass

    except Exception as e:
        print(f'[WARN] finalization failed: {e}')


    print(f"[ICML runner] done. out_dir={out_dir}")

    print(f"  curves:  {curves_path}")

    print(f"  summary: {summary_path}")

    # [PATCH] robust print (avoid NameError on variable renames)
    try:
        gm = float(global_min_bw)
    except Exception:
        try:
            gm = float(global_min)
        except Exception:
            gm = float('nan')
    try:
        fb = float(floor_bw)
    except Exception:
        try:
            fb = 2.0 * gm
        except Exception:
            fb = float('nan')
    print(f"  global_min_bw={max(gm,0.0):.6g}, floor_bw(2x)={max(fb,0.0):.6g}")





# === ICML ORTH PATCH v4 START ===
def _pick_itspace_Q(env):
    """Pick the ITSPACE polar factor from locals(); prefers common names; fallback selects best orth residual."""
    try:
        import torch
    except Exception:
        torch = None
    # Prefer common names first
    for n in (
        'Q','Qk','Q_polar','Qhat','Q_hat','Qns','Q_ns','Qtilde','Q_tilde',
        'U','Uq','U_hat','Up','U_polar','U_k','U1','U2'
    ):
        if n in env:
            return env[n]
    # Fallback: choose any 2D torch tensor whose name contains 'q'/'u' with smallest ||QtQ-I||_F
    if torch is None:
        return None
    best = None
    best_r = float('inf')
    for k,v in env.items():
        name = str(k).lower()
        if ('q' not in name) and ('u' not in name):
            continue
        if not isinstance(v, torch.Tensor):
            continue
        if v.ndim != 2:
            continue
        r = int(v.shape[-1])
        if r <= 0 or r > 256:
            continue
        try:
            rr = _itspace_orth_residual(v)
            if rr == rr and rr < best_r:
                best = v
                best_r = rr
        except Exception:
            pass
    return best

def _itspace_orth_residual(Q):
    """Orthogonality residual ||Q^T Q - I||_F (torch or numpy)."""
    import math
    if Q is None:
        return float('nan')
    try:
        import torch
        if isinstance(Q, torch.Tensor):
            QtQ = Q.transpose(-2, -1) @ Q
            I = torch.eye(QtQ.shape[-1], device=Q.device, dtype=Q.dtype)
            return float(torch.linalg.norm(QtQ - I, ord='fro').item())
    except Exception:
        pass
    import numpy as np
    QtQ = Q.T @ Q
    I = np.eye(QtQ.shape[-1], dtype=QtQ.dtype)
    return float(np.linalg.norm(QtQ - I, ord='fro'))
# === ICML ORTH PATCH v4 END ===

if __name__ == "__main__":

    main()

