{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Installation\n",
    "\n",
    "!curl -L https://raw.githubusercontent.com/facebookresearch/habitat-sim/main/examples/colab_utils/colab_install.sh | NIGHTLY=true bash -s\n",
    "!wget -c http://dl.fbaipublicfiles.com/habitat/mp3d_example.zip && unzip -o mp3d_example.zip -d /content/habitat-sim/data/scene_datasets/mp3d/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip uninstall --yes pyopenssl\n",
    "!pip install pyopenssl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Colab Setup and Imports { display-mode: \"form\" }\n",
    "# @markdown (double click to see the code)\n",
    "\n",
    "import os\n",
    "import random\n",
    "import sys\n",
    "\n",
    "import git\n",
    "import numpy as np\n",
    "from gym import spaces\n",
    "\n",
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%cd \"/content/habitat-lab\"\n",
    "\n",
    "\n",
    "if \"google.colab\" in sys.modules:\n",
    "    # This tells imageio to use the system FFMPEG that has hardware acceleration.\n",
    "    os.environ[\"IMAGEIO_FFMPEG_EXE\"] = \"/usr/bin/ffmpeg\"\n",
    "repo = git.Repo(\".\", search_parent_directories=True)\n",
    "dir_path = repo.working_tree_dir\n",
    "%cd $dir_path\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import habitat\n",
    "from habitat.core.logging import logger\n",
    "from habitat.core.registry import registry\n",
    "from habitat.sims.habitat_simulator.actions import HabitatSimActions\n",
    "from habitat.tasks.nav.nav import NavigationTask\n",
    "from habitat_baselines.common.baseline_registry import baseline_registry\n",
    "from habitat_baselines.config.default import get_config as get_baselines_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Define Observation Display Utility Function { display-mode: \"form\" }\n",
    "\n",
    "# @markdown A convenient function that displays sensor observations with matplotlib.\n",
    "\n",
    "# @markdown (double click to see the code)\n",
    "\n",
    "\n",
    "# Change to do something like this maybe: https://stackoverflow.com/a/41432704\n",
    "def display_sample(\n",
    "    rgb_obs, semantic_obs=np.array([]), depth_obs=np.array([])\n",
    "):  # noqa: B006\n",
    "    from habitat_sim.utils.common import d3_40_colors_rgb\n",
    "\n",
    "    rgb_img = Image.fromarray(rgb_obs, mode=\"RGB\")\n",
    "\n",
    "    arr = [rgb_img]\n",
    "    titles = [\"rgb\"]\n",
    "    if semantic_obs.size != 0:\n",
    "        semantic_img = Image.new(\n",
    "            \"P\", (semantic_obs.shape[1], semantic_obs.shape[0])\n",
    "        )\n",
    "        semantic_img.putpalette(d3_40_colors_rgb.flatten())\n",
    "        semantic_img.putdata((semantic_obs.flatten() % 40).astype(np.uint8))\n",
    "        semantic_img = semantic_img.convert(\"RGBA\")\n",
    "        arr.append(semantic_img)\n",
    "        titles.append(\"semantic\")\n",
    "\n",
    "    if depth_obs.size != 0:\n",
    "        depth_img = Image.fromarray(\n",
    "            (depth_obs / 10 * 255).astype(np.uint8), mode=\"L\"\n",
    "        )\n",
    "        arr.append(depth_img)\n",
    "        titles.append(\"depth\")\n",
    "\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    for i, data in enumerate(arr):\n",
    "        ax = plt.subplot(1, 3, i + 1)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set_title(titles[i])\n",
    "        plt.imshow(data)\n",
    "    plt.show(block=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup PointNav Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat \"./configs/test/habitat_all_sensors_test.yaml\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_paths=\"./configs/test/habitat_all_sensors_test.yaml\"\n",
    "    )\n",
    "\n",
    "    try:\n",
    "        env.close()\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    action = None\n",
    "    obs = env.reset()\n",
    "    valid_actions = [\"TURN_LEFT\", \"TURN_RIGHT\", \"MOVE_FORWARD\", \"STOP\"]\n",
    "    interactive_control = False  # @param {type:\"boolean\"}\n",
    "    while action != \"STOP\":\n",
    "        display_sample(obs[\"rgb\"])\n",
    "        print(\n",
    "            \"distance to goal: {:.2f}\".format(\n",
    "                obs[\"pointgoal_with_gps_compass\"][0]\n",
    "            )\n",
    "        )\n",
    "        print(\n",
    "            \"angle to goal (radians): {:.2f}\".format(\n",
    "                obs[\"pointgoal_with_gps_compass\"][1]\n",
    "            )\n",
    "        )\n",
    "        if interactive_control:\n",
    "            action = input(\n",
    "                \"enter action out of {}:\\n\".format(\", \".join(valid_actions))\n",
    "            )\n",
    "            assert (\n",
    "                action in valid_actions\n",
    "            ), \"invalid action {} entered, choose one amongst \" + \",\".join(\n",
    "                valid_actions\n",
    "            )\n",
    "        else:\n",
    "            action = valid_actions.pop()\n",
    "        obs = env.step(\n",
    "            {\n",
    "                \"action\": action,\n",
    "            }\n",
    "        )\n",
    "\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    print(env.get_metrics())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = get_baselines_config(\n",
    "        \"./habitat_baselines/config/pointnav/ppo_pointnav_example.yaml\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seeds\n",
    "if __name__ == \"__main__\":\n",
    "    seed = \"42\"  # @param {type:\"string\"}\n",
    "    steps_in_thousands = \"10\"  # @param {type:\"string\"}\n",
    "\n",
    "    config.defrost()\n",
    "    config.TASK_CONFIG.SEED = int(seed)\n",
    "    config.TOTAL_NUM_STEPS = int(steps_in_thousands)\n",
    "    config.LOG_INTERVAL = 1\n",
    "    config.freeze()\n",
    "\n",
    "    random.seed(config.TASK_CONFIG.SEED)\n",
    "    np.random.seed(config.TASK_CONFIG.SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)\n",
    "    trainer = trainer_init(config)\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "# @markdown (double click to see the code)\n",
    "\n",
    "# example tensorboard visualization\n",
    "# for more details refer to [link](https://github.com/facebookresearch/habitat-lab/tree/main/habitat_baselines#additional-utilities).\n",
    "\n",
    "try:\n",
    "    from IPython import display\n",
    "\n",
    "    with open(\"./res/img/tensorboard_video_demo.gif\", \"rb\") as f:\n",
    "        display.display(display.Image(data=f.read(), format=\"png\"))\n",
    "except ImportError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Key Concepts\n",
    "\n",
    "All the concepts link to their definitions:\n",
    "\n",
    "1. [`habitat.sims.habitat_simulator.HabitatSim`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat/sims/habitat_simulator/habitat_simulator.py#L159)\n",
    "Thin wrapper over `habitat_sim` providing seamless integration with experimentation framework.\n",
    "\n",
    "\n",
    "2. [`habitat.core.env.Env`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat/core/env.py)\n",
    "Abstraction for the universe of agent, task and simulator. Agents that you train and evaluate operate inside the environment.\n",
    "\n",
    "\n",
    "3. [`habitat.core.env.RLEnv`](https://github.com/facebookresearch/habitat-lab/blob/71d409ab214a7814a9bd9b7e44fd25f57a0443ba/habitat/core/env.py#L278)\n",
    "Extends the `Env` class for reinforcement learning by defining the reward and other required components.\n",
    "\n",
    "\n",
    "4. [`habitat.core.embodied_task.EmbodiedTask`](https://github.com/facebookresearch/habitat-lab/blob/71d409ab214a7814a9bd9b7e44fd25f57a0443ba/habitat/core/embodied_task.py#L242)\n",
    "Defines the task that the agent needs to solve. This class holds the definition of observation space, action space, measures, simulator usage. Eg: PointNav, ObjectNav.\n",
    "\n",
    "\n",
    "5. [`habitat.core.dataset.Dataset`](https://github.com/facebookresearch/habitat-lab/blob/4b6da1c4f8eb287cea43e70c50fe1d615a261198/habitat/core/dataset.py#L63)\n",
    "Wrapper over information required for the dataset of embodied task, contains definition and interaction with an `episode`.\n",
    "\n",
    "\n",
    "6. [`habitat.core.embodied_task.Measure`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat/core/embodied_task.py#L82)\n",
    "Defines the metrics for embodied task, eg: [SPL](https://github.com/facebookresearch/habitat-lab/blob/d0db1b55be57abbacc5563dca2ca14654c545552/habitat/tasks/nav/nav.py#L533).\n",
    "\n",
    "\n",
    "7. [`habitat_baselines`](https://github.com/facebookresearch/habitat-lab/tree/71d409ab214a7814a9bd9b7e44fd25f57a0443ba/habitat_baselines)\n",
    "RL, SLAM, heuristic baseline implementations for the different embodied tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_paths=\"./configs/test/habitat_all_sensors_test.yaml\"\n",
    "    )\n",
    "\n",
    "\n",
    "@registry.register_task(name=\"TestNav-v0\")\n",
    "class NewNavigationTask(NavigationTask):\n",
    "    def __init__(self, config, sim, dataset):\n",
    "        logger.info(\"Creating a new type of task\")\n",
    "        super().__init__(config=config, sim=sim, dataset=dataset)\n",
    "\n",
    "    def _check_episode_is_active(self, *args, **kwargs):\n",
    "        logger.info(\n",
    "            \"Current agent position: {}\".format(self._sim.get_agent_state())\n",
    "        )\n",
    "        collision = self._sim.previous_step_collided\n",
    "        stop_called = not getattr(self, \"is_stop_called\", False)\n",
    "        return collision or stop_called\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    config.defrost()\n",
    "    config.TASK.TYPE = \"TestNav-v0\"\n",
    "    config.freeze()\n",
    "\n",
    "    try:\n",
    "        env.close()\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    action = None\n",
    "    env.reset()\n",
    "    valid_actions = [\"TURN_LEFT\", \"TURN_RIGHT\", \"MOVE_FORWARD\", \"STOP\"]\n",
    "    interactive_control = False  # @param {type:\"boolean\"}\n",
    "    while env.episode_over is not True:\n",
    "        display_sample(obs[\"rgb\"])\n",
    "        if interactive_control:\n",
    "            action = input(\n",
    "                \"enter action out of {}:\\n\".format(\", \".join(valid_actions))\n",
    "            )\n",
    "            assert (\n",
    "                action in valid_actions\n",
    "            ), \"invalid action {} entered, choose one amongst \" + \",\".join(\n",
    "                valid_actions\n",
    "            )\n",
    "        else:\n",
    "            action = valid_actions.pop()\n",
    "        obs = env.step(\n",
    "            {\n",
    "                \"action\": action,\n",
    "                \"action_args\": None,\n",
    "            }\n",
    "        )\n",
    "        print(\"Episode over:\", env.episode_over)\n",
    "\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new Sensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@registry.register_sensor(name=\"agent_position_sensor\")\n",
    "class AgentPositionSensor(habitat.Sensor):\n",
    "    def __init__(self, sim, config, **kwargs):\n",
    "        super().__init__(config=config)\n",
    "        self._sim = sim\n",
    "\n",
    "    # Defines the name of the sensor in the sensor suite dictionary\n",
    "    def _get_uuid(self, *args, **kwargs):\n",
    "        return \"agent_position\"\n",
    "\n",
    "    # Defines the type of the sensor\n",
    "    def _get_sensor_type(self, *args, **kwargs):\n",
    "        return habitat.SensorTypes.POSITION\n",
    "\n",
    "    # Defines the size and range of the observations of the sensor\n",
    "    def _get_observation_space(self, *args, **kwargs):\n",
    "        return spaces.Box(\n",
    "            low=np.finfo(np.float32).min,\n",
    "            high=np.finfo(np.float32).max,\n",
    "            shape=(3,),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "    # This is called whenver reset is called or an action is taken\n",
    "    def get_observation(self, observations, *args, episode, **kwargs):\n",
    "        return self._sim.get_agent_state().position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_paths=\"./configs/test/habitat_all_sensors_test.yaml\"\n",
    "    )\n",
    "\n",
    "    config.defrost()\n",
    "    # Now define the config for the sensor\n",
    "    config.TASK.AGENT_POSITION_SENSOR = habitat.Config()\n",
    "    # Use the custom name\n",
    "    config.TASK.AGENT_POSITION_SENSOR.TYPE = \"agent_position_sensor\"\n",
    "    # Add the sensor to the list of sensors in use\n",
    "    config.TASK.SENSORS.append(\"AGENT_POSITION_SENSOR\")\n",
    "    config.freeze()\n",
    "\n",
    "    try:\n",
    "        env.close()\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    obs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    print(obs[\"agent_position\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# An example agent which can be submitted to habitat-challenge.\n",
    "# To participate and for more details refer to:\n",
    "# - https://aihabitat.org/challenge/2020/\n",
    "# - https://github.com/facebookresearch/habitat-challenge\n",
    "\n",
    "\n",
    "class ForwardOnlyAgent(habitat.Agent):\n",
    "    def __init__(self, success_distance, goal_sensor_uuid):\n",
    "        self.dist_threshold_to_stop = success_distance\n",
    "        self.goal_sensor_uuid = goal_sensor_uuid\n",
    "\n",
    "    def reset(self):\n",
    "        pass\n",
    "\n",
    "    def is_goal_reached(self, observations):\n",
    "        dist = observations[self.goal_sensor_uuid][0]\n",
    "        return dist <= self.dist_threshold_to_stop\n",
    "\n",
    "    def act(self, observations):\n",
    "        if self.is_goal_reached(observations):\n",
    "            action = HabitatSimActions.STOP\n",
    "        else:\n",
    "            action = HabitatSimActions.MOVE_FORWARD\n",
    "        return {\"action\": action}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other Examples\n",
    "\n",
    "[Create a new action space](https://github.com/facebookresearch/habitat-lab/blob/main/examples/new_actions.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Sim2Real with Habitat { display-mode: \"form\" }\n",
    "\n",
    "try:\n",
    "    from IPython.display import HTML\n",
    "\n",
    "    HTML(\n",
    "        '<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/Hun2rhgnWLU\" frameborder=\"0\" allow=\"accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture\" allowfullscreen></iframe>'\n",
    "    )\n",
    "except ImportError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deploy habitat-sim trained models on real robots with the [habitat-pyrobot bridge](https://github.com/facebookresearch/habitat-lab/blob/71d409ab214a7814a9bd9b7e44fd25f57a0443ba/habitat/sims/pyrobot/pyrobot.py)\n",
    "\n",
    "```python\n",
    "# Are we in sim or reality?\n",
    "if args.use_robot: # Use LoCoBot via PyRobot\n",
    "    config.SIMULATOR.TYPE = \"PyRobot-Locobot-v0\"\n",
    "else: # Use simulation\n",
    "    config.SIMULATOR.TYPE = \"Habitat-Sim-v0\"\n",
    "```\n",
    "\n",
    "Paper: [https://arxiv.org/abs/1912.06321](https://arxiv.org/abs/1912.06321)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Habitat Lab",
   "provenance": []
  },
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "nb_python//py:percent,colabs//ipynb",
   "main_language": "python",
   "notebook_metadata_filter": "all"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
