{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MpkYHwCqk7W-"
      },
      "source": [
        "# \u003cdiv align=\"center\"\u003e**`dm_control` tutorial**\u003c/div\u003e\n",
        "# \u003cdiv align=\"center\"\u003e[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb)\u003c/div\u003e\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_UbO9uhtBSX5"
      },
      "source": [
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eCopyright 2020 The dm_control Authors.\u003c/small\u003e\u003c/p\u003e\n",
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at \u003ca href=\"http://www.apache.org/licenses/LICENSE-2.0\"\u003ehttp://www.apache.org/licenses/LICENSE-2.0\u003c/a\u003e.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e\n",
        "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aThGKGp0cD76"
      },
      "source": [
        "This notebook provides an overview tutorial of DeepMind's `dm_control` package, hosted at the [deepmind/dm_control](https://github.com/deepmind/dm_control) repository on GitHub.\n",
        "\n",
        "It is adjunct to this [tech report](http://arxiv.org/abs/2006.12983)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YvyGCsgSCxHQ"
      },
      "source": [
        "### Installing `dm_control`\n",
        "\n",
        "If your runtime kernel has `dm_control` installed, you can proceed to the next section (\"Imports\"). \n",
        "\n",
        "Otherwise, follow the instructions that apply to you:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VEEj3Qw60y73"
      },
      "source": [
        "#### Edit and run the following cell if you have an institutional MuJoCo license."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "IbZxYDxzoz5R"
      },
      "outputs": [],
      "source": [
        "#@title Edit and run\n",
        "#@test {\"skip\": true}\n",
        "mjkey = \"\"\"\n",
        "\n",
        "REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
        "\n",
        "\"\"\".strip()\n",
        "\n",
        "mujoco_dir = \"$HOME/.mujoco\"\n",
        "\n",
        "# Install OpenGL deps\n",
        "!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
        "  libgl1-mesa-glx libosmesa6 libglew2.0\n",
        "\n",
        "# Fetch MuJoCo binaries from Roboti\n",
        "!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
        "!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
        "\n",
        "# Copy over MuJoCo license\n",
        "!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
        "\n",
        "\n",
        "# Configure dm_control to use the OSMesa rendering backend\n",
        "%env MUJOCO_GL=osmesa\n",
        "\n",
        "# Install dm_control\n",
        "!pip install dm_control\n",
        "\n",
        "# Kick the tyres a bit, check that the installation succeeded\n",
        "try:\n",
        "  from dm_control import suite\n",
        "  env = suite.load('cartpole', 'swingup')\n",
        "  pixels = env.physics.render()\n",
        "except Exception as e:\n",
        "  raise e from RuntimeError(\n",
        "      'Something went wrong during installation. Check the shell output above '\n",
        "      'for more information.')\n",
        "else:\n",
        "  from IPython.display import clear_output\n",
        "  clear_output()\n",
        "  del suite, env, pixels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-_7tVg-zzjzW"
      },
      "source": [
        "#### Follow these instructions if you have a machine-locked MuJoCo license."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xxMh7EoJU0Y3"
      },
      "source": [
        "Colab supports using a Jupyter kernel on your local machine, as detailed at https://research.google.com/colaboratory/local-runtimes.html.\n",
        "\n",
        "1. Install dm_control as usual by following instructions in https://github.com/deepmind/dm_control\n",
        "2. Install Jupyter: \n",
        "```\n",
        "pip install jupyterlab jupyter_http_over_ws\n",
        "```\n",
        "3. Enable WebSocket extension for Jupyter:\n",
        "```\n",
        "jupyter serverextension enable --py jupyter_http_over_ws\n",
        "```\n",
        "4. Start the Jupyter server:\n",
        "```\n",
        "jupyter notebook \\\n",
        "     --NotebookApp.allow_origin='https://colab.research.google.com' \\\n",
        "     --port=8888 \\\n",
        "     --NotebookApp.port_retries=0\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wtDN43hIJh2C"
      },
      "source": [
        "# Imports\n",
        "\n",
        "Run both of these cells:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "T5f4w3Kq2X14"
      },
      "outputs": [],
      "source": [
        "#@title All `dm_control` imports required for this tutorial\n",
        "\n",
        "# The basic mujoco wrapper.\n",
        "from dm_control import mujoco\n",
        "\n",
        "# Access to enums and MuJoCo library functions.\n",
        "from dm_control.mujoco.wrapper.mjbindings import enums\n",
        "from dm_control.mujoco.wrapper.mjbindings import mjlib\n",
        "\n",
        "# PyMJCF\n",
        "from dm_control import mjcf\n",
        "\n",
        "# Composer high level imports\n",
        "from dm_control import composer\n",
        "from dm_control.composer.observation import observable\n",
        "from dm_control.composer import variation\n",
        "\n",
        "# Imports for Composer tutorial example\n",
        "from dm_control.composer.variation import distributions\n",
        "from dm_control.composer.variation import noises\n",
        "from dm_control.locomotion.arenas import floors\n",
        "\n",
        "# Control Suite\n",
        "from dm_control import suite\n",
        "\n",
        "# Run through corridor example\n",
        "from dm_control.locomotion.walkers import cmu_humanoid\n",
        "from dm_control.locomotion.arenas import corridors as corridor_arenas\n",
        "from dm_control.locomotion.tasks import corridors as corridor_tasks\n",
        "\n",
        "# Soccer\n",
        "from dm_control.locomotion import soccer\n",
        "\n",
        "# Manipulation\n",
        "from dm_control import manipulation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "gKc1FNhKiVJX"
      },
      "outputs": [],
      "source": [
        "#@title Other imports and helper functions\n",
        "\n",
        "# General\n",
        "import copy\n",
        "import os\n",
        "from IPython.display import clear_output\n",
        "import numpy as np\n",
        "\n",
        "# Graphics-related\n",
        "import matplotlib\n",
        "import matplotlib.animation as animation\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import HTML\n",
        "import PIL.Image\n",
        "\n",
        "# Use svg backend for figure rendering\n",
        "%config InlineBackend.figure_format = 'svg'\n",
        "\n",
        "# Font sizes\n",
        "SMALL_SIZE = 8\n",
        "MEDIUM_SIZE = 10\n",
        "BIGGER_SIZE = 12\n",
        "plt.rc('font', size=SMALL_SIZE)          # controls default text sizes\n",
        "plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title\n",
        "plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels\n",
        "plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
        "plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels\n",
        "plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize\n",
        "plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
        "\n",
        "# Inline video helper function\n",
        "if os.environ.get('COLAB_NOTEBOOK_TEST', False):\n",
        "  # We skip video generation during tests, as it is quite expensive.\n",
        "  display_video = lambda *args, **kwargs: None\n",
        "else:\n",
        "  def display_video(frames, framerate=30):\n",
        "    height, width, _ = frames[0].shape\n",
        "    dpi = 70\n",
        "    orig_backend = matplotlib.get_backend()\n",
        "    matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.\n",
        "    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)\n",
        "    matplotlib.use(orig_backend)  # Switch back to the original backend.\n",
        "    ax.set_axis_off()\n",
        "    ax.set_aspect('equal')\n",
        "    ax.set_position([0, 0, 1, 1])\n",
        "    im = ax.imshow(frames[0])\n",
        "    def update(frame):\n",
        "      im.set_data(frame)\n",
        "      return [im]\n",
        "    interval = 1000/framerate\n",
        "    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,\n",
        "                                   interval=interval, blit=True, repeat=False)\n",
        "    return HTML(anim.to_html5_video())\n",
        "\n",
        "# Seed numpy's global RNG so that cell outputs are deterministic. We also try to\n",
        "# use RandomState instances that are local to a single cell wherever possible.\n",
        "np.random.seed(42)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jZXz9rPYGA-Y"
      },
      "source": [
        "# Model definition, compilation and rendering\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MRBaZsf1d7Gb"
      },
      "source": [
        "We begin by describing some basic concepts of the [MuJoCo](http://mujoco.org/) physics simulation library, but recommend the [official documentation](http://mujoco.org/book/) for details.\n",
        "\n",
        "Let's define a simple model with two geoms and a light."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "ZS2utl59ZTsr"
      },
      "outputs": [],
      "source": [
        "#@title A static model {vertical-output: true}\n",
        "\n",
        "static_model = \"\"\"\n",
        "\u003cmujoco\u003e\n",
        "  \u003cworldbody\u003e\n",
        "    \u003clight name=\"top\" pos=\"0 0 1\"/\u003e\n",
        "    \u003cgeom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/\u003e\n",
        "    \u003cgeom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/\u003e\n",
        "  \u003c/worldbody\u003e\n",
        "\u003c/mujoco\u003e\n",
        "\"\"\"\n",
        "physics = mujoco.Physics.from_xml_string(static_model)\n",
        "pixels = physics.render()\n",
        "PIL.Image.fromarray(pixels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "p4vPllljTJh8"
      },
      "source": [
        "`static_model` is written in MuJoCo's XML-based [MJCF](http://www.mujoco.org/book/modeling.html) modeling language. The `from_xml_string()` method invokes the model compiler, which instantiates the library's internal data structures. These can be accessed via the `physics` object, see below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MdUF2UYmR4TA"
      },
      "source": [
        "## Adding DOFs and simulating\n",
        "This is a perfectly legitimate model, but if we simulate it, nothing will happen except for time advancing. This is because this model has no degrees of freedom (DOFs). We add DOFs by adding **joints** to bodies, specifying how they can move with respect to their parents. Let us add a hinge joint and re-render, visualizing the joint axis."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "R7zokzd_yeEg"
      },
      "outputs": [],
      "source": [
        "#@title A child body with a joint { vertical-output: true }\n",
        "\n",
        "swinging_body = \"\"\"\n",
        "\u003cmujoco\u003e\n",
        "  \u003cworldbody\u003e\n",
        "    \u003clight name=\"top\" pos=\"0 0 1\"/\u003e\n",
        "    \u003cbody name=\"box_and_sphere\" euler=\"0 0 -30\"\u003e  \n",
        "      \u003cjoint name=\"swing\" type=\"hinge\" axis=\"1 -1 0\" pos=\"-.2 -.2 -.2\"/\u003e\n",
        "      \u003cgeom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/\u003e\n",
        "      \u003cgeom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/\u003e\n",
        "    \u003c/body\u003e\n",
        "  \u003c/worldbody\u003e\n",
        "\u003c/mujoco\u003e\n",
        "\"\"\"\n",
        "physics = mujoco.Physics.from_xml_string(swinging_body)\n",
        "# Visualize the joint axis.\n",
        "scene_option = mujoco.wrapper.core.MjvOption()\n",
        "scene_option.flags[enums.mjtVisFlag.mjVIS_JOINT] = True\n",
        "pixels = physics.render(scene_option=scene_option)\n",
        "PIL.Image.fromarray(pixels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "INOGGV0PTQus"
      },
      "source": [
        "The things that move (and which have inertia) are called *bodies*. The body's child `joint` specifies how that body can move with respect to its parent, in this case `box_and_sphere` w.r.t the `worldbody`. \n",
        "\n",
        "Note that the body's frame is **rotated** with an `euler` directive, and its children, the geoms and the joint, rotate with it. This is to emphasize the local-to-parent-frame nature of position and orientation directives in MJCF.\n",
        "\n",
        "Let's make a video, to get a sense of the dynamics and to see the body swinging under gravity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Z_57VMUDpGrj"
      },
      "outputs": [],
      "source": [
        "#@title Making a video {vertical-output: true}\n",
        "\n",
        "duration = 2    # (seconds)\n",
        "framerate = 30  # (Hz)\n",
        "\n",
        "# Visualize the joint axis\n",
        "scene_option = mujoco.wrapper.core.MjvOption()\n",
        "scene_option.flags[enums.mjtVisFlag.mjVIS_JOINT] = True\n",
        "\n",
        "# Simulate and display video.\n",
        "frames = []\n",
        "physics.reset()  # Reset state and time\n",
        "while physics.data.time \u003c duration:\n",
        "  physics.step()\n",
        "  if len(frames) \u003c physics.data.time * framerate:\n",
        "    pixels = physics.render(scene_option=scene_option)\n",
        "    frames.append(pixels)\n",
        "display_video(frames, framerate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yYvS1UaciMX_"
      },
      "source": [
        "Note how we collect the video frames. Because physics simulation timesteps are generally much smaller than framerates (the default timestep is 2ms), we don't render after each step."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nQ8XOnRQx7T1"
      },
      "source": [
        "### Additional rendering options\n",
        "\n",
        "Like joint visualisation, additional rendering options are exposed as parameters to the `render` method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "AQITZiIgx7T2"
      },
      "outputs": [],
      "source": [
        "#@title Enable transparency and frames visualization. {vertical-output: true}\n",
        "\n",
        "scene_option = mujoco.wrapper.core.MjvOption()\n",
        "scene_option.frame = enums.mjtFrame.mjFRAME_GEOM\n",
        "scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True\n",
        "pixels = physics.render(scene_option=scene_option)\n",
        "PIL.Image.fromarray(pixels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "PDDgY48vx7T6"
      },
      "outputs": [],
      "source": [
        "#@title Depth rendering {vertical-output: true}\n",
        "\n",
        "# depth is a float array, in meters.\n",
        "depth = physics.render(depth=True)\n",
        "# Shift nearest values to the origin.\n",
        "depth -= depth.min()\n",
        "# Scale by 2 mean distances of near rays.\n",
        "depth /= 2*depth[depth \u003c= 1].mean()\n",
        "# Scale to [0, 255]\n",
        "pixels = 255*np.clip(depth, 0, 1)\n",
        "PIL.Image.fromarray(pixels.astype(np.uint8))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PNwiIrgpx7T8"
      },
      "outputs": [],
      "source": [
        "#@title Segmentation rendering {vertical-output: true}\n",
        "\n",
        "seg = physics.render(segmentation=True)\n",
        "# Display the contents of the first channel, which contains object\n",
        "# IDs. The second channel, seg[:, :, 1], contains object types.\n",
        "geom_ids = seg[:, :, 0]\n",
        "# Infinity is mapped to -1\n",
        "geom_ids = geom_ids.astype(np.float64) + 1\n",
        "# Scale to [0, 1]\n",
        "geom_ids = geom_ids / geom_ids.max()\n",
        "pixels = 255*geom_ids\n",
        "PIL.Image.fromarray(pixels.astype(np.uint8))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gf9h_wi9weet"
      },
      "source": [
        "# MuJoCo basics and named indexing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NCcZxrDDB1Cj"
      },
      "source": [
        "## `mjModel`\n",
        "MuJoCo's `mjModel`, encapsulated in `physics.model`, contains the *model description*, including the default initial state and other fixed quantities which are not a function of the state, e.g. the positions of geoms in the frame of their parent body. The (x, y, z) offsets of the `box` and `sphere` geoms, relative their parent body `box_and_sphere` are given by `model.geom_pos`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wx8NANvOF8g1"
      },
      "outputs": [],
      "source": [
        "physics.model.geom_pos"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v9O1VNdmHn_K"
      },
      "source": [
        "Docstrings of attributes provide short descriptions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8_0MwaeYHn_N"
      },
      "outputs": [],
      "source": [
        "help(type(physics.model).geom_pos)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Wee5ATLtIQn_"
      },
      "source": [
        "The `model.opt` structure contains global quantities like"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BhzbZIfDIU2-"
      },
      "outputs": [],
      "source": [
        "print('timestep', physics.model.opt.timestep)\n",
        "print('gravity', physics.model.opt.gravity)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t5hY0fyXFLcf"
      },
      "source": [
        "## `mjData`\n",
        "`mjData`, encapsulated in `physics.data`, contains the *state* and quantities that depend on it. The state is made up of time, generalized positions and generalised velocities. These are respectively `data.time`, `data.qpos` and `data.qvel`. \n",
        "\n",
        "Let's print the state of the swinging body where we left it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "acwZtDwp9mQU"
      },
      "outputs": [],
      "source": [
        "print(physics.data.time, physics.data.qpos, physics.data.qvel)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7YlmcLcA-WQu"
      },
      "source": [
        "`physics.data` also contains functions of the state, for example the cartesian positions of objects in the world frame. The (x, y, z) positions of our two geoms are in `data.geom_xpos`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CPwDcAQ0-uUE"
      },
      "outputs": [],
      "source": [
        "print(physics.data.geom_xpos)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Z0UodCxS_v49"
      },
      "source": [
        "## Named indexing\n",
        "\n",
        "The semantics of the above arrays are made clearer using the `named` wrapper, which assigns names to rows and type names to columns."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cLARcaK6-xCU"
      },
      "outputs": [],
      "source": [
        "print(physics.named.data.geom_xpos)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wgXOUZNZHIx6"
      },
      "source": [
        "Note how `model.geom_pos` and `data.geom_xpos` have similar semantics but very different meanings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-cW61ClRHS8a"
      },
      "outputs": [],
      "source": [
        "print(physics.named.model.geom_pos)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-lQ0AChVASMv"
      },
      "source": [
        "Name strings can be used to index **into** the relevant quantities, making code much more readable and robust."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Rj4ad9fQAnFZ"
      },
      "outputs": [],
      "source": [
        "physics.named.data.geom_xpos['green_sphere', 'z']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "axr_p6APAzFn"
      },
      "source": [
        "Joint names can be used to index into quantities in configuration space (beginning with the letter `q`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "hluF9aDG9O1W"
      },
      "outputs": [],
      "source": [
        "physics.named.data.qpos['swing']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3IhfyD2Q1pjv"
      },
      "source": [
        "We can mix NumPy slicing operations with named indexing. As an example, we can set the color of the box using its name (`\"red_box\"`) as an index into the rows of the `geom_rgba` array. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "f5vVUullUvWH"
      },
      "outputs": [],
      "source": [
        "#@title Changing colors using named indexing{vertical-output: true}\n",
        "\n",
        "random_rgb = np.random.rand(3)\n",
        "physics.named.model.geom_rgba['red_box', :3] = random_rgb\n",
        "pixels = physics.render()\n",
        "PIL.Image.fromarray(pixels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "elzPPdq-KhLI"
      },
      "source": [
        "Note that while `physics.model` quantities will not be changed by the engine, we can change them ourselves between steps. This however is generally not recommended, the preferred approach being to modify the model at the XML level using the PyMJCF library, see below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "22ENjtVuhwsm"
      },
      "source": [
        "## Setting the state with `reset_context()`\n",
        "\n",
        "In order for `data` quantities that are functions of the state to be in sync with the state, MuJoCo's `mj_step1()` needs to be called. This is facilitated by the `reset_context()` context, please see in-depth discussion in Section 2.1 of the [tech report](https://arxiv.org/abs/2006.12983)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WBPprCtWgXFN"
      },
      "outputs": [],
      "source": [
        "physics.named.data.qpos['swing'] = np.pi\n",
        "print('Without reset_context, spatial positions are not updated:',\n",
        "      physics.named.data.geom_xpos['green_sphere', ['z']])\n",
        "with physics.reset_context():\n",
        "  physics.named.data.qpos['swing'] = np.pi\n",
        "print('After reset_context, positions are up-to-date:',\n",
        "      physics.named.data.geom_xpos['green_sphere', ['z']])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SHppAOjvSupc"
      },
      "source": [
        "## Free bodies: the self-inverting \"tippe-top\"\n",
        "\n",
        "A free body is a body with a `free` joint, with 6 movement DOFs: 3 translations and 3 rotations. We could give our `box_and_sphere` body a free joint and watch it fall, but let's look at something more interesting. A \"tippe top\" is a spinning toy which flips itself on its head ([Wikipedia](https://en.wikipedia.org/wiki/Tippe_top)). We model it as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xasXQpVMjIwA"
      },
      "outputs": [],
      "source": [
        "#@title The \"tippe-top\" model{vertical-output: true}\n",
        "\n",
        "tippe_top = \"\"\"\n",
        "\u003cmujoco model=\"tippe top\"\u003e\n",
        "  \u003coption integrator=\"RK4\"/\u003e\n",
        "  \u003casset\u003e\n",
        "    \u003ctexture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\" \n",
        "     rgb2=\".2 .3 .4\" width=\"300\" height=\"300\"/\u003e\n",
        "    \u003cmaterial name=\"grid\" texture=\"grid\" texrepeat=\"8 8\" reflectance=\".2\"/\u003e\n",
        "  \u003c/asset\u003e\n",
        "  \u003cworldbody\u003e\n",
        "    \u003cgeom size=\".2 .2 .01\" type=\"plane\" material=\"grid\"/\u003e\n",
        "    \u003clight pos=\"0 0 .6\"/\u003e\n",
        "    \u003ccamera name=\"closeup\" pos=\"0 -.1 .07\" xyaxes=\"1 0 0 0 1 2\"/\u003e\n",
        "    \u003cbody name=\"top\" pos=\"0 0 .02\"\u003e\n",
        "      \u003cfreejoint/\u003e\n",
        "      \u003cgeom name=\"ball\" type=\"sphere\" size=\".02\" /\u003e\n",
        "      \u003cgeom name=\"stem\" type=\"cylinder\" pos=\"0 0 .02\" size=\"0.004 .008\"/\u003e\n",
        "      \u003cgeom name=\"ballast\" type=\"box\" size=\".023 .023 0.005\"  pos=\"0 0 -.015\" \n",
        "       contype=\"0\" conaffinity=\"0\" group=\"3\"/\u003e\n",
        "    \u003c/body\u003e\n",
        "  \u003c/worldbody\u003e\n",
        "  \u003ckeyframe\u003e\n",
        "    \u003ckey name=\"spinning\" qpos=\"0 0 0.02 1 0 0 0\" qvel=\"0 0 0 0 1 200\" /\u003e\n",
        "  \u003c/keyframe\u003e\n",
        "\u003c/mujoco\u003e\n",
        "\"\"\"\n",
        "physics = mujoco.Physics.from_xml_string(tippe_top)\n",
        "PIL.Image.fromarray(physics.render(camera_id='closeup'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bvHlr6maJYIG"
      },
      "source": [
        "Note several new features of this model definition:\n",
        "0. The free joint is added with the `\u003cfreejoint/\u003e` clause, which is similar to `\u003cjoint type=\"free\"/\u003e`, but prohibits unphysical attributes like friction or stiffness.\n",
        "1. We use the `\u003coption/\u003e` clause to set the integrator to the more accurate Runge Kutta 4th order.\n",
        "2. We define the floor's grid material inside the `\u003casset/\u003e` clause and reference it in the floor geom. \n",
        "3. We use an invisible and non-colliding box geom called `ballast` to move the top's center-of-mass lower. Having a low center of mass is  (counter-intuitively) required for the flipping behaviour to occur.\n",
        "4. We save our initial spinning state as a keyframe. It has a high rotational velocity around the z-axis, but is not perfectly oriented with the world.\n",
        "5. We define a `\u003ccamera\u003e` in our model, and then render from it using the `camera_id` argument to `render()`.\n",
        "Let us examine the state:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "o4S9nYhHOKmb"
      },
      "outputs": [],
      "source": [
        "print('positions', physics.data.qpos)\n",
        "print('velocities', physics.data.qvel)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "71UgzBAqWdtZ"
      },
      "source": [
        "The velocities are easy to interpret, 6 zeros, one for each DOF. What about the length-7 positions? We can see the initial 2cm height of the body; the subsequent four numbers are the 3D orientation, defined by a *unit quaternion*. These normalized four-vectors, which preserve the topology of the orientation group, are the reason that `data.qpos` can be bigger than `data.qvel`: 3D orientations are represented with **4** numbers while angular velocities are **3** numbers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5P4HkhKNGQvs"
      },
      "outputs": [],
      "source": [
        "#@title Video of the tippe-top {vertical-output: true}\n",
        "#@test {\"timeout\": 600}\n",
        "\n",
        "duration = 7    # (seconds)\n",
        "framerate = 60  # (Hz)\n",
        "\n",
        "# Simulate and display video.\n",
        "frames = []\n",
        "physics.reset(0)  # Reset to keyframe 0 (load a saved state).\n",
        "while physics.data.time \u003c duration:\n",
        "  physics.step()\n",
        "  if len(frames) \u003c (physics.data.time) * framerate:\n",
        "    pixels = physics.render(camera_id='closeup')\n",
        "    frames.append(pixels)\n",
        "\n",
        "display_video(frames, framerate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rRuFKD2ubPgu"
      },
      "source": [
        "### Measuring values from `physics.data`\n",
        "The `physics.data` structure contains all of the dynamic variables and intermediate results produced by the simulation. These are expected to change on each timestep. \n",
        "\n",
        "Below we simulate for 2000 timesteps and plot the state and height of the sphere as a function of time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1XXB6asJoZ2N"
      },
      "outputs": [],
      "source": [
        "#@title Measuring values {vertical-output: true}\n",
        "\n",
        "timevals = []\n",
        "angular_velocity = []\n",
        "stem_height = []\n",
        "\n",
        "# Simulate and save data\n",
        "physics.reset(0)\n",
        "while physics.data.time \u003c duration:\n",
        "  physics.step()\n",
        "  timevals.append(physics.data.time)\n",
        "  angular_velocity.append(physics.data.qvel[3:6].copy())\n",
        "  stem_height.append(physics.named.data.geom_xpos['stem', 'z'])\n",
        "\n",
        "dpi = 100\n",
        "width = 480\n",
        "height = 640\n",
        "figsize = (width / dpi, height / dpi)\n",
        "_, ax = plt.subplots(2, 1, figsize=figsize, dpi=dpi, sharex=True)\n",
        "\n",
        "ax[0].plot(timevals, angular_velocity)\n",
        "ax[0].set_title('angular velocity')\n",
        "ax[0].set_ylabel('radians / second')\n",
        "\n",
        "ax[1].plot(timevals, stem_height)\n",
        "ax[1].set_xlabel('time (seconds)')\n",
        "ax[1].set_ylabel('meters')\n",
        "_ = ax[1].set_title('stem height')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UAMItwu8e1WR"
      },
      "source": [
        "# PyMJCF tutorial\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hPiY8m3MssKM"
      },
      "source": [
        "This library provides a Python object model for MuJoCo's XML-based\n",
        "[MJCF](http://www.mujoco.org/book/modeling.html) physics modeling language. The\n",
        "goal of the library is to allow users to easily interact with and modify MJCF\n",
        "models in Python, similarly to what the JavaScript DOM does for HTML.\n",
        "\n",
        "A key feature of this library is the ability to easily compose multiple separate\n",
        "MJCF models into a larger one. Disambiguation of duplicated names from different\n",
        "models, or multiple instances of the same model, is handled automatically.\n",
        "\n",
        "One typical use case is when we want robots with a variable number of joints. This is a fundamental change to the kinematics, requiring a new XML descriptor and new binary model to be compiled. \n",
        "\n",
        "The following snippets realise this scenario and provide a quick example of this library's use case."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gKny5EJ4uVzu"
      },
      "outputs": [],
      "source": [
        "class Leg(object):\n",
        "  \"\"\"A 2-DoF leg with position actuators.\"\"\"\n",
        "  def __init__(self, length, rgba):\n",
        "    self.model = mjcf.RootElement()\n",
        "\n",
        "    # Defaults:\n",
        "    self.model.default.joint.damping = 2\n",
        "    self.model.default.joint.type = 'hinge'\n",
        "    self.model.default.geom.type = 'capsule'\n",
        "    self.model.default.geom.rgba = rgba  # Continued below...\n",
        "\n",
        "    # Thigh:\n",
        "    self.thigh = self.model.worldbody.add('body')\n",
        "    self.hip = self.thigh.add('joint', axis=[0, 0, 1])\n",
        "    self.thigh.add('geom', fromto=[0, 0, 0, length, 0, 0], size=[length/4])\n",
        "\n",
        "    # Hip:\n",
        "    self.shin = self.thigh.add('body', pos=[length, 0, 0])\n",
        "    self.knee = self.shin.add('joint', axis=[0, 1, 0])\n",
        "    self.shin.add('geom', fromto=[0, 0, 0, 0, 0, -length], size=[length/5])\n",
        "\n",
        "    # Position actuators:\n",
        "    self.model.actuator.add('position', joint=self.hip, kp=10)\n",
        "    self.model.actuator.add('position', joint=self.knee, kp=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cFJerI--UtTy"
      },
      "source": [
        "The `Leg` class describes an abstract articulated leg, with two joints and corresponding proportional-derivative actuators. \n",
        "\n",
        "Note that:\n",
        "\n",
        "- MJCF attributes correspond directly to arguments of the `add()` method.\n",
        "- When referencing elements, e.g when specifying the joint to which an actuator is attached, the MJCF element itself is used, rather than the name string."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "SESlL_TidKHx"
      },
      "outputs": [],
      "source": [
        "BODY_RADIUS = 0.1\n",
        "BODY_SIZE = (BODY_RADIUS, BODY_RADIUS, BODY_RADIUS / 2)\n",
        "random_state = np.random.RandomState(42)\n",
        "\n",
        "def make_creature(num_legs):\n",
        "  \"\"\"Constructs a creature with `num_legs` legs.\"\"\"\n",
        "  rgba = random_state.uniform([0, 0, 0, 1], [1, 1, 1, 1])\n",
        "  model = mjcf.RootElement()\n",
        "  model.compiler.angle = 'radian'  # Use radians.\n",
        "\n",
        "  # Make the torso geom.\n",
        "  model.worldbody.add(\n",
        "      'geom', name='torso', type='ellipsoid', size=BODY_SIZE, rgba=rgba)\n",
        "\n",
        "  # Attach legs to equidistant sites on the circumference.\n",
        "  for i in range(num_legs):\n",
        "    theta = 2 * i * np.pi / num_legs\n",
        "    hip_pos = BODY_RADIUS * np.array([np.cos(theta), np.sin(theta), 0])\n",
        "    hip_site = model.worldbody.add('site', pos=hip_pos, euler=[0, 0, theta])\n",
        "    leg = Leg(length=BODY_RADIUS, rgba=rgba)\n",
        "    hip_site.attach(leg.model)\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "elyuJiI3U3kM"
      },
      "source": [
        "The `make_creature` function uses PyMJCF's `attach()` method to procedurally attach legs to the torso. Note that at this stage both the torso and hip attachment sites are children of the `worldbody`, since their parent body has yet to be instantiated. We'll now make an arena with a chequered floor and two lights, and place our creatures in a grid."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "F7_Tx9P9U_VJ"
      },
      "outputs": [],
      "source": [
        "#@title Six Creatures on a floor.{vertical-output: true}\n",
        "\n",
        "arena = mjcf.RootElement()\n",
        "chequered = arena.asset.add('texture', type='2d', builtin='checker', width=300,\n",
        "                            height=300, rgb1=[.2, .3, .4], rgb2=[.3, .4, .5])\n",
        "grid = arena.asset.add('material', name='grid', texture=chequered,\n",
        "                       texrepeat=[5, 5], reflectance=.2)\n",
        "arena.worldbody.add('geom', type='plane', size=[2, 2, .1], material=grid)\n",
        "for x in [-2, 2]:\n",
        "  arena.worldbody.add('light', pos=[x, -1, 3], dir=[-x, 1, -2])\n",
        "\n",
        "# Instantiate 6 creatures with 3 to 8 legs.\n",
        "creatures = [make_creature(num_legs=num_legs) for num_legs in range(3, 9)]\n",
        "\n",
        "# Place them on a grid in the arena.\n",
        "height = .15\n",
        "grid = 5 * BODY_RADIUS\n",
        "xpos, ypos, zpos = np.meshgrid([-grid, 0, grid], [0, grid], [height])\n",
        "for i, model in enumerate(creatures):\n",
        "  # Place spawn sites on a grid.\n",
        "  spawn_pos = (xpos.flat[i], ypos.flat[i], zpos.flat[i])\n",
        "  spawn_site = arena.worldbody.add('site', pos=spawn_pos, group=3)\n",
        "  # Attach to the arena at the spawn sites, with a free joint.\n",
        "  spawn_site.attach(model).add('freejoint')\n",
        "\n",
        "# Instantiate the physics and render.\n",
        "physics = mjcf.Physics.from_mjcf_model(arena)\n",
        "PIL.Image.fromarray(physics.render())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cMfDaD7PfuoI"
      },
      "source": [
        "Multi-legged creatures, ready to roam! Let's inject some controls and watch them move. We'll generate a sinusoidal open-loop control signal of fixed frequency and random phase, recording both video frames and the horizontal positions of the torso geoms, in order to plot the movement trajectories."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8Gx39DMEUZDt"
      },
      "outputs": [],
      "source": [
        "#@title Video of the movement{vertical-output: true}\n",
        "#@test {\"timeout\": 600}\n",
        "\n",
        "duration = 10   # (Seconds)\n",
        "framerate = 30  # (Hz)\n",
        "video = []\n",
        "pos_x = []\n",
        "pos_y = []\n",
        "torsos = []  # List of torso geom elements.\n",
        "actuators = []  # List of actuator elements.\n",
        "for creature in creatures:\n",
        "  torsos.append(creature.find('geom', 'torso'))\n",
        "  actuators.extend(creature.find_all('actuator'))\n",
        "\n",
        "# Control signal frequency, phase, amplitude.\n",
        "freq = 5\n",
        "phase = 2 * np.pi * random_state.rand(len(actuators))\n",
        "amp = 0.9\n",
        "\n",
        "# Simulate, saving video frames and torso locations.\n",
        "physics.reset()\n",
        "while physics.data.time \u003c duration:\n",
        "  # Inject controls and step the physics.\n",
        "  physics.bind(actuators).ctrl = amp * np.sin(freq * physics.data.time + phase)\n",
        "  physics.step()\n",
        "\n",
        "  # Save torso horizontal positions using bind().\n",
        "  pos_x.append(physics.bind(torsos).xpos[:, 0].copy())\n",
        "  pos_y.append(physics.bind(torsos).xpos[:, 1].copy())\n",
        "\n",
        "  # Save video frames.\n",
        "  if len(video) \u003c physics.data.time * framerate:\n",
        "    pixels = physics.render()\n",
        "    video.append(pixels.copy())\n",
        "\n",
        "display_video(video, framerate)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "u09JfenWYLZu"
      },
      "outputs": [],
      "source": [
        "#@title Movement trajectories{vertical-output: true}\n",
        "\n",
        "creature_colors = physics.bind(torsos).rgba[:, :3]\n",
        "fig, ax = plt.subplots(figsize=(4, 4))\n",
        "ax.set_prop_cycle(color=creature_colors)\n",
        "_ = ax.plot(pos_x, pos_y, linewidth=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kggQyvNpf_Y9"
      },
      "source": [
        "The plot above shows the corresponding movement trajectories of creature positions. Note how `physics.bind(torsos)` was used to access both `xpos` and `rgba` values. Once the `Physics` had been instantiated by `from_mjcf_model()`, the `bind()` method will expose both the associated `mjData` and `mjModel` fields of an `mjcf` element, providing unified access to all quantities in the simulation. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wcRX_wu_8q8u"
      },
      "source": [
        "# Composer tutorial"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1DMhNPE5tSdw"
      },
      "source": [
        "In this tutorial we will create a task requiring our \"creature\" above to press a colour-changing button on the floor with a prescribed force. We begin by implementing our creature as a `composer.Entity`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WwfzIqgNuFKt"
      },
      "outputs": [],
      "source": [
        "#@title The `Creature` class\n",
        "\n",
        "\n",
        "class Creature(composer.Entity):\n",
        "  \"\"\"A multi-legged creature derived from `composer.Entity`.\"\"\"\n",
        "  def _build(self, num_legs):\n",
        "    self._model = make_creature(num_legs)\n",
        "\n",
        "  def _build_observables(self):\n",
        "    return CreatureObservables(self)\n",
        "\n",
        "  @property\n",
        "  def mjcf_model(self):\n",
        "    return self._model\n",
        "\n",
        "  @property\n",
        "  def actuators(self):\n",
        "    return tuple(self._model.find_all('actuator'))\n",
        "\n",
        "\n",
        "# Add simple observable features for joint angles and velocities.\n",
        "class CreatureObservables(composer.Observables):\n",
        "\n",
        "  @composer.observable\n",
        "  def joint_positions(self):\n",
        "    all_joints = self._entity.mjcf_model.find_all('joint')\n",
        "    return observable.MJCFFeature('qpos', all_joints)\n",
        "\n",
        "  @composer.observable\n",
        "  def joint_velocities(self):\n",
        "    all_joints = self._entity.mjcf_model.find_all('joint')\n",
        "    return observable.MJCFFeature('qvel', all_joints)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CXZOBK6RkjxH"
      },
      "source": [
        "The `Creature` Entity includes generic Observables for joint angles and velocities. Because `find_all()` is called on the `Creature`'s MJCF model, it will only return the creature's leg joints, and not the \"free\" joint with which it will be attached to the world. Note that Composer Entities should override the `_build` and `_build_observables` methods rather than `__init__`. The implementation of `__init__` in the base class calls `_build` and `_build_observables`, in that order, to ensure that the entity's MJCF model is created before its observables. This was a design choice which allows the user to refer to an observable as an attribute (`entity.observables.foo`) while still making it clear which attributes are observables. The stateful `Button` class derives from `composer.Entity` and implements the `initialize_episode` and `after_substep` callbacks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BE9VU2EOvR-u"
      },
      "outputs": [],
      "source": [
        "#@title The `Button` class\n",
        "\n",
        "NUM_SUBSTEPS = 25  # The number of physics substeps per control timestep.\n",
        "\n",
        "\n",
        "class Button(composer.Entity):\n",
        "  \"\"\"A button Entity which changes colour when pressed with certain force.\"\"\"\n",
        "  def _build(self, target_force_range=(5, 10)):\n",
        "    self._min_force, self._max_force = target_force_range\n",
        "    self._mjcf_model = mjcf.RootElement()\n",
        "    self._geom = self._mjcf_model.worldbody.add(\n",
        "        'geom', type='cylinder', size=[0.25, 0.02], rgba=[1, 0, 0, 1])\n",
        "    self._site = self._mjcf_model.worldbody.add(\n",
        "        'site', type='cylinder', size=self._geom.size*1.01, rgba=[1, 0, 0, 0])\n",
        "    self._sensor = self._mjcf_model.sensor.add('touch', site=self._site)\n",
        "    self._num_activated_steps = 0\n",
        "\n",
        "  def _build_observables(self):\n",
        "    return ButtonObservables(self)\n",
        "\n",
        "  @property\n",
        "  def mjcf_model(self):\n",
        "    return self._mjcf_model\n",
        "  # Update the activation (and colour) if the desired force is applied.\n",
        "  def _update_activation(self, physics):\n",
        "    current_force = physics.bind(self.touch_sensor).sensordata[0]\n",
        "    self._is_activated = (current_force \u003e= self._min_force and\n",
        "                          current_force \u003c= self._max_force)\n",
        "    physics.bind(self._geom).rgba = (\n",
        "        [0, 1, 0, 1] if self._is_activated else [1, 0, 0, 1])\n",
        "    self._num_activated_steps += int(self._is_activated)\n",
        "\n",
        "  def initialize_episode(self, physics, random_state):\n",
        "    self._reward = 0.0\n",
        "    self._num_activated_steps = 0\n",
        "    self._update_activation(physics)\n",
        "\n",
        "  def after_substep(self, physics, random_state):\n",
        "    self._update_activation(physics)\n",
        "\n",
        "  @property\n",
        "  def touch_sensor(self):\n",
        "    return self._sensor\n",
        "\n",
        "  @property\n",
        "  def num_activated_steps(self):\n",
        "    return self._num_activated_steps\n",
        "\n",
        "\n",
        "class ButtonObservables(composer.Observables):\n",
        "  \"\"\"A touch sensor which averages contact force over physics substeps.\"\"\"\n",
        "  @composer.observable\n",
        "  def touch_force(self):\n",
        "    return observable.MJCFFeature('sensordata', self._entity.touch_sensor,\n",
        "                                  buffer_size=NUM_SUBSTEPS, aggregator='mean')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D9vB5nCwkyIW"
      },
      "source": [
        "Note how the Button counts the number of sub-steps during which it is pressed with the desired force. It also exposes an `Observable` of the force being applied to the button, whose value is an average of the readings over the physics time-steps.\n",
        "\n",
        "We import some `variation` modules and an arena factory:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "aDTTQMHtVawM"
      },
      "outputs": [],
      "source": [
        "#@title Random initialiser using `composer.variation`\n",
        "\n",
        "\n",
        "class UniformCircle(variation.Variation):\n",
        "  \"\"\"A uniformly sampled horizontal point on a circle of radius `distance`.\"\"\"\n",
        "  def __init__(self, distance):\n",
        "    self._distance = distance\n",
        "    self._heading = distributions.Uniform(0, 2*np.pi)\n",
        "\n",
        "  def __call__(self, initial_value=None, current_value=None, random_state=None):\n",
        "    distance, heading = variation.evaluate(\n",
        "        (self._distance, self._heading), random_state=random_state)\n",
        "    return (distance*np.cos(heading), distance*np.sin(heading), 0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dgZwP-pvxJdt"
      },
      "outputs": [],
      "source": [
        "#@title The `PressWithSpecificForce` task\n",
        "\n",
        "\n",
        "class PressWithSpecificForce(composer.Task):\n",
        "\n",
        "  def __init__(self, creature):\n",
        "    self._creature = creature\n",
        "    self._arena = floors.Floor()\n",
        "    self._arena.add_free_entity(self._creature)\n",
        "    self._arena.mjcf_model.worldbody.add('light', pos=(0, 0, 4))\n",
        "    self._button = Button()\n",
        "    self._arena.attach(self._button)\n",
        "\n",
        "    # Configure initial poses\n",
        "    self._creature_initial_pose = (0, 0, 0.15)\n",
        "    button_distance = distributions.Uniform(0.5, .75)\n",
        "    self._button_initial_pose = UniformCircle(button_distance)\n",
        "\n",
        "    # Configure variators\n",
        "    self._mjcf_variator = variation.MJCFVariator()\n",
        "    self._physics_variator = variation.PhysicsVariator()\n",
        "\n",
        "    # Configure and enable observables\n",
        "    pos_corrptor = noises.Additive(distributions.Normal(scale=0.01))\n",
        "    self._creature.observables.joint_positions.corruptor = pos_corrptor\n",
        "    self._creature.observables.joint_positions.enabled = True\n",
        "    vel_corruptor = noises.Multiplicative(distributions.LogNormal(sigma=0.01))\n",
        "    self._creature.observables.joint_velocities.corruptor = vel_corruptor\n",
        "    self._creature.observables.joint_velocities.enabled = True\n",
        "    self._button.observables.touch_force.enabled = True\n",
        "\n",
        "    def to_button(physics):\n",
        "      button_pos, _ = self._button.get_pose(physics)\n",
        "      return self._creature.global_vector_to_local_frame(physics, button_pos)\n",
        "\n",
        "    self._task_observables = {}\n",
        "    self._task_observables['button_position'] = observable.Generic(to_button)\n",
        "\n",
        "    for obs in self._task_observables.values():\n",
        "      obs.enabled = True\n",
        "\n",
        "    self.control_timestep = NUM_SUBSTEPS * self.physics_timestep\n",
        "\n",
        "  @property\n",
        "  def root_entity(self):\n",
        "    return self._arena\n",
        "\n",
        "  @property\n",
        "  def task_observables(self):\n",
        "    return self._task_observables\n",
        "\n",
        "  def initialize_episode_mjcf(self, random_state):\n",
        "    self._mjcf_variator.apply_variations(random_state)\n",
        "\n",
        "  def initialize_episode(self, physics, random_state):\n",
        "    self._physics_variator.apply_variations(physics, random_state)\n",
        "    creature_pose, button_pose = variation.evaluate(\n",
        "        (self._creature_initial_pose, self._button_initial_pose),\n",
        "        random_state=random_state)\n",
        "    self._creature.set_pose(physics, position=creature_pose)\n",
        "    self._button.set_pose(physics, position=button_pose)\n",
        "\n",
        "  def get_reward(self, physics):\n",
        "    return self._button.num_activated_steps / NUM_SUBSTEPS    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dRuuZdLpthbv"
      },
      "outputs": [],
      "source": [
        "#@title Instantiating an environment{vertical-output: true}\n",
        "\n",
        "creature = Creature(num_legs=4)\n",
        "task = PressWithSpecificForce(creature)\n",
        "env = composer.Environment(task, random_state=np.random.RandomState(42))\n",
        "\n",
        "env.reset()\n",
        "PIL.Image.fromarray(env.physics.render())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "giTL_6euZFlw"
      },
      "source": [
        "# The *Control Suite*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zfIcrDECtdB2"
      },
      "source": [
        "The **Control Suite** is a set of stable, well-tested tasks designed to serve as a benchmark for continuous control learning agents. Tasks are written using the basic MuJoCo wrapper interface. Standardised action, observation and reward structures make suite-wide benchmarking simple and learning curves easy to interpret. Control Suite domains are not meant to be modified, in order to facilitate benchmarking. For full details regarding benchmarking, please refer to our original [publication](https://arxiv.org/abs/1801.00690).\n",
        "\n",
        "A video of solved benchmark tasks is [available here](https://www.youtube.com/watch?v=rAai4QzcYbs\u0026feature=youtu.be).\n",
        "\n",
        "The suite come with convenient module level tuples for iterating over tasks:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "a_whTJG8uTp1"
      },
      "outputs": [],
      "source": [
        "#@title Iterating over tasks{vertical-output: true}\n",
        "\n",
        "max_len = max(len(d) for d, _ in suite.BENCHMARKING)\n",
        "for domain, task in suite.BENCHMARKING:\n",
        "  print(f'{domain:\u003c{max_len}}  {task}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "qN8y3etfZFly"
      },
      "outputs": [],
      "source": [
        "#@title Loading and simulating a `suite` task{vertical-output: true}\n",
        "\n",
        "# Load the environment\n",
        "random_state = np.random.RandomState(42)\n",
        "env = suite.load('hopper', 'stand', task_kwargs={'random': random_state})\n",
        "\n",
        "# Simulate episode with random actions\n",
        "duration = 4  # Seconds\n",
        "frames = []\n",
        "ticks = []\n",
        "rewards = []\n",
        "observations = []\n",
        "\n",
        "spec = env.action_spec()\n",
        "time_step = env.reset()\n",
        "\n",
        "while env.physics.data.time \u003c duration:\n",
        "\n",
        "  action = random_state.uniform(spec.minimum, spec.maximum, spec.shape)\n",
        "  time_step = env.step(action)\n",
        "\n",
        "  camera0 = env.physics.render(camera_id=0, height=200, width=200)\n",
        "  camera1 = env.physics.render(camera_id=1, height=200, width=200)\n",
        "  frames.append(np.hstack((camera0, camera1)))\n",
        "  rewards.append(time_step.reward)\n",
        "  observations.append(copy.deepcopy(time_step.observation))\n",
        "  ticks.append(env.physics.data.time)\n",
        "\n",
        "html_video = display_video(frames, framerate=1./env.control_timestep())\n",
        "\n",
        "# Show video and plot reward and observations\n",
        "num_sensors = len(time_step.observation)\n",
        "\n",
        "_, ax = plt.subplots(1 + num_sensors, 1, sharex=True, figsize=(4, 8))\n",
        "ax[0].plot(ticks, rewards)\n",
        "ax[0].set_ylabel('reward')\n",
        "ax[-1].set_xlabel('time')\n",
        "\n",
        "for i, key in enumerate(time_step.observation):\n",
        "  data = np.asarray([observations[j][key] for j in range(len(observations))])\n",
        "  ax[i+1].plot(ticks, data, label=key)\n",
        "  ax[i+1].set_ylabel(key)\n",
        "\n",
        "html_video"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ggVbQr5hZFl5"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing an initial state of one task per domain in the Control Suite\n",
        "domains_tasks = {domain: task for domain, task in suite.ALL_TASKS}\n",
        "random_state = np.random.RandomState(42)\n",
        "num_domains = len(domains_tasks)\n",
        "n_col = num_domains // int(np.sqrt(num_domains))\n",
        "n_row = num_domains // n_col + int(0 \u003c num_domains % n_col)\n",
        "_, ax = plt.subplots(n_row, n_col, figsize=(12, 12))\n",
        "for a in ax.flat:\n",
        "  a.axis('off')\n",
        "  a.grid(False)\n",
        "\n",
        "print(f'Iterating over all {num_domains} domains in the Suite:')\n",
        "for j, [domain, task] in enumerate(domains_tasks.items()):\n",
        "  print(domain, task)\n",
        "\n",
        "  env = suite.load(domain, task, task_kwargs={'random': random_state})\n",
        "  timestep = env.reset()\n",
        "  pixels = env.physics.render(height=200, width=200, camera_id=0)\n",
        "\n",
        "  ax.flat[j].imshow(pixels)\n",
        "  ax.flat[j].set_title(domain + ': ' + task)\n",
        "\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JHSvxHiaopDb"
      },
      "source": [
        "# Locomotion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yTn3C03dpHzL"
      },
      "source": [
        "## Humanoid running along corridor with obstacles\n",
        "\n",
        "As an illustrative example of using the Locomotion infrastructure to build an RL environment, consider placing a humanoid in a corridor with walls, and a task specifying that the humanoid will be rewarded for running along this corridor, navigating around the wall obstacles using vision. We instantiate the environment as a composition of the Walker, Arena, and Task as follows. First, we build a position-controlled CMU humanoid walker. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gE8rrB7PpN9X"
      },
      "outputs": [],
      "source": [
        "#@title A position controlled `cmu_humanoid`\n",
        "\n",
        "walker = cmu_humanoid.CMUHumanoidPositionControlledV2020(\n",
        "    observable_options={'egocentric_camera': dict(enabled=True)})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3fYbaDflBrgE"
      },
      "source": [
        "Next, we construct a corridor-shaped arena that is obstructed by walls."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "t-O17Fnm3E6R"
      },
      "outputs": [],
      "source": [
        "#@title A corridor arena with wall obstacles\n",
        "\n",
        "arena = corridor_arenas.WallsCorridor(\n",
        "    wall_gap=3.,\n",
        "    wall_width=distributions.Uniform(2., 3.),\n",
        "    wall_height=distributions.Uniform(2.5, 3.5),\n",
        "    corridor_width=4.,\n",
        "    corridor_length=30.,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "970nN38eBx-R"
      },
      "source": [
        "The task constructor places the walker in the arena."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dz4Jy2UGpQ4Z"
      },
      "outputs": [],
      "source": [
        "#@title A task to navigate the arena\n",
        "\n",
        "task = corridor_tasks.RunThroughCorridor(\n",
        "    walker=walker,\n",
        "    arena=arena,\n",
        "    walker_spawn_position=(0.5, 0, 0),\n",
        "    target_velocity=3.0,\n",
        "    physics_timestep=0.005,\n",
        "    control_timestep=0.03,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "r-Oy-qTSB4HW"
      },
      "source": [
        "Finally, a task that rewards the agent for running down the corridor at a specific velocity is instantiated as a `composer.Environment`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sQXlaEZk3ytl"
      },
      "outputs": [],
      "source": [
        "#@title The `RunThroughCorridor` environment\n",
        "\n",
        "env = composer.Environment(\n",
        "    task=task,\n",
        "    time_limit=10,\n",
        "    random_state=np.random.RandomState(42),\n",
        "    strip_singleton_obs_buffer_dim=True,\n",
        ")\n",
        "env.reset()\n",
        "pixels = []\n",
        "for camera_id in range(3):\n",
        "  pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
        "PIL.Image.fromarray(np.hstack(pixels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HuuQLm8YopDe"
      },
      "source": [
        "## Multi-Agent Soccer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OPNshDEEopDf"
      },
      "source": [
        "Building on Composer and Locomotion libraries, the Multi-agent soccer environments, introduced in [this paper](https://arxiv.org/abs/1902.07151), follow a consistent task structure of Walkers, Arena, and Task where instead of a single walker, we inject multiple walkers that can interact with each other physically in the same scene. The code snippet below shows how to instantiate a 2-vs-2 Multi-agent Soccer environment with the simple, 5 degree-of-freedom `BoxHead` walker type."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "zAb3je0DAeQo"
      },
      "outputs": [],
      "source": [
        "#@title 2-v-2 `Boxhead` soccer\n",
        "\n",
        "random_state = np.random.RandomState(42)\n",
        "env = soccer.load(\n",
        "    team_size=2,\n",
        "    time_limit=45.,\n",
        "    random_state=random_state,\n",
        "    disable_walker_contacts=False,\n",
        "    walker_type=soccer.WalkerType.BOXHEAD,\n",
        ")\n",
        "env.reset()\n",
        "pixels = []\n",
        "# Select a random subset of 6 cameras (soccer envs have lots of cameras)\n",
        "cameras = random_state.choice(env.physics.model.ncam, 6, replace=False)\n",
        "for camera_id in cameras:\n",
        "  pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
        "image = np.vstack((np.hstack(pixels[:3]), np.hstack(pixels[3:])))\n",
        "PIL.Image.fromarray(image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J_5C2k0NGvxE"
      },
      "source": [
        " It can trivially be replaced by e.g. the `WalkerType.ANT` walker:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WDIGodhBG-Mn"
      },
      "outputs": [],
      "source": [
        "#@title 3-v-3 `Ant` soccer\n",
        "\n",
        "random_state = np.random.RandomState(42)\n",
        "env = soccer.load(\n",
        "    team_size=3,\n",
        "    time_limit=45.,\n",
        "    random_state=random_state,\n",
        "    disable_walker_contacts=False,\n",
        "    walker_type=soccer.WalkerType.ANT,\n",
        ")\n",
        "env.reset()\n",
        "\n",
        "pixels = []\n",
        "cameras = random_state.choice(env.physics.model.ncam, 6, replace=False)\n",
        "for camera_id in cameras:\n",
        "  pixels.append(env.physics.render(camera_id=camera_id, width=240))\n",
        "image = np.vstack((np.hstack(pixels[:3]), np.hstack(pixels[3:])))\n",
        "PIL.Image.fromarray(image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MvK9BW4A5c9p"
      },
      "source": [
        "# Manipulation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jPt27n2Dch_m"
      },
      "source": [
        "The `manipulation` module provides a robotic arm, a set of simple objects, and tools for building reward functions for manipulation tasks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cZxmJoovahCA"
      },
      "outputs": [],
      "source": [
        "#@title Listing all `manipulation` tasks{vertical-output: true}\n",
        "\n",
        "# `ALL` is a tuple containing the names of all of the environments in the suite.\n",
        "print('\\n'.join(manipulation.ALL))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oj0cJFlR5nTS"
      },
      "outputs": [],
      "source": [
        "#@title Listing `manipulation` tasks that use vision{vertical-output: true}\n",
        "print('\\n'.join(manipulation.get_environments_by_tag('vision')))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "e_6q4FqFIKxy"
      },
      "outputs": [],
      "source": [
        "#@title Loading and simulating a `manipulation` task{vertical-output: true}\n",
        "#@test {\"timeout\": 180}\n",
        "\n",
        "env = manipulation.load('stack_2_of_3_bricks_random_order_vision', seed=42)\n",
        "action_spec = env.action_spec()\n",
        "\n",
        "def sample_random_action():\n",
        "  return env.random_state.uniform(\n",
        "      low=action_spec.minimum,\n",
        "      high=action_spec.maximum,\n",
        "  ).astype(action_spec.dtype, copy=False)\n",
        "\n",
        "# Step the environment through a full episode using random actions and record\n",
        "# the camera observations.\n",
        "frames = []\n",
        "timestep = env.reset()\n",
        "frames.append(timestep.observation['front_close'])\n",
        "while not timestep.last():\n",
        "  timestep = env.step(sample_random_action())\n",
        "  frames.append(timestep.observation['front_close'])\n",
        "all_frames = np.concatenate(frames, axis=0)\n",
        "display_video(all_frames, 30)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "YvyGCsgSCxHQ",
        "VEEj3Qw60y73",
        "-_7tVg-zzjzW",
        "jZXz9rPYGA-Y",
        "MdUF2UYmR4TA",
        "nQ8XOnRQx7T1",
        "gf9h_wi9weet",
        "UAMItwu8e1WR",
        "wcRX_wu_8q8u",
        "giTL_6euZFlw",
        "JHSvxHiaopDb",
        "MvK9BW4A5c9p"
      ],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "dm_control",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1EBFHfEqrhxufldesVECPZfJF9HsKxAd2",
          "timestamp": 1588282893099
        },
        {
          "file_id": "0B8iLF4FF6KlDd0I3X2xTSHpQVEk",
          "timestamp": 1573234197672
        },
        {
          "file_id": "0B7GWNc3xBQH5ZjlPRzZGelNTcjg",
          "timestamp": 1483004131812
        },
        {
          "file_id": "0B7GWNc3xBQH5bHBRVkhOa1g3a28",
          "timestamp": 1482848217314
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
