from ._base_task import Base_Task
from .utils import *
import sapien
import math


class stack_bowls_two(Base_Task):

    def setup_demo(self, **kwags):
        super()._init_task_env_(**kwags)

    def load_actors(self):
        bowl_pose_lst = []
        for i in range(2):
            bowl_pose = rand_pose(
                xlim=[-0.3, 0.3],
                ylim=[-0.15, 0.15],
                qpos=[0.5, 0.5, 0.5, 0.5],
                ylim_prop=True,
                rotate_rand=False,
            )

            def check_bowl_pose(bowl_pose):
                for j in range(len(bowl_pose_lst)):
                    if (np.sum(pow(bowl_pose.p[:2] - bowl_pose_lst[j].p[:2], 2)) < 0.0169):
                        return False
                return True

            while (abs(bowl_pose.p[0]) < 0.09 or np.sum(pow(bowl_pose.p[:2] - np.array([0, -0.1]), 2)) < 0.0169
                   or not check_bowl_pose(bowl_pose)):
                bowl_pose = rand_pose(
                    xlim=[-0.3, 0.3],
                    ylim=[-0.15, 0.15],
                    qpos=[0.5, 0.5, 0.5, 0.5],
                    ylim_prop=True,
                    rotate_rand=False,
                )
            bowl_pose_lst.append(deepcopy(bowl_pose))

        bowl_pose_lst = sorted(bowl_pose_lst, key=lambda x: x.p[1])

        def create_bowl(bowl_pose):
            return create_actor(self, pose=bowl_pose, modelname="002_bowl", model_id=3, convex=True)

        self.bowl1 = create_bowl(bowl_pose_lst[0])
        self.bowl2 = create_bowl(bowl_pose_lst[1])

        self.add_prohibit_area(self.bowl1, padding=0.07)
        self.add_prohibit_area(self.bowl2, padding=0.07)
        target_pose = [-0.1, -0.15, 0.1, -0.05]
        self.prohibited_area.append(target_pose)
        self.bowl1_target_pose = np.array([0, -0.1, 0.76])
        self.quat_of_target_pose =  [0, 0.707, 0.707, 0]

    def move_bowl(self, actor, target_pose):
        actor_pose = actor.get_pose().p
        arm_tag = ArmTag("left" if actor_pose[0] < 0 else "right")

        if self.las_arm is None or arm_tag == self.las_arm:
            self.move(
                self.grasp_actor(
                    actor,
                    arm_tag=arm_tag,
                    contact_point_id=[0, 2][int(arm_tag == "left")],
                    pre_grasp_dis=0.1,
                ))
        else:
            self.move(
                self.grasp_actor(
                    actor,
                    arm_tag=arm_tag,
                    contact_point_id=[0, 2][int(arm_tag == "left")],
                    pre_grasp_dis=0.1,
                ),  # arm_tag
                self.back_to_origin(arm_tag=arm_tag.opposite),  # arm_tag.opposite
            )
        self.move(self.move_by_displacement(arm_tag, z=0.1))
        self.move(
            self.place_actor(
                actor,
                target_pose=target_pose.tolist() + self.quat_of_target_pose,
                arm_tag=arm_tag,
                functional_point_id=0,
                pre_dis=0.09,
                dis=0,
                constrain="align",
            ))
        self.move(self.move_by_displacement(arm_tag, z=0.09))
        self.las_arm = arm_tag
        return arm_tag

    def play_once(self):
        # Initialize last arm used to None
        self.las_arm = None

        # Move bowl1 to position [0, -0.1, 0.76]
        arm_tag1 = self.move_bowl(self.bowl1, self.bowl1_target_pose)
        # Move bowl2 to be 0.05m above bowl1's position
        arm_tag2 = self.move_bowl(self.bowl2, self.bowl1.get_pose().p + [0, 0, 0.05])

        # Store information about the bowls and arms used in the info dictionary
        self.info["info"] = {
            "{A}": f"002_bowl/base3",
            "{B}": f"002_bowl/base3",
            "{a}": str(arm_tag1),
            "{b}": str(arm_tag2),
        }
        return self.info
    
    def get_info(self):
        
        #arm_tag1 = self.move_bowl(self.bowl1, self.bowl1_target_pose)
        
        actor_pose = self.bowl1.get_pose().p
        arm_tag1 = ArmTag("left" if actor_pose[0] < 0 else "right")
        
        # Move bowl2 to be 0.05m above bowl1's position
        #arm_tag2 = self.move_bowl(self.bowl2, self.bowl1.get_pose().p + [0, 0, 0.05])
        
        actor_pose = self.bowl2.get_pose().p
        arm_tag2 = ArmTag("left" if actor_pose[0] < 0 else "right")
        
        info =   {
            "{A}": f"002_bowl/base3",
            "{B}": f"002_bowl/base3",
            "{a}": str(arm_tag1),
            "{b}": str(arm_tag2),
        }
        return info

    def check_success(self):
        bowl1_pose = self.bowl1.get_pose().p
        bowl2_pose = self.bowl2.get_pose().p
        bowl1_pose, bowl2_pose = sorted([bowl1_pose, bowl2_pose], key=lambda x: x[2])
        target_height = [
            0.74 + self.table_z_bias,
            0.77 + self.table_z_bias,
        ]
        eps = 0.02
        eps2 = 0.04
        return (np.all(abs(bowl1_pose[:2] - bowl2_pose[:2]) < eps2)
                and np.all(np.array([bowl1_pose[2], bowl2_pose[2]]) - target_height < eps)
                and self.is_left_gripper_open() and self.is_right_gripper_open())