from typing import Any, Dict, Union

import numpy as np
import sapien
import torch
from scipy.spatial.transform import Rotation as R

import mani_skill.envs.utils.randomization as randomization
from mani_skill.agents.robots import WidowX250S
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose

from widowx_expert.agent.widowx import WidowX250SCustom


# Example angles in radians
# (joint1, joint2, joint3, joint4, joint5, [joint6])
STRAIGHT_DOWN_QPOS = [0.0, 0, 0, 0, 1.57, 0.0, 0.0, 0.0]


@register_env("WidowXLiftCubeBase-v1", max_episode_steps=50)
class WidowXLiftCubeBase(BaseEnv):
    """
    **Task Description:**
    A simple task where the objective is to grasp a red cube and lift it up to certain height.

    **Randomizations:**
    - There will be 3 positions on the table where the cube can be placed, each time the cube will be placed at one of the 3 positions
    - the cube's z-axis rotation is always 0
    - the target goal position of the cube will be at a certain height above the table

    **Success Conditions:**
    - the cube position is over a certain height
    """

    _sample_video_link = "https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/PickCube-v1_rt.mp4"
    SUPPORTED_ROBOTS = [
        "widowx250s_custom",
    ]
    agent: Union[WidowX250SCustom]
    cube_half_size = 0.018
    goal_thresh = 0.1
    float_thresh = 0.05

    def __init__(self, *args, robot_uids="widowx250s_custom", robot_init_qpos_noise=0.02, **kwargs):
        self.robot_init_qpos_noise = robot_init_qpos_noise
        self.cubic_init_position_xy = torch.tensor([[0.05, 0.05], [0.05, -0.05], [-0.05, 0.05]])
        super().__init__(*args, robot_uids=robot_uids, **kwargs)

    @property
    def _default_sensor_configs(self):
        # pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
        pose = sapien_utils.look_at(eye=[0.2, 0.05, 0.3], target=[-0.1, 0, 0.1])
        return [CameraConfig("base_camera", pose, 1024, 1024, np.pi / 2, 0.01, 100)]

    @property
    def _default_human_render_camera_configs(self):
        # pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
        pose = sapien_utils.look_at([1, -1, 0], [0.3, 0.5, 0.35])
        # return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100)
        return CameraConfig("render_camera", pose, 1024, 1024, 1, 0.01, 100)

    def _load_agent(self, options: dict):
        super()._load_agent(options, sapien.Pose(p=[-0.2, 0, 0]))

    def _load_scene(self, options: dict):
        self.table_scene = TableSceneBuilder(
            self, robot_init_qpos_noise=self.robot_init_qpos_noise
        )
        self.table_scene.build()


        self.cube = actors.build_cube(
            self.scene,
            half_size=self.cube_half_size,
            color=[1, 0, 0, 1],
            name="cube",
            initial_pose=sapien.Pose(p=[0, 0, self.cube_half_size]),
        )

    def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
        with torch.device(self.device):
            b = len(env_idx)
            self.table_scene.initialize(env_idx)

            # init cube site
            # xyz = torch.zeros((b, 3))
            # xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1
            # xyz[:, 2] = self.cube_half_size
            # qs = randomization.random_quaternions(b, lock_x=True, lock_y=True)
            # self.cube.set_pose(Pose.create_from_pq(xyz, qs))
            xyz = torch.zeros((b, 3))
            cube_xy_pos_id = torch.randint(0, 3, (b,))
            for i in range(b):
                xyz[i, :2] = self.cubic_init_position_xy[cube_xy_pos_id[i]]
                xyz[i, 2] = self.cube_half_size
            
            # set rotation of the cube
            # qs = randomization.random_quaternions(b, lock_x=True, lock_y=True)

            # Use fixed quaternion for all cubes (90° around Z)
            quat_xyzw = R.from_euler('z', 90, degrees=True).as_quat()
            quat_wxyz = torch.tensor([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])

            # Expand the quaternion to match batch size
            qs = quat_wxyz.repeat(b, 1)


            self.cube.set_pose(Pose.create_from_pq(xyz, qs))

            # Set robot joints to desired angles
            # This method name might differ depending on your custom agent:
            self.agent.robot.set_qpos(STRAIGHT_DOWN_QPOS)

    def _get_obs_extra(self, info: Dict):
        # in reality some people hack is_grasped into observations by checking if the gripper can close fully or not
        obs = dict(
            is_grasped=info["is_grasped"],
            # tcp_to_obj_pos=self.cube.pose.p - self.agent.tcp.pose.p,
            object_rel_height=torch.tensor([self.cube.pose.p[0][2] - self.float_thresh]), # if this value is none-nagetive, the task is succeeded
        )
        return obs

    def evaluate(self):
        is_obj_picked = (self.cube.pose.p[0][2] >= self.float_thresh)
        is_grasped = self.agent.is_grasping(self.cube)
        return {
            "success": is_obj_picked,
            "is_obj_picked": is_obj_picked,
            "is_grasped": is_grasped,
        }
        # is_robot_static = self.agent.is_static(0.2)
        # return {
        #     "success": is_obj_placed & is_robot_static,
        #     "is_obj_placed": is_obj_placed,
        #     "is_robot_static": is_robot_static,
        #     "is_grasped": is_grasped,
        # }

    # def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
    #     tcp_to_obj_dist = torch.linalg.norm(
    #         self.cube.pose.p - self.agent.tcp.pose.p, axis=1
    #     )
    #     reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
    #     reward = reaching_reward

    #     is_grasped = info["is_grasped"]
    #     reward += is_grasped

    #     obj_to_goal_dist = self.float_thresh - self.cube.pose.p[0][2]
    #     place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
    #     reward += place_reward * is_grasped

    #     # reward[info["success"]] = 5
    #     return reward

    def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
        obj_to_goal_dist = self.float_thresh - self.cube.pose.p[0][2]
        place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
        reward = place_reward

        # reward[info["success"]] = 5
        return reward

    def compute_normalized_dense_reward(
        self, obs: Any, action: torch.Tensor, info: Dict
    ):
        return self.compute_dense_reward(obs=obs, action=action, info=info) / 5
    




@register_env("WidowXLiftCubeBase-v2", max_episode_steps=50)
class WidowXLiftCubeBase(BaseEnv):
    """
    **Task Description:**
    A simple task where the objective is to grasp a red cube and lift it up to certain height.

    **Randomizations:**
    - There will be 3 positions on the table where the cube can be placed, each time the cube will be placed at one of the 3 positions
    - the cube's z-axis rotation is always 0
    - the target goal position of the cube will be at a certain height above the table

    **Success Conditions:**
    - the cube position is over a certain height
    """

    _sample_video_link = "https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/PickCube-v1_rt.mp4"
    SUPPORTED_ROBOTS = [
        "widowx250s_custom",
    ]
    agent: Union[WidowX250SCustom]
    cube_half_size = 0.018
    goal_thresh = 0.1
    float_thresh = 0.05

    def __init__(self, *args, robot_uids="widowx250s_custom", robot_init_qpos_noise=0.02, **kwargs):
        self.robot_init_qpos_noise = robot_init_qpos_noise
        self.cubic_init_position_xy = torch.tensor([[0.05, 0.05], [0.05, -0.05], [-0.05, 0.05]])
        super().__init__(*args, robot_uids=robot_uids, **kwargs)

    @property
    def _default_sensor_configs(self):
        # pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
        pose = sapien_utils.look_at(eye=[0.2, 0.05, 0.3], target=[-0.1, 0, 0.1])
        return [CameraConfig("base_camera", pose, 1024, 1024, np.pi / 2, 0.01, 100)]

    @property
    def _default_human_render_camera_configs(self):
        # pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
        pose = sapien_utils.look_at([1, -1, 0], [0.3, 0.5, 0.35])
        # return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100)
        return CameraConfig("render_camera", pose, 1024, 1024, 1, 0.01, 100)

    def _load_agent(self, options: dict):
        super()._load_agent(options, sapien.Pose(p=[-0.2, 0, 0]))

    def _load_scene(self, options: dict):
        self.table_scene = TableSceneBuilder(
            self, robot_init_qpos_noise=self.robot_init_qpos_noise
        )
        self.table_scene.build()


        self.cube = actors.build_cube(
            self.scene,
            half_size=self.cube_half_size,
            color=[1, 0, 0, 1],
            name="cube",
            initial_pose=sapien.Pose(p=[0, 0, self.cube_half_size]),
        )

    def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
        with torch.device(self.device):
            b = len(env_idx)
            self.table_scene.initialize(env_idx)

            # init cube site
            # xyz = torch.zeros((b, 3))
            # xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1
            # xyz[:, 2] = self.cube_half_size
            # qs = randomization.random_quaternions(b, lock_x=True, lock_y=True)
            # self.cube.set_pose(Pose.create_from_pq(xyz, qs))
            xyz = torch.zeros((b, 3))
            cube_xy_pos_id = torch.randint(0, 3, (b,))
            for i in range(b):
                xyz[i, :2] = self.cubic_init_position_xy[cube_xy_pos_id[i]]
                xyz[i, 2] = self.cube_half_size
            
            # set rotation of the cube
            # qs = randomization.random_quaternions(b, lock_x=True, lock_y=True)

            # Use fixed quaternion for all cubes (90° around Z)
            quat_xyzw = R.from_euler('z', 90, degrees=True).as_quat()
            quat_wxyz = torch.tensor([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])

            # Expand the quaternion to match batch size
            qs = quat_wxyz.repeat(b, 1)


            self.cube.set_pose(Pose.create_from_pq(xyz, qs))

            # Set robot joints to desired angles
            # This method name might differ depending on your custom agent:
            self.agent.robot.set_qpos(STRAIGHT_DOWN_QPOS)

    def _get_obs_extra(self, info: Dict):
        # in reality some people hack is_grasped into observations by checking if the gripper can close fully or not
        obs = dict(
            is_grasped=info["is_grasped"],
            # tcp_to_obj_pos=self.cube.pose.p - self.agent.tcp.pose.p,
            object_rel_height=torch.tensor([self.cube.pose.p[0][2] - self.float_thresh]), # if this value is none-nagetive, the task is succeeded
        )
        return obs

    def evaluate(self):
        is_obj_picked = (self.cube.pose.p[0][2] >= self.float_thresh)
        is_grasped = self.agent.is_grasping(self.cube)
        return {
            "success": torch.tensor(False), 
            "is_obj_picked": is_obj_picked,
            "is_grasped": is_grasped,
        }
        # is_robot_static = self.agent.is_static(0.2)
        # return {
        #     "success": is_obj_placed & is_robot_static,
        #     "is_obj_placed": is_obj_placed,
        #     "is_robot_static": is_robot_static,
        #     "is_grasped": is_grasped,
        # }

    # def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
    #     tcp_to_obj_dist = torch.linalg.norm(
    #         self.cube.pose.p - self.agent.tcp.pose.p, axis=1
    #     )
    #     reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
    #     reward = reaching_reward

    #     is_grasped = info["is_grasped"]
    #     reward += is_grasped

    #     obj_to_goal_dist = self.float_thresh - self.cube.pose.p[0][2]
    #     place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
    #     reward += place_reward * is_grasped

    #     # reward[info["success"]] = 5
    #     return reward

    def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
        obj_to_goal_dist = self.float_thresh - self.cube.pose.p[0][2]
        place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
        reward = place_reward

        # reward[info["success"]] = 5
        return reward

    def compute_normalized_dense_reward(
        self, obs: Any, action: torch.Tensor, info: Dict
    ):
        return self.compute_dense_reward(obs=obs, action=action, info=info) / 5
    




