import gym
from gym import spaces
import pybullet as pb
import numpy as np
import os
from collections import deque

FILE_DIR = os.path.dirname(os.path.realpath(__file__))


class SpiralEnv(gym.Env):

    def __init__(self, render=False, n_frames=3):

        self.steps = 0
        self.n_frames = n_frames
        self.action_dim = 2
        self.max_action = 1.0
        self.frames_buffer = deque(maxlen=self.n_frames)
        self.render = render
        if render:
            pb.connect(pb.GUI,
                       options='--background_color_red=1.0 --background_color_green=1.0 --background_color_blue=1.0')
            pb.configureDebugVisualizer(pb.COV_ENABLE_GUI, 0)
        else:
            pb.connect(pb.DIRECT)

        pb.resetSimulation()

        # Enable gravity
        pb.setGravity(0, 0, -10)

        # Add plane
        block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                            halfExtents=[2.0, 2.0, 0.025],
                                            visualFramePosition=np.array([0, 0, 0.0]),
                                            rgbaColor=np.array([0.8, 0.8, 0.9, 0.9]))
        block_collision = pb.createCollisionShape(shapeType=pb.GEOM_BOX,
                                                  halfExtents=[2.0, 2.0, 0.025],
                                                  collisionFramePosition=np.array([0, 0, 0.0]))
        pb.createMultiBody(baseMass=0,
                           baseVisualShapeIndex=block_visual,
                           baseCollisionShapeIndex=block_collision,
                           basePosition=[0.0, 0.0, 0.025])

        # Add block robot
        offset = [0.0, 0.0, 0.2]
        self.robot = pb.loadURDF(os.path.join(FILE_DIR, "block_robot.urdf"), offset, useFixedBase=True)

        self.goal_pos = np.array([-0.175, 0.0])
        self.goal_margin = 0.25

        # Add walls to scene
        self.obstacles_all = []
        wall_pos = [[0.0, -2.05], [0.0, 2.05], [-2.05, 0.0], [2.05, 0.0], [-0.5, -1.25], [1.1, 0], [0, 1.25],
                    [-1.25, 0.4], [-0.4, -0.45], [0.25, 0], [-0.1, 0.4]]
        wall_half_ext = [[2, 0.05], [2, 0.05], [0.05, 2.1], [0.05, 2.1], [1.5, 0.1], [0.1, 1.35], [1.2, 0.1],
                         [0.1, 0.95], [0.75, 0.1], [0.1, 0.5], [0.3, 0.1,]]
        for i in range(len(wall_pos)):
            half_ext = wall_half_ext[i]
            pos = wall_pos[i]

            block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                                halfExtents=[half_ext[0], half_ext[1], 0.2],
                                                visualFramePosition=np.array([0, 0, 0.0]),
                                                rgbaColor=np.array([0.5, 0.5, 0.5, 1.0]))
            block_collision = pb.createCollisionShape(shapeType=pb.GEOM_BOX,
                                                      halfExtents=[half_ext[0], half_ext[1], 0.2],
                                                      collisionFramePosition=np.array([0, 0, 0.0]))
            wall = pb.createMultiBody(baseMass=0,
                                      baseVisualShapeIndex=block_visual,
                                      baseCollisionShapeIndex=block_collision,
                                      basePosition=[pos[0], pos[1], 0.2])

            self.obstacles_all.append(wall)

        # Add goal region box
        block_visual = pb.createVisualShape(shapeType=pb.GEOM_BOX,
                                            halfExtents=[self.goal_margin, self.goal_margin, 0.1],
                                            visualFramePosition=np.array([0, 0, 0.0]),
                                            rgbaColor=np.array([1, 0.0, 0.0, 1]))
        pb.createMultiBody(baseMass=0,
                           baseVisualShapeIndex=block_visual,
                           basePosition=[self.goal_pos[0], self.goal_pos[1], 0.05])

        # Set camera properties
        self.viewMatrix = pb.computeViewMatrix(
            cameraEyePosition=[0, 0.0, 5],
            cameraTargetPosition=[0, 0, 0],
            cameraUpVector=[0, 1, 0])

        self.projectionMatrix = pb.computeProjectionMatrixFOV(
            fov=45.0,
            aspect=1.0,
            nearVal=0.1,
            farVal=5.1)

        # Define action space
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)

        # Define observation space
        obs = self.reset()
        self.observation_space = spaces.Dict(dict(
            desired_goal=spaces.Box(-1.0, 1.0, shape=obs['state'].shape, dtype='float32'),
            achieved_goal=spaces.Box(-1.0, 1.0, shape=obs['state'].shape, dtype='float32'),
            observation=spaces.Box(-1.0, 1.0, shape=obs['video'].shape, dtype='float32'), ))

    def simulate_n_steps(self, n=75):
        for i in range(n):
            pb.stepSimulation()

    def step(self, action):

        action = np.clip(action, -1.0, 1.0)

        # Scale action
        action *= 0.5

        # Set action
        v_x = action[0]
        v_y = action[1]
        pb.setJointMotorControl2(self.robot, 0, pb.VELOCITY_CONTROL, targetVelocity=v_x, force=0.5)
        pb.setJointMotorControl2(self.robot, 1, pb.VELOCITY_CONTROL, targetVelocity=v_y, force=0.5)

        # Run simulation
        self.simulate_n_steps()

        # Append observation to image buffer
        self._add_image_to_buffer()

        # Get new observations
        obs = self._get_obs()
        reward = 0.0
        self.steps += 1
        done = False
        info = {}

        # Check if goal is reached
        if self._is_goal_achieved(obs):
            reward += 1.0
            done = True
        return obs, reward, done, info

    def _is_goal_achieved(self, obs):
        return (np.abs(obs["state"] - self.goal_pos) <= self.goal_margin).all()

    def _render_rgb_image(self):
        return pb.getCameraImage(
            width=64,
            height=64,
            viewMatrix=self.viewMatrix,
            projectionMatrix=self.projectionMatrix)[2][:, :, 0:3]

    def _add_image_to_buffer(self):
        img = self._render_rgb_image()
        self.frames_buffer.appendleft(img.copy())

    def _clear_buffer(self):
        self.frames_buffer.clear()

    def reset(self, start=None, reset_info=None):

        self._clear_buffer()

        if reset_info is not None and reset_info["mode"] == "eval":
            start = np.array([-1.75,-1.75]) + np.random.normal(0.0, 0.1, 2).clip(-0.1,0.1)

        # Set initial state
        if start is None:
            # Randomly sample if none was specified
            while 1:
                initial_pos = np.random.uniform(-2.0, 2.0, 2)
                pb.resetJointState(self.robot, 0, initial_pos[0], targetVelocity=0.0)
                pb.resetJointState(self.robot, 1, initial_pos[1], targetVelocity=0.0)
                pb.stepSimulation()

                # Break if no collision with walls
                if not self._is_goal_achieved(obs=self._get_obs()) and np.all([len(pb.getContactPoints(self.robot, obs)) <= 0 for obs in self.obstacles_all]):
                    break

        else:
            # Set initial state
            initial_pos = start
            pb.resetJointState(self.robot, 0, initial_pos[0], targetVelocity=0.0)
            pb.resetJointState(self.robot, 1, initial_pos[1], targetVelocity=0.0)
            pb.stepSimulation()

        # Fill image buffer
        for _ in range(self.n_frames):
            self.simulate_n_steps()
            self._add_image_to_buffer()

        # Get first observation
        obs = self._get_obs()

        return obs

    def _get_obs(self):
        obs = dict()
        pos_x = pb.getJointState(self.robot, 0)[0]
        pos_y = pb.getJointState(self.robot, 1)[0]
        obs["state"] = np.array([pos_x, pos_y])
        if self.frames_buffer:
            obs["video"] = np.concatenate(self.frames_buffer, axis=2)
        return obs

    def close(self):
        pb.disconnect()


# Test environment
if __name__ == '__main__':
    import matplotlib.pyplot as plt

    env = SpiralEnv(render=True)

    for j in range(0, 100):

        _ = env.reset()

        action = np.random.uniform(-1., 1., 2)
        for k in range(0, 25000):
            action += np.random.normal(0.0, 0.25, 2)
            action = np.clip(action, -1.0, 1.0)
            obs, reward, done, info = env.step(action=action)

            plt.pause(0.01)
