import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
from envs.mujoco_env import MujocoEnv
# from serializable.serializable import Serializable
import os
class AntGoalEnv(MujocoEnv): # , Serializable #utils.EzPickle
    def __init__(self,
                sparse_reward = False, 
                reduced_observation=False, 
                full_state_goal=False,
                goal_sampling_strategy = 'uniform',
                automatically_set_spaces=False,
                reward_by_body_xy_pos = False,
                fixed_goal_qvel = True,
                random_init = False,
                ):
        # Serializable.quick_init(self, locals())
        
        self.sparse_reward = sparse_reward
        self.reduced_observation = reduced_observation
        self.full_state_goal = full_state_goal
        self.goal_sampling_strategy = goal_sampling_strategy
        self.reward_by_body_xy_pos = reward_by_body_xy_pos
        self.fixed_goal_qvel = fixed_goal_qvel
        self.random_init = random_init
        # assert (reduced_observation and not full_state_goal) or (not reduced_observation and full_state_goal)
        self.distance_threshold = 1.0 # 0.5
        MujocoEnv.__init__(self, os.path.dirname(__file__)+'/assets/ant_goal.xml', 5, automatically_set_spaces=automatically_set_spaces)
        # utils.EzPickle.__init__(self)

                
        goal_xy_low = np.array([-5,-5])
        goal_xy_high = np.array([5,5])
        self.goal_xy_space = Box(low = goal_xy_low, high =goal_xy_high)

        qvel_low = -np.ones_like(self.init_qvel)*0.01
        qvel_high = np.ones_like(self.init_qvel)*0.01
        self.goal_qvel_space = Box(low=qvel_low, high=qvel_high)


        # if not self.reduced_observation:
        if self.full_state_goal:
            self.obs_nqpos = self.init_qpos.shape[-1]-2
            qpos = self.init_qpos.copy()
            qpos = qpos[2:]
            if self.fixed_goal_qvel:
                qvel = np.zeros_like(self.init_qvel)
            else :
                qvel = self.goal_qvel_space.sample()
            self._state_goal = np.concatenate([qpos, qvel, self.goal_xy_space.sample()], axis = -1)
        else :
            self.obs_nqpos = 2 # body_xy_pos
            self._state_goal = self.goal_xy_space.sample()
        


        if not automatically_set_spaces:
            self._set_action_space()
            action = np.zeros_like(self.action_space)
            for _ in range(6):
                observation, _reward, done, _info = self.step(action)
            assert not done
            self._set_observation_space(observation)
            
            self.init_qpos_for_goal_sampling = self.data.qpos.copy()

        
        # self.pos_dim = self.init_qpos.shape[-1] -2 + self.goal_xy_space.shape[0]

    
    def _sample_goal(self, goal = None):
        if goal is not None:
            return goal

        if not self.full_state_goal:
            return self.goal_xy_space.sample()

        if self.goal_sampling_strategy == 'fixed_goal_qpos':
            raise NotImplementedError
            assert not self.use_euler
            fixed_goal = self.fixed_goal_qpos
            if self.vel_in_state:
                fixed_goal = np.concatenate((fixed_goal, np.zeros(14)))
            state_goals = np.tile(fixed_goal, (batch_size, 1))
        
        elif self.goal_sampling_strategy == 'uniform':
            # qpos = self.init_qpos.copy()
            qpos = self.init_qpos_for_goal_sampling.copy()
            qpos = qpos[2:]
            if self.fixed_goal_qvel:
                qvel = np.zeros_like(self.init_qvel)
            else :
                qvel= self.goal_qvel_space.sample()
            state_goals = np.concatenate([qpos, qvel, self.goal_xy_space.sample()], axis = -1)

            
        elif self.goal_sampling_strategy == 'uniform_pos_and_rot':
            raise NotImplementedError
            qpos = self.init_qpos.copy().reshape(1, -1)
            qpos = np.tile(qpos, (batch_size, 1))
            qpos[:,:2] = self._sample_uniform_xy(batch_size, mode='goal')

            rots = np.random.randint(4, size=batch_size)
            for i in range(batch_size):
                if rots[i] == 0:
                    qpos[i,3:7] = [1, 0, 0, 0]
                elif rots[i] == 1:
                    qpos[i, 3:7] = [0, 0, 0, 1]
                elif rots[i] == 2:
                    qpos[i, 3:7] = [0.7071068, 0, 0, 0.7071068]
                elif rots[i] == 3:
                    qpos[i, 3:7] = [0.7071068, 0, 0, -0.7071068]

            if self.use_euler:
                pos = self._qpos_to_epos(qpos)
            else:
                pos = qpos

            if self.vel_in_state:
                qvel = np.zeros((batch_size, 14))
                state_goals = np.concatenate((pos, qvel), axis=1)
        elif self.goal_sampling_strategy == 'presampled':
            raise NotImplementedError
            idxs = np.random.randint(
                self.presampled_goals.shape[0], size=batch_size,
            )
            state_goals = self.presampled_goals[idxs, :]
            if self.use_euler:
                qpos = state_goals[:,:15]
                qvel = state_goals[:,15:]
                epos = self._qpos_to_epos(qpos)
                state_goals = np.hstack((epos, qvel))
            if not self.vel_in_state:
                state_goals = state_goals[:, :self.pos_dim]
        else:
            raise NotImplementedError(self.goal_sampling_strategy)    
        
        return state_goals

    def step(self, a):
        # xposbefore = self.get_body_com("torso")[0]
        self.do_simulation(a, self.frame_skip)
        
        ob = self._get_obs()
        done = False
        info = {
            'is_success': self.is_success(ob['achieved_goal'], self._state_goal),
            'l2_distance_to_goal' : np.linalg.norm(ob['desired_goal']-ob['achieved_goal'], ord=2, axis = -1),
            'l1_distance_to_goal' : np.linalg.norm(ob['desired_goal']-ob['achieved_goal'], ord=1, axis = -1),
            'l2_distance_to_goal_of_interest' : np.linalg.norm(ob['desired_goal'][-2:]-ob['achieved_goal'][-2:], ord=2, axis=-1),
            'l1_distance_to_goal_of_interest' : np.linalg.norm(ob['desired_goal'][-2:]-ob['achieved_goal'][-2:], ord=1, axis=-1),
            
        }
        if not self.full_state_goal:
            info.update({'l2_distance_to_goal_for_reward' : info['l2_distance_to_goal'],
                         'l1_distance_to_goal_for_reward' : info['l1_distance_to_goal']})
        elif self.reward_by_body_xy_pos:
            info.update({'l2_distance_to_goal_for_reward' : np.linalg.norm(ob['desired_goal'][-2:]-ob['achieved_goal'][-2:], ord=2, axis=-1),  
                         'l1_distance_to_goal_for_reward' : np.linalg.norm(ob['desired_goal'][-2:]-ob['achieved_goal'][-2:], ord=1, axis=-1),
                         })
        else :
            info.update({'l2_distance_to_goal_for_reward' : np.linalg.norm(
                np.concatenate([ob['desired_goal'][:self.obs_nqpos], ob['desired_goal'][-2:]], axis =-1)-
                np.concatenate([ob['achieved_goal'][:self.obs_nqpos], ob['achieved_goal'][-2:]], axis =-1),
                ord=2, axis = -1
            ),
            'l1_distance_to_goal_for_reward' : np.linalg.norm(
                np.concatenate([ob['desired_goal'][:self.obs_nqpos], ob['desired_goal'][-2:]], axis =-1)-
                np.concatenate([ob['achieved_goal'][:self.obs_nqpos], ob['achieved_goal'][-2:]], axis =-1),
                ord=1, axis = -1
            )})
        
        reward = self.compute_reward(ob['achieved_goal'], self._state_goal, info)
        self._set_goal_marker(self._state_goal)

        return ob, reward, done, info
 
    def is_success(self, achieved_goal, desired_goal):
        d = np.linalg.norm(achieved_goal[-2:]-desired_goal[-2:])
        return (d < self.distance_threshold).astype(np.float32)
    
    # Has no meaning in TDM, LEAP
    def compute_reward(self, achieved_goal, goal, info = None):
        distance = np.linalg.norm(goal[..., -2:]-achieved_goal[..., -2:], axis = -1)
        if not self.sparse_reward:    
            reward = -distance
        else :
            # for batch inputs
            reward = (np.array([distance < self.distance_threshold]).squeeze()).astype(np.float) - 1.0
            # if distance < self.distance_threshold:
            #     reward = 0.0
            # else :
            #     reward = -1.0
        return reward

    def _get_obs(self):
        xpos = np.array([self.get_body_com("torso")[0]]) 
        ypos = np.array([self.get_body_com("torso")[1]])
        # ypos = np.array([self.sim.data.qpos.flat[1]]) 
        body_xy_pos = np.concatenate([xpos, ypos], axis =-1)
        obs = np.concatenate([
                self.sim.data.qpos.flat[2:],
                self.sim.data.qvel.flat,
                # np.clip(self.sim.data.cfrc_ext, -1, 1).flat, # too much increase in obs space
                body_xy_pos
            ])
        if self.reduced_observation:
            obs = body_xy_pos
        else :
            pass

        if self.full_state_goal:
            achieved_goal = obs
        else :
            achieved_goal = body_xy_pos
            
        return {
            'observation' : obs.copy(),
            'achieved_goal' : achieved_goal.copy(),
            'desired_goal' : self._state_goal.copy(), 
        }    
    @property
    def goal(self):
        return self._state_goal.copy()

    # state_goal should be defined in child class
    def set_goal(self, goal):
        self._state_goal = goal
        self._set_goal_marker(goal)

    def convert_goal_for_reward(self, goal):
        if goal.ndim ==1:
            if not self.full_state_goal:
                return goal
            elif self.reward_by_body_xy_pos:
                return goal[-2:]
            else: #exclude qvel and first 2 element of qpos in reward computation in outer wrapper
                return np.concatenate([goal[:self.obs_nqpos], goal[-2:]], axis =-1)
        elif goal.ndim ==2:
            if not self.full_state_goal:
                return goal
            elif self.reward_by_body_xy_pos:
                return goal[:, -2:]
            else: #exclude qvel and first 2 element of qpos in reward computation in outer wrapper
                return np.concatenate([goal[:, :self.obs_nqpos], goal[:, -2:]], axis =-1)
        else :
            raise NotImplementedError
    

    def reset_model(self, goal=None):
        if self.random_init:
            qpos = self.init_qpos[2:] + self.np_random.uniform(size=self.model.nq-2, low=-.1, high=.1)
            qpos = np.concatenate([self.goal_xy_space.sample(), qpos], axis =-1)
        else:
            qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        self.set_state(qpos, qvel)
        
        xpos = np.array([self.get_body_com("torso")[0]]) 
        ypos = np.array([self.get_body_com("torso")[1]])
        self._state_goal = self._sample_goal(goal)
        body_xy_pos = self._state_goal[-2:]
        dist = np.linalg.norm(body_xy_pos - np.concatenate([xpos, ypos]))
        # while dist < 1:
        #     self._state_goal = self._sample_goal()
        #     body_xy_pos = self._state_goal[-2:]
        #     dist = np.linalg.norm(body_xy_pos- np.concatenate([xpos, ypos]))
        
        self._set_goal_marker(self._state_goal)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5



    def _set_goal_marker(self, goal):
        """
        This should be use ONLY for visualization. Use self._state_goal for
        logging, learning, etc.
        """
        
        self.data.site_xpos[self.model.site_name2id('goal')] = (
            np.concatenate([goal[-2:], np.array([0.5])])
        )
    
    def get_current_goal(self):
        return self._state_goal
    
    def reset_goal(self, goal):
        self._state_goal = goal.copy()
    
    # override for curriculum based goal
    def reset(self, goal = None):
        self.sim.reset()
        ob = self.reset_model(goal)
        return ob
