"""
Policy Workspace - Safe IQL Policy 训练工作空间
仿照 robobase/robobase/policy_workspace.py 实现
"""

import shutil
import signal
import sys
import os
import json
import time
import copy
import random
from typing import Callable, Any, Optional, Dict, List, Tuple
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from tqdm import tqdm

from config import Config
from models.world_model import WorldModel
from dataprocess import create_loader, analyze_episodes
from utils import (
    set_seed,
    Timer,
    Every,
    Until,
    Logger,
    create_output_dir,
    save_config,

    normalize_action_controller,
    unnormalize_action_controller,
    merge_delta_actions,
    batch_interpolate_eef,
    process_obs_for_libero,
    process_obs_for_real,
    compute_eef_error,
    compute_eef_error_ratio,
    
    calculate_model_size,
    calculate_loss_iteration_flops
)
import logging



def soft_update_params(net, target_net, tau):
    """软更新目标网络参数"""
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


class PolicyWorkspace:
    """Safe IQL Policy 训练工作空间"""
    
    def __init__(
        self,
        cfg: Config,
        work_dir: Optional[str] = None,
        train: bool = True,
    ):
        """
        Args:
            cfg: 配置对象
            work_dir: 工作目录（如果为 None，则使用配置中的输出目录）
            train: 是否训练
        """
        self.cfg = cfg
        
        # 设置随机种子
        set_seed(cfg.training.seed)
        
        # 设置设备
        dev = "cpu"
        if cfg.training.device != "cpu":
            if sys.platform == "darwin":
                dev = "mps"
            else:
                dev = cfg.training.device
        self.device = torch.device(dev)
        print(f"使用设备: {self.device}")
        
        # 创建工作目录
        if work_dir is None:
            self.work_dir = create_output_dir(
                cfg.log.output_dir,
                cfg.log.exp_name
            )
        else:
            self.work_dir = Path(work_dir)
            self.work_dir.mkdir(parents=True, exist_ok=True)
        print(f"工作目录: {self.work_dir}")
        
        # 保存配置
        save_config(cfg.to_dict(), self.work_dir)
        
        # 创建 logger
        self.logger = Logger(
            log_dir=self.work_dir,
            use_wandb=cfg.log.use_wandb,
            use_tensorboard=cfg.log.use_tensorboard
        )
        
        # 如果使用 wandb，初始化
        if cfg.log.use_wandb:
            import wandb
            wandb.init(
                project=cfg.log.wandb_project,
                entity=cfg.log.wandb_entity,
                name=cfg.log.wandb_name or cfg.log.exp_name,
                config=cfg.to_dict()
            )

        
        # 加载数据
        if train:
            print("\n加载数据...")
            self.load_data()
        
        # 加载 dynamics model
        print("\n加载 dynamics model...")
        self.load_dynamics_snapshot()
        
        # 计时器
        self._timer = Timer()
        
        # 训练状态
        self._pretrain_step = 0
        self._main_loop_iterations = 0
        self._global_env_episode = 0
        
        # 最佳指标
        self.best_metrics = {
            "best_episode_success": 0,
            "best_episode_len": 0,
        }
        
        # 关闭标志
        self._shutting_down = False
    
    def load_data(self):
        """加载并准备数据"""
        # 创建数据加载器
        loader = create_loader(
            env_name=self.cfg.data.env_name,
            data_path=self.cfg.data.data_dir,
            **self.cfg.data.loader_kwargs
        )
        
        # 加载 episodes
        episodes = loader.load_episodes(num_episodes=self.cfg.data.num_demos)
        
        # 分析数据
        data_stats = analyze_episodes(episodes)
        
        # 保存 episodes 供后续使用
        self.episodes = episodes
        
        # 保存数据统计信息
        self.obs_dim = data_stats.get('obs_dim', 0)
        self.action_dim = data_stats.get('action_dim', 0)
        if data_stats.get('has_rgb'):
            rgb_shape = data_stats['rgb_shape']
            self.num_cameras = rgb_shape[0]
        else:
            self.num_cameras = 0
    
    def load_dynamics_snapshot(self, path_to_dynamics_snapshot: Optional[str] = None):
        """加载 dynamics model 快照（参考 dynamics_workspace.py 的 load_snapshot）"""
        path_to_dynamics_snapshot = str(self.work_dir.parent / "world_model" / "snapshots" / "latest_snapshot.pt")
        path_to_dynamics_snapshot = Path(path_to_dynamics_snapshot)
        
        if not path_to_dynamics_snapshot.exists():
            raise FileNotFoundError(
                f"Dynamics snapshot not found: {path_to_dynamics_snapshot}"
            )
        
        with path_to_dynamics_snapshot.open("rb") as f:
            payload = torch.load(f, weights_only=False)
            # payload = torch.load(f)

        
        # 从快照中获取配置
        self.dynamics_cfg = Config.from_dict(payload.pop("cfg", {}))
        
        # 从快照配置中获取模型参数（优先使用快照配置，因为模型是用这些参数训练的）
        model_cfg = self.dynamics_cfg.model
        training_cfg = self.dynamics_cfg.training
        
        # 计算实际的 obs_dim（考虑 framestack）
        actual_obs_dim = model_cfg.obs_dim if model_cfg.obs_dim > 0 else 0
        
        # 创建 dynamics model（使用快照配置中的参数）
        self.dynamics = WorldModel(
            obs_dim=actual_obs_dim,
            action_dim=model_cfg.action_dim,
            use_pixels=model_cfg.use_pixels,
            image_channels=model_cfg.image_channels * model_cfg.framestack if model_cfg.use_pixels else model_cfg.image_channels,
            image_size=model_cfg.image_size,
            num_cameras=model_cfg.num_cameras,
            use_dinov2=model_cfg.use_dinov2 if hasattr(model_cfg, 'use_dinov2') else False,
            dinov2_model_type=model_cfg.dinov2_model_type if hasattr(model_cfg, 'dinov2_model_type') else 'dinov2_vits14',
            dinov2_visual_feature_dim=model_cfg.dinov2_visual_feature_dim if hasattr(model_cfg, 'dinov2_visual_feature_dim') else 64,
            dinov2_mlp_hidden_dims=model_cfg.dinov2_mlp_hidden_dims if hasattr(model_cfg, 'dinov2_mlp_hidden_dims') else [256, 64],
            dinov2_use_cls_token=model_cfg.dinov2_use_cls_token if hasattr(model_cfg, 'dinov2_use_cls_token') else True,
            dinov2_dropout=model_cfg.dinov2_dropout if hasattr(model_cfg, 'dinov2_dropout') else 0.0,
            hidden_dim=model_cfg.hidden_dim,
            rnn_num_layers=model_cfg.rnn_num_layers,
            dropout=model_cfg.dropout,
            learning_rate=training_cfg.learning_rate,
            weight_decay=training_cfg.weight_decay,
            grad_clip=training_cfg.grad_clip,
            use_symlog=model_cfg.use_symlog,
            use_var=model_cfg.use_var,
            use_residual=model_cfg.use_residual,
            framestack=model_cfg.framestack,
            device=str(self.device)
        )
        
        # 加载模型权重
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        self.dynamics.train(False)
        
        # 加载数据统计信息（如果存在）
        if "data_statistics" in payload:
            stats = payload.pop("data_statistics")
            self.obs_mean = stats.get('obs_mean')
            self.obs_std = stats.get('obs_std')
            self.action_mean = stats.get('action_mean')
            self.action_std = stats.get('action_std')
            print(f"已加载数据统计信息:")
            if self.obs_mean is not None:
                print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
                print(f"  观测标准差: {self.obs_std[:5]}...")
            if self.action_mean is not None:
                print(f"  动作均值: {self.action_mean[:5]}...")
                print(f"  动作标准差: {self.action_std[:5]}...")
        else:
            print("警告: 快照中未找到数据统计信息，将使用默认值（None）")
            # 如果快照中没有统计信息，设置为 None（可能是在 normalize_obs=False 时保存的）
            self.obs_mean = None
            self.obs_std = None
            self.action_mean = None
            self.action_std = None
        
        print(f"\nDynamics model loaded from {path_to_dynamics_snapshot}")
        print(f"模型参数量: {sum(p.numel() for p in self.dynamics.parameters()):,}")
        if model_cfg.framestack > 1:
            print(f"Framestack: {model_cfg.framestack} (实际输入维度: obs_dim={actual_obs_dim}, image_channels={model_cfg.image_channels * model_cfg.framestack if model_cfg.use_pixels else model_cfg.image_channels})")
        if model_cfg.use_pixels and (model_cfg.use_dinov2 if hasattr(model_cfg, 'use_dinov2') else False):
            print(f"使用 DINOv2 编码器: {model_cfg.dinov2_model_type if hasattr(model_cfg, 'dinov2_model_type') else 'dinov2_vits14'}")
    
    def _apply_framestack(self, obs_sequence: np.ndarray, framestack: int, first_frame: Optional[np.ndarray] = None) -> np.ndarray:
        """对观测序列应用 framestack（参考 dynamics_workspace.py）
        
        Args:
            obs_sequence: [T, ...] 观测序列
            framestack: 堆叠的帧数
            first_frame: 第一帧观测（用于补全历史，如果为 None 则使用 obs_sequence[0]）
        
        Returns:
            stacked_obs: [T, ...*framestack] 堆叠后的观测
        """
        if framestack == 1:
            return obs_sequence
        
        T = len(obs_sequence)
        if T == 0:
            return obs_sequence
        
        # 如果没有提供 first_frame，使用序列的第一帧
        if first_frame is None:
            first_frame = obs_sequence[0]
        
        stacked_obs = []
        
        for t in range(T):
            # 获取前 framestack 帧（从当前帧往前回溯）
            frames = []
            for i in range(framestack):
                idx = t - (framestack - 1 - i)
                if idx < 0:
                    # 历史帧不足，用第一帧补全
                    frames.append(first_frame)
                else:
                    frames.append(obs_sequence[idx])
            
            # 堆叠：对于低维观测在特征维度堆叠，对于 RGB 在通道维度堆叠
            if len(obs_sequence.shape) == 2:
                # 低维观测：[obs_dim] -> [obs_dim * framestack]
                stacked = np.concatenate(frames, axis=-1)
            elif len(obs_sequence.shape) == 5:
                # RGB 观测：[num_cameras, C, H, W] -> [num_cameras, C * framestack, H, W]
                stacked = np.concatenate(frames, axis=2)  # axis=2 是通道维度
            else:
                # 其他情况，尝试在最后一个维度堆叠
                stacked = np.concatenate(frames, axis=-1)
            
            stacked_obs.append(stacked)
        
        return np.array(stacked_obs)
    
    def get_feedback_eef_interpolated_batch(
        self, eef_model, obs: List[List[np.ndarray]], rgb_obs: Optional[List[Optional[List[np.ndarray]]]], 
        action_chunks: List[np.ndarray], ds: int, min_ds: int = 1, minimum_decay_steps: int = 2
    ):
        """
        批量版本的 get_feedback_eef_interpolated（真正的并行批处理）
        
        Args:
            eef_model: World Model 模型
            obs_histories: 历史低维观测列表的列表，每个元素是 [obs_dim] 的列表
            rgb_obs_histories: 历史 RGB 观测列表的列表，每个元素是 [num_cameras, H, W, C] 的列表或 None
            action_chunks: 动作块列表，每个元素是 [T, action_dim]
            ds: 降采样率
            min_ds: 最小降采样率
            minimum_decay_steps: 最小衰减步数
        
        Returns:
            eef_trajectories: EEF 轨迹列表，每个元素是 [T+1, 7]
            next_states: 下一个状态数组 [batch_size, obs_dim]
        """
        eef_model.eval()
        
        batch_size = len(obs)
        normalize_obs = self.dynamics_cfg.data.normalize_obs if hasattr(self.dynamics_cfg, 'data') else False
        use_pixels = self.dynamics_cfg.model.use_pixels if hasattr(self.dynamics_cfg, 'model') else False
        
        # 准备批量数据
        batch_obs = obs
        batch_rgb_obs = rgb_obs
        batch_ds_actions = []
        batch_all_time_indices = []

        batch_target_indices = []
        
        for i in range(batch_size):
            current_obs = obs[i]
            action_chunk = action_chunks[i]
            
            # 处理动作降采样
            all_time_indices = [0]
            if self.cfg.safe_iql.control_mode == "abs":
                ds_actions = action_chunk[ds-1::ds]
                all_time_indices.extend(list(range(ds, len(action_chunk)+1, ds)))
            else:
                ds_actions = []
                remove_last_length = len(action_chunk) - minimum_decay_steps
                last_t = -1
                for t in range(ds - 1, remove_last_length, ds):
                    action_batch = action_chunk[t - ds + 1:t + 1]
                    action_batch_raw = normalize_action_controller(action_batch)
                    action_raw = merge_delta_actions(action_batch_raw)
                    action = unnormalize_action_controller(action_raw[None, :])[0]
                    ds_actions.append(action)
                    all_time_indices.append(min(t + 1, remove_last_length))
                    last_t = t
                
                for i_idx in range(max(0, last_t + min_ds), len(action_chunk), min_ds):
                    action_batch = action_chunk[i_idx - min_ds + 1:i_idx + 1]
                    action_batch_raw = normalize_action_controller(action_batch)
                    action_raw = merge_delta_actions(action_batch_raw)
                    action = unnormalize_action_controller(action_raw[None, :])[0]
                    ds_actions.append(action)
                    all_time_indices.append(min(i_idx + 1, len(action_chunk)))
            
            batch_ds_actions.append(np.array(ds_actions))
            batch_all_time_indices.append(all_time_indices)
            batch_target_indices.append(np.arange(len(action_chunk)+1))
        
        # 假设所有序列长度相同，直接批量处理
        seq_len = len(batch_ds_actions[0])  # 所有序列长度相同
        
        if normalize_obs:
            batch_obs_normalized = (batch_obs - self.obs_mean) / self.obs_std
        else:
            batch_obs_normalized = batch_obs
        
        # 批量处理观测
        batch_obs_tensor = torch.from_numpy(batch_obs_normalized).float().to(self.device)
        
        # 准备 RGB 观测
        if any(rgb is not None for rgb in batch_rgb_obs):
            first_rgb = next(rgb for rgb in batch_rgb_obs if rgb is not None)
            rgb_shape = first_rgb.shape
            batch_rgb_tensor = torch.zeros((batch_size, *rgb_shape), dtype=torch.float32, device=self.device)
            for i, rgb_stacked in enumerate(batch_rgb_obs):
                if rgb_stacked is not None:
                    batch_rgb_tensor[i] = torch.from_numpy(rgb_stacked).float().to(self.device)
        else:
            batch_rgb_tensor = None
        
        # 批量处理动作序列（假设长度相同）
        batch_actions_array = np.stack(batch_ds_actions)  # [B, seq_len, action_dim]
        if normalize_obs and self.action_mean is not None:
            batch_actions_normalized = (batch_actions_array - self.action_mean) / self.action_std
        else:
            batch_actions_normalized = batch_actions_array
        batch_actions_tensor = torch.from_numpy(batch_actions_normalized).float().to(self.device)  # [B, seq_len, action_dim]
        
        # 批量预测
        with torch.no_grad():
            predicted_obs_batch = eef_model.predict(
                low_dim_obs=batch_obs_tensor,
                rgb_obs=batch_rgb_tensor,
                actions=batch_actions_tensor
            )  # [B, seq_len, obs_dim]
        
        # 批量反标准化
        predicted_obs_batch_np = predicted_obs_batch.cpu().numpy()  # [B, seq_len, obs_dim]
        if normalize_obs and self.obs_mean is not None:
            predicted_obs_unnorm_batch = predicted_obs_batch_np * self.obs_std + self.obs_mean
        else:
            predicted_obs_unnorm_batch = predicted_obs_batch_np.copy()
        
        time_indices = np.array(batch_all_time_indices[0])
        target_indices = np.array(batch_target_indices[0])

        if self.obs_dim >= 14:
            eef_trajectories = process_obs_for_real(predicted_obs_unnorm_batch, batch_obs, time_indices, target_indices, self.obs_dim<=14)
        else:
            eef_trajectories = process_obs_for_libero(predicted_obs_unnorm_batch, batch_obs, time_indices, target_indices)
        next_states_array = predicted_obs_unnorm_batch[:, -1]  # [B, obs_dim]
        return eef_trajectories, batch_actions_array, next_states_array
    
    def compute_diff_eef(self, desired_eefs, feedback_eefs, mode="mean"):
        """计算 eef 差异"""
        error = np.abs(desired_eefs - feedback_eefs)
        if mode == "mean":
            return np.mean(error)
        elif mode == "max":
            return np.max(error)
        else:
            return np.min(error)
    
    def _sample_safe_iql_data(self, save_path: Optional[str] = None):
        """
        从 demos 中采样 safe transitions 并保存到文件
        
        注意：这个函数需要根据实际环境的数据格式调整
        """
        if save_path is None:
            save_path = self.work_dir / "safe_iql_data.pt"
        
        if os.path.exists(save_path):
            print(f"Safe-IQL data already exists at {save_path}. Loading...")
            return save_path
        
        k_low = self.cfg.safe_iql.k_low
        k_high = self.cfg.safe_iql.k_high
        epsilon = self.cfg.safe_iql.epsilon
        eef_error_func = compute_eef_error_ratio if epsilon >= 1 else compute_eef_error
        k_values = list(range(k_low, k_high + 1))
        
        # Action sequence lengths
        seq_len = self.cfg.safe_iql.seq_len
        A_lengths = [seq_len-2, seq_len, seq_len+2]
        # A_lengths = [seq_len]
        
        print(f"Sampling safe transitions with k in {k_values}, epsilon={epsilon}, A_lengths={A_lengths}...")
        safe_transitions = []
        
        # Statistics for each k value
        k_stats = {k: {'safe': 0, 'unsafe': 0, 'total': 0} for k in k_values}
        
        # 处理每个 episode
        for episode_idx, episode in tqdm(enumerate(self.episodes), desc="Processing episodes", total=len(self.episodes)):
            # 提取观测和动作
            # 这里需要根据实际的数据格式调整
            # 假设 episode 格式为: {'obs': [...], 'actions': [...], 'rewards': [...], 'terminals': [...]}
            if 'obs' not in episode or 'actions' not in episode:
                continue
            
            obs = episode['obs']  # [T+1, obs_dim]
            actions = episode['actions']  # [T, action_dim]
            rgb_obs = getattr(episode, 'rgb_obs', [])
            if len(actions) == 0:
                continue
            
            # 对于每个 A_len，处理所有有效的状态-动作对
            for A_len in A_lengths:
                valid_states = obs[0:len(obs)-A_len]
                valid_rgb_obs = rgb_obs[0:len(rgb_obs)-A_len] if len(rgb_obs) > 0 else [None] * len(valid_states)
                valid_state_indices = list(range(len(obs)-A_len))
                valid_actions = []
                
                for state_idx in range(len(obs) - A_len):
                    A = np.array(actions[state_idx:state_idx + A_len])  # (A_len, action_dim)
                    valid_actions.append(A)
                valid_actions = np.array(valid_actions)

                # 准备历史观测列表（从 observation 字典转换为 obs_history 和 rgb_obs_history）
                # NOTE: current only framestack = 1 is supported
                framestack = self.dynamics_cfg.model.framestack if hasattr(self.dynamics_cfg, 'model') else 1
                assert framestack == 1, "current only framestack = 1 is supported"
                
                
                # Batch compute baselines
                baseline_eefs_list, baseline_actions_list, baseline_next_states = self.get_feedback_eef_interpolated_batch(
                    self.dynamics, valid_states, valid_rgb_obs, 
                    valid_actions, k_low, min_ds=k_low, minimum_decay_steps=0
                )
                
                # For each k value, batch compute feedbacks and compare
                for k in k_values:
                    if k != k_low:
                        # Batch compute feedbacks
                        feedback_eefs_list, feedback_actions_list, feedback_next_states = self.get_feedback_eef_interpolated_batch(
                            self.dynamics, valid_states, valid_rgb_obs, 
                            valid_actions, k, min_ds=k_low, minimum_decay_steps=0
                        )
                    else:
                        feedback_eefs_list = baseline_eefs_list
                        feedback_actions_list = baseline_actions_list
                        feedback_next_states = baseline_next_states
                    # Compare each feedback with its baseline
                    for i, (baseline_eef, feedback_eef) in enumerate(zip(baseline_eefs_list, feedback_eefs_list)):
                        min_length = min(len(baseline_eef), len(feedback_eef))
                        
                        # 将 EEF 轨迹拆分为位置和四元数
                        baseline_eef_clipped = baseline_eef[:min_length]
                        feedback_eef_clipped = feedback_eef[:min_length] 
                        
                        # 双臂
                        if baseline_eef_clipped.shape[1] > 7:
                            error = eef_error_func(
                                baseline_eef_clipped[:, :3],
                                baseline_eef_clipped[:, 3:7],
                                feedback_eef_clipped[:, :3],
                                feedback_eef_clipped[:, 3:7]
                            )
                            error1 = eef_error_func(
                                baseline_eef_clipped[:, 7:10],
                                baseline_eef_clipped[:, 10:14],
                                feedback_eef_clipped[:, 7:10],
                                feedback_eef_clipped[:, 10:14]
                            )
                            error = (error + error1) / 2
                        else:
                            error = eef_error_func(
                                baseline_eef_clipped[:, :3],
                                baseline_eef_clipped[:, 3:7],
                                feedback_eef_clipped[:, :3],
                                feedback_eef_clipped[:, 3:7]
                            )
                        # Update statistics
                        k_stats[k]['total'] += 1
                        penalty_max = 1.0
                        if k == k_low:
                            k_stats[k]['safe'] += 1
                            reward = 0
                        elif error <= epsilon:
                            k_stats[k]['safe'] += 1
                            reward = k / (k_high*penalty_max)
                        else:
                            k_stats[k]['unsafe'] += 1
                            reward = -1.0
                        
                        s_env = valid_states[i]
                        s_env_next = feedback_next_states[i]
                        
                        if self.obs_mean is not None and self.obs_std is not None:
                            s_env_normalized = (s_env - self.obs_mean) / self.obs_std
                            s_env_next_normalized = (s_env_next - self.obs_mean) / self.obs_std
                        else:
                            s_env_normalized = s_env
                            s_env_next_normalized = s_env_next
                        A_i = feedback_actions_list[i]
                        if self.action_mean is not None and self.action_std is not None:
                            A_i_normalized = (A_i - self.action_mean) / self.action_std
                        else:
                            A_i_normalized = A_i
                        safe_transitions.append({
                            's_env': s_env_normalized,
                            'A': A_i_normalized,
                            'k': k,
                            'r': reward,
                            's_env_next': s_env_next_normalized,
                        })
                        
        
        # Save data
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        with save_path.open("wb") as f:
            torch.save({
                'transitions': safe_transitions,
                'k_low': k_low,
                'k_high': k_high,
            }, f)
        print(f"Saved {len(safe_transitions)} safe transitions to {save_path}")
        
        # Print statistics
        print("\n=== Safety Statistics by k value ===")
        for k in sorted(k_values):
            stats = k_stats[k]
            total = stats['total']
            safe = stats['safe']
            unsafe = stats['unsafe']
            if total > 0:
                safe_rate = safe / total * 100
                unsafe_rate = unsafe / total * 100
                print(f"k={k}: Total={total}, Safe={safe} ({safe_rate:.2f}%), Unsafe={unsafe} ({unsafe_rate:.2f}%)")
            else:
                print(f"k={k}: Total=0, Safe=0, Unsafe=0")
        print("=" * 40 + "\n")
        
        return save_path
    
    def _train_safe_iql_scheduling_policy(
        self, 
        data_path: Optional[str] = None, 
        eval: bool = False,
        as_client: bool = False,
        client_host: str = "0.0.0.0",
        client_port: int = 8888,
    ):
        """
        训练 Safe-IQL Scheduling Policy
        
        这个函数实现了完整的 Safe-IQL 训练流程，包括：
        1. 加载数据
        2. 定义网络结构（Value Network, Q-Network）
        3. 训练循环
        4. 评估和保存
        
        Args:
            data_path: 数据路径
            eval: 如果为 True，则只加载模型并调用评估函数，不进行训练
            as_client: 如果为 True 且 eval=True，则在加载完 agent snapshot 后启动 HTTP API 服务器
            client_host: 服务器绑定的 IP 地址（仅在 as_client=True 时使用）
            client_port: 服务器端口（仅在 as_client=True 时使用）
        """
        # 定义网络结构（与 robobase 中的实现相同）
        def safe_iql_weight_init(m):
            """初始化权重：隐藏层使用正交初始化"""
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)
        
        def safe_iql_output_init(m):
            """初始化输出层：使用很小的权重"""
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                m.weight.data.mul_(1e-3)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)
        
        class SimpleMLP(nn.Module):
            """简单的 MLP 网络"""
            def __init__(self, input_dim, hidden_dims, output_dim, activation=nn.ReLU, use_layernorm=True):
                super().__init__()
                layers = []
                prev_dim = input_dim
                for i, hidden_dim in enumerate(hidden_dims):
                    layers.append(nn.Linear(prev_dim, hidden_dim))
                    if use_layernorm:
                        layers.append(nn.LayerNorm(hidden_dim))
                    layers.append(activation())
                    prev_dim = hidden_dim
                
                output_layer = nn.Linear(prev_dim, output_dim)
                layers.append(output_layer)
                self.net = nn.Sequential(*layers)
                
                # 初始化
                for module in self.net:
                    if isinstance(module, nn.Linear) and module is output_layer:
                        safe_iql_output_init(module)
                    elif isinstance(module, nn.Linear):
                        safe_iql_weight_init(module)
            
            def forward(self, x):
                return self.net(x)
        
        class QNetworkRNN(nn.Module):
            """Q-Network with RNN for variable-length action sequences"""
            def __init__(self, env_dim, action_dim, hidden_dims, rnn_hidden_size=128, num_rnn_layers=2):
                super().__init__()
                self.rnn_hidden_size = rnn_hidden_size
                
                # Encode environment state
                self.env_encoder = SimpleMLP(env_dim, hidden_dims, rnn_hidden_size)
                
                # RNN for processing action sequence
                self.rnn = nn.GRU(
                    input_size=action_dim,
                    hidden_size=rnn_hidden_size,
                    num_layers=num_rnn_layers,
                    batch_first=True
                )
                
                # Final MLP to output Q value
                self.q_head = SimpleMLP(rnn_hidden_size * 2, hidden_dims, 1)
            
            def forward(self, s_env, A_prime_list):
                """
                Args:
                    s_env: (batch, env_dim)
                    A_prime_list: List of variable-length sequences, each (seq_len, action_dim)
                Returns:
                    q_values: (batch, 1)
                """
                batch_size = s_env.shape[0]
                
                # Encode environment state
                s_encoded = self.env_encoder(s_env)  # (batch, rnn_hidden_size)
                
                # Handle empty sequences
                lengths = []
                valid_seqs = []
                valid_indices = []
                for i, a_seq in enumerate(A_prime_list):
                    if a_seq.shape[0] > 0:
                        valid_seqs.append(a_seq)
                        lengths.append(a_seq.shape[0])
                        valid_indices.append(i)
                
                if len(valid_seqs) == 0:
                    rnn_outs = torch.zeros(batch_size, self.rnn_hidden_size, device=s_env.device)
                else:
                    # Pad sequences
                    padded_seqs = pad_sequence(valid_seqs, batch_first=True)
                    lengths_tensor = torch.tensor(lengths, dtype=torch.int64, device='cpu')
                    
                    # Pack padded sequences
                    packed_seqs = pack_padded_sequence(
                        padded_seqs, lengths_tensor, batch_first=True, enforce_sorted=False
                    )
                    
                    # Process with RNN
                    rnn_output, hidden = self.rnn(packed_seqs)
                    rnn_out_valid = hidden[-1]  # (valid_batch, rnn_hidden_size)
                    
                    # Map back to full batch
                    rnn_outs = torch.zeros(batch_size, self.rnn_hidden_size, device=s_env.device)
                    for idx, valid_idx in enumerate(valid_indices):
                        rnn_outs[valid_idx] = rnn_out_valid[idx]
                
                # Concatenate environment encoding and RNN output
                combined = torch.cat([s_encoded, rnn_outs], dim=-1)
                q_values = self.q_head(combined)
                
                return q_values
        
        class SafeIQLAgent:
            """Safe-IQL Agent"""
            def __init__(self, env_dim, action_dim, k_values, hidden_dims=[256, 256],
                        rnn_hidden_size=128, num_rnn_layers=2,
                        expectile=0.7, device='cuda', learning_rate=3e-4, tau=0.005):
                self.device = device
                self.k_values = k_values
                self.expectile = expectile
                self.tau = tau
                
                # Value Network
                self.value_net = SimpleMLP(env_dim, hidden_dims, 1).to(device)
                
                # Q-Network
                self.q_net = QNetworkRNN(env_dim, action_dim, hidden_dims, 
                                        rnn_hidden_size=rnn_hidden_size, 
                                        num_rnn_layers=num_rnn_layers).to(device)
                
                # Target Q-Network
                self.target_q_net = QNetworkRNN(env_dim, action_dim, hidden_dims,
                                                rnn_hidden_size=rnn_hidden_size,
                                                num_rnn_layers=num_rnn_layers).to(device)
                self.target_q_net.load_state_dict(self.q_net.state_dict())
                for param in self.target_q_net.parameters():
                    param.requires_grad = False
                
                # Optimizers
                self.value_opt = torch.optim.Adam(self.value_net.parameters(), lr=learning_rate)
                self.q_opt = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
            
            def update_target_q_net(self):
                """Soft update target Q network"""
                soft_update_params(self.q_net, self.target_q_net, self.tau)
            
            def state_dict(self):
                """Return state dict for saving"""
                return {
                    'value_net': self.value_net.state_dict(),
                    'q_net': self.q_net.state_dict(),
                    'target_q_net': self.target_q_net.state_dict(),
                    'value_opt': self.value_opt.state_dict(),
                    'q_opt': self.q_opt.state_dict(),
                    'k_values': self.k_values,
                    'expectile': self.expectile,
                    'tau': self.tau,
                }
            
            def load_state_dict(self, state_dict):
                """Load state dict"""
                self.value_net.load_state_dict(state_dict['value_net'])
                self.q_net.load_state_dict(state_dict['q_net'])
                if 'target_q_net' in state_dict:
                    self.target_q_net.load_state_dict(state_dict['target_q_net'])
                else:
                    self.target_q_net.load_state_dict(self.q_net.state_dict())
                self.value_opt.load_state_dict(state_dict['value_opt'])
                self.q_opt.load_state_dict(state_dict['q_opt'])
                self.k_values = state_dict.get('k_values', self.k_values)
                self.expectile = state_dict.get('expectile', self.expectile)
                self.tau = state_dict.get('tau', self.tau)
            
            def predict_k(self, s_env, A, action_mean=None, action_std=None, control_mode="delta"):
                """Predict k by querying max Q value over all k values
                
                注意：训练时是先降采样再归一化，所以这里也采用相同的顺序。
                如果提供了 action_mean 和 action_std，则先降采样，再归一化。
                如果没有提供，则假设 A 已经是归一化后的降采样序列（向后兼容）。
                
                Args:
                    s_env: [obs_dim] 当前状态（已标准化）
                    A: [T, action_dim] 动作序列（未标准化，delta actions，如果提供了 action_mean/action_std）
                       或已标准化的降采样序列（如果没有提供 action_mean/action_std，向后兼容）
                    action_mean: 动作均值（可选，如果提供则先降采样再归一化）
                    action_std: 动作标准差（可选，如果提供则先降采样再归一化）
                
                Returns:
                    predicted_k: [1] 预测的 k 值
                """
                q_values_list = []
                s_env_tensor = torch.from_numpy(s_env).float().to(self.device)
                s_env_tensor = s_env_tensor.unsqueeze(0)
                
                # 如果提供了归一化参数，则先降采样再归一化（与训练时一致）
                use_normalization = action_mean is not None and action_std is not None
                
                for k in self.k_values:
                    if k == 1:
                        if use_normalization:
                            A_prime = A.copy()
                            A_prime_normalized = (A_prime - action_mean) / action_std
                        else:
                            A_prime_normalized = A.copy()
                    else:
                        if control_mode == "abs":
                            A_prime = A[k-1::k]
                        else:
                            A_prime_list = []
                            for t in range(k - 1, len(A), k):
                                action_batch = A[t - k + 1:t + 1]  # [k, action_dim]
                                action_batch_raw = normalize_action_controller(action_batch)
                                action_raw = merge_delta_actions(action_batch_raw)  # [action_dim]
                                action = unnormalize_action_controller(action_raw[None, :])[0]  # [action_dim]
                                A_prime_list.append(action)
                            A_prime = np.array(A_prime_list)  # [num_merged, action_dim]

                        if use_normalization:
                            A_prime_normalized = (A_prime - action_mean) / action_std
                        else:
                            A_prime_normalized = A_prime
                    
                    A_prime_tensor = torch.from_numpy(A_prime_normalized).float().to(self.device)
                    q_vals = self.q_net(s_env_tensor, [A_prime_tensor])
                    q_values_list.append(q_vals)
                q_matrix = torch.cat(q_values_list, dim=1)
                print(q_matrix)
                k_indices = q_matrix.argmax(dim=1)
                return torch.tensor([self.k_values[idx.item()] for idx in k_indices], device=self.device)
        
        # Load data
        if data_path is None:
            data_path = self.work_dir / "safe_iql_data.pt"
        data_path = Path(data_path)
        
        if not data_path.exists():
            raise FileNotFoundError(
                f"Data file not found: {data_path}. Please run _sample_safe_iql_data first."
            )
        
        with data_path.open("rb") as f:
            data = torch.load(f, map_location="cpu", weights_only=False)
            safe_transitions = data['transitions']
            k_low = data['k_low']
            k_high = data['k_high']
            k_values = list(range(k_low, k_high + 1))
        
        if len(safe_transitions) == 0:
            print("Warning: No safe transitions found! Cannot train.")
            return
        
        # Get dimensions
        env_dim = safe_transitions[0]['s_env'].shape[-1] if len(safe_transitions[0]['s_env'].shape) > 0 else len(safe_transitions[0]['s_env'])
        action_dim = safe_transitions[0]['A'].shape[-1] if len(safe_transitions[0]['A'].shape) > 1 else safe_transitions[0]['A'].shape[0]
        
        # Get config values
        safe_iql_cfg = self.cfg.safe_iql
        hidden_dims = safe_iql_cfg.hidden_dims
        rnn_hidden_size = safe_iql_cfg.rnn_hidden_size
        num_rnn_layers = safe_iql_cfg.num_rnn_layers
        expectile = safe_iql_cfg.expectile
        learning_rate = safe_iql_cfg.learning_rate
        num_epochs = safe_iql_cfg.num_epochs
        batch_size = safe_iql_cfg.batch_size
        gamma = safe_iql_cfg.gamma
        tau = safe_iql_cfg.tau
        grad_clip_norm = safe_iql_cfg.grad_clip_norm
        save_checkpoint = safe_iql_cfg.save_checkpoint
        checkpoint_every_n_epochs = safe_iql_cfg.checkpoint_every_n_epochs
        save_best_checkpoint = safe_iql_cfg.save_best_checkpoint
        eval_every_n_epochs = safe_iql_cfg.eval_every_n_epochs
        eval_episodes = safe_iql_cfg.eval_episodes
        
        # Initialize agent
        agent = SafeIQLAgent(
            env_dim=env_dim,
            action_dim=action_dim,
            k_values=k_values,
            hidden_dims=hidden_dims,
            rnn_hidden_size=rnn_hidden_size,
            num_rnn_layers=num_rnn_layers,
            expectile=expectile,
            device=self.device,
            learning_rate=learning_rate,
            tau=tau
        )
        
        # Try to load existing checkpoint
        start_epoch, best_success_rate = self.load_safe_iql_snapshot(agent, eval=eval)
        if start_epoch == 0:
            print("No checkpoint found, starting from scratch")
        if eval:
            # 评估模式：加载快照后调用评估函数
            agent.q_net.eval()
            agent.value_net.eval()
            agent.target_q_net.eval()
            print(f"Evaluation mode: Agent loaded from snapshot (epoch {start_epoch})")
            
            # 如果 as_client=True，启动 HTTP API 服务器
            if as_client:
                print(f"启动 Safe-IQL Agent HTTP API 服务器...")
                print(f"从 workspace 获取标准化参数...")
                start_safe_iql_agent_server(
                    agent=agent,
                    host=client_host,
                    port=client_port,
                    obs_mean=getattr(self, 'obs_mean', None),
                    obs_std=getattr(self, 'obs_std', None),
                    action_mean=getattr(self, 'action_mean', None),
                    action_std=getattr(self, 'action_std', None),
                    control_mode=getattr(self.cfg.safe_iql, 'control_mode', "delta")
                )
                # 服务器会一直运行，不会返回
                return None
            
        if start_epoch > 0:
            metrics = {
                'safe_iql_resume_epoch': start_epoch,
                'safe_iql_best_success_rate': best_success_rate,
            }
            self.logger.log_metrics(metrics, start_epoch, prefix="eval")
        
        # Prepare data for training
        transitions_by_k = {k: [] for k in k_values}
        for t in safe_transitions:
            transitions_by_k[t['k']].append(t)
        
        # Reserve test samples
        test_samples_by_k = {}
        test_sample_size_per_k = 50
        for k in k_values:
            if len(transitions_by_k[k]) > test_sample_size_per_k:
                k_indices = np.arange(len(transitions_by_k[k]))
                np.random.shuffle(k_indices)
                test_indices = k_indices[:test_sample_size_per_k]
                test_samples_by_k[k] = [transitions_by_k[k][i] for i in test_indices]
                train_indices = k_indices[test_sample_size_per_k:]
                transitions_by_k[k] = [transitions_by_k[k][i] for i in train_indices]
            else:
                test_samples_by_k[k] = transitions_by_k[k].copy()
                transitions_by_k[k] = []
        
        # Reconstruct safe_transitions
        safe_transitions = []
        for k in k_values:
            safe_transitions.extend(transitions_by_k[k])
        
        # Training loop
        for epoch in tqdm(range(start_epoch, num_epochs), desc="Training Safe-IQL"):
            epoch_vf_loss = 0.0
            epoch_q_loss = 0.0
            num_batches = 0
            
            # Sample batches
            all_indices = np.arange(len(safe_transitions))
            np.random.shuffle(all_indices)
            
            for batch_start in range(0, len(safe_transitions), batch_size):
                batch_indices = all_indices[batch_start:batch_start + batch_size]
                batch_transitions = [safe_transitions[i] for i in batch_indices]
                
                # Prepare batch
                s_env_batch = torch.from_numpy(
                    np.array([t['s_env'] for t in batch_transitions])
                ).float().to(self.device)
                s_env_batch = s_env_batch.squeeze()
                A_batch = [torch.from_numpy(t['A']).float().to(self.device) for t in batch_transitions]
                r_batch = torch.tensor([t['r'] for t in batch_transitions]).float().to(self.device)
                s_env_next_batch = torch.from_numpy(
                    np.array([t['s_env_next'] for t in batch_transitions])
                ).float().to(self.device)
                s_env_next_batch = s_env_next_batch.squeeze()
                
                # Update Value Network
                with torch.no_grad():
                    q_values = agent.target_q_net(s_env_batch, A_batch)
                
                v_values = agent.value_net(s_env_batch)
                vf_err = q_values - v_values
                vf_weight = torch.where(vf_err > 0, agent.expectile, 1 - agent.expectile)
                vf_loss = (vf_weight * (vf_err ** 2)).mean()
                
                agent.value_opt.zero_grad()
                vf_loss.backward()
                torch.nn.utils.clip_grad_norm_(agent.value_net.parameters(), max_norm=grad_clip_norm)
                agent.value_opt.step()
                
                # Update Q-Network
                with torch.no_grad():
                    v_next = agent.value_net(s_env_next_batch)
                    target_q = r_batch.unsqueeze(1) + gamma * v_next
                
                q_values = agent.q_net(s_env_batch, A_batch)
                q_loss = F.mse_loss(q_values, target_q)
                
                agent.q_opt.zero_grad()
                q_loss.backward()
                torch.nn.utils.clip_grad_norm_(agent.q_net.parameters(), max_norm=grad_clip_norm)
                agent.q_opt.step()
                
                # Update target Q network
                agent.update_target_q_net()
                
                epoch_vf_loss += vf_loss.item()
                epoch_q_loss += q_loss.item()
                num_batches += 1
            
            # Calculate average losses
            avg_vf_loss = epoch_vf_loss / num_batches if num_batches > 0 else 0.0
            avg_q_loss = epoch_q_loss / num_batches if num_batches > 0 else 0.0
            
            # 计算并记录不同 k 值的 Q 值（采样一个 batch 来计算）
            if num_batches > 0:
                # 从训练数据中采样一个 batch 来计算不同 k 值的 Q 值
                sample_batch_size = min(batch_size, len(safe_transitions))
                sample_indices = np.random.choice(len(safe_transitions), sample_batch_size, replace=False)
                sample_transitions = [safe_transitions[i] for i in sample_indices]
                
                sample_s_env = torch.from_numpy(
                    np.array([t['s_env'] for t in sample_transitions])
                ).float().to(self.device).squeeze()
                sample_A = [torch.from_numpy(t['A']).float().to(self.device) for t in sample_transitions]
                
                with torch.no_grad():
                    sample_q_values = agent.q_net(sample_s_env, sample_A).cpu().numpy().flatten()
                    sample_k_values = [t['k'] for t in sample_transitions]
                
                # 计算每个 k 值的平均 Q 值
                k_q_means = {}
                k_q_stds = {}
                for k_val in k_values:
                    k_mask = np.array([k == k_val for k in sample_k_values])
                    if np.any(k_mask):
                        k_qs = sample_q_values[k_mask]
                        k_q_means[f'safe_iql_q_mean_k{k_val}'] = float(np.mean(k_qs))
                        k_q_stds[f'safe_iql_q_std_k{k_val}'] = float(np.std(k_qs))
            else:
                k_q_means = {}
                k_q_stds = {}
            
            # Log loss and Q values
            metrics = {
                'safe_iql_vf_loss': avg_vf_loss,
                'safe_iql_q_loss': avg_q_loss,
                **k_q_means,
                **k_q_stds,
            }
            self.logger.log_metrics(metrics, epoch + 1, prefix="train")
            
            # Evaluate periodically
            if (epoch + 1) % eval_every_n_epochs == 0:
                # 这里需要实现评估函数
                print(f"Epoch {epoch + 1}: Evaluation not implemented yet")
                
                # Save checkpoint
                if save_checkpoint:
                    self.save_safe_iql_snapshot(agent, epoch + 1, best_success_rate, best_ckpt=False)
            elif save_checkpoint and (epoch + 1) % checkpoint_every_n_epochs == 0:
                self.save_safe_iql_snapshot(agent, epoch + 1, best_success_rate, best_ckpt=False)
        
        # Save final checkpoint
        if save_checkpoint:
            self.save_safe_iql_snapshot(agent, num_epochs, best_success_rate, best_ckpt=False)
        
        return agent
    
    def save_safe_iql_snapshot(self, agent, epoch, best_success_rate=0.0, best_ckpt=False):
        """保存 Safe-IQL agent 快照"""
        snapshot_dir = self.work_dir / "snapshots"
        snapshot_dir.mkdir(parents=True, exist_ok=True)
        
        if best_ckpt:
            snapshot = snapshot_dir / "safe_iql_best_snapshot.pt"
        else:
            snapshot = snapshot_dir / f"safe_iql_epoch_{epoch}_snapshot.pt"
        
        payload = {
            'cfg': self.cfg.to_dict(),
            'agent_state': agent.state_dict(),
            'epoch': epoch,
            'best_success_rate': best_success_rate,
        }
        
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        
        latest_snapshot = snapshot_dir / "safe_iql_latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)
    
    def load_safe_iql_snapshot(self, agent, path_to_snapshot=None, eval=False):
        """加载 Safe-IQL agent 快照"""
        if path_to_snapshot is None:
            snapshot_dir = self.work_dir / "snapshots"
            path_to_snapshot = snapshot_dir / "safe_iql_latest_snapshot.pt"
            # if eval:
            #     path_to_snapshot = snapshot_dir / "safe_iql_best_snapshot.pt"
        else:
            path_to_snapshot = Path(path_to_snapshot)
        print(f"Loading Safe-IQL snapshot from {path_to_snapshot}")
        
        if not path_to_snapshot.exists():
            return 0, 0.0
        
        with path_to_snapshot.open("rb") as f:
            payload = torch.load(f, map_location=self.device, weights_only=False)
        
        agent.load_state_dict(payload['agent_state'])
        epoch = payload.get('epoch', 0)
        best_success_rate = payload.get('best_success_rate', 0.0)
        return epoch, best_success_rate
    
    def _get_common_metrics(self) -> Dict[str, Any]:
        """获取通用指标"""
        _, total_time = self._timer.reset()
        metrics = {
            "total_time": total_time,
            "env_steps": self._pretrain_step,
            "env_episodes": self.global_env_episodes,
        }
        return metrics
    
    def _signal_handler(self, sig, frame):
        """信号处理器"""
        print("\nCtrl+C detected. Preparing to shutdown...")
        self._shutting_down = True
    
    def shutdown(self):
        """关闭工作空间"""
        if hasattr(self, 'logger'):
            self.logger.close()


def start_safe_iql_agent_server(
    agent,
    host: str = "0.0.0.0",
    port: int = 8888,
    obs_mean: Optional[np.ndarray] = None,
    obs_std: Optional[np.ndarray] = None,
    action_mean: Optional[np.ndarray] = None,
    action_std: Optional[np.ndarray] = None,
    control_mode: str = "delta",
):
    """
    启动 Safe-IQL Agent 的 WebSocket 服务器
    
    该函数会启动一个 WebSocket 服务器，提供 API 供其他 Python 环境调用 Safe-IQL Agent。
    使用 WebSocket 可以避免与 HTTP proxy 的冲突。
    
    Args:
        agent: SafeIQLAgent 实例（已加载的 agent）
        host: 服务器绑定的 IP 地址，默认为 "0.0.0.0"（允许外部访问）
        port: 服务器端口，默认为 8888
        obs_mean: 观测均值（用于标准化），如果为 None 则不进行标准化
        obs_std: 观测标准差（用于标准化），如果为 None 则不进行标准化
        action_mean: 动作均值（用于标准化），如果为 None 则不进行标准化
        action_std: 动作标准差（用于标准化），如果为 None 则不进行标准化
        control_mode: 控制模式，默认为 "delta"
    """
    try:
        import websockets.sync.server
    except ImportError:
        raise ImportError(
            "websockets is required to run the server. "
            "Please install it with: pip install websockets"
        )
    
    from msgpack_numpy_utils import Packer, unpackb
    
    def handle_client(websocket):
        """处理客户端连接"""
        logging.info(f"Connection opened")
        packer = Packer()
        
        try:
            # 发送服务器元数据
            metadata = {
                "status": "ready",
                "k_values": agent.k_values.tolist() if hasattr(agent.k_values, 'tolist') else list(agent.k_values)
            }
            websocket.send(packer.pack(metadata))
            
            # 处理消息循环
            while True:
                try:
                    # 接收消息
                    message = websocket.recv()
                    if isinstance(message, str):
                        # 字符串表示错误，关闭连接
                        logging.error(f"Received string message: {message}")
                        break
                    
                    data = unpackb(message)
                    request_type = data.get("type", "predict_k")
                    
                    if request_type == "health":
                        # 健康检查
                        response = {
                            "status": "success",
                            "k_values": agent.k_values.tolist() if hasattr(agent.k_values, 'tolist') else list(agent.k_values)
                        }
                        websocket.send(packer.pack(response))
                    
                    elif request_type == "predict_k":
                        # 预测 k 值
                        if "s_env" not in data or "A" not in data:
                            response = {
                                "status": "error",
                                "message": "Request must contain 's_env' and 'A' fields"
                            }
                            websocket.send(packer.pack(response))
                            continue
                        
                        s_env = np.array(data["s_env"], dtype=np.float32)
                        A = np.array(data["A"], dtype=np.float32)
                        
                        # 标准化状态（如果需要）
                        if obs_mean is not None and obs_std is not None:
                            s_env_normalized = (s_env - obs_mean) / obs_std
                        else:
                            s_env_normalized = s_env
                        
                        # 调用 agent 预测 k 值
                        with torch.no_grad():
                            predicted_k_tensor = agent.predict_k(
                                s_env_normalized, 
                                A,
                                action_mean=action_mean,
                                action_std=action_std,
                                control_mode=control_mode
                            )
                            predicted_k = int(predicted_k_tensor[0].cpu().item())
                        
                        response = {
                            "status": "success",
                            "predicted_k": predicted_k
                        }
                        websocket.send(packer.pack(response))
                    
                    else:
                        response = {
                            "status": "error",
                            "message": f"Unknown request type: {request_type}"
                        }
                        websocket.send(packer.pack(response))
                
                except Exception as e:
                    # 发送错误信息
                    import traceback
                    error_traceback = traceback.format_exc()
                    logging.error(f"Error handling message: {error_traceback}")
                    websocket.send(error_traceback)
                    break
        
        except Exception as e:
            logging.error(f"Error in client handler: {e}")
        
        logging.info(f"Connection closed")
    
    # 启动服务器
    uri = f"ws://{host}:{port}"
    print(f"启动 Safe-IQL Agent WebSocket 服务器...")
    print(f"服务器地址: {uri}")
    print(f"按 Ctrl+C 停止服务器")
    
    with websockets.sync.server.serve(handle_client, host, port, compression=None, max_size=None) as server:
        try:
            server.serve_forever()
        except KeyboardInterrupt:
            print("\n正在关闭服务器...")

