import os
import numpy as np
import matplotlib.pyplot as plt
from pprint import pformat
import torch
import torch.nn as nn
from utilities.utils import set_requires_grad
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F

from data.data_process import PDEDataProcessor, PDEDataset
from exp.exp_basic import Exp_Basic, ExpConfigs
from utilities.losses import LpLoss
from utilities.vis import visualize_pred_vs_gt
from utilities.data_proc import delay_stack_last_channel
from utilities.read_file import load_matrix_from_path

from config import MemKNOParamBundle
from config import build_encoder, build_set_encoder, build_field_decoder, build_fourier_decoder, build_latent_process
from config import build_latent_process_discrete

"""
TODO:
"""


class Exp_MemKNO(Exp_Basic):
    def __init__(self, args, exp_cfg: ExpConfigs, model_cfg: MemKNOParamBundle, data_processor: PDEDataProcessor):
        super(Exp_MemKNO, self).__init__(args, exp_cfg, model_cfg, data_processor)

        if hasattr(model_cfg, "as_model_kwargs"):
            self.model_cfg = model_cfg.as_model_kwargs(include_meta=True)
        elif isinstance(model_cfg, dict):
            self.model_cfg = model_cfg
        else:
            raise TypeError("model_cfg must be a MemKNOParamBundle or a dict loaded from model_cfg.json")
        assert self.model_cfg["n_frames_cond"] == data_processor.n_frames_cond

        self.state_dim  = self.model_cfg["state_dim"]
        self.latent_dim = self.model_cfg["latent_dim"]
        self.code_dim   = self.model_cfg["code_dim"]
        self.latent_type = self.model_cfg["latent_type"]
        self.use_memory = (self.latent_type in {"linear+memory", "gru_memory", "lstm_memory"})

        # dataloader, initialized in Exp_Basic
        # teacher forcing params
        self.tf_epsilon, self.epsilon = exp_cfg.tf_epsilon, exp_cfg.epsilon

        self.dt_eval = exp_cfg.dt_eval

        self.enc_mode = self.cfg.enc_mode
        self.dec_mode = self.cfg.dec_mode
        self.latent_mode = self.cfg.latent_mode
        self.loss_with_mask = self.cfg.loss_with_mask

        # loss weights
        self.lambda_dyn, self.lambda_pred = self.cfg.lambda_dyn, self.cfg.lambda_pred
        self.lambda_corr = self.cfg.lambda_corr
        self.lambda_resid = getattr(self.cfg, "lambda_residual", 0.0)

        self.use_diag_whiten = getattr(self.cfg, "use_diag_whiten", True)
        self.whiten_eps = getattr(self.cfg, "whiten_eps", 1e-8)
        self.whiten_clamp = getattr(self.cfg, "whiten_clamp", 1e-3)
        self.whiten_scale = None  # [D]

        ################# Params for Low-dimensional Projectors #################
        self.U_proj = None           # [D, d] or None
        self.use_projector = False  
        self.latent_dim_y = None 

        # load model
        self.load_model()
        self.log_param_table()

    
    def build_dataloader(self, group: str):
        sample_map = {
            "train": self.train_sample_idx,
            "train_eval": self.train_eval_sample_idx,
            "test": self.test_sample_idx,
        }
        dataloader = self.data_processor.get_dataloader(group=group, samples=sample_map[group])
        if group == "train":
            self.train_loader = dataloader
        elif group == "train_eval":
            self.train_eval_loader = dataloader
        elif group == "test":
            self.test_loader = dataloader

    
    def load_model(self):
        if self.enc_mode == "galerkin_transformer":
            self.encoder = build_encoder(model_cfg=self.model_cfg["encoder"]).to(self.device)
        elif self.enc_mode == "set_transformer":
            self.encoder = build_set_encoder(model_cfg=self.model_cfg["set_encoder"]).to(self.device)

        if self.dec_mode == "fouriernet":
            self.decoder = build_field_decoder(x_grid=self.pos_feat, model_cfg=self.model_cfg["field_decoder"]).to(self.device)
        elif self.dec_mode == "fouriermlp":
            self.decoder = build_fourier_decoder(model_cfg=self.model_cfg["fourier_decoder"]).to(self.device)

        if self.latent_mode == "continuous":
            self.latent_process = build_latent_process(model_cfg=self.model_cfg["latent_process"]).to(self.device)
        elif self.latent_mode == "discrete":
            self.latent_process = build_latent_process_discrete(model_cfg=self.model_cfg["latent_process_discrete"]).to(self.device)

    
    def count_parameters(self) -> dict:
        """Return a dict with per-module and total trainable parameter counts."""
        def ntrainable(module):
            return sum(p.numel() for p in module.parameters() if p.requires_grad)

        counts = {
            "encoder": ntrainable(self.encoder) if hasattr(self, "encoder") else 0,
            "latent_process": ntrainable(self.latent_process) if hasattr(self, "latent_process") else 0,
            "decoder": ntrainable(self.decoder) if hasattr(self, "decoder") else 0,
        }
        counts["total"] = sum(counts.values())
        return counts


    def log_param_table(self, title: str = "Trainable parameters"):
        c = self.count_parameters()
        lines = [
            f"{title}:",
            f"  encoder       : {c['encoder']:,}",
            f"  latent_process: {c['latent_process']:,}",
            f"  decoder       : {c['decoder']:,}",
            f"  ---------------------------",
            f"  TOTAL         : {c['total']:,}",
        ]
        msg = "\n".join(lines)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(msg)
        else:
            print(msg)

    
    def switch_to_train(self):
        self.encoder.train()
        self.decoder.train()
        self.latent_process.train()

    
    def switch_to_eval(self):
        self.encoder.eval()
        self.decoder.eval()
        self.latent_process.eval()
        

    def init_optim(self):
        self.optim_enc = torch.optim.Adam([{'params': self.encoder.parameters(), 'lr': self.lr}])
        self.optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': self.lr}])
        ######################################
        self.optim_dyn = torch.optim.Adam([{'params': self.latent_process.parameters(), 'lr': self.lr, 'weight_decay': self.cfg.weight_decay}])

        if self.cfg.scheduler == 'OneCycleLR':
            self.scheduler_enc = torch.optim.lr_scheduler.OneCycleLR(self.optim_enc, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
            self.scheduler_dec = torch.optim.lr_scheduler.OneCycleLR(self.optim_dec, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
            ######################################
            self.scheduler_dyn = torch.optim.lr_scheduler.OneCycleLR(self.optim_dyn, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
        elif self.cfg.scheduler == 'CosineAnnealingLR':
            self.scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_enc, T_max=self.cfg.epochs)
            self.scheduler_dec = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dec, T_max=self.cfg.epochs)
            self.scheduler_dyn = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dyn, T_max=self.cfg.epochs)
        elif self.cfg.scheduler == 'StepLR':
            self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
            self.scheduler_dec = torch.optim.lr_scheduler.StepLR(self.optim_dec, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
            self.scheduler_dyn = torch.optim.lr_scheduler.StepLR(self.optim_dyn, step_size=self.cfg.step_size, gamma=self.cfg.gamma)


    def train_recon(self, log_every: int | None = None, verbose: str = None):
        """
        batch['data']: [B, T, H, W, C]
        batch['t']: tensor of len n_frames_train / n_frames_train + n_frames_out
        batch['index']: sample idx 
        batch['mask']: [B, T, H, W, C]
        """
        self.setup_logger()    
        self.save_repro_artifacts()
        self.log_param_table()
        criterion = nn.MSELoss()
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        self.init_optim()
        self._save_split_and_samples()
        self.logger.info(f"Begin training using decoder backbone: {self.dec_mode}")

        if verbose is not None:
            self.logger.info(f"{verbose}")
        for epoch in range(1, self.cfg.epochs + 1):
            for i, batch in enumerate(self.train_loader):
                ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
                sample_idx = batch["index"].to(self.device)     # [B,]
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]    (same over time and channels)
                bs, train_len, H, W, _ = ground_truth.shape
                assert train_len == self.n_frames_train

                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]

                data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token]

                if self.dec_mode == "fouriernet":
                    latent_feats = latent_token.flatten(1, 2).unflatten(0, (bs, train_len-self.n_frames_cond+1))    # [B, t, K*latent_token]
                    assert latent_feats.shape[-1] == self.latent_dim
                    latent_feats = latent_feats.view(bs, train_len-self.n_frames_cond+1, self.state_dim, self.code_dim)    # [B, t, s, code_dim]
                    recon_field = self.decoder(latent_feats)    # [B, t, H, W, s]
                    recon_loss = criterion(ground_truth[:, self.n_frames_cond-1:, ...], recon_field)
                    sqerr = (ground_truth[:, self.n_frames_cond-1:, ...] - recon_field).pow(2) * masks[:, self.n_frames_cond-1:, ...]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, self.n_frames_cond-1:, ...].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss_wz_mask = (sqerr_sum / denom).mean()
                    
                    self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                    recon_loss_wz_mask.backward() if self.loss_with_mask else recon_loss.backward()
                    self.optim_enc.step()
                    self.optim_dec.step()
                    if self.cfg.scheduler == "OneCycleLR":
                        self.scheduler_enc.step()
                        self.scheduler_dec.step()

                elif self.dec_mode == "fouriermlp":
                    latent_feats = latent_token.flatten(1, 2)    # [B*t, K*latent_token]
                    grid_dim = self.pos_feat.shape[-1]
                    grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
                    recon_field = self.decoder(grid=grid, latent_feat=latent_feats)    # [B*t, N_pts, out_dim]
                    recon_field = recon_field.reshape(bs, train_len-self.n_frames_cond+1, *self.shapelist, self.state_dim)
                    # print(recon_field.shape)

                    recon_loss = criterion(ground_truth[:, self.n_frames_cond-1:, ...], recon_field)
                    sqerr = (ground_truth[:, self.n_frames_cond-1:, ...] - recon_field).pow(2) * masks[:, self.n_frames_cond-1:, ...]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, self.n_frames_cond-1:, ...].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss_wz_mask = (sqerr_sum / denom).mean()

                    self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                    recon_loss_wz_mask.backward() if self.loss_with_mask else recon_loss.backward()
                    self.optim_enc.step()
                    self.optim_dec.step()
                    if self.cfg.scheduler == "OneCycleLR":
                        self.scheduler_enc.step()
                        self.scheduler_dec.step()
    
                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    self.logger.info(f"Epoch {epoch:04d}/{self.cfg.epochs} | iteration {i+1:03d} |"
                        f"| field reconstruction loss {recon_loss.item():.8f} | field reconstruction loss (with mask) {recon_loss_wz_mask.item():.8f}")
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_enc.step()
                self.scheduler_dec.step()


    def mask_to_bs_index(self, mask: torch.Tensor, S: int | None = None, threshold: float = 0.5) -> torch.Tensor:
        # mask: [B, T, H, W, C]  (1=keep)
        B, T, H, W, C = mask.shape
        device = mask.device

        m2d = (mask[:, 0, :, :, 0] > threshold)
        lin = (torch.arange(H, device=device).view(H, 1) * W + torch.arange(W, device=device).view(1, W))
        lin = lin.view(1, H, W).expand(B, -1, -1)
        idx_list = [lin[b][m2d[b]].reshape(-1).long() for b in range(B)]
        S = min(x.numel() for x in idx_list) if S is None else S
        if any(x.numel() < S for x in idx_list):
            raise ValueError(f"Some batches have fewer kept points than S={S}.")
        idx = torch.stack([x[:S] for x in idx_list], dim=0)   # [B, S]
        return idx


    def index_points(self, points, idx):
        """
        Input:
            points: input points data, [B, N, C]
            idx: sample index data, [B, S]
        Return:
            new_points: indexed points data, [B, S, C]
        """
        device = points.device
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points

    
    def one_step_loss(self,
                      alpha_gt,           # [T, B, D]
                      t_eval,             # [T]
                      *,
                      train_latent: bool = True,
                      weight_per_step: torch.Tensor | None = None,  # [T-1] or [T-1,B]
                      reduction: str = "mean"):
        device = alpha_gt.device
        T, B, D = alpha_gt.shape

        # Mask that ends every segment at each step: [T-1] all True except the last False
        tf_mask = torch.ones(T - 1, dtype=torch.bool, device=device)
        tf_mask[-1] = False

        # Run once with teacher forcing; memory is carried internally for memory variants.
        # (alpha_gt is detached inside your _odeint_with_tf implementation.)
        out = self.latent_process.forward(
            alpha_0=alpha_gt[0],             # not used when teacher_forcing=True, but required by signature
            t_eval=t_eval,
            memory_init=None,                # optional: pass learned/init memory state if you have one
            teacher_forcing=True,
            tf_alpha=alpha_gt,
            tf_epsilon=0.0,
            tf_mask=tf_mask,
            tf_detach_alpha_starts=not(train_latent),
        )

        # Unpack (handles both return_aux=False/True)
        if isinstance(out, tuple):
            alpha_pred = out[0]  # [T, B, D]
        else:
            alpha_pred = out     # [T, B, D]

        pred_next = alpha_pred[1:]         # [T-1, B, D]
        gt_next   = alpha_gt[1:]           # [T-1, B, D]

        # Per (k,b) MSE before reduction
        per_step_mse = (pred_next - gt_next).pow(2).mean(dim=-1)   # [T-1, B]

        if weight_per_step is not None:
            # Broadcast weights to [T-1, B]
            if weight_per_step.ndim == 1:      # [T-1]
                w = weight_per_step[:, None].to(per_step_mse)
            elif weight_per_step.ndim == 2:    # [T-1, B]
                w = weight_per_step.to(per_step_mse)
            else:
                raise ValueError("weight_per_step must be shape [T-1] or [T-1, B]")
            per_step_mse = per_step_mse * w

        if reduction == "mean":
            loss = per_step_mse.mean()
        elif reduction == "sum":
            loss = per_step_mse.sum()
        elif reduction == "none":
            loss = per_step_mse
        else:
            raise ValueError(f"Unknown reduction: {reduction}")
        return loss, pred_next, gt_next, per_step_mse


    def kstep_rollout_loss(
        self,
        alpha_gt: torch.Tensor,         # [T, B, D]
        t_eval: torch.Tensor,           # [T]
        K: int,
        *,
        num_starts: int | None = None,
        stride: int = 1,
        train_latent: bool = True,
        target_stopgrad: bool = True,
        discount_gamma: float = 1.0,
        reduction: str = "mean",
        # NEW: carry memory across windows
        carry_memory: bool = True,
        # Whether to detach the precomputed memory trace before using it as memory_init
        carry_memory_detach: bool = False,
    ):
        device = alpha_gt.device
        T, B, D = alpha_gt.shape
        assert K >= 1 and T >= K + 1, "Sequence too short for K-step rollout."

        # Candidate window starts s ∈ [0, T-K-1] with the chosen stride.
        all_starts = torch.arange(0, T - K, stride, device=device)

        # Optionally subsample a random subset of starts.
        if num_starts is not None and num_starts < all_starts.numel():
            perm = torch.randperm(all_starts.numel(), device=device)
            starts = all_starts[perm[:num_starts]]
        else:
            starts = all_starts

        t_rel = (t_eval[:K + 1] - t_eval[0]).to(device=device, dtype=alpha_gt.dtype)

        # Optional horizon discount weights: shape [K] or None.
        if discount_gamma == 1.0:
            h_w = None
        else:
            h_w = discount_gamma ** torch.arange(1, K + 1, device=device, dtype=alpha_gt.dtype)

        mem_trace = None
        if carry_memory and self.use_memory:
            tf_mask = torch.ones(T - 1, dtype=torch.bool, device=device)
            tf_mask[-1] = False  # keep last interval connected
            out_tf = self.latent_process.forward(
                alpha_0=alpha_gt[0],     # not used when teacher_forcing=True
                t_eval=t_eval,
                memory_init=None,
                teacher_forcing=True,
                tf_alpha=alpha_gt,
                tf_epsilon=0.0,
                tf_mask=tf_mask,
                tf_detach_alpha_starts=not(train_latent),
                detach_memory_between_segments=False,   # carry memory (and grads) across 1-step segments
            )
            if isinstance(out_tf, tuple):
                # out_tf = (alpha_t, memory_t, [aux])
                mem_trace = out_tf[1]   # [T, B, Dm]
            if mem_trace is not None and carry_memory_detach:
                mem_trace = mem_trace.detach()
        W = starts.numel()
        if W == 0:
            return alpha_gt.new_tensor(0.0), {
                "num_windows": 0, "K": int(K),
                "carry_memory": bool(carry_memory),
                "carry_memory_detach": bool(carry_memory_detach),
            }
        a0 = alpha_gt[starts] if train_latent else alpha_gt[starts].detach()  # [W, B, D]
        a0 = a0.reshape(W * B, D)
        if self.use_memory:
            if carry_memory and (mem_trace is not None):
                mem0 = mem_trace[starts].reshape(W * B, self.latent_process.memory_dim)
            else:
                mem0 = torch.zeros(W * B, self.latent_process.memory_dim, device=device, dtype=alpha_gt.dtype)
        else:
            mem0 = None
        out = self.latent_process.forward(
            alpha_0=a0,
            t_eval=t_rel,
            memory_init=mem0,
            teacher_forcing=False,
        )
        alpha_pred = out[0] if isinstance(out, tuple) else out  # [K+1, W*B, D]
        pred = alpha_pred[1:].reshape(K, W, B, D)               # [K, W, B, D]

        idx = (starts[None, :] + torch.arange(1, K + 1, device=device)[:, None]).reshape(-1)  # [K*W]
        target = alpha_gt.index_select(0, idx).reshape(K, W, B, D)
        if target_stopgrad:
            target = target.detach()

        mse = (pred - target).pow(2).mean(dim=-1)   # [K, W, B]
        if h_w is not None:
            mse = mse * h_w[:, None, None]

        if reduction == "mean":
            loss = mse.mean()
        elif reduction == "sum":
            loss = mse.sum()
        elif reduction == "none":
            loss = mse
        else:
            raise ValueError(f"Unknown reduction: {reduction}")

        return loss, {
            "num_windows": int(W),
            "K": int(K),
            "carry_memory": bool(carry_memory),
            "carry_memory_detach": bool(carry_memory_detach),
        }


    @staticmethod
    @torch.no_grad()
    def solve_Ab_ridge(Z0: torch.Tensor, Zp: torch.Tensor, ridge: float = 1e-3, add_bias: bool = True):
        """
        Closed-form ridge regression for one-step latent linear dynamics:
        min_{A,b} ||Zp - (A Z0 + b)||_F^2 + ridge * ||[A b]||_F^2
        Inputs:
        Z0: [N, D]  flattened (t,b) -> N samples at time t
        Zp: [N, D]  flattened (t,b) -> N samples at time t+1
        Returns:
        A*: [D, D], b*: [D] or None
        """
        N, D = Z0.shape
        if add_bias:
            X = torch.cat([Z0, torch.ones(N, 1, device=Z0.device, dtype=Z0.dtype)], dim=1)  # [N, D+1]
            p = D + 1
        else:
            X = Z0; p = D
        Y = Zp                                          # [N, D]
        G = X.T @ X + ridge * torch.eye(p, device=X.device, dtype=X.dtype)   # [p, p]
        YXt = Y.T @ X                                   # [D, p]
        Theta = torch.linalg.solve(G, YXt.T).T          # [D, p]
        A = Theta[:, :D]                                # [D, D]
        b = Theta[:, D] if add_bias else None           # [D]
        return A, b
    

    @staticmethod
    @torch.no_grad()
    def ema_update_stats(Sxx: torch.Tensor, Syx: torch.Tensor,
                        Z0: torch.Tensor, Zp: torch.Tensor,
                        ema_beta: float = 0.95, add_bias: bool = True):
        """
        EMA sufficient statistics across batches:
        Sxx ≈ E[X^T X],  Syx ≈ E[Y^T X], where X=[Z0; 1] (if add_bias)
        In-place updates with EMA.
        """
        N, D = Z0.shape
        if add_bias:
            X = torch.cat([Z0, torch.ones(N, 1, device=Z0.device, dtype=Z0.dtype)], dim=1)  # [N, D+1]
        else:
            X = Z0
        Y = Zp
        Sxx.mul_(ema_beta).add_(X.T @ X, alpha=(1 - ema_beta))
        Syx.mul_(ema_beta).add_(Y.T @ X, alpha=(1 - ema_beta))
        return Sxx, Syx


    @staticmethod
    @torch.no_grad()
    def solve_global_A_from_stats(Sxx: torch.Tensor, Syx: torch.Tensor,
                                ridge: float = 1e-3, D: int | None = None, add_bias: bool = True):
        """
        Solve global A_ema (+ b_ema) from EMA sufficient statistics.
        """
        p = Sxx.size(0)
        I = torch.eye(p, device=Sxx.device, dtype=Sxx.dtype)
        Theta = torch.linalg.solve(Sxx + ridge * I, Syx.T).T  # [D, p]
        if D is None:
            D = Theta.size(0)
        A_ema = Theta[:, :D]
        b_ema = Theta[:, D] if add_bias else None
        return A_ema, b_ema
    
    
    @torch.no_grad()
    def solve_global_A_fullpass(self, dataloader, ridge: float = 5e-3, use_bias: bool = True):
        enc_was_train = self.encoder.training
        dec_was_train = self.decoder.training
        self.encoder.eval(); self.decoder.eval()
        Z0_list, Zp_list = [], []
        for batch in dataloader:
            latent_states, _, _ = self._encode_and_recon(batch)  # [T',B,D]
            T1, B, D = latent_states.shape
            if T1 <= 1:
                continue
            Z0_list.append(latent_states[:-1].reshape(-1, D))
            Zp_list.append(latent_states[ 1:].reshape(-1, D))

        if len(Z0_list) == 0:
            raise RuntimeError("[solve_global_A_fullpass] No (Z0,Zp) pairs collected. Check dataloader or n_frames_cond.")
        Z0 = torch.cat(Z0_list, dim=0).to(self.device)
        Zp = torch.cat(Zp_list, dim=0).to(self.device)
        A_full, b_full = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)

        if enc_was_train: self.encoder.train()
        if dec_was_train: self.decoder.train()
        return A_full.detach(), (b_full.detach() if b_full is not None else None)


    def train_phase1_linear(self, 
                            epochs: int,
                            ridge: float = 5e-3,
                            ema_beta: float = 0.97,
                            use_bias: bool = True,
                            use_pred_loss: bool = True,
                            lambda_pred: float = 0.1,
                            lambda_dyn: float = 0.05,
                            log_every: int | None = 5,
                            eval_every: int | None = 20,
                            verbose: str = None,
                            diag_enable: bool = True,                    # print diagnostics
                            ms_consistency_enable: bool = False,         # latent multi-step consistency with batch A*
                            freq_ms_enable: bool = False,                # turn on spectral+multiscale penalty
                            global_A_mode: str | None = None,            # "fullpass" | "ema"
                            ):
        """
        Phase-I: train encoder/decoder with closed-form linear skeleton on latent.
        - Per-batch closed-form A* only defines losses (do NOT write into model params).
        - Maintain EMA sufficient statistics to get a global A_ema (+ b_ema) each epoch.
        """
        mode = (global_A_mode or getattr(self.cfg, "global_A_mode", "fullpass")).lower()
        assert mode in {"fullpass", "ema"}, f"global_A_mode must be 'fullpass' or 'ema', got {mode}"

        self.setup_logger()
        self.save_repro_artifacts()
        self.log_param_table()
        assert self.data_processor.mode == "interpolation", "Mismatched dataloaders"
        self.build_dataloader(group="train")
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        # Switch modules
        self.encoder.train()
        self.decoder.train()
        set_requires_grad(self.latent_process, False)  # freeze latent dynamics in Phase-I

        # Optimizers (enc/dec only)
        self.init_optim()

        if verbose is not None:
            self.logger.info(f"{verbose}")

        if mode == "ema":    # EMA stats buffers for a global A
            D = self.latent_dim
            p = D + (1 if use_bias else 0)
            Sxx = torch.zeros(p, p, device=self.device)
            Syx = torch.zeros(D, p, device=self.device)
        else:
            Sxx = Syx = None

        num_epochs = epochs if epochs is not None else self.cfg.epochs
        best_rec, best_metrics = float("inf"), None

        # ===== NEW: buffers for within-epoch diagnostics =====
        rho_list_epoch = []        # spectral radii of A* per batch
        gap_list_epoch = []        # ||A* - A_ema|| / ||A_ema||
        """dynA_list_epoch = []       # dyn loss with A*
        dynE_list_epoch = []       # dyn loss with current A_ema from stats"""
        A_global_prev = getattr(self, "A_phase1_ema", None)

        for epoch in range(1, num_epochs + 1):
            for it, batch in enumerate(self.train_loader):
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
                # print(batch["index"])
                # print(masks[0, 0, ...])
                # ------------------------------------------------------
                # 1) Encode inputs -> latent_states [T,B,D] and recon loss
                # ------------------------------------------------------
                latent_states, recon_loss, recon_loss_wzmask, x_rec, x_gt, _ = self._encode_and_recon(batch, return_recon=True)
                T, B, D = latent_states.shape

                # ------------------------------------------------------
                # 2) Closed-form A*, b* (only for loss) on flattened pairs
                # ------------------------------------------------------
                Z0 = latent_states[:-1].reshape(-1, D)  # [N, D]
                Zp = latent_states[ 1:].reshape(-1, D)  # [N, D]
                with torch.no_grad():
                    A_star, b_star = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)

                rho_Astar = self._spectral_radius(A_star)
                rho_list_epoch.append(rho_Astar)

                # compute a *current* A_ema from stats BEFORE we update them, as a diagnostic baseline
                """with torch.no_grad():
                    A_ema_tmp, b_ema_tmp = self.solve_global_A_from_stats(Sxx, Syx, ridge=ridge, D=D, add_bias=use_bias)
                if torch.isfinite(A_ema_tmp).all():
                    denom = torch.norm(A_ema_tmp) + 1e-12
                    gap = torch.norm(A_star - A_ema_tmp) / denom
                    gap_list_epoch.append(float(gap))"""
                if (A_global_prev is not None) and torch.isfinite(A_global_prev).all():
                    with torch.no_grad():
                        denom = torch.norm(A_global_prev) + 1e-12
                        gap_list_epoch.append(float(torch.norm(A_star - A_global_prev) / denom))

                # Latent one-step consistency: Zp ≈ A*Z0 (+b)
                Zp_hat = (Z0 @ A_star.T) + (b_star if use_bias else 0.0)      # [N, D]
                dyn_loss = F.mse_loss(Zp_hat, Zp)
                # ------------------------------------------------------
                # 3) Optional: one-step in observation space
                # ------------------------------------------------------
                if use_pred_loss:
                    a1_hat = Zp_hat.view(T - 1, B, D)                  # [T-1,B,D]
                    x1_hat = self._decode_latent(a1_hat)               # [B,T-1,H,W,C]
                    x1_true = batch["data"][:, self.n_frames_cond:, ...].to(self.device)
                    pred_loss, pred_loss_wzmask = self._mse_loss(x1_hat, x1_true, masks[:, self.n_frames_cond:, ...])
                else:
                    pred_loss, pred_loss_wzmask = torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)


                # long term prediction loss
                lambda_lt_pred = self.cfg.lambda_lt_pred if (self.cfg.lambda_lt_pred is not None) and ms_consistency_enable else 0.0
                if ms_consistency_enable and (self.cfg.lambda_lt_pred > 0.0):
                    ms_loss = self._multistep_latent_consistency(latent_states, A_star, b_star if use_bias else None,
                                                                 H=self.cfg.rollout_steps, gamma=self.cfg.gamma_decay)
                else:
                    ms_loss = torch.tensor(0.0, device=self.device)

                # loss on the frequency domain
                lambda_freq = self.cfg.lambda_freq if (self.cfg.lambda_freq is not None) and freq_ms_enable else 0.0
                if freq_ms_enable and (lambda_freq > 0.0):
                    xf_hat, xf_true = x_rec, x_gt  # use reconstruction frames
                    fft_loss = self._fft_mse_on_frames(
                        xf_hat, xf_true,
                        use_log_mag=True,
                        hf_power=self.cfg.freq_hf_power
                    )
                    ms_loss_img = self._multiscale_spatial_mse(
                        xf_hat, xf_true,
                        pool_scales=self.cfg.ms_pool_scales
                    )
                    freq_ms_loss = fft_loss + ms_loss_img
                else:
                    freq_ms_loss = torch.tensor(0.0, device=self.device)


                # ------------------------------------------------------
                # 4) Total loss and backward (enc/dec only)
                # ------------------------------------------------------
                loss = recon_loss_wzmask + lambda_dyn * dyn_loss + lambda_pred * pred_loss_wzmask
                loss += lambda_lt_pred * ms_loss
                loss += lambda_freq * freq_ms_loss

                self.optim_enc.zero_grad()
                self.optim_dec.zero_grad()
                loss.backward()
                if self.cfg.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.cfg.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.cfg.max_grad_norm)
                self.optim_enc.step()
                self.optim_dec.step()
                if self.cfg.scheduler == "OneCycleLR":
                    self.scheduler_enc.step()
                    self.scheduler_dec.step()
                # ------------------------------------------------------
                # 5) EMA stats for a global A
                # ------------------------------------------------------
                if mode == "ema":
                    self.ema_update_stats(Sxx, Syx, Z0, Zp, ema_beta=ema_beta, add_bias=use_bias)
                if log_every and ((epoch * len(self.train_loader) + it) % log_every == 0):
                    self.logger.info(
                        f"[Phase-I] epoch {epoch:03d} it {it:04d} | "
                        f"rec {recon_loss.item():.8f} | dyn {dyn_loss.item():.8f} | pred {pred_loss.item():.8f} | "
                        f"rec(with mask) {recon_loss_wzmask.item():.8f} | pred(with mask) {pred_loss_wzmask.item():.8f} |"
                        f"dyn(long term) {ms_loss.item():.8f} | freq_ms_loss {freq_ms_loss.item():.8f}"
                    )
            # Epoch end: compute global A_ema (+ b_ema) for Phase-II init/logging
            with torch.no_grad():
                if mode == "fullpass":
                    A_glob, b_glob = self.solve_global_A_fullpass(self.train_loader, ridge=ridge, use_bias=use_bias)
                else:
                    A_glob, b_glob = self.solve_global_A_from_stats(Sxx, Syx, ridge=ridge, D=D, add_bias=use_bias)
                self.A_phase1_ema = A_glob.clone()
                self.b_phase1_ema = b_glob.clone() if b_glob is not None else None
                # Optionally persist for Phase-II (e.g., pH-dense init)
                out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
                torch.save({'A_ema': self.A_phase1_ema, 'b_ema': self.b_phase1_ema}, os.path.join(out_dir, "A_phase1_ema.pt"))
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_enc.step()
                self.scheduler_dec.step()

            # ===== epoch-end diagnostics print (distributions) =====
            if diag_enable:
                def _stats(v):
                    import math
                    if len(v) == 0: return dict(min=float("nan"), max=float("nan"), mean=float("nan"), std=float("nan"))
                    m = sum(v) / len(v)
                    var = sum((x - m) ** 2 for x in v) / max(1, len(v) - 1)
                    return dict(min=min(v), max=max(v), mean=m, std=math.sqrt(var))
                rs = _stats(rho_list_epoch)
                gs = _stats(gap_list_epoch) if len(gap_list_epoch) else None
            
                # keep your "[P1/Eval-TrainEval]" style but mark "diag"
                diag_str = {
                    "rho(A*_batch)": f"{rs['min']:.4f}/{rs['mean']:.4f}/{rs['max']:.4f}/{rs['std']:.4f}",
                    "gap(A*,A_ema)": (f"{gs['mean']:.3f}" if gs is not None else "NA"),
                    "rho(A_ema)":    (f"{self._spectral_radius(self.A_phase1_ema):.6f}" if hasattr(self, "A_phase1_ema") else "NA"),
                }
                self.logger.info(f"[P1/Diag-Epoch] E={epoch:03d} | "
                                f"rho(A*_batch) min/mean/max/std={diag_str['rho(A*_batch)']} | "
                                f"gap ||A*-A_ema||/||A_ema|| mean={diag_str['gap(A*,A_ema)']} | "
                                f"rho(A_ema)={diag_str['rho(A_ema)']}")

            # reset epoch buffers
            rho_list_epoch.clear()
            gap_list_epoch.clear()
            A_global_prev = self.A_phase1_ema

            ###################### Evaluation ######################
            if eval_every is not None and epoch % eval_every == 0:
                if hasattr(self, "train_eval_loader") and self.train_eval_loader is not None:
                    metrics_tr = self.evaluate_phase1(self.train_eval_loader, use_global_A=True)
                    self.logger.info(f"[P1/Eval-TrainEval] rec(mask)={metrics_tr['rec_masked']:.8f} | "
                                    f"dyn={metrics_tr['dyn_mse']:.8f} | pred(mask)={metrics_tr['pred_masked']:.8f} | "
                                    f"diag={metrics_tr['diag']}")
                    # Track the best by masked reconstruction error
                    if metrics_tr["rec_masked"] < best_rec:
                        best_rec = metrics_tr["rec_masked"]
                        best_metrics = {"split": "train_eval", **metrics_tr}
                        self.save_phase1_checkpoint(epoch=epoch, metrics=best_metrics, pth_name="phase1_best_rec.pth")
                if hasattr(self, "test_loader") and self.test_loader is not None:
                    metrics_ts = self.evaluate_phase1(self.test_loader, use_global_A=True)
                    self.logger.info(f"[P1/Eval-Test] rec(mask)={metrics_ts['rec_masked']:.8f} | "
                                    f"dyn={metrics_ts['dyn_mse']:.8f} | pred(mask)={metrics_ts['pred_masked']:.8f} | "
                                    f"diag={metrics_ts['diag']}")
                    
        self.logger.info("[Phase-I] finished. Global A_ema saved to self.A_phase1_ema.")
        self.save_phase1_checkpoint(epoch=None, metrics=best_metrics, pth_name="phase1_final.pth")

    
    @torch.no_grad()
    def evaluate_phase1(self,
                        dataloader=None,
                        *,
                        use_global_A: bool = True,   # use A_ema if available; else fall back to per-batch closed-form
                        use_bias: bool = True,
                        use_pred_loss: bool = True,  # include one-step decoded loss
                        ridge: float = 5e-3):
        self.encoder.eval()
        self.decoder.eval()

        # choose loader
        if dataloader is None:
            # prefer a held-out split if present
            if hasattr(self, "test_loader") and self.test_loader is not None:
                dataloader = self.test_loader
            elif hasattr(self, "train_eval_loader") and self.train_eval_loader is not None:
                dataloader = self.train_eval_loader
            else:
                # fallback to building a test/eval loader
                try:
                    self.build_dataloader(group="test")
                    dataloader = self.test_loader
                except Exception:
                    self.build_dataloader(group="train_eval")
                    dataloader = self.train_eval_loader

        tot_rec_m = 0.0   # masked recon
        tot_dyn   = 0.0   # latent linear consistency
        tot_pred_m = 0.0  # masked one-step decoded
        nsamples  = 0

        # cache global A if requested and available
        A_fix = None
        b_fix = None
        if use_global_A and hasattr(self, "A_phase1_ema"):
            A_fix = self.A_phase1_ema.to(self.device, dtype=torch.float32)
            b_fix = (self.b_phase1_ema.to(self.device, dtype=torch.float32)
                    if getattr(self, "b_phase1_ema", None) is not None else None)

        for batch in dataloader:
            masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
            # 1) encode & recon: latent_states [t_eff,B,D], recon loss on [nf_cond-1:]
            latent_states, recon_loss, recon_loss_wzmask = self._encode_and_recon(batch)  # recon_loss_wzmask is masked MSE
            T_eff, B, D = latent_states.shape
            nsamples += B

            # 2) closed-form A,b (per-batch) OR fixed A_ema,b_ema
            Z0 = latent_states[:-1].reshape(-1, D)  # [N, D], N=(T_eff-1)*B
            Zp = latent_states[ 1:].reshape(-1, D)  # [N, D]
            if A_fix is None:
                A_use, b_use = self.solve_Ab_ridge(Z0, Zp, ridge=ridge, add_bias=use_bias)
            else:
                A_use, b_use = A_fix, b_fix

            # 3) latent linear consistency loss: MSE(Zp_hat, Zp)
            Zp_hat = (Z0 @ A_use.T) + (b_use if (use_bias and b_use is not None) else 0.0)
            dyn_mse = F.mse_loss(Zp_hat, Zp)

            # 4) optional: one-step decoded loss in observation space
            if use_pred_loss:
                a1_hat = Zp_hat.view(T_eff - 1, B, D)                # [T_eff-1, B, D]
                x1_hat = self._decode_latent(a1_hat)                 # [B, T_eff-1, H, W, C]
                x1_true = batch["data"][:, self.n_frames_cond:, ...].to(self.device)  # next frames [nf_cond : T-1]
                pred_mse, pred_mse_wzmask = self._mse_loss(x1_hat, x1_true, masks[:, self.n_frames_cond:, ...])
            else:
                pred_mse_wzmask = torch.tensor(0.0, device=self.device)

            # accumulate (weight by batch size for fair averaging)
            tot_rec_m += float(recon_loss_wzmask) * B
            tot_dyn   += float(dyn_mse)            * B
            tot_pred_m += float(pred_mse_wzmask)   * B

        # finalize averages
        rec_masked = tot_rec_m / max(1, nsamples)
        dyn_avg    = tot_dyn   / max(1, nsamples)
        pred_masked= tot_pred_m/ max(1, nsamples)

        # optional: spectral diagnostics when using A_ema
        diag = {}
        if A_fix is not None:
            try:
                eigvals = torch.linalg.eigvals(A_fix).detach().cpu()
                diag["rho(A_ema)"] = float(eigvals.abs().max())       # discrete spectral radius
            except Exception:
                pass

        return {
            "rec_masked": rec_masked,
            "dyn_mse": dyn_avg,
            "pred_masked": pred_masked if use_pred_loss else None,
            "use_global_A": (A_fix is not None),
            "diag": diag
        }


    @torch.no_grad()
    def _phase1_ckpt_payload(self, epoch: int | None, metrics: dict | None = None):
        def sd_cpu(module: nn.Module):
            return {k: v.detach().cpu() for k, v in module.state_dict().items()}

        payload = {
            "epoch": epoch,
            "run_id": getattr(self, "run_id", None),
            "model_cfg": getattr(self, "model_cfg", None),
            "encoder_state": sd_cpu(self.encoder),
            "decoder_state": sd_cpu(self.decoder),
            # Stash Phase-I linear stats if available
            "A_ema": (self.A_phase1_ema.detach().cpu()
                    if hasattr(self, "A_phase1_ema") and self.A_phase1_ema is not None else None),
            "b_ema": (self.b_phase1_ema.detach().cpu()
                    if hasattr(self, "b_phase1_ema") and self.b_phase1_ema is not None else None),
            "metrics": metrics,
        }
        return payload


    def save_phase1_checkpoint(self, epoch: int | None, metrics: dict | None, pth_name: str):
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, pth_name)
        torch.save(self._phase1_ckpt_payload(epoch, metrics), path)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Phase-I] checkpoint saved: {path}")


    def load_phase1_ckpt(
        self,
        path: str,
        *,
        restore_modules: bool = True,
        init_linear: bool = True,
        clip_positive_symmetric: bool = False,
        max_pos_real: float = 0.0,
        eps_eye: float = 1e-9,
        strict: bool = True,
    ):
        # ---- 0) Load checkpoint on CPU first (portable), then move what we need. ----
        ckpt = torch.load(path, map_location="cpu")

        # ---- 1) Optionally restore encoder/decoder weights. ----
        if restore_modules:
            enc_state = ckpt.get("encoder_state", None)
            dec_state = ckpt.get("decoder_state", None)
            if enc_state is not None:
                self.encoder.load_state_dict(enc_state, strict=strict)
            if dec_state is not None:
                self.decoder.load_state_dict(dec_state, strict=strict)

        Ad = None
        b = None
        z_star = None
        rho = float("nan")

        # ---- 2) Optionally initialize continuous-time linear dynamics from Ad. ----
        if init_linear:
            Ad_cpu = ckpt.get("A_ema", None)
            if Ad_cpu is None:
                raise KeyError("[load_phase1_ckpt] 'A_ema' is missing in the checkpoint.")
            Ad = Ad_cpu.to(self.device, dtype=torch.float32)

            # Basic shape checks
            if Ad.dim() != 2 or Ad.size(0) != Ad.size(1):
                raise ValueError(f"[load_phase1_ckpt] Ad must be square, got {tuple(Ad.shape)}")
            if Ad.size(0) != self.latent_dim:
                raise ValueError(f"[load_phase1_ckpt] Ad dim {Ad.size(0)} != latent_dim {self.latent_dim}")

            if self.latent_mode == "continuous":
                # Initialize continuous-time drift from Ad
                self.latent_process.init_linear_from_Ad(
                    Ad, self.dt_eval,
                    clip_positive_symmetric=clip_positive_symmetric,
                    max_pos_real=max_pos_real,
                )
            elif self.latent_mode == "discrete":
                print(Ad.shape)
                print(self.latent_process.A.shape)
                with torch.no_grad():
                    self.latent_process.A.copy_(Ad)

            # Try to compute bias center if b_ema exists
            b_cpu = ckpt.get("b_ema", None)
            if b_cpu is not None:
                b = b_cpu.to(self.device, dtype=torch.float32).view(-1)
                if b.numel() != self.latent_dim:
                    raise ValueError(f"[load_phase1_ckpt] b dim {b.numel()} != latent_dim {self.latent_dim}")

                # Solve (I - Ad) z* = b with small Tikhonov for stability
                D = Ad.size(0)
                I = torch.eye(D, device=self.device, dtype=Ad.dtype)
                M = I - Ad
                M_reg = M + eps_eye * I
                try:
                    z_star = torch.linalg.solve(M_reg, b)
                except RuntimeError:
                    z_star = torch.linalg.pinv(M_reg) @ b

                # Cache for later centering use
                self.latent_center = z_star.detach().clone()
            else:
                self.latent_center = None

            # Keep copies for diagnostics/exports
            self.Ad_phase1 = Ad.detach().clone()
            self.b_phase1 = b.detach().clone() if b is not None else None

            # Spectral radius of Ad (discrete)
            try:
                rho = torch.max(torch.abs(torch.linalg.eigvals(Ad))).item()
            except Exception:
                rho = float("nan")

        # ---- 3) Logging ----
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(
                f"[load_phase1_ckpt] from {path} | "
                f"restore_modules={restore_modules} | init_linear={init_linear} | "
                f"bias={'yes' if b is not None else 'no'} | "
                f"center={'yes' if z_star is not None else 'no'} | "
                f"rho(Ad)={rho:.4f}"
            )

        return {"Ad": Ad, "b": b, "z_star": z_star, "rho_Ad": rho}

    
    def _center_latent(self, z: torch.Tensor) -> torch.Tensor:
        """Subtract the learned fixed-point center z* if available, preserving shape.
        - Input:  z [..., D]  (e.g., [D], [B,D], [T,B,D])
        - Output: same shape as input
        """
        zc = getattr(self, "latent_center", None)
        if zc is None:
            return z
        # Safety: last dim must match
        assert z.shape[-1] == zc.numel(), f"center dim {zc.numel()} != z last dim {z.shape[-1]}"
        if z.ndim == 1:
            return z - zc.to(z)
        # Shape like [1, 1, ..., D] with (z.ndim-1) leading ones
        view_shape = [1] * (z.ndim - 1) + [-1]
        return z - zc.to(z).view(*view_shape)


    def _decenter_latent(self, z: torch.Tensor) -> torch.Tensor:
        """Add the learned fixed-point center z* back if available, preserving shape.
        - Input:  z [..., D]
        - Output: same shape as input
        """
        zc = getattr(self, "latent_center", None)
        if zc is None:
            return z
        assert z.shape[-1] == zc.numel(), f"center dim {zc.numel()} != z last dim {z.shape[-1]}"
        if z.ndim == 1:
            return z + zc.to(z)
        view_shape = [1] * (z.ndim - 1) + [-1]
        return z + zc.to(z).view(*view_shape)
    

    def _encode_and_recon(self, batch, return_recon: bool = False):
        ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
        sample_idx = batch["index"].to(self.device)     # [B,]
        masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
        t_eval = batch['t'][0].to(self.device)          # [T]
        bs, train_len, H, W, _ = ground_truth.shape
        # assert train_len == self.n_frames_train

        mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
        mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
        delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
        data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]

        data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
        # print(data_in.shape)
        pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
        pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
        data_in = torch.cat((data_in, pos_in), dim=-1) 

        # latent_states: encoded from encoder / dyn_states: predicted using latent ode
        latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token] (t=T-nf_cond+1)
        # reshape to [B, t, latent_dim]
        latent_states = latent_token.reshape((bs, train_len-self.n_frames_cond+1, -1))    # [B, t, latent_dim]
        latent_states = latent_states.permute(1, 0, 2)    # [t, B, latent_dim]

        recon_encdec = self._decode_latent(latent_seqs=latent_states)    # [B, t, H, W, c]
        recon_loss, recon_loss_wz_mask = self._mse_loss(recon_encdec, ground_truth[:, self.n_frames_cond-1:, ...],
                                                mask=masks[:, self.n_frames_cond-1:, ...])
        if return_recon:
            # Return both losses and tensors needed for spectral/multiscale penalties
            x_rec   = recon_encdec
            x_gt    = ground_truth[:, self.n_frames_cond - 1:, ...]
            m_rec   = masks[:, self.n_frames_cond - 1:, ...]
            return latent_states, recon_loss, recon_loss_wz_mask, x_rec, x_gt, m_rec
        else:
            return latent_states, recon_loss, recon_loss_wz_mask

    
    @torch.no_grad()    # for evaluation
    def _encode_cond_batch(self, batch):
        ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
        masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
        bs, length, H, W, _ = ground_truth.shape 
        with torch.no_grad():
            mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
            delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
            data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
            data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
            pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
            pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
            data_in = torch.cat((data_in, pos_in), dim=-1) 

            # latent_states: encoded from encoder / dyn_states: predicted using latent ode
            latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
            # reshape to [B, latent_dim]
            latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]
        return latent_state
        
        
    def _mse_loss(self, data1: torch.Tensor, data2: torch.Tensor, mask: torch.Tensor):
        # datas: [B, T, H, W, s], mask: [B, T, H, W, s] or None
        # check!!!!!!!!!!
        criterion = nn.MSELoss(reduction="none")
        loss = criterion(data1, data2)
        mse = loss.mean()
        
        sqerr = (data1 - data2).pow(2) * mask
        sqerr_sum = sqerr.sum(dim=(2, 3))
        denom = mask.sum(dim=(2, 3)).clamp_(min=1e-6)
        mse_wzmask = (sqerr_sum / denom).mean()
        return mse, mse_wzmask
    

    
    def _fft_mse_on_frames(
        self,
        x_hat: torch.Tensor,   # [B, T, H, W, C]
        x_true: torch.Tensor,  # [B, T, H, W, C]
        *,
        use_log_mag: bool = True,
        hf_power: float = 0.0
    ) -> torch.Tensor:
        B, T, H, W, C = x_hat.shape

        # Flatten to [B*T*C, H, W]
        def flat_hw(x: torch.Tensor) -> torch.Tensor:
            return x.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, H, W)
        Xh = flat_hw(x_hat)
        Xt = flat_hw(x_true)
        # 1) Per-frame de-mean in spatial domain (kills DC)
        Xh = Xh - Xh.mean(dim=(1, 2), keepdim=True)
        Xt = Xt - Xt.mean(dim=(1, 2), keepdim=True)
        # 2) 2D rFFT with orthonormal scaling
        Fh = torch.fft.rfft2(Xh, norm="ortho")   # [B*T*C, H, W_r]
        Ft = torch.fft.rfft2(Xt, norm="ortho")   # [B*T*C, H, W_r]
        W_r = Fh.size(-1)
        # 3) Magnitude (optionally log1p)
        Mh = torch.abs(Fh)
        Mt = torch.abs(Ft)
        if use_log_mag:
            Mh = torch.log1p(Mh)
            Mt = torch.log1p(Mt)
        # 4) Optional radial high-frequency emphasis
        if hf_power > 0.0:
            fy = torch.fft.fftfreq(H, d=1.0).to(Xh.device, Xh.dtype)      # [H]
            fx = torch.fft.rfftfreq(W, d=1.0).to(Xh.device, Xh.dtype)     # [W_r]
            gy = fy[:, None]                                              # [H, 1]
            gx = fx[None, :]                                              # [1, W_r]
            r  = torch.sqrt(gx * gx + gy * gy)                            # [H, W_r]
            r  = r / (r.max() + 1e-12)
            w  = r.pow(hf_power)
        else:
            w = Xh.new_ones((H, W_r))                                     # [H, W_r]
        # (robust) completely remove DC contribution in frequency loss
        w = w.clone()
        w[0, 0] = 0.0
        Mh = Mh * w
        Mt = Mt * w
        # 5) Per-sample spectral normalization (keeps loss scale tidy)
        def normalize(a: torch.Tensor) -> torch.Tensor:
            scale = a.mean(dim=(1, 2), keepdim=True)                      # [B*T*C, 1, 1]
            return a / (scale + 1e-6)
        Mh = normalize(Mh)
        Mt = normalize(Mt)
        # 6) MSE over spectra
        return F.mse_loss(Mh, Mt)


    def _multiscale_spatial_mse(self,
                                x_hat: torch.Tensor,  # [B, T, H, W, C]
                                x_true: torch.Tensor, # [B, T, H, W, C]
                                *,
                                pool_scales: tuple[int, ...] = (2, 4)) -> torch.Tensor:
        """
        Multi-scale MSE via average pooling pyramid.
        For each s in pool_scales, downsample by s in H and W and compute MSE.
        """
        if not pool_scales:
            return x_hat.new_zeros(())
        B, T, H, W, C = x_hat.shape
        xh = x_hat.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, 1, H, W)  # [BTC,1,H,W]
        xt = x_true.permute(0, 1, 4, 2, 3).contiguous().view(B * T * C, 1, H, W)

        loss = x_hat.new_zeros(())
        n = 0
        for s in pool_scales:
            if H // s < 1 or W // s < 1:
                continue
            pool = torch.nn.AvgPool2d(kernel_size=s, stride=s, ceil_mode=False, count_include_pad=False)
            xh_s = pool(xh)
            xt_s = pool(xt)
            loss = loss + F.mse_loss(xh_s, xt_s)
            n += 1
        if n == 0:
            return x_hat.new_zeros(())
        return loss / n


    def _decode_latent(self, latent_seqs: torch.Tensor):
        """
        Decode a latent sequence a_seq [T',B,D] into field tensor aligned per your decoder.
        Returns [B,T',H,W,C] (or your state_dim layout).
        """
        T1, B, D = latent_seqs.shape
        if self.dec_mode == "fouriermlp":
            latent_feats_ = latent_seqs.permute(1, 0, 2).flatten(0, 1)    # [B*t, latent_dim]
            grid_dim = self.pos_feat.shape[-1]
            grid = self.pos_feat.reshape(-1, grid_dim).to(self.device)    # [N_pt, grid_dim]
            recon_field_ = self.decoder(grid=grid, latent_feat=latent_feats_)    # [B*t, N_pts, out_dim]
            recon_field_ = recon_field_.reshape(B, T1, *self.shapelist, self.state_dim)
            return recon_field_
        elif self.dec_mode == "fouriernet":
            latent_feats_ = latent_seqs.permute(1, 0, 2).view(B, T1, self.state_dim, self.code_dim)  # [B, t, s, code_dim]
            recon_field_ = self.decoder(latent_feats_)  # [B, t, H, W, s]
            return recon_field_
        else:
            raise ValueError(f"Unknown dec_mode: {self.dec_mode}")


    @torch.no_grad()
    def print_latent_of_index(self, group: str, seq_index: int, num_steps: int = 8, show_norms: bool = True,):
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        assert 0 <= seq_index < len(dataset), f"seq_index cross the border: 0..{len(dataset)-1}"
        from torch.utils.data import Subset, DataLoader
        tmp = DataLoader(Subset(dataset, [seq_index]), batch_size=1, shuffle=False)
        batch = next(iter(tmp))
        lat, _, _ = self._encode_and_recon(batch)   # [T', B=1, D]
        lat = lat[:, 0].detach().cpu()              # [T', D]
        Tp, D = lat.shape
        print(f"[latent] shape = [T'={Tp}, D={D}]  (group={group}, idx={seq_index})")
        n = Tp if (num_steps is None) else max(0, min(num_steps, Tp))
        if n > 0:
            print(lat[:n])
        if show_norms:
            # Per-step norms across feature dimension D
            # L2 (Euclidean), Linf (max-abs), and mean absolute value
            l2 = torch.linalg.norm(lat, ord=2, dim=-1)                # [T']
            linf = torch.linalg.norm(lat, ord=float("inf"), dim=-1)   # [T']
            mean_abs = lat.abs().mean(dim=-1)                         # [T']

            k = n if n > 0 else min(8, Tp)  # how many time steps to list explicitly
            print(f"\nPer-step norms for first {k} steps (L2, L∞, mean|.|):")
            for t in range(k):
                print(
                    f"t={t:02d}: "
                    f"L2={float(l2[t]):.6f}  "
                    f"Linf={float(linf[t]):.6f}  "
                    f"mean|.|={float(mean_abs[t]):.6f}"
                )

            # Summary statistics over the whole sequence
            print("\nSummary across all steps:")
            print(
                "  L2:    "
                f"min={float(l2.min()):.6f}  "
                f"max={float(l2.max()):.6f}  "
                f"mean={float(l2.mean()):.6f}"
            )
            print(
                "  Linf:  "
                f"min={float(linf.min()):.6f}  "
                f"max={float(linf.max()):.6f}  "
                f"mean={float(linf.mean()):.6f}"
            )
            print(
                "  mean|.|: "
                f"min={float(mean_abs.min()):.6f}  "
                f"max={float(mean_abs.max()):.6f}  "
                f"mean={float(mean_abs.mean()):.6f}"
            )
        return lat



   
    @torch.no_grad()
    def set_whitening_scale(self, s: torch.Tensor | None):
        """Register per-dim whitening scales."""
        if s is None:
            self.whiten_scale = None
            return
        s = s.view(-1).to(device=self.device, dtype=torch.float32)
        assert s.numel() == self.latent_dim
        self.whiten_scale = torch.clamp(s, min=float(self.whiten_clamp))


    def _whiten_latent(self, centered: torch.Tensor) -> torch.Tensor:
        """w = S^{-1}(z - z*); shape-preserving."""
        s = getattr(self, "whiten_scale", None)
        if (not self.use_diag_whiten) or (s is None):
            return centered
        view = [1]*(centered.ndim - 1) + [-1]
        return centered / s.view(*view)


    def _unwhiten_latent(self, w: torch.Tensor) -> torch.Tensor:
        """(z - z*) = S w; shape-preserving."""
        s = getattr(self, "whiten_scale", None)
        if (not self.use_diag_whiten) or (s is None):
            return w
        view = [1]*(w.ndim - 1) + [-1]
        return w * s.view(*view)


    @torch.no_grad()
    def fit_diag_whitening_from_phase1(self, group: str = "train_eval", max_batches: int | None = 16):
        """Estimate per-dim RMS of centered latents from Phase-I encoder outputs."""
        # choose dataloader
        self._ensure_loader(group)
        loader = {"train": self.train_loader, "train_eval": self.train_eval_loader, "test": self.test_loader}[group]
        if loader is None:
            raise RuntimeError(f"[fit_diag_whitening_from_phase1] loader for {group} is None")
        self.encoder.eval()
        m2 = torch.zeros(self.latent_dim, device=self.device, dtype=torch.float64)
        n = 0
        for b_idx, batch in enumerate(loader):
            if max_batches is not None and b_idx >= max_batches:
                break
            lat, _, _ = self._encode_and_recon(batch)                # [T',B,D]
            y = self._center_latent(lat).reshape(-1, self.latent_dim).to(torch.float64)
            m2 += (y*y).sum(dim=0)
            n += y.shape[0]
        if n == 0:
            raise RuntimeError("[fit_diag_whitening_from_phase1] no samples")
        mean_sq = (m2 / n).to(torch.float32)
        s = torch.sqrt(torch.clamp(mean_sq, min=self.whiten_eps)).clamp_min(self.whiten_clamp)
        self.set_whitening_scale(s)
        return s


    @torch.no_grad()
    def _reinit_linear_in_whiten_space(self):
        """Apply Aw = S^{-1} Ad S and re-init latent_process linear skeleton."""
        if (not self.use_diag_whiten) or (self.whiten_scale is None):
            return
        if not hasattr(self, "Ad_phase1") or self.Ad_phase1 is None:
            return
        Ad = self.Ad_phase1.to(self.device, dtype=torch.float32)  # [D,D]
        s  = self.whiten_scale.to(self.device, dtype=Ad.dtype)    # [D]
        # Aw = S^{-1} Ad S  (column scale by s, then row divide by s)
        Aw = (Ad * s.view(1, -1)) / s.view(-1, 1)
        if self.latent_mode == "continuous":
            self.latent_process.init_linear_from_Ad(
                Aw, self.dt_eval, clip_positive_symmetric=False, max_pos_real=0.0
            )
        elif self.latent_mode == "discrete":
            self.latent_process.A.copy_(Aw)
        self.Ad_phase1_whiten = Aw.detach().clone()

    
    def _get_U(self) -> torch.Tensor:
        """Return orthonormal U [D,d] on correct device/dtype."""
        if self.U_proj is None:
            raise RuntimeError("U_proj is None. Train or load projector first.")
        p0 = next(self.latent_process.parameters(), None)
        dtype = p0.dtype if p0 is not None else torch.float32
        return self.U_proj.to(device=self.device, dtype=dtype)


    def _project_latent(self, w: torch.Tensor) -> torch.Tensor:
        """w [..., D] -> y [..., d] via U."""
        U = self._get_U()
        return torch.matmul(w, U)  # broadcast matmul: (...,D) x (D,d) -> (...,d)


    def _lift_latent(self, y: torch.Tensor) -> torch.Tensor:
        """y [..., d] -> w [..., D] via U^T."""
        U = self._get_U()
        return torch.matmul(y, U.t())


    @torch.no_grad()
    def set_projector(self, U: torch.Tensor):
        """Register an orthonormal projector U [D,d]."""
        U = U.to(device=self.device, dtype=torch.float32)
        Q, _ = torch.linalg.qr(U)
        d = U.size(1)
        self.U_proj = Q[:, :d].contiguous()
        self.use_projector = True
        self.latent_dim_y = d


    def save_projector(self, path: str):
        import os
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save({"U_proj": (self.U_proj.detach().cpu() if self.U_proj is not None else None)}, path)


    def load_projector(self, path: str):
        payload = torch.load(path, map_location="cpu")
        U = payload.get("U_proj", None)
        if U is None:
            raise KeyError(f"'U_proj' missing in {path}")
        self.set_projector(U)


    def train_projector_from_phase1_ckpt(self, phase1_path: str, d: int,
                                         epochs: int = 3, lr: float = 1e-3, lr_dec: float = 0.0,
                                         lambda_dyn: float = 0.1, lambda_ortho: float = 1e-2,
                                         log_every: int | None = None, eval_every: int | None = None
                                         ):
        self.setup_logger()
        self.logger.info(f"Training low dimensional projectors based on pretrained encoder-decoder models!")
        self.save_repro_artifacts()

        # load model from phase1
        # load encoder-decoder, and linear parameter
        phase1_info = self.load_phase1_ckpt(path=phase1_path, clip_positive_symmetric=False)
        assert self.Ad_phase1 is not None
        Ad = self.Ad_phase1
        train_dec = (self.cfg.lr_dec != 0)
        self.logger.info(f"finetune decoder = {train_dec}")
        set_requires_grad(self.encoder, False)
        set_requires_grad(self.decoder, True) if train_dec else set_requires_grad(self.decoder, False)    ##################

        # ---- estimate whitening scales (Phase-I encoder) & re-init Aw ----
        if self.use_diag_whiten:
            self.fit_diag_whitening_from_phase1(group="train", max_batches=16)
            self._reinit_linear_in_whiten_space()
            self.logger.info(f"whiten vector: {self.whiten_scale}")
        if self.whiten_scale is not None:
            s = self.whiten_scale.to(self.device)
            Aw = (Ad.to(self.device, torch.float32) * s.view(1,-1)) / s.view(-1,1)
            bw = None
            if getattr(self, "b_phase1", None) is not None:
                bw = (self.b_phase1.to(self.device, torch.float32) / (s + 1e-12))
        else:
            Aw = Ad.to(self.device, torch.float32)
            bw = (self.b_phase1.to(self.device, torch.float32) if getattr(self, "b_phase1", None) is not None else None)
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        # self.build_dataloader(group="test")
        # self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        ######## Build Optimizer ########
        D = self.latent_dim
        assert d <= D, f"d={d} must be <= latent_dim={D}"
        U_param = torch.nn.Parameter(torch.randn(D, d, device=self.device) / (D**0.5))
        optim_proj = torch.optim.Adam([U_param], lr=lr)
        optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': lr_dec}]) if lr_dec != 0 else None

        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        for epoch in range(1, epochs + 1):
            for i, batch in enumerate(self.train_loader):
                latent_states, _, _ = self._encode_and_recon(batch)     # [t, B, latent_dim]
                latent_states_ = self._center_latent(latent_states)
                latent_states_ = self._whiten_latent(latent_states_)

                Q, _ = torch.linalg.qr(U_param)                         # [D,D]
                U = Q[:, :d]                                            # [D,d]
                eff_states_ = torch.matmul(latent_states_, U)           # [T',B,d]
                latent_hat_ = torch.matmul(eff_states_, U.t())          # [T',B,D]
                latent_hat = self._decenter_latent(self._unwhiten_latent(latent_hat_))
                x_hat = self._decode_latent(latent_hat)                 # [B,T',H,W,C]
                x_gt = batch["data"][:, self.n_frames_cond-1:, ...].to(self.device)
                m_gt = batch["mask"][:, self.n_frames_cond-1:, ...].to(self.device)
                _, rec_loss = self._mse_loss(x_hat, x_gt, m_gt)         # masked MSE

                A_eff = U.t() @ Aw @ U                                  # [d,d]
                eff_next_ = torch.matmul(eff_states_[:-1], A_eff.t())   # [T'-1,B,d]
                """if bw is not None:
                    eff_next_ = eff_next_ + (U.t() @ bw).view(1,1,-1)"""
                dyn_loss = torch.nn.functional.mse_loss(eff_states_[1:], eff_next_)
                ortho_loss = torch.norm(U.t() @ U - torch.eye(d, device=U.device))**2

                loss = rec_loss + lambda_dyn * dyn_loss + lambda_ortho * ortho_loss
                optim_proj.zero_grad()
                optim_dec.zero_grad() if train_dec else None
                loss.backward()
                optim_proj.step()
                optim_dec.step() if train_dec else None

                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    self.logger.info(f"[Projector] epoch {epoch:03d}/{epochs} | rec(mask)={float(rec_loss):.6f} | dyn={float(dyn_loss):.6f} | ortho={float(ortho_loss):.6f}")
        
        with torch.no_grad():
            Q, _ = torch.linalg.qr(U_param.data)
            self.U_proj = Q[:, :d].contiguous()
            self.use_projector = True
            self.latent_dim_y = d
        A_eff = (self.U_proj.t().to(Aw) @ Aw @ self.U_proj.to(Aw)).detach().cpu()
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        os.makedirs(out_dir, exist_ok=True)
        torch.save({"U_proj": self.U_proj.detach().cpu(),
                    "A_eff": A_eff}, os.path.join(out_dir, f"U_proj_d{d}.pt"))
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Projector] saved to: {os.path.join(out_dir, f'U_proj_d{d}.pt')}")

        ##### rebuild latent process #####
        Ay = self.U_proj.t() @ Aw @ self.U_proj
        from copy import deepcopy
        if self.latent_mode == "continuous":
            lp_cfg = deepcopy(self.model_cfg["latent_process"])
            assert d % self.state_dim == 0
            lp_cfg["code_dim"] = d // self.state_dim
            self.model_cfg["latent_process"] = lp_cfg
            self.latent_process = build_latent_process(model_cfg=lp_cfg).to(self.device)
            self.latent_process.init_linear_from_Ad(Ay, self.dt_eval)
        else:
            lp_cfg = deepcopy(self.model_cfg["latent_process_discrete"])
            assert d % self.state_dim == 0
            lp_cfg["code_dim"] = d // self.state_dim
            self.model_cfg["latent_process_discrete"] = lp_cfg
            self.latent_process = build_latent_process_discrete(model_cfg=lp_cfg).to(self.device)
            with torch.no_grad():
                self.latent_process.A.copy_(Ay)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"[Projector] latent_process rebuilt at dim d={d}. Now ready for Phase-2.")


    def _phase2_param_groups(self, lr_mem=1e-3, lr_lin=0.0, wd_mem=0.0, wd_lin=1e-4,):
        lp = self.latent_process
        mem_params, lin_params = [], []

        # ---- linear skeleton ----
        if self.latent_mode == "continuous":
            if getattr(lp, "linear_param", "free") == "free":
                if hasattr(lp, "linear"):
                    lin_params += [lp.linear]
            elif lp.linear_param == "pH_dense":
                for n in ["ph_B", "ph_C", "ph_L", "ph_W", "ph_omega_raw"]:
                    if hasattr(lp, n):
                        lin_params += [getattr(lp, n)]
        elif self.latent_mode == "discrete":
            lin_params += [lp.A]

        # ---- memory correction----
        if self.latent_mode == "continuous":
            if getattr(self, "latent_type", None) in {"linear+memory", "gru_memory", "lstm_memory"}:
                for n in ["memory_encoder", "memory_decoder",
                          "gru_r", "gru_z", "gru_h",
                          "lstm_i", "lstm_f", "lstm_o", "lstm_cand"]:
                    if hasattr(lp, n):
                        mem_params += list(getattr(lp, n).parameters())
                for n in ["_raw_lambda", "_raw_tau_m_inv", "_raw_tau_c_inv", "_raw_tau_h_inv"]:
                    if hasattr(lp, n):
                        mem_params += [getattr(lp, n)]
                if hasattr(lp, "_raw_mem_scale"):
                    mem_params += [lp._raw_mem_scale]
        elif self.latent_mode == "discrete":
            for n in ["memory_encoder", "memory_decoder"]:     
                if hasattr(lp, n):
                    mem_params += list(getattr(lp, n).parameters())  
            if hasattr(lp, "memory"):                           
                mem_params += list(lp.memory.parameters())     
            if hasattr(lp, "_raw_gate"):                       
                mem_params += [lp._raw_gate]    
            for n in ["ode_func", "rnn_cell"]:  
                if hasattr(lp, n):
                    mem_params += list(getattr(lp, n).parameters()) 
            if hasattr(lp, "res_lstm"):
                mem_params += list(lp.res_lstm.parameters()) 

        groups = []
        if mem_params:
            groups.append({"name": "memory", "params": mem_params, "lr": lr_mem, "weight_decay": wd_mem})
        if lin_params:
            groups.append({"name": "linear", "params": lin_params, "lr": lr_lin, "weight_decay": wd_lin})
        return groups


    def init_optim_phase2(self, lr_mem=None, lr_lin=None, lr_dec=None):
        if lr_mem is None:
            lr_mem = getattr(self.cfg, "lr_dyn_mem", self.lr)
        if lr_lin is None:
            lr_lin = getattr(self.cfg, "lr_dyn_lin", 0.0)
        if lr_dec is None:
            lr_dec = getattr(self.cfg, "lr_dec", 0.0)

        groups = self._phase2_param_groups(
            lr_mem=lr_mem, lr_lin=lr_lin,
            wd_mem=0.0, wd_lin=0.0,
        )

        self.optim_dyn = torch.optim.Adam(groups) if groups else None
        # self.optim_enc = torch.optim.Adam([{'params': self.encoder.parameters(), 'lr': lr_encdec}]) if lr_encdec != 0 else None
        self.optim_dec = torch.optim.Adam([{'params': self.decoder.parameters(), 'lr': lr_dec}]) if lr_dec != 0 else None

        # ------- scheduler -------
        self.scheduler_dyn = None
        self.scheduler_enc, self.scheduler_dec = None, None
        if self.optim_dyn is not None:
            """if self.cfg.scheduler == 'OneCycleLR':
                steps_per_epoch = len(self.train_loader)
                max_lr = max(g.get("lr", 0.0) for g in groups)
                self.scheduler_dyn = torch.optim.lr_scheduler.OneCycleLR(
                    self.optim_dyn, max_lr=max_lr,
                    epochs=self.cfg.epochs, steps_per_epoch=steps_per_epoch,
                    pct_start=self.cfg.pct_start
                )"""
            if self.cfg.scheduler == 'CosineAnnealingLR':
                self.scheduler_dyn = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dyn, T_max=self.cfg.epochs)
            elif self.cfg.scheduler == 'StepLR':
                self.scheduler_dyn = torch.optim.lr_scheduler.StepLR(
                    self.optim_dyn, step_size=self.cfg.step_size, gamma=self.cfg.gamma
                )
        if lr_dec != 0:
            if self.cfg.scheduler == 'CosineAnnealingLR':
                # self.scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_enc, T_max=self.cfg.epochs)
                self.scheduler_dec = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim_dec, T_max=self.cfg.epochs)
            elif self.cfg.scheduler == 'StepLR':
                # self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
                self.scheduler_dec = torch.optim.lr_scheduler.StepLR(self.optim_dec, step_size=self.cfg.step_size, gamma=self.cfg.gamma)


    @torch.no_grad()
    def _compute_memory_lengths(self, latent_states_: torch.Tensor, memory_states: torch.Tensor | None):
        """
        Return a small dict of memory time constants measured in *steps*.
        - latent_states_: [T, B, D]  (centered/whitened version used by latent_process)
        - memory_states : [T, B, Dm] or None
        """
        out = {}
        if not self.use_memory or memory_states is None:
            return out

        lp = self.latent_process
        dt = float(self.dt_eval)

        try:
            if lp.latent_type == "linear+memory":
                # tau = 1 / lambda; report stats in steps
                lam = lp.lambda_diag                      # [Dm]
                tau_steps = (1.0 / (lam + 1e-12)) / dt    # [Dm]
                v = tau_steps.detach().cpu().float()
                out.update({
                    "tau_steps_mean": float(v.mean()),
                    "tau_steps_med":  float(v.median()),
                    "tau_steps_p90":  float(v.quantile(0.9)),
                    "tau_steps_max":  float(v.max()),
                })

            elif lp.latent_type == "gru_memory":
                # Effective tau uses mean update gate z over time & batch
                Dm = lp.memory_dim
                # concat over all (t>0, b)
                concat = torch.cat([latent_states_[1:], memory_states[1:]], dim=-1)
                concat = concat.reshape(-1, lp.latent_dim + Dm)           # [(T-1)*B, D+D_m]
                z = torch.sigmoid(lp.gru_z(concat))                       # [(T-1)*B, Dm]
                z_mean = z.mean(dim=0)                                    # [Dm]
                tau_eff_steps = 1.0 / ((lp.tau_m_inv * z_mean) + 1e-12) / dt
                v = tau_eff_steps.detach().cpu().float()
                out.update({
                    "tau_eff_steps_mean": float(v.mean()),
                    "tau_eff_steps_p90":  float(v.quantile(0.9)),
                    "tau_raw_steps_mean": float((1.0 / (lp.tau_m_inv + 1e-12) / dt).mean().item()),
                })

            elif lp.latent_type == "lstm_memory":
                # Use forget gate f to estimate leakage; larger f -> longer memory.
                Dm = lp.memory_dim
                concat = torch.cat([latent_states_[1:], memory_states[1:]], dim=-1)
                concat = concat.reshape(-1, lp.latent_dim + Dm)           # [(T-1)*B, D+D_m]
                f = torch.sigmoid(lp.lstm_f(concat))                      # [(T-1)*B, Dm]
                leak = (1.0 - f.mean(dim=0)).clamp_min(1e-4)              # [Dm], avoid div-by-zero
                tau_eff_steps = 1.0 / ((lp.tau_c_inv * leak) + 1e-12) / dt
                v = tau_eff_steps.detach().cpu().float()
                out.update({
                    "tau_c_eff_steps_mean": float(v.mean()),
                    "tau_c_eff_steps_p90":  float(v.quantile(0.9)),
                    "tau_c_raw_steps_mean": float((1.0 / (lp.tau_c_inv + 1e-12) / dt).mean().item()),
                })
        except Exception:
            # Keep training robust even if shape/branch mismatch happens
            pass

        return out


    def train_phase2(self, phase1_path: str, log_every: int | None = None, eval_every: int | None = None, 
                     verbose: str = None):
        self.setup_logger()
        self.save_repro_artifacts()
        self.log_param_table()
        criterion = nn.MSELoss()

        # load model from phase1
        # load encoder-decoder, and linear parameter
        if not self.use_projector:
            phase1_info = self.load_phase1_ckpt(path=phase1_path, clip_positive_symmetric=False)
        # train_encdec = (self.cfg.lr_encdec != 0)
        train_dec = (self.cfg.lr_dec != 0)
        self.logger.info(f"finetune decoder = {train_dec}")
        # set_requires_grad(self.encoder, True) if train_encdec else set_requires_grad(self.encoder, False)    ##################
        set_requires_grad(self.encoder, False)
        set_requires_grad(self.decoder, True) if train_dec else set_requires_grad(self.decoder, False)    ##################
        set_requires_grad(self.latent_process, True)
        # self.latent_process.init_memory_time_constant(tau_in_steps=8, dt=self.dt_eval)    ##########################################

        # ---- estimate whitening scales (Phase-I encoder) & re-init Aw ----
        if self.use_diag_whiten and not self.use_projector:
            # prefer train_eval; fallback to train
            """try:
                if getattr(self, "train_eval_loader", None) is None:
                    self.build_dataloader(group="train_eval")
                self.fit_diag_whitening_from_phase1(group="train_eval", max_batches=16)
            except Exception:
                self.fit_diag_whitening_from_phase1(group="train", max_batches=16)"""
            self.fit_diag_whitening_from_phase1(group="train", max_batches=16)
            self._reinit_linear_in_whiten_space()

        self.logger.info(f"whiten vector: {self.whiten_scale}")
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        self.init_optim_phase2()
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        if verbose is not None:
            self.logger.info(f"{verbose}")

        ######## Initial Evaluation ########
        if self.train_eval_loader is not None:
            self.logger.info("--------Begin Evaluation on Train--------")
            train_eval_errs = self.evaluate(dataloader=self.train_eval_loader)
            self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
        # out-of-domain evaluation
        if self.test_loader is not None:
            self.logger.info("--------Begin Evaluation on Test--------")
            test_errs = self.evaluate(dataloader=self.test_loader)
            self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))

        tf_epsilon = self.tf_epsilon
        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        for epoch in range(1, self.cfg.epochs + 1):
            self.latent_process.train()
            for i, batch in enumerate(self.train_loader):
                ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
                sample_idx = batch["index"].to(self.device)     # [B,]
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
                t_eval = batch['t'][0].to(self.device)          # [T]
                bs, train_len, H, W, _ = ground_truth.shape
                assert train_len == self.n_frames_train

                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                mask_index = mask_index.unsqueeze(1).expand(-1, train_len-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth, d=self.n_frames_cond)    # [B, T-nf_cond+1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*t, H*W, n_ch*nf_cond]
                data_in = self.index_points(data_in, mask_index)    # [B*t, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*t, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                # latent_states: encoded from encoder / dyn_states: predicted using latent ode
                latent_token = self.encoder(data_in, pos_in)    # [B*t, K, latent_token]
                # reshape to [B, t, latent_dim]
                latent_states = latent_token.reshape((bs, train_len-self.n_frames_cond+1, -1))    # [B, t, latent_dim]
                latent_states = latent_states.permute(1, 0, 2)    # [t, B, latent_dim]

                latent_states_ = self._center_latent(latent_states)
                latent_states_ = self._whiten_latent(latent_states_)

                if self.use_projector:
                    latent_states_ = self._project_latent(latent_states_)
                    dyn_y, memory_states, aux = self.latent_process(
                        alpha_0=latent_states_[0], t_eval=t_eval[self.n_frames_cond-1:],
                        memory_init=None, teacher_forcing=True,
                        tf_alpha=latent_states_, tf_epsilon=tf_epsilon, tf_mask=None
                    )                                                          # [T',B,d]
                    dyn_loss = criterion(dyn_y, latent_states_.detach())
                    dyn_states_ = self._lift_latent(dyn_y)                     # lift back to [T',B,D]
                else:
                    dyn_states_, memory_states, aux = self.latent_process(
                        alpha_0=latent_states_[0], t_eval=t_eval[self.n_frames_cond-1:],
                        memory_init=None,
                        teacher_forcing=True, tf_alpha=latent_states_, tf_epsilon=tf_epsilon, tf_mask=None
                    )    # [t, B, latent_dim]
                    dyn_loss = criterion(dyn_states_, latent_states_.detach())
                    """long_term_loss, _ = self.kstep_rollout_loss(alpha_gt=latent_states_, t_eval=t_eval[self.n_frames_cond-1:], K=6, 
                                                            train_latent=False)"""
                corr = aux["phi_dec_l2"] if self.use_memory else None
                if self.latent_mode == "discrete" and self.latent_process.memory_type != "residual":
                    T_eff, B, D = latent_states_.shape
                    A = self.latent_process.A
                    res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.T)
                    mem_flatten = memory_states[:-1].reshape(-1, self.latent_process.memory_dim)
                    res_flatten = self.latent_process.memory_decoder(mem_flatten)
                    res_pred = (self.latent_process.gate * res_flatten).view(T_eff-1, B, D)
                    residual_loss = F.mse_loss(res_pred, res_gt)
                elif self.latent_mode == "discrete" and self.latent_process.memory_type == "residual":
                    A = self.latent_process.A
                    res_gt = latent_states_[1:] - (latent_states_[:-1] @ A.T)
                    res_pred = memory_states[:-1]
                    residual_loss = F.mse_loss(res_pred, res_gt)

                # mem_info = self._compute_memory_lengths(latent_states_, memory_states)
                dyn_states_ = self._unwhiten_latent(dyn_states_)   
                dyn_states = self._decenter_latent(dyn_states_)

                pred_field = self._decode_latent(latent_seqs=dyn_states)    # [B, t, H, W, C]
                pred_loss, pred_loss_wzmask = self._mse_loss(
                    pred_field, ground_truth[:, self.n_frames_cond-1:, ...], masks[:, self.n_frames_cond-1:, ...]
                )

                loss = dyn_loss + self.lambda_pred * pred_loss_wzmask
            
                if self.lambda_corr is not None and corr is not None:
                    loss += self.lambda_corr * corr
                """if self.cfg.lambda_lt_pred is not None:
                    loss += self.cfg.lambda_lt_pred * long_term_loss"""
                if self.latent_mode == "discrete":
                    loss += self.lambda_resid * residual_loss

                self.optim_dyn.zero_grad()
                if train_dec:
                    # self.optim_enc.zero_grad()
                    self.optim_dec.zero_grad()
                loss.backward()
                if self.cfg.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.latent_process.parameters(), self.cfg.max_grad_norm)    #######
                    if train_dec:
                        # torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.cfg.max_grad_norm)
                        torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.cfg.max_grad_norm)
                self.optim_dyn.step()
                if train_dec:
                    # self.optim_enc.step()
                    self.optim_dec.step()
                """if self.cfg.scheduler == "OneCycleLR":
                    self.scheduler_dyn.step()"""
                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    msg = [
                        f"Epoch {epoch:04d}/{self.cfg.epochs} | iteration {i+1:03d}",
                        f"| pred {pred_loss.item():.8f}",
                        f"| pred(mask) {pred_loss_wzmask.item():.8f}",
                        f"| dyn {dyn_loss.item():.8f}",
                        # f"| dyn (long term) {long_term_loss.item():.8f}"
                    ]
                    if self.latent_mode == "discrete":
                        msg += [f"| residual {residual_loss.item():.8f}"]
                        msg += [f"| correction ratio mean {aux['mem_ratio_mean']:.8f}"]
                        msg += [f"| correction ratio p90 {aux['mem_ratio_p90']:.8f}"]
                    if self.use_memory and (corr is not None):
                        msg += [f"| memory_corr {corr.item():.8f}"]
                        # msg += [f"| memory (relative) {L_mem_rel.item():.8f}"]
                        if self.latent_mode == "continuous":
                            msg += [F"| memory scale {float(self.latent_process.mem_scale.mean().item()):.4f}"]

                    msg += [f"| epsilon {tf_epsilon:.4f}"]
                    self.logger.info(" ".join(msg))
                if (epoch * len(self.train_loader) + i + 1) % self.cfg.update_every == 0:
                    tf_epsilon = max(tf_epsilon * self.epsilon, self.cfg.tf_epsilon_min)

                if eval_every is not None and (epoch * len(self.train_loader) + i + 1) % eval_every == 0:
                    if self.train_eval_loader is not None:
                        self.logger.info("--------Begin Evaluation on Train--------")
                        train_eval_errs = self.evaluate(dataloader=self.train_eval_loader)
                        self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
                        losses_tr = train_eval_errs["mse_losses"]
                        if loss_tr_min > losses_tr["loss"]:
                            loss_tr_min = losses_tr["loss"]
                            self.save_model(epoch, train_eval_errs, pth_name="model_tr_best.pth")
                    # out-of-domain evaluation
                    if self.test_loader is not None:
                        self.logger.info("--------Begin Evaluation on Test--------")
                        test_errs = self.evaluate(dataloader=self.test_loader)
                        self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))
                        losses_ts = test_errs["mse_losses"]
                        if loss_ts_min > losses_ts["loss"]:
                            loss_ts_min = losses_ts["loss"]
                            self.save_model(epoch, test_errs, pth_name="model_ts_best.pth")
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler_dyn.step()
                if train_dec:
                    # self.scheduler_enc.step()
                    self.scheduler_dec.step()
        self.logger.info("Training Finished! Saving model...")
        self.save_model(epoch=None, losses=None, pth_name="final_model.pth")


    @torch.no_grad()
    def evaluate(self, dataloader, model_pth: str | None = None):
        if model_pth is not None:
            _ = self.load_from_ckpt(ckpt_path=model_pth, device=self.device)
        
        err_dict = {}
        rel_err, mse_err = 0.0, 0.0
        rel_err_in_t, rel_err_out_t, mse_err_in_t, mse_err_out_t = 0.0, 0.0, 0.0, 0.0
        rmse_err, rmse_err_in_t, rmse_err_out_t = 0.0, 0.0, 0.0  
        rel_criterion = LpLoss(size_average=False)
        loss, loss_out_t, loss_in_t = 0.0, 0.0, 0.0
        num_samples = int(0)

        for batch in dataloader:
            ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            t_eval = batch['t'][0][self.n_frames_cond-1:].to(self.device)
            masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
            bs, length, H, W, _ = ground_truth.shape
            num_samples += bs
            
            with torch.no_grad():
                mask_index = self.mask_to_bs_index(mask=masks)    # [B, S]
                # mask_index = mask_index.unsqueeze(1).expand(-1, length-self.n_frames_cond+1, -1).flatten(0, 1)  # [B*t, S]
                delay_data = delay_stack_last_channel(x=ground_truth[:, :self.n_frames_cond, ...], d=self.n_frames_cond)    # [B, 1, ..., n_ch*nf_cond]
                data_in = delay_data.flatten(0, 1).flatten(1, 2)    # [B*1, H*W, n_ch*nf_cond]
                data_in = self.index_points(data_in, mask_index)    # [B*1, S, n_ch*nf_cond]
                pos_in = self.pos_feat.flatten(0, 1).unsqueeze(0).expand(data_in.shape[0], -1, -1).to(self.device)
                pos_in = self.index_points(pos_in, mask_index)    # [B*1, S, 2]
                data_in = torch.cat((data_in, pos_in), dim=-1) 

                # latent_states: encoded from encoder / dyn_states: predicted using latent ode
                latent_token = self.encoder(data_in, pos_in)    # [B*1, K, latent_token]
                # reshape to [B, latent_dim]
                latent_state = latent_token.reshape((bs, -1))    # [B, latent_dim]
                latent_state_ = self._center_latent(latent_state)
                latent_state_ = self._whiten_latent(latent_state_)
                if self.use_projector:
                    latent_state_ = self._project_latent(latent_state_)
                dyn_states_, _, _ = self.latent_process(alpha_0=latent_state_, t_eval=t_eval, teacher_forcing=False)    # [T, B, latent_dim]
                if self.use_projector:
                    dyn_states_ = self._lift_latent(dyn_states_)
                dyn_states_ = self._unwhiten_latent(dyn_states_)
                dyn_states = self._decenter_latent(dyn_states_)
                recon_seq = self._decode_latent(dyn_states)

                # compute losses
                # recon_seq, ground_truth: [B, T, H, W, s]
                n_cond = self.n_frames_cond - 1
                ground_truth_ = ground_truth[:, n_cond:, ...]
                masks_ = masks[:, n_cond:, ...]
                pred_in_t_, pred_out_t_ = recon_seq[:, :self.n_frames_train-n_cond, ...], recon_seq[:, self.n_frames_train-n_cond:, ...]
                gt_in_t_, gt_out_t_ = ground_truth_[:, :self.n_frames_train-n_cond, ...], ground_truth_[:, self.n_frames_train-n_cond:, ...]

                rel_err += rel_criterion(recon_seq.reshape(bs, -1), ground_truth_.reshape(bs, -1)).item()
                rel_err_in_t += rel_criterion(pred_in_t_.reshape(bs, -1), gt_in_t_.reshape(bs, -1)).item()
                rel_err_out_t += rel_criterion(pred_out_t_.reshape(bs, -1), gt_out_t_.reshape(bs, -1)).item()
                mse_err += self._compute_loss(recon_seq, ground_truth_) * bs
                mse_err_in_t += self._compute_loss(pred_in_t_, gt_in_t_) * bs
                mse_err_out_t += self._compute_loss(pred_out_t_, gt_out_t_) * bs
                rmse_err += torch.sqrt(self._compute_loss(recon_seq, ground_truth_)) * bs
                rmse_err_in_t += torch.sqrt(self._compute_loss(pred_in_t_, gt_in_t_)) * bs
                rmse_err_out_t += torch.sqrt(self._compute_loss(pred_out_t_, gt_out_t_)) * bs

        rel_err = rel_err / num_samples
        mse_err = mse_err / num_samples
        rmse_err = rmse_err / num_samples
        rel_err_in_t, rel_err_out_t = rel_err_in_t / num_samples, rel_err_out_t / num_samples
        mse_err_in_t, mse_err_out_t = mse_err_in_t / num_samples, mse_err_out_t / num_samples
        rmse_err_in_t, rmse_err_out_t = rmse_err_in_t / num_samples, rmse_err_out_t / num_samples

        rel_losses = {
            "loss": rel_err, "loss_in_t": rel_err_in_t, "loss_out_t": rel_err_out_t
        }
        mse_losses = {
            "loss": mse_err, "loss_in_t": mse_err_in_t, "loss_out_t": mse_err_out_t
        }
        rmse_losses = {
            "loss": rmse_err, "loss_in_t": rmse_err_in_t, "loss_out_t": rmse_err_out_t
        }
        err_dict.update({"rel_losses": rel_losses})
        err_dict.update({"mse_losses": mse_losses})
        err_dict.update({"rmse_losses": rmse_losses})
        return err_dict

            
    def _compute_loss(self, data1: torch.Tensor, data2: torch.Tensor, mask: torch.Tensor | None = None):
        # datas: [B, T, H, W, s], mask: [B, T, H, W, s] or None
        # check!!!!!!!!!!
        with torch.no_grad():
            if mask is None:
                criterion = nn.MSELoss(reduction="none")
                loss = criterion(data1, data2)
                mse = loss.mean()
            else:
                sqerr = (data1 - data2).pow(2) * mask
                sqerr_sum = sqerr.sum(dim=(2, 3))
                denom = mask.sum(dim=(2, 3)).clamp_(min=1e-6)
                mse = (sqerr_sum / denom).mean()
        return mse

    
    def save_model(self, epoch: int | None = None, losses: dict | None = None, pth_name: str = "model_tr.pth"):
        save_dict = {
            "args": vars(self.args) if self.args is not None else None,
            "epoch": epoch,
            "encoder": self.encoder.state_dict(),
            "latent_process": self.latent_process.state_dict(),
            "decoder": self.decoder.state_dict(),
            "losses": dict(losses) if isinstance(losses, dict) else losses
        }
        lc = getattr(self, "latent_center", None)
        if lc is not None:
            save_dict["latent_center"] = lc.detach().cpu().view(-1)
        ws = getattr(self, "whiten_scale", None)
        if ws is not None:
            save_dict["whiten_scale"] = ws.detach().cpu().view(-1)
        U = getattr(self, "U_proj", None)
        if U is not None:
            save_dict["U_proj"] = U.detach().cpu()
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        torch.save(save_dict, os.path.join(out_dir, f'{pth_name}'))


    def load_from_ckpt(self, ckpt_path: str, device: str | None = None):
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)

        device = device if device is not None else self.device

        self.encoder.load_state_dict(ckpt["encoder"])
        self.latent_process.load_state_dict(ckpt["latent_process"])
        self.decoder.load_state_dict(ckpt["decoder"])
        
        self.encoder.to(self.device)
        self.latent_process.to(device)
        self.decoder.to(self.device)

        lc = ckpt.get("latent_center", None)
        if lc is not None:
            # match dtype of latent_process params if possible
            p0 = next(self.latent_process.parameters(), None)
            dtype = p0.dtype if p0 is not None else torch.float32
            lc = lc.to(device=self.device, dtype=dtype).view(-1)
            if lc.numel() == self.latent_dim:
                self.latent_center = lc
            else:
                self.latent_center = None
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning(
                        f"[load_from_ckpt] latent_center dim {lc.numel()} != latent_dim {self.latent_dim}; dropped."
                    )
        else:
            self.latent_center = None

        ws = ckpt.get("whiten_scale", None) 
        if ws is not None:
            p0 = next(self.latent_process.parameters(), None)
            dtype = p0.dtype if p0 is not None else torch.float32
            ws = ws.to(device=self.device, dtype=dtype).view(-1)
            if ws.numel() == self.latent_dim and (ws > 0).all():
                self.set_whitening_scale(ws)
            else:
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning("whiten_scale invalid; ignored.")
        else:
            self.set_whitening_scale(None)

        U = ckpt.get("U_proj", None)
        if U is not None:
            try:
                self.set_projector(U)  # turn on use_projector
            except Exception as e:
                if hasattr(self, "logger") and self.logger is not None:
                    self.logger.warning(f"U_proj load failed; ignoring. err={e}")
        
        return {"epoch": ckpt.get("epoch", -1), "losses": ckpt.get("losses", None), "args": ckpt.get("args", None)}


    def _ensure_loader(self, group: str) -> None:
        """
        Lazily build dataloader for the given group if not yet built.
        group ∈ {"train", "train_eval", "test"}.
        """
        if group == "train" and self.train_loader is None:
            self.build_dataloader(group="train")
        elif group == "train_eval" and self.train_eval_loader is None:
            self.build_dataloader(group="train_eval")
        elif group == "test" and self.test_loader is None:
            self.build_dataloader(group="test")

    # ---------------------------------------------
    # Utilities for diagnostics & regularizers
    # ---------------------------------------------

    @staticmethod
    @torch.no_grad()
    def _spectral_radius(A: torch.Tensor) -> float:
        """Return spectral radius max|lambda_i(A)| for a real square matrix."""
        eig = torch.linalg.eigvals(A)  # complex
        return float(eig.abs().max().item())


    @staticmethod
    @torch.no_grad()
    def _fit_affine_Qc(X: torch.Tensor, Y: torch.Tensor, ridge: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Fit an affine map Y ≈ X @ Q.T + c using ridge least squares (row-vector convention).
        X, Y: [N, D]
        Returns:
            Q: [D, D], c: [D]
        """
        assert X.shape == Y.shape
        N, D = X.shape
        ones = torch.ones(N, 1, device=X.device, dtype=X.dtype)
        X1 = torch.cat([X, ones], dim=1)                 # [N, D+1]
        R = torch.zeros(D + 1, D + 1, device=X.device, dtype=X.dtype)
        R[:D, :D] = ridge * torch.eye(D, device=X.device, dtype=X.dtype)
        Theta = torch.linalg.solve(X1.T @ X1 + R, X1.T @ Y)  # [D+1, D]
        Q = Theta[:D, :].T.contiguous()
        c = Theta[D, :].contiguous()
        return Q, c


    @torch.no_grad()
    def _procrustes_error_affine(self, Z_old: torch.Tensor, Z_new: torch.Tensor, ridge: float = 1e-6) -> float:
        """
        Relative Procrustes error with affine map:
        err = ||Z_new - (Z_old @ Q.T + c)|| / ||Z_new||
        where (Q,c) is fitted by ridge regression.
        Z_old, Z_new: [N, D]
        """
        N = min(Z_old.size(0), Z_new.size(0))
        Z_old = Z_old[:N]
        Z_new = Z_new[:N]
        Q, c = self._fit_affine_Qc(Z_old, Z_new, ridge=ridge)
        pred = Z_old @ Q.T + c
        err = torch.norm(Z_new - pred) / (torch.norm(Z_new) + 1e-12)
        return float(err.item())


    
    
    def _multistep_latent_consistency(self,
                                      latent_states: torch.Tensor,  # [T, B, D]
                                      A: torch.Tensor,              # [D, D]
                                      b: torch.Tensor | None = None,
                                      H: int = 4,
                                      gamma: float = 1.0):
        T, B, D = latent_states.shape
        if H <= 0 or T < 2:
            return latent_states.new_tensor(0.0)
        H = min(H, T - 1)

        A_T = A.transpose(-1, -2)
        w_sum = 0.0
        loss_acc = latent_states.new_tensor(0.0)

        z_pred = latent_states[:-1]                 # [T-1, B, D]

        for h in range(1, H + 1):
            z_pred = z_pred @ A_T + (b if b is not None else 0.0)
            T_h = T - h
            z_use = z_pred[:T_h]                    # [T-h, B, D]
            target = latent_states[h:]              # [T-h, B, D]
            mse_h = (z_use - target).pow(2).mean()
            w = (gamma ** h)
            loss_acc = loss_acc + w * mse_h
            w_sum += w
            if h < H:
                z_pred = z_pred[:-1]                
        return loss_acc / max(w_sum, 1e-12)


    