import os
import numpy as np
import matplotlib.pyplot as plt
from pprint import pformat
import torch
import torch.nn as nn
from utilities.utils import set_requires_grad
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F

from data.data_process import PDEDataProcessor, PDEDataset
from exp.exp_basic import Exp_Basic, ExpConfigs
from utilities.losses import LpLoss
from utilities.vis import visualize_pred_vs_gt
from utilities.data_proc import delay_stack_last_channel
from utilities.read_file import load_matrix_from_path

from config import MemKNOParamBundle
from config import build_encoder, build_set_encoder, build_field_decoder, build_fourier_decoder, build_latent_process
from config import build_latent_process_discrete

"""
TODO:
"""


class Exp_MemKNO(Exp_Basic):
    def __init__(self, args, exp_cfg: ExpConfigs, model_cfg: MemKNOParamBundle, data_processor: PDEDataProcessor):
        super(Exp_MemKNO, self).__init__(args, exp_cfg, model_cfg, data_processor)

        if hasattr(model_cfg, "as_model_kwargs"):
            self.model_cfg = model_cfg.as_model_kwargs(include_meta=True)
        elif isinstance(model_cfg, dict):
            self.model_cfg = model_cfg
        else:
            raise TypeError("model_cfg must be a MemKNOParamBundle or a dict loaded from model_cfg.json")
        assert self.model_cfg["n_frames_cond"] == data_processor.n_frames_cond

        self.state_dim  = self.model_cfg["state_dim"]
        self.latent_dim = self.model_cfg["latent_dim"]
        self.code_dim   = self.model_cfg["code_dim"]
        self.latent_type = self.model_cfg["latent_type"]
        self.use_memory = (self.latent_type in {"linear+memory", "gru_memory", "lstm_memory"})

        # dataloader, initialized in Exp_Basic
        # teacher forcing params
        self.tf_epsilon, self.epsilon = exp_cfg.tf_epsilon, exp_cfg.epsilon

        self.dt_eval = exp_cfg.dt_eval

        self.enc_mode = self.cfg.enc_mode
        self.dec_mode = self.cfg.dec_mode
        self.latent_mode = self.cfg.latent_mode
        self.loss_with_mask = self.cfg.loss_with_mask

        # loss weights
        self.lambda_dyn, self.lambda_pred = self.cfg.lambda_dyn, self.cfg.lambda_pred
        self.lambda_corr = self.cfg.lambda_corr
        self.lambda_resid = getattr(self.cfg, "lambda_residual", 0.0)

        self.use_diag_whiten = getattr(self.cfg, "use_diag_whiten", True)
        self.whiten_eps = getattr(self.cfg, "whiten_eps", 1e-8)
        self.whiten_clamp = getattr(self.cfg, "whiten_clamp", 1e-3)
        self.whiten_scale = None  # [D]

        ################# Params for Low-dimensional Projectors #################
        self.U_proj = None           # [D, d] or None
        self.use_projector = False  
        self.latent_dim_y = None 

        # load model
        self.load_model()
        self.log_param_table()

    
    def build_dataloader(self, group: str):
        sample_map = {
            "train": self.train_sample_idx,
            "train_eval": self.train_eval_sample_idx,
            "test": self.test_sample_idx,
        }
        dataloader = self.data_processor.get_dataloader(group=group, samples=sample_map[group])
        if group == "train":
            self.train_loader = dataloader
        elif group == "train_eval":
            self.train_eval_loader = dataloader
        elif group == "test":
            self.test_loader = dataloader

    
    def load_model(self):
        if self.enc_mode == "galerkin_transformer":
            self.encoder = build_encoder(model_cfg=self.model_cfg["encoder"]).to(self.device)
        elif self.enc_mode == "set_transformer":
            self.encoder = build_set_encoder(model_cfg=self.model_cfg["set_encoder"]).to(self.device)

        if self.dec_mode == "fouriernet":
            self.decoder = build_field_decoder(x_grid=self.pos_feat, model_cfg=self.model_cfg["field_decoder"]).to(self.device)
        elif self.dec_mode == "fouriermlp":
            self.decoder = build_fourier_decoder(model_cfg=self.model_cfg["fourier_decoder"]).to(self.device)

        if self.latent_mode == "continuous":
            self.latent_process = build_latent_process(model_cfg=self.model_cfg["latent_process"]).to(self.device)
        elif self.latent_mode == "discrete":
            self.latent_process = build_latent_process_discrete(model_cfg=self.model_cfg["latent_process_discrete"]).to(self.device)

    
    def count_parameters(self) -> dict:
        """Return a dict with per-module and total trainable parameter counts."""
        def ntrainable(module):
            return sum(p.numel() for p in module.parameters() if p.requires_grad)

        counts = {
            "encoder": ntrainable(self.encoder) if hasattr(self, "encoder") else 0,
            "latent_process": ntrainable(self.latent_process) if hasattr(self, "latent_process") else 0,
            "decoder": ntrainable(self.decoder) if hasattr(self, "decoder") else 0,
        }
        counts["total"] = sum(counts.values())
        return counts


    def log_param_table(self, title: str = "Trainable parameters"):
        c = self.count_parameters()
        lines = [
            f"{title}:",
            f"  encoder       : {c['encoder']:,}",
            f"  latent_process: {c['latent_process']:,}",
            f"  decoder       : {c['decoder']:,}",
            f"  ---------------------------",
            f"  TOTAL         : {c['total']:,}",
        ]
        msg = "\n".join(lines)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(msg)
        else:
            print(msg)

    
    def switch_to_train(self):
        self.encoder.train()
        self.decoder.train()
        self.latent_process.train()

    
    def switch_to_eval(self):
        self.encoder.eval()
        self.decoder.eval()
        self.latent_process.eval()
        

    def init_optim(self):
        self.optim_enc = torch.optim.Adam([{'params': self.encoder.parameters(), 'lr': self.lr}])
        self.optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': self.lr}])
        ######################################
        self.optim_dyn = torch.optim.Adam([{'params': self.latent_process.parameters(), 'lr': self.lr, 'weight_decay': self.cfg.weight_decay}])

        if self.cfg.scheduler == 'OneCycleLR':
            self.scheduler_enc = torch.optim.lr_scheduler.OneCycleLR(self.optim_enc, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
            self.scheduler_dec = torch.optim.lr_scheduler.OneCycleLR(self.optim_dec, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
            ######################################
            self.scheduler_dyn = torch.optim.lr_scheduler.OneCycleLR(self.optim_dyn, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
        elif self.cfg.scheduler == 'CosineAnnealingLR':
            self.scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_enc, T_max=self.cfg.epochs)
            self.scheduler_dec = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dec, T_max=self.cfg.epochs)
            self.scheduler_dyn = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dyn, T_max=self.cfg.epochs)
        elif self.cfg.scheduler == 'StepLR':
            self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
            self.scheduler_dec = torch.optim.lr_scheduler.StepLR(self.optim_dec, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
            self.scheduler_dyn = torch.optim.lr_scheduler.StepLR(self.optim_dyn, step_size=self.cfg.step_size, gamma=self.cfg.gamma)


    def train_recon(self, log_every: int | None = None, verbose: str = None):
        """
        batch['data']: [B, T, H, W, C]
        batch['t']: tensor of len n_frames_train / n_frames_train + n_frames_out
        batch['index']: sample idx 
        batch['mask']: [B, T, H, W, C]
        """
        self.setup_logger()    
        self.save_repro_artifacts()
        self.log_param_table()
        criterion = nn.MSELoss()
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        self.init_optim()
        self._save_split_and_samples()
        self.logger.info(f"Begin training using decoder backbone: {self.dec_mode}")

        if verbose is not None:
            self.logger.info(f"{verbose}")
        for epoch in range(1, self.cfg.epochs + 1):
            for i, batch in enumerate(self.train_loader):
                ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
                sample_idx = batch["index"].to(self.device)     # [B,]
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]    (same over time and channels)
                bs, train_len, H, W, _ = ground_truth.shape
                assert train_len == self.n_frames_train

                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]

                data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token]

                if self.dec_mode == "fouriernet":
                    latent_feats = latent_token.flatten(1, 2).unflatten(0, (bs, train_len-self.n_frames_cond+1))    # [B, t, K*latent_token]
                    assert latent_feats.shape[-1] == self.latent_dim
                    latent_feats = latent_feats.view(bs, train_len-self.n_frames_cond+1, self.state_dim, self.code_dim)    # [B, t, s, code_dim]
                    recon_field = self.decoder(latent_feats)    # [B, t, H, W, s]
                    recon_loss = criterion(ground_truth[:, self.n_frames_cond-1:, ...], recon_field)
                    sqerr = (ground_truth[:, self.n_frames_cond-1:, ...] - recon_field).pow(2) * masks[:, self.n_frames_cond-1:, ...]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, self.n_frames_cond-1:, ...].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss_wz_mask = (sqerr_sum / denom).mean()
                    
                    self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                    recon_loss_wz_mask.backward() if self.loss_with_mask else recon_loss.backward()
                    self.optim_enc.step()
                    self.optim_dec.step()
                    if self.cfg.scheduler == "OneCycleLR":
                        self.scheduler_enc.step()
                        self.scheduler_dec.step()

                elif self.dec_mode == "fouriermlp":
                    latent_feats = latent_token.flatten(1, 2)    # [B*t, K*latent_token]
                    grid_dim = self.pos_feat.shape[-1]
                    grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
                    recon_field = self.decoder(grid=grid, latent_feat=latent_feats)    # [B*t, N_pts, out_dim]
                    recon_field = recon_field.reshape(bs, train_len-self.n_frames_cond+1, *self.shapelist, self.state_dim)
                    # print(recon_field.shape)

                    recon_loss = criterion(ground_truth[:, self.n_frames_cond-1:, ...], recon_field)
                    sqerr = (ground_truth[:, self.n_frames_cond-1:, ...] - recon_field).pow(2) * masks[:, self.n_frames_cond-1:, ...]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, self.n_frames_cond-1:, ...].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss_wz_mask = (sqerr_sum / denom).mean()

                    self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                    recon_loss_wz_mask.backward() if self.loss_with_mask else recon_loss.backward()
                    self.optim_enc.step()
                    self.optim_dec.step()
                    if self.cfg.scheduler == "OneCycleLR":
                        self.scheduler_enc.step()
                        self.scheduler_dec.step()
    
                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    self.logger.info(f"Epoch {epoch:04d}/{self.cfg.epochs} | iteration {i+1:03d} |"
                        f"| field reconstruction loss {recon_loss.item():.8f} | field reconstruction loss (with mask) {recon_loss_wz_mask.item():.8f}")
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_enc.step()
                self.scheduler_dec.step()


    def mask_to_bs_index(self, mask: torch.Tensor, S: int | None = None, threshold: float = 0.5) -> torch.Tensor:
        # mask: [B, T, H, W, C]  (1=keep)
        B, T, H, W, C = mask.shape
        device = mask.device

        m2d = (mask[:, 0, :, :, 0] > threshold)
        lin = (torch.arange(H, device=device).view(H, 1) * W + torch.arange(W, device=device).view(1, W))
        lin = lin.view(1, H, W).expand(B, -1, -1)
        idx_list = [lin[b][m2d[b]].reshape(-1).long() for b in range(B)]
        S = min(x.numel() for x in idx_list) if S is None else S
        if any(x.numel() < S for x in idx_list):
            raise ValueError(f"Some batches have fewer kept points than S={S}.")
        idx = torch.stack([x[:S] for x in idx_list], dim=0)   # [B, S]
        return idx


    def index_points(self, points, idx):
        """
        Input:
            points: input points data, [B, N, C]
            idx: sample index data, [B, S]
        Return:
            new_points: indexed points data, [B, S, C]
        """
        device = points.device
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points

    
    def one_step_loss(self,
                      alpha_gt,           # [T, B, D]
                      t_eval,             # [T]
                      *,
                      train_latent: bool = True,
                      weight_per_step: torch.Tensor | None = None,  # [T-1] or [T-1,B]
                      reduction: str = "mean"):
        """
        Compute the training loss that matches ODE one-step predictions alpha_{k+1}^{pred}
        against ground-truth alpha_{k+1} for all k.

        Implementation detail:
        - We call process.forward with teacher forcing and a mask that cuts at every step
            (except the very last index, per your implementation). This yields the exact
            0->1, 1->2, ..., (T-2)->(T-1) one-step integration with memory carried across.

        Returns
        -------
        loss        : scalar tensor suitable for backward()
        pred_next   : [T-1, B, D] predicted alpha at t_{k+1}
        gt_next     : [T-1, B, D] ground-truth alpha at t_{k+1}
        per_step_mse: [T-1, B]    per (k,b) MSE before reduction
        """
        device = alpha_gt.device
        T, B, D = alpha_gt.shape

        # Mask that ends every segment at each step: [T-1] all True except the last False
        tf_mask = torch.ones(T - 1, dtype=torch.bool, device=device)
        tf_mask[-1] = False

        # Run once with teacher forcing; memory is carried internally for memory variants.
        # (alpha_gt is detached inside your _odeint_with_tf implementation.)
        out = self.latent_process.forward(
            alpha_0=alpha_gt[0],             # not used when teacher_forcing=True, but required by signature
            t_eval=t_eval,
            memory_init=None,                # optional: pass learned/init memory state if you have one
            teacher_forcing=True,
            tf_alpha=alpha_gt,
            tf_epsilon=0.0,
            tf_mask=tf_mask,
            tf_detach_alpha_starts=not(train_latent),
        )

        # Unpack (handles both return_aux=False/True)
        if isinstance(out, tuple):
            alpha_pred = out[0]  # [T, B, D]
        else:
            alpha_pred = out     # [T, B, D]

        pred_next = alpha_pred[1:]         # [T-1, B, D]
        gt_next   = alpha_gt[1:]           # [T-1, B, D]

        # Per (k,b) MSE before reduction
        per_step_mse = (pred_next - gt_next).pow(2).mean(dim=-1)   # [T-1, B]

        if weight_per_step is not None:
            # Broadcast weights to [T-1, B]
            if weight_per_step.ndim == 1:      # [T-1]
                w = weight_per_step[:, None].to(per_step_mse)
            elif weight_per_step.ndim == 2:    # [T-1, B]
                w = weight_per_step.to(per_step_mse)
            else:
                raise ValueError("weight_per_step must be shape [T-1] or [T-1, B]")
            per_step_mse = per_step_mse * w

        if reduction == "mean":
            loss = per_step_mse.mean()
        elif reduction == "sum":
            loss = per_step_mse.sum()
        elif reduction == "none":
            loss = per_step_mse
        else:
            raise ValueError(f"Unknown reduction: {reduction}")
        return loss, pred_next, gt_next, per_step_mse


    def kstep_rollout_loss(
        self,
        alpha_gt: torch.Tensor,         # [T, B, D]
        t_eval: torch.Tensor,           # [T]
        K: int,
        *,
        num_starts: int | None = None,
        stride: int = 1,
        train_latent: bool = True,
        target_stopgrad: bool = True,
        discount_gamma: float = 1.0,
        reduction: str = "mean",
        # NEW: carry memory across windows
        carry_memory: bool = True,
        # Whether to detach the precomputed memory trace before using it as memory_init
        carry_memory_detach: bool = False,
    ):
        """
        Sample multiple starting windows from a single latent sequence and perform
        free integration (teacher_forcing=False) for K steps in each window.
        Supervise every step against the ground-truth latent.

        If `carry_memory=True` and the latent process has memory, we first run a
        teacher-forced pass over the *whole* sequence with 1-step cuts (so memory is
        carried across steps while alpha is reset to GT at every step). The memory at
        time t_s is then used as `memory_init` for the rollout window starting at s.

        NOTE (LSTM): forward() returns the hidden state h only (not cell c). Thus we
        carry h across windows and still initialize c=0. If you need to carry (h,c),
        you'd need to extend the model to return c_t or accept an extra arg.
        """
        device = alpha_gt.device
        T, B, D = alpha_gt.shape
        assert K >= 1 and T >= K + 1, "Sequence too short for K-step rollout."

        # Candidate window starts s ∈ [0, T-K-1] with the chosen stride.
        all_starts = torch.arange(0, T - K, stride, device=device)

        # Optionally subsample a random subset of starts.
        if num_starts is not None and num_starts < all_starts.numel():
            perm = torch.randperm(all_starts.numel(), device=device)
            starts = all_starts[perm[:num_starts]]
        else:
            starts = all_starts

        t_rel = (t_eval[:K + 1] - t_eval[0]).to(device=device, dtype=alpha_gt.dtype)

        # Optional horizon discount weights: shape [K] or None.
        if discount_gamma == 1.0:
            h_w = None
        else:
            h_w = discount_gamma ** torch.arange(1, K + 1, device=device, dtype=alpha_gt.dtype)

        # ----------------------------------------------------------------------
        # NEW: precompute memory trace over the whole sequence (teacher-forced).
        #      We cut at every single step (except the last) so memory is carried,
        #      but alpha is reset to the ground-truth at each segment start.
        mem_trace = None
        if carry_memory and self.use_memory:
            tf_mask = torch.ones(T - 1, dtype=torch.bool, device=device)
            tf_mask[-1] = False  # keep last interval connected
            out_tf = self.latent_process.forward(
                alpha_0=alpha_gt[0],     # not used when teacher_forcing=True
                t_eval=t_eval,
                memory_init=None,
                teacher_forcing=True,
                tf_alpha=alpha_gt,
                tf_epsilon=0.0,
                tf_mask=tf_mask,
                # match one_step_loss behavior: allow gradient to flow to alpha when training latent
                tf_detach_alpha_starts=not(train_latent),
                detach_memory_between_segments=False,   # carry memory (and grads) across 1-step segments
            )
            if isinstance(out_tf, tuple):
                # out_tf = (alpha_t, memory_t, [aux])
                mem_trace = out_tf[1]   # [T, B, Dm]
            if mem_trace is not None and carry_memory_detach:
                mem_trace = mem_trace.detach()

        """losses = []

        # --- Iterate over windows one by one.
        for s in starts.tolist():
            # Free rollout starts from a0 at time t_s.
            a0 = alpha_gt[s] if train_latent else alpha_gt[s].detach()

            # Initialize memory for this window.
            if self.use_memory:
                if carry_memory and (mem_trace is not None):
                    mem0 = mem_trace[s]  # memory at t_s computed from the teacher-forced pass
                else:
                    # fall back to zeros if no carry or no memory module
                    mem0 = torch.zeros(B, self.latent_process.memory_dim, device=device, dtype=alpha_gt.dtype)
            else:
                mem0 = None

            # Window-specific relative time axis
            t_win = (t_eval[s:s + K + 1] - t_eval[s])

            # Free integrate K steps (teacher_forcing=False).
            out = self.latent_process.forward(
                alpha_0=a0,
                t_eval=t_win,
                memory_init=mem0,
                teacher_forcing=False,
                tf_detach_alpha_starts=not(train_latent)
            )
            alpha_pred = out[0] if isinstance(out, tuple) else out  # [K+1, B, D]

            # Predictions for the next K steps and the corresponding targets.
            pred = alpha_pred[1:]                 # [K, B, D]
            target = alpha_gt[s + 1:s + K + 1]    # [K, B, D]
            if target_stopgrad:
                target = target.detach()

            # MSE per (h, b); average over feature dimension D.
            mse = (pred - target).pow(2).mean(dim=-1)  # [K, B]

            # Apply optional horizon discount γ^h.
            if h_w is not None:
                mse = mse * h_w[:, None]

            # Reduce over K and B for this window.
            losses.append(mse.mean())

        # Aggregate over windows.
        if len(losses) == 0:
            loss = alpha_gt.new_tensor(0.0)
        else:
            stacked = torch.stack(losses)  # [num_windows]
            if reduction == "mean":
                loss = stacked.mean()
            elif reduction == "sum":
                loss = stacked.sum()
            elif reduction == "none":
                loss = stacked
            else:
                raise ValueError(f"Unknown reduction: {reduction}")

        return loss, {
            "num_windows": int(starts.numel()),
            "K": int(K),
            "carry_memory": bool(carry_memory),
            "carry_memory_detach": bool(carry_memory_detach),
        }"""

        W = starts.numel()
        if W == 0:
            return alpha_gt.new_tensor(0.0), {
                "num_windows": 0, "K": int(K),
                "carry_memory": bool(carry_memory),
                "carry_memory_detach": bool(carry_memory_detach),
            }
        a0 = alpha_gt[starts] if train_latent else alpha_gt[starts].detach()  # [W, B, D]
        a0 = a0.reshape(W * B, D)
        if self.use_memory:
            if carry_memory and (mem_trace is not None):
                mem0 = mem_trace[starts].reshape(W * B, self.latent_process.memory_dim)
            else:
                mem0 = torch.zeros(W * B, self.latent_process.memory_dim, device=device, dtype=alpha_gt.dtype)
        else:
            mem0 = None
        out = self.latent_process.forward(
            alpha_0=a0,
            t_eval=t_rel,
            memory_init=mem0,
            teacher_forcing=False,
        )
        alpha_pred = out[0] if isinstance(out, tuple) else out  # [K+1, W*B, D]
        pred = alpha_pred[1:].reshape(K, W, B, D)               # [K, W, B, D]

        idx = (starts[None, :] + torch.arange(1, K + 1, device=device)[:, None]).reshape(-1)  # [K*W]
        target = alpha_gt.index_select(0, idx).reshape(K, W, B, D)
        if target_stopgrad:
            target = target.detach()

        mse = (pred - target).pow(2).mean(dim=-1)   # [K, W, B]
        if h_w is not None:
            mse = mse * h_w[:, None, None]

        if reduction == "mean":
            loss = mse.mean()
        elif reduction == "sum":
            loss = mse.sum()
        elif reduction == "none":
            loss = mse
        else:
            raise ValueError(f"Unknown reduction: {reduction}")

        return loss, {
            "num_windows": int(W),
            "K": int(K),
            "carry_memory": bool(carry_memory),
            "carry_memory_detach": bool(carry_memory_detach),
        }


    @staticmethod
    @torch.no_grad()
    def solve_Ab_ridge(Z0: torch.Tensor, Zp: torch.Tensor, ridge: float = 1e-3, add_bias: bool = True):
        """
        Closed-form ridge regression for one-step latent linear dynamics:
        min_{A,b} ||Zp - (A Z0 + b)||_F^2 + ridge * ||[A b]||_F^2
        Inputs:
        Z0: [N, D]  flattened (t,b) -> N samples at time t
        Zp: [N, D]  flattened (t,b) -> N samples at time t+1
        Returns:
        A*: [D, D], b*: [D] or None
        """
        N, D = Z0.shape
        if add_bias:
            X = torch.cat([Z0, torch.ones(N, 1, device=Z0.device, dtype=Z0.dtype)], dim=1)  # [N, D+1]
            p = D + 1
        else:
            X = Z0; p = D
        Y = Zp                                          # [N, D]
        G = X.T @ X + ridge * torch.eye(p, device=X.device, dtype=X.dtype)   # [p, p]
        YXt = Y.T @ X                                   # [D, p]
        Theta = torch.linalg.solve(G, YXt.T).T          # [D, p]
        A = Theta[:, :D]                                # [D, D]
        b = Theta[:, D] if add_bias else None           # [D]
        return A, b
    

    @staticmethod
    @torch.no_grad()
    def ema_update_stats(Sxx: torch.Tensor, Syx: torch.Tensor,
                        Z0: torch.Tensor, Zp: torch.Tensor,
                        ema_beta: float = 0.95, add_bias: bool = True):
        """
        EMA sufficient statistics across batches:
        Sxx ≈ E[X^T X],  Syx ≈ E[Y^T X], where X=[Z0; 1] (if add_bias)
        In-place updates with EMA.
        """
        N, D = Z0.shape
        if add_bias:
            X = torch.cat([Z0, torch.ones(N, 1, device=Z0.device, dtype=Z0.dtype)], dim=1)  # [N, D+1]
        else:
            X = Z0
        Y = Zp
        Sxx.mul_(ema_beta).add_(X.T @ X, alpha=(1 - ema_beta))
        Syx.mul_(ema_beta).add_(Y.T @ X, alpha=(1 - ema_beta))
        return Sxx, Syx


    @staticmethod
    @torch.no_grad()
    def solve_global_A_from_stats(Sxx: torch.Tensor, Syx: torch.Tensor,
                                ridge: float = 1e-3, D: int | None = None, add_bias: bool = True):
        """
        Solve global A_ema (+ b_ema) from EMA sufficient statistics.
        """
        p = Sxx.size(0)
        I = torch.eye(p, device=Sxx.device, dtype=Sxx.dtype)
        Theta = torch.linalg.solve(Sxx + ridge * I, Syx.T).T  # [D, p]
        if D is None:
            D = Theta.size(0)
        A_ema = Theta[:, :D]
        b_ema = Theta[:, D] if add_bias else None
        return A_ema, b_ema
    
    
    @torch.no_grad()
    def solve_global_A_fullpass(self, dataloader, ridge: float = 5e-3, use_bias: bool = True):
        enc_was_train = self.encoder.training
        dec_was_train = self.decoder.training
        self.encoder.eval(); self.decoder.eval()
        Z0_list, Zp_list = [], []
        for batch in dataloader:
            latent_states, _, _ = self._encode_and_recon(batch)  # [T',B,D]
            T1, B, D = latent_states.shape
            if T1 <= 1:
                continue
            Z0_list.append(latent_states[:-1].reshape(-1, D))
            Zp_list.append(latent_states[ 1:].reshape(-1, D))

        if len(Z0_list) == 0:
            raise RuntimeError("[solve_global_A_fullpass] No (Z0,Zp) pairs collected. Check dataloader or n_frames_cond.")
        Z0 = torch.cat(Z0_list, dim=0).to(self.device)
        Zp = torch.cat(Zp_list, dim=0).to(self.device)
        A_full, b_full = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)

        if enc_was_train: self.encoder.train()
        if dec_was_train: self.decoder.train()
        return A_full.detach(), (b_full.detach() if b_full is not None else None)


    def train_phase1_linear(self, 
                            epochs: int,
                            ridge: float = 5e-3,
                            ema_beta: float = 0.97,
                            use_bias: bool = True,
                            use_pred_loss: bool = True,
                            lambda_pred: float = 0.1,
                            lambda_dyn: float = 0.05,
                            log_every: int | None = 5,
                            eval_every: int | None = 20,
                            verbose: str = None,
                            diag_enable: bool = True,                    # print diagnostics
                            ms_consistency_enable: bool = False,         # latent multi-step consistency with batch A*
                            freq_ms_enable: bool = False,                # turn on spectral+multiscale penalty
                            global_A_mode: str | None = None,            # "fullpass" | "ema"
                            ):
        """
        Phase-I: train encoder/decoder with closed-form linear skeleton on latent.
        - Per-batch closed-form A* only defines losses (do NOT write into model params).
        - Maintain EMA sufficient statistics to get a global A_ema (+ b_ema) each epoch.
        """
        mode = (global_A_mode or getattr(self.cfg, "global_A_mode", "fullpass")).lower()
        assert mode in {"fullpass", "ema"}, f"global_A_mode must be 'fullpass' or 'ema', got {mode}"

        self.setup_logger()
        self.save_repro_artifacts()
        self.log_param_table()
        assert self.data_processor.mode == "interpolation", "Mismatched dataloaders"
        self.build_dataloader(group="train")
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        # Switch modules
        self.encoder.train()
        self.decoder.train()
        set_requires_grad(self.latent_process, False)  # freeze latent dynamics in Phase-I

        # Optimizers (enc/dec only)
        self.init_optim()

        if verbose is not None:
            self.logger.info(f"{verbose}")

        if mode == "ema":    # EMA stats buffers for a global A
            D = self.latent_dim
            p = D + (1 if use_bias else 0)
            Sxx = torch.zeros(p, p, device=self.device)
            Syx = torch.zeros(D, p, device=self.device)
        else:
            Sxx = Syx = None

        num_epochs = epochs if epochs is not None else self.cfg.epochs
        best_rec, best_metrics = float("inf"), None

        # ===== NEW: buffers for within-epoch diagnostics =====
        rho_list_epoch = []        # spectral radii of A* per batch
        gap_list_epoch = []        # ||A* - A_ema|| / ||A_ema||
        """dynA_list_epoch = []       # dyn loss with A*
        dynE_list_epoch = []       # dyn loss with current A_ema from stats"""
        A_global_prev = getattr(self, "A_phase1_ema", None)

        for epoch in range(1, num_epochs + 1):
            for it, batch in enumerate(self.train_loader):
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
                # print(batch["index"])
                # print(masks[0, 0, ...])
                # ------------------------------------------------------
                # 1) Encode inputs -> latent_states [T,B,D] and recon loss
                # ------------------------------------------------------
                latent_states, recon_loss, recon_loss_wzmask, x_rec, x_gt, _ = self._encode_and_recon(batch, return_recon=True)
                T, B, D = latent_states.shape

                # ------------------------------------------------------
                # 2) Closed-form A*, b* (only for loss) on flattened pairs
                # ------------------------------------------------------
                Z0 = latent_states[:-1].reshape(-1, D)  # [N, D]
                Zp = latent_states[ 1:].reshape(-1, D)  # [N, D]
                with torch.no_grad():
                    A_star, b_star = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)

                rho_Astar = self._spectral_radius(A_star)
                rho_list_epoch.append(rho_Astar)

                # compute a *current* A_ema from stats BEFORE we update them, as a diagnostic baseline
                """with torch.no_grad():
                    A_ema_tmp, b_ema_tmp = self.solve_global_A_from_stats(Sxx, Syx, ridge=ridge, D=D, add_bias=use_bias)
                if torch.isfinite(A_ema_tmp).all():
                    denom = torch.norm(A_ema_tmp) + 1e-12
                    gap = torch.norm(A_star - A_ema_tmp) / denom
                    gap_list_epoch.append(float(gap))"""
                if (A_global_prev is not None) and torch.isfinite(A_global_prev).all():
                    with torch.no_grad():
                        denom = torch.norm(A_global_prev) + 1e-12
                        gap_list_epoch.append(float(torch.norm(A_star - A_global_prev) / denom))

                # Latent one-step consistency: Zp ≈ A*Z0 (+b)
                Zp_hat = (Z0 @ A_star.T) + (b_star if use_bias else 0.0)      # [N, D]
                dyn_loss = F.mse_loss(Zp_hat, Zp)
                # ------------------------------------------------------
                # 3) Optional: one-step in observation space
                # ------------------------------------------------------
                if use_pred_loss:
                    a1_hat = Zp_hat.view(T - 1, B, D)                  # [T-1,B,D]
                    x1_hat = self._decode_latent(a1_hat)               # [B,T-1,H,W,C]
                    x1_true = batch["data"][:, self.n_frames_cond:, ...].to(self.device)
                    pred_loss, pred_loss_wzmask = self._mse_loss(x1_hat, x1_true, masks[:, self.n_frames_cond:, ...])
                else:
                    pred_loss, pred_loss_wzmask = torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)


                # long term prediction loss
                lambda_lt_pred = self.cfg.lambda_lt_pred if (self.cfg.lambda_lt_pred is not None) and ms_consistency_enable else 0.0
                if ms_consistency_enable and (self.cfg.lambda_lt_pred > 0.0):
                    ms_loss = self._multistep_latent_consistency(latent_states, A_star, b_star if use_bias else None,
                                                                 H=self.cfg.rollout_steps, gamma=self.cfg.gamma_decay)
                else:
                    ms_loss = torch.tensor(0.0, device=self.device)

                # loss on the frequency domain
                lambda_freq = self.cfg.lambda_freq if (self.cfg.lambda_freq is not None) and freq_ms_enable else 0.0
                if freq_ms_enable and (lambda_freq > 0.0):
                    xf_hat, xf_true = x_rec, x_gt  # use reconstruction frames
                    fft_loss = self._fft_mse_on_frames(
                        xf_hat, xf_true,
                        use_log_mag=True,
                        hf_power=self.cfg.freq_hf_power
                    )
                    ms_loss_img = self._multiscale_spatial_mse(
                        xf_hat, xf_true,
                        pool_scales=self.cfg.ms_pool_scales
                    )
                    freq_ms_loss = fft_loss + ms_loss_img
                else:
                    freq_ms_loss = torch.tensor(0.0, device=self.device)


                # ------------------------------------------------------
                # 4) Total loss and backward (enc/dec only)
                # ------------------------------------------------------
                loss = recon_loss_wzmask + lambda_dyn * dyn_loss + lambda_pred * pred_loss_wzmask
                loss += lambda_lt_pred * ms_loss
                loss += lambda_freq * freq_ms_loss

                self.optim_enc.zero_grad()
                self.optim_dec.zero_grad()
                loss.backward()
                if self.cfg.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.cfg.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.cfg.max_grad_norm)
                self.optim_enc.step()
                self.optim_dec.step()
                if self.cfg.scheduler == "OneCycleLR":
                    self.scheduler_enc.step()
                    self.scheduler_dec.step()
                # ------------------------------------------------------
                # 5) EMA stats for a global A
                # ------------------------------------------------------
                if mode == "ema":
                    self.ema_update_stats(Sxx, Syx, Z0, Zp, ema_beta=ema_beta, add_bias=use_bias)
                if log_every and ((epoch * len(self.train_loader) + it) % log_every == 0):
                    self.logger.info(
                        f"[Phase-I] epoch {epoch:03d} it {it:04d} | "
                        f"rec {recon_loss.item():.8f} | dyn {dyn_loss.item():.8f} | pred {pred_loss.item():.8f} | "
                        f"rec(with mask) {recon_loss_wzmask.item():.8f} | pred(with mask) {pred_loss_wzmask.item():.8f} |"
                        f"dyn(long term) {ms_loss.item():.8f} | freq_ms_loss {freq_ms_loss.item():.8f}"
                    )
            # Epoch end: compute global A_ema (+ b_ema) for Phase-II init/logging
            with torch.no_grad():
                if mode == "fullpass":
                    A_glob, b_glob = self.solve_global_A_fullpass(self.train_loader, ridge=ridge, use_bias=use_bias)
                else:
                    A_glob, b_glob = self.solve_global_A_from_stats(Sxx, Syx, ridge=ridge, D=D, add_bias=use_bias)
                self.A_phase1_ema = A_glob.clone()
                self.b_phase1_ema = b_glob.clone() if b_glob is not None else None
                # Optionally persist for Phase-II (e.g., pH-dense init)
                out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
                torch.save({'A_ema': self.A_phase1_ema, 'b_ema': self.b_phase1_ema}, os.path.join(out_dir, "A_phase1_ema.pt"))
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_enc.step()
                self.scheduler_dec.step()

            # ===== epoch-end diagnostics print (distributions) =====
            if diag_enable:
                def _stats(v):
                    import math
                    if len(v) == 0: return dict(min=float("nan"), max=float("nan"), mean=float("nan"), std=float("nan"))
                    m = sum(v) / len(v)
                    var = sum((x - m) ** 2 for x in v) / max(1, len(v) - 1)
                    return dict(min=min(v), max=max(v), mean=m, std=math.sqrt(var))
                rs = _stats(rho_list_epoch)
                gs = _stats(gap_list_epoch) if len(gap_list_epoch) else None
            
                # keep your "[P1/Eval-TrainEval]" style but mark "diag"
                diag_str = {
                    "rho(A*_batch)": f"{rs['min']:.4f}/{rs['mean']:.4f}/{rs['max']:.4f}/{rs['std']:.4f}",
                    "gap(A*,A_ema)": (f"{gs['mean']:.3f}" if gs is not None else "NA"),
                    "rho(A_ema)":    (f"{self._spectral_radius(self.A_phase1_ema):.6f}" if hasattr(self, "A_phase1_ema") else "NA"),
                }
                self.logger.info(f"[P1/Diag-Epoch] E={epoch:03d} | "
                                f"rho(A*_batch) min/mean/max/std={diag_str['rho(A*_batch)']} | "
                                f"gap ||A*-A_ema||/||A_ema|| mean={diag_str['gap(A*,A_ema)']} | "
                                f"rho(A_ema)={diag_str['rho(A_ema)']}")

            # reset epoch buffers
            rho_list_epoch.clear()
            gap_list_epoch.clear()
            A_global_prev = self.A_phase1_ema

            ###################### Evaluation ######################
            if eval_every is not None and epoch % eval_every == 0:
                if hasattr(self, "train_eval_loader") and self.train_eval_loader is not None:
                    metrics_tr = self.evaluate_phase1(self.train_eval_loader, use_global_A=True)
                    self.logger.info(f"[P1/Eval-TrainEval] rec(mask)={metrics_tr['rec_masked']:.8f} | "
                                    f"dyn={metrics_tr['dyn_mse']:.8f} | pred(mask)={metrics_tr['pred_masked']:.8f} | "
                                    f"diag={metrics_tr['diag']}")
                    # Track the best by masked reconstruction error
                    if metrics_tr["rec_masked"] < best_rec:
                        best_rec = metrics_tr["rec_masked"]
                        best_metrics = {"split": "train_eval", **metrics_tr}
                        self.save_phase1_checkpoint(epoch=epoch, metrics=best_metrics, pth_name="phase1_best_rec.pth")
                if hasattr(self, "test_loader") and self.test_loader is not None:
                    metrics_ts = self.evaluate_phase1(self.test_loader, use_global_A=True)
                    self.logger.info(f"[P1/Eval-Test] rec(mask)={metrics_ts['rec_masked']:.8f} | "
                                    f"dyn={metrics_ts['dyn_mse']:.8f} | pred(mask)={metrics_ts['pred_masked']:.8f} | "
                                    f"diag={metrics_ts['diag']}")
                    
        self.logger.info("[Phase-I] finished. Global A_ema saved to self.A_phase1_ema.")
        self.save_phase1_checkpoint(epoch=None, metrics=best_metrics, pth_name="phase1_final.pth")

    
    @torch.no_grad()
    def evaluate_phase1(self,
                        dataloader=None,
                        *,
                        use_global_A: bool = True,   # use A_ema if available; else fall back to per-batch closed-form
                        use_bias: bool = True,
                        use_pred_loss: bool = True,  # include one-step decoded loss
                        ridge: float = 5e-3):
        """
        Phase-I evaluation without using latent_process.
        Metrics:
        - rec_wzmask: masked recon MSE on frames [nf_cond-1 : T-1]
        - dyn_mse   : latent linear consistency MSE (Z_{t+1} vs A Z_t + b)
        - pred_wzmask (optional): one-step decoded masked MSE using A,b
        - (optional) spectral stats of A_ema if used

        If use_global_A=True and self.A_phase1_ema exists, we use that fixed A (+b) for all batches;
        otherwise we solve closed-form A*, b* per batch.
        """
        self.encoder.eval()
        self.decoder.eval()

        # choose loader
        if dataloader is None:
            # prefer a held-out split if present
            if hasattr(self, "test_loader") and self.test_loader is not None:
                dataloader = self.test_loader
            elif hasattr(self, "train_eval_loader") and self.train_eval_loader is not None:
                dataloader = self.train_eval_loader
            else:
                # fallback to building a test/eval loader
                try:
                    self.build_dataloader(group="test")
                    dataloader = self.test_loader
                except Exception:
                    self.build_dataloader(group="train_eval")
                    dataloader = self.train_eval_loader

        tot_rec_m = 0.0   # masked recon
        tot_dyn   = 0.0   # latent linear consistency
        tot_pred_m = 0.0  # masked one-step decoded
        nsamples  = 0

        # cache global A if requested and available
        A_fix = None
        b_fix = None
        if use_global_A and hasattr(self, "A_phase1_ema"):
            A_fix = self.A_phase1_ema.to(self.device, dtype=torch.float32)
            b_fix = (self.b_phase1_ema.to(self.device, dtype=torch.float32)
                    if getattr(self, "b_phase1_ema", None) is not None else None)

        for batch in dataloader:
            masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
            # 1) encode & recon: latent_states [t_eff,B,D], recon loss on [nf_cond-1:]
            latent_states, recon_loss, recon_loss_wzmask = self._encode_and_recon(batch)  # recon_loss_wzmask is masked MSE
            T_eff, B, D = latent_states.shape
            nsamples += B

            # 2) closed-form A,b (per-batch) OR fixed A_ema,b_ema
            Z0 = latent_states[:-1].reshape(-1, D)  # [N, D], N=(T_eff-1)*B
            Zp = latent_states[ 1:].reshape(-1, D)  # [N, D]
            if A_fix is None:
                A_use, b_use = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)
            else:
                A_use, b_use = A_fix, b_fix

            # 3) latent linear consistency loss: MSE(Zp_hat, Zp)
            Zp_hat = (Z0 @ A_use.T) + (b_use if (use_bias and b_use is not None) else 0.0)
            dyn_mse = F.mse_loss(Zp_hat, Zp)

            # 4) optional: one-step decoded loss in observation space
            if use_pred_loss:
                a1_hat = Zp_hat.view(T_eff - 1, B, D)                # [T_eff-1, B, D]
                x1_hat = self._decode_latent(a1_hat)                 # [B, T_eff-1, H, W, C]
                x1_true = batch["data"][:, self.n_frames_cond:, ...].to(self.device)  # next frames [nf_cond : T-1]
                pred_mse, pred_mse_wzmask = self._mse_loss(x1_hat, x1_true, masks[:, self.n_frames_cond:, ...])
            else:
                pred_mse_wzmask = torch.tensor(0.0, device=self.device)

            # accumulate (weight by batch size for fair averaging)
            tot_rec_m += float(recon_loss_wzmask) * B
            tot_dyn   += float(dyn_mse)            * B
            tot_pred_m += float(pred_mse_wzmask)   * B

        # finalize averages
        rec_masked = tot_rec_m / max(1, nsamples)
        dyn_avg    = tot_dyn   / max(1, nsamples)
        pred_masked= tot_pred_m/ max(1, nsamples)

        # optional: spectral diagnostics when using A_ema
        diag = {}
        if A_fix is not None:
            try:
                eigvals = torch.linalg.eigvals(A_fix).detach().cpu()
                diag["rho(A_ema)"] = float(eigvals.abs().max())       # discrete spectral radius
            except Exception:
                pass

        return {
            "rec_masked": rec_masked,
            "dyn_mse": dyn_avg,
            "pred_masked": pred_masked if use_pred_loss else None,
            "use_global_A": (A_fix is not None),
            "diag": diag
        }


    @torch.no_grad()
    def _phase1_ckpt_payload(self, epoch: int | None, metrics: dict | None = None):
        """
        Pack a portable Phase-I checkpoint. Everything is moved to CPU tensors
        so the file can be loaded on any device later.
        """
        def sd_cpu(module: nn.Module):
            return {k: v.detach().cpu() for k, v in module.state_dict().items()}

        payload = {
            "epoch": epoch,
            "run_id": getattr(self, "run_id", None),
            "model_cfg": getattr(self, "model_cfg", None),
            "encoder_state": sd_cpu(self.encoder),
            "decoder_state": sd_cpu(self.decoder),
            # Stash Phase-I linear stats if available
            "A_ema": (self.A_phase1_ema.detach().cpu()
                    if hasattr(self, "A_phase1_ema") and self.A_phase1_ema is not None else None),
            "b_ema": (self.b_phase1_ema.detach().cpu()
                    if hasattr(self, "b_phase1_ema") and self.b_phase1_ema is not None else None),
            "metrics": metrics,
        }
        return payload


    def save_phase1_checkpoint(self, epoch: int | None, metrics: dict | None, pth_name: str):
        """
        Write a Phase-I checkpoint file under this run's output directory.
        """
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, pth_name)
        torch.save(self._phase1_ckpt_payload(epoch, metrics), path)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Phase-I] checkpoint saved: {path}")


    def load_phase1_ckpt(
        self,
        path: str,
        *,
        restore_modules: bool = True,
        init_linear: bool = True,
        clip_positive_symmetric: bool = False,
        max_pos_real: float = 0.0,
        eps_eye: float = 1e-9,
        strict: bool = True,
    ):
        """
        Load a Phase-I checkpoint and optionally:
        1) restore encoder/decoder weights,
        2) initialize the latent linear dynamics from Ad (=A_ema),
        3) compute the fixed-point center z* from (I - Ad) z* = b_ema for de-biasing.

        Args
        ----
        path : str
            Path to a Phase-I checkpoint saved by `save_phase1_checkpoint`.
        restore_modules : bool
            If True, load encoder/decoder from ckpt["encoder_state"/"decoder_state"].
        init_linear : bool
            If True, read Ad (= ckpt["A_ema"]) and call
            `self.latent_process.init_linear_from_Ad(Ad, dt, ...)`.
        clip_positive_symmetric, max_pos_real : see your latent_process.init_linear_from_Ad.
        eps_eye : float
            Small Tikhonov regularizer for solving (I - Ad) z* = b when near-singular.
        strict : bool
            Passed to load_state_dict for encoder/decoder.

        Returns
        -------
        info : dict
            {
            "Ad": torch.Tensor | None,
            "b":  torch.Tensor | None,
            "z_star": torch.Tensor | None,
            "dt": float | None,
            "rho_Ad": float
            }
        """
        # ---- 0) Load checkpoint on CPU first (portable), then move what we need. ----
        ckpt = torch.load(path, map_location="cpu")

        # ---- 1) Optionally restore encoder/decoder weights. ----
        if restore_modules:
            enc_state = ckpt.get("encoder_state", None)
            dec_state = ckpt.get("decoder_state", None)
            if enc_state is not None:
                self.encoder.load_state_dict(enc_state, strict=strict)
            if dec_state is not None:
                self.decoder.load_state_dict(dec_state, strict=strict)

        Ad = None
        b = None
        z_star = None
        rho = float("nan")

        # ---- 2) Optionally initialize continuous-time linear dynamics from Ad. ----
        if init_linear:
            Ad_cpu = ckpt.get("A_ema", None)
            if Ad_cpu is None:
                raise KeyError("[load_phase1_ckpt] 'A_ema' is missing in the checkpoint.")
            Ad = Ad_cpu.to(self.device, dtype=torch.float32)

            # Basic shape checks
            if Ad.dim() != 2 or Ad.size(0) != Ad.size(1):
                raise ValueError(f"[load_phase1_ckpt] Ad must be square, got {tuple(Ad.shape)}")
            if Ad.size(0) != self.latent_dim:
                raise ValueError(f"[load_phase1_ckpt] Ad dim {Ad.size(0)} != latent_dim {self.latent_dim}")

            if self.latent_mode == "continuous":
                # Initialize continuous-time drift from Ad
                self.latent_process.init_linear_from_Ad(
                    Ad, self.dt_eval,
                    clip_positive_symmetric=clip_positive_symmetric,
                    max_pos_real=max_pos_real,
                )
            elif self.latent_mode == "discrete":
                print(Ad.shape)
                print(self.latent_process.A.shape)
                with torch.no_grad():
                    self.latent_process.A.copy_(Ad)

            # Try to compute bias center if b_ema exists
            b_cpu = ckpt.get("b_ema", None)
            if b_cpu is not None:
                b = b_cpu.to(self.device, dtype=torch.float32).view(-1)
                if b.numel() != self.latent_dim:
                    raise ValueError(f"[load_phase1_ckpt] b dim {b.numel()} != latent_dim {self.latent_dim}")

                # Solve (I - Ad) z* = b with small Tikhonov for stability
                D = Ad.size(0)
                I = torch.eye(D, device=self.device, dtype=Ad.dtype)
                M = I - Ad
                M_reg = M + eps_eye * I
                try:
                    z_star = torch.linalg.solve(M_reg, b)
                except RuntimeError:
                    z_star = torch.linalg.pinv(M_reg) @ b

                # Cache for later centering use
                self.latent_center = z_star.detach().clone()
            else:
                self.latent_center = None

            # Keep copies for diagnostics/exports
            self.Ad_phase1 = Ad.detach().clone()
            self.b_phase1 = b.detach().clone() if b is not None else None

            # Spectral radius of Ad (discrete)
            try:
                rho = torch.max(torch.abs(torch.linalg.eigvals(Ad))).item()
            except Exception:
                rho = float("nan")

        # ---- 3) Logging ----
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(
                f"[load_phase1_ckpt] from {path} | "
                f"restore_modules={restore_modules} | init_linear={init_linear} | "
                f"bias={'yes' if b is not None else 'no'} | "
                f"center={'yes' if z_star is not None else 'no'} | "
                f"rho(Ad)={rho:.4f}"
            )

        return {"Ad": Ad, "b": b, "z_star": z_star, "rho_Ad": rho}


    """def plot_Ad_spectrum(
        self,
        Ad: torch.Tensor | None = None,
        ckpt_path: str | None = None,
        key: str = "A_ema",
        save_dir: str | None = None,
        fname_prefix: str = "Ad_spectrum"
    ):
        # -------- 1) Locate / validate Ad --------
        if Ad is None:
            if ckpt_path is not None:
                ckpt = torch.load(ckpt_path, map_location="cpu")
                if key not in ckpt:
                    raise KeyError(f"Key '{key}' not found in checkpoint: {ckpt_path}")
                Ad = ckpt[key]
            elif hasattr(self, "Ad_phase1") and (self.Ad_phase1 is not None):
                Ad = self.Ad_phase1
            elif hasattr(self, "A_phase1_ema") and (self.A_phase1_ema is not None):
                Ad = self.A_phase1_ema
            else:
                raise ValueError("Ad is None and neither ckpt nor cached Ad found.")

        Ad = Ad.detach().to("cpu", dtype=torch.float64)
        if Ad.dim() != 2 or Ad.size(0) != Ad.size(1):
            raise ValueError(f"Ad must be square [D,D], got {tuple(Ad.shape)}")
        D = Ad.size(0)

        # -------- 2) Eigen-decomposition (discrete spectrum) --------
        # Use torch for complex eigvals, then convert to numpy
        lam = torch.linalg.eigvals(Ad).cpu().numpy()  # complex128
        mod = np.abs(lam)
        rho = float(mod.max()) if lam.size else float("nan")

        # -------- 3) Complex-plane scatter with unit circle --------
        fig1 = plt.figure()
        ax1 = fig1.gca()
        ax1.scatter(lam.real, lam.imag, s=16, alpha=0.85)  # default matplotlib color
        # Unit circle
        theta = np.linspace(0.0, 2.0 * np.pi, 512)
        ax1.plot(np.cos(theta), np.sin(theta), linestyle="--", linewidth=1.0)
        # Axes & aspect
        ax1.axhline(0.0, linewidth=0.8)
        ax1.axvline(0.0, linewidth=0.8)
        ax1.set_aspect("equal", adjustable="box")
        ax1.set_xlabel("Re(λ)")
        ax1.set_ylabel("Im(λ)")
        ax1.set_title(f"Spec(Ad)  D={D},  ρ(Ad)={rho:.4f}  (unit circle shown)")

        out_paths = []
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            p1 = os.path.join(save_dir, f"{fname_prefix}_complex.png")
            fig1.savefig(p1, dpi=600, bbox_inches="tight")
            out_paths.append(p1)

        # -------- 4) Magnitude histogram (|λ|) --------
        fig2 = plt.figure()
        ax2 = fig2.gca()
        nbins = int(np.clip(D // 2, 10, 60))  # a reasonable bin count
        ax2.hist(mod, bins=nbins)
        ax2.set_xlabel("|λ|")
        ax2.set_ylabel("count")
        ax2.set_title("Spectrum magnitude histogram |λ|")

        if save_dir:
            p2 = os.path.join(save_dir, f"{fname_prefix}_abs_hist.png")
            fig2.savefig(p2, dpi=600, bbox_inches="tight")
            out_paths.append(p2)"""
    

    def plot_Ad_spectrum(
        self,
        Ad: torch.Tensor | None = None,
        ckpt_path: str | None = None,
        key: str = "A_ema",
        save_dir: str | None = None,
        fname_prefix: str = "Ad_spectrum",
        *,
        # visibility knobs (unchanged except defaults you set before)
        s_min: float = 14.0,
        s_max: float = 110.0,
        alpha_min: float = 0.55,
        alpha_max: float = 0.95,
        dist_percentile: float = 95.0,
        dist_gamma: float = 0.7,
        figsize_complex=(6.2, 6.2),
        figsize_hist=(6.2, 3.6),
        hist_bins: int | None = None,
        save_matrix: bool = True,
        save_numpy_copy: bool = True,
        # --- new style knobs (for “four spines + bigger ticks/caption”) ---
        spine_lw: float = 1.2,        # linewidth for all four spines
        tick_labelsize: int = 14,     # axis tick numbers size
        axis_labelsize: int = 15,     # axis label size
        rho_fontsize: int = 15,       # caption font size for ρ(Ad)
        dpi: int = 450
    ):
        """
        Complex plane: hue = arg(λ) (cyclic scientific colormap), size/alpha ∝ |λ−1|.
        Unit circle = solid, thicker; spectral-radius circle dashed; ρ(Ad) annotated.
        No title, bold LaTeX axis labels, no colorbar. Also saves Ad to disk.
        """
        import os, numpy as np, matplotlib.pyplot as plt
        import matplotlib.ticker as mticker
        from matplotlib import rcParams
        from matplotlib.cm import get_cmap
        import torch

        # --- Resolve / validate Ad ---
        if Ad is None:
            if ckpt_path is not None:
                ckpt = torch.load(ckpt_path, map_location="cpu")
                if key not in ckpt:
                    raise KeyError(f"Key '{key}' not found in checkpoint: {ckpt_path}")
                Ad = ckpt[key]
            elif hasattr(self, "Ad_phase1") and (self.Ad_phase1 is not None):
                Ad = self.Ad_phase1
            elif hasattr(self, "A_phase1_ema") and (self.A_phase1_ema is not None):
                Ad = self.A_phase1_ema
            else:
                raise ValueError("Ad is None and neither ckpt nor cached Ad found.")
        Ad = Ad.detach().to("cpu", dtype=torch.float64)
        if Ad.dim() != 2 or Ad.size(0) != Ad.size(1):
            raise ValueError(f"Ad must be square [D,D], got {tuple(Ad.shape)}")
        D = Ad.size(0)

        # Save matrix copies
        out_paths = []
        if save_dir and save_matrix:
            os.makedirs(save_dir, exist_ok=True)
            pt_path = os.path.join(save_dir, f"{fname_prefix}_Ad.pt")
            torch.save(Ad, pt_path); out_paths.append(pt_path)
            if save_numpy_copy:
                npy_path = os.path.join(save_dir, f"{fname_prefix}_Ad.npy")
                np.save(npy_path, Ad.numpy()); out_paths.append(npy_path)

        # --- Spectrum ---
        lam = torch.linalg.eigvals(Ad).cpu().numpy()      # complex128
        if lam.size == 0:
            raise RuntimeError("No eigenvalues found.")
        re, im = lam.real, lam.imag
        mod = np.abs(lam)
        rho = float(mod.max())

        # Hue by phase (cyclic)
        phase = np.angle(lam)                             # [-π, π]
        phase_norm = (phase + np.pi) / (2.0 * np.pi)      # [0, 1]
        try:
            cmap = get_cmap("twilight_shifted")
        except Exception:
            cmap = get_cmap("twilight")
        colors = cmap(phase_norm)                         # RGBA
        colors[:, :3] = np.clip(colors[:, :3] * 0.85, 0, 1)  # slightly darken

        # Size/alpha by |λ−1|
        dist1 = np.abs(lam - 1.0)
        scale = float(np.percentile(dist1, dist_percentile)) or 1.0
        dist_norm = np.clip((dist1 / scale) ** dist_gamma, 0.0, 1.0)
        sizes  = s_min + (s_max - s_min) * dist_norm
        alphas = alpha_min + (alpha_max - alpha_min) * dist_norm
        colors[:, 3] = alphas

        # --- Global style (keep grid; enable LaTeX math) ---
        rcParams.update({
            "axes.grid": True,
            "grid.linestyle": "--",
            "grid.alpha": 0.22,
            "axes.formatter.use_mathtext": True,
        })

        # ======================= Complex plane =======================
        fig1, ax1 = plt.subplots(figsize=figsize_complex)

        # Four spines visible and thickened
        for side in ("left", "right", "top", "bottom"):
            ax1.spines[side].set_visible(True)
            ax1.spines[side].set_linewidth(spine_lw)

        # Larger tick labels (and minor ticks for niceness)
        ax1.tick_params(axis="both", which="major", labelsize=tick_labelsize, direction="out", length=4.5, width=0.9)
        ax1.tick_params(axis="both", which="minor", labelsize=tick_labelsize-1, direction="out", length=3.0, width=0.7)
        ax1.xaxis.set_minor_locator(mticker.AutoMinorLocator(4))
        ax1.yaxis.set_minor_locator(mticker.AutoMinorLocator(4))

        theta = np.linspace(0.0, 2.0 * np.pi, 720)
        # Unit circle
        ax1.plot(np.cos(theta), np.sin(theta),
                linestyle="-", linewidth=2.0, color="#222222", zorder=1)
        # Spectral-radius circle
        ax1.plot(rho * np.cos(theta), rho * np.sin(theta),
                linestyle="--", linewidth=1.6, color="#c44e52", zorder=1)

        # Eigenvalues
        ax1.scatter(re, im,
                    s=sizes, c=colors,
                    linewidths=0.25, edgecolors=(0, 0, 0, 0.18), zorder=3)

        # Axes & limits
        ax1.axhline(0.0, linewidth=0.8, alpha=0.6, color="#4c4c4c")
        ax1.axvline(0.0, linewidth=0.8, alpha=0.6, color="#4c4c4c")
        ax1.set_aspect("equal", adjustable="box")
        R = 1.05 * max(1.0,
                    np.max(np.abs(re)) if re.size else 1.0,
                    np.max(np.abs(im)) if im.size else 1.0,
                    rho)
        ax1.set_xlim([-R, R]); ax1.set_ylim([-R, R])

        # Bold labels; bigger font sizes
        ax1.set_xlabel(r"$\mathbf{Re}(\lambda)$", fontweight="bold", fontsize=axis_labelsize)
        ax1.set_ylabel(r"$\mathbf{Im}(\lambda)$", fontweight="bold", fontsize=axis_labelsize)

        # Spectral radius caption (bigger)
        ax1.text(
            0.02, 0.98, rf"$\rho(\mathbf{{A}}_d) = {rho:.4f}$",
            transform=ax1.transAxes, ha="left", va="top",
            fontsize=rho_fontsize,
            bbox=dict(boxstyle="round,pad=0.25", facecolor="white",
                    edgecolor="#c44e52", alpha=0.95)
        )

        if save_dir:
            p1 = os.path.join(save_dir, f"{fname_prefix}_complex.png")
            fig1.savefig(p1, dpi=dpi, bbox_inches="tight")
            out_paths.append(p1)

        # ======================= |λ| histogram =======================
        fig2, ax2 = plt.subplots(figsize=figsize_hist)

        # Four spines visible and thickened
        for side in ("left", "right", "top", "bottom"):
            ax2.spines[side].set_visible(True)
            ax2.spines[side].set_linewidth(spine_lw)

        # Larger tick labels & minor ticks
        ax2.tick_params(axis="both", which="major", labelsize=tick_labelsize, direction="out", length=4.5, width=0.9)
        ax2.tick_params(axis="both", which="minor", labelsize=tick_labelsize-1, direction="out", length=3.0, width=0.7)
        ax2.xaxis.set_minor_locator(mticker.AutoMinorLocator(4))
        ax2.yaxis.set_minor_locator(mticker.AutoMinorLocator(4))

        nbins = int(np.clip(D // 2, 16, 60)) if hist_bins is None else int(hist_bins)
        counts, edges = np.histogram(mod, bins=nbins)
        ax2.step(edges[:-1], counts, where="post", linewidth=1.6, color="#4c72b0")
        ax2.fill_between(edges[:-1], counts, step="post", alpha=0.15, color="#4c72b0")
        ax2.axvline(1.0, linestyle="-", linewidth=1.2, color="#222222")   # unit radius
        ax2.axvline(rho,  linestyle="--", linewidth=1.2, color="#c44e52") # spectral radius

        ax2.set_xlabel(r"$|\lambda|$", fontweight="bold", fontsize=axis_labelsize)
        ax2.set_ylabel(f"count", fontweight="bold", fontsize=axis_labelsize)

        if save_dir:
            p2 = os.path.join(save_dir, f"{fname_prefix}_abs_hist.png")
            fig2.savefig(p2, dpi=dpi, bbox_inches="tight")
            out_paths.append(p2)

        plt.close(fig1); plt.close(fig2)
        return {"rho": rho, "paths": out_paths}


    def _center_latent(self, z: torch.Tensor) -> torch.Tensor:
        """Subtract the learned fixed-point center z* if available, preserving shape.
        - Input:  z [..., D]  (e.g., [D], [B,D], [T,B,D])
        - Output: same shape as input
        """
        zc = getattr(self, "latent_center", None)
        if zc is None:
            return z
        # Safety: last dim must match
        assert z.shape[-1] == zc.numel(), f"center dim {zc.numel()} != z last dim {z.shape[-1]}"
        if z.ndim == 1:
            return z - zc.to(z)
        # Shape like [1, 1, ..., D] with (z.ndim-1) leading ones
        view_shape = [1] * (z.ndim - 1) + [-1]
        return z - zc.to(z).view(*view_shape)


    def _decenter_latent(self, z: torch.Tensor) -> torch.Tensor:
        """Add the learned fixed-point center z* back if available, preserving shape.
        - Input:  z [..., D]
        - Output: same shape as input
        """
        zc = getattr(self, "latent_center", None)
        if zc is None:
            return z
        assert z.shape[-1] == zc.numel(), f"center dim {zc.numel()} != z last dim {z.shape[-1]}"
        if z.ndim == 1:
            return z + zc.to(z)
        view_shape = [1] * (z.ndim - 1) + [-1]
        return z + zc.to(z).view(*view_shape)
    

    def _encode_and_recon(self, batch, return_recon: bool = False):
        ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
        sample_idx = batch["index"].to(self.device)     # [B,]
        masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
        t_eval = batch['t'][0].to(self.device)          # [T]
        bs, train_len, H, W, _ = ground_truth.shape
        # assert train_len == self.n_frames_train

        mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
        mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
        delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
        data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]

        data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
        # print(data_in.shape)
        pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
        pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
        data_in = torch.cat((data_in, pos_in), dim=-1) 

        # latent_states: encoded from encoder / dyn_states: predicted using latent ode
        latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token] (t=T-nf_cond+1)
        # reshape to [B, t, latent_dim]
        latent_states = latent_token.reshape((bs, train_len-self.n_frames_cond+1, -1))    # [B, t, latent_dim]
        latent_states = latent_states.permute(1, 0, 2)    # [t, B, latent_dim]

        recon_encdec = self._decode_latent(latent_seqs=latent_states)    # [B, t, H, W, c]
        recon_loss, recon_loss_wz_mask = self._mse_loss(recon_encdec, ground_truth[:, self.n_frames_cond-1:, ...],
                                                mask=masks[:, self.n_frames_cond-1:, ...])
        if return_recon:
            # Return both losses and tensors needed for spectral/multiscale penalties
            x_rec   = recon_encdec
            x_gt    = ground_truth[:, self.n_frames_cond - 1:, ...]
            m_rec   = masks[:, self.n_frames_cond - 1:, ...]
            return latent_states, recon_loss, recon_loss_wz_mask, x_rec, x_gt, m_rec
        else:
            return latent_states, recon_loss, recon_loss_wz_mask

    
    @torch.no_grad()    # for evaluation
    def _encode_cond_batch(self, batch):
        ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
        masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
        bs, length, H, W, _ = ground_truth.shape 
        with torch.no_grad():
            mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
            delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
            data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
            data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
            pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
            pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
            data_in = torch.cat((data_in, pos_in), dim=-1) 

            # latent_states: encoded from encoder / dyn_states: predicted using latent ode
            latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
            # reshape to [B, latent_dim]
            latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]
        return latent_state
        
        
    def _mse_loss(self, data1: torch.Tensor, data2: torch.Tensor, mask: torch.Tensor):
        # datas: [B, T, H, W, s], mask: [B, T, H, W, s] or None
        # check!!!!!!!!!!
        criterion = nn.MSELoss(reduction="none")
        loss = criterion(data1, data2)
        mse = loss.mean()
        
        sqerr = (data1 - data2).pow(2) * mask
        sqerr_sum = sqerr.sum(dim=(2, 3))
        denom = mask.sum(dim=(2, 3)).clamp_(min=1e-6)
        mse_wzmask = (sqerr_sum / denom).mean()
        return mse, mse_wzmask
    

    """def _fft_mse_on_frames(self,
                           x_hat: torch.Tensor,   # [B, T, H, W, C]
                           x_true: torch.Tensor,  # [B, T, H, W, C]
                           *,
                           use_log_mag: bool = True,
                           hf_power: float = 0.0) -> torch.Tensor:
        B, T, H, W, C = x_hat.shape

        # Flatten to [B*T*C, H, W]
        def flat_hw(x):
            return x.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, H, W)

        Xh = flat_hw(x_hat)
        Xt = flat_hw(x_true)

        # 2D FFT (real-to-complex, orthonormal normalization)
        Fh = torch.fft.rfft2(Xh, norm="ortho")
        Ft = torch.fft.rfft2(Xt, norm="ortho")

        Mh = torch.abs(Fh)
        Mt = torch.abs(Ft)
        if use_log_mag:
            Mh = torch.log(Mh + 1e-6)
            Mt = torch.log(Mt + 1e-6)

        # Optional radial HF emphasis
        if hf_power > 0.0:
            fy = torch.fft.fftfreq(H, d=1.0).to(Xh.device)      # [-.5, .5)
            fx = torch.fft.rfftfreq(W, d=1.0).to(Xh.device)     # [0, .5]
            gy = fy[:, None].expand(H, fx.numel())
            gx = fx[None, :].expand(H, fx.numel())
            r  = torch.sqrt(gx * gx + gy * gy)
            r  = (r / (r.max() + 1e-12)).pow(hf_power)          # [H, W_rfft]
            Mh = Mh * r
            Mt = Mt * r

        return F.mse_loss(Mh, Mt)"""

    
    def _fft_mse_on_frames(
        self,
        x_hat: torch.Tensor,   # [B, T, H, W, C]
        x_true: torch.Tensor,  # [B, T, H, W, C]
        *,
        use_log_mag: bool = True,
        hf_power: float = 0.0
    ) -> torch.Tensor:
        """
        Frequency-domain MSE on (log-)magnitude spectra of 2D rFFT per frame.
        Steps: flatten -> per-frame de-mean -> rFFT -> |.| -> log1p (opt)
            -> radial HF weighting (opt) -> per-sample spectral normalization -> MSE.
        Returns a scalar tensor.
        """
        B, T, H, W, C = x_hat.shape

        # Flatten to [B*T*C, H, W]
        def flat_hw(x: torch.Tensor) -> torch.Tensor:
            return x.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, H, W)
        Xh = flat_hw(x_hat)
        Xt = flat_hw(x_true)
        # 1) Per-frame de-mean in spatial domain (kills DC)
        Xh = Xh - Xh.mean(dim=(1, 2), keepdim=True)
        Xt = Xt - Xt.mean(dim=(1, 2), keepdim=True)
        # 2) 2D rFFT with orthonormal scaling
        Fh = torch.fft.rfft2(Xh, norm="ortho")   # [B*T*C, H, W_r]
        Ft = torch.fft.rfft2(Xt, norm="ortho")   # [B*T*C, H, W_r]
        W_r = Fh.size(-1)
        # 3) Magnitude (optionally log1p)
        Mh = torch.abs(Fh)
        Mt = torch.abs(Ft)
        if use_log_mag:
            Mh = torch.log1p(Mh)
            Mt = torch.log1p(Mt)
        # 4) Optional radial high-frequency emphasis
        if hf_power > 0.0:
            fy = torch.fft.fftfreq(H, d=1.0).to(Xh.device, Xh.dtype)      # [H]
            fx = torch.fft.rfftfreq(W, d=1.0).to(Xh.device, Xh.dtype)     # [W_r]
            gy = fy[:, None]                                              # [H, 1]
            gx = fx[None, :]                                              # [1, W_r]
            r  = torch.sqrt(gx * gx + gy * gy)                            # [H, W_r]
            r  = r / (r.max() + 1e-12)
            w  = r.pow(hf_power)
        else:
            w = Xh.new_ones((H, W_r))                                     # [H, W_r]
        # (robust) completely remove DC contribution in frequency loss
        w = w.clone()
        w[0, 0] = 0.0
        Mh = Mh * w
        Mt = Mt * w
        # 5) Per-sample spectral normalization (keeps loss scale tidy)
        def normalize(a: torch.Tensor) -> torch.Tensor:
            scale = a.mean(dim=(1, 2), keepdim=True)                      # [B*T*C, 1, 1]
            return a / (scale + 1e-6)
        Mh = normalize(Mh)
        Mt = normalize(Mt)
        # 6) MSE over spectra
        return F.mse_loss(Mh, Mt)


    def _multiscale_spatial_mse(self,
                                x_hat: torch.Tensor,  # [B, T, H, W, C]
                                x_true: torch.Tensor, # [B, T, H, W, C]
                                *,
                                pool_scales: tuple[int, ...] = (2, 4)) -> torch.Tensor:
        """
        Multi-scale MSE via average pooling pyramid.
        For each s in pool_scales, downsample by s in H and W and compute MSE.
        """
        if not pool_scales:
            return x_hat.new_zeros(())
        B, T, H, W, C = x_hat.shape
        xh = x_hat.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, 1, H, W)  # [BTC,1,H,W]
        xt = x_true.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, 1, H, W)

        loss = x_hat.new_zeros(())
        n = 0
        for s in pool_scales:
            if H // s < 1 or W // s < 1:
                continue
            pool = torch.nn.AvgPool2d(kernel_size=s, stride=s, ceil_mode=False, count_include_pad=False)
            xh_s = pool(xh)
            xt_s = pool(xt)
            loss = loss + F.mse_loss(xh_s, xt_s)
            n += 1
        if n == 0:
            return x_hat.new_zeros(())
        return loss / n


    def _decode_latent(self, latent_seqs: torch.Tensor):
        """
        Decode a latent sequence a_seq [T',B,D] into field tensor aligned per your decoder.
        Returns [B,T',H,W,C] (or your state_dim layout).
        """
        T1, B, D = latent_seqs.shape
        if self.dec_mode == "fouriermlp":
            latent_feats_ = latent_seqs.permute(1, 0, 2).flatten(0, 1)    # [B*t, latent_dim]
            grid_dim = self.pos_feat.shape[-1]
            grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
            recon_field_ = self.decoder(grid=grid, latent_feat=latent_feats_)    # [B*t, N_pts, out_dim]
            recon_field_ = recon_field_.reshape(B, T1, *self.shapelist, self.state_dim)
            return recon_field_
        elif self.dec_mode == "fouriernet":
            latent_feats_ = latent_seqs.permute(1, 0, 2).view(B, T1, self.state_dim, self.code_dim)  # [B, t, s, code_dim]
            recon_field_ = self.decoder(latent_feats_)  # [B, t, H, W, s]
            return recon_field_
        else:
            raise ValueError(f"Unknown dec_mode: {self.dec_mode}")


    @torch.no_grad()
    def print_latent_of_index(self, group: str, seq_index: int, num_steps: int = 8, show_norms: bool = True,):
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        assert 0 <= seq_index < len(dataset), f"seq_index cross the border: 0..{len(dataset)-1}"
        from torch.utils.data import Subset, DataLoader
        tmp = DataLoader(Subset(dataset, [seq_index]), batch_size=1, shuffle=False)
        batch = next(iter(tmp))
        lat, _, _ = self._encode_and_recon(batch)   # [T', B=1, D]
        lat = lat[:, 0].detach().cpu()              # [T', D]
        Tp, D = lat.shape
        print(f"[latent] shape = [T'={Tp}, D={D}]  (group={group}, idx={seq_index})")
        n = Tp if (num_steps is None) else max(0, min(num_steps, Tp))
        if n > 0:
            print(lat[:n])
        if show_norms:
            # Per-step norms across feature dimension D
            # L2 (Euclidean), Linf (max-abs), and mean absolute value
            l2 = torch.linalg.norm(lat, ord=2, dim=-1)                # [T']
            linf = torch.linalg.norm(lat, ord=float("inf"), dim=-1)   # [T']
            mean_abs = lat.abs().mean(dim=-1)                         # [T']

            k = n if n > 0 else min(8, Tp)  # how many time steps to list explicitly
            print(f"\nPer-step norms for first {k} steps (L2, L∞, mean|.|):")
            for t in range(k):
                print(
                    f"t={t:02d}: "
                    f"L2={float(l2[t]):.6f}  "
                    f"Linf={float(linf[t]):.6f}  "
                    f"mean|.|={float(mean_abs[t]):.6f}"
                )

            # Summary statistics over the whole sequence
            print("\nSummary across all steps:")
            print(
                "  L2:    "
                f"min={float(l2.min()):.6f}  "
                f"max={float(l2.max()):.6f}  "
                f"mean={float(l2.mean()):.6f}"
            )
            print(
                "  Linf:  "
                f"min={float(linf.min()):.6f}  "
                f"max={float(linf.max()):.6f}  "
                f"mean={float(linf.mean()):.6f}"
            )
            print(
                "  mean|.|: "
                f"min={float(mean_abs.min()):.6f}  "
                f"max={float(mean_abs.max()):.6f}  "
                f"mean={float(mean_abs.mean()):.6f}"
            )
        return lat
    

    """@torch.no_grad()
    def visualize_latent_pca2d(
        self,
        group: str,
        n_traj: int = 16,
        max_time: int | None = None,
        save_path: str | None = None,
        time_style: str = "color+arrows",   # "color", "arrows", or "color+arrows"
        arrow_every: int = 5,               # place an arrow every k steps
        show_colorbar: bool = True,         # show colorbar when using color/gradient
    ):
    
        import os
        import numpy as np
        import matplotlib.pyplot as plt
        from matplotlib.collections import LineCollection

        assert time_style in {"color", "arrows", "color+arrows"}

        # ---------- 1) Collect latent sequences ----------
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)
        n_traj = min(n_traj, N)

        from torch.utils.data import Subset, DataLoader
        idxs = np.arange(N)[:n_traj].tolist()   # deterministic; replace with np.random.choice for random pick
        tmp_loader = DataLoader(Subset(dataset, idxs), batch_size=1, shuffle=False, num_workers=0)

        lat_list = []
        for batch in tmp_loader:
            lat, _, _ = self._encode_and_recon(batch)   # [T',1,D]
            lat_list.append(lat[:, 0].detach().cpu())   # [T',D] on CPU

        # ---------- 2) Fit PCA on stacked latents and project to 2D ----------
        X = torch.cat(lat_list, dim=0).float()          # [sum_T', D]
        mu = X.mean(dim=0, keepdim=True)                # [1,D]
        Xc = X - mu
        C = Xc.T @ Xc / max(1, Xc.size(0) - 1)          # [D,D]
        vals, vecs = torch.linalg.eigh(C)               # ascending
        W = vecs[:, -2:]                                 # take top-2 eigenvectors -> [D,2]

        proj_list = []
        T_plot_list = []
        for lat in lat_list:
            Z = (lat.float() - mu).mm(W)                 # [T',2]
            if max_time is not None:
                Z = Z[:max_time]
            proj_list.append(Z.numpy())
            T_plot_list.append(len(Z))

        # For color normalization across different lengths
        Tmax = max(T_plot_list) if T_plot_list else 0

        # ---------- 3) Plot ----------
        fig, ax = plt.subplots()
        used_color = ("color" in time_style)

        for Z in proj_list:
            if len(Z) < 2:
                # Not enough points to draw a polyline; scatter start/end only
                ax.scatter(Z[0, 0], Z[0, 1], s=30)
                continue

            # (a) color-gradient polyline for time direction
            if used_color:
                # Build line segments [(p0->p1), (p1->p2), ...]
                segments = np.stack([Z[:-1], Z[1:]], axis=1)  # [T-1, 2, 2]
                lc = LineCollection(
                    segments,
                    array=np.arange(segments.shape[0]),  # color by step index
                    cmap="viridis",                      # default colormap
                    linewidth=1.5,
                )
                ax.add_collection(lc)
            else:
                # Regular polyline without gradient
                ax.plot(Z[:, 0], Z[:, 1], linewidth=1.0, alpha=0.9)

            # (b) optional arrows every k steps (avoid clutter)
            if ("arrows" in time_style) and (len(Z) >= 2) and (arrow_every is not None) and (arrow_every > 0):
                idx = np.arange(0, len(Z) - 1, arrow_every)
                for i in idx:
                    ax.annotate(
                        "", xy=(Z[i + 1, 0], Z[i + 1, 1]), xytext=(Z[i, 0], Z[i, 1]),
                        arrowprops=dict(arrowstyle="->", lw=0.8, alpha=0.9)
                    )

            # (c) mark start/end
            ax.scatter(Z[0, 0],  Z[0, 1],  s=32)  # start
            ax.scatter(Z[-1, 0], Z[-1, 1], s=32, marker="x")  # end

        # Single colorbar keyed to "time step" (0 .. Tmax-2) if using color
        if used_color and show_colorbar and Tmax >= 2:
            import matplotlib as mpl
            norm = mpl.colors.Normalize(vmin=0, vmax=Tmax - 2)
            sm = mpl.cm.ScalarMappable(cmap="viridis", norm=norm)
            sm.set_array([])
            cb = fig.colorbar(sm, ax=ax)
            cb.set_label("time step")

        ax.set_xlabel("PC1")
        ax.set_ylabel("PC2")
        ax.set_title(f"PCA(α_t)  group={group}, trajectories={len(proj_list)}")
        ax.axhline(0.0, linewidth=0.5)
        ax.axvline(0.0, linewidth=0.5)
        ax.set_aspect("equal", adjustable="box")

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            fig.savefig(save_path, dpi=200, bbox_inches="tight")
            print("Saved latent PCA plot to:", save_path)

        plt.close(fig)"""


    @torch.no_grad()
    def visualize_latent_pca2d(
        self,
        group: str,
        n_traj: int = 16,
        max_time: int | None = None,
        save_path: str | None = None,
        time_style: str = "color+arrows",   # "color", "arrows", or "color+arrows"
        arrow_every: int = 5,
        show_colorbar: bool = True,
        center: str = "none",               # <- 新增: "none" | "zstar" | "dataset"
        overlay_compare: bool = False,      # <- 新增: 是否叠加去中心与中心化
    ):
        import numpy as np
        import matplotlib.pyplot as plt
        from matplotlib.collections import LineCollection
        from torch.utils.data import Subset, DataLoader

        assert time_style in {"color", "arrows", "color+arrows"}
        assert center in {"none", "zstar", "dataset"}

        # ---------- 1) 收集若干条 latent 轨迹 ----------
        self._ensure_loader(group)
        loader = {"train": self.train_loader, "train_eval": self.train_eval_loader, "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)
        n_traj = min(n_traj, N)
        idxs = np.arange(N)[:n_traj].tolist()
        tmp_loader = DataLoader(Subset(dataset, idxs), batch_size=1, shuffle=False, num_workers=0)

        lat_list = []
        for batch in tmp_loader:
            lat, _, _ = self._encode_and_recon(batch)   # [T',1,D]
            z = lat[:, 0].detach().cpu().float()        # [T',D]
            if max_time is not None:
                z = z[:max_time]
            lat_list.append(z)

        if len(lat_list) == 0:
            print("[visualize_latent_pca2d] no data.")
            return

        # ---------- 2) 做 PCA 投影基 ----------
        X = torch.cat(lat_list, dim=0).float()     # [sum_T', D]
        mu_ds = X.mean(dim=0, keepdim=True)        # 数据集均值
        Xc = X - mu_ds
        C = Xc.T @ Xc / max(1, Xc.size(0) - 1)
        vals, vecs = torch.linalg.eigh(C)          # 升序
        W = vecs[:, -2:]                           # 取前2主成分

        # ---------- 3) 准备中心化方式 ----------
        zstar = getattr(self, "latent_center", None)
        use_zstar = (center == "zstar") and (zstar is not None) and (zstar.numel() == X.shape[1])

        def _centerize(Z: torch.Tensor, how: str) -> torch.Tensor:
            if how == "none":
                return Z
            elif how == "dataset":
                return Z - mu_ds
            elif how == "zstar" and use_zstar:
                return Z - zstar.detach().cpu().view(1, -1)
            else:
                return Z  # 回退

        # ---------- 4) 投影并绘图 ----------
        fig, ax = plt.subplots()
        Tmax = max(z.size(0) for z in lat_list)

        def _plot_one(Z_np, label=None, style=None):
            if len(Z_np) < 2:
                ax.scatter(Z_np[0, 0], Z_np[0, 1], s=30, label=label)
                return
            if "color" in time_style:
                seg = np.stack([Z_np[:-1], Z_np[1:]], axis=1)  # [T-1, 2, 2]
                lc = LineCollection(seg, array=np.arange(seg.shape[0]), cmap="viridis", linewidth=1.5)
                ax.add_collection(lc)
            else:
                ax.plot(Z_np[:, 0], Z_np[:, 1], linewidth=1.0, alpha=0.9, label=label, linestyle=style or "-")
            if "arrows" in time_style and len(Z_np) >= 2 and arrow_every and arrow_every > 0:
                idx = np.arange(0, len(Z_np) - 1, arrow_every)
                for i in idx:
                    ax.annotate("", xy=(Z_np[i+1,0], Z_np[i+1,1]), xytext=(Z_np[i,0], Z_np[i,1]),
                                arrowprops=dict(arrowstyle="->", lw=0.8, alpha=0.9))
            ax.scatter(Z_np[0,0], Z_np[0,1], s=32)      # start
            ax.scatter(Z_np[-1,0], Z_np[-1,1], s=32, marker="x")  # end

        # 是否同时叠加两条（去中心 & 指定中心化）
        do_overlay = overlay_compare and (center != "none")

        for z in lat_list:
            z_unc = z
            z_ctr = _centerize(z, center)

            # 统一都用同一个 PCA 基（W）投影
            Zu = (z_unc - mu_ds).mm(W).numpy()
            Zc = (z_ctr - mu_ds).mm(W).numpy()

            if do_overlay:
                _plot_one(Zu, label=None, style="--")  # 去中心（虚线）
                _plot_one(Zc, label=None, style="-")   # 中心化（实线）
            else:
                _plot_one(Zc if center != "none" else Zu, label=None, style="-")

        if "color" in time_style and show_colorbar and Tmax >= 2:
            import matplotlib as mpl
            norm = mpl.colors.Normalize(vmin=0, vmax=Tmax - 2)
            sm = mpl.cm.ScalarMappable(cmap="viridis", norm=norm)
            sm.set_array([])
            cb = fig.colorbar(sm, ax=ax)
            cb.set_label("time step")

        ax.set_xlabel("PC1"); ax.set_ylabel("PC2")
        ttl = f"PCA(α_t) group={group}, n_traj={len(lat_list)}, center={center}"
        if do_overlay: ttl += " (overlay)"
        ax.set_title(ttl)
        ax.axhline(0.0, linewidth=0.5); ax.axvline(0.0, linewidth=0.5)
        ax.set_aspect("equal", adjustable="box")

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            fig.savefig(save_path, dpi=200, bbox_inches="tight")
            print("Saved latent PCA plot to:", save_path)
        plt.close(fig)


    @torch.no_grad()
    def save_all_latents(self, group: str, out_dir: str,
                        mode: str = "aggregate",   # "per_sample" 或 "aggregate"
                        fname: str = "latents_all.pt"):
        """
        将该split的所有样本的alpha_t保存到磁盘。
        - per_sample: 每个样本一份 .pt 文件，内含 {'alpha': [T',D], 't': [T'], 'index': (traj_id,t0)}
        - aggregate:  单个大文件 .pt，内含列表/字典
        """
        assert mode in {"per_sample", "aggregate"}
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        os.makedirs(out_dir, exist_ok=True)

        # 采样顺序与 dataset 顺序一致
        from torch.utils.data import DataLoader
        tmp_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

        agg = []  # 若 aggregate，用它收集
        for k, batch in enumerate(tmp_loader):
            lat, _, _ = self._encode_and_recon(batch)   # [T',1,D]
            lat = lat[:, 0].detach().cpu()              # [T',D]
            # t (若想保存时间轴，可从 batch['t'][0] 取并裁剪到 [nf_cond-1:])
            t_full = batch['t'][0].detach().cpu()       # [T]
            n_cond = self.n_frames_cond - 1
            t_eff = t_full[n_cond:]                     # [T']
            # 记录 index（如果数据集有 samples 属性）
            idx_meta = None
            if hasattr(dataset, "samples"):
                try:
                    idx_meta = dataset.samples[k]
                except Exception:
                    idx_meta = (int(k), 0)

            payload = {"alpha": lat, "t": t_eff, "index": idx_meta, "group": group}
            if mode == "per_sample":
                torch.save(payload, os.path.join(out_dir, f"latent_{k:06d}.pt"))
            else:
                agg.append(payload)

        if mode == "aggregate":
            torch.save(agg, os.path.join(out_dir, fname))
        print(f"[save_all_latents] done. group={group}, N={len(dataset)}, mode={mode}, out={out_dir}")


    @torch.no_grad()
    def latent_channel_diagnostics(
        self,
        group: str,
        n_traj: int = 64,
        max_time: int | None = None,
        center: str = "none",          # "none" | "zstar" | "dataset"
        save_dir: str | None = None,
        topk: int = 10,
    ):
        """
        Compute per-dimension stats (mean/std/RMS) and inter-channel relation (corr / 1-|corr|).
        Also quantifies how centering changes per-step norms if center="zstar" or "dataset".

        - group:     "train" | "train_eval" | "test"
        - n_traj:    number of sequences to use (each with batch_size=1 for stable masks)
        - max_time:  truncate each sequence if provided
        - center:    "none": as-is; "zstar": subtract self.latent_center; "dataset": subtract dataset-wide mean
        - save_dir:  if given, save heatmaps & bar plots
        - topk:      print top-k most similar/dissimilar channel pairs by |corr|
        """
        import os
        import numpy as np
        import matplotlib.pyplot as plt

        # ---------- 1) Collect latents as [N, D] ----------
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        Nall = len(dataset)
        n_traj = min(n_traj, Nall)

        from torch.utils.data import Subset, DataLoader
        idxs = np.arange(Nall)[:n_traj].tolist()  # deterministic; change to random if preferred
        tmp_loader = DataLoader(Subset(dataset, idxs), batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

        lat_list = []
        for batch in tmp_loader:
            lat, _, _ = self._encode_and_recon(batch)  # [T',1,D]
            z = lat[:, 0].detach().cpu().float()       # [T',D]
            if max_time is not None:
                z = z[:max_time]
            lat_list.append(z)
        if len(lat_list) == 0:
            print("[latent_channel_diagnostics] No samples collected.")
            return None

        X = torch.cat(lat_list, dim=0)                # [sum_T', D]
        Tsum, D = X.shape

        # ---------- 2) Centering options ----------
        X_raw = X.clone()
        used_center = "none"
        if center == "zstar" and getattr(self, "latent_center", None) is not None:
            zc = self.latent_center.detach().cpu().view(1, -1).to(X)
            if zc.size(1) == D:
                X = X - zc
                used_center = "zstar"
        elif center == "dataset":
            X = X - X.mean(dim=0, keepdim=True)
            used_center = "dataset"

        # ---------- 3) Per-channel stats ----------
        mu = X.mean(dim=0)                             # [D]
        std = X.std(dim=0, unbiased=True).clamp_min(1e-12)   # [D]
        rms = (X.pow(2).mean(dim=0)).sqrt()            # [D]
        energy_total = float(X.pow(2).sum())           # scalar

        # ---------- 4) Corr (Pearson) and a simple "difference" score ----------
        # cov = E[(x-μ)(x-μ)^T]; corr = cov / (σ_i σ_j)
        Xc = X - mu.view(1, -1)
        cov = (Xc.T @ Xc) / max(1, Tsum - 1)          # [D,D]
        denom = std.view(-1, 1) * std.view(1, -1)
        corr = (cov / denom).clamp(min=-1.0, max=1.0) # [D,D]
        diff = 1.0 - corr.abs()                       # "difference" score ∈ [0,1]

        # ---------- 5) Top-k similar / dissimilar pairs ----------
        # Use upper triangle without diagonal
        iu = torch.triu_indices(D, D, offset=1)
        corr_abs = corr.abs()[iu[0], iu[1]]           # [M]
        top_sim_val, top_sim_idx = torch.topk(corr_abs, k=min(topk, corr_abs.numel()))
        top_dis_val, top_dis_idx = torch.topk(-corr_abs, k=min(topk, corr_abs.numel()))  # smallest |corr|

        def pairs_from_indices(idx_tensor):
            pairs = []
            for j in idx_tensor.tolist():
                i0, i1 = iu[0][j].item(), iu[1][j].item()
                pairs.append((i0, i1))
            return pairs

        print(f"[latent_channel_diagnostics] group={group} | D={D} | N(time-steps)={Tsum} | center={used_center}")
        print("Per-dim summary: show first 8 dims (mean/std/RMS):")
        for d in range(min(8, D)):
            print(f"  dim {d:03d}: mean={float(mu[d]):.4e}  std={float(std[d]):.4e}  rms={float(rms[d]):.4e}")

        sim_pairs = pairs_from_indices(top_sim_idx)
        dis_pairs = pairs_from_indices(top_dis_idx)
        print(f"\nTop-{len(sim_pairs)} most similar pairs by |corr|:")
        for (i, j), v in zip(sim_pairs, top_sim_val):
            print(f"  ({i:03d},{j:03d})  |corr|={float(v):.4f}  corr={float(corr[i,j]):.4f}")

        print(f"\nTop-{len(dis_pairs)} most dissimilar pairs by |corr| (smallest |corr|):")
        for (i, j), v in zip(dis_pairs, -top_dis_val):
            print(f"  ({i:03d},{j:03d})  |corr|={float(v):.4f}  corr={float(corr[i,j]):.4f}")

        # ---------- 6) Optional figures ----------
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)

            # (a) correlation heatmap
            fig1, ax1 = plt.subplots()
            im1 = ax1.imshow(corr.numpy(), vmin=-1.0, vmax=1.0, cmap="coolwarm", interpolation="none")
            ax1.set_title(f"Latent corr (center={used_center})")
            ax1.set_xlabel("dim")
            ax1.set_ylabel("dim")
            fig1.colorbar(im1, ax=ax1)
            p1 = os.path.join(save_dir, f"latent_corr_{used_center}.png")
            fig1.savefig(p1, dpi=200, bbox_inches="tight")
            plt.close(fig1)

            # (b) difference heatmap (1-|corr|)
            fig2, ax2 = plt.subplots()
            im2 = ax2.imshow(diff.numpy(), vmin=0.0, vmax=1.0, cmap="viridis", interpolation="none")
            ax2.set_title(f"Latent difference (1-|corr|), center={used_center}")
            ax2.set_xlabel("dim")
            ax2.set_ylabel("dim")
            fig2.colorbar(im2, ax=ax2)
            p2 = os.path.join(save_dir, f"latent_diff_{used_center}.png")
            fig2.savefig(p2, dpi=200, bbox_inches="tight")
            plt.close(fig2)

            # (c) per-dim bar of RMS (energy proxy)
            fig3, ax3 = plt.subplots()
            ax3.bar(np.arange(D), rms.numpy())
            ax3.set_title(f"Per-dim RMS, center={used_center}")
            ax3.set_xlabel("dimension")
            ax3.set_ylabel("RMS")
            p3 = os.path.join(save_dir, f"latent_rms_{used_center}.png")
            fig3.savefig(p3, dpi=200, bbox_inches="tight")
            plt.close(fig3)

        # ---------- 7) If we centered, quantify the reduction in per-step energy ----------
        info = {}
        if center in {"zstar", "dataset"}:
            l2_before = torch.linalg.norm(X_raw, dim=1)   # [N]
            l2_after  = torch.linalg.norm(X, dim=1)       # [N]
            red = 1.0 - (l2_after.mean() / l2_before.mean())
            print(f"\nEnergy reduction by centering ({used_center}): "
                f"{100.0*float(red):.2f}%  (mean L2 per step)")
            info["energy_reduction_meanL2"] = float(red)

        # Return a compact dict in case you want to log programmatically
        info.update({
            "center": used_center,
            "per_dim_mean": mu,
            "per_dim_std": std,
            "per_dim_rms": rms,
            "corr": corr,
        })
        return info


    @torch.no_grad()
    def latent_center_report(self, topk: int = 20, save_dir: str | None = None):
        """
        Summarize the learned latent_center (z*), if available:
        - print L2/Linf norms
        - list top-|z*| dimensions
        - optional bar plot
        """
        import os
        import numpy as np
        import matplotlib.pyplot as plt

        zc = getattr(self, "latent_center", None)
        if zc is None:
            print("[latent_center_report] latent_center is None (not set).")
            return None

        z = zc.detach().cpu().view(-1).float()
        D = z.numel()
        l2 = float(torch.linalg.norm(z, ord=2))
        linf = float(torch.linalg.norm(z, ord=float("inf")))
        print(f"[latent_center_report] D={D} | ||z*||_2={l2:.6f} | ||z*||_inf={linf:.6f}")

        # Top-|z| dims
        topk = min(topk, D)
        vals, idx = torch.topk(z.abs(), k=topk)
        print(f"Top-{topk} dims by |z*|:")
        for r in range(topk):
            d = int(idx[r])
            print(f"  rank {r+1:02d}: dim={d:03d}  z*={float(z[d]):.6f}  |z*|={float(vals[r]):.6f}")

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            fig, ax = plt.subplots()
            ax.bar(np.arange(D), z.numpy())
            ax.set_title("latent_center (z*) values per dimension")
            ax.set_xlabel("dimension")
            ax.set_ylabel("value")
            p = os.path.join(save_dir, "latent_center_bar.png")
            fig.savefig(p, dpi=200, bbox_inches="tight")
            plt.close(fig)

        return {"z_star": z, "l2": l2, "linf": linf}


    @torch.no_grad()
    def set_whitening_scale(self, s: torch.Tensor | None):
        """Register per-dim whitening scales."""
        if s is None:
            self.whiten_scale = None
            return
        s = s.view(-1).to(device=self.device, dtype=torch.float32)
        assert s.numel() == self.latent_dim
        self.whiten_scale = torch.clamp(s, min=float(self.whiten_clamp))


    def _whiten_latent(self, centered: torch.Tensor) -> torch.Tensor:
        """w = S^{-1}(z - z*); shape-preserving."""
        s = getattr(self, "whiten_scale", None)
        if (not self.use_diag_whiten) or (s is None):
            return centered
        view = [1]*(centered.ndim - 1) + [-1]
        return centered / s.view(*view)


    def _unwhiten_latent(self, w: torch.Tensor) -> torch.Tensor:
        """(z - z*) = S w; shape-preserving."""
        s = getattr(self, "whiten_scale", None)
        if (not self.use_diag_whiten) or (s is None):
            return w
        view = [1]*(w.ndim - 1) + [-1]
        return w * s.view(*view)


    @torch.no_grad()
    def fit_diag_whitening_from_phase1(self, group: str = "train_eval", max_batches: int | None = 16):
        """Estimate per-dim RMS of centered latents from Phase-I encoder outputs."""
        # choose dataloader
        self._ensure_loader(group)
        loader = {"train": self.train_loader, "train_eval": self.train_eval_loader, "test": self.test_loader}[group]
        if loader is None:
            raise RuntimeError(f"[fit_diag_whitening_from_phase1] loader for {group} is None")
        self.encoder.eval()
        m2 = torch.zeros(self.latent_dim, device=self.device, dtype=torch.float64)
        n = 0
        for b_idx, batch in enumerate(loader):
            if max_batches is not None and b_idx >= max_batches:
                break
            lat, _, _ = self._encode_and_recon(batch)                # [T',B,D]
            y = self._center_latent(lat).reshape(-1, self.latent_dim).to(torch.float64)
            m2 += (y*y).sum(dim=0)
            n += y.shape[0]
        if n == 0:
            raise RuntimeError("[fit_diag_whitening_from_phase1] no samples")
        mean_sq = (m2 / n).to(torch.float32)
        s = torch.sqrt(torch.clamp(mean_sq, min=self.whiten_eps)).clamp_min(self.whiten_clamp)
        self.set_whitening_scale(s)
        return s


    @torch.no_grad()
    def _reinit_linear_in_whiten_space(self):
        """Apply Aw = S^{-1} Ad S and re-init latent_process linear skeleton."""
        if (not self.use_diag_whiten) or (self.whiten_scale is None):
            return
        if not hasattr(self, "Ad_phase1") or self.Ad_phase1 is None:
            return
        Ad = self.Ad_phase1.to(self.device, dtype=torch.float32)  # [D,D]
        s  = self.whiten_scale.to(self.device, dtype=Ad.dtype)    # [D]
        # Aw = S^{-1} Ad S  (column scale by s, then row divide by s)
        Aw = (Ad * s.view(1, -1)) / s.view(-1, 1)
        if self.latent_mode == "continuous":
            self.latent_process.init_linear_from_Ad(
                Aw, self.dt_eval, clip_positive_symmetric=False, max_pos_real=0.0
            )
        elif self.latent_mode == "discrete":
            self.latent_process.A.copy_(Aw)
        self.Ad_phase1_whiten = Aw.detach().clone()

    
    def _get_U(self) -> torch.Tensor:
        """Return orthonormal U [D,d] on correct device/dtype."""
        if self.U_proj is None:
            raise RuntimeError("U_proj is None. Train or load projector first.")
        p0 = next(self.latent_process.parameters(), None)
        dtype = p0.dtype if p0 is not None else torch.float32
        return self.U_proj.to(device=self.device, dtype=dtype)


    def _project_latent(self, w: torch.Tensor) -> torch.Tensor:
        """w [..., D] -> y [..., d] via U."""
        U = self._get_U()
        return torch.matmul(w, U)  # broadcast matmul: (...,D) x (D,d) -> (...,d)


    def _lift_latent(self, y: torch.Tensor) -> torch.Tensor:
        """y [..., d] -> w [..., D] via U^T."""
        U = self._get_U()
        return torch.matmul(y, U.t())


    @torch.no_grad()
    def set_projector(self, U: torch.Tensor):
        """Register an orthonormal projector U [D,d]."""
        U = U.to(device=self.device, dtype=torch.float32)
        Q, _ = torch.linalg.qr(U)
        d = U.size(1)
        self.U_proj = Q[:, :d].contiguous()
        self.use_projector = True
        self.latent_dim_y = d


    def save_projector(self, path: str):
        import os
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({"U_proj": (self.U_proj.detach().cpu() if self.U_proj is not None else None)}, path)


    def load_projector(self, path: str):
        payload = torch.load(path, map_location="cpu")
        U = payload.get("U_proj", None)
        if U is None:
            raise KeyError(f"'U_proj' missing in {path}")
        self.set_projector(U)


    def train_projector_from_phase1_ckpt(self, phase1_path: str, d: int,
                                         epochs: int = 3, lr: float = 1e-3, lr_dec: float = 0.0,
                                         lambda_dyn: float = 0.1, lambda_ortho: float = 1e-2,
                                         log_every: int | None = None, eval_every: int | None = None
                                         ):
        self.setup_logger()
        self.logger.info(f"Training low dimensional projectors based on pretrained encoder-decoder models!")
        self.save_repro_artifacts()

        # load model from phase1
        # load encoder-decoder, and linear parameter
        phase1_info = self.load_phase1_ckpt(path=phase1_path, clip_positive_symmetric=False)
        assert self.Ad_phase1 is not None
        Ad = self.Ad_phase1
        train_dec = (self.cfg.lr_dec != 0)
        self.logger.info(f"finetune decoder = {train_dec}")
        set_requires_grad(self.encoder, False)
        set_requires_grad(self.decoder, True) if train_dec else set_requires_grad(self.decoder, False)    ##################

        # ---- estimate whitening scales (Phase-I encoder) & re-init Aw ----
        if self.use_diag_whiten:
            self.fit_diag_whitening_from_phase1(group="train", max_batches=16)
            self._reinit_linear_in_whiten_space()
            self.logger.info(f"whiten vector: {self.whiten_scale}")
        if self.whiten_scale is not None:
            s = self.whiten_scale.to(self.device)
            Aw = (Ad.to(self.device, torch.float32) * s.view(1,-1)) / s.view(-1,1)
            bw = None
            if getattr(self, "b_phase1", None) is not None:
                bw = (self.b_phase1.to(self.device, torch.float32) / (s + 1e-12))
        else:
            Aw = Ad.to(self.device, torch.float32)
            bw = (self.b_phase1.to(self.device, torch.float32) if getattr(self, "b_phase1", None) is not None else None)
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        # self.build_dataloader(group="test")
        # self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        ######## Build Optimizer ########
        D = self.latent_dim
        assert d <= D, f"d={d} must be <= latent_dim={D}"
        U_param = torch.nn.Parameter(torch.randn(D, d, device=self.device) / (D**0.5))
        optim_proj = torch.optim.Adam([U_param], lr=lr)
        optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': lr_dec}]) if lr_dec != 0 else None

        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        for epoch in range(1, epochs + 1):
            for i, batch in enumerate(self.train_loader):
                latent_states, _, _ = self._encode_and_recon(batch)     # [t, B, latent_dim]
                latent_states_ = self._center_latent(latent_states)
                latent_states_ = self._whiten_latent(latent_states_)

                Q, _ = torch.linalg.qr(U_param)                         # [D,D]
                U = Q[:, :d]                                            # [D,d]
                eff_states_ = torch.matmul(latent_states_, U)           # [T',B,d]
                latent_hat_ = torch.matmul(eff_states_, U.t())          # [T',B,D]
                latent_hat = self._decenter_latent(self._unwhiten_latent(latent_hat_))
                x_hat = self._decode_latent(latent_hat)                 # [B,T',H,W,C]
                x_gt = batch["data"][:, self.n_frames_cond-1:, ...].to(self.device)
                m_gt = batch["mask"][:, self.n_frames_cond-1:, ...].to(self.device)
                _, rec_loss = self._mse_loss(x_hat, x_gt, m_gt)         # masked MSE

                A_eff = U.t() @ Aw @ U                                  # [d,d]
                eff_next_ = torch.matmul(eff_states_[:-1], A_eff.t())   # [T'-1,B,d]
                """if bw is not None:
                    eff_next_ = eff_next_ + (U.t() @ bw).view(1,1,-1)"""
                dyn_loss = torch.nn.functional.mse_loss(eff_states_[1:], eff_next_)
                ortho_loss = torch.norm(U.t() @ U - torch.eye(d, device=U.device))**2

                loss = rec_loss + lambda_dyn * dyn_loss + lambda_ortho * ortho_loss
                optim_proj.zero_grad()
                optim_dec.zero_grad() if train_dec else None
                loss.backward()
                optim_proj.step()
                optim_dec.step() if train_dec else None

                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    self.logger.info(f"[Projector] epoch {epoch:03d}/{epochs} | rec(mask)={float(rec_loss):.6f} | dyn={float(dyn_loss):.6f} | ortho={float(ortho_loss):.6f}")
        
        with torch.no_grad():
            Q, _ = torch.linalg.qr(U_param.data)
            self.U_proj = Q[:, :d].contiguous()
            self.use_projector = True
            self.latent_dim_y = d
        A_eff = (self.U_proj.t().to(Aw) @ Aw @ self.U_proj.to(Aw)).detach().cpu()
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        os.makedirs(out_dir, exist_ok=True)
        torch.save({"U_proj": self.U_proj.detach().cpu(),
                    "A_eff": A_eff}, os.path.join(out_dir, f"U_proj_d{d}.pt"))
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Projector] saved to: {os.path.join(out_dir, f'U_proj_d{d}.pt')}")

        ##### rebuild latent process #####
        Ay = self.U_proj.t() @ Aw @ self.U_proj
        from copy import deepcopy
        if self.latent_mode == "continuous":
            lp_cfg = deepcopy(self.model_cfg["latent_process"])
            assert d % self.state_dim == 0
            lp_cfg["code_dim"] = d // self.state_dim
            self.model_cfg["latent_process"] = lp_cfg
            self.latent_process = build_latent_process(model_cfg=lp_cfg).to(self.device)
            self.latent_process.init_linear_from_Ad(Ay, self.dt_eval)
        else:
            lp_cfg = deepcopy(self.model_cfg["latent_process_discrete"])
            assert d % self.state_dim == 0
            lp_cfg["code_dim"] = d // self.state_dim
            self.model_cfg["latent_process_discrete"] = lp_cfg
            self.latent_process = build_latent_process_discrete(model_cfg=lp_cfg).to(self.device)
            with torch.no_grad():
                self.latent_process.A.copy_(Ay)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Projector] latent_process rebuilt at dim d={d}. Now ready for Phase-2.")


    def _phase2_param_groups(self, lr_mem=1e-3, lr_lin=0.0, wd_mem=0.0, wd_lin=1e-4,):
        lp = self.latent_process
        mem_params, lin_params = [], []

        # ---- linear skeleton ----
        if self.latent_mode == "continuous":
            if getattr(lp, "linear_param", "free") == "free":
                if hasattr(lp, "linear"):
                    lin_params += [lp.linear]
            elif lp.linear_param == "pH_dense":
                for n in ["ph_B", "ph_C", "ph_L", "ph_W", "ph_omega_raw"]:
                    if hasattr(lp, n):
                        lin_params += [getattr(lp, n)]
        elif self.latent_mode == "discrete":
            lin_params += [lp.A]

        # ---- memory correction----
        if self.latent_mode == "continuous":
            if getattr(self, "latent_type", None) in {"linear+memory", "gru_memory", "lstm_memory"}:
                for n in ["memory_encoder", "memory_decoder",
                          "gru_r", "gru_z", "gru_h",
                          "lstm_i", "lstm_f", "lstm_o", "lstm_cand"]:
                    if hasattr(lp, n):
                        mem_params += list(getattr(lp, n).parameters())
                for n in ["_raw_lambda", "_raw_tau_m_inv", "_raw_tau_c_inv", "_raw_tau_h_inv"]:
                    if hasattr(lp, n):
                        mem_params += [getattr(lp, n)]
                if hasattr(lp, "_raw_mem_scale"):
                    mem_params += [lp._raw_mem_scale]
        elif self.latent_mode == "discrete":
            for n in ["memory_encoder", "memory_decoder"]:     
                if hasattr(lp, n):
                    mem_params += list(getattr(lp, n).parameters())  
            if hasattr(lp, "memory"):                           
                mem_params += list(lp.memory.parameters())     
            if hasattr(lp, "_raw_gate"):                       
                mem_params += [lp._raw_gate]    
            for n in ["ode_func", "rnn_cell"]:  
                if hasattr(lp, n):
                    mem_params += list(getattr(lp, n).parameters()) 
            if hasattr(lp, "res_lstm"):
                mem_params += list(lp.res_lstm.parameters()) 

        groups = []
        if mem_params:
            groups.append({"name": "memory", "params": mem_params, "lr": lr_mem, "weight_decay": wd_mem})
        if lin_params:
            groups.append({"name": "linear", "params": lin_params, "lr": lr_lin, "weight_decay": wd_lin})
        return groups


    def init_optim_phase2(self, lr_mem=None, lr_lin=None, lr_dec=None):
        if lr_mem is None:
            lr_mem = getattr(self.cfg, "lr_dyn_mem", self.lr)
        if lr_lin is None:
            lr_lin = getattr(self.cfg, "lr_dyn_lin", 0.0)
        if lr_dec is None:
            lr_dec = getattr(self.cfg, "lr_dec", 0.0)

        groups = self._phase2_param_groups(
            lr_mem=lr_mem, lr_lin=lr_lin,
            wd_mem=0.0, wd_lin=0.0,
        )

        self.optim_dyn = torch.optim.Adam(groups) if groups else None
        # self.optim_enc = torch.optim.Adam([{'params': self.encoder.parameters(), 'lr': lr_encdec}]) if lr_encdec != 0 else None
        self.optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': lr_dec}]) if lr_dec != 0 else None

        # ------- scheduler -------
        self.scheduler_dyn = None
        self.scheduler_enc, self.scheduler_dec = None, None
        if self.optim_dyn is not None:
            """if self.cfg.scheduler == 'OneCycleLR':
                steps_per_epoch = len(self.train_loader)
                max_lr = max(g.get("lr", 0.0) for g in groups)
                self.scheduler_dyn = torch.optim.lr_scheduler.OneCycleLR(
                    self.optim_dyn, max_lr=max_lr,
                    epochs=self.cfg.epochs, steps_per_epoch=steps_per_epoch,
                    pct_start=self.cfg.pct_start
                )"""
            if self.cfg.scheduler == 'CosineAnnealingLR':
                self.scheduler_dyn = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dyn, T_max=self.cfg.epochs)
            elif self.cfg.scheduler == 'StepLR':
                self.scheduler_dyn = torch.optim.lr_scheduler.StepLR(
                    self.optim_dyn, step_size=self.cfg.step_size, gamma=self.cfg.gamma
                )
        if lr_dec != 0:
            if self.cfg.scheduler == 'CosineAnnealingLR':
                # self.scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_enc, T_max=self.cfg.epochs)
                self.scheduler_dec = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dec, T_max=self.cfg.epochs)
            elif self.cfg.scheduler == 'StepLR':
                # self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
                self.scheduler_dec = torch.optim.lr_scheduler.StepLR(self.optim_dec, step_size=self.cfg.step_size, gamma=self.cfg.gamma)


    @torch.no_grad()
    def _compute_memory_lengths(self, latent_states_: torch.Tensor, memory_states: torch.Tensor | None):
        """
        Return a small dict of memory time constants measured in *steps*.
        - latent_states_: [T, B, D]  (centered/whitened version used by latent_process)
        - memory_states : [T, B, Dm] or None
        """
        out = {}
        if not self.use_memory or memory_states is None:
            return out

        lp = self.latent_process
        dt = float(self.dt_eval)

        try:
            if lp.latent_type == "linear+memory":
                # tau = 1 / lambda; report stats in steps
                lam = lp.lambda_diag                      # [Dm]
                tau_steps = (1.0 / (lam + 1e-12)) / dt    # [Dm]
                v = tau_steps.detach().cpu().float()
                out.update({
                    "tau_steps_mean": float(v.mean()),
                    "tau_steps_med":  float(v.median()),
                    "tau_steps_p90":  float(v.quantile(0.9)),
                    "tau_steps_max":  float(v.max()),
                })

            elif lp.latent_type == "gru_memory":
                # Effective tau uses mean update gate z over time & batch
                Dm = lp.memory_dim
                # concat over all (t>0, b)
                concat = torch.cat([latent_states_[1:], memory_states[1:]], dim=-1)
                concat = concat.reshape(-1, lp.latent_dim + Dm)           # [(T-1)*B, D+D_m]
                z = torch.sigmoid(lp.gru_z(concat))                       # [(T-1)*B, Dm]
                z_mean = z.mean(dim=0)                                    # [Dm]
                tau_eff_steps = 1.0 / ((lp.tau_m_inv * z_mean) + 1e-12) / dt
                v = tau_eff_steps.detach().cpu().float()
                out.update({
                    "tau_eff_steps_mean": float(v.mean()),
                    "tau_eff_steps_p90":  float(v.quantile(0.9)),
                    "tau_raw_steps_mean": float((1.0 / (lp.tau_m_inv + 1e-12) / dt).mean().item()),
                })

            elif lp.latent_type == "lstm_memory":
                # Use forget gate f to estimate leakage; larger f -> longer memory.
                Dm = lp.memory_dim
                concat = torch.cat([latent_states_[1:], memory_states[1:]], dim=-1)
                concat = concat.reshape(-1, lp.latent_dim + Dm)           # [(T-1)*B, D+D_m]
                f = torch.sigmoid(lp.lstm_f(concat))                      # [(T-1)*B, Dm]
                leak = (1.0 - f.mean(dim=0)).clamp_min(1e-4)              # [Dm], avoid div-by-zero
                tau_eff_steps = 1.0 / ((lp.tau_c_inv * leak) + 1e-12) / dt
                v = tau_eff_steps.detach().cpu().float()
                out.update({
                    "tau_c_eff_steps_mean": float(v.mean()),
                    "tau_c_eff_steps_p90":  float(v.quantile(0.9)),
                    "tau_c_raw_steps_mean": float((1.0 / (lp.tau_c_inv + 1e-12) / dt).mean().item()),
                })
        except Exception:
            # Keep training robust even if shape/branch mismatch happens
            pass

        return out


    def train_phase2(self, phase1_path: str, log_every: int | None = None, eval_every: int | None = None, 
                     verbose: str = None):
        self.setup_logger()
        self.save_repro_artifacts()
        self.log_param_table()
        criterion = nn.MSELoss()

        # load model from phase1
        # load encoder-decoder, and linear parameter
        if not self.use_projector:
            phase1_info = self.load_phase1_ckpt(path=phase1_path, clip_positive_symmetric=False)
        # train_encdec = (self.cfg.lr_encdec != 0)
        train_dec = (self.cfg.lr_dec != 0)
        self.logger.info(f"finetune decoder = {train_dec}")
        # set_requires_grad(self.encoder, True) if train_encdec else set_requires_grad(self.encoder, False)    ##################
        set_requires_grad(self.encoder, False)
        set_requires_grad(self.decoder, True) if train_dec else set_requires_grad(self.decoder, False)    ##################
        set_requires_grad(self.latent_process, True)
        # self.latent_process.init_memory_time_constant(tau_in_steps=8, dt=self.dt_eval)    ##########################################

        # ---- estimate whitening scales (Phase-I encoder) & re-init Aw ----
        if self.use_diag_whiten and not self.use_projector:
            # prefer train_eval; fallback to train
            """try:
                if getattr(self, "train_eval_loader", None) is None:
                    self.build_dataloader(group="train_eval")
                self.fit_diag_whitening_from_phase1(group="train_eval", max_batches=16)
            except Exception:
                self.fit_diag_whitening_from_phase1(group="train", max_batches=16)"""
            self.fit_diag_whitening_from_phase1(group="train", max_batches=16)
            self._reinit_linear_in_whiten_space()

        self.logger.info(f"whiten vector: {self.whiten_scale}")
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        self.init_optim_phase2()
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        if verbose is not None:
            self.logger.info(f"{verbose}")

        ######## Initial Evaluation ########
        if self.train_eval_loader is not None:
            self.logger.info("--------Begin Evaluation on Train--------")
            train_eval_errs = self.evaluate(dataloader=self.train_eval_loader)
            self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
        # out-of-domain evaluation
        if self.test_loader is not None:
            self.logger.info("--------Begin Evaluation on Test--------")
            test_errs = self.evaluate(dataloader=self.test_loader)
            self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))

        tf_epsilon = self.tf_epsilon
        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        for epoch in range(1, self.cfg.epochs + 1):
            self.latent_process.train()
            for i, batch in enumerate(self.train_loader):
                ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
                sample_idx = batch["index"].to(self.device)     # [B,]
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
                t_eval = batch['t'][0].to(self.device)          # [T]
                bs, train_len, H, W, _ = ground_truth.shape
                assert train_len == self.n_frames_train

                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]
                data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                # latent_states: encoded from encoder / dyn_states: predicted using latent ode
                latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token]
                # reshape to [B, t, latent_dim]
                latent_states = latent_token.reshape((bs, train_len-self.n_frames_cond+1, -1))    # [B, t, latent_dim]
                latent_states = latent_states.permute(1, 0, 2)    # [t, B, latent_dim]

                latent_states_ = self._center_latent(latent_states)
                latent_states_ = self._whiten_latent(latent_states_)

                if self.use_projector:
                    latent_states_ = self._project_latent(latent_states_)
                    dyn_y, memory_states, aux = self.latent_process(
                        alpha_0=latent_states_[0], t_eval=t_eval[self.n_frames_cond-1:],
                        memory_init=None, teacher_forcing=True,
                        tf_alpha=latent_states_, tf_epsilon=tf_epsilon, tf_mask=None
                    )                                                          # [T',B,d]
                    dyn_loss = criterion(dyn_y, latent_states_.detach())
                    dyn_states_ = self._lift_latent(dyn_y)                     # lift back to [T',B,D]
                else:
                    dyn_states_, memory_states, aux = self.latent_process(
                        alpha_0=latent_states_[0], t_eval=t_eval[self.n_frames_cond-1:],
                        memory_init=None,
                        teacher_forcing=True, tf_alpha=latent_states_, tf_epsilon=tf_epsilon, tf_mask=None
                    )    # [t, B, latent_dim]
                    dyn_loss = criterion(dyn_states_, latent_states_.detach())
                    """long_term_loss, _ = self.kstep_rollout_loss(alpha_gt=latent_states_, t_eval=t_eval[self.n_frames_cond-1:], K=6, 
                                                            train_latent=False)"""
                corr = aux["phi_dec_l2"] if self.use_memory else None

                """if self.latent_mode == "discrete":
                    if self.use_projector:
                        A = self.latent_process.A                           
                        res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.t())
                        if self.latent_process.memory_type != "residual":
                            Dm = self.latent_process.memory_dim
                            mem_flatten = memory_states[:-1].reshape(-1, Dm)
                            res_flatten = self.latent_process.memory_decoder(mem_flatten)
                            res_pred = (self.latent_process.gate * res_flatten).view(res_gt.shape)
                        else:
                            res_pred = memory_states[:-1]
                        residual_loss = F.mse_loss(res_pred, res_gt)
                    else:
                        A = self.latent_process.A
                        res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.t())
                        if self.latent_process.memory_type != "residual":
                            Dm = self.latent_process.memory_dim
                            mem_flatten = memory_states[:-1].reshape(-1, Dm)
                            res_flatten = self.latent_process.memory_decoder(mem_flatten)
                            res_pred = (self.latent_process.gate * res_flatten).view(res_gt.shape)
                        else:
                            res_pred = memory_states[:-1]
                        residual_loss = F.mse_loss(res_pred, res_gt)"""

                if self.latent_mode == "discrete" and self.latent_process.memory_type != "residual":
                    T_eff, B, D = latent_states_.shape
                    A = self.latent_process.A
                    res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.T)
                    mem_flatten = memory_states[:-1].reshape(-1, self.latent_process.memory_dim)
                    res_flatten = self.latent_process.memory_decoder(mem_flatten)
                    res_pred = (self.latent_process.gate * res_flatten).view(T_eff-1, B, D)
                    residual_loss = F.mse_loss(res_pred, res_gt)
                elif self.latent_mode == "discrete" and self.latent_process.memory_type == "residual":
                    A = self.latent_process.A
                    res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.T)
                    res_pred = memory_states[:-1]
                    residual_loss = F.mse_loss(res_pred, res_gt)




                # mem_info = self._compute_memory_lengths(latent_states_, memory_states)
                dyn_states_ = self._unwhiten_latent(dyn_states_)   
                dyn_states = self._decenter_latent(dyn_states_)

                pred_field = self._decode_latent(latent_seqs=dyn_states)    # [B, t, H, W, C]
                pred_loss, pred_loss_wzmask = self._mse_loss(
                    pred_field, ground_truth[:, self.n_frames_cond-1:, ...], masks[:, self.n_frames_cond-1:, ...]
                )

                loss = dyn_loss + self.lambda_pred * pred_loss_wzmask
                """if self.use_memory and (memory_states is not None):
                    T, B, D   = dyn_states_.shape
                    Dm        = memory_states.shape[-1]
                    dt        = float(self.dt_eval)

                    f_lin = self.latent_process._apply_linear(
                        latent_states_[:-1].detach()
                    )                                            # [T-1, B, D]
                    mem_flat   = memory_states[1:].reshape(-1, Dm)             # [T-1*B, Dm]
                    f_mem_raw  = self.latent_process.memory_decoder(mem_flat)  # [T-1*B, D]
                    f_mem_raw  = f_mem_raw.view(T-1, B, D)                     # [T-1, B, D]
                    mem_gate   = self.latent_process.mem_scale                 # scalar ∈ (0,1)
                    f_mem      = mem_gate * f_mem_raw                          # [T-1, B, D]

                    e_lin = (f_lin**2).sum(dim=-1)           # [T-1, B]
                    e_mem = (f_mem**2).sum(dim=-1)           # [T-1, B]
                    r     = e_mem / (e_lin + e_mem + 1e-8)   # [T-1, B]

                    r_max = getattr(self.cfg, "mem_ratio_max", 0.5)
                    L_mem_rel_cap = F.relu(r - r_max)**2     # [T-1, B]
                    r_min = getattr(self.cfg, "mem_ratio_min", 0.02)
                    L_mem_rel_floor = F.relu(r_min - r)**2
                    L_mem_rel_band = L_mem_rel_cap + L_mem_rel_floor
                    L_mem_rel = L_mem_rel_band.mean()

                    lambda_rel = getattr(self.cfg, "lambda_corr", 1e-3)
                    loss = loss + lambda_rel * L_mem_rel"""
                if self.lambda_corr is not None and corr is not None:
                    loss += self.lambda_corr * corr
                """if self.cfg.lambda_lt_pred is not None:
                    loss += self.cfg.lambda_lt_pred * long_term_loss"""
                if self.latent_mode == "discrete":
                    loss += self.lambda_resid * residual_loss

                self.optim_dyn.zero_grad()
                if train_dec:
                    # self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                loss.backward()
                if self.cfg.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.latent_process.parameters(), self.cfg.max_grad_norm)    #######
                    if train_dec:
                        # torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.cfg.max_grad_norm)
                        torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.cfg.max_grad_norm)
                self.optim_dyn.step()
                if train_dec:
                    # self.optim_enc.step()
                    self.optim_dec.step()
                """if self.cfg.scheduler == "OneCycleLR":
                    self.scheduler_dyn.step()"""
                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    msg = [
                        f"Epoch {epoch:04d}/{self.cfg.epochs} | iteration {i+1:03d}",
                        f"| pred {pred_loss.item():.8f}",
                        f"| pred(mask) {pred_loss_wzmask.item():.8f}",
                        f"| dyn {dyn_loss.item():.8f}",
                        # f"| dyn (long term) {long_term_loss.item():.8f}"
                    ]
                    if self.latent_mode == "discrete":
                        msg += [f"| residual {residual_loss.item():.8f}"]
                        msg += [f"| correction ratio mean {aux['mem_ratio_mean']:.8f}"]
                        msg += [f"| correction ratio p90 {aux['mem_ratio_p90']:.8f}"]
                    if self.use_memory and (corr is not None):
                        msg += [f"| memory_corr {corr.item():.8f}"]
                        # msg += [f"| memory (relative) {L_mem_rel.item():.8f}"]
                        if self.latent_mode == "continuous":
                            msg += [F"| memory scale {float(self.latent_process.mem_scale.mean().item()):.4f}"]

                        """if "tau_steps_mean" in mem_info:
                            msg += [f"| tau(mem)_mean {mem_info['tau_steps_mean']:.2f} steps",
                                    f"| tau_p90 {mem_info['tau_steps_p90']:.2f}"]
                        if "tau_eff_steps_mean" in mem_info:
                            msg += [f"| tau_eff_mean {mem_info['tau_eff_steps_mean']:.2f} steps",
                                    f"| tau_eff_p90 {mem_info['tau_eff_steps_p90']:.2f}"]
                        if "tau_c_eff_steps_mean" in mem_info:
                            msg += [f"| tau_c_eff_mean {mem_info['tau_c_eff_steps_mean']:.2f} steps",
                                    f"| tau_c_eff_p90 {mem_info['tau_c_eff_steps_p90']:.2f}"]"""
                    msg += [f"| epsilon {tf_epsilon:.4f}"]
                    self.logger.info(" ".join(msg))
                if (epoch * len(self.train_loader) + i + 1) % self.cfg.update_every == 0:
                    tf_epsilon = max(tf_epsilon * self.epsilon, self.cfg.tf_epsilon_min)

                if eval_every is not None and (epoch * len(self.train_loader) + i + 1) % eval_every == 0:
                    if self.train_eval_loader is not None:
                        self.logger.info("--------Begin Evaluation on Train--------")
                        train_eval_errs = self.evaluate(dataloader=self.train_eval_loader)
                        self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
                        losses_tr = train_eval_errs["mse_losses"]
                        if loss_tr_min > losses_tr["loss"]:
                            loss_tr_min = losses_tr["loss"]
                            self.save_model(epoch, train_eval_errs, pth_name="model_tr_best.pth")
                    # out-of-domain evaluation
                    if self.test_loader is not None:
                        self.logger.info("--------Begin Evaluation on Test--------")
                        test_errs = self.evaluate(dataloader=self.test_loader)
                        self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))
                        losses_ts = test_errs["mse_losses"]
                        if loss_ts_min > losses_ts["loss"]:
                            loss_ts_min = losses_ts["loss"]
                            self.save_model(epoch, test_errs, pth_name="model_ts_best.pth")
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_dyn.step()
                if train_dec:
                    # self.scheduler_enc.step()
                    self.scheduler_dec.step()
        self.logger.info("Training Finished! Saving model...")
        self.save_model(epoch=None, losses=None, pth_name="final_model.pth")


    @torch.no_grad()
    def evaluate(self, dataloader, model_pth: str | None = None):
        if model_pth is not None:
            _ = self.load_from_ckpt(ckpt_path=model_pth, device=self.device)
        
        err_dict = {}
        rel_err, mse_err = 0.0, 0.0
        rel_err_in_t, rel_err_out_t, mse_err_in_t, mse_err_out_t = 0.0, 0.0, 0.0, 0.0
        rmse_err, rmse_err_in_t, rmse_err_out_t = 0.0, 0.0, 0.0  
        rel_criterion = LpLoss(size_average=False)
        loss, loss_out_t, loss_in_t = 0.0, 0.0, 0.0
        num_samples = int(0)

        for batch in dataloader:
            ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            t_eval = batch['t'][0][self.n_frames_cond-1:].to(self.device)
            masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
            bs, length, H, W, _ = ground_truth.shape
            num_samples += bs
            
            with torch.no_grad():
                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                # mask_index = mask_index.unsqueeze(1).expand(-1, length-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
                data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                # latent_states: encoded from encoder / dyn_states: predicted using latent ode
                latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
                # reshape to [B, latent_dim]
                latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]
                latent_state_ = self._center_latent(latent_state)
                latent_state_ = self._whiten_latent(latent_state_)
                if self.use_projector:
                    latent_state_ = self._project_latent(latent_state_)
                dyn_states_, _, _ = self.latent_process(alpha_0=latent_state_, t_eval=t_eval, teacher_forcing=False)    # [T, B, latent_dim]
                if self.use_projector:
                    dyn_states_ = self._lift_latent(dyn_states_)
                dyn_states_ = self._unwhiten_latent(dyn_states_)
                dyn_states = self._decenter_latent(dyn_states_)
                recon_seq = self._decode_latent(dyn_states)

                # compute losses
                # recon_seq, ground_truth: [B, T, H, W, s]
                n_cond = self.n_frames_cond - 1
                ground_truth_ = ground_truth[:, n_cond:, ...]
                masks_ = masks[:, n_cond:, ...]
                pred_in_t_, pred_out_t_ = recon_seq[:, :self.n_frames_train-n_cond, ...], recon_seq[:, self.n_frames_train-n_cond:, ...]
                gt_in_t_, gt_out_t_ = ground_truth_[:, :self.n_frames_train-n_cond, ...], ground_truth_[:, self.n_frames_train-n_cond:, ...]

                rel_err += rel_criterion(recon_seq.reshape(bs, -1), ground_truth_.reshape(bs, -1)).item()
                rel_err_in_t += rel_criterion(pred_in_t_.reshape(bs, -1), gt_in_t_.reshape(bs, -1)).item()
                rel_err_out_t += rel_criterion(pred_out_t_.reshape(bs, -1), gt_out_t_.reshape(bs, -1)).item()
                mse_err += self._compute_loss(recon_seq, ground_truth_) * bs
                mse_err_in_t += self._compute_loss(pred_in_t_, gt_in_t_) * bs
                mse_err_out_t += self._compute_loss(pred_out_t_, gt_out_t_) * bs
                rmse_err += torch.sqrt(self._compute_loss(recon_seq, ground_truth_)) * bs
                rmse_err_in_t += torch.sqrt(self._compute_loss(pred_in_t_, gt_in_t_)) * bs
                rmse_err_out_t += torch.sqrt(self._compute_loss(pred_out_t_, gt_out_t_)) * bs

        rel_err = rel_err / num_samples
        mse_err = mse_err / num_samples
        rmse_err = rmse_err / num_samples
        rel_err_in_t, rel_err_out_t = rel_err_in_t / num_samples, rel_err_out_t / num_samples
        mse_err_in_t, mse_err_out_t = mse_err_in_t / num_samples, mse_err_out_t / num_samples
        rmse_err_in_t, rmse_err_out_t = rmse_err_in_t / num_samples, rmse_err_out_t / num_samples

        rel_losses = {
            "loss": rel_err, "loss_in_t": rel_err_in_t, "loss_out_t": rel_err_out_t
        }
        mse_losses = {
            "loss": mse_err, "loss_in_t": mse_err_in_t, "loss_out_t": mse_err_out_t
        }
        rmse_losses = {
            "loss": rmse_err, "loss_in_t": rmse_err_in_t, "loss_out_t": rmse_err_out_t
        }
        err_dict.update({"rel_losses": rel_losses})
        err_dict.update({"mse_losses": mse_losses})
        err_dict.update({"rmse_losses": rmse_losses})
        return err_dict

            
    def _compute_loss(self, data1: torch.Tensor, data2: torch.Tensor, mask: torch.Tensor | None = None):
        # datas: [B, T, H, W, s], mask: [B, T, H, W, s] or None
        # check!!!!!!!!!!
        with torch.no_grad():
            if mask is None:
                criterion = nn.MSELoss(reduction="none")
                loss = criterion(data1, data2)
                mse = loss.mean()
            else:
                sqerr = (data1 - data2).pow(2) * mask
                sqerr_sum = sqerr.sum(dim=(2, 3))
                denom = mask.sum(dim=(2, 3)).clamp_(min=1e-6)
                mse = (sqerr_sum / denom).mean()
        return mse

    
    def save_model(self, epoch: int | None = None, losses: dict | None = None, pth_name: str = "model_tr.pth"):
        save_dict = {
            "args": vars(self.args) if self.args is not None else None,
            "epoch": epoch,
            "encoder": self.encoder.state_dict(),
            "latent_process": self.latent_process.state_dict(),
            "decoder": self.decoder.state_dict(),
            "losses": dict(losses) if isinstance(losses, dict) else losses
        }
        lc = getattr(self, "latent_center", None)
        if lc is not None:
            save_dict["latent_center"] = lc.detach().cpu().view(-1)
        ws = getattr(self, "whiten_scale", None)
        if ws is not None:
            save_dict["whiten_scale"] = ws.detach().cpu().view(-1)
        U = getattr(self, "U_proj", None)
        if U is not None:
            save_dict["U_proj"] = U.detach().cpu()
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        torch.save(save_dict, os.path.join(out_dir, f'{pth_name}'))


    def load_from_ckpt(self, ckpt_path: str, device: str | None = "cuda:0"):
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)

        device = device if device is not None else self.device

        self.encoder.load_state_dict(ckpt["encoder"])
        self.latent_process.load_state_dict(ckpt["latent_process"])
        self.decoder.load_state_dict(ckpt["decoder"])
        
        self.encoder.to(self.device)
        self.latent_process.to(device)
        self.decoder.to(self.device)

        lc = ckpt.get("latent_center", None)
        if lc is not None:
            # match dtype of latent_process params if possible
            p0 = next(self.latent_process.parameters(), None)
            dtype = p0.dtype if p0 is not None else torch.float32
            lc = lc.to(device=self.device, dtype=dtype).view(-1)
            if lc.numel() == self.latent_dim:
                self.latent_center = lc
            else:
                self.latent_center = None
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning(
                        f"[load_from_ckpt] latent_center dim {lc.numel()} != latent_dim {self.latent_dim}; dropped."
                    )
        else:
            self.latent_center = None

        ws = ckpt.get("whiten_scale", None) 
        if ws is not None:
            p0 = next(self.latent_process.parameters(), None)
            dtype = p0.dtype if p0 is not None else torch.float32
            ws = ws.to(device=self.device, dtype=dtype).view(-1)
            if ws.numel() == self.latent_dim and (ws > 0).all():
                self.set_whitening_scale(ws)
            else:
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning("whiten_scale invalid; ignored.")
        else:
            self.set_whitening_scale(None)

        U = ckpt.get("U_proj", None)
        if U is not None:
            try:
                self.set_projector(U)  # turn on use_projector
            except Exception as e:
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning(f"U_proj load failed; ignoring. err={e}")
        
        return {"epoch": ckpt.get("epoch", -1), "losses": ckpt.get("losses", None), "args": ckpt.get("args", None)}


    def _ensure_loader(self, group: str) -> None:
        """
        Lazily build dataloader for the given group if not yet built.
        group ∈ {"train", "train_eval", "test"}.
        """
        if group == "train" and self.train_loader is None:
            self.build_dataloader(group="train")
        elif group == "train_eval" and self.train_eval_loader is None:
            self.build_dataloader(group="train_eval")
        elif group == "test" and self.test_loader is None:
            self.build_dataloader(group="test")

    
    def sample_batch(self, group: str, batch_size: int, replace: bool = False):
        """
        Randomly sample one batch from a given split, independent of the
        loader's inherent batch size.

        Args:
            group: One of {"train", "train_eval", "test"}.
            batch_size: Desired batch size for visualization/rollout.
            replace: If True, sample with replacement when batch_size > N.

        Returns: one batch
        """
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                  "train_eval": self.train_eval_loader,
                  "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)

        if batch_size > N:
            replace = True

        # Use the same RNG style as data processor for reproducibility
        np_gen = getattr(self.data_processor, "np_gen", None)
        if np_gen is None:
            import numpy as np
            np_gen = np.random.default_rng(self.cfg.seed)

        indices = np_gen.choice(N, size=batch_size, replace=replace).tolist()
        subset = Subset(dataset, indices)

        # Build a temporary loader that yields exactly one batch
        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        batch = next(iter(tmp_loader))
        return batch

    
    def make_subset_from_saved(
        self,
        group: str,
        saved_json_path: str,
        batch_size: int | None = None
    ):    ########################################################################################################
        """
        Build a temporary DataLoader that contains the exact (traj_id, t0) pairs
        stored under `group` in split_metadata.json. This is useful for strict
        reproducible visualization/rollouts.

        Args:
            group: "train" / "train_eval" / "test"
            saved_json_path: path to split_metadata.json
            batch_size: optional batch size; if None, use len(index_list).

        Returns:
            A one-shot DataLoader that yields exactly those samples.
        """
        import json
        from torch.utils.data import Subset, DataLoader

        # Ensure dataset is constructed
        self._ensure_loader(group)
        dataset = {"train": self.train_loader.dataset,
                   "train_eval": self.train_eval_loader.dataset,
                   "test": self.test_loader.dataset}[group]

        with open(saved_json_path, "r") as f:
            payload = json.load(f)

        idx_list = payload.get("samples", {}).get(group, None)
        if idx_list is None:
            raise ValueError(f"No saved indices found for group '{group}' in {saved_json_path}")

        # Map (traj_id, t0) -> dataset indices: dataset.samples is a list of (traj_id, t0)
        # We build a position lookup for O(1) probing.
        position = { (int(tid), int(t0)): i for i, (tid, t0) in enumerate(dataset.samples) }
        take = []
        for tid, t0 in idx_list:
            key = (int(tid), int(t0))
            if key not in position:
                raise KeyError(f"Sample {key} not present in current dataset.samples")
            take.append(position[key])

        subset = Subset(dataset, take)
        if batch_size is None:
            batch_size = len(take)

        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        return tmp_loader

    
    def recon_one_batch(self, batch_samples,):
        self.switch_to_eval()
        device = self.device
        with torch.no_grad():
            latent_states, recon_loss, recon_loss_wz_mask = self._encode_and_recon(batch_samples)
            print(f"recon loss = {recon_loss}, recon loss (mask) = {recon_loss_wz_mask}")
            recon_seq = self._decode_latent(latent_seqs=latent_states)    # [B, t, H, W, c]
            recon_seq_ = recon_seq.permute(0, 2, 3, 1, 4)
        return recon_seq_, batch_samples["data"][:, self.n_frames_cond-1:, ...].permute(0, 2, 3, 1, 4)


    def linear_rollout_one_batch_with_Ab(self, batch_samples, rollout_steps: int, return_gt: bool = False):
        self.switch_to_eval()
        device = self.device
        assert hasattr(self, "Ad_phase1") and hasattr (self, "b_phase1"), f"load phase I checkpoint first"
        A, b = self.Ad_phase1, self.b_phase1
        with torch.no_grad():
            n_cond = self.n_frames_cond - 1
            ground_truth = batch_samples["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            masks = batch_samples["mask"].to(self.device)           # [B, T, H, W, n_ch]
            assert rollout_steps + self.data_processor.n_frames_cond <= ground_truth.shape[1]
            bs, length, H, W, _ = ground_truth.shape

            """mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
            # mask_index = mask_index.unsqueeze(1).expand(-1, length-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
            delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
            data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
            data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
            pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
            pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
            data_in = torch.cat((data_in, pos_in), dim=-1) 

            # latent_states: encoded from encoder / dyn_states: predicted using latent ode
            latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
            # reshape to [B, latent_dim]
            latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]"""
            latent_state = self._encode_cond_batch(batch_samples)

            dyn_states = torch.empty((rollout_steps+1, bs, self.latent_dim), device=device, dtype=latent_state.dtype)
            dyn_states[0] = latent_state
            if b is not None:
                b = b.view(1, self.latent_dim).to(device=device, dtype=latent_state.dtype)
            for t in range(rollout_steps):
                dyn_states[t + 1] = dyn_states[t] @ A.T + (b if b is not None else 0.0)
            if self.dec_mode == "fouriermlp":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).flatten(0, 1)
                grid_dim = self.pos_feat.shape[-1]
                grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
                recon_seq = self.decoder(grid=grid, latent_feat=latent_feats_dyn)
                recon_seq = recon_seq.reshape(bs, -1, *self.shapelist, self.state_dim)
            elif self.dec_mode == "fouriernet":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).view(bs, -1, self.state_dim, self.code_dim)
                recon_seq = self.decoder(latent_feats_dyn)
            recon_seq_ = recon_seq.permute(0, 2, 3, 1, 4)

            # recon_seq, ground_truth: [B, T, H, W, s]
            ground_truth_ = ground_truth[:, n_cond:n_cond+rollout_steps+1, ...].permute(0, 2, 3, 1, 4)
            masks_ = masks[:, n_cond:n_cond+rollout_steps+1, ...].permute(0, 2, 3, 1, 4)

            pred_err = self._compute_loss(recon_seq, ground_truth[:, n_cond:n_cond+rollout_steps+1, ...])
            pred_err_wzmask = self._compute_loss(
                recon_seq, ground_truth[:, n_cond:n_cond+rollout_steps+1, ...], masks[:, n_cond:n_cond+rollout_steps+1, ...]
            )
            print(f"pred loss = {pred_err}, pred loss (mask) = {pred_err_wzmask}")
        if return_gt:
            return recon_seq_, ground_truth_, ground_truth.permute(0, 2, 3, 1, 4)
        return recon_seq_, ground_truth_    # [B, ..., rollout_steps+1, C]


    def rollout_one_batch(self, batch_samples, rollout_steps: int, return_gt: bool = False):
        # batch_samples, containing data_tensor: [B, T, C, H, W], using first n_frames_cond to generate initial state,
        # record prediciton at n_cond_frames, n_cond_frames + 1, ..., n_cond_frames + rollout_steps (< T)
        self.switch_to_eval()
        with torch.no_grad():
            n_cond = self.n_frames_cond - 1
            ground_truth = batch_samples["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            t_eval = batch_samples['t'][0][n_cond:].to(self.device)
            sample_idx = batch_samples["index"].to(self.device)     # [B,]
            masks = batch_samples["mask"].to(self.device)           # [B, T, H, W, n_ch]
            assert rollout_steps + self.data_processor.n_frames_cond <= ground_truth.shape[1]
            bs, length, H, W, _ = ground_truth.shape

            """mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
            # mask_index = mask_index.unsqueeze(1).expand(-1, length-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
            delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
            data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
            data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
            pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
            pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
            data_in = torch.cat((data_in, pos_in), dim=-1) 

            # latent_states: encoded from encoder / dyn_states: predicted using latent ode
            latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
            # reshape to [B, latent_dim]
            latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]"""
            latent_state = self._encode_cond_batch(batch_samples)
            latent_state_ = self._center_latent(latent_state)
            latent_state_ = self._whiten_latent(latent_state_)
            if self.use_projector:
                latent_state_ = self._project_latent(latent_state_)

            # print(latent_state.shape, latent_state_.shape)
            dyn_states_, _, _ = self.latent_process(alpha_0=latent_state_, t_eval=t_eval[:rollout_steps+1], teacher_forcing=False)    # [T, B, latent_dim]
            if self.use_projector:
                dyn_states_ = self._lift_latent(dyn_states_)
            dyn_states_ = self._unwhiten_latent(dyn_states_)
            dyn_states = self._decenter_latent(dyn_states_)
            if self.dec_mode == "fouriermlp":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).flatten(0, 1)
                grid_dim = self.pos_feat.shape[-1]
                grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
                recon_seq = self.decoder(grid=grid, latent_feat=latent_feats_dyn)
                recon_seq = recon_seq.reshape(bs, -1, *self.shapelist, self.state_dim)
            elif self.dec_mode == "fouriernet":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).view(bs, -1, self.state_dim, self.code_dim)
                recon_seq = self.decoder(latent_feats_dyn)
            recon_seq_ = recon_seq.permute(0, 2, 3, 1, 4)

            # compute losses
            # recon_seq, ground_truth: [B, T, H, W, s]
            ground_truth_ = ground_truth[:, n_cond:n_cond+rollout_steps+1, ...].permute(0, 2, 3, 1, 4)
            masks_ = masks[:, n_cond:n_cond+rollout_steps+1, ...].permute(0, 2, 3, 1, 4)

            pred_err = self._compute_loss(recon_seq, ground_truth[:, n_cond:n_cond+rollout_steps+1, ...])
            pred_err_wzmask = self._compute_loss(
                recon_seq, ground_truth[:, n_cond:n_cond+rollout_steps+1, ...], masks[:, n_cond:n_cond+rollout_steps+1, ...]
            )
            print(f"pred loss = {pred_err}, pred loss (mask) = {pred_err_wzmask}")

        if return_gt:
            return recon_seq_, ground_truth_, ground_truth.permute(0, 2, 3, 1, 4)
        return recon_seq_, ground_truth_    # [B, ..., rollout_steps+1, C]


    def illustrate_one_frame_pred(self, batch_samples, rollout_steps: int, out_dir: str, dyn_type: str = "memory"):
        assert dyn_type in {"linear", "memory", "recon"}
        if dyn_type == "linear":
            pred_tensor, true_tensor = self.linear_rollout_one_batch_with_Ab(batch_samples, rollout_steps)
        elif dyn_type == "memory":
            pred_tensor, true_tensor = self.rollout_one_batch(batch_samples, rollout_steps)    # [B, H, W, T, C]
        else:
            pred_tensor, true_tensor = self.recon_one_batch(batch_samples)
        print(pred_tensor.shape, true_tensor.shape)
        assert self.spatial_dim == 2
        out_png = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="last", batch_index=0,
            save_path=os.path.join(out_dir, 'one_frame_pred'),
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28,   # ↓ smaller -> closer
            header_cols_ratio=0.34,   # optionally shrink the column-title row too
            time_fontsize=12          # slightly smaller time label
        )
        print("Saved last-step prediction to:", out_png)

    
    def illustrate_long_term_pred(self, batch_samples, rollout_steps: int, out_dir: str, dyn_type: str = "memory"):
        assert dyn_type in {"linear", "memory", "recon"}
        if dyn_type == "linear":
            pred_tensor, true_tensor = self.linear_rollout_one_batch_with_Ab(batch_samples, rollout_steps)
        elif dyn_type == "memory":
            pred_tensor, true_tensor = self.rollout_one_batch(batch_samples, rollout_steps)    # [B, H, W, T, C]
        else:
            pred_tensor, true_tensor = self.recon_one_batch(batch_samples)
        print(pred_tensor.shape, true_tensor.shape)
        assert self.spatial_dim == 2
        out_gif = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="all", batch_index=0,
            save_path=os.path.join(out_dir, 'long_term_seq.gif'),
            as_type="gif", fps=3,
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28, header_cols_ratio=0.34
        )
        print("Saved GIF to:", out_gif)


    def visualize_random_rollout(self, group: str, batch_size: int, rollout_steps: int,
                                 out_dir: str, mode: str = "last", dyn_type: str = "memory"):
        """
        Sample a random batch from the given split and visualize rollout.
        - group: "train" (seen IC), "train_eval", or "test" (new IC)
        - batch_size: arbitrary batch size for visualization (independent of training batch_size)
        - rollout_steps: number of prediction steps (<= T_out in data_y)
        - mode: "last" -> plot last-step panel; "all" -> export GIF/MP4 via your visualize utility
        """
        os.makedirs(out_dir, exist_ok=True)
        batch_samples = self.sample_batch(group=group, batch_size=batch_size)
        assert self.spatial_dim == 2
        if mode == "last":
            self.illustrate_one_frame_pred(batch_samples, rollout_steps, out_dir, dyn_type)
        else:
            self.illustrate_long_term_pred(batch_samples, rollout_steps, out_dir, dyn_type)


    def visualize_rollout_by_index(
        self,
        group: str,              # "train" | "train_eval" | "test"
        seq_index: int,          # zero-based index in the dataset order
        rollout_steps: int,
        out_dir: str,
        mode: str = "last",      # "last": plot last step; "all": export full GIF
        dyn_type: str = "memory",
    ):
        """
        Pick the `seq_index`-th sample (by dataset order) from the specified split,
        run a K-step rollout, and visualize the results.

        This bypasses any DataLoader shuffling by addressing the dataset directly.
        """
        import os
        from torch.utils.data import Subset, DataLoader

        os.makedirs(out_dir, exist_ok=True)

        # Ensure the corresponding loader (and hence dataset) exists
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)
        if not (0 <= seq_index < N):
            raise IndexError(f"seq_index={seq_index} out of bounds (0..{N-1})")

        # Build a temporary DataLoader that yields exactly this single sample (batch_size=1)
        subset = Subset(dataset, [seq_index])
        tmp_loader = DataLoader(subset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
        batch_samples = next(iter(tmp_loader))

        # Visualize rollout
        if mode == "last":
            self.illustrate_one_frame_pred(batch_samples, rollout_steps, out_dir, dyn_type)
        else:
            self.illustrate_long_term_pred(batch_samples, rollout_steps, out_dir, dyn_type)

    
    def save_rollout_comparison(
        self,
        group: str,              # "train" | "train_eval" | "test"
        rollout_steps: int,
        out_dir: str,
        seq_index: int | None = None,          # zero-based index in the dataset order
        save_linear: bool = True
    ):
        import os
        from torch.utils.data import Subset, DataLoader
        os.makedirs(out_dir, exist_ok=True)
        if seq_index is None:
            batch_samples = self.sample_batch(group=group, batch_size=1)
        else:
            # Ensure the corresponding loader (and hence dataset) exists
            self._ensure_loader(group)
            loader = {"train": self.train_loader,
                    "train_eval": self.train_eval_loader,
                    "test": self.test_loader}[group]
            dataset = loader.dataset
            N = len(dataset)
            if not (0 <= seq_index < N):
                raise IndexError(f"seq_index={seq_index} out of bounds (0..{N-1})")
            # Build a temporary DataLoader that yields exactly this single sample (batch_size=1)
            subset = Subset(dataset, [seq_index])
            tmp_loader = DataLoader(subset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
            batch_samples = next(iter(tmp_loader))

        if save_linear:
            pred_tensor, true_tensor, true_tensor_full = self.linear_rollout_one_batch_with_Ab(batch_samples, rollout_steps, return_gt=True)
            payload = {
                "pred": pred_tensor.detach().cpu().contiguous(),
                "true": true_tensor.detach().cpu().contiguous(),
                "true_full": true_tensor_full.detach().cpu().contiguous(),
            }
            torch.save(payload, os.path.join(out_dir, "phase1"))

        pred_tensor, true_tensor, true_tensor_full = self.rollout_one_batch(batch_samples, rollout_steps, return_gt=True)
        payload = {
            "pred": pred_tensor.detach().cpu().contiguous(),
            "true": true_tensor.detach().cpu().contiguous(),
            "true_full": true_tensor_full.detach().cpu().contiguous(),
        }
        torch.save(payload, os.path.join(out_dir, "phase2"))


    def sample_from_fix(self, traj_id: int, t0: int, rollout_steps: int | None = None):
        from torch.utils.data import DataLoader
        p = self.data_processor
        cfg = p.cfg
        samples = [(int(traj_id), t0)]

        dataset = PDEDataset(
            data_tensor=(p.data_norm if cfg.normalize else p.data),
            cfg=cfg,
            n_frames_train=cfg.n_frames_train,
            n_frames_cond=cfg.n_frames_cond,
            n_frames_out=cfg.n_frames_out if rollout_steps is None else rollout_steps-cfg.n_frames_train+cfg.n_frames_cond,
            traj_indices=[traj_id],
            n_sample_per_traj=1,
            sample_strategy="disjoint",
            mode="interpolation",
            group="test",
            samples=samples,
            mask_tensor=p.mask_tensor,
            np_rng=p.np_gen
        )
        loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=cfg.num_workers,
                            pin_memory=True, drop_last=False)
        return loader


    def save_rollout_tensors(
        self,
        out_dir: str,
        traj_id: int, t0: int, rollout_steps: int | None = None,
        phase: str = "phase1",
    ):
        os.makedirs(out_dir, exist_ok=True)
        self.switch_to_eval()
        loader = self.sample_from_fix(traj_id, t0, rollout_steps)
        for batch in loader:
            masks = batch["mask"].permute(0, 2, 3, 1, 4)    # [1, ..., T, C]
            if phase == "phase1":
                recon_seq_, ground_truth_, ground_truth = self.linear_rollout_one_batch_with_Ab(
                    batch, rollout_steps, return_gt=True
                )
            else:
                recon_seq_, ground_truth_, ground_truth = self.rollout_one_batch(
                    batch, rollout_steps, return_gt=True
                )
            print(f"recon_seq_: {recon_seq_.shape}, ground_truth_: {ground_truth_.shape}, ground_truth: {ground_truth.shape}")
            payload = {
                "traj_id": int(traj_id),
                "t0": int(t0),
                "pred": recon_seq_.squeeze(0).detach().cpu(),                      # [H, W, rollout_steps+1, C]
                "gt": ground_truth_.squeeze(0).detach().cpu(),                     # [H, W, rollout_steps+1, C]
                "gt_full": ground_truth.squeeze(0).detach().cpu(),                 # [H, W, T', C]
                "mask_full": masks.squeeze(0).detach().cpu(),                      # [H, W, T', C]
            }
            torch.save(payload, os.path.join(out_dir, f"traj_id{traj_id}_t0_{t0}.pt"))
    

    def _build_long_eval_loader(self, group: str = "test", rollout_steps: int | None = None, batch_size: int | None = None):
        from torch.utils.data import DataLoader, Subset
        p = self.data_processor
        cfg = p.cfg
        traj_indices = p.train_traj_ids if group in {"train", "train_eval"} else p.test_traj_ids
        samples = [(int(tid), 0) for tid in traj_indices]

        dataset = PDEDataset(
            data_tensor=(p.data_norm if cfg.normalize else p.data),
            cfg=cfg,
            n_frames_train=cfg.n_frames_train,
            n_frames_cond=cfg.n_frames_cond,
            n_frames_out=cfg.n_frames_out if rollout_steps is None else rollout_steps,
            traj_indices=traj_indices,
            n_sample_per_traj=1,
            sample_strategy="disjoint",
            mode="interpolation",
            group=("train_eval" if group == "train" else group),
            samples=samples,
            mask_tensor=p.mask_tensor,
            np_rng=p.np_gen
        )
        bs = (cfg.test_bs if batch_size is None else int(batch_size))
        loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=cfg.num_workers,
                            pin_memory=True, drop_last=False)
        return loader


    @torch.no_grad()
    def evaluate_long_trajs(
        self, out_dir: str, 
        group: str = "test",
        rollout_steps: int | None = None,
        batch_size: int | None = None,
        loader=None,
        save_pt: bool = True,
        save_png: bool = False,
        vis_cols: int = 8,
        c_vis: int = 0,
        save_time_curve: bool = True
    ):
        import os, json, numpy as np
        import matplotlib.pyplot as plt
        t_eval_ref = None
        curve_sum   = None
        curve_sumsq = None
        curve_count = None

        os.makedirs(out_dir, exist_ok=True)
        self.switch_to_eval()
        if loader is None:
            loader = self._build_long_eval_loader(group=group, rollout_steps=rollout_steps, batch_size=batch_size)
        dataset = loader.dataset
        base_dataset = getattr(dataset, "dataset", dataset)
        samples_list = base_dataset.samples  # List[(traj_id, t0)]

        all_mse = []
        for batch in loader:
            latent_state = self._encode_cond_batch(batch)          # [B, latent_dim]
            latent_state_ = self._center_latent(latent_state)
            latent_state_ = self._whiten_latent(latent_state_)
            if self.use_projector:
                latent_state_ = self._project_latent(latent_state_)
            n_cond_m1 = int(self.n_frames_cond) - 1
            t_vec = batch["t"][0].to(self.device)                  # [T]
            t_eval = t_vec[n_cond_m1:]                             # [T' = length - (n_cond-1)]
            dyn_states_, _, _ = self.latent_process(alpha_0=latent_state_, t_eval=t_eval, teacher_forcing=False)
            if self.use_projector:
                dyn_states_ = self._lift_latent(dyn_states_)
            dyn_states_ = self._unwhiten_latent(dyn_states_)
            dyn_states = self._decenter_latent(dyn_states_)        # [T', B, latent_dim]
            bs = latent_state.shape[0]
            if self.dec_mode == "fouriermlp":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).flatten(0, 1)          # [B*T', D]
                grid_dim = self.pos_feat.shape[-1]
                grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)            # [H*W, P]
                recon_seq = self.decoder(grid=grid, latent_feat=latent_feats_dyn)     # [B*T', H, W, C]
                H, W, C = self.shapelist[0], self.shapelist[1], self.state_dim
                recon_seq = recon_seq.reshape(bs, -1, H, W, C)
            elif self.dec_mode == "fouriernet":
                latent_feats_dyn = dyn_states.permute(1, 0, 2).view(bs, -1, self.state_dim, self.code_dim)
                recon_seq = self.decoder(latent_feats_dyn)                               # [B, T', H, W, C]
            recon_seq_sp_last = recon_seq.permute(0, 2, 3, 1, 4).contiguous()           # [B, H, W, T', C]
            gt_full = batch["data"].to(self.device)                                     # [B, T, H, W, C]
            gt_slice = gt_full[:, n_cond_m1:, ...]                                      # [B, T', H, W, C]
            mask_full = batch.get("mask", None)
            if mask_full is not None:
                mask_slice = mask_full.to(self.device)[:, n_cond_m1:, ...]              # [B, T', H, W, C]
            else:
                mask_slice = None

            if mask_slice is None:
                se = (recon_seq - gt_slice).pow(2)                                      # [B,T',H,W,C]
                mse_b = se.mean(dim=(1,2,3,4))                                          # [B]
                mse_t = se.mean(dim=(0,2,3,4)).detach().cpu()                           # [T']
                mse_bt = se.mean(dim=(2,3,4))                                           # [time-curve] -> [B,T']
                valid_bt = torch.ones_like(mse_bt, dtype=mse_bt.dtype)                  
            else:
                se = (recon_seq - gt_slice).pow(2) * mask_slice
                num_b = se.sum(dim=(1,2,3,4)); den_b = mask_slice.sum(dim=(1,2,3,4)).clamp_min(1e-6)
                mse_b = (num_b / den_b)                                                 # [B]
                num_t = se.sum(dim=(0,2,3,4)); den_t = mask_slice.sum(dim=(0,2,3,4)).clamp_min(1e-6)
                mse_t = (num_t / den_t).detach().cpu()                                  # [T']
                num_bt = se.sum(dim=(2,3,4))                                            # NEW [time-curve] -> [B,T']
                den_bt = mask_slice.sum(dim=(2,3,4)).clamp_min(1e-6)                    # NEW [time-curve]
                mse_bt = (num_bt / den_bt)                                              # NEW [time-curve]
                valid_bt = (den_bt > 0).to(mse_bt.dtype)                                # NEW [time-curve]

            all_mse.extend([float(x) for x in mse_b])

            # -------- NEW [time-curve]: 累计跨样本的逐时间 MSE 统计 --------
            if save_time_curve:
                mse_bt_cpu   = mse_bt.detach().cpu()
                valid_bt_cpu = valid_bt.detach().cpu()
                if curve_sum is None:
                    Tprime = int(mse_bt_cpu.shape[1])
                    import torch as _torch
                    curve_sum   = _torch.zeros(Tprime)
                    curve_sumsq = _torch.zeros(Tprime)
                    curve_count = _torch.zeros(Tprime)
                    t_eval_ref  = t_eval.detach().cpu()
                else:
                    assert mse_bt_cpu.shape[1] == curve_sum.shape[0], "T' mismatch across batches"

                curve_sum   += (mse_bt_cpu * valid_bt_cpu).sum(dim=0)     # sum_b MSE
                curve_sumsq += ((mse_bt_cpu**2) * valid_bt_cpu).sum(dim=0)# sum_b MSE^2
                curve_count += valid_bt_cpu.sum(dim=0)                    # 有效样本计数
            # ----------------------------------------------------------------


            if save_pt or save_png:
                idx_vec = batch["index"].tolist() 
                for b, ds_idx in enumerate(idx_vec):
                    traj_id, t0 = samples_list[int(ds_idx)]
                    stem = f"traj{int(traj_id):04d}_t0{int(t0):04d}"

                    if save_pt:
                        payload = {
                            "traj_id": int(traj_id),
                            "t0": int(t0),
                            "t": t_eval.detach().cpu(),                                  # [T']
                            "pred": recon_seq_sp_last[b].detach().cpu(),                 # [H,W,T',C]
                            "gt":   gt_slice[b].permute(1,2,0,3).detach().cpu(),         # [H,W,T',C]
                            "mse_all": float(mse_b[b].item()),
                            "mse_t": mse_t,                                              # [T']
                            "normalized": bool(self.data_processor.cfg.normalize),
                        }
                        torch.save(payload, os.path.join(out_dir, f"{stem}.pt"))

                    if save_png:
                        try:
                            import numpy as np
                            Tsel = min(vis_cols, recon_seq.shape[1])
                            idx  = np.linspace(0, recon_seq.shape[1]-1, Tsel, dtype=int)
                            gt_np  = gt_slice[b].detach().cpu().numpy()                  # [T',H,W,C]
                            pr_np  = recon_seq[b].detach().cpu().numpy()                 # [T',H,W,C]
                            err_np = np.abs(gt_np - pr_np)

                            rows = 3
                            fig, axes = plt.subplots(rows, Tsel, figsize=(2.6*Tsel, 2.6*rows), squeeze=False)
                            for j, ti in enumerate(idx):
                                axes[0, j].imshow(gt_np[ti, :, :, c_vis]); axes[0, j].set_title(f"GT t={float(t_eval[ti]):.2f}")
                                axes[1, j].imshow(pr_np[ti, :, :, c_vis]); axes[1, j].set_title("Pred")
                                axes[2, j].imshow(err_np[ti, :, :, c_vis]); axes[2, j].set_title("|Err|")
                                for r in range(rows): axes[r, j].axis('off')
                            plt.tight_layout()
                            plt.savefig(os.path.join(out_dir, f"{stem}.png"), dpi=160)
                            plt.close(fig)
                        except Exception:
                            pass

            print(f"[{group}] batch_size={recon_seq.shape[0]} | T'={recon_seq.shape[1]} | MSE_mean(batch)={float(mse_b.mean().item()):.6e}")

        import numpy as np
        mean_mse = float(np.mean(all_mse)) if all_mse else float("nan")
        std_mse  = float(np.std(all_mse))  if all_mse else float("nan")
        summary = {"group": group, "n_samples": len(all_mse), "MSE_mean": mean_mse, "MSE_std": std_mse}
        with open(os.path.join(out_dir, f"summary_{group}.json"), "w") as f:
            json.dump(summary, f, indent=2)
        print(f"==> [{group}] Long-eval(batched) summary: MSE_mean={mean_mse:.6e} over {len(all_mse)} samples")

        # -------- NEW [time-curve]: 汇总并保存“误差-时间”曲线（均值±标准差） --------
        if save_time_curve and (curve_count is not None) and (curve_count.max().item() > 0):
            mean_curve = (curve_sum / curve_count).numpy()                 # [T']
            var_curve  = (curve_sumsq / curve_count).numpy() - mean_curve**2
            var_curve  = np.clip(var_curve, 0.0, None)
            std_curve  = np.sqrt(var_curve)

            fig, ax = plt.subplots(figsize=(7, 4))
            t_np = t_eval_ref.numpy()
            ax.plot(t_np, mean_curve, label="MSE (mean across traj)")
            ax.fill_between(t_np, mean_curve-std_curve, mean_curve+std_curve, alpha=0.2, label="±1 std")
            ax.set_xlabel("t")
            ax.set_ylabel("MSE")
            ax.set_title(f"MSE vs time — {group}")
            ax.grid(True, alpha=0.3)
            ax.legend()
            fig.tight_layout()
            plt.savefig(os.path.join(out_dir, f"mse_over_time_{group}.png"), dpi=180)
            plt.close(fig)

            # np.savez(os.path.join(out_dir, f"mse_over_time_{group}.npz"),
            #         t=t_np, mean=mean_curve, std=std_curve, count=curve_count.numpy())
            out_json = os.path.join(out_dir, f"mse_over_time_{group}.json")                         # NEW [time-curve-json]
            series = {                                                                              # NEW [time-curve-json]
                "group": group,                                                                     # NEW [time-curve-json]
                "t":    [float(x) for x in t_np.tolist()],                                          # NEW [time-curve-json]
                "mean": [float(x) for x in mean_curve.tolist()],                                    # NEW [time-curve-json]
                "std":  [float(x) for x in std_curve.tolist()],                                     # NEW [time-curve-json]
                "count":[int(x)   for x in curve_count.detach().cpu().numpy().astype(int).tolist()] # NEW [time-curve-json]
            }                                                                                       # NEW [time-curve-json]
            with open(out_json, "w") as f:                                                          # NEW [time-curve-json]
                json.dump(series, f, indent=2)                                                      # NEW [time-curve-json]
        # --------------------------------------------------------------------

        return summary


    def evaluate_long_by_indices(self, out_dir: str, group: str, indices: list[int],
                                 rollout_steps: int | None = None,
                                 **kwargs):
        from torch.utils.data import Subset, DataLoader
        base_loader = self._build_long_eval_loader(group=group, rollout_steps=rollout_steps)
        subset = Subset(base_loader.dataset, indices)
        sub_loader = DataLoader(subset, batch_size=len(indices), shuffle=False,
                                num_workers=0, pin_memory=True)
        return self.evaluate_long_trajs(group=group, out_dir=out_dir, loader=sub_loader, **kwargs)


    @torch.no_grad()
    def plot_centering_effects(
        self,
        group: str,
        seq_index: int = 0,
        rollout_steps: int = 32,
        save_dir: str | None = None,
        use_mask: bool = True,
    ):
        """
        Compare centered vs uncentered for both DISCRETE (Ad,b) and CONTINUOUS (latent_process) rollouts.
        Produce two figures:
        (1) per-time-step masked MSE curves
        (2) per-time-step latent L2-norm curves

        Args
        ----
        group         : "train" | "train_eval" | "test"
        seq_index     : take the seq_index-th sample (by dataset order)
        rollout_steps : number of predicted steps (<= T_out)
        save_dir      : if provided, figures will be saved here
        use_mask      : use dataset masks when computing MSE (recommended)
        """
        import os
        import numpy as np
        import matplotlib.pyplot as plt

        # ---------------------- helpers ----------------------
        def _one_sample_from_group(group: str, seq_index: int):
            """Return a batch (B=1) and some handy tensors."""
            from torch.utils.data import Subset, DataLoader
            self._ensure_loader(group)
            loader = {"train": self.train_loader,
                    "train_eval": self.train_eval_loader,
                    "test": self.test_loader}[group]
            dataset = loader.dataset
            if not (0 <= seq_index < len(dataset)):
                raise IndexError(f"seq_index={seq_index} out of bounds (0..{len(dataset)-1})")
            tmp = DataLoader(Subset(dataset, [seq_index]), batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
            batch = next(iter(tmp))
            return batch, dataset

        def _per_time_mse(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor | None):
            """
            pred, gt: [B, T, H, W, C]; return [T] masked average MSE per time step.
            """
            B, T, H, W, C = pred.shape
            if mask is None or not use_mask:
                se = (pred - gt).pow(2)                       # [B,T,H,W,C]
                num = se.sum(dim=(0, 2, 3, 4))                # [T]
                den = torch.full_like(num, fill_value=B*H*W*C, dtype=pred.dtype)
            else:
                se = (pred - gt).pow(2) * mask                # [B,T,H,W,C]
                num = se.sum(dim=(0, 2, 3, 4))                # [T]
                den = mask.sum(dim=(0, 2, 3, 4)).clamp_min(1e-6)  # [T]
            return (num / den).cpu()                          # [T]

        def _latent_l2_per_time(lat_seq: torch.Tensor):
            """
            lat_seq: [T, B, D] -> return [T] average L2 over B.
            """
            l2 = torch.linalg.norm(lat_seq, dim=-1)           # [T,B]
            return l2.mean(dim=1).cpu()                       # [T]

        def _encode_initial_and_targets(batch):
            """
            Build initial latent a0 from the first nf_cond frames, and assemble ground-truth targets and masks.
            Returns:
            a0         : [B, D]
            gt_seq     : [B, T', H, W, C] where T' = rollout_steps+1 aligned with predictions
            mask_seq   : same shape as gt_seq or None
            t_eval_sub : [T'] relative evaluation times for the rollout
            """
            ground_truth = batch["data"].to(self.device)    # [B, T, H, W, C]
            masks = batch["mask"].to(self.device)           # [B, T, H, W, C]
            t_full = batch["t"][0].to(self.device)          # [T]
            B, T, H, W, C = ground_truth.shape
            n_cond = self.n_frames_cond - 1
            assert rollout_steps + self.data_processor.n_frames_cond <= T, \
                f"rollout_steps too long for this sequence (T={T})"

            # ---- encode a0 from the first nf_cond frames (your standard pipeline) ----
            mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
            delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)  # [B,1,...]
            data_in = delay_data.flatten(0, 1).flatten(1, 2)   # [B, H*W, n_ch*nf_cond]
            data_in = self.index_points(data_in, mask_index)   # [B, S, n_ch*nf_cond]
            pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(B, -1, -1).to(self.device)
            pos_in = self.index_points(pos_in, mask_index)     # [B, S, 2]
            data_in = torch.cat((data_in, pos_in), dim=-1)
            latent_token = self.encoder(data_in, pos_in)       # [B, K, token]
            a0 = latent_token.reshape(B, -1)                   # [B, D]

            # ---- ground-truth subseq aligned with predictions ----
            t_eval_sub = t_full[n_cond:n_cond+rollout_steps+1]  # [T']
            gt_seq = ground_truth[:, n_cond:n_cond+rollout_steps+1, ...]  # [B,T',H,W,C]
            mask_seq = masks[:, n_cond:n_cond+rollout_steps+1, ...]       # [B,T',H,W,C]
            return a0, gt_seq, mask_seq, t_eval_sub

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)

        # ---------------------- fetch one sample ----------------------
        batch, _ = _one_sample_from_group(group, seq_index)
        a0, gt_seq, mask_seq, t_eval_sub = _encode_initial_and_targets(batch)
        B, Tprime, H, W, C = gt_seq.shape
        D = a0.shape[-1]

        # convenience
        zstar = getattr(self, "latent_center", None)
        has_center = (zstar is not None) and (zstar.numel() == D)

        # ---------------------- rollouts we will compare ----------------------
        curves = {
            "disc_uncentered": None,
            "disc_centered": None,
            "cont_uncentered": None,
            "cont_centered": None,
        }
        l2_curves = {}   # per-time latent L2 curves

        # ===== DISCRETE: Ad,b linear rollout =====
        if hasattr(self, "Ad_phase1") and (self.Ad_phase1 is not None):
            Ad = self.Ad_phase1.to(self.device, dtype=a0.dtype)
            b  = getattr(self, "b_phase1", None)
            b  = (b.to(self.device, dtype=a0.dtype).view(1, -1) if b is not None else None)

            # --- uncentered: z_{t+1} = A z_t + b ---
            z = torch.empty((Tprime, B, D), device=self.device, dtype=a0.dtype)
            z[0] = a0
            for t in range(Tprime - 1):
                z[t+1] = z[t] @ Ad.T + (b if b is not None else 0.0)
            l2_curves["disc_uncentered"] = _latent_l2_per_time(z)

            # decode to fields
            if self.dec_mode == "fouriermlp":
                latent_flat = z.permute(1, 0, 2).flatten(0, 1)        # [B*T', D]
                grid_dim = self.pos_feat.shape[-1]
                grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)
                pred = self.decoder(grid=grid, latent_feat=latent_flat).reshape(B, Tprime, *self.shapelist, self.state_dim)
            else:
                latent_4 = z.permute(1, 0, 2).view(B, Tprime, self.state_dim, self.code_dim)
                pred = self.decoder(latent_4)
            curves["disc_uncentered"] = _per_time_mse(pred, gt_seq, mask_seq)

            # --- centered: y=z-z*, y_{t+1}=A y_t, final add z* back ---
            if has_center:
                zc = zstar.to(self.device, dtype=a0.dtype).view(1, -1)  # [1,D]
                y = torch.empty((Tprime, B, D), device=self.device, dtype=a0.dtype)
                y[0] = a0 - zc
                for t in range(Tprime - 1):
                    y[t+1] = y[t] @ Ad.T   # no bias if z* solves (I-Ad)z*=b
                z_cent = y + zc
                l2_curves["disc_centered"] = _latent_l2_per_time(z_cent)

                if self.dec_mode == "fouriermlp":
                    latent_flat = z_cent.permute(1, 0, 2).flatten(0, 1)
                    grid_dim = self.pos_feat.shape[-1]
                    grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)
                    pred_c = self.decoder(grid=grid, latent_feat=latent_flat).reshape(B, Tprime, *self.shapelist, self.state_dim)
                else:
                    latent_4 = z_cent.permute(1, 0, 2).view(B, Tprime, self.state_dim, self.code_dim)
                    pred_c = self.decoder(latent_4)
                curves["disc_centered"] = _per_time_mse(pred_c, gt_seq, mask_seq)
            else:
                print("[plot_centering_effects] latent_center is None -> skip discrete-centered curve.")
        else:
            print("[plot_centering_effects] Ad_phase1/b_phase1 not found -> skip DISCRETE comparison.")

        # ===== CONTINUOUS: latent_process ODE rollout =====
        # uncentered: feed a0 directly (no center/decenter)
        out_uc = self.latent_process(alpha_0=a0, t_eval=t_eval_sub - t_eval_sub[0], teacher_forcing=False)
        z_uc = out_uc[0] if isinstance(out_uc, tuple) else out_uc   # [T', B, D]
        l2_curves["cont_uncentered"] = _latent_l2_per_time(z_uc)
        # decode
        if self.dec_mode == "fouriermlp":
            latent_flat = z_uc.permute(1, 0, 2).flatten(0, 1)
            grid_dim = self.pos_feat.shape[-1]
            grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)
            pred_uc = self.decoder(grid=grid, latent_feat=latent_flat).reshape(B, Tprime, *self.shapelist, self.state_dim)
        else:
            latent_4 = z_uc.permute(1, 0, 2).view(B, Tprime, self.state_dim, self.code_dim)
            pred_uc = self.decoder(latent_4)
        curves["cont_uncentered"] = _per_time_mse(pred_uc, gt_seq, mask_seq)

        # centered: a0-z*, integrate, then add z* back (your train_phase2 style)
        if has_center:
            a0c = a0 - zstar.to(self.device, dtype=a0.dtype).view(1, -1)
            out_c = self.latent_process(alpha_0=a0c, t_eval=t_eval_sub - t_eval_sub[0], teacher_forcing=False)
            z_c = (out_c[0] if isinstance(out_c, tuple) else out_c) + zstar.to(self.device, dtype=a0.dtype).view(1, 1, -1)
            l2_curves["cont_centered"] = _latent_l2_per_time(z_c)
            if self.dec_mode == "fouriermlp":
                latent_flat = z_c.permute(1, 0, 2).flatten(0, 1)
                grid_dim = self.pos_feat.shape[-1]
                grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)
                pred_c = self.decoder(grid=grid, latent_feat=latent_flat).reshape(B, Tprime, *self.shapelist, self.state_dim)
            else:
                latent_4 = z_c.permute(1, 0, 2).view(B, Tprime, self.state_dim, self.code_dim)
                pred_c = self.decoder(latent_4)
            curves["cont_centered"] = _per_time_mse(pred_c, gt_seq, mask_seq)
        else:
            print("[plot_centering_effects] latent_center is None -> skip continuous-centered curve.")

        # ---------------------- plot: MSE curves ----------------------
        t_axis = np.arange(Tprime)
        fig1, ax1 = plt.subplots()
        if curves["disc_uncentered"] is not None: ax1.plot(t_axis, curves["disc_uncentered"].numpy(), label="Discrete - Uncentered")
        if curves["disc_centered"]   is not None: ax1.plot(t_axis, curves["disc_centered"].numpy(),   label="Discrete - Centered")
        if curves["cont_uncentered"] is not None: ax1.plot(t_axis, curves["cont_uncentered"].numpy(), label="Continuous - Uncentered")
        if curves["cont_centered"]   is not None: ax1.plot(t_axis, curves["cont_centered"].numpy(),   label="Continuous - Centered")
        ax1.set_xlabel("time step")
        ax1.set_ylabel("masked MSE" if use_mask else "MSE")
        ax1.set_title(f"Per-step error  (group={group}, idx={seq_index})")
        ax1.legend()
        ax1.grid(True, linestyle="--", alpha=0.3)
        if save_dir:
            p1 = os.path.join(save_dir, f"center_vs_nocenter_mse_{group}_idx{seq_index}.png")
            fig1.savefig(p1, dpi=200, bbox_inches="tight")
            print("Saved:", p1)
        plt.close(fig1)

        # ---------------------- plot: latent L2 curves ----------------------
        fig2, ax2 = plt.subplots()
        for k in ["disc_uncentered", "disc_centered", "cont_uncentered", "cont_centered"]:
            v = l2_curves.get(k, None)
            if v is not None:
                ax2.plot(t_axis, v.numpy(), label=k.replace("_", " "))
        ax2.set_xlabel("time step")
        ax2.set_ylabel("‖latent‖₂ (avg over batch)")
        ax2.set_title(f"Latent L2 per step  (group={group}, idx={seq_index})")
        ax2.legend()
        ax2.grid(True, linestyle="--", alpha=0.3)
        if save_dir:
            p2 = os.path.join(save_dir, f"center_vs_nocenter_l2_{group}_idx{seq_index}.png")
            fig2.savefig(p2, dpi=200, bbox_inches="tight")
            print("Saved:", p2)
        plt.close(fig2)

        # ---------------------- quick console summary ----------------------
        def _summ(v):
            if v is None: return None
            return dict(mean=float(v.mean()), last=float(v[-1]), min=float(v.min()), max=float(v.max()))
        summary = {k: _summ(v) for k, v in curves.items()}
        print("[plot_centering_effects] summary (per-step MSE):", summary)
        return summary

    # ---------------------------------------------
    # Utilities for diagnostics & regularizers
    # ---------------------------------------------

    @staticmethod
    @torch.no_grad()
    def _spectral_radius(A: torch.Tensor) -> float:
        """Return spectral radius max|lambda_i(A)| for a real square matrix."""
        eig = torch.linalg.eigvals(A)  # complex
        return float(eig.abs().max().item())


    @staticmethod
    @torch.no_grad()
    def _fit_affine_Qc(X: torch.Tensor, Y: torch.Tensor, ridge: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Fit an affine map Y ≈ X @ Q.T + c using ridge least squares (row-vector convention).
        X, Y: [N, D]
        Returns:
            Q: [D, D], c: [D]
        """
        assert X.shape == Y.shape
        N, D = X.shape
        ones = torch.ones(N, 1, device=X.device, dtype=X.dtype)
        X1 = torch.cat([X, ones], dim=1)                 # [N, D+1]
        R = torch.zeros(D + 1, D + 1, device=X.device, dtype=X.dtype)
        R[:D, :D] = ridge * torch.eye(D, device=X.device, dtype=X.dtype)
        Theta = torch.linalg.solve(X1.T @ X1 + R, X1.T @ Y)  # [D+1, D]
        Q = Theta[:D, :].T.contiguous()
        c = Theta[D, :].contiguous()
        return Q, c


    @torch.no_grad()
    def _procrustes_error_affine(self, Z_old: torch.Tensor, Z_new: torch.Tensor, ridge: float = 1e-6) -> float:
        """
        Relative Procrustes error with affine map:
        err = ||Z_new - (Z_old @ Q.T + c)|| / ||Z_new||
        where (Q,c) is fitted by ridge regression.
        Z_old, Z_new: [N, D]
        """
        N = min(Z_old.size(0), Z_new.size(0))
        Z_old = Z_old[:N]
        Z_new = Z_new[:N]
        Q, c = self._fit_affine_Qc(Z_old, Z_new, ridge=ridge)
        pred = Z_old @ Q.T + c
        err = torch.norm(Z_new - pred) / (torch.norm(Z_new) + 1e-12)
        return float(err.item())


    """def _multistep_latent_consistency(self, Z: torch.Tensor, A: torch.Tensor, b: torch.Tensor | None, H: int = 0) -> torch.Tensor:
        if H is None or H <= 1:
            return Z.new_zeros(())
        T, B, D = Z.shape
        loss = Z.new_zeros(())
        for h in range(2, min(H, T - 1) + 1):
            Zt  = Z[:-h].reshape(-1, D)    # all (t,b) that have t+h
            Zth = Z[h:].reshape(-1, D)
            roll = Zt
            for _ in range(h):
                roll = roll @ A.T + (b if b is not None else 0.0)
            loss = loss + F.mse_loss(roll, Zth)
        return loss / max(1, min(H, T - 1) - 1)"""
    
    
    def _multistep_latent_consistency(self,
                                      latent_states: torch.Tensor,  # [T, B, D]
                                      A: torch.Tensor,              # [D, D]
                                      b: torch.Tensor | None = None,
                                      H: int = 4,
                                      gamma: float = 1.0):
        T, B, D = latent_states.shape
        if H <= 0 or T < 2:
            return latent_states.new_tensor(0.0)
        H = min(H, T - 1)

        A_T = A.transpose(-1, -2)
        w_sum = 0.0
        loss_acc = latent_states.new_tensor(0.0)

        z_pred = latent_states[:-1]                 # [T-1, B, D]

        for h in range(1, H + 1):
            z_pred = z_pred @ A_T + (b if b is not None else 0.0)
            T_h = T - h
            z_use = z_pred[:T_h]                    # [T-h, B, D]
            target = latent_states[h:]              # [T-h, B, D]
            mse_h = (z_use - target).pow(2).mean()
            w = (gamma ** h)
            loss_acc = loss_acc + w * mse_h
            w_sum += w
            if h < H:
                z_pred = z_pred[:-1]                
        return loss_acc / max(w_sum, 1e-12)


    ########################## visualizations ##########################
    @torch.no_grad()
    def plot_dyn_energy_stats(
        self,
        dataloader,
        save_dir: str,
        fname_prefix: str = "dyn_energy",
        center_and_whiten: bool = True,
        use_projector_space: bool = True,
        # ---- style knobs (new) ----
        fig_w: float = 7.2,
        fig_h: float = 4.0,
        line_w: float = 2.4,            # mean-curve line width
        band_alpha: float = 0.18,       # ±std band transparency
        tick_labelsize: int = 14,       # tick numbers size (bigger axis numbers)
        axis_labelsize: int = 15,       # x/y label size
        spine_lw: float = 1.2,          # all four spines width
        major_grid_alpha: float = 0.32, # grid strength
        minor_grid_alpha: float = 0.22,
        use_minor_grid: bool = True,
        color_linear: str = "#1f77b4",  # vivid blue
        color_memory: str = "#d62728",  # vivid red
        dpi: int = 450,
    ):
        """
        Compute per-timestep L2 norms of linear & memory contributions, aggregate mean±std,
        and plot them with a clean, publication-friendly style.

        - Four spines enabled and thickened.
        - Larger tick labels (axis numbers) and bold axis labels.
        - Dense major + minor grid.
        - Mean curves with ±std shading.

        Saves:
            save_dir/{fname_prefix}.png
            save_dir/{fname_prefix}.json
            save_dir/{fname_prefix}_raw.pt
        """
        import os, json, torch
        import numpy as np
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker
        from matplotlib import rcParams

        os.makedirs(save_dir, exist_ok=True)

        # --------------------------- accumulate stats ---------------------------
        lin_list = []   # list length T-1; each element is a list of [B] tensors
        mem_list = []
        T_ref = None

        for batch in dataloader:
            gt = batch["data"].to(self.device)         # [B, T, H, W, C]
            masks = batch.get("mask", None)
            if masks is not None:
                masks = masks.to(self.device)
            t_eval = batch["t"][0].to(self.device)     # [T]
            B, T = gt.shape[0], gt.shape[1]
            if T_ref is None:
                T_ref = T

            # Encode following your training path (already aligned to [nf_cond-1:])
            latent_states, _, _ = self._encode_and_recon(batch)  # [T', B, D_full]
            z = latent_states
            Tprime = z.size(0)  # timesteps used by latent process

            # Canonicalize to internal latent process space (center/whiten/project)
            z_ = z
            if center_and_whiten:
                z_ = self._center_latent(z_)
                z_ = self._whiten_latent(z_)
            if use_projector_space and self.use_projector:
                z_ = self._project_latent(z_)     # [T', B, d_internal]

            # One teacher-forced pass to expose memory states if any
            out = self.latent_process.forward(
                alpha_0=z_[0],
                t_eval=t_eval[self.n_frames_cond-1:],
                memory_init=None,
                teacher_forcing=True,
                tf_alpha=z_,
                tf_epsilon=0.0,
                tf_mask=None,
            )
            if isinstance(out, tuple):
                _, memory_states, aux = out
            else:
                memory_states, aux = None, {}

            # Separate linear vs memory contributions (discrete / continuous)
            if self.latent_mode == "discrete":
                A = self.latent_process.A                 # [D_s, D_s]
                zk = z_[:-1]                              # [T'-1, B, D_s]
                lin = torch.matmul(zk, A.T)               # A z_k -> [T'-1, B, D_s]

                if getattr(self.latent_process, "memory_type", "decoder") == "residual":
                    mem = memory_states[:-1]              # [T'-1, B, D_s]
                else:
                    Dm = memory_states.shape[-1] if (memory_states is not None) else 0
                    mem_flat = memory_states[:-1].reshape(-1, Dm)             # [(T'-1)*B, Dm]
                    corr = self.latent_process.memory_decoder(mem_flat)        # [(T'-1)*B, D_s]
                    corr = corr.view(Tprime-1, B, -1)
                    gate = getattr(self.latent_process, "gate", 1.0)
                    mem = gate * corr
            else:
                # continuous-time: need linear drift & memory branch
                if hasattr(self.latent_process, "_apply_linear"):
                    lin = self.latent_process._apply_linear(z_[:-1])           # [T'-1, B, D]
                else:
                    raise RuntimeError("latent_process._apply_linear is required for continuous mode.")
                if memory_states is not None and hasattr(self.latent_process, "memory_decoder"):
                    Dm = memory_states.shape[-1]
                    mem_flat = memory_states[1:].reshape(-1, Dm)               # align with training aux
                    f_mem = self.latent_process.memory_decoder(mem_flat).view(Tprime-1, B, -1)
                    mem_scale = getattr(self.latent_process, "mem_scale", 1.0)
                    mem = mem_scale * f_mem
                else:
                    mem = lin.new_zeros(lin.shape)

            # L2 norms per step across feature dim -> [T'-1, B]
            lin_l2 = torch.linalg.norm(lin, ord=2, dim=-1)
            mem_l2 = torch.linalg.norm(mem, ord=2, dim=-1)

            if not lin_list:
                lin_list = [[] for _ in range(Tprime-1)]
                mem_list = [[] for _ in range(Tprime-1)]
            for k in range(Tprime-1):
                lin_list[k].append(lin_l2[k].detach().cpu())
                mem_list[k].append(mem_l2[k].detach().cpu())

        # --------------------------- aggregate mean/std ---------------------------
        lin_mean, lin_std, mem_mean, mem_std = [], [], [], []
        for k in range(len(lin_list)):
            lk = torch.cat(lin_list[k], dim=0).float()  # [N_total]
            mk = torch.cat(mem_list[k], dim=0).float()
            lin_mean.append(float(lk.mean()))
            mem_mean.append(float(mk.mean()))
            lin_std.append(float(lk.std(unbiased=False)))
            mem_std.append(float(mk.std(unbiased=False)))

        # --------------------------- save raw + json ---------------------------
        raw_path = os.path.join(save_dir, f"{fname_prefix}_raw.pt")
        torch.save({
            "lin_l2_per_step": [torch.cat(v, dim=0) for v in lin_list],   # list of [N_total]
            "mem_l2_per_step": [torch.cat(v, dim=0) for v in mem_list],
        }, raw_path)

        json_path = os.path.join(save_dir, f"{fname_prefix}.json")
        with open(json_path, "w") as f:
            json.dump({
                "t_index": list(range(1, len(lin_list) + 1)),
                "linear": {"mean": lin_mean, "std": lin_std},
                "memory": {"mean": mem_mean, "std": mem_std},
                "meta": {
                    "latent_mode": self.latent_mode,
                    "use_projector": bool(self.use_projector and use_projector_space),
                    "center_whiten": bool(center_and_whiten),
                }
            }, f, indent=2)

        # --------------------------- plot (publication style) ---------------------------
        rcParams.update({
            "axes.formatter.use_mathtext": True,
            "mathtext.default": "regular",
        })

        fig, ax = plt.subplots(figsize=(fig_w, fig_h))
        x = np.arange(1, len(lin_list) + 1)

        # Show all four spines with consistent width
        for side in ("left", "right", "top", "bottom"):
            ax.spines[side].set_visible(True)
            ax.spines[side].set_linewidth(spine_lw)

        # Larger tick labels (axis numbers bigger)
        ax.tick_params(axis="both", which="major", labelsize=tick_labelsize, direction="out", length=4.5, width=0.9)
        ax.tick_params(axis="both", which="minor", labelsize=tick_labelsize-1, direction="out", length=3.0, width=0.7)

        # Dense major/minor ticks on x
        ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=8, steps=[1, 2, 2.5, 5, 10]))
        ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(4))
        # Minor ticks on y as well
        ax.yaxis.set_minor_locator(mticker.AutoMinorLocator(4))

        # Grid (major + optional minor)
        ax.grid(True, which="major", linestyle="--", alpha=major_grid_alpha)
        if use_minor_grid:
            ax.grid(True, which="minor", linestyle=":", alpha=minor_grid_alpha)

        # Helper to draw line + band with fixed color
        def _band(ax, x, mean, std, color, label):
            m = np.asarray(mean); s = np.asarray(std)
            ax.plot(x, m, color=color, linewidth=line_w)
            ax.fill_between(x, m - s, m + s, color=color, alpha=band_alpha, linewidth=0.0)

        _band(ax, x, lin_mean, lin_std, color_linear, label="linear")
        _band(ax, x, mem_mean, mem_std, color_memory, label="memory")

        # Axis labels (bold-ish, larger)
        ax.set_xlabel("time step", fontsize=axis_labelsize, fontweight="bold")
        ax.set_ylabel(r"$\ell_{2}$ norm", fontsize=axis_labelsize, fontweight="bold")

        ax.set_xlim(x[0], x[-1])

        # Legend inside axes with subtle background
        # leg = ax.legend(
        #     loc="best",
        #     frameon=True, fancybox=True,
        #     framealpha=0.95, edgecolor="#444444",
        #     facecolor="white",
        #     fontsize=tick_labelsize-1
        # )

        fig.tight_layout()
        png_path = os.path.join(save_dir, f"{fname_prefix}.png")
        fig.savefig(png_path, dpi=dpi, bbox_inches="tight", pad_inches=0.06)
        plt.close(fig)

        print(f"[dyn_energy] saved: {png_path}\n- json: {json_path}\n- raw:  {raw_path}")

    
    @torch.no_grad()
    def visualize_sample_evolution(
        self,
        group: str,
        save_dir: str,
        *,
        traj_id: int | None = None,
        t0: int | None = None,
        seq_index: int | None = None,
        steps: int | None = None,
        fname_prefix: str = "latent_evolution",
        # ---------- background options ----------
        bg_mode: str = "none",            # "none" | "discrete_map" | "vector_field" | "landscape"
        grid_res: int = 25,
        field_stride: int = 2,
        field_scale: float = 1.0,
        bg_cmap: str | None = "spiral",
        bg_alpha: float = 0.65,
        # ---- contour overlay on landscape ----
        contour_levels: int = 28,
        contour_color: str = "white",
        contour_lw: float = 0.8,
        contour_alpha: float = 0.55,
        # ---------- extra random samples ----------
        extra_k: int = 3,
        extra_seed: int | None = 123,
        extra_alpha: float = 0.3,        # lighter extras
        # ---------- smoothing (spline) ----------
        smooth_curve: bool = True,
        samples_per_step: int = 12,
        # ---------- markers ----------
        show_markers_main: bool = True,   # keep markers for main
        show_markers_extra: bool = False, # <-- TURNED OFF for extras
        marker_size_main: float = 18.0,
        marker_size_extra: float = 12.0,
        marker_alpha_main: float = 0.9,
        marker_alpha_extra: float = 0.45,
        # ---------- arrows (hollow, sharp) ----------
        arrow_count_main: int = 12,
        arrow_count_extra: int = 8,
        arrow_size_main: float = 14.0,
        arrow_size_extra: float = 11.0,
        # linewidths (main solid; extras thinner)
        lw_main: float = 3.2,
        lw_extra: float = 1.8,            # <-- slightly thinner than before
        lw_arrow_main: float | None = None,   # if None -> 0.8 * lw_main
        lw_arrow_extra: float | None = None,  # if None -> 0.75 * lw_extra
        # ---------- axes & frame ----------
        spine_lw: float = 1.2,
        tick_labelsize: int = 14,
    ):
        """
        Main trajectory (by traj_id,t0 or seq_index) + several random extras,
        projected to PCA-2D fitted on the main GT latents.

        This version:
        - MAIN curves are SOLID (no gradient).
        - EXTRAS are thinner (lw_extra) and draw NO discrete markers by default.
        - Hollow, sharp arrow heads; four spines; DPI=450.
        """
        import os, json, numpy as np, matplotlib.pyplot as plt, matplotlib.ticker as mticker
        from matplotlib import rcParams
        from matplotlib.colors import LinearSegmentedColormap
        from matplotlib.collections import LineCollection
        from matplotlib.patches import FancyArrowPatch
        import torch
        from torch.utils.data import Subset, DataLoader
        rng = np.random.default_rng(extra_seed)

        # Colors: GT / Linear / Full (sky blue)
        COL_GT   = "#000000"
        COL_LIN  = "#ff7f0e"
        COL_FULL = "#2FA4FF"

        def _get_bg_cmap(name: str | None):
            if name in (None, "", "spiral"):
                colors = [
                    "#ecff3a", "#bce86b", "#7dd672", "#38c27f", "#1da1a2",
                    "#2d79c7", "#4653b3", "#5b3a9e", "#4a2d84", "#2b1b53"
                ]
                return LinearSegmentedColormap.from_list("spiral", colors, N=256)
            return plt.get_cmap(name)

        # ---------- dataset plumbing ----------
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset

        if seq_index is None:
            if not hasattr(dataset, "samples"):
                raise ValueError("dataset has no 'samples' attribute; please pass seq_index.")
            pos = {(int(tid), int(s0)): i for i, (tid, s0) in enumerate(dataset.samples)}
            key = (int(traj_id), int(t0))
            if key not in pos:
                raise KeyError(f"Sample (traj_id={traj_id}, t0={t0}) not found.")
            seq_index = pos[key]

        os.makedirs(save_dir, exist_ok=True)

        def _fetch_by_index(idx: int):
            tmp = DataLoader(Subset(dataset, [idx]), batch_size=1, shuffle=False, num_workers=0)
            batch = next(iter(tmp))
            t_full = batch["t"][0].to(self.device)              # [T]
            t_eff  = t_full[self.n_frames_cond-1:]              # align with encoded seq
            z_gt_full, _, _ = self._encode_and_recon(batch)     # [T',1,D]
            z_gt_full = z_gt_full[:, 0]
            return batch, t_eff, z_gt_full

        def _to_internal(z_seq: torch.Tensor) -> torch.Tensor:
            z_ = self._center_latent(z_seq)
            z_ = self._whiten_latent(z_)
            if self.use_projector:
                z_ = self._project_latent(z_)
            return z_

        # ---------- main sample ----------
        batch_main, t_eff_main, z_gt_full = _fetch_by_index(seq_index)
        Tprime_main = z_gt_full.size(0)
        steps_eff = int(steps) if steps is not None else int(Tprime_main)
        steps_eff = max(2, min(steps_eff, int(Tprime_main)))

        z_gt_int_main = _to_internal(z_gt_full)
        Dz = z_gt_int_main.size(-1)

        # linear-only rollout
        if self.latent_mode == "discrete":
            A_int = self.latent_process.A.to(z_gt_int_main)
            z_lin = [z_gt_int_main[0]]
            for _ in range(1, steps_eff):
                z_lin.append(z_lin[-1] @ A_int.T)
            z_lin_main = torch.stack(z_lin, dim=0)
        else:
            if hasattr(self.latent_process, "mem_scale"):
                old = float(self.latent_process.mem_scale.detach().cpu().item())
                try:
                    self.latent_process.mem_scale.fill_(0.0)
                    z_lin_full, *_ = self.latent_process(
                        alpha_0=z_gt_int_main[0].unsqueeze(0),
                        t_eval=t_eff_main[:steps_eff],
                        teacher_forcing=False
                    )
                    z_lin_main = z_lin_full[:, 0]
                finally:
                    self.latent_process.mem_scale.fill_(old)
            else:
                if not hasattr(self.latent_process, "_apply_linear"):
                    raise RuntimeError("Need mem_scale or _apply_linear for continuous mode.")
                dt = float(self.dt_eval)
                z_lin = [z_gt_int_main[0]]
                for _ in range(1, steps_eff):
                    f_lin = self.latent_process._apply_linear(z_lin[-1].unsqueeze(0)).squeeze(0)
                    z_lin.append(z_lin[-1] + dt * f_lin)
                z_lin_main = torch.stack(z_lin, dim=0)

        # full rollout
        z_full_main, *_ = self.latent_process(
            alpha_0=z_gt_int_main[0].unsqueeze(0),
            t_eval=t_eff_main[:steps_eff],
            teacher_forcing=False
        )
        z_full_main = z_full_main[:, 0]

        # PCA on main GT ONLY
        X = z_gt_int_main[:steps_eff].float()
        mu = X.mean(dim=0, keepdim=True)
        Xc = X - mu
        C = Xc.T @ Xc / max(1, Xc.size(0) - 1)
        _, vecs = torch.linalg.eigh(C)
        W = vecs[:, -2:]                                        # [Dz, 2]

        def _proj2(Z: torch.Tensor) -> np.ndarray:
            return ((Z.float() - mu) @ W).cpu().numpy()

        P_gt_main   = _proj2(z_gt_int_main[:steps_eff])
        P_lin_main  = _proj2(z_lin_main)
        P_full_main = _proj2(z_full_main)

        # bounds from main
        all_xy = np.vstack([P_gt_main, P_lin_main, P_full_main])
        xmin, xmax = float(all_xy[:, 0].min()), float(all_xy[:, 0].max())
        ymin, ymax = float(all_xy[:, 1].min()), float(all_xy[:, 1].max())
        xmid, ymid = 0.5 * (xmin + xmax), 0.5 * (ymin + ymax)
        L = max(xmax - xmin, ymax - ymin); pad = 0.05 * L
        xlim = (xmid - 0.5 * L - pad, xmid + 0.5 * L + pad)
        ylim = (ymid - 0.5 * L - pad, ymid + 0.5 * L + pad)

        # ---------- background prep ----------
        A_lin = None
        G_gen = None

        def _try_build_A_lin_continuous() -> torch.Tensor | None:
            if not hasattr(self.latent_process, "_apply_linear"):
                return None
            I = torch.eye(Dz, device=z_gt_int_main.device, dtype=z_gt_int_main.dtype)
            F = self.latent_process._apply_linear(I)
            return F if F.shape == (Dz, Dz) else F.reshape(Dz, Dz)

        def _logm_via_scipy(A: torch.Tensor) -> torch.Tensor:
            from scipy.linalg import logm
            A_np = A.detach().to("cpu", dtype=torch.float64).numpy()
            Lm = logm(A_np)
            Lm = np.real_if_close(Lm, tol=1e-7)
            if np.iscomplexobj(Lm):
                Lm = Lm.real
            return torch.from_numpy(Lm).to(device=A.device, dtype=A.dtype)

        if bg_mode in {"discrete_map", "vector_field", "landscape"}:
            if self.latent_mode == "discrete":
                A_lin = self.latent_process.A.detach().to(z_gt_int_main)
                try:
                    G_gen = _logm_via_scipy(A_lin) / float(self.dt_eval)
                except Exception:
                    G_gen = None
            else:
                A_lin = _try_build_A_lin_continuous()
                if A_lin is not None:
                    G_gen = A_lin.detach().clone()

        bg = {"mode": bg_mode, "xlim": xlim, "ylim": ylim}
        if bg_mode != "none":
            Xg = np.linspace(xlim[0], xlim[1], grid_res)
            Yg = np.linspace(ylim[0], ylim[1], grid_res)
            XX, YY = np.meshgrid(Xg, Yg)
            P_grid = np.stack([XX, YY], axis=-1).reshape(-1, 2)
            W_np = W.cpu().numpy(); mu_np = mu.cpu().numpy().reshape(-1)
            Z_grid = (mu_np[None, :] + P_grid @ W_np.T)
            Z_grid_t = torch.from_numpy(Z_grid).to(z_gt_int_main)

            if bg_mode in {"discrete_map", "vector_field"}:
                if bg_mode == "discrete_map":
                    if A_lin is None:
                        U = V = np.zeros_like(XX)
                    else:
                        Z_next = (Z_grid_t @ A_lin.T)
                        P_next = _proj2(Z_next)
                        P_curr = _proj2(Z_grid_t)
                        dP = P_next - P_curr
                        U = dP[:, 0].reshape(XX.shape); V = dP[:, 1].reshape(XX.shape)
                else:
                    if G_gen is None and A_lin is not None:
                        try:
                            G_gen = _logm_via_scipy(A_lin) / float(self.dt_eval)
                        except Exception:
                            G_gen = None
                    if G_gen is not None:
                        Vz = (Z_grid_t @ G_gen.T)
                        dP = _proj2(Vz)
                    else:
                        if A_lin is None:
                            dP = np.zeros_like(P_grid)
                        else:
                            Z_next = (Z_grid_t @ A_lin.T)
                            dP = _proj2(Z_next) - _proj2(Z_grid_t)
                    U = dP[:, 0].reshape(XX.shape); V = dP[:, 1].reshape(XX.shape)
                bg.update({"grid_X": XX, "grid_Y": YY, "U": U, "V": V})

            elif bg_mode == "landscape":
                if G_gen is not None:
                    Vz = (Z_grid_t @ G_gen.T)
                    S = torch.linalg.norm(Vz, dim=-1).cpu().numpy().reshape(XX.shape)
                else:
                    if A_lin is None:
                        S = np.zeros_like(XX)
                    else:
                        Z_next = (Z_grid_t @ A_lin.T)
                        disp = Z_next - Z_grid_t
                        S = torch.linalg.norm(disp, dim=-1).cpu().numpy().reshape(XX.shape)
                # mild smoothing for nicer rings
                try:
                    from scipy.ndimage import gaussian_filter
                    S_plot = gaussian_filter(S, sigma=0.8)
                except Exception:
                    S_plot = S
                bg.update({"grid_X": XX, "grid_Y": YY, "S": S_plot})

        # ---------- collect extra samples ----------
        extra_seq_indices, P_gt_extra, P_lin_extra, P_full_extra = [], [], [], []
        if hasattr(dataset, "samples") and extra_k > 0:
            pool = np.arange(len(dataset.samples))
            pool = pool[pool != int(seq_index)]
            if len(pool) > 0:
                rng.shuffle(pool)
                chosen = pool[:min(extra_k, len(pool))]
                for idx in chosen:
                    try:
                        _, t_eff_e, z_gt_full_e = _fetch_by_index(int(idx))
                        Tprime_e = int(z_gt_full_e.size(0))
                        steps_e = max(2, min(steps_eff, Tprime_e))

                        z_gt_int_e = _to_internal(z_gt_full_e)
                        # linear-only for extra
                        if self.latent_mode == "discrete":
                            A_int_e = self.latent_process.A.to(z_gt_int_e)
                            tmp = [z_gt_int_e[0]]
                            for _ in range(1, steps_e):
                                tmp.append(tmp[-1] @ A_int_e.T)
                            z_lin_e = torch.stack(tmp, dim=0)
                        else:
                            if hasattr(self.latent_process, "mem_scale"):
                                old = float(self.latent_process.mem_scale.detach().cpu().item())
                                try:
                                    self.latent_process.mem_scale.fill_(0.0)
                                    z_lin_full_e, *_ = self.latent_process(
                                        alpha_0=z_gt_int_e[0].unsqueeze(0),
                                        t_eval=t_eff_e[:steps_e],
                                        teacher_forcing=False
                                    )
                                    z_lin_e = z_lin_full_e[:, 0]
                                finally:
                                    self.latent_process.mem_scale.fill_(old)
                            else:
                                if not hasattr(self.latent_process, "_apply_linear"):
                                    continue
                                dt = float(self.dt_eval)
                                tmp = [z_gt_int_e[0]]
                                for _ in range(1, steps_e):
                                    f_lin = self.latent_process._apply_linear(tmp[-1].unsqueeze(0)).squeeze(0)
                                    tmp.append(tmp[-1] + dt * f_lin)
                                z_lin_e = torch.stack(tmp, dim=0)

                        z_full_e, *_ = self.latent_process(
                            alpha_0=z_gt_int_e[0].unsqueeze(0),
                            t_eval=t_eff_e[:steps_e],
                            teacher_forcing=False
                        )
                        z_full_e = z_full_e[:, 0]

                        P_gt_extra.append(_proj2(z_gt_int_e[:steps_e]))
                        P_lin_extra.append(_proj2(z_lin_e))
                        P_full_extra.append(_proj2(z_full_e))
                        extra_seq_indices.append(int(idx))

                        all_xy = np.vstack([all_xy, P_gt_extra[-1], P_lin_extra[-1], P_full_extra[-1]])
                    except Exception:
                        continue

        # update bounds including extras
        xmin, xmax = float(all_xy[:, 0].min()), float(all_xy[:, 0].max())
        ymin, ymax = float(all_xy[:, 1].min()), float(all_xy[:, 1].max())
        xmid, ymid = 0.5 * (xmin + xmax), 0.5 * (ymin + ymax)
        L = max(xmax - xmin, ymax - ymin); pad = 0.05 * L
        xlim = (xmid - 0.5 * L - pad, xmid + 0.5 * L + pad)
        ylim = (ymid - 0.5 * L - pad, ymid + 0.5 * L + pad)

        # ---------- smoothing ----------
        def _smooth_polyline(P: np.ndarray, n_per_seg: int) -> np.ndarray:
            """Return smoothed & upsampled polyline P (N,2)."""
            if not smooth_curve or P.shape[0] < 3 or n_per_seg <= 1:
                t = np.arange(P.shape[0], dtype=float)
                ti = np.linspace(0, t[-1], (P.shape[0]-1)*max(1, n_per_seg)+1)
                xi = np.interp(ti, t, P[:,0]); yi = np.interp(ti, t, P[:,1])
                return np.stack([xi, yi], axis=-1)
            try:
                from scipy.interpolate import CubicSpline
                s = np.concatenate([[0.0], np.cumsum(np.linalg.norm(np.diff(P, axis=0), axis=1))])
                if s[-1] <= 0: return P.copy()
                s /= s[-1]
                si = np.linspace(0.0, 1.0, (P.shape[0]-1)*n_per_seg + 1)
                csx = CubicSpline(s, P[:,0], bc_type="natural")
                csy = CubicSpline(s, P[:,1], bc_type="natural")
                xi = csx(si); yi = csy(si)
                return np.stack([xi, yi], axis=-1)
            except Exception:
                # Hermite fallback
                N = P.shape[0]
                Tt = np.zeros_like(P)
                Tt[0]  = P[1]   - P[0]
                Tt[-1] = P[-1]  - P[-2]
                if N > 2: Tt[1:-1] = 0.5 * (P[2:] - P[:-2])
                out = [P[0]]
                for i in range(N-1):
                    p0, p1 = P[i], P[i+1]; m0, m1 = Tt[i], Tt[i+1]
                    ts = np.linspace(0.0, 1.0, n_per_seg+1)
                    h00 = (2*ts**3 - 3*ts**2 + 1)[:, None]
                    h10 = (ts**3 - 2*ts**2 + ts)[:, None]
                    h01 = (-2*ts**3 + 3*ts**2)[:, None]
                    h11 = (ts**3 - ts**2)[:, None]
                    seg = h00*p0 + h10*m0 + h01*p1 + h11*m1
                    if i < N-2: seg = seg[:-1]
                    out.append(seg)
                return np.vstack(out)

        # gradient segment drawer (used for EXTRAS only)
        def _line_segments(P: np.ndarray):
            return np.stack([P[:-1], P[1:]], axis=1)

        def _grad_line(ax, P: np.ndarray, c0: str, c1: str, lw: float, alpha: float):
            from matplotlib.colors import LinearSegmentedColormap
            segs = _line_segments(P)
            lc = LineCollection(segs, linewidths=lw, capstyle="round")
            cmap = LinearSegmentedColormap.from_list("tmp", [c0, c1])
            t = np.linspace(0.0, 1.0, max(2, len(segs)))
            lc.set_cmap(cmap); lc.set_array(t); lc.set_alpha(alpha)
            ax.add_collection(lc)
            return lc

        # hollow, sharp arrows
        def _add_arrows(ax, P: np.ndarray, color: str, n: int, ms: float, lw: float, alpha: float):
            N = len(P)
            if N < 2 or n <= 0: return
            idxs = np.linspace(1, N-1, num=min(n, max(1, N-1)), dtype=int)
            idxs = np.unique(np.clip(idxs, 1, N-1))
            for i in idxs:
                x0, y0 = P[i-1]; x1, y1 = P[i]
                arr = FancyArrowPatch(
                    (x0, y0), (x1, y1),
                    arrowstyle='-|>',
                    mutation_scale=ms,
                    linewidth=lw,
                    facecolor="none",
                    edgecolor=color,
                    shrinkA=0.0, shrinkB=0.0,
                    alpha=alpha,
                    joinstyle="miter", capstyle="round",
                    zorder=6
                )
                ax.add_patch(arr)

        # smoothed polylines (for drawing)
        P_gt_main_s   = _smooth_polyline(P_gt_main,   samples_per_step)
        P_lin_main_s  = _smooth_polyline(P_lin_main,  samples_per_step)
        P_full_main_s = _smooth_polyline(P_full_main, samples_per_step)

        P_gt_extra_s, P_lin_extra_s, P_full_extra_s = [], [], []
        for Pgt, Plin, Pfull in zip(P_gt_extra, P_lin_extra, P_full_extra):
            P_gt_extra_s.append(_smooth_polyline(Pgt,   samples_per_step))
            P_lin_extra_s.append(_smooth_polyline(Plin, samples_per_step))
            P_full_extra_s.append(_smooth_polyline(Pfull, samples_per_step))

        # arrow widths
        if lw_arrow_main is None: lw_arrow_main = max(0.6, 0.8 * lw_main)
        if lw_arrow_extra is None: lw_arrow_extra = max(0.5, 0.75 * lw_extra)

        # ===================== PLOTTING =====================
        rcParams.update({
            "axes.grid": True,
            "grid.linestyle": "--",
            "grid.alpha": 0.25,
            "axes.formatter.use_mathtext": True,
        })
        fig, ax = plt.subplots(figsize=(6.2, 6.2))

        # four spines
        for side in ("left", "right", "top", "bottom"):
            ax.spines[side].set_visible(True)
            ax.spines[side].set_linewidth(spine_lw)

        ax.tick_params(axis="both", which="major", labelsize=tick_labelsize,
                    direction="out", length=4.5, width=0.9)
        ax.tick_params(axis="both", which="minor", labelsize=tick_labelsize-1,
                    direction="out", length=3.0, width=0.7)
        ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(4))
        ax.yaxis.set_minor_locator(mticker.AutoMinorLocator(4))

        # limits BEFORE background
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlim(*xlim); ax.set_ylim(*ylim)

        # -------- background --------
        if bg_mode == "discrete_map":
            U = bg["U"]; V = bg["V"]; XX = bg["grid_X"]; YY = bg["grid_Y"]
            ax.quiver(XX[::field_stride, ::field_stride],
                    YY[::field_stride, ::field_stride],
                    (U*field_scale)[::field_stride, ::field_stride],
                    (V*field_scale)[::field_stride, ::field_stride],
                    width=0.003, alpha=0.65, minlength=0.0, zorder=1)
        elif bg_mode == "vector_field":
            U = bg["U"]; V = bg["V"]; XX = bg["grid_X"]; YY = bg["grid_Y"]
            ax.quiver(XX[::field_stride, ::field_stride],
                    YY[::field_stride, ::field_stride],
                    (U*field_scale)[::field_stride, ::field_stride],
                    (V*field_scale)[::field_stride, ::field_stride],
                    width=0.003, alpha=0.70, minlength=0.0, zorder=1)
        elif bg_mode == "landscape":
            S = bg["S"]
            ax.imshow(
                S,
                extent=(xlim[0], xlim[1], ylim[0], ylim[1]),
                origin="lower",
                cmap=_get_bg_cmap(bg_cmap),
                alpha=float(max(0.0, min(1.0, bg_alpha))),
                interpolation="bilinear",
                zorder=0,
                aspect="auto"
            )
            # overlay contour rings
            cx = np.linspace(xlim[0], xlim[1], S.shape[1])
            cy = np.linspace(ylim[0], ylim[1], S.shape[0])
            CCX, CCY = np.meshgrid(cx, cy)
            ax.contour(
                CCX, CCY, S,
                levels=int(contour_levels),
                colors=contour_color,
                linewidths=contour_lw,
                alpha=contour_alpha,
                zorder=1
            )

        # ------- MAIN (SOLID colors, no gradient) -------
        ax.plot(P_gt_main_s[:,0],   P_gt_main_s[:,1],   color=COL_GT,   lw=lw_main,  zorder=5)
        ax.plot(P_lin_main_s[:,0],  P_lin_main_s[:,1],  color=COL_LIN,  lw=lw_main,  zorder=5)
        ax.plot(P_full_main_s[:,0], P_full_main_s[:,1], color=COL_FULL, lw=lw_main,  zorder=5)
        _add_arrows(ax, P_gt_main_s,   COL_GT,   arrow_count_main,  arrow_size_main,  lw_arrow_main, 1.0)
        _add_arrows(ax, P_lin_main_s,  COL_LIN,  arrow_count_main,  arrow_size_main,  lw_arrow_main, 1.0)
        _add_arrows(ax, P_full_main_s, COL_FULL, arrow_count_main,  arrow_size_main,  lw_arrow_main, 1.0)

        # start markers (main) + optional per-step markers
        ax.scatter(P_gt_main[0,0],   P_gt_main[0,1],   s=36, color=COL_GT,   zorder=7)
        ax.scatter(P_lin_main[0,0],  P_lin_main[0,1],  s=36, color=COL_LIN,  zorder=7)
        ax.scatter(P_full_main[0,0], P_full_main[0,1], s=36, color=COL_FULL, zorder=7)
        if show_markers_main:
            ax.scatter(P_gt_main[:,0],   P_gt_main[:,1],   s=marker_size_main,  color=COL_GT,   alpha=marker_alpha_main,  zorder=8)
            ax.scatter(P_lin_main[:,0],  P_lin_main[:,1],  s=marker_size_main,  color=COL_LIN,  alpha=marker_alpha_main,  zorder=8)
            ax.scatter(P_full_main[:,0], P_full_main[:,1], s=marker_size_main,  color=COL_FULL, alpha=marker_alpha_main,  zorder=8)

        # ------- EXTRAS (gradient, thinner, NO markers by default) -------
        for Pgt, Plin, Pfull, Pgt_s, Plin_s, Pfull_s in zip(
            P_gt_extra, P_lin_extra, P_full_extra, P_gt_extra_s, P_lin_extra_s, P_full_extra_s
        ):
            _grad_line(ax, Pgt_s,   "#D9D9D9", COL_GT,   lw_extra, extra_alpha)
            _grad_line(ax, Plin_s,  "#FFDAB8", COL_LIN,  lw_extra, extra_alpha)
            _grad_line(ax, Pfull_s, "#D7EEFF", COL_FULL, lw_extra, extra_alpha)
            _add_arrows(ax, Pgt_s,   COL_GT,   arrow_count_extra, arrow_size_extra, lw_arrow_extra, extra_alpha)
            _add_arrows(ax, Plin_s,  COL_LIN,  arrow_count_extra, arrow_size_extra, lw_arrow_extra, extra_alpha)
            _add_arrows(ax, Pfull_s, COL_FULL, arrow_count_extra, arrow_size_extra, lw_arrow_extra, extra_alpha)

            if show_markers_extra:
                ax.scatter(Pgt[:,0],   Pgt[:,1],   s=marker_size_extra,  color=COL_GT,   alpha=marker_alpha_extra,  zorder=7)
                ax.scatter(Plin[:,0],  Plin[:,1],  s=marker_size_extra,  color=COL_LIN,  alpha=marker_alpha_extra,  zorder=7)
                ax.scatter(Pfull[:,0], Pfull[:,1], s=marker_size_extra,  color=COL_FULL, alpha=marker_alpha_extra,  zorder=7)

        # optional crosshair
        ax.axhline(0, lw=0.8, alpha=0.5); ax.axvline(0, lw=0.8, alpha=0.5)

        # save (dpi=450, no pad_inches)
        png_path = os.path.join(save_dir, f"{fname_prefix}.png")
        fig.savefig(png_path, dpi=450, bbox_inches="tight")
        plt.close(fig)

        # ---------- save data ----------
        data_path = os.path.join(save_dir, f"{fname_prefix}_data.pt")
        torch.save({
            "z_gt_int_main": z_gt_int_main[:steps_eff].detach().cpu(),
            "z_lin_int_main": z_lin_main.detach().cpu(),
            "z_full_int_main": z_full_main.detach().cpu(),
            "P_gt_main": torch.from_numpy(P_gt_main),
            "P_lin_main": torch.from_numpy(P_lin_main),
            "P_full_main": torch.from_numpy(P_full_main),
            "W": W.detach().cpu(),
            "mu": mu.detach().cpu(),
            "A_lin": (A_lin.detach().cpu() if A_lin is not None else None),
            "G_gen": (G_gen.detach().cpu() if G_gen is not None else None),
            "bg": bg,
            "extras": [
                {"seq_index": int(i),
                "P_gt": torch.from_numpy(p0),
                "P_lin": torch.from_numpy(p1),
                "P_full": torch.from_numpy(p2)}
                for i, p0, p1, p2 in zip(extra_seq_indices, P_gt_extra, P_lin_extra, P_full_extra)
            ],
            "meta": {
                "group": group,
                "seq_index_main": int(seq_index),
                "extra_seq_indices": [int(i) for i in extra_seq_indices],
                "steps_main": int(steps_eff),
                "bg_mode": bg_mode,
                "grid_res": int(grid_res),
                "field_stride": int(field_stride),
                "dt_eval": float(self.dt_eval),
                "bg_cmap": bg_cmap,
                "bg_alpha": float(bg_alpha),
                "colors": {"gt": COL_GT, "lin": COL_LIN, "full": COL_FULL},
                "style": {
                    "lw_main": float(lw_main), "lw_extra": float(lw_extra),
                    "lw_arrow_main": float(lw_arrow_main), "lw_arrow_extra": float(lw_arrow_extra),
                    "arrow_count_main": int(arrow_count_main),
                    "arrow_count_extra": int(arrow_count_extra),
                    "arrow_size_main": float(arrow_size_main),
                    "arrow_size_extra": float(arrow_size_extra),
                    "spine_lw": float(spine_lw),
                    "tick_labelsize": int(tick_labelsize),
                    "extra_alpha": float(extra_alpha),
                    "smooth_curve": bool(smooth_curve),
                    "samples_per_step": int(samples_per_step),
                    "show_markers_main": bool(show_markers_main),
                    "show_markers_extra": bool(show_markers_extra),
                    "marker_size_main": float(marker_size_main),
                    "marker_size_extra": float(marker_size_extra),
                    "marker_alpha_main": float(marker_alpha_main),
                    "marker_alpha_extra": float(marker_alpha_extra),
                    "contour_levels": int(contour_levels),
                    "contour_color": str(contour_color),
                    "contour_lw": float(contour_lw),
                    "contour_alpha": float(contour_alpha),
                }
            }
        }, data_path)

        print(f"[evolution] saved: {png_path}\n- data: {data_path}")



    @torch.no_grad()
    def plot_lowdim_time_series(
        self,
        time_scale: float = 4.0,
        *,
        traj_id: int | None = None,
        t0: int | None = None,
        use_center: bool = True,
        use_whiten: bool = True,
        use_projector: bool = True,        # requires U_proj
        steps: int | None = None,
        save_dir: str = "./vis",
        fname_prefix: str = "lowdim_timeseries",
        # --- visual knobs ---
        lw: float = 3.0,                   # unified line width
        alpha: float = 1.0,                # unified alpha
        legend: bool = True,
        legend_max: int = 20,
        legend_loc: str = "upper left",    # inside-axes legend location
        downsample: int = 1,               # temporal stride on time axis
        # --- axes & layout ---
        hide_y_ticks: bool = True,         # remove y ticks/labels
        x_tick_step: float | None = None,  # major x tick step; None -> auto
        fig_w: float = 7.2,
        fig_h: float = 3.6,
        tick_labelsize: int = 13,
        # --- readability helpers (viz only) ---
        standardize: bool = False,         # per-dim z-score (only affects viz)
        smooth_window: int = 1,            # odd >=1; 1 = no smoothing
        # --- NEW: dimension selection ---
        max_dims: int | None = None,       # keep first K dims (after projection)
        dim_indices: list[int] | None = None,  # explicit dim indices to plot
        topk_by_var: int | None = None,    # choose K dims with largest variance (after viz ops)
        # --- output ---
        dpi: int = 450
    ):
        """
        Overlay time series of low-dimensional latent coordinates (after U_proj) on one axes.

        Downsampling:
            - 'downsample' is a stride on the time axis: keep every s-th frame (s>=1).

        Dimension limiting (priority):
            1) dim_indices: use exactly these dims (validated & de-duplicated)
            2) topk_by_var: pick K dims with largest variance on the displayed data
            3) max_dims:    keep first K dims (0..K-1)
            4) otherwise:   plot all dims

        Saves:
            - {fname_prefix}.png
            - {fname_prefix}_data.pt
            - {fname_prefix}_data.npz
            - {fname_prefix}_data.csv
        """
        import os, json, numpy as np, matplotlib.pyplot as plt, matplotlib.ticker as mticker
        from matplotlib import rcParams
        import torch, colorsys

        assert use_projector, "This function visualizes post-projection time series; set use_projector=True."
        assert getattr(self, "use_projector", False) and (self.U_proj is not None), \
            "U_proj not set. Train or load a projector first."

        os.makedirs(save_dir, exist_ok=True)

        # ------------------ 1) fetch one fixed sample ------------------
        loader = self.sample_from_fix(traj_id, t0, steps)
        batch = next(iter(loader))
        t_full = batch["t"][0].to(self.device)                # [T]
        t_eff  = t_full[self.n_frames_cond - 1:] * time_scale

        # ------------------ 2) encode -> (center, whiten) -> project ------------------
        z_full, _, _ = self._encode_and_recon(batch)          # [T',1,D]
        z_full = z_full[:, 0]                                 # [T',D]
        Tprime = int(z_full.size(0))
        if steps is None:
            steps = Tprime
        steps = max(2, min(int(steps), Tprime))

        z = z_full[:steps]
        if use_center:  z = self._center_latent(z)
        if use_whiten:  z = self._whiten_latent(z)
        z = self._project_latent(z) if use_projector else z   # [T', d]
        z = z.detach().cpu().float()                          # [T', d]

        # ------------------ 3) time downsample / standardize / smooth ------------------
        s = max(1, int(downsample))
        z_plot = z[::s]                                       # [T_ds, d]
        t_plot = t_eff[:steps].detach().cpu().numpy()[::s]    # [T_ds]

        if standardize:
            mu = z_plot.mean(0, keepdim=True)
            sd = z_plot.std(0, keepdim=True).clamp_min(1e-8)
            z_plot = (z_plot - mu) / sd

        def _movavg(x: np.ndarray, w: int) -> np.ndarray:
            if w <= 1: return x
            if w % 2 == 0: w += 1
            pad = w // 2
            xpad = np.pad(x, ((pad, pad), (0, 0)), mode="reflect")
            kernel = np.ones((w,), dtype=np.float64) / float(w)
            return np.apply_along_axis(lambda col: np.convolve(col, kernel, mode="valid"), axis=0, arr=xpad)

        if int(smooth_window) > 1:
            z_plot = torch.from_numpy(_movavg(z_plot.numpy(), int(smooth_window))).float()

        Tds, d_all = z_plot.shape

        # ------------------ 4) choose which dims to plot ------------------
        # Priority: dim_indices > topk_by_var > max_dims > all
        if dim_indices is not None:
            idx_np = np.array([int(i) for i in dim_indices], dtype=int)
            idx_np = idx_np[(idx_np >= 0) & (idx_np < d_all)]
            # de-duplicate while keeping order
            _, first_pos = np.unique(idx_np, return_index=True)
            idx_np = idx_np[np.sort(first_pos)]
            if idx_np.size == 0:
                raise ValueError("dim_indices filtered to empty set (out of range?).")
        elif topk_by_var is not None:
            k = int(max(1, min(topk_by_var, d_all)))
            var = z_plot.var(dim=0) if Tds > 1 else torch.zeros(d_all)
            idx_np = torch.topk(var, k=k).indices.cpu().numpy()
            idx_np.sort()  # keep ascending order for labels
        elif max_dims is not None:
            k = int(max(1, min(max_dims, d_all)))
            idx_np = np.arange(k, dtype=int)
        else:
            idx_np = np.arange(d_all, dtype=int)

        z_plot = z_plot[:, idx_np]   # [T_ds, d_sel]
        d = z_plot.shape[1]

        # ------------------ 5) vivid color cycle (HSV with golden-ratio jumps) ------------------
        def bright_cycle(n: int):
            out = []
            phi = 0.6180339887498949  # golden ratio conjugate
            h = 0.0
            for _ in range(n):
                h = (h + phi) % 1.0
                out.append(colorsys.hsv_to_rgb(h, 0.95, 1.0))  # high saturation & brightness
            return out

        colors = bright_cycle(d)

        # ------------------ 6) figure style ------------------
        rcParams.update({
            "axes.formatter.use_mathtext": True,
            "font.size": tick_labelsize,
            "axes.spines.top": True,
            "axes.spines.right": True,
        })

        fig, ax = plt.subplots(figsize=(fig_w, fig_h))

        # full frame with consistent linewidth (show all four spines)
        for side in ("left", "right", "top", "bottom"):
            ax.spines[side].set_visible(True)
            ax.spines[side].set_linewidth(1.1)

        # ticks & grids
        ax.set_xlim(float(t_plot[0]), float(t_plot[-1]))
        if x_tick_step is not None and x_tick_step > 0:
            ax.xaxis.set_major_locator(mticker.MultipleLocator(base=float(x_tick_step)))
        else:
            ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=6, prune=None, steps=[1, 2, 2.5, 5, 10]))
        ax.xaxis.set_minor_locator(mticker.AutoMinorLocator(4))
        ax.yaxis.set_minor_locator(mticker.AutoMinorLocator(4))

        ax.tick_params(axis="x", which="both", labelsize=tick_labelsize)
        ax.tick_params(axis="y", which="both", labelsize=tick_labelsize)

        if hide_y_ticks:
            ax.tick_params(axis="y", which="both", length=0)
            ax.set_yticklabels([])

        # dense grids (no y=0 baseline)
        ax.grid(True, which="major", linestyle="--", linewidth=0.8, alpha=0.35)
        ax.grid(True, which="minor", linestyle=":",  linewidth=0.55, alpha=0.25)

        # unified lines
        labels = []
        for j in range(d):
            ax.plot(
                t_plot, z_plot[:, j].numpy(),
                color=colors[j], lw=lw, alpha=alpha, solid_capstyle="round",
                label=(f"obs {int(idx_np[j])+1}" if (legend and d <= legend_max) else None)
            )
            labels.append(int(idx_np[j]))

        # legend INSIDE the axes
        if legend and d <= legend_max:
            ax.legend(
                loc=legend_loc, frameon=True, framealpha=0.85,
                facecolor="white", edgecolor="none",
                ncol=min(len(ax.get_legend_handles_labels()[1]), 4),
                fontsize=tick_labelsize - 1
            )

        fig.tight_layout()
        out_png = os.path.join(save_dir, f"{fname_prefix}.png")
        fig.savefig(out_png, dpi=dpi, bbox_inches="tight", pad_inches=0.06)
        plt.close(fig)

        # ------------------ 7) save data (exact reproduction) ------------------
        meta = {
            "traj_id": int(traj_id) if traj_id is not None else None,
            "t0": int(t0) if t0 is not None else None,
            "steps": int(steps),
            "downsample": int(s),
            "use_center": bool(use_center),
            "use_whiten": bool(use_whiten),
            "use_projector": bool(use_projector),
            "standardize": bool(standardize),
            "smooth_window": int(smooth_window),
            "dpi": int(dpi),
            "dt_eval": float(getattr(self, "dt_eval", 1.0)),
            "time_scale": float(time_scale),
            "hide_y_ticks": bool(hide_y_ticks),
            "x_tick_step": (None if x_tick_step is None else float(x_tick_step)),
            "fig_w": float(fig_w),
            "fig_h": float(fig_h),
            "tick_labelsize": int(tick_labelsize),
            "lw": float(lw),
            "alpha": float(alpha),
            "legend_loc": str(legend_loc),
            # NEW: record selection strategy & indices
            "max_dims": (None if max_dims is None else int(max_dims)),
            "topk_by_var": (None if topk_by_var is None else int(topk_by_var)),
            "dim_indices": [int(x) for x in labels],   # actual plotted original dim ids (after selection)
        }

        torch_path = os.path.join(save_dir, f"{fname_prefix}_data.pt")
        torch.save({"t": torch.from_numpy(np.asarray(t_plot)),
                    "z_plot": z_plot.detach().cpu(),   # [T_ds, d_selected]
                    "meta": meta}, torch_path)

        npz_path = os.path.join(save_dir, f"{fname_prefix}_data.npz")
        np.savez_compressed(
            npz_path,
            t=np.asarray(t_plot),
            z_plot=z_plot.numpy(),
            meta_json=np.frombuffer(json.dumps(meta).encode("utf-8"), dtype=np.uint8)
        )

        csv_path = os.path.join(save_dir, f"{fname_prefix}_data.csv")
        header = ",".join(["t"] + [f"dim_{k}" for k in labels])
        table  = np.concatenate([np.asarray(t_plot)[:, None], z_plot.numpy()], axis=1)
        np.savetxt(csv_path, table, delimiter=",", header=header, comments="", fmt="%.10g")

        print(f"[lowdim] saved figure: {out_png}")
        print(f"[lowdim] saved data  : {torch_path}")
        print(f"[lowdim] saved data  : {npz_path}")
        print(f"[lowdim] saved data  : {csv_path}")

