import numpy as np
import os
from gym import utils
from .fetch_env import FetchEnv
try:
  import gym_robotics
except:
  import gym.envs.robotics as gym_robotics

MODEL_XML_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../assets/fetch/ergodic_push_walls.xml") # object can freely move, but with a cube around the talbe which limits both robot and object, Large size, the same as the table
#MODEL_XML_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../assets/fetch/ergodic_push_walls_M.xml") # object can freely move, but with a cube around the talbe which limits both robot and object, Large size, the same as the table


class FetchPushErgodicEnv3(FetchEnv, utils.EzPickle):
    def __init__(self, reward_type="sparse", full_state_goal=False, reset_at_goal=False, workspace_min=None, workspace_max=None):
        '''
        compared to v1:
        i) virtual limit of object0 position is reduced by 0.04 (along x, y axis)
        ii) x,y joint limit object0 is hardened (solreflimit, solimplimit)
        iii) initial position of object0 is "almost" fixed
        '''
        
        self.reset_at_goal = reset_at_goal

        initial_qpos = {
            "robot0:slide0": 0.405,
            "robot0:slide1": 0.48,
            "robot0:slide2": 0.0,
            #"object0:joint": [1.25, 0.53, 0.4, 1.0, 0.0, 0.0, 0.0],
            #"object0:joint": [1.34, 0.75, 0.53, 1.0, 0.0, 0.0, 0.0],
            'object0:joint': [1.33, 0.65, 0.42, 1., 0., 0., 0.],
        }
        FetchEnv.__init__(
            self,
            MODEL_XML_PATH,
            has_object=True,
            #block_gripper=True,
            block_gripper=False,
            n_substeps=20,
            gripper_extra_height=0.0,
            target_in_the_air=False,
            target_offset=0.0,
            #obj_range=0.15,
            obj_range=0.0,
            target_range=0.15, # for L
            #target_range=0.1, # for M
            distance_threshold=0.05,
            initial_qpos=initial_qpos,
            reward_type=reward_type,
            full_state_goal=full_state_goal,
            workspace_min=([1.16, 0.54, 0.42]), # for L
            workspace_max=([1.54, 0.96, 0.55]), # for L
            #workspace_min=([1.21, 0.59, 0.42]), # for M
            #workspace_max=([1.49, 0.91, 0.55]), # for M
            z_range=0.0,
        )
        utils.EzPickle.__init__(self, reward_type=reward_type)

    def _reset_sim(self):
        self.sim.set_state(self.initial_state)
        if self.reset_at_goal:
            if self.has_object: 
                # get initial config of object
                object_qpos = np.array(self.sim.data.get_joint_qpos("object0:joint"))

                init_obj_xpos = self.initial_gripper_xpos[:3] + self.np_random.uniform(
                    -self.target_range, self.target_range, size=3
                )
                init_obj_xpos += self.target_offset
                init_obj_xpos[2] = self.height_offset
                
                if self.target_in_the_air and self.np_random.uniform() < 0.5:
                    init_obj_xpos[2] += self.np_random.uniform(0, 0.45)
                    
                
                gripper_target = init_obj_xpos.copy() + np.array([0, 0, 0.015])
                gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
                self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
                self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)
                
                action = np.array([0, 0, 0, 1.0])
                pos_ctrl, gripper_ctrl = action[:3], action[3]

                pos_ctrl *= 0.05  # limit maximum change in position
                rot_ctrl = [
                    1.0,
                    0.0,
                    1.0,
                    0.0,
                ]  # fixed rotation of the end effector, expressed as a quaternion
                
                gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
                assert gripper_ctrl.shape == (2,)
                if self.block_gripper:
                    self._step_callback()
                    

                action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl])

                # Apply action to simulation.
                # gym_robotics.envs.utils.ctrl_set_action(self.sim, action)
                # gym_robotics.envs.utils.mocap_set_action(self.sim, action)

                
                for _ in range(5): # move the end effector with opened gripper
                    self.sim.step()

                #object_qpos = np.concatenate([np.atleast_1d(object_qpx), np.atleast_1d(object_qpy), np.atleast_1d(object_qpz), object_qrxyz])
                assert object_qpos.shape == (7,)
                custom_offset = np.array([-0.025, -0.025, -0.025])                    
                for _ in range(1): # set object position with closed gripper                    
                    object_qpos[:3] = init_obj_xpos + custom_offset
                    self.sim.data_set_joint_qpos("object0:joint", object_qpos)
                    
                    gripper_ctrl = np.array([-1.0, -1.0])
                    action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl])
                    # Apply action to simulation.
                    gym_robotics.envs.utils.ctrl_set_action(self.sim, action)
                    gym_robotics.envs.utils.mocap_set_action(self.sim, action)
                    self.sim.step()

        else:
            # "Fix" start position of object.
            if self.has_object:
                object_xpos = np.array([1.43, 0.76]) + self.np_random.uniform(-0.025, 0.025, size=2)
                object_qpos = np.array(self.sim.data.get_joint_qpos("object0:joint"))
                assert object_qpos.shape == (7,)
                self.sim.data.set_joint_qpos("object0:joint", object_qpos)

        self.sim.forward()
        return True

    def _viewer_setup(self):
        body_id = self.sim.model.body_name2id("robot0:gripper_link")
        lookat = self.sim.data.body_xpos[body_id]
        for idx, value in enumerate(lookat):
            self.viewer.cam.lookat[idx] = value
        #self.viewer.cam.distance = 2.5
        self.viewer.cam.distance = 1.5
        self.viewer.cam.azimuth = 132.0
        self.viewer.cam.elevation = -14.0
    
    def _sample_goal(self):
        if self.reset_at_goal:
            if self.has_object:
                object_xpos = np.array([1.43, 0.76]) + self.np_random.uniform(-0.025, 0.025, size=2)
                goal = np.concatenate([object_xpos, np.array([self.height_offset])], axis =-1)            
            
            return goal.copy()

        else:
            return super()._sample_goal()
        
