# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from matplotlib.pyplot import axis
import numpy as np
import os
import random
import torch
import pickle
import math

from utils.torch_jit_utils import *

# from isaacgym.torch_utils import *

from tasks.hand_base.base_task import BaseTask
from isaacgym import gymtorch
from isaacgym import gymapi

import matplotlib.pyplot as plt
from PIL import Image as Im
from utils import o3dviewer
import cv2
from torch import nn
import torch.nn.functional as F
# from utils_pb.isaac_blender_recorder import IsaacBlenderRecorder
from scipy.interpolate import interp1d

class TrajEstimator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(TrajEstimator, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, 128)
        self.output_layer = nn.Linear(128, output_dim)

        self.activate_func = nn.ELU()

    def forward(self, inputs):
        x = self.activate_func(self.linear1(inputs))
        x = self.activate_func(self.linear2(x))
        x = self.activate_func(self.linear3(x))
        outputs = self.output_layer(x)

        return outputs, x


class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_grad_enabled(self.prev)


class AllegroHandArcticMultiObjectUse01WithArm(BaseTask):
    def __init__(
        self,
        cfg,
        sim_params,
        physics_engine,
        device_type,
        device_id,
        headless,
        agent_index=[[[0, 1, 2, 3, 4, 5]], [[0, 1, 2, 3, 4, 5]]],
        is_multi_agent=False,
    ):
        self.cfg = cfg
        self.sim_params = sim_params
        self.physics_engine = physics_engine
        self.agent_index = agent_index

        self.is_multi_agent = is_multi_agent

        self.randomize = self.cfg["task"]["randomize"]
        self.randomization_params = self.cfg["task"]["randomization_params"]
        self.aggregate_mode = self.cfg["env"]["aggregateMode"]

        self.dist_reward_scale = self.cfg["env"]["distRewardScale"]
        self.rot_reward_scale = self.cfg["env"]["rotRewardScale"]
        self.action_penalty_scale = self.cfg["env"]["actionPenaltyScale"]
        self.success_tolerance = self.cfg["env"]["successTolerance"]
        self.reach_goal_bonus = self.cfg["env"]["reachGoalBonus"]
        self.fall_dist = self.cfg["env"]["fallDistance"]
        self.fall_penalty = self.cfg["env"]["fallPenalty"]
        self.rot_eps = self.cfg["env"]["rotEps"]

        self.vel_obs_scale = 0.2  # scale factor of velocity based observations
        self.force_torque_obs_scale = 10.0  # scale factor of velocity based observations

        self.reset_position_noise = self.cfg["env"]["resetPositionNoise"]
        self.reset_rotation_noise = self.cfg["env"]["resetRotationNoise"]
        self.reset_dof_pos_noise = self.cfg["env"]["resetDofPosRandomInterval"]
        self.reset_dof_vel_noise = self.cfg["env"]["resetDofVelRandomInterval"]

        self.allegro_hand_dof_speed_scale = self.cfg["env"]["dofSpeedScale"]
        self.use_relative_control = self.cfg["env"]["useRelativeControl"]
        self.act_moving_average = self.cfg["env"]["actionsMovingAverage"]

        self.debug_viz = self.cfg["env"]["enableDebugVis"]

        self.max_episode_length = self.cfg["env"]["episodeLength"]
        self.reset_time = self.cfg["env"].get("resetTime", -1.0)
        self.print_success_stat = self.cfg["env"]["printNumSuccesses"]
        self.max_consecutive_successes = self.cfg["env"]["maxConsecutiveSuccesses"]
        self.av_factor = self.cfg["env"].get("averFactor", 0.01)
        print("Averaging factor: ", self.av_factor)

        control_freq_inv = self.cfg["env"].get("controlFrequencyInv", 1)
        if self.reset_time > 0.0:
            self.max_episode_length = int(
                round(self.reset_time / (control_freq_inv * self.sim_params.dt))
            )
            print("Reset time: ", self.reset_time)
            print("New episode length: ", self.max_episode_length)

        self.object_type = self.cfg["env"]["objectType"]
        assert self.object_type in [
            "block",
            "egg",
            "pen",
            "ycb/banana",
            "ycb/can",
            "ycb/mug",
            "ycb/brick",
        ]

        self.ignore_z = self.object_type == "pen"

        self.asset_files_dict = {
            "block": "urdf/objects/cube_multicolor.urdf",
            "ball": "urdf/objects/ball.urdf",
            "egg": "mjcf/open_ai_assets/hand/egg.xml",
            "pen": "mjcf/open_ai_assets/hand/pen.xml",
            "ycb/banana": "urdf/ycb/011_banana/011_banana.urdf",
            "ycb/can": "urdf/ycb/010_potted_meat_can/010_potted_meat_can.urdf",
            "ycb/mug": "urdf/ycb/025_mug/025_mug.urdf",
            "ycb/brick": "urdf/ycb/061_foam_brick/061_foam_brick.urdf",
        }

        self.asset_files_dict = {
            "box": "arctic_assets/object_urdf/box.urdf",
            "scissors": "arctic_assets/object_urdf/scissors.urdf",
            "microwave": "arctic_assets/object_urdf/microwave.urdf",
            "laptop": "arctic_assets/object_urdf/laptop.urdf",
            "capsulemachine": "arctic_assets/object_urdf/capsulemachine.urdf",
            "ketchup": "arctic_assets/object_urdf/ketchup.urdf",
            "mixer": "arctic_assets/object_urdf/mixer.urdf",
            "notebook": "arctic_assets/object_urdf/notebook.urdf",
            "phone": "arctic_assets/object_urdf/phone.urdf",
            "waffleiron": "arctic_assets/object_urdf/waffleiron.urdf",
            "espressomachine": "arctic_assets/object_urdf/espressomachine.urdf",
        }

        # self.used_training_objects = ["box", "capsulemachine", "espressomachine", "ketchup", "laptop", "microwave", "mixer",
        #           "notebook", "phone", "scissors", "waffleiron"]
        # self.used_training_objects = ['ball', "block", "pen", "obj0", "obj1", "obj2", "obj4", "obj6", "obj7", "obj9", "obj10"]
        self.used_training_objects = self.cfg["env"]["used_training_objects"]
        self.used_hand_type = self.cfg["env"]["used_hand_type"]
        self.traj_index = self.cfg["env"]["traj_index"]
        
        if self.used_training_objects[0] == "all":
            self.used_training_objects = ["box", "capsulemachine", "espressomachine", "ketchup", "laptop", "microwave", "mixer", "notebook", "phone", "scissors", "waffleiron"]
        
        self.is_microwave_real = False
        if self.used_training_objects[0] == "microwave_real":
            self.used_training_objects = ["microwave"]
            self.is_microwave_real = True
            
        self.is_notebook_real = False
        if self.used_training_objects[0] == "notebook_real":
            self.used_training_objects = ["notebook"]
            self.is_notebook_real = True
            self.asset_files_dict["notebook"] = "arctic_assets/object_urdf/notebook_real.urdf"
    
        self.is_box_traj_laptop = False
        if self.used_training_objects[0] == "box_traj_laptop":
            self.used_training_objects = ["laptop"]
            self.is_box_traj_laptop = True
    
        # can be "openai", "full_no_vel", "full", "full_state"
        self.obs_type = self.cfg["env"]["observationType"]

        # if not (self.obs_type in ["point_cloud", "full_state"]):
        #     raise Exception(
        #         "Unknown type of observations!\nobservationType should be one of: [point_cloud, full_state]")

        print("Obs type:", self.obs_type)

        self.num_point_cloud_feature_dim = 384
        self.one_frame_num_obs = 178
        self.num_obs_dict = {
            "full_state": 571,
        }
        # self.num_obs_dict = {
        #     "point_cloud": 111 + self.num_point_cloud_feature_dim * 3,
        #     "point_cloud_for_distill": 111 + self.num_point_cloud_feature_dim * 3,
        #     "full_state": 111
        # }
        self.contact_sensor_names = ["wrist_2_link", "wrist_1_link", "shoulder_link", "upper_arm_link", "forearm_link"]
        # self.contact_sensor_names = ["ffdistal", "mfdistal", "rfdistal", "lfdistal", "thdistal"]
        # self.contact_sensor_names = ["wrist"]

        self.up_axis = 'z'

        self.use_vel_obs = False
        self.fingertip_obs = True
        self.asymmetric_obs = self.cfg["env"]["asymmetric_observations"]

        num_states = 0
        if self.asymmetric_obs:
            # num_states = 215 + 384 * 3
            num_states = 571

        self.cfg["env"]["numObservations"] = self.num_obs_dict[self.obs_type]
        self.cfg["env"]["numStates"] = num_states
        if self.is_multi_agent:
            self.num_agents = 2
            self.cfg["env"]["numActions"] = 30

        else:
            self.num_agents = 1
            self.cfg["env"]["numActions"] = 60

        self.cfg["device_type"] = device_type
        self.cfg["device_id"] = device_id
        self.cfg["headless"] = headless

        self.enable_camera_sensors = self.cfg["env"]["enableCameraSensors"]
        self.camera_debug = self.cfg["env"].get("cameraDebug", False)
        self.point_cloud_debug = self.cfg["env"].get("pointCloudDebug", False)
        self.num_envs = cfg["env"]["numEnvs"]
        
        # ablation study
        self.use_fingertip_ik = self.cfg["env"]["use_fingertip_ik"]
        self.use_joint_space_ik = self.cfg["env"]["use_joint_space_ik"]
        self.use_fingertip_reward = self.cfg["env"]["use_fingertip_reward"]
        self.use_hierarchy = self.cfg["env"]["use_hierarchy"]
        self.use_p_c_impro_loop = self.cfg["env"]["use_p_c_impro_loop"]

        if self.point_cloud_debug:
            import open3d as o3d
            from utils.o3dviewer import PointcloudVisualizer

            self.pointCloudVisualizer = PointcloudVisualizer()
            self.pointCloudVisualizerInitialized = False
            self.o3d_pc = o3d.geometry.PointCloud()
        else:
            self.pointCloudVisualizer = None

        super().__init__(cfg=self.cfg, enable_camera_sensors=self.enable_camera_sensors)

        if self.viewer != None:
            cam_pos = gymapi.Vec3(0.5, -0.0, 1.2)
            cam_target = gymapi.Vec3(-0.5, -0.0, 0.2)

            # self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

            # cam_pos = gymapi.Vec3(self.obj_params[1, 4]+ 0.5, self.obj_params[1, 5] + 0.5, self.obj_params[1, 6] + 0.5 + 0.5)
            # cam_target = gymapi.Vec3(self.obj_params[1, 4], self.obj_params[1, 5], self.obj_params[1, 6])

            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

        # get gym GPU state tensors
        actor_root_state_tensor = self.gym.acquire_actor_root_state_tensor(self.sim)
        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)
        rigid_body_tensor = self.gym.acquire_rigid_body_state_tensor(self.sim)
        contact_tensor = self.gym.acquire_net_contact_force_tensor(self.sim)
        self.jacobian_tensor = gymtorch.wrap_tensor(
            self.gym.acquire_jacobian_tensor(self.sim, "hand")
        )
        self.another_jacobian_tensor = gymtorch.wrap_tensor(
            self.gym.acquire_jacobian_tensor(self.sim, "another_hand")
        )
        
        dof_force_tensor = self.gym.acquire_dof_force_tensor(self.sim)
        self.dof_force_tensor = gymtorch.wrap_tensor(dof_force_tensor).view(self.num_envs, self.num_allegro_hand_dofs*2 + 2)

        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.refresh_dof_force_tensor(self.sim)

        # print(self.trans_l[0])
        # print(self.trans_r[0])
        # print(self.obj_params[0, 4:7])
        # exit()

        # ur10e ok bottom
        # [-1.596, -1.814, -2.055, -0.871, 1.571, 0.0]

        # create some wrapper tensors for different slices
        self.another_allegro_hand_default_dof_pos = torch.zeros(
            self.num_allegro_hand_dofs, dtype=torch.float, device=self.device
        )

        self.another_allegro_hand_default_dof_pos[:6] = to_torch(
            [3.4991441036750577, -1.310780687961321, -2.128748927522598, -2.84180679300243, -1.2157104341775433, 3.1342631916289605-3.1415],
            # [-0.6073, -1.5811,  1.8363, -0.3685,  5.7037, -6.1832],
            # [-0.6629, -1.4276,  2.1661, -3.8718,  0.6387, -3.1889],
            dtype=torch.float,
            device=self.device,
        )

        self.allegro_hand_default_dof_pos = torch.zeros(
            self.num_allegro_hand_dofs, dtype=torch.float, device=self.device
        )

        self.allegro_hand_default_dof_pos[:6] = to_torch(
            [-0.4235312584306925, -1.8417856713793022, 2.1118022259904565, -0.26705746630618066, 1.1434836562123438, -3.150733285519455],
            # [-0.0468, -1.7186,  2.1906, -2.0794,  3.2086, -4.8109],
            dtype=torch.float,
            device=self.device,
        )

        self.object_default_dof_pos = to_torch(
            [self.obj_params[0, 0, 0]], dtype=torch.float, device=self.device
        )
        
        self.dof_state = gymtorch.wrap_tensor(dof_state_tensor)
        self.allegro_hand_dof_state = self.dof_state.view(self.num_envs, -1, 2)[
            :, : self.num_allegro_hand_dofs
        ]
        self.allegro_hand_dof_pos = self.allegro_hand_dof_state[..., 0]
        self.allegro_hand_dof_vel = self.allegro_hand_dof_state[..., 1]

        self.allegro_hand_another_dof_state = self.dof_state.view(self.num_envs, -1, 2)[
            :, self.num_allegro_hand_dofs : self.num_allegro_hand_dofs * 2
        ]
        self.allegro_hand_another_dof_pos = self.allegro_hand_another_dof_state[..., 0]
        self.allegro_hand_another_dof_vel = self.allegro_hand_another_dof_state[..., 1]

        self.env_dof_state = self.dof_state.view(self.num_envs, -1, 2)

        self.object_dof_state = self.dof_state.view(self.num_envs, -1, 2)[
            :, self.num_allegro_hand_dofs * 2 : self.num_allegro_hand_dofs * 2 + 1
        ]
        self.object_dof_pos = self.object_dof_state[..., 0]
        self.object_dof_vel = self.object_dof_state[..., 1]

        self.rigid_body_states = gymtorch.wrap_tensor(rigid_body_tensor).view(self.num_envs, -1, 13)
        self.num_bodies = self.rigid_body_states.shape[1]

        self.root_state_tensor = gymtorch.wrap_tensor(actor_root_state_tensor).view(-1, 13)
        self.hand_positions = self.root_state_tensor[:, 0:3]
        self.hand_orientations = self.root_state_tensor[:, 3:7]
        self.hand_linvels = self.root_state_tensor[:, 7:10]
        self.hand_angvels = self.root_state_tensor[:, 10:13]
        self.saved_root_tensor = self.root_state_tensor.clone()

        self.contact_tensor = gymtorch.wrap_tensor(contact_tensor).view(self.num_envs, -1)

        self.num_dofs = self.gym.get_sim_dof_count(self.sim) // self.num_envs
        self.prev_targets = torch.zeros(
            (self.num_envs, self.num_dofs), dtype=torch.float, device=self.device
        )
        self.cur_targets = torch.zeros(
            (self.num_envs, self.num_dofs), dtype=torch.float, device=self.device
        )
        self.object_init_quat = torch.zeros(
            (self.num_envs, 4), dtype=torch.float, device=self.device
        )

        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat(
            (self.num_envs, 1)
        )
        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat(
            (self.num_envs, 1)
        )
        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat(
            (self.num_envs, 1)
        )

        self.reset_goal_buf = self.reset_buf.clone()
        self.successes = torch.zeros(self.num_envs, dtype=torch.float, device=self.device)
        self.consecutive_successes = torch.zeros(1, dtype=torch.float, device=self.device)

        self.av_factor = to_torch(self.av_factor, dtype=torch.float, device=self.device)
        self.object_pose_for_open_loop = torch.zeros_like(
            self.root_state_tensor[self.object_indices, 0:7]
        )

        self.total_successes = 0
        self.total_resets = 0

        self.state_buf_stack_frames = []
        self.obs_buf_stack_frames = []

        for i in range(3):
            self.obs_buf_stack_frames.append(
                torch.zeros_like(self.obs_buf[:, 0 : self.one_frame_num_obs])
            )
            self.state_buf_stack_frames.append(torch.zeros_like(self.states_buf[:, 0:215]))

        self.object_seq_len = 20
        self.object_state_stack_frames = torch.zeros(
            (self.num_envs, self.object_seq_len * 3), dtype=torch.float, device=self.device
        )

        self.proprioception_close_loop = torch.zeros_like(self.allegro_hand_dof_pos[:, 0:22])
        
        if self.used_hand_type == "shadow":
            self.another_hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.another_hand_indices[0], "wrist", gymapi.DOMAIN_ENV
            )
            self.hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.hand_indices[0], "wrist", gymapi.DOMAIN_ENV
            )
        elif self.used_hand_type == "allegro":
            self.another_hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.another_hand_indices[0], "link_0.0", gymapi.DOMAIN_ENV
            )
            self.hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.hand_indices[0], "link_0.0", gymapi.DOMAIN_ENV
            )
        elif self.used_hand_type == "schunk":
            self.another_hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.another_hand_indices[0], "left_hand_k", gymapi.DOMAIN_ENV
            )
            self.hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.hand_indices[0], "right_hand_k", gymapi.DOMAIN_ENV
            )
        elif self.used_hand_type == "ability":
            self.another_hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.another_hand_indices[0], "index_L1", gymapi.DOMAIN_ENV
            )
            self.hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(
                self.envs[0], self.hand_indices[0], "index_L1", gymapi.DOMAIN_ENV
            )
        print("hand_base_rigid_body_index: ", self.hand_base_rigid_body_index)
        print("another_hand_base_rigid_body_index: ", self.another_hand_base_rigid_body_index)
        # exit()
        # with open("./demo_throw.pkl", "rb") as f:
        #     self.demo_throw = pickle.load(f)

        # print(self.demo_throw)
        # # self.demo_throw = to_torch(self.demo_throw['qpos'], dtype=torch.float, device=self.device).unsqueeze(0).repeat(self.num_envs, 1, 1)
        # self.demo_throw = to_torch(self.demo_throw['qpos'], dtype=torch.float, device=self.device)
        self.rb_forces = torch.zeros(
            (self.num_envs, self.num_bodies, 3), dtype=torch.float, device=self.device
        )
        object_rb_count = self.gym.get_asset_rigid_body_count(self.object_asset)
        self.object_rb_handles = 94
        self.perturb_direction = torch_rand_float(
            -1, 1, (self.num_envs, 6), device=self.device
        ).squeeze(-1)

        self.predict_pose = self.goal_init_state[:, 0:3].clone()

        self.apply_forces = torch.zeros(
            (self.num_envs, self.num_bodies, 3), device=self.device, dtype=torch.float
        )
        self.apply_torque = torch.zeros(
            (self.num_envs, self.num_bodies, 3), device=self.device, dtype=torch.float
        )

        self.r_pos_global_init = self.trans_r[:, 0].clone()
        self.r_rot_global_init = self.rot_r_quat[:, 0].clone()
        self.l_pos_global_init = self.trans_l[:, 0].clone()
        self.l_rot_global_init = self.rot_l_quat[:, 0].clone()
        self.obj_pos_global_init = self.obj_params[:, 0, 4:7]
        self.obj_rot_global_init = self.obj_rot_quat[:, 0, 0:4].clone()
        self.obj_joint_init = self.obj_params[:, 0, 0:1].clone()

        self.max_episode_length = self.trans_r.shape[1]
        self.init_step_buf = torch.zeros_like(self.progress_buf)
        self.end_step_buf = torch.zeros_like(self.progress_buf)

        self.last_actions = torch.zeros(
            (self.num_envs, self.num_actions), device=self.device, dtype=torch.float
        )

        self.allegro_right_hand_pos = self.rigid_body_states[:, 6, 0:3]
        self.allegro_right_hand_rot = self.rigid_body_states[:, 6, 3:7]
        self.allegro_left_hand_pos = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 0:3]
        self.allegro_left_hand_rot = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 3:7]

        self.use_calibrated_init_state = False
        if self.use_calibrated_init_state:
            with open("./calibrated_init_state.pkl", "rb") as f:
                self.init_dof_state = pickle.load(f).clone()

        
        self.train_teacher_policy = True
        self.apply_perturbation = False

        self.sim2real_record = False
        if self.sim2real_record:
            import pickle
            self.data = {"left_hand_targets": [], "right_hand_targets": [], "right_hand_ee_pos": [], "right_hand_ee_rot": [], "r_pos_global": [], "r_rot_global": [], "left_hand_ee_pos": [], "left_hand_ee_rot": [], "l_pos_global": [], "l_rot_global": [], "obs_buf": []}

        self.test_record = False

        self.load_sim_record = False
        if self.load_sim_record:
            file_path = ''
            import pickle
            # Open the file in binary read mode
            with open(file_path, 'rb') as file:
                data = pickle.load(file)
            
            self.f_actor_root_state_tensor = data["root_state_tensor"]
            self.f_dof_state_tensor = data["dof_state"]


        self.r2 = torch.tensor([[-0.999999,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0,  0.999999],
         [ 0.0000e+00,  0.999999,  0.0]], device=self.device)
        self.sim_to_real_rotation_quaternion = torch.tensor([[0.0000, 0.0000, -0.7071, 0.7071]], device=self.device)
        self.sim_to_real_translation_matrix = torch.tensor([[0.0, -1.0, -0.0, -0.0],
                                            [1.0, 0.0, 0.0, 0.0],
                                            [0.0, -0.0, 1.0, 0],
                                            [0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=self.device)
        self.sim_to_real_object_quaternion = torch.tensor([[-0.0000, -0.0000, -0.3825,  0.9240]], device=self.device)

        self.allegro_right_hand_targets_pos_real = torch.tensor([[0.54, 0.0, 0.5]], device=self.device).repeat(self.num_envs, 1)
        self.allegro_right_hand_targets_rot_real = torch.tensor([[0.5,  -0.5, 0.5, -0.5]], device=self.device).repeat(self.num_envs, 1)
        self.allegro_left_hand_targets_pos_real = torch.tensor([[0.54, -0.0,  0.5]], device=self.device).repeat(self.num_envs, 1)
        self.allegro_left_hand_targets_rot_real = torch.tensor([[0.5,  -0.5, 0.5, -0.5]], device=self.device).repeat(self.num_envs, 1)

        self.complete_percentage = torch.zeros_like(self.progress_buf).float()
        
        self.left_hand_fingertip_pos_list = torch.zeros((5, self.num_envs, 3), device=self.device, dtype=torch.float32)
        self.right_hand_fingertip_pos_list = torch.zeros((5, self.num_envs, 3), device=self.device, dtype=torch.float32)

        if self.use_fingertip_ik:
            from high_level_planner.pybullet_ik_solver import PybulletIKSolver
            self.ik_solver = PybulletIKSolver("", "", hand_type="shadow")
        elif self.use_joint_space_ik:
            from high_level_planner.anyteleop_solver import TestOptimizer
            self.ik_solver = TestOptimizer(hand_type="shadow")

    def create_sim(self):
        self.dt = self.sim_params.dt
        self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, self.up_axis)
        # self.sim_params.physx.max_gpu_contact_pairs = self.sim_params.physx.max_gpu_contact_pairs

        self.sim = super().create_sim(
            self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params
        )
        self.create_object_asset_dict(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../assets')
        )
        self._create_ground_plane()
        self._create_envs(self.num_envs, self.cfg["env"]['envSpacing'], int(np.sqrt(self.num_envs)))

    def _create_ground_plane(self):
        plane_params = gymapi.PlaneParams()
        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)
        self.gym.add_ground(self.sim, plane_params)

    def create_object_asset_dict(self, asset_root):
        self.object_asset_dict = {}
        print("ENTER ASSET CREATING!")
        for used_objects in self.used_training_objects:
            object_asset_file = self.asset_files_dict[used_objects]
            object_asset_options = gymapi.AssetOptions()
            object_asset_options.density = 1000
            object_asset_options.fix_base_link = False
            object_asset_options.flip_visual_attachments = False
            object_asset_options.collapse_fixed_joints = True
            object_asset_options.disable_gravity = False
            object_asset_options.thickness = 0.001
            object_asset_options.angular_damping = 0.01
            object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
            object_asset_options.override_com = True
            object_asset_options.override_inertia = True
            object_asset_options.vhacd_enabled = True
            object_asset_options.vhacd_params = gymapi.VhacdParams()
            object_asset_options.vhacd_params.resolution = 100000
            
            self.object_asset = self.gym.load_asset(
                self.sim, asset_root, object_asset_file, object_asset_options
            )

            object_asset_file = self.asset_files_dict[used_objects]
            object_asset_options = gymapi.AssetOptions()
            object_asset_options.density = 2000
            object_asset_options.disable_gravity = True
            object_asset_options.fix_base_link = True

            goal_asset = self.gym.load_asset(
                self.sim, asset_root, object_asset_file, object_asset_options
            )

            predict_goal_asset = self.gym.load_asset(
                self.sim, asset_root, object_asset_file, object_asset_options
            )

            if self.is_box_traj_laptop:
                used_objects = "box"

            self.object_asset_dict[used_objects] = {
                'obj': self.object_asset,
                'goal': goal_asset,
                'predict goal': predict_goal_asset,
            }

    def _create_envs(self, num_envs, spacing, num_per_row):
        lower = gymapi.Vec3(-spacing, -spacing, 0.0)
        upper = gymapi.Vec3(spacing, spacing, spacing)

        asset_root = "../assets"

        if self.used_hand_type == "shadow":
            allegro_hand_asset_file = "urdf/shadow_hand_description/ur10e_shadowhand_right_digital_twin.urdf"
            allegro_hand_another_asset_file = "urdf/shadow_hand_description/ur10e_shadowhand_left_digital_twin.urdf"
            if self.use_fingertip_ik:
                allegro_hand_asset_file = "urdf/shadow_hand_description/ur10e_shadowhand_right_fingertip_ik.urdf"
                allegro_hand_another_asset_file = "urdf/shadow_hand_description/ur10e_shadowhand_left_fingertip_ik.urdf"
                
        elif self.used_hand_type == "allegro":
            allegro_hand_asset_file = "urdf/shadow_hand_description/ur10e_allegrohand_right_digital_twin.urdf"
            allegro_hand_another_asset_file = "urdf/shadow_hand_description/ur10e_allegrohand_left_digital_twin.urdf"
        elif self.used_hand_type == "schunk":
            allegro_hand_asset_file = "urdf/shadow_hand_description/ur10e_schunkhand_right_digital_twin.urdf"
            allegro_hand_another_asset_file = "urdf/shadow_hand_description/ur10e_schunkhand_left_digital_twin.urdf"
        elif self.used_hand_type == "ability":
            allegro_hand_asset_file = "urdf/shadow_hand_description/ur10e_abilityhand_right_digital_twin.urdf"
            allegro_hand_another_asset_file = "urdf/shadow_hand_description/ur10e_abilityhand_left_digital_twin.urdf"
        else:
            raise Exception(
        "Unrecognized hand type!\Hand type should be one of: [shadow, allegro, schunk, ability]"
    )
        # load shadow hand_ asset
        asset_options = gymapi.AssetOptions()
        asset_options.flip_visual_attachments = False
        asset_options.fix_base_link = True
        asset_options.collapse_fixed_joints = True
        asset_options.disable_gravity = True
        asset_options.thickness = 0.001
        asset_options.angular_damping = 0.01
        # asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        # asset_options.override_com = True
        # asset_options.override_inertia = True
        # asset_options.vhacd_enabled = True
        # asset_options.vhacd_params = gymapi.VhacdParams()
        # asset_options.vhacd_params.resolution = 100000
        # asset_options.default_dof_drive_mode = gymapi.DOF_MODE_EFFORT

        if self.physics_engine == gymapi.SIM_PHYSX:
            asset_options.use_physx_armature = True
        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_POS
        allegro_hand_asset = self.gym.load_asset(
            self.sim, asset_root, allegro_hand_asset_file, asset_options
        )
        allegro_hand_another_asset = self.gym.load_asset(
            self.sim, asset_root, allegro_hand_another_asset_file, asset_options
        )

        self.num_allegro_hand_bodies = self.gym.get_asset_rigid_body_count(allegro_hand_asset)
        self.num_allegro_hand_shapes = self.gym.get_asset_rigid_shape_count(allegro_hand_asset)
        self.num_allegro_hand_dofs = self.gym.get_asset_dof_count(allegro_hand_asset)
        self.num_allegro_hand_actuators = self.gym.get_asset_dof_count(allegro_hand_asset)
        self.num_allegro_hand_tendons = self.gym.get_asset_tendon_count(allegro_hand_asset)

        print("self.num_allegro_hand_bodies: ", self.num_allegro_hand_bodies)
        print("self.num_allegro_hand_shapes: ", self.num_allegro_hand_shapes)
        print("self.num_allegro_hand_dofs: ", self.num_allegro_hand_dofs)
        print("self.num_allegro_hand_actuators: ", self.num_allegro_hand_actuators)
        print("self.num_allegro_hand_tendons: ", self.num_allegro_hand_tendons)

        self.actuated_dof_indices = [i for i in range(16)]

        # set allegro_hand dof properties
        allegro_hand_dof_props = self.gym.get_asset_dof_properties(allegro_hand_asset)
        allegro_hand_another_dof_props = self.gym.get_asset_dof_properties(
            allegro_hand_another_asset
        )

        self.allegro_hand_dof_lower_limits = []
        self.allegro_hand_dof_upper_limits = []
        self.a_allegro_hand_dof_lower_limits = []
        self.a_allegro_hand_dof_upper_limits = []
        self.allegro_hand_dof_default_pos = []
        self.allegro_hand_dof_default_vel = []
        self.allegro_hand_dof_stiffness = []
        self.allegro_hand_dof_damping = []
        self.allegro_hand_dof_effort = []
        self.sensors = []
        sensor_pose = gymapi.Transform()

        for i in range(self.num_allegro_hand_dofs):
            self.allegro_hand_dof_lower_limits.append(allegro_hand_dof_props['lower'][i])
            self.allegro_hand_dof_upper_limits.append(allegro_hand_dof_props['upper'][i])
            self.a_allegro_hand_dof_lower_limits.append(allegro_hand_another_dof_props['lower'][i])
            self.a_allegro_hand_dof_upper_limits.append(allegro_hand_another_dof_props['upper'][i])
            self.allegro_hand_dof_default_pos.append(0.0)
            self.allegro_hand_dof_default_vel.append(0.0)

            allegro_hand_dof_props['driveMode'][i] = gymapi.DOF_MODE_NONE
            allegro_hand_another_dof_props['driveMode'][i] = gymapi.DOF_MODE_NONE
            if i < 6:
                allegro_hand_dof_props['stiffness'][i] = 1000
                allegro_hand_dof_props['effort'][i] = 2000
                allegro_hand_dof_props['damping'][i] = 100
                allegro_hand_dof_props['velocity'][i] = 4
                allegro_hand_another_dof_props['stiffness'][i] = 1000
                allegro_hand_another_dof_props['effort'][i] = 2000
                allegro_hand_another_dof_props['damping'][i] = 100
                allegro_hand_another_dof_props['velocity'][i] = 4
                if self.use_fingertip_ik or self.use_joint_space_ik:
                    allegro_hand_dof_props['stiffness'][i] = 4000
                    allegro_hand_dof_props['effort'][i] = 2000
                    allegro_hand_dof_props['damping'][i] = 100
                    allegro_hand_dof_props['velocity'][i] = 4
                    allegro_hand_another_dof_props['stiffness'][i] = 4000
                    allegro_hand_another_dof_props['effort'][i] = 2000
                    allegro_hand_another_dof_props['damping'][i] = 100
                    allegro_hand_another_dof_props['velocity'][i] = 4

            else:
                allegro_hand_dof_props['velocity'][i] = 3.0
                allegro_hand_dof_props['stiffness'][i] = 30
                allegro_hand_dof_props['effort'][i] = 5
                allegro_hand_dof_props['damping'][i] = 1
                allegro_hand_another_dof_props['velocity'][i] = 3.0
                allegro_hand_another_dof_props['stiffness'][i] = 30
                allegro_hand_another_dof_props['effort'][i] = 5
                allegro_hand_another_dof_props['damping'][i] = 1

            if self.used_hand_type == "shadow":
                if 8 > i > 6:
                    allegro_hand_dof_props['velocity'][i] = 3.0
                    allegro_hand_dof_props['stiffness'][i] = 150
                    allegro_hand_dof_props['effort'][i] = 25
                    allegro_hand_dof_props['damping'][i] = 10
                    allegro_hand_another_dof_props['velocity'][i] = 3.0
                    allegro_hand_another_dof_props['stiffness'][i] = 150
                    allegro_hand_another_dof_props['effort'][i] = 25
                    allegro_hand_another_dof_props['damping'][i] = 10

        self.actuated_dof_indices = to_torch(
            self.actuated_dof_indices, dtype=torch.long, device=self.device
        )
        self.allegro_hand_dof_lower_limits = to_torch(
            self.allegro_hand_dof_lower_limits, device=self.device
        )
        self.allegro_hand_dof_upper_limits = to_torch(
            self.allegro_hand_dof_upper_limits, device=self.device
        )
        self.a_allegro_hand_dof_lower_limits = to_torch(
            self.a_allegro_hand_dof_lower_limits, device=self.device
        )
        self.a_allegro_hand_dof_upper_limits = to_torch(
            self.a_allegro_hand_dof_upper_limits, device=self.device
        )
        self.allegro_hand_dof_default_pos = to_torch(
            self.allegro_hand_dof_default_pos, device=self.device
        )
        self.allegro_hand_dof_default_vel = to_torch(
            self.allegro_hand_dof_default_vel, device=self.device
        )

        self.object_name = self.used_training_objects
                
        if self.traj_index == "all":
            self.functional = ["use", "grab"]
            used_seq_list = ["01", "02", "04", "07", "08", "09"]
            used_sub_seq_list = ["01", "02", "03", "04"]
        elif self.traj_index == "unseen":
            self.functional = ["use", "grab"]
            used_seq_list = ["10", "08"]
            used_sub_seq_list = ["01", "02", "03", "04"]      
            
            
        from high_level_planner.data_utils import DataLoader
        self.interpolate_time = 1
        if self.is_microwave_real:
            self.interpolate_time = 2
        if self.is_notebook_real:
            self.interpolate_time = 1
        if self.is_box_traj_laptop:
            self.object_name = ["box"]
            
        self.dl = DataLoader(self.gym, self.sim, used_seq_list, self.functional, used_sub_seq_list, self.object_name, self.use_fingertip_ik, self.use_joint_space_ik, self.device, interpolate_time=self.interpolate_time)
        self.seq_list, self.texture_list, self.obj_name_seq = self.dl.load_arctic_data(is_microwave_real=self.is_microwave_real, is_notebook_real=self.is_notebook_real)


        print("seq_num: ", len(self.seq_list))
        # exit()
        self.seq_list_i = [i for i in range(len(self.seq_list))]
        self.traj_len = 1000
        if self.is_microwave_real:
            self.traj_len = 200
        if self.is_notebook_real:
            self.traj_len = 200
            
        self.rot_r = torch.zeros((self.num_envs, self.traj_len, 3), device=self.device, dtype=torch.float)
        self.trans_r = torch.zeros((self.num_envs, self.traj_len, 3), device=self.device, dtype=torch.float)
        self.rot_l = torch.zeros((self.num_envs, self.traj_len, 3), device=self.device, dtype=torch.float)
        self.trans_l = torch.zeros((self.num_envs, self.traj_len, 3), device=self.device, dtype=torch.float)
        self.obj_params = torch.zeros(
            (self.num_envs, self.traj_len, 7), device=self.device, dtype=torch.float
        )
        self.obj_rot_quat = torch.zeros(
            (self.num_envs, self.traj_len, 4), device=self.device, dtype=torch.float
        )
        self.rot_r_quat = torch.zeros(
            (self.num_envs, self.traj_len, 4), device=self.device, dtype=torch.float
        )
        self.rot_l_quat = torch.zeros(
            (self.num_envs, self.traj_len, 4), device=self.device, dtype=torch.float
        )
        self.left_fingertip = torch.zeros(
            (self.num_envs, self.traj_len, 15), device=self.device, dtype=torch.float
        )
        self.right_fingertip = torch.zeros(
            (self.num_envs, self.traj_len, 15), device=self.device, dtype=torch.float
        )
        self.left_middle_finger = torch.zeros(
            (self.num_envs, self.traj_len, 15), device=self.device, dtype=torch.float
        )
        self.right_middle_finger = torch.zeros(
            (self.num_envs, self.traj_len, 15), device=self.device, dtype=torch.float
        )
        
        for i in range(self.num_envs):
            seq_idx = i % len(self.seq_list)
            self.seq_idx_tensor = to_torch([range(self.num_envs)], dtype=int, device=self.device)
            self.rot_r[i] = self.seq_list[seq_idx]["rot_r"][:self.traj_len].clone()
            self.trans_r[i] = self.seq_list[seq_idx]["trans_r"][:self.traj_len].clone()
            self.rot_l[i] = self.seq_list[seq_idx]["rot_l"][:self.traj_len].clone()
            self.trans_l[i] = self.seq_list[seq_idx]["trans_l"][:self.traj_len].clone()
            self.obj_params[i] = self.seq_list[seq_idx]["obj_params"][:self.traj_len].clone()
            self.obj_rot_quat[i] = self.seq_list[seq_idx]["obj_rot_quat"][:self.traj_len].clone()
            self.rot_r_quat[i] = self.seq_list[seq_idx]["rot_r_quat"][:self.traj_len].clone()
            self.rot_l_quat[i] = self.seq_list[seq_idx]["rot_l_quat"][:self.traj_len].clone()
            self.left_fingertip[i] = self.seq_list[seq_idx]["left_fingertip"][:self.traj_len].clone()
            self.right_fingertip[i] = self.seq_list[seq_idx]["right_fingertip"][:self.traj_len].clone()
            self.left_middle_finger[i] = self.seq_list[seq_idx]["left_middle_finger"][:self.traj_len].clone()
            self.right_middle_finger[i] = self.seq_list[seq_idx]["right_middle_finger"][:self.traj_len].clone()
            
        if self.used_training_objects[0] == "box":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "espressomachine":
            self.object_joint_tolerate = 0.05
            self.object_joint_reset = 0.2
        elif self.used_training_objects[0] == "ketchup":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "mixer":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "phone":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "scissors":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "waffleiron":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        elif self.used_training_objects[0] == "notebook":
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
        else:
            self.object_joint_tolerate = 0.1
            self.object_joint_reset = 0.5
            
        # load manipulated object and goal assets
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500

        self.object_radius = 0.06
        object_asset = self.gym.create_sphere(self.sim, 0.12, object_asset_options)

        object_asset_options.disable_gravity = True
        goal_asset = self.gym.create_sphere(self.sim, 0.04, object_asset_options)

        # allegro_hand_start_pose = gymapi.Transform()
        # allegro_hand_start_pose.p = gymapi.Vec3(-self.obj_params[1, 4] - 0.84, -self.obj_params[1, 5] - 0.64, -self.obj_params[1, 6] + 1.5)
        # allegro_hand_start_pose.r = gymapi.Quat().from_euler_zyx(-self.obj_params[1, 1] - 1.571, -self.obj_params[1, 2] - 1.571, -self.obj_params[1, 3])

        # allegro_another_hand_start_pose = gymapi.Transform()
        # allegro_another_hand_start_pose.p = gymapi.Vec3(-self.obj_params[1, 4] - 0.84, -self.obj_params[1, 5] - 0.64, -self.obj_params[1, 6] + 1.5)
        # allegro_another_hand_start_pose.r = gymapi.Quat().from_euler_zyx(-self.obj_params[1, 1] - 1.571, -self.obj_params[1, 2] - 1.571, -self.obj_params[1, 3])

        # object_start_pose = gymapi.Transform()
        # object_start_pose.p = gymapi.Vec3(-self.obj_params[1, 4] - 0.84, -self.obj_params[1, 5] - 0.64, -self.obj_params[1, 6] + 1.5)
        # object_start_pose.r = gymapi.Quat().from_euler_zyx(-self.obj_params[1, 1] - 1.571, -self.obj_params[1, 2] - 1.571, -self.obj_params[1, 3])
        # object_start_pose.r = gymapi.Quat().from_euler_zyx(-self.obj_params[1, 1] - 1.571, -self.obj_params[1, 2] - 1.571, -self.obj_params[1, 3])

        allegro_hand_start_pose = gymapi.Transform()
        allegro_hand_start_pose.p = gymapi.Vec3(-0.5, 0.95, 0.7)
        allegro_hand_start_pose.r = gymapi.Quat().from_euler_zyx(0, 0, 1.571)

        allegro_another_hand_start_pose = gymapi.Transform()
        allegro_another_hand_start_pose.p = gymapi.Vec3(0.5, 0.95, 0.7)
        allegro_another_hand_start_pose.r = gymapi.Quat().from_euler_zyx(0, 0, 1.571)

        object_start_pose = gymapi.Transform()
        object_start_pose.p = gymapi.Vec3(0, -0.25, 0)
        object_start_pose.r = gymapi.Quat().from_euler_zyx(0, 0, 0.0)

        self.goal_displacement = gymapi.Vec3(-0.0, 0.0, 0.0)
        self.goal_displacement_tensor = to_torch(
            [self.goal_displacement.x, self.goal_displacement.y, self.goal_displacement.z],
            device=self.device,
        )
        goal_start_pose = gymapi.Transform()
        goal_start_pose.p = object_start_pose.p + self.goal_displacement

        goal_start_pose.p.z -= 0.0

        # create table asset
        table_dims = gymapi.Vec3(1.5, 2.2, 0.76)
        table_asset_options = gymapi.AssetOptions()
        table_asset_options.fix_base_link = True
        table_asset_options.flip_visual_attachments = True
        table_asset_options.collapse_fixed_joints = True
        table_asset_options.disable_gravity = True
        table_asset_options.thickness = 0.001

        table_asset = self.gym.create_box(
            self.sim, table_dims.x, table_dims.y, table_dims.z, table_asset_options
        )
        table_pose = gymapi.Transform()
        table_pose.p = gymapi.Vec3(0.0, 0.0, 0.5 * table_dims.z)
        table_pose.r = gymapi.Quat().from_euler_zyx(-0, 0, 0.0)

        # create support box asset
        support_box_dims = gymapi.Vec3(0.15, 0.15, 0.20)
        support_box_asset_options = gymapi.AssetOptions()
        support_box_asset_options.fix_base_link = True
        support_box_asset_options.flip_visual_attachments = True
        support_box_asset_options.collapse_fixed_joints = True
        support_box_asset_options.disable_gravity = True
        support_box_asset_options.thickness = 0.001

        support_box_asset = self.gym.create_box(
            self.sim,
            support_box_dims.x,
            support_box_dims.y,
            support_box_dims.z,
            support_box_asset_options,
        )
        support_box_pose = gymapi.Transform()
        support_box_pose.p = gymapi.Vec3(0.0, -0.10, 0.5 * (2 * table_dims.z + support_box_dims.z))
        if self.is_box_traj_laptop:
            support_box_pose.p = gymapi.Vec3(0.05, -0.15, 0.5 * (2 * table_dims.z + support_box_dims.z) + 0.05)
        if self.used_training_objects[0] == "scissors":
            support_box_pose.p = gymapi.Vec3(0.0, 0.00, 0.5 * (2 * table_dims.z + support_box_dims.z))
        if self.used_training_objects[0] in ["espressomachine", "microwave"]:
            support_box_pose.p = gymapi.Vec3(0.1, -0.10, 0.5 * (2 * table_dims.z + support_box_dims.z))
        support_box_pose.r = gymapi.Quat().from_euler_zyx(-0, 0, 0.0)

        # compute aggregate size
        max_agg_bodies = self.num_allegro_hand_bodies * 2 + 2 + 50
        max_agg_shapes = self.num_allegro_hand_shapes * 2 + 2 + 50

        self.allegro_hands = []
        self.envs = []

        self.object_init_state = []
        self.hand_start_states = []
        self.another_hand_start_states = []

        self.hand_indices = []
        self.another_hand_indices = []
        self.fingertip_indices = []
        self.object_indices = []
        self.goal_object_indices = []
        self.predict_goal_object_indices = []

        self.table_indices = []
        self.support_box_indices = []

        if self.enable_camera_sensors:
            self.cameras = []
            self.camera_tensors = []
            self.camera_view_matrixs = []
            self.camera_proj_matrixs = []

            self.camera_props = gymapi.CameraProperties()
            self.camera_props.width = 1920
            self.camera_props.height = 1080
            self.camera_props.enable_tensors = True

            self.env_origin = torch.zeros((self.num_envs, 3), device=self.device, dtype=torch.float)
            self.pointCloudDownsampleNum = 384
            self.camera_u = torch.arange(0, self.camera_props.width, device=self.device)
            self.camera_v = torch.arange(0, self.camera_props.height, device=self.device)
            self.point_clouds = torch.zeros(
                (self.num_envs, self.pointCloudDownsampleNum, 3), device=self.device
            )

            self.camera_v2, self.camera_u2 = torch.meshgrid(
                self.camera_v, self.camera_u, indexing='ij'
            )

            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            
            self.video_out_list = []
            for i in range(len(self.seq_list)):
                save_video_dir = ''
                if not os.path.exists(save_video_dir):
                    os.makedirs(save_video_dir)
                    
                self.out = cv2.VideoWriter(save_video_dir + '{}_{}_{}.mp4'.format(self.used_hand_type, self.used_training_objects[0], i), fourcc, 30.0, (1920, 1080))
                self.video_out_list.append(self.out)
                
            if self.point_cloud_debug:
                import open3d as o3d
                from utils.o3dviewer import PointcloudVisualizer

                self.pointCloudVisualizer = PointcloudVisualizer()
                self.pointCloudVisualizerInitialized = False
                self.o3d_pc = o3d.geometry.PointCloud()
            else:
                self.pointCloudVisualizer = None

        import open3d as o3d

        self.origin_point_clouds = torch.zeros((self.num_envs, 10000, 3), device=self.device)
        self.pointCloudDownsampleNum = 384
        self.point_clouds = torch.zeros(
            (self.num_envs, self.pointCloudDownsampleNum, 3), device=self.device
        )

        for i in range(self.num_envs):
            # create env instance
            env_ptr = self.gym.create_env(self.sim, lower, upper, num_per_row)

            if self.aggregate_mode >= 1:
                self.gym.begin_aggregate(env_ptr, max_agg_bodies, max_agg_shapes, True)

            # add hand - collision filter = -1 to use asset collision filters set in mjcf loader
            allegro_hand_actor = self.gym.create_actor(
                env_ptr, allegro_hand_asset, allegro_hand_start_pose, "hand", i, 0, 0
            )
            allegro_hand_another_actor = self.gym.create_actor(
                env_ptr,
                allegro_hand_another_asset,
                allegro_another_hand_start_pose,
                "another_hand",
                i,
                0,
                0,
            )

            self.hand_start_states.append(
                [
                    allegro_hand_start_pose.p.x,
                    allegro_hand_start_pose.p.y,
                    allegro_hand_start_pose.p.z,
                    allegro_hand_start_pose.r.x,
                    allegro_hand_start_pose.r.y,
                    allegro_hand_start_pose.r.z,
                    allegro_hand_start_pose.r.w,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                ]
            )

            self.another_hand_start_states.append(
                [
                    allegro_another_hand_start_pose.p.x,
                    allegro_another_hand_start_pose.p.y,
                    allegro_another_hand_start_pose.p.z,
                    allegro_another_hand_start_pose.r.x,
                    allegro_another_hand_start_pose.r.y,
                    allegro_another_hand_start_pose.r.z,
                    allegro_another_hand_start_pose.r.w,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                ]
            )

            self.gym.set_actor_dof_properties(env_ptr, allegro_hand_actor, allegro_hand_dof_props)
            hand_idx = self.gym.get_actor_index(env_ptr, allegro_hand_actor, gymapi.DOMAIN_SIM)
            self.hand_indices.append(hand_idx)

            self.gym.set_actor_dof_properties(
                env_ptr, allegro_hand_another_actor, allegro_hand_another_dof_props
            )
            another_hand_idx = self.gym.get_actor_index(
                env_ptr, allegro_hand_another_actor, gymapi.DOMAIN_SIM
            )
            self.another_hand_indices.append(another_hand_idx)

            self.gym.enable_actor_dof_force_sensors(env_ptr, allegro_hand_actor)
            self.gym.enable_actor_dof_force_sensors(env_ptr, allegro_hand_another_actor)
            
            # randomize colors and textures for rigid body
            num_bodies = self.gym.get_actor_rigid_body_count(env_ptr, allegro_hand_actor)
            hand_rigid_body_index = [
                [0, 1, 2, 3],
                [4, 5, 6, 7],
                [8, 9, 10, 11],
                [12, 13, 14, 15],
                [16, 17, 18, 19, 20],
                [21, 22, 23, 24, 25],
            ]

            # add object
            index = i % len(self.obj_name_seq)
            select_obj = self.obj_name_seq[index]

            object_handle = self.gym.create_actor(
                env_ptr,
                self.object_asset_dict[select_obj]['obj'],
                object_start_pose,
                "object",
                i,
                0,
                0,
            )

            # object_handle = self.gym.create_actor(env_ptr, object_asset, object_start_pose, "object", i, 0, 0)
            self.object_init_state.append(
                [
                    object_start_pose.p.x,
                    object_start_pose.p.y,
                    object_start_pose.p.z,
                    object_start_pose.r.x,
                    object_start_pose.r.y,
                    object_start_pose.r.z,
                    object_start_pose.r.w,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                ]
            )
            object_idx = self.gym.get_actor_index(env_ptr, object_handle, gymapi.DOMAIN_SIM)
            self.gym.set_rigid_body_texture(
                env_ptr, object_handle, 0, gymapi.MESH_VISUAL, self.texture_list[index]
            )
            self.gym.set_rigid_body_texture(
                env_ptr, object_handle, 1, gymapi.MESH_VISUAL, self.texture_list[index]
            )
            self.object_indices.append(object_idx)

            lego_body_props = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
            for lego_body_prop in lego_body_props:
                lego_body_prop.mass *= 1.0
                if self.used_training_objects[0] == "notebook":
                    lego_body_prop.mass *= 1.5
            self.gym.set_actor_rigid_body_properties(env_ptr, object_handle, lego_body_props)

            object_dof_props = self.gym.get_actor_dof_properties(env_ptr, object_handle)
            for object_dof_prop in object_dof_props:
                object_dof_prop[4] = 100
                object_dof_prop[5] = 50
                object_dof_prop[6] = 5
                object_dof_prop[7] = 1
            self.gym.set_actor_dof_properties(env_ptr, object_handle, object_dof_props)

            object_shape_props = self.gym.get_actor_rigid_shape_properties(env_ptr, object_handle)
            for object_shape_prop in object_shape_props:
                object_shape_prop.restitution = 0.0
                # object_shape_prop.friction = 0
            self.gym.set_actor_rigid_shape_properties(env_ptr, object_handle, object_shape_props)
            
            
            # hand_shape_props = self.gym.get_actor_rigid_shape_properties(
            #     env_ptr, allegro_hand_actor
            # )
            # for hand_shape_prop in hand_shape_props:
            #     hand_shape_prop.restitution = 0.0
            #     # hand_shape_prop.friction = 0
            # self.gym.set_actor_rigid_shape_properties(env_ptr, allegro_hand_actor, hand_shape_props)

            # another_hand_shape_props = self.gym.get_actor_rigid_shape_properties(
            #     env_ptr, allegro_hand_another_actor
            # )
            # for another_hand_shape_prop in another_hand_shape_props:
            #     another_hand_shape_prop.restitution = 0.0
            #     # another_hand_shape_prop.friction = 0
            # self.gym.set_actor_rigid_shape_properties(env_ptr, allegro_hand_another_actor, another_hand_shape_props)


            # generate offline point cloud
            # pcd = o3d.io.read_triangle_mesh(self.asset_point_cloud_files_dict[select_obj])
            # self.origin_point_clouds[i] = torch.tensor([pcd.vertices], dtype=torch.float, device=self.device)

            # add goal object
            goal_handle = self.gym.create_actor(
                env_ptr,
                self.object_asset_dict[select_obj]['goal'],
                goal_start_pose,
                "goal_object",
                i + self.num_envs,
                0,
                0,
            )
            goal_object_idx = self.gym.get_actor_index(env_ptr, goal_handle, gymapi.DOMAIN_SIM)
            self.goal_object_indices.append(goal_object_idx)

            # add goal object
            # predict_goal_handle = self.gym.create_actor(env_ptr, self.object_asset_dict[select_obj]['predict goal'], goal_start_pose, "predict_goal_object", i + self.num_envs * 2, 0, 0)
            # predict_goal_object_idx = self.gym.get_actor_index(env_ptr, predict_goal_handle, gymapi.DOMAIN_SIM)
            # self.predict_goal_object_indices.append(predict_goal_object_idx)
            # self.gym.set_rigid_body_color(env_ptr, predict_goal_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.4, 0.))

            # add table
            table_handle = self.gym.create_actor(env_ptr, table_asset, table_pose, "table", i, 0, 0)
            table_idx = self.gym.get_actor_index(env_ptr, table_handle, gymapi.DOMAIN_SIM)
            self.gym.set_rigid_body_color(
                env_ptr, table_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.9, 0.8)
            )
            self.table_indices.append(table_idx)

            # add support box
            support_box_handle = self.gym.create_actor(
                env_ptr, support_box_asset, support_box_pose, "support_box", i, 0, 0
            )
            support_box_idx = self.gym.get_actor_index(
                env_ptr, support_box_handle, gymapi.DOMAIN_SIM
            )
            self.gym.set_rigid_body_color(
                env_ptr, support_box_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.9, 0.8)
            )
            self.support_box_indices.append(support_box_idx)

            if self.enable_camera_sensors:
                camera_handle = self.gym.create_camera_sensor(env_ptr, self.camera_props)
                origin = self.gym.get_env_origin(env_ptr)
                self.gym.set_camera_location(
                    camera_handle, env_ptr, gymapi.Vec3(0.0, 0.5, 1.6), gymapi.Vec3(0, -0.5, 0.5)
                )
                
                camera_tensor = self.gym.get_camera_image_gpu_tensor(
                    self.sim, env_ptr, camera_handle, gymapi.IMAGE_COLOR
                )
                torch_cam_tensor = gymtorch.wrap_tensor(camera_tensor)
                cam_vinv = torch.inverse(
                    (
                        torch.tensor(
                            self.gym.get_camera_view_matrix(self.sim, env_ptr, camera_handle)
                        )
                    )
                ).to(self.device)
                cam_proj = torch.tensor(
                    self.gym.get_camera_proj_matrix(self.sim, env_ptr, camera_handle),
                    device=self.device,
                )

            if self.object_type != "block":
                self.gym.set_rigid_body_color(
                    env_ptr, object_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.6, 0.72, 0.98)
                )
                self.gym.set_rigid_body_color(
                    env_ptr, goal_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.6, 0.72, 0.98)
                )

            if self.aggregate_mode > 0:
                self.gym.end_aggregate(env_ptr)

            self.envs.append(env_ptr)
            self.allegro_hands.append(allegro_hand_actor)

            if self.enable_camera_sensors:
                origin = self.gym.get_env_origin(env_ptr)
                self.env_origin[i][0] = origin.x
                self.env_origin[i][1] = origin.y
                self.env_origin[i][2] = origin.z
                self.camera_tensors.append(torch_cam_tensor)
                self.camera_view_matrixs.append(cam_vinv)
                self.camera_proj_matrixs.append(cam_proj)
                self.cameras.append(camera_handle)

        another_sensor_handles = [
            self.gym.find_actor_rigid_body_handle(env_ptr, allegro_hand_another_actor, sensor_name)
            for sensor_name in self.contact_sensor_names
        ]

        sensor_handles = [
            self.gym.find_actor_rigid_body_handle(env_ptr, allegro_hand_actor, sensor_name)
            for sensor_name in self.contact_sensor_names
        ]

        object_sensor_handles = [
            self.gym.find_actor_rigid_body_handle(env_ptr, object_handle, sensor_name)
            for sensor_name in ["bottom", "top"]
        ]

        self.sensor_handle_indices = to_torch(sensor_handles, dtype=torch.int64, device=self.device)
        self.another_sensor_handle_indices = to_torch(another_sensor_handles, dtype=torch.int64, device=self.device)
        self.object_sensor_handles_indices = to_torch(object_sensor_handles, dtype=torch.int64, device=self.device)

        object_rb_props = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
        self.object_rb_masses = [prop.mass for prop in object_rb_props]

        self.object_init_state = to_torch(
            self.object_init_state, device=self.device, dtype=torch.float
        ).view(self.num_envs, 13)
        self.goal_states = self.object_init_state.clone()
        self.goal_pose = self.goal_states[:, 0:7]
        self.goal_pos = self.goal_states[:, 0:3]
        self.goal_rot = self.goal_states[:, 3:7]
        # self.goal_states[:, self.up_axis_idx] -= 0.04
        self.goal_init_state = self.goal_states.clone()
        self.hand_start_states = to_torch(self.hand_start_states, device=self.device).view(
            self.num_envs, 13
        )
        self.another_hand_start_states = to_torch(
            self.another_hand_start_states, device=self.device
        ).view(self.num_envs, 13)

        self.hand_indices = to_torch(self.hand_indices, dtype=torch.long, device=self.device)
        self.another_hand_indices = to_torch(
            self.another_hand_indices, dtype=torch.long, device=self.device
        )

        self.object_indices = to_torch(self.object_indices, dtype=torch.long, device=self.device)
        self.goal_object_indices = to_torch(
            self.goal_object_indices, dtype=torch.long, device=self.device
        )
        self.predict_goal_object_indices = to_torch(
            self.predict_goal_object_indices, dtype=torch.long, device=self.device
        )

        self.table_indices = to_torch(self.table_indices, dtype=torch.long, device=self.device)
        self.support_box_indices = to_torch(
            self.support_box_indices, dtype=torch.long, device=self.device
        )

        self.init_object_tracking = True
        self.test_for_robot_controller = False

        self.p_gain_val = 100.0
        self.d_gain_val = 4.0
        self.p_gain = (
            torch.ones(
                (self.num_envs, self.num_allegro_hand_dofs * 2),
                device=self.device,
                dtype=torch.float,
            )
            * self.p_gain_val
        )
        self.d_gain = (
            torch.ones(
                (self.num_envs, self.num_allegro_hand_dofs * 2),
                device=self.device,
                dtype=torch.float,
            )
            * self.d_gain_val
        )

        self.pd_previous_dof_pos = (
            torch.zeros(
                (self.num_envs, self.num_allegro_hand_dofs * 2),
                device=self.device,
                dtype=torch.float,
            )
            * self.p_gain_val
        )
        self.pd_dof_pos = (
            torch.zeros(
                (self.num_envs, self.num_allegro_hand_dofs * 2),
                device=self.device,
                dtype=torch.float,
            )
            * self.p_gain_val
        )

        self.debug_target = []
        self.debug_qpos = []

        self.traj_estimator = TrajEstimator(input_dim=60, output_dim=3).to(self.device)
        for param in self.traj_estimator.parameters():
            param.requires_grad_(True)
        self.is_test = True

        self.traj_estimator_optimizer = torch.optim.Adam(
            self.traj_estimator.parameters(), lr=0.0003
        )
        self.traj_estimator_save_path = "./demonstration/traj_e/"
        os.makedirs(self.traj_estimator_save_path, exist_ok=True)
        self.bce_logits_loss = torch.nn.BCEWithLogitsLoss()

        if self.is_test:
            self.traj_estimator.eval()
        else:
            self.traj_estimator.train()

        self.total_steps = 0
        self.success_buf = torch.zeros_like(self.rew_buf)
        self.hit_success_buf = torch.zeros_like(self.rew_buf)

    def get_internal_state(self):
        return self.root_state_tensor[self.object_indices, 3:7]

    def get_internal_info(self, key):
        if key == 'target':
            return self.debug_target
        elif key == 'qpos':
            return self.debug_qpos
        elif key == 'contact':
            return self.finger_contacts
        return None
        # return {'target': self.debug_target, 'gt': self.debug_qpos}

    def compute_reward(self, actions):
        (
            self.rew_buf[:],
            self.reset_buf[:],
            self.reset_goal_buf[:],
            self.progress_buf[:],
            self.successes[:],
            self.consecutive_successes[:],
        ) = compute_hand_reward(
            self.rew_buf,
            self.reset_buf,
            self.reset_goal_buf,
            self.progress_buf,
            self.successes,
            self.consecutive_successes,
            self.object_contacts,
            self.left_contacts,
            self.right_contacts,
            self.allegro_left_hand_pos,
            self.allegro_right_hand_pos,
            self.allegro_left_hand_rot,
            self.allegro_right_hand_rot,
            self.max_episode_length,
            self.object_base_pos,
            self.object_base_rot,
            self.goal_pos,
            self.goal_rot,
            self.allegro_left_hand_dof,
            self.allegro_right_hand_dof,
            self.object_dof,
            self.trans_r,
            self.trans_l,
            self.rot_r_quat,
            self.rot_l_quat,
            self.obj_params,
            self.obj_rot_quat,
            self.dist_reward_scale,
            self.rot_reward_scale,
            self.rot_eps,
            self.actions,
            self.action_penalty_scale,
            self.a_hand_palm_pos,
            self.last_actions,
            self.success_tolerance,
            self.reach_goal_bonus,
            self.fall_dist,
            self.fall_penalty,
            self.right_hand_energy_penalty,
            self.left_hand_energy_penalty,
            self.end_step_buf,
            self.seq_idx_tensor,
            self.max_consecutive_successes,
            self.av_factor,
            (self.object_type == "pen"),
            self.object_joint_tolerate,
            self.object_joint_reset,
            self.use_fingertip_reward,
            self.use_hierarchy,
            self.left_fingertip_global,
            self.right_fingertip_global,
            self.left_hand_fingertip_pos_list,
            self.right_hand_fingertip_pos_list,
        )

        self.last_actions = self.actions.clone()

        self.extras['successes'] = self.successes
        self.extras['consecutive_successes'] = self.consecutive_successes

        self.total_steps += 1

        if self.print_success_stat:
            self.total_resets = self.total_resets + self.reset_buf.sum()
            direct_average_successes = self.total_successes + self.successes.sum()
            self.total_successes = self.total_successes + (self.successes * self.reset_buf).sum()

            # The direct average shows the overall result more quickly, but slightly undershoots long term
            # policy performance.
            print(
                "Direct average consecutive successes = {:.1f}".format(
                    direct_average_successes / (self.total_resets + self.num_envs)
                )
            )
            if self.total_resets > 0:
                print(
                    "Post-Reset average consecutive successes = {:.1f}".format(
                        self.total_successes / self.total_resets
                    )
                )

    def compute_observations(self):
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_dof_force_tensor(self.sim)

        self.allegro_right_hand_base_pos = self.root_state_tensor[self.hand_indices, 0:3]
        self.allegro_right_hand_base_rot = self.root_state_tensor[self.hand_indices, 3:7]

        self.allegro_left_hand_base_pos = self.root_state_tensor[self.another_hand_indices, 0:3]
        self.allegro_left_hand_base_rot = self.root_state_tensor[self.another_hand_indices, 3:7]

        self.object_base_pos = self.root_state_tensor[self.object_indices, 0:3]
        self.object_base_rot = self.root_state_tensor[self.object_indices, 3:7]

        # self.allegro_right_hand_pos = self.root_state_tensor[self.hand_indices, 0:3]
        # self.allegro_right_hand_rot = self.root_state_tensor[self.hand_indices, 3:7]
        # self.allegro_right_hand_linvel = self.root_state_tensor[self.hand_indices, 7:10]
        # self.allegro_right_hand_angvel = self.root_state_tensor[self.hand_indices, 10:13]

        # self.allegro_left_hand_pos = self.root_state_tensor[self.another_hand_indices, 0:3]
        # self.allegro_left_hand_rot = self.root_state_tensor[self.another_hand_indices, 3:7]
        # self.allegro_left_hand_linvel = self.root_state_tensor[self.another_hand_indices, 7:10]
        # self.allegro_left_hand_angvel = self.root_state_tensor[self.another_hand_indices, 10:13]

        self.allegro_right_hand_pos = self.rigid_body_states[:, 6, 0:3]
        self.allegro_right_hand_rot = self.rigid_body_states[:, 6, 3:7]
        self.allegro_right_hand_linvel = self.rigid_body_states[:, 6, 7:10]
        self.allegro_right_hand_angvel = self.rigid_body_states[:, 6, 10:13]

        self.allegro_left_hand_pos = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 0:3]
        self.allegro_left_hand_rot = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 3:7]
        self.allegro_left_hand_linvel = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 7:10]
        self.allegro_left_hand_angvel = self.rigid_body_states[:, 6 + self.num_allegro_hand_bodies, 10:13]

        self.a_hand_palm_pos = self.allegro_left_hand_pos.clone()
        
        # axis = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32, device=self.device)
        # angle = torch.tensor([3.1415], dtype=torch.float32, device=self.device)
        # object_real_quat = axis_angle_to_quaternion(axis, angle)
        # print(self.allegro_right_hand_rot)

        # self.allegro_left_hand_rot = quaternion_multiply(self.allegro_left_hand_rot, object_real_quat)
        # self.allegro_left_hand_pos = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.04)

        self.object_pose = self.root_state_tensor[self.object_indices, 0:7]
        self.object_pos = self.root_state_tensor[self.object_indices, 0:3]
        self.object_rot = self.root_state_tensor[self.object_indices, 3:7]
        self.object_linvel = self.root_state_tensor[self.object_indices, 7:10]
        self.object_angvel = self.root_state_tensor[self.object_indices, 10:13]

        self.object_bottom_pose = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 0, 0:7]
        self.object_bottom_pos = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 0, 0:3]
        self.object_bottom_rot = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 0, 3:7]
        
        self.object_bottom_rot = quaternion_multiply(self.object_bottom_rot, self.sim_to_real_object_quaternion)
        self.object_bottom_pos = self.object_bottom_pos + quat_apply(self.object_bottom_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.08)
        self.object_bottom_pos = self.object_bottom_pos + quat_apply(self.object_bottom_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.13)
        self.object_bottom_pos = self.object_bottom_pos + quat_apply(self.object_bottom_rot, to_torch([1, 0, 0], device=self.device).repeat(self.num_envs, 1) * -0.03)
        
        self.object_top_pose = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 1, 0:7]
        self.object_top_pos = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 1, 0:3]
        self.object_top_rot = self.rigid_body_states[:, self.num_allegro_hand_bodies*2 + 1, 3:7]

        self.object_top_rot = quaternion_multiply(self.object_top_rot, self.sim_to_real_object_quaternion)
        self.object_top_pos = self.object_top_pos + quat_apply(self.object_top_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.08)
        self.object_top_pos = self.object_top_pos + quat_apply(self.object_top_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.13)
        self.object_top_pos = self.object_top_pos + quat_apply(self.object_top_rot, to_torch([1, 0, 0], device=self.device).repeat(self.num_envs, 1) * 0.03)

        self.allegro_right_hand_dof = self.allegro_hand_dof_pos.clone()
        self.allegro_left_hand_dof = self.allegro_hand_another_dof_pos.clone()
        self.object_dof = self.object_dof_pos.clone()

        self.goal_pose = self.goal_states[:, 0:7]
        self.goal_pos = self.goal_states[:, 0:3]
        self.goal_rot = self.goal_states[:, 3:7]
        
        # right hand finger
        if self.used_hand_type == "shadow":
            self.allegro_right_hand_pos = self.allegro_right_hand_pos + quat_apply(self.allegro_right_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.36).clone()
            a = quaternion_to_rotation_matrix(self.allegro_right_hand_rot)
            c = a @ self.r2
            self.allegro_right_hand_rot = rotation_matrix_to_quaternion(c)
            
            self.allegro_left_hand_pos = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.36).clone()
            a = quaternion_to_rotation_matrix(self.allegro_left_hand_rot)
            c = a @ self.r2
            self.allegro_left_hand_rot = rotation_matrix_to_quaternion(c)
        
            self.right_hand_ff_pos = self.rigid_body_states[:, 12, 0:3]
            self.right_hand_ff_rot = self.rigid_body_states[:, 12, 3:7]
            self.right_hand_ff_pos = self.right_hand_ff_pos + quat_apply(self.right_hand_ff_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_lf_pos = self.rigid_body_states[:, 17, 0:3]
            self.right_hand_lf_rot = self.rigid_body_states[:, 17, 3:7]
            self.right_hand_lf_pos = self.right_hand_lf_pos + quat_apply(self.right_hand_lf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_mf_pos = self.rigid_body_states[:, 21, 0:3]
            self.right_hand_mf_rot = self.rigid_body_states[:, 21, 3:7]
            self.right_hand_mf_pos = self.right_hand_mf_pos + quat_apply(self.right_hand_mf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_rf_pos = self.rigid_body_states[:, 25, 0:3]
            self.right_hand_rf_rot = self.rigid_body_states[:, 25, 3:7]
            self.right_hand_rf_pos = self.right_hand_rf_pos + quat_apply(self.right_hand_rf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_th_pos = self.rigid_body_states[:, 30, 0:3]
            self.right_hand_th_rot = self.rigid_body_states[:, 30, 3:7]
            self.right_hand_th_pos = self.right_hand_th_pos + quat_apply(self.right_hand_th_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)

            self.right_hand_ff_state = self.rigid_body_states[:, 12, 0:13]
            self.right_hand_lf_state = self.rigid_body_states[:, 17, 0:13]
            self.right_hand_mf_state = self.rigid_body_states[:, 21, 0:13]
            self.right_hand_rf_state = self.rigid_body_states[:, 25, 0:13]
            self.right_hand_th_state = self.rigid_body_states[:, 30, 0:13]

            self.left_hand_ff_pos = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_ff_rot = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_ff_pos = self.left_hand_ff_pos + quat_apply(self.left_hand_ff_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_lf_pos = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_lf_rot = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_lf_pos = self.left_hand_lf_pos + quat_apply(self.left_hand_lf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_mf_pos = self.rigid_body_states[:, 21 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_mf_rot = self.rigid_body_states[:, 21 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_mf_pos = self.left_hand_mf_pos + quat_apply(self.left_hand_mf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_rf_pos = self.rigid_body_states[:, 25 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_rf_rot = self.rigid_body_states[:, 25 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_rf_pos = self.left_hand_rf_pos + quat_apply(self.left_hand_rf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_th_pos = self.rigid_body_states[:, 30 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_th_rot = self.rigid_body_states[:, 30 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_th_pos = self.left_hand_th_pos + quat_apply(self.left_hand_th_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)

            self.left_hand_ff_state = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_lf_state = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_mf_state = self.rigid_body_states[:, 21 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_rf_state = self.rigid_body_states[:, 25 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_th_state = self.rigid_body_states[:, 30 + self.num_allegro_hand_bodies, 0:13]

        elif self.used_hand_type == "allegro":
            self.allegro_right_hand_pos = self.allegro_right_hand_pos + quat_apply(self.allegro_right_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.2).clone()
            a = quaternion_to_rotation_matrix(self.allegro_right_hand_rot)
            c = a @ self.r2
            self.allegro_right_hand_rot = rotation_matrix_to_quaternion(c)
            
            self.allegro_left_hand_pos = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.2).clone()
            a = quaternion_to_rotation_matrix(self.allegro_left_hand_rot)
            c = a @ self.r2
            self.allegro_left_hand_rot = rotation_matrix_to_quaternion(c)
        
            self.right_hand_ff_pos = self.rigid_body_states[:, 10, 0:3]
            self.right_hand_ff_rot = self.rigid_body_states[:, 10, 3:7]
            self.right_hand_ff_pos = self.right_hand_ff_pos + quat_apply(self.right_hand_ff_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_lf_pos = self.rigid_body_states[:, 14, 0:3]
            self.right_hand_lf_rot = self.rigid_body_states[:, 14, 3:7]
            self.right_hand_lf_pos = self.right_hand_lf_pos + quat_apply(self.right_hand_lf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_rf_pos = self.rigid_body_states[:, 18, 0:3]
            self.right_hand_rf_rot = self.rigid_body_states[:, 18, 3:7]
            self.right_hand_rf_pos = self.right_hand_rf_pos + quat_apply(self.right_hand_rf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.right_hand_th_pos = self.rigid_body_states[:, 22, 0:3]
            self.right_hand_th_rot = self.rigid_body_states[:, 22, 3:7]
            self.right_hand_th_pos = self.right_hand_th_pos + quat_apply(self.right_hand_th_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)

            self.right_hand_ff_state = self.rigid_body_states[:, 10, 0:13]
            self.right_hand_lf_state = self.rigid_body_states[:, 14, 0:13]
            self.right_hand_rf_state = self.rigid_body_states[:, 18, 0:13]
            self.right_hand_th_state = self.rigid_body_states[:, 22, 0:13]

            self.left_hand_ff_pos = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_ff_rot = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_ff_pos = self.left_hand_ff_pos + quat_apply(self.left_hand_ff_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_lf_pos = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_lf_rot = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_lf_pos = self.left_hand_lf_pos + quat_apply(self.left_hand_lf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_rf_pos = self.rigid_body_states[:, 18 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_rf_rot = self.rigid_body_states[:, 18 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_rf_pos = self.left_hand_rf_pos + quat_apply(self.left_hand_rf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
            self.left_hand_th_pos = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_th_rot = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_th_pos = self.left_hand_th_pos + quat_apply(self.left_hand_th_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)

            self.left_hand_ff_state = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_lf_state = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_rf_state = self.rigid_body_states[:, 18 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_th_state = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 0:13]

        elif self.used_hand_type == "schunk":
            self.allegro_right_hand_pos = self.allegro_right_hand_pos + quat_apply(self.allegro_right_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.18).clone()
            a = quaternion_to_rotation_matrix(self.allegro_right_hand_rot)
            c = a @ self.r2
            self.allegro_right_hand_rot = rotation_matrix_to_quaternion(c)
            
            self.allegro_left_hand_pos = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.18).clone()
            a = quaternion_to_rotation_matrix(self.allegro_left_hand_rot)
            c = a @ self.r2
            self.allegro_left_hand_rot = rotation_matrix_to_quaternion(c)
        
            self.right_hand_ff_pos = self.rigid_body_states[:, 9, 0:3]
            self.right_hand_ff_rot = self.rigid_body_states[:, 9, 3:7]
            self.right_hand_lf_pos = self.rigid_body_states[:, 13, 0:3]
            self.right_hand_lf_rot = self.rigid_body_states[:, 13, 3:7]
            self.right_hand_mf_pos = self.rigid_body_states[:, 17, 0:3]
            self.right_hand_mf_rot = self.rigid_body_states[:, 17, 3:7]
            self.right_hand_rf_pos = self.rigid_body_states[:, 22, 0:3]
            self.right_hand_rf_rot = self.rigid_body_states[:, 22, 3:7]
            self.right_hand_th_pos = self.rigid_body_states[:, 26, 0:3]
            self.right_hand_th_rot = self.rigid_body_states[:, 26, 3:7]

            self.right_hand_ff_state = self.rigid_body_states[:, 9, 0:13]
            self.right_hand_lf_state = self.rigid_body_states[:, 13, 0:13]
            self.right_hand_mf_state = self.rigid_body_states[:, 17, 0:13]
            self.right_hand_rf_state = self.rigid_body_states[:, 22, 0:13]
            self.right_hand_th_state = self.rigid_body_states[:, 26, 0:13]

            self.left_hand_ff_pos = self.rigid_body_states[:, 9 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_ff_rot = self.rigid_body_states[:, 9 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_lf_pos = self.rigid_body_states[:, 13 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_lf_rot = self.rigid_body_states[:, 13 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_mf_pos = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_mf_rot = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_rf_pos = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_rf_rot = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_th_pos = self.rigid_body_states[:, 26 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_th_rot = self.rigid_body_states[:, 26 + self.num_allegro_hand_bodies, 3:7]

            self.left_hand_ff_state = self.rigid_body_states[:, 9 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_lf_state = self.rigid_body_states[:, 13 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_mf_state = self.rigid_body_states[:, 17 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_rf_state = self.rigid_body_states[:, 22 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_th_state = self.rigid_body_states[:, 26 + self.num_allegro_hand_bodies, 0:13]

        elif self.used_hand_type == "ability":
            self.allegro_right_hand_pos = self.allegro_right_hand_pos + quat_apply(self.allegro_right_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.13).clone()
            a = quaternion_to_rotation_matrix(self.allegro_right_hand_rot)
            c = a @ self.r2
            self.allegro_right_hand_rot = rotation_matrix_to_quaternion(c)
            
            self.allegro_left_hand_pos = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 1, 0], device=self.device).repeat(self.num_envs, 1) * 0.13).clone()
            a = quaternion_to_rotation_matrix(self.allegro_left_hand_rot)
            c = a @ self.r2
            self.allegro_left_hand_rot = rotation_matrix_to_quaternion(c)
        
            self.right_hand_ff_pos = self.rigid_body_states[:, 8, 0:3]
            self.right_hand_ff_rot = self.rigid_body_states[:, 8, 3:7]
            self.right_hand_lf_pos = self.rigid_body_states[:, 10, 0:3]
            self.right_hand_lf_rot = self.rigid_body_states[:, 10, 3:7]
            self.right_hand_mf_pos = self.rigid_body_states[:, 12, 0:3]
            self.right_hand_mf_rot = self.rigid_body_states[:, 12, 3:7]
            self.right_hand_rf_pos = self.rigid_body_states[:, 14, 0:3]
            self.right_hand_rf_rot = self.rigid_body_states[:, 14, 3:7]
            self.right_hand_th_pos = self.rigid_body_states[:, 16, 0:3]
            self.right_hand_th_rot = self.rigid_body_states[:, 16, 3:7]

            self.right_hand_ff_state = self.rigid_body_states[:, 8, 0:13]
            self.right_hand_lf_state = self.rigid_body_states[:, 10, 0:13]
            self.right_hand_mf_state = self.rigid_body_states[:, 12, 0:13]
            self.right_hand_rf_state = self.rigid_body_states[:, 14, 0:13]
            self.right_hand_th_state = self.rigid_body_states[:, 16, 0:13]

            self.left_hand_ff_pos = self.rigid_body_states[:, 8 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_ff_rot = self.rigid_body_states[:, 8 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_lf_pos = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_lf_rot = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_mf_pos = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_mf_rot = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_rf_pos = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_rf_rot = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 3:7]
            self.left_hand_th_pos = self.rigid_body_states[:, 16 + self.num_allegro_hand_bodies, 0:3]
            self.left_hand_th_rot = self.rigid_body_states[:, 16 + self.num_allegro_hand_bodies, 3:7]

            self.left_hand_ff_state = self.rigid_body_states[:, 8 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_lf_state = self.rigid_body_states[:, 10 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_mf_state = self.rigid_body_states[:, 12 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_rf_state = self.rigid_body_states[:, 14 + self.num_allegro_hand_bodies, 0:13]
            self.left_hand_th_state = self.rigid_body_states[:, 16 + self.num_allegro_hand_bodies, 0:13]
            
        self.right_hand_fingertip_pos_list[0] = self.right_hand_th_pos
        self.right_hand_fingertip_pos_list[1] = self.right_hand_ff_pos
        self.right_hand_fingertip_pos_list[2] = self.right_hand_mf_pos
        self.right_hand_fingertip_pos_list[3] = self.right_hand_rf_pos
        self.right_hand_fingertip_pos_list[4] = self.right_hand_lf_pos

        self.left_hand_fingertip_pos_list[0] = self.left_hand_th_pos
        self.left_hand_fingertip_pos_list[1] = self.left_hand_ff_pos
        self.left_hand_fingertip_pos_list[2] = self.left_hand_mf_pos
        self.left_hand_fingertip_pos_list[3] = self.left_hand_rf_pos
        self.left_hand_fingertip_pos_list[4] = self.left_hand_lf_pos
        
        # generate random values
        self.right_hand_dof_vel_finite_diff = self.allegro_hand_dof_vel[:, 6:self.num_allegro_hand_dofs].clone()
        self.left_hand_dof_vel_finite_diff = self.allegro_hand_another_dof_vel[:, 6:self.num_allegro_hand_dofs].clone()
        self.right_hand_dof_torque = self.dof_force_tensor[:, 6:self.num_allegro_hand_dofs].clone()
        self.left_hand_dof_torque = self.dof_force_tensor[:, self.num_allegro_hand_dofs+6:self.num_allegro_hand_dofs*2].clone()

        self.right_hand_energy_penalty = ((self.right_hand_dof_torque * self.right_hand_dof_vel_finite_diff).sum(-1)) ** 2
        self.left_hand_energy_penalty = ((self.left_hand_dof_torque * self.left_hand_dof_vel_finite_diff).sum(-1)) ** 2

        contacts = self.contact_tensor.reshape(self.num_envs, -1, 3)  # 39+27
        
        self.object_contacts = contacts[:, self.object_sensor_handles_indices, :]  # 12
        self.object_contacts = torch.norm(self.object_contacts, dim=-1)
        self.object_contacts = torch.where(self.object_contacts >= 0.1, 1.0, 0.0)
        
        # for i in range(len(self.object_contacts[0])):
        #     if self.object_contacts[0][i] == 1.0:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.object_indices[0], self.object_sensor_handles_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.3, 0.3))
        #     else:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.object_indices[0], self.object_sensor_handles_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 1, 1))

        self.right_contacts = contacts[:, self.sensor_handle_indices, :]  # 12
        self.right_contacts = torch.norm(self.right_contacts, dim=-1)
        self.right_contacts = torch.where(self.right_contacts >= 0.1, 1.0, 0.0)

        # for i in range(len(self.right_contacts[0])):
        #     if self.right_contacts[0][i] == 1.0:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.3, 0.3))
        #     else:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 1, 1))

        self.left_contacts = contacts[:, self.another_sensor_handle_indices, :]  # 12
        self.left_contacts = torch.norm(self.left_contacts, dim=-1)
        self.left_contacts = torch.where(self.left_contacts >= 0.1, 1.0, 0.0)

        # for i in range(len(self.left_contacts[0])):
        #     if self.left_contacts[0][i] == 1.0:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.another_hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.3, 0.3))
        #     else:
        #         self.gym.set_rigid_body_color(
        #                     self.envs[0], self.another_hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 1, 1))

        self.all_contact = torch.norm(self.contact_tensor.reshape(self.num_envs, -1, 3), dim=-1)
        self.all_contact = torch.where(self.all_contact >= 0.1, 1.0, 0.0)
        
        rand_floats = torch_rand_float(-1.0, 1.0, (self.num_envs, 63), device=self.device)

        # convert to real-world
        
        # self.r_pos_global = self.trans_r[self.seq_idx_tensor, [0]].clone().squeeze(0)
        # self.r_rot_global = (
        #     self.rot_r_quat[self.seq_idx_tensor, [0]].clone().squeeze(0)
        # )
        # self.l_pos_global = self.trans_l[self.seq_idx_tensor, [0]].clone().squeeze(0)
        # self.l_rot_global = (
        #     self.rot_l_quat[self.seq_idx_tensor, [0]].clone().squeeze(0)
        # )
        
        # self.r_pos_global_real = self.allegro_right_hand_pos + quat_apply(self.allegro_right_hand_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * -0.2)
        # self.allegro_right_hand_targets_pos_real = apply_affine_transformation(self.sim_to_real_translation_matrix, self.r_pos_global_real - self.allegro_right_hand_base_pos)
        # self.allegro_right_hand_targets_rot_real = quaternion_multiply(self.allegro_right_hand_rot, self.sim_to_real_rotation_quaternion)

        # self.l_pos_global_real = self.allegro_left_hand_pos + quat_apply(self.allegro_left_hand_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * -0.2)
        # self.allegro_left_hand_targets_pos_real = apply_affine_transformation(self.sim_to_real_translation_matrix, self.l_pos_global_real - self.allegro_left_hand_base_pos)
        # self.allegro_left_hand_targets_rot_real = quaternion_multiply(self.allegro_left_hand_rot, self.sim_to_real_rotation_quaternion)

        if self.use_p_c_impro_loop:
            self.step_collect_success_traj()
        
        self.compute_sim2real_asymmetric_obs(rand_floats)
        self.compute_sim2real_observation(rand_floats)

    def compute_sim2real_observation(self, rand_floats):
        # origin obs
        self.obs_buf[:, 6 : self.num_allegro_hand_dofs] = unscale(
            self.allegro_hand_dof_pos[:, 6 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
        )
        self.obs_buf[:, self.num_allegro_hand_dofs + 6 : 2 * self.num_allegro_hand_dofs] = unscale(
            self.allegro_hand_another_dof_pos[:, 6 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
        )

        # self.obs_buf[:, 36:84] = self.actions

        self.obs_buf[:, 96:99] = self.allegro_right_hand_pos
        self.obs_buf[:, 99:103] = self.allegro_right_hand_rot

        self.obs_buf[:, 103:106] = self.allegro_left_hand_pos
        self.obs_buf[:, 106:110] = self.allegro_left_hand_rot

        self.obs_buf[:, 110:117] = self.object_bottom_pose
        self.obs_buf[:, 117:118] = self.object_dof_pos
        self.obs_buf[:, 118:125] = self.object_top_pose

        # self.obs_buf[:, 123:126] = self.object_pos - self.obj_params[
        #     self.seq_idx_tensor, self.progress_buf, 4:7
        # ].squeeze(0)
        
        # self.obs_buf[:, 126:130] = quat_mul(
        #     self.object_rot,
        #     quat_conjugate(self.obj_rot_quat[self.seq_idx_tensor, self.progress_buf].squeeze(0)),
        # )

        self.obs_buf[:, 130:131] = self.object_dof - self.obj_params[
            self.seq_idx_tensor, self.progress_buf, 0:1
        ].squeeze(0)

        self.stack_frame = 3
        for i in range(self.stack_frame):
            self.obs_buf[:, 144 + 22 * i : 147 + 22 * i] = self.obj_params[
                self.seq_idx_tensor, self.progress_buf + i, 4:7
            ].squeeze(0)
            self.obs_buf[:, 147 + 22 * i : 151 + 22 * i] = self.obj_rot_quat[
                self.seq_idx_tensor, self.progress_buf + i
            ].squeeze(0)
            self.obs_buf[:, 151 + 22 * i : 152 + 22 * i] = self.obj_params[
                self.seq_idx_tensor, self.progress_buf + i, 0:1
            ].squeeze(0)

            self.obs_buf[:, 152 + 22 * i : 155 + 22 * i] = self.trans_l[
                self.seq_idx_tensor, self.progress_buf + i
            ]
            self.obs_buf[:, 155 + 22 * i : 159 + 22 * i] = self.rot_l_quat[
                self.seq_idx_tensor, self.progress_buf + i
            ]
            self.obs_buf[:, 159 + 22 * i : 162 + 22 * i] = self.trans_r[
                self.seq_idx_tensor, self.progress_buf + i
            ]
            self.obs_buf[:, 162 + 22 * i : 166 + 22 * i] = self.rot_r_quat[
                self.seq_idx_tensor, self.progress_buf + i
            ]
            
        if self.train_teacher_policy:
            self.obs_buf = self.states_buf.clone()

    def compute_sim2real_asymmetric_obs(self, rand_floats):
        # visualize
        self.states_buf[:, 0 : self.num_allegro_hand_dofs] = unscale(
            self.allegro_hand_dof_pos[:, 0 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_lower_limits[0 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_upper_limits[0 : self.num_allegro_hand_dofs],
        )
        self.states_buf[:, self.num_allegro_hand_dofs : 2 * self.num_allegro_hand_dofs] = unscale(
            self.allegro_hand_another_dof_pos[:, 0 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_lower_limits[0 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_upper_limits[0 : self.num_allegro_hand_dofs],
        )

        # self.states_buf[:, 36:84] = self.actions
        self.states_buf[:, 84:87] = self.allegro_right_hand_linvel
        self.states_buf[:, 87:90] = self.allegro_right_hand_angvel
        self.states_buf[:, 90:93] = self.allegro_left_hand_linvel
        self.states_buf[:, 93:96] = self.allegro_left_hand_angvel

        self.states_buf[:, 96:99] = self.allegro_right_hand_pos
        self.states_buf[:, 99:103] = self.allegro_right_hand_rot

        self.states_buf[:, 103:106] = self.allegro_left_hand_pos
        self.states_buf[:, 106:110] = self.allegro_left_hand_rot

        self.states_buf[:, 110:117] = self.object_pose
        self.states_buf[:, 117:120] = self.object_linvel
        self.states_buf[:, 120:123] = self.object_angvel

        self.states_buf[:, 123:126] = self.object_pos - self.obj_params[
            self.seq_idx_tensor, self.progress_buf, 4:7
        ].squeeze(0)
        
        self.states_buf[:, 126:133] = self.object_bottom_pose
        self.states_buf[:, 133:134] = self.object_dof_pos
        self.states_buf[:, 134:141] = self.object_top_pose
        
        # self.states_buf[:, 126:130] = quat_mul(
        #     self.object_rot,
        #     quat_conjugate(self.obj_rot_quat[self.seq_idx_tensor, self.progress_buf].squeeze(0)),
        # )
        # self.states_buf[:, 130:133] = self.allegro_left_hand_pos - self.trans_l[
        #     self.seq_idx_tensor, self.progress_buf
        # ].squeeze(0)
        # self.states_buf[:, 133:137] = quat_mul(
        #     self.allegro_left_hand_rot,
        #     quat_conjugate(self.rot_l_quat[self.seq_idx_tensor, self.progress_buf].squeeze(0)),
        # )
        # self.states_buf[:, 137:140] = self.allegro_right_hand_pos - self.trans_r[
        #     self.seq_idx_tensor, self.progress_buf
        # ].squeeze(0)
        # self.states_buf[:, 140:144] = quat_mul(
        #     self.allegro_right_hand_rot,
        #     quat_conjugate(self.rot_r_quat[self.seq_idx_tensor, self.progress_buf].squeeze(0)),
        # )

        if not self.use_hierarchy:
            skip_frame = 1
            for i in range(10):
                self.states_buf[:, 144 + 22 * i : 147 + 22 * i] = self.obj_params[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame, 4:7
                ].squeeze(0)
                self.states_buf[:, 147 + 22 * i : 151 + 22 * i] = self.obj_rot_quat[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame
                ].squeeze(0)
                self.states_buf[:, 151 + 22 * i : 154 + 22 * i] = self.trans_l[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame
                ].squeeze(0)
                self.states_buf[:, 154 + 22 * i : 158 + 22 * i] = self.rot_l_quat[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame
                ].squeeze(0)
                self.states_buf[:, 158 + 22 * i : 161 + 22 * i] = self.trans_r[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame
                ].squeeze(0)
                self.states_buf[:, 161 + 22 * i : 165 + 22 * i] = self.rot_r_quat[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame
                ].squeeze(0)
                self.states_buf[:, 165 + 22 * i : 166 + 22 * i] = self.obj_params[
                    self.seq_idx_tensor, self.progress_buf + i*skip_frame, 0:1
                ].squeeze(0)

        self.index_after_future_frame = 166 + 22 * 9
        
        self.states_buf[:, self.index_after_future_frame:self.index_after_future_frame+1] = self.object_dof - self.obj_params[
            self.seq_idx_tensor, self.progress_buf, 0:1
        ].squeeze(0)
        
        if self.used_hand_type == "shadow":
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 0 : self.index_after_future_frame+1 + 13 * 1] = self.left_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 1 : self.index_after_future_frame+1 + 13 * 2] = self.left_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 2 : self.index_after_future_frame+1 + 13 * 3] = self.left_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 3 : self.index_after_future_frame+1 + 13 * 4] = self.left_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 4 : self.index_after_future_frame+1 + 13 * 5] = self.left_hand_th_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 5 : self.index_after_future_frame+1 + 13 * 6] = self.right_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 6 : self.index_after_future_frame+1 + 13 * 7] = self.right_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 7 : self.index_after_future_frame+1 + 13 * 8] = self.right_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 8 : self.index_after_future_frame+1 + 13 * 9] = self.right_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 9 : self.index_after_future_frame+1 + 13 * 10] = self.right_hand_th_state
            
        elif self.used_hand_type == "allegro":
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 0 : self.index_after_future_frame+1 + 13 * 1] = self.left_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 1 : self.index_after_future_frame+1 + 13 * 2] = self.left_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 3 : self.index_after_future_frame+1 + 13 * 4] = self.left_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 4 : self.index_after_future_frame+1 + 13 * 5] = self.left_hand_th_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 5 : self.index_after_future_frame+1 + 13 * 6] = self.right_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 6 : self.index_after_future_frame+1 + 13 * 7] = self.right_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 8 : self.index_after_future_frame+1 + 13 * 9] = self.right_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 9 : self.index_after_future_frame+1 + 13 * 10] = self.right_hand_th_state

        elif self.used_hand_type == "schunk":
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 0 : self.index_after_future_frame+1 + 13 * 1] = self.left_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 1 : self.index_after_future_frame+1 + 13 * 2] = self.left_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 2 : self.index_after_future_frame+1 + 13 * 3] = self.left_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 3 : self.index_after_future_frame+1 + 13 * 4] = self.left_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 4 : self.index_after_future_frame+1 + 13 * 5] = self.left_hand_th_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 5 : self.index_after_future_frame+1 + 13 * 6] = self.right_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 6 : self.index_after_future_frame+1 + 13 * 7] = self.right_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 7 : self.index_after_future_frame+1 + 13 * 8] = self.right_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 8 : self.index_after_future_frame+1 + 13 * 9] = self.right_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 9 : self.index_after_future_frame+1 + 13 * 10] = self.right_hand_th_state

        elif self.used_hand_type == "ability":
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 0 : self.index_after_future_frame+1 + 13 * 1] = self.left_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 1 : self.index_after_future_frame+1 + 13 * 2] = self.left_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 2 : self.index_after_future_frame+1 + 13 * 3] = self.left_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 3 : self.index_after_future_frame+1 + 13 * 4] = self.left_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 4 : self.index_after_future_frame+1 + 13 * 5] = self.left_hand_th_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 5 : self.index_after_future_frame+1 + 13 * 6] = self.right_hand_ff_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 6 : self.index_after_future_frame+1 + 13 * 7] = self.right_hand_lf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 7 : self.index_after_future_frame+1 + 13 * 8] = self.right_hand_mf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 8 : self.index_after_future_frame+1 + 13 * 9] = self.right_hand_rf_state
            self.states_buf[:, self.index_after_future_frame+1 + 13 * 9 : self.index_after_future_frame+1 + 13 * 10] = self.right_hand_th_state

        self.states_buf[:, self.index_after_future_frame+1 + 13 * 10: self.index_after_future_frame+1 + 13 * 10 + self.all_contact.shape[1]] = self.all_contact
        #self.index_after_future_frame = 364 + 1 + 130 + 68 = 563
        self.index_after_fingertip = 563
        # self.states_buf[:, self.index_after_fingertip: self.index_after_fingertip+60] = self.actions

    def reset_target_pose(self, env_ids, apply_reset=False):
        rand_floats = torch_rand_float(-1.0, 1.0, (len(env_ids), 4), device=self.device)

        self.goal_states[env_ids, 0:3] = self.goal_init_state[env_ids, 0:3]

        self.goal_states[env_ids, 0] += rand_floats[:, 0] * 0.05
        # self.goal_states[env_ids, 1] -= 0.45 + rand_floats[:, 1] * 0.15
        # self.goal_states[env_ids, 2] += 0.1

        self.goal_states[env_ids, 1] -= 0.55 + rand_floats[:, 1] * 0.05
        self.goal_states[env_ids, 2] += 10.1

        # self.goal_states[env_ids, 3:7] = new_rot
        self.root_state_tensor[self.goal_object_indices[env_ids], 0:3] = (
            self.goal_states[env_ids, 0:3] + self.goal_displacement_tensor
        )

        if apply_reset:
            goal_object_indices = self.goal_object_indices[env_ids].to(torch.int32)
            self.gym.set_actor_root_state_tensor_indexed(
                self.sim,
                gymtorch.unwrap_tensor(self.root_state_tensor),
                gymtorch.unwrap_tensor(goal_object_indices),
                len(env_ids),
            )
        self.reset_goal_buf[env_ids] = 0

    def calc_succ_rate(self, env_ids):
        joint_dist = torch.abs(
            self.object_dof[:, 0:1] - self.obj_joint_init
        )
        
        self.success_buf = torch.where(
            joint_dist[env_ids] > 0.5,
            torch.ones_like(self.success_buf),
            torch.zeros_like(self.success_buf),
        )
        
        self.complete_percentage[env_ids] = torch.clamp(self.progress_buf[env_ids].float() / float(self.init_step + 500 - 11), 0 , 1)
        
        print("success_buf: ", self.success_buf[:].mean())
        print("complete_percentage: ", self.complete_percentage[:].mean())
        if self.traj_index == "all":
            print(self.complete_percentage[:].item())
        exit()

    def pre_collect_success_traj(self):        
        self.collect_trans_r = self.trans_r.clone()
        self.collect_trans_l = self.trans_l.clone()
        self.collect_obj_params = self.obj_params.clone()
        self.collect_obj_rot_quat = self.obj_rot_quat.clone()
        self.collect_rot_r_quat = self.rot_r_quat.clone()
        self.collect_rot_l_quat = self.rot_l_quat.clone()
        
    def step_collect_success_traj(self):
        print(self.collect_trans_r[self.seq_idx_tensor, self.progress_buf].shape)
        self.collect_trans_r[self.seq_idx_tensor, self.progress_buf] = self.allegro_right_hand_pos.clone()
        self.collect_trans_l[self.seq_idx_tensor, self.progress_buf] = self.allegro_left_hand_pos.clone()
        self.collect_rot_r_quat[self.seq_idx_tensor, self.progress_buf] = self.allegro_right_hand_rot.clone()
        self.collect_rot_l_quat[self.seq_idx_tensor, self.progress_buf] = self.allegro_left_hand_rot.clone()
        self.collect_obj_params[self.seq_idx_tensor, self.progress_buf, 1:4] = self.object_pos.clone()
        self.collect_obj_params[self.seq_idx_tensor, self.progress_buf, 0:1] = self.object_dof_pos.clone()
        self.collect_obj_rot_quat[self.seq_idx_tensor, self.progress_buf] = self.object_rot.clone()
        
    def post_collect_success_traj(self, env_ids):
        self.complete_percentage[env_ids] = torch.clamp(self.progress_buf[env_ids].float() / float(self.init_step + 500 - 11), 0 , 1)
        self.append_seq_list = []
        for env_id in env_ids:
            if self.complete_percentage[env_id] >= 0.8:
                self.append_seq_list.append(
                        {
                            "trans_r": self.trans_r[env_id].clone(),
                            "trans_l": self.trans_l[env_id].clone(),
                            "obj_params": self.obj_params[env_id].clone(),
                            "obj_rot_quat": self.obj_rot_quat[env_id].clone(),
                            "rot_r_quat": self.rot_r_quat[env_id].clone(),
                            "rot_l_quat": self.rot_l_quat[env_id].clone(),
                        })
                
        print(len(self.append_seq_list))
        exit()

        
    def reset(self, env_ids, goal_env_ids):
        # randomization can happen only at reset time, since it can reset actor positions on GPU
        # if self.total_steps > 0:
        #     self.calc_succ_rate(env_ids)
        
        # if self.total_steps >0:
        #     exit()
        if self.use_p_c_impro_loop:
            if self.total_steps == 0:
                self.pre_collect_success_traj()
            if self.total_steps == 1000:
                self.post_collect_success_traj()
            

        if self.randomize:
            self.apply_randomizations(self.randomization_params)

        self.perturb_direction[env_ids] = torch_rand_float(
            -1, 1, (len(env_ids), 6), device=self.device
        ).squeeze(-1)

        # generate random values
        rand_floats = torch_rand_float(
            -1.0, 1.0, (len(env_ids), self.num_allegro_hand_dofs * 2 + 5), device=self.device
        )

        # randomize start object poses
        self.reset_target_pose(env_ids)

        # self.root_state_tensor[self.another_hand_indices[env_ids], 2] = -0.05 + rand_floats[:, 4] * 0.01

        # reset object
        self.root_state_tensor[self.object_indices[env_ids]] = self.object_init_state[
            env_ids
        ].clone()
        self.root_state_tensor[self.hand_indices[env_ids]] = self.hand_start_states[env_ids].clone()
        self.root_state_tensor[self.another_hand_indices[env_ids]] = self.another_hand_start_states[
            env_ids
        ].clone()

        self.object_pose_for_open_loop[env_ids] = self.root_state_tensor[
            self.object_indices[env_ids], 0:7
        ].clone()

        object_indices = torch.unique(
            torch.cat(
                [
                    self.object_indices[env_ids],
                    self.goal_object_indices[env_ids],
                    self.table_indices[env_ids],
                    self.support_box_indices[env_ids],
                    self.goal_object_indices[goal_env_ids],
                ]
            ).to(torch.int32)
        )
        
        # self.gym.set_actor_root_state_tensor_indexed(self.sim,
        #                                              gymtorch.unwrap_tensor(self.root_state_tensor),
        #                                              gymtorch.unwrap_tensor(object_indices), len(object_indices))

        # reset shadow hand
        pos = self.allegro_hand_default_dof_pos
        another_pos = self.another_allegro_hand_default_dof_pos

        self.allegro_hand_dof_pos[env_ids, :] = pos
        self.allegro_hand_another_dof_pos[env_ids, :] = another_pos

        self.allegro_hand_dof_vel[env_ids, :] = self.allegro_hand_dof_default_vel
        self.allegro_hand_another_dof_vel[env_ids, :] = self.allegro_hand_dof_default_vel

        self.prev_targets[env_ids, : self.num_allegro_hand_dofs] = pos
        self.cur_targets[env_ids, : self.num_allegro_hand_dofs] = pos

        self.prev_targets[
            env_ids, self.num_allegro_hand_dofs : self.num_allegro_hand_dofs * 2
        ] = another_pos
        self.cur_targets[
            env_ids, self.num_allegro_hand_dofs : self.num_allegro_hand_dofs * 2
        ] = another_pos

        # reset object
        self.object_dof_pos[env_ids, :] = self.object_default_dof_pos
        self.object_dof_vel[env_ids, :] = torch.zeros_like(self.object_dof_vel[env_ids, :])

        self.prev_targets[env_ids, 2 * self.num_allegro_hand_dofs :] = self.object_default_dof_pos
        self.cur_targets[env_ids, 2 * self.num_allegro_hand_dofs :] = self.object_default_dof_pos

        hand_indices = self.hand_indices[env_ids].to(torch.int32)
        another_hand_indices = self.another_hand_indices[env_ids].to(torch.int32)
        all_hand_indices = torch.unique(
            torch.cat(
                [
                    hand_indices,
                    another_hand_indices,
                    self.object_indices[env_ids],
                    self.goal_object_indices[env_ids],
                ]
            ).to(torch.int32)
        )

        # load calibrated init state
        if self.use_calibrated_init_state:
            self.env_dof_state[env_ids] = self.init_dof_state.view(self.num_envs, -1, 2)[
                env_ids
            ].clone()

        self.gym.set_dof_state_tensor_indexed(
            self.sim,
            gymtorch.unwrap_tensor(self.dof_state),
            gymtorch.unwrap_tensor(all_hand_indices),
            len(all_hand_indices),
        )

        self.gym.set_dof_position_target_tensor_indexed(
            self.sim,
            gymtorch.unwrap_tensor(self.prev_targets),
            gymtorch.unwrap_tensor(all_hand_indices),
            len(all_hand_indices),
        )

        all_indices = torch.unique(torch.cat([all_hand_indices, object_indices]).to(torch.int32))

        # self.init_step = random.randint(0, self.max_episode_length - 211)
        self.init_step = 0
        self.progress_buf[env_ids] = self.init_step
        # self.end_step_buf[env_ids] = self.init_step + 0
        self.end_step_buf[env_ids] = self.init_step + self.traj_len - 11
            
        self.r_pos_global_init[env_ids] = self.trans_r[env_ids, self.init_step]
        self.r_rot_global_init[env_ids] = self.rot_r_quat[env_ids, self.init_step]
        self.l_pos_global_init[env_ids] = self.trans_l[env_ids, self.init_step]
        self.l_rot_global_init[env_ids] = self.rot_l_quat[env_ids, self.init_step]
        self.obj_pos_global_init[env_ids] = self.obj_params[env_ids, self.init_step, 4:7]
        self.obj_rot_global_init[env_ids] = self.obj_rot_quat[env_ids, self.init_step, 0:4]

        self.root_state_tensor[self.object_indices[env_ids], 0:3] = self.obj_pos_global_init[
            env_ids
        ]
        self.root_state_tensor[self.object_indices[env_ids], 3:7] = self.obj_rot_global_init[
            env_ids
        ]
        self.root_state_tensor[self.object_indices[env_ids], 7:10] = 0
        self.root_state_tensor[self.object_indices[env_ids], 10:13] = 0

        # self.root_state_tensor[self.hand_indices[env_ids], 0:3] = self.r_pos_global_init[env_ids]
        # self.root_state_tensor[self.hand_indices[env_ids], 3:7] = self.r_rot_global_init[env_ids]

        # self.root_state_tensor[self.another_hand_indices[env_ids], 0:3] = self.l_pos_global_init[env_ids]
        # self.root_state_tensor[self.another_hand_indices[env_ids], 3:7] = self.l_rot_global_init[env_ids]

        self.gym.set_actor_root_state_tensor_indexed(
            self.sim,
            gymtorch.unwrap_tensor(self.root_state_tensor),
            gymtorch.unwrap_tensor(all_indices),
            len(all_indices),
        )

        # post reset
        # self.calibrate_init_state(all_hand_indices)

        # self.progress_buf[env_ids] = 0
        self.reset_buf[env_ids] = 0
        self.successes[env_ids] = 0
        self.last_actions[env_ids] = torch.zeros_like(self.actions[env_ids])

        self.proprioception_close_loop[env_ids] = self.allegro_hand_dof_pos[env_ids, 0:22].clone()

        self.object_state_stack_frames[env_ids] = torch.zeros_like(
            self.object_state_stack_frames[env_ids]
        )

    def pre_physics_step(self, actions):
        # if self.progress_buf[0] > 20:
        #     self.tarject_predict(0, 19)
        self.actions = actions.clone().to(self.device)

        self.r_pos_global = self.trans_r[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        self.r_rot_global = (
            self.rot_r_quat[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        )
        self.l_pos_global = self.trans_l[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        self.l_rot_global = (
            self.rot_l_quat[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        )
        self.left_fingertip_global = self.left_fingertip[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        self.right_fingertip_global = (
            self.right_fingertip[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        )
        self.left_middle_finger_global = self.left_middle_finger[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        self.right_middle_finger_global = (
            self.right_middle_finger[self.seq_idx_tensor, self.progress_buf].clone().squeeze(0)
        )
        
        if self.test_record:
            absolute_difference = torch.abs(self.obs_buf - self.recorded_obs[self.progress_buf[0]].unsqueeze(0).cuda())
            absolute_sum = torch.sum(absolute_difference)
            print(absolute_sum.item())
            
            self.actions = self.policy.predict(torch.clamp(self.recorded_obs[self.progress_buf[0]].unsqueeze(0).cuda(), -5, 5))

        if self.sim2real_record and self.total_steps > 0:
            self.data["left_hand_targets"].append([self.allegro_hand_another_dof_pos[0].cpu().tolist()[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19, 20, 21, 22, 23, 24, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29]])
            self.data["right_hand_targets"].append([self.allegro_hand_dof_pos[0].cpu().tolist()[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19, 20, 21, 22, 23, 24, 12, 13, 14, 15, 16, 25, 26, 27, 28, 29]])

            self.data["right_hand_ee_pos"].append((self.allegro_right_hand_pos - self.root_state_tensor[self.hand_indices, 0:3])[0].cpu().tolist())
            self.data["right_hand_ee_rot"].append(self.allegro_right_hand_rot[0].cpu().tolist())

            self.data["r_pos_global"].append(self.r_pos_global[0].cpu().tolist())
            self.data["r_rot_global"].append(self.r_rot_global[0].cpu().tolist())

            self.data["left_hand_ee_pos"].append((self.allegro_left_hand_pos - self.root_state_tensor[self.another_hand_indices, 0:3])[0].cpu().tolist())
            self.data["left_hand_ee_rot"].append(self.allegro_left_hand_rot[0].cpu().tolist())

            self.data["l_pos_global"].append(self.l_pos_global[0].cpu().tolist())
            self.data["l_rot_global"].append(self.l_rot_global[0].cpu().tolist())

            self.data["obs_buf"].append(self.obs_buf[0].cpu())

            print("total_steps: ", self.total_steps)
            if self.reset_buf[0] == 1 and self.total_steps != 0:
                with open("./arctic_box_trajs_joint_ee_obs_residual_{}_{}.pkl".format(self.used_training_objects[0], self.traj_index), "wb") as f:
                    pickle.dump(self.data, f)
                print("Record finish!!")
                exit()

        if self.enable_camera_sensors:
            env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
            
            for i in self.seq_list_i:
                if i in env_ids and self.total_steps > 0:
                    self.video_out_list[i].release()
                    self.seq_list_i.remove(i)
                    if len(self.seq_list_i) == 0:
                        cv2.destroyAllWindows()
                        self.calc_succ_rate(env_ids)
                        exit()
                    
            self.gym.render_all_camera_sensors(self.sim)
            self.gym.start_access_image_tensors(self.sim)

            for i in self.seq_list_i:
                camera_rgba_image = self.camera_rgb_visulization(self.camera_tensors, env_id=i, is_depth_image=False)
                self.video_out_list[i].write(camera_rgba_image)   

            cv2.imshow("DEBUG_RGB_VIS", camera_rgba_image)
            cv2.waitKey(1)
                
            self.gym.end_access_image_tensors(self.sim)

        env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
        goal_env_ids = self.reset_goal_buf.nonzero(as_tuple=False).squeeze(-1)
        
        if len(env_ids) > 0:
            self.reset(env_ids, goal_env_ids)

        if self.used_hand_type == "shadow":
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = scale(
                self.actions[:, 6:self.num_allegro_hand_dofs],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            
            self.cur_targets[
                :, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2
            ] = scale(
                self.actions[:, self.num_allegro_hand_dofs+6:self.num_allegro_hand_dofs*2],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
        elif self.used_hand_type == "allegro":
            self.act_moving_average = 0.2
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = scale(
                self.actions[:, 6:self.num_allegro_hand_dofs],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = self.act_moving_average * self.cur_targets[:,
                                                                                                        6 : self.num_allegro_hand_dofs] + (1.0 - self.act_moving_average) * self.prev_targets[:, 6 : self.num_allegro_hand_dofs]
            self.cur_targets[
                :, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2
            ] = scale(
                self.actions[:, self.num_allegro_hand_dofs+6:self.num_allegro_hand_dofs*2],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:,self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] = self.act_moving_average * self.cur_targets[:,
                                                                                                       self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] + (1.0 - self.act_moving_average) * self.prev_targets[:, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2]
        
        elif self.used_hand_type == "schunk":
            self.act_moving_average = 0.2
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = scale(
                self.actions[:, 6:self.num_allegro_hand_dofs],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = self.act_moving_average * self.cur_targets[:,
                                                                                                        6 : self.num_allegro_hand_dofs] + (1.0 - self.act_moving_average) * self.prev_targets[:, 6 : self.num_allegro_hand_dofs]
            self.cur_targets[
                :, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2
            ] = scale(
                self.actions[:, self.num_allegro_hand_dofs+6:self.num_allegro_hand_dofs*2],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:,self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] = self.act_moving_average * self.cur_targets[:,
                                                                                                       self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] + (1.0 - self.act_moving_average) * self.prev_targets[:, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2]

        elif self.used_hand_type == "ability":
            self.act_moving_average = 0.5
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = scale(
                self.actions[:, 6:self.num_allegro_hand_dofs],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:, 6 : self.num_allegro_hand_dofs] = self.act_moving_average * self.cur_targets[:,
                                                                                                        6 : self.num_allegro_hand_dofs] + (1.0 - self.act_moving_average) * self.prev_targets[:, 6 : self.num_allegro_hand_dofs]
            self.cur_targets[
                :, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2
            ] = scale(
                self.actions[:, self.num_allegro_hand_dofs+6:self.num_allegro_hand_dofs*2],
                self.allegro_hand_dof_lower_limits[6 : self.num_allegro_hand_dofs],
                self.allegro_hand_dof_upper_limits[6 : self.num_allegro_hand_dofs],
            )
            self.cur_targets[:,self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] = self.act_moving_average * self.cur_targets[:,
                                                                                                       self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2] + (1.0 - self.act_moving_average) * self.prev_targets[:, self.num_allegro_hand_dofs + 6 : self.num_allegro_hand_dofs * 2]
        
        # object curriculum
        # self.cur_targets[:, 60] = (
        #     self.obj_params[self.seq_idx_tensor, self.progress_buf, 0].clone().squeeze(0)
        # )
        self.cur_targets[:, self.num_allegro_hand_dofs*2:self.num_allegro_hand_dofs*2+1] = self.object_dof_pos
        # fixed WR1 WR2
        # self.cur_targets[:, 6:8] = torch.zeros_like(self.cur_targets[:, 6:8])
        # self.cur_targets[:, 36:38] = torch.zeros_like(self.cur_targets[:, 36:38])
        if self.used_hand_type == "shadow":
            self.trans_range = 0.04
            self.rot_range = 0.5
        elif self.used_hand_type == "schunk":
            self.trans_range = 0.04
            self.rot_range = 0.5
        elif self.used_hand_type == "allegro":
            self.trans_range = 0.04
            self.rot_range = 0.5
        elif self.used_hand_type == "ability":
            self.trans_range = 0.04
            self.rot_range = 0.5

        right_pos_err = (self.r_pos_global - self.allegro_right_hand_pos)
        right_target_rot = self.r_rot_global
        right_rot_err = orientation_error(right_target_rot, self.allegro_right_hand_rot)
        right_pos_err += self.actions[:, 0:3] * self.trans_range
        right_rot_err += self.actions[:, 3:6] * self.rot_range
        
        if self.use_hierarchy:
            right_pos_err = self.actions[:, 0:3] * self.trans_range
            right_rot_err = self.actions[:, 3:6] * self.rot_range
            
            
        right_dpose = torch.cat([right_pos_err, right_rot_err], -1).unsqueeze(-1)
        right_delta = control_ik(
            self.jacobian_tensor[:, self.hand_base_rigid_body_index - 1, :, :6],
            self.device,
            right_dpose,
            self.num_envs,
        )
                
        right_targets = self.allegro_hand_dof_pos[:, 0:6] + right_delta[:, :6]

        left_pos_err = (self.l_pos_global - self.allegro_left_hand_pos)
        left_target_rot = self.l_rot_global
        left_rot_err = orientation_error(left_target_rot, self.allegro_left_hand_rot)
        left_pos_err += self.actions[:, self.num_allegro_hand_dofs:self.num_allegro_hand_dofs+3] * self.trans_range
        left_rot_err += self.actions[:, self.num_allegro_hand_dofs+3:self.num_allegro_hand_dofs+6] * self.rot_range
            
        if self.use_hierarchy:
            left_pos_err = self.actions[:, self.num_allegro_hand_dofs:self.num_allegro_hand_dofs+3] * self.trans_range
            left_rot_err = self.actions[:, self.num_allegro_hand_dofs+3:self.num_allegro_hand_dofs+6] * self.rot_range
            
        left_dpose = torch.cat([left_pos_err, left_rot_err], -1).unsqueeze(-1)
        left_delta = control_ik(
            self.another_jacobian_tensor[:, self.hand_base_rigid_body_index - 1, :, :6],
            self.device,
            left_dpose,
            self.num_envs,
        )
        left_targets = self.allegro_hand_another_dof_pos[:, 0:6] + left_delta[:, :6]


        self.cur_targets[:, :6] = right_targets[:, :6].clone()
        self.cur_targets[:, self.num_allegro_hand_dofs:self.num_allegro_hand_dofs+6] = left_targets[:, :6].clone()

        self.cur_targets[:, : self.num_allegro_hand_dofs] = tensor_clamp(
            self.cur_targets[:, 0 : self.num_allegro_hand_dofs],
            self.allegro_hand_dof_lower_limits[:],
            self.allegro_hand_dof_upper_limits[:],
        )

        self.cur_targets[
            :, self.num_allegro_hand_dofs : self.num_allegro_hand_dofs * 2
        ] = tensor_clamp(
            self.cur_targets[:, self.num_allegro_hand_dofs : self.num_allegro_hand_dofs * 2],
            self.allegro_hand_dof_lower_limits[:],
            self.allegro_hand_dof_upper_limits[:],
        )

        self.prev_targets[:, :] = self.cur_targets[:, :].clone()

        if self.use_fingertip_ik:            
            left_hand_joint, right_hand_joint = self.ik_solver.solve_ik(self.left_fingertip_global[0, :].view(5, 3), self.allegro_left_hand_pos.view(3), self.dl.left_fingertip_rot[self.progress_buf[0]].view(4), self.right_fingertip_global[0, :].view(5, 3), self.allegro_right_hand_pos.view(3), self.dl.right_fingertip_rot[self.progress_buf[0]].view(4))
            left_hand_joint = to_torch(list(left_hand_joint), device=self.device)
            right_hand_joint = to_torch(list(right_hand_joint), device=self.device)
            self.cur_targets[:, 6:8] = 0
            self.cur_targets[:, 8:30] = right_hand_joint
            self.cur_targets[:, 36:38] = 0
            self.cur_targets[:, 8+self.num_allegro_hand_dofs:30+self.num_allegro_hand_dofs] = left_hand_joint
        if self.use_joint_space_ik:
            left_hand_joint, right_hand_joint = self.ik_solver.bimanual_position_optimizer(torch.concat((self.left_fingertip_global[0, :].view(5, 3) - self.l_pos_global.view(3), self.left_middle_finger_global[0, :].view(5, 3) - self.l_pos_global.view(3)), dim=0), torch.concat((self.right_fingertip_global[0, :].view(5, 3) - self.r_pos_global.view(3), self.right_middle_finger_global[0, :].view(5, 3) - self.r_pos_global.view(3)), dim=0))
            left_hand_joint = to_torch(list(left_hand_joint), device=self.device)
            right_hand_joint = to_torch(list(right_hand_joint), device=self.device)
            self.cur_targets[:, 6:30] = right_hand_joint
            self.cur_targets[:, 6+self.num_allegro_hand_dofs:30+self.num_allegro_hand_dofs] = left_hand_joint
        
        self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.cur_targets))
        
        
        if self.load_sim_record:
            self.gym.set_actor_root_state_tensor(
                self.sim,
                gymtorch.unwrap_tensor(self.f_actor_root_state_tensor[self.progress_buf[0]])
            )
            self.gym.set_dof_state_tensor(
                self.sim,
                gymtorch.unwrap_tensor(self.f_dof_state_tensor[self.progress_buf[0]])
            )
        ############################################################
        
        if self.apply_perturbation:
            rand_floats = torch_rand_float(-1.0, 1.0, (self.num_envs, 3), device=self.device)
            self.apply_forces[:, 31*2 + 0, :] = rand_floats * 20

            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.apply_forces), gymtorch.unwrap_tensor(self.apply_torque), gymapi.ENV_SPACE)
        #############################################################


    def post_physics_step(self):
        self.progress_buf += 1
        self.randomize_buf += 1

        self.compute_observations()
        self.compute_reward(self.actions)

        self.gym.clear_lines(self.viewer)
        self.gym.refresh_rigid_body_state_tensor(self.sim)

            
        if self.viewer and self.debug_viz:
            # draw axes on target object
            self.gym.clear_lines(self.viewer)
            self.gym.refresh_rigid_body_state_tensor(self.sim)

            for i in range(self.num_envs):
                self.add_debug_lines(
                    self.envs[i],
                    self.allegro_hand_another_thmub_pos[i],
                    self.allegro_hand_another_thmub_rot[i],
                    line_width=2,
                )
                # self.add_debug_lines(self.envs[i], self.allegro_left_hand_pos[i], self.allegro_left_hand_rot[i])

    def add_debug_lines(self, env, pos, rot, line_width=1):
        posx = (pos + quat_apply(rot, to_torch([1, 0, 0], device=self.device) * 0.2)).cpu().numpy()
        posy = (pos + quat_apply(rot, to_torch([0, 1, 0], device=self.device) * 0.2)).cpu().numpy()
        posz = (pos + quat_apply(rot, to_torch([0, 0, 1], device=self.device) * 0.2)).cpu().numpy()

        p0 = pos.cpu().numpy()
        self.gym.add_lines(
            self.viewer,
            env,
            line_width,
            [p0[0], p0[1], p0[2], posx[0], posx[1], posx[2]],
            [0.85, 0.1, 0.1],
        )
        self.gym.add_lines(
            self.viewer,
            env,
            line_width,
            [p0[0], p0[1], p0[2], posy[0], posy[1], posy[2]],
            [0.1, 0.85, 0.1],
        )
        self.gym.add_lines(
            self.viewer,
            env,
            line_width,
            [p0[0], p0[1], p0[2], posz[0], posz[1], posz[2]],
            [0.1, 0.1, 0.85],
        )

    #####################################################################
    ###=========================jit functions=========================###
    #####################################################################

    def camera_rgb_visulization(self, camera_tensors, env_id=0, is_depth_image=False):
        torch_rgba_tensor = camera_tensors[env_id].clone()
        camera_image = torch_rgba_tensor.cpu().numpy()
        camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)
        
        return camera_image


    def camera_visulization(self, is_depth_image=False):
        if is_depth_image:
            camera_depth_tensor = self.gym.get_camera_image_gpu_tensor(
                self.sim, self.envs[0], self.cameras[0], gymapi.IMAGE_DEPTH
            )
            torch_depth_tensor = gymtorch.wrap_tensor(camera_depth_tensor)
            torch_depth_tensor = torch.clamp(torch_depth_tensor, -1, 1)
            torch_depth_tensor = scale(
                torch_depth_tensor,
                to_torch([0], dtype=torch.float, device=self.device),
                to_torch([256], dtype=torch.float, device=self.device),
            )
            camera_image = torch_depth_tensor.cpu().numpy()
            camera_image = Im.fromarray(camera_image)

        else:
            camera_rgba_tensor = self.gym.get_camera_image_gpu_tensor(
                self.sim, self.envs[0], self.cameras[0], gymapi.IMAGE_COLOR
            )
            torch_rgba_tensor = gymtorch.wrap_tensor(camera_rgba_tensor)
            camera_image = torch_rgba_tensor.cpu().numpy()
            camera_image = Im.fromarray(camera_image)

        return camera_image

    def rand_row(self, tensor, dim_needed):
        row_total = tensor.shape[0]
        return tensor[torch.randint(low=0, high=row_total, size=(dim_needed,)), :]

    def sample_points(self, points, sample_num=1000, sample_mathed='furthest'):
        eff_points = points[points[:, 2] > 0.04]
        if eff_points.shape[0] < sample_num:
            eff_points = points
        if sample_mathed == 'random':
            sampled_points = self.rand_row(eff_points, sample_num)
        elif sample_mathed == 'furthest':
            sampled_points_id = pointnet2_utils.furthest_point_sample(
                eff_points.reshape(1, *eff_points.shape), sample_num
            )
            sampled_points = eff_points.index_select(0, sampled_points_id[0].long())
        return sampled_points


@torch.jit.script
def depth_image_to_point_cloud_GPU(
    camera_tensor,
    camera_view_matrix_inv,
    camera_proj_matrix,
    u,
    v,
    width: float,
    height: float,
    depth_bar: float,
    device: torch.device,
):
    # time1 = time.time()
    depth_buffer = camera_tensor.to(device)

    # Get the camera view matrix and invert it to transform points from camera to world space
    vinv = camera_view_matrix_inv

    # Get the camera projection matrix and get the necessary scaling
    # coefficients for deprojection

    proj = camera_proj_matrix
    fu = 2 / proj[0, 0]
    fv = 2 / proj[1, 1]

    centerU = width / 2
    centerV = height / 2

    Z = depth_buffer
    X = -(u - centerU) / width * Z * fu
    Y = (v - centerV) / height * Z * fv

    Z = Z.view(-1)
    valid = Z > -depth_bar
    X = X.view(-1)
    Y = Y.view(-1)

    position = torch.vstack((X, Y, Z, torch.ones(len(X), device=device)))[:, valid]
    position = position.permute(1, 0)
    position = position @ vinv

    points = position[:, 0:3]

    return points


@torch.jit.script
def compute_hand_reward(
    rew_buf,
    reset_buf,
    reset_goal_buf,
    progress_buf,
    successes,
    consecutive_successes,
    object_contact,
    left_contact,
    right_contact,
    allegro_left_hand_pos,
    allegro_right_hand_pos,
    allegro_left_hand_rot,
    allegro_right_hand_rot,
    max_episode_length: float,
    object_pos,
    object_rot,
    target_pos,
    target_rot,
    allegro_left_hand_dof,
    allegro_right_hand_dof,
    object_dof,
    trans_r,
    trans_l,
    rot_r_quat,
    rot_l_quat,
    obj_params,
    obj_quat,
    dist_reward_scale: float,
    rot_reward_scale: float,
    rot_eps: float,
    actions,
    action_penalty_scale: float,
    a_hand_palm_pos,
    last_actions,
    success_tolerance: float,
    reach_goal_bonus: float,
    fall_dist: float,
    fall_penalty: float,
    right_hand_energy_penalty,
    left_hand_energy_penalty,
    end_step_buf,
    seq_idx_tensor,
    max_consecutive_successes: int,
    av_factor: float,
    ignore_z_rot: bool,
    object_joint_tolerate: float,
    object_joint_reset: float,
    use_fingertip_reward: int,
    use_hierarchy: int,
    left_fingertip_global,
    right_fingertip_global,
    left_fingertip_pos_list,
    right_fingertip_pos_list,
):
    object_pos_dist = torch.norm(
        object_pos - obj_params[seq_idx_tensor, progress_buf, 4:7].squeeze(0), p=2, dim=-1
    )
    object_quat_diff = quat_mul(
        object_rot, quat_conjugate(obj_quat[seq_idx_tensor, progress_buf].squeeze(0))
    )
    object_rot_dist = 2.0 * torch.asin(
        torch.clamp(torch.norm(object_quat_diff[:, 0:3], p=2, dim=-1), max=1.0)
    )
    object_joint_dist = torch.abs(
        object_dof[:, 0] - obj_params[seq_idx_tensor, progress_buf, 0].squeeze(0)
    )
    
    object_pos_dist = torch.clamp(object_pos_dist - 0.05, 0, None)
    object_rot_dist = torch.clamp(object_rot_dist - 0.1, 0, None)
    object_joint_dist = torch.clamp(object_joint_dist - object_joint_tolerate, 0, None)

    left_hand_pos_dist = torch.norm(
        allegro_left_hand_pos - trans_l[seq_idx_tensor, progress_buf].squeeze(0), p=2, dim=-1
    )
    left_hand_quat_diff = quat_mul(
        allegro_left_hand_rot, quat_conjugate(rot_l_quat[seq_idx_tensor, progress_buf].squeeze(0))
    )
    left_hand_rot_dist = 2.0 * torch.asin(
        torch.clamp(torch.norm(left_hand_quat_diff[:, 0:3], p=2, dim=-1), max=1.0)
    )

    left_hand_pos_dist = torch.clamp(left_hand_pos_dist - 0.15, 0, None)
    left_hand_rot_dist = torch.clamp(left_hand_rot_dist - 0.5, 0, None)

    right_hand_pos_dist = torch.norm(
        allegro_right_hand_pos - trans_r[seq_idx_tensor, progress_buf].squeeze(0), p=2, dim=-1
    )
    right_hand_quat_diff = quat_mul(
        allegro_right_hand_rot, quat_conjugate(rot_r_quat[seq_idx_tensor, progress_buf].squeeze(0))
    )
    right_hand_rot_dist = 2.0 * torch.asin(
        torch.clamp(torch.norm(right_hand_quat_diff[:, 0:3], p=2, dim=-1), max=1.0)
    )

    right_hand_pos_dist = torch.clamp(right_hand_pos_dist - 0.15, 0, None)
    right_hand_rot_dist = torch.clamp(right_hand_rot_dist - 0.5, 0, None)

    object_reward = 1 * torch.exp(
        -2 * object_rot_dist - 20 * object_pos_dist - 2 * object_joint_dist
    )
    left_hand_reward = 1 * torch.exp(-1 * left_hand_rot_dist - 20 * left_hand_pos_dist)
    right_hand_reward = 1 * torch.exp(-1 * right_hand_rot_dist - 20 * right_hand_pos_dist)
    # left_hand_reward = torch.ones_like(object_reward)
    # right_hand_reward = torch.ones_like(object_reward)
    
    is_left_contact = (left_contact == 1).any(dim=1, keepdim=False)
    is_right_contact = (right_contact == 1).any(dim=1, keepdim=False)

    jittering_penalty = 0.003 * torch.sum(actions**2, dim=-1)
    energy_penalty = -0.000001 * (right_hand_energy_penalty + left_hand_energy_penalty)
    
    reward = (object_reward + energy_penalty)

    if use_fingertip_reward:
        reward *= torch.exp((torch.norm(left_fingertip_global[:, 0:3] - left_fingertip_pos_list[0], p=2, dim=-1) + torch.norm(left_fingertip_global[:, 3:6] - left_fingertip_pos_list[1], p=2, dim=-1) +
                   torch.norm(left_fingertip_global[:, 6:9] - left_fingertip_pos_list[2], p=2, dim=-1) + torch.norm(left_fingertip_global[:, 9:12] - left_fingertip_pos_list[3], p=2, dim=-1) + 
                   torch.norm(left_fingertip_global[:, 12:15] - left_fingertip_pos_list[4], p=2, dim=-1)) / 5 * (-20))                                                                                                   
        reward *= torch.exp((torch.norm(right_fingertip_global[:, 0:3] - right_fingertip_pos_list[0], p=2, dim=-1) + torch.norm(right_fingertip_global[:, 3:6] - right_fingertip_pos_list[1], p=2, dim=-1) +
                   torch.norm(right_fingertip_global[:, 6:9] - right_fingertip_pos_list[2], p=2, dim=-1) + torch.norm(right_fingertip_global[:, 9:12] - right_fingertip_pos_list[3], p=2, dim=-1) + 
                   torch.norm(right_fingertip_global[:, 12:15] - right_fingertip_pos_list[4], p=2, dim=-1)) / 5 * (-20))
        
    if use_hierarchy:
        reward *= right_hand_reward * left_hand_reward
        
    # print("reward: ", reward[0].item())
    # print("object_reward: ", object_reward[0].item())
    # print("energy_penalty: ", energy_penalty[0].item())
    # reward = (object_reward) * progress_buf + left_contact_reward + left_contact_reward
    
    # print("if_contact_left: ", if_contact_left[0].item())
    # print("if_contact_right: ", if_contact_right[0].item())
    # print("right_hand_reward: ", right_hand_reward[0].item())
    # print("left_hand_reward: ", left_hand_reward[0].item())

    # print("object_rot_dist: ", object_rot_dist[0].item())
    # print("right_hand_rot_dist: ", right_hand_rot_dist[0].item())
    # print("left_hand_rot_dist: ", left_hand_rot_dist[0].item())

    # print("object_pos_dist: ", object_pos_dist[0].item())
    # print("right_hand_pos_dist: ", right_hand_pos_dist[0].item())
    # print("left_hand_pos_dist: ", left_hand_pos_dist[0].item())

    # print("object_joint_dist: ", object_joint_dist[0].item())
    # print("jittering_penalty: ", jittering_penalty[0].item())

    # Check env termination conditions, including maximum success number
    resets = torch.where(object_pos[:, 2] <= -10.15, torch.ones_like(reset_buf), reset_buf)

    resets = torch.where(object_pos_dist >= 0.05, torch.ones_like(resets), resets)
    resets = torch.where(object_rot_dist >= 0.5, torch.ones_like(resets), resets)
    resets = torch.where(object_joint_dist >= object_joint_reset, torch.ones_like(resets), resets)
    # print(resets.shape)
    resets = torch.where(is_left_contact, torch.ones_like(resets), resets)
    resets = torch.where(is_right_contact, torch.ones_like(resets), resets)
    # print(resets.shape)
    # resets = torch.where(left_hand_pos_dist >= 0.35, torch.ones_like(resets), resets)
    # resets = torch.where(left_hand_rot_dist >= 5.0, torch.ones_like(resets), resets)

    # resets = torch.where(right_hand_pos_dist >= 0.35, torch.ones_like(resets), resets)
    # resets = torch.where(right_hand_rot_dist >= 5.0, torch.ones_like(resets), resets)

    # reward = torch.where(resets == 1, reward - 10, reward)

    # hard constraint finger motion
    goal_resets = torch.where(
        object_pos[:, 2] <= -10, torch.ones_like(reset_goal_buf), reset_goal_buf
    )
    successes = successes + goal_resets

    # Success bonus: orientation is within `success_tolerance` of goal orientation
    resets = torch.where(progress_buf >= end_step_buf, torch.ones_like(resets), resets)

    # Apply penalty for not reaching the goal
    if max_consecutive_successes > 0:
        reward = torch.where(
            progress_buf >= max_episode_length, reward + 0.5 * fall_penalty, reward
        )

    num_resets = torch.sum(resets)
    finished_cons_successes = torch.sum(successes * resets.float())

    cons_successes = torch.where(
        num_resets > 0,
        av_factor * finished_cons_successes / num_resets
        + (1.0 - av_factor) * consecutive_successes,
        consecutive_successes,
    )

    return reward, resets, goal_resets, progress_buf, successes, cons_successes


@torch.jit.script
def randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):
    return quat_mul(
        quat_from_angle_axis(rand0 * np.pi, x_unit_tensor),
        quat_from_angle_axis(rand1 * np.pi, y_unit_tensor),
    )


@torch.jit.script
def randomize_rotation_pen(rand0, rand1, max_angle, x_unit_tensor, y_unit_tensor, z_unit_tensor):
    rot = quat_mul(
        quat_from_angle_axis(0.5 * np.pi + rand0 * max_angle, x_unit_tensor),
        quat_from_angle_axis(rand0 * np.pi, z_unit_tensor),
    )
    return rot

@torch.jit.script
def orientation_error(desired, current):
    cc = quat_conjugate(current)
    q_r = quat_mul(desired, cc)
    return q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1)

@torch.jit.script
def control_ik(j_eef, device: str, dpose, num_envs):
    # Set controller parameters
    # IK params
    damping = 0.05
    # solve damped least squares
    j_eef_T = torch.transpose(j_eef, 1, 2)
    lmbda = torch.eye(6, device=device) * (damping**2)
    u = (j_eef_T @ torch.inverse(j_eef @ j_eef_T + lmbda) @ dpose).view(num_envs, -1)
    return u

@torch.jit.script
def quaternion_to_euler_xyz(quaternion):
    """
    Convert quaternion to euler angles in XYZ order.

    Parameters:
    quaternion (torch.Tensor): Tensor of quaternions of shape (batch_size, 4) with the order (x, y, z, w).

    Returns:
    torch.Tensor: Tensor of euler angles in XYZ order of shape (batch_size, 3) representing roll, pitch, and yaw.
    """
    x, y, z, w = quaternion.unbind(dim=-1)

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(torch.abs(sinp) >= 1, torch.sign(sinp) * (torch.pi / 2), torch.asin(sinp))

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return torch.stack((roll, pitch, yaw), dim=-1)


@torch.jit.script
def xyzw_quaternion_to_euler_xyz(quaternion):
    """
    Convert quaternion to euler angles in XYZ order.

    Parameters:
    quaternion (torch.Tensor): Tensor of quaternions of shape (batch_size, 4) with the order (x, y, z, w).

    Returns:
    torch.Tensor: Tensor of euler angles in XYZ order of shape (batch_size, 3) representing roll, pitch, and yaw.
    """
    x, y, z, w = quaternion.unbind(dim=-1)

    # Roll (x-axis rotation)
    sinr_cosp = 2 * (w * x + y * z)
    cosr_cosp = 1 - 2 * (x * x + y * y)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    pitch = torch.where(torch.abs(sinp) >= 1, torch.sign(sinp) * (torch.pi / 2), torch.asin(sinp))

    # Yaw (z-axis rotation)
    siny_cosp = 2 * (w * z + x * y)
    cosy_cosp = 1 - 2 * (y * y + z * z)
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return torch.stack((roll, pitch, yaw), dim=-1)

@torch.jit.script
def quaternion_to_rotation_matrix(quaternion):
    """
    Convert quaternion to rotation matrix.
    
    Args:
        quaternion (torch.Tensor): Input quaternion tensor with shape (batch_size, 4).
        
    Returns:
        torch.Tensor: Rotation matrix tensor with shape (batch_size, 3, 3).
    """
    q1, q2, q3, q0 = quaternion.unbind(dim=-1)  # Assuming xyzw order
    batch_size = quaternion.size(0)

    # Compute rotation matrix
    rotation_matrix = torch.zeros(batch_size, 3, 3, dtype=quaternion.dtype, device=quaternion.device)
    rotation_matrix[:, 0, 0] = 1 - 2*q2*q2 - 2*q3*q3
    rotation_matrix[:, 0, 1] = 2*q1*q2 - 2*q0*q3
    rotation_matrix[:, 0, 2] = 2*q1*q3 + 2*q0*q2
    rotation_matrix[:, 1, 0] = 2*q1*q2 + 2*q0*q3
    rotation_matrix[:, 1, 1] = 1 - 2*q1*q1 - 2*q3*q3
    rotation_matrix[:, 1, 2] = 2*q2*q3 - 2*q0*q1
    rotation_matrix[:, 2, 0] = 2*q1*q3 - 2*q0*q2
    rotation_matrix[:, 2, 1] = 2*q2*q3 + 2*q0*q1
    rotation_matrix[:, 2, 2] = 1 - 2*q1*q1 - 2*q2*q2

    return rotation_matrix