"""
Dynamics Workspace - World Model 训练工作空间
仿照 robobase/robobase/dynamics_workspace.py 实现
"""

import shutil
import signal
import sys
import os
import json
import time
import copy
import random
import logging
from typing import Callable, Any, Optional, Dict, List, Tuple
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict

from config import Config
from models.mlp_world_model import MLPWorldModel
from dataprocess import create_loader, analyze_episodes, split_train_val
from utils import (
    set_seed,
    Timer,
    Every,
    Until,
    Logger,
    create_output_dir,
    save_config,
    put_text,
    to_torch,
    to_numpy,
    merge_delta_actions,
    normalize_action_controller,
    unnormalize_action_controller,
    convert_delta_to_absolute_actions,
    compute_eef_error,
    compute_eef_error_ratio,
    interpolate_pos_quat,
    extract_eef_from_obs,
    quat2axisangle,
    axisangle2quat,

    apply_patches,
    revert_patches
)


class DynamicsWorkspace:
    """World Model 训练工作空间"""
    
    def __init__(
        self,
        cfg: Config,
        work_dir: Optional[str] = None,
        train: bool = True
    ):
        """
        Args:
            cfg: 配置对象
            work_dir: 工作目录（如果为 None，则使用配置中的输出目录）
            train: 是否训练
        """
        self.cfg = cfg
        
        # 设置随机种子
        set_seed(cfg.training.seed)
        
        # 设置设备
        dev = "cpu"
        if cfg.training.device != "cpu":
            if sys.platform == "darwin":
                dev = "mps"
            else:
                dev = cfg.training.device
        self.device = torch.device(dev)
        print(f"使用设备: {self.device}")
        
        # 创建工作目录
        if work_dir is None:
            self.work_dir = create_output_dir(
                cfg.log.output_dir,
                cfg.log.exp_name
            )
        else:
            self.work_dir = Path(work_dir)
            self.work_dir.mkdir(parents=True, exist_ok=True)
        print(f"工作目录: {self.work_dir}")
        
        # 创建 logger
        self.logger = Logger(
            log_dir=self.work_dir,
            use_wandb=cfg.log.use_wandb,
            use_tensorboard=cfg.log.use_tensorboard
        )
        
        # 如果使用 wandb，初始化
        if cfg.log.use_wandb:
            import wandb
            wandb.init(
                project=cfg.log.wandb_project,
                entity=cfg.log.wandb_entity,
                name=cfg.log.wandb_name or cfg.log.exp_name,
                mode=cfg.log.wandb_mode,
                config=cfg.to_dict()
            )
        
        # 加载数据
        if train:
            print("\n加载数据...")
            self.load_data()
            # 创建数据采样器（直接使用 episodes）
            print("\n准备训练数据...")
            self.create_data_sampler()
        
        # 创建模型
        print("\n创建模型...")
        self.create_model()
        
        # 计时器
        self._timer = Timer()
        
        # 训练状态
        self._pretrain_step = 0
        self._main_loop_iterations = 0
        self._global_env_episode = 0
        
        # 最佳指标
        self.best_metrics = {
            "best_val_loss": float('inf'),
        }
        
        # 关闭标志
        self._shutting_down = False
        # 保存配置
        save_config(cfg.to_dict(), self.work_dir)

    def load_data(self):
        """加载并准备数据"""
        # 创建数据加载器
        loader = create_loader(
            env_name=self.cfg.data.env_name,
            data_path=self.cfg.data.data_dir,
            **self.cfg.data.loader_kwargs
        )

        # 加载 episodes
        episodes = loader.load_episodes(num_episodes=self.cfg.data.num_demos)
        
        # 分析数据
        data_stats = analyze_episodes(episodes)
        
        # 更新配置中的维度（如果未设置）
        if data_stats.get('has_obs'):
            self.cfg.model.obs_dim = data_stats['obs_dim']
        else:
            self.cfg.model.obs_dim = 0  # 只使用 RGB
        
        self.cfg.model.action_dim = data_stats['action_dim']
        
        # 更新 RGB 配置
        if self.cfg.model.use_pixels and data_stats.get('has_rgb'):
            rgb_shape = data_stats['rgb_shape']  # (num_cameras, C, H, W)
            self.cfg.model.num_cameras = rgb_shape[0]
            self.cfg.model.image_channels = rgb_shape[1]
            self.cfg.model.image_size = (rgb_shape[2], rgb_shape[3])
        
        print(f"\n数据维度:")
        print(f"  低维观测维度: {self.cfg.model.obs_dim}")
        if self.cfg.model.framestack > 1:
            print(f"  Framestack: {self.cfg.model.framestack} (实际输入维度: {self.cfg.model.obs_dim * self.cfg.model.framestack})")
        print(f"  动作维度: {self.cfg.model.action_dim}")
        if self.cfg.model.use_pixels:
            print(f"  使用 RGB: True")
            print(f"  相机数量: {self.cfg.model.num_cameras}")
            print(f"  图像尺寸: {self.cfg.model.image_size}")
            if self.cfg.model.framestack > 1:
                print(f"  Framestack: {self.cfg.model.framestack} (实际输入通道数: {self.cfg.model.image_channels * self.cfg.model.framestack})")
        
        # 保存 episodes 供后续使用
        self.episodes = episodes
        
        # 如果是多任务数据，显示任务分布
        if self.cfg.data.env_name == "libero" and 'info' in episodes[0]:
            task_counts = {}
            for ep in episodes:
                if 'info' in ep and 'task' in ep['info']:
                    task = ep['info']['task']
                    task_counts[task] = task_counts.get(task, 0) + 1
            
            if task_counts:
                print(f"\n任务分布:")
                for task, count in sorted(task_counts.items()):
                    print(f"  {task}: {count} episodes")
    
    def create_model(self):
        """创建 MLP World Model"""
        # MLP模型只使用低维观测，不支持像素输入
        actual_obs_dim = self.cfg.model.obs_dim if self.cfg.model.obs_dim > 0 else 0
        
        # 获取隐藏层配置
        hidden_dims = [256, 256, 256]
        
        self.dynamics = MLPWorldModel(
            obs_dim=actual_obs_dim,
            action_dim=self.cfg.model.action_dim,
            hidden_dims=hidden_dims,
            dropout=self.cfg.model.dropout,
            learning_rate=self.cfg.training.learning_rate,
            weight_decay=self.cfg.training.weight_decay,
            grad_clip=self.cfg.training.grad_clip,
            use_symlog=self.cfg.model.use_symlog,
            use_var=self.cfg.model.use_var,
            use_residual=self.cfg.model.use_residual,
            device=str(self.device)
        )
        self.dynamics.train(False)  # 默认评估模式
        
        print(f"\n模型参数量: {sum(p.numel() for p in self.dynamics.parameters()):,}")
        print(f"MLP隐藏层维度: {hidden_dims}")
        print(f"观测维度: {actual_obs_dim}, 动作维度: {self.cfg.model.action_dim}")
    
    def create_data_sampler(self):
        """准备训练数据（直接使用 episodes list）"""
        # 划分训练集和验证集
        train_episodes, val_episodes = split_train_val(
            self.episodes,
            train_ratio=self.cfg.data.train_ratio,
            shuffle=True
        )
        
        # 直接保存 episodes
        self.train_episodes = train_episodes
        self.val_episodes = val_episodes
        
        # 计算标准化统计信息（仅从训练集计算）
        if self.cfg.data.normalize_obs:
            self._compute_statistics(train_episodes)
        else:
            self.obs_mean = None
            self.obs_std = None
            self.action_mean = None
            self.action_std = None
        
        print(f"\n数据准备完成:")
        print(f"  训练 episodes 数: {len(train_episodes)}")
        print(f"  验证 episodes 数: {len(val_episodes)}")
        
        # 计算可采样转换数量（用于统计）
        train_transitions = sum(len(ep['actions']) for ep in train_episodes)
        val_transitions = sum(len(ep['actions']) for ep in val_episodes)
        print(f"  训练转换数: {train_transitions}")
        print(f"  验证转换数: {val_transitions}")
    
    def _compute_statistics(self, episodes: List[Dict[str, np.ndarray]]):
        """计算标准化统计信息"""
        all_obs = []
        all_actions = []
        
        for episode in episodes:
            if 'obs' in episode and self.cfg.model.obs_dim > 0:
                all_obs.append(episode['obs'][:-1])  # 排除最后一个状态
            all_actions.append(episode['actions'])
        
        if all_obs:
            all_obs_concat = np.concatenate(all_obs, axis=0)
            self.obs_mean = np.mean(all_obs_concat, axis=0)
            self.obs_std = np.std(all_obs_concat, axis=0) + 1e-8
        else:
            self.obs_mean = None
            self.obs_std = None
        
        if all_actions:
            all_actions_concat = np.concatenate(all_actions, axis=0)
            self.action_mean = np.mean(all_actions_concat, axis=0)
            self.action_std = np.std(all_actions_concat, axis=0) + 1e-8
        
        if self.cfg.data.normalize_obs:
            print("\n数据统计信息:")
            if self.obs_mean is not None:
                print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
                print(f"  观测标准差: {self.obs_std[:5]}...")
            print(f"  动作均值: {self.action_mean[:5]}...")
            print(f"  动作标准差: {self.action_std[:5]}...")
    
    def _sample_batch(self, episodes: List[Dict[str, np.ndarray]], batch_size: int) -> Dict[str, torch.Tensor]:
        """从 episodes 中随机采样一个 batch (单步转换)"""
        # MLP模型使用单步转换，不需要序列
        # 随机选择 episodes 和时间步
        selected = []
        for _ in range(batch_size):
            ep = random.choice(episodes)
            if len(ep['actions']) > 0:
                # 随机选择一个时间步（注意obs比actions多一个）
                t = random.randint(0, len(ep['actions']) - 1)
                selected.append((ep, t))
        
        # 提取单步转换数据
        has_obs = 'obs' in episodes[0] and self.cfg.model.obs_dim > 0
        
        batch_obs = []
        batch_next_obs = []
        batch_actions = []
        
        for ep, t in selected:
            if has_obs:
                # 当前观测和下一观测
                obs = ep['obs'][t]  # [obs_dim]
                next_obs = ep['obs'][t + 1]  # [obs_dim]
                batch_obs.append(obs)
                batch_next_obs.append(next_obs)
            
            # 动作
            action = ep['actions'][t]  # [action_dim]
            batch_actions.append(action)
        
        # 堆叠并转换为 numpy array
        batch = {}
        if has_obs and batch_obs:
            batch_obs_array = np.stack(batch_obs)  # [batch_size, obs_dim]
            batch_next_obs_array = np.stack(batch_next_obs)  # [batch_size, obs_dim]
            
            # 标准化
            if self.cfg.data.normalize_obs and self.obs_mean is not None:
                batch_obs_array = (batch_obs_array - self.obs_mean) / self.obs_std
                batch_next_obs_array = (batch_next_obs_array - self.obs_mean) / self.obs_std
            
            batch['obs'] = torch.from_numpy(batch_obs_array).float()
            batch['next_obs'] = torch.from_numpy(batch_next_obs_array).float()
        
        batch_actions_array = np.stack(batch_actions)  # [batch_size, action_dim]
        
        # 标准化动作
        if self.cfg.data.normalize_obs and self.action_mean is not None:
            batch_actions_array = (batch_actions_array - self.action_mean) / self.action_std
        
        batch['actions'] = torch.from_numpy(batch_actions_array).float()
        
        return batch
    
    @property
    def pretrain_steps(self):
        return self._pretrain_step
    
    @property
    def main_loop_iterations(self):
        return self._main_loop_iterations
    
    @property
    def global_env_episodes(self):
        return self._global_env_episode
    
    @property
    def global_env_steps(self):
        """总的环境步数"""
        return self.pretrain_steps
    
    def train(self):
        """训练循环"""
        print("\n" + "="*60)
        print("开始训练 World Model")
        print("="*60 + "\n")
        
        # 训练控制
        should_train = Until(self.cfg.training.num_train_steps)
        should_log = Every(self.cfg.training.log_every_steps)
        should_eval = Every(self.cfg.training.eval_every_steps)
        should_save = Every(self.cfg.training.save_every_steps)
        
        # 训练循环
        while should_train(self.pretrain_steps):
            # 采样 batch
            batch = self._sample_batch(self.train_episodes, self.cfg.training.batch_size)
            # 训练步骤
            metrics = self._perform_updates(batch)

            # 日志记录
            if should_log(self.pretrain_steps):
                elapsed, total_time = self._timer.reset()
                metrics['elapsed_time'] = elapsed
                metrics['total_time'] = total_time
                metrics['steps_per_sec'] = self.cfg.training.log_every_steps / elapsed
                metrics.update(self._get_common_metrics())
                
            self.logger.log_metrics(
                metrics, self.pretrain_steps, prefix="pretrain"
            )
                
            
            # 评估
            if should_eval(self.pretrain_steps):
                print(f"\n评估 (步骤 {self.pretrain_steps})...")
                eval_metrics = self._eval()
                eval_metrics.update(self._get_common_metrics())
                self.logger.log_metrics(
                    eval_metrics, self.pretrain_steps, prefix="pretrain_eval"
                )
                
                # 保存最佳模型
                if eval_metrics['val_loss'] < self.best_metrics['best_val_loss']:
                    self.best_metrics['best_val_loss'] = eval_metrics['val_loss']
                    self.save_snapshot(best_ckpt=True)
                    print(f"保存最佳模型 (val_loss: {self.best_metrics['best_val_loss']:.4f})")
            
            # 保存检查点
            if should_save(self.pretrain_steps):
                self.save_snapshot()
            
            self._pretrain_step += 1
        
        # 训练结束
        print("\n" + "="*60)
        print("训练完成!")
        print("="*60)
        
        # 最终评估
        print("\n最终评估...")
        final_metrics = self._eval()
        final_metrics.update(self._get_common_metrics())
        self.logger.log_metrics(
            final_metrics, self.pretrain_steps, prefix="final"
        )
        
        # 保存最终模型
        self.save_snapshot()
        
        # 关闭 logger
        self.logger.close()
    
    def _perform_updates(self, batch) -> Dict[str, Any]:
        """执行一次训练更新"""
        metrics = {}
        self.dynamics.train(True)
        
        # 提取数据 (MLP模型只使用低维观测)
        obs = batch['obs'].to(self.device)
        next_obs = batch['next_obs'].to(self.device)
        actions = batch['actions'].to(self.device)
        
        # 更新模型
        update_metrics = self.dynamics.update(
            obs=obs,
            next_obs=next_obs,
            actions=actions
        )
        
        metrics.update(update_metrics)
        self.dynamics.train(False)
        
        return metrics
    
    def _sample_sequence_batch(self, episodes: List[Dict[str, np.ndarray]], batch_size: int, seq_len: int) -> Dict[str, torch.Tensor]:
        """从 episodes 中随机采样序列用于多步预测评估"""
        selected = []
        for _ in range(batch_size):
            ep = random.choice(episodes)
            if len(ep['actions']) >= seq_len:
                # 随机选择一个起始位置
                start = random.randint(0, len(ep['actions']) - seq_len)
                selected.append((ep, start))
        
        # 如果没有足够长的episode，补充采样
        while len(selected) < batch_size:
            ep = random.choice(episodes)
            if len(ep['actions']) >= seq_len:
                start = random.randint(0, len(ep['actions']) - seq_len)
                selected.append((ep, start))
        
        has_obs = 'obs' in episodes[0] and self.cfg.model.obs_dim > 0
        
        batch_init_obs = []
        batch_target_obs = []
        batch_actions = []
        
        for ep, start in selected:
            if has_obs:
                # 初始观测
                init_obs = ep['obs'][start]  # [obs_dim]
                # 目标观测序列（从start+1到start+seq_len+1）
                target_obs_seq = ep['obs'][start+1:start+seq_len+1]  # [seq_len, obs_dim]
                batch_init_obs.append(init_obs)
                batch_target_obs.append(target_obs_seq)
            
            # 动作序列
            action_seq = ep['actions'][start:start+seq_len]  # [seq_len, action_dim]
            batch_actions.append(action_seq)
        
        # 堆叠并转换为 numpy array
        batch = {}
        if has_obs and batch_init_obs:
            batch_init_obs_array = np.stack(batch_init_obs)  # [batch_size, obs_dim]
            batch_target_obs_array = np.stack(batch_target_obs)  # [batch_size, seq_len, obs_dim]
            
            # 标准化
            if self.cfg.data.normalize_obs and self.obs_mean is not None:
                batch_init_obs_array = (batch_init_obs_array - self.obs_mean) / self.obs_std
                # 目标观测也需要标准化
                batch_size_dim, seq_dim, obs_dim = batch_target_obs_array.shape
                batch_target_obs_array = batch_target_obs_array.reshape(-1, obs_dim)
                batch_target_obs_array = (batch_target_obs_array - self.obs_mean) / self.obs_std
                batch_target_obs_array = batch_target_obs_array.reshape(batch_size_dim, seq_dim, obs_dim)
            
            batch['init_obs'] = torch.from_numpy(batch_init_obs_array).float()
            batch['target_obs'] = torch.from_numpy(batch_target_obs_array).float()
        
        batch_actions_array = np.stack(batch_actions)  # [batch_size, seq_len, action_dim]
        
        # 标准化动作
        if self.cfg.data.normalize_obs and self.action_mean is not None:
            batch_size_dim, seq_dim, action_dim = batch_actions_array.shape
            batch_actions_array = batch_actions_array.reshape(-1, action_dim)
            batch_actions_array = (batch_actions_array - self.action_mean) / self.action_std
            batch_actions_array = batch_actions_array.reshape(batch_size_dim, seq_dim, action_dim)
        
        batch['actions'] = torch.from_numpy(batch_actions_array).float()
        
        return batch
    
    @torch.no_grad()
    def _eval(self):
        """评估模型 - 使用多步预测"""
        self.dynamics.eval()
        
        # 评估配置
        eval_seq_lengths = [1, 5, 10, 20]  # 评估不同的序列长度
        num_eval_batches = max(10, len(self.val_episodes) // self.cfg.training.batch_size)
        
        eval_metrics = {}
        
        # 对每个序列长度分别评估
        for seq_len in eval_seq_lengths:
            total_mse = 0
            total_mse_per_step = np.zeros(seq_len)
            num_batches = 0
            
            for _ in range(num_eval_batches):
                # 采样序列batch
                batch = self._sample_sequence_batch(self.val_episodes, self.cfg.training.batch_size, seq_len)
                
                init_obs = batch['init_obs'].to(self.device)  # [B, obs_dim]
                target_obs = batch['target_obs'].to(self.device)  # [B, seq_len, obs_dim]
                actions = batch['actions'].to(self.device)  # [B, seq_len, action_dim]
                
                # 使用 predict_all 进行多步预测
                predicted_obs = self.dynamics.predict_all(
                    obs=init_obs,
                    action_sequence=actions
                )  # [B, seq_len, obs_dim]
                
                # 反标准化以计算真实误差
                if self.cfg.data.normalize_obs and self.obs_mean is not None:
                    obs_mean_tensor = torch.from_numpy(self.obs_mean).float().to(self.device)
                    obs_std_tensor = torch.from_numpy(self.obs_std).float().to(self.device)
                    predicted_obs_unnorm = predicted_obs * obs_std_tensor + obs_mean_tensor
                    target_obs_unnorm = target_obs * obs_std_tensor + obs_mean_tensor
                else:
                    predicted_obs_unnorm = predicted_obs
                    target_obs_unnorm = target_obs
                
                # 计算MSE（整体）
                mse = torch.mean((predicted_obs_unnorm - target_obs_unnorm) ** 2).item()
                total_mse += mse
                
                # 计算每一步的MSE
                mse_per_step = torch.mean((predicted_obs_unnorm - target_obs_unnorm) ** 2, dim=(0, 2)).cpu().numpy()
                total_mse_per_step += mse_per_step
                
                num_batches += 1
            
            # 计算平均指标
            avg_mse = total_mse / num_batches
            avg_mse_per_step = total_mse_per_step / num_batches
            
            # 记录指标
            eval_metrics[f'val_mse_seq{seq_len}'] = avg_mse
            
            # 记录关键步骤的MSE（第1步、中间步、最后一步）
            eval_metrics[f'val_mse_seq{seq_len}_step1'] = avg_mse_per_step[0]
            if seq_len > 1:
                eval_metrics[f'val_mse_seq{seq_len}_step_mid'] = avg_mse_per_step[seq_len // 2]
                eval_metrics[f'val_mse_seq{seq_len}_step_last'] = avg_mse_per_step[-1]
        
        # 使用seq_len=10作为主要验证损失
        eval_metrics['val_loss'] = eval_metrics.get('val_mse_seq10', eval_metrics.get('val_mse_seq1', 0))
        eval_metrics['val_mse'] = eval_metrics['val_loss']
        
        self.dynamics.train()
        
        return eval_metrics
    
    def _get_common_metrics(self) -> Dict[str, Any]:
        """获取通用指标"""
        _, total_time = self._timer.reset()
        train_transitions = sum(len(ep['actions']) for ep in self.train_episodes) if hasattr(self, 'train_episodes') else 0
        metrics = {
            "total_time": total_time,
            "env_steps": self.global_env_steps,
            "env_episodes": self.global_env_episodes,
            "train_episodes": len(self.train_episodes) if hasattr(self, 'train_episodes') else 0,
            "train_transitions": train_transitions,
        }
        return metrics
    
    def shutdown(self):
        """关闭工作空间"""
        if hasattr(self, 'logger'):
            self.logger.close()
    
    def save_snapshot(self, best_ckpt: bool = False):
        """保存快照"""
        snapshot_dir = self.work_dir / "snapshots"
        snapshot_dir.mkdir(parents=True, exist_ok=True)
        
        if best_ckpt:
            snapshot = snapshot_dir / "best_snapshot.pt"
        else:
            snapshot = snapshot_dir / f"{self.global_env_steps}_snapshot.pt"
        
        keys_to_save = [
            "_pretrain_step",
            "_main_loop_iterations",
            "_global_env_episode",
        ]
        payload = {k: self.__dict__[k] for k in keys_to_save}
        payload["dynamics"] = self.dynamics.state_dict()
        payload["cfg"] = self.cfg.to_dict()
        
        # 保存统计信息
        if self.cfg.data.normalize_obs:
            stats = {
                'obs_mean': self.obs_mean,
                'obs_std': self.obs_std,
                'action_mean': self.action_mean,
                'action_std': self.action_std
            }
            payload["data_statistics"] = stats
        
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        
        # 同时保存为 latest
        latest_snapshot = snapshot_dir / "latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)
        
        print(f"快照已保存至: {snapshot}")
    
    def load_snapshot(self, path_to_snapshot_to_load: Optional[str] = None):
        """加载快照"""
        if path_to_snapshot_to_load is None:
            path_to_snapshot_to_load = (
                self.work_dir / "snapshots" / "latest_snapshot.pt"
            )
        else:
            path_to_snapshot_to_load = Path(path_to_snapshot_to_load)
        
        if not path_to_snapshot_to_load.is_file():
            raise ValueError(
                f"提供的文件 '{str(path_to_snapshot_to_load)}' 不是快照文件。"
            )
        
        with path_to_snapshot_to_load.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        
        # 加载数据统计信息
        if "data_statistics" in payload:
            stats = payload.pop("data_statistics")
            self.obs_mean = stats.get('obs_mean')
            self.obs_std = stats.get('obs_std')
            self.action_mean = stats.get('action_mean')
            self.action_std = stats.get('action_std')
            print(f"已加载数据统计信息:")
            if self.obs_mean is not None:
                print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
                print(f"  观测标准差: {self.obs_std[:5]}...")
            if self.action_mean is not None:
                print(f"  动作均值: {self.action_mean[:5]}...")
                print(f"  动作标准差: {self.action_std[:5]}...")
        else:
            print("警告: 快照中未找到数据统计信息，将使用默认值（None）")
            # 如果快照中没有统计信息，设置为 None（可能是在 normalize_obs=False 时保存的）
            if not hasattr(self, 'obs_mean'):
                self.obs_mean = None
                self.obs_std = None
                self.action_mean = None
                self.action_std = None
        

        # 加载其他状态
        for k, v in payload.items():
            if k in self.__dict__ and k != "cfg":
                self.__dict__[k] = v
        
        # cfg = payload.pop("cfg")
        # remove hidden dims key
        # self.cfg = Config.from_dict(cfg)
        # self.cfg.model.hidden_dims = [256, 256, 256]
        print(f"快照已从 {path_to_snapshot_to_load} 加载")
    
    # ==================== State Deviation 估计方法 ====================
    
    def _compute_chunk_eef_interpolation_with_world_model(
        self, obs_history: List[np.ndarray], rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray, ds: int, min_ds: int = 1, minimum_decay_steps: int = 2
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """使用 World Model 计算动作块的末端执行器插值轨迹
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测 [obs_dim]
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测 [num_cameras, H, W, C] 或 None
            action_chunk: 动作块 [T, action_dim]
            ds: 降采样率
            min_ds: 最小降采样率
        
        Returns:
            pos_interp: 插值后的位置轨迹 [T, 3]
            quat_interp: 插值后的四元数轨迹 [T, 4]
            gripper_interp: 插值后的夹爪状态 [T]
            ds_actions: 降采样后的动作列表
        """
        self.dynamics.eval()
        
        # 获取当前观测（历史列表的最后一个元素）
        current_obs = obs_history[-1]  # [obs_dim]
        # 准备初始观测（支持 framestack）
        framestack = self.cfg.model.framestack
        # 标准化历史观测
        if self.cfg.data.normalize_obs and self.obs_mean is not None:
            obs_history_normalized = [(obs - self.obs_mean) / self.obs_std for obs in obs_history]
        else:
            obs_history_normalized = obs_history
        
        # 应用 framestack：使用历史观测序列，如果不足则用第一帧补全
        if framestack > 1:
            # 获取足够的历史帧（至少 framestack 个）
            if len(obs_history_normalized) >= framestack:
                obs_seq_for_stack = obs_history_normalized[-framestack:]  # 取最后 framestack 个
            else:
                # 历史不足，用第一帧补全
                first_frame = obs_history_normalized[0]
                obs_seq_for_stack = [first_frame] * (framestack - len(obs_history_normalized)) + obs_history_normalized
            
            # 堆叠
            init_obs_stacked = np.concatenate(obs_seq_for_stack, axis=-1)  # [obs_dim * framestack]
        else:
            init_obs_stacked = obs_history_normalized[-1]
        
        # 准备 RGB 观测（支持 framestack）
        if rgb_obs_history is not None and len(rgb_obs_history) > 0 and self.cfg.model.use_pixels:
            # 转换为 [num_cameras, C, H, W] 格式
            rgb_obs_list = []
            for rgb_obs in rgb_obs_history:
                if rgb_obs.shape[-1] == 3:  # [H, W, C] 或 [num_cameras, H, W, C]
                    if len(rgb_obs.shape) == 3:
                        rgb_obs = rgb_obs.transpose(2, 0, 1)[None]  # [1, C, H, W]
                    else:
                        rgb_obs = rgb_obs.transpose(0, 3, 1, 2)  # [num_cameras, C, H, W]
                rgb_obs_list.append(rgb_obs)
            
            # 应用 framestack：使用历史 RGB 观测序列，如果不足则用第一帧补全
            if framestack > 1:
                if len(rgb_obs_list) >= framestack:
                    rgb_seq_for_stack = rgb_obs_list[-framestack:]  # 取最后 framestack 个
                else:
                    # 历史不足，用第一帧补全
                    first_rgb_frame = rgb_obs_list[0]
                    rgb_seq_for_stack = [first_rgb_frame] * (framestack - len(rgb_obs_list)) + rgb_obs_list
                
                # 在通道维度堆叠
                init_rgb_stacked = np.concatenate(rgb_seq_for_stack, axis=1)  # [num_cameras, C * framestack, H, W]
            else:
                init_rgb_stacked = rgb_obs_list[-1]
            
            init_rgb_tensor = torch.from_numpy(init_rgb_stacked).float().unsqueeze(0).to(self.device)  # [1, V, C*framestack, H, W]
        else:
            init_rgb_tensor = None
        
        # 提取初始 EEF 状态（使用当前观测）
        init_pos, init_angle, init_gripper = current_obs[:3], current_obs[3:6], current_obs[6]
        
        all_time_indices = [0]
        ds_actions = []
        
        # 控制器配置（从 adads3.py 复制）
        
        remove_last_length = len(action_chunk) - minimum_decay_steps
        
        # 第一阶段：使用 ds 降采样收集动作，同时记录时间索引
        last_t = -1
        for t in range(ds - 1, remove_last_length, ds):
            # 合并动作
            action_batch = action_chunk[t - ds + 1:t + 1]  # [ds, action_dim]
            action_batch_raw = normalize_action_controller(action_batch)
            action_raw = merge_delta_actions(action_batch_raw)
            action = unnormalize_action_controller(action_raw[None, :])[0]
            ds_actions.append(action)
            all_time_indices.append(min(t + 1, remove_last_length))  # 动作执行后的时间索引
            last_t = t
        
        # 第二阶段：使用 min_ds 收集剩余动作，同时记录时间索引
        for i in range(max(0, last_t + min_ds), len(action_chunk), min_ds):
            action_batch = action_chunk[i - min_ds + 1:i + 1]
            action_batch_raw = normalize_action_controller(action_batch)
            action_raw = merge_delta_actions(action_batch_raw)
            action = unnormalize_action_controller(action_raw[None, :])[0]
            ds_actions.append(action)
            all_time_indices.append(min(i + 1, len(action_chunk)))  # 动作执行后的时间索引
        

        # 第三阶段：将收集到的动作序列一次性输入 world model 进行预测
        ds_actions_array = np.array(ds_actions)  # [num_actions, action_dim]
        if self.cfg.data.normalize_obs and self.action_mean is not None:
            ds_actions_array = (ds_actions_array - self.action_mean) / self.action_std
        else:
            ds_actions_array = ds_actions_array
        action_sequence = torch.from_numpy(ds_actions_array).float().unsqueeze(0).to(self.device)  # [1, T, action_dim]
        obs_tensor = torch.from_numpy(init_obs_stacked[None, :]).float().to(self.device)  # [1, obs_dim*framestack]
        
        # 使用 predict_all 循环预测整个序列
        with torch.no_grad():
            predicted_obs_seq = self.dynamics.predict_all(
                obs=obs_tensor,
                action_sequence=action_sequence
            )  # [1, T, obs_dim]
        predicted_obs = predicted_obs_seq[0].cpu().numpy()
        # 处理序列预测结果
        
        if self.cfg.data.normalize_obs and self.obs_mean is not None:
            predicted_obs_unnorm = predicted_obs * self.obs_std + self.obs_mean
        else:
            predicted_obs_unnorm = predicted_obs.copy()
        # 提取 EEF 状态
        pos, angle, gripper = predicted_obs_unnorm[:, :3], predicted_obs_unnorm[:, 3:6], predicted_obs_unnorm[:, 6]
        position_actuals = np.concatenate((init_pos[None, :], pos), axis=0)
        angle_actuals = np.concatenate((init_angle[None, :], angle), axis=0)
        gripper_actuals = np.concatenate((init_gripper[None], gripper), axis=0)
        quat_actuals = np.array([axisangle2quat(angle) for angle in angle_actuals])
        
        # 插值到所有时间步
        target_indices = np.arange(len(action_chunk)+1)
        pos_interp, quat_interp, gripper_interp = interpolate_pos_quat(
            position_actuals, quat_actuals, gripper_actuals, all_time_indices, target_indices
        )
        
        return pos_interp, quat_interp, gripper_interp, np.array(ds_actions)
    
    def _select_downsample_rate(
        self, obs_history: List[np.ndarray], rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray, downsample_rates: List[int],
        manipulation_threshold: float, movement_threshold: float,
        threshold: Optional[float] = None,
        minimum_decay_steps: int = 2
    ) -> Tuple[int, np.ndarray]:
        """选择降采样率
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测
            action_chunk: 动作块
            downsample_rates: 降采样率列表
            manipulation_threshold: 操作阶段阈值
            movement_threshold: 移动阶段阈值
            threshold: 阈值（如果为 None 则自动检测）
        
        Returns:
            selected_ds: 选择的降采样率
            selected_actions: 降采样后的动作列表
        """
        downsample_list = sorted(list(downsample_rates), reverse=True)
        min_ds = min(downsample_list)
        
        if len(action_chunk) < min_ds:
            return 1, action_chunk
        
        # 计算 baseline（使用最小降采样率）
        baseline_pos_interp, baseline_quat_interp, baseline_gripper_interp, baseline_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, min_ds, min_ds, minimum_decay_steps
            )
        
        use_error = threshold < 1
        
        # 尝试不同的 ds，从大到小
        selected_ds = min_ds
        selected_actions = baseline_actions
        
        for ds in downsample_list:
            if ds > len(action_chunk):
                continue
            if ds != min_ds:
                chunk_pos_interp, chunk_quat_interp, chunk_gripper_interp, chunk_actions = \
                    self._compute_chunk_eef_interpolation_with_world_model(
                        obs_history, rgb_obs_history, action_chunk, ds, min_ds
                    )
                chunk_eef_error = compute_eef_error(
                    chunk_pos_interp, chunk_quat_interp,
                    baseline_pos_interp, baseline_quat_interp
                )
                print(chunk_eef_error)
                if chunk_eef_error <= threshold:
                    selected_ds = ds
                    selected_actions = chunk_actions
                    break
        
        return selected_ds, selected_actions
    
    def compute_state_deviation(
        self,
        obs_history: List[np.ndarray],
        rgb_obs_history: Optional[List[np.ndarray]],
        action_chunk: np.ndarray,
        min_ds: int = 1,
        minimum_decay_steps: int = 2
    ) -> Dict[str, Any]:
        """
        计算给定 action chunk 在降采样率为2时的 state deviation
        
        这个方法计算 ds=2 相对于 baseline (ds=1) 的 EEF 误差。
        
        Args:
            obs_history: 历史低维观测列表，最后一个元素是当前观测 [obs_dim]
            rgb_obs_history: 历史 RGB 观测列表，最后一个元素是当前观测（可选）
            action_chunk: 动作块 [T, action_dim]
            min_ds: 最小降采样率（baseline），默认为1
            minimum_decay_steps: 最小衰减步数，默认为2
        
        Returns:
            结果字典，包含:
                - deviation: float EEF误差（ds=2 vs baseline）
        """
        self.dynamics.eval()
        
        # 计算 baseline（ds=1）
        baseline_pos_interp, baseline_quat_interp, baseline_gripper_interp, baseline_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, min_ds, min_ds, minimum_decay_steps
            )
        
        # 计算 ds=2
        ds2_pos_interp, ds2_quat_interp, ds2_gripper_interp, ds2_actions = \
            self._compute_chunk_eef_interpolation_with_world_model(
                obs_history, rgb_obs_history, action_chunk, 2, min_ds, minimum_decay_steps
            )
        
        # 计算 EEF 误差
        deviation_mean = compute_eef_error(
            ds2_pos_interp, ds2_quat_interp,
            baseline_pos_interp, baseline_quat_interp, mode="mean"
        )
        deviation_max = compute_eef_error(
            ds2_pos_interp, ds2_quat_interp,
            baseline_pos_interp, baseline_quat_interp, mode="max"
        )
        
        result = {
            'deviation_mean': float(deviation_mean),
            'deviation_max': float(deviation_max),
        }
        
        self.dynamics.train()
        return result

    
    def start_server(
        self,
        host: str = "0.0.0.0",
        port: int = 8888,
        downsample_rates: List[int] = None,
        manipulation_threshold: float = 0.002,
        movement_threshold: float = 0.05,
        threshold: Optional[float] = None,
        minimum_decay_steps: int = 2,
    ):
        """
        启动 AdaDS Agent WebSocket 服务器
        
        该方法会启动一个 WebSocket 服务器，提供 API 供其他 Python 环境调用 AdaDS 降采样率选择。
        使用 WebSocket 可以避免与 HTTP proxy 的冲突。
        
        Args:
            host: 服务器绑定的 IP 地址，默认为 "0.0.0.0"（允许外部访问）
            port: 服务器端口，默认为 8888
            downsample_rates: 降采样率列表，默认为 [1, 2]
            manipulation_threshold: 操作阶段阈值，默认为 0.002
            movement_threshold: 移动阶段阈值，默认为 0.05
            threshold: 阈值（如果为 None 则自动检测），默认为 None
            minimum_decay_steps: 最小衰减步数，默认为 2
        
        使用示例:
            # 加载 workspace 后启动服务器
            workspace = DynamicsWorkspace(cfg, work_dir="...", train=False)
            workspace.load_snapshot()
            
            # 启动服务器
            workspace.start_server(
                host="0.0.0.0",
                port=8888,
                downsample_rates=[1, 2]
            )
        """
        start_adads_agent_server(
            workspace=self,
            host=host,
            port=port,
            downsample_rates=downsample_rates,
            manipulation_threshold=manipulation_threshold,
            movement_threshold=movement_threshold,
            threshold=threshold,
            minimum_decay_steps=minimum_decay_steps,
        )
    
    def collect_transitions_with_different_ds(
        self,
        task_suite_name: str = "libero_spatial",
        num_episodes: int = 10,
        downsample_rates: List[int] = [1, 2, 3],
        dataset_name: str = "libero_spatial",
        num_steps_wait: int = 10,
        seed: int = 7,
        resize_size: int = 224,
        control_freq: int = 20,
        original_control_freq: int = 20,
        policy_host: str = "0.0.0.0",
        policy_port: int = 8000,
    ) -> Dict[str, List[Dict[str, np.ndarray]]]:
        """收集不同降采样率的 transition 数据
        
        Args:
            task_suite_name: 任务套件名称
            num_episodes: 每个降采样率收集的 episode 数量
            downsample_rates: 降采样率列表
            num_steps_wait: 等待步数
            seed: 随机种子
            resize_size: 图像调整大小
            control_freq: 控制频率
            original_control_freq: 原始控制频率
            policy_host: 策略服务器地址
            policy_port: 策略服务器端口
        
        Returns:
            transitions_by_ds: 按降采样率组织的 transition 字典 {ds: [transitions]}
        """
        import pickle
        
        # 创建保存目录
        save_dir = Path("exp_output") / dataset_name / "transitions"

        if os.path.exists(save_dir):
            print("Transitions already exist")
            return
        
        save_dir.mkdir(parents=True, exist_ok=True)

        repeats = control_freq // original_control_freq
        np.random.seed(seed)
        
        # 初始化 LIBERO task suite
        benchmark_dict = benchmark.get_benchmark_dict()
        task_suite = benchmark_dict[task_suite_name]()
        num_tasks_in_suite = task_suite.n_tasks
        
        libero_dummy_action = [0.0] * 6 + [-1.0]
        libero_env_resolution = 256
        
        policy_client = websocket_client_policy.WebsocketClientPolicy(policy_host, policy_port)
        apply_patches()
        
        # 设置最大步数
        max_steps_dict = {
            "libero_spatial": 220,
            "libero_object": 280,
            "libero_goal": 300,
            "libero_10": 520,
            "libero_90": 400,
        }
        if task_suite_name not in max_steps_dict:
            raise ValueError(f"Unknown task suite: {task_suite_name}")
        max_steps = max_steps_dict[task_suite_name] * repeats
        num_steps_wait = num_steps_wait * repeats
        
        # 存储 transitions（按降采样率组织）
        transitions_by_ds = {ds: [] for ds in downsample_rates}
        for ds in downsample_rates:
            if ds == 1:
                revert_patches()
            else:
                apply_patches()
            # 遍历所有任务
            for task_id in tqdm(range(num_tasks_in_suite), desc="Collecting transitions"):
                task = task_suite.get_task(task_id)
                initial_states = task_suite.get_task_init_states(task_id)
                task_description = task.language
                
                # 初始化环境
                task_bddl_file = Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
                env_args = {
                    "bddl_file_name": task_bddl_file,
                    "camera_heights": libero_env_resolution,
                    "camera_widths": libero_env_resolution,
                    "control_freq": control_freq
                }
                env = OffScreenRenderEnv(**env_args)
                env.seed(seed)
                
                # 为每个降采样率收集数据
                for episode_idx in tqdm(range(num_episodes), desc=f"Task {task_id} DS={ds}", leave=False):
                    # 重置环境
                    env.reset()
                    obs = env.set_init_state(initial_states[episode_idx])
                    
                    # 等待物体稳定
                    for _ in range(num_steps_wait):
                        obs, _, _, _ = env.step(libero_dummy_action)
                    
                    # 收集 transition
                    episode_obs = []
                    episode_next_obs = []
                    episode_actions = []
                    episode_rgb_obs = []
                    episode_next_rgb_obs = []
                    
                    t = 0
                    done = False
                    
                    while t < max_steps and not done:
                        # 获取观测
                        img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                        wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
                        
                        # 预处理图像
                        img_processed = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img, resize_size, resize_size)
                        )
                        wrist_img_processed = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(wrist_img, resize_size, resize_size)
                        )
                        
                        # 准备观测字典（用于策略）
                        state = np.concatenate((
                            obs["robot0_eef_pos"],
                            quat2axisangle(obs["robot0_eef_quat"]),
                            obs["robot0_gripper_qpos"],
                        ))
                        element = {
                            "observation/image": img_processed,
                            "observation/wrist_image": wrist_img_processed,
                            "observation/state": state,
                            "prompt": str(task_description),
                        }
                        
                        # 从策略获取动作块
                        action_chunk = policy_client.infer(element)["actions"]
                        
                        # 根据降采样率执行动作
                        if ds == 1:
                            # 不降采样：执行所有动作
                            actions_to_execute = action_chunk
                        else:
                            
                            # 合并 ds 个动作为一个
                            actions_to_execute = []
                            for i in range(0, len(action_chunk), ds):
                                action_batch = action_chunk[i:i+ds]
                                if len(action_batch) > 0:
                                    action_batch_raw = normalize_action_controller(action_batch)
                                    action_raw = merge_delta_actions(action_batch_raw)
                                    action = unnormalize_action_controller(action_raw[None, :])[0]
                                    actions_to_execute.append(action)
                        
                        # 执行动作并收集 transition
                        for action in actions_to_execute:
                            if done:
                                break
                            
                            # 保存当前观测
                            current_state = state.copy()
                            current_rgb = np.stack([img, wrist_img], axis=0)
                            
                            # 执行动作
                            next_obs, reward, done, info = env.step(action.tolist())
                            
                            # 保存下一个观测
                            next_img = np.ascontiguousarray(next_obs["agentview_image"][::-1, ::-1])
                            next_wrist_img = np.ascontiguousarray(next_obs["robot0_eye_in_hand_image"][::-1, ::-1])
                            next_state = np.concatenate((
                                next_obs["robot0_eef_pos"],
                                quat2axisangle(next_obs["robot0_eef_quat"]),
                                next_obs["robot0_gripper_qpos"],
                            ))
                            next_rgb = np.stack([next_img, next_wrist_img], axis=0) if self.cfg.model.use_pixels else None
                            
                            # 保存 transition
                            episode_obs.append(current_state)
                            episode_next_obs.append(next_state)
                            episode_actions.append(action)
                            if current_rgb is not None:
                                episode_rgb_obs.append(current_rgb)
                                episode_next_rgb_obs.append(next_rgb)
                            
                            obs = next_obs
                            state = next_state
                            img = next_img
                            wrist_img = next_wrist_img
                            t += 1
                            
                            if done:
                                break
                    
                    # 构建 episode transition
                    episode_actions = np.array(episode_actions, dtype=np.float32)
                    episode_obs = np.array(episode_obs, dtype=np.float32)
                    episode_next_obs = np.array(episode_next_obs, dtype=np.float32)

                    if len(episode_obs) > 0:
                        transition = {
                            'obs': episode_obs,
                            'next_obs': episode_next_obs,
                            'actions': episode_actions,
                            'task_id': task_id,
                            'task_description': task_description,
                            'episode_idx': episode_idx,
                        }
                        if len(episode_rgb_obs) > 0:
                            transition['rgb_obs'] = np.array(episode_rgb_obs, dtype=np.uint8)
                            transition['next_rgb_obs'] = np.array(episode_next_rgb_obs, dtype=np.uint8)
                        
                        transitions_by_ds[ds].append(transition)

        # 保存每个降采样率的数据
        for ds, transitions in transitions_by_ds.items():
            save_path = save_dir / f"transitions_ds{ds}.pkl"
            with open(save_path, 'wb') as f:
                pickle.dump(transitions, f)
            print(f"已保存 {len(transitions)} 个 DS={ds} 的 transitions 到 {save_path}")
        
        # 保存元数据
        metadata = {
            'dataset_name': dataset_name,
            'downsample_rates': list(transitions_by_ds.keys()),
            'num_transitions_by_ds': {str(ds): len(transitions) for ds, transitions in transitions_by_ds.items()},
        }
        metadata_path = save_dir / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"已保存元数据到 {metadata_path}")

    
    def eval_world_model_on_transitions(
        self,
        dataset_name: str,
        downsample_rates: List[int] = [1, 2, 3],
        verbose: bool = True,
    ) -> Dict[str, Any]:
        """在收集的 transition 上评估 World Model 性能
        
        Args:
            dataset_name: 数据集名称
            downsample_rates: 要评估的降采样率列表
        
        Returns:
            评估结果字典
        """
        import pickle
        
        # 加载 transition 数据
        transitions_dir = Path("exp_output") / dataset_name / "transitions"
        if not transitions_dir.exists():
            raise ValueError(f"Transitions 目录不存在: {transitions_dir}")
        
        self.dynamics.eval()
        
        results = {}
        
        for ds in downsample_rates:
            transitions_path = transitions_dir / f"transitions_ds{ds}.pkl"
            if not transitions_path.exists():
                print(f"警告: 未找到 DS={ds} 的 transitions 文件: {transitions_path}")
                continue
            
            with open(transitions_path, 'rb') as f:
                transitions = pickle.load(f)
            
            if verbose:
                print(f"\n评估 DS={ds} 的 transitions ({len(transitions)} 个 episodes)...")
            
            # 评估指标
            total_mae = 0.0
            total_transitions = 0
            episode_errors = []
            seq_len = self.cfg.training.seq_len
            
            for transition in tqdm(transitions, desc=f"Evaluating DS={ds}"):
                obs = transition['obs']  # [T, obs_dim]
                next_obs = transition['next_obs']  # [T, obs_dim]
                actions = transition['actions']  # [T, action_dim]
                
                # 如果 episode 长度小于 seq_len，跳过
                if len(actions) < seq_len:
                    continue
                
                # 准备数据
                if self.cfg.data.normalize_obs and self.obs_mean is not None:
                    obs_normalized = (obs - self.obs_mean) / self.obs_std
                    next_obs_normalized = (next_obs - self.obs_mean) / self.obs_std
                else:
                    obs_normalized = obs
                    next_obs_normalized = next_obs
                
                # 准备 RGB 观测（如果有）
                rgb_obs_array = None
                if 'rgb_obs' in transition and self.cfg.model.use_pixels:
                    rgb_obs_array = transition['rgb_obs']  # [T, num_cameras, H, W, C]
                    # 转换为 [T, num_cameras, C, H, W] 格式
                    if rgb_obs_array.shape[-1] == 3:
                        rgb_obs_array = rgb_obs_array.transpose(0, 1, 4, 2, 3)  # [T, num_cameras, C, H, W]
                
                # 将 episode 分割成多个 seq_len 长度的序列进行评估
                # 与训练时一致：从每个 episode 中采样多个 seq_len 长度的序列
                num_sequences = len(actions) - seq_len + 1
                episode_mae = 0.0
                episode_transitions = 0
                
                # 评估所有可能的序列（或者可以随机采样，这里评估所有序列以保持一致性）
                framestack = self.cfg.model.framestack
                
                for start_idx in range(num_sequences):
                    # 计算需要的起始索引（包含 framestack-1 个历史帧）
                    history_start = max(0, start_idx - (framestack - 1))
                    
                    # 提取 obs 序列（包含历史帧）
                    obs_seq_with_history = obs_normalized[history_start:start_idx + seq_len + 1]  # [?, obs_dim]
                    
                    # 如果历史不足，用第一帧补全
                    if history_start == 0 and start_idx > 0:
                        num_missing = (framestack - 1) - start_idx
                        if num_missing > 0:
                            first_frame = obs_normalized[0:1]  # [1, obs_dim]
                            padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, obs_dim]
                            obs_seq_with_history = np.concatenate([padding, obs_seq_with_history], axis=0)
                    
                    # 应用 framestack（现在有真实历史数据）
                    seq_obs_stacked = self._apply_framestack(
                        obs_seq_with_history[:len(obs_seq_with_history)-1], 
                        framestack, 
                        first_frame=None  # 不需要补全，已经有完整历史
                    )[-seq_len:]  # 只取最后 seq_len 个，维度：[seq_len, obs_dim*framestack]
                    
                    # 目标观测不需要 framestack
                    next_obs_seq = obs_seq_with_history[-seq_len:]  # [seq_len, obs_dim]
                    
                    seq_actions = actions[start_idx:start_idx+seq_len]  # [seq_len, action_dim]
                    
                    # 标准化动作
                    if self.cfg.data.normalize_obs and self.action_mean is not None:
                        seq_actions_normalized = (seq_actions - self.action_mean) / self.action_std
                    else:
                        seq_actions_normalized = seq_actions
                    
                    # 转换为 tensor（使用堆叠后的第一个观测）
                    first_obs_tensor = torch.from_numpy(seq_obs_stacked[0:1]).float().to(self.device)  # [1, obs_dim*framestack]
                    actions_tensor = torch.from_numpy(seq_actions_normalized).float().unsqueeze(0).to(self.device)  # [1, seq_len, action_dim]
                    
                    # 准备 RGB 观测（如果使用 framestack，需要堆叠）
                    first_rgb = None
                    if rgb_obs_array is not None:
                        # 计算 RGB 序列的起始索引（包含 framestack-1 个历史帧）
                        rgb_history_start = max(0, start_idx - (framestack - 1))
                        
                        # 提取 RGB 序列（包含历史）
                        rgb_seq_with_history = rgb_obs_array[rgb_history_start:start_idx + seq_len + 1]  # [?, num_cameras, C, H, W]
                        
                        # 如果历史不足，用第一帧补全
                        if rgb_history_start == 0 and start_idx > 0:
                            num_missing = (framestack - 1) - start_idx
                            if num_missing > 0:
                                first_frame = rgb_obs_array[0:1]  # [1, num_cameras, C, H, W]
                                padding = np.repeat(first_frame, num_missing, axis=0)  # [num_missing, num_cameras, C, H, W]
                                rgb_seq_with_history = np.concatenate([padding, rgb_seq_with_history], axis=0)
                        
                        # 应用 framestack（现在有真实历史数据）
                        rgb_seq_stacked = self._apply_framestack(
                            rgb_seq_with_history[:len(rgb_seq_with_history)-1], 
                            framestack,
                            first_frame=None  # 不需要补全，已经有完整历史
                        )[-seq_len:]  # 只取最后 seq_len 个，维度：[seq_len, num_cameras, C*framestack, H, W]
                        
                        first_rgb = torch.from_numpy(rgb_seq_stacked[0:1]).float().to(self.device)  # [1, num_cameras, C*framestack, H, W]
                    
                    # 使用 predict_all 循环预测
                    with torch.no_grad():
                        predicted_seq = self.dynamics.predict_all(
                            obs=first_obs_tensor,
                            action_sequence=actions_tensor
                        )  # [1, seq_len, obs_dim]
                    
                    # 处理预测结果
                    predicted_obs_seq = predicted_seq[0].cpu().numpy()  # [seq_len, obs_dim]
                    
                    # 反标准化（next_obs_seq 是目标观测，需要反标准化）
                    if self.cfg.data.normalize_obs and self.obs_mean is not None:
                        predicted_obs_seq = predicted_obs_seq * self.obs_std + self.obs_mean
                        # next_obs_seq 需要反标准化
                        next_obs_original = next_obs_seq * self.obs_std + self.obs_mean
                    else:
                        next_obs_original = next_obs_seq  # [seq_len, obs_dim]
                    
                    # 计算误差，不考虑夹爪
                    seq_mae = np.mean(np.abs(predicted_obs_seq - next_obs_original)[:, :-2])
                    
                    episode_mae += seq_mae
                    episode_transitions += 1
                    total_mae += seq_mae
                    total_transitions += 1
                
                # 记录 episode 级别的误差
                if episode_transitions > 0:
                    avg_episode_mae = episode_mae / episode_transitions
                    episode_errors.append({
                        'mae': float(avg_episode_mae),
                        'num_sequences': episode_transitions,
                        'num_steps': len(actions),
                    })
            
            # 计算平均误差
            avg_mae = total_mae / total_transitions if total_transitions > 0 else 0.0
            
            results[f'ds_{ds}'] = {
                'avg_mae': float(avg_mae),
                'num_episodes': len(transitions),
                'num_transitions': total_transitions,
                'episode_errors': episode_errors,
            }
            
            if verbose:
                print(f"DS={ds} 评估结果:")
                print(f"  平均 MAE: {avg_mae:.6f}")
                print(f"  Episodes: {len(transitions)}")
                print(f"  Transitions: {total_transitions}")
        

        
        return results


def start_adads_agent_server(
    workspace,
    host: str = "0.0.0.0",
    port: int = 8888,
    downsample_rates: List[int] = None,
    manipulation_threshold: float = 0.002,
    movement_threshold: float = 0.05,
    threshold: Optional[float] = None,
    minimum_decay_steps: int = 2,
):
    """
    启动 AdaDS Agent 的 WebSocket 服务器
    
    该函数会启动一个 WebSocket 服务器，提供 API 供其他 Python 环境调用 AdaDS 降采样率选择。
    使用 WebSocket 可以避免与 HTTP proxy 的冲突。
    
    Args:
        workspace: DynamicsWorkspace 实例（已加载 dynamics model）
        host: 服务器绑定的 IP 地址，默认为 "0.0.0.0"（允许外部访问）
        port: 服务器端口，默认为 8888
        downsample_rates: 降采样率列表，默认为 [1, 2]
        manipulation_threshold: 操作阶段阈值
        movement_threshold: 移动阶段阈值
        threshold: 阈值（如果为 None 则自动检测）
        minimum_decay_steps: 最小衰减步数
    """
    try:
        import websockets.sync.server
    except ImportError:
        raise ImportError(
            "websockets is required to run the server. "
            "Please install it with: pip install websockets"
        )
    
    if downsample_rates is None:
        downsample_rates = [1, 2]
    
    from msgpack_numpy_utils import Packer, unpackb
    
    def handle_client(websocket):
        """处理客户端连接"""
        logging.info(f"Connection opened")
        packer = Packer()
        
        try:
            # 发送服务器元数据
            metadata = {
                "status": "ready",
                "downsample_rates": downsample_rates
            }
            websocket.send(packer.pack(metadata))
            
            # 处理消息循环
            while True:
                try:
                    # 接收消息
                    message = websocket.recv()
                    if isinstance(message, str):
                        # 字符串表示错误，关闭连接
                        logging.error(f"Received string message: {message}")
                        break
                    
                    data = unpackb(message)
                    request_type = data.get("type", "predict_k")
                    
                    if request_type == "health":
                        # 健康检查
                        response = {
                            "status": "success",
                            "downsample_rates": downsample_rates
                        }
                        websocket.send(packer.pack(response))
                    
                    elif request_type == "state_deviation":
                        # 计算 ds=2 的 state deviation
                        if "s_env" not in data or "A" not in data:
                            response = {
                                "status": "error",
                                "message": "Request must contain 's_env' and 'A' fields"
                            }
                            websocket.send(packer.pack(response))
                            continue
                        
                        s_env = np.array(data["s_env"], dtype=np.float32)
                        A = np.array(data["A"], dtype=np.float32)
                        
                        # 解析参数
                        min_ds = data.get("min_ds", 1)
                        minimum_decay_steps = data.get("minimum_decay_steps", 2)
                        
                        # 准备观测历史
                        if len(s_env.shape) > 1:
                            obs_history = [s_env[i] for i in range(len(s_env))]
                        else:
                            obs_history = [s_env]
                        
                        # 解析 RGB 观测（可选）
                        rgb_obs_history = None
                        if "rgb_obs" in data and data["rgb_obs"] is not None:
                            rgb_obs = np.array(data["rgb_obs"], dtype=np.uint8)
                            if len(rgb_obs.shape) >= 4 and rgb_obs.shape[0] > 1:
                                rgb_obs_history = [rgb_obs[i] for i in range(len(rgb_obs))]
                            else:
                                rgb_obs_history = [rgb_obs]
                        
                        # 计算 state deviation
                        result = workspace.compute_state_deviation(
                            obs_history=obs_history,
                            rgb_obs_history=rgb_obs_history,
                            action_chunk=A,
                            min_ds=min_ds,
                            minimum_decay_steps=minimum_decay_steps
                        )
                        
                        response = {
                            "status": "success",
                            "deviation_mean": result['deviation_mean'],
                            "deviation_max": result['deviation_max'],
                        }
                        
                        websocket.send(packer.pack(response))
                    
                    elif request_type == "predict_k":
                        # 选择降采样率
                        if "s_env" not in data or "A" not in data:
                            response = {
                                "status": "error",
                                "message": "Request must contain 's_env' and 'A' fields"
                            }
                            websocket.send(packer.pack(response))
                            continue
                        
                        s_env = np.array(data["s_env"], dtype=np.float32)
                        A = np.array(data["A"], dtype=np.float32)
                        
                        # 获取 framestack 配置
                        framestack = workspace.cfg.model.framestack
                        if len(s_env.shape) > 1:
                            obs_history = [s_env[i] for i in range(len(s_env))]
                        else:
                            obs_history = [s_env] * max(1, framestack)
                        
                        # 解析 RGB 观测（可选）
                        rgb_obs_history = None
                        if "rgb_obs" in data and data["rgb_obs"] is not None:
                            rgb_obs = np.array(data["rgb_obs"], dtype=np.uint8)
                            if len(rgb_obs.shape) >= 4 and rgb_obs.shape[0] > 1:
                                rgb_obs_history = [rgb_obs[i] for i in range(len(rgb_obs))]
                            else:
                                rgb_obs_history = [rgb_obs] * max(1, framestack)
                        
                        # 调用 workspace 的 _select_downsample_rate 方法
                        selected_ds, selected_actions = workspace._select_downsample_rate(
                            obs_history=obs_history,
                            rgb_obs_history=rgb_obs_history,
                            action_chunk=A,
                            downsample_rates=downsample_rates,
                            manipulation_threshold=manipulation_threshold,
                            movement_threshold=movement_threshold,
                            threshold=threshold,
                            minimum_decay_steps=minimum_decay_steps
                        )
                        
                        response = {
                            "status": "success",
                            "predicted_k": int(selected_ds)
                        }
                        websocket.send(packer.pack(response))
                    
                    else:
                        response = {
                            "status": "error",
                            "message": f"Unknown request type: {request_type}"
                        }
                        websocket.send(packer.pack(response))
                
                except Exception as e:
                    # 发送错误信息
                    import traceback
                    error_traceback = traceback.format_exc()
                    logging.error(f"Error handling message: {error_traceback}")
                    websocket.send(error_traceback)
                    break
        
        except Exception as e:
            logging.error(f"Error in client handler: {e}")
        
        logging.info(f"Connection closed")
    
    # 启动服务器
    uri = f"ws://{host}:{port}"
    print(f"启动 AdaDS Agent WebSocket 服务器...")
    print(f"服务器地址: {uri}")
    print(f"降采样率列表: {downsample_rates}")
    print(f"按 Ctrl+C 停止服务器")
    
    with websockets.sync.server.serve(handle_client, host, port, compression=None, max_size=None) as server:
        try:
            server.serve_forever()
        except KeyboardInterrupt:
            print("\n正在关闭服务器...")

