import os, time
import torch
from pprint import pformat
from torch.utils.data import Subset, DataLoader

from data.data_process import PDEDataProcessor
from baselines.model_factory import get_model
from exp.exp_basic import Exp_Basic, ExpConfigs
from utilities.losses import LpLoss
from utilities.vis import visualize_pred_vs_gt


class Exp_Dynamic_Autoregressive(Exp_Basic):
    def __init__(self, args, exp_cfg: ExpConfigs, model_cfg, data_processor: PDEDataProcessor):
        super(Exp_Dynamic_Autoregressive, self).__init__(args, exp_cfg, model_cfg, data_processor)
        # dataloader, initialized in Exp_Basic
        # model intializer
        self.model = get_model(model_name=exp_cfg.model_name, model_cfg=model_cfg)
        self.model.to(self.device)
    

    def build_dataloader(self, group: str):
        sample_map = {
            "train": self.train_sample_idx,
            "train_eval": self.train_eval_sample_idx,
            "test": self.test_sample_idx,
        }
        dataloader = self.data_processor.get_dataloader(group=group, samples=sample_map[group])
        if group == "train":
            self.train_loader = dataloader
        elif group == "train_eval":
            self.train_eval_loader = dataloader
        elif group == "test":
            self.test_loader = dataloader

    
    def count_parameters(self):
        total_params = 0
        for name, parameter in self.model.named_parameters():
            if not parameter.requires_grad: continue
            params = parameter.numel()
            total_params += params
        print(f"Total Trainable Params: {total_params}")
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(f"Total Trainable Params: {total_params}")
        return total_params

        
    def init_optim(self):
        if self.cfg.optimizer == 'AdamW':
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)
        elif self.cfg.optimizer == 'Adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)
        else: 
            raise ValueError('Optimizer only AdamW or Adam')
        if self.cfg.scheduler == 'OneCycleLR':
            self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.cfg.lr, epochs=self.cfg.epochs,
                                                            steps_per_epoch=len(self.train_loader),
                                                            pct_start=self.cfg.pct_start)
        elif self.cfg.scheduler == 'CosineAnnealingLR':
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.cfg.epochs)
        elif self.cfg.scheduler == 'StepLR':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.step_size, gamma=self.cfg.gamma)
        

    def train(self, log_every: int | None = None, eval_every: int | None = None):
        """
        batch['data']: (data_xx, data_y)
            - data_xx: [B, ..., T_in*C], T_in: conditional frames used (n_cond + 1)
            - data_y:  [B, ..., T_out, C] 
        """

        self.setup_logger()
        self.save_repro_artifacts()
        self.count_parameters()
        rel_criterion = LpLoss(size_average=False)
        mse_criterion = torch.nn.MSELoss()
        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        self.build_dataloader(group="train")
        self.init_optim()
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        for epoch in range(1, self.cfg.epochs + 1):
            self.model.train()
            rel_l2_step, rel_l2_full = 0.0, 0.0
            
            num_samples = 0
            for i, batch in enumerate(self.train_loader):
                loss = 0.0
                data_xx, data_y = batch['data']    # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out, C]
                pos_feat = self.pos_feat.reshape(-1, self.spatial_dim)     # [(*spatial_dims), d] -> [N_pt, d]
                pos_feat, data_xx, data_y = pos_feat.to(self.device), data_xx.to(self.device), data_y.to(self.device)
                bs, num_channels, t_out = data_y.shape[0], data_y.shape[-1], data_y.shape[-2]
                pos_feat = pos_feat.expand(bs, -1, -1)    # [B, N_pt, d]
                
                pred = []
                for t_indice in range(t_out):
                    next_y = data_y[..., t_indice, :]    # [B, ..., C]
                    next_y = self._from_grid(next_y)
                    # self.logger.info(f"shape of pos_feat: {pos_feat.shape}, data_xx: {data_xx.shape}")
                    # self.logger.info(f"shape of data_xx after transformation: {self._from_grid(data_xx).shape}")
                    im = self.model(fx=self._from_grid(data_xx), x=pos_feat)    # ([B, N_pt, T_in*C], [B, N_pt, d]) -> [B, N_pt, C]
                    loss += rel_criterion(x=im.reshape(bs, -1), y=next_y.reshape(bs, -1))    # sum over batch and times 
                    pred.append(self._to_grid(im))
                    
                    if self.cfg.teacher_forcing:
                        data_xx = torch.cat((data_xx[..., num_channels:], self._to_grid(next_y)), dim=-1)
                    else:
                        data_xx = torch.cat((data_xx[..., num_channels:], self._to_grid(im)), dim=-1)
                
                num_samples += bs
                rel_l2_step += loss.item()    # sum over samples and times
                pred_tensor = torch.stack(pred, dim=-2)    # stack [B, ..., C] to form [B, ..., T_out, C]
                rel_l2_full += rel_criterion(pred_tensor.reshape(bs, -1), data_y.reshape(bs, -1)).item()

                loss = loss / bs
                self.optimizer.zero_grad()
                loss.backward()

                if self.cfg.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
                self.optimizer.step()

                if self.cfg.scheduler == "OneCycleLR":
                    self.scheduler.step()
            
            if self.cfg.scheduler == 'CosineAnnealingLR' or self.cfg.scheduler == 'StepLR':
                self.scheduler.step()
            
            train_loss_step, train_loss_full = rel_l2_step / (num_samples * t_out), rel_l2_full / num_samples
            if log_every is not None and epoch % log_every == 0:
                self.logger.info(f"Epoch {epoch:04d}/{self.cfg.epochs} |"
                    f"| train loss step {train_loss_step:.8f}, train loss full {train_loss_full:.8f}")
            if eval_every is not None and epoch % eval_every == 0:
                train_eval_errs = self.evaluate(self.train_eval_loader)
                test_errs = self.evaluate(self.test_loader)
                train_eval_rel_err = train_eval_errs["rel_losses"]["loss"]
                test_rel_err = test_errs["rel_losses"]["loss"]

                self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
                self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))

                if train_eval_rel_err < loss_tr_min:
                    loss_tr_min = train_eval_rel_err
                    self.save_model(epoch, losses=None, pth_name="model_tr_best.pth")

        self.logger.info("Training Finished! Saving model...")
        self.save_model(epoch=None, losses=None, pth_name="final_model.pth")

    
    def evaluate(self, dataloader, model_pth: str | None = None):
        if model_pth is not None:
            _ = self.load_from_ckpt(model_pth, model=self.model, device=self.device)
        self.model.eval()
        err_dict = {}
        rel_err, mse_err = 0.0, 0.0
        rel_err_in_t, rel_err_out_t, mse_err_in_t, mse_err_out_t = 0.0, 0.0, 0.0, 0.0
        rmse_err, rmse_err_in_t, rmse_err_out_t = 0.0, 0.0, 0.0  
        rel_criterion = LpLoss(size_average=False)
        mse_criterion = torch.nn.MSELoss()

        in_t_id_end = self.n_frames_train - self.n_frames_cond
        
        with torch.no_grad():
            num_samples = 0
            for i, batch in enumerate(dataloader):
                data_xx, data_y = batch['data']    # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out, C]
                pos_feat = self.pos_feat.reshape(-1, self.spatial_dim)     # [(*spatial_dims), d] -> [N_pt, d]
                pos_feat, data_xx, data_y = pos_feat.to(self.device), data_xx.to(self.device), data_y.to(self.device)
                bs, num_channels, t_out = data_y.shape[0], data_y.shape[-1], data_y.shape[-2]
                num_samples += bs
                pos_feat = pos_feat.expand(bs, -1, -1)    # [B, N_pt, C]
                
                pred = []
                for t_indice in range(t_out):
                    im = self.model(fx=self._from_grid(data_xx), x=pos_feat)    # ([B, N_pt, T_in*C], [B, N_pt, d]) -> [B, N_pt, C]
                    pred.append(self._to_grid(im))
                    data_xx = torch.cat((data_xx[..., num_channels:], self._to_grid(im)), dim=-1)
            
                pred_tensor = torch.stack(pred, dim=-2)    # stack [B, ..., C] to form [B, ..., T_out, C]
                rel_err += rel_criterion(pred_tensor.reshape(bs, -1), data_y.reshape(bs, -1)).item()
                mse_err += mse_criterion(pred_tensor.reshape(bs, -1), data_y.reshape(bs, -1)).item() * bs
                rmse_err += torch.sqrt(mse_criterion(pred_tensor.reshape(bs, -1), data_y.reshape(bs, -1))).item() * bs

                pred_in_t_, pred_out_t_ = pred_tensor[..., :in_t_id_end, :], pred_tensor[..., in_t_id_end:, :]
                gt_in_t_, gt_out_t_ = data_y[..., :in_t_id_end, :], data_y[..., in_t_id_end:, :]
                rel_err_in_t += rel_criterion(pred_in_t_.reshape(bs, -1), gt_in_t_.reshape(bs, -1)).item()
                rel_err_out_t += rel_criterion(pred_out_t_.reshape(bs, -1), gt_out_t_.reshape(bs, -1)).item()
                mse_err_in_t += mse_criterion(pred_in_t_.reshape(bs, -1), gt_in_t_.reshape(bs, -1)).item() * bs
                mse_err_out_t += mse_criterion(pred_out_t_.reshape(bs, -1), gt_out_t_.reshape(bs, -1)).item() * bs
                rmse_err_in_t += torch.sqrt(mse_criterion(pred_in_t_.reshape(bs, -1), gt_in_t_.reshape(bs, -1))).item() * bs
                rmse_err_out_t += torch.sqrt(mse_criterion(pred_out_t_.reshape(bs, -1), gt_out_t_.reshape(bs, -1))).item() * bs
                
        rel_err = rel_err / num_samples
        mse_err = mse_err / num_samples
        rmse_err = rmse_err / num_samples
        rel_err_in_t, rel_err_out_t = rel_err_in_t / num_samples, rel_err_out_t / num_samples
        mse_err_in_t, mse_err_out_t = mse_err_in_t / num_samples, mse_err_out_t / num_samples
        rmse_err_in_t, rmse_err_out_t = rmse_err_in_t / num_samples, rmse_err_out_t / num_samples

        rel_losses = {
            "loss": rel_err, "loss_in_t": rel_err_in_t, "loss_out_t": rel_err_out_t
        }
        mse_losses = {
            "loss": mse_err, "loss_in_t": mse_err_in_t, "loss_out_t": mse_err_out_t
        }
        rmse_losses = {
            "loss": rmse_err, "loss_in_t": rmse_err_in_t, "loss_out_t": rmse_err_out_t
        }
        err_dict.update({"rel_losses": rel_losses})
        err_dict.update({"mse_losses": mse_losses})
        err_dict.update({"rmse_losses": rmse_losses})
        return err_dict

  
    def save_model(self, epoch: int | None = None, losses: dict | None = None, pth_name: str = "model_tr.pth"):
        save_dict = {
            "args": vars(self.args) if self.args is not None else None,
            "epoch": epoch,
            "model": self.model.state_dict(),
            "losses": dict(losses) if isinstance(losses, dict) else losses
        }
        # out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}_{self.dataset}")
        # torch.save(save_dict, os.path.join(out_dir, f'{self.dataset}_{pth_name}'))
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        torch.save(save_dict, os.path.join(out_dir, f'{pth_name}'))


    """def load_from_ckpt(self, ckpt_path: str, device: str | None = None):
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)

        self.model.load_state_dict(ckpt["model"])
        epoch  = ckpt.get("epoch", -1)
        losses = ckpt.get("losses", None)
        args   = ckpt.get("args", None)

        self.model.to(device) if device is not None else self.model.to(self.device)
        return {"model": self.model, "epoch": epoch, "losses": losses, "args": args}"""


    def load_from_ckpt(self, ckpt_path: str, device: str | None = None):
        import torch, pickle
        from torch.serialization import add_safe_globals
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except pickle.UnpicklingError:
            try:
                add_safe_globals([torch._C._nn.gelu])
            except Exception:
                pass
            try:
                ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
            except Exception:
                ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)
        state = ckpt.get("model", ckpt) if isinstance(ckpt, dict) else ckpt
        if not isinstance(state, dict):
            raise RuntimeError(f"Unexpected checkpoint format: {type(state)}")
        if "_metadata" in state:
            state = {k: v for k, v in state.items() if k != "_metadata"}
        self.model.load_state_dict(state, strict=True)
        (self.model.to(device) if device is not None else self.model.to(self.device))
        epoch  = ckpt.get("epoch", -1) if isinstance(ckpt, dict) else -1
        losses = ckpt.get("losses", None) if isinstance(ckpt, dict) else None
        args   = ckpt.get("args", None) if isinstance(ckpt, dict) else None
        return {"model": self.model, "epoch": epoch, "losses": losses, "args": args}


    def _to_grid(self, pts: torch.Tensor) -> torch.Tensor:
        # pts: [B, N, C] -> [B, *shapelist, C]
        B = pts.size(0)
        return pts.reshape(B, *self.shapelist, pts.shape[-1])


    def _from_grid(self, grid: torch.Tensor) -> torch.Tensor:
        # grid: [B, *shapelist, C] -> [B, N, C]
        B, C = grid.size(0), grid.size(-1)
        return grid.reshape(B, -1, C)


    def _ensure_loader(self, group: str) -> None:
        """
        Lazily build dataloader for the given group if not yet built.
        group ∈ {"train", "train_eval", "test"}.
        """
        if group == "train" and self.train_loader is None:
            self.build_dataloader(group="train")
        elif group == "train_eval" and self.train_eval_loader is None:
            self.build_dataloader(group="train_eval")
        elif group == "test" and self.test_loader is None:
            self.build_dataloader(group="test")

    
    def sample_batch(self, group: str, batch_size: int, replace: bool = False):
        """
        Randomly sample one batch from a given split, independent of the
        loader's inherent batch size.

        Args:
            group: One of {"train", "train_eval", "test"}.
            batch_size: Desired batch size for visualization/rollout.
            replace: If True, sample with replacement when batch_size > N.

        Returns:
            (data_xx, data_y) matching the training interface:
            - data_xx: [B, ..., T_in*C]
            - data_y : [B, ..., T_out, C]
        """
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                  "train_eval": self.train_eval_loader,
                  "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)

        if batch_size > N:
            replace = True

        # Use the same RNG style as data processor for reproducibility
        np_gen = getattr(self.data_processor, "np_gen", None)
        if np_gen is None:
            import numpy as np
            np_gen = np.random.default_rng(self.cfg.seed)

        indices = np_gen.choice(N, size=batch_size, replace=replace).tolist()
        subset = Subset(dataset, indices)

        # Build a temporary loader that yields exactly one batch
        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        batch = next(iter(tmp_loader))
        data_xx, data_y = batch['data']
        return data_xx, data_y

    
    def make_subset_from_saved(
        self,
        group: str,
        saved_json_path: str,
        batch_size: int | None = None
    ):    ########################################################################################################
        """
        Build a temporary DataLoader that contains the exact (traj_id, t0) pairs
        stored under `group` in split_metadata.json. This is useful for strict
        reproducible visualization/rollouts.

        Args:
            group: "train" / "train_eval" / "test"
            saved_json_path: path to split_metadata.json
            batch_size: optional batch size; if None, use len(index_list).

        Returns:
            A one-shot DataLoader that yields exactly those samples.
        """
        import json
        from torch.utils.data import Subset, DataLoader

        # Ensure dataset is constructed
        self._ensure_loader(group)
        dataset = {"train": self.train_loader.dataset,
                   "train_eval": self.train_eval_loader.dataset,
                   "test": self.test_loader.dataset}[group]

        with open(saved_json_path, "r") as f:
            payload = json.load(f)

        idx_list = payload.get("samples", {}).get(group, None)
        if idx_list is None:
            raise ValueError(f"No saved indices found for group '{group}' in {saved_json_path}")

        # Map (traj_id, t0) -> dataset indices: dataset.samples is a list of (traj_id, t0)
        # We build a position lookup for O(1) probing.
        position = { (int(tid), int(t0)): i for i, (tid, t0) in enumerate(dataset.samples) }
        take = []
        for tid, t0 in idx_list:
            key = (int(tid), int(t0))
            if key not in position:
                raise KeyError(f"Sample {key} not present in current dataset.samples")
            take.append(position[key])

        subset = Subset(dataset, take)
        if batch_size is None:
            batch_size = len(take)

        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        return tmp_loader


    def rollout_one_batch(self, data_xx: torch.Tensor, data_y: torch.Tensor, rollout_steps: int):
        # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out_y, C] (T_out_y >= rollout_steps)
        self.model.eval()
        assert rollout_steps <= data_y.shape[-2]
        with torch.no_grad():
            pos_feat = self.pos_feat.reshape(-1, self.spatial_dim)     # [(*spatial_dims), d] -> [N_pt, d]
            pos_feat, data_xx, data_y = pos_feat.to(self.device), data_xx.to(self.device), data_y.to(self.device)
            bs, num_channels = data_y.shape[0], data_y.shape[-1]
            pos_feat = pos_feat.expand(bs, -1, -1)    # [B, N_pt, C]

            pred = []
            for t_indice in range(rollout_steps):
                im = self.model(fx=self._from_grid(data_xx), x=pos_feat)    # ([B, N_pt, T_in*C], [B, N_pt, d]) -> [B, N_pt, C]
                pred.append(self._to_grid(im))
                data_xx = torch.cat((data_xx[..., num_channels:], self._to_grid(im)), dim=-1)
        
            pred_tensor = torch.stack(pred, dim=-2)    # stack [B, ..., C] to form [B, ..., T_out, C]
        
        return pred_tensor, data_y[..., :rollout_steps, :]    # [B, ..., T_out, C]


    def illustrate_one_frame_pred(self, data_xx: torch.Tensor, data_y: torch.Tensor, rollout_steps: int, out_dir: str):
        # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out, C]
        pred_tensor, true_tensor = self.rollout_one_batch(data_xx, data_y, rollout_steps)    # [B, H, W, T, C]
        assert self.spatial_dim == 2
        out_png = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="last", batch_index=0,
            save_path=os.path.join(out_dir, 'one_frame_pred'),
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28,   # ↓ smaller -> closer
            header_cols_ratio=0.34,   # optionally shrink the column-title row too
            time_fontsize=12          # slightly smaller time label
        )
        print("Saved last-step prediction to:", out_png)

    
    def illustrate_long_term_pred(self, data_xx: torch.Tensor, data_y: torch.Tensor, rollout_steps: int, out_dir: str):
        # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out, C]
        pred_tensor, true_tensor = self.rollout_one_batch(data_xx, data_y, rollout_steps)    # [B, H, W, T, C]
        assert self.spatial_dim == 2
        out_gif = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="all", batch_index=0,
            save_path=os.path.join(out_dir, 'long_term_seq.gif'),
            as_type="gif", fps=3,
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28, header_cols_ratio=0.34
        )
        print("Saved GIF to:", out_gif)


    def visualize_random_rollout(self, group: str, batch_size: int, rollout_steps: int,
                                 out_dir: str, mode: str = "last", show_error: bool = True):
        """
        Sample a random batch from the given split and visualize rollout.
        - group: "train" (seen IC), "train_eval", or "test" (new IC)
        - batch_size: arbitrary batch size for visualization (independent of training batch_size)
        - rollout_steps: number of prediction steps (<= T_out in data_y)
        - mode: "last" -> plot last-step panel; "all" -> export GIF/MP4 via your visualize utility
        """
        os.makedirs(out_dir, exist_ok=True)
        data_xx, data_y = self.sample_batch(group=group, batch_size=batch_size)

        if self.spatial_dim != 2:
            # For non-2D, just run rollout and return tensors (no visualization here).
            pred_tensor, true_tensor = self.rollout_one_batch(data_xx, data_y, rollout_steps)
            return pred_tensor, true_tensor

        # 2D visualization using your existing helpers
        if mode == "last":
            self.illustrate_one_frame_pred(data_xx, data_y, rollout_steps, out_dir)
        else:
            self.illustrate_long_term_pred(data_xx, data_y, rollout_steps, out_dir)


    def visualize_rollout_by_index(
        self,
        group: str,              # "train" | "train_eval" | "test"
        seq_index: int,          # zero-based index in the dataset order
        rollout_steps: int,
        out_dir: str,
        mode: str = "last",      # "last": plot last step; "all": export full GIF
    ):
        """
        Pick the `seq_index`-th sample (by dataset order) from the specified split,
        run a K-step rollout, and visualize the results.

        This bypasses any DataLoader shuffling by addressing the dataset directly.
        """
        import os
        from torch.utils.data import Subset, DataLoader

        os.makedirs(out_dir, exist_ok=True)

        # Ensure the corresponding loader (and hence dataset) exists
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)
        if not (0 <= seq_index < N):
            raise IndexError(f"seq_index={seq_index} out of bounds (0..{N-1})")

        # Build a temporary DataLoader that yields exactly this single sample (batch_size=1)
        subset = Subset(dataset, [seq_index])
        tmp_loader = DataLoader(subset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
        batch_samples = next(iter(tmp_loader))
        data_xx, data_y = batch_samples['data']

        # Visualize rollout
        if mode == "last":
            self.illustrate_one_frame_pred(data_xx, data_y, rollout_steps, out_dir)
        else:
            self.illustrate_long_term_pred(data_xx, data_y, rollout_steps, out_dir)


    def _build_long_eval_loader(self, group: str = "test",
                                rollout_steps: int | None = None,
                                batch_size: int | None = None):
        from torch.utils.data import DataLoader
        from data.data_process import PDEDataset

        p = self.data_processor
        cfg = p.cfg
        traj_indices = p.train_traj_ids if group in {"train", "train_eval"} else p.test_traj_ids
        samples = [(int(tid), 0) for tid in traj_indices]

        dataset = PDEDataset(
            data_tensor=(p.data_norm if cfg.normalize else p.data),
            cfg=cfg,
            n_frames_train=cfg.n_frames_train,
            n_frames_cond=cfg.n_frames_cond,
            n_frames_out=(cfg.n_frames_out if rollout_steps is None else int(rollout_steps)),
            traj_indices=traj_indices,
            n_sample_per_traj=1,
            sample_strategy="disjoint",
            mode="autoregressive",
            group=("train_eval" if group == "train" else group),
            samples=samples,
            mask_tensor=p.mask_tensor,
            np_rng=p.np_gen
        )
        bs_default = cfg.train_bs if group in {"train", "train_eval"} else cfg.test_bs  # NEW
        bs = (bs_default if batch_size is None else int(batch_size))                    # NEW
        loader = DataLoader(dataset, batch_size=bs, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        return loader


    @torch.no_grad()
    def evaluate_long_trajs(
        self, out_dir: str,
        group: str = "test",
        rollout_steps: int | None = None,
        batch_size: int | None = None,
        loader=None,
        save_pt: bool = True,
        save_png: bool = False,
        vis_cols: int = 8,
        c_vis: int = 0,
        save_time_curve: bool = True
    ):
        import os, json, numpy as np
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt

        os.makedirs(out_dir, exist_ok=True)
        self.model.eval()

        if loader is None:
            loader = self._build_long_eval_loader(group=group, rollout_steps=rollout_steps, batch_size=batch_size)
        dataset = loader.dataset
        base_dataset = getattr(dataset, "dataset", dataset)
        samples_list = base_dataset.samples  # List[(traj_id, t0)]

        n_cond = int(self.n_frames_cond)
        all_mse = []
        t_eval_ref   = None
        curve_sum    = None
        curve_sumsq  = None
        curve_count  = None

        begin = time.time()
        for batch in loader:
            data_xx, data_y = batch['data']    # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out, C]
            t_eval = torch.arange(n_cond, n_cond+rollout_steps) * 1.0
            # data_xx: [B, ..., T_in*C], data_y: [B, ..., T_out_y, C] (T_out_y >= rollout_steps)
            pred, gt = self.rollout_one_batch(data_xx, data_y, rollout_steps=rollout_steps)  # [B,H,W,K,C], [B,H,W,K,C]
            mask_slice = None
            if "mask" in batch and batch["mask"] is not None:
                mask_full = batch["mask"].to(self.device)          # [B, T, H, W, C]
                mask_slice = mask_full[:, n_cond:n_cond+rollout_steps, ...].permute(0, 2, 3, 1, 4).contiguous()

            if mask_slice is None:
                se     = (pred - gt).pow(2)                           # [B,H,W,K,C]
                mse_b  = se.mean(dim=(1,2,3,4))                       # [B]
                mse_t  = se.mean(dim=(0,1,2,4)).detach().cpu()        # [K] （batch 均值）
                mse_bt = se.mean(dim=(1,2,4))                         # [B,K]（每样本逐时间）
                valid_bt = torch.ones_like(mse_bt, dtype=mse_bt.dtype)
            else:
                se     = (pred - gt).pow(2) * mask_slice
                num_b  = se.sum(dim=(1,2,3,4)); den_b  = mask_slice.sum(dim=(1,2,3,4)).clamp_min(1e-6)
                mse_b  = (num_b / den_b)                               # [B]
                num_t  = se.sum(dim=(0,1,2,4)); den_t = mask_slice.sum(dim=(0,1,2,4)).clamp_min(1e-6)
                mse_t  = (num_t / den_t).detach().cpu()                # [K]
                num_bt = se.sum(dim=(1,2,4));   den_bt = mask_slice.sum(dim=(1,2,4)).clamp_min(1e-6)
                mse_bt = (num_bt / den_bt)                             # [B,K]
                valid_bt = (den_bt > 0).to(mse_bt.dtype)

            all_mse.extend([float(x) for x in mse_b])

            # 聚合“误差-时间”曲线
            if save_time_curve:
                mse_bt_cpu   = mse_bt.detach().cpu()
                valid_bt_cpu = valid_bt.detach().cpu()
                if curve_sum is None:
                    import torch as _torch
                    Tprime = int(mse_bt_cpu.shape[1])
                    curve_sum   = _torch.zeros(Tprime)
                    curve_sumsq = _torch.zeros(Tprime)
                    curve_count = _torch.zeros(Tprime)
                    t_eval_ref  = t_eval.detach().cpu()
                curve_sum   += (mse_bt_cpu * valid_bt_cpu).sum(dim=0)
                curve_sumsq += ((mse_bt_cpu**2) * valid_bt_cpu).sum(dim=0)
                curve_count += valid_bt_cpu.sum(dim=0)

            # 导出每条样本
            if save_pt or save_png:
                idx_vec = batch["index"].tolist()   # 与 dataset.samples 对齐
                for b, ds_idx in enumerate(idx_vec):
                    traj_id, t0 = samples_list[int(ds_idx)]
                    stem = f"traj{int(traj_id):04d}_t0{int(t0):04d}"

                    if save_pt:
                        payload = {
                            "traj_id": int(traj_id),
                            "t0": int(t0),
                            "pred": pred[b].detach().cpu(),                      # [H,W,K,C]
                            "gt":   gt[b].detach().cpu(),                        # [H,W,K,C]
                            "mse_all": float(mse_b[b].item()),
                            "mse_t_batch_mean": mse_t,                           # [K]（本 batch 均值，便于参考）
                            "mse_t": mse_bt[b].detach().cpu(),                   # [K]（该样本自己的逐时间曲线）
                            "normalized": bool(self.data_processor.cfg.normalize),
                        }
                        torch.save(payload, os.path.join(out_dir, f"{stem}.pt"))

                    if save_png and self.spatial_dim == 2:
                        try:
                            import numpy as _np
                            Tsel = min(vis_cols, rollout_steps)
                            idx  = _np.linspace(0, rollout_steps-1, Tsel, dtype=int)
                            gt_np  = gt[b].detach().cpu().numpy()                # [H,W,K,C]
                            pr_np  = pred[b].detach().cpu().numpy()              # [H,W,K,C]
                            err_np = _np.abs(gt_np - pr_np)

                            rows = 3
                            fig, axes = plt.subplots(rows, Tsel, figsize=(2.6*Tsel, 2.6*rows), squeeze=False)
                            for j, ti in enumerate(idx):
                                # axes[0, j].imshow(gt_np[:, :, ti, c_vis]); axes[0, j].set_title(f"GT t={float(t_eval[ti]):.2f}")
                                axes[0, j].imshow(gt_np[:, :, ti, c_vis])
                                axes[1, j].imshow(pr_np[:, :, ti, c_vis]); axes[1, j].set_title("Pred")
                                axes[2, j].imshow(err_np[:, :, ti, c_vis]); axes[2, j].set_title("|Err|")
                                for r in range(rows): axes[r, j].axis('off')
                            plt.tight_layout()
                            plt.savefig(os.path.join(out_dir, f"{stem}.png"), dpi=160)
                            plt.close(fig)
                        except Exception:
                            pass
            
            print(f"[{group}] batch_size={pred.shape[0]} | K={rollout_steps} | MSE_mean(batch)={float(mse_b.mean().item()):.6e}")
        end = time.time()
        # 汇总整体标量
        import numpy as _np
        mean_mse = float(_np.mean(all_mse)) if all_mse else float("nan")
        std_mse  = float(_np.std(all_mse))  if all_mse else float("nan")
        summary = {"group": group, "n_samples": len(all_mse), "MSE_mean": mean_mse, "MSE_std": std_mse, "run_time": begin - end}
        with open(os.path.join(out_dir, f"summary_{group}.json"), "w") as f:
            json.dump(summary, f, indent=2)
        print(f"==> [{group}] Long-eval(batched) summary: MSE_mean={mean_mse:.6e} over {len(all_mse)} samples")

        # 保存“误差-时间”曲线（JSON + 可视化）
        if save_time_curve and (curve_count is not None) and (curve_count.max().item() > 0):
            mean_curve = (curve_sum / curve_count).numpy()
            var_curve  = (curve_sumsq / curve_count).numpy() - mean_curve**2
            var_curve  = np.clip(var_curve, 0.0, None)
            std_curve  = np.sqrt(var_curve)

            t_np = t_eval_ref.numpy()
            # 可视化（可选）
            fig, ax = plt.subplots(figsize=(7, 4))
            ax.plot(t_np, mean_curve, label="MSE (mean across traj)")
            ax.fill_between(t_np, mean_curve-std_curve, mean_curve+std_curve, alpha=0.2, label="±1 std")
            ax.set_xlabel("t"); ax.set_ylabel("MSE")
            ax.set_title(f"MSE vs time — {group}")
            ax.grid(True, alpha=0.3); ax.legend()
            fig.tight_layout()
            plt.savefig(os.path.join(out_dir, f"mse_over_time_{group}.png"), dpi=180)
            plt.close(fig)

            # JSON 导出
            out_json = os.path.join(out_dir, f"mse_over_time_{group}.json")
            series = {
                "group": group,
                "t":    [float(x) for x in t_np.tolist()],
                "mean": [float(x) for x in mean_curve.tolist()],
                "std":  [float(x) for x in std_curve.tolist()],
                "count":[int(x)   for x in curve_count.detach().cpu().numpy().astype(int).tolist()]
            }
            with open(out_json, "w") as f:
                json.dump(series, f, indent=2)

        return summary


    # NEW
    def evaluate_long_by_indices(self, out_dir: str, group: str, indices: list[int],
                                rollout_steps: int | None = None, **kwargs):
        """
        只评估给定 indices 的若干条样本（按当前 _build_long_eval_loader 的 dataset 顺序取 Subset）。
        """
        from torch.utils.data import Subset, DataLoader
        base_loader = self._build_long_eval_loader(group=group, rollout_steps=rollout_steps)
        subset = Subset(base_loader.dataset, indices)
        sub_loader = DataLoader(subset, batch_size=len(indices), shuffle=False, num_workers=0, pin_memory=True)
        return self.evaluate_long_trajs(group=group, out_dir=out_dir, loader=sub_loader, rollout_steps=rollout_steps, **kwargs)
