import os
import re
import sapien.core as sapien
from sapien.render import clear_cache as sapien_clear_cache
from sapien.utils.viewer import Viewer
import numpy as np
import gymnasium as gym
import pdb
import toppra as ta
import json
import transforms3d as t3d
from collections import OrderedDict
import torch, random

from .utils import *
import math
from .robot import Robot
from .camera import Camera

from copy import deepcopy
import subprocess
from pathlib import Path
import trimesh
import imageio
import glob


from ._GLOBAL_CONFIGS import *

from typing import Optional, Literal

current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

class Base_Task(gym.Env):

    def __init__(self):
        pass

    # =========================================================== Init Task Env ===========================================================
    def _init_task_env_(self, table_xy_bias=[0, 0], table_height_bias=0, **kwags):
        """
        Initialization TODO
        - `self.FRAME_IDX`: The index of the file saved for the current scene.
        - `self.fcitx5-configtool`: Left gripper pose (close <=0, open >=0.4).
        - `self.ep_num`: Episode ID.
        - `self.task_name`: Task name.
        - `self.save_dir`: Save path.`
        - `self.left_original_pose`: Left arm original pose.
        - `self.right_original_pose`: Right arm original pose.
        - `self.left_arm_joint_id`: [6,14,18,22,26,30].
        - `self.right_arm_joint_id`: [7,15,19,23,27,31].
        - `self.render_fre`: Render frequency.
        """
        super().__init__()
        ta.setup_logging("CRITICAL")  # hide logging
        np.random.seed(kwags.get("seed", 0))
        torch.manual_seed(kwags.get("seed", 0))
        # random.seed(kwags.get('seed', 0))

        self.FRAME_IDX = 0
        self.task_name = kwags.get("task_name")
        self.save_dir = kwags.get("save_path", "data")
        self.ep_num = kwags.get("now_ep_num", 0)
        self.render_freq = kwags.get("render_freq", 10)
        self.data_type = kwags.get("data_type", None)
        self.save_data = kwags.get("save_data", False)
        self.dual_arm = kwags.get("dual_arm", True)
        self.eval_mode = kwags.get("eval_mode", False)

        self.need_topp = True  # TODO

        # Random
        random_setting = kwags.get("domain_randomization")
        self.random_background = random_setting.get("random_background", False)
        self.cluttered_table = random_setting.get("cluttered_table", False)
        self.clean_background_rate = random_setting.get("clean_background_rate", 1)
        self.random_head_camera_dis = random_setting.get("random_head_camera_dis", 0)
        self.random_table_height = random_setting.get("random_table_height", 0)
        self.random_light = random_setting.get("random_light", False)
        self.crazy_random_light_rate = random_setting.get("crazy_random_light_rate", 0)
        self.crazy_random_light = (0 if not self.random_light else np.random.rand() < self.crazy_random_light_rate)
        self.random_embodiment = random_setting.get("random_embodiment", False)  # TODO

        self.file_path = []
        self.plan_success = True
        self.step_lim = None
        self.fix_gripper = False
        self.setup_scene()

        self.left_js = None
        self.right_js = None
        self.raw_head_pcl = None
        self.real_head_pcl = None
        self.real_head_pcl_color = None

        self.now_obs = {}
        self.take_action_cnt = 0
        self.eval_video_path = kwags.get("eval_video_save_dir", None)

        self.save_freq = kwags.get("save_freq")
        self.world_pcd = None

        self.size_dict = list()
        self.cluttered_objs = list()
        self.prohibited_area = list()  # [x_min, y_min, x_max, y_max]
        self.record_cluttered_objects = list()  # record cluttered objects info

        self.eval_success = False
        self.table_z_bias = (np.random.uniform(low=-self.random_table_height, high=0) + table_height_bias)  # TODO
        self.need_plan = kwags.get("need_plan", True)
        self.left_joint_path = kwags.get("left_joint_path", [])
        self.right_joint_path = kwags.get("right_joint_path", [])
        self.left_cnt = 0
        self.right_cnt = 0

        self.instruction = None  # for Eval

        self.create_table_and_wall(table_xy_bias=table_xy_bias, table_height=0.74)
        self.load_robot(**kwags)
        self.load_camera(**kwags)
        self.robot.move_to_homestate()

        render_freq = self.render_freq
        self.render_freq = 0
        self.together_open_gripper(save_freq=None)
        self.render_freq = render_freq

        self.robot.set_origin_endpose()
        self.load_actors()

        if self.cluttered_table:
            self.get_cluttered_table()

        is_stable, unstable_list = self.check_stable()
        if not is_stable:
            raise UnStableError(
                f'Objects is unstable in seed({kwags.get("seed", 0)}), unstable objects: {", ".join(unstable_list)}')

        if self.eval_mode:
            with open(os.path.join(CONFIGS_PATH, "_eval_step_limit.yml"), "r") as f:
                try:
                    data = yaml.safe_load(f)
                    self.step_lim = data[self.task_name]
                except:
                    print(f"{self.task_name} not in step limit file, set to 1000")
                    self.step_lim = 1000

        # info
        self.info = dict()
        self.info["cluttered_table_info"] = self.record_cluttered_objects
        self.info["texture_info"] = {
            "wall_texture": self.wall_texture,
            "table_texture": self.table_texture,
        }
        self.info["info"] = {}

        self.stage_success_tag = False

    def check_stable(self):
        actors_list, actors_pose_list = [], []
        for actor in self.scene.get_all_actors():
            actors_list.append(actor)

        def get_sim(p1, p2):
            return np.abs(cal_quat_dis(p1.q, p2.q) * 180)

        is_stable, unstable_list = True, []

        def check(times):
            nonlocal self, is_stable, actors_list, actors_pose_list
            for _ in range(times):
                self.scene.step()
                for idx, actor in enumerate(actors_list):
                    actors_pose_list[idx].append(actor.get_pose())

            for idx, actor in enumerate(actors_list):
                final_pose = actors_pose_list[idx][-1]
                for pose in actors_pose_list[idx][-200:]:
                    if get_sim(final_pose, pose) > 3.0:
                        is_stable = False
                        unstable_list.append(actor.get_name())
                        break

        is_stable = True
        for _ in range(2000):
            self.scene.step()
        for idx, actor in enumerate(actors_list):
            actors_pose_list.append([actor.get_pose()])
        check(500)
        return is_stable, unstable_list

    def play_once(self):
        pass

    def check_success(self):
        pass

    def setup_scene(self, **kwargs):
        """
        Set the scene
            - Set up the basic scene: light source, viewer.
        """
        self.engine = sapien.Engine()
        # declare sapien renderer
        from sapien.render import set_global_config

        set_global_config(max_num_materials=50000, max_num_textures=50000)
        self.renderer = sapien.SapienRenderer()
        # give renderer to sapien sim
        self.engine.set_renderer(self.renderer)

        sapien.render.set_camera_shader_dir("rt")
        sapien.render.set_ray_tracing_samples_per_pixel(32)
        sapien.render.set_ray_tracing_path_depth(8)
        sapien.render.set_ray_tracing_denoiser("oidn")

        # declare sapien scene
        scene_config = sapien.SceneConfig()
        self.scene = self.engine.create_scene(scene_config)
        # set simulation timestep
        self.scene.set_timestep(kwargs.get("timestep", 1 / 250))
        # add ground to scene
        self.scene.add_ground(kwargs.get("ground_height", 0))
        # set default physical material
        self.scene.default_physical_material = self.scene.create_physical_material(
            kwargs.get("static_friction", 0.5),
            kwargs.get("dynamic_friction", 0.5),
            kwargs.get("restitution", 0),
        )
        # give some white ambient light of moderate intensity
        self.scene.set_ambient_light(kwargs.get("ambient_light", [0.5, 0.5, 0.5]))
        # default enable shadow unless specified otherwise
        shadow = kwargs.get("shadow", True)
        # default spotlight angle and intensity
        direction_lights = kwargs.get("direction_lights", [[[0, 0.5, -1], [0.5, 0.5, 0.5]]])
        self.direction_light_lst = []
        for direction_light in direction_lights:
            if self.random_light:
                direction_light[1] = [
                    np.random.rand(),
                    np.random.rand(),
                    np.random.rand(),
                ]
            self.direction_light_lst.append(
                self.scene.add_directional_light(direction_light[0], direction_light[1], shadow=shadow))
        # default point lights position and intensity
        point_lights = kwargs.get("point_lights", [[[1, 0, 1.8], [1, 1, 1]], [[-1, 0, 1.8], [1, 1, 1]]])
        self.point_light_lst = []
        for point_light in point_lights:
            if self.random_light:
                point_light[1] = [np.random.rand(), np.random.rand(), np.random.rand()]
            self.point_light_lst.append(self.scene.add_point_light(point_light[0], point_light[1], shadow=shadow))

        # initialize viewer with camera position and orientation
        if self.render_freq:
            self.viewer = Viewer(self.renderer)
            self.viewer.set_scene(self.scene)
            self.viewer.set_camera_xyz(
                x=kwargs.get("camera_xyz_x", 0.4),
                y=kwargs.get("camera_xyz_y", 0.22),
                z=kwargs.get("camera_xyz_z", 1.5),
            )
            self.viewer.set_camera_rpy(
                r=kwargs.get("camera_rpy_r", 0),
                p=kwargs.get("camera_rpy_p", -0.8),
                y=kwargs.get("camera_rpy_y", 2.45),
            )

    def create_table_and_wall(self, table_xy_bias=[0, 0], table_height=0.74):
        self.table_xy_bias = table_xy_bias
        wall_texture, table_texture = None, None
        table_height += self.table_z_bias

        if self.random_background:
            texture_type = "seen" if not self.eval_mode else "unseen"
            #directory_path = f"./assets/background_texture/{texture_type}"
            directory_path = os.path.join(ROOT_PATH , f"assets/background_texture/{texture_type}")
            file_count = len(
                [name for name in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, name))])

            # wall_texture, table_texture = random.randint(0, file_count - 1), random.randint(0, file_count - 1)
            wall_texture, table_texture = np.random.randint(0, file_count), np.random.randint(0, file_count)

            self.wall_texture, self.table_texture = (
                f"{texture_type}/{wall_texture}",
                f"{texture_type}/{table_texture}",
            )
            if np.random.rand() <= self.clean_background_rate:
                self.wall_texture = None
            if np.random.rand() <= self.clean_background_rate:
                self.table_texture = None
        else:
            self.wall_texture, self.table_texture = None, None

        self.wall = create_box(
            self.scene,
            sapien.Pose(p=[0, 1, 1.5]),
            half_size=[3, 0.6, 1.5],
            color=(1, 0.9, 0.9),
            name="wall",
            texture_id=self.wall_texture,
            is_static=True,
        )

        self.table = create_table(
            self.scene,
            sapien.Pose(p=[table_xy_bias[0], table_xy_bias[1], table_height]),
            length=1.2,
            width=0.7,
            height=table_height,
            thickness=0.05,
            is_static=True,
            texture_id=self.table_texture,
        )

    def get_cluttered_table(self, cluttered_numbers=10, xlim=[-0.59, 0.59], ylim=[-0.34, 0.34], zlim=[0.741]):
        self.record_cluttered_objects = []  # record cluttered objects

        xlim[0] += self.table_xy_bias[0]
        xlim[1] += self.table_xy_bias[0]
        ylim[0] += self.table_xy_bias[1]
        ylim[1] += self.table_xy_bias[1]

        if np.random.rand() < self.clean_background_rate:
            return

        task_objects_list = []
        for entity in self.scene.get_all_actors():
            actor_name = entity.get_name()
            if actor_name == "":
                continue
            if actor_name in ["table", "wall", "ground"]:
                continue
            task_objects_list.append(actor_name)
        self.obj_names, self.cluttered_item_info = get_available_cluttered_objects(task_objects_list)

        success_count = 0
        max_try = 50
        trys = 0

        while success_count < cluttered_numbers and trys < max_try:
            obj = np.random.randint(len(self.obj_names))
            obj_name = self.obj_names[obj]
            obj_idx = np.random.randint(len(self.cluttered_item_info[obj_name]["ids"]))
            obj_idx = self.cluttered_item_info[obj_name]["ids"][obj_idx]
            obj_radius = self.cluttered_item_info[obj_name]["params"][obj_idx]["radius"]
            obj_offset = self.cluttered_item_info[obj_name]["params"][obj_idx]["z_offset"]
            obj_maxz = self.cluttered_item_info[obj_name]["params"][obj_idx]["z_max"]

            success, self.cluttered_obj = rand_create_cluttered_actor(
                self.scene,
                xlim=xlim,
                ylim=ylim,
                zlim=np.array(zlim) + self.table_z_bias,
                modelname=obj_name,
                modelid=obj_idx,
                modeltype=self.cluttered_item_info[obj_name]["type"],
                rotate_rand=True,
                rotate_lim=[0, 0, math.pi],
                size_dict=self.size_dict,
                obj_radius=obj_radius,
                z_offset=obj_offset,
                z_max=obj_maxz,
                prohibited_area=self.prohibited_area,
            )
            if not success or self.cluttered_obj is None:
                trys += 1
                continue
            self.cluttered_obj.set_name(f"{obj_name}")
            self.cluttered_objs.append(self.cluttered_obj)
            pose = self.cluttered_obj.get_pose().p.tolist()
            pose.append(obj_radius)
            self.size_dict.append(pose)
            success_count += 1
            self.record_cluttered_objects.append({"object_type": obj_name, "object_index": obj_idx})

        if success_count < cluttered_numbers:
            print(f"Warning: Only {success_count} cluttered objects are placed on the table.")

        self.size_dict = None
        self.cluttered_objs = []

    def load_robot(self, **kwags):
        """
        load aloha robot urdf file, set root pose and set joints
        """
        if not hasattr(self, "robot"):
            self.robot = Robot(self.scene, self.need_topp, **kwags)
            self.robot.set_planner(self.scene)
            self.robot.init_joints()
        else:
            self.robot.reset(self.scene, self.need_topp, **kwags)

        for link in self.robot.left_entity.get_links():
            link: sapien.physx.PhysxArticulationLinkComponent = link
            link.set_mass(1)
        for link in self.robot.right_entity.get_links():
            link: sapien.physx.PhysxArticulationLinkComponent = link
            link.set_mass(1)

    def load_camera(self, **kwags):
        """
        Add cameras and set camera parameters
            - Including four cameras: left, right, front, head.
        """

        self.cameras = Camera(
            bias=self.table_z_bias,
            random_head_camera_dis=self.random_head_camera_dis,
            **kwags,
        )
        self.cameras.load_camera(self.scene)
        self.scene.step()  # run a physical step
        self.scene.update_render()  # sync pose from SAPIEN to renderer

    # =========================================================== Sapien ===========================================================

    def _update_render(self):
        """
        Update rendering to refresh the camera's RGBD information
        (rendering must be updated even when disabled, otherwise data cannot be collected).
        """
        if self.crazy_random_light:
            for renderColor in self.point_light_lst:
                renderColor.set_color([np.random.rand(), np.random.rand(), np.random.rand()])
            for renderColor in self.direction_light_lst:
                renderColor.set_color([np.random.rand(), np.random.rand(), np.random.rand()])
            now_ambient_light = self.scene.ambient_light
            now_ambient_light = np.clip(np.array(now_ambient_light) + np.random.rand(3) * 0.2 - 0.1, 0, 1)
            self.scene.set_ambient_light(now_ambient_light)
        self.cameras.update_wrist_camera(self.robot.left_camera.get_pose(), self.robot.right_camera.get_pose())
        self.scene.update_render()

    # =========================================================== Basic APIs ===========================================================

    def get_obs(self):
        self._update_render()
        self.cameras.update_picture()
        pkl_dic = {
            "observation": {},
            "pointcloud": [],
            "joint_action": {},
            "endpose": {},
        }
        
        pkl_dic["observation"] = self.cameras.get_config()
        # rgb
        if self.data_type.get("rgb", False):
            rgb = self.cameras.get_rgb()
            for camera_name in rgb.keys():
                pkl_dic["observation"][camera_name].update(rgb[camera_name])
        
        if self.data_type.get("third_view", False):
            third_view_rgb = self.cameras.get_observer_rgb()
            pkl_dic["third_view_rgb"] = third_view_rgb
        # mesh_segmentation
        if self.data_type.get("mesh_segmentation", False):
            mesh_segmentation = self.cameras.get_segmentation(level="mesh")
            for camera_name in mesh_segmentation.keys():
                pkl_dic["observation"][camera_name].update(mesh_segmentation[camera_name])
        # actor_segmentation
        if self.data_type.get("actor_segmentation", False):
            actor_segmentation = self.cameras.get_segmentation(level="actor")
            for camera_name in actor_segmentation.keys():
                pkl_dic["observation"][camera_name].update(actor_segmentation[camera_name])
        # depth
        if self.data_type.get("depth", False):
            depth = self.cameras.get_depth()
            for camera_name in depth.keys():
                pkl_dic["observation"][camera_name].update(depth[camera_name])
        # endpose
        if self.data_type.get("endpose", False):
            norm_gripper_val = [
                self.robot.get_left_gripper_val(),
                self.robot.get_right_gripper_val(),
            ]
            left_endpose = self.get_arm_pose("left")
            right_endpose = self.get_arm_pose("right")
            pkl_dic["endpose"]["left_endpose"] = left_endpose
            pkl_dic["endpose"]["left_gripper"] = norm_gripper_val[0]
            pkl_dic["endpose"]["right_endpose"] = right_endpose
            pkl_dic["endpose"]["right_gripper"] = norm_gripper_val[1]
        # qpos
        if self.data_type.get("qpos", False):

            left_jointstate = self.robot.get_left_arm_jointState()
            right_jointstate = self.robot.get_right_arm_jointState()

            pkl_dic["joint_action"]["left_arm"] = left_jointstate[:-1]
            pkl_dic["joint_action"]["left_gripper"] = left_jointstate[-1]
            pkl_dic["joint_action"]["right_arm"] = right_jointstate[:-1]
            pkl_dic["joint_action"]["right_gripper"] = right_jointstate[-1]
            pkl_dic["joint_action"]["vector"] = np.array(left_jointstate + right_jointstate)
        # pointcloud
        if self.data_type.get("pointcloud", False):
            pkl_dic["pointcloud"] = self.cameras.get_pcd(self.data_type.get("conbine", False))

        self.now_obs = deepcopy(pkl_dic)
        return pkl_dic

    def save_camera_rgb(self, save_path, camera_name='head_camera'):
        self._update_render()
        self.cameras.update_picture()
        rgb = self.cameras.get_rgb()
        save_img(save_path, rgb[camera_name]['rgb'])

    def _take_picture(self):  # save data
        if not self.save_data:
            return

        print("saving: episode = ", self.ep_num, " index = ", self.FRAME_IDX, end="\r")

        if self.FRAME_IDX == 0:
            self.folder_path = {"cache": f"{self.save_dir}/.cache/episode{self.ep_num}/"}

            for directory in self.folder_path.values():  # remove previous data
                if os.path.exists(directory):
                    file_list = os.listdir(directory)
                    for file in file_list:
                        os.remove(directory + file)

        pkl_dic = self.get_obs()
        save_pkl(self.folder_path["cache"] + f"{self.FRAME_IDX}.pkl", pkl_dic)  # use cache
        self.FRAME_IDX += 1

    def save_traj_data(self, idx):
        file_path = os.path.join(self.save_dir, "_traj_data", f"episode{idx}.pkl")
        traj_data = {
            "left_joint_path": deepcopy(self.left_joint_path),
            "right_joint_path": deepcopy(self.right_joint_path),
        }
        save_pkl(file_path, traj_data)

    def load_tran_data(self, idx):
        assert self.save_dir is not None, "self.save_dir is None"
        file_path = os.path.join(self.save_dir, "_traj_data", f"episode{idx}.pkl")
        with open(file_path, "rb") as f:
            traj_data = pickle.load(f)
        return traj_data

    def merge_pkl_to_hdf5_video(self):
        if not self.save_data:
            return
        cache_path = self.folder_path["cache"]
        target_file_path = f"{self.save_dir}/data/episode{self.ep_num}.hdf5"
        target_video_path = f"{self.save_dir}/video/episode{self.ep_num}.mp4"
        # print('Merging pkl to hdf5: ', cache_path, ' -> ', target_file_path)

        os.makedirs(f"{self.save_dir}/data", exist_ok=True)
        process_folder_to_hdf5_video(cache_path, target_file_path, target_video_path)

    def remove_data_cache(self):
        folder_path = self.folder_path["cache"]
        GREEN = "\033[92m"
        RED = "\033[91m"
        RESET = "\033[0m"
        try:
            shutil.rmtree(folder_path)
            print(f"{GREEN}Folder {folder_path} deleted successfully.{RESET}")
        except OSError as e:
            print(f"{RED}Error: {folder_path} is not empty or does not exist.{RESET}")

    def set_instruction(self, instruction=None):
        self.instruction = instruction

    def get_instruction(self, instruction=None):
        return self.instruction

    def set_path_lst(self, args):
        self.need_plan = args.get("need_plan", True)
        self.left_joint_path = args.get("left_joint_path", [])
        self.right_joint_path = args.get("right_joint_path", [])

    def _set_eval_video_ffmpeg(self, ffmpeg):
        self.eval_video_ffmpeg = ffmpeg

    def close_env(self, clear_cache=False):
        if clear_cache:
            # for actor in self.scene.get_all_actors():
            #     self.scene.remove_actor(actor)
            sapien_clear_cache()
        self.close()

    def _del_eval_video_ffmpeg(self):
        if self.eval_video_ffmpeg:
            self.eval_video_ffmpeg.stdin.close()
            self.eval_video_ffmpeg.wait()
            del self.eval_video_ffmpeg

    def delay(self, delay_time, save_freq=None):
        render_freq = self.render_freq
        self.render_freq = 0

        left_gripper_val = self.robot.get_left_gripper_val()
        right_gripper_val = self.robot.get_right_gripper_val()
        for i in range(delay_time):
            self.together_close_gripper(
                left_pos=left_gripper_val,
                right_pos=right_gripper_val,
                save_freq=save_freq,
            )

        self.render_freq = render_freq

    def set_gripper(self, set_tag="together", left_pos=None, right_pos=None):
        """
        Set gripper posture
        - `left_pos`: Left gripper pose
        - `right_pos`: Right gripper pose
        - `set_tag`: "left" to set the left gripper, "right" to set the right gripper, "together" to set both grippers simultaneously.
        """
        alpha = 0.5

        left_result, right_result = None, None

        if set_tag == "left" or set_tag == "together":
            left_result = self.robot.left_plan_grippers(self.robot.get_left_gripper_val(), left_pos)
            left_gripper_step = left_result["per_step"]
            left_gripper_res = left_result["result"]
            num_step = left_result["num_step"]
            left_result["result"] = np.pad(
                left_result["result"],
                (0, int(alpha * num_step)),
                mode="constant",
                constant_values=left_gripper_res[-1],
            )  # append
            left_result["num_step"] += int(alpha * num_step)
            if set_tag == "left":
                return left_result

        if set_tag == "right" or set_tag == "together":
            right_result = self.robot.right_plan_grippers(self.robot.get_right_gripper_val(), right_pos)
            right_gripper_step = right_result["per_step"]
            right_gripper_res = right_result["result"]
            num_step = right_result["num_step"]
            right_result["result"] = np.pad(
                right_result["result"],
                (0, int(alpha * num_step)),
                mode="constant",
                constant_values=right_gripper_res[-1],
            )  # append
            right_result["num_step"] += int(alpha * num_step)
            if set_tag == "right":
                return right_result

        return left_result, right_result

    def add_prohibit_area(
        self,
        actor: Actor | sapien.Entity | sapien.Pose | list | np.ndarray,
        padding=0.01,
    ):

        if (isinstance(actor, sapien.Pose) or isinstance(actor, list) or isinstance(actor, np.ndarray)):
            actor_pose = transforms._toPose(actor)
            actor_data = {}
        else:
            actor_pose = actor.get_pose()
            if isinstance(actor, Actor):
                actor_data = actor.config
            else:
                actor_data = {}

        scale: float = actor_data.get("scale", 1)
        origin_bounding_size = (np.array(actor_data.get("extents", [0.1, 0.1, 0.1])) * scale / 2)
        origin_bounding_pts = (np.array([
            [-1, -1, -1],
            [-1, -1, 1],
            [-1, 1, -1],
            [-1, 1, 1],
            [1, -1, -1],
            [1, -1, 1],
            [1, 1, -1],
            [1, 1, 1],
        ]) * origin_bounding_size)

        actor_matrix = actor_pose.to_transformation_matrix()
        trans_bounding_pts = actor_matrix[:3, :3] @ origin_bounding_pts.T + actor_matrix[:3, 3].reshape(3, 1)
        x_min = np.min(trans_bounding_pts[0]) - padding
        x_max = np.max(trans_bounding_pts[0]) + padding
        y_min = np.min(trans_bounding_pts[1]) - padding
        y_max = np.max(trans_bounding_pts[1]) + padding
        # add_robot_visual_box(self, [x_min, y_min, actor_matrix[3, 3]])
        # add_robot_visual_box(self, [x_max, y_max, actor_matrix[3, 3]])
        self.prohibited_area.append([x_min, y_min, x_max, y_max])

    def is_left_gripper_open(self):
        return self.robot.is_left_gripper_open()

    def is_right_gripper_open(self):
        return self.robot.is_right_gripper_open()

    def is_left_gripper_open_half(self):
        return self.robot.is_left_gripper_open_half()

    def is_right_gripper_open_half(self):
        return self.robot.is_right_gripper_open_half()

    def is_left_gripper_close(self):
        return self.robot.is_left_gripper_close()

    def is_right_gripper_close(self):
        return self.robot.is_right_gripper_close()

    # =========================================================== Our APIS ===========================================================

    def together_close_gripper(self, save_freq=-1, left_pos=0, right_pos=0):
        left_result, right_result = self.set_gripper(left_pos=left_pos, right_pos=right_pos, set_tag="together")
        control_seq = {
            "left_arm": None,
            "left_gripper": left_result,
            "right_arm": None,
            "right_gripper": right_result,
        }
        self.take_dense_action(control_seq, save_freq=save_freq)

    def together_open_gripper(self, save_freq=-1, left_pos=1, right_pos=1):
        left_result, right_result = self.set_gripper(left_pos=left_pos, right_pos=right_pos, set_tag="together")
        control_seq = {
            "left_arm": None,
            "left_gripper": left_result,
            "right_arm": None,
            "right_gripper": right_result,
        }
        self.take_dense_action(control_seq, save_freq=save_freq)

    def left_move_to_pose(
        self,
        pose,
        constraint_pose=None,
        use_point_cloud=False,
        use_attach=False,
        save_freq=-1,
    ):
        """
        Interpolative planning with screw motion.
        Will not avoid collision and will fail if the path contains collision.
        """
        if not self.plan_success:
            return
        if pose is None:
            self.plan_success = False
            return
        if type(pose) == sapien.Pose:
            pose = pose.p.tolist() + pose.q.tolist()

        if self.need_plan:
            left_result = self.robot.left_plan_path(pose, constraint_pose=constraint_pose)
            self.left_joint_path.append(deepcopy(left_result))
        else:
            left_result = deepcopy(self.left_joint_path[self.left_cnt])
            self.left_cnt += 1

        if left_result["status"] != "Success":
            self.plan_success = False
            return

        return left_result

    def right_move_to_pose(
        self,
        pose,
        constraint_pose=None,
        use_point_cloud=False,
        use_attach=False,
        save_freq=-1,
    ):
        """
        Interpolative planning with screw motion.
        Will not avoid collision and will fail if the path contains collision.
        """
        if not self.plan_success:
            return
        if pose is None:
            self.plan_success = False
            return
        if type(pose) == sapien.Pose:
            pose = pose.p.tolist() + pose.q.tolist()

        if self.need_plan:
            right_result = self.robot.right_plan_path(pose, constraint_pose=constraint_pose)
            self.right_joint_path.append(deepcopy(right_result))
        else:
            right_result = deepcopy(self.right_joint_path[self.right_cnt])
            self.right_cnt += 1

        if right_result["status"] != "Success":
            self.plan_success = False
            return

        return right_result

    def together_move_to_pose(
        self,
        left_target_pose,
        right_target_pose,
        left_constraint_pose=None,
        right_constraint_pose=None,
        use_point_cloud=False,
        use_attach=False,
        save_freq=-1,
    ):
        """
        Interpolative planning with screw motion.
        Will not avoid collision and will fail if the path contains collision.
        """
        if not self.plan_success:
            return
        if left_target_pose is None or right_target_pose is None:
            self.plan_success = False
            return
        if type(left_target_pose) == sapien.Pose:
            left_target_pose = left_target_pose.p.tolist() + left_target_pose.q.tolist()
        if type(right_target_pose) == sapien.Pose:
            right_target_pose = (right_target_pose.p.tolist() + right_target_pose.q.tolist())
        save_freq = self.save_freq if save_freq == -1 else save_freq
        if self.need_plan:
            left_result = self.robot.left_plan_path(left_target_pose, constraint_pose=left_constraint_pose)
            right_result = self.robot.right_plan_path(right_target_pose, constraint_pose=right_constraint_pose)
            self.left_joint_path.append(deepcopy(left_result))
            self.right_joint_path.append(deepcopy(right_result))
        else:
            left_result = deepcopy(self.left_joint_path[self.left_cnt])
            right_result = deepcopy(self.right_joint_path[self.right_cnt])
            self.left_cnt += 1
            self.right_cnt += 1

        try:
            left_success = left_result["status"] == "Success"
            right_success = right_result["status"] == "Success"
            if not left_success or not right_success:
                self.plan_success = False
                # return TODO
        except Exception as e:
            if left_result is None or right_result is None:
                self.plan_success = False
                return  # TODO

        if save_freq != None:
            self._take_picture()

        now_left_id = 0
        now_right_id = 0
        i = 0

        left_n_step = left_result["position"].shape[0] if left_success else 0
        right_n_step = right_result["position"].shape[0] if right_success else 0

        while now_left_id < left_n_step or now_right_id < right_n_step:
            # set the joint positions and velocities for move group joints only.
            # The others are not the responsibility of the planner
            if (left_success and now_left_id < left_n_step
                    and (not right_success or now_left_id / left_n_step <= now_right_id / right_n_step)):
                self.robot.set_arm_joints(
                    left_result["position"][now_left_id],
                    left_result["velocity"][now_left_id],
                    "left",
                )
                now_left_id += 1

            if (right_success and now_right_id < right_n_step
                    and (not left_success or now_right_id / right_n_step <= now_left_id / left_n_step)):
                self.robot.set_arm_joints(
                    right_result["position"][now_right_id],
                    right_result["velocity"][now_right_id],
                    "right",
                )
                now_right_id += 1

            self.scene.step()
            if self.render_freq and i % self.render_freq == 0:
                self._update_render()
                self.viewer.render()

            if save_freq != None and i % save_freq == 0:
                self._update_render()
                self._take_picture()
            i += 1

        if save_freq != None:
            self._take_picture()

    def move(
        self,
        actions_by_arm1: tuple[ArmTag, list[Action]],
        actions_by_arm2: tuple[ArmTag, list[Action]] = None,
        save_freq=-1,
    ):
        """
        Take action for the robot.
        """

        def get_actions(actions, arm_tag: ArmTag) -> list[Action]:
            if actions[1] is None:
                if actions[0][0] == arm_tag:
                    return actions[0][1]
                else:
                    return []
            else:
                if actions[0][0] == actions[0][1]:
                    raise ValueError("")
                if actions[0][0] == arm_tag:
                    return actions[0][1]
                else:
                    return actions[1][1]

        if self.plan_success is False:
            return False

        actions = [actions_by_arm1, actions_by_arm2]
        left_actions = get_actions(actions, "left")
        right_actions = get_actions(actions, "right")

        max_len = max(len(left_actions), len(right_actions))
        left_actions += [None] * (max_len - len(left_actions))
        right_actions += [None] * (max_len - len(right_actions))

        for left, right in zip(left_actions, right_actions):

            if (left is not None and left.arm_tag != "left") or (right is not None
                                                                 and right.arm_tag != "right"):  # check
                raise ValueError(f"Invalid arm tag: {left.arm_tag} or {right.arm_tag}. Must be 'left' or 'right'.")

            if (left is not None and left.action == "move") and (right is not None
                                                                 and right.action == "move"):  # together move
                self.together_move_to_pose(  # TODO
                    left_target_pose=left.target_pose,
                    right_target_pose=right.target_pose,
                    left_constraint_pose=left.args.get("constraint_pose"),
                    right_constraint_pose=right.args.get("constraint_pose"),
                )
                if self.plan_success is False:
                    return False
                continue  # TODO
            else:
                control_seq = {
                    "left_arm": None,
                    "left_gripper": None,
                    "right_arm": None,
                    "right_gripper": None,
                }
                if left is not None:
                    if left.action == "move":
                        control_seq["left_arm"] = self.left_move_to_pose(
                            pose=left.target_pose,
                            constraint_pose=left.args.get("constraint_pose"),
                        )
                    else:  # left.action == 'gripper'
                        control_seq["left_gripper"] = self.set_gripper(left_pos=left.target_gripper_pos, set_tag="left")
                    if self.plan_success is False:
                        return False

                if right is not None:
                    if right.action == "move":
                        control_seq["right_arm"] = self.right_move_to_pose(
                            pose=right.target_pose,
                            constraint_pose=right.args.get("constraint_pose"),
                        )
                    else:  # right.action == 'gripper'
                        control_seq["right_gripper"] = self.set_gripper(right_pos=right.target_gripper_pos,
                                                                        set_tag="right")
                    if self.plan_success is False:
                        return False

            self.take_dense_action(control_seq)

        return True

    def get_gripper_actor_contact_position(self, actor_name):
        contacts = self.scene.get_contacts()
        position_lst = []
        for contact in contacts:
            if (contact.bodies[0].entity.name == actor_name or contact.bodies[1].entity.name == actor_name):
                contact_object = (contact.bodies[1].entity.name
                                  if contact.bodies[0].entity.name == actor_name else contact.bodies[0].entity.name)
                if contact_object in self.robot.gripper_name:
                    for point in contact.points:
                        position_lst.append(point.position)
        return position_lst

    def check_actors_contact(self, actor1, actor2):
        """
        Check if two actors are in contact.
        - actor1: The first actor.
        - actor2: The second actor.
        """
        contacts = self.scene.get_contacts()
        for contact in contacts:
            if (contact.bodies[0].entity.name == actor1
                    and contact.bodies[1].entity.name == actor2) or (contact.bodies[0].entity.name == actor2
                                                                     and contact.bodies[1].entity.name == actor1):
                return True
        return False

    def get_scene_contact(self):
        contacts = self.scene.get_contacts()
        for contact in contacts:
            pdb.set_trace()
            print(dir(contact))
            print(contact.bodies[0].entity.name, contact.bodies[1].entity.name)

    def choose_best_pose(self, res_pose, center_pose, arm_tag: ArmTag = None):
        """
        Choose the best pose from the list of target poses.
        - target_lst: List of target poses.
        """
        if not self.plan_success:
            return [-1, -1, -1, -1, -1, -1, -1]
        if arm_tag == "left":
            plan_multi_pose = self.robot.left_plan_multi_path
        elif arm_tag == "right":
            plan_multi_pose = self.robot.right_plan_multi_path
        target_lst = self.robot.create_target_pose_list(res_pose, center_pose, arm_tag)
        pose_num = len(target_lst)
        traj_lst = plan_multi_pose(target_lst)
        now_pose = None
        now_step = -1
        for i in range(pose_num):
            if traj_lst["status"][i] != "Success":
                continue
            if now_pose is None or len(traj_lst["position"][i]) < now_step:
                now_pose = target_lst[i]
        return now_pose

    # test grasp pose of all contact points
    def _print_all_grasp_pose_of_contact_points(self, actor: Actor, pre_dis: float = 0.1):
        for i in range(len(actor.config["contact_points_pose"])):
            print(i, self.get_grasp_pose(actor, pre_dis=pre_dis, contact_point_id=i))

    def get_grasp_pose(
        self,
        actor: Actor,
        arm_tag: ArmTag,
        contact_point_id: int = 0,
        pre_dis: float = 0.0,
    ) -> list:
        """
        Obtain the grasp pose through the marked grasp point.
        - actor: The instance of the object to be grasped.
        - arm_tag: The arm to be used, either "left" or "right".
        - pre_dis: The distance in front of the grasp point.
        - contact_point_id: The index of the grasp point.
        """
        if not self.plan_success:
            return [-1, -1, -1, -1, -1, -1, -1]

        contact_matrix = actor.get_contact_point(contact_point_id, "matrix")
        if contact_matrix is None:
            return None
        global_contact_pose_matrix = contact_matrix @ np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0],
                                                                [0, 0, 0, 1]])
        global_contact_pose_matrix_q = global_contact_pose_matrix[:3, :3]
        global_grasp_pose_p = (global_contact_pose_matrix[:3, 3] +
                               global_contact_pose_matrix_q @ np.array([-0.12 - pre_dis, 0, 0]).T)
        global_grasp_pose_q = t3d.quaternions.mat2quat(global_contact_pose_matrix_q)
        res_pose = list(global_grasp_pose_p) + list(global_grasp_pose_q)
        res_pose = self.choose_best_pose(res_pose, actor.get_contact_point(contact_point_id, "list"), arm_tag)
        return res_pose

    def _default_choose_grasp_pose(self, actor: Actor, arm_tag: ArmTag, pre_dis: float) -> list:
        """
        Default grasp pose function.
        - actor: The target actor to be grasped.
        - arm_tag: The arm to be used for grasping, either "left" or "right".
        - pre_dis: The distance in front of the grasp point, default is 0.1.
        """
        id = -1
        score = -1

        for i, contact_point in actor.iter_contact_points("list"):
            pose = self.get_grasp_pose(actor, arm_tag, pre_dis, i)
            now_score = 0
            if not (contact_point[1] < -0.1 and pose[2] < 0.85 or contact_point[1] > 0.05 and pose[2] > 0.92):
                now_score -= 1
            quat_dis = cal_quat_dis(pose[-4:], GRASP_DIRECTION_DIC[str(arm_tag) + "_arm_perf"])

        return self.get_grasp_pose(actor, arm_tag, pre_dis=pre_dis)

    def choose_grasp_pose(
        self,
        actor: Actor,
        arm_tag: ArmTag,
        pre_dis=0.1,
        target_dis=0,
        contact_point_id: list | float = None,
    ) -> list:
        """
        Test the grasp pose function.
        - actor: The actor to be grasped.
        - arm_tag: The arm to be used for grasping, either "left" or "right".
        - pre_dis: The distance in front of the grasp point, default is 0.1.
        """
        if not self.plan_success:
            return
        res_pre_top_down_pose = None
        res_top_down_pose = None
        dis_top_down = 1e9
        res_pre_side_pose = None
        res_side_pose = None
        dis_side = 1e9
        res_pre_pose = None
        res_pose = None
        dis = 1e9

        pref_direction = self.robot.get_grasp_perfect_direction(arm_tag)

        def get_grasp_pose(pre_grasp_pose, pre_grasp_dis):
            grasp_pose = deepcopy(pre_grasp_pose)
            grasp_pose = np.array(grasp_pose)
            direction_mat = t3d.quaternions.quat2mat(grasp_pose[-4:])
            grasp_pose[:3] += [pre_grasp_dis, 0, 0] @ np.linalg.inv(direction_mat)
            grasp_pose = grasp_pose.tolist()
            return grasp_pose

        def check_pose(pre_pose, pose, arm_tag):
            if arm_tag == "left":
                plan_func = self.robot.left_plan_path
            else:
                plan_func = self.robot.right_plan_path
            pre_path = plan_func(pre_pose)
            if pre_path["status"] != "Success":
                return False
            pre_qpos = pre_path["position"][-1]
            return plan_func(pose)["status"] == "Success"

        if contact_point_id is not None:
            if type(contact_point_id) != list:
                contact_point_id = [contact_point_id]
            contact_point_id = [(i, None) for i in contact_point_id]
        else:
            contact_point_id = actor.iter_contact_points()

        for i, _ in contact_point_id:
            pre_pose = self.get_grasp_pose(actor, arm_tag, contact_point_id=i, pre_dis=pre_dis)
            if pre_pose is None:
                continue
            pose = get_grasp_pose(pre_pose, pre_dis - target_dis)
            now_dis_top_down = cal_quat_dis(
                pose[-4:],
                GRASP_DIRECTION_DIC[("top_down_little_left" if arm_tag == "right" else "top_down_little_right")],
            )
            now_dis_side = cal_quat_dis(pose[-4:], GRASP_DIRECTION_DIC[pref_direction])

            if res_pre_top_down_pose is None or now_dis_top_down < dis_top_down:
                res_pre_top_down_pose = pre_pose
                res_top_down_pose = pose
                dis_top_down = now_dis_top_down

            if res_pre_side_pose is None or now_dis_side < dis_side:
                res_pre_side_pose = pre_pose
                res_side_pose = pose
                dis_side = now_dis_side

            now_dis = 0.7 * now_dis_top_down + 0.3 * now_dis_side
            if res_pre_pose is None or now_dis < dis:
                res_pre_pose = pre_pose
                res_pose = pose
                dis = now_dis

        if dis_top_down < 0.15:
            return res_pre_top_down_pose, res_top_down_pose
        if dis_side < 0.15:
            return res_pre_side_pose, res_side_pose
        return res_pre_pose, res_pose

    def grasp_actor(
        self,
        actor: Actor,
        arm_tag: ArmTag,
        pre_grasp_dis=0.1,
        grasp_dis=0,
        gripper_pos=0.0,
        contact_point_id: list | float = None,
    ):
        if not self.plan_success:
            return None, []
        if self.need_plan == False:
            if pre_grasp_dis == grasp_dis:
                return arm_tag, [
                    Action(arm_tag, "move", target_pose=[0, 0, 0, 0, 0, 0, 0]),
                    Action(arm_tag, "close", target_gripper_pos=gripper_pos),
                ]
            else:
                return arm_tag, [
                    Action(arm_tag, "move", target_pose=[0, 0, 0, 0, 0, 0, 0]),
                    Action(
                        arm_tag,
                        "move",
                        target_pose=[0, 0, 0, 0, 0, 0, 0],
                        constraint_pose=[1, 1, 1, 0, 0, 0],
                    ),
                    Action(arm_tag, "close", target_gripper_pos=gripper_pos),
                ]

        pre_grasp_pose, grasp_pose = self.choose_grasp_pose(
            actor,
            arm_tag=arm_tag,
            pre_dis=pre_grasp_dis,
            target_dis=grasp_dis,
            contact_point_id=contact_point_id,
        )
        if pre_grasp_pose == grasp_pose:
            return arm_tag, [
                Action(arm_tag, "move", target_pose=pre_grasp_pose),
                Action(arm_tag, "close", target_gripper_pos=gripper_pos),
            ]
        else:
            return arm_tag, [
                Action(arm_tag, "move", target_pose=pre_grasp_pose),
                Action(
                    arm_tag,
                    "move",
                    target_pose=grasp_pose,
                    constraint_pose=[1, 1, 1, 0, 0, 0],
                ),
                Action(arm_tag, "close", target_gripper_pos=gripper_pos),
            ]

    def get_place_pose(
        self,
        actor: Actor,
        arm_tag: ArmTag,
        target_pose: list | np.ndarray,
        constrain: Literal["free", "align", "auto"] = "auto",
        align_axis: list[np.ndarray] | np.ndarray | list = None,
        actor_axis: np.ndarray | list = [1, 0, 0],
        actor_axis_type: Literal["actor", "world"] = "actor",
        functional_point_id: int = None,
        pre_dis: float = 0.1,
        pre_dis_axis: Literal["grasp", "fp"] | np.ndarray | list = "grasp",
    ):

        if not self.plan_success:
            return [-1, -1, -1, -1, -1, -1, -1]

        actor_matrix = actor.get_pose().to_transformation_matrix()
        if functional_point_id is not None:
            place_start_pose = actor.get_functional_point(functional_point_id, "pose")
            z_transform = False
        else:
            place_start_pose = actor.get_pose()
            z_transform = True

        end_effector_pose = (self.robot.get_left_ee_pose() if arm_tag == "left" else self.robot.get_right_ee_pose())

        if constrain == "auto":
            grasp_direct_vec = place_start_pose.p - end_effector_pose[:3]
            if np.abs(np.dot(grasp_direct_vec, [0, 0, 1])) <= 0.1:
                place_pose = get_place_pose(
                    place_start_pose,
                    target_pose,
                    constrain="align",
                    actor_axis=grasp_direct_vec,
                    actor_axis_type="world",
                    align_axis=[1, 1, 0] if arm_tag == "left" else [-1, 1, 0],
                    z_transform=z_transform,
                )
            else:
                camera_vec = transforms._toPose(end_effector_pose).to_transformation_matrix()[:3, 2]
                place_pose = get_place_pose(
                    place_start_pose,
                    target_pose,
                    constrain="align",
                    actor_axis=camera_vec,
                    actor_axis_type="world",
                    align_axis=[0, 1, 0],
                    z_transform=z_transform,
                )
        else:
            place_pose = get_place_pose(
                place_start_pose,
                target_pose,
                constrain=constrain,
                actor_axis=actor_axis,
                actor_axis_type=actor_axis_type,
                align_axis=align_axis,
                z_transform=z_transform,
            )
        start2target = (transforms._toPose(place_pose).to_transformation_matrix()[:3, :3]
                        @ place_start_pose.to_transformation_matrix()[:3, :3].T)
        target_point = (start2target @ (actor_matrix[:3, 3] - place_start_pose.p).reshape(3, 1)).reshape(3) + np.array(
            place_pose[:3])

        ee_pose_matrix = t3d.quaternions.quat2mat(end_effector_pose[-4:])
        target_grasp_matrix = start2target @ ee_pose_matrix

        res_matrix = np.eye(4)
        res_matrix[:3, 3] = actor_matrix[:3, 3] - end_effector_pose[:3]
        res_matrix[:3, 3] = np.linalg.inv(ee_pose_matrix) @ res_matrix[:3, 3]
        target_grasp_qpose = t3d.quaternions.mat2quat(target_grasp_matrix)

        grasp_bias = target_grasp_matrix @ res_matrix[:3, 3]
        if pre_dis_axis == "grasp":
            target_dis_vec = target_grasp_matrix @ res_matrix[:3, 3]
            target_dis_vec /= np.linalg.norm(target_dis_vec)
        else:
            target_pose_mat = transforms._toPose(target_pose).to_transformation_matrix()
            if pre_dis_axis == "fp":
                pre_dis_axis = [0.0, 0.0, 1.0]
            pre_dis_axis = np.array(pre_dis_axis)
            pre_dis_axis /= np.linalg.norm(pre_dis_axis)
            target_dis_vec = (target_pose_mat[:3, :3] @ np.array(pre_dis_axis).reshape(3, 1)).reshape(3)
            target_dis_vec /= np.linalg.norm(target_dis_vec)
        res_pose = (target_point - grasp_bias - pre_dis * target_dis_vec).tolist() + target_grasp_qpose.tolist()
        return res_pose

    def place_actor(
        self,
        actor: Actor,
        arm_tag: ArmTag,
        target_pose: list | np.ndarray,
        functional_point_id: int = None,
        pre_dis: float = 0.1,
        dis: float = 0.02,
        is_open: bool = True,
        **args,
    ):
        if not self.plan_success:
            return None, []
        if self.need_plan:
            place_pre_pose = self.get_place_pose(
                actor,
                arm_tag,
                target_pose,
                functional_point_id=functional_point_id,
                pre_dis=pre_dis,
                **args,
            )
            place_pose = self.get_place_pose(
                actor,
                arm_tag,
                target_pose,
                functional_point_id=functional_point_id,
                pre_dis=dis,
                **args,
            )
        else:
            place_pre_pose = [0, 0, 0, 0, 0, 0, 0]
            place_pose = [0, 0, 0, 0, 0, 0, 0]

        actions = [
            Action(arm_tag, "move", target_pose=place_pre_pose),
            Action(arm_tag, "move", target_pose=place_pose),
        ]
        if is_open:
            actions.append(Action(arm_tag, "open", target_gripper_pos=1.0))
        return arm_tag, actions

    def move_by_displacement(
        self,
        arm_tag: ArmTag,
        x: float = 0.0,
        y: float = 0.0,
        z: float = 0.0,
        quat: list = None,
        move_axis: Literal["world", "arm"] = "world",
    ):
        if arm_tag == "left":
            origin_pose = np.array(self.robot.get_left_ee_pose(), dtype=np.float64)
        elif arm_tag == "right":
            origin_pose = np.array(self.robot.get_right_ee_pose(), dtype=np.float64)
        else:
            raise ValueError(f'arm_tag must be either "left" or "right", not {arm_tag}')
        displacement = np.zeros(7, dtype=np.float64)
        if move_axis == "world":
            displacement[:3] = np.array([x, y, z], dtype=np.float64)
        else:
            dir_vec = transforms._toPose(origin_pose).to_transformation_matrix()[:3, 0]
            dir_vec /= np.linalg.norm(dir_vec)
            displacement[:3] = -z * dir_vec
        origin_pose += displacement
        if quat is not None:
            origin_pose[3:] = quat
        return arm_tag, [Action(arm_tag, "move", target_pose=origin_pose)]

    def move_to_pose(
        self,
        arm_tag: ArmTag,
        target_pose: list | np.ndarray | sapien.Pose,
    ):
        return arm_tag, [Action(arm_tag, "move", target_pose=target_pose)]

    def close_gripper(self, arm_tag: ArmTag, pos: float = 0.0):
        return arm_tag, [Action(arm_tag, "close", target_gripper_pos=pos)]

    def open_gripper(self, arm_tag: ArmTag, pos: float = 1.0):
        return arm_tag, [Action(arm_tag, "open", target_gripper_pos=pos)]

    def back_to_origin(self, arm_tag: ArmTag):
        if arm_tag == "left":
            return arm_tag, [Action(arm_tag, "move", self.robot.left_original_pose)]
        elif arm_tag == "right":
            return arm_tag, [Action(arm_tag, "move", self.robot.right_original_pose)]
        return None, []

    def get_arm_pose(self, arm_tag: ArmTag):
        if arm_tag == "left":
            return self.robot.get_left_ee_pose()
        elif arm_tag == "right":
            return self.robot.get_right_ee_pose()
        else:
            raise ValueError(f'arm_tag must be either "left" or "right", not {arm_tag}')

    # =========================================================== Control Robot ===========================================================

    def take_dense_action(self, control_seq, save_freq=-1):
        """
        control_seq:
            left_arm, right_arm, left_gripper, right_gripper
        """
        left_arm, left_gripper, right_arm, right_gripper = (
            control_seq["left_arm"],
            control_seq["left_gripper"],
            control_seq["right_arm"],
            control_seq["right_gripper"],
        )

        save_freq = self.save_freq if save_freq == -1 else save_freq
        if save_freq != None:
            self._take_picture()

        max_control_len = 0

        if left_arm is not None:
            max_control_len = max(max_control_len, left_arm["position"].shape[0])
        if left_gripper is not None:
            max_control_len = max(max_control_len, left_gripper["num_step"])
        if right_arm is not None:
            max_control_len = max(max_control_len, right_arm["position"].shape[0])
        if right_gripper is not None:
            max_control_len = max(max_control_len, right_gripper["num_step"])

        for control_idx in range(max_control_len):

            if (left_arm is not None and control_idx < left_arm["position"].shape[0]):  # control left arm
                self.robot.set_arm_joints(
                    left_arm["position"][control_idx],
                    left_arm["velocity"][control_idx],
                    "left",
                )

            if left_gripper is not None and control_idx < left_gripper["num_step"]:
                self.robot.set_gripper(
                    left_gripper["result"][control_idx],
                    "left",
                    left_gripper["per_step"],
                )  # TODO

            if (right_arm is not None and control_idx < right_arm["position"].shape[0]):  # control right arm
                self.robot.set_arm_joints(
                    right_arm["position"][control_idx],
                    right_arm["velocity"][control_idx],
                    "right",
                )

            if right_gripper is not None and control_idx < right_gripper["num_step"]:
                self.robot.set_gripper(
                    right_gripper["result"][control_idx],
                    "right",
                    right_gripper["per_step"],
                )  # TODO

            self.scene.step()

            if self.render_freq and control_idx % self.render_freq == 0:
                self._update_render()
                self.viewer.render()

            if save_freq != None and control_idx % save_freq == 0:
                self._update_render()
                self._take_picture()

        if save_freq != None:
            self._take_picture()

        return True  # TODO: maybe need try error

    # def take_action(self, action, action_type='qpos'):  # action_type: qpos or ee
    #     if self.take_action_cnt == self.step_lim:
    #         return

    #     eval_video_freq = 1  # fixed
    #     if (self.eval_video_path is not None and self.take_action_cnt % eval_video_freq == 0):
    #         self.eval_video_ffmpeg.stdin.write(self.now_obs["observation"]["head_camera"]["rgb"].tobytes())

    #     self.take_action_cnt += 1
    #     print(f"step: \033[92m{self.take_action_cnt} / {self.step_lim}\033[0m", end="\r")

    #     self._update_render()
    #     if self.render_freq:
    #         self.viewer.render()

    #     actions = np.array([action])
    #     left_jointstate = self.robot.get_left_arm_jointState()
    #     right_jointstate = self.robot.get_right_arm_jointState()
    #     left_arm_dim = len(left_jointstate) - 1
    #     right_arm_dim = len(right_jointstate) - 1
    #     current_jointstate = np.array(left_jointstate + right_jointstate)

    #     left_arm_actions, left_gripper_actions, left_current_qpos, left_path = (
    #         [],
    #         [],
    #         [],
    #         [],
    #     )
    #     right_arm_actions, right_gripper_actions, right_current_qpos, right_path = (
    #         [],
    #         [],
    #         [],
    #         [],
    #     )

    #     left_arm_actions, left_gripper_actions = (
    #         actions[:, :left_arm_dim],
    #         actions[:, left_arm_dim],
    #     )
    #     right_arm_actions, right_gripper_actions = (
    #         actions[:, left_arm_dim + 1:left_arm_dim + right_arm_dim + 1],
    #         actions[:, left_arm_dim + right_arm_dim + 1],
    #     )
    #     left_current_qpos, right_current_qpos = (
    #         current_jointstate[:left_arm_dim],
    #         current_jointstate[left_arm_dim + 1:left_arm_dim + right_arm_dim + 1],
    #     )
    #     left_current_gripper, right_current_gripper = (
    #         current_jointstate[left_arm_dim:left_arm_dim + 1],
    #         current_jointstate[left_arm_dim + right_arm_dim + 1:left_arm_dim + right_arm_dim + 2],
    #     )

    #     left_path = np.vstack((left_current_qpos, left_arm_actions))
    #     left_gripper_path = np.hstack((left_current_gripper, left_gripper_actions))
    #     right_path = np.vstack((right_current_qpos, right_arm_actions))
    #     right_gripper_path = np.hstack((right_current_gripper, right_gripper_actions))

    #     # ========== TOPP ==========
    #     # TODO
    #     topp_left_flag, topp_right_flag = True, True

    #     try:
    #         times, left_pos, left_vel, acc, duration = (self.robot.left_mplib_planner.TOPP(left_path,
    #                                                                                        1 / 250,
    #                                                                                        verbose=True))
    #         left_result = dict()
    #         left_result["position"], left_result["velocity"] = left_pos, left_vel
    #         left_n_step = left_result["position"].shape[0]
    #     except Exception as e:
    #         # print("left arm TOPP error: ", e)
    #         topp_left_flag = False
    #         left_n_step = 50  # fixed

    #     if left_n_step == 0:
    #         topp_left_flag = False
    #         left_n_step = 50  # fixed

    #     try:
    #         times, right_pos, right_vel, acc, duration = (self.robot.right_mplib_planner.TOPP(right_path,
    #                                                                                           1 / 250,
    #                                                                                           verbose=True))
    #         right_result = dict()
    #         right_result["position"], right_result["velocity"] = right_pos, right_vel
    #         right_n_step = right_result["position"].shape[0]
    #     except Exception as e:
    #         # print("right arm TOPP error: ", e)
    #         topp_right_flag = False
    #         right_n_step = 50  # fixed

    #     if right_n_step == 0:
    #         topp_right_flag = False
    #         right_n_step = 50  # fixed

    #     # ========== Gripper ==========

    #     left_mod_num = left_n_step % len(left_gripper_actions)
    #     right_mod_num = right_n_step % len(right_gripper_actions)
    #     left_gripper_step = [0] + [
    #         left_n_step // len(left_gripper_actions) + (1 if i < left_mod_num else 0)
    #         for i in range(len(left_gripper_actions))
    #     ]
    #     right_gripper_step = [0] + [
    #         right_n_step // len(right_gripper_actions) + (1 if i < right_mod_num else 0)
    #         for i in range(len(right_gripper_actions))
    #     ]

    #     left_gripper = []
    #     for gripper_step in range(1, left_gripper_path.shape[0]):
    #         region_left_gripper = np.linspace(
    #             left_gripper_path[gripper_step - 1],
    #             left_gripper_path[gripper_step],
    #             left_gripper_step[gripper_step] + 1,
    #         )[1:]
    #         left_gripper = left_gripper + region_left_gripper.tolist()
    #     left_gripper = np.array(left_gripper)

    #     right_gripper = []
    #     for gripper_step in range(1, right_gripper_path.shape[0]):
    #         region_right_gripper = np.linspace(
    #             right_gripper_path[gripper_step - 1],
    #             right_gripper_path[gripper_step],
    #             right_gripper_step[gripper_step] + 1,
    #         )[1:]
    #         right_gripper = right_gripper + region_right_gripper.tolist()
    #     right_gripper = np.array(right_gripper)

    #     now_left_id, now_right_id = 0, 0

    #     # ========== Control Loop ==========
    #     while now_left_id < left_n_step or now_right_id < right_n_step:

    #         if (now_left_id < left_n_step and now_left_id / left_n_step <= now_right_id / right_n_step):
    #             if topp_left_flag:
    #                 self.robot.set_arm_joints(
    #                     left_result["position"][now_left_id],
    #                     left_result["velocity"][now_left_id],
    #                     "left",
    #                 )
    #             self.robot.set_gripper(left_gripper[now_left_id], "left")

    #             now_left_id += 1

    #         if (now_right_id < right_n_step and now_right_id / right_n_step <= now_left_id / left_n_step):
    #             if topp_right_flag:
    #                 self.robot.set_arm_joints(
    #                     right_result["position"][now_right_id],
    #                     right_result["velocity"][now_right_id],
    #                     "right",
    #                 )
    #             self.robot.set_gripper(right_gripper[now_right_id], "right")

    #             now_right_id += 1

    #         self.scene.step()
    #         self._update_render()

    #         if self.check_success():
    #             self.eval_success = True
    #             return

    #     self._update_render()
    #     if self.render_freq:  # UI
    #         self.viewer.render()
    
    def compress_path(self, path, tolerance=1e-5):
        compressed = [path[0]]
        for i in range(1, len(path)):
            if not np.allclose(path[i], compressed[-1], atol=tolerance,rtol=0):
                compressed.append(path[i])
        
        if len(compressed) == 1:
            compressed.append(compressed[0].copy() + np.random.normal(0, 2e-4, 6))  
            
        return np.array(compressed)
    
    def downsample_trajectory(self, positions, velocities, max_points=500):
       
        n_points = positions.shape[0]
        assert positions.shape[0] == velocities.shape[0]
        if n_points <= max_points:
            return positions, velocities
        
        # 使用 linspace 生成均匀分布的索引，包含起点和终点
        indices = np.linspace(0, n_points - 1, max_points).astype(int)
        
        # 去除可能的重复索引
        indices = np.unique(indices)
        
        return positions[indices], velocities[indices]
    
    
    def take_action(self, action, action_type='qpos'):  # action_type: qpos or ee
        if self.take_action_cnt == self.step_lim:
            print(f"Should not reach here self.take_action_cnt == self.step_lim", flush=True)
            return

        # eval_video_freq = 1  # fixed
        # if (self.eval_video_path is not None and self.take_action_cnt % eval_video_freq == 0):
        #     self.eval_video_ffmpeg.stdin.write(self.now_obs["observation"]["head_camera"]["rgb"].tobytes())

        self.take_action_cnt += action.shape[0]
        #print(f"step: \033[92m{self.take_action_cnt} / {self.step_lim}\033[0m", end="\r")

        self._update_render()
        if self.render_freq:
            self.viewer.render()

        #actions = np.array([action])
        actions = action
        left_jointstate = self.robot.get_left_arm_jointState()
        right_jointstate = self.robot.get_right_arm_jointState()
        left_arm_dim = len(left_jointstate) - 1
        right_arm_dim = len(right_jointstate) - 1
        current_jointstate = np.array(left_jointstate + right_jointstate)

        left_arm_actions, left_gripper_actions, left_current_qpos, left_path = (
            [],
            [],
            [],
            [],
        )
        right_arm_actions, right_gripper_actions, right_current_qpos, right_path = (
            [],
            [],
            [],
            [],
        )

        left_arm_actions, left_gripper_actions = (
            actions[:, :left_arm_dim],
            actions[:, left_arm_dim],
        )
        right_arm_actions, right_gripper_actions = (
            actions[:, left_arm_dim + 1:left_arm_dim + right_arm_dim + 1],
            actions[:, left_arm_dim + right_arm_dim + 1],
        )
        left_current_qpos, right_current_qpos = (
            current_jointstate[:left_arm_dim],
            current_jointstate[left_arm_dim + 1:left_arm_dim + right_arm_dim + 1],
        )
        left_current_gripper, right_current_gripper = (
            current_jointstate[left_arm_dim:left_arm_dim + 1],
            current_jointstate[left_arm_dim + right_arm_dim + 1:left_arm_dim + right_arm_dim + 2],
        )

        left_path = np.vstack((left_current_qpos, left_arm_actions))
        left_gripper_path = np.hstack((left_current_gripper, left_gripper_actions))
        right_path = np.vstack((right_current_qpos, right_arm_actions))
        right_gripper_path = np.hstack((right_current_gripper, right_gripper_actions))
        
        left_path = self.compress_path(left_path)
        right_path = self.compress_path(right_path)

        # ========== TOPP ==========
        # TODO
        topp_left_flag, topp_right_flag = True, True

        try:
            times, left_pos, left_vel, acc, duration = (self.robot.left_mplib_planner.TOPP(left_path,
                                                                                           1 / 250,
                                                                                           verbose=True))
            
            left_pos, left_vel = self.downsample_trajectory(left_pos, left_vel)
            
            left_result = dict()
            left_result["position"], left_result["velocity"] = left_pos, left_vel
            left_n_step = left_result["position"].shape[0]
        except Exception as e:
            # print("left arm TOPP error: ", e)
            topp_left_flag = False
            left_n_step = 50  # fixed

        if left_n_step == 0:
            topp_left_flag = False
            left_n_step = 50  # fixed

        try:
            times, right_pos, right_vel, acc, duration = (self.robot.right_mplib_planner.TOPP(right_path,
                                                                                              1 / 250,
                                                                                              verbose=True))
            
            right_pos, right_vel = self.downsample_trajectory(right_pos, right_vel)
            
            right_result = dict()
            right_result["position"], right_result["velocity"] = right_pos, right_vel
            right_n_step = right_result["position"].shape[0]
        except Exception as e:
            # print("right arm TOPP error: ", e)
            topp_right_flag = False
            right_n_step = 50  # fixed

        if right_n_step == 0:
            topp_right_flag = False
            right_n_step = 50  # fixed

        # ========== Gripper ==========

        left_mod_num = left_n_step % len(left_gripper_actions)
        right_mod_num = right_n_step % len(right_gripper_actions)
        left_gripper_step = [0] + [
            left_n_step // len(left_gripper_actions) + (1 if i < left_mod_num else 0)
            for i in range(len(left_gripper_actions))
        ]
        right_gripper_step = [0] + [
            right_n_step // len(right_gripper_actions) + (1 if i < right_mod_num else 0)
            for i in range(len(right_gripper_actions))
        ]

        left_gripper = []
        for gripper_step in range(1, left_gripper_path.shape[0]):
            region_left_gripper = np.linspace(
                left_gripper_path[gripper_step - 1],
                left_gripper_path[gripper_step],
                left_gripper_step[gripper_step] + 1,
            )[1:]
            left_gripper = left_gripper + region_left_gripper.tolist()
        left_gripper = np.array(left_gripper)

        right_gripper = []
        for gripper_step in range(1, right_gripper_path.shape[0]):
            region_right_gripper = np.linspace(
                right_gripper_path[gripper_step - 1],
                right_gripper_path[gripper_step],
                right_gripper_step[gripper_step] + 1,
            )[1:]
            right_gripper = right_gripper + region_right_gripper.tolist()
        right_gripper = np.array(right_gripper)

        now_left_id, now_right_id = 0, 0

        # ========== Control Loop ==========
        while now_left_id < left_n_step or now_right_id < right_n_step:

            if (now_left_id < left_n_step and now_left_id / left_n_step <= now_right_id / right_n_step):
                if topp_left_flag:
                    self.robot.set_arm_joints(
                        left_result["position"][now_left_id],
                        left_result["velocity"][now_left_id],
                        "left",
                    )
                self.robot.set_gripper(left_gripper[now_left_id], "left")

                now_left_id += 1

            if (now_right_id < right_n_step and now_right_id / right_n_step <= now_left_id / left_n_step):
                if topp_right_flag:
                    self.robot.set_arm_joints(
                        right_result["position"][now_right_id],
                        right_result["velocity"][now_right_id],
                        "right",
                    )
                self.robot.set_gripper(right_gripper[now_right_id], "right")

                now_right_id += 1

            self.scene.step()
            self._update_render()

            if self.check_success():
                self.eval_success = True
                return

        self._update_render()
        if self.render_freq:  # UI
            self.viewer.render()


    def save_camera_images(self, task_name, step_name, generate_num_id, save_dir="./camera_images"):
        """
        Save camera images - patched version to ensure consistent episode numbering across all steps.

        Args:
            task_name (str): Name of the task.
            step_name (str): Name of the step.
            generate_num_id (int): Generated ID used to create subfolders under the task directory.
            save_dir (str): Base directory to save images, default is './camera_images'.

        Returns:
            dict: A dictionary containing image data from each camera.
        """
        # print(f"Received generate_num_id in save_camera_images: {generate_num_id}")

        # Create a subdirectory specific to the task
        task_dir = os.path.join(save_dir, task_name)
        os.makedirs(task_dir, exist_ok=True)
        
        # Create a subdirectory for the given generate_num_id
        generate_dir = os.path.join(task_dir, generate_num_id)
        os.makedirs(generate_dir, exist_ok=True)
        
        obs = self.get_obs()
        cam_obs = obs["observation"]
        image_data = {}

        # Extract step number and description from step_name using regex
        match = re.match(r'(step[_]?\d+)(?:_(.*))?', step_name)
        if match:
            step_num = match.group(1)
            step_description = match.group(2) if match.group(2) else ""
        else:
            step_num = None
            step_description = step_name

        # Only process head_camera
        cam_name = "head_camera"
        if cam_name in cam_obs:
            rgb = cam_obs[cam_name]["rgb"]
            if rgb.dtype != np.uint8:
                rgb = (rgb * 255).clip(0, 255).astype(np.uint8)
            
            # Use the instance's ep_num as the episode number
            episode_num = getattr(self, 'ep_num', 0)
            
            # Save image to the subdirectory for the specific generate_num_id
            filename = f"episode{episode_num}_{step_num}_{step_description}.png"
            filepath = os.path.join(generate_dir, filename)
            imageio.imwrite(filepath, rgb)
            image_data[cam_name] = rgb
            
            # print(f"Saving image with episode_num={episode_num}, filename: {filename}, path: {generate_dir}")
        
        return image_data
