import os
from pprint import pformat
import torch
import torch.nn as nn
from utilities.utils import set_requires_grad

import os
import torch
from torch.utils.data import Subset, DataLoader

from data.data_process import PDEDataProcessor
from exp.exp_basic import Exp_Basic, ExpConfigs
from utilities.losses import LpLoss
from utilities.vis import visualize_pred_vs_gt

from baselines.model_factory import DINOParamBundle
from baselines.DINO import build_field_decoder, build_latent_ode, build_set_encoder

"""
TODO:
    load dataloader logic unexpected when not train_eval??

    specialize self.pos_feat for DINO????
"""


class Exp_DINO(Exp_Basic):
    def __init__(self, args, exp_cfg: ExpConfigs, model_cfg: DINOParamBundle, data_processor: PDEDataProcessor):
        super(Exp_DINO, self).__init__(args, exp_cfg, model_cfg, data_processor)

        self.model_cfg = model_cfg.as_model_kwargs(include_meta=True)
        self.use_delay = self.model_cfg["use_delay"]
        assert self.model_cfg["n_frames_cond"] == data_processor.n_frames_cond
        self.n_cond = data_processor.n_frames_cond - 1
        self.state_dim  = self.model_cfg["state_dim"]
        self.code_dim   = self.model_cfg["code_dim"]
        self.latent_dim = self.model_cfg["latent_dim"]

        # dataloader, initialized in Exp_Basic
        # teacher forcing params
        self.tf_epsilon = self.epsilon = exp_cfg.tf_epsilon

        # load model
        self.load_model()
        self.log_param_table("Params after modules init (with states)")


    def build_dataloader(self, group: str):
        sample_map = {
            "train": self.train_sample_idx,
            "train_eval": self.train_eval_sample_idx,
            "test": self.test_sample_idx,
        }
        dataloader = self.data_processor.get_dataloader(group=group, samples=sample_map[group])
        if group == "train":
            self.train_loader = dataloader
        elif group == "train_eval":
            self.train_eval_loader = dataloader
        elif group == "test":
            self.test_loader = dataloader

    
    def load_model(self):
        self.field_decoder = build_field_decoder(x_grid=self.pos_feat, model_cfg=self.model_cfg["field_decoder"]).to(self.device)
        self.latent_process = build_latent_ode(
            model_cfg_ode=self.model_cfg["latent_ode"], model_cfg_latent=self.model_cfg["latent_process"]
        ).to(self.device)
        self.states_params = nn.ParameterList(
            [nn.Parameter(torch.zeros(self.n_frames_train, self.latent_dim).to(self.device)) 
             for _ in range(self.n_seqs_tr)]
        )    # [N_seqs, n_frames_train, latent_dim]
        self.set_encoder = build_set_encoder(model_cfg=self.model_cfg["set_encoder"]).to(self.device) if self.use_delay else None

    
    def count_parameters(self, include_states: bool = True) -> dict:
        """Return a dict with per-module and total trainable parameter counts."""
        def ntrainable(module):
            return sum(p.numel() for p in module.parameters() if p.requires_grad)

        counts = {
            "field_decoder": ntrainable(self.field_decoder) if hasattr(self, "field_decoder") else 0,
            "latent_process": ntrainable(self.latent_process) if hasattr(self, "latent_process") else 0,
            "set_encoder": ntrainable(self.set_encoder) if (self.use_delay and hasattr(self, "set_encoder")) else 0,
            "states_params": 0,
        }
        if include_states and isinstance(self.states_params, torch.nn.ParameterList):
            counts["states_params"] = sum(p.numel() for p in self.states_params)

        counts["total"] = sum(counts.values())
        return counts


    def log_param_table(self, title: str = "Trainable parameters", include_states: bool = True):
        c = self.count_parameters(include_states=include_states)
        lines = [
            f"{title}:",
            f"  field_decoder : {c['field_decoder']:,}",
            f"  latent_process: {c['latent_process']:,}",
            f"  set_encoder   : {c['set_encoder']:,}",
            f"  states_params : {c['states_params']:,}",
            f"  ---------------------------",
            f"  TOTAL         : {c['total']:,}",
        ]
        msg = "\n".join(lines)
        if hasattr(self, "logger") and self.logger is not None:
            self.logger.info(msg)
        else:
            print(msg)

    
    def switch_to_train(self):
        self.field_decoder.train()
        self.latent_process.train()
        self.set_encoder.train() if self.use_delay else None

    
    def switch_to_eval(self):
        self.field_decoder.eval()
        self.latent_process.eval()
        self.set_encoder.eval() if self.use_delay else None
        

    def init_optim(self):
        self.optim_field = torch.optim.Adam([{'params': self.field_decoder.parameters(), 'lr': self.lr}])
        self.optim_dyn = torch.optim.Adam([{'params': self.latent_process.parameters(), 'lr': self.lr / 10}])
        self.optim_latent = torch.optim.Adam([{'params': self.states_params, 'lr': self.lr / 10}])
        if self.use_delay:
            self.optim_setenc = torch.optim.Adam([{'params': self.set_encoder.parameters(), 'lr': self.lr / 10}])


    def train(self, log_every: int | None = None, eval_every: int | None = None):
        """
        batch['data']: [B, T, H, W, C]
        batch['t']: tensor of len n_frames_train / n_frames_train + n_frames_out
        batch['index']: sample idx 
        batch['mask']: [B, T, H, W, C]
        """
        
        self.setup_logger()    
        self.save_repro_artifacts()
        self.log_param_table("Params after modules init (with states)")
        criterion = nn.MSELoss()
        
        assert self.data_processor.mode == "interpolation", f"Mismatched dataloaders"
        self.build_dataloader(group="train")
        self.init_optim()
        if eval_every is not None:
            self.build_dataloader(group="test")
            self.build_dataloader(group="train_eval")
        self._save_split_and_samples()

        tf_epsilon = self.tf_epsilon
        loss_tr_min, loss_ts_min = float('inf'), float('inf')
        for epoch in range(1, self.cfg.epochs + 1):
            self.switch_to_train()
            for i, batch in enumerate(self.train_loader):
                ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch]
                sample_idx = batch["index"].to(self.device)     # [B,]
                masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
                bs, train_len, H, W, _ = ground_truth.shape
                assert train_len == self.n_frames_train
    
                latent_feats_params = torch.stack([self.states_params[d] for d in sample_idx], dim=0)    # [B, n_frames_train, latent_dim]
                latent_feats = latent_feats_params.view(bs, train_len, self.state_dim, self.code_dim)    # [B, n_frames_train, s, code_dim]
                recon_field = self.field_decoder(latent_feats)    # [B, T, H, W, s]
                sqerr = (ground_truth - recon_field).pow(2) * masks
                sqerr_sum = sqerr.sum(dim=(2, 3))
                denom = masks.sum(dim=(2, 3)).clamp_(min=1e-6)
                recon_loss = (sqerr_sum / denom).mean()

                self.optim_latent.zero_grad()
                recon_loss.backward()
                self.optim_latent.step()
                if self.use_delay and (epoch * len(self.train_loader) + i) % 4 == 0:
                    self.optim_field.step()
                    self.optim_field.zero_grad()

                # Update dynamics
                if self.use_delay:
                    extra_states = []
                    # latent_feats_params: [B, n_frames_train, latent_dim]
                    for jjj in range(self.n_frames_train - self.n_cond):
                        extra_states.append(
                            self.set_encoder(latent_feats_params[:, jjj:jjj+self.n_cond, :].detach().clone())
                        )    # [B, n_cond, latent_dim] -> [B, latent_dim]
                    extra_states = torch.stack(extra_states, dim=0)    # [T-n_cond, B, latent_dim]
                    augmented_states = torch.cat([extra_states, latent_feats_params[:, self.n_cond:, :].detach().clone().permute(1, 0, 2)],
                                                   dim=-1)    # [T-n_cond, B, latent_dim*2]
                    latent_feats_dyn, _ = self.latent_process(
                        alpha_0=augmented_states[0],    # [B, latent_dim*2]
                        t_eval=batch['t'][0][self.n_cond:].to(self.device),
                        teacher_forcing=True,
                        tf_alpha=augmented_states,
                        tf_epsilon=tf_epsilon
                    )    # [T-n_cond, B, latent_dim*2]
                    dyn_loss = criterion(latent_feats_dyn[:, :, self.latent_dim:], 
                                         latent_feats_params[:, self.n_cond:, :].detach().clone().permute(1, 0, 2))
                    dyn_loss.backward()
                    self.optim_dyn.step()
                    self.optim_setenc.step()
                    self.optim_dyn.zero_grad()
                    self.optim_setenc.zero_grad()
                else:
                    latent_feats_dyn, _ = self.latent_process(
                        alpha_0=latent_feats_params[:, 0, :],
                        t_eval=batch['t'][0].to(self.device),
                        teacher_forcing=True,
                        tf_alpha=latent_feats_params.detach().clone().permute(1, 0, 2),
                        tf_epsilon=tf_epsilon
                    )    # [T, B, latent_dim]
                    dyn_loss = criterion(latent_feats_dyn, latent_feats_params.detach().clone().permute(1, 0, 2))
                    dyn_loss.backward()

                if log_every is not None and (epoch * len(self.train_loader) + i) % log_every == 0:
                    self.logger.info(f"Epoch {epoch:04d}/{self.cfg.epochs} | iteration {i+1:03d} |"
                        f"| field reconstruction loss {recon_loss.item():.8f}, dyn_loss {dyn_loss.item():.8f} | epsilon {tf_epsilon:.4f}")
                """if (epoch * len(self.train_loader) + i + 1) % self.cfg.update_every == 0:
                    tf_epsilon *= self.epsilon"""

                if eval_every is not None and (epoch * len(self.train_loader) + i + 1) % eval_every == 0:
                    if self.train_eval_loader is not None:
                        self.logger.info("--------Begin Evaluation on Train--------")
                        self.switch_to_eval()
                        train_eval_errs = self.evaluate(
                            dataloader=self.train_eval_loader, states_params=self.states_params, 
                            learn_latent=False, lr_adapt=0.0, optim_steps=300
                        )
                        self.logger.info("Evaluation on train:\n%s", pformat(train_eval_errs, width=100, compact=False))
                        losses_tr = train_eval_errs["mse_losses"]
                        if loss_tr_min > losses_tr["loss"]:
                            loss_tr_min = losses_tr["loss"]
                            self.save_model(epoch, train_eval_errs, pth_name="model_tr_best.pth")
                        self.switch_to_train()
                    # out-of-domain evaluation
                    if self.test_loader is not None:
                        self.logger.info("--------Begin Evaluation on Test--------")
                        self.switch_to_eval()
                        test_errs = self.evaluate(
                            dataloader=self.test_loader, states_params=self.states_params, 
                            learn_latent=True, lr_adapt=self.cfg.lr_adapt, optim_steps=300
                        )
                        self.logger.info("Evaluation on test:\n%s", pformat(test_errs, width=100, compact=False))
                        losses_ts = test_errs["mse_losses"]
                        if loss_ts_min > losses_ts["loss"]:
                            loss_ts_min = losses_ts["loss"]
                            self.save_model(epoch, test_errs, pth_name="model_ts_best.pth")
                        self.switch_to_train()
            if epoch % self.cfg.update_every == 0:
                tf_epsilon *= self.epsilon
            if not self.use_delay:
                self.optim_field.step()
                self.optim_dyn.step()
                self.optim_field.zero_grad()
                self.optim_dyn.zero_grad()
        self.logger.info("Training Finished! Saving model...")
        self.save_model(epoch=None, losses=None, pth_name="final_model.pth")


    def evaluate(self, dataloader, model_pth: str | None = None, states_params: torch.Tensor | None = None,
                 learn_latent: bool = False, lr_adapt: float = 0.0, optim_steps: int = 300):
        """
        states_params: [n_seqs_tr, n_frames_train, latent_dim], only provided when learn_latent=False
        ----------------------------------------------------------------------------------------------
        In_t: loss within train horizon.
        Out_t: loss outside train horizon.
        In_s: loss within observation grid.
        Out_s: loss outside observation grid.
        loss: loss averaged across in_t/out_t and in_s/out_s
        loss_in_t: loss averaged across in_s/out_s for in_t.
        loss_in_t_in_s, loss_in_t_out_s: loss in_t + in_s / out_s
        """
        if model_pth is not None:
            _ = self.load_from_ckpt(ckpt_path=model_pth, device=self.device)
        
        err_dict = {}
        loss, loss_out_t, loss_in_t, loss_in_t_in_s, loss_in_t_out_s,\
                loss_out_t_in_s, loss_out_t_out_s = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        # if self.use_delay:
        #     set_requires_grad(self.set_encoder, False)
        # # set_requires_grad(self.field_decoder, False)
        # set_requires_grad(self.latent_process, False)
        num_samples = int(0)

        ground_truths, recon_seqs = [], []

        for batch in dataloader:
            ground_truth = batch["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            t_eval = batch['t'][0][self.n_cond:].to(self.device) if self.use_delay\
                  else batch['t'][0].to(self.device)        # should be of len [T]
            sample_idx = batch["index"].to(self.device)     # [B,]
            masks = batch["mask"].to(self.device)           # [B, T, H, W, n_ch]
            bs, length, H, W, _ = ground_truth.shape
            num_samples += bs
            if learn_latent and lr_adapt != 0:
                loss_best = 1e30
                latent_shape = (self.n_cond+1, self.latent_dim)
                states_params_out = nn.ParameterList(
                    [nn.Parameter(torch.zeros(latent_shape).to(self.device)) for _ in range(bs)]
                )    # [B, 1, latent_dim] or [B, n_cond+1, latent_dim]
                optim_latent_out = torch.optim.Adam(states_params_out, lr=lr_adapt)
                for i in range(optim_steps):
                    latent_feats_out = torch.stack(list(states_params_out), dim=0)
                    latent_feats_out = latent_feats_out.view(bs, latent_shape[0], self.state_dim, self.code_dim)    # [B, 1/n_cond+1, s, code_dim]
                    
                    recon_field = self.field_decoder(latent_feats_out)    # [B, 1/n_cond+1, H, W, s]
                    # print(latent_feats_out)
                    sqerr = (ground_truth[:, 0:self.n_cond+1, :] - recon_field).pow(2) * masks[:, 0:self.n_cond+1, :]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, 0:self.n_cond+1, :].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss = (sqerr_sum / denom).mean()
                    if recon_loss < loss_best:
                        loss_best = recon_loss.item()
                        best_latent_feats = latent_feats_out.detach().clone()
                    optim_latent_out.zero_grad()
                    recon_loss.backward()
                    optim_latent_out.step()
                # print(loss_best)
                latent_feats_params = best_latent_feats.view(bs, latent_shape[0], self.latent_dim)    # [B, 1+n_cond, latent_dim]
            else:
                with torch.no_grad():
                    """states_params_copy = states_params.detach().clone()
                    latent_feats_params = torch.stack(
                        [states_params_copy[d][0:self.n_cond+1] for d in sample_idx], dim=0
                    )    # [B, 1+n_cond, latent_dim]"""
                    sp = self._states_to_tensor(states_params)              # [N, T, latent_dim]
                    # print(sp)
                    latent_feats_params = sp.index_select(0, sample_idx)    # [B, T, latent_dim]
                    print(latent_feats_params[0][0])
                    latent_feats_params = latent_feats_params[:, 0:self.n_cond+1, :]  # [B, 1+n_cond, latent_dim]
            with torch.no_grad():
                if self.use_delay:
                    extra_state = self.set_encoder(latent_feats_params[:, :self.n_cond, :].detach().clone())    # [B, latent_dim]
                    augmented_state = torch.cat([extra_state, latent_feats_params[:, -1, :]], dim=-1)
                    dyn_codes, _ = self.latent_process(alpha_0=augmented_state, t_eval=t_eval, teacher_forcing=False)    # [T, B, latent_dim*2]
                    alpha_codes = dyn_codes[:, :, self.latent_dim:].permute(1, 0, 2).view(bs, t_eval.numel(), self.state_dim, self.code_dim)
                    recon_seq = self.field_decoder(alpha_codes)    # [B, T, H, W, s]
                else:
                    latent_feats_params = latent_feats_params.squeeze(1)
                    dyn_codes, _ = self.latent_process(alpha_0=latent_feats_params, t_eval=t_eval, teacher_forcing=False)    # [T, B, latent]
                    dyn_codes = dyn_codes.permute(1, 0, 2).view(bs, length, self.state_dim, self.code_dim)
                    recon_seq = self.field_decoder(dyn_codes)    # [B, T, H, W, s]
                
                # compute losses
                # recon_seq, ground_truth: [B, T, H, W, s]
                ground_truth_ = ground_truth[:, self.n_cond:, ...]
                masks_ = masks[:, self.n_cond:, ...]
                loss += self._compute_loss(recon_seq, ground_truth_) * bs
                loss_in_t += self._compute_loss(recon_seq[:, :self.n_frames_train-self.n_cond, ...], ground_truth_[:, :self.n_frames_train-self.n_cond, ...]) * bs
                loss_out_t += self._compute_loss(recon_seq[:, self.n_frames_train-self.n_cond:, ...], ground_truth_[:, self.n_frames_train-self.n_cond:, ...]) * bs
                if self.mask_ratio != 0.0:
                    loss_in_t_in_s += self._compute_loss(
                        recon_seq[:, :self.n_frames_train-self.n_cond, ...], ground_truth_[:, :self.n_frames_train-self.n_cond, ...], 
                        mask=masks_[:, :self.n_frames_train-self.n_cond, ...]
                    ) * bs
                    loss_in_t_out_s += self._compute_loss(
                        recon_seq[:, :self.n_frames_train-self.n_cond, ...], ground_truth_[:, :self.n_frames_train-self.n_cond, ...], 
                        mask=1.0-masks_[:, :self.n_frames_train-self.n_cond, ...]
                    ) * bs
                    loss_out_t_in_s += self._compute_loss(
                        recon_seq[:, self.n_frames_train-self.n_cond:, ...], ground_truth_[:, self.n_frames_train-self.n_cond:, ...], 
                        mask=masks_[:, self.n_frames_train-self.n_cond:, ...]
                    ) * bs
                    loss_out_t_out_s += self._compute_loss(
                        recon_seq[:, self.n_frames_train-self.n_cond:, ...], ground_truth_[:, self.n_frames_train-self.n_cond:, ...], 
                        mask=1.0-masks_[:, self.n_frames_train-self.n_cond:, ...]
                    ) * bs
                # ground_truths.append(ground_truth_.cpu())
                # recon_seqs.append(recon_seq.cpu())
        
        loss /= num_samples
        loss_in_t /= num_samples
        loss_out_t /= num_samples
        loss_out_t_in_s /= num_samples
        loss_out_t_out_s /= num_samples
        loss_in_t_in_s /= num_samples
        loss_in_t_out_s /= num_samples
        losses = {
            "loss": loss, "loss_in_t": loss_in_t, "loss_out_t": loss_out_t,
            "loss_out_t_in_s": loss_out_t_in_s, "loss_out_t_out_s": loss_out_t_out_s,
            "loss_in_t_in_s": loss_in_t_in_s, "loss_in_t_out_s": loss_in_t_out_s
        }
        set_requires_grad(self.field_decoder, True)
        set_requires_grad(self.latent_process, True)
        if self.use_delay:
            set_requires_grad(self.set_encoder, True)
        # return losses, ground_truths, recon_seqs
        err_dict["mse_losses"] = losses
        return err_dict

            
    def _compute_loss(self, data1: torch.Tensor, data2: torch.Tensor, mask: torch.Tensor | None = None):
        # datas: [B, T, H, W, s], mask: [B, T, H, W, s] or None
        # check!!!!!!!!!!
        with torch.no_grad():
            if mask is None:
                criterion = nn.MSELoss(reduce="none")
                loss = criterion(data1, data2)
                mse = loss.mean()
            else:
                sqerr = (data1 - data2).pow(2) * mask
                sqerr_sum = sqerr.sum(dim=(2, 3))
                denom = mask.sum(dim=(2, 3)).clamp_(min=1e-6)
                mse = (sqerr_sum / denom).mean()
        return mse

        
    def save_model(self, epoch: int | None = None, losses: dict | None = None, pth_name: str = "model_tr.pth"):
        save_dict = {
            "args": vars(self.args) if self.args is not None else None,
            "epoch": epoch,
            "field_decoder": self.field_decoder.state_dict(),
            "states_params": torch.stack([p.detach().cpu() for p in self.states_params], dim=0),  # [N_seqs, T_tr, D]
            "latent_process": self.latent_process.state_dict(),
            "set_encoder": self.set_encoder.state_dict() if self.use_delay else None,
            "losses": dict(losses) if isinstance(losses, dict) else losses
        }
        # out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}_{self.dataset}")
        # torch.save(save_dict, os.path.join(out_dir, f'{self.dataset}_{pth_name}'))
        out_dir = os.path.join(self.cfg.out_dir, f"{self.run_id}")
        torch.save(save_dict, os.path.join(out_dir, f'{pth_name}'))


    """def load_from_ckpt(self, ckpt_path: str, device: str | None = None):
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)

        device = device if device is not None else self.device
        self.field_decoder.load_state_dict(ckpt["field_decoder"]).to(device)
        self.latent_process.load_state_dict(ckpt["latent_process"]).to(device)
        self.set_encoder.load_state_dict(ckpt["set_encoder"]).to(device) if self.use_delay else None
        self.states_params = nn.Parameter(ckpt['states_params'].to(device))
        
        epoch  = ckpt.get("epoch", -1)
        losses = ckpt.get("losses", None)
        args   = ckpt.get("args", None)
        return {"epoch": epoch, "losses": losses, "args": args}"""


    def load_from_ckpt(self, ckpt_path: str, device: str | None = None):
        try:
            ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
        except TypeError:
            ckpt = torch.load(ckpt_path, map_location=device)

        device = device if device is not None else self.device

        self.field_decoder.load_state_dict(ckpt["field_decoder"])
        self.field_decoder.to(device)

        self.latent_process.load_state_dict(ckpt["latent_process"])
        self.latent_process.to(device)

        if self.use_delay and ckpt.get("set_encoder") is not None:
            self.set_encoder.load_state_dict(ckpt["set_encoder"])
            self.set_encoder.to(device)

        if "states_params" in ckpt:
            arr = ckpt["states_params"].to(device)  # [N_seqs, T_tr, D]
            self.states_params = nn.ParameterList([nn.Parameter(arr[i]) for i in range(arr.shape[0])])

        return {"epoch": ckpt.get("epoch", -1), "losses": ckpt.get("losses", None), "args": ckpt.get("args", None)}


    def _ensure_loader(self, group: str) -> None:
        """
        Lazily build dataloader for the given group if not yet built.
        group ∈ {"train", "train_eval", "test"}.
        """
        if group == "train" and self.train_loader is None:
            self.build_dataloader(group="train")
        elif group == "train_eval" and self.train_eval_loader is None:
            self.build_dataloader(group="train_eval")
        elif group == "test" and self.test_loader is None:
            self.build_dataloader(group="test")

    
    def sample_batch(self, group: str, batch_size: int, replace: bool = False):
        """
        Randomly sample one batch from a given split, independent of the
        loader's inherent batch size.

        Args:
            group: One of {"train", "train_eval", "test"}.
            batch_size: Desired batch size for visualization/rollout.
            replace: If True, sample with replacement when batch_size > N.

        Returns: one batch
        """
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                  "train_eval": self.train_eval_loader,
                  "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)

        if batch_size > N:
            replace = True

        # Use the same RNG style as data processor for reproducibility
        np_gen = getattr(self.data_processor, "np_gen", None)
        if np_gen is None:
            import numpy as np
            np_gen = np.random.default_rng(self.cfg.seed)

        indices = np_gen.choice(N, size=batch_size, replace=replace).tolist()
        subset = Subset(dataset, indices)

        # Build a temporary loader that yields exactly one batch
        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        batch = next(iter(tmp_loader))
        return batch

    
    def make_subset_from_saved(
        self,
        group: str,
        saved_json_path: str,
        batch_size: int | None = None
    ):    ########################################################################################################
        """
        Build a temporary DataLoader that contains the exact (traj_id, t0) pairs
        stored under `group` in split_metadata.json. This is useful for strict
        reproducible visualization/rollouts.

        Args:
            group: "train" / "train_eval" / "test"
            saved_json_path: path to split_metadata.json
            batch_size: optional batch size; if None, use len(index_list).

        Returns:
            A one-shot DataLoader that yields exactly those samples.
        """
        import json
        from torch.utils.data import Subset, DataLoader

        # Ensure dataset is constructed
        self._ensure_loader(group)
        dataset = {"train": self.train_loader.dataset,
                   "train_eval": self.train_eval_loader.dataset,
                   "test": self.test_loader.dataset}[group]

        with open(saved_json_path, "r") as f:
            payload = json.load(f)

        idx_list = payload.get("samples", {}).get(group, None)
        if idx_list is None:
            raise ValueError(f"No saved indices found for group '{group}' in {saved_json_path}")

        # Map (traj_id, t0) -> dataset indices: dataset.samples is a list of (traj_id, t0)
        # We build a position lookup for O(1) probing.
        position = { (int(tid), int(t0)): i for i, (tid, t0) in enumerate(dataset.samples) }
        take = []
        for tid, t0 in idx_list:
            key = (int(tid), int(t0))
            if key not in position:
                raise KeyError(f"Sample {key} not present in current dataset.samples")
            take.append(position[key])

        subset = Subset(dataset, take)
        if batch_size is None:
            batch_size = len(take)

        tmp_loader = DataLoader(
            subset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True
        )
        return tmp_loader


    def rollout_one_batch(self, batch_samples, rollout_steps: int, states_params: torch.Tensor | None = None,
                          learn_latent: bool = False, lr_adapt: float = 0.0, optim_steps: int = 300):
        # batch_samples, containing data_tensor: [B, T, C, H, W], using first n_cond frames (=n_frames_cond-1) to generate history state,
        # record prediciton at n_cond_frames + 1, ..., n_cond_frames + rollout_steps (< T)
        # using states_params only for seqs in training dataset
        self.switch_to_eval()
        with torch.no_grad():
            ground_truth = batch_samples["data"].to(self.device)    # [B, T, H, W, n_ch] (T = n_frames_train + n_frames_out)
            t_eval = batch_samples['t'][0][self.n_cond:].to(self.device)
            sample_idx = batch_samples["index"].to(self.device)     # [B,]
            masks = batch_samples["mask"].to(self.device)           # [B, T, H, W, n_ch]
            assert rollout_steps + self.data_processor.n_frames_cond <= ground_truth.shape[1]
            bs, length, H, W, _ = ground_truth.shape
            if learn_latent and lr_adapt != 0:
                loss_best = 1e30
                latent_shape = (self.n_cond+1, self.latent_dim)
                states_params_out = nn.ParameterList(
                    [nn.Parameter(torch.zeros(latent_shape).to(self.device)) for _ in range(bs)]
                )    # [B, 1, latent_dim] or [B, n_cond+1, latent_dim]
                optim_latent_out = torch.optim.Adam(states_params_out, lr=lr_adapt)
                latent_feats_out = torch.stack(list(states_params_out), dim=0)
                latent_feats_out = latent_feats_out.view(bs, latent_shape[0], self.state_dim, self.code_dim)    # [B, 1/n_cond+1, s, code_dim]

                for i in range(optim_steps):
                    recon_field = self.field_decoder(latent_feats_out)    # [B, 1/n_cond+1, H, W, s]
                    sqerr = (ground_truth[:, 0:self.n_cond+1, :] - recon_field).pow(2) * masks[:, 0:self.n_cond+1, :]
                    sqerr_sum = sqerr.sum(dim=(2, 3))
                    denom = masks[:, 0:self.n_cond+1, :].sum(dim=(2, 3)).clamp_(min=1e-6)
                    recon_loss = (sqerr_sum / denom).mean()
                    if recon_loss < loss_best:
                        loss_best = recon_loss.item()
                        best_latent_feats = latent_feats_out
                    optim_latent_out.zero_grad()
                    recon_loss.backward()
                    optim_latent_out.step()
                latent_feats_params = best_latent_feats.view(bs, latent_shape[0], self.latent_dim)    # [B, 1+n_cond, latent_dim]
            else:
                with torch.no_grad():
                    """states_params_copy = states_params.detach().clone()
                    latent_feats_params = torch.stack(
                        [states_params_copy[d][0:self.n_cond+1] for d in sample_idx], dim=0
                    )    # [B, 1+n_cond, latent_dim]"""
                    sp = self._states_to_tensor(states_params)              # [N, T, D]
                    latent_feats_params = sp.index_select(0, sample_idx)    # [B, T, D]
                    latent_feats_params = latent_feats_params[:, 0:self.n_cond+1, :]  # [B, 1+n_cond, D]
            with torch.no_grad():
                if self.use_delay:
                    extra_state = self.set_encoder(latent_feats_params[:, :self.n_cond, :].detach().clone())    # [B, latent_dim]
                    augmented_state = torch.cat([extra_state, latent_feats_params[:, -1, :]], dim=-1)
                    dyn_codes, _ = self.latent_process(alpha_0=augmented_state, t_eval=t_eval[:rollout_steps+1], teacher_forcing=False)    # [T, B, latent_dim*2]
                    alpha_codes = dyn_codes[:, :, self.latent_dim:].permute(1, 0, 2).view(bs, t_eval.numel(), self.state_dim, self.code_dim)
                    recon_seq = self.field_decoder(alpha_codes)    # [B, T, H, W, s]
                else:
                    latent_feats_params = latent_feats_params.squeeze(1)
                    dyn_codes, _ = self.latent_process(alpha_0=latent_feats_params, t_eval=t_eval[:rollout_steps+1], teacher_forcing=False)    # [T, B, latent]
                    dyn_codes = dyn_codes.permute(1, 0, 2).view(bs, length, self.state_dim, self.code_dim)
                    recon_seq = self.field_decoder(dyn_codes)    # [B, T, H, W, s]
                
                ground_truth_ = ground_truth[:, self.n_cond+1:self.n_cond+1+rollout_steps, ...].permute(0, 2, 3, 1, 4)
                recon_seq_ = recon_seq[:, 1:, ...].permute(0, 2, 3, 1, 4)
        
        return recon_seq_, ground_truth_    # [B, ..., rollout_steps, C]


    def illustrate_one_frame_pred(self, batch_samples, rollout_steps: int, out_dir: str):
        pred_tensor, true_tensor = self.rollout_one_batch(batch_samples, rollout_steps)    # [B, H, W, T, C]
        assert self.spatial_dim == 2
        out_png = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="last", batch_index=0,
            save_path=os.path.join(out_dir, 'one_frame_pred'),
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28,   # ↓ smaller -> closer
            header_cols_ratio=0.34,   # optionally shrink the column-title row too
            time_fontsize=12          # slightly smaller time label
        )
        print("Saved last-step prediction to:", out_png)

    
    def illustrate_long_term_pred(self, batch_samples, rollout_steps: int, out_dir: str):
        pred_tensor, true_tensor = self.rollout_one_batch(batch_samples, rollout_steps)    # [B, H, W, T, C]
        assert self.spatial_dim == 2
        out_gif = visualize_pred_vs_gt(
            pred_tensor, true_tensor,
            mode="all", batch_index=0,
            save_path=os.path.join(out_dir, 'long_term_seq.gif'),
            as_type="gif", fps=3,
            show_error=True, err_type="abs", add_colorbar=True,
            header_time_ratio=0.28, header_cols_ratio=0.34
        )
        print("Saved GIF to:", out_gif)


    def visualize_random_rollout(self, group: str, batch_size: int, rollout_steps: int,
                                 out_dir: str, mode: str = "last", show_error: bool = True):
        """
        Sample a random batch from the given split and visualize rollout.
        - group: "train" (seen IC), "train_eval", or "test" (new IC)
        - batch_size: arbitrary batch size for visualization (independent of training batch_size)
        - rollout_steps: number of prediction steps (<= T_out in data_y)
        - mode: "last" -> plot last-step panel; "all" -> export GIF/MP4 via your visualize utility
        """
        os.makedirs(out_dir, exist_ok=True)
        batch_samples = self.sample_batch(group=group, batch_size=batch_size)
        assert self.spatial_dim == 2
        if mode == "last":
            self.illustrate_one_frame_pred(batch_samples, rollout_steps, out_dir)
        else:
            self.illustrate_long_term_pred(batch_samples, rollout_steps, out_dir)

    
    def _states_to_tensor(self, states) -> torch.Tensor:
        """
        Convert states parameters to a Tensor [N, T, D] on current device.
        - nn.ParameterList -> stack + detach
        - Tensor           -> detach copy
        """
        if isinstance(states, nn.ParameterList):
            return torch.stack([p.detach().clone() for p in states], dim=0).to(self.device)
        elif torch.is_tensor(states):
            return states.detach().clone().to(self.device)
        else:
            raise TypeError(f"Unsupported states type: {type(states)}")


    def visualize_rollout_by_index(
        self,
        group: str,              # "train" | "train_eval" | "test"
        seq_index: int,          # zero-based index in the dataset order
        rollout_steps: int,
        out_dir: str,
        mode: str = "last",      # "last": plot last step; "all": export full GIF
    ):
        """
        Pick the `seq_index`-th sample (by dataset order) from the specified split,
        run a K-step rollout, and visualize the results.

        This bypasses any DataLoader shuffling by addressing the dataset directly.
        """
        import os
        from torch.utils.data import Subset, DataLoader

        os.makedirs(out_dir, exist_ok=True)

        # Ensure the corresponding loader (and hence dataset) exists
        self._ensure_loader(group)
        loader = {"train": self.train_loader,
                "train_eval": self.train_eval_loader,
                "test": self.test_loader}[group]
        dataset = loader.dataset
        N = len(dataset)
        if not (0 <= seq_index < N):
            raise IndexError(f"seq_index={seq_index} out of bounds (0..{N-1})")

        # Build a temporary DataLoader that yields exactly this single sample (batch_size=1)
        subset = Subset(dataset, [seq_index])
        tmp_loader = DataLoader(subset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
        batch_samples = next(iter(tmp_loader))

        # Visualize rollout
        if mode == "last":
            self.illustrate_one_frame_pred(batch_samples, rollout_steps, out_dir)
        else:
            self.illustrate_long_term_pred(batch_samples, rollout_steps, out_dir)







    