
import gym
from gym import spaces
import pybullet as pb
import numpy as np
import os
from collections import deque

FILE_DIR = os.path.dirname(os.path.realpath(__file__))


class ObstacleEnv(gym.Env):

    def __init__(self, render=False, n_frames=3):

        self.steps = 0
        self.n_frames = n_frames
        self.action_dim = 2
        self.max_action = 1.0
        self.goal_thres_y = 1.8
        self.frames_buffer = deque(maxlen=self.n_frames)
        self.render = render
        if render:
            pb.connect(pb.GUI,
                       options='--background_color_red=1.0 --background_color_green=1.0 --background_color_blue=1.0')
            pb.configureDebugVisualizer(pb.COV_ENABLE_GUI, 0)
        else:
            pb.connect(pb.DIRECT)

        pb.resetSimulation()

        # Enable gravity
        pb.setGravity(0, 0, -10)

        # Add plane
        block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                            halfExtents=[2.0, 2.0, 0.025],
                                            visualFramePosition=np.array([0, 0, 0.0]),
                                            rgbaColor=np.array([0.8, 0.8, 0.9, 0.9]))
        block_collision = pb.createCollisionShape(shapeType=pb.GEOM_BOX,
                                                  halfExtents=[2.0, 2.0, 0.025],
                                                  collisionFramePosition=np.array([0, 0, 0.0]))
        pb.createMultiBody(baseMass=0,
                           baseVisualShapeIndex=block_visual,
                           baseCollisionShapeIndex=block_collision,
                           basePosition=[0.0, 0.0, 0.025])

        # Add block robot
        offset = [0.0, 0.0, 0.2]
        self.robot = pb.loadURDF(os.path.join(FILE_DIR, "block_robot.urdf"), offset, useFixedBase=True)

        self.goal_pos = np.array([-0.175, 0.0])
        self.goal_radius = 0.25

        # Add goal region box
        block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                            halfExtents=[2.0, 0.1, 0.05],
                                            visualFramePosition=np.array([0, 0.0, 0.0]),
                                            rgbaColor=np.array([1.0, 0.2, 0.2, 1.0]))
        wall = pb.createMultiBody(baseMass=0,
                                  baseVisualShapeIndex=block_visual,
                                  basePosition=[0.0, 1.9, 0.025])

        # Add walls to scene
        self.obstacles_all = []
        wall_pos = [[0.0, -2.05], [0.0, 2.05], [-2.05, 0.0], [2.05, 0.0], [0.0, -0.75], [0.0, 0.5]]
        wall_half_ext = [[2, 0.05], [2, 0.05], [0.05, 2.1], [0.05, 2.1], [1.5, 0.1], [1.5, 0.1]]
        for i in range(len(wall_pos)):
            half_ext = wall_half_ext[i]
            pos = wall_pos[i]

            block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                                halfExtents=[half_ext[0], half_ext[1], 0.2],
                                                visualFramePosition=np.array([0, 0, 0.0]),
                                                rgbaColor=np.array([0.5, 0.5, 0.5, 1.0]))
            block_collision = pb.createCollisionShape(shapeType=pb.GEOM_BOX,
                                                      halfExtents=[half_ext[0], half_ext[1], 0.2],
                                                      collisionFramePosition=np.array([0, 0, 0.0]))
            wall = pb.createMultiBody(baseMass=0,
                                      baseVisualShapeIndex=block_visual,
                                      baseCollisionShapeIndex=block_collision,
                                      basePosition=[pos[0], pos[1], 0.2])

            self.obstacles_all.append(wall)

        # Set camera properties
        self.viewMatrix = pb.computeViewMatrix(
            cameraEyePosition=[0, 0.0, 5.1],
            cameraTargetPosition=[0, 0, 0],
            cameraUpVector=[0, 1, 0])

        self.projectionMatrix = pb.computeProjectionMatrixFOV(
            fov=45.0,
            aspect=1.0,
            nearVal=0.1,
            farVal=5.1)

        # Define action space
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)

        # Define observation space
        obs = self.reset()
        self.observation_space = spaces.Dict(dict(
            desired_goal=spaces.Box(-1.0, 1.0, shape=obs['state'].shape, dtype='float32'),
            achieved_goal=spaces.Box(-1.0, 1.0, shape=obs['state'].shape, dtype='float32'),
            observation=spaces.Box(-1.0, 1.0, shape=obs['video'].shape, dtype='float32'), ))

    def simulate_n_steps(self, n=75):
        for i in range(n):
            pb.stepSimulation()

    def step(self, action):

        action = np.clip(action, -1.0, 1.0)

        # Scale action
        action *= 0.5

        # Set action
        v_x = action[0]
        v_y = action[1]
        pb.setJointMotorControl2(self.robot, 0, pb.VELOCITY_CONTROL, targetVelocity=v_x, force=0.5)
        pb.setJointMotorControl2(self.robot, 1, pb.VELOCITY_CONTROL, targetVelocity=v_y, force=0.5)

        # Run simulation
        self.simulate_n_steps()

        # Append observation to image buffer
        self._add_image_to_buffer()

        # Get new observations
        obs = self._get_obs()

        reward = self.is_goal_reached()
        self.steps += 1
        done = reward > 0.0
        info = {}

        return obs, reward, done, info

    def is_goal_reached(self):
        p_y = pb.getJointState(self.robot, 1)[0]
        return float(p_y > self.goal_thres_y)

    def _render_rgb_image(self):
        return pb.getCameraImage(
            width=64,
            height=64,
            viewMatrix=self.viewMatrix,
            projectionMatrix=self.projectionMatrix)[2][:, :, 0:3]

    def _add_image_to_buffer(self):
        img = self._render_rgb_image()
        self.frames_buffer.appendleft(img.copy())

    def _clear_buffer(self):
        self.frames_buffer.clear()

    def reset(self, start=None, reset_info=None):

        self._clear_buffer()

        if reset_info is None:
            reset_info = {"mode": "train", "context_params": None}

        if reset_info["context_params"] is None:
            obs_pos = np.random.uniform(-0.5,0.5, 2)
        else:
            obs_pos = reset_info["context_params"]

        if reset_info["mode"] == "eval":
            start = [np.random.uniform(-0.5,0.5), np.random.uniform(-1.7,-1.5)]

        obs1_pos = pb.getBasePositionAndOrientation(self.obstacles_all[-1])[0]
        obs2_pos = pb.getBasePositionAndOrientation(self.obstacles_all[-2])[0]
        obs1_pos = [obs_pos[0], obs1_pos[1], obs1_pos[2]]
        obs2_pos = [obs_pos[1], obs2_pos[1], obs2_pos[2]]
        pb.resetBasePositionAndOrientation(self.obstacles_all[-1], obs1_pos,
                                           pb.getBasePositionAndOrientation(self.obstacles_all[-1])[1])

        pb.resetBasePositionAndOrientation(self.obstacles_all[-2], obs2_pos,
                                           pb.getBasePositionAndOrientation(self.obstacles_all[-2])[1])

        # Set initial state
        if start is None:
            # Randomly sample if none was specified
            while 1:
                initial_pos = np.random.uniform(-2.0, 1.75, 2)
                pb.resetJointState(self.robot, 0, initial_pos[0], targetVelocity=0.0)
                pb.resetJointState(self.robot, 1, initial_pos[1], targetVelocity=0.0)
                pb.stepSimulation()

                # Break if not and goal and no collision with walls
                # if np.all([len(pb.getContactPoints(self.robot, obs)) <= 0 for obs in self.obstacles_all]):
                if not self.is_goal_reached() and np.all([len(pb.getContactPoints(self.robot, obs)) <= 0 for obs in self.obstacles_all]):
                    break
        else:
            # Set initial state
            initial_pos = start
            pb.resetJointState(self.robot, 0, initial_pos[0], targetVelocity=0.0)
            pb.resetJointState(self.robot, 1, initial_pos[1], targetVelocity=0.0)
            pb.stepSimulation()

        # Fill image buffer
        for _ in range(self.n_frames):
            self.simulate_n_steps()
            self._add_image_to_buffer()

        # Get first observation
        obs = self._get_obs()

        return obs

    def _get_obs(self):
        obs = dict()
        pos_x = pb.getJointState(self.robot, 0)[0]
        pos_y = pb.getJointState(self.robot, 1)[0]
        obs["state"] = np.array([pos_x, pos_y])
        obs["video"] = np.concatenate(self.frames_buffer, axis=2)
        return obs

    def close(self):
        pb.disconnect()


# Test environment
if __name__ == '__main__':

    env = ObstacleEnv(render=True)

    for j in range(0, 1000):

        _ = env.reset()

        action = np.random.uniform(-1., 1., 2)
        for k in range(0, 100000):
            action += np.random.normal(0.0, 0.25, 2)
            action = np.clip(action, -1.0, 1.0)
            obs, reward, done, info = env.step(action=action)
