
from safety_gym.envs.engine import Engine
from sklearn.utils.validation import check_random_state
import numpy as np
import pickle
from mujoco_py import MjSimState


class Env(Engine):
    """Wrapper for Open AI Safety Safexp-CarGoal1-v0 ."""
    def __init__(self, config={}):
        config = {
                'robot_base': 'xmls/car.xml',
                'task': 'goal',
                'observe_goal_dist': True,
                'observe_goal_lidar': False,
                'observe_box_lidar': False,
                'observe_sensors': True,  # Observe all sensor data from simulator
                'sensors_hinge_joints': False,  # Observe named joint position / velocity sensors
                'sensors_ball_joints': False,  # Observe named balljoint position / velocity sensors
                'sensors_angle_components': False,  # Observe sin/cos theta instead of theta
                'sensors_obs': ['accelerometer', 'velocimeter', 'gyro'],

                # 'observe_walls': True,
                'lidar_max_dist': None,
                'observe_hazards': True,
                'observe_vases': False,
                'constrain_hazards': True,
                'hazards_num': 8,
                'vases_num': 0,
                'observation_flatten': True,
                'observe_goal_comp': True,
                'continue_goal': True,  # If true, draw a new goal after achievement
        }
        super(Env, self).__init__(config)
        self._elapsed_steps = 0
        self.world_states_pickle = []
        self.save_world = True # For real-system



    def seed(self, seed=None):
        """Same as parent method but passing a RandomState instance is allowed.
        """
        super(Env, self).seed(seed)
        self.np_random = check_random_state(seed)
        return [seed]

    def reset(self):
        """Same as parent method but returns states instead of observations."""
        observation = super(Env, self).reset()
        self._elapsed_steps = 0
        self.save_world = True

        # observation = self.flatten_observation(observation)
        return observation

    def reward_bis(self):
        """ We invert the distance from the goal to the real one and use the negative as a reward."""
        # The goal distance returned in the observations is inversely proportional to the real goal dist.
        # This means, the higher the obs_goal_dist, the closer we are to the goal.
        # For this reason we invert it to return to the real goal distance
        dist_goal = self.obs()[5] #['goal_dist'])
        # Then we use the negative of the goal distance as a reward, so maximizing r leads to minimizing the goal dist

        return dist_goal

    def cost(self):
        # Lidars give the closeness to an object. So reading of 1 means that we are on the obj,
        # reading of 0 means the object is far away
        lidar_hazard_obs = self.obs()[9:19] #['hazards_lidar']#
        cost = {'cost': np.count_nonzero(lidar_hazard_obs > 0.9)}
        self._cost = cost
        return cost

    def obs(self):
        obs = super(Env, self).obs()
        # Drop ballang and ballvel...
        obs = np.delete(obs, [3, 4, 5, 6, 7, 8, 9])
        return obs

    def step(self, action):
        """Same as parent method but using benchmark paper reward"""
        if self.save_world:
            self.world_states_pickle = [pickle.dumps(self.get_world_state()['world'])]
            self.save_world = False
        observation, reward, done, info = super(Env, self).step(action)
        info['original_reward'] = reward
        self._elapsed_steps += 1
        # observation = self.flatten_observation(observation)
        if 'goal_met' in info:
            self.world_states_pickle.append(pickle.dumps(self.get_world_state()['world']))
        return observation, self.reward_bis(), done, info

    def flatten_observation(self, observation):
        """
        Function used to flatten the observation
        """
        flat_obs = np.concatenate([observation[k].flatten() for k in sorted(observation.keys())])
        return flat_obs

    def get_state(self):
        """Get state of the environement."""
        state_dict = {
            'qpos': self.data.qpos.copy(),
            'qvel': self.data.qvel.copy(),
            '_elapsed_steps': self._elapsed_steps,
        }
        return state_dict

    def set_numpy_state(self, state):
        old_state = self.sim.get_state()
        qpos = state[:13]
        qvel = state[13:]
        new_state = MjSimState(old_state.time, qpos, qvel,
                                         old_state.act, old_state.udd_state)

        self.world.sim.set_state(new_state)

    def get_numpy_state(self):
        """Get the state numpy array from the environment."""
        state_dict = self.get_state()
        return np.r_[state_dict['qpos'], state_dict['qvel']].squeeze()

    def get_world_state(self):
        """
        Get the state of the whole world
        """
        state = {"world": self.world_config_dict.copy(),
                 'steps': self.steps}
        return state

    def rebuild_sim_to_state(self, state, old_state=False):
        """
        This rebuilds the whole world to the given state
        """
        # self.world.reset(build=False)
        self.world.rebuild(state['world'], state=old_state)

        self.robot_rot = state['world']['robot_rot']
        self.steps = state['steps']
        self.last_dist_goal = self.dist_goal()
        self.update_layout()
