import gc
import hashlib
import json
import os
import time
from pathlib import Path
import collections.abc
from types import SimpleNamespace

import arrow
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import torchvision.transforms.v2
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from einops import rearrange, repeat
import mediapy as media

from .metrics.codebook import CodebookMetric
from .metrics.loss import LossMetric
from .models import BaseOutput, PretrainedModel
from .utils.count_params import count_parameters
from .utils.json_encoder import UniversalJSONEncoder
from .utils.checkpointing import CheckpointSaver, ResumeResult, unwrap_model
from .utils.optimization import configure_optimizer
from .utils.tracking import Metrics, TrackingConfig
from .utils.inference_utils import prepare_multi_resolution_info
from .training_mode_manager import TrainingModeManager
from .trainer_probe_head import OfflineActionProber, OfflineActionProberOneLayer, ActionToVQAdapter


def dict_to_namespace(d):
    if isinstance(d, dict):
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    elif isinstance(d, list):
        return [dict_to_namespace(elem) for elem in d]
    else:
        return d

def namespace_to_dict(obj):
    if isinstance(obj, SimpleNamespace):
        return {k: namespace_to_dict(v) for k, v in obj.__dict__.items()}
    elif isinstance(obj, collections.abc.Mapping):
        return {k: namespace_to_dict(v) for k, v in obj.items()}
    elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
        return [namespace_to_dict(elem) for elem in obj]
    else:
        return obj

def torch_gc():
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()

def prepare_dataset_cfg(ds_config, wm_model_cfg):
    import math
    from .datasets.auto.mixture.ego import EGO_FULL_V2

    ds_config['common_config']['sample_length'] = wm_model_cfg.sample_length
    for dataset_name, per_cfg in ds_config['per_dataset_config'].items():
        ori_sample_interval = per_cfg['sample_interval']
        per_cfg['sample_interval'] = 1 if ori_sample_interval <= wm_model_cfg.vae_max_compress_rate \
            else math.ceil(ori_sample_interval / wm_model_cfg.vae_max_compress_rate)
        per_cfg['micro_frame_size'] = round(ori_sample_interval / per_cfg['sample_interval'])
        per_cfg['expect_valid_frames'] = wm_model_cfg.train_wm_seq_length * per_cfg['micro_frame_size']
        if wm_model_cfg.allow_padding_mask:
            per_cfg['min_valid_frames'] = int(per_cfg['expect_valid_frames'] / 2)
            if dataset_name == "egohod_ego4d":
                per_cfg['min_valid_frames'] = int(per_cfg['expect_valid_frames'] / 4)
        else:
            per_cfg['min_valid_frames'] = per_cfg['expect_valid_frames']
        per_cfg['min_right_samples'] = per_cfg['min_valid_frames']
        if dataset_name in dict(EGO_FULL_V2):
            per_cfg['sample_policy'] = 'fixed'
            per_cfg['val_sample_policy'] = 'fixed'
        ds_config['per_dataset_config'][dataset_name] = per_cfg

    ds_config['common_config'].update(dict(
        allow_padding=wm_model_cfg.allow_padding_mask,
        padding_side='right',
        resolution=(wm_model_cfg.input_height, wm_model_cfg.input_width),
        export_data_statistics=False,
    ))

    target = ds_config['target']
    return target, ds_config


class TrainerConfigFDM(BaseModel):
    flatten_state_action: bool = True
    n_max_state_action: int | None = None
    gradient_accumulate_steps: int = 1
    num_workers: int = 8
    batch_size: int = 64
    base_lr: float = 1.5e-4
    min_lr: float = 0.0
    warmup_epochs: int = 2
    warmup_steps: int = 2000
    max_epochs: int = 40
    val_samples: int = 20000
    val_batch_size: int | None = None
    val_interval: int = 1000
    weight_decay: float = 0.01
    betas: tuple[float, float] = (0.9, 0.95)
    scheduler: str = "linear-warmup+cosine-decay"
    optimizer: str = "adamw"
    wm_lr: float | None = None
    run_id: str | None = None
    run_dir: str = "outputs"
    model_save_dir: str | None = None
    ckpt_dir: str | None = None
    init_checkpoint: str | None = None
    input_keys: list[str] = ["observation.image_primary", "observation.pad_mask"]
    tracking: TrackingConfig = TrackingConfig()
    datasetwise_itv: bool = True
    start_new_wandb_run_for_new_finetune: bool = False
    seed: int = 1024
    clip_grad: float = 0.0
    val_video_steps: int = 15
    train_prober_steps: int = 1000
    validate_probing_only: bool = False
    validate_video_only: bool = False
    train_adapter_only: bool = False
    train_adapter_steps: int = 3000
    finetune_wm_adapt_realaction: bool = False
    finetune_wm_use_adapter: bool = True
    load_adapter_path: str | None = None
    realaction_add_noise_level: float = 0.0
    enable_validate_during_train: bool = False
    dataset_target: str | None = None
    validate_bootstrap_only: bool = False
    wm_resize_obs_64: bool = False
    load_idm_only: bool = False

    def model_post_init(self, __context):
        self.tracking.wandb_resume_id = self.run_id
        self.tracking.wandb_run_name = self.run_id
        self.tracking.wandb_dir = self.run_id
        self.model_save_dir = Path(os.environ.get("SYNC_DIR", "./")) / "ckpt"
        self.ckpt_dir = Path(os.environ.get("OUTPUT_DIR", "./")) / "ckpt"
        self.tracking.log_file = f"{self.run_id}.log"
        if self.val_batch_size is None:
            self.val_batch_size = self.batch_size

    def set_run_id_by_path(self, path: str, checksum_input: str = ""):
        if self.run_id is None:
            checksum_input = checksum_input + self.model_dump_json()
            checksum = hashlib.md5(checksum_input.encode()).hexdigest()[:6]
            run_id = f"run-{checksum}"
            timestamp = arrow.utcnow().strftime("%Y%m%d%H%M%S")
            self.run_id = f'{timestamp}-{run_id}'

    @classmethod
    def load(cls, path: str, use_base_dir_as_run_id: bool = True, checksum_input: str = ""):
        import yaml
        with open(path) as f:
            config = cls.model_validate(yaml.safe_load(f))
            if use_base_dir_as_run_id and config.run_id is None:
                config.set_run_id_by_path(path, checksum_input)
            config.model_post_init(None)
            return config


def resolve_keys(x: dict, keys: list[str]) -> list:
    ret = []
    for key in keys:
        key_parts = key.split(".")
        value = x
        for part in key_parts:
            value = value[part]
        ret.append(value)
    return ret


class TrainerFDM:
    def __init__(
        self,
        model: PretrainedModel,
        input_keys: list[str] | None = None,
        train_dataloader: DataLoader | None = None,
        val_dataloader: DataLoader | None = None,
        cfg: TrainerConfigFDM | None = None,
        config_to_log: dict | None = None,
    ):
        self._cfg = cfg or TrainerConfigFDM()
        self.dataset_cfg, self.model_cfg = config_to_log['dataset'], config_to_log['model']
        self._model = model
        self._resume_step = 0
        self._train_dl = train_dataloader
        self._val_dl = val_dataloader
        self._input_keys = input_keys or cfg.input_keys
        self._configs_to_log = config_to_log or {}
        self.training_mode_manager = TrainingModeManager(self.model_cfg)

    def setup(self):
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
        self._device_id = self.accelerator.device
        self._world_size = self.accelerator.num_processes
        self._is_rank_zero = self.accelerator.is_main_process

        if not self._is_rank_zero:
            logger.remove()
        else:
            logger.add(self._cfg.tracking.log_file, rotation="1 day")

        self._train_batches = len(self._train_dl) if self._train_dl is not None else 0
        self._val_batches = len(self._val_dl) if self._val_dl is not None else 0
        run_dir = self._cfg.run_id
        os.makedirs(run_dir, exist_ok=True)

        self._effective_bsz = (
            self._cfg.batch_size * self._world_size * self._cfg.gradient_accumulate_steps
        )

        gradient_steps = self._train_batches // self._cfg.gradient_accumulate_steps
        warmup_steps = min(self._cfg.warmup_steps, self._cfg.warmup_epochs * gradient_steps)
        self._optimizer, self._lr_scheduler = configure_optimizer(
            model=self._model,
            base_lr=self._cfg.base_lr,
            batch_size=self._effective_bsz,
            min_lr=self._cfg.min_lr,
            warmup_steps=warmup_steps,
            max_steps=self._cfg.max_epochs * gradient_steps,
            weight_decay=self._cfg.weight_decay,
            betas=self._cfg.betas,
            scheduler=self._cfg.scheduler,
            optimizer=self._cfg.optimizer,
            wm_lr=self._cfg.wm_lr,
        )

        self.saver = CheckpointSaver(
            local_ckpt_dir=self._cfg.model_save_dir,
            ckpt_dir=self._cfg.ckpt_dir,
            strategy=self._cfg.tracking.save_strategy,
            steps=self._cfg.tracking.save_steps,
            n_batches=self._train_batches,
            is_rank_zero=self._is_rank_zero,
            run_id=self._cfg.run_id,
        )
        resume = self.saver.resume(
            model=self._model, opt=self._optimizer, scheduler=self._lr_scheduler
        )
        if resume.checkpoint is not None:
            self._cfg.tracking.wandb_resume_id = resume.run_id
        elif self._cfg.init_checkpoint is not None:
            if not self._cfg.start_new_wandb_run_for_new_finetune:
                resume = self.saver.load(self._cfg.init_checkpoint, model=self._model, opt=self._optimizer, scheduler=self._lr_scheduler)
                if resume.checkpoint is not None:
                    self._cfg.tracking.wandb_resume_id = resume.run_id
            else:
                from safetensors.torch import load_file
                filename = Path(self._cfg.init_checkpoint) / "model.safetensors" \
                    if not self._cfg.init_checkpoint.endswith(".safetensors") else Path(self._cfg.init_checkpoint)
                state_dict = load_file(str(filename))
                if 'vq.embedding.weight' in state_dict and 'quantizer.embedding.weight' in self._model.state_dict():
                    state_dict['quantizer.embedding.weight'] = state_dict.pop('vq.embedding.weight')
                if self._cfg.load_idm_only:
                    state_dict = {k: v for k, v in state_dict.items() if not k.startswith('wm.')}
                missing, unexpected = self._model.load_state_dict(state_dict, strict=False)

                if len([k for k in state_dict if k.startswith('wm')]) == 0:
                    for _ in missing:
                        assert _.startswith('wm')
                    for _ in unexpected:
                        assert _.startswith('decoder') or _.startswith('prober')
                else:
                    assert len(missing) == 0, "no missing states allowed"
                    assert len(unexpected) == 0, "no unexpected states allowed"

                resume = ResumeResult(checkpoint=self._cfg.init_checkpoint, step=0, duration=0, run_id=None)

        ema = self._model.wm.ema
        self._model.wm.ema = None
        (
            self._model,
            self._optimizer,
            self._lr_scheduler,
        ) = self.accelerator.prepare(
            self._model,
            self._optimizer,
            self._lr_scheduler,
        )
        unwrapped_model = unwrap_model(self._model)
        unwrapped_model.wm.ema = ema.to(self._device_id)

        config_to_log = {
            "model": unwrap_model(self._model).__class__.__name__,
            "model_parameters": count_parameters(self._model),
            "world_size": self._world_size,
            "effective_bsz": self._effective_bsz,
            "resume_step": resume.step,
            "resume_checkpoint": resume.checkpoint,
            "trainer": self._cfg.model_dump(),
            **self._configs_to_log,
        }
        self.metrics = Metrics(
            self._cfg.tracking,
            run_id=self._cfg.run_id,
            config_to_log=config_to_log,
            is_rank_zero=self._is_rank_zero,
            window=self._cfg.gradient_accumulate_steps,
        )
        if self._is_rank_zero:
            msg = json.dumps(config_to_log, indent=4, ensure_ascii=False, cls=UniversalJSONEncoder)

        self.metrics.resume(
            step=resume.step // self._cfg.gradient_accumulate_steps,
            resume_time=resume.duration,
        )
        self.metrics.commit("train", lr=self._lr_scheduler.get_last_lr())
        self._resume_step = resume.step

        if self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None:
            self._adapter = ActionToVQAdapter(
                real_action_dim=self.model_cfg['probe_action_dim'],
                n_real_actions=self._cfg.n_max_state_action,
                num_learned_tokens=self.model_cfg['num_learned_tokens'],
                vq_codebook_size=self.model_cfg['n_codes'],
                hidden_dim=self.model_cfg['probe_dim']
            )
            self._adapter.load_state_dict(
                torch.load(self._cfg.load_adapter_path, map_location=lambda storage, loc: storage), strict=True
            )
            self._adapter = self.accelerator.prepare(self._adapter)
            self._adapter.eval()

    def train(self):
        try:
            self._train()
        except Exception as e:
            import traceback
            tb = traceback.format_exc()
            logger.exception(e)
            raise e

    def _train(self):
        self.setup()
        epoch = self._resume_step // self._train_batches
        status = self.metrics.get_status(epoch)
        pbar = tqdm(
            total=self._cfg.max_epochs * self._train_batches,
            initial=self._resume_step,
            desc=status,
            leave=False,
            disable=not self._is_rank_zero,
        )
        train_iter = iter(self._train_dl)

        job_name = os.environ.get("JOB_NAME", "local")
        output_dir = Path(os.environ.get("OUTPUT_DIR", "."))
        
        if self._cfg.validate_video_only:
            metrics_write_path = output_dir / f"metrics_video_{job_name}.txt"
            eval_on_real_action = self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None
            eval_metrics = self.evaluate(0, self._val_dl, infer_use_real_action=eval_on_real_action)
            metric = LossMetric()
            for k, v in eval_metrics.items():
                metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            info = metric.finalize()
            if self._is_rank_zero:
                with open(metrics_write_path, "a") as f:
                    f.write(f"{job_name}, {self._cfg.init_checkpoint}, dataset metrics: {info}\n")
            time.sleep(10)

        if self._cfg.validate_probing_only:
            metrics_write_path = output_dir / f"metrics_probing_{job_name}.txt"
            probing_info = self.evaluate_probing(0, self._train_dl, self._val_dl)
            metric = LossMetric()
            for k, v in probing_info.items():
                metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            info = metric.finalize()
            if self._is_rank_zero:
                with open(metrics_write_path, "a") as f:
                    f.write(f"{job_name}, {self._cfg.init_checkpoint}, probing_info: {info}\n")
            time.sleep(10)

        if self._cfg.train_adapter_only:
            adapter_info = self.train_adapter_main(0, self._train_dl, self._val_dl)
            metric = LossMetric()
            for k, v in adapter_info.items():
                metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            info = metric.finalize()
            time.sleep(10)

        if self._cfg.validate_bootstrap_only:
            metrics_write_path = output_dir / f"metrics_video_{job_name}.txt"
            eval_metrics = self.evaluate_bootstrap(0, self._val_dl)
            metric = LossMetric()
            for k, v in eval_metrics.items():
                metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            info = metric.finalize()
            if self._is_rank_zero:
                with open(metrics_write_path, "a") as f:
                    f.write(f"{job_name}, {self._cfg.init_checkpoint}, dataset metrics: {info}\n")
            time.sleep(10)
        
        if self._cfg.validate_video_only or self._cfg.validate_probing_only or self._cfg.train_adapter_only or self._cfg.validate_bootstrap_only:
            return

        step_s_t = time.time()
        for step in range(self._resume_step, self._cfg.max_epochs * self._train_batches):
            self._model.train()
            load_s_t = time.time()
            batch = next(train_iter)
            self.metrics.commit("train", load_time=time.time() - load_s_t)

            t_tmp = time.time()
            if self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None and self._cfg.finetune_wm_use_adapter:
                obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, real_action, _, _ = self.process_batch_withaction(batch)
                latent_action = self.adapter_predict(self._adapter, real_action, pad_mask_in)
                latent_action = latent_action.reshape(*latent_action.shape[:-2], 1, -1)
            else:
                obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask = self.process_batch(batch)
                latent_action = None
            self.metrics.commit("train", processbatch_time=time.time() - t_tmp)

            component_state = self.training_mode_manager.get_component_state(step)
            self.accelerator.unwrap_model(self._model).update_training_state(component_state)

            if self.training_mode_manager.should_log_stage_info(step, log_interval=1):
                stage_info = self.training_mode_manager.get_current_stage_info()

            t_tmp = time.time()
            with self.accelerator.autocast():
                output: BaseOutput = self._model(**{"obs4lam": obs4lam, "obs4vaewm": obs4vaewm, "pad_mask_in": pad_mask_in, "obs4vaewm_mask": obs4vaewm_mask, "latent_action": latent_action})
            self.metrics.commit("train", forward_time=time.time() - t_tmp)

            self.metrics.commit("train", values={"loss": output.loss, **output.loss_dict})
            if hasattr(output, "codebook_indices") and output.codebook_indices is not None and unwrap_model(self._model).codebook_size is not None:
                cb_metric = CodebookMetric(unwrap_model(self._model).codebook_size).add(output.codebook_indices.cpu())
                self.metrics.commit("train", values=cb_metric.finalize())
            if hasattr(unwrap_model(self._model), "codebook") and unwrap_model(self._model).codebook is not None:
                self.metrics.commit(
                    "train",
                    values={
                        "codebook/mean": unwrap_model(self._model).codebook.mean().item(),
                        "codebook/std": unwrap_model(self._model).codebook.std().item(),
                    },
                )

            loss = output.loss / self._cfg.gradient_accumulate_steps
            t_tmp = time.time()
            self.accelerator.backward(loss)
            self.metrics.commit("train", backward_time=time.time() - t_tmp)

            should_validate = (step + 1) % (
                self._cfg.val_interval * self._cfg.gradient_accumulate_steps
            ) == 0 and self._cfg.enable_validate_during_train

            if (step + 1) % self._cfg.gradient_accumulate_steps == 0:
                t_tmp = time.time()

                if self._cfg.clip_grad > 0:
                    self.accelerator.clip_grad_norm_(self._model.parameters(), max_norm=self._cfg.clip_grad)
                self._optimizer.step()
                lr = self._lr_scheduler.step()
                self._optimizer.zero_grad()
                self.metrics.commit("train", optimizer_time=time.time() - t_tmp)

                t_tmp = time.time()
                self.accelerator.unwrap_model(self._model).post_process()
                self.metrics.commit("train", updateema_time=time.time() - t_tmp)

                self.metrics.commit("train", step_time=time.time() - step_s_t, lr=lr)
                step_s_t = time.time()
                if not should_validate:
                    status = self.metrics.push(epoch, "train")

            pbar.update()
            pbar.set_description(status)

            if (step + 1) % self._train_batches == 0:
                epoch += 1

            if should_validate:
                info = self.validate(step + 1)
                self.metrics.push_values(epoch, info, group="val", push_all=True)
                step_s_t = time.time()

            s_t = time.time()
            self.saver.save(
                step=step + 1,
                model=self._model,
                optimizer=self._optimizer,
                scheduler=self._lr_scheduler,
                duration=self.metrics.duration,
                train_loss=self.metrics.get("loss", "train"),
                val_loss=self.metrics.get("loss", "val"),
            )
            self.metrics.commit("train", checkpoint_time=time.time() - s_t)

        self.metrics.finalize()

    def validate(self, step):
        self._model.eval()
        torch_gc()
        metric = LossMetric()

        current_config = self.training_mode_manager.get_component_state(step)
        
        if current_config.train_fdm:
            eval_on_real_action = self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None
            eval_metrics = self.evaluate(step, self._val_dl)
            for k, v in eval_metrics.items():
                metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
                if eval_on_real_action:
                    k = f"gtlam_{k}"
                    metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            if eval_on_real_action:
                eval_metrics = self.evaluate(step, self._val_dl, infer_use_real_action=True)
                for k, v in eval_metrics.items():
                    k = f"realaction_{k}"
                    metric.add(k, self.accelerator.gather(torch.tensor(v, device=self._device_id)))
            torch_gc()
        
        info = metric.finalize()
        return info

    @torch.no_grad()
    def evaluate(self, train_step, eval_dataloader, infer_use_real_action=False):
        self._model.eval()
        data_fetcher = iter(eval_dataloader)
        eval_metrics = {_: [] for _ in ['psnr', 'ssim', 'lpips', 'fvd']}
        recon_fvd_feats, orig_fvd_feats = [], []

        import lpips
        self.lpips_loss_fn = lpips.LPIPS(net='alex', version='0.1').to(self._device_id)

        fvd_method = "styleganv"
        if fvd_method == 'styleganv':
            from fvd.styleganv.fvd import load_i3d_pretrained
        elif fvd_method == 'videogpt':
            from fvd.videogpt.fvd import load_i3d_pretrained
        self.fvd_i3d = load_i3d_pretrained(device="cuda:0")

        for eval_step in range(1, self._cfg.val_video_steps + 1):
            batch_traj = next(data_fetcher)
            if self._cfg.dataset_target == "egofull_v2":
                def recursive_process(x):
                    if isinstance(x, (torch.Tensor, list)):
                        return x[:len(x) // 2]
                    if isinstance(x, dict):
                        for key in x:
                            x[key] = recursive_process(x[key])
                    return x
                batch_traj = recursive_process(batch_traj)
            
            with torch.no_grad():
                eval_metrics_batch, recon_fvd_feats_batch, orig_fvd_feats_batch, _ = \
                    self.visualize_reconstruction(train_step, batch_traj, return_fvd_feat=True, extra_name=f'testdata{eval_step}', infer_use_real_action=infer_use_real_action)
            eval_metrics = {k: eval_metrics[k] + [v] for k, v in eval_metrics_batch.items()}
            recon_fvd_feats += [recon_fvd_feats_batch]
            orig_fvd_feats += [orig_fvd_feats_batch]
        
        eval_metrics = {k: np.array(v).mean() for k, v in eval_metrics.items()}
        recon_fvd_feats, orig_fvd_feats = np.concatenate(recon_fvd_feats, axis=0), np.concatenate(orig_fvd_feats, axis=0)
        fvd_metric = self.get_fvd_metric(recon_fvd_feats, orig_fvd_feats, get_feats=False, get_metric=True)
        eval_metrics['fvd'] = fvd_metric

        return eval_metrics

    @torch.no_grad()
    def evaluate_bootstrap(self, train_step, eval_dataloader, infer_use_real_action=False):
        self._model.eval()
        data_fetcher = iter(eval_dataloader)
        bootstrap_steps = 10
        eval_metrics = [{_: [] for _ in ['psnr', 'ssim', 'lpips', 'fvd']} for _ in range(bootstrap_steps)]
        recon_fvd_feats, orig_fvd_feats = [], []

        import lpips
        self.lpips_loss_fn = lpips.LPIPS(net='alex', version='0.1').to(self._device_id)

        fvd_method = "styleganv"
        if fvd_method == 'styleganv':
            from fvd.styleganv.fvd import load_i3d_pretrained
        elif fvd_method == 'videogpt':
            from fvd.videogpt.fvd import load_i3d_pretrained
        self.fvd_i3d = load_i3d_pretrained(device="cuda:0")

        for eval_step in range(1, self._cfg.val_video_steps + 1):
            batch_traj = next(data_fetcher)
            if self._cfg.dataset_target == "egofull_v2":
                def recursive_process(x):
                    if isinstance(x, (torch.Tensor, list)):
                        return x[:len(x) // 2]
                    if isinstance(x, dict):
                        for key in x:
                            x[key] = recursive_process(x[key])
                    return x
                batch_traj = recursive_process(batch_traj)

            gt_video = batch_traj['observation']['image_primary'].clone()
            for i in range(bootstrap_steps):
                with torch.no_grad():
                    eval_metrics_batch, recon_fvd_feats_batch, orig_fvd_feats_batch, recon_video = \
                        self.visualize_reconstruction(train_step, batch_traj, return_fvd_feat=True, extra_name=f'testdata{eval_step}', infer_use_real_action=infer_use_real_action, gt_video=gt_video)
                    eval_metrics[i] = {k: eval_metrics[i][k] + [v] for k, v in eval_metrics_batch.items()}
                    batch_traj['observation']['image_primary'] = torch.as_tensor(recon_video, dtype=batch_traj['observation']['image_primary'].dtype, device=batch_traj['observation']['image_primary'].device)
            
        eval_metrics = [{k: np.array(v).mean() for k, v in item.items()} for item in eval_metrics]
        return eval_metrics
    
    def evaluate_probing(self, train_step, train_dataloader, eval_dataloader):
        self._model.eval()
        info = {}

        prober_onelayer = OfflineActionProberOneLayer(
            action_latent_dim=self.model_cfg['action_latent_dim'],
            num_learned_tokens=self.model_cfg['num_learned_tokens'],
            n_probe_actions=self.model_cfg['n_probe_actions'],
            probe_action_dim=self.model_cfg['probe_action_dim'],
        )
        self.evaluate_probing_training(prober_onelayer, train_dataloader)
        onelayer_info = self.evaluate_probing_validation(prober_onelayer, eval_dataloader)
        onelayer_info = {f"{k}_onelayer": v for k, v in onelayer_info.items()}

        info = {**info, **onelayer_info}
        return info

    def evaluate_probing_training(self, prober, train_dataloader):
        prober_optimizer, prober_lr_scheduler = configure_optimizer(
            model=prober,
            base_lr=self._cfg.base_lr,
            batch_size=self._effective_bsz,
            min_lr=self._cfg.min_lr,
            warmup_steps=0,
            max_steps=self._cfg.max_epochs * self._train_batches,
            weight_decay=self._cfg.weight_decay,
            betas=self._cfg.betas,
            scheduler=self._cfg.scheduler,
            optimizer=self._cfg.optimizer,
        )
        prober, prober_optimizer, prober_lr_scheduler = self.accelerator.prepare(prober, prober_optimizer, prober_lr_scheduler)

        train_iter = iter(train_dataloader)
        for step in range(self._cfg.train_prober_steps):
            prober.train()
            batch = next(train_iter)
            action_loss, _ = self.get_probing_loss(prober, batch)

            prober_optimizer.zero_grad()
            self.accelerator.backward(action_loss)
            prober_optimizer.step()
            prober_lr_scheduler.step()

        del prober_optimizer, prober_lr_scheduler
        torch_gc()

    @torch.no_grad()
    def evaluate_probing_validation(self, prober, eval_dataloader):
        torch_gc()
        prober.eval()
        val_iter = iter(eval_dataloader)
        metric = LossMetric()

        with torch.no_grad():
            val_steps = int(self._cfg.val_samples / self._cfg.val_batch_size / self._world_size)
            for step in range(val_steps):
                batch = next(val_iter)

                action_loss, action_l1_loss = self.get_probing_loss(prober, batch)
                metric.add("probing_loss", action_loss)
                metric.add("probing_loss_l1", action_l1_loss)

        del prober
        torch_gc()
        info = metric.finalize()
        return info

    def get_probing_loss(self, prober, batch):
        obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, action, state_action_mask, valid_action_mask = self.process_batch_withaction(batch)
        with self.accelerator.autocast(), torch.no_grad():
            output = self._model(**{
                "obs4lam": obs4lam,
                "obs4vaewm": obs4vaewm,
                "pad_mask_in": pad_mask_in,
                "obs4vaewm_mask": obs4vaewm_mask
            }, only_return_actions=True)
        action_tokens = output.action_tokens
        clip_len = pad_mask_in.sum(1).tolist()
        has_state_action = batch['has_action']

        la = torch.split(action_tokens, [x - 1 for x in clip_len], dim=0)
        la = torch.cat([x for x, e in zip(la, has_state_action) if e]).squeeze(1)
        pclip_len = [x for x, e in zip(clip_len, has_state_action) if e]

        def select_last_k(x: list[list], k: list[int], loffset: int = 0, roffset: int | None = None, right_padding: bool = False):
            assert len(x) == len(k), "Length of x and k must be the same"
            if not right_padding:
                roffset = roffset if roffset is None else -roffset
                ret = [x[i][-k[i] + loffset : roffset] for i in range(len(k))]
            else:
                roffset = 0 if roffset is None else roffset
                ret = [x[i][: k[i] - roffset] for i in range(len(k))]
            return ret

        action_out = torch.cat(select_last_k(action, pclip_len, right_padding=True, roffset=1), dim=0)
        action_out_mask = torch.cat(select_last_k(state_action_mask, pclip_len, right_padding=True, roffset=1), dim=0)

        def repeat_to_batch(x: torch.Tensor, length: list[int]):
            return torch.repeat_interleave(
                x, repeats=torch.tensor(length).to(x.device) - 1, dim=0
            ).long()
        valid_action_mask = repeat_to_batch(valid_action_mask, pclip_len)

        pred_action = prober(la)
        action_loss = F.mse_loss(pred_action, action_out, reduction='none')
        action_loss_mask = action_out_mask.unsqueeze(-1) * valid_action_mask.unsqueeze(1)
        action_loss = (action_loss * action_loss_mask).sum() / (action_loss_mask.sum() + 1e-8)

        action_l1_loss = (F.l1_loss(pred_action, action_out, reduction='none') * action_loss_mask).sum() / (action_loss_mask.sum() + 1e-8)
        return action_loss, action_l1_loss

    def process_batch_withaction(self, batch_traj):
        wm = self.accelerator.unwrap_model(self._model).wm
        obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, action4probingtarget, state_action_mask = [], [], [], [], [], []

        for dataset_name, obs_i, pad_mask_i in zip(batch_traj['dataset_name'],
                                                batch_traj["observation"]["image_primary"],
                                                batch_traj["observation"]["pad_mask"]):
            sample_valid_frames = pad_mask_i.sum().item()
            expect_valid_frames = self.per_dataset_config[dataset_name]['expect_valid_frames']
            min_valid_frames = self.per_dataset_config[dataset_name]['min_valid_frames']
            micro_frame_size = self.per_dataset_config[dataset_name]['micro_frame_size']
            assert expect_valid_frames == micro_frame_size * wm.train_wm_seq_length
            assert sample_valid_frames >= min_valid_frames

            valid_frames = min(sample_valid_frames, expect_valid_frames)
            obs4lam.append(obs_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size])
            pad_mask_in.append(pad_mask_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size])
            num_valid_wm_timesteps = obs_i[micro_frame_size - 1: valid_frames: micro_frame_size].shape[0]
            assert num_valid_wm_timesteps == valid_frames // micro_frame_size
            assert pad_mask_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size].sum().item() == num_valid_wm_timesteps

            tmp = obs_i[:expect_valid_frames].reshape(wm.train_wm_seq_length, micro_frame_size, *obs_i.shape[-3:])
            mask_2d = torch.zeros(wm.train_wm_seq_length, wm.vae_max_compress_rate,
                                device=obs_i.device, dtype=torch.bool)
            mask_2d[:num_valid_wm_timesteps, -micro_frame_size:] = True
            tmp = F.pad(tmp,
                        (0, 0,
                        0, 0,
                        0, 0,
                        wm.vae_max_compress_rate - micro_frame_size, 0,
                        0, 0),
                        mode='constant',
                        value=0)
            obs4vaewm.append(tmp.reshape(-1, *tmp.shape[-3:]))
            obs4vaewm_mask.append(mask_2d.reshape(-1))

        dataset_name_action = [x for x, y in zip(batch_traj['dataset_name'], batch_traj['has_action']) if y]
        state_per_obs_action = [x for x, y in zip(batch_traj['state_per_obs'], batch_traj['has_action']) if y]
        for dataset_name, action_i, state_per_obs in zip(dataset_name_action,
                                                        batch_traj['action'],
                                                        state_per_obs_action):
            if len(action_i.shape) != 3:
                assert len(action_i.shape) == 4
                assert action_i.shape[-2] == 1
                action_i = action_i.squeeze(-2)
            if state_per_obs == 0:
                state_per_obs = 1
            expect_valid_frames = self.per_dataset_config[dataset_name]['expect_valid_frames']
            micro_frame_size = self.per_dataset_config[dataset_name]['micro_frame_size']

            action_tmp = action_i[:expect_valid_frames].reshape(wm.train_wm_seq_length, micro_frame_size, *action_i.shape[-2:])
            action_tmp = action_tmp[:, :, :int(state_per_obs), :].reshape(action_tmp.shape[0], -1, action_tmp.shape[-1])
            n_max_state_action = self._cfg.n_max_state_action
            action_tmp = action_tmp[:, :n_max_state_action]
            state_action_mask_i = torch.ones(action_tmp.shape[0], n_max_state_action)
            if action_tmp.shape[1] < n_max_state_action:
                state_action_mask_i[:, action_tmp.shape[1]:] = 0
                action_tmp = F.pad(action_tmp, (0, 0, 0, n_max_state_action - action_tmp.shape[1], 0, 0),
                                mode='constant',
                                value=0)
            action4probingtarget.append(action_tmp)
            state_action_mask.append(state_action_mask_i)

        obs4lam = torch.stack(obs4lam, dim=0).to(self._device_id, non_blocking=True)
        pad_mask_in = torch.stack(pad_mask_in, dim=0).to(self._device_id, non_blocking=True)
        obs4vaewm = torch.stack(obs4vaewm, dim=0).to(self._device_id, non_blocking=True)
        obs4vaewm_mask = torch.stack(obs4vaewm_mask, dim=0).to(self._device_id, non_blocking=True)
        action4probingtarget = torch.stack(action4probingtarget, dim=0).to(self._device_id, non_blocking=True)
        state_action_mask = torch.stack(state_action_mask, dim=0).to(self._device_id, non_blocking=True)
        valid_action_mask = batch_traj['valid_action_mask'].clone().to(self._device_id, non_blocking=True)

        if self._cfg.wm_resize_obs_64:
            resize_64 = torchvision.transforms.Resize(64)
            obs4vaewm = rearrange(resize_64(rearrange(obs4vaewm, "B T H W C -> (B T) C H W")), "(B T) C H W -> B T H W C", T=obs4vaewm.shape[1])

        return obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, action4probingtarget, state_action_mask, valid_action_mask

    def add_realaction_noise(self, batch_traj):
        assert not batch_traj['action'][:, :, 1:].any()
        realaction_noise = torch.randn_like(batch_traj['action'][:, :, 0])
        batch_traj['action'][:, :, 0] += self._cfg.realaction_add_noise_level * realaction_noise

    @torch.no_grad()
    def visualize_reconstruction(self, step, batch_traj, action_tokens=None, return_fvd_feat=False, extra_name='train', infer_use_real_action=False, gt_video=None):
        self._model.eval()
        wm = self.accelerator.unwrap_model(self._model).wm

        orig_video, dataset_names = batch_traj["observation"]["image_primary"].clone(), batch_traj['dataset_name']
        if gt_video is not None:
            orig_video = gt_video.clone()
        batch_micro_frame_sizes, batch_expect_valid_frames = \
            np.array([self.per_dataset_config[_]['micro_frame_size'] for _ in dataset_names]), \
            np.array([self.per_dataset_config[_]['expect_valid_frames'] for _ in dataset_names])
        batch_sample_valid_frames = batch_traj["observation"]["pad_mask"].sum(1).cpu().numpy()
        batch_valid_frames = np.minimum(batch_expect_valid_frames, batch_sample_valid_frames)
        batch_valid_wm_timesteps = batch_valid_frames // batch_micro_frame_sizes
        batch_recon_eval_valid_frames = batch_valid_wm_timesteps * batch_micro_frame_sizes

        if infer_use_real_action and self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None:
            if self._cfg.realaction_add_noise_level > 0:
                self.add_realaction_noise(batch_traj)
            obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, real_action, _, _ = self.process_batch_withaction(batch_traj)
            latent_action = self.adapter_predict(self._adapter, real_action, pad_mask_in)
            latent_action = latent_action.reshape(*latent_action.shape[:-2], 1, -1)
        else:
            obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask = self.process_batch(batch_traj)
            latent_action = None

        obs4lam_first, obs4vaewm_first, obs4vaewm_mask_first = \
            obs4lam[:, 0], obs4vaewm[:, :wm.vae_max_compress_rate], obs4vaewm_mask[:, :wm.vae_max_compress_rate]
        clip_len = pad_mask_in.sum(1).tolist()
        B, T = obs4lam.shape[:2]

        with self.accelerator.autocast() and torch.no_grad():
            if action_tokens is None:
                output: BaseOutput = self._model(**{"obs4lam": obs4lam, "obs4vaewm": obs4vaewm, "pad_mask_in": pad_mask_in, "obs4vaewm_mask": obs4vaewm_mask,
                                                    'latent_action': latent_action},
                                                only_return_actions=True)
                action_tokens_flat = output.action_tokens
                action_tokens = unwrap_model(self._model).get_unflatten_results(action_tokens_flat, clip_len, B, T)

            x = wm.vae_encode(obs4vaewm_first, obs4vaewm_mask_first)
            pred_z = self.inference(x, action_tokens)
            recon = wm.vae_decode(pred_z, batch_micro_frame_sizes)
            if self._cfg.wm_resize_obs_64:
                resize_64 = torchvision.transforms.Resize(64)
                orig_video = rearrange(resize_64(rearrange(orig_video, "B T H W C -> (B T) C H W")), "(B T) C H W -> B T H W C", T=orig_video.shape[1])

        lo, hi = -1, 1
        recon_video = recon.clamp(lo, hi).sub(lo).div(hi - lo).mul(255).add(0.5).clamp(0, 255)
        recon_video = recon_video.to(device="cpu", dtype=torch.uint8).numpy()
        recon_video = rearrange(recon_video, "B C T H W -> B T H W C")
        for i, (orig_obs_i, micro_frame_size) in enumerate(zip(orig_video.numpy(), batch_micro_frame_sizes)):
            recon_video[i, : micro_frame_size] = orig_obs_i[: micro_frame_size]

        eval_metrics, recon_fvd_feats, orig_fvd_feats = \
            self.get_video_eval_metric(wm.args.video_output.video_eval_metrics, recon_video, orig_video, valid_frames=batch_recon_eval_valid_frames)

        if self._is_rank_zero:
            try:
                save_video_dir = 'video_reconstruction'
                os.makedirs(save_video_dir, exist_ok=True)
                sample_idx = np.random.randint(0, orig_video.shape[0])
                recon_eval_valid_frames = batch_recon_eval_valid_frames[sample_idx]
                combined_orig_recon = np.concatenate([orig_video[sample_idx][:recon_eval_valid_frames].numpy(),
                                                    recon_video[sample_idx][:recon_eval_valid_frames]], axis=2)
                try:
                    lam_type = "gt" if not infer_use_real_action else "adapted"
                    media.write_video(f"{save_video_dir}/step{step}-{extra_name}-{lam_type}-{recon_eval_valid_frames}frames.mp4",
                                    combined_orig_recon, fps=8)
                except Exception:
                    pass
            except Exception:
                pass

        if return_fvd_feat:
            return eval_metrics, recon_fvd_feats, orig_fvd_feats, recon_video
        else:
            return eval_metrics, recon_video

    @torch.no_grad()
    def imagine(self, first_image, action_tokens, imagine_micro_frame_size=3):
        self._model.eval()
        device = self._device_id
        B = first_image.shape[0]
        T = action_tokens.shape[1] + 1
        wm = self.accelerator.unwrap_model(self._model).wm
        if self._cfg.finetune_wm_adapt_realaction and self._cfg.load_adapter_path is not None:
            assert imagine_micro_frame_size == self._cfg.n_max_state_action, "micro_frame_size must match n_max_state_action"
            assert action_tokens.shape[1] % imagine_micro_frame_size == 0, "num of real actions must be divisible by micro_frame_size"
            T = action_tokens.shape[1] // imagine_micro_frame_size + 1
            action_tokens = action_tokens.reshape(B, T - 1, imagine_micro_frame_size, action_tokens.shape[-1])
            action_tokens = action_tokens.to(device, non_blocking=True)
            latent_action = self.adapter_predict(self._adapter, action_tokens, include_last_action=True)
            action_tokens = latent_action.reshape(B, T - 1, 1, -1)

        obs4lam_first = rearrange(first_image, "B 1 H W C -> B H W C")
        obs4vaewm_mask_first = torch.zeros(1, wm.vae_max_compress_rate, device=device, dtype=torch.bool)
        obs4vaewm_mask_first[:, -1] = True
        obs4vaewm_first = F.pad(first_image,
                                (0, 0,
                                0, 0,
                                0, 0,
                                wm.vae_max_compress_rate - 1, 0,
                                0, 0),
                                mode='constant',
                                value=0)
        obs4lam_first = obs4lam_first.to(device, non_blocking=True).float()
        obs4vaewm_first = obs4vaewm_first.to(device, non_blocking=True).float()
        obs4vaewm_mask_first = obs4vaewm_mask_first.to(device, non_blocking=True)

        with self.accelerator.autocast() and torch.no_grad():
            x = wm.vae_encode(obs4vaewm_first, obs4vaewm_mask_first)
            pred_z = self.inference(x, action_tokens)
            recon = wm.vae_decode(pred_z, np.array([imagine_micro_frame_size for _ in range(B)]))
            recon = recon[:, :, imagine_micro_frame_size-1: T*imagine_micro_frame_size]

        lo, hi = -1, 1
        recon_video = recon.clamp(lo, hi).sub(lo).div(hi - lo).mul(255).add(0.5).clamp(0, 255)
        recon_video = recon_video.to(device="cpu", dtype=torch.uint8).numpy()
        recon_video = rearrange(recon_video, "B C T H W -> B T H W C")
        recon_video[:, :1] = first_image[:, :1].numpy()

        return recon_video

    @torch.no_grad()
    def inference(self, first_state, action_tokens, fdm_pred=None):
        wm = self.accelerator.unwrap_model(self._model).wm
        model = wm.ema if wm.args.use_ema_infer else wm.transformer
        device = first_state.device
        dtype = first_state.dtype
        scheduler = wm.scheduler
        num_frames = wm.args.num_frames
        B = first_state.shape[0]
        T = action_tokens.shape[1] + 1
        C = first_state.shape[1]

        act_cond = repeat(
            model.latent_action_embedder.y_embedding,
            "N D -> B T 1 (N D)",
            B=B, T=T,
        ).clone()
        act_cond[:, 1:] = action_tokens.clone()

        z = torch.randn(B, C, T, *first_state.shape[-2:], device=device, dtype=dtype)
        z[:, :, :1] = first_state
        masks = torch.ones(B, T, dtype=torch.float, device=device)
        masks[:, :1] = 0

        model_args = prepare_multi_resolution_info(
            wm.args.video_output.multi_resolution, 1, (wm.args.input_height, wm.args.input_width), num_frames=num_frames, fps=wm.args.video_output.fps, device=device, dtype=dtype
        )

        samples = scheduler.sample_with_latent_action(
            model,
            None,
            z=z,
            latent_actions=act_cond,
            image_cond=None,
            device=device,
            additional_args=model_args,
            progress=False,
            mask=masks,
        )
        return samples

    def get_video_eval_metric(self, eval_metrics, recon_video, orig_video, valid_frames):
        res = {_: [] for _ in eval_metrics if _ != 'fvd'}
        recon_fvd_feats, orig_fvd_feats = [], []

        for recon, orig, valid_frame in zip(recon_video, orig_video, valid_frames):
            recon = torch.tensor(recon[:valid_frame], device=self._device_id).float().unsqueeze(0)
            orig = torch.tensor(orig[:valid_frame], device=self._device_id).float().unsqueeze(0)
            for eval_metric in eval_metrics:
                if eval_metric != 'fvd':
                    metric_fn = getattr(self, f'get_{eval_metric}_metric', None)
                    if metric_fn is None:
                        raise NotImplementedError
                    res[eval_metric].append(metric_fn(recon, orig))
            if valid_frame >= 10:
                recon_fvd_feats_batch, orig_fvd_feats_batch = self.get_fvd_metric(recon, orig, get_feats=True, get_metric=False)
                recon_fvd_feats.append(recon_fvd_feats_batch)
                orig_fvd_feats.append(orig_fvd_feats_batch)

        res = {k: np.array(v).mean() for k, v in res.items()}
        if recon_fvd_feats:
            recon_fvd_feats = np.concatenate(recon_fvd_feats, axis=0)
            orig_fvd_feats = np.concatenate(orig_fvd_feats, axis=0)
            res['fvd'] = self.get_fvd_metric(recon_fvd_feats, orig_fvd_feats, get_feats=False, get_metric=True)
        else:
            res['fvd'] = float('nan')
            
        return res, recon_fvd_feats, orig_fvd_feats

    def get_ssim_metric(self, recon_video, orig_video):
        from torchmetrics.functional import structural_similarity_index_measure as ssim
        recon_video = rearrange(recon_video, "B T H W C -> (B T) C H W")
        orig_video = rearrange(orig_video, "B T H W C -> (B T) C H W")
        return ssim(recon_video, orig_video, data_range=255.0).item()

    def get_psnr_metric(self, recon_video, orig_video):
        from torchmetrics.functional import peak_signal_noise_ratio as psnr
        recon_video = rearrange(recon_video, "B T H W C -> (B T) C H W")
        orig_video = rearrange(orig_video, "B T H W C -> (B T) C H W")
        return psnr(recon_video, orig_video, data_range=255.0).item()

    def get_lpips_metric(self, recon_video, orig_video):
        recon_video = rearrange(recon_video, "B T H W C -> (B T) C H W")
        orig_video = rearrange(orig_video, "B T H W C -> (B T) C H W")
        recon_video = recon_video / 255.0 * 2. - 1.
        orig_video = orig_video / 255.0 * 2. - 1.
        return self.lpips_loss_fn(recon_video, orig_video).mean().item()

    def get_fvd_metric(self, recon_video, orig_video, method='styleganv', get_feats=True, get_metric=True):
        assert method in ['styleganv', 'videogpt']
        if method == 'styleganv':
            from fvd.styleganv.fvd import frechet_distance
        elif method == 'videogpt':
            from fvd.videogpt.fvd import frechet_distance

        assert get_feats or get_metric
        if get_feats:
            recon_video = torch.tensor(recon_video, device=self._device_id)
            orig_video = torch.tensor(orig_video, device=self._device_id)
            recon_video = rearrange(recon_video / 255.0, 'B T H W C -> B T C H W')
            orig_video = rearrange(orig_video / 255.0, 'B T H W C -> B T C H W')
            recon_feats, orig_feats = self.get_fvd_feats_batch(recon_video, orig_video, method=method)
        else:
            recon_feats, orig_feats = recon_video, orig_video
        
        fvd = float('nan')
        if get_metric:
            if recon_feats.shape[0] > 0 and orig_feats.shape[0] > 0:
                fvd = frechet_distance(recon_feats, orig_feats)

        if get_metric and get_feats:
            return recon_feats, orig_feats, fvd
        elif get_metric:
            return fvd
        else:
            return recon_feats, orig_feats

    def get_fvd_feats_batch(self, videos1, videos2, method='styleganv'):
        def trans(x):
            if x.shape[-3] == 1:
                x = x.repeat(1, 1, 3, 1, 1)
            x = x.permute(0, 2, 1, 3, 4)
            return x

        if method == 'styleganv':
            from fvd.styleganv.fvd import get_fvd_feats
        elif method == 'videogpt':
            from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats

        assert videos1.shape == videos2.shape
        videos1 = trans(videos1)
        videos2 = trans(videos2)
        assert videos1.shape[2] >= 10, "FVD requires at least 10 frames per clip"

        feats1 = get_fvd_feats(videos1, i3d=self.fvd_i3d, device="cuda:0")
        feats2 = get_fvd_feats(videos2, i3d=self.fvd_i3d, device="cuda:0")

        return feats1, feats2

    def process_batch(self, batch_traj):
        wm = self.accelerator.unwrap_model(self._model).wm

        obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask = [], [], [], []
        for dataset_name, obs_i, pad_mask_i in zip(batch_traj['dataset_name'],
                                                batch_traj["observation"]["image_primary"],
                                                batch_traj["observation"]["pad_mask"]):
            sample_valid_frames = pad_mask_i.sum().item()
            expect_valid_frames = self.per_dataset_config[dataset_name]['expect_valid_frames']
            min_valid_frames = self.per_dataset_config[dataset_name]['min_valid_frames']
            micro_frame_size = self.per_dataset_config[dataset_name]['micro_frame_size']
            assert expect_valid_frames == micro_frame_size * wm.train_wm_seq_length
            assert sample_valid_frames >= min_valid_frames

            valid_frames = min(sample_valid_frames, expect_valid_frames)
            obs4lam.append(obs_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size])
            pad_mask_in.append(pad_mask_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size])
            num_valid_wm_timesteps = obs_i[micro_frame_size - 1: valid_frames: micro_frame_size].shape[0]
            assert num_valid_wm_timesteps == valid_frames // micro_frame_size
            assert pad_mask_i[micro_frame_size - 1: expect_valid_frames: micro_frame_size].sum().item() == num_valid_wm_timesteps

            tmp = obs_i[:expect_valid_frames].reshape(wm.train_wm_seq_length, micro_frame_size, *obs_i.shape[-3:])
            mask_2d = torch.zeros(wm.train_wm_seq_length, wm.vae_max_compress_rate,
                                device=obs_i.device, dtype=torch.bool)
            mask_2d[:num_valid_wm_timesteps, -micro_frame_size:] = True
            tmp = F.pad(tmp,
                        (0, 0,
                        0, 0,
                        0, 0,
                        wm.vae_max_compress_rate - micro_frame_size, 0,
                        0, 0),
                        mode='constant',
                        value=0)
            obs4vaewm.append(tmp.reshape(-1, *tmp.shape[-3:]))
            obs4vaewm_mask.append(mask_2d.reshape(-1))

        obs4lam = torch.stack(obs4lam, dim=0).to(self._device_id, non_blocking=True).float()
        pad_mask_in = torch.stack(pad_mask_in, dim=0).to(self._device_id, non_blocking=True)
        obs4vaewm = torch.stack(obs4vaewm, dim=0).to(self._device_id, non_blocking=True).float()
        obs4vaewm_mask = torch.stack(obs4vaewm_mask, dim=0).to(self._device_id, non_blocking=True)

        if self._cfg.wm_resize_obs_64:
            resize_64 = torchvision.transforms.Resize(64)
            obs4vaewm = rearrange(resize_64(rearrange(obs4vaewm, "B T H W C -> (B T) C H W")), "(B T) C H W -> B T H W C", T=obs4vaewm.shape[1])

        return obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask

    def train_adapter_main(self, train_step, train_dataloader, eval_dataloader):
        self._model.eval()

        adapter = ActionToVQAdapter(
            real_action_dim=self.model_cfg['probe_action_dim'],
            n_real_actions=self._cfg.n_max_state_action,
            num_learned_tokens=self.model_cfg['num_learned_tokens'],
            vq_codebook_size=self.model_cfg['n_codes'],
            hidden_dim=self.model_cfg['probe_dim']
        )
        self.train_adapter(adapter, train_dataloader)
        info = self.evaluate_adapter(adapter, eval_dataloader)

        return info

    def train_adapter(self, adapter, train_dataloader):
        adapter_optimizer, adapter_lr_scheduler = configure_optimizer(
            model=adapter,
            base_lr=self._cfg.base_lr,
            batch_size=self._effective_bsz,
            min_lr=self._cfg.min_lr,
            warmup_steps=0,
            max_steps=self._cfg.max_epochs * self._train_batches,
            weight_decay=self._cfg.weight_decay,
            betas=self._cfg.betas,
            scheduler=self._cfg.scheduler,
            optimizer=self._cfg.optimizer,
        )
        adapter, adapter_optimizer, adapter_lr_scheduler = self.accelerator.prepare(adapter, adapter_optimizer, adapter_lr_scheduler)

        train_iter = iter(train_dataloader)
        for step in range(self._cfg.train_adapter_steps):
            adapter.train()
            batch = next(train_iter)
            adapter_loss, _ = self.get_adapter_loss(adapter, batch)

            adapter_optimizer.zero_grad()
            self.accelerator.backward(adapter_loss)
            adapter_optimizer.step()
            adapter_lr_scheduler.step()

        save_path = self._cfg.ckpt_dir / f"adapter_model_{step + 1}.pt"
        torch.save(adapter.state_dict(), save_path)
        
        del adapter_optimizer, adapter_lr_scheduler
        torch_gc()

    @torch.no_grad()
    def evaluate_adapter(self, adapter, eval_dataloader):
        torch_gc()
        adapter.eval()
        val_iter = iter(eval_dataloader)
        metric = LossMetric()

        with torch.no_grad():
            val_steps = int(self._cfg.val_samples / self._cfg.val_batch_size / self._world_size)
            val_steps = 100
            for step in range(val_steps):
                batch = next(val_iter)

                adapter_loss, adapter_accuracy_info = self.get_adapter_loss(adapter, batch)
                metric.add("adapter_loss", adapter_loss)
                for k, v in adapter_accuracy_info.items():
                    metric.add(k, v)

        del adapter
        torch_gc()
        info = metric.finalize()
        return info

    def get_adapter_loss(self, adapter, batch):
        obs4lam, obs4vaewm, pad_mask_in, obs4vaewm_mask, action, state_action_mask, valid_action_mask = self.process_batch_withaction(batch)

        with self.accelerator.autocast(), torch.no_grad():
            output = self._model(**{
                "obs4lam": obs4lam,
                "obs4vaewm": obs4vaewm,
                "pad_mask_in": pad_mask_in,
                "obs4vaewm_mask": obs4vaewm_mask
            }, only_return_actions=True)

        vq_indices = output.label
        clip_len = pad_mask_in.sum(1).tolist()
        has_state_action = batch['has_action']
        assert has_state_action.all()

        vq_indices_list = torch.split(vq_indices, [x - 1 for x in clip_len], dim=0)
        vq_indices_cat = torch.cat([x for x, e in zip(vq_indices_list, has_state_action) if e])
        pclip_len = [x for x, e in zip(clip_len, has_state_action) if e]

        def select_last_k(x: list[list], k: list[int], loffset: int = 0, roffset: int | None = None, right_padding: bool = False):
            assert len(x) == len(k), "Length of x and k must be the same"
            if not right_padding:
                roffset = roffset if roffset is None else -roffset
                ret = [x[i][-k[i] + loffset : roffset] for i in range(len(k))]
            else:
                roffset = 0 if roffset is None else roffset
                ret = [x[i][: k[i] - roffset] for i in range(len(k))]
            return ret

        action_input = torch.cat(select_last_k(action, pclip_len, right_padding=True, roffset=1), dim=0)
        action_input_mask = torch.cat(select_last_k(state_action_mask, pclip_len, right_padding=True, roffset=1), dim=0)
        assert action_input_mask.all(), "action_input_mask should be all 1"

        def repeat_to_batch(x: torch.Tensor, length: list[int]):
            return torch.repeat_interleave(
                x, repeats=torch.tensor(length).to(x.device) - 1, dim=0
            ).long()
        valid_action_mask = repeat_to_batch(valid_action_mask, pclip_len)
        assert valid_action_mask.all(), "valid_action_mask should be all 1"

        pred_vq_logits = adapter(action_input)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        adapter_loss = loss_fn(pred_vq_logits.view(-1, pred_vq_logits.shape[-1]), vq_indices_cat.view(-1))
        adapter_loss = adapter_loss.view(pred_vq_logits.shape[0], pred_vq_logits.shape[1])

        adapter_loss_mask = (torch.prod(action_input_mask, dim=1) * torch.prod(valid_action_mask, dim=1)).unsqueeze(-1).expand(-1, adapter_loss.shape[1])
        assert adapter_loss_mask.all(), "adapter_loss_mask should be all 1"
        adapter_loss = (adapter_loss * adapter_loss_mask).sum() / (adapter_loss_mask.sum() + 1e-8)

        pred_vq_indices = torch.argmax(pred_vq_logits, dim=-1)
        adapter_accuracy = ((pred_vq_indices == vq_indices_cat).float() * adapter_loss_mask).sum() / (adapter_loss_mask.sum() + 1e-8)

        topk = 3
        _, topk_indices = torch.topk(pred_vq_logits, k=topk, dim=-1)
        top_k_correct = torch.any(topk_indices == vq_indices_cat.unsqueeze(-1), dim=-1).float()
        top_k_accuracy = (top_k_correct * adapter_loss_mask).sum() / (adapter_loss_mask.sum() + 1e-8)

        return adapter_loss, {'top1_accuracy': adapter_accuracy, f'top{topk}_accuracy': top_k_accuracy}

    @torch.no_grad()
    def adapter_predict(self, adapter, real_actions, pad_mask_in=None, include_last_action=False):
        def select_last_k(x: list[list], k: list[int], loffset: int = 0, roffset: int | None = None, right_padding: bool = False):
            assert len(x) == len(k), "Length of x and k must be the same"
            if not right_padding:
                roffset = roffset if roffset is None else -roffset
                ret = [x[i][-k[i] + loffset : roffset] for i in range(len(k))]
            else:
                roffset = 0 if roffset is None else roffset
                ret = [x[i][: k[i] - roffset] for i in range(len(k))]
            return ret
        
        pclip_len = pad_mask_in.sum(1).tolist() if pad_mask_in is not None else [real_actions.shape[1] for _ in range(real_actions.shape[0])]
        action_input = torch.cat(select_last_k(real_actions, pclip_len,
                                            right_padding=True,
                                            roffset=1 if not include_last_action else 0), dim=0)

        adapter.eval()
        pred_vq_logits = adapter(action_input)
        pred_vq_indices = torch.argmax(pred_vq_logits, dim=-1)
        get_codebook_entry = unwrap_model(self._model).quantizer.get_codebook_entry
        pred_lam = get_codebook_entry(pred_vq_indices)
        return pred_lam