# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import os
import torch
import random

from isaacgym import gymtorch
from isaacgym import gymapi
from isaacgym.torch_utils import *
from tasks.hand_base.base_task import BaseTask

from tasks.hand_base.vec_task import VecTask

from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import matplotlib.pyplot as plt
from PIL import Image as Im
import cv2

from einops import rearrange
import pickle
from utils.cnn_module import FeatureTunk

class VValue(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(VValue, self).__init__()
        self.feature_tunk = FeatureTunk(pretrained=False, input_dim=input_dim, output_dim=output_dim)

    def forward(self, inputs):
        # 1 * 8 * 8 feat
        inputs = inputs / 255.0
        outputs = self.feature_tunk(inputs)

        return outputs

# class VValue(nn.Module):
#     def __init__(self, input_channel, output_channel) :
#         super(VValue, self).__init__()
#         self._input_channel = input_channel
#         self._output_channel = output_channel
#         self.nets = nn.Sequential(
#             torch.nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3),
#             torch.nn.ELU(),
#             torch.nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0),
#             torch.nn.ELU(),
#             torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
#             torch.nn.ELU(),
#             torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
#             torch.nn.ELU(),
#             torch.nn.Flatten(),
#             nn.Linear(131072, 1),
#             torch.nn.ELU(),
#         )

#     def forward(self, inputs):
#         # inputs = inputs / 255.0
#         x = self.nets(inputs)
#         return x

class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_grad_enabled(self.prev)


class AllegroHandLegoRetrieveGraspVValue(BaseTask):

    def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, headless, agent_index=[[[0, 1, 2, 3, 4, 5]], [[0, 1, 2, 3, 4, 5]]], is_multi_agent=False):

        self.cfg = cfg

        self.sim_params = sim_params
        self.physics_engine = physics_engine
        self.agent_index = agent_index

        self.is_multi_agent = is_multi_agent

        self.randomize = self.cfg["task"]["randomize"]
        self.randomization_params = self.cfg["task"]["randomization_params"]
        
        self.aggregate_mode = self.cfg["env"]["aggregateMode"]

        self.dist_reward_scale = self.cfg["env"]["distRewardScale"]
        self.rot_reward_scale = self.cfg["env"]["rotRewardScale"]
        self.action_penalty_scale = self.cfg["env"]["actionPenaltyScale"]
        self.success_tolerance = self.cfg["env"]["successTolerance"]
        self.reach_goal_bonus = self.cfg["env"]["reachGoalBonus"]
        self.fall_dist = self.cfg["env"]["fallDistance"]
        self.fall_penalty = self.cfg["env"]["fallPenalty"]
        self.rot_eps = self.cfg["env"]["rotEps"]
        self.hand_reset_step = self.cfg["env"]["handResetStep"]

        self.vel_obs_scale = 0.2  # scale factor of velocity based observations
        self.force_torque_obs_scale = 10.0  # scale factor of velocity based observations

        self.reset_position_noise = self.cfg["env"]["resetPositionNoise"]
        self.reset_rotation_noise = self.cfg["env"]["resetRotationNoise"]
        self.reset_dof_pos_noise = self.cfg["env"]["resetDofPosRandomInterval"]
        self.reset_dof_vel_noise = self.cfg["env"]["resetDofVelRandomInterval"]

        self.force_scale = self.cfg["env"].get("forceScale", 0.0)
        self.force_prob_range = self.cfg["env"].get("forceProbRange", [0.001, 0.1])
        self.force_decay = self.cfg["env"].get("forceDecay", 0.99)
        self.force_decay_interval = self.cfg["env"].get("forceDecayInterval", 0.08)
        self.rotation_axis = "y"
        if self.rotation_axis == "x":
            self.rotation_id = 0
        elif self.rotation_axis == "y":
            self.rotation_id = 1
        else:
            self.rotation_id = 2

        self.shadow_hand_dof_speed_scale = self.cfg["env"]["dofSpeedScale"]
        self.use_relative_control = self.cfg["env"]["useRelativeControl"]
        self.act_moving_average = self.cfg["env"]["actionsMovingAverage"]

        self.debug_viz = self.cfg["env"]["enableDebugVis"]

        self.max_episode_length = self.cfg["env"]["episodeLength"]
        self.reset_time = self.cfg["env"].get("resetTime", -1.0)
        self.print_success_stat = self.cfg["env"]["printNumSuccesses"]
        self.max_consecutive_successes = self.cfg["env"]["maxConsecutiveSuccesses"]
        self.av_factor = self.cfg["env"].get("averFactor", 0.1)

        self.object_type = self.cfg["env"]["objectType"]
        self.spin_coef = self.cfg["env"].get("spin_coef", 1.0)
        assert self.object_type in ["block", "egg", "pen"]

        self.ignore_z = (self.object_type == "pen")

        self.robot_asset_files_dict = {
            "normal": "urdf/franka_description/robots/franka_panda_allegro.urdf",
            "large":  "urdf/xarm6/xarm6_allegro_left_fsr_large.urdf"
        }
        self.asset_files_dict = {
            "block": "urdf/objects/cube_multicolor.urdf",
            "egg": "mjcf/box/mobility.urdf",
            "pen": "mjcf/open_ai_assets/hand/pen.xml"
        }

        # can be "full_no_vel", "full", "full_state"
        self.obs_type = self.cfg["env"]["observationType"]

        if not (self.obs_type in ["full_no_vel", "full", "full_state", "full_contact", "partial_contact"]):
            raise Exception(
                "Unknown type of observations!\nobservationType should be one of: [openai, full_no_vel, full, full_state]")

        print("Obs type:", self.obs_type)

        self.palm_name = "palm"
        self.contact_sensor_names = ["link_1.0_fsr", "link_2.0_fsr", "link_3.0_tip_fsr",
                                     "link_5.0_fsr", "link_6.0_fsr", "link_7.0_tip_fsr", "link_9.0_fsr",
                                     "link_10.0_fsr", "link_11.0_tip_fsr", "link_14.0_fsr", "link_15.0_fsr",
                                     "link_15.0_tip_fsr"]
        self.fingertip_names = ["link_3.0_tip",
                                "link_7.0_tip",
                                "link_11.0_tip",
                                "link_15.0_tip"]
        # 11, 13, 16, 20, 22, 24, 27, 29, 32, 36, 39, 40
        # self.contact_sensor_names = ["link_1.0", "link_2.0", "link_3.0_tip",
        #                              "link_5.0", "link_6.0", "link_7.0_tip", "link_9.0",
        #                              "link_10.0", "link_11.0_tip", "link_14.0", "link_15.0",
        #                              "link_15.0_tip"]
        self.stack_obs = 3
        # self.num_obs_dict = {
        #     "full_no_vel": 50,
        #     "full": 72,
        #     "full_state": 88,
        #     "full_contact": 90,
        #     "partial_contact": 74 + 128*128*4
        # }

        self.num_obs_dict = {
            "partial_contact": 67
        }
        self.up_axis = 'z'

        self.use_vel_obs = False
        self.fingertip_obs = True
        self.asymmetric_obs = self.cfg["env"]["asymmetric_observations"]

        num_states = 0
        if self.asymmetric_obs:
            num_states = 154

        self.one_frame_num_obs = self.num_obs_dict[self.obs_type]
        self.one_frame_num_states = num_states
        self.cfg["env"]["numObservations"] = self.num_obs_dict[self.obs_type] * self.stack_obs
        self.cfg["env"]["numStates"] = num_states * self.stack_obs
        self.cfg["env"]["numActions"] = 23

        self.cfg["device_type"] = device_type
        self.cfg["device_id"] = device_id
        self.cfg["headless"] = headless

        self.enable_camera_sensors = self.cfg["env"]["enable_camera_sensors"]

        super().__init__(cfg=self.cfg, enable_camera_sensors=self.enable_camera_sensors)

        if self.viewer != None:
            cam_pos = gymapi.Vec3(0.5, -0.1, 1.5)
            cam_target = gymapi.Vec3(-0.7, -0.1, 0.0)
            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

        self.dt = self.sim_params.dt
        control_freq_inv = self.cfg["env"].get("controlFrequencyInv", 1)
        if self.reset_time > 0.0:
            self.max_episode_length = int(round(self.reset_time/(control_freq_inv * self.dt)))
            print("Reset time: ", self.reset_time)
            print("New episode length: ", self.max_episode_length)

        # get gym GPU state tensors
        actor_root_state_tensor = self.gym.acquire_actor_root_state_tensor(self.sim)
        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)
        rigid_body_tensor = self.gym.acquire_rigid_body_state_tensor(self.sim)
        contact_tensor = self.gym.acquire_net_contact_force_tensor(self.sim)
        self.jacobian_tensor = gymtorch.wrap_tensor(self.gym.acquire_jacobian_tensor(self.sim, "hand"))

        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)

        # create some wrapper tensors for different slices
        self.arm_hand_default_dof_pos = torch.zeros(self.num_arm_hand_dofs, dtype=torch.float, device=self.device)
        self.arm_hand_default_dof_pos[:7] = torch.tensor([0.9467, -0.5708, -2.4997, -2.3102, -0.7739,  2.6616, -0.9208], dtype=torch.float, device=self.device)        

        self.arm_hand_default_dof_pos[7:] = to_torch([0.0, -0.174, 0.785, 0.785,
                                            0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)

        self.arm_hand_prepare_dof_poses = torch.zeros((self.num_envs, self.num_arm_hand_dofs), dtype=torch.float, device=self.device)
        self.end_effector_rotation = torch.zeros((self.num_envs, 4), dtype=torch.float, device=self.device)

        self.arm_hand_prepare_dof_pos_list = []
        self.end_effector_rot_list = []

        # rot = [0, 0.707, 0, 0.707]
        self.arm_hand_prepare_dof_pos = to_torch([-0.3463, -0.3414,  0.4400, -2.7079,  0.2244,  2.3851, -0.0901,
                                                0.0, -0.174, 0.785, 0.785,
                                            0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos_list.append(self.arm_hand_prepare_dof_pos)
        self.end_effector_rot_list.append(to_torch([0, 0.707, 0, 0.707], device=self.device))

        # face forward
        self.arm_hand_prepare_dof_pos = to_torch([-1.4528e-02,  2.3290e-01,  1.5519e-02, -2.7374e+00,  8.7328e-04, 4.5402e+00,  3.1363e+00,
                                                0.0, -0.174, 0.785, 0.785,
                                            0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos_list.append(self.arm_hand_prepare_dof_pos)
        self.end_effector_rot_list.append(to_torch([1, 0., 0., 0.], device=self.device))

        # face right, [-0.4227, -0.6155, -0.3687, -0.5537]  0.0276,  0.0870, -0.4854, -2.6056,  1.2111,  1.3671, -1.1870
        # self.arm_hand_prepare_dof_pos = to_torch([1.0260,  0.0671, 0.42, -2.4576, -0.25,  3.7172,  1.82,
        #                                         0.0, -0.174, 0.785, 0.785,
        #                                     0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos = to_torch([0.1707,  0.0737, -0.5725, -2.4737,  1.2567,  1.3162, -1.0150,
                                                0.0, -0.174, 0.785, 0.785,
                                            0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos_list.append(self.arm_hand_prepare_dof_pos)
        self.end_effector_rot_list.append(to_torch([0.5, 0.5, 0.5, 0.5], device=self.device))

        # face left, [ 0.4175, -0.5494,  0.4410, -0.5739] -1.5712, -1.5254,  1.7900, -2.2848,  3.1094,  3.7490, -2.8722
        # self.arm_hand_prepare_dof_pos = to_torch([1.0260,  0.0671, -2.72, -2.4576, -0.25,  3.7172,  -1.32,
        #                                         0.0, -0.174, 0.785, 0.785,
        #                                     0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos = to_torch([-0.4006, -0.1464,  0.7419, -2.3031, -1.2898,  1.3568, -0.9339,
                                                0.0, -0.174, 0.785, 0.785,
                                            0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785, 0.0, -0.174, 0.785, 0.785], dtype=torch.float, device=self.device)
        self.arm_hand_prepare_dof_pos_list.append(self.arm_hand_prepare_dof_pos)
        self.end_effector_rot_list.append(to_torch([-0.707, 0.707, 0.0, -0.0], device=self.device))

        self.dof_state = gymtorch.wrap_tensor(dof_state_tensor)
        self.arm_hand_dof_state = self.dof_state.view(self.num_envs, -1, 2)[:, :self.num_arm_hand_dofs]
        self.arm_hand_dof_pos = self.arm_hand_dof_state[..., 0]
        self.arm_hand_dof_vel = self.arm_hand_dof_state[..., 1]

        self.rigid_body_states = gymtorch.wrap_tensor(rigid_body_tensor).view(self.num_envs, -1, 13)
        self.num_bodies = self.rigid_body_states.shape[1]

        self.root_state_tensor = gymtorch.wrap_tensor(actor_root_state_tensor).view(-1, 13)
        self.contact_tensor = gymtorch.wrap_tensor(contact_tensor).view(self.num_envs, -1)
        print("Contact Tensor Dimension", self.contact_tensor.shape)

        self.num_dofs = self.gym.get_sim_dof_count(self.sim) // self.num_envs
        print("Num dofs: ", self.num_dofs)

        self.prev_targets = torch.zeros((self.num_envs, self.num_dofs), dtype=torch.float, device=self.device)
        self.cur_targets = torch.zeros((self.num_envs, self.num_dofs), dtype=torch.float, device=self.device)

        self.global_indices = torch.arange(self.num_envs * 3, dtype=torch.int32, device=self.device).view(self.num_envs, -1)
        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))

        self.reset_goal_buf = self.reset_buf.clone()
        self.successes = torch.zeros(self.num_envs, dtype=torch.float, device=self.device)
        self.consecutive_successes = torch.zeros(1, dtype=torch.float, device=self.device)

        self.av_factor = to_torch(self.av_factor, dtype=torch.float, device=self.device)

        self.total_successes = 0
        self.total_resets = 0
        self.total_steps = 0

        # object apply random forces parameters
        self.force_decay = to_torch(self.force_decay, dtype=torch.float, device=self.device)
        self.force_prob_range = to_torch(self.force_prob_range, dtype=torch.float, device=self.device)
        self.random_force_prob = torch.exp((torch.log(self.force_prob_range[0]) - torch.log(self.force_prob_range[1]))
                                           * torch.rand(self.num_envs, device=self.device) + torch.log(self.force_prob_range[1]))

        self.rb_forces = torch.zeros((self.num_envs, self.num_bodies, 3), dtype=torch.float, device=self.device)
        self.object_pose_for_open_loop = torch.zeros_like(self.root_state_tensor[self.object_indices, 0:7])

        self.hand_base_rigid_body_index = self.gym.find_actor_rigid_body_index(self.envs[0], self.hand_indices[0], "base_link", gymapi.DOMAIN_ENV)
        print("hand_base_rigid_body_index: ", self.hand_base_rigid_body_index)

        self.hand_pos_history = torch.zeros((self.num_envs, self.max_episode_length, 3), dtype=torch.float, device=self.device)
        self.segmentation_object_center_point_x = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)
        self.segmentation_object_center_point_y = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)
        self.segmentation_object_point_num = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)

        self.meta_obs_buf = torch.zeros(
            (self.num_envs, self.num_obs), device=self.device, dtype=torch.float)
        self.meta_states_buf = torch.zeros(
            (self.num_envs, self.num_states), device=self.device, dtype=torch.float)
        self.meta_rew_buf = torch.zeros(
            self.num_envs, device=self.device, dtype=torch.float)
        self.meta_reset_buf = torch.ones(
            self.num_envs, device=self.device, dtype=torch.long)
        self.meta_progress_buf = torch.zeros(
            self.num_envs, device=self.device, dtype=torch.long)

        self.arm_hand_prepare_dof_poses[:, :] = self.arm_hand_prepare_dof_pos_list[3]
        self.end_effector_rotation[:, :] = self.end_effector_rot_list[3]
        self.allegro_dof_low_level_action = torch.zeros((self.num_envs, 16), dtype=torch.float, device=self.device)

        self.state_buf_stack_frames = []
        self.obs_buf_stack_frames = []
        self.student_obs_buf_stack_frames = []
        for i in range(self.stack_obs):
            self.obs_buf_stack_frames.append(torch.zeros_like(self.obs_buf[:, 0:self.one_frame_num_obs]))
            self.state_buf_stack_frames.append(torch.zeros_like(self.states_buf[:, 0:self.one_frame_num_states]))
            self.student_obs_buf_stack_frames.append(torch.zeros_like(self.obs_buf[:, 0:self.one_frame_num_obs]))

        self.multi_object_index = torch.zeros((self.num_envs, 12), device=self.device, dtype=torch.float)
        for i in range(self.num_envs):
            self.multi_object_index[i, i % 12] = 1

        # retrieval config
        import copy
        from utils.robot_controller.nn_builder import build_network
        from utils.robot_controller.nn_controller import NNController
        from rl_games.algos_torch import torch_ext
        import time

        self.retrieve_policy = NNController(num_actors=1, config_path='./utils/robot_controller/network.yaml')
        self.retrieve_policy.load('./utils/robot_controller/models/last_AllegroHandLego_ep_10800_rew_24.16554.pth')

        self.retrieve_num_obs = 81
        self.retrieve_obs_buf = torch.zeros(
            (self.num_envs, self.retrieve_num_obs), device=self.device, dtype=torch.float)
        self.retrieve_num_states = 129
        self.retrieve_states_buf = torch.zeros(
            (self.num_envs, self.retrieve_num_states), device=self.device, dtype=torch.float)

        self.hand_reset_step = 45
        self.hand_pos_history = torch.zeros((self.num_envs, 361, 3), dtype=torch.float, device=self.device)
        self.segmentation_object_center_point_x = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)
        self.segmentation_object_center_point_y = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)
        self.segmentation_object_point_num = torch.zeros((self.num_envs, 1), dtype=torch.int, device=self.device)
        self.extras["retri_rew_buf"] = torch.zeros(
            self.num_envs, device=self.device, dtype=torch.float)
        self.extras["retri_reset_buf"] = torch.ones(
            self.num_envs, device=self.device, dtype=torch.long)
        
        with open("demonstration/saved_retrieval_ternimal_states_good.pkl", "rb") as f:
            self.saved_retrieval_ternimal_states = pickle.load(f)

        # self.v_value = VValue(input_channel=6, output_channel=32).to(self.device)
        self.v_value = VValue(input_dim=6, output_dim=2).to(self.device)
        for param in self.v_value.parameters():
            param.requires_grad_(True)
        self.retrieval_terminal_image_buf = torch.zeros((self.num_envs, 6, 128, 128), dtype=torch.float, device=self.device)
        self.retrieval_terminal_image_buf_for_training = torch.zeros((self.num_envs, 6, 128, 128), dtype=torch.float, device=self.device)
        # self.is_test = self.cfg["env"]["test"]
        self.is_test = False

        self.v_value_optimizer = optim.Adam(self.v_value.parameters(), lr=0.0003)
        self.v_value_save_path = "./demonstration/v_value/{}-{}-{}_{}:{}:{}".format(time.localtime()[0], time.localtime()[1], time.localtime()[2], time.localtime()[3], time.localtime()[4], time.localtime()[5])
        os.makedirs(self.v_value_save_path)
        # self.v_value.load_state_dict(torch.load("/home/jmji/DexterousHandEnvs/dexteroushandenvs/demonstration/v_value/2023-1-19_1:39:15/model_5.pt", map_location='cuda:0'))
        self.bce_logits_loss = torch.nn.BCEWithLogitsLoss()

        if self.is_test:
            self.v_value.eval()
        else:
            self.v_value.train()

        self.success_buf = torch.zeros((self.num_envs, 2), dtype=torch.float32, device=self.device)
        self.predict_success_confident = torch.zeros((self.num_envs, 2), dtype=torch.float32, device=self.device)

        self.saved_retrieval_ternimal_state = torch.zeros(
            (self.num_envs, self.root_state_tensor.view(self.num_envs, -1, 13).shape[1], 13), device=self.device, dtype=torch.float)

        self.saved_retrieval_ternimal_state = self.saved_retrieval_ternimal_states[0:self.num_envs].clone()
        self.saved_retrieval_ternimal_states_index = 0

        cv2.namedWindow("DEBUG_RGB_VIS", 0)
        cv2.namedWindow("DEBUG_SEG_VIS", 0)

    def create_sim(self):
        self.dt = self.sim_params.dt
        self.up_axis_idx = self.set_sim_params_up_axis(self.sim_params, self.up_axis)
        self.sim_params.physx.max_gpu_contact_pairs = int(self.sim_params.physx.max_gpu_contact_pairs)
        # self.sim_params.dt = 1./120.

        self.sim = super().create_sim(self.device_id, self.graphics_device_id, self.physics_engine, self.sim_params)
        self._create_ground_plane()
        self._create_envs(self.num_envs, self.cfg["env"]['envSpacing'], int(np.sqrt(self.num_envs)))

    def _create_ground_plane(self):
        plane_params = gymapi.PlaneParams()
        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)
        self.gym.add_ground(self.sim, plane_params)

    def _create_envs(self, num_envs, spacing, num_per_row):
        lower = gymapi.Vec3(-spacing, -spacing, 0.0)
        upper = gymapi.Vec3(spacing, spacing, spacing)

        asset_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../assets')

        arm_hand_asset_file = self.robot_asset_files_dict["normal"]
        # arm_hand_asset_file = "urdf/xarm6/xarm6_allegro_left.urdf"
        #"urdf/xarm6/xarm6_allegro_fsr.urdf"

        if "asset" in self.cfg["env"]:
            asset_root = self.cfg["env"]["asset"].get("assetRoot", asset_root)
            # arm_hand_asset_file = self.cfg["env"]["asset"].get("assetFileName", arm_hand_asset_file)

        object_asset_file = self.asset_files_dict[self.object_type]

        # load arm and hand.
        asset_options = gymapi.AssetOptions()
        asset_options.flip_visual_attachments = False
        asset_options.fix_base_link = True
        asset_options.collapse_fixed_joints = False
        asset_options.disable_gravity = True
        asset_options.thickness = 0.001
        asset_options.angular_damping = 0.01
        # asset_options.use_mesh_materials = True
        asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        asset_options.override_com = True
        asset_options.override_inertia = True
        asset_options.vhacd_enabled = True
        asset_options.vhacd_params = gymapi.VhacdParams()
        asset_options.vhacd_params.resolution = 2000000
        asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE

        if self.physics_engine == gymapi.SIM_PHYSX:
            asset_options.use_physx_armature = True
        arm_hand_asset = self.gym.load_asset(self.sim, asset_root, arm_hand_asset_file, asset_options)
        self.num_arm_hand_bodies = self.gym.get_asset_rigid_body_count(arm_hand_asset)
        self.num_arm_hand_shapes = self.gym.get_asset_rigid_shape_count(arm_hand_asset)
        self.num_arm_hand_dofs = self.gym.get_asset_dof_count(arm_hand_asset)
        print("Num dofs: ", self.num_arm_hand_dofs)
        print("num_arm_hand_shapes: ", self.num_arm_hand_shapes)
        print("num_arm_hand_bodies: ", self.num_arm_hand_bodies)
        self.num_arm_hand_actuators = self.num_arm_hand_dofs #self.gym.get_asset_actuator_count(shadow_hand_asset)

        # Set up each DOF.
        self.actuated_dof_indices = [i for i in range(7, self.num_arm_hand_dofs)]

        self.arm_hand_dof_lower_limits = []
        self.arm_hand_dof_upper_limits = []
        self.arm_hand_dof_default_pos = []
        self.arm_hand_dof_default_vel = []

        robot_lower_qpos = []
        robot_upper_qpos = []

        robot_dof_props = self.gym.get_asset_dof_properties(arm_hand_asset)

        for i in range(23):
            robot_dof_props['driveMode'][i] = gymapi.DOF_MODE_POS
            if i < 3:
                robot_dof_props['stiffness'][i] = 400
                robot_dof_props['effort'][i] = 200
                robot_dof_props['damping'][i] = 80
            elif i < 7:
                robot_dof_props['stiffness'][i] = 400
                robot_dof_props['effort'][i] = 200
                robot_dof_props['damping'][i] = 80
            else:
                robot_dof_props['velocity'][i] = 3.0
                robot_dof_props['stiffness'][i] = 30
                robot_dof_props['effort'][i] = 5
                robot_dof_props['damping'][i] = 1

            robot_lower_qpos.append(robot_dof_props['lower'][i])
            robot_upper_qpos.append(robot_dof_props['upper'][i])

        self.actuated_dof_indices = to_torch(self.actuated_dof_indices, dtype=torch.long, device=self.device)
        self.arm_hand_dof_lower_limits = to_torch(robot_lower_qpos, device=self.device)
        self.arm_hand_dof_upper_limits = to_torch(robot_upper_qpos, device=self.device)
        self.arm_hand_dof_lower_qvel = to_torch(-robot_dof_props["velocity"], device=self.device)
        self.arm_hand_dof_upper_qvel = to_torch(robot_dof_props["velocity"], device=self.device)

        for i in range(self.num_arm_hand_dofs):
            self.arm_hand_dof_default_vel.append(0.0)

        self.arm_hand_dof_default_pos = to_torch(self.arm_hand_dof_default_pos, device=self.device)
        self.arm_hand_dof_default_vel = to_torch(self.arm_hand_dof_default_vel, device=self.device)

        # load manipulated object and goal assets
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.disable_gravity = True
        object_asset_options.fix_base_link = True
        # object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        # object_asset_options.override_com = True
        # object_asset_options.override_inertia = True
        # object_asset_options.vhacd_enabled = True
        # object_asset_options.vhacd_params = gymapi.VhacdParams()
        # object_asset_options.vhacd_params.resolution = 100000
        # object_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        object_asset = self.gym.load_asset(self.sim, asset_root, object_asset_file, object_asset_options)

        object_asset_options.disable_gravity = True
        goal_asset = self.gym.load_asset(self.sim, asset_root, object_asset_file, object_asset_options)

        # Put objects in the scene.
        arm_hand_start_pose = gymapi.Transform()
        arm_hand_start_pose.p = gymapi.Vec3(-0.35, 0.0, 0.7)
        arm_hand_start_pose.r = gymapi.Quat().from_euler_zyx(0, 0, 0.0)

        # create table asset
        table_dims = gymapi.Vec3(1.0, 1.0, 0.6)
        table_asset_options = gymapi.AssetOptions()
        table_asset_options.fix_base_link = True
        table_asset_options.flip_visual_attachments = True
        table_asset_options.collapse_fixed_joints = True
        table_asset_options.disable_gravity = True
        table_asset_options.thickness = 0.001

        table_asset = self.gym.create_box(self.sim, table_dims.x, table_dims.y, table_dims.z, table_asset_options)
        table_pose = gymapi.Transform()
        table_pose.p = gymapi.Vec3(0.0, 0.0, 0.5 * table_dims.z)
        table_pose.r = gymapi.Quat().from_euler_zyx(-0., 0, 0)

        # create box asset
        box_assets = []
        box_start_poses = []

        box_thin = 0.01
        box_xyz = [0.52, 0.72, 0.3]
        box_offset = [0.1, 0, 0]

        box_asset_options = gymapi.AssetOptions()
        box_asset_options.disable_gravity = False
        box_asset_options.fix_base_link = True
        box_asset_options.flip_visual_attachments = True
        box_asset_options.collapse_fixed_joints = True
        box_asset_options.disable_gravity = True
        box_asset_options.thickness = 0.001

        box_bottom_asset = self.gym.create_box(self.sim, box_xyz[0], box_xyz[1], box_thin, table_asset_options)
        box_left_asset = self.gym.create_box(self.sim, box_xyz[0], box_thin, box_xyz[2], table_asset_options)
        box_right_asset = self.gym.create_box(self.sim, box_xyz[0], box_thin, box_xyz[2], table_asset_options)
        box_former_asset = self.gym.create_box(self.sim, box_thin, box_xyz[1], box_xyz[2], table_asset_options)
        box_after_asset = self.gym.create_box(self.sim, box_thin, box_xyz[1], box_xyz[2], table_asset_options)

        box_bottom_start_pose = gymapi.Transform()
        box_bottom_start_pose.p = gymapi.Vec3(0.0 + box_offset[0], 0.0 + box_offset[1], 0.6 + (box_thin) / 2)
        box_left_start_pose = gymapi.Transform()
        box_left_start_pose.p = gymapi.Vec3(0.0 + box_offset[0], (box_xyz[1] - box_thin) / 2 + box_offset[1], 0.6 + (box_xyz[2]) / 2)
        box_right_start_pose = gymapi.Transform()
        box_right_start_pose.p = gymapi.Vec3(0.0 + box_offset[0], -(box_xyz[1] - box_thin) / 2 + box_offset[1], 0.6 + (box_xyz[2]) / 2)
        box_former_start_pose = gymapi.Transform()
        box_former_start_pose.p = gymapi.Vec3((box_xyz[0] - box_thin) / 2 + box_offset[0], 0.0 + box_offset[1], 0.6 + (box_xyz[2]) / 2)
        box_after_start_pose = gymapi.Transform()
        box_after_start_pose.p = gymapi.Vec3(-(box_xyz[0] - box_thin) / 2 + box_offset[0], 0.0 + box_offset[1], 0.6 + (box_xyz[2]) / 2)

        box_assets.append(box_bottom_asset)
        box_assets.append(box_left_asset)
        box_assets.append(box_right_asset)
        box_assets.append(box_former_asset)
        box_assets.append(box_after_asset)
        box_start_poses.append(box_bottom_start_pose)
        box_start_poses.append(box_left_start_pose)
        box_start_poses.append(box_right_start_pose)
        box_start_poses.append(box_former_start_pose)
        box_start_poses.append(box_after_start_pose)

        object_start_pose = gymapi.Transform()
        object_start_pose.p = gymapi.Vec3(0, 0.0, -10.78)
        object_start_pose.r = gymapi.Quat().from_euler_zyx(0, 0, 1.571)

        if self.object_type == "pen":
            object_start_pose.p.z = arm_hand_start_pose.p.z + 0.02

        self.goal_displacement = gymapi.Vec3(-0.2, -0.06, -10.12)
        self.goal_displacement_tensor = to_torch(
            [self.goal_displacement.x, self.goal_displacement.y, self.goal_displacement.z], device=self.device)
        goal_start_pose = gymapi.Transform()
        goal_start_pose.p = object_start_pose.p + self.goal_displacement

        goal_start_pose.p.z -= 0.04

        lego_path = "urdf/leoCAD/urdf/"
        all_lego_files_name = os.listdir("/home/jmji/DexterousHandEnvs/assets/" + lego_path)

        lego_assets = []
        lego_start_poses = []
        self.segmentation_id = 1

        for n in range(15):
            for i, lego_file_name in enumerate(all_lego_files_name):
                lego_asset_options = gymapi.AssetOptions()
                lego_asset_options.disable_gravity = False
                # lego_asset_options.fix_base_link = True
                # lego_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
                # lego_asset_options.override_com = True
                # lego_asset_options.override_inertia = True
                # lego_asset_options.vhacd_enabled = True
                # lego_asset_options.vhacd_params = gymapi.VhacdParams()
                # lego_asset_options.vhacd_params.resolution = 100000
                # lego_asset_options.thickness = 0.00001
                # lego_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
                # lego_asset_options.density = 1000
                lego_asset = self.gym.load_asset(self.sim, asset_root, lego_path + lego_file_name, lego_asset_options)

                lego_start_pose = gymapi.Transform()
                # if n > 0:
                #     lego_start_pose.p = gymapi.Vec3(-0.15 + 0.1 * int(i % 4) + 0.1, -0.25 + 0.1 * int(i % 24 / 4), 0.62 + 0.15 * int(i / 24) + n * 0.2 + 0.2)
                # else:
                if n % 2 == 0:
                    lego_start_pose.p = gymapi.Vec3(-0.15 + 0.15 * int(i % 3) + 0.1, -0.2 + 0.15 * int(i / 3), 0.62 + n * 0.06)
                else:
                    lego_start_pose.p = gymapi.Vec3(0.15 - 0.15 * int(i % 3) + 0.1, 0.2 - 0.15 * int(i / 3), 0.62 + n * 0.06)

                lego_start_pose.r = gymapi.Quat().from_euler_zyx(0.0, 0.0, 0.785)
                # Assets visualization
                # lego_start_pose.p = gymapi.Vec3(-0.15 + 0.2 * int(i % 18) + 0.1, 0, 0.62 + 0.2 * int(i / 18) + n * 0.8 + 5.0)
                # lego_start_pose.r = gymapi.Quat().from_euler_zyx(0.0, 0, 0)
                
                lego_assets.append(lego_asset)
                lego_start_poses.append(lego_start_pose)

        # compute aggregate size
        max_agg_bodies = self.num_arm_hand_bodies + 2 + 1 + len(lego_assets) + 5 + 10
        max_agg_shapes = self.num_arm_hand_shapes + 2 + 1 + len(lego_assets) + 5 + 10

        self.arm_hands = []
        self.envs = []

        self.object_init_state = []
        self.lego_init_states = []
        self.hand_start_states = []

        self.hand_indices = []
        self.fingertip_indices = []
        self.object_indices = []
        self.goal_object_indices = []
        self.predict_object_indices = []
        self.table_indices = []
        self.lego_indices = []
        self.lego_segmentation_indices = []

        arm_hand_rb_count = self.gym.get_asset_rigid_body_count(arm_hand_asset)
        object_rb_count = self.gym.get_asset_rigid_body_count(object_asset)
        self.object_rb_handles = list(range(arm_hand_rb_count, arm_hand_rb_count + object_rb_count))

        self.cameras = []
        self.camera_tensors = []
        self.camera_seg_tensors = []
        self.camera_view_matrixs = []
        self.camera_proj_matrixs = []

        self.camera_props = gymapi.CameraProperties()
        self.camera_props.width = 128
        self.camera_props.height = 128
        self.camera_props.enable_tensors = True

        self.env_origin = torch.zeros((self.num_envs, 3), device=self.device, dtype=torch.float)
        self.camera_u = torch.arange(0, self.camera_props.width, device=self.device)
        self.camera_v = torch.arange(0, self.camera_props.height, device=self.device)

        self.camera_v2, self.camera_u2 = torch.meshgrid(self.camera_v, self.camera_u, indexing='ij')

        for i in range(self.num_envs):
            # create env instance
            env_ptr = self.gym.create_env(
                self.sim, lower, upper, num_per_row
            )

            if self.aggregate_mode >= 1:
                self.gym.begin_aggregate(env_ptr, max_agg_bodies, max_agg_shapes, True)

            # add hand - collision filter = -1 to use asset collision filters set in mjcf loader
            arm_hand_actor = self.gym.create_actor(env_ptr, arm_hand_asset, arm_hand_start_pose, "hand", i, -1, 0)
            self.hand_start_states.append([arm_hand_start_pose.p.x,
                                           arm_hand_start_pose.p.y,
                                           arm_hand_start_pose.p.z,
                                           arm_hand_start_pose.r.x,
                                           arm_hand_start_pose.r.y,
                                           arm_hand_start_pose.r.z,
                                           arm_hand_start_pose.r.w,
                                           0, 0, 0, 0, 0, 0])
            self.gym.set_actor_dof_properties(env_ptr, arm_hand_actor, robot_dof_props)
            hand_idx = self.gym.get_actor_index(env_ptr, arm_hand_actor, gymapi.DOMAIN_SIM)
            self.hand_indices.append(hand_idx)

            # add object
            object_handle = self.gym.create_actor(env_ptr, object_asset, object_start_pose, "object", i, 0, 0)
            self.object_init_state.append([object_start_pose.p.x, object_start_pose.p.y, object_start_pose.p.z,
                                           object_start_pose.r.x, object_start_pose.r.y, object_start_pose.r.z, object_start_pose.r.w,
                                           0, 0, 0, 0, 0, 0])
            object_idx = self.gym.get_actor_index(env_ptr, object_handle, gymapi.DOMAIN_SIM)
            self.object_indices.append(object_idx)

            # add goal object
            goal_handle = self.gym.create_actor(env_ptr, goal_asset, goal_start_pose, "goal_object", i + self.num_envs, 0, 0)
            goal_object_idx = self.gym.get_actor_index(env_ptr, goal_handle, gymapi.DOMAIN_SIM)
            self.goal_object_indices.append(goal_object_idx)

            # add table
            table_handle = self.gym.create_actor(env_ptr, table_asset, table_pose, "table", i, -1, 0)
            # self.gym.set_rigid_body_texture(env_ptr, table_handle, 0, gymapi.MESH_VISUAL, table_texture_handle)
            table_idx = self.gym.get_actor_index(env_ptr, table_handle, gymapi.DOMAIN_SIM)
            self.gym.set_rigid_body_color(
                env_ptr, table_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.9, 0.8)
            )
            self.table_indices.append(table_idx)

            table_shape_props = self.gym.get_actor_rigid_shape_properties(env_ptr, table_handle)
            for object_shape_prop in table_shape_props:
                object_shape_prop.friction = 1
            self.gym.set_actor_rigid_shape_properties(env_ptr, table_handle, table_shape_props)

            # add box
            for box_i, box_asset in enumerate(box_assets):
                box_handle = self.gym.create_actor(env_ptr, box_asset, box_start_poses[box_i], "box_{}".format(box_i), i, 0, 0)
                # self.lego_init_state.append([lego_init_state.p.x, lego_init_state.p.y, object_start_pose.p.z,
                #                             lego_init_state.r.x, lego_init_state.r.y, object_start_pose.r.z, object_start_pose.r.w,
                #                             0, 0, 0, 0, 0, 0])
                # object_idx = self.gym.get_actor_index(env_ptr, object_handle, gymapi.DOMAIN_SIM)
                # self.object_indices.append(object_idx)
                self.gym.set_rigid_body_color(env_ptr, box_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.8, 0.4, 0))

            # add lego
            color_map = [[0, 0, 0], [1, 1, 1], [1, 1, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 0, 1], [1, 0, 0]]
            lego_idx = []
            # self.segmentation_id = i % 12

            for lego_i, lego_asset in enumerate(lego_assets):
                lego_handle = self.gym.create_actor(env_ptr, lego_asset, lego_start_poses[lego_i], "lego_{}".format(lego_i), i, 0, lego_i)
                # lego_handle = self.gym.create_actor(env_ptr, lego_asset, lego_start_poses[lego_i], "lego_{}".format(lego_i), i + self.num_envs + lego_i, -1, 0)
                self.lego_init_states.append([lego_start_poses[lego_i].p.x, lego_start_poses[lego_i].p.y, lego_start_poses[lego_i].p.z,
                                            lego_start_poses[lego_i].r.x, lego_start_poses[lego_i].r.y, lego_start_poses[lego_i].r.z, lego_start_poses[lego_i].r.w,
                                            0, 0, 0, 0, 0, 0])
                idx = self.gym.get_actor_index(env_ptr, lego_handle, gymapi.DOMAIN_SIM)
                if lego_i == self.segmentation_id:
                    self.lego_segmentation_indices.append(idx)

                # lego_body_props = self.gym.get_actor_rigid_body_properties(env_ptr, lego_handle)
                # for lego_body_prop in lego_body_props:
                #     print(lego_body_prop.mass)
                #     lego_body_prop.mass = 2
                # self.gym.set_actor_rigid_body_properties(env_ptr, lego_handle, lego_body_props)

                # lego_shape_props = self.gym.get_actor_rigid_shape_properties(env_ptr, lego_handle)
                # for object_shape_prop in lego_shape_props:
                #     object_shape_prop.friction = 2
                # self.gym.set_actor_rigid_shape_properties(env_ptr, lego_handle, lego_shape_props)

                lego_idx.append(idx)

                color = color_map[lego_i % 8]
                self.gym.set_rigid_body_color(env_ptr, lego_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(color[0], color[1], color[2]))
            self.lego_indices.append(lego_idx)

            if self.enable_camera_sensors:
                camera_handle = self.gym.create_camera_sensor(env_ptr, self.camera_props)
                self.gym.set_camera_location(camera_handle, env_ptr, gymapi.Vec3(0.2, 0, 1.2), gymapi.Vec3(0.0, 0, 0))
                camera_tensor = self.gym.get_camera_image_gpu_tensor(self.sim, env_ptr, camera_handle, gymapi.IMAGE_COLOR)
                torch_cam_tensor = gymtorch.wrap_tensor(camera_tensor)
                camera_seg_tensor = self.gym.get_camera_image_gpu_tensor(self.sim, env_ptr, camera_handle, gymapi.IMAGE_SEGMENTATION)
                torch_cam_seg_tensor = gymtorch.wrap_tensor(camera_seg_tensor)

                cam_vinv = torch.inverse((torch.tensor(self.gym.get_camera_view_matrix(self.sim, env_ptr, camera_handle)))).to(self.device)
                cam_proj = torch.tensor(self.gym.get_camera_proj_matrix(self.sim, env_ptr, camera_handle), device=self.device)
            
            # arm_hand_shape_props = self.gym.get_actor_rigid_shape_properties(env_ptr, arm_hand_actor)
            # for arm_hand_shape_prop in arm_hand_shape_props:
            #     arm_hand_shape_prop.friction = 1.
            # self.gym.set_actor_rigid_shape_properties(env_ptr, arm_hand_actor, arm_hand_shape_props)

            # Set up object...
            if self.object_type != "block":
                self.gym.set_rigid_body_color(
                    env_ptr, object_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.6, 0.72, 0.98)
                )
                self.gym.set_rigid_body_color(
                    env_ptr, goal_handle, 0, gymapi.MESH_VISUAL, gymapi.Vec3(0.6, 0.72, 0.98)
                )

            if self.aggregate_mode > 0:
                self.gym.end_aggregate(env_ptr)

            if self.enable_camera_sensors:
                origin = self.gym.get_env_origin(env_ptr)
                self.env_origin[i][0] = origin.x
                self.env_origin[i][1] = origin.y
                self.env_origin[i][2] = origin.z
                self.camera_tensors.append(torch_cam_tensor)
                self.camera_seg_tensors.append(torch_cam_seg_tensor)
                self.camera_view_matrixs.append(cam_vinv)
                self.camera_proj_matrixs.append(cam_proj)
                self.cameras.append(camera_handle)

            self.envs.append(env_ptr)
            self.arm_hands.append(arm_hand_actor)

        self.emergence_reward = torch.zeros_like(self.rew_buf, device=self.device, dtype=torch.float)
        self.emergence_pixel = torch.zeros_like(self.rew_buf, device=self.device, dtype=torch.float)
        self.last_emergence_pixel = torch.zeros_like(self.rew_buf, device=self.device, dtype=torch.float)

        self.heap_movement_penalty= torch.zeros_like(self.rew_buf, device=self.device, dtype=torch.float)

        # Acquire specific links.
        sensor_handles = range(7)
        self.sensor_handle_indices = to_torch(sensor_handles, dtype=torch.int64)

        self.fingertip_handles = [self.gym.find_actor_rigid_body_handle(env_ptr, arm_hand_actor, name) for name in self.fingertip_names]
        object_rb_props = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
        self.object_rb_masses = [prop.mass for prop in object_rb_props]

        self.object_init_state = to_torch(self.object_init_state, device=self.device, dtype=torch.float).view(self.num_envs, 13)
        self.goal_states = self.object_init_state.clone()
        self.goal_states[:, self.up_axis_idx] -= 0.02
        self.goal_init_state = self.goal_states.clone()
        self.hand_start_states = to_torch(self.hand_start_states, device=self.device).view(self.num_envs, 13)

        self.lego_init_states = to_torch(self.lego_init_states, device=self.device).view(self.num_envs, len(lego_assets), 13)

        self.fingertip_handles = to_torch(self.fingertip_handles, dtype=torch.long, device=self.device)
        self.object_rb_handles = to_torch(self.object_rb_handles, dtype=torch.long, device=self.device)
        self.object_rb_masses = to_torch(self.object_rb_masses, dtype=torch.float, device=self.device)

        self.hand_indices = to_torch(self.hand_indices, dtype=torch.long, device=self.device)
        self.object_indices = to_torch(self.object_indices, dtype=torch.long, device=self.device)
        self.goal_object_indices = to_torch(self.goal_object_indices, dtype=torch.long, device=self.device)
        self.lego_indices = to_torch(self.lego_indices, dtype=torch.long, device=self.device)
        self.lego_segmentation_indices = to_torch(self.lego_segmentation_indices, dtype=torch.long, device=self.device)

    def compute_reward(self, actions):
        self.rew_buf[:], self.reset_buf[:], self.reset_goal_buf[:], self.progress_buf[:], self.successes[:], self.consecutive_successes[:] = compute_hand_grasp_reward(
            torch.tensor(self.spin_coef).to(self.device), self.rew_buf, self.reset_buf, self.reset_goal_buf, self.progress_buf, self.successes, self.consecutive_successes, self.hand_reset_step, self.contacts,
            self.max_episode_length, self.object_pos, self.object_rot, self.object_angvel, self.goal_pos, self.goal_rot, self.segmentation_target_pos, self.hand_base_pos, self.emergence_reward, self.arm_hand_ff_pos, self.arm_hand_rf_pos, self.arm_hand_mf_pos, self.arm_hand_th_pos, self.heap_movement_penalty, self.segmentation_target_init_pos,
            self.dist_reward_scale, self.rot_reward_scale, self.rot_eps, self.actions, self.action_penalty_scale,
            self.success_tolerance, self.reach_goal_bonus, self.fall_dist, self.fall_penalty, self.rotation_id,
            self.max_consecutive_successes, self.av_factor, (self.object_type == "pen"), self.segmentation_target_linvel
        )

        self.extras["retri_rew_buf"][:], self.extras["retri_reset_buf"], self.reset_goal_buf[:], self.progress_buf[:], self.successes[:], self.consecutive_successes[:], self.retri_env_reward, self.retri_v_value_reward = compute_hand_retri_reward(
            torch.tensor(self.spin_coef).to(self.device), self.extras["retri_rew_buf"], self.extras["retri_reset_buf"], self.reset_goal_buf, self.progress_buf, self.successes, self.consecutive_successes, self.hand_reset_step, self.contacts, self.palm_contacts_z, self.predict_success_confident,
            361, self.object_pos, self.object_rot, self.object_angvel, self.goal_pos, self.goal_rot, self.segmentation_target_pos, self.hand_base_pos, self.emergence_reward, self.arm_hand_ff_pos, self.arm_hand_rf_pos, self.arm_hand_mf_pos, self.arm_hand_th_pos, self.heap_movement_penalty,
            self.dist_reward_scale, self.rot_reward_scale, self.rot_eps, self.actions, self.action_penalty_scale,
            self.success_tolerance, self.reach_goal_bonus, self.fall_dist, self.fall_penalty, self.rotation_id,
            self.max_consecutive_successes, self.av_factor, (self.object_type == "pen"), self.init_heap_movement_penalty,
        )

        self.meta_rew_buf += self.rew_buf[:].clone()

        self.extras['emergence_reward'] = self.emergence_reward
        self.extras['heap_movement_penalty'] = self.heap_movement_penalty
        self.extras['meta_reward'] = self.meta_rew_buf

        if self.progress_buf[0] < 361:
            self.extras['retri_env_reward'] = self.retri_env_reward.mean()
            self.extras['retri_v_value_reward'] = self.retri_v_value_reward.mean()

        self.total_steps += 1
        # print("Total epoch = {}".format(int(self.total_steps/8)))

        if self.print_success_stat:
            print("Total steps = {}".format(self.total_steps))
            self.total_resets = self.total_resets + self.reset_buf.sum()
            direct_average_successes = self.total_successes + self.successes.sum()
            self.total_successes = self.total_successes + (self.successes * self.reset_buf).sum()

            # The direct average shows the overall result more quickly, but slightly undershoots long term
            # policy performance.
            print("Direct average consecutive successes = {:.1f}".format(direct_average_successes/(self.total_resets + self.num_envs)))
            if self.total_resets > 0:
                print("Post-Reset average consecutive successes = {:.1f}".format(self.total_successes/self.total_resets))

    def retrieval_reset(self):
        current_hand_base_pos = self.arm_hand_dof_pos[:, :].clone()

        for i in range(20):
            self.gym.refresh_dof_state_tensor(self.sim)
            self.gym.refresh_actor_root_state_tensor(self.sim)
            self.gym.refresh_rigid_body_state_tensor(self.sim)
            self.gym.refresh_jacobian_tensors(self.sim)
            self.gym.refresh_net_contact_force_tensor(self.sim)

            # pos_err = to_torch([0.0, 0.0, 0.02], device=self.device).repeat((self.num_envs, 1))

            # # target_euler = to_torch([0.0, 1.571, 0], device=self.device).repeat((self.num_envs, 1))
            # # target_rot = quat_from_euler_xyz(target_euler[:, 0], target_euler[:, 1], target_euler[:, 2])
            # target_rot = self.end_effector_rotation

            # rot_err = orientation_error(target_rot, self.rigid_body_states[:, self.hand_base_rigid_body_index, 3:7].clone())

            # dpose = torch.cat([pos_err, rot_err], -1).unsqueeze(-1)
            # delta = control_ik(self.jacobian_tensor[:, self.hand_base_rigid_body_index, :, :-2], self.device, dpose, self.num_envs)
            # targets = self.cur_targets[:, :7] + delta[:, :7]

            targets = ((self.arm_hand_prepare_dof_poses - current_hand_base_pos) + current_hand_base_pos)[:, :7]

            self.cur_targets[:, :7] = tensor_clamp(targets,
                                                    self.arm_hand_dof_lower_limits[:7],
                                                    self.arm_hand_dof_upper_limits[:7])

            # self.prev_targets[:, :] = self.cur_targets[:, :]
            self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.cur_targets))

            self.render()
            self.gym.simulate(self.sim)

        pos = self.arm_hand_default_dof_pos #+ self.reset_dof_pos_noise * rand_delta
        self.arm_hand_dof_pos[:, 0:23] = pos[0:23]
        self.arm_hand_dof_vel[:, :] = self.arm_hand_dof_default_vel #+ \
        #     #self.reset_dof_vel_noise * rand_floats[:, 5+self.num_arm_hand_dofs:5+self.num_arm_hand_dofs*2]
        self.prev_targets[:, :self.num_arm_hand_dofs] = pos
        self.cur_targets[:, :self.num_arm_hand_dofs] = pos
        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                            gymtorch.unwrap_tensor(self.dof_state),
                                            gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        for i in range(20):
            self.render()
            self.gym.simulate(self.sim)

        self.render_for_camera()
        self.gym.fetch_results(self.sim, True)

        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.render_all_camera_sensors(self.sim)
        self.gym.start_access_image_tensors(self.sim)

        camera_rgba_image = self.camera_rgb_visulization(self.camera_tensors, env_id=0, is_depth_image=False)
        camera_seg_image = self.camera_segmentation_visulization(self.camera_tensors, self.camera_seg_tensors, env_id=0, is_depth_image=False)

        self.compute_emergence_reward(self.camera_tensors, self.camera_seg_tensors, segmentation_id=self.segmentation_id)
        self.all_lego_brick_pos = self.root_state_tensor[self.lego_indices[:].view(-1), 0:3].clone().view(self.num_envs, -1, 3)
        self.compute_heap_movement_penalty(self.all_lego_brick_pos)

        self.hand_pos_history_0 = torch.mean(self.hand_pos_history[:, 0*self.hand_reset_step:1*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_1 = torch.mean(self.hand_pos_history[:, 1*self.hand_reset_step:2*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_2 = torch.mean(self.hand_pos_history[:, 2*self.hand_reset_step:3*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_3 = torch.mean(self.hand_pos_history[:, 3*self.hand_reset_step:4*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_4 = torch.mean(self.hand_pos_history[:, 4*self.hand_reset_step:5*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_5 = torch.mean(self.hand_pos_history[:, 5*self.hand_reset_step:6*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_6 = torch.mean(self.hand_pos_history[:, 6*self.hand_reset_step:7*self.hand_reset_step, :], dim=1, keepdim=False)
        self.hand_pos_history_7 = torch.mean(self.hand_pos_history[:, 7*self.hand_reset_step:8*self.hand_reset_step, :], dim=1, keepdim=False)

        # self.camera_rgbd_image_tensors = torch.stack(self.camera_tensors, dim=0).view(self.num_envs, -1)
        # self.camera_seg_image_tensors = ((torch.stack(self.camera_seg_tensors, dim=0) == self.segmentation_id) * 255).view(self.num_envs, -1)

        cv2.imshow("DEBUG_RGB_VIS", camera_rgba_image)
        cv2.imshow("DEBUG_SEG_VIS", camera_seg_image)
        cv2.waitKey(1)

        for i in range(self.num_envs):
            self.retrieval_terminal_image_buf[i, :3, :, :] = self.camera_tensors[i][:, :, :3].permute(2, 0, 1)
            seg_camera_tensors = self.camera_tensors[i][:, :, :3].clone()
            seg_camera_tensors[self.camera_seg_tensors[i] != self.segmentation_id] = 0
            self.retrieval_terminal_image_buf[i, 3:6, :, :] = seg_camera_tensors.permute(2, 0, 1)

        self.predict_success_confident = self.v_value(self.retrieval_terminal_image_buf).detach()

        self.arm_hand_dof_pos[:, 0:23] = self.arm_hand_prepare_dof_poses
        self.prev_targets[:, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses
        self.cur_targets[:, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses

        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                            gymtorch.unwrap_tensor(self.dof_state),
                                            gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        self.extras["retri_reset_buf"][:] = 0

        if self.progress_buf[0] > 351:
            self.retri_success_count = 0
            for i in range(self.num_envs):
                if self.segmentation_object_point_num[i] > 40:
                    self.saved_retrieval_ternimal_state[i] = self.root_state_tensor.clone().view(self.num_envs, -1, 13)[i]
                    self.retri_success_count += 1

            self.extras["retri_success_rate"] = self.retri_success_count / self.num_envs

    def grasp_reset(self):
        saved_object_indices = torch.unique(torch.cat([self.object_indices[:],
                                                 self.goal_object_indices[:],
                                                 self.lego_indices[:].view(-1)]).to(torch.int32))
        if self.grasp_inited:
            self.success_buf[:, 0] = torch.where(self.segmentation_target_pos[:, 2]-self.segmentation_target_init_pos[:, 2] > 0.1, 1.0, 0.0)
            self.success_buf[:, 1] = torch.where(self.segmentation_target_pos[:, 2]-self.segmentation_target_init_pos[:, 2] <= 0.1, 1.0, 0.0)

            if not self.is_test:
                with TemporaryGrad():
                    for _ in range(5):
                        # forward
                        self.predict_success_confident = self.v_value(self.retrieval_terminal_image_buf_for_training)

                        # update v value
                        loss = self.bce_logits_loss(self.predict_success_confident, self.success_buf)
                        self.v_value_optimizer.zero_grad()
                        loss.backward()
                        self.v_value_optimizer.step()
                    self.extras['BCE_loss'] = loss
                    self.extras['predict_success_confident'] = self.predict_success_confident[:, 0].mean()
                    self.extras['predict_unsuccess_confident'] = self.predict_success_confident[:, 1].mean()
                    self.extras['success_buf'] = self.success_buf[:, 0].mean()
            
        pos = self.arm_hand_default_dof_pos #+ self.reset_dof_pos_noise * rand_delta
        self.arm_hand_dof_pos[:, 0:23] = pos[0:23]
        self.arm_hand_dof_vel[:, :] = self.arm_hand_dof_default_vel #+ \
        #     #self.reset_dof_vel_noise * rand_floats[:, 5+self.num_arm_hand_dofs:5+self.num_arm_hand_dofs*2]
        self.prev_targets[:, :self.num_arm_hand_dofs] = pos
        self.cur_targets[:, :self.num_arm_hand_dofs] = pos

        hand_indices = self.hand_indices[:].to(torch.int32)
        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(hand_indices), self.num_envs)

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                              gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(hand_indices), self.num_envs)

        self.gym.set_actor_root_state_tensor_indexed(self.sim,
                                                     gymtorch.unwrap_tensor(self.saved_retrieval_ternimal_state),
                                                     gymtorch.unwrap_tensor(saved_object_indices), len(saved_object_indices))
        for i in range(2):
            self.render()
            self.gym.simulate(self.sim)

        self.gym.fetch_results(self.sim, True)

        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.render_all_camera_sensors(self.sim)
        self.gym.start_access_image_tensors(self.sim)

        for i in range(self.num_envs):
            self.retrieval_terminal_image_buf_for_training[i, :3, :, :] = self.camera_tensors[i][:, :, :3].permute(2, 0, 1)
            seg_camera_tensors = self.camera_tensors[i][:, :, :3].clone()
            seg_camera_tensors[self.camera_seg_tensors[i] != self.segmentation_id] = 0
            self.retrieval_terminal_image_buf_for_training[i, 3:6, :, :] = seg_camera_tensors.permute(2, 0, 1)

        camera_rgba_image = self.camera_rgb_visulization(self.camera_tensors, env_id=0, is_depth_image=False)
        camera_seg_image = self.camera_segmentation_visulization(self.camera_tensors, self.camera_seg_tensors, env_id=0, is_depth_image=False)

        # self.predict_success_confident = self.v_value(self.retrieval_terminal_image_buf).detach()

        cv2.imshow("DEBUG_RGB_VIS", camera_rgba_image)
        cv2.imshow("DEBUG_SEG_VIS", camera_seg_image)
        cv2.waitKey(1)

        # save retrieval terminal state image
        # print(self.segmentation_object_point_num[0])
        # if self.segmentation_object_point_num[0] > 40:
        #     cv2.imwrite("demonstration/figure/v_value_seg_{}.jpg".format(self.total_steps), camera_seg_image)
        #     cv2.imwrite("demonstration/figure/v_value_rgb_{}.jpg".format(self.total_steps), camera_rgba_image)

        self.arm_hand_prepare_dof_poses[:, :] = self.arm_hand_prepare_dof_pos_list[0]
        self.end_effector_rotation[:, :] = self.end_effector_rot_list[0]

        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)

        self.segmentation_target_init_pos = self.root_state_tensor[self.lego_segmentation_indices, 0:3].clone()
        self.segmentation_target_init_rot = self.root_state_tensor[self.lego_segmentation_indices, 3:7].clone()

        self.arm_hand_dof_pos[:, 0:23] = self.arm_hand_prepare_dof_poses
        self.prev_targets[:, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses
        self.cur_targets[:, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses
        self.arm_hand_dof_vel[:, :] = self.arm_hand_dof_default_vel 

        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                              gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(self.hand_indices.to(torch.int32)), self.num_envs)

        for i in range(2):
            self.gym.simulate(self.sim)

        if self.grasp_inited == True:
            self.reset_buf[:] = 1
        self.grasp_inited = True

    def compute_observations(self, is_retrieval=False, last_action=0):
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        self.gym.refresh_jacobian_tensors(self.sim)

        # self.gym.render_all_camera_sensors(self.sim)
        # self.gym.start_access_image_tensors(self.sim)
        if self.enable_camera_sensors and self.progress_buf[0] % 45 == 0 and self.progress_buf[0] < 361 and self.progress_buf[0] != 0:
            self.retrieval_reset()

        if 960 > self.progress_buf[0] >= 360 and (self.progress_buf[0] - 360) % 150 == 0 and self.progress_buf[0] != self.max_episode_length:
            self.grasp_reset()

        # if self.enable_camera_sensors:
        self.object_pose = self.root_state_tensor[self.object_indices, 0:7]
        self.object_pos = self.root_state_tensor[self.object_indices, 0:3]
        self.object_rot = self.root_state_tensor[self.object_indices, 3:7]
        self.object_linvel = self.root_state_tensor[self.object_indices, 7:10]
        self.object_angvel = self.root_state_tensor[self.object_indices, 10:13]

        self.goal_pose = self.goal_states[:, 0:7]
        self.goal_pos = self.goal_states[:, 0:3]
        self.goal_rot = self.goal_states[:, 3:7]

        self.hand_base_pose = self.rigid_body_states[:, self.hand_base_rigid_body_index, 0:7]
        self.hand_base_pos = self.rigid_body_states[:, self.hand_base_rigid_body_index, 0:3]
        self.hand_base_rot = self.rigid_body_states[:, self.hand_base_rigid_body_index, 3:7]
        self.hand_base_linvel = self.rigid_body_states[:, self.hand_base_rigid_body_index, 7:10]
        self.hand_base_angvel = self.rigid_body_states[:, self.hand_base_rigid_body_index, 10:13]

        if self.progress_buf[0] - 1 < 361:
            self.hand_pos_history[:, self.progress_buf[0] - 1, :] = self.hand_base_pos.clone()

        self.segmentation_target_pose = self.root_state_tensor[self.lego_segmentation_indices, 0:7]
        self.segmentation_target_pos = self.root_state_tensor[self.lego_segmentation_indices, 0:3]
        self.segmentation_target_rot = self.root_state_tensor[self.lego_segmentation_indices, 3:7]
        self.segmentation_target_linvel = self.root_state_tensor[self.lego_segmentation_indices, 7:10]
        self.segmentation_target_angvel = self.root_state_tensor[self.lego_segmentation_indices, 10:13]

        self.arm_hand_ff_pos = self.rigid_body_states[:, self.fingertip_handles[0], 0:3]
        self.arm_hand_ff_rot = self.rigid_body_states[:, self.fingertip_handles[0], 3:7]
        self.arm_hand_ff_linvel = self.rigid_body_states[:, self.fingertip_handles[0], 7:10]
        self.arm_hand_ff_angvel = self.rigid_body_states[:, self.fingertip_handles[0], 10:13]

        # self.arm_hand_ff_pos = self.arm_hand_ff_pos + quat_apply(self.arm_hand_ff_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
        self.arm_hand_mf_pos = self.rigid_body_states[:, self.fingertip_handles[1], 0:3]
        self.arm_hand_mf_rot = self.rigid_body_states[:, self.fingertip_handles[1], 3:7]
        self.arm_hand_mf_linvel = self.rigid_body_states[:, self.fingertip_handles[1], 7:10]
        self.arm_hand_mf_angvel = self.rigid_body_states[:, self.fingertip_handles[1], 10:13]

        self.arm_hand_rf_pos = self.rigid_body_states[:, self.fingertip_handles[2], 0:3]
        self.arm_hand_rf_rot = self.rigid_body_states[:, self.fingertip_handles[2], 3:7]
        self.arm_hand_rf_linvel = self.rigid_body_states[:, self.fingertip_handles[2], 7:10]
        self.arm_hand_rf_angvel = self.rigid_body_states[:, self.fingertip_handles[2], 10:13]
        # self.arm_hand_lf_rot = self.rigid_body_states[:, 20, 3:7]
        # self.arm_hand_lf_pos = self.arm_hand_lf_pos + quat_apply(self.arm_hand_lf_rot, to_torch([0, 0, 1], device=self.device).repeat(self.num_envs, 1) * 0.02)
        self.arm_hand_th_pos = self.rigid_body_states[:, self.fingertip_handles[3], 0:3]
        self.arm_hand_th_rot = self.rigid_body_states[:, self.fingertip_handles[3], 3:7]
        self.arm_hand_th_linvel = self.rigid_body_states[:, self.fingertip_handles[3], 7:10]
        self.arm_hand_th_angvel = self.rigid_body_states[:, self.fingertip_handles[3], 10:13]

        contacts = self.contact_tensor.reshape(self.num_envs, -1, 3)  # 39+27
        palm_contacts = contacts[:, 8, :]
        contacts = contacts[:, self.sensor_handle_indices, :] # 12
        contacts = torch.norm(contacts, dim=-1)
        self.contacts = torch.where(contacts >= 0.1, 1.0, 0.0)

        self.palm_contacts_z = palm_contacts[:, 2]
        # self.palm_contacts = torch.where(palm_contacts_z >= 100, 1.0, 0.0)

        if self.palm_contacts_z[0] > 100.0:
            self.gym.set_rigid_body_color(
                        self.envs[0], self.hand_indices[0], 8, gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.3, 0.3))
        else:
            self.gym.set_rigid_body_color(
                        self.envs[0], self.hand_indices[0], 8, gymapi.MESH_VISUAL, gymapi.Vec3(1, 1, 1))

        for i in range(len(self.contacts[0])):
            if self.contacts[0][i] == 1.0:
                self.gym.set_rigid_body_color(
                            self.envs[0], self.hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 0.3, 0.3))
            else:
                self.gym.set_rigid_body_color(
                            self.envs[0], self.hand_indices[0], self.sensor_handle_indices[i], gymapi.MESH_VISUAL, gymapi.Vec3(1, 1, 1))


        if is_retrieval:
            self.compute_retrieval_observations(last_action)
        elif self.obs_type == "partial_contact":
            self.compute_grasp_observations(False)
            self.compute_grasp_states()
            self.compute_retri_observations()
            self.compute_retri_states()
        
        if self.enable_camera_sensors and self.progress_buf[0] % 45 == 0 and self.progress_buf[0] < 361 and self.progress_buf[0] != 0:
            self.gym.end_access_image_tensors(self.sim)

        if self.progress_buf[0] >= 360 and (self.progress_buf[0] - 360) % 150 == 0 and self.progress_buf[0] != self.max_episode_length:
            self.gym.end_access_image_tensors(self.sim)

    def compute_grasp_states(self):
            self.states_buf[:, 0:23] = unscale(self.arm_hand_dof_pos[:, 0:23],
                                                                self.arm_hand_dof_lower_limits[0:23],
                                                                self.arm_hand_dof_upper_limits[0:23])
            self.states_buf[:, 23:46] = self.vel_obs_scale * self.arm_hand_dof_vel[:, 0:23]

            self.states_buf[:, 46:49] = self.arm_hand_ff_pos
            self.states_buf[:, 49:52] = self.arm_hand_rf_pos
            self.states_buf[:, 52:55] = self.arm_hand_mf_pos
            self.states_buf[:, 55:58] = self.arm_hand_th_pos

            self.states_buf[:, 58:81] = self.actions
            self.states_buf[:, 81:88] = self.hand_base_pose

            self.states_buf[:, 88:95] = self.segmentation_target_pose

            self.states_buf[:, 95:98] = self.hand_base_linvel
            self.states_buf[:, 98:101] = self.hand_base_angvel

            self.states_buf[:, 101:105] = self.arm_hand_ff_rot  
            self.states_buf[:, 105:108] = self.arm_hand_ff_linvel
            self.states_buf[:, 108:111] = self.arm_hand_ff_angvel

            self.states_buf[:, 111:115] = self.arm_hand_mf_rot  
            self.states_buf[:, 115:118] = self.arm_hand_mf_linvel
            self.states_buf[:, 118:121] = self.arm_hand_mf_angvel

            self.states_buf[:, 121:125] = self.arm_hand_rf_rot  
            self.states_buf[:, 125:128] = self.arm_hand_rf_linvel
            self.states_buf[:, 128:131] = self.arm_hand_rf_angvel

            self.states_buf[:, 131:135] = self.arm_hand_th_rot  
            self.states_buf[:, 135:138] = self.arm_hand_th_linvel
            self.states_buf[:, 138:141] = self.arm_hand_th_angvel

            self.states_buf[:, 141:142] = (self.progress_buf.unsqueeze(-1) - 361) % (150) / (150)

            self.states_buf[:, 142:145] = self.segmentation_target_linvel
            self.states_buf[:, 145:148] = self.segmentation_target_angvel

            self.states_buf[:, 148:151] = self.segmentation_target_init_pos
            self.states_buf[:, 151:154] = self.segmentation_target_pos - self.segmentation_target_init_pos

            # self.states_buf[:, 154:166] = self.multi_object_index
            self.states_buf[:, 154:157] = self.hand_base_pos - self.segmentation_target_pos
            self.states_buf[:, 157:161] = quat_mul(self.hand_base_rot, quat_conjugate(self.segmentation_target_rot))

            self.states_buf[:, 161:164] = self.segmentation_target_pos - self.arm_hand_ff_pos
            self.states_buf[:, 164:167] = self.segmentation_target_pos - self.arm_hand_rf_pos
            self.states_buf[:, 167:170] = self.segmentation_target_pos - self.arm_hand_mf_pos
            self.states_buf[:, 170:173] = self.segmentation_target_pos - self.arm_hand_th_pos

            self.states_buf[:, 173:174] = self.arm_hand_finger_dist.unsqueeze(-1)

            for i in range(len(self.state_buf_stack_frames) - 1):
                self.states_buf[:, (i+1) * self.one_frame_num_states:(i+2) * self.one_frame_num_states] = self.state_buf_stack_frames[i]
                self.state_buf_stack_frames[i] = self.states_buf[:, (i) * self.one_frame_num_states:(i+1) * self.one_frame_num_states].clone()

    def compute_grasp_observations(self, full_contact=True):
        self.obs_buf[:, 0:23] = unscale(self.arm_hand_dof_pos[:,0:23],
                                                            self.arm_hand_dof_lower_limits[0:23],
                                                            self.arm_hand_dof_upper_limits[0:23])
        # self.obs_buf[:, 16:23] = self.goal_pose
        self.obs_buf[:, 23:46] = self.actions

        self.obs_buf[:, 46:53] = self.hand_base_pose

        self.obs_buf[:, 53:56] = self.segmentation_target_pos
        self.obs_buf[:, 56:60] = self.segmentation_target_rot

        self.obs_buf[:, 60:61] = (self.progress_buf.unsqueeze(-1) - 361) % (150) / (150)

        self.obs_buf[:, 61:64] = self.segmentation_target_init_pos
        self.obs_buf[:, 64:67] = self.segmentation_target_pos - self.segmentation_target_init_pos

        self.obs_buf[:, 67:70] = self.hand_base_pos - self.segmentation_target_pos
        self.obs_buf[:, 70:74] = quat_mul(self.hand_base_rot, quat_conjugate(self.segmentation_target_rot))

        for i in range(len(self.obs_buf_stack_frames) - 1):
            self.obs_buf[:, (i+1) * self.one_frame_num_obs:(i+2) * self.one_frame_num_obs] = self.obs_buf_stack_frames[i]
            self.obs_buf_stack_frames[i] = self.obs_buf[:, (i) * self.one_frame_num_obs:(i+1) * self.one_frame_num_obs].clone()

        # self.extras["distill_obs"] = self.obs_buf.clone()
        # self.extras["distill_obs"][:, 64:78] = torch.zeros_like(self.extras["distill_obs"][:, 64:78])
        # self.extras["distill_obs"][:, 81:84] = torch.zeros_like(self.extras["distill_obs"][:, 81:84])

        # for i in range(len(self.obs_buf_stack_frames) - 1):
        #     self.extras["distill_obs"][:, 64 + self.one_frame_num_obs * i:78 + self.one_frame_num_obs * i] = torch.zeros_like(self.extras["distill_obs"][:, 64:78])
        #     self.extras["distill_obs"][:, 81 + self.one_frame_num_obs * i:84 + self.one_frame_num_obs * i] = torch.zeros_like(self.extras["distill_obs"][:, 81:84])

    def compute_retri_states(self):
        self.retrieve_states_buf[:, 0:23] = unscale(self.arm_hand_dof_pos[:, 0:23],
                                                            self.arm_hand_dof_lower_limits[0:23],
                                                            self.arm_hand_dof_upper_limits[0:23])
        self.retrieve_states_buf[:, 23:46] = self.vel_obs_scale * self.arm_hand_dof_vel[:, 0:23]

        self.retrieve_states_buf[:, 46:49] = self.arm_hand_ff_pos
        self.retrieve_states_buf[:, 49:52] = self.arm_hand_rf_pos
        self.retrieve_states_buf[:, 52:55] = self.arm_hand_mf_pos
        self.retrieve_states_buf[:, 55:58] = self.arm_hand_th_pos

        self.retrieve_states_buf[:, 58:81] = self.actions
        self.retrieve_states_buf[:, 81:88] = self.hand_base_pose

        self.retrieve_states_buf[:, 88:95] = self.segmentation_target_pose

        self.retrieve_states_buf[:, 95:96] = (self.progress_buf.unsqueeze(-1) - 1) % self.hand_reset_step

        self.retrieve_states_buf[:, 96:99] = self.hand_pos_history_0
        self.retrieve_states_buf[:, 99:102] = self.hand_pos_history_1
        self.retrieve_states_buf[:, 102:105] = self.hand_pos_history_2
        self.retrieve_states_buf[:, 105:108] = self.hand_pos_history_3
        self.retrieve_states_buf[:, 108:111] = self.hand_pos_history_4
        self.retrieve_states_buf[:, 111:114] = self.hand_pos_history_5
        self.retrieve_states_buf[:, 114:117] = self.hand_pos_history_6
        self.retrieve_states_buf[:, 117:120] = self.hand_pos_history_7

        # self.states_buf[:, 108:128*128*4 + 108] = self.camera_rgbd_image_tensors
        # self.states_buf[:, 128*128*3 + 108:128*128*4 + 108] = self.camera_seg_image_tensors

        # for i in range(self.num_envs):
        #     if self.segmentation_object_center_point_x[i] != 0:
        #         self.states_buf[i, 120:(self.seg_xmax-self.seg_xmin)*(self.seg_ymax-self.seg_ymin)*3 + 120] = self.camera_tensors[i][self.seg_xmin:self.seg_xmax, self.seg_ymin:self.seg_ymax, :3].flatten() / 256
        #     else:
        #         self.states_buf[i, 120:20*20*3 + 120] = torch.zeros_like(self.states_buf[i, 120:20*20*3 + 120])

        self.retrieve_states_buf[:, 120:121] = self.segmentation_object_center_point_x / 128
        self.retrieve_states_buf[:, 121:122] = self.segmentation_object_center_point_y / 128
        self.retrieve_states_buf[:, 122:123] = self.segmentation_object_point_num / 100

        self.retrieve_states_buf[:, 123:126] = self.hand_base_linvel
        self.retrieve_states_buf[:, 126:129] = self.hand_base_angvel

        self.retrieve_states_buf[:, 129:133] = self.arm_hand_ff_rot  
        self.retrieve_states_buf[:, 133:136] = self.arm_hand_ff_linvel
        self.retrieve_states_buf[:, 136:139] = self.arm_hand_ff_angvel

        self.retrieve_states_buf[:, 139:143] = self.arm_hand_mf_rot  
        self.retrieve_states_buf[:, 143:146] = self.arm_hand_mf_linvel
        self.retrieve_states_buf[:, 146:149] = self.arm_hand_mf_angvel

        self.retrieve_states_buf[:, 149:153] = self.arm_hand_rf_rot  
        self.retrieve_states_buf[:, 153:156] = self.arm_hand_rf_linvel
        self.retrieve_states_buf[:, 156:159] = self.arm_hand_rf_angvel

        self.retrieve_states_buf[:, 159:163] = self.arm_hand_th_rot  
        self.retrieve_states_buf[:, 163:166] = self.arm_hand_th_linvel
        self.retrieve_states_buf[:, 166:169] = self.arm_hand_th_angvel

        self.retrieve_states_buf[:, 169:172] = self.segmentation_target_linvel
        self.retrieve_states_buf[:, 172:175] = self.segmentation_target_angvel

    def compute_retrieval_observations(self, full_contact=True):
        self.retrieve_obs_buf[:, 0:23] = unscale(self.arm_hand_dof_pos[:,0:23],
                                                            self.arm_hand_dof_lower_limits[0:23],
                                                            self.arm_hand_dof_upper_limits[0:23])
        # self.obs_buf[:, 16:23] = self.goal_pose
        self.retrieve_obs_buf[:, 23:42] = self.actions[:, 4:23]
        self.retrieve_obs_buf[:, 77:81] = self.actions[:, 0:4]

        self.retrieve_obs_buf[:, 42:49] = self.hand_base_pose

        self.retrieve_obs_buf[:, 49:50] = (self.progress_buf.unsqueeze(-1) - 1) % self.hand_reset_step

        self.retrieve_obs_buf[:, 50:53] = self.hand_pos_history_0
        self.retrieve_obs_buf[:, 53:56] = self.hand_pos_history_1
        self.retrieve_obs_buf[:, 56:59] = self.hand_pos_history_2
        self.retrieve_obs_buf[:, 59:62] = self.hand_pos_history_3
        self.retrieve_obs_buf[:, 62:65] = self.hand_pos_history_4
        self.retrieve_obs_buf[:, 65:68] = self.hand_pos_history_5
        self.retrieve_obs_buf[:, 68:71] = self.hand_pos_history_6
        self.retrieve_obs_buf[:, 71:74] = self.hand_pos_history_7

        for i in range(self.num_envs):
            self.segmentation_object_point_list = torch.nonzero(torch.where(self.camera_seg_tensors[i] == self.segmentation_id, self.camera_seg_tensors[i], torch.zeros_like(self.camera_seg_tensors[i])))
            self.segmentation_object_point_list = self.segmentation_object_point_list.float()
            if self.segmentation_object_point_list.shape[0] > 0:
                self.segmentation_object_center_point_x[i] = int(torch.mean(self.segmentation_object_point_list[:, 0]))
                self.segmentation_object_center_point_y[i] = int(torch.mean(self.segmentation_object_point_list[:, 1]))
            else:
                self.segmentation_object_center_point_x[i] = 0
                self.segmentation_object_center_point_y[i] = 0
            
            self.segmentation_object_point_num[i] = self.segmentation_object_point_list.shape[0]

            # if self.segmentation_object_center_point_x[i] != 0:
            #     self.seg_xmin = torch.clamp(self.segmentation_object_center_point_x[i] - 10, 0, 128)
            #     self.seg_xmax = torch.clamp(self.segmentation_object_center_point_x[i] + 10, 0, 128)
            #     self.seg_ymin = torch.clamp(self.segmentation_object_center_point_y[i] - 10, 0, 128)
            #     self.seg_ymax = torch.clamp(self.segmentation_object_center_point_y[i] + 10, 0, 128)
            #     self.obs_buf[i, 74:(self.seg_xmax-self.seg_xmin)*(self.seg_ymax-self.seg_ymin)*3 + 74] = self.camera_tensors[i][self.seg_xmin:self.seg_xmax, self.seg_ymin:self.seg_ymax, :3].reshape(-1) / 256
            # else:
            #     self.obs_buf[i, 74:20*20*3 + 74] = torch.zeros_like(self.obs_buf[i, 74:20*20*3 + 74])

        self.retrieve_obs_buf[:, 74:75] = self.segmentation_object_center_point_x / 128
        self.retrieve_obs_buf[:, 75:76] = self.segmentation_object_center_point_y / 128
        self.retrieve_obs_buf[:, 76:77] = self.segmentation_object_point_num / 100

        # x1 = self.camera_seg_tensors[0]
        # self.camera_tensors[0][x1 != self.segmentation_id] = 0

        # camera_image = self.camera_tensors[0].cpu().numpy()
        # camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)

        # cv2.circle(camera_image, (int(self.segmentation_object_center_point_x[0]), int(self.segmentation_object_center_point_y[0])), 1, (0, 0, 255), 2)

        # cv2.namedWindow("DEBUG_X1_VIS", 0)
        # cv2.imshow("DEBUG_X1_VIS", camera_image)

        # cv2.waitKey(1)

    def reset_target_pose(self, env_ids, apply_reset=False):
        rand_floats_x = torch_rand_float(-1, 1, (len(env_ids), 4), device=self.device)
        rand_floats_y = torch_rand_float(-1, 1, (len(env_ids), 4), device=self.device)

        new_rot = randomize_rotation(rand_floats_x[:, 0], rand_floats_y[:, 1],
                                     self.x_unit_tensor[env_ids],
                                     self.y_unit_tensor[env_ids])

        if apply_reset:
            self.object_pose_for_open_loop[env_ids] = self.goal_states[env_ids, 0:7]

        self.goal_states[env_ids, 0:3] = self.goal_init_state[env_ids, 0:3]
        # if not apply_reset:
        self.goal_states[env_ids, 3:7] = new_rot
        # self.goal_states[env_ids, 3:7] = self.goal_init_state[env_ids, 3:7]
        self.root_state_tensor[self.goal_object_indices[env_ids], 0:3] = self.goal_states[env_ids, 0:3] + self.goal_displacement_tensor
        self.root_state_tensor[self.goal_object_indices[env_ids], 3:7] = self.goal_states[env_ids, 3:7]
        self.root_state_tensor[self.goal_object_indices[env_ids], 7:13] = torch.zeros_like(self.root_state_tensor[self.goal_object_indices[env_ids], 7:13])

        if apply_reset:
            self.object_pose_for_open_loop[env_ids] = self.goal_states[env_ids, 0:7].clone()
            goal_object_indices = self.goal_object_indices[env_ids].to(torch.int32)
            self.gym.set_actor_root_state_tensor_indexed(self.sim,
                                                         gymtorch.unwrap_tensor(self.root_state_tensor),
                                                         gymtorch.unwrap_tensor(goal_object_indices), len(env_ids))
        self.reset_goal_buf[env_ids] = 0

    # default robot pose: [0.00, 0.782, -1.087, 3.487, 2.109, -1.415]
    def reset_idx(self, env_ids, goal_env_ids):
        # train V value
        if self.total_steps > 0:
            self.success_buf[:, 0] = torch.where(self.segmentation_target_pos[:, 2]-self.segmentation_target_init_pos[:, 2] > 0.1, 1.0, 0.0)
            self.success_buf[:, 1] = torch.where(self.segmentation_target_pos[:, 2]-self.segmentation_target_init_pos[:, 2] <= 0.1, 1.0, 0.0)

            if not self.is_test:
                with TemporaryGrad():
                    for _ in range(5):
                        # forward
                        self.predict_success_confident = self.v_value(self.retrieval_terminal_image_buf_for_training)

                        # update v value
                        loss = self.bce_logits_loss(self.predict_success_confident, self.success_buf)
                        self.v_value_optimizer.zero_grad()
                        loss.backward()
                        self.v_value_optimizer.step()
                    self.extras['BCE_loss'] = loss
                    self.extras['predict_success_confident'] = self.predict_success_confident[:, 0].mean()
                    self.extras['predict_unsuccess_confident'] = self.predict_success_confident[:, 1].mean()
                    self.extras['success_buf'] = self.success_buf[:, 0].mean()

            iter = int(self.total_steps / (10 * (self.max_episode_length - 1)))
            print("v_value_udpate_iter: ", iter)
            if self.total_steps % (10 * (self.max_episode_length - 1)) == 0:
                torch.save(self.v_value.state_dict(), self.v_value_save_path + "/model_{}.pt".format(iter))

                self.ground_true_success_buf = torch.where(self.segmentation_target_pos[:, 2]-self.segmentation_target_init_pos[:, 2] > 0.1, 1.0, 0.0).unsqueeze(-1)
                self.predict_success_buf = torch.zeros((self.num_envs, 1), dtype=torch.float32, device=self.device)
                self.predict_success_confident = self.v_value(self.retrieval_terminal_image_buf).detach()
                for i in range(self.num_envs):
                    self.predict_success_confident = torch.sigmoid(self.predict_success_confident)
                    if self.predict_success_confident[i, 0] > self.predict_success_confident[i, 1]:
                        self.predict_success_buf[i] = 1

                self.valid_v_value_success_rate = torch.count_nonzero((self.ground_true_success_buf - self.predict_success_buf)) / self.num_envs
                self.extras['valid_v_value_success_rate'] = self.valid_v_value_success_rate

        # generate random values
        rand_floats = torch_rand_float(-1.0, 1.0, (len(env_ids), self.num_arm_hand_dofs * 2 + 5), device=self.device)

        # randomize start object poses
        self.reset_target_pose(env_ids)

        # reset rigid body forces
        self.rb_forces[env_ids, :, :] = 0.0

        # reset object
        self.root_state_tensor[self.object_indices[env_ids]] = self.object_init_state[env_ids].clone()
        self.root_state_tensor[self.object_indices[env_ids], 0:2] = self.object_init_state[env_ids, 0:2] + \
            self.reset_position_noise * rand_floats[:, 0:2]
        self.root_state_tensor[self.object_indices[env_ids], self.up_axis_idx] = self.object_init_state[env_ids, self.up_axis_idx] + \
            self.reset_position_noise * rand_floats[:, self.up_axis_idx]

        self.root_state_tensor[self.object_indices[env_ids], 3:7] = self.object_init_state[env_ids, 3:7].clone()
        self.root_state_tensor[self.object_indices[env_ids], 7:13] = torch.zeros_like(self.root_state_tensor[self.object_indices[env_ids], 7:13])
        self.object_pose_for_open_loop[env_ids] = self.root_state_tensor[self.object_indices[env_ids], 0:7].clone()

        self.root_state_tensor[self.lego_indices[env_ids].view(-1), 0:7] = self.lego_init_states[env_ids].view(-1, 13)[:, 0:7].clone()
        self.root_state_tensor[self.lego_indices[env_ids].view(-1), 7:13] = torch.zeros_like(self.root_state_tensor[self.lego_indices[env_ids].view(-1), 7:13])

        # randomize segmentation object
        self.root_state_tensor[self.lego_segmentation_indices[env_ids], 0] = rand_floats[env_ids, 0] * 0.2 + 0.1
        self.root_state_tensor[self.lego_segmentation_indices[env_ids], 1] = rand_floats[env_ids, 1] * 0.3
        self.root_state_tensor[self.lego_segmentation_indices[env_ids], 2] = rand_floats[env_ids, 2] * 0.1 + 0.82

        object_indices = torch.unique(torch.cat([self.object_indices[env_ids],
                                                 self.goal_object_indices[env_ids],
                                                 self.goal_object_indices[goal_env_ids],
                                                 self.lego_indices[env_ids].view(-1)]).to(torch.int32))
        
        self.saved_object_indices = object_indices

        # self.env_rand_range = range(0, 2000)
        # self.env_rand = random.sample(self.env_rand_range, self.num_envs)
        # self.saved_retrieval_ternimal_state = self.saved_retrieval_ternimal_states[self.env_rand]

        # self.gym.set_actor_root_state_tensor_indexed(self.sim,
        #                                              gymtorch.unwrap_tensor(self.saved_retrieval_ternimal_state),
        #                                              gymtorch.unwrap_tensor(object_indices), len(object_indices))

        self.gym.set_actor_root_state_tensor_indexed(self.sim,
                                                     gymtorch.unwrap_tensor(self.root_state_tensor),
                                                     gymtorch.unwrap_tensor(object_indices), len(object_indices))
        # reset random force probabilities
        self.random_force_prob[env_ids] = torch.exp((torch.log(self.force_prob_range[0]) - torch.log(self.force_prob_range[1]))
                                                    * torch.rand(len(env_ids), device=self.device) + torch.log(self.force_prob_range[1]))

        # reset shadow hand
        #delta_max = self.arm_hand_dof_upper_limits - self.arm_hand_dof_default_pos
        #delta_min = self.arm_hand_dof_lower_limits - self.arm_hand_dof_default_pos
        #rand_delta = delta_min + (delta_max - delta_min) * rand_floats[:, 5:5+self.num_arm_hand_dofs]

        pos = self.arm_hand_default_dof_pos #+ self.reset_dof_pos_noise * rand_delta
        self.arm_hand_dof_pos[env_ids, 0:23] = pos[0:23]
        self.arm_hand_dof_vel[env_ids, :] = self.arm_hand_dof_default_vel #+ \
        #     #self.reset_dof_vel_noise * rand_floats[:, 5+self.num_arm_hand_dofs:5+self.num_arm_hand_dofs*2]
        self.prev_targets[env_ids, :self.num_arm_hand_dofs] = pos
        self.cur_targets[env_ids, :self.num_arm_hand_dofs] = pos

        hand_indices = self.hand_indices[env_ids].to(torch.int32)
        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(hand_indices), len(env_ids))

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                              gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(hand_indices), len(env_ids))

        self.post_reset(env_ids, hand_indices, object_indices, rand_floats)

        self.progress_buf[env_ids] = 0
        self.reset_buf[env_ids] = 0
        self.extras["retri_reset_buf"][env_ids] = 0
        self.successes[env_ids] = 0
        self.meta_rew_buf[env_ids] = 0
        self.grasp_inited = False

    def post_reset(self, env_ids, hand_indices, object_indices, rand_floats):
        # step physics and render each frame

        ###########################
        for i in range(50):
            self.render()
            self.gym.simulate(self.sim)

        self.render_for_camera()
        self.gym.fetch_results(self.sim, True)

        if self.enable_camera_sensors:
            self.gym.refresh_dof_state_tensor(self.sim)
            self.gym.refresh_actor_root_state_tensor(self.sim)
            self.gym.refresh_rigid_body_state_tensor(self.sim)
            self.gym.refresh_jacobian_tensors(self.sim)
            self.gym.refresh_net_contact_force_tensor(self.sim)
            self.gym.render_all_camera_sensors(self.sim)
            self.gym.start_access_image_tensors(self.sim)

            self.segmentation_target_init_pos = self.root_state_tensor[self.lego_segmentation_indices, 0:3].clone()
            self.segmentation_target_init_rot = self.root_state_tensor[self.lego_segmentation_indices, 3:7].clone()

            camera_rgba_image = self.camera_rgb_visulization(self.camera_tensors, env_id=0, is_depth_image=False)
            camera_seg_image = self.camera_segmentation_visulization(self.camera_tensors, self.camera_seg_tensors, env_id=0, is_depth_image=False)

            self.compute_emergence_reward(self.camera_tensors, self.camera_seg_tensors, segmentation_id=self.segmentation_id)
            # for i in range(self.num_envs):
            #     torch_seg_tensor = self.camera_tensors[i]
            #     self.last_emergence_pixel[i] = torch_seg_tensor[torch_seg_tensor == self.segmentation_id].shape[0]

            self.last_all_lego_brick_pos = self.root_state_tensor[self.lego_indices[:], 0:3].clone()
            
            self.hand_pos_history = torch.zeros_like(self.hand_pos_history)

            self.hand_pos_history_0 = torch.mean(self.hand_pos_history[:, 0*self.hand_reset_step:1*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_1 = torch.mean(self.hand_pos_history[:, 1*self.hand_reset_step:2*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_2 = torch.mean(self.hand_pos_history[:, 2*self.hand_reset_step:3*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_3 = torch.mean(self.hand_pos_history[:, 3*self.hand_reset_step:4*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_4 = torch.mean(self.hand_pos_history[:, 4*self.hand_reset_step:5*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_5 = torch.mean(self.hand_pos_history[:, 5*self.hand_reset_step:6*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_6 = torch.mean(self.hand_pos_history[:, 6*self.hand_reset_step:7*self.hand_reset_step, :], dim=1, keepdim=False)
            self.hand_pos_history_7 = torch.mean(self.hand_pos_history[:, 7*self.hand_reset_step:8*self.hand_reset_step, :], dim=1, keepdim=False)
            # self.camera_rgbd_image_tensors = torch.stack(self.camera_tensors, dim=0).view(self.num_envs, -1)
            # self.camera_seg_image_tensors = ((torch.stack(self.camera_seg_tensors, dim=0) == self.segmentation_id) * 255).view(self.num_envs, -1)

            cv2.imshow("DEBUG_RGB_VIS", camera_rgba_image)
            cv2.imshow("DEBUG_SEG_VIS", camera_seg_image)
            cv2.waitKey(1)

            self.gym.end_access_image_tensors(self.sim)

            self.all_lego_brick_pos = self.root_state_tensor[self.lego_indices[:].view(-1), 0:3].clone().view(self.num_envs, -1, 3)
            self.init_heap_movement_penalty = torch.where(abs(self.all_lego_brick_pos[:self.num_envs, :, 0] - 1) > 0.25,
                                                torch.where(abs(self.all_lego_brick_pos[:self.num_envs, :, 1]) > 0.35, torch.ones_like(self.all_lego_brick_pos[:self.num_envs, :, 0]), torch.zeros_like(self.all_lego_brick_pos[:self.num_envs, :, 0])), torch.zeros_like(self.all_lego_brick_pos[:self.num_envs, :, 0]))
            
            self.init_heap_movement_penalty = torch.sum(self.init_heap_movement_penalty, dim=1, keepdim=False)

        self.arm_hand_prepare_dof_poses[:, :] = self.arm_hand_prepare_dof_pos_list[3]
        self.end_effector_rotation[:, :] = self.end_effector_rot_list[3]

        self.arm_hand_dof_pos[env_ids, 0:23] = self.arm_hand_prepare_dof_poses
        self.prev_targets[env_ids, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses
        self.cur_targets[env_ids, :self.num_arm_hand_dofs] = self.arm_hand_prepare_dof_poses

        self.gym.set_dof_position_target_tensor_indexed(self.sim,
                                                        gymtorch.unwrap_tensor(self.prev_targets),
                                                        gymtorch.unwrap_tensor(hand_indices), len(env_ids))

        self.gym.set_dof_state_tensor_indexed(self.sim,
                                              gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(hand_indices), len(env_ids))
        print("post_reset finish")

    def pre_physics_step(self, actions):
        if self.progress_buf[0] < 951 and self.total_steps != 0:
            self.reset_buf[:] = 0
        
        env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
        goal_env_ids = self.reset_goal_buf.nonzero(as_tuple=False).squeeze(-1)

        # if only goals need reset, then call set API
        if len(env_ids) > 0:
            self.reset_idx(env_ids, goal_env_ids)
        
        self.actions = actions.clone().to(self.device)
        # self.actions = torch.ones_like(actions.clone().to(self.device))
        if self.use_relative_control:
            # targets = self.prev_targets[:, self.actuated_dof_indices] + self.shadow_hand_dof_speed_scale * self.dt * self.actions
            targets = self.arm_hand_dof_pos[:, self.actuated_dof_indices] + self.shadow_hand_dof_speed_scale * self.dt * self.actions
            self.cur_targets[:, self.actuated_dof_indices] = tensor_clamp(targets,
                                                                          self.arm_hand_dof_lower_limits[self.actuated_dof_indices],
                                                                          self.arm_hand_dof_upper_limits[self.actuated_dof_indices])
        else:
            # self.allegro_dof_low_level_action[:, 0:1] = 0
            # self.allegro_dof_low_level_action[:, 1:4] += self.actions[:, 3].unsqueeze(-1).expand(-1, 3) * self.dt * 5
            # self.allegro_dof_low_level_action[:, 4:5] = 1
            # self.allegro_dof_low_level_action[:, 5:8] += self.actions[:, 4].unsqueeze(-1).expand(-1, 3) * self.dt * 5
            # self.allegro_dof_low_level_action[:, 8:9] = 0
            # self.allegro_dof_low_level_action[:, 9:12] += self.actions[:, 3].unsqueeze(-1).expand(-1, 3) * self.dt * 5
            # self.allegro_dof_low_level_action[:, 12:13] = 0
            # self.allegro_dof_low_level_action[:, 13:16] += self.actions[:, 3].unsqueeze(-1).expand(-1, 3) * self.dt * 5
            # self.cur_targets[:, self.actuated_dof_indices] = scale(self.allegro_dof_low_level_action[:, 0:16],
            # self.cur_targets[:, self.actuated_dof_indices] = scale(self.actions[:, 3:19],
            self.cur_targets[:, self.actuated_dof_indices] = scale(self.actions[:, 7:23],
                                                                   self.arm_hand_dof_lower_limits[self.actuated_dof_indices],
                                                                   self.arm_hand_dof_upper_limits[self.actuated_dof_indices])
            self.cur_targets[:, self.actuated_dof_indices] = self.act_moving_average * self.cur_targets[:,
                                                                                                        self.actuated_dof_indices] + (1.0 - self.act_moving_average) * self.prev_targets[:, self.actuated_dof_indices]
            self.cur_targets[:, self.actuated_dof_indices] = tensor_clamp(self.cur_targets[:, self.actuated_dof_indices],
                                                                          self.arm_hand_dof_lower_limits[self.actuated_dof_indices],
                                                                          self.arm_hand_dof_upper_limits[self.actuated_dof_indices])
            
            # qpos control robotic arm
            # targets = self.prev_targets[:, :7] + self.shadow_hand_dof_speed_scale * self.dt * self.actions[:, :7]
            # self.cur_targets[:, :7] = tensor_clamp(targets,
            #                                         self.arm_hand_dof_lower_limits[:7],
            #                                         self.arm_hand_dof_upper_limits[:7])

            # IK control robotic arm
            if self.progress_buf[0] < 361:
                pos_err = self.actions[:, 0:3] * 1.6
                rot_err = self.actions[:, 3:6] * 0.5
            else:
                pos_err = self.actions[:, 0:3] * 0.16
                rot_err = self.actions[:, 3:6] * 0.05
            # target_rot = self.end_effector_rotation

            # rot_err = orientation_error(target_rot, self.rigid_body_states[:, self.hand_base_rigid_body_index, 3:7].clone())
            dpose = torch.cat([pos_err, rot_err], -1).unsqueeze(-1)
            delta = control_ik(self.jacobian_tensor[:, self.hand_base_rigid_body_index - 1, :, :7], self.device, dpose, self.num_envs)
            targets = self.arm_hand_dof_pos[:, 0:7] + delta[:, :7]

            self.cur_targets[:, :7] = tensor_clamp(targets,
                                                    self.arm_hand_dof_lower_limits[:7],
                                                    self.arm_hand_dof_upper_limits[:7])
            # print(self.arm_hand_dof_pos[0, :7])

        self.prev_targets[:, :] = self.cur_targets[:, :]
        self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.cur_targets))
        # self.arm_hand_dof_pos[:, 0:7] = self.cur_targets[:, :7]
        # self.gym.set_dof_state_tensor(self.sim,
        #                                       gymtorch.unwrap_tensor(self.dof_state))

        if self.force_scale > 0.0:
            self.rb_forces *= torch.pow(self.force_decay, self.dt / self.force_decay_interval)

            # apply new forces
            force_indices = (torch.rand(self.num_envs, device=self.device) < self.random_force_prob).nonzero()
            self.rb_forces[force_indices, self.object_rb_handles, :] = torch.randn(
                self.rb_forces[force_indices, self.object_rb_handles, :].shape, device=self.device) * self.object_rb_masses * self.force_scale

            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.rb_forces), None, gymapi.LOCAL_SPACE)

    def post_physics_step(self):
        self.progress_buf += 1
        self.randomize_buf += 1

        self.compute_observations()
        self.compute_reward(self.actions)
        # self.gym.clear_lines(self.viewer)
        # self.gym.refresh_rigid_body_state_tensor(self.sim)

        # self.add_debug_lines(self.envs[0], self.segmentation_target_pos[0], self.segmentation_target_rot[0])
        # self.add_debug_lines(self.envs[0], self.segmentation_target_init_pos[0], self.segmentation_target_init_rot[0])

        if self.viewer and self.debug_viz:
            # draw axes on target object
            self.gym.clear_lines(self.viewer)
            self.gym.refresh_rigid_body_state_tensor(self.sim)

            for i in range(self.num_envs):
                self.add_debug_lines(self.envs[i], self.segmentation_target_pos[i], self.segmentation_target_rot[i])
                self.add_debug_lines(self.envs[i], self.hand_base_pos[i], self.hand_base_rot[i])

    def add_debug_lines(self, env, pos, rot):
        posx = (pos + quat_apply(rot, to_torch([1, 0, 0], device=self.device) * 0.2)).cpu().numpy()
        posy = (pos + quat_apply(rot, to_torch([0, 1, 0], device=self.device) * 0.2)).cpu().numpy()
        posz = (pos + quat_apply(rot, to_torch([0, 0, 1], device=self.device) * 0.2)).cpu().numpy()

        p0 = pos.cpu().numpy()
        self.gym.add_lines(self.viewer, env, 1, [p0[0], p0[1], p0[2], posx[0], posx[1], posx[2]], [0.85, 0.1, 0.1])
        self.gym.add_lines(self.viewer, env, 1, [p0[0], p0[1], p0[2], posy[0], posy[1], posy[2]], [0.1, 0.85, 0.1])
        self.gym.add_lines(self.viewer, env, 1, [p0[0], p0[1], p0[2], posz[0], posz[1], posz[2]], [0.1, 0.1, 0.85])

    def camera_rgb_visulization(self, camera_tensors, env_id=0, is_depth_image=False):
        torch_rgba_tensor = camera_tensors[env_id].clone()
        camera_image = torch_rgba_tensor.cpu().numpy()
        camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)
        
        return camera_image

    def camera_segmentation_visulization(self, camera_tensors, camera_seg_tensors, segmentation_id=0, env_id=0, is_depth_image=False):
        torch_rgba_tensor = camera_tensors[env_id].clone()
        torch_seg_tensor = camera_seg_tensors[env_id].clone()
        torch_rgba_tensor[torch_seg_tensor != self.segmentation_id] = 0

        camera_image = torch_rgba_tensor.cpu().numpy()
        camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)

        return camera_image

    def compute_emergence_reward(self, camera_tensors, camera_seg_tensors, segmentation_id=0):
        for i in range(self.num_envs):
            torch_seg_tensor = camera_seg_tensors[i]
            self.emergence_pixel[i] = torch_seg_tensor[torch_seg_tensor == segmentation_id].shape[0]

        self.emergence_reward = (self.emergence_pixel - self.last_emergence_pixel) * 2
        self.last_emergence_pixel = self.emergence_pixel.clone()

    def compute_heap_movement_penalty(self, all_lego_brick_pos):
        self.heap_movement_penalty = torch.where(abs(all_lego_brick_pos[:self.num_envs, :, 0] - 1) > 0.25,
                                            torch.where(abs(all_lego_brick_pos[:self.num_envs, :, 1]) > 0.35, torch.ones_like(all_lego_brick_pos[:self.num_envs, :, 0]), torch.zeros_like(all_lego_brick_pos[:self.num_envs, :, 0])), torch.zeros_like(all_lego_brick_pos[:self.num_envs, :, 0]))
        
        self.heap_movement_penalty = torch.sum(self.heap_movement_penalty, dim=1, keepdim=False)
        # self.heap_movement_penalty = torch.where(self.emergence_reward < 0.05, torch.mean(torch.norm(all_lego_brick_pos - last_all_lego_brick_pos, p=2, dim=-1), dim=-1, keepdim=False), torch.zeros_like(self.heap_movement_penalty))
        
        self.last_all_lego_brick_pos = self.all_lego_brick_pos.clone()
        
#####################################################################
###=========================jit functions=========================###
#####################################################################

@torch.jit.script
def compute_hand_grasp_reward(
    spin_coef, rew_buf, reset_buf, reset_goal_buf, progress_buf, successes, consecutive_successes, max_hand_reset_length: int, arm_contacts,
    max_episode_length: float, object_pos, object_rot, object_angvel, target_pos, target_rot, segmentation_target_pos, hand_base_pos, emergence_reward, arm_hand_ff_pos, arm_hand_rf_pos, arm_hand_mf_pos, arm_hand_th_pos, heap_movement_penalty, segmentation_target_init_pos,
    dist_reward_scale: float, rot_reward_scale: float, rot_eps: float,
    actions, action_penalty_scale: float,
    success_tolerance: float, reach_goal_bonus: float, fall_dist: float,
    fall_penalty: float, rotation_id: int, max_consecutive_successes: int, av_factor: float, ignore_z_rot: bool, segmentation_target_linvel
):
    # Distance from the hand to the object
    # goal_dist = torch.norm(hand_base_pos - segmentation_target_pos, p=2, dim=-1)
    # dist_rew = goal_dist * dist_reward_scale

    arm_hand_finger_dist = (torch.norm(segmentation_target_pos - arm_hand_ff_pos, p=2, dim=-1) + torch.norm(segmentation_target_pos - arm_hand_mf_pos, p=2, dim=-1)
                            + torch.norm(segmentation_target_pos - arm_hand_rf_pos, p=2, dim=-1) + torch.norm(segmentation_target_pos - arm_hand_th_pos, p=2, dim=-1))
    # dist_rew = torch.exp(-2 * arm_hand_finger_dist)
    dist_rew = torch.clamp(- arm_hand_finger_dist, None, -0.3)

    # Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
    # emergence_reward = torch.where(arm_hand_finger_dist < 2, emergence_reward / 10, torch.zeros_like(emergence_reward))

    # object_up_reward = 1 / ((torch.clamp_min(abs(0.2 - torch.clamp_min(segmentation_target_pos[:, 2]-segmentation_target_init_pos[:, 2], 0)), 0)) + 0.02) - 1/0.2
    object_up_reward = (0.1 - (0.1 - torch.clamp(segmentation_target_pos[:, 2]-segmentation_target_init_pos[:, 2], 0, None))) * 100
    object_up_reward = torch.clamp(torch.where(arm_hand_finger_dist < 0.5, object_up_reward, torch.zeros_like(object_up_reward)), min=None, max=30) 

    success_bonus = torch.zeros_like(object_up_reward)
    success_bonus = torch.where(segmentation_target_pos[:, 2]-segmentation_target_init_pos[:, 2] > 0.1, 800, 0)

    # dist_rew = torch.where(progress_buf % max_hand_reset_length == 0, dist_rew, torch.zeros_like(dist_rew))
    heap_movement_penalty = torch.clamp(heap_movement_penalty, min=0, max=5)

    move_out_penalty = torch.zeros_like(object_up_reward)
    move_out_penalty = torch.where(abs(hand_base_pos[:, 0] - 0.1) >= 0.3, torch.ones_like(move_out_penalty) * 3, move_out_penalty)
    move_out_penalty = torch.where(abs(hand_base_pos[:, 1]) >= 0.4, torch.ones_like(move_out_penalty) * 3, move_out_penalty)

    action_penalty = torch.sum(actions ** 2, dim=-1) * 0.001
    arm_contacts_penalty = torch.sum(arm_contacts, dim=-1)

    resets = torch.where(arm_hand_finger_dist <= -1, torch.ones_like(reset_buf), reset_buf)

    timed_out = progress_buf >= max_episode_length - 1
    resets = torch.where(timed_out, torch.ones_like(resets), resets)
    success_bonus = torch.where(resets,
                        torch.where(arm_hand_finger_dist < 0.5, success_bonus, torch.zeros_like(success_bonus)), torch.zeros_like(success_bonus))
    # print("segmentation_target_pos[:, 2]-segmentation_target_init_pos[:, 2]: ", segmentation_target_pos[0, 2]-segmentation_target_init_pos[0, 2])

    reward = - move_out_penalty + dist_rew + success_bonus - action_penalty + object_up_reward - arm_contacts_penalty
    # if reward[0] != 0:
        # print(dist_rew[0])
        # print("object_up_reward: ", object_up_reward[0])
        # print("dist_rew: ", dist_rew[0])
        # print("move_out_penalty: ", move_out_penalty[0])
        # print("success_bonus: ", success_bonus[0])
        # print("arm_contacts_penalty: ", arm_contacts_penalty[0])

    # print(object_up_reward[0])
    # # print(spin_reward[0])
    # print("----finish----")

    # Fall penalty: distance to the goal is larger than a threshold
    # Check env termination conditions, including maximum success number

    # Apply penalty for not reaching the goal
    if max_consecutive_successes > 0:
        reward = torch.where(timed_out, reward + 0.5 * fall_penalty, reward)

    num_resets = torch.sum(resets)
    finished_cons_successes = torch.sum(successes * resets.float())

    cons_successes = torch.where(num_resets > 0, av_factor*finished_cons_successes/num_resets + (1.0 - av_factor)*consecutive_successes, consecutive_successes)

    return reward, resets, reset_goal_buf, progress_buf, successes, cons_successes

@torch.jit.script
def compute_hand_retri_reward(
    spin_coef, rew_buf, reset_buf, reset_goal_buf, progress_buf, successes, consecutive_successes, max_hand_reset_length: int, arm_contacts, palm_contacts_z, predict_success_confident,
    max_episode_length: float, object_pos, object_rot, object_angvel, target_pos, target_rot, segmentation_target_pos, hand_base_pos, emergence_reward, arm_hand_ff_pos, arm_hand_rf_pos, arm_hand_mf_pos, arm_hand_th_pos, heap_movement_penalty,
    dist_reward_scale: float, rot_reward_scale: float, rot_eps: float,
    actions, action_penalty_scale: float,
    success_tolerance: float, reach_goal_bonus: float, fall_dist: float,
    fall_penalty: float, rotation_id: int, max_consecutive_successes: int, av_factor: float, ignore_z_rot: bool, init_heap_movement_penalty
):
    # Distance from the hand to the object
    # goal_dist = torch.norm(hand_base_pos - segmentation_target_pos, p=2, dim=-1)
    # dist_rew = goal_dist * dist_reward_scale

    arm_hand_finger_dist = (torch.norm(segmentation_target_pos - arm_hand_ff_pos, p=2, dim=-1) + torch.norm(segmentation_target_pos - arm_hand_mf_pos, p=2, dim=-1)
                            + torch.norm(segmentation_target_pos - arm_hand_rf_pos, p=2, dim=-1) + torch.norm(segmentation_target_pos - arm_hand_th_pos, p=2, dim=-1))
    dist_rew = arm_hand_finger_dist * (-0.02)

    action_penalty = torch.sum(actions ** 2, dim=-1)

    arm_contacts_penalty = torch.sum(arm_contacts, dim=-1)
    palm_contacts_penalty = torch.clamp(palm_contacts_z / 100, 0, None)

    # Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
    emergence_reward = torch.where(progress_buf % max_hand_reset_length == 0, emergence_reward, torch.zeros_like(emergence_reward))
    # emergence_reward = torch.where(arm_hand_finger_dist < 2, emergence_reward / 10, torch.zeros_like(emergence_reward))

    # object_up_reward = (segmentation_target_pos[:, 2]-0.5) * 5
    # dist_rew = torch.where(progress_buf % max_hand_reset_length == 0, dist_rew, torch.zeros_like(dist_rew))
    heap_movement_penalty = torch.where(progress_buf % max_hand_reset_length == 0, torch.clamp(heap_movement_penalty - init_heap_movement_penalty, min=0, max=15), torch.zeros_like(heap_movement_penalty))

    predict_success_reward = torch.where(progress_buf % max_hand_reset_length == 0, predict_success_confident[:, 0], torch.zeros_like(predict_success_confident[:, 0]))
    predict_success_reward = torch.clamp(predict_success_reward, -2.5, 2.5)

    reward = emergence_reward - heap_movement_penalty + dist_rew - arm_contacts_penalty - palm_contacts_penalty + predict_success_reward

    # if predict_success_reward[0] != 0 or emergence_reward[0] != 0:
    #     print("emergence_reward: ", emergence_reward[0])
    #     print("predict_success_reward: ", predict_success_reward[0])
    # print(object_up_reward[0])
    # # print(spin_reward[0])
    # print("----finish----")

    # Fall penalty: distance to the goal is larger than a threshold
    # Check env termination conditions, including maximum success number
    resets = torch.where(arm_hand_finger_dist <= -1, torch.ones_like(reset_buf), reset_buf)
    # resets = torch.where(goal_resets == 1, torch.ones_like(resets), resets)
    # resets = torch.where(goal_resets == 1, torch.ones_like(resets), resets)

    timed_out = progress_buf >= max_episode_length
    resets = torch.where(timed_out, torch.ones_like(resets), resets)

    # Apply penalty for not reaching the goal
    if max_consecutive_successes > 0:
        reward = torch.where(timed_out, reward + 0.5 * fall_penalty, reward)

    num_resets = torch.sum(resets)
    finished_cons_successes = torch.sum(successes * resets.float())

    cons_successes = torch.where(num_resets > 0, av_factor*finished_cons_successes/num_resets + (1.0 - av_factor)*consecutive_successes, consecutive_successes)

    env_reward = emergence_reward - heap_movement_penalty + dist_rew - arm_contacts_penalty - palm_contacts_penalty
    v_value_reward = predict_success_reward

    return reward, resets, reset_goal_buf, progress_buf, successes, cons_successes, env_reward, v_value_reward


@torch.jit.script
def randomize_rotation(rand0, rand1, x_unit_tensor, y_unit_tensor):
    return quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor),
                    quat_from_angle_axis(rand1 * np.pi, y_unit_tensor))


@torch.jit.script
def randomize_rotation_pen(rand0, rand1, max_angle, x_unit_tensor, y_unit_tensor, z_unit_tensor):
    rot = quat_mul(quat_from_angle_axis(0.5 * np.pi + rand0 * max_angle, x_unit_tensor),
                   quat_from_angle_axis(rand0 * np.pi, z_unit_tensor))
    return rot

def orientation_error(desired, current):
	cc = quat_conjugate(current)
	q_r = quat_mul(desired, cc)
	return q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1)

def control_ik(j_eef, device, dpose, num_envs):
	# Set controller parameters
	# IK params
    damping = 0.05
    # solve damped least squares
    j_eef_T = torch.transpose(j_eef, 1, 2)
    lmbda = torch.eye(6, device=device) * (damping ** 2)
    u = (j_eef_T @ torch.inverse(j_eef @ j_eef_T + lmbda) @ dpose).view(num_envs, -1)
    return u