# Gradient flow on mixed-curvature space using MCSTSW

from __future__ import annotations
import os
import io
import math
import argparse
from typing import Tuple, Optional, List, Dict, Any

import torch
from torch import nn
from tqdm.auto import trange
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append('../')

# Local imports
from utils.MCS import MCS, project_inside_ball, _lambda_x_K_single
from methods.mcstsw import MCSTSW

try:
    from utils.mixture_mcs import learn_mus_by_spread, mixture_sample, mixture_log_prob
except Exception as _e:
    learn_mus_by_spread = None
    mixture_sample = None
    mixture_log_prob = None

def create_swd(p):
    def pseudo_swd(X, Y):
        return swd(X, Y, p)
    return pseudo_swd


def set_seed(seed: int | None):
    if seed is None:
        return
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def parse_K(K_str: str, ncomp: int, device: str) -> torch.Tensor:
    if K_str is None:
        vals = [-1.0] * ncomp
    else:
        vals = [float(x.strip()) for x in K_str.split(',') if x.strip() != '']
        if len(vals) == 1 and ncomp > 1:
            vals = vals * ncomp
        assert len(vals) == ncomp, f"Expected {ncomp} curvatures, got {len(vals)}"
    return torch.tensor(vals, dtype=torch.get_default_dtype(), device=device)


def rgrad_step(mcs: MCS, x: torch.Tensor, grad_eucl: torch.Tensor, lr: float) -> torch.Tensor:
    if mcs is not None:
        lam = _lambda_x_K_single(x, mcs.K)  # (A, N)
        scale = (lam * lam).unsqueeze(-1)   # (A, N, 1)
        v = -lr * (grad_eucl / scale)       # (A, N, M)
        x_next = mcs.exp_map_pairwise(x.detach(), v)
        x_next = project_inside_ball(x_next, mcs.K, eps=1e-7)
        return x_next
    else:
        v = -lr * grad_eucl
        x_next = x + v
        return x_next

def _extract_2d(arr: torch.Tensor, comp: int = 0, dims: Tuple[int, int] = (0, 1)):
    a = arr[:, comp, :]
    i, j = int(dims[0]), int(dims[1])
    return a[:, [i, j]].detach().cpu().numpy()

def _get_radius_from_K_val(K_val: float):
    if K_val == 0.0 or not np.isfinite(K_val):
        return None
    return float(1.0 / math.sqrt(abs(K_val)))

def _draw_background(ax, K_val: float):
    r = _get_radius_from_K_val(K_val)
    if r is None:
        return
    circ = plt.Circle((0.0, 0.0), r, fill=False, linewidth=1.5, linestyle='--' if K_val > 0 else '-')
    ax.add_patch(circ)
    ax.set_xlim(-r * 1.05, r * 1.05)
    ax.set_ylim(-r * 1.05, r * 1.05)

def plot_dist_2d(arr: torch.Tensor, *, comp: int = 0, dims: Tuple[int, int] = (0, 1),
                 fig_path: str | None = None, label: str | None = None, K: torch.Tensor | None = None):
    X2 = _extract_2d(arr, comp=comp, dims=dims)
    fig, ax = plt.subplots(figsize=(6, 6))
    K_val = float(K.view(-1)[comp].detach().cpu()) if K is not None else 0.0
    if K_val != 0.0:
        _draw_background(ax, K_val)
    ax.scatter(X2[:, 0], X2[:, 1], s=8, alpha=0.85, linewidths=0)
    ax.set_xlabel(f"comp {comp}, dim {dims[0]}")
    ax.set_ylabel(f"comp {comp}, dim {dims[1]}")
    ax.set_xlim(-1.5, 1.5); ax.set_ylim(-1.5, 1.5)
    ttl = (label or "distribution") + (f" | K={K_val:+.3g}" if K is not None else "")
    if K_val < 0: ttl += " (Poincaré disk)"
    if K_val > 0: ttl += " (sphere chart)"
    ax.set_title(ttl); ax.axis('equal'); fig.tight_layout()
    if fig_path:
        os.makedirs(os.path.dirname(fig_path) or '.', exist_ok=True)
        fig.savefig(fig_path, dpi=200)
    plt.close(fig)

def plot_snapshots_2d(Y: torch.Tensor, snapshots: dict, *, comp: int = 0, dims: Tuple[int, int] = (0, 1),
                      fig_path: str | None = None, title: str | None = None):
    Y2 = _extract_2d(Y, comp=comp, dims=dims)
    steps_sorted = sorted(snapshots.keys())
    plt.figure(figsize=(7, 7))
    plt.scatter(Y2[:, 0], Y2[:, 1], s=8, alpha=0.4, label='target ν', linewidths=0)
    for k in steps_sorted:
        Xk = snapshots[k]
        X2 = _extract_2d(Xk, comp=comp, dims=dims)
        ms = 10 if k == steps_sorted[-1] else 6
        plt.scatter(X2[:, 0], X2[:, 1], s=ms, alpha=0.7, label=f"μ @ step {k}", linewidths=0)
    plt.xlabel(f"comp {comp}, dim {dims[0]}"); plt.ylabel(f"comp {comp}, dim {dims[1]}")
    if title: plt.title(title)
    plt.legend(loc='best', fontsize=8); plt.axis('equal'); plt.tight_layout()
    if fig_path:
        os.makedirs(os.path.dirname(fig_path) or '.', exist_ok=True)
        plt.savefig(fig_path, dpi=200)
    plt.close()

def _render_overlay_image(X: torch.Tensor, Y: torch.Tensor, *, comp: int, dims: Tuple[int, int], K: Optional[torch.Tensor]) -> Any:
    X2 = _extract_2d(X, comp=comp, dims=dims); Y2 = _extract_2d(Y, comp=comp, dims=dims)
    fig, ax = plt.subplots(figsize=(6, 6))
    K_val = float(K.view(-1)[comp].detach().cpu()) if (K is not None) else 0.0
    if K_val != 0.0: _draw_background(ax, K_val)
    ax.scatter(Y2[:, 0], Y2[:, 1], s=6, alpha=0.35, linewidths=0, label='ν (target)')
    ax.scatter(X2[:, 0], X2[:, 1], s=8, alpha=0.9, linewidths=0, label='μ (current)')
    ax.set_xlabel(f"comp {comp}, dim {dims[0]}"); ax.set_ylabel(f"comp {comp}, dim {dims[1]}")
    ax.set_title("Overlay: μ vs ν"); ax.legend(loc='best', fontsize=8); ax.axis('equal'); fig.tight_layout()
    return fig

def save_frames(Y: torch.Tensor, snapshots: dict, *, comp: int, dims: Tuple[int, int], fig_dir: str, K: torch.Tensor | None = None):
    os.makedirs(fig_dir, exist_ok=True)
    plot_dist_2d(Y, comp=comp, dims=dims, fig_path=os.path.join(fig_dir, 'target.png'), label='target ν', K=K)
    for step in sorted(snapshots.keys()):
        fp = os.path.join(fig_dir, f'step_{int(step):04d}.png')
        plot_dist_2d(snapshots[step], comp=comp, dims=dims, fig_path=fp, label=f'μ @ step {int(step)}', K=K)

def prepare_or_load_dataset(*, mcstsw_like, mcs: MCS, n_samples: int, r_per_comp: float,
                            s: float, dataset_path: str, seed: int | None):
    """
    Returns (Y, prior_params). If dataset_path exists, loads it; otherwise creates and saves it.
    """
    os.makedirs(os.path.dirname(dataset_path) or ".", exist_ok=True)
    if os.path.isfile(dataset_path):
        bundle = torch.load(dataset_path, map_location=mcs.device)
        Y = bundle["Y"].to(mcs.device)
        prior_params = {k: v.to(mcs.device) if torch.is_tensor(v) else v for k, v in bundle["prior_params"].items()}
        return Y, prior_params

    if learn_mus_by_spread is None or mixture_sample is None:
        raise RuntimeError("utils/mixture_mcs.py (learn_mus_by_spread, mixture_sample) not found/importable.")

    set_seed(seed)
    Kcenters = 6  # fixed as requested
    out = learn_mus_by_spread(
        mcs,
        n_components=Kcenters,
        radius_per_comp=r_per_comp,
        steps=10000, lr=0.2, momentum=0.9, seed=42, verbose=True
    )
    mus = out["mus"]  # (K, N, M)
    sigmas  = torch.full((mcs.N,), s, device=mcs.device, dtype=torch.get_default_dtype())
    weights = torch.full((Kcenters,), 1.0 / Kcenters, device=mcs.device, dtype=torch.get_default_dtype())
    prior_params = {"mus": mus, "sigmas": sigmas, "weights": weights}

    Z, _ = mixture_sample(mcs, prior_params, n_samples=n_samples)  # (n_samples, N, M)

    torch.save({"Y": Z.detach().cpu(), "prior_params": {k: (v.detach().cpu() if torch.is_tensor(v) else v)
              for k, v in prior_params.items()}}, dataset_path)
    return Z, prior_params


def compute_nll(mcs: MCS, prior_params: Dict[str, Any], X: torch.Tensor) -> float:
    """
    Negative log-likelihood of X under the fixed mixture prior (mean over particles).
    """
    if mixture_log_prob is None:
        return float('nan')
    with torch.no_grad():
        logp = mixture_log_prob(mcs, prior_params, X)  # shape (A,) or (A,1)
        if logp.ndim > 1:
            logp = logp.squeeze(-1)
        nll = -logp.mean().item()
    return nll

def prepare_or_load_source(*, mcs: MCS, n_samples: int, s_source: float,
                           source_path: str, seed: int | None, ncomp: int, dcomp: int) -> torch.Tensor:
    """
    Returns X0. If source_path exists, loads it; otherwise creates and saves it.
    μ0 is sampled as wrapped-normal around the origin with tangent std s_source (per component).
    """
    os.makedirs(os.path.dirname(source_path) or ".", exist_ok=True)
    if os.path.isfile(source_path):
        X0 = torch.load(source_path, map_location=mcs.device)
        if isinstance(X0, dict) and "X0" in X0:
            X0 = X0["X0"]
        return X0.to(device=mcs.device, dtype=torch.get_default_dtype())

    set_seed(seed)
    origin = torch.zeros(ncomp, dcomp, device=mcs.device, dtype=torch.get_default_dtype())
    sigma_vec = torch.full((ncomp,), s_source, device=mcs.device, dtype=torch.get_default_dtype())
    X0 = mcs.sample_wrap_normal(origin, sigma=sigma_vec, batch=(n_samples,))  # (A, N, M)

    X0 = project_inside_ball(X0, mcs.K, eps=1e-7)

    torch.save({"X0": X0.detach().cpu()}, source_path)
    return X0

def append_metrics_line(metrics_path: str, step: int, log2_w2: Optional[float], nll: Optional[float], write_header_if_empty: bool=True):
    os.makedirs(os.path.dirname(metrics_path) or ".", exist_ok=True)
    header_needed = write_header_if_empty and (not os.path.isfile(metrics_path) or os.path.getsize(metrics_path) == 0)
    with open(metrics_path, "a", encoding="utf-8") as f:
        if header_needed:
            f.write("step,log2_w2,nll\n")
        s_log2w2 = "" if (log2_w2 is None or not np.isfinite(log2_w2)) else f"{log2_w2:.8f}"
        s_nll    = "" if (nll is None or not np.isfinite(nll)) else f"{nll:.8f}"
        f.write(f"{step},{s_log2w2},{s_nll}\n")

def run_gradient_flow(
    *,
    loss_type: str,
    ncomp: int,
    dcomp: int,
    K: torch.Tensor,
    npoints: int,
    sigma: float,
    steps: int,
    lr: float,
    ntrees: int,
    nlines: int,
    p: int,
    delta: float,
    resample_trees: bool,
    fixed_trees_seed: int | None,
    device: str,
    seed: int | None,
    save_path: str | None,
    verbose: bool = True,
    snapshot_every: int = 100,
    eval: bool = False,
    eval_every: int = 50,
    log_every: int = 10,
    wandb_run: Optional[Any] = None,
    wandb_log_images_every: int = 0,
    vis_comp: int = 0,
    vis_dims: Tuple[int, int] = (0, 1),
    # metrics & prior
    metrics_path: Optional[str] = None,
    prior_params: Optional[Dict[str, Any]] = None,
    # NEW: pass target directly
    Y_target: torch.Tensor | None = None,
    X_source: torch.Tensor | None = None,
):


    set_seed(seed)

    # Instantiate geometry + distance
    if loss_type == "MCSTSW":
        mcstsw = MCSTSW(ncomp=ncomp, dcomp=dcomp, initK=K.clone(), ntrees=ntrees, nlines=nlines,
                        p=p, delta=delta, fixK=True, device=device)
        mcs = mcstsw.mcs
    else:
        raise ValueError(f"Unknown loss_type: {loss_type}")

    # ---- Target ν: load (prebuilt) mixture dataset ----
    # The caller is responsible for building/loading Y and prior_params; we just receive Y via closure.
    # We'll set Y after creation in main() to keep run_gradient_flow pure in its responsibilities.
    raise_if_y_missing = True  # safety
    # We'll temporarily store Y in a closure var via list to allow reassignment.
    Y_holder: List[torch.Tensor] = []

    def set_target_Y(Y: torch.Tensor):
        Y_holder.clear()
        Y_holder.append(Y)

    def get_target_Y() -> torch.Tensor:
        if not Y_holder:
            if raise_if_y_missing:
                raise RuntimeError("Target Y not set in run_gradient_flow().")
        return Y_holder[0]

    # ---- Target ν: use provided Y_target directly ----
    if Y_target is None:
        raise RuntimeError("Y_target must be provided to run_gradient_flow().")
    Y = Y_target.to(device=device, dtype=torch.get_default_dtype())
    A = Y.shape[0]

    if X_source is None:
        # Fallback to previous behavior if no source is provided
        origin = torch.zeros(ncomp, dcomp, device=device, dtype=torch.get_default_dtype())
        sigma_vec = torch.full((ncomp,), 0.1, device=device, dtype=torch.get_default_dtype())
        X = mcs.sample_wrap_normal(origin, sigma=sigma_vec, batch=(A,))
    else:
        X = X_source.to(device=device, dtype=torch.get_default_dtype())
        if X.shape[0] != A:
            # Align to target size
            if X.shape[0] > A:
                X = X[:A]
            else:
                # pad by repeating
                pad = X[: (A - X.shape[0])]
                X = torch.cat([X, pad], dim=0)

    # X = project_inside_ball(X, mcs.K, eps=1e-7)
    X.requires_grad_(True)

    # Prepare (fixed) tree frames if requested
    fixed_frames: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
    if not resample_trees:
        set_seed(fixed_trees_seed)
        root, intercept, intercept_index = mcstsw.generate_trees_frames()
        fixed_frames = (root, intercept, intercept_index)

    # Snapshots
    snapshots: Dict[int, torch.Tensor] = {0: X.detach().cpu().clone()}


    last_loss = None
    pbar = trange(1, steps + 1, disable=not verbose, desc="GF (mcstsw)")

    log_w2_losses: List[float] = []

    # -------- Inner step loop --------
    for step in pbar:
        # lazily fetch Y and align shapes at first step
        if step == 1:
            # Y = get_target_Y()
            Y = Y_target
            A = Y.shape[0]
            X.requires_grad_(True)
            snapshots[0] = X.detach().cpu().clone()

        X.grad = None

        if resample_trees:
            root, intercept, intercept_index = mcstsw.generate_trees_frames()
        else:
            root, intercept, intercept_index = fixed_frames  # type: ignore

        # Y = get_target_Y()
        loss = mcstsw(X, Y, root, intercept, intercept_index)
        if loss.ndim > 0:
            loss = loss.mean()

        loss.backward()
        grad_norm = float(X.grad.norm().item()) if (X.grad is not None) else float('nan')

        with torch.no_grad():
            X_new = rgrad_step(mcs, X, X.grad, lr)
            X.data.copy_(X_new)

        last_loss = float(loss.detach().cpu())
        if verbose and (step % log_every == 0):
            pbar.set_postfix_str(f"loss={last_loss:.6f}")

        if snapshot_every and (step % snapshot_every == 0):
            snapshots[step] = X.detach().cpu().clone()

        # Evaluate metrics
        w2_val = None
        log2_w2_val = None
        nll_val = None
        if eval and (step % eval_every == 0):
            with torch.no_grad():
                w2 = mcs.wasserstein(X, Y)
                if w2.ndim > 0: w2 = w2.mean()
                w2_val = float(w2.item())
                # log2(W2)
                w2_clamped = max(w2_val, 1e-12)
                log2_w2_val = math.log10(w2_clamped)
                log_w2_losses.append(log2_w2_val)

            # NLL under fixed mixture prior (if available)
            if prior_params is not None:
                nll_val = compute_nll(mcs, prior_params, X)

            # Log to text
            if metrics_path is not None:
                append_metrics_line(metrics_path, step, log2_w2_val, nll_val)


    # Final snapshot
    snapshots[steps] = X.detach().cpu().clone()

    # Save .pt bundle
    if save_path:
        os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
        torch.save({
            'X_source': X_source.detach().cpu() if (X_source is not None) else None,
            'X_final': X.detach().cpu(),
            'Y_target': Y.detach().cpu(),
            'K': K.detach().cpu(),
            'ncomp': ncomp,
            'dcomp': dcomp,
            'ntrees': ntrees,
            'nlines': nlines,
            'steps': steps,
            'lr': lr,
            'p': p,
            'delta': delta,
            'snapshots': {int(k): v for k, v in snapshots.items()},
        }, save_path)

    # Optional W2 curve figure
    if eval and len(log_w2_losses) > 0:
        w2_fig_path = os.path.join(os.path.dirname(save_path) or '.', 'W2.png')
        plt.figure()
        plt.plot([x*eval_every for x in range(len(log_w2_losses))], log_w2_losses)
        os.makedirs(os.path.dirname(w2_fig_path) or '.', exist_ok=True)
        plt.xlabel('Eval Index'); plt.ylabel('log(W2)'); plt.title('Convergence (log W2)')
        plt.grid(True)
        plt.tight_layout(); plt.savefig(w2_fig_path, dpi=200); plt.close()

    return X.detach(), Y.detach(), last_loss, snapshots

def main():
    parser = argparse.ArgumentParser("Gradient flow on mixed-curvature space with MCSTSW")
    parser.add_argument('--loss_type', type=str, default='MCSTSW', help='MCSTSW, MCSSWSeparate, MCSTSWSeparateTree')
    parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
    parser.add_argument('--ncomp', type=int, default=1, help='number of components (N)')
    parser.add_argument('--dcomp', type=int, default=3, help='dimension per component (M)')
    parser.add_argument('--K', type=str, default='-1', help='CSV curvatures per component, e.g. "-1,0,1"')
    parser.add_argument('--npoints', type=int, default=2400, help='number of particles in μ and ν')
    parser.add_argument('--sigma', type=float, default=0.35, help='std for init μ (and fallback default)')
    parser.add_argument('--steps', type=int, default=500, help='gradient flow steps')
    parser.add_argument('--lr', type=float, default=5e-2, help='Riemannian step size')
    parser.add_argument('--ntrees', type=int, default=200, help='number of random trees')
    parser.add_argument('--nlines', type=int, default=5, help='lines per tree')
    parser.add_argument('--p', type=int, default=1, help='L^p ground metric in TW')
    parser.add_argument('--delta', type=float, default=2.0, help='softmax temperature inverse (mass division)')

    # Logging / eval cadence
    parser.add_argument('--eval', action='store_true', help='periodically compute W2(X,Y) for logging')
    parser.add_argument('--eval_every', type=int, default=50, help='evaluate metric every k steps when --eval is set')
    parser.add_argument('--log_every', type=int, default=10, help='update tqdm postfix & log scalars every k steps')

    # Trees
    parser.add_argument('--fixed_trees', action='store_true', help='use one fixed set of trees for all steps')
    parser.add_argument('--fixed_trees_seed', type=int, default=123, help='seed to draw fixed trees')

    # General
    parser.add_argument('--seed', type=int, default=0, help='global seed')
    parser.add_argument('--out_dir', type=str, default='outputs/gf_run', help='directory to collect all outputs for this run')

    # Snapshot / plotting
    parser.add_argument('--snapshot_every', type=int, default=100, help='store μ every k steps plus step 0 and final')
    parser.add_argument('--comp', type=int, default=0, help='component index to visualize')
    parser.add_argument('--dim0', type=int, default=0, help='first dimension to plot')
    parser.add_argument('--dim1', type=int, default=1, help='second dimension to plot')

    parser.add_argument('--dataset_path', type=str, default='./data.pt',
                        help='Path to a saved 6-Gaussians mixture dataset; if absent, it will be created here')
    parser.add_argument('--r_per_comp', type=float, default=0.5, help='Equal radius from origin for the 6 centers')
    parser.add_argument('--mixture_sigma', type=float, default=None,
                        help='Per-component tangent std for the mixture (defaults to --sigma if None)')

    parser.add_argument('--source_path', type=str, default='./source_data.pt',
                        help='Path to a saved initial source μ0; if absent, it will be created here')
    parser.add_argument('--source_sigma', type=float, default=0.1,
                        help='Per-component tangent std for the initial source wrapped normal')


    args = parser.parse_args()

    # Compute run directory and derived paths
    run_dir = args.out_dir
    os.makedirs(run_dir, exist_ok=True)

    # Route every output into run_dir
    bundle_path = os.path.join(run_dir, 'bundle.pt')
    snapshots_panel_path = os.path.join(run_dir, 'snapshots.png')
    frames_dir = os.path.join(run_dir, 'frames')
    w2_fig_path = os.path.join(run_dir, 'W2.png')
    metrics_path = os.path.join(run_dir, 'metrics.txt')

    device = args.device
    torch.set_default_dtype(torch.float32)
    torch.backends.cudnn.benchmark = True

    # Prepare K
    K = parse_K(args.K, args.ncomp, device=device)

    dataset_path = args.dataset_path or os.path.join(run_dir, 'dataset.pt')

    # Build a temporary mcs to construct the dataset with the same geometry
    tmp_mcstsw = MCSTSW(ncomp=args.ncomp, dcomp=args.dcomp, initK=K.clone(),
                        ntrees=args.ntrees, nlines=args.nlines, p=args.p, delta=args.delta,
                        fixK=True, device=device)
    mcs_for_data = tmp_mcstsw.mcs

    mixture_s = args.mixture_sigma if args.mixture_sigma is not None else args.sigma
    Y_target, prior_params = prepare_or_load_dataset(
        mcstsw_like=tmp_mcstsw, mcs=mcs_for_data, n_samples=args.npoints,
        r_per_comp=args.r_per_comp, s=mixture_s, dataset_path=dataset_path, seed=args.seed
    )

    source_path = args.source_path or os.path.join(run_dir, 'source.pt')
    X0 = prepare_or_load_source(
        mcs=mcs_for_data, n_samples=args.npoints, s_source=args.source_sigma,
        source_path=source_path, seed=args.seed, ncomp=args.ncomp, dcomp=args.dcomp
    )

    # ------------------
    # Run gradient flow
    # ------------------
    X_final, Y_used, last_loss, snapshots = run_gradient_flow(
        loss_type=args.loss_type, ncomp=args.ncomp, dcomp=args.dcomp, K=K,
        npoints=args.npoints, sigma=args.sigma, steps=args.steps, lr=args.lr,
        ntrees=args.ntrees, nlines=args.nlines, p=args.p, delta=args.delta,
        resample_trees=not args.fixed_trees, fixed_trees_seed=args.fixed_trees_seed,
        device=device, seed=args.seed, save_path=bundle_path, verbose=True,
        snapshot_every=args.snapshot_every, eval=args.eval, eval_every=args.eval_every,
        log_every=args.log_every, wandb_run=None, wandb_log_images_every=None,
        vis_comp=args.comp, vis_dims=(args.dim0, args.dim1),
        metrics_path=metrics_path, prior_params=prior_params,
        Y_target=Y_target, X_source=X0,
    )

    # # Inject target Y into the running loop (we delayed setting it to keep purity)
    # set_target_Y(Y_target)

    # Plot combined snapshots (μ at several steps + ν)
    if snapshots_panel_path:
        plot_snapshots_2d(Y_used, snapshots, comp=args.comp, dims=(args.dim0, args.dim1),
                        fig_path=snapshots_panel_path,
                        title=f"Progress: comp {args.comp} | dims ({args.dim0},{args.dim1})")
        # plot_snapshots_2d(Y_target, snapshots, comp=args.comp, dims=(args.dim0, args.dim1),
        #                   fig_path=snapshots_panel_path,
        #                   title=f"Progress: comp {args.comp} | dims ({args.dim0},{args.dim1})")

    # Save per-snapshot frames
    if frames_dir:
        if args.ncomp == 1:
            save_frames(Y_used, snapshots, comp=args.comp, dims=(args.dim0, args.dim1), fig_dir=frames_dir, K=K)
        else:
            for c in range(args.ncomp):
                subdir = os.path.join(frames_dir, f'comp_{c:02d}')
                save_frames(Y_used, snapshots, comp=c, dims=(args.dim0, args.dim1), fig_dir=subdir, K=K)
                pass

    print(f"Done. Final loss ≈ {last_loss:.6f}. All outputs saved under: {run_dir}")


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    main()
