
#!/usr/bin/env python3

"""

CovDrift-MR (fixed, mechanism-correct) downstream runner.



Key design (integrity-safe, mechanism-correct):

- Uses ONLY Zs_train, ys_train from the leakage-safe packs (no target data required).

- Split Zs_train into train / unlab / test.

- Train a linear head on train only.

- Work in *source-whitened* space so Cs = I on train by construction.

- Apply synthetic rank-r covariance drift to (unlab,test) in whitened space with IDENTITY complement.

- Estimate Ct from drifted unlab **in the r-drift subspace** (stable, avoids harmful full-d noise).

- Run the same method list as alignment (itspace, bw_geodesic, bw_gd, euclidean, logeuclid, airm, coral, sinkhorn, sinkhorn_gaus).

- Apply correction as a subspace-only map (complement untouched).

- Add two oracles:

    Oracle-Undo (true inverse drift) and Oracle-CORAL(true Ct),

  which must match Undrifted(reference) if everything is correct.



This fixes:

- contracted-direction deletion (no rank truncation by largest eigenvalues),

- complement rescaling (complement is exactly identity),

- Camelyon AUROC invariance (via drift_subspace=headmix).

"""

from __future__ import annotations



import argparse

import json

import time

from dataclasses import asdict, dataclass

from pathlib import Path

from typing import Dict, List, Tuple



import numpy as np

import torch



from sklearn.linear_model import RidgeClassifier

from sklearn.metrics import roc_auc_score

from sklearn.preprocessing import StandardScaler





# -----------------------------

# helpers

# -----------------------------

def parse_int_list(s: str) -> List[int]:

    return [int(x) for x in s.replace(" ", "").split(",") if x.strip()]



def sym(A: torch.Tensor) -> torch.Tensor:

    return 0.5 * (A + A.transpose(-1, -2))



def eigh_spd(A: torch.Tensor, clamp: float) -> Tuple[torch.Tensor, torch.Tensor]:

    w, Q = torch.linalg.eigh(sym(A))

    w = torch.clamp(w, min=clamp)

    return w, Q



def sqrtm_spd(A: torch.Tensor, clamp: float) -> torch.Tensor:

    w, Q = eigh_spd(A, clamp)

    return Q @ torch.diag(torch.sqrt(w)) @ Q.T



def invsqrtm_spd(A: torch.Tensor, clamp: float) -> torch.Tensor:

    w, Q = eigh_spd(A, clamp)

    return Q @ torch.diag(1.0 / torch.sqrt(w)) @ Q.T



def logm_spd(A: torch.Tensor, clamp: float) -> torch.Tensor:

    w, Q = eigh_spd(A, clamp)

    return Q @ torch.diag(torch.log(w)) @ Q.T



def expm_sym(A: torch.Tensor) -> torch.Tensor:

    w, Q = torch.linalg.eigh(sym(A))

    return Q @ torch.diag(torch.exp(w)) @ Q.T



def powm_spd(A: torch.Tensor, t: float, clamp: float) -> torch.Tensor:

    w, Q = eigh_spd(A, clamp)

    return Q @ torch.diag(w ** t) @ Q.T



def polar_factor(A: torch.Tensor) -> torch.Tensor:

    U, _, Vh = torch.linalg.svd(A, full_matrices=False)

    return U @ Vh



def cov_np(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

    mu = X.mean(axis=0, keepdims=True)

    Xm = X - mu

    n = X.shape[0]

    C = (Xm.T @ Xm) / max(1, (n - 1))

    return mu, C



def shrink_cov(C: np.ndarray, gamma: float) -> np.ndarray:

    if gamma <= 0:

        return C

    d = C.shape[0]

    tr = float(np.trace(C))

    return (1.0 - gamma) * C + gamma * (tr / d) * np.eye(d, dtype=C.dtype)



def make_orthonormal(M: np.ndarray) -> np.ndarray:

    Q, _ = np.linalg.qr(M)

    return Q



def pad_orthonormal(U: np.ndarray, r: int, rng: np.random.Generator) -> np.ndarray:

    d, k = U.shape

    if k >= r:

        return U[:, :r]

    R = rng.standard_normal((d, r - k))

    R = R - U @ (U.T @ R)

    Qc, _ = np.linalg.qr(R)

    return np.concatenate([U, Qc[:, : (r - k)]], axis=1)



def head_aligned_subspace(head: RidgeClassifier, r: int, drift_seed: int) -> np.ndarray:

    W = head.coef_

    if W.ndim == 1:

        W = W.reshape(1, -1)

    _, svals, Vt = np.linalg.svd(W, full_matrices=False)

    k = min(r, Vt.shape[0])

    U0 = Vt[:k, :].T  # d×k

    U0 = make_orthonormal(U0)

    rng = np.random.default_rng(int(drift_seed))

    U = pad_orthonormal(U0, r, rng)

    if k > 1:

        order = np.argsort(-svals[:k])

        U[:, :k] = U[:, :k][:, order]

    return U



def random_subspace(d: int, r: int, drift_seed: int) -> np.ndarray:

    rng = np.random.default_rng(int(drift_seed))

    return make_orthonormal(rng.standard_normal((d, r)))



def sample_log_scales(r: int, smax: float, drift_seed: int, sort_by_magnitude: bool = True) -> np.ndarray:

    rng = np.random.default_rng(int(drift_seed) + 1337)

    a = float(np.log(smax))

    logs = rng.uniform(-a, a, size=r)

    if sort_by_magnitude:

        logs = logs[np.argsort(-np.abs(logs))]

    return np.exp(logs).astype(np.float64)



def apply_subspace_scaling(X: np.ndarray, U: np.ndarray, s: np.ndarray) -> np.ndarray:

    Xu = X @ U

    Xu2 = Xu * s.reshape(1, -1)

    return X + (Xu2 - Xu) @ U.T



def metric_name_and_value(head: RidgeClassifier, X: np.ndarray, y: np.ndarray) -> Tuple[str, float]:

    classes = getattr(head, "classes_", np.unique(y))

    if len(classes) == 2:

        scores = head.decision_function(X)

        return "roc_auc", float(roc_auc_score(y, scores) * 100.0)

    yhat = head.predict(X)

    return "acc", float((yhat == y).mean() * 100.0)





# -----------------------------

# methods in r×r space

# -----------------------------

def cov_path(method: str, Ct: torch.Tensor, k: int, Kmax: int, clamp: float) -> torch.Tensor:

    t = float(k) / float(Kmax)

    r = Ct.shape[0]

    I = torch.eye(r, device=Ct.device, dtype=Ct.dtype)



    if method == "euclidean":

        return sym((1.0 - t) * I + t * Ct)



    if method == "logeuclid":

        return expm_sym(t * logm_spd(Ct, clamp))



    if method == "airm":

        return powm_spd(Ct, t, clamp)



    if method == "bw_geodesic":

        Ct_s = sqrtm_spd(Ct, clamp)

        G = (1.0 - t) * I + t * Ct_s

        return sym(G @ G)



    raise ValueError(method)



def itspace_step(Y: torch.Tensor, Ct_s: torch.Tensor, lam: float) -> torch.Tensor:

    alpha = (2.0 * lam) / (1.0 + 2.0 * lam)

    Q = polar_factor(Ct_s @ Y)

    return alpha * (Ct_s @ Q) + (1.0 - alpha) * Y



def bw2(A: torch.Tensor, B: torch.Tensor, clamp: float) -> torch.Tensor:

    A = sym(A); B = sym(B)

    As = sqrtm_spd(A, clamp)

    inner = As @ B @ As

    inner_s = sqrtm_spd(inner, clamp)

    return torch.trace(A) + torch.trace(B) - 2.0 * torch.trace(inner_s)



def transport_map_bw(A: torch.Tensor, B: torch.Tensor, clamp: float) -> torch.Tensor:

    A = sym(A); B = sym(B)

    As = sqrtm_spd(A, clamp)

    Ais = invsqrtm_spd(A, clamp)

    inner = As @ B @ As

    inner_s = sqrtm_spd(inner, clamp)

    return Ais @ inner_s @ Ais



def bw_gd_step(X: torch.Tensor, Ct: torch.Tensor, eta: float, max_bt: int, bt_shrink: float, clamp: float) -> torch.Tensor:

    r = Ct.shape[0]

    I = torch.eye(r, device=Ct.device, dtype=Ct.dtype)

    obj0 = float(bw2(X, Ct, clamp).detach().cpu())

    step = float(eta)

    for _ in range(max_bt + 1):

        T = transport_map_bw(X, Ct, clamp)

        G = (1.0 - step) * I + step * T

        X_try = sym(G @ X @ G)

        obj1 = float(bw2(X_try, Ct, clamp).detach().cpu())

        if obj1 <= obj0 + 1e-12:

            return X_try

        step *= float(bt_shrink)

    return X



def sinkhorn_barycentric_step(Xs: torch.Tensor, Xt: torch.Tensor, sigma0: torch.Tensor, eps: float, iters: int) -> torch.Tensor:

    import run_experiments_icml as gap  # repo-local



    # IMPORTANT: In this repo version, sinkhorn_barycentric_cov does NOT accept sigma0 / warm-start.

    # Passing sigma0 as the 3rd positional arg breaks because eps must be a scalar.

    out = gap.sinkhorn_barycentric_cov(Xs, Xt, float(eps), int(iters))



    sig1 = out[0] if isinstance(out, tuple) else out

    return sym(sig1)



def find_pack(dataset_dir: Path, seed: int) -> Path:

    cands = [

        dataset_dir / "downstream" / f"pack_seed{seed}.npz",

        dataset_dir / "downstream" / f"pack_seed{seed}_closedset_eval.npz",

        dataset_dir / "downstream_packs" / f"seed_{seed}.npz",

    ]

    for p in cands:

        if p.exists():

            return p

    raise FileNotFoundError(f"Could not find downstream pack for seed={seed} under {dataset_dir}")



def load_Zs_ys(dataset_dir: Path, seed: int) -> Tuple[np.ndarray, np.ndarray]:

    p = find_pack(dataset_dir, seed)

    data = np.load(p)

    Zs = data["Zs_train"].astype(np.float64)

    ys = data["ys_train"].astype(np.int64)

    return Zs, ys





@dataclass

class Cfg:

    dataset: str

    out: str

    seeds: List[int]

    device: str

    dtype: str



    train_frac: float

    unlab_frac: float

    test_frac: float



    rank: int

    Ks: List[int]

    methods: List[str]



    shrinkage: float

    lam_floor_mult: float

    use_standard_scaler: bool



    drift_rank: int

    drift_smax: float

    drift_seed: int

    drift_subspace: str



    itspace_lambda: float

    bw_gd_eta: float

    bw_gd_max_backtracks: int

    bw_gd_bt_shrink: float



    sinkhorn_n: int

    sinkhorn_eps: float

    sinkhorn_iters_per_step: int





def main() -> None:

    ap = argparse.ArgumentParser()

    ap.add_argument("--dataset", required=True)

    ap.add_argument("--out", required=True)

    ap.add_argument("--seeds", default="0,1,2")

    ap.add_argument("--device", default="auto")

    ap.add_argument("--dtype", default="float64", choices=["float64", "float32"])



    ap.add_argument("--train-frac", type=float, default=0.70)

    ap.add_argument("--unlab-frac", type=float, default=0.15)

    ap.add_argument("--test-frac", type=float, default=0.15)



    ap.add_argument("--rank", type=int, default=16)

    ap.add_argument("--drift-rank", type=int, default=16)

    ap.add_argument("--drift-smax", type=float, default=1.70)

    ap.add_argument("--drift-seed", type=int, default=0)

    ap.add_argument("--drift-subspace", default="headmix", choices=["random", "head", "headmix"])



    ap.add_argument("--Ks", default="1,2,5,20")

    ap.add_argument("--methods", default="no_adapt,itspace,bw_geodesic,bw_gd,euclidean,logeuclid,airm,coral,sinkhorn,sinkhorn_gaus")



    ap.add_argument("--shrinkage", type=float, default=0.05)

    ap.add_argument("--lam-floor-mult", type=float, default=1e-6)



    ap.add_argument("--use-standard-scaler", action="store_true")

    ap.add_argument("--no-standard-scaler", dest="use_standard_scaler", action="store_false")

    ap.set_defaults(use_standard_scaler=True)



    ap.add_argument("--itspace-lambda", type=float, default=6.0)

    ap.add_argument("--bw-gd-eta", type=float, default=0.3)

    ap.add_argument("--bw-gd-max-backtracks", type=int, default=30)

    ap.add_argument("--bw-gd-bt-shrink", type=float, default=0.5)



    ap.add_argument("--sinkhorn-n", type=int, default=512)

    ap.add_argument("--sinkhorn-eps", type=float, default=0.3)

    ap.add_argument("--sinkhorn-iters-per-step", type=int, default=10)



    args = ap.parse_args()



    if int(args.rank) != int(args.drift_rank):

        raise ValueError(f"Matched-rank required: --rank must equal --drift-rank. Got {args.rank} vs {args.drift_rank}.")



    seeds = parse_int_list(args.seeds)

    Ks = sorted(set(parse_int_list(args.Ks)))

    methods = [m.strip() for m in args.methods.split(",") if m.strip()]

    Kmax = max(Ks) if Ks else 1



    dev = torch.device("cuda" if (args.device == "auto" and torch.cuda.is_available()) else args.device)

    dtype = torch.float64 if args.dtype == "float64" else torch.float32



    cfg = Cfg(

        dataset=args.dataset,

        out=args.out,

        seeds=seeds,

        device=str(dev),

        dtype=args.dtype,

        train_frac=float(args.train_frac),

        unlab_frac=float(args.unlab_frac),

        test_frac=float(args.test_frac),

        rank=int(args.rank),

        Ks=Ks,

        methods=methods,

        shrinkage=float(args.shrinkage),

        lam_floor_mult=float(args.lam_floor_mult),

        use_standard_scaler=bool(args.use_standard_scaler),

        drift_rank=int(args.drift_rank),

        drift_smax=float(args.drift_smax),

        drift_seed=int(args.drift_seed),

        drift_subspace=str(args.drift_subspace),

        itspace_lambda=float(args.itspace_lambda),

        bw_gd_eta=float(args.bw_gd_eta),

        bw_gd_max_backtracks=int(args.bw_gd_max_backtracks),

        bw_gd_bt_shrink=float(args.bw_gd_bt_shrink),

        sinkhorn_n=int(args.sinkhorn_n),

        sinkhorn_eps=float(args.sinkhorn_eps),

        sinkhorn_iters_per_step=int(args.sinkhorn_iters_per_step),

    )



    out_root = Path(cfg.out)

    out_root.mkdir(parents=True, exist_ok=True)



    ds_id = Path(cfg.dataset).name

    ts = time.strftime("%Y%m%d_%H%M%S")

    run_dir = out_root / f"{ds_id}_covdrift_mr_fixed_r{cfg.rank}_smax{cfg.drift_smax:g}_{ts}"

    run_dir.mkdir(parents=True, exist_ok=True)

    (out_root / "run_dir.txt").write_text(str(run_dir))

    (run_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2))



    rows: List[Dict[str, object]] = []

    diag_rows: List[Dict[str, object]] = []



    for seed in cfg.seeds:

        t_shared0 = time.perf_counter()



        Zs, ys = load_Zs_ys(Path(cfg.dataset), seed)

        N, d = Zs.shape

        rng = np.random.default_rng(int(seed))

        perm = rng.permutation(N)



        n_train = int(round(cfg.train_frac * N))

        n_unlab = int(round(cfg.unlab_frac * N))

        n_train = min(n_train, N - 2)

        n_unlab = min(n_unlab, N - n_train - 1)

        n_test = N - n_train - n_unlab

        if n_test <= 0:

            raise ValueError("Bad split; adjust fractions.")



        idx_tr = perm[:n_train]

        idx_ul = perm[n_train:n_train + n_unlab]

        idx_te = perm[n_train + n_unlab:]



        Ztr, ytr = Zs[idx_tr], ys[idx_tr]

        Zul, yul_dummy = Zs[idx_ul], ys[idx_ul]  # labels unused

        Zte, yte = Zs[idx_te], ys[idx_te]



        # StandardScaler on train only (source-only)

        if cfg.use_standard_scaler:

            sc = StandardScaler(with_mean=True, with_std=True)

            Ztr_s = sc.fit_transform(Ztr)

            Zul_s = sc.transform(Zul)

            Zte_s = sc.transform(Zte)

        else:

            Ztr_s, Zul_s, Zte_s = Ztr, Zul, Zte



        # Full whitening from train covariance: X = (Z - mu) Cs^{-1/2}

        mu_s, Cs = cov_np(Ztr_s)

        lam_floor = cfg.lam_floor_mult * float(np.trace(Cs) / d)

        Cs_t = torch.as_tensor(Cs, device=dev, dtype=dtype)

        Cs_is = invsqrtm_spd(Cs_t, clamp=float(lam_floor))

        mu_s_t = torch.as_tensor(mu_s, device=dev, dtype=dtype)



        Xtr = (torch.as_tensor(Ztr_s, device=dev, dtype=dtype) - mu_s_t) @ Cs_is

        Xul = (torch.as_tensor(Zul_s, device=dev, dtype=dtype) - mu_s_t) @ Cs_is

        Xte = (torch.as_tensor(Zte_s, device=dev, dtype=dtype) - mu_s_t) @ Cs_is



        Xtr_np = Xtr.detach().cpu().numpy()

        Xul_np = Xul.detach().cpu().numpy()

        Xte_np = Xte.detach().cpu().numpy()



        # Train head on undrifted train

        head = RidgeClassifier(alpha=1.0)

        head.fit(Xtr_np, ytr)



        # Drift subspace Udr in whitened space

        r = int(cfg.rank)

        if cfg.drift_subspace in ("head", "headmix"):

            Uspan = head_aligned_subspace(head, r=r, drift_seed=cfg.drift_seed)  # d×r

            if cfg.drift_subspace == "headmix":

                rngm = np.random.default_rng(int(cfg.drift_seed))

                Q, _ = np.linalg.qr(rngm.standard_normal((r, r)))

                Udr = Uspan @ Q

            else:

                Udr = Uspan

            s = sample_log_scales(r=r, smax=cfg.drift_smax, drift_seed=cfg.drift_seed, sort_by_magnitude=True)

        else:

            Udr = random_subspace(d=d, r=r, drift_seed=cfg.drift_seed)

            s = sample_log_scales(r=r, smax=cfg.drift_smax, drift_seed=cfg.drift_seed, sort_by_magnitude=False)



        # Apply drift to unlab/test (train remains undrifted)

        Xul_d = apply_subspace_scaling(Xul_np, Udr, s)

        Xte_d = apply_subspace_scaling(Xte_np, Udr, s)



        # Oracles (pointwise, uses known drift)

        Xte_oracle_undo = apply_subspace_scaling(Xte_d, Udr, 1.0 / s)

        Ct_true = torch.diag(torch.as_tensor((s ** 2).astype(np.float64), device=dev, dtype=dtype))



        # Estimate Ct in drift subspace (stable)

        Xul_dU = Xul_d @ Udr  # n_unlab×r

        _, CtU = cov_np(Xul_dU)

        CtU = shrink_cov(CtU, gamma=cfg.shrinkage)

        Ct = torch.as_tensor(CtU, device=dev, dtype=dtype)

        # --- Diagnostics (label-free): Ct_est spectrum + distance to Ct_true in r-space ---

        clamp_r = float(cfg.lam_floor_mult)

        try:

            eig = np.linalg.eigvalsh(CtU)

            eigmin = float(eig[0]); eigmax = float(eig[-1])

            cond = float(eigmax / max(eigmin, 1e-12))

            true_eig = (s ** 2).astype(np.float64)

            loge = np.sort(np.log(np.clip(eig, 1e-12, None)))

            logt = np.sort(np.log(np.clip(true_eig, 1e-12, None)))

            logeig_rmse = float(np.linalg.norm(loge - logt) / np.sqrt(len(eig)))

            bw2_to_true = float(bw2(Ct, Ct_true, clamp=clamp_r).detach().cpu())

            diag_rows.append(dict(

                seed=int(seed),

                n_train=int(n_train), n_unlab=int(n_unlab), n_test=int(n_test),

                ct_eigmin=eigmin, ct_eigmax=eigmax, ct_cond=cond,

                logeig_rmse=logeig_rmse,

                bw2_to_true=bw2_to_true,

            ))

        except Exception as e:

            diag_rows.append(dict(seed=int(seed), error=str(e)))





        t_shared = time.perf_counter() - t_shared0



        # metric references

        mname, undrifted = metric_name_and_value(head, Xte_np, yte)

        _, no_adapt = metric_name_and_value(head, Xte_d, yte)

        _, oracle_undo = metric_name_and_value(head, Xte_oracle_undo, yte)



        # Evaluate with a given r×r covariance Xk (map = invsqrt(Xk) in drift coords)

        clamp_r = float(cfg.lam_floor_mult)



        def eval_with_cov(Xk: torch.Tensor) -> Tuple[float, float]:

            t0 = time.perf_counter()

            Xk_is = invsqrtm_spd(Xk, clamp=clamp_r).detach().cpu().numpy()

            ZU = Xte_d @ Udr

            ZU_corr = ZU @ Xk_is

            Zcorr = Xte_d + (ZU_corr - ZU) @ Udr.T

            t_map = time.perf_counter() - t0

            _, mval = metric_name_and_value(head, Zcorr, yte)

            return mval, t_map



        oracle_coral_true, _ = eval_with_cov(Ct_true)



        # record refs

        for ref_name, ref_val in [

            ("Undrifted (reference)", undrifted),

            ("No adapt", no_adapt),

            ("Oracle-Undo (sanity)", oracle_undo),

            ("Oracle-CORAL (true Ct)", oracle_coral_true),

        ]:

            for k in Ks:

                rows.append(dict(

                    seed=int(seed),

                    method=ref_name,

                    K=int(k),

                    metric=mname,

                    value=float(ref_val),

                    t_shared_s=float(t_shared),

                    t_adapt_s=0.0,

                ))



        I = torch.eye(r, device=dev, dtype=dtype)



        # CORAL endpoint (estimated Ct)

        if "coral" in methods:

            for k in Ks:

                mval, t_map = eval_with_cov(Ct)

                rows.append(dict(seed=int(seed), method="CORAL", K=int(k), metric=mname, value=float(mval),

                                 t_shared_s=float(t_shared), t_adapt_s=float(t_map)))



        # Closed-form paths

        for pm, name in [("bw_geodesic","BW-geodesic"), ("euclidean","Euclidean"), ("logeuclid","Log-Euclidean"), ("airm","AIRM")]:

            if pm not in methods:

                continue

            for k in Ks:

                t0 = time.perf_counter()

                Xk = cov_path(pm, Ct, k=k, Kmax=Kmax, clamp=clamp_r)

                mval, t_map = eval_with_cov(Xk)

                rows.append(dict(seed=int(seed), method=name, K=int(k), metric=mname, value=float(mval),

                                 t_shared_s=float(t_shared), t_adapt_s=float((time.perf_counter()-t0) + t_map)))



        # ITSPACE (anytime)

        if "itspace" in methods:

            Ct_s = sqrtm_spd(Ct, clamp=clamp_r)

            Y = I.clone()

            t0 = time.perf_counter()

            for k in range(1, Kmax + 1):

                Y = itspace_step(Y, Ct_s, lam=float(cfg.itspace_lambda))

                if k in Ks:

                    Xk = sym(Y @ Y.T)

                    mval, t_map = eval_with_cov(Xk)

                    rows.append(dict(seed=int(seed), method="ITSPACE", K=int(k), metric=mname, value=float(mval),

                                     t_shared_s=float(t_shared), t_adapt_s=float((time.perf_counter()-t0) + t_map)))



        # BW-GD (anytime)

        if "bw_gd" in methods:

            X = I.clone()

            t0 = time.perf_counter()

            for k in range(1, Kmax + 1):

                X = bw_gd_step(X, Ct, eta=float(cfg.bw_gd_eta), max_bt=int(cfg.bw_gd_max_backtracks),

                               bt_shrink=float(cfg.bw_gd_bt_shrink), clamp=clamp_r)

                if k in Ks:

                    mval, t_map = eval_with_cov(X)

                    rows.append(dict(seed=int(seed), method="BW-GD", K=int(k), metric=mname, value=float(mval),

                                     t_shared_s=float(t_shared), t_adapt_s=float((time.perf_counter()-t0) + t_map)))



        # Sinkhorn (sample-based) in r-dim

        if "sinkhorn" in methods:

            n = int(cfg.sinkhorn_n)

            rr = np.random.default_rng(int(seed) + 999)

            XsU = Xtr_np @ Udr

            XtU = Xul_d @ Udr

            if XsU.shape[0] > n:

                XsU = XsU[rr.choice(XsU.shape[0], size=n, replace=False)]

            if XtU.shape[0] > n:

                XtU = XtU[rr.choice(XtU.shape[0], size=n, replace=False)]

            Xs = torch.as_tensor(XsU, device=dev, dtype=torch.float32)

            Xt = torch.as_tensor(XtU, device=dev, dtype=torch.float32)

            sigma0 = I.clone()

            t0 = time.perf_counter()

            for k in range(1, Kmax + 1):

                sigma0 = sinkhorn_barycentric_step(Xs, Xt, sigma0, eps=float(cfg.sinkhorn_eps), iters=int(cfg.sinkhorn_iters_per_step)).to(dev, dtype)

                if k in Ks:

                    mval, t_map = eval_with_cov(sigma0)

                    rows.append(dict(seed=int(seed), method="Sinkhorn", K=int(k), metric=mname, value=float(mval),

                                     t_shared_s=float(t_shared), t_adapt_s=float((time.perf_counter()-t0) + t_map)))



        # Sinkhorn-Gaussian

        if "sinkhorn_gaus" in methods:

            n = int(cfg.sinkhorn_n)

            Xs = torch.randn(n, r, device=dev, dtype=torch.float32)

            Ct_s32 = sqrtm_spd(Ct, clamp=clamp_r).to(torch.float32)

            Xt = torch.randn(n, r, device=dev, dtype=torch.float32) @ Ct_s32.T

            sigma0 = I.clone()

            t0 = time.perf_counter()

            for k in range(1, Kmax + 1):

                sigma0 = sinkhorn_barycentric_step(Xs, Xt, sigma0, eps=float(cfg.sinkhorn_eps), iters=int(cfg.sinkhorn_iters_per_step)).to(dev, dtype)

                if k in Ks:

                    mval, t_map = eval_with_cov(sigma0)

                    rows.append(dict(seed=int(seed), method="Sinkhorn-Gaussian", K=int(k), metric=mname, value=float(mval),

                                     t_shared_s=float(t_shared), t_adapt_s=float((time.perf_counter()-t0) + t_map)))



    import pandas as pd

    df = pd.DataFrame(rows)

    df.to_csv(run_dir / "results.csv", index=False)

    if diag_rows:
        pd.DataFrame(diag_rows).to_csv(run_dir / "diagnostics.csv", index=False)




    print(f"WROTE: {run_dir}")

    print("  - config.json")

    print("  - results.csv")

    print("run_dir:", run_dir)





if __name__ == "__main__":

    main()

