"""
Dynamics Workspace - World Model 训练工作空间
仿照 robobase/robobase/dynamics_workspace.py 实现
"""

import shutil
import signal
import sys
import os
import json
import time
import copy
import random
import logging
from typing import Callable, Any, Optional, Dict, List, Tuple
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict

from config import Config
from models.world_model import WorldModel
from dataprocess import create_loader, analyze_episodes, split_train_val
from utils import (
    set_seed,
    Timer,
    Every,
    Until,
    Logger,
    create_output_dir,
    save_config,
    put_text,
    to_torch,
    to_numpy,
    merge_delta_actions,
    normalize_action_controller,
    unnormalize_action_controller,
    convert_delta_to_absolute_actions,
    compute_eef_error,
    compute_eef_error_ratio,
    interpolate_pos_quat,
    extract_eef_from_obs,
    quat2axisangle,
    axisangle2quat,

    apply_patches,
    revert_patches
)


class DynamicsWorkspace:
    """World Model 训练工作空间"""
    
    def __init__(
        self,
        cfg: Config,
        work_dir: Optional[str] = None,
        train: bool = True
    ):
        """
        Args:
            cfg: 配置对象
            work_dir: 工作目录（如果为 None，则使用配置中的输出目录）
            train: 是否训练
        """
        self.cfg = cfg
        
        # 设置随机种子
        set_seed(cfg.training.seed)
        
        # 设置设备
        dev = "cpu"
        if cfg.training.device != "cpu":
            if sys.platform == "darwin":
                dev = "mps"
            else:
                dev = cfg.training.device
        self.device = torch.device(dev)
        print(f"使用设备: {self.device}")
        
        # 创建工作目录
        if work_dir is None:
            self.work_dir = create_output_dir(
                cfg.log.output_dir,
                cfg.log.exp_name
            )
        else:
            self.work_dir = Path(work_dir)
            self.work_dir.mkdir(parents=True, exist_ok=True)
        print(f"工作目录: {self.work_dir}")
        
        # 创建 logger
        self.logger = Logger(
            log_dir=self.work_dir,
            use_wandb=cfg.log.use_wandb,
            use_tensorboard=cfg.log.use_tensorboard
        )
        
        # 如果使用 wandb，初始化
        if cfg.log.use_wandb:
            import wandb
            wandb.init(
                project=cfg.log.wandb_project,
                entity=cfg.log.wandb_entity,
                name=cfg.log.wandb_name or cfg.log.exp_name,
                mode=cfg.log.wandb_mode,
                config=cfg.to_dict()
            )
        
        # 加载数据
        if train:
            print("\n加载数据...")
            self.load_data()
            # 创建数据采样器（直接使用 episodes）
            print("\n准备训练数据...")
            self.create_data_sampler()
        
        # 创建模型
        print("\n创建模型...")
        self.create_model()
        
        # 计时器
        self._timer = Timer()
        
        # 训练状态
        self._pretrain_step = 0
        self._main_loop_iterations = 0
        self._global_env_episode = 0
        
        # 最佳指标
        self.best_metrics = {
            "best_val_loss": float('inf'),
        }
        
        # 关闭标志
        self._shutting_down = False
        # 保存配置
        save_config(cfg.to_dict(), self.work_dir)

    def load_data(self):
        """加载并准备数据"""
        # 创建数据加载器
        loader = create_loader(
            env_name=self.cfg.data.env_name,
            data_path=self.cfg.data.data_dir,
            **self.cfg.data.loader_kwargs
        )

        # 加载 episodes
        episodes = loader.load_episodes(num_episodes=self.cfg.data.num_demos)
        
        # 分析数据
        data_stats = analyze_episodes(episodes)
        
        # 更新配置中的维度（如果未设置）
        if data_stats.get('has_obs'):
            self.cfg.model.obs_dim = data_stats['obs_dim']
        else:
            self.cfg.model.obs_dim = 0  # 只使用 RGB
        
        self.cfg.model.action_dim = data_stats['action_dim']
        
        # 更新 RGB 配置
        if self.cfg.model.use_pixels and data_stats.get('has_rgb'):
            rgb_shape = data_stats['rgb_shape']  # (num_cameras, C, H, W)
            self.cfg.model.num_cameras = rgb_shape[0]
            self.cfg.model.image_channels = rgb_shape[1]
            self.cfg.model.image_size = (rgb_shape[2], rgb_shape[3])
        
        print(f"\n数据维度:")
        print(f"  低维观测维度: {self.cfg.model.obs_dim}")
        if self.cfg.model.framestack > 1:
            print(f"  Framestack: {self.cfg.model.framestack} (实际输入维度: {self.cfg.model.obs_dim * self.cfg.model.framestack})")
        print(f"  动作维度: {self.cfg.model.action_dim}")
        if self.cfg.model.use_pixels:
            print(f"  使用 RGB: True")
            print(f"  相机数量: {self.cfg.model.num_cameras}")
            print(f"  图像尺寸: {self.cfg.model.image_size}")
            if self.cfg.model.framestack > 1:
                print(f"  Framestack: {self.cfg.model.framestack} (实际输入通道数: {self.cfg.model.image_channels * self.cfg.model.framestack})")
        
        # 保存 episodes 供后续使用
        self.episodes = episodes
        
        # 如果是多任务数据，显示任务分布
        if self.cfg.data.env_name == "libero" and 'info' in episodes[0]:
            task_counts = {}
            for ep in episodes:
                if 'info' in ep and 'task' in ep['info']:
                    task = ep['info']['task']
                    task_counts[task] = task_counts.get(task, 0) + 1
            
            if task_counts:
                print(f"\n任务分布:")
                for task, count in sorted(task_counts.items()):
                    print(f"  {task}: {count} episodes")
    
    def create_model(self):
        """创建 World Model"""
        # 计算实际的 obs_dim（考虑 framestack）
        actual_obs_dim = self.cfg.model.obs_dim if self.cfg.model.obs_dim > 0 else 0
        
        self.dynamics = WorldModel(
            obs_dim=actual_obs_dim,
            action_dim=self.cfg.model.action_dim,
            use_pixels=self.cfg.model.use_pixels,
            image_channels=self.cfg.model.image_channels * self.cfg.model.framestack if self.cfg.model.use_pixels else self.cfg.model.image_channels,
            image_size=self.cfg.model.image_size,
            num_cameras=self.cfg.model.num_cameras,
            use_dinov2=self.cfg.model.use_dinov2 if hasattr(self.cfg.model, 'use_dinov2') else False,
            dinov2_model_type=self.cfg.model.dinov2_model_type if hasattr(self.cfg.model, 'dinov2_model_type') else 'dinov2_vits14',
            dinov2_visual_feature_dim=self.cfg.model.dinov2_visual_feature_dim if hasattr(self.cfg.model, 'dinov2_visual_feature_dim') else 64,
            dinov2_mlp_hidden_dims=self.cfg.model.dinov2_mlp_hidden_dims if hasattr(self.cfg.model, 'dinov2_mlp_hidden_dims') else [256, 64],
            dinov2_use_cls_token=self.cfg.model.dinov2_use_cls_token if hasattr(self.cfg.model, 'dinov2_use_cls_token') else True,
            dinov2_dropout=self.cfg.model.dinov2_dropout if hasattr(self.cfg.model, 'dinov2_dropout') else 0.0,
            hidden_dim=self.cfg.model.hidden_dim,
            rnn_num_layers=self.cfg.model.rnn_num_layers,
            dropout=self.cfg.model.dropout,
            learning_rate=self.cfg.training.learning_rate,
            weight_decay=self.cfg.training.weight_decay,
            grad_clip=self.cfg.training.grad_clip,
            use_symlog=self.cfg.model.use_symlog,
            use_var=self.cfg.model.use_var,
            use_residual=self.cfg.model.use_residual,
            framestack=self.cfg.model.framestack,
            device=str(self.device)
        )
        self.dynamics.train(False)  # 默认评估模式
        
        print(f"\n模型参数量: {sum(p.numel() for p in self.dynamics.parameters()):,}")
        if self.cfg.model.framestack > 1:
            print(f"Framestack: {self.cfg.model.framestack} (实际输入维度: obs_dim={actual_obs_dim}, image_channels={self.cfg.model.image_channels * self.cfg.model.framestack if self.cfg.model.use_pixels else self.cfg.model.image_channels})")
        if self.cfg.model.use_pixels and self.cfg.model.use_dinov2 if hasattr(self.cfg.model, 'use_dinov2') else False:
            print(f"使用 DINOv2 编码器: {self.cfg.model.dinov2_model_type if hasattr(self.cfg.model, 'dinov2_model_type') else 'dinov2_vits14'}")
    
    def create_data_sampler(self):
        """准备训练数据（直接使用 episodes list）"""
        # 划分训练集和验证集
        train_episodes, val_episodes = split_train_val(
            self.episodes,
            train_ratio=self.cfg.data.train_ratio,
            shuffle=True
        )
        
        # 直接保存 episodes
        self.train_episodes = train_episodes
        self.val_episodes = val_episodes
        
        # 计算标准化统计信息（仅从训练集计算）
        if self.cfg.data.normalize_obs:
            self._compute_statistics(train_episodes)
        else:
            self.obs_mean = None
            self.obs_std = None
            self.action_mean = None
            self.action_std = None
        
        print(f"\n数据准备完成:")
        print(f"  训练 episodes 数: {len(train_episodes)}")
        print(f"  验证 episodes 数: {len(val_episodes)}")
        
        # 计算可采样序列数量（用于统计）
        train_seq_count = sum(max(0, len(ep['actions']) - self.cfg.training.seq_len + 1) 
                              for ep in train_episodes)
        val_seq_count = sum(max(0, len(ep['actions']) - self.cfg.training.seq_len + 1) 
                            for ep in val_episodes)
        print(f"  训练序列数（估计）: {train_seq_count}")
        print(f"  验证序列数（估计）: {val_seq_count}")
    
    def _compute_statistics(self, episodes: List[Dict[str, np.ndarray]]):
        """计算标准化统计信息"""
        all_obs = []
        all_actions = []
        
        for episode in episodes:
            if 'obs' in episode and self.cfg.model.obs_dim > 0:
                all_obs.append(episode['obs'][:-1])  # 排除最后一个状态
            all_actions.append(episode['actions'])
        
        if all_obs:
            all_obs_concat = np.concatenate(all_obs, axis=0)
            self.obs_mean = np.mean(all_obs_concat, axis=0)
            self.obs_std = np.std(all_obs_concat, axis=0) + 1e-8
        else:
            self.obs_mean = None
            self.obs_std = None
        
        if all_actions:
            all_actions_concat = np.concatenate(all_actions, axis=0)
            self.action_mean = np.mean(all_actions_concat, axis=0)
            self.action_std = np.std(all_actions_concat, axis=0) + 1e-8
        
        if self.cfg.data.normalize_obs:
            print("\n数据统计信息:")
            if self.obs_mean is not None:
                print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
                print(f"  观测标准差: {self.obs_std[:5]}...")
            print(f"  动作均值: {self.action_mean[:5]}...")
            print(f"  动作标准差: {self.action_std[:5]}...")
    
    def _apply_framestack(self, obs_sequence: np.ndarray, framestack: int, first_frame: Optional[np.ndarray] = None) -> np.ndarray:
        """对观测序列应用 framestack
        
        Args:
            obs_sequence: [T, ...] 观测序列
            framestack: 堆叠的帧数
            first_frame: 第一帧观测（用于补全历史，如果为 None 则使用 obs_sequence[0]）
        
        Returns:
            stacked_obs: [T, ...*framestack] 堆叠后的观测
        """
        if framestack == 1:
            return obs_sequence
        
        T = len(obs_sequence)
        if T == 0:
            return obs_sequence
        
        # 如果没有提供 first_frame，使用序列的第一帧
        if first_frame is None:
            first_frame = obs_sequence[0]
        
        stacked_obs = []
        
        for t in range(T):
            # 获取前 framestack 帧（从当前帧往前回溯）
            # 如果历史帧不足，用第一帧补全
            frames = []
            for i in range(framestack):
                # 回溯索引：当前帧是 t，往前找 framestack-1-i 帧
                # 例如 framestack=3, t=0: 需要 [t-2, t-1, t] = [-2, -1, 0]，用第一帧补全
                # 例如 framestack=3, t=1: 需要 [t-2, t-1, t] = [-1, 0, 1]，用第一帧补全第一个
                idx = t - (framestack - 1 - i)
                if idx < 0:
                    # 历史帧不足，用第一帧补全
                    frames.append(first_frame)
                else:
                    frames.append(obs_sequence[idx])
            
            # 堆叠：对于低维观测在特征维度堆叠，对于 RGB 在通道维度堆叠
            if len(obs_sequence.shape) == 2:
                # 低维观测：[obs_dim] -> [obs_dim * framestack]
                stacked = np.concatenate(frames, axis=-1)
            elif len(obs_sequence.shape) == 5:
                # RGB 观测：[num_cameras, C, H, W] -> [num_cameras, C * framestack, H, W]
                # 在通道维度（axis=2）堆叠
                stacked = np.concatenate(frames, axis=2)  # axis=2 是通道维度
            else:
                # 其他情况，尝试在最后一个维度堆叠
                stacked = np.concatenate(frames, axis=-1)
            
            stacked_obs.append(stacked)
        
        return np.array(stacked_obs)
    
    def _sample_batch(self, episodes: List[Dict[str, np.ndarray]], batch_size: int) -> Dict[str, torch.Tensor]:
        """从 episodes 中随机采样一个 batch"""
        seq_len = self.cfg.training.seq_len
        framestack = self.cfg.model.framestack
        
        # 使用列表推导式随机选择 episodes 和起始位置
        selected = [
            (ep, random.randint(0, len(ep['actions']) - seq_len))
            for ep in [random.choice(episodes) for _ in range(batch_size)]
            if len(ep['actions']) >= seq_len
        ]
        
        # 如果过滤后数量不足，补充采样
        while len(selected) < batch_size:
            ep = random.choice(episodes)
            if len(ep['actions']) >= seq_len:
                selected.append((ep, random.randint(0, len(ep['actions']) - seq_len)))
        
        # 使用列表推导式提取序列（不标准化）
        has_obs = 'obs' in episodes[0] and self.cfg.model.obs_dim > 0
        has_rgb = self.cfg.model.use_pixels and 'rgb_obs' in episodes[0]
        
        # 提取序列（包含真实历史帧）
        batch_obs = []
        batch_next_obs = []
        batch_rgb_obs = []
        batch_next_rgb_obs = []
        
        for ep, start in selected:
            if has_obs:
                # 计算需要的起始索引（包含 framestack-1 个历史帧）
                history_start = max(0, start - (framestack - 1))
                # 提取 obs 序列（包含历史）
                obs_seq_with_history = ep['obs'][history_start:start + seq_len + 1]  # [?, obs_dim]
                
                # 如果历史不足，用第一帧补全
                if history_start == 0 and start > 0:
                    # 需要补全
                    num_missing = (framestack - 1) - start
                    if num_missing > 0:
                        first_frame = ep['obs'][0:1]  # [1, obs_dim]
                        padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, obs_dim]
                        obs_seq_with_history = np.concatenate([padding, obs_seq_with_history], axis=0)
                
                # 应用 framestack（现在有真实历史数据）
                obs_stacked = self._apply_framestack(
                    obs_seq_with_history[:len(obs_seq_with_history)-1], 
                    framestack, 
                    first_frame=None  # 不需要补全，已经有完整历史
                )[-seq_len:]  # 只取最后 seq_len 个
                next_obs_stacked = obs_seq_with_history[-seq_len:]  # [seq_len, obs_dim]
                batch_obs.append(obs_stacked)
                batch_next_obs.append(next_obs_stacked)
            
            if has_rgb:
                # 计算需要的起始索引（包含 framestack-1 个历史帧）
                history_start = max(0, start - (framestack - 1))
                # 提取 RGB 序列（包含历史）
                rgb_seq_with_history = ep['rgb_obs'][history_start:start + seq_len + 1]  # [?, V, C, H, W]
                
                # 如果历史不足，用第一帧补全
                if history_start == 0 and start > 0:
                    num_missing = (framestack - 1) - start
                    if num_missing > 0:
                        first_frame = ep['rgb_obs'][0:1]  # [1, V, C, H, W]
                        padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, V, C, H, W]
                        rgb_seq_with_history = np.concatenate([padding, rgb_seq_with_history], axis=0)
                
                # 应用 framestack（现在有真实历史数据）
                rgb_stacked = self._apply_framestack(
                    rgb_seq_with_history[:len(rgb_seq_with_history)-1], 
                    framestack,
                    first_frame=None  # 不需要补全，已经有完整历史
                )[-seq_len:]  # 只取最后 seq_len 个
                next_rgb_stacked = rgb_seq_with_history[-seq_len:]  # [seq_len, V, C, H, W]
                batch_rgb_obs.append(rgb_stacked)
                batch_next_rgb_obs.append(next_rgb_stacked)
        
        batch_actions = [ep['actions'][start:start+seq_len] for ep, start in selected]
        
        # 堆叠并转换为 numpy array
        batch = {}
        if has_obs and batch_obs:
            batch_obs_array = np.stack(batch_obs)  # [batch_size, seq_len, obs_dim*framestack]
            batch_next_obs_array = np.stack(batch_next_obs)  # [batch_size, seq_len, obs_dim*framestack]
            
            # 批量标准化（整个 batch 统一标准化，速度更快）
            if self.cfg.data.normalize_obs and self.obs_mean is not None:
                # 对于堆叠的观测，需要对每个时间步的每个帧分别标准化
                # 重塑为 [batch_size, seq_len, framestack, obs_dim] 进行标准化，然后重新堆叠
                original_obs_dim = self.cfg.model.obs_dim
                obs_reshaped = batch_obs_array.reshape(batch_size, seq_len, framestack, original_obs_dim)
                next_obs_reshaped = batch_next_obs_array.reshape(batch_size, seq_len, original_obs_dim)
                obs_normalized = (obs_reshaped - self.obs_mean) / self.obs_std
                next_obs_normalized = (next_obs_reshaped - self.obs_mean) / self.obs_std
                batch_obs_array = obs_normalized.reshape(batch_size, seq_len, -1)
                batch_next_obs_array = next_obs_normalized.reshape(batch_size, seq_len, -1)
            
            batch['obs'] = torch.from_numpy(batch_obs_array).float()
            batch['next_obs'] = torch.from_numpy(batch_next_obs_array).float()
        
        batch_actions_array = np.stack(batch_actions)  # [batch_size, seq_len, action_dim]
        
        # 批量标准化动作（整个 batch 统一标准化，速度更快）
        if self.cfg.data.normalize_obs and self.action_mean is not None:
            batch_actions_array = (batch_actions_array - self.action_mean) / self.action_std
        
        batch['actions'] = torch.from_numpy(batch_actions_array).float()
        
        if has_rgb and batch_rgb_obs:
            batch['rgb_obs'] = torch.from_numpy(np.stack(batch_rgb_obs)).float()
            batch['next_rgb_obs'] = torch.from_numpy(np.stack(batch_next_rgb_obs)).float()
        
        return batch
    
    @property
    def pretrain_steps(self):
        return self._pretrain_step
    
    @property
    def main_loop_iterations(self):
        return self._main_loop_iterations
    
    @property
    def global_env_episodes(self):
        return self._global_env_episode
    
    @property
    def global_env_steps(self):
        """总的环境步数"""
        return self.pretrain_steps
    
    
 
    
    def train(self):
        """训练循环"""
        print("\n" + "="*60)
        print("开始训练 World Model")
        print("="*60 + "\n")
        
        # 训练控制
        should_train = Until(self.cfg.training.num_train_steps)
        should_log = Every(self.cfg.training.log_every_steps)
        should_eval = Every(self.cfg.training.eval_every_steps)
        should_save = Every(self.cfg.training.save_every_steps)
        
        # 训练循环
        while should_train(self.pretrain_steps):
            # 采样 batch
            batch = self._sample_batch(self.train_episodes, self.cfg.training.batch_size)
            # 训练步骤
            metrics = self._perform_updates(batch)

            # 日志记录
            if should_log(self.pretrain_steps):
                elapsed, total_time = self._timer.reset()
                metrics['elapsed_time'] = elapsed
                metrics['total_time'] = total_time
                metrics['steps_per_sec'] = self.cfg.training.log_every_steps / elapsed
                metrics.update(self._get_common_metrics())
                
            self.logger.log_metrics(
                metrics, self.pretrain_steps, prefix="pretrain"
            )
                
            
            # 评估
            if should_eval(self.pretrain_steps):
                print(f"\n评估 (步骤 {self.pretrain_steps})...")
                eval_metrics = self._eval()
                eval_metrics.update(self._get_common_metrics())
                self.logger.log_metrics(
                    eval_metrics, self.pretrain_steps, prefix="pretrain_eval"
                )
                
                # 保存最佳模型
                if eval_metrics['val_loss'] < self.best_metrics['best_val_loss']:
                    self.best_metrics['best_val_loss'] = eval_metrics['val_loss']
                    self.save_snapshot(best_ckpt=True)
                    print(f"保存最佳模型 (val_loss: {self.best_metrics['best_val_loss']:.4f})")
            
            # 保存检查点
            if should_save(self.pretrain_steps):
                self.save_snapshot()
            
            self._pretrain_step += 1
        
        # 训练结束
        print("\n" + "="*60)
        print("训练完成!")
        print("="*60)
        
        # 最终评估
        print("\n最终评估...")
        final_metrics = self._eval()
        final_metrics.update(self._get_common_metrics())
        self.logger.log_metrics(
            final_metrics, self.pretrain_steps, prefix="final"
        )
        
        # 保存最终模型
        self.save_snapshot()
        
        # 关闭 logger
        self.logger.close()
    
    def _perform_updates(self, batch) -> Dict[str, Any]:
        """执行一次训练更新"""
        metrics = {}
        self.dynamics.train(True)
        
        # 提取数据
        low_dim_obs = batch['obs'].to(self.device) if 'obs' in batch else None
        next_low_dim_obs = batch['next_obs'].to(self.device) if 'next_obs' in batch else None
        actions = batch['actions'].to(self.device)
        
        # RGB 观测
        rgb_obs = batch['rgb_obs'].to(self.device) if 'rgb_obs' in batch else None
        next_rgb_obs = batch['next_rgb_obs'].to(self.device) if 'next_rgb_obs' in batch else None
        
        # 更新模型
        update_metrics = self.dynamics.update(
            low_dim_obs=low_dim_obs,
            rgb_obs=rgb_obs,
            next_low_dim_obs=next_low_dim_obs,
            next_rgb_obs=next_rgb_obs,
            actions=actions
        )
        
        metrics.update(update_metrics)
        self.dynamics.train(False)
        
        return metrics
    
    @torch.no_grad()
    def _eval(self):
        """评估模型"""
        self.dynamics.eval()
        
        total_loss = 0
        total_mse = 0
        total_mse_perstep = defaultdict(float)
        num_batches = 0
        
        # 评估时使用固定数量的 batch
        num_eval_batches = max(10, len(self.val_episodes) // self.cfg.training.batch_size)
        
        for _ in range(num_eval_batches):
            batch = self._sample_batch(self.val_episodes, self.cfg.training.batch_size)
            
            low_dim_obs = batch['obs'].to(self.device) if 'obs' in batch else None
            next_low_dim_obs = batch['next_obs'].to(self.device) if 'next_obs' in batch else None
            actions = batch['actions'].to(self.device)
            
            rgb_obs = batch['rgb_obs'].to(self.device) if 'rgb_obs' in batch else None
            next_rgb_obs = batch['next_rgb_obs'].to(self.device) if 'next_rgb_obs' in batch else None
            
            # 计算单步损失
            loss, metrics = self.dynamics.compute_loss(
                low_dim_obs=low_dim_obs,
                rgb_obs=rgb_obs,
                next_low_dim_obs=next_low_dim_obs,
                next_rgb_obs=next_rgb_obs,
                actions=actions
            )
            for k, v in metrics.items():
                if 'raw_mse_step' in k:
                    total_mse_perstep[k] += v
            total_loss += metrics['total_loss']
            total_mse += metrics['raw_mse']
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        avg_mse = total_mse / num_batches
        for k, v in total_mse_perstep.items():
            total_mse_perstep[k] /= num_batches

        # 计算平均多步预测损失
        eval_metrics = {
            'val_loss': avg_loss,
            'val_mse': avg_mse,
        }
        for k, v in total_mse_perstep.items():
            eval_metrics[k] = v
        
        # 评估 transition 数据（如果存在）
        # 可以通过配置控制是否启用（默认启用）
        transition_metrics = self._eval_transitions_if_available()
        if transition_metrics:
            eval_metrics.update(transition_metrics)
        
        self.dynamics.train()
        
        return eval_metrics
    
    def _eval_transitions_if_available(self) -> Dict[str, float]:
        """如果 transition 数据存在，则进行评估
        
        Returns:
            评估指标字典，如果数据不存在则返回空字典
        """
        # 从 work_dir 推断 dataset_name
        # work_dir 通常是 exp_output/{dataset_name}/world_model
        work_dir_parts = Path(self.work_dir).parts
        dataset_name = None
        
        # 尝试从路径中提取 dataset_name
        if 'exp_output' in work_dir_parts:
            exp_output_idx = work_dir_parts.index('exp_output')
            if exp_output_idx + 1 < len(work_dir_parts):
                dataset_name = work_dir_parts[exp_output_idx + 1]
        
        if dataset_name is None:
            return {}
        
        # 检查 transition 数据是否存在
        transitions_dir = Path("exp_output") / dataset_name / "transitions"
        if not transitions_dir.exists():
            return {}
        
        # 检查是否有 transition 文件
        transition_files = list(transitions_dir.glob("transitions_ds*.pkl"))
        if not transition_files:
            return {}
        
        # 提取所有降采样率
        downsample_rates = []
        for file in transition_files:
            # 从文件名提取 ds 值，例如 "transitions_ds1.pkl" -> 1
            try:
                ds_str = file.stem.replace("transitions_ds", "")
                ds = int(ds_str)
                downsample_rates.append(ds)
            except ValueError:
                continue
        
        if not downsample_rates:
            return {}
        
        # 评估 transition 数据（静默模式，不打印详细信息）
        transition_results = self.eval_world_model_on_transitions(
            dataset_name=dataset_name,
            downsample_rates=sorted(downsample_rates),
            verbose=False,  # 训练评估时使用静默模式
        )
        
        # 将结果转换为评估指标格式
        metrics = {}
        for ds_key, ds_results in transition_results.items():
            # 添加每个降采样率的 MAE 指标
            metrics[f'transition_{ds_key}_mae'] = ds_results['avg_mae']
            metrics[f'transition_{ds_key}_num_sequences'] = ds_results['num_transitions']
        
        return metrics

    
    def _get_common_metrics(self) -> Dict[str, Any]:
        """获取通用指标"""
        _, total_time = self._timer.reset()
        train_seq_count = sum(max(0, len(ep['actions']) - self.cfg.training.seq_len + 1) 
                              for ep in self.train_episodes) if hasattr(self, 'train_episodes') else 0
        metrics = {
            "total_time": total_time,
            "env_steps": self.global_env_steps,
            "env_episodes": self.global_env_episodes,
            "train_episodes": len(self.train_episodes) if hasattr(self, 'train_episodes') else 0,
            "train_seq_count": train_seq_count,
        }
        return metrics
    
    def shutdown(self):
        """关闭工作空间"""
        if hasattr(self, 'logger'):
            self.logger.close()
    
    def save_snapshot(self, best_ckpt: bool = False):
        """保存快照"""
        snapshot_dir = self.work_dir / "snapshots"
        snapshot_dir.mkdir(parents=True, exist_ok=True)
        
        if best_ckpt:
            snapshot = snapshot_dir / "best_snapshot.pt"
        else:
            snapshot = snapshot_dir / f"{self.global_env_steps}_snapshot.pt"
        
        keys_to_save = [
            "_pretrain_step",
            "_main_loop_iterations",
            "_global_env_episode",
        ]
        payload = {k: self.__dict__[k] for k in keys_to_save}
        payload["dynamics"] = self.dynamics.state_dict()
        payload["cfg"] = self.cfg.to_dict()
        
        # 保存统计信息
        if self.cfg.data.normalize_obs:
            stats = {
                'obs_mean': self.obs_mean,
                'obs_std': self.obs_std,
                'action_mean': self.action_mean,
                'action_std': self.action_std
            }
            payload["data_statistics"] = stats
        
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        
        # 同时保存为 latest
        latest_snapshot = snapshot_dir / "latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)
        
        print(f"快照已保存至: {snapshot}")
    
    def load_snapshot(self, path_to_snapshot_to_load: Optional[str] = None):
        """加载快照"""
        if path_to_snapshot_to_load is None:
            path_to_snapshot_to_load = (
                self.work_dir / "snapshots" / "latest_snapshot.pt"
            )
        else:
            path_to_snapshot_to_load = Path(path_to_snapshot_to_load)
        
        if not path_to_snapshot_to_load.is_file():
            raise ValueError(
                f"提供的文件 '{str(path_to_snapshot_to_load)}' 不是快照文件。"
            )
        
        with path_to_snapshot_to_load.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        
        # 加载数据统计信息
        if "data_statistics" in payload:
            stats = payload.pop("data_statistics")
            self.obs_mean = stats.get('obs_mean')
            self.obs_std = stats.get('obs_std')
            self.action_mean = stats.get('action_mean')
            self.action_std = stats.get('action_std')
            print(f"已加载数据统计信息:")
            if self.obs_mean is not None:
                print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
                print(f"  观测标准差: {self.obs_std[:5]}...")
            if self.action_mean is not None:
                print(f"  动作均值: {self.action_mean[:5]}...")
                print(f"  动作标准差: {self.action_std[:5]}...")
        else:
            print("警告: 快照中未找到数据统计信息，将使用默认值（None）")
            # 如果快照中没有统计信息，设置为 None（可能是在 normalize_obs=False 时保存的）
            if not hasattr(self, 'obs_mean'):
                self.obs_mean = None
                self.obs_std = None
                self.action_mean = None
                self.action_std = None
        

        # 加载其他状态
        for k, v in payload.items():
            if k in self.__dict__:
                self.__dict__[k] = v
        
        self.cfg = Config.from_dict(payload.pop("cfg"))
        print(f"快照已从 {path_to_snapshot_to_load} 加载")
    
    # ==================== AdaDS 评估相关方法 ====================
    
    def _detect_task_phase(self, pos_interp: np.ndarray, quat_interp: np.ndarray,
                           gripper_interp: np.ndarray, manipulation_threshold: float,
                           movement_threshold: float, min_ds: int) -> float:
        """基于预测轨迹检测任务阶段：移动 vs 操作"""
        # 位置速度分析
        pos_diffs = np.diff(pos_interp, axis=0)
        pos_speeds = np.linalg.norm(pos_diffs, axis=1)
        avg_pos_speed = np.mean(pos_speeds)
        
        # 夹爪状态分析
        gripper_diffs = np.abs(np.diff(gripper_interp))
        gripper_change_total = np.sum(gripper_diffs)
        
        # 检查是否有停顿
        POS_SPEED_THRESHOLD = 0.005
        POS_SPEED_VERY_SLOW_RATIO = 0.1
        GRIPPER_CHANGE_THRESHOLD = 0.02
        
        very_slow_ratio = np.sum(pos_speeds < POS_SPEED_THRESHOLD) / len(pos_speeds)
        
        # 操作阶段的特征
        is_manipulation = (
            very_slow_ratio > POS_SPEED_VERY_SLOW_RATIO or
            gripper_change_total > GRIPPER_CHANGE_THRESHOLD
        )
        # print(f"avg_pos: {avg_pos_speed:.4f}, "
        #     f"slow_ratio: {very_slow_ratio:.2f}, "
        #     f"gripper_changes: {gripper_change_total:.4f}, is_manip: {is_manipulation}")
        if is_manipulation:
            return manipulation_threshold
        else:
            return movement_threshold
    
    def _compute_chunk_eef_interpolation_with_world_model(
        self, obs_history: List[np.ndarray], rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray, ds: int, min_ds: int = 1, minimum_decay_steps: int = 2
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """使用 World Model 计算动作块的末端执行器插值轨迹
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测 [obs_dim]
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测 [num_cameras, H, W, C] 或 None
            action_chunk: 动作块 [T, action_dim]
            ds: 降采样率
            min_ds: 最小降采样率
        
        Returns:
            pos_interp: 插值后的位置轨迹 [T, 3]
            quat_interp: 插值后的四元数轨迹 [T, 4]
            gripper_interp: 插值后的夹爪状态 [T]
            ds_actions: 降采样后的动作列表
        """
        self.dynamics.eval()
        
        # 获取当前观测（历史列表的最后一个元素）
        current_obs = obs_history[-1]  # [obs_dim]
        # 准备初始观测（支持 framestack）
        framestack = self.cfg.model.framestack
        # 标准化历史观测
        if self.cfg.data.normalize_obs and self.obs_mean is not None:
            obs_history_normalized = [(obs - self.obs_mean) / self.obs_std for obs in obs_history]
        else:
            obs_history_normalized = obs_history
        
        # 应用 framestack：使用历史观测序列，如果不足则用第一帧补全
        if framestack > 1:
            # 获取足够的历史帧（至少 framestack 个）
            if len(obs_history_normalized) >= framestack:
                obs_seq_for_stack = obs_history_normalized[-framestack:]  # 取最后 framestack 个
            else:
                # 历史不足，用第一帧补全
                first_frame = obs_history_normalized[0]
                obs_seq_for_stack = [first_frame] * (framestack - len(obs_history_normalized)) + obs_history_normalized
            
            # 堆叠
            init_obs_stacked = np.concatenate(obs_seq_for_stack, axis=-1)  # [obs_dim * framestack]
        else:
            init_obs_stacked = obs_history_normalized[-1]
        
        # 准备 RGB 观测（支持 framestack）
        if rgb_obs_history is not None and len(rgb_obs_history) > 0 and self.cfg.model.use_pixels:
            # 转换为 [num_cameras, C, H, W] 格式
            rgb_obs_list = []
            for rgb_obs in rgb_obs_history:
                if rgb_obs.shape[-1] == 3:  # [H, W, C] 或 [num_cameras, H, W, C]
                    if len(rgb_obs.shape) == 3:
                        rgb_obs = rgb_obs.transpose(2, 0, 1)[None]  # [1, C, H, W]
                    else:
                        rgb_obs = rgb_obs.transpose(0, 3, 1, 2)  # [num_cameras, C, H, W]
                rgb_obs_list.append(rgb_obs)
            
            # 应用 framestack：使用历史 RGB 观测序列，如果不足则用第一帧补全
            if framestack > 1:
                if len(rgb_obs_list) >= framestack:
                    rgb_seq_for_stack = rgb_obs_list[-framestack:]  # 取最后 framestack 个
                else:
                    # 历史不足，用第一帧补全
                    first_rgb_frame = rgb_obs_list[0]
                    rgb_seq_for_stack = [first_rgb_frame] * (framestack - len(rgb_obs_list)) + rgb_obs_list
                
                # 在通道维度堆叠
                init_rgb_stacked = np.concatenate(rgb_seq_for_stack, axis=1)  # [num_cameras, C * framestack, H, W]
            else:
                init_rgb_stacked = rgb_obs_list[-1]
            
            init_rgb_tensor = torch.from_numpy(init_rgb_stacked).float().unsqueeze(0).to(self.device)  # [1, V, C*framestack, H, W]
        else:
            init_rgb_tensor = None
        
        # 提取初始 EEF 状态（使用当前观测）
        init_pos, init_angle, init_gripper = current_obs[:3], current_obs[3:6], current_obs[6]
        
        all_time_indices = [0]
        ds_actions = []
        
        # 控制器配置（从 adads3.py 复制）
        
        remove_last_length = len(action_chunk) - minimum_decay_steps
        
        # 第一阶段：使用 ds 降采样收集动作，同时记录时间索引
        last_t = -1
        for t in range(ds - 1, remove_last_length, ds):
            # 合并动作
            action_batch = action_chunk[t - ds + 1:t + 1]  # [ds, action_dim]
            action_batch_raw = normalize_action_controller(action_batch)
            action_raw = merge_delta_actions(action_batch_raw)
            action = unnormalize_action_controller(action_raw[None, :])[0]
            ds_actions.append(action)
            all_time_indices.append(min(t + 1, remove_last_length))  # 动作执行后的时间索引
            last_t = t
        
        # 第二阶段：使用 min_ds 收集剩余动作，同时记录时间索引
        for i in range(max(0, last_t + min_ds), len(action_chunk), min_ds):
            action_batch = action_chunk[i - min_ds + 1:i + 1]
            action_batch_raw = normalize_action_controller(action_batch)
            action_raw = merge_delta_actions(action_batch_raw)
            action = unnormalize_action_controller(action_raw[None, :])[0]
            ds_actions.append(action)
            all_time_indices.append(min(i + 1, len(action_chunk)))  # 动作执行后的时间索引
        

        # 第三阶段：将收集到的动作序列一次性输入 world model 进行预测
        ds_actions_array = np.array(ds_actions)  # [num_actions, action_dim]
        if self.cfg.data.normalize_obs and self.action_mean is not None:
            ds_actions_array = (ds_actions_array - self.action_mean) / self.action_std
        else:
            ds_actions_array = ds_actions_array
        action_sequence = torch.from_numpy(ds_actions_array).float().unsqueeze(0).to(self.device)  # [1, T, action_dim]
        obs_tensor = torch.from_numpy(init_obs_stacked[None, :]).float().to(self.device)  # [1, obs_dim*framestack]
        
        # 一次性预测整个序列
        with torch.no_grad():
            mean = self.dynamics.predict(
                low_dim_obs=obs_tensor,
                rgb_obs=init_rgb_tensor,
                actions=action_sequence
            )  # mean: [1, T, obs_dim]
        predicted_obs = mean[0].cpu().numpy()
        # 处理序列预测结果
        
        if self.cfg.data.normalize_obs and self.obs_mean is not None:
            predicted_obs_unnorm = predicted_obs * self.obs_std + self.obs_mean
        else:
            predicted_obs_unnorm = predicted_obs.copy()
        # 提取 EEF 状态
        pos, angle, gripper = predicted_obs_unnorm[:, :3], predicted_obs_unnorm[:, 3:6], predicted_obs_unnorm[:, 6]
        position_actuals = np.concatenate((init_pos[None, :], pos), axis=0)
        angle_actuals = np.concatenate((init_angle[None, :], angle), axis=0)
        gripper_actuals = np.concatenate((init_gripper[None], gripper), axis=0)
        quat_actuals = np.array([axisangle2quat(angle) for angle in angle_actuals])
        
        # 插值到所有时间步
        target_indices = np.arange(len(action_chunk)+1)
        pos_interp, quat_interp, gripper_interp = interpolate_pos_quat(
            position_actuals, quat_actuals, gripper_actuals, all_time_indices, target_indices
        )
        
        return pos_interp, quat_interp, gripper_interp, np.array(ds_actions)
    
    def _select_downsample_rate(
        self, obs_history: List[np.ndarray], rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray, downsample_rates: List[int],
        manipulation_threshold: float, movement_threshold: float,
        threshold: Optional[float] = None,
        minimum_decay_steps: int = 2
    ) -> Tuple[int, np.ndarray]:
        """选择降采样率
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测
            action_chunk: 动作块
            downsample_rates: 降采样率列表
            manipulation_threshold: 操作阶段阈值
            movement_threshold: 移动阶段阈值
            threshold: 阈值（如果为 None 则自动检测）
        
        Returns:
            selected_ds: 选择的降采样率
            selected_actions: 降采样后的动作列表
        """
        downsample_list = sorted(list(downsample_rates), reverse=True)
        min_ds = min(downsample_list)
        
        if len(action_chunk) < min_ds:
            return 1, action_chunk
        
        # 计算 baseline（使用最小降采样率）
        baseline_pos_interp, baseline_quat_interp, baseline_gripper_interp, baseline_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, min_ds, min_ds, minimum_decay_steps
            )
        
        use_error = threshold is None
        if threshold is None:
            threshold = self._detect_task_phase(
                baseline_pos_interp, baseline_quat_interp, baseline_gripper_interp,
                manipulation_threshold, movement_threshold, min_ds
            )
        
        # 尝试不同的 ds，从大到小
        selected_ds = min_ds
        selected_actions = baseline_actions
        
        for ds in downsample_list:
            if ds > len(action_chunk):
                continue
            if ds != min_ds:
                chunk_pos_interp, chunk_quat_interp, chunk_gripper_interp, chunk_actions = \
                    self._compute_chunk_eef_interpolation_with_world_model(
                        obs_history, rgb_obs_history, action_chunk, ds, min_ds
                    )
                if use_error:
                    chunk_eef_error = compute_eef_error(
                        chunk_pos_interp, chunk_quat_interp,
                        baseline_pos_interp, baseline_quat_interp
                    )
                else:
                    # chunk_eef_error = compute_eef_error_ratio(
                    #     chunk_pos_interp, chunk_quat_interp,
                    #     baseline_pos_interp, baseline_quat_interp
                    # )
                    chunk_eef_error = compute_eef_error(
                        chunk_pos_interp, chunk_quat_interp,
                        baseline_pos_interp, baseline_quat_interp
                    )
                # print(f"max_eef_error: {max_eef_error:.4f}, ds: {ds}, threshold: {threshold}")
                print(chunk_eef_error)
                if chunk_eef_error <= threshold:
                    selected_ds = ds
                    selected_actions = chunk_actions
                    break
        
        return selected_ds, selected_actions
    
    def compute_state_deviation(
        self,
        obs_history: List[np.ndarray],
        rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray,
        min_ds: int = 1,
        minimum_decay_steps: int = 2
    ) -> Dict[str, Any]:
        """
        计算给定 action chunk 在降采样率为2时的 state deviation
        
        这个方法计算 ds=2 相对于 baseline (ds=1) 的 EEF 误差。
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测 [obs_dim]
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测（可选）
            action_chunk: 动作块 [T, action_dim]
            min_ds: 最小降采样率（baseline），默认为1
            minimum_decay_steps: 最小衰减步数，默认为2
        
        Returns:
            结果字典，包含:
                - deviation: float EEF误差（ds=2 vs baseline）
        """
        self.dynamics.eval()
        
        # 计算 baseline（ds=1）
        baseline_pos_interp, baseline_quat_interp, baseline_gripper_interp, baseline_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, min_ds, min_ds, minimum_decay_steps
            )
        
        # 计算 ds=2
        ds2_pos_interp, ds2_quat_interp, ds2_gripper_interp, ds2_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, 2, min_ds, minimum_decay_steps
            )
        
        # 计算 EEF 误差
        deviation_mean = compute_eef_error(
            ds2_pos_interp, ds2_quat_interp,
            baseline_pos_interp, baseline_quat_interp, mode="mean"
        )
        deviation_max = compute_eef_error(
            ds2_pos_interp, ds2_quat_interp,
            baseline_pos_interp, baseline_quat_interp, mode="max"
        )
        
        result = {
            'deviation_mean': float(deviation_mean),
            'deviation_max': float(deviation_max),
        }
        
        self.dynamics.train()
        return result

    
    def start_server(
        self,
        host: str = "0.0.0.0",
        port: int = 8888,
        downsample_rates: List[int] = None,
        manipulation_threshold: float = 0.002,
        movement_threshold: float = 0.05,
        threshold: Optional[float] = None,
        minimum_decay_steps: int = 2,
    ):
        """
        启动 AdaDS Agent WebSocket 服务器
        
        该方法会启动一个 WebSocket 服务器，提供 API 供其他 Python 环境调用 AdaDS 降采样率选择。
        使用 WebSocket 可以避免与 HTTP proxy 的冲突。
        
        Args:
            host: 服务器绑定的 IP 地址，默认为 "0.0.0.0"（允许外部访问）
            port: 服务器端口，默认为 8888
            downsample_rates: 降采样率列表，默认为 [1, 2]
            manipulation_threshold: 操作阶段阈值，默认为 0.002
            movement_threshold: 移动阶段阈值，默认为 0.05
            threshold: 阈值（如果为 None 则自动检测），默认为 None
            minimum_decay_steps: 最小衰减步数，默认为 2
        
        使用示例:
            # 加载 workspace 后启动服务器
            workspace = DynamicsWorkspace(cfg, work_dir="...", train=False)
            workspace.load_snapshot()
            
            # 启动服务器
            workspace.start_server(
                host="0.0.0.0",
                port=8888,
                downsample_rates=[1, 2]
            )
        """
        start_adads_agent_server(
            workspace=self,
            host=host,
            port=port,
            downsample_rates=downsample_rates,
            manipulation_threshold=manipulation_threshold,
            movement_threshold=movement_threshold,
            threshold=threshold,
            minimum_decay_steps=minimum_decay_steps,
        )
    
    def collect_transitions_with_different_ds(
        self,
        task_suite_name: str = "libero_spatial",
        num_episodes: int = 10,
        downsample_rates: List[int] = [1, 2, 3],
        dataset_name: str = "libero_spatial",
        num_steps_wait: int = 10,
        seed: int = 7,
        resize_size: int = 224,
        control_freq: int = 20,
        original_control_freq: int = 20,
        policy_host: str = "0.0.0.0",
        policy_port: int = 8000,
    ) -> Dict[str, List[Dict[str, np.ndarray]]]:
        """收集不同降采样率的 transition 数据
        
        Args:
            task_suite_name: 任务套件名称
            num_episodes: 每个降采样率收集的 episode 数量
            downsample_rates: 降采样率列表
            num_steps_wait: 等待步数
            seed: 随机种子
            resize_size: 图像调整大小
            control_freq: 控制频率
            original_control_freq: 原始控制频率
            policy_host: 策略服务器地址
            policy_port: 策略服务器端口
        
        Returns:
            transitions_by_ds: 按降采样率组织的 transition 字典 {ds: [transitions]}
        """
        import pickle
        
        # 创建保存目录
        save_dir = Path("exp_output") / dataset_name / "transitions"

        if os.path.exists(save_dir):
            print("Transitions already exist")
            return
        
        save_dir.mkdir(parents=True, exist_ok=True)

        repeats = control_freq // original_control_freq
        np.random.seed(seed)
        
        # 初始化 LIBERO task suite
        benchmark_dict = benchmark.get_benchmark_dict()
        task_suite = benchmark_dict[task_suite_name]()
        num_tasks_in_suite = task_suite.n_tasks
        
        libero_dummy_action = [0.0] * 6 + [-1.0]
        libero_env_resolution = 256
        
        policy_client = websocket_client_policy.WebsocketClientPolicy(policy_host, policy_port)
        apply_patches()
        
        # 设置最大步数
        max_steps_dict = {
            "libero_spatial": 220,
            "libero_object": 280,
            "libero_goal": 300,
            "libero_10": 520,
            "libero_90": 400,
        }
        if task_suite_name not in max_steps_dict:
            raise ValueError(f"Unknown task suite: {task_suite_name}")
        max_steps = max_steps_dict[task_suite_name] * repeats
        num_steps_wait = num_steps_wait * repeats
        
        # 存储 transitions（按降采样率组织）
        transitions_by_ds = {ds: [] for ds in downsample_rates}
        for ds in downsample_rates:
            if ds == 1:
                revert_patches()
            else:
                apply_patches()
            # 遍历所有任务
            for task_id in tqdm(range(num_tasks_in_suite), desc="Collecting transitions"):
                task = task_suite.get_task(task_id)
                initial_states = task_suite.get_task_init_states(task_id)
                task_description = task.language
                
                # 初始化环境
                task_bddl_file = Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
                env_args = {
                    "bddl_file_name": task_bddl_file,
                    "camera_heights": libero_env_resolution,
                    "camera_widths": libero_env_resolution,
                    "control_freq": control_freq
                }
                env = OffScreenRenderEnv(**env_args)
                env.seed(seed)
                
                # 为每个降采样率收集数据
                for episode_idx in tqdm(range(num_episodes), desc=f"Task {task_id} DS={ds}", leave=False):
                    # 重置环境
                    env.reset()
                    obs = env.set_init_state(initial_states[episode_idx])
                    
                    # 等待物体稳定
                    for _ in range(num_steps_wait):
                        obs, _, _, _ = env.step(libero_dummy_action)
                    
                    # 收集 transition
                    episode_obs = []
                    episode_next_obs = []
                    episode_actions = []
                    episode_rgb_obs = []
                    episode_next_rgb_obs = []
                    
                    t = 0
                    done = False
                    
                    while t < max_steps and not done:
                        # 获取观测
                        img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                        wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
                        
                        # 预处理图像
                        img_processed = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img, resize_size, resize_size)
                        )
                        wrist_img_processed = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(wrist_img, resize_size, resize_size)
                        )
                        
                        # 准备观测字典（用于策略）
                        state = np.concatenate((
                            obs["robot0_eef_pos"],
                            quat2axisangle(obs["robot0_eef_quat"]),
                            obs["robot0_gripper_qpos"],
                        ))
                        element = {
                            "observation/image": img_processed,
                            "observation/wrist_image": wrist_img_processed,
                            "observation/state": state,
                            "prompt": str(task_description),
                        }
                        
                        # 从策略获取动作块
                        action_chunk = policy_client.infer(element)["actions"]
                        
                        # 根据降采样率执行动作
                        if ds == 1:
                            # 不降采样：执行所有动作
                            actions_to_execute = action_chunk
                        else:
                            
                            # 合并 ds 个动作为一个
                            actions_to_execute = []
                            for i in range(0, len(action_chunk), ds):
                                action_batch = action_chunk[i:i+ds]
                                if len(action_batch) > 0:
                                    action_batch_raw = normalize_action_controller(action_batch)
                                    action_raw = merge_delta_actions(action_batch_raw)
                                    action = unnormalize_action_controller(action_raw[None, :])[0]
                                    actions_to_execute.append(action)
                        
                        # 执行动作并收集 transition
                        for action in actions_to_execute:
                            if done:
                                break
                            
                            # 保存当前观测
                            current_state = state.copy()
                            current_rgb = np.stack([img, wrist_img], axis=0)
                            
                            # 执行动作
                            next_obs, reward, done, info = env.step(action.tolist())
                            
                            # 保存下一个观测
                            next_img = np.ascontiguousarray(next_obs["agentview_image"][::-1, ::-1])
                            next_wrist_img = np.ascontiguousarray(next_obs["robot0_eye_in_hand_image"][::-1, ::-1])
                            next_state = np.concatenate((
                                next_obs["robot0_eef_pos"],
                                quat2axisangle(next_obs["robot0_eef_quat"]),
                                next_obs["robot0_gripper_qpos"],
                            ))
                            next_rgb = np.stack([next_img, next_wrist_img], axis=0) if self.cfg.model.use_pixels else None
                            
                            # 保存 transition
                            episode_obs.append(current_state)
                            episode_next_obs.append(next_state)
                            episode_actions.append(action)
                            if current_rgb is not None:
                                episode_rgb_obs.append(current_rgb)
                                episode_next_rgb_obs.append(next_rgb)
                            
                            obs = next_obs
                            state = next_state
                            img = next_img
                            wrist_img = next_wrist_img
                            t += 1
                            
                            if done:
                                break
                    
                    # 构建 episode transition
                    episode_actions = np.array(episode_actions, dtype=np.float32)
                    episode_obs = np.array(episode_obs, dtype=np.float32)
                    episode_next_obs = np.array(episode_next_obs, dtype=np.float32)

                    if len(episode_obs) > 0:
                        transition = {
                            'obs': episode_obs,
                            'next_obs': episode_next_obs,
                            'actions': episode_actions,
                            'task_id': task_id,
                            'task_description': task_description,
                            'episode_idx': episode_idx,
                        }
                        if len(episode_rgb_obs) > 0:
                            transition['rgb_obs'] = np.array(episode_rgb_obs, dtype=np.uint8)
                            transition['next_rgb_obs'] = np.array(episode_next_rgb_obs, dtype=np.uint8)
                        
                        transitions_by_ds[ds].append(transition)

        # 保存每个降采样率的数据
        for ds, transitions in transitions_by_ds.items():
            save_path = save_dir / f"transitions_ds{ds}.pkl"
            with open(save_path, 'wb') as f:
                pickle.dump(transitions, f)
            print(f"已保存 {len(transitions)} 个 DS={ds} 的 transitions 到 {save_path}")
        
        # 保存元数据
        metadata = {
            'dataset_name': dataset_name,
            'downsample_rates': list(transitions_by_ds.keys()),
            'num_transitions_by_ds': {str(ds): len(transitions) for ds, transitions in transitions_by_ds.items()},
        }
        metadata_path = save_dir / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"已保存元数据到 {metadata_path}")

    
    def eval_world_model_on_transitions(
        self,
        dataset_name: str,
        downsample_rates: List[int] = [1, 2, 3],
        verbose: bool = True,
    ) -> Dict[str, Any]:
        """在收集的 transition 上评估 World Model 性能
        
        Args:
            dataset_name: 数据集名称
            downsample_rates: 要评估的降采样率列表
        
        Returns:
            评估结果字典
        """
        import pickle
        
        # 加载 transition 数据
        transitions_dir = Path("exp_output") / dataset_name / "transitions"
        if not transitions_dir.exists():
            raise ValueError(f"Transitions 目录不存在: {transitions_dir}")
        
        self.dynamics.eval()
        
        results = {}
        
        for ds in downsample_rates:
            transitions_path = transitions_dir / f"transitions_ds{ds}.pkl"
            if not transitions_path.exists():
                print(f"警告: 未找到 DS={ds} 的 transitions 文件: {transitions_path}")
                continue
            
            with open(transitions_path, 'rb') as f:
                transitions = pickle.load(f)
            
            if verbose:
                print(f"\n评估 DS={ds} 的 transitions ({len(transitions)} 个 episodes)...")
            
            # 评估指标
            total_mae = 0.0
            total_transitions = 0
            episode_errors = []
            seq_len = self.cfg.training.seq_len
            
            for transition in tqdm(transitions, desc=f"Evaluating DS={ds}"):
                obs = transition['obs']  # [T, obs_dim]
                next_obs = transition['next_obs']  # [T, obs_dim]
                actions = transition['actions']  # [T, action_dim]
                
                # 如果 episode 长度小于 seq_len，跳过
                if len(actions) < seq_len:
                    continue
                
                # 准备数据
                if self.cfg.data.normalize_obs and self.obs_mean is not None:
                    obs_normalized = (obs - self.obs_mean) / self.obs_std
                    next_obs_normalized = (next_obs - self.obs_mean) / self.obs_std
                else:
                    obs_normalized = obs
                    next_obs_normalized = next_obs
                
                # 准备 RGB 观测（如果有）
                rgb_obs_array = None
                if 'rgb_obs' in transition and self.cfg.model.use_pixels:
                    rgb_obs_array = transition['rgb_obs']  # [T, num_cameras, H, W, C]
                    # 转换为 [T, num_cameras, C, H, W] 格式
                    if rgb_obs_array.shape[-1] == 3:
                        rgb_obs_array = rgb_obs_array.transpose(0, 1, 4, 2, 3)  # [T, num_cameras, C, H, W]
                
                # 将 episode 分割成多个 seq_len 长度的序列进行评估
                # 与训练时一致：从每个 episode 中采样多个 seq_len 长度的序列
                num_sequences = len(actions) - seq_len + 1
                episode_mae = 0.0
                episode_transitions = 0
                
                # 评估所有可能的序列（或者可以随机采样，这里评估所有序列以保持一致性）
                framestack = self.cfg.model.framestack
                
                for start_idx in range(num_sequences):
                    # 计算需要的起始索引（包含 framestack-1 个历史帧）
                    history_start = max(0, start_idx - (framestack - 1))
                    
                    # 提取 obs 序列（包含历史帧）
                    obs_seq_with_history = obs_normalized[history_start:start_idx + seq_len + 1]  # [?, obs_dim]
                    
                    # 如果历史不足，用第一帧补全
                    if history_start == 0 and start_idx > 0:
                        num_missing = (framestack - 1) - start_idx
                        if num_missing > 0:
                            first_frame = obs_normalized[0:1]  # [1, obs_dim]
                            padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, obs_dim]
                            obs_seq_with_history = np.concatenate([padding, obs_seq_with_history], axis=0)
                    
                    # 应用 framestack（现在有真实历史数据）
                    seq_obs_stacked = self._apply_framestack(
                        obs_seq_with_history[:len(obs_seq_with_history)-1], 
                        framestack, 
                        first_frame=None  # 不需要补全，已经有完整历史
                    )[-seq_len:]  # 只取最后 seq_len 个，维度：[seq_len, obs_dim*framestack]
                    
                    # 目标观测不需要 framestack
                    next_obs_seq = obs_seq_with_history[-seq_len:]  # [seq_len, obs_dim]
                    
                    seq_actions = actions[start_idx:start_idx+seq_len]  # [seq_len, action_dim]
                    
                    # 标准化动作
                    if self.cfg.data.normalize_obs and self.action_mean is not None:
                        seq_actions_normalized = (seq_actions - self.action_mean) / self.action_std
                    else:
                        seq_actions_normalized = seq_actions
                    
                    # 转换为 tensor（使用堆叠后的第一个观测）
                    first_obs_tensor = torch.from_numpy(seq_obs_stacked[0:1]).float().to(self.device)  # [1, obs_dim*framestack]
                    actions_tensor = torch.from_numpy(seq_actions_normalized).float().unsqueeze(0).to(self.device)  # [1, seq_len, action_dim]
                    
                    # 准备 RGB 观测（如果使用 framestack，需要堆叠）
                    first_rgb = None
                    if rgb_obs_array is not None:
                        # 计算 RGB 序列的起始索引（包含 framestack-1 个历史帧）
                        rgb_history_start = max(0, start_idx - (framestack - 1))
                        
                        # 提取 RGB 序列（包含历史）
                        rgb_seq_with_history = rgb_obs_array[rgb_history_start:start_idx + seq_len + 1]  # [?, num_cameras, C, H, W]
                        
                        # 如果历史不足，用第一帧补全
                        if rgb_history_start == 0 and start_idx > 0:
                            num_missing = (framestack - 1) - start_idx
                            if num_missing > 0:
                                first_frame = rgb_obs_array[0:1]  # [1, num_cameras, C, H, W]
                                padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, num_cameras, C, H, W]
                                rgb_seq_with_history = np.concatenate([padding, rgb_seq_with_history], axis=0)
                        
                        # 应用 framestack（现在有真实历史数据）
                        rgb_seq_stacked = self._apply_framestack(
                            rgb_seq_with_history[:len(rgb_seq_with_history)-1], 
                            framestack,
                            first_frame=None  # 不需要补全，已经有完整历史
                        )[-seq_len:]  # 只取最后 seq_len 个，维度：[seq_len, num_cameras, C*framestack, H, W]
                        
                        first_rgb = torch.from_numpy(rgb_seq_stacked[0:1]).float().to(self.device)  # [1, num_cameras, C*framestack, H, W]
                    
                    # 预测
                    with torch.no_grad():
                        predicted_seq = self.dynamics.predict(
                            low_dim_obs=first_obs_tensor,
                            rgb_obs=first_rgb,
                            actions=actions_tensor
                        )  # [1, seq_len, obs_dim]
                    
                    # 处理预测结果
                    predicted_obs_seq = predicted_seq[0].cpu().numpy()  # [seq_len, obs_dim]
                    
                    # 反标准化（next_obs_seq 是目标观测，需要反标准化）
                    if self.cfg.data.normalize_obs and self.obs_mean is not None:
                        predicted_obs_seq = predicted_obs_seq * self.obs_std + self.obs_mean
                        # next_obs_seq 需要反标准化
                        next_obs_original = next_obs_seq * self.obs_std + self.obs_mean
                    else:
                        next_obs_original = next_obs_seq  # [seq_len, obs_dim]
                    
                    # 计算误差，不考虑夹爪
                    seq_mae = np.mean(np.abs(predicted_obs_seq - next_obs_original)[:, :-2])
                    
                    episode_mae += seq_mae
                    episode_transitions += 1
                    total_mae += seq_mae
                    total_transitions += 1
                
                # 记录 episode 级别的误差
                if episode_transitions > 0:
                    avg_episode_mae = episode_mae / episode_transitions
                    episode_errors.append({
                        'mae': float(avg_episode_mae),
                        'num_sequences': episode_transitions,
                        'num_steps': len(actions),
                    })
            
            # 计算平均误差
            avg_mae = total_mae / total_transitions if total_transitions > 0 else 0.0
            
            results[f'ds_{ds}'] = {
                'avg_mae': float(avg_mae),
                'num_episodes': len(transitions),
                'num_transitions': total_transitions,
                'episode_errors': episode_errors,
            }
            
            if verbose:
                print(f"DS={ds} 评估结果:")
                print(f"  平均 MAE: {avg_mae:.6f}")
                print(f"  Episodes: {len(transitions)}")
                print(f"  Transitions: {total_transitions}")
        

        
        return results


def start_adads_agent_server(
    workspace,
    host: str = "0.0.0.0",
    port: int = 8888,
    downsample_rates: List[int] = None,
    manipulation_threshold: float = 0.002,
    movement_threshold: float = 0.05,
    threshold: Optional[float] = None,
    minimum_decay_steps: int = 2,
):
    """
    启动 AdaDS Agent 的 WebSocket 服务器
    
    该函数会启动一个 WebSocket 服务器，提供 API 供其他 Python 环境调用 AdaDS 降采样率选择。
    使用 WebSocket 可以避免与 HTTP proxy 的冲突。
    
    Args:
        workspace: DynamicsWorkspace 实例（已加载 dynamics model）
        host: 服务器绑定的 IP 地址，默认为 "0.0.0.0"（允许外部访问）
        port: 服务器端口，默认为 8888
        downsample_rates: 降采样率列表，默认为 [1, 2]
        manipulation_threshold: 操作阶段阈值
        movement_threshold: 移动阶段阈值
        threshold: 阈值（如果为 None 则自动检测）
        minimum_decay_steps: 最小衰减步数
    """
    try:
        import websockets.sync.server
    except ImportError:
        raise ImportError(
            "websockets is required to run the server. "
            "Please install it with: pip install websockets"
        )
    
    if downsample_rates is None:
        downsample_rates = [1, 2]
    
    from msgpack_numpy_utils import Packer, unpackb
    
    def handle_client(websocket):
        """处理客户端连接"""
        logging.info(f"Connection opened")
        packer = Packer()
        
        try:
            # 发送服务器元数据
            metadata = {
                "status": "ready",
                "downsample_rates": downsample_rates
            }
            websocket.send(packer.pack(metadata))
            
            # 处理消息循环
            while True:
                try:
                    # 接收消息
                    message = websocket.recv()
                    if isinstance(message, str):
                        # 字符串表示错误，关闭连接
                        logging.error(f"Received string message: {message}")
                        break
                    
                    data = unpackb(message)
                    request_type = data.get("type", "predict_k")
                    
                    if request_type == "health":
                        # 健康检查
                        response = {
                            "status": "success",
                            "downsample_rates": downsample_rates
                        }
                        websocket.send(packer.pack(response))
                    
                    elif request_type == "state_deviation":
                        # 计算 ds=2 的 state deviation
                        if "s_env" not in data or "A" not in data:
                            response = {
                                "status": "error",
                                "message": "Request must contain 's_env' and 'A' fields"
                            }
                            websocket.send(packer.pack(response))
                            continue
                        
                        s_env = np.array(data["s_env"], dtype=np.float32)
                        A = np.array(data["A"], dtype=np.float32)
                        
                        # 解析参数
                        min_ds = data.get("min_ds", 1)
                        minimum_decay_steps = data.get("minimum_decay_steps", 2)
                        
                        # 获取 framestack 配置
                        framestack = workspace.cfg.model.framestack
                        if len(s_env.shape) > 1:
                            obs_history = [s_env[i] for i in range(len(s_env))]
                        else:
                            obs_history = [s_env] * max(1, framestack)
                        
                        # 解析 RGB 观测（可选）
                        rgb_obs_history = None
                        if "rgb_obs" in data and data["rgb_obs"] is not None:
                            rgb_obs = np.array(data["rgb_obs"], dtype=np.uint8)
                            if len(rgb_obs.shape) >= 4 and rgb_obs.shape[0] > 1:
                                rgb_obs_history = [rgb_obs[i] for i in range(len(rgb_obs))]
                            else:
                                rgb_obs_history = [rgb_obs] * max(1, framestack)
                        
                        # 计算 state deviation
                        result = workspace.compute_state_deviation(
                            obs_history=obs_history,
                            rgb_obs_history=rgb_obs_history,
                            action_chunk=A,
                            min_ds=min_ds,
                            minimum_decay_steps=minimum_decay_steps
                        )
                        
                        response = {
                            "status": "success",
                            "deviation_mean": result['deviation_mean'],
                            "deviation_max": result['deviation_max'],
                        }
                        
                        websocket.send(packer.pack(response))
                    
                    elif request_type == "predict_k":
                        # 选择降采样率
                        if "s_env" not in data or "A" not in data:
                            response = {
                                "status": "error",
                                "message": "Request must contain 's_env' and 'A' fields"
                            }
                            websocket.send(packer.pack(response))
                            continue
                        
                        s_env = np.array(data["s_env"], dtype=np.float32)
                        A = np.array(data["A"], dtype=np.float32)
                        
                        # 获取 framestack 配置
                        framestack = workspace.cfg.model.framestack
                        if len(s_env.shape) > 1:
                            obs_history = [s_env[i] for i in range(len(s_env))]
                        else:
                            obs_history = [s_env] * max(1, framestack)
                        
                        # 解析 RGB 观测（可选）
                        rgb_obs_history = None
                        if "rgb_obs" in data and data["rgb_obs"] is not None:
                            rgb_obs = np.array(data["rgb_obs"], dtype=np.uint8)
                            if len(rgb_obs.shape) >= 4 and rgb_obs.shape[0] > 1:
                                rgb_obs_history = [rgb_obs[i] for i in range(len(rgb_obs))]
                            else:
                                rgb_obs_history = [rgb_obs] * max(1, framestack)
                        
                        # 调用 workspace 的 _select_downsample_rate 方法
                        selected_ds, selected_actions = workspace._select_downsample_rate(
                            obs_history=obs_history,
                            rgb_obs_history=rgb_obs_history,
                            action_chunk=A,
                            downsample_rates=downsample_rates,
                            manipulation_threshold=manipulation_threshold,
                            movement_threshold=movement_threshold,
                            threshold=threshold,
                            minimum_decay_steps=minimum_decay_steps
                        )
                        
                        response = {
                            "status": "success",
                            "predicted_k": int(selected_ds)
                        }
                        websocket.send(packer.pack(response))
                    
                    else:
                        response = {
                            "status": "error",
                            "message": f"Unknown request type: {request_type}"
                        }
                        websocket.send(packer.pack(response))
                
                except Exception as e:
                    # 发送错误信息
                    import traceback
                    error_traceback = traceback.format_exc()
                    logging.error(f"Error handling message: {error_traceback}")
                    websocket.send(error_traceback)
                    break
        
        except Exception as e:
            logging.error(f"Error in client handler: {e}")
        
        logging.info(f"Connection closed")
    
    # 启动服务器
    uri = f"ws://{host}:{port}"
    print(f"启动 AdaDS Agent WebSocket 服务器...")
    print(f"服务器地址: {uri}")
    print(f"降采样率列表: {downsample_rates}")
    print(f"按 Ctrl+C 停止服务器")
    
    with websockets.sync.server.serve(handle_client, host, port, compression=None, max_size=None) as server:
        try:
            server.serve_forever()
        except KeyboardInterrupt:
            print("\n正在关闭服务器...")

