import shutil
import signal
import sys
import os
import json
import time
import copy
import random
from typing import Callable, Any
from functools import partial
import logging
from omegaconf import DictConfig, OmegaConf

from gymnasium import spaces
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm

from build.lib.robobase.envs.wrappers.action_sequence import ActionSequence, RecedingHorizonControl
from robobase import utils
from robobase.envs.env import EnvFactory
from robobase.logger import Logger
from robobase.replay_buffer.prioritized_replay_buffer import PrioritizedReplayBuffer
from robobase.replay_buffer.replay_buffer import ReplayBuffer
from robobase.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
from robobase.envs.wrappers import RescaleFromTanhWithMinMax
from robobase.envs.utils.bigym_utils import ErrorCalculator
from pathlib import Path
from robobase.method.utils import (
    extract_many_from_batch,
    extract_from_batch,
    extract_many_from_spec
)
import hydra
import numpy as np
import torch
import gymnasium as gym
from torch.utils.data import DataLoader
import imageio, cv2

torch.backends.cudnn.benchmark = True

import pdb

import os
import numpy as np
from PIL import Image

def save_numpy_image_to_folder(
    img_array: np.ndarray,
    save_dir: str,
    img_name: str,
    img_format: str = 'png'
) -> None:
    if img_array.dtype != np.uint8:
        raise ValueError(f"数组数据类型必须是 uint8，当前为 {img_array.dtype}")
    
    if img_array.shape != (3, 84, 84):
        raise ValueError(f"数组形状必须是 (3, 84, 84)，当前为 {img_array.shape}")
    
    os.makedirs(save_dir, exist_ok=True)
    
    img_array_transposed = np.transpose(img_array, (1, 2, 0))
    
    try:
        img = Image.fromarray(img_array_transposed)
        save_path = os.path.join(save_dir, f"{img_name}.{img_format.lower()}")
        img.save(save_path)
    except Exception as e:
        raise IOError(f"保存图片失败: {str(e)}") from e




def batch_save_images(
    img_arrays: list[np.ndarray],
    save_dir: str='./temp',
    base_name: str = "frame",
    start_idx: int = 0,
    img_format: str = 'png'
) -> None:
    if os.path.exists(save_dir):
        assert '/temp' in save_dir
        shutil.rmtree(save_dir)
    os.mkdir(save_dir)
    for idx, img in enumerate(img_arrays, start=start_idx):
        save_numpy_image_to_folder(
            img_array=img,
            save_dir=save_dir,
            img_name=f"{base_name}_{idx}",
            img_format=img_format
        )


def images_to_video(
    img_dir: str,
    video_path: str,
    fps: int = 10,
    img_format: str = "png",
    base_name: str = "frame"
) -> None:
    if not os.path.isdir(img_dir):
        raise FileNotFoundError(f"图片目录不存在: {img_dir}")
    
    img_files = [f for f in os.listdir(img_dir) 
                 if f.startswith(base_name) and f.endswith(f".{img_format.lower()}")]
    if not img_files:
        raise FileNotFoundError(f"目录{img_dir}中无符合条件的{img_format}图片")
    
    def extract_idx(filename):
        return int(filename.replace(base_name+"_", "").replace(f".{img_format}", ""))
    img_files.sort(key=extract_idx)
    
    first_img_path = os.path.join(img_dir, img_files[0])
    first_img = cv2.imread(first_img_path)
    if first_img is None:
        raise IOError(f"无法读取图片: {first_img_path}")
    height, width = first_img.shape[:2]
    
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    video_writer = cv2.VideoWriter(
        video_path,
        fourcc,
        fps,
        (width, height)
    )
    
    if not video_writer.isOpened():
        raise IOError(f"无法创建视频写入器: {video_path}")
    
    for img_file in img_files:
        img_path = os.path.join(img_dir, img_files[0])
        img = cv2.imread(os.path.join(img_dir, img_file))
        if img is None:
            print(f"警告：跳过损坏的图片 {img_file}")
            continue
        video_writer.write(img)
    
    video_writer.release()
    print(f"视频已保存至: {video_path}")



def _worker_init_fn(worker_id):
    seed = np.random.get_state()[1][0] + worker_id
    np.random.seed(seed)
    random.seed(int(seed))


def _create_default_replay_buffer(
    cfg: DictConfig,
    observation_space: gym.Space,
    action_space: gym.Space,
) -> ReplayBuffer:
    extra_replay_elements = spaces.Dict({})
    if cfg.demos != 0:
        extra_replay_elements["demo"] = spaces.Box(0, 1, shape=(), dtype=np.uint8)
    replay_class = UniformReplayBuffer

    replay_class = partial(
        replay_class,
        nstep=cfg.replay.nstep,
        gamma=cfg.replay.gamma,
    )
    return replay_class(
        save_dir=cfg.replay.save_dir,
        batch_size=cfg.batch_size,
        replay_capacity=cfg.replay.size,
        action_shape=(1, action_space.shape[-1]),
        action_dtype=action_space.dtype,
        reward_shape=(),
        reward_dtype=np.float32,
        observation_elements=observation_space,
        extra_replay_elements=extra_replay_elements,
        num_workers=cfg.replay.num_workers,
        sequential=cfg.replay.sequential,
        transition_seq_len=cfg.replay.transition_seq_len,
    )


def _create_default_envs(cfg: DictConfig) -> EnvFactory:
    factory = None
    if cfg.env.env_name == "rlbench":
        from robobase.envs.rlbench import RLBenchEnvFactory

        factory = RLBenchEnvFactory()
    elif cfg.env.env_name == "dmc":
        from robobase.envs.dmc import DMCEnvFactory

        factory = DMCEnvFactory()
    elif cfg.env.env_name == "bigym":
        from robobase.envs.bigym import BiGymEnvFactory

        factory = BiGymEnvFactory()
        factory.HIGH_GAIN = True
    elif cfg.env.env_name == "d4rl":
        from robobase.envs.d4rl import D4RLEnvFactory

        factory = D4RLEnvFactory()
    else:
        ValueError()
    return factory

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

def put_text(img, text, font_size=1, thickness=2, resize=False,position="top"):
    img = img.copy()
    if resize:
        img = cv2.resize(np.uint8(img), (256, 256))
    if position == "top":
        p = (10, 30)
    elif position == "bottom":
        p = (300, img.shape[0] - 20)
    img = cv2.putText(
        img,
        text,
        p,
        cv2.FONT_HERSHEY_SIMPLEX,
        font_size,
        (0, 255, 255),
        thickness,
        cv2.LINE_AA,
    )
    return img


class DynamicsWorkspace:
    def __init__(
        self,
        cfg: DictConfig,
        env_factory: EnvFactory = None,
        create_replay_fn: Callable[[DictConfig], ReplayBuffer] = None,
        work_dir: str = None,
        just_ret_demo = False
    ):  
        self.use_pi05=False
        if env_factory is None:
            env_factory = _create_default_envs(cfg)
        if create_replay_fn is None:
            create_replay_fn = _create_default_replay_buffer
    
        self.work_dir = Path(
            hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
            if work_dir is None
            else work_dir
        )
        print(f"workspace: {self.work_dir}")

        self.cfg = cfg
        utils.set_seed_everywhere(cfg.seed)
        dev = "cpu"
        if cfg.num_gpus > 0:
            if sys.platform == "darwin":
                dev = "mps"
            else:
                dev = 0
                job_num = False
                try:
                    job_num = HydraConfig.get().job.get("num", False)
                except ValueError:
                    pass
                if job_num:
                    dev = job_num % cfg.num_gpus
        self.device = torch.device(dev)

        # create logger
        self.logger = Logger(self.work_dir, cfg=self.cfg)
        self.env_factory = env_factory

        # Create evaluation environment

        if (num_demos := cfg.demos) != 0:
            # Collect demos or fetch saved demos before making environments
            # to consider demo-based action space (e.g., standardization)
            self.demo_data = self.env_factory.collect_or_fetch_demos(cfg, num_demos)
        if just_ret_demo:
            # post process
            demo_data = self.env_factory.post_collect_or_fetch_demos(cfg, self.work_dir)
            demos = []
            for i, demo in enumerate(demo_data):
                successful = (demo[0][-1]["demo"] == 1)
                if successful:
                    demos.append(demo)
                else:
                    print(f"Skipping failed demonstration {i}")
                    continue
            self.demo_for_pi05 = demos
            return
        self.agent_eval_env = None
        self.dynamics_eval_env = self.env_factory.make_eval_env(cfg, self.work_dir)
        self.error_calculator = ErrorCalculator(self.dynamics_eval_env)

        # Create the RL Agent
        observation_space = self.dynamics_eval_env.observation_space
        action_space = self.dynamics_eval_env.action_space

        self.dynamics = hydra.utils.instantiate(
            cfg.method,
            device=self.device,
            observation_space=observation_space, # original space
            action_space=action_space, 
            use_var=cfg.use_var,
            use_qpos_pred=cfg.use_qpos_pred,
            use_pixels=cfg.use_pixels,
        )
        self.dynamics.train(False)

        if num_demos != 0:
            # Post-process demos using the information from environments
            self.env_factory.post_collect_or_fetch_demos(cfg, self.work_dir)
        
        # self.obs_stats = copy.deepcopy(self.env_factory._obs_stats)

        self.act_stats = copy.deepcopy(self.env_factory._action_stats)
        # sequential replay buffer, with raw env action space
        self.replay_buffer = create_replay_fn(cfg, observation_space, action_space)
        self.extra_replay_elements = self.replay_buffer.extra_replay_elements

        self.replay_loader = DataLoader(
            self.replay_buffer,
            batch_size=self.replay_buffer.batch_size,
            num_workers=cfg.replay.num_workers,
            pin_memory=cfg.replay.pin_memory,
            worker_init_fn=_worker_init_fn,
        )
        self._replay_iter = None

        # RLBench doesn't like it when we import cv2 before it, so moving
        # import here.

        self._timer = utils.Timer()
        self._pretrain_step = 0
        self._main_loop_iterations = 0
        self._global_env_episode = 0
        # self._act_dim = self.dynamics_eval_env.action_space.shape[1]

        self._shutting_down = False

        self.best_metrics = {
            "best_episode_success": 0,  
            "best_episode_len": 0,  
        }
    
    def prepare_eval(self, pi05=False):
        path_to_agent = f"exp_local/pixel_act/bigym_{self.cfg.env.task_name}_speedup_qv/snapshots/best_snapshot.pt"
        if pi05:
            path_to_agent = path_to_agent.split("/")
            path_to_agent[1] = "pi05"
            path_to_agent = "/".join(path_to_agent)
        print("Loading agents...", path_to_agent)
        self.load_agent_snapshot(path_to_agent)
        if pi05:
            return
        
        path_to_rollouts = f"exp_local/pixel_act/bigym_{self.cfg.env.task_name}_speedup_qv/rollouts/best_policy_rollouts.pt"
        if pi05:
            path_to_rollouts = path_to_rollouts.split("/")
            path_to_rollouts[1] = "pi05"
            path_to_rollouts = "/".join(path_to_rollouts)
        if not os.path.exists(path_to_rollouts):
            print("Generating rollouts...", path_to_rollouts)
            rollouts = self._get_policy_rollouts(num_episodes=10)
            self._save_policy_rollouts(rollouts, path_to_rollouts)
        print("Loading rollouts...", path_to_rollouts)
        self.rollouts = self._load_policy_rollouts(path_to_rollouts)

        path_to_ds_rollouts = f"exp_local/pixel_act/bigym_{self.cfg.env.task_name}_speedup_qv/rollouts/best_policy_ds_rollouts.pt"
        if pi05:
            path_to_ds_rollouts = path_to_ds_rollouts.split("/")
            path_to_ds_rollouts[1] = "pi05"
            path_to_ds_rollouts = "/".join(path_to_ds_rollouts)
        if not os.path.exists(path_to_ds_rollouts):
            print("Generating rollouts...", path_to_ds_rollouts)
            rollouts = self._get_policy_ds_rollouts(2, self.rollouts) # downsample rate 2
            self._save_policy_rollouts(rollouts, path_to_ds_rollouts)
        print("Loading rollouts...", path_to_ds_rollouts)
        self.ds_rollouts = self._load_policy_rollouts(path_to_ds_rollouts)

    # def preprocess_obs_and_act(self, obs, act):
    #     # normalize and flatten
    #     for k, v in obs.items():
    #         if self.cfg.norm_obs and k in self.obs_stats:
    #             v = (v - self.obs_stats["mean"][k]) / self.obs_stats["std"][k]
    #     rescaled_act = RescaleFromTanhWithMinMax.transform_to_tanh(
    #         act,
    #         action_stats=self.act_stats,
    #         min_max_margin=self.cfg.min_max_margin,
    #     )
    #     return obs, rescaled_act

    @property
    def pretrain_steps(self):
        return self._pretrain_step

    @property
    def main_loop_iterations(self):
        return self._main_loop_iterations

    @property
    def global_env_episodes(self):
        return self._global_env_episode

    @property
    def global_env_steps(self):
        """Total number of environment steps taken."""
        return self.pretrain_steps

    @property
    def replay_iter(self):
        if self._replay_iter is None:
            _replay_iter = iter(self.replay_loader)
            self._replay_iter = _replay_iter
        return self._replay_iter

    def train(self):
        signal.signal(signal.SIGINT, self._signal_handler)
        try:
            self._train()
        except Exception as e:
            self.shutdown()
            raise e

    def _train(self):
        # Load Demo
        # self._eval(self.rollouts, seq_len=self.cfg.action_sequence)
        self._load_demos()
        self.prepare_eval()

        # eval before training
        eval_metrics = self._eval(seq_len=self.cfg.action_sequence)
        selected_steps = [1, self.cfg.action_sequence//4, self.cfg.action_sequence//2, self.cfg.action_sequence]
        selected_eval_metrics = {}
        for k, v in eval_metrics.items():
            for selected_step in selected_steps:
                selected_eval_metrics[k+f"_step{selected_step}"] = v[selected_step-1]
        selected_eval_metrics.update(self._get_common_metrics())
        self.logger.log_metrics(
            selected_eval_metrics, self.pretrain_steps, prefix="pretrain_eval"
        )
        
        # Perform pretraining. This is suitable for behaviour cloning or Offline RL
        self._pretrain_on_demos()

        if self.cfg.save_snapshot:
            self.save_snapshot()

        self.shutdown()

    def eval(self) -> dict[str, Any]:
        set_seed(1000)
        self.prepare_eval() # prepare rollouts for evaluation
        return self._eval(seq_len=self.cfg.action_sequence)

    def _eval(self, seq_len):
        metric1 = self._eval_eef(self.rollouts, seq_len, ds_rate=1)
        metric2 = self._eval_eef(self.ds_rollouts, seq_len, ds_rate=2)
        return metric1 | metric2

    def _eval_policy(self, num_episodes) -> dict[str, Any]:
        metric1 = self._eval_adads_policy_new(num_episodes=num_episodes, ds_rates=[2,3,4], threshold=0.01, record_media=False)
        metric2 = self._eval_adads_policy_new(num_episodes=num_episodes, ds_rates=[2,3,4], threshold=0.015, record_media=False)
        if self.best_metrics['best_episode_success'] <= metric1['adads_th0.005_success_rate']:
            self.best_metrics = {
                "best_episode_success": metric1['adads_th0.005_success_rate'],  
                "best_episode_len": metric1['adads_th0.005_avg_steps'], 
            }
            self.save_snapshot(best_ckpt = True)
        if self.best_metrics['best_episode_success'] <= metric2['adads_th0.01_success_rate']:
            self.best_metrics = {
                "best_episode_success": metric2['adads_th0.01_success_rate'],  
                "best_episode_len": metric2['adads_th0.01_avg_steps'], 
            }
            self.save_snapshot(best_ckpt = True)

        return metric1 | metric2

    def _eval_original_rollouts(self, num_episodes) -> dict[str, Any]:
        # rollout policy and record rollout data
        self.agent.set_eval_env_running(True)
        env = self.agent_eval_env
        cur_trail_id = 0
        num_suc = 0
        total_steps = 0
        rgb_names = self.cfg.env.cameras
        for i in tqdm(range(num_episodes), desc="Generating Policy Rollouts"):
            observation, infos = env.reset()
            
            rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1)
            low_dim_obs = observation["low_dim_state"]   # unnormalized
            eef = observation["eef"]   # unnormalized

            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            rollouts = {
                'rgb_obs': [rgb_obs],
                'low_dim_obs': [low_dim_obs],
                'eef': [eef],
                'act': [],
                'seed': env.unwrapped._current_seed,
            }
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi05:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()
                next_observation, reward, termination, truncation, next_info = env.step(action)
                all_recorded_data = env.get_wrapper_attr("get_recorded_data")()
                rewards += reward
                done = termination | truncation
                observation = next_observation
                obs, act = all_recorded_data['obs'], all_recorded_data['act']
                rgb_obs = np.stack([obs['rgb_' + name] for name in rgb_names], axis=1)
                low_dim_obs = obs["low_dim_state"]   # unnormalized
                eef = obs["eef"]   # unnormalized
                rollouts['rgb_obs'].append(rgb_obs)
                rollouts['low_dim_obs'].append(low_dim_obs)
                rollouts['eef'].append(eef)
                rollouts['act'].append(act)
            rollouts['rgb_obs'] = np.concatenate(rollouts['rgb_obs'], axis=0)
            rollouts['low_dim_obs'] = np.concatenate(rollouts['low_dim_obs'], axis=0)
            rollouts['eef'] = np.concatenate(rollouts['eef'], axis=0)
            rollouts['act'] = np.concatenate(rollouts['act'], axis=0)
            rollouts['reward'] = rewards
            if rewards > 0:
                num_suc += 1
                print("succ")
                total_steps += len(rollouts['act'])
            cur_trail_id += 1
        print(f"Policy success rate: {num_suc}/{num_episodes}={num_suc/num_episodes}")
        # if self.use_pi05:
        #     exit(0)
        ret = {
            "success_rate": num_suc/num_episodes,
            "avg_steps": total_steps/num_suc if num_suc>0 else None
        }
        env.close()
        return ret, str(self.agent.current_task)


    def _get_policy_rollouts(self, num_episodes) -> dict[str, Any]:
        # rollout policy and record rollout data
        self.agent.set_eval_env_running(True)
        env = self.agent_eval_env
        all_rollouts = []
        num_suc = 0
        rgb_names = self.cfg.env.cameras
        for i in tqdm(range(num_episodes), desc="Generating Policy Rollouts"):
            observation, infos = env.reset()
            
            rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1)
            low_dim_obs = observation["low_dim_state"]   # unnormalized
            eef = observation["eef"]   # unnormalized

            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            rollouts = {
                'rgb_obs': [rgb_obs],
                'low_dim_obs': [low_dim_obs],
                'eef': [eef],
                'act': [],
                'seed': env.unwrapped._current_seed,
            }
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi05:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()
                next_observation, reward, termination, truncation, next_info = env.step(action)
                all_recorded_data = env.get_wrapper_attr("get_recorded_data")()
                rewards += reward
                done = termination | truncation
                observation = next_observation
                obs, act = all_recorded_data['obs'], all_recorded_data['act']
                rgb_obs = np.stack([obs['rgb_' + name] for name in rgb_names], axis=1)
                low_dim_obs = obs["low_dim_state"]   # unnormalized
                eef = obs["eef"]   # unnormalized
                rollouts['rgb_obs'].append(rgb_obs)
                rollouts['low_dim_obs'].append(low_dim_obs)
                rollouts['eef'].append(eef)
                rollouts['act'].append(act)
            rollouts['rgb_obs'] = np.concatenate(rollouts['rgb_obs'], axis=0)
            rollouts['low_dim_obs'] = np.concatenate(rollouts['low_dim_obs'], axis=0)
            rollouts['eef'] = np.concatenate(rollouts['eef'], axis=0)
            rollouts['act'] = np.concatenate(rollouts['act'], axis=0)
            rollouts['reward'] = rewards

            if self.use_pi05:
                save_dir ="./temp"
                batch_save_images(rollouts['rgb_obs'][:, 0], save_dir=save_dir)
                images_to_video(img_dir=save_dir, video_path=f'./temp_videos/pi05_{len(all_rollouts)}.mp4')
            if rewards > 0:
                num_suc += 1
            all_rollouts.append(rollouts)
        print(f"Policy success rate: {num_suc}/{num_episodes}={num_suc/num_episodes}")
        return all_rollouts
    
    def _save_policy_rollouts(self, rollouts, path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("wb") as f:
            torch.save(rollouts, f)
        print(f"Saved {len(rollouts)} rollouts to {path}")
    
    def _load_policy_rollouts(self, path):
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Rollouts file not found at {path}")
        with path.open("rb") as f:
            rollouts = torch.load(f, map_location="cpu", weights_only=False)
        print(f"Loaded {len(rollouts)} rollouts from {path}")
        return rollouts

    def _get_policy_ds_rollouts(self, ds_rate, rollouts):
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        new_rollouts = []
        rgb_names = self.cfg.env.cameras
        num_suc = 0
        for rollout in tqdm(rollouts, desc=f"Generating DS{ds_rate} Policy Rollouts"):
            act, seed = rollout['act'], rollout['seed']
            observation, info = wrapped_env.reset(seed=seed)
            rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1)
            low_dim_obs = observation["low_dim_state"]
            eef = observation["eef"]
            ds_rollout = {
                'rgb_obs': [rgb_obs],
                'low_dim_obs': [low_dim_obs],
                'eef': [eef],
                'act': [],
                'seed': rollout['seed'],
            }

            time_len = len(act)
            rewards = 0
            for i in range(time_len):
                if i % ds_rate == ds_rate -1 and i < time_len:
                    action = act[i]
                    observation, r, _, _, _ = wrapped_env.step(act[i])
                    rewards += r
                    rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1)
                    low_dim_obs = observation["low_dim_state"]
                    eef = observation["eef"]
                    ds_rollout['rgb_obs'].append(rgb_obs)
                    ds_rollout['low_dim_obs'].append(low_dim_obs)
                    ds_rollout['eef'].append(eef)
                    ds_rollout['act'].append(action[np.newaxis, :])
            if rewards > 0:
                num_suc += 1
            ds_rollout['rgb_obs'] = np.concatenate(ds_rollout['rgb_obs'], axis=0)
            ds_rollout['low_dim_obs'] = np.concatenate(ds_rollout['low_dim_obs'], axis=0)
            ds_rollout['eef'] = np.concatenate(ds_rollout['eef'], axis=0)
            ds_rollout['act'] = np.concatenate(ds_rollout['act'], axis=0)
            ds_rollout['reward'] = rewards
            new_rollouts.append(ds_rollout)
        print(f"DS {ds_rate} policy success rate: {num_suc}/{len(self.rollouts)}={num_suc/len(self.rollouts)}")
        wrapped_env.close()
        return new_rollouts

    def _eval_ds_eef_diff(self, rollout, ds_rates=[2,3,4], chunk_length=1):
        # given a rollout and ds_rate, eval eef diff predicted by model and env
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        rgb_names = self.cfg.env.cameras

        diff_preds_all = {}
        diff_reals_all = {}

        for ds_rate in tqdm(ds_rates, "Evaluating EEF Diff with DS Rate"):
            diff_preds = []
            diff_reals = []
            act, seed = rollout['act'], rollout['seed']
            unscaled_act = RescaleFromTanhWithMinMax.transform_from_tanh(
                act,
                action_stats=self.act_stats,
                min_max_margin=self.cfg.min_max_margin,
            )

            observation, info = wrapped_env.reset(seed=seed)

            ds_actions = act[ds_rate-1::ds_rate]
            unscaled_ds_actions = unscaled_act[ds_rate-1::ds_rate]
            

            for i in range(0, len(ds_actions), chunk_length):
                # prediction
                rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1)
                low_dim_obs = observation["low_dim_state"]
                
                chunk_act = ds_actions[i : i + chunk_length]
                unscaled_chunk_act = unscaled_ds_actions[i : i + chunk_length]
                
                act_torch = torch.from_numpy(chunk_act).float().unsqueeze(0).to(self.device) # 1, chunk_len, act_dim
                rgb_obs_torch = torch.from_numpy(rgb_obs).float().to(self.device)
                low_dim_obs_torch = torch.from_numpy(low_dim_obs).float().to(self.device)
                if not getattr(self.cfg, "use_pixels", True):
                    rgb_obs_torch = None 
                chunk_pred, _ = self.dynamics.step_multi(rgb_obs_torch, low_dim_obs_torch, act_torch)[0] # chunk_len, eef_dim
                chunk_pred = chunk_pred.cpu().numpy()

                desired_eefs = []
                actual_eefs = []
                for j in range(len(chunk_act)):
                    # desired
                    desired_eef = self.error_calculator._desired_eef_array(wrapped_env, unscaled_chunk_act[j:j+1])  # warm up
                    desired_eefs.append(desired_eef)
                    # actual
                    observation, r, _, _, _ = wrapped_env.step(chunk_act[j])
                    actual_eef = observation["eef"]
                    actual_eefs.append(actual_eef)
                desired_eef = np.concatenate(desired_eefs, axis=0)
                actual_eef = np.concatenate(actual_eefs, axis=0)
                pred_diff = self.error_calculator.compute_eef_error_from_array(desired_eef, chunk_pred).mean()
                real_diff = self.error_calculator.compute_eef_error_from_array(desired_eef, actual_eef).mean()
                diff_preds.extend([pred_diff] * ds_rate)
                diff_reals.extend([real_diff] * ds_rate)
            diff_preds_all[ds_rate] = diff_preds
            diff_reals_all[ds_rate] = diff_reals

        from matplotlib import pyplot as plt
        for k, v in diff_preds_all.items():
            plt.plot(v, label=f"pred ds{k}")
            plt.plot(diff_reals_all[k], label=f"real ds{k}")
            plt.legend()
        plt.savefig(os.path.join(self.work_dir, f"eef_diff_plot_cl{chunk_length}.png"))
        plt.close()
        wrapped_env.close()

    
    def _eval_eef(self, rollouts, seq_len, ds_rate=1):
        # given rollouts and seq_len, eval dynamics model prediction error
        abs_errors = []
        grip_errors = []
        qpos_errors = []
        # eef_mean = self.obs_stats["mean"]["eef"]
        # eef_std = self.obs_stats["std"]["eef"]
        # tmp_env = self.env_factory._create_env(self.cfg, self.work_dir)
        for rollout in tqdm(rollouts, desc="Evaluating Dynamics Model"):
            rgb_obs, low_dim_state, act = rollout['rgb_obs'], rollout['low_dim_obs'], rollout['act']
            eef = rollout['eef']
            eval_batch_size = self.cfg.batch_size
            time_len = len(act)
            slices = list(range(0, time_len-seq_len, eval_batch_size))
            rest_len = (time_len-seq_len) % eval_batch_size

            for i in slices:
                # prepare batched input
                batch_len = rest_len if i + eval_batch_size >= len(act)-seq_len else eval_batch_size
                act_batch = act[i:i+batch_len+seq_len] 
                rgb_obs_batch = rgb_obs[i:i+batch_len] # b, x
                low_dim_state_batch = low_dim_state[i:i+batch_len]  # b, x 
                act_seq_batch = np.stack([act_batch[j:j+seq_len] for j in range(batch_len)], axis=0) # b, seq_len, x
                # already normalized in data collection
                normalized_act_seq_batch = torch.from_numpy(act_seq_batch).float().to(self.device)
                # rgb_obs does not need normalization
                rgb_obs_batch = torch.from_numpy(rgb_obs_batch).float().to(self.device)
                normalized_low_dim_state_batch = torch.from_numpy(low_dim_state_batch).float().to(self.device)
                if not getattr(self.cfg, "use_pixels", True):
                    rgb_obs_batch = None  
                next_eef_pred_batch, next_qpos_pred_batch = self.dynamics.step_multi(rgb_obs_batch, normalized_low_dim_state_batch, normalized_act_seq_batch) # b, seq_len, x
                
                next_eef_pred_batch = next_eef_pred_batch.cpu().numpy()
                
                all_next_eef_batch = eef[i+1:i+1+batch_len+seq_len]
                
                next_eef_batch = np.stack([all_next_eef_batch[j:j+seq_len] for j in range(batch_len)], axis=0)
                # only examine qpos per step error
                e = np.abs(next_eef_pred_batch - next_eef_batch)
                abs_errors.append(np.mean(e, axis=(0, 2)))
                grip_errors.append(np.mean(e[..., [7,15]], axis=(0, 2)))  # exclude gripper

                if self.cfg.use_qpos_pred:
                    next_qpos_pred_batch = next_qpos_pred_batch.cpu().numpy()
                    all_next_qpos_batch = low_dim_state[i+1:i+1+batch_len+seq_len]
                    next_qpos_batch = np.stack([all_next_qpos_batch[j:j+seq_len] for j in range(batch_len)], axis=0)
                    qpos_error = np.mean(np.abs(next_qpos_pred_batch - next_qpos_batch), axis=(0, 2)) # reserve time length dim
                    qpos_errors.append(qpos_error)
                
        mean_abs_error = np.mean(np.array(abs_errors), axis=0)  # reserve time length dim
        grip_errors = np.mean(np.array(grip_errors), axis=0)  # reserve time length dim
        if self.cfg.use_qpos_pred:
            mean_qpos_error = np.mean(np.array(qpos_errors), axis=0)  # reserve time length dim
        else:
            mean_qpos_error = 0

        return {
            f"mean_eef_error_ds{ds_rate}": mean_abs_error,
            f"mean_eef_grip_error_ds{ds_rate}": grip_errors,
            f"mean_qpos_error_ds{ds_rate}": mean_qpos_error,
        }

    def get_feedback_eef(self, eef_model, observation, action_seq):
        rgb_names = self.cfg.env.cameras
        rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1) # 1, ...
        low_dim_obs = observation["low_dim_state"] # 1, ...
        rgb_obs_torch = torch.from_numpy(rgb_obs).float().to(self.device)
        low_dim_obs_torch = torch.from_numpy(low_dim_obs).float().to(self.device)
        action_seq_torch = torch.from_numpy(action_seq).float().unsqueeze(0).to(self.device) # 1, seq_len, act_dim
        if not getattr(self.cfg, "use_pixels", True):
            rgb_obs_torch = None 
        feedback_eefs, _ = eef_model.step_multi(rgb_obs_torch, low_dim_obs_torch, action_seq_torch) # seq_len, eef_dim
        feedback_eefs = feedback_eefs[0].cpu().numpy() # seq_len, eef_dim
        return feedback_eefs
    
    def get_feedback_eef_interpolated(self, eef_model, observation, action_seq, ds):
        def interpolate_pos_quat(position, quat, gripper, time_indices, target_indices):
            """插值位置、四元数和夹爪（使用NLERP）"""
            # 转换为 numpy 
            pos_pts = np.array(position)
            quat_pts = np.array(quat)
            gripper_pts = np.array(gripper)
            
            # 修正四元数符号翻转 (Crucial for NLERP)
            # 确保相邻四元数点积为正，保证走最短路径
            for i in range(1, len(quat_pts)):
                if np.dot(quat_pts[i-1], quat_pts[i]) < 0:
                    quat_pts[i] *= -1

            # 向量化插值
            # 位置插值 (Linear)
            pos_interp = np.zeros((len(target_indices), 3))
            for i in range(3):
                pos_interp[:, i] = np.interp(target_indices, time_indices, pos_pts[:, i])
            
            # 姿态插值 (NLERP)
            quat_interp = np.zeros((len(target_indices), 4))
            for i in range(4):
                quat_interp[:, i] = np.interp(target_indices, time_indices, quat_pts[:, i])
            
            # NLERP 必须做的归一化
            norms = np.linalg.norm(quat_interp, axis=1, keepdims=True)
            quat_interp /= (norms + 1e-12) # 防止除零
            
            # 夹爪插值 (Linear)
            gripper_interp = np.interp(target_indices, time_indices, gripper_pts)

            return pos_interp, quat_interp, gripper_interp
        
        rgb_names = self.cfg.env.cameras
        rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1) # 1, ...
        low_dim_obs = observation["low_dim_state"] # 1, ...
        rgb_obs_torch = torch.from_numpy(rgb_obs).float().to(self.device)
        low_dim_obs_torch = torch.from_numpy(low_dim_obs).float().to(self.device)
        action_seq_torch = torch.from_numpy(action_seq).float().unsqueeze(0).to(self.device) # 1, seq_len, act_dim
        if not getattr(self.cfg, "use_pixels", True):
            rgb_obs_torch = None 
        feedback_eefs, _ = eef_model.step_multi(rgb_obs_torch, low_dim_obs_torch, action_seq_torch) # seq_len, eef_dim
        feedback_eefs = feedback_eefs[0].cpu().numpy() # seq_len, eef_dim
        feedback_eefs_addfirst = np.concatenate([observation['eef'], feedback_eefs], axis=0)
        # action_seq 已经是下采样后的序列，长度为 len(action_seq)
        # feedback_eefs 对应下采样后的时间点，长度也是 len(action_seq)
        # 需要插值回原始完整序列，原始序列长度为 len(action_seq) * ds
        
        # 1. 设置时间索引
        # time_indices: 下采样点的时间索引 [0, ds, 2*ds, ..., (len(action_seq)-1)*ds]
        seq_len = len(action_seq)
        time_indices = np.arange(0, seq_len+1) * ds  # 下采样后的时间索引
        # target_indices: 原始完整序列的所有时间步 [0, 1, 2, ..., seq_len*ds-1]
        target_indices = np.arange(1, seq_len * ds + 1)

        # 2. 提取左右臂的位置、四元数和夹爪
        # eef 格式: [left_pos(3), left_quat(4), left_grip(1), right_pos(3), right_quat(4), right_grip(1)]
        left_pos_pts = feedback_eefs_addfirst[:, :3]
        left_quat_pts = feedback_eefs_addfirst[:, 3:7]  # w, x, y, z
        left_grip_pts = feedback_eefs_addfirst[:, 7]
        right_pos_pts = feedback_eefs_addfirst[:, 8:11]
        right_quat_pts = feedback_eefs_addfirst[:, 11:15]  # w, x, y, z
        right_grip_pts = feedback_eefs_addfirst[:, 15]

        # 3. 使用插值函数分别处理左右臂
        left_pos_interp, left_quat_interp, left_grip_interp = interpolate_pos_quat(
            left_pos_pts, left_quat_pts, left_grip_pts, time_indices, target_indices
        )
        right_pos_interp, right_quat_interp, right_grip_interp = interpolate_pos_quat(
            right_pos_pts, right_quat_pts, right_grip_pts, time_indices, target_indices
        )
        
        # 4. 合并结果
        feedback_eefs_interpolated = np.concatenate([
            left_pos_interp,      # seq_len*ds, 3
            left_quat_interp,     # seq_len*ds, 4
            left_grip_interp[:, None],  # seq_len*ds, 1
            right_pos_interp,     # seq_len*ds, 3
            right_quat_interp,    # seq_len*ds, 4
            right_grip_interp[:, None]  # seq_len*ds, 1
        ], axis=1)  # seq_len*ds, 16

        return feedback_eefs_interpolated
    
    def get_desired_eef(self, env, action_seq):
        unscaled_action_seq = RescaleFromTanhWithMinMax.transform_from_tanh(
            action_seq,
            action_stats=self.act_stats,
            min_max_margin=self.cfg.min_max_margin,
        )
        desired_eefs = self.error_calculator._desired_eef_array(env, unscaled_action_seq) # seq_len, eef_dim
        return desired_eefs
    
    def compute_diff_eef(self, desired_eefs, feedback_eefs, mode="mean"):
        error = self.error_calculator.compute_eef_error_from_array(desired_eefs, feedback_eefs)
        if mode == "mean":
            return np.mean(error)
        elif mode == "max":
            return np.max(error)
        else:
            return np.min(error)
    
    def _eval_te_policy(self, num_episodes=50):
        # rollout policy and record rollout data
        self.agent.set_eval_env_running(True)
        env = self.agent_eval_env
        num_suc = 0
        ep_lengths = []
        for i in tqdm(range(num_episodes), desc="Evaluating Temporal Ensemble Policy Rollouts"):
            observation, infos = env.reset()
            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            num_steps = 0
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi0:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()
                next_observation, reward, termination, truncation, next_info = env.step(action)
                rewards += reward
                done = termination | truncation
                observation = next_observation
                num_steps += next_info["sub_time_count"]
            ep_lengths.append(num_steps)
            if rewards > 0:
                num_suc += 1
        print(f"Policy success rate: {num_suc}/{num_episodes}={num_suc/num_episodes}")
        print(f"Average episode length: {np.mean(ep_lengths)}")

        
    def _eval_ds_policy(self, num_episodes=50, ds=2, record_media=False):
        # rollout policy with adaptive downsample rates
        self.agent.set_eval_env_running(True)
        # get no chunk-wrapper env
       
        high_gain = False if ds == 1 else True
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir, high_gain=high_gain),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        all_frames = []
        num_steps_list = []
        suc_list = []
        
        for i in tqdm(range(num_episodes), desc=f"Testing DS{ds} Policy"):
            observation, infos = wrapped_env.reset()
            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            num_steps = 0
            frames = []
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi05:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=wrapped_env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()
                ds_actions = action[0::ds]
                for ds_action in ds_actions:
                    next_observation, reward, termination, truncation, next_info = wrapped_env.step(ds_action)
                    rewards += reward
                    num_steps += 1
                    if record_media and wrapped_env.render_mode:
                        frame = wrapped_env.render()
                        frame = put_text(frame, f"{num_steps},{ds}", font_size=1, resize=False)
                        frames.append(frame)
                done = termination | truncation
                observation = next_observation
            suc_list.append(rewards>0)
            if rewards>0:
                num_steps_list.append(num_steps)
            if record_media:
                all_frames.append(frames)
        
        print(f"Policy success rate: {sum(suc_list)/num_episodes}")
        print(f"Average number of steps: {np.mean(num_steps_list)}")
        if record_media:
            for i, frames in enumerate(all_frames):
                os.makedirs(os.path.join(self.work_dir, f"ds{ds}"), exist_ok=True)
                imageio.mimsave(os.path.join(self.work_dir, f"ds{ds}", f"{i}.mp4"), np.array(frames), fps=wrapped_env.unwrapped._control_frequency)
        metric = {
            f"ds{ds}_success_rate": sum(suc_list)/num_episodes,
            f"ds{ds}_avg_steps": np.mean(num_steps_list),
        }
        wrapped_env.close()
        return metric
    
    def _eval_adads_policy(self, num_episodes=50, ds_rates=[2,3,4], threshold=0.01, record_media=False, baseline_ds=None):
        # rollout policy with adaptive downsample rates
        self.agent.set_eval_env_running(True)
        # get no chunk-wrapper env
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        all_frames = []
        num_steps_list = []
        suc_list = []
        downsample_list = sorted(list(ds_rates), reverse=True)
        media_dir = os.path.join(self.work_dir, f"adads_th{threshold}")
        os.makedirs(media_dir, exist_ok=True)

        for i in tqdm(range(num_episodes), desc="Testing ADADS Policy"):
            observation, infos = wrapped_env.reset()
            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            num_steps = 0
            frames = []
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi05:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=wrapped_env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()
                all_desired_eef = self.get_desired_eef(wrapped_env, action)
                min_ds = min(downsample_list)
                baseline_ds = min_ds if baseline_ds is None else baseline_ds
                baseline_actions = action[0::baseline_ds]
                # baseline_actions = action[0::1]
                desired_eef = all_desired_eef[0::baseline_ds]
                feedback_eef = self.get_feedback_eef(self.dynamics, observation, baseline_actions)
                
                eef_baseline_value = self.compute_diff_eef(desired_eef, feedback_eef)

                for ds in downsample_list:
                    ds_actions = action[0::ds]
                    if ds != min_ds:
                        desired_eef = all_desired_eef[0::ds]
                        feedback_eef = self.get_feedback_eef(self.dynamics, observation, ds_actions)
                        eef_value = self.compute_diff_eef(desired_eef, feedback_eef, mode="mean")
                        difference = abs(eef_value - eef_baseline_value)
                        if difference <= threshold:
                            break

                for ds_action in ds_actions:
                    next_observation, reward, termination, truncation, next_info = wrapped_env.step(ds_action)
                    rewards += reward
                    num_steps += 1
                    if record_media and wrapped_env.render_mode:
                        frame = wrapped_env.render()
                        frame = put_text(frame, f"{num_steps},{ds}", font_size=1, resize=False)
                        frames.append(frame)
                done = termination | truncation
                observation = next_observation
            suc_list.append(rewards>0)
            if rewards>0:
                num_steps_list.append(num_steps)
            if record_media:
                imageio.mimsave(os.path.join(media_dir, f"{i}_{rewards>0}.mp4"), np.array(frames), fps=wrapped_env.unwrapped._control_frequency)

        
        print(f"Policy success rate: {sum(suc_list)/num_episodes}")
        print(f"Average number of steps: {np.mean(num_steps_list)}")
       
        metric = {
            f"adads_th{threshold}_success_rate": sum(suc_list)/num_episodes,
            f"adads_th{threshold}_avg_steps": np.mean(num_steps_list),
        }
        wrapped_env.close()
        return metric

    def _eval_adads_policy_new(self, num_episodes=50, ds_rates=[2,3,4], threshold=0.01, record_media=False, baseline_ds=None):
        # rollout policy with adaptive downsample rates
        self.agent.set_eval_env_running(True)
        # get no chunk-wrapper env
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        all_frames = []
        num_steps_list = []
        suc_list = []
        downsample_list = sorted(list(ds_rates), reverse=True)
        media_dir = os.path.join(self.work_dir, f"adads_th{threshold}_new")
        os.makedirs(media_dir, exist_ok=True)
        for i in tqdm(range(num_episodes), desc="Testing ADADS Policy"):
            observation, infos = wrapped_env.reset()
            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            num_steps = 0
            frames = []
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }
                    if self.use_pi05:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True, act_env=wrapped_env
                        )
                    else:
                        action = self.agent.act(
                            torch_observations, self.main_loop_iterations, eval_mode=True
                        )
                        action = action[0].cpu().detach().numpy()

                min_ds = min(downsample_list)
                baseline_ds = min_ds if baseline_ds is None else baseline_ds
                baseline_actions = action[baseline_ds-1::baseline_ds]
                feedback_eef_baseline = self.get_feedback_eef_interpolated(self.dynamics, observation, baseline_actions, baseline_ds)

                for ds in downsample_list:
                    ds_actions = action[ds-1::ds]
                    if ds != min_ds:
                        feedback_eef = self.get_feedback_eef_interpolated(self.dynamics, observation, ds_actions, ds)
                        min_length = min(len(feedback_eef_baseline), len(feedback_eef))
                        difference = self.compute_diff_eef(feedback_eef_baseline[:min_length], feedback_eef[:min_length], mode="max")
                        print(f"ds:{ds}, difference:{difference}")
                        if difference <= threshold:
                            break

                for ds_action in ds_actions:
                    next_observation, reward, termination, truncation, next_info = wrapped_env.step(ds_action)
                    rewards += reward
                    num_steps += 1
                    if record_media and wrapped_env.render_mode:
                        frame = wrapped_env.render()
                        frame = put_text(frame, f"{num_steps},{ds}", font_size=1, resize=False)
                        frames.append(frame)
                done = termination | truncation
                observation = next_observation
            if record_media:
                imageio.mimsave(os.path.join(media_dir, f"{i}_{rewards>0}.mp4"), np.array(frames), fps=wrapped_env.unwrapped._control_frequency)
            frames = []
            suc_list.append(rewards>0)
            if rewards>0:
                num_steps_list.append(num_steps)
        
        print(f"Policy success rate: {sum(suc_list)/num_episodes}")
        print(f"Average number of steps: {np.mean(num_steps_list)}")
        
        metric = {
            f"adads_th{threshold}_success_rate": sum(suc_list)/num_episodes,
            f"adads_th{threshold}_avg_steps": np.mean(num_steps_list),
        }
        wrapped_env.close()
        return metric

    def _signal_handler(self, sig, frame):
        print("\nCtrl+C detected. Preparing to shutdown...")
        self._shutting_down = True

    def _load_demos(self):
        if (num_demos := self.cfg.demos) != 0:
            # NOTE: Currently we do not protect demos from being evicted from replay
            self.env_factory.load_demos_into_replay(
                self.cfg,
                self.replay_buffer,
                is_demo_buffer=False, # load all demos
            )
        else:
            raise NotImplementedError

    def _perform_updates(self) -> dict[str, Any]:
        # if self.dynamics.logging:
        #     start_time = time.time()
        metrics = {}
        self.dynamics.train(True)
        metrics.update(
            self.dynamics.update(
                self.replay_iter, self.main_loop_iterations, self.replay_buffer
            )
        )
        self.dynamics.train(False)

        return metrics

    def _pretrain_on_demos(self):
        pre_train_until_step = utils.Until(self.cfg.num_pretrain_steps)
        should_pretrain_log = utils.Every(self.cfg.log_pretrain_every)
        should_pretrain_eval = utils.Every(self.cfg.eval_every_steps)
        snapshot_every_n = self.cfg.snapshot_every_n if self.cfg.save_snapshot else 0
        should_save_snapshot = utils.Every(snapshot_every_n)
        if self.cfg.log_pretrain_every > 0:
            assert self.cfg.num_pretrain_steps % self.cfg.log_pretrain_every == 0
        if len(self.replay_buffer) <= 0:
            raise ValueError(
                "there is no sample to pre-train with in the replay buffer "
                f"but num_pretrain_steps ({self.cfg.num_pretrain_steps}) is > 0"
            )

        while pre_train_until_step(self.pretrain_steps):
            self.dynamics.logging = False

            if should_pretrain_log(self.pretrain_steps):
                self.dynamics.logging = True
            pretrain_metrics = self._perform_updates()

            if should_pretrain_log(self.pretrain_steps):
                pretrain_metrics.update(self._get_common_metrics())
                self.logger.log_metrics(
                    pretrain_metrics, self.pretrain_steps, prefix="pretrain"
                )
            # TODO: Finish eval logic for dynamics model
            if should_pretrain_eval(self.pretrain_steps):
                eval_metrics = self._eval(seq_len=self.cfg.action_sequence)
                selected_steps = [1, self.cfg.action_sequence//4, self.cfg.action_sequence//2, self.cfg.action_sequence]
                selected_eval_metrics = {}
                for k, v in eval_metrics.items():
                    for selected_step in selected_steps:
                        selected_eval_metrics[k+f"_step{selected_step}"] = v[selected_step-1]
                selected_eval_metrics.update(self._get_common_metrics())
                self.logger.log_metrics(
                    selected_eval_metrics, self.pretrain_steps, prefix="pretrain_eval"
                )

                policy_eval_metrics = self._eval_policy(num_episodes=50)
                self.logger.log_metrics(
                    policy_eval_metrics, self.pretrain_steps, prefix="pretrain_eval"
                )
            
            if should_save_snapshot(self._pretrain_step):
                self.save_snapshot()

            self._pretrain_step += 1

    def _get_common_metrics(self) -> dict[str, Any]:
        _, total_time = self._timer.reset()
        metrics = {
            "total_time": total_time,
            "env_steps": self.global_env_steps,
            "env_episodes": self.global_env_episodes,
            "buffer_size": len(self.replay_buffer),
        }

        return metrics

    def shutdown(self):
        if self.dynamics_eval_env:
            self.dynamics_eval_env.close()
        if self.agent_eval_env:
            self.agent_eval_env.close()

        self.replay_buffer.shutdown()

    def save_snapshot(self, best_ckpt=False):
        snapshot = self.work_dir / "snapshots" / f"{self.global_env_steps}_snapshot.pt"
        if best_ckpt:
            snapshot = self.work_dir / "snapshots" / f"best_snapshot.pt"
        snapshot.parent.mkdir(parents=True, exist_ok=True)
        keys_to_save = [
            # "obs_stats",
            # "act_stats",
            "_pretrain_step",
            "_main_loop_iterations",
            "_global_env_episode",
            "cfg",
        ]
        payload = {k: self.__dict__[k] for k in keys_to_save}
        payload["dynamics"] = self.dynamics.state_dict()
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        latest_snapshot = self.work_dir / "snapshots" / "latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)

    def load_snapshot(self, path_to_snapshot_to_load=None):
        if path_to_snapshot_to_load is None:
            path_to_snapshot_to_load = (
                self.work_dir / "snapshots" / "latest_snapshot.pt"
            )
        else:
            path_to_snapshot_to_load = Path(path_to_snapshot_to_load)
        if not path_to_snapshot_to_load.is_file():
            raise ValueError(
                f"Provided file '{str(path_to_snapshot_to_load)}' is not a snapshot."
            )
        with path_to_snapshot_to_load.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        
        for k, v in payload.items():
            self.__dict__[k] = v
        # after loading recorded obs and act stats, post-process demos again
        # self.env_factory._obs_stats = self.obs_stats
        # self.env_factory._action_stats = self.act_stats
        # self.env_factory.post_collect_or_fetch_demos(
        #     self.cfg, self.work_dir
        # )

    def load_agent_snapshot(self, path_to_agent):
        pi05 = ("pi05" in path_to_agent)
        if pi05:
            dual_act_path_agent = path_to_agent.split('/')
            dual_act_path_agent[1] = 'pixel_act'
            dual_act_path_agent = '/'.join(dual_act_path_agent)
            path_to_agent = dual_act_path_agent
        path_to_agent = Path(path_to_agent)
        
        with path_to_agent.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        self.agent_cfg = payload.pop("cfg")
        if pi05:
            self.agent_cfg.method['_target_'] = 'robobase.method.pi05.Pi05BCAgent'

        self.agent_eval_env = self.env_factory.make_eval_env(self.agent_cfg, self.work_dir)
        self.agent = hydra.utils.instantiate(
            self.agent_cfg.method,
            current_task=self.cfg.env.task_name,
            device=self.device,
            observation_space=self.agent_eval_env.observation_space,
            action_space=self.agent_eval_env.action_space,
            num_train_envs=self.agent_cfg.num_train_envs,
            replay_alpha=self.agent_cfg.replay.alpha,
            replay_beta=self.agent_cfg.replay.beta,
            frame_stack_on_channel=self.agent_cfg.frame_stack_on_channel,
        )
        self.agent.load_state_dict(payload.pop("agent"))
        if self.agent_cfg.load_ema:
            print("Load ema...")
            self.agent.actor.ema.load_state_dict(payload.pop("ema"))
        
