
from .base_task import Base_task
from .utils import *
import sapien
import math

class blocks_stack_hard(Base_task):

    def setup_demo(self, is_test=False, **kwags):
        super()._init(**kwags)
        self.create_table_and_wall()
        self.load_robot()
        self.setup_planner()
        self.load_camera()
        self.pre_move()
        self.load_actors()
        self.step_lim = 850
        self.is_test = is_test

    def pre_move(self):
        render_freq = self.render_freq
        self.render_freq=0
        self.together_open_gripper(save_freq=None)
        self.render_freq = render_freq

    def load_actors(self):
        block_pose = rand_pose(
            xlim=[-0.25,0.25],
            ylim=[-0.15,0.05],
            zlim=[0.76],
            qpos=[0.5, 0.5, 0.5, 0.5],
            ylim_prop=True,
            rotate_rand=True,
            rotate_lim=[0,1.57,0],
        )

        while abs(block_pose.p[0]) < 0.05 or np.sum(pow(block_pose.p[:2] - np.array([0,-0.1]),2)) < 0.0225:
            block_pose = rand_pose(
                xlim=[-0.25,0.25],
                ylim=[-0.15,0.05],
                zlim=[0.76],
                qpos=[0.5, 0.5, 0.5, 0.5],
                ylim_prop=True,
                rotate_rand=True,
                rotate_lim=[0,1.57,0],
            )

        self.block1 = create_box(
            scene = self.scene,
            pose = block_pose,
            half_size=(0.025,0.025,0.025),
            color=(1,0,0),
            name="box"
        )

        block_pose = rand_pose(
            xlim=[-0.25,0.25],
            ylim=[-0.15,0.05],
            zlim=[0.76],
            qpos=[0.5, 0.5, 0.5, 0.5],
            ylim_prop=True,
            rotate_rand=True,
            rotate_lim=[0,1.57,0],
        )

        while abs(block_pose.p[0]) < 0.05 or np.sum(pow(block_pose.p[:2] - self.block1.get_pose().p[:2],2)) < 0.01 \
              or np.sum(pow(block_pose.p[:2] - np.array([0,-0.1]),2)) < 0.0225:
            block_pose = rand_pose(
                xlim=[-0.25,0.25],
                ylim=[-0.15,0.05],
                zlim=[0.76],
                qpos=[0.5, 0.5, 0.5, 0.5],
                ylim_prop=True,
                rotate_rand=True,
                rotate_lim=[0,1.57,0],
            )


        self.block2 = create_box(
            scene = self.scene,
            pose = block_pose,
            half_size=(0.025,0.025,0.025),
            color=(0,1,0),
            name="box"
        )

        block_pose = rand_pose(
            xlim=[-0.25,0.25],
            ylim=[-0.15,0.05],
            zlim=[0.76],
            qpos=[0.5, 0.5, 0.5, 0.5],
            ylim_prop=True,
            rotate_rand=True,
            rotate_lim=[0,1.57,0],
        )

        while abs(block_pose.p[0]) < 0.05 or np.sum(pow(block_pose.p[:2] - self.block1.get_pose().p[:2],2)) < 0.01 or\
              np.sum(pow(block_pose.p[:2] - self.block2.get_pose().p[:2],2)) < 0.01 or np.sum(pow(block_pose.p[:2] - np.array([0,-0.1]),2)) < 0.0225:
            block_pose = rand_pose(
                xlim=[-0.25,0.25],
                ylim=[-0.15,0.05],
                zlim=[0.76],
                qpos=[0.5, 0.5, 0.5, 0.5],
                ylim_prop=True,
                rotate_rand=True,
                rotate_lim=[0,1.57,0],
            )


        self.block3 = create_box(
            scene = self.scene,
            pose = block_pose,
            half_size=(0.025,0.025,0.025),
            color=(0,0,1),
            name="box"
        )
        self.block1.find_component_by_type(sapien.physx.PhysxRigidDynamicComponent).mass = 0.01
        self.block2.find_component_by_type(sapien.physx.PhysxRigidDynamicComponent).mass = 0.01
        self.block3.find_component_by_type(sapien.physx.PhysxRigidDynamicComponent).mass = 0.01

    def move_block(self,actor,id = 0, las_arm = None):
        actor_rpy = actor.get_pose().get_rpy()
        actor_pose = actor.get_pose().p
        actor_euler = math.fmod(actor_rpy[2], math.pi / 2)
        grasp_euler = actor_euler - math.pi/2  if actor_euler > math.pi/4 else actor_euler
        grasp_trans_quat = t3d.euler.euler2quat(0,0,grasp_euler)
        grasp_qpose = t3d.quaternions.qmult(grasp_trans_quat, [-0.5,0.5,-0.5,-0.5]).tolist()
        if actor_pose[0] > 0:
            now_arm = 'right'
            pose1 = list(actor_pose + [0,0,0.2]) + grasp_qpose
            if now_arm == las_arm or las_arm is None:
                if now_arm == las_arm:
                    pose0 = list(self.right_endpose.global_pose.p + [0,0,0.05]) + [-0.5,0.5,-0.5,-0.5]
                    self.right_move_to_pose_with_screw(pose0)
                self.right_move_to_pose_with_screw(pose1)
            else:
                self.together_move_to_pose_with_screw(left_target_pose=self.left_original_pose,right_target_pose=pose1)
            pose1[2] -= 0.05
            self.right_move_to_pose_with_screw(pose1)
            self.close_right_gripper()
            pose1[2] += 0.05
            self.right_move_to_pose_with_screw(pose1)
            traget_pose = [0.01,-0.097,0.95 + id * 0.05,-0.5,0.5,-0.5,-0.5]
            self.right_move_to_pose_with_screw(traget_pose)
            traget_pose[2] -= 0.04
            self.right_move_to_pose_with_screw(traget_pose)
            self.open_right_gripper()
            traget_pose[2] += 0.04
            self.right_move_to_pose_with_screw(traget_pose)
        else:
            now_arm = 'left'
            pose1 = list(actor_pose + [0,0,0.2]) + grasp_qpose
            if now_arm == las_arm or las_arm is None:
                if now_arm == las_arm:
                    pose0 = list(self.left_endpose.global_pose.p + [0,0,0.05]) + [-0.5,0.5,-0.5,-0.5]
                    self.left_move_to_pose_with_screw(pose0)
                self.left_move_to_pose_with_screw(pose1)
            else:
                self.together_move_to_pose_with_screw(left_target_pose=pose1,right_target_pose=self.right_original_pose)
            pose1[2] -= 0.05
            self.left_move_to_pose_with_screw(pose1)
            self.close_left_gripper()
            pose1[2] += 0.05
            self.left_move_to_pose_with_screw(pose1)
            traget_pose = [0,-0.1,0.95 + id * 0.05,-0.5,0.5,-0.5,-0.5]
            self.left_move_to_pose_with_screw(traget_pose)
            traget_pose[2] -= 0.04
            self.left_move_to_pose_with_screw(traget_pose)
            self.open_left_gripper()
            traget_pose[2] += 0.04
            self.left_move_to_pose_with_screw(traget_pose)
        return now_arm
    
    def play_once(self):
        las_arm = self.move_block(self.block1, id = 0, las_arm = None)
        las_arm = self.move_block(self.block2, id = 1, las_arm = las_arm)
        las_arm = self.move_block(self.block3, id = 2, las_arm = las_arm)
        
    def check_success(self):
        block1_pose = self.block1.get_pose().p
        block2_pose = self.block2.get_pose().p
        block3_pose = self.block3.get_pose().p
        target_pose = [0,-0.1]
        if self.is_test:
            target_pose = [block1_pose[0], block1_pose[1]]
        eps = [0.025,0.025,0.01]
        return np.all(abs(block1_pose - np.array(target_pose + [0.765])) < eps) and \
               np.all(abs(block2_pose - np.array(target_pose + [0.815])) < eps) and \
               np.all(abs(block3_pose - np.array(target_pose + [0.865])) < eps) and self.is_left_gripper_open() and self.is_right_gripper_open()