import shutil
import signal
import sys
import os
import json
import time
import copy
import random
from typing import Callable, Any
from functools import partial
import logging
from omegaconf import DictConfig, OmegaConf

from gymnasium import spaces
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm

from robobase import utils
from robobase.envs.env import EnvFactory
from robobase.logger import Logger

from robobase.envs.wrappers import RescaleFromTanhWithMinMax
from robobase.envs.utils.bigym_utils import ErrorCalculator
from pathlib import Path
from robobase.method.utils import (
    extract_many_from_batch,
    extract_from_batch,
    extract_many_from_spec
)
import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
from torch.utils.data import DataLoader
import imageio, cv2

torch.backends.cudnn.benchmark = True

import pdb

# get the name of each value in obs and act space
# env = eval_env.unwrapped
# obs space: [joint_names.qpos, joint_names.qvel, floating_base.qpos, gripper.qpos]
# joint_names = [joint.mjcf.name for joint in env.robot._joints]
# floating_names = [mjcf.name for mjcf in env.robot.floating_base.all_actuators]
# gripper_names = [k for k, gripper in env.robot.grippers.items()]
# act space: [floating_base.qpos, limb_actuators.qpos, gripper.qpos] 
# env.robot.qpos_actuated
import os
import numpy as np
from PIL import Image

def _create_default_envs(cfg: DictConfig) -> EnvFactory:
    factory = None
    if cfg.env.env_name == "rlbench":
        from robobase.envs.rlbench import RLBenchEnvFactory

        factory = RLBenchEnvFactory()
    elif cfg.env.env_name == "dmc":
        from robobase.envs.dmc import DMCEnvFactory

        factory = DMCEnvFactory()
    elif cfg.env.env_name == "bigym":
        from robobase.envs.bigym import BiGymEnvFactory

        factory = BiGymEnvFactory()
        # NOTE: here default high gain
        factory.HIGH_GAIN = True
    elif cfg.env.env_name == "d4rl":
        from robobase.envs.d4rl import D4RLEnvFactory

        factory = D4RLEnvFactory()
    else:
        ValueError()
    return factory

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

def put_text(img, text, font_size=1, thickness=2, resize=False,position="top"):
    img = img.copy()
    if resize:
        img = cv2.resize(np.uint8(img), (256, 256))
    if position == "top":
        p = (10, 30)
    elif position == "bottom":
        p = (300, img.shape[0] - 20)
    # put the frame number in the top left corner
    img = cv2.putText(
        img,
        text,
        p,
        cv2.FONT_HERSHEY_SIMPLEX,
        font_size,
        (0, 255, 255),
        thickness,
        cv2.LINE_AA,
    )
    return img


class PolicyWorkspace:
    def __init__(
        self,
        cfg: DictConfig,
        env_factory: EnvFactory = None,
        work_dir: str = None,
    ):  
        if env_factory is None:
            env_factory = _create_default_envs(cfg)
    
        self.work_dir = Path(
            hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
            if work_dir is None
            else work_dir
        )
        print(f"workspace: {self.work_dir}")

        self.cfg = cfg
        utils.set_seed_everywhere(cfg.seed)
        dev = "cpu"
        if cfg.num_gpus > 0:
            if sys.platform == "darwin":
                dev = "mps"
            else:
                dev = 0
                job_num = False
                try:
                    job_num = HydraConfig.get().job.get("num", False)
                except ValueError:
                    pass
                if job_num:
                    dev = job_num % cfg.num_gpus
        self.device = torch.device(dev)

        # create logger
        self.logger = Logger(self.work_dir, cfg=self.cfg)
        self.env_factory = env_factory

        # Create evaluation environment

        if (num_demos := cfg.demos) != 0:
            # Collect demos or fetch saved demos before making environments
            # to consider demo-based action space (e.g., standardization)
            self.demo_data = self.env_factory.collect_or_fetch_demos(cfg, num_demos)

        self.agent_eval_env = None
        self.dynamics_eval_env = self.env_factory.make_eval_env(cfg, self.work_dir)
        self.error_calculator = ErrorCalculator(self.dynamics_eval_env)

        # Create the RL Agent
        observation_space = self.dynamics_eval_env.observation_space
        action_space = self.dynamics_eval_env.action_space

        path_to_agent = f"exp_local/pixel_act/bigym_{self.cfg.env.task_name}_speedup_qv/snapshots/best_snapshot.pt"
        print("Loading agents...", path_to_agent)
        self.load_agent_snapshot(path_to_agent)

        path_to_dynamics_snapshot = f"exp_local/models_eefnp/bigym_{self.cfg.env.task_name}/snapshots/best_snapshot.pt"
        self.load_dynamics_snapshot(path_to_dynamics_snapshot)
        
        if num_demos != 0:
            # Post-process demos using the information from environments
            self.env_factory.post_collect_or_fetch_demos(cfg, self.work_dir)
        
        # self.obs_stats = copy.deepcopy(self.env_factory._obs_stats)
        self.act_stats = copy.deepcopy(self.env_factory._action_stats)

        self._timer = utils.Timer()
        self._pretrain_step = 0
        self._main_loop_iterations = 0
        self._global_env_episode = 0
        self._shutting_down = False 
        self.best_metrics = {
            "best_episode_success": 0,  
            "best_episode_len": 0,  
        }
    


    @property
    def pretrain_steps(self):
        return self._pretrain_step

    @property
    def main_loop_iterations(self):
        return self._main_loop_iterations

    @property
    def global_env_episodes(self):
        return self._global_env_episode

    @property
    def global_env_steps(self):
        """Total number of environment steps taken."""
        return self.pretrain_steps


    def train(self):
        signal.signal(signal.SIGINT, self._signal_handler)
        try:
            self._train()
        except Exception as e:
            self.shutdown()
            raise e


    def get_feedback_eef_interpolated(self, eef_model, observation, action_seq, ds):
        def interpolate_pos_quat(position, quat, gripper, time_indices, target_indices):
            """插值位置、四元数和夹爪（使用NLERP）"""
            # 转换为 numpy 
            pos_pts = np.array(position)
            quat_pts = np.array(quat)
            gripper_pts = np.array(gripper)
            
            # 修正四元数符号翻转 (Crucial for NLERP)
            # 确保相邻四元数点积为正，保证走最短路径
            for i in range(1, len(quat_pts)):
                if np.dot(quat_pts[i-1], quat_pts[i]) < 0:
                    quat_pts[i] *= -1

            # 向量化插值
            # 位置插值 (Linear)
            pos_interp = np.zeros((len(target_indices), 3))
            for i in range(3):
                pos_interp[:, i] = np.interp(target_indices, time_indices, pos_pts[:, i])
            
            # 姿态插值 (NLERP)
            quat_interp = np.zeros((len(target_indices), 4))
            for i in range(4):
                quat_interp[:, i] = np.interp(target_indices, time_indices, quat_pts[:, i])
            
            # NLERP 必须做的归一化
            norms = np.linalg.norm(quat_interp, axis=1, keepdims=True)
            quat_interp /= (norms + 1e-12) # 防止除零
            
            # 夹爪插值 (Linear)
            gripper_interp = np.interp(target_indices, time_indices, gripper_pts)

            return pos_interp, quat_interp, gripper_interp
        
        rgb_names = self.dynamics_cfg.env.cameras
        rgb_obs = np.stack([observation['rgb_' + name] for name in rgb_names], axis=1) # 1, ...
        low_dim_obs = observation["low_dim_state"] # 1, ...
        rgb_obs_torch = torch.from_numpy(rgb_obs).float().to(self.device)
        low_dim_obs_torch = torch.from_numpy(low_dim_obs).float().to(self.device)
        action_seq_torch = torch.from_numpy(action_seq).float().unsqueeze(0).to(self.device) # 1, seq_len, act_dim
        if not getattr(self.dynamics_cfg, "use_pixels", True):
            rgb_obs_torch = None 
        feedback_eefs, next_qpos_pred = eef_model.step_multi(rgb_obs_torch, low_dim_obs_torch, action_seq_torch) # seq_len, eef_dim
        feedback_eefs = feedback_eefs[0].cpu().numpy() # seq_len, eef_dim
        
        # next_qpos_pred = next_qpos_pred[0].cpu().numpy() # seq_len, qpos_dim
        feedback_eefs_addfirst = np.concatenate([observation['eef'], feedback_eefs], axis=0)
        next_state = feedback_eefs_addfirst[-1]

        # action_seq 已经是下采样后的序列，长度为 len(action_seq)
        # feedback_eefs 对应下采样后的时间点，长度也是 len(action_seq)
        # 需要插值回原始完整序列，原始序列长度为 len(action_seq) * ds
        
        # 1. 设置时间索引
        # time_indices: 下采样点的时间索引 [0, ds, 2*ds, ..., (len(action_seq)-1)*ds]
        seq_len = len(action_seq)
        time_indices = np.arange(0, seq_len+1) * ds  # 下采样后的时间索引
        # target_indices: 原始完整序列的所有时间步 [0, 1, 2, ..., seq_len*ds-1]
        target_indices = np.arange(1, seq_len * ds + 1)

        # 2. 提取左右臂的位置、四元数和夹爪
        # eef 格式: [left_pos(3), left_quat(4), left_grip(1), right_pos(3), right_quat(4), right_grip(1)]
        left_pos_pts = feedback_eefs_addfirst[:, :3]
        left_quat_pts = feedback_eefs_addfirst[:, 3:7]  # w, x, y, z
        left_grip_pts = feedback_eefs_addfirst[:, 7]
        right_pos_pts = feedback_eefs_addfirst[:, 8:11]
        right_quat_pts = feedback_eefs_addfirst[:, 11:15]  # w, x, y, z
        right_grip_pts = feedback_eefs_addfirst[:, 15]

        # 3. 使用插值函数分别处理左右臂
        left_pos_interp, left_quat_interp, left_grip_interp = interpolate_pos_quat(
            left_pos_pts, left_quat_pts, left_grip_pts, time_indices, target_indices
        )
        right_pos_interp, right_quat_interp, right_grip_interp = interpolate_pos_quat(
            right_pos_pts, right_quat_pts, right_grip_pts, time_indices, target_indices
        )
        
        # 4. 合并结果
        feedback_eefs_interpolated = np.concatenate([
            left_pos_interp,      # seq_len*ds, 3
            left_quat_interp,     # seq_len*ds, 4
            left_grip_interp[:, None],  # seq_len*ds, 1
            right_pos_interp,     # seq_len*ds, 3
            right_quat_interp,    # seq_len*ds, 4
            right_grip_interp[:, None]  # seq_len*ds, 1
        ], axis=1)  # seq_len*ds, 16

        # next_state = next_qpos_pred[-1]

        return feedback_eefs_interpolated, next_state
    
    def get_feedback_eef_interpolated_batch(self, eef_model, observations, action_seqs, ds):
        """
        批量版本的get_feedback_eef_interpolated
        Args:
            eef_model: dynamics模型
            observations: list of observation dicts, 每个包含'low_dim_state'和'eef'
            action_seqs: list of action sequences, 每个shape为(seq_len, action_dim)
            ds: downsample rate
        Returns:
            feedback_eefs_list: list of interpolated eef sequences
            next_states: array of next states (batch_size, eef_dim)
        """
        batch_size = len(observations)
        assert batch_size == len(action_seqs), "observations and action_seqs must have same length"
        
        def interpolate_pos_quat(position, quat, gripper, time_indices, target_indices):
            """插值位置、四元数和夹爪（使用NLERP）"""
            pos_pts = np.array(position)
            quat_pts = np.array(quat)
            gripper_pts = np.array(gripper)
            
            # 修正四元数符号翻转
            for i in range(1, len(quat_pts)):
                if np.dot(quat_pts[i-1], quat_pts[i]) < 0:
                    quat_pts[i] *= -1

            # 向量化插值
            pos_interp = np.zeros((len(target_indices), 3))
            for i in range(3):
                pos_interp[:, i] = np.interp(target_indices, time_indices, pos_pts[:, i])
            
            quat_interp = np.zeros((len(target_indices), 4))
            for i in range(4):
                quat_interp[:, i] = np.interp(target_indices, time_indices, quat_pts[:, i])
            
            norms = np.linalg.norm(quat_interp, axis=1, keepdims=True)
            quat_interp /= (norms + 1e-12)
            
            gripper_interp = np.interp(target_indices, time_indices, gripper_pts)
            return pos_interp, quat_interp, gripper_interp
        
        # 准备批量输入
        rgb_names = self.dynamics_cfg.env.cameras
        low_dim_obs_list = [obs["low_dim_state"] for obs in observations]
        eef_list = [obs["eef"] for obs in observations]
        
        # 检查序列长度是否一致（用于batch处理）
        seq_lens = [len(seq) for seq in action_seqs]
        if len(set(seq_lens)) == 1:
            # 所有序列长度相同，可以真正batch处理
            max_seq_len = seq_lens[0]
            low_dim_obs_batch = np.concatenate(low_dim_obs_list, axis=0)  # (batch_size, low_dim_dim)
            eef_batch = np.concatenate(eef_list, axis=0)  # (batch_size, eef_dim)
            
            # 准备RGB obs (如果使用pixels)
            if getattr(self.dynamics_cfg, "use_pixels", True):
                # 每个obs的rgb是(1, 3, 84, 84)，先stack成(1, num_cameras, 3, 84, 84)
                # 然后concatenate第一个维度，得到(batch_size, num_cameras, 3, 84, 84)
                rgb_obs_list = [np.stack([obs[f'rgb_{name}'] for name in rgb_names], axis=1) for obs in observations]
                rgb_obs_batch = np.concatenate(rgb_obs_list, axis=0)  # (batch_size, num_cameras, C, H, W)
                rgb_obs_torch = torch.from_numpy(rgb_obs_batch).float().to(self.device)
            else:
                rgb_obs_torch = None
            
            # 堆叠action sequences
            action_seqs_array = np.stack(action_seqs, axis=0)  # (batch_size, seq_len, action_dim)
            
            low_dim_obs_torch = torch.from_numpy(low_dim_obs_batch).float().to(self.device)
            action_seqs_torch = torch.from_numpy(action_seqs_array).float().to(self.device)
            
            # 批量调用模型
            feedback_eefs_batch, _ = eef_model.step_multi(rgb_obs_torch, low_dim_obs_torch, action_seqs_torch)
            feedback_eefs_batch = feedback_eefs_batch.cpu().numpy()  # (batch_size, seq_len, eef_dim)
            
            # 处理每个样本的插值
            feedback_eefs_list = []
            next_states = []
            for i in range(batch_size):
                feedback_eefs = feedback_eefs_batch[i]  # (seq_len, eef_dim)
                eef = eef_list[i]  # (1, eef_dim)
                feedback_eefs_addfirst = np.concatenate([eef, feedback_eefs], axis=0)
                next_state = feedback_eefs_addfirst[-1]
                next_states.append(next_state)
                
                # 插值处理
                seq_len = len(action_seqs[i])
                time_indices = np.arange(0, seq_len+1) * ds
                target_indices = np.arange(1, seq_len * ds + 1)
                
                left_pos_pts = feedback_eefs_addfirst[:, :3]
                left_quat_pts = feedback_eefs_addfirst[:, 3:7]
                left_grip_pts = feedback_eefs_addfirst[:, 7]
                right_pos_pts = feedback_eefs_addfirst[:, 8:11]
                right_quat_pts = feedback_eefs_addfirst[:, 11:15]
                right_grip_pts = feedback_eefs_addfirst[:, 15]
                
                left_pos_interp, left_quat_interp, left_grip_interp = interpolate_pos_quat(
                    left_pos_pts, left_quat_pts, left_grip_pts, time_indices, target_indices
                )
                right_pos_interp, right_quat_interp, right_grip_interp = interpolate_pos_quat(
                    right_pos_pts, right_quat_pts, right_grip_pts, time_indices, target_indices
                )
                
                feedback_eefs_interpolated = np.concatenate([
                    left_pos_interp,
                    left_quat_interp,
                    left_grip_interp[:, None],
                    right_pos_interp,
                    right_quat_interp,
                    right_grip_interp[:, None]
                ], axis=1)
                feedback_eefs_list.append(feedback_eefs_interpolated)
            
            return feedback_eefs_list, np.array(next_states)
        else:
            # 序列长度不同，回退到逐个处理
            feedback_eefs_list = []
            next_states = []
            for i in range(batch_size):
                eef, next_state = self.get_feedback_eef_interpolated(
                    eef_model, observations[i], action_seqs[i], ds
                )
                feedback_eefs_list.append(eef)
                next_states.append(next_state)
            return feedback_eefs_list, np.array(next_states)
    
    def compute_diff_eef(self, desired_eefs, feedback_eefs, mode="mean"):
        error = self.error_calculator.compute_eef_error_from_array(desired_eefs, feedback_eefs)
        if mode == "mean":
            return np.mean(error)
        elif mode == "max":
            return np.max(error)
        else:
            return np.min(error)

    def _signal_handler(self, sig, frame):
        print("\nCtrl+C detected. Preparing to shutdown...")
        self._shutting_down = True


    def _get_common_metrics(self) -> dict[str, Any]:
        _, total_time = self._timer.reset()
        metrics = {
            "total_time": total_time,
            "env_steps": self.global_env_steps,
            "env_episodes": self.global_env_episodes,
        }

        return metrics

    def shutdown(self):
        if self.dynamics_eval_env:
            self.dynamics_eval_env.close()
        if self.agent_eval_env:
            self.agent_eval_env.close()

    def save_snapshot(self, best_ckpt=False):
        snapshot = self.work_dir / "snapshots" / f"{self.global_env_steps}_snapshot.pt"
        if best_ckpt:
            snapshot = self.work_dir / "snapshots" / f"best_snapshot.pt"
        snapshot.parent.mkdir(parents=True, exist_ok=True)
        keys_to_save = [
            # "obs_stats",
            # "act_stats",
            "_pretrain_step",
            "_main_loop_iterations",
            "_global_env_episode",
            "cfg",
        ]
        payload = {k: self.__dict__[k] for k in keys_to_save}
        payload["dynamics"] = self.dynamics.state_dict()
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        latest_snapshot = self.work_dir / "snapshots" / "latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)

    def load_snapshot(self, path_to_snapshot_to_load=None):
        if path_to_snapshot_to_load is None:
            path_to_snapshot_to_load = (
                self.work_dir / "snapshots" / "latest_snapshot.pt"
            )
        else:
            path_to_snapshot_to_load = Path(path_to_snapshot_to_load)
        if not path_to_snapshot_to_load.is_file():
            raise ValueError(
                f"Provided file '{str(path_to_snapshot_to_load)}' is not a snapshot."
            )
        with path_to_snapshot_to_load.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        
        for k, v in payload.items():
            self.__dict__[k] = v
    
    def load_dynamics_snapshot(self, path_to_dynamics_snapshot):
        path_to_dynamics_snapshot = Path(path_to_dynamics_snapshot)
        with path_to_dynamics_snapshot.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        self.dynamics_cfg = payload.pop("cfg")
        self.dynamics = hydra.utils.instantiate(
            self.dynamics_cfg.method,
            device=self.device,
            observation_space=self.dynamics_eval_env.observation_space, # original space
            action_space=self.dynamics_eval_env.action_space, 
            use_var=self.dynamics_cfg.use_var,
            use_qpos_pred=getattr(self.dynamics_cfg, "use_qpos_pred", False),
            use_pixels=self.dynamics_cfg.use_pixels,
        )
        self.dynamics.load_state_dict(payload.pop("dynamics"))
        self.dynamics.train(False)
    
    def save_safe_iql_snapshot(self, agent, epoch, best_success_rate=0.0, best_ckpt=False):
        """Save Safe-IQL agent snapshot"""
        snapshot_dir = self.work_dir / "snapshots"
        snapshot_dir.mkdir(parents=True, exist_ok=True)
        
        if best_ckpt:
            snapshot = snapshot_dir / "safe_iql_best_snapshot.pt"
        else:
            snapshot = snapshot_dir / f"safe_iql_epoch_{epoch}_snapshot.pt"
        
        safe_iql_cfg = OmegaConf.select(self.cfg, 'safe_iql', default={})
        payload = {
            'cfg': self.cfg,
            'safe_iql_cfg': safe_iql_cfg,
            'agent_state': agent.state_dict(),
            'epoch': epoch,
            'best_success_rate': best_success_rate,
        }
        
        with snapshot.open("wb") as f:
            torch.save(payload, f)
        
        # Also update latest snapshot
        latest_snapshot = snapshot_dir / "safe_iql_latest_snapshot.pt"
        shutil.copy(snapshot, latest_snapshot)
    
    def load_safe_iql_snapshot(self, agent, path_to_snapshot=None, eval=False):
        """Load Safe-IQL agent snapshot and return epoch, best_success_rate"""
        if path_to_snapshot is None:
            snapshot_dir = self.work_dir / "snapshots"
            path_to_snapshot = snapshot_dir / "safe_iql_latest_snapshot.pt"
            if eval:
                path_to_snapshot = snapshot_dir / "safe_iql_best_snapshot.pt"
        else:
            path_to_snapshot = Path(path_to_snapshot)
        
        if not path_to_snapshot.exists():
            return 0, 0.0
        
        with path_to_snapshot.open("rb") as f:
            payload = torch.load(f, map_location=self.device, weights_only=False)
        
        agent.load_state_dict(payload['agent_state'])
        epoch = payload.get('epoch', 0)
        best_success_rate = payload.get('best_success_rate', 0.0)
        return epoch, best_success_rate

    def load_agent_snapshot(self, path_to_agent):
        path_to_agent = Path(path_to_agent)
        
        with path_to_agent.open("rb") as f:
            payload = torch.load(f, map_location="cpu", weights_only=False)
        self.agent_cfg = payload.pop("cfg")

        self.agent_eval_env = self.env_factory.make_eval_env(self.agent_cfg, self.work_dir)
        self.agent = hydra.utils.instantiate(
            self.agent_cfg.method,
            current_task=self.cfg.env.task_name,
            device=self.device,
            observation_space=self.agent_eval_env.observation_space,
            action_space=self.agent_eval_env.action_space,
            num_train_envs=self.agent_cfg.num_train_envs,
            replay_alpha=self.agent_cfg.replay.alpha,
            replay_beta=self.agent_cfg.replay.beta,
            frame_stack_on_channel=self.agent_cfg.frame_stack_on_channel,
        )
        self.agent.load_state_dict(payload.pop("agent"))
        if self.agent_cfg.load_ema:
            print("Load ema...")
            self.agent.actor.ema.load_state_dict(payload.pop("ema"))
    
    def _sample_safe_iql_data(self, save_path=None):
        """Sample safe transitions from demos and save to file"""
        assert hasattr(self.env_factory, '_demos'), (
            "env_factory._demos not found. Please ensure collect_or_fetch_demos and "
            "post_collect_or_fetch_demos are called first."
        )
        
        if save_path is None:
            save_path = self.work_dir / "safe_iql_data.pt"
        
        if os.path.exists(save_path):
            print(f"Safe-IQL data already exists at {save_path}. Loading...")
            return save_path
        
        safe_iql_cfg = OmegaConf.select(self.cfg, 'safe_iql', default={})
        k_low = OmegaConf.select(safe_iql_cfg, 'k_low', default=2)
        k_high = OmegaConf.select(safe_iql_cfg, 'k_high', default=4)
        epsilon = OmegaConf.select(safe_iql_cfg, 'epsilon', default=0.01)
        k_values = list(range(k_low, k_high + 1))  # [2, 3, 4] if low=2, high=4
        
        # Action sequence lengths: only even lengths from 20 to 30
        A_lengths = list(range(20, 31, 2))  # [20, 22, 24, 26, 28, 30]
        
        print(f"Sampling safe transitions with k in {k_values}, epsilon={epsilon}, A_lengths={A_lengths}...")
        safe_transitions = []
        
        # Statistics for each k value: {k: {'safe': count, 'unsafe': count, 'total': count}}
        k_stats = {k: {'safe': 0, 'unsafe': 0, 'total': 0} for k in k_values}
        
        # Filter successful demonstrations (same logic as load_demos_into_replay)
        all_demos = self.env_factory._demos
        demos = []
        for i, demo in enumerate(all_demos):
            successful = (demo[0][-1]["demo"] == 1)
            if successful:
                demos.append(demo)
            else:
                print(f"Skipping failed demonstration {i}")
        
        print(f"Using {len(demos)} successful demos out of {len(all_demos)} total demos")
        
        # Process each demo
        for demo_idx, demo in tqdm(enumerate(demos), desc="Processing demos", total=len(demos)):
            # Extract observations and actions from demo
            # Demo format: first element is (observation, info), rest are (observation, reward, term, trunc, info)
            # Action is in info["demo_action"] of the next step
            observations = []
            actions = []
            
            # First observation
            obs, info = demo[0]
            observations.append(obs)
            
            # Extract actions and subsequent observations
            for i in range(1, len(demo)):
                obs, reward, term, trunc, info = demo[i]
                observations.append(obs)
                if "demo_action" in info:
                    actions.append(info["demo_action"])
            
            # Convert observations to low_dim_state format
            low_dim_states = np.array([np.concatenate([obs['proprioception'], obs['proprioception_floating_base'], obs['proprioception_grippers']]) for obs in observations])
            eef_states = np.array([obs['eef'] for obs in observations])
            
            if len(actions) == 0:
                continue  # Skip if no actions
            
            rgb_names = self.dynamics_cfg.env.cameras
            
            # Batch processing: outer loop over A_lengths, inner batch processing over states
            for A_len in A_lengths:
                # Collect all valid (state_idx, A) pairs for this A_len
                valid_states = []
                valid_state_indices = []
                valid_actions = []
                
                for state_idx in range(len(low_dim_states)):
                    if state_idx + A_len > len(actions):
                        continue  # Not enough future actions
                    
                    A = np.array(actions[state_idx:state_idx + A_len])  # (A_len, action_dim)
                    s_env = low_dim_states[state_idx:state_idx+1]
                    eef = eef_states[state_idx:state_idx+1]
                    
                    # Create observation dict
                    observation = {
                        'low_dim_state': s_env,
                        'eef': eef
                    }
                    for name in rgb_names:
                        observation[f'rgb_{name}'] = np.zeros((1, 3, 84, 84), dtype=np.uint8)
                    
                    valid_states.append(observation)
                    valid_state_indices.append(state_idx)
                    valid_actions.append(A)
                
                if len(valid_states) == 0:
                    continue  # No valid states for this A_len
                
                # Batch compute baseline (using k_low) for all valid states
                baseline_action_seqs = []
                valid_baseline_mask = []
                for A in valid_actions:
                    baseline_actions = A[k_low-1::k_low]
                    baseline_action_seqs.append(baseline_actions)
                    valid_baseline_mask.append(len(baseline_actions) > 0)
                
                if not any(valid_baseline_mask):
                    continue
                
                # Filter to only valid baselines
                valid_states_baseline = [obs for i, obs in enumerate(valid_states) if valid_baseline_mask[i]]
                valid_state_indices_baseline = [idx for i, idx in enumerate(valid_state_indices) if valid_baseline_mask[i]]
                valid_actions_baseline = [A for i, A in enumerate(valid_actions) if valid_baseline_mask[i]]
                baseline_action_seqs = [seq for seq in baseline_action_seqs if len(seq) > 0]
                
                # Batch compute baselines
                baseline_eefs_list, baseline_next_states = self.get_feedback_eef_interpolated_batch(
                    self.dynamics, valid_states_baseline, baseline_action_seqs, k_low
                )
                
                # For each k value, batch compute feedbacks and compare
                for k in k_values:
                    # Prepare downsampled action sequences
                    ds_action_seqs = []
                    valid_ds_mask = []
                    for A in valid_actions_baseline:
                        ds_actions = A[k-1::k]
                        ds_action_seqs.append(ds_actions)
                        valid_ds_mask.append(len(ds_actions) > 0)
                    
                    if not any(valid_ds_mask):
                        continue
                    
                    # Filter to only valid downsampled sequences
                    valid_states_ds = [obs for i, obs in enumerate(valid_states_baseline) if valid_ds_mask[i]]
                    valid_state_indices_ds = [idx for i, idx in enumerate(valid_state_indices_baseline) if valid_ds_mask[i]]
                    baseline_eefs_list_ds = [eef for i, eef in enumerate(baseline_eefs_list) if valid_ds_mask[i]]
                    baseline_next_states_ds = baseline_next_states[valid_ds_mask]
                    ds_action_seqs = [seq for seq in ds_action_seqs if len(seq) > 0]
                    
                    # Batch compute feedbacks
                    feedback_eefs_list, feedback_next_states = self.get_feedback_eef_interpolated_batch(
                        self.dynamics, valid_states_ds, ds_action_seqs, k
                    )
                    
                    # Compare each feedback with its baseline
                    for i, (baseline_eef, feedback_eef) in enumerate(zip(baseline_eefs_list_ds, feedback_eefs_list)):
                        min_length = min(len(baseline_eef), len(feedback_eef))
                        error = self.compute_diff_eef(baseline_eef[:min_length], feedback_eef[:min_length], mode="max")
                        # Update statistics
                        k_stats[k]['total'] += 1
                        if error <= epsilon or k == k_low:
                            k_stats[k]['safe'] += 1
                            reward = k / k_high
                        else:
                            k_stats[k]['unsafe'] += 1
                            reward = -1.0
                        state_idx = valid_state_indices_ds[i]
                        eef = eef_states[state_idx]
                        s_env_next = feedback_next_states[i]
                        safe_transitions.append({
                            's_env': eef,
                            'A': ds_action_seqs[i],
                            'k': k,
                            'r': reward,
                            's_env_next': s_env_next,
                        })
        # Save data
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        with save_path.open("wb") as f:
            torch.save({
                'transitions': safe_transitions,
                'k_low': k_low,
                'k_high': k_high,
            }, f)
        print(f"Saved {len(safe_transitions)} safe transitions to {save_path}")
        
        # Print statistics for each k value
        print("\n=== Safety Statistics by k value ===")
        for k in sorted(k_values):
            stats = k_stats[k]
            total = stats['total']
            safe = stats['safe']
            unsafe = stats['unsafe']
            if total > 0:
                safe_rate = safe / total * 100
                unsafe_rate = unsafe / total * 100
                print(f"k={k}: Total={total}, Safe={safe} ({safe_rate:.2f}%), Unsafe={unsafe} ({unsafe_rate:.2f}%)")
            else:
                print(f"k={k}: Total=0, Safe=0, Unsafe=0")
        print("=" * 40 + "\n")
        
        return save_path
    
    def _train_safe_iql_scheduling_policy(self, data_path=None, eval=False):
        """Train Safe-IQL Scheduling Policy using saved data"""
        from torch.distributions import Categorical
        from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
        
        # Custom initialization function with output layer scaling
        def safe_iql_weight_init(m):
            """Initialize weights: orthogonal for hidden layers, small weights for output layers"""
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)
        
        def safe_iql_output_init(m):
            """Initialize output layer with very small weights"""
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                m.weight.data.mul_(1e-3)  # Scale down output weights
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)
        
        # Simple MLP for Value Network with LayerNorm
        class SimpleMLP(nn.Module):
            def __init__(self, input_dim, hidden_dims, output_dim, activation=nn.ReLU, use_layernorm=True):
                super().__init__()
                layers = []
                prev_dim = input_dim
                for i, hidden_dim in enumerate(hidden_dims):
                    layers.append(nn.Linear(prev_dim, hidden_dim))
                    if use_layernorm:
                        layers.append(nn.LayerNorm(hidden_dim))
                    layers.append(activation())
                    prev_dim = hidden_dim
                
                # Output layer
                output_layer = nn.Linear(prev_dim, output_dim)
                layers.append(output_layer)
                self.net = nn.Sequential(*layers)
                
                # Initialize: hidden layers with orthogonal, output layer with small weights
                for module in self.net:
                    if isinstance(module, nn.Linear) and module is output_layer:
                        safe_iql_output_init(module)
                    elif isinstance(module, nn.Linear):
                        safe_iql_weight_init(module)
            
            def forward(self, x):
                return self.net(x)
        
        # Q-Network with RNN for variable-length action sequences
        class QNetworkRNN(nn.Module):
            def __init__(self, env_dim, action_dim, hidden_dims, rnn_hidden_size=128, num_rnn_layers=2):
                super().__init__()
                self.rnn_hidden_size = rnn_hidden_size
                # Encode environment state
                self.env_encoder = SimpleMLP(env_dim, hidden_dims, rnn_hidden_size)
                
                # RNN for processing action sequence
                self.rnn = nn.GRU(
                    input_size=action_dim,
                    hidden_size=rnn_hidden_size,
                    num_layers=num_rnn_layers,
                    batch_first=True
                )
                
                # Final MLP to output Q value
                self.q_head = SimpleMLP(rnn_hidden_size * 2, hidden_dims, 1)
                
            def forward(self, s_env, A_prime_list):
                """
                Args:
                    s_env: (batch, env_dim)
                    A_prime_list: List of variable-length sequences, each (seq_len, action_dim)
                Returns:
                    q_values: (batch, 1)
                """
                
                batch_size = s_env.shape[0]
                
                # Encode environment state
                s_encoded = self.env_encoder(s_env)  # (batch, rnn_hidden_size)
                
                # Handle empty sequences
                lengths = []
                valid_seqs = []
                valid_indices = []
                for i, a_seq in enumerate(A_prime_list):
                    if a_seq.shape[0] > 0:
                        valid_seqs.append(a_seq)
                        lengths.append(a_seq.shape[0])
                        valid_indices.append(i)
                
                if len(valid_seqs) == 0:
                    # All sequences are empty
                    rnn_outs = torch.zeros(batch_size, self.rnn_hidden_size, device=s_env.device)
                else:
                    # Pad sequences to same length
                    padded_seqs = pad_sequence(valid_seqs, batch_first=True)  # (valid_batch, max_len, action_dim)
                    # pack_padded_sequence requires lengths to be on CPU with int64 dtype
                    lengths_tensor = torch.tensor(lengths, dtype=torch.int64, device='cpu')
                    
                    # Pack padded sequences
                    packed_seqs = pack_padded_sequence(padded_seqs, lengths_tensor, batch_first=True, enforce_sorted=False)
                    
                    # Process with RNN
                    rnn_output, hidden = self.rnn(packed_seqs)  # hidden: (num_layers, valid_batch, rnn_hidden_size)
                    
                    # Take the last hidden state (from the last layer)
                    rnn_out_valid = hidden[-1]  # (valid_batch, rnn_hidden_size)
                    
                    # Map back to full batch (handle empty sequences)
                    rnn_outs = torch.zeros(batch_size, self.rnn_hidden_size, device=s_env.device)
                    for idx, valid_idx in enumerate(valid_indices):
                        rnn_outs[valid_idx] = rnn_out_valid[idx]
                # pdb.set_trace()
                # Concatenate environment encoding and RNN output
                combined = torch.cat([s_encoded, rnn_outs], dim=-1)  # (batch, rnn_hidden_size * 2)
                q_values = self.q_head(combined)  # (batch, 1)
                
                return q_values
        
        class SafeIQLAgent:
            def __init__(self, env_dim, action_dim, k_values, hidden_dims=[256, 256], 
                        expectile=0.7, device='cuda', learning_rate=3e-4, tau=0.005):
                self.device = device
                self.k_values = k_values  # e.g., [2, 3, 4]
                self.expectile = expectile
                self.tau = tau  # Soft update coefficient for target network
                
                # Value Network V_φ(s^env): estimates value of environment state
                self.value_net = SimpleMLP(env_dim, hidden_dims, 1).to(device)
                
                # Q-Network Q_θ(s^env, A'): uses RNN for variable-length sequences
                self.q_net = QNetworkRNN(env_dim, action_dim, hidden_dims).to(device)
                
                # Target Q-Network Q_θ⁻(s^env, A'): for stable V network training
                self.target_q_net = QNetworkRNN(env_dim, action_dim, hidden_dims).to(device)
                # Initialize target network with same weights as Q network
                self.target_q_net.load_state_dict(self.q_net.state_dict())
                # Freeze target network parameters
                for param in self.target_q_net.parameters():
                    param.requires_grad = False
                
                # Optimizers (no policy network needed)
                self.value_opt = torch.optim.Adam(self.value_net.parameters(), lr=learning_rate)
                self.q_opt = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
            
            def update_target_q_net(self):
                """Soft update target Q network"""
                utils.soft_update_params(self.q_net, self.target_q_net, self.tau)
            
            def state_dict(self):
                """Return state dict for saving"""
                return {
                    'value_net': self.value_net.state_dict(),
                    'q_net': self.q_net.state_dict(),
                    'target_q_net': self.target_q_net.state_dict(),
                    'value_opt': self.value_opt.state_dict(),
                    'q_opt': self.q_opt.state_dict(),
                    'k_values': self.k_values,
                    'expectile': self.expectile,
                    'tau': self.tau,
                }
            
            def load_state_dict(self, state_dict):
                """Load state dict"""
                self.value_net.load_state_dict(state_dict['value_net'])
                self.q_net.load_state_dict(state_dict['q_net'])
                if 'target_q_net' in state_dict:
                    self.target_q_net.load_state_dict(state_dict['target_q_net'])
                else:
                    # For backward compatibility, initialize target from Q net
                    self.target_q_net.load_state_dict(self.q_net.state_dict())
                self.value_opt.load_state_dict(state_dict['value_opt'])
                self.q_opt.load_state_dict(state_dict['q_opt'])
                self.k_values = state_dict.get('k_values', self.k_values)
                self.expectile = state_dict.get('expectile', self.expectile)
                self.tau = state_dict.get('tau', self.tau)
            
            def predict_k(self, s_env, A):
                """Predict k by querying max Q value over all k values"""
                # For each k value, compute f(A, k) and get Q value
                q_values_list = []
                s_env_tensor = torch.from_numpy(s_env).float().to(self.device)
                s_env_tensor = s_env_tensor.unsqueeze(0)
                for k in self.k_values:
                    # Compute f(A, k): downsample action sequence (consistent with data generation)
                    A_prime = A[k-1::k]  # Downsample
                    if len(A_prime) == 0:
                        q_vals = torch.tensor([[-1e10]], device=self.device)  # Very low value, shape (1, 1)
                    else:
                        A_prime_tensor = torch.from_numpy(A_prime).float().to(self.device)
                        q_vals = self.q_net(s_env_tensor, [A_prime_tensor])  # Shape (1, 1)
                    q_values_list.append(q_vals)
                # print(q_values_list)
                # Find k with max Q
                q_matrix = torch.cat(q_values_list, dim=1)  # (1, num_k_values)
                k_indices = q_matrix.argmax(dim=1)
                return torch.tensor([self.k_values[idx.item()] for idx in k_indices], device=self.device)
        
        # Load data
        if data_path is None:
            data_path = self.work_dir / "safe_iql_data.pt"
        data_path = Path(data_path)
        
        if not data_path.exists():
            raise FileNotFoundError(f"Data file not found: {data_path}. Please run _sample_safe_iql_data first.")
        
        with data_path.open("rb") as f:
            data = torch.load(f, map_location="cpu", weights_only=False)
            safe_transitions = data['transitions']
            k_low = data['k_low']
            k_high = data['k_high']
            k_values = list(range(k_low, k_high + 1))
        
        if len(safe_transitions) == 0:
            print("Warning: No safe transitions found! Cannot train.")
            return
        
        # Get dimensions from environment
        sample_obs = self.dynamics_eval_env.observation_space.sample()
        if isinstance(sample_obs, dict):
            env_dim = sample_obs['eef'].shape[-1]
        else:
            env_dim = sample_obs.shape[-1]
        
        action_dim = self.dynamics_eval_env.action_space.shape[-1]
        
        # Get config values
        safe_iql_cfg = OmegaConf.select(self.cfg, 'safe_iql', default={})
        hidden_dims = OmegaConf.select(safe_iql_cfg, 'hidden_dims', default=[256, 256])
        expectile = OmegaConf.select(safe_iql_cfg, 'expectile', default=0.8)
        learning_rate = OmegaConf.select(safe_iql_cfg, 'learning_rate', default=3e-4)
        num_epochs = OmegaConf.select(safe_iql_cfg, 'num_epochs', default=100)
        batch_size = OmegaConf.select(safe_iql_cfg, 'batch_size', default=64)
        gamma = OmegaConf.select(safe_iql_cfg, 'gamma', default=0.99)
        tau = OmegaConf.select(safe_iql_cfg, 'tau', default=0.005)  # Target network soft update coefficient
        # reward_scale = OmegaConf.select(safe_iql_cfg, 'reward_scale', default=1.0)  # Reward scaling factor
        grad_clip_norm = OmegaConf.select(safe_iql_cfg, 'grad_clip_norm', default=1.0)  # Gradient clipping
        save_checkpoint = OmegaConf.select(safe_iql_cfg, 'save_checkpoint', default=True)
        checkpoint_every_n_epochs = OmegaConf.select(safe_iql_cfg, 'checkpoint_every_n_epochs', default=10)
        save_best_checkpoint = OmegaConf.select(safe_iql_cfg, 'save_best_checkpoint', default=True)
        eval_every_n_epochs = OmegaConf.select(safe_iql_cfg, 'eval_every_n_epochs', default=10)
        eval_episodes = OmegaConf.select(safe_iql_cfg, 'eval_episodes', default=50)
        
        # Initialize agent
        agent = SafeIQLAgent(
            env_dim=env_dim,
            action_dim=action_dim,
            k_values=k_values,
            hidden_dims=hidden_dims,
            expectile=expectile,
            device=self.device,
            learning_rate=learning_rate,
            tau=tau
        )
        
        # Try to load existing checkpoint

        start_epoch, best_success_rate = self.load_safe_iql_snapshot(agent, eval=eval)
        if eval:
            metric = self._eval_safe_iql_policy(agent, num_episodes=eval_episodes, record_media=True)
            return metric

        if start_epoch > 0:
            # Log loading checkpoint
            metrics = {
                'safe_iql_resume_epoch': start_epoch,
                'safe_iql_best_success_rate': best_success_rate,
            }
            self.logger.log_metrics(metrics, start_epoch, prefix="eval")
        
        # Prepare data for training
        # Group transitions by k for batch processing
        transitions_by_k = {k: [] for k in k_values}
        for t in safe_transitions:
            transitions_by_k[t['k']].append(t)
        
        # Reserve test samples for each k value (for evaluation)
        test_samples_by_k = {}
        test_sample_size_per_k = 50  # Reserve 50 samples per k for evaluation
        for k in k_values:
            if len(transitions_by_k[k]) > test_sample_size_per_k:
                # Randomly sample test samples
                k_indices = np.arange(len(transitions_by_k[k]))
                np.random.shuffle(k_indices)
                test_indices = k_indices[:test_sample_size_per_k]
                test_samples_by_k[k] = [transitions_by_k[k][i] for i in test_indices]
                # Remove test samples from training data
                train_indices = k_indices[test_sample_size_per_k:]
                transitions_by_k[k] = [transitions_by_k[k][i] for i in train_indices]
            else:
                # If not enough samples, use all as test (and skip training for this k)
                test_samples_by_k[k] = transitions_by_k[k].copy()
                transitions_by_k[k] = []
        
        # Reconstruct safe_transitions from remaining training data
        safe_transitions = []
        for k in k_values:
            safe_transitions.extend(transitions_by_k[k])
        
        # Training loop
        for epoch in tqdm(range(start_epoch, num_epochs), desc="Training Safe-IQL"):
            # Initialize loss accumulators for this epoch
            epoch_vf_loss = 0.0
            epoch_q_loss = 0.0
            num_batches = 0
            
            # Sample batches (handle variable-length sequences)
            all_indices = np.arange(len(safe_transitions))
            np.random.shuffle(all_indices)
            
            for batch_start in range(0, len(safe_transitions), batch_size):
                batch_indices = all_indices[batch_start:batch_start + batch_size]
                batch_transitions = [safe_transitions[i] for i in batch_indices]
                
                # Prepare batch (handle variable-length sequences)
                s_env_batch = torch.from_numpy(np.array([t['s_env'] for t in batch_transitions])).float().to(self.device)
                s_env_batch = s_env_batch.squeeze()
                A_batch = [torch.from_numpy(t['A']).float().to(self.device) for t in batch_transitions]
                # k_batch = torch.tensor([t['k'] for t in batch_transitions]).long().to(self.device)
                r_batch = torch.tensor([t['r'] for t in batch_transitions]).float().to(self.device)
                s_env_next_batch = torch.from_numpy(np.array([t['s_env_next'] for t in batch_transitions])).float().to(self.device)
                s_env_next_batch = s_env_next_batch.squeeze()
                
                # Apply reward scaling
                # r_batch_scaled = r_batch * reward_scale

                
                # 1. Update Value Network V_φ(s^env)
                # L_V(φ) = E[L_τ(Q_θ⁻(s^env, f(A,k)) - V_φ(s^env))]
                # Use target Q network for stable training
                with torch.no_grad():
                    q_values = agent.target_q_net(s_env_batch, A_batch)
                
                v_values = agent.value_net(s_env_batch)
                vf_err = q_values - v_values  # q_values already detached from target network
                vf_weight = torch.where(vf_err > 0, agent.expectile, 1 - agent.expectile)
                vf_loss = (vf_weight * (vf_err ** 2)).mean()
                
                agent.value_opt.zero_grad()
                vf_loss.backward()
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(agent.value_net.parameters(), max_norm=grad_clip_norm)
                agent.value_opt.step()
                
                # 2. Update Q-Network Q_θ(s^env, A')
                # L_Q(θ) = E[(r(k) + γV_φ(s^env') - Q_θ(s^env, f(A,k)))^2]
                with torch.no_grad():
                    v_next = agent.value_net(s_env_next_batch)
                    target_q = r_batch.unsqueeze(1) + gamma * v_next
                
                q_values = agent.q_net(s_env_batch, A_batch)
                q_loss = F.mse_loss(q_values, target_q)
                
                agent.q_opt.zero_grad()
                q_loss.backward()
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(agent.q_net.parameters(), max_norm=grad_clip_norm)
                agent.q_opt.step()
                
                # Update target Q network (soft update)
                agent.update_target_q_net()
                
                # Accumulate losses for logging
                epoch_vf_loss += vf_loss.item()
                epoch_q_loss += q_loss.item()
                num_batches += 1
            
            # Calculate average losses for this epoch
            avg_vf_loss = epoch_vf_loss / num_batches if num_batches > 0 else 0.0
            avg_q_loss = epoch_q_loss / num_batches if num_batches > 0 else 0.0
            
            # Log loss every epoch
            metrics = {
                'safe_iql_vf_loss': avg_vf_loss,
                'safe_iql_q_loss': avg_q_loss,
            }
            self.logger.log_metrics(metrics, epoch + 1, prefix="train")
            
            # Log additional metrics periodically (every 10 epochs)
            # Evaluate V and Q values for different k values using reserved test samples
            if (epoch + 1) % 10 == 0:
                with torch.no_grad():
                    # Collect all test samples for V mean calculation
                    all_test_samples = []
                    for k in k_values:
                        all_test_samples.extend(test_samples_by_k.get(k, []))
                    
                    if len(all_test_samples) > 0:
                        s_all_sample = torch.from_numpy(np.array([t['s_env'] for t in all_test_samples])).float().to(self.device)
                        s_all_sample = s_all_sample.squeeze()
                        # V mean (same for all k)
                        v_mean = agent.value_net(s_all_sample).mean().item()
                    else:
                        v_mean = 0.0
                    
                    # Q mean for each k value using reserved test samples
                    metrics = {
                        'safe_iql_v_mean': v_mean,
                        'safe_iql_epoch': epoch + 1,
                    }
                    
                    # For each k, compute Q values using reserved test samples
                    for k in k_values:
                        k_test_samples = test_samples_by_k.get(k, [])
                        if len(k_test_samples) > 0:
                            s_k_sample = torch.from_numpy(np.array([t['s_env'] for t in k_test_samples])).float().to(self.device)
                            s_k_sample = s_k_sample.squeeze()
                            A_k_sample = [torch.from_numpy(t['A']).float().to(self.device) for t in k_test_samples]
                            q_k_mean = agent.q_net(s_k_sample, A_k_sample).mean().item()
                            metrics[f'safe_iql_q_mean_k{k}'] = q_k_mean
                        else:
                            metrics[f'safe_iql_q_mean_k{k}'] = 0.0
                    
                    self.logger.log_metrics(metrics, epoch + 1, prefix="eval")
            
            # Evaluate policy periodically and save checkpoint
            if (epoch + 1) % eval_every_n_epochs == 0:
                # Evaluate policy performance
                eval_metrics = self._eval_safe_iql_policy(agent, num_episodes=eval_episodes, record_media=False)
                success_rate = eval_metrics['safe_iql_success_rate']
                avg_steps = eval_metrics['safe_iql_avg_steps']
                
                # Log evaluation metrics (including k selection ratios)
                metrics = {
                    'safe_iql_eval_success_rate': success_rate,
                    'safe_iql_eval_avg_steps': avg_steps,
                }
                # Add k selection ratios to metrics
                for k in k_values:
                    key = f'safe_iql_k{k}_selection_ratio'
                    if key in eval_metrics:
                        metrics[key] = eval_metrics[key]
                self.logger.log_metrics(metrics, epoch + 1, prefix="eval")
                
                # Save checkpoint
                if save_checkpoint:
                    self.save_safe_iql_snapshot(agent, epoch + 1, best_success_rate, best_ckpt=False)
                    
                    # Update best checkpoint if success_rate is better
                    if save_best_checkpoint and success_rate > best_success_rate:
                        best_success_rate = success_rate
                        self.save_safe_iql_snapshot(agent, epoch + 1, best_success_rate, best_ckpt=True)
                        
                        # Log best checkpoint update
                        metrics = {
                            'safe_iql_best_success_rate': best_success_rate,
                            'safe_iql_best_epoch': epoch + 1,
                        }
                        self.logger.log_metrics(metrics, epoch + 1, prefix="train")
            elif save_checkpoint and (epoch + 1) % checkpoint_every_n_epochs == 0:
                # Save checkpoint without evaluation (if checkpoint_every_n_epochs != eval_every_n_epochs)
                self.save_safe_iql_snapshot(agent, epoch + 1, best_success_rate, best_ckpt=False)
        
        # Save final checkpoint
        if save_checkpoint:
            self.save_safe_iql_snapshot(agent, num_epochs, best_success_rate, best_ckpt=False)
        
        return agent
    
    def _eval_safe_iql_policy(self, agent, num_episodes=50, record_media=False):
        """Evaluate Safe-IQL Scheduling Policy using trained agent"""
        # Get no chunk-wrapper env
        wrapped_env = self.env_factory._wrap_env(
            self.env_factory._create_env(self.agent_cfg, self.work_dir),
            self.cfg,
            demo_env=False,
            train=False,
            chunk_env=False
        )
        all_frames = []
        num_steps_list = []
        suc_list = []
        media_dir = os.path.join(self.work_dir, f"th{self.cfg.safe_iql.epsilon}")
        os.makedirs(media_dir, exist_ok=True)
        
        # Track k selection statistics
        k_selection_counts = {k: 0 for k in agent.k_values}
        total_selections = 0
        
        for i in tqdm(range(num_episodes), desc="Testing Safe-IQL Policy"):
            observation, infos = wrapped_env.reset()
            self.agent.reset(self.main_loop_iterations, [0])
            done = False
            rewards = 0
            num_steps = 0
            frames = []
            while not done:
                with torch.no_grad(), utils.eval_mode(self.agent):
                    torch_observations = {
                        k: torch.from_numpy(v).unsqueeze(0).to(self.device) for k, v in observation.items()
                    }

                    action = self.agent.act(
                        torch_observations, self.main_loop_iterations, eval_mode=True
                    )
                    action = action[0].cpu().detach().numpy()
                
                # Use Safe-IQL agent to predict k based on Q values
                # s_env = observation['low_dim_state']
                s_env = observation['eef'].squeeze()

                k_selected = agent.predict_k(s_env, action).item()
                
                # Track k selection
                if k_selected in k_selection_counts:
                    k_selection_counts[k_selected] += 1
                total_selections += 1
                
                # Downsample actions using selected k
                ds_actions = action[k_selected-1::k_selected]
                
                # Execute downsampled actions
                for ds_action in ds_actions:
                    next_observation, reward, termination, truncation, next_info = wrapped_env.step(ds_action)
                    rewards += reward
                    num_steps += 1
                    if record_media and wrapped_env.render_mode:
                        frame = wrapped_env.render()
                        frame = utils.put_text(frame, f"{num_steps},{k_selected}", font_size=1, resize=False)
                        frames.append(frame)
                done = termination | truncation
                observation = next_observation
            if record_media:
                imageio.mimsave(
                    os.path.join(media_dir, f"{i}_{rewards > 0}.mp4"), 
                    np.array(frames), 
                    fps=wrapped_env.unwrapped._control_frequency
                )
            frames = []
            suc_list.append(rewards > 0)
            if rewards > 0:
                num_steps_list.append(num_steps)
            if record_media:
                all_frames.append(frames)
    
        # Calculate k selection ratios
        k_selection_ratios = {}
        for k in agent.k_values:
            if total_selections > 0:
                k_selection_ratios[f"safe_iql_k{k}_selection_ratio"] = k_selection_counts[k] / total_selections
            else:
                k_selection_ratios[f"safe_iql_k{k}_selection_ratio"] = 0.0
        
        metric = {
            "safe_iql_success_rate": sum(suc_list)/num_episodes,
            "safe_iql_avg_steps": np.mean(num_steps_list) if num_steps_list else 0,
            **k_selection_ratios,
        }
        wrapped_env.close()
        return metric
        
