import gym
from gym import spaces
import numpy as np
from widowx_expert.env.widowx_lift_cube import WidowXLiftCubeBase  # assuming the class is inside this file

from mani_skill.utils import common

# class WidowXLiftCubeGymWrapper(gym.Env):
#     """
#     Wrapper class for the WidowXLiftCubeEnv to conform to the Gym environment interface.
#     """

#     def __init__(self, *args, **kwargs):
#         """
#         Initialize the wrapped environment.
#         """
#         # Initialize the WidowXLiftCubeEnv environment
#         self.env = WidowXLiftCubeBase(*args, **kwargs)
        
#         # Define the action and observation space
#         # Assuming the action space is a continuous space of joint angles and the observation space
#         # includes joint positions and other relevant information.
#         self.action_space = None
        
#         self.observation_space = None
    
#     def reset(self):
#         """
#         Reset the environment at the start of an episode.
#         """
#         # Reset the internal WidowXLiftCubeEnv environment and return the initial observation
#         obs = self.env.reset()
#         return obs

#     def step(self, action):
#         """
#         Take a step in the environment by applying the given action.
#         """
#         # Apply the action in the WidowXLiftCubeEnv environment
#         obs, reward, done, info = self.env.step(action)
#         return obs, reward, done, info

#     def render(self, mode="rgb_array"):
#         """
#         Render the environment (optional).
#         """
#         # You can implement rendering based on the underlying environment.
#         self.env.render()

#     def close(self):
#         """
#         Close the environment (optional).
#         """
#         self.env.close()


class WidowXLiftCubeGym(gym.Env):
    """
    Wrapper class for the WidowXLiftCubeEnv to conform to the Gym environment interface.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize the wrapped environment.
        """
        # Initialize the WidowXLiftCubeEnv environment

        self.constrained_action_space = kwargs.pop("constrained_action_space", False)
        kwargs["control_mode"] = "pd_ee_delta_pose"
        kwargs["render_mode"] = "sensors"
        kwargs["num_envs"] = 1

        kwargs["sensor_configs"] = dict(shader_pack="default")
        kwargs["human_render_camera_configs"] = dict(shader_pack="default")
        kwargs["viewer_camera_configs"] = dict(shader_pack="default")
        kwargs["sim_backend"] = 'cpu'
        kwargs["render_backend"] = 'cpu'

        self.env = WidowXLiftCubeBase(**kwargs)
        
        
        # Define the action and observation space
        # Assuming the action space is a continuous space of joint angles and the observation space
        # includes joint positions and other relevant information.
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)
        
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(28,), dtype=np.float32)
    
    def reset(self):
        """
        Reset the environment at the start of an episode.
        """
        # Reset the internal WidowXLiftCubeEnv environment and return the initial observation
        obs = self.env.reset()
        return self._get_obs(obs)
    
    def action_transform(self, action):
        """
        Transform the action to the desired format.
        """
        # Initialize the end-effector (ee) action with zero values
        ee_action = np.zeros(6)  # [x, y, z, roll, pitch, yaw]. roll, pitch, yaw will remains to be 0
        gripper_action = 0  # Gripper action: -1 for open, 1 for close, 0 for no action
        
        # Extract gripper action from the last element of the action
        gripper_action = action[-1]  # Assuming the gripper action is the last element of the action vector
        
        # Unconstrained action space: allow all end effector movements
        if not self.constrained_action_space:
            ee_action[:3] = action[:3]  # Move in x, y, z directions
            ee_action[:3] = np.clip(ee_action[:3], -1.0, 1.0)  # Clip to the action space limits
        else:
            # Constrained action space: allow only the largest movement (x, y, or z) and the gripper action
            largest_move_dim = np.argmax(np.abs(action[:3]))  # Find the largest movement in x, y, or z
            
            # Set all movements to 0, except for the largest one
            ee_action[:3] = [0.0, 0.0, 0.0]  # Clear all dimensions
            ee_action[largest_move_dim] = action[largest_move_dim]  # Only move in the largest direction

        # Create the action dict
        base_action = np.zeros([2])  # Assuming base action is not used in this case
        body_action = np.zeros([3])  # Assuming body action is not used in this case
        
        action_dict = dict(
            base=base_action,
            arm=ee_action,
            body=body_action,
            gripper=gripper_action
        )
        
        # Convert the action_dict to tensors if necessary
        action_dict = common.to_tensor(action_dict)
        
        # Use the controller to execute the action
        action = self.env.agent.controller.from_action_dict(action_dict)

        return action

    def step(self, action):
        sim_action = self.action_transform(action)
        # Take a step in the environment using the action
        obs, reward, terminated, truncated, info = self.env.step(sim_action)

        # transform everything to be numpy
        reward = reward.item()
        terminated = terminated.item()

        for key in info.keys():
            info[key] = info[key].item()

        
        return self._get_obs(obs), reward, terminated, info

    def get_render_args(self):
        return {'mode': 'human'}

    def render(self, mode='human', **kwargs):
        """
        Render the environment (optional).
        """
        # You can implement rendering based on the underlying environment.
        img = self.env.render()
        # if img has 4 dim, return img[0]
        if len(img.shape) == 4:
            return img[0]
        return img

    def _get_obs(self, old_obs):
        """
        Process the raw observation info and return the formatted observation
        in a dictionary form suitable for the Gym environment.
        """

        # positions
        # agent_pos
        agent_pos = self.env.agent.robot.get_qpos().cpu().numpy().flatten()
        # agent velocity
        agent_vel = self.env.agent.robot.get_qvel().cpu().numpy().flatten()
        # box pos
        cube_position = self.env.cube.pose.p[0].cpu().numpy().flatten()
        # ee pos
        ee_pos = self.env.agent.robot.get_links()[6].pose.p[0].cpu().numpy().flatten()
        # Joint displacement of the gripper fingers
        left_finger_pos = self.env.agent.robot.get_links()[11].pose.p[0].cpu().numpy().flatten()
        right_finger_pos = self.env.agent.robot.get_links()[12].pose.p[0].cpu().numpy().flatten()

        # # elapsed_steps
        # elapsed_steps = info[1]["elapsed_steps"].cpu().numpy().flatten()
        # # success
        # success = info[1]["success"].cpu().numpy().flatten()
        # # if object is picked
        # is_obj_picked = info[1]["is_obj_picked"].cpu().numpy().flatten()
        # # if grasping object
        # is_grasped = info[1]["is_grasped"].cpu().numpy().flatten()
        obs = np.concatenate([
            agent_pos, agent_vel, cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel()
            # elapsed_steps, success, is_obj_picked, is_grasped
        ])
        return obs

    def close(self):
        """
        Close the environment (optional).
        """
        self.env.close()

    def _is_success(self):
        """
        Check if the task is successful.
        """
        # eval_info = self.env.evaluate()
        # """ eval info
        # {
        #     "success": is_obj_picked,
        #     "is_obj_picked": is_obj_picked,
        #     "is_grasped": is_grasped,
        # }
        # """
        # return eval_info["success"]
        is_success = (self.env.cube.pose.p[0][2].item() >= self.env.float_thresh)
        return is_success


class WidowXLiftCubeGymV2(gym.Env):
    """
    Wrapper class for the WidowXLiftCubeEnv to conform to the Gym environment interface.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize the wrapped environment.
        """
        # Initialize the WidowXLiftCubeEnv environment

        self.constrained_action_space = kwargs.pop("constrained_action_space", False)
        kwargs["control_mode"] = "pd_ee_delta_pose"
        kwargs["render_mode"] = "sensors"
        kwargs["num_envs"] = 1

        kwargs["sensor_configs"] = dict(shader_pack="default")
        kwargs["human_render_camera_configs"] = dict(shader_pack="default")
        kwargs["viewer_camera_configs"] = dict(shader_pack="default")
        kwargs["sim_backend"] = 'cpu'
        kwargs["render_backend"] = 'cpu'

        self.env = WidowXLiftCubeBase(**kwargs)
        
        
        # Define the action and observation space
        # Assuming the action space is a continuous space of joint angles and the observation space
        # includes joint positions and other relevant information.
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)
        
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(28,), dtype=np.float32)
    
    def reset(self):
        """
        Reset the environment at the start of an episode.
        """
        # Reset the internal WidowXLiftCubeEnv environment and return the initial observation
        obs = self.env.reset()
        return self._get_obs(obs)
    
    def action_transform(self, action):
        """
        Transform the action to the desired format.
        """
        # Initialize the end-effector (ee) action with zero values
        ee_action = np.zeros(6)  # [x, y, z, roll, pitch, yaw]. roll, pitch, yaw will remains to be 0
        gripper_action = 0  # Gripper action: -1 for open, 1 for close, 0 for no action
        
        # Extract gripper action from the last element of the action
        gripper_action = action[-1]  # Assuming the gripper action is the last element of the action vector
        
        # Unconstrained action space: allow all end effector movements
        if not self.constrained_action_space:
            ee_action[:3] = action[:3]  # Move in x, y, z directions
            ee_action[:3] = np.clip(ee_action[:3], -1.0, 1.0)  # Clip to the action space limits
        else:
            # Constrained action space: allow only the largest movement (x, y, or z) and the gripper action
            largest_move_dim = np.argmax(np.abs(action[:3]))  # Find the largest movement in x, y, or z
            
            # Set all movements to 0, except for the largest one
            ee_action[:3] = [0.0, 0.0, 0.0]  # Clear all dimensions
            ee_action[largest_move_dim] = action[largest_move_dim]  # Only move in the largest direction

        # Create the action dict
        base_action = np.zeros([2])  # Assuming base action is not used in this case
        body_action = np.zeros([3])  # Assuming body action is not used in this case
        
        action_dict = dict(
            base=base_action,
            arm=ee_action,
            body=body_action,
            gripper=gripper_action
        )
        
        # Convert the action_dict to tensors if necessary
        action_dict = common.to_tensor(action_dict)
        
        # Use the controller to execute the action
        action = self.env.agent.controller.from_action_dict(action_dict)

        return action

    def step(self, action):
        sim_action = self.action_transform(action)
        # Take a step in the environment using the action
        obs, reward, terminated, truncated, info = self.env.step(sim_action)

        # transform everything to be numpy
        reward = reward.item()
        terminated = terminated.item()

        for key in info.keys():
            info[key] = info[key].item()

        
        return self._get_obs(obs), reward, terminated, info

    def get_render_args(self):
        return {'mode': 'human'}

    def render(self, mode='human', **kwargs):
        """
        Render the environment (optional).
        """
        # You can implement rendering based on the underlying environment.
        img = self.env.render()
        # if img has 4 dim, return img[0]
        if len(img.shape) == 4:
            return img[0]
        return img

    def _get_obs(self, old_obs):
        """
        Process the raw observation info and return the formatted observation
        in a dictionary form suitable for the Gym environment.
        """

        # positions
        # agent_pos
        agent_pos = self.env.agent.robot.get_qpos().cpu().numpy().flatten()
        # agent velocity
        agent_vel = self.env.agent.robot.get_qvel().cpu().numpy().flatten()
        # box pos
        cube_position = self.env.cube.pose.p[0].cpu().numpy().flatten()
        # ee pos
        ee_pos = self.env.agent.robot.get_links()[6].pose.p[0].cpu().numpy().flatten()
        # Joint displacement of the gripper fingers
        left_finger_pos = self.env.agent.robot.get_links()[11].pose.p[0].cpu().numpy().flatten()
        right_finger_pos = self.env.agent.robot.get_links()[12].pose.p[0].cpu().numpy().flatten()
        # whether the agent is grasped
        is_grasped = self.env.agent.is_grasping(self.env.cube)

        # # elapsed_steps
        # elapsed_steps = info[1]["elapsed_steps"].cpu().numpy().flatten()
        # # success
        # success = info[1]["success"].cpu().numpy().flatten()
        # # if object is picked
        # is_obj_picked = info[1]["is_obj_picked"].cpu().numpy().flatten()
        # # if grasping object
        # is_grasped = info[1]["is_grasped"].cpu().numpy().flatten()
        obs = np.concatenate([
            agent_pos, agent_vel, cube_position.ravel(), ee_pos.ravel(), left_finger_pos.ravel(), right_finger_pos.ravel(), is_grasped,
            # elapsed_steps, success, is_obj_picked, is_grasped
        ])
        return obs

    def close(self):
        """
        Close the environment (optional).
        """
        self.env.close()

    def _is_success(self):
        """
        Check if the task is successful.
        """
        # eval_info = self.env.evaluate()
        # """ eval info
        # {
        #     "success": is_obj_picked,
        #     "is_obj_picked": is_obj_picked,
        #     "is_grasped": is_grasped,
        # }
        # """
        # return eval_info["success"]
        is_obj_picked = (self.env.cube.pose.p[0][2].item() >= self.env.float_thresh)
        is_grasped = self.env.agent.is_grasping(self.env.cube)
        is_success = is_obj_picked & is_grasped
        return is_success


# class WidowXLiftCubeObsWrapper(gym.Env):
#     """
#     Wrapper class for the WidowXLiftCubeEnv to conform to the Gym environment interface.
#     """

#     def __init__(self, *args, **kwargs):
#         """
#         Initialize the wrapped environment.
#         """
#         # Initialize the WidowXLiftCubeEnv environment
#         self.env = WidowXLiftCubeBase(*args, **kwargs)
        
#         # Define the action and observation space
#         # Assuming the action space is a continuous space of joint angles and the observation space
#         # includes joint positions and other relevant information.
#         self.action_space = self.env.action_space
        
#         # self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
#         self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32)
    
#     def reset(self):
#         """
#         Reset the environment at the start of an episode.
#         """
#         # Reset the internal WidowXLiftCubeEnv environment and return the initial observation
#         obs = self.env.reset()
#         return self._get_obs(obs)

#     def step(self, action):
#         """
#         Take a step in the environment by applying the given action.
#         """
#         # Apply the action in the WidowXLiftCubeEnv environment
#         obs, reward, done, info = self.env.step(action)
#         return self._get_obs(obs), reward, done, info

#     def render(self, mode="rgb_array"):
#         """
#         Render the environment (optional).
#         """
#         # You can implement rendering based on the underlying environment.
#         self.env.render()

#     def _get_obs(self, info):
#         """
#         Process the raw observation info and return the formatted observation
#         in a dictionary form suitable for the Gym environment.
#         """
#         # positions
#         # grip_pos = self.sim.data.get_site_xpos('robot0:grip')
#         # dt = self.sim.nsubsteps * self.sim.model.opt.timestep
#         # grip_velp = self.sim.data.get_site_xvelp('robot0:grip') * dt
#         # robot_qpos, robot_qvel = utils.robot_get_obs(self.sim)
#         # if self.has_object:
#         #     object_pos = self.sim.data.get_site_xpos('object0')
#         #     # rotations
#         #     object_rot = rotations.mat2euler(self.sim.data.get_site_xmat('object0'))
#         #     # velocities
#         #     object_velp = self.sim.data.get_site_xvelp('object0') * dt
#         #     object_velr = self.sim.data.get_site_xvelr('object0') * dt
#         #     # gripper state
#         #     object_rel_pos = object_pos - grip_pos
#         #     object_velp -= grip_velp
#         # else:
#         #     object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
#         # gripper_state = robot_qpos[-2:]
#         # gripper_vel = robot_qvel[-2:] * dt  # change to a scalar if the gripper is made symmetric

#         # if not self.has_object:
#         #     achieved_goal = grip_pos.copy()
#         # else:
#         #     achieved_goal = np.squeeze(object_pos.copy())

#         # positions
#         # ee pos
#         # box pos
#         # Joint displacement of the gripper fingers
#         # gripper state
#         return info
#         obs = np.concatenate([
#             grip_pos, object_pos.ravel(), object_rel_pos.ravel(), gripper_state, object_rot.ravel(),
#             object_velp.ravel(), object_velr.ravel(), grip_velp, gripper_vel,
#         ])
#         return obs

#     def close(self):
#         """
#         Close the environment (optional).
#         """
#         self.env.close()

# class WidowXLiftCubeActionSpaceWrapper(gym.Env):
#     """
#     Wrapper class for the WidowXLiftCubeEnv to modify the action space.
#     """

#     def __init__(self, *args, **kwargs):
#         """
#         Wrap the widowxLiftCubeEnv with two type of action space:
#             unconstrained:
#                 End effector's movement(x, y, z) + gripper's movement (open, close)
#             constrained:
#                 Same dimentionality. But only the largest one is allowed to move + gripper's movement
#         """

#         self.constrained_action_space = kwargs.pop("constrained_action_space", False)
#         self.env = WidowXLiftCubeObsWrapper(*args, **kwargs)
        
#         # Define the action space
#         self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)

#     def step(self, action):
#         # Initialize the end-effector (ee) action with zero values
#         ee_action = np.zeros(6)  # [x, y, z, roll, pitch, yaw]. roll, pitch, yaw will remains to be 0
#         gripper_action = 0  # Gripper action: -1 for open, 1 for close, 0 for no action
        
#         # Extract gripper action from the last element of the action
#         gripper_action = action[-1]  # Assuming the gripper action is the last element of the action vector
        
#         # Unconstrained action space: allow all end effector movements
#         if not self.constrained_action_space:
#             ee_action[:3] = action[:3]  # Move in x, y, z directions
#             ee_action[:3] = np.clip(ee_action[:3], -1.0, 1.0)  # Clip to the action space limits
#         else:
#             # Constrained action space: allow only the largest movement (x, y, or z) and the gripper action
#             largest_move_dim = np.argmax(np.abs(action[:3]))  # Find the largest movement in x, y, or z
            
#             # Set all movements to 0, except for the largest one
#             ee_action[:3] = [0.0, 0.0, 0.0]  # Clear all dimensions
#             ee_action[largest_move_dim] = action[largest_move_dim]  # Only move in the largest direction

#         # Create the action dict
#         base_action = np.zeros([2])  # Assuming base action is not used in this case
#         body_action = np.zeros([3])  # Assuming body action is not used in this case
        
#         action_dict = dict(
#             base=base_action,
#             arm=ee_action,
#             body=body_action,
#             gripper=gripper_action
#         )
        
#         # Print the action_dict for debugging purposes
#         print("action_dict", action_dict)
        
#         # Convert the action_dict to tensors if necessary
#         action_dict = common.to_tensor(action_dict)
#         print("action_dict after conversion", action_dict)
        
#         # Use the controller to execute the action
#         action = self.env.agent.controller.from_action_dict(action_dict)
        
#         # Take a step in the environment using the action
#         obs, reward, done, info = self.env.step(action)
        
#         return obs, reward, done, info


    


