from gym.envs.robotics.fetch.pick_and_place import FetchPickAndPlaceEnv
from gym.envs.robotics.utils import robot_get_obs
from gym.utils import EzPickle
from gym.spaces import Box
import numpy as np
from gym.wrappers import TimeLimit


class PickAndPlace(FetchPickAndPlaceEnv, EzPickle):

    def __init__(self, reward_type='sparse', with_noise=True, trajectory=None, img_obs=False):
        # Just need to set this to false temporarily in order to get deterministic reference positions from the first
        # reset
        self.with_noise = False

        FetchPickAndPlaceEnv.__init__(self, reward_type=reward_type)
        EzPickle.__init__(self)

        self.with_noise = with_noise
        self.x_noise = 0.03

        self.img_obs = img_obs
        if self.img_obs:
            # Also a Tuple observation would be more appropriate, we simply flatten the data in order to work with
            # e.g. SAC
            img_size = 128 * 128 * 3
            img_low = np.zeros(img_size)
            img_high = np.ones(img_size)
            self.observation_space = Box(low=np.concatenate((-np.inf * np.ones(10), img_low)),
                                         high=np.concatenate((np.inf * np.ones(10), img_high)))
        else:
            self.observation_space = Box(
                low=np.concatenate(
                    (self.observation_space["observation"].low, self.observation_space["desired_goal"].low)),
                high=np.concatenate(
                    (self.observation_space["observation"].high, self.observation_space["desired_goal"].high)))

        self.trajectory = trajectory
        self.context = None

    def _sample_goal(self):
        if self.has_object:
            goal = np.array([1.34193226, 0.89910037, 0.62469975])
            if self.with_noise:
                goal += np.random.uniform(-self.x_noise, self.x_noise, size=(3,))
        else:
            raise RuntimeError("PickAndPlace needs to have an object")
        return goal.copy()

    def _reset_sim(self):
        self.sim.set_state(self.initial_state)

        # Randomize start position of object.
        if self.has_object:
            object_xpos = self.initial_gripper_xpos[:2]
            object_xpos = object_xpos - np.array([0., self.obj_range])

            object_qpos = self.sim.data.get_joint_qpos('object0:joint')
            assert object_qpos.shape == (7,)
            object_qpos[:2] = object_xpos
            self.sim.data.set_joint_qpos('object0:joint', object_qpos)
        else:
            raise RuntimeError("PickAndPlace needs to have an object")

        if self.trajectory is not None and self.context is not None:
            joint_state = self.trajectory[self.context]
            for i, joint in enumerate([n for n in self.sim.model.joint_names if n.startswith('robot')]):
                self.sim.data.set_joint_qpos(joint, joint_state[i])
            self.sim.data.set_joint_qpos('object0:joint', joint_state[15:])

        self.sim.forward()

        if self.with_noise:
            object_qpos = self.sim.data.get_joint_qpos('object0:joint')
            assert object_qpos.shape == (7,)
            # Check the contacts with the endeffector - we only move the object if it is not in contact with the
            # endeffector
            contact_exists = False
            for i in range(self.sim.data.ncon):
                contact = self.sim.data.contact[i]
                object_id = self.sim.model._geom_name2id["object0"]
                r_gripper_id = self.sim.model._geom_name2id["robot0:r_gripper_finger_link"]
                l_gripper_id = self.sim.model._geom_name2id["robot0:l_gripper_finger_link"]

                if (contact.geom1 == object_id and contact.geom2 in [r_gripper_id, l_gripper_id]) or \
                        (contact.geom2 == object_id and contact.geom1 in [r_gripper_id, l_gripper_id]):
                    # print("Not changing initial position since object and gripper is in contact")
                    contact_exists = True
                    break

            if not contact_exists:
                object_qpos[:2] += np.random.uniform(-self.x_noise, self.x_noise, size=(2,))
                self.sim.data.set_joint_qpos('object0:joint', object_qpos)

        self.sim.forward()
        return True

    def reset(self):
        did_reset_sim = False
        while not did_reset_sim:
            did_reset_sim = self._reset_sim()
        self.goal = self._sample_goal().copy()
        obs = self._get_obs()
        return self.process_observation(obs)

    def step(self, action):
        obs, reward, done, info = FetchPickAndPlaceEnv.step(self, action)
        self.set_info_reward(info)
        done = done or info["is_success"] != 0
        return self.process_observation(obs), reward, done, info

    def _viewer_setup(self):
        body_id = self.sim.model.body_name2id('table0')
        lookat = self.sim.data.body_xpos[body_id] + np.array([0., 0., 0.3])
        for idx, value in enumerate(lookat):
            self.viewer.cam.lookat[idx] = value
        self.viewer.cam.distance = 1.
        self.viewer.cam.azimuth = 150.
        self.viewer.cam.elevation = -20.

    def set_info_reward(self, info):
        info["success"] = info["is_success"]
        if info["is_success"] != 0:
            info["reward"] = 1.
        else:
            info["reward"] = 0.

    def process_observation(self, obs):
        if self.img_obs:
            img_obs = self.render("rgb_array", width=128, height=128).astype(np.float64) / 255.
            return np.concatenate((obs["observation"][[0, 1, 2, 9, 10, 20, 21, 22, 23, 24]], img_obs.reshape(-1)))
        else:
            return np.concatenate([obs["observation"], obs["desired_goal"]], axis=0)


def generate_demonstration():
    env = TimeLimit(PickAndPlace(with_noise=False), max_episode_steps=150)
    obs = env.reset()
    internal_states = []
    internal_states.append(np.concatenate((robot_get_obs(env.env.sim)[0],
                                           env.env.sim.data.get_joint_qpos('object0:joint')), axis=0))
    external_states = [obs[:6]]
    actions = []
    # env.render()

    # Use a hard-coded trajectory to go to the object, grasp it and put it to the goal
    done = False

    mode = 0
    count = 0
    while not done:
        # Go to initial object
        gripper_pos = obs[:3]
        object_pos = obs[3:6]
        if mode == 0:
            target_pos = np.copy(object_pos) + np.array([0., 0., 0.1])
            diff = target_pos - gripper_pos
            if np.linalg.norm(diff) > 1e-2:
                movement = np.sign(diff) * np.clip(np.abs(diff), 0.2, np.inf)
                movement = np.where(np.abs(diff) < 0.01, diff, movement)
                action = np.concatenate((movement, [1.]), axis=0)
                obs, reward, done, info = env.step(action)
            else:
                mode = 1

        if mode == 1:
            target_pos = object_pos + np.array([0., 0., 0.01])
            diff = target_pos - gripper_pos
            if np.linalg.norm(target_pos - gripper_pos) > 1e-2:
                movement = np.sign(diff) * np.clip(np.abs(diff), 0.2, np.inf)
                movement = np.where(np.abs(diff) < 0.01, diff, movement)
                action = np.concatenate((movement, [1.]), axis=0)
                obs, reward, done, info = env.step(action)
            else:
                mode = 2

        if mode == 2:
            obs, reward, done, info = env.step(np.array([0., 0., 0., -1.]))
            count += 1

            if count == 5:
                mode = 3

        if mode == 3:
            target_pos = env.goal
            diff = target_pos - gripper_pos
            movement = np.sign(diff) * np.clip(np.abs(diff), 0.2, np.inf)
            movement = np.where(np.abs(diff) < 0.01, diff, movement)
            action = np.concatenate((movement, [-1.]), axis=0)
            obs, reward, done, info = env.step(action)
            done = reward > -1.

        actions.append(action)

        # env.render()
        internal_states.append(np.concatenate((robot_get_obs(env.env.sim)[0],
                                               env.env.sim.data.get_joint_qpos('object0:joint')), axis=0))
        external_states.append(obs[:6])

    return internal_states, external_states, np.array(actions)
