{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# Create Environments with Braxlines Composer\n",
        "\n",
        "[Braxlines Composer](https://github.com/google/brax/blob/main/brax/experimental/composer) allows modular composition of Brax environments. Let's try it out! "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYe1kc3a4Oxc"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/composer.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlVNS8JstMRr"
      },
      "outputs": [],
      "source": [
        "#@title Colab setup and imports\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime \u003e Change Runtime Type, then select **'TPU'** in the dropdown.\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "import pprint\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "# from jax.config import config\n",
        "# config.update(\"jax_debug_nans\", True)\n",
        "from IPython.display import HTML, clear_output\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "from brax.io import html\n",
        "from brax.experimental.composer import composer\n",
        "from brax.experimental.composer.training import mappo\n",
        "from brax.experimental.braxlines import experiments\n",
        "from brax.experimental.braxlines.common import evaluators\n",
        "from brax.experimental.braxlines.common import logger_utils\n",
        "from brax.experimental.braxlines.training import ppo\n",
        "from brax.experimental.braxlines import experiments\n",
        "\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  from jax.tools import colab_tpu\n",
        "  colab_tpu.setup_tpu()\n",
        "\n",
        "def show_env(env, mode):\n",
        "  if mode == 'print_obs':\n",
        "    pprint.pprint(composer.get_obs_dict_shape(env))\n",
        "  elif mode == 'print_sys':\n",
        "    pprint.pprint(composer.unwrap(env).metadata.config_json)\n",
        "  else:\n",
        "    jit_env_reset = jax.jit(env.reset)\n",
        "    state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))\n",
        "    clear_output(wait=True)\n",
        "    return HTML(html.render(env.sys, [state.qp]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qL2eK6Ipj04B"
      },
      "outputs": [],
      "source": [
        "# @title List [registered environments](https://github.com/google/brax/blob/main/brax/experimental/composer/envs)\n",
        "max_n_env = 10 # @param {'type': 'number'}\n",
        "\n",
        "env_list = composer.list_env()\n",
        "print(f'# registered envs = {len(env_list)}, e.g.')\n",
        "pprint.pprint(env_list[:max_n_env])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "riA5oBKFK5B7"
      },
      "outputs": [],
      "source": [
        "#@title Specify an environment\n",
        "env_name = 'ant_run' # @param ['humanoid_run', 'squidgame', 'sumo', 'follow', 'chase', 'pro_ant_run', 'ant_run', 'ant_chase', 'ant_push']\n",
        "output_path = '' # @param {type: 'string'}\n",
        "show_params = True # @param {'type':'boolean'}\n",
        "\n",
        "if output_path:\n",
        "  output_path = f'{output_path}/{datetime.now().strftime(\"%Y%m%d\")}' \n",
        "  output_path = f'{output_path}/{env_name}'\n",
        "  print(f'Saving outputs to {output_path}')\n",
        "\n",
        "if show_params:\n",
        "  supported_params, support_kwargs = composer.inspect_env(env_name=env_name)\n",
        "  print(f'supported_params for \"{env_name}\" =')\n",
        "  pprint.pprint(supported_params)\n",
        "  print(f'support unlisted kwargs? (i.e. **kwargs): {support_kwargs}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T1ZJ2jZDKH8Y"
      },
      "outputs": [],
      "source": [
        "#@title Create a custom env\n",
        "# @markdown put a `None` or a `Dict` as `env_params`, based on `supported_params` above\n",
        "env_params =  None# @param{'type': 'raw'}\n",
        "mode = 'viewer'# @param ['print_step', 'print_obs', 'print_sys', 'viewer']\n",
        "ignore_kwargs = True # @param {'type':'boolean'}\n",
        "\n",
        "# check supported params\n",
        "env_params = env_params or {}\n",
        "composer.assert_env_params(env_name, env_params, ignore_kwargs)\n",
        "\n",
        "# create env\n",
        "env_fn = composer.create_fn(env_name=env_name, **env_params)\n",
        "env = env_fn()\n",
        "show_env(env, mode)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WGRizNxK3MtF"
      },
      "outputs": [],
      "source": [
        "#@title Training the custom env\n",
        "num_timesteps_multiplier =   3# @param {type: 'number'}\n",
        "seed = 0 # @param{type: 'integer'}\n",
        "skip_training = False # @param {type: 'boolean'}\n",
        "\n",
        "log_path = output_path\n",
        "if log_path:\n",
        "  log_path = f'{log_path}/training_curves.csv'\n",
        "tab = logger_utils.Tabulator(output_path=log_path,\n",
        "    append=False)\n",
        "\n",
        "ppo_lib = mappo if env.is_multiagent else ppo\n",
        "ppo_params = experiments.defaults.get_ppo_params(\n",
        "    'ant', num_timesteps_multiplier)\n",
        "train_fn = functools.partial(ppo_lib.train, **ppo_params)\n",
        "\n",
        "times = [datetime.now()]\n",
        "plotpatterns = ['eval/episode_reward', 'eval/episode_score']\n",
        "\n",
        "progress, _, _, _ = experiments.get_progress_fn(\n",
        "    plotpatterns, times, tab=tab, max_ncols=5,\n",
        "    xlim=[0, train_fn.keywords['num_timesteps']],\n",
        "    pre_plot_fn = lambda : clear_output(wait=True),\n",
        "    post_plot_fn = plt.show)\n",
        "\n",
        "if skip_training:\n",
        "  action_size = (env.group_action_shapes if \n",
        "    env.is_multiagent else env.action_size)\n",
        "  params, inference_fn = ppo_lib.make_params_and_inference_fn(\n",
        "    env.observation_size, action_size,\n",
        "    normalize_observations=True)\n",
        "  inference_fn = jax.jit(inference_fn)\n",
        "else:\n",
        "  inference_fn, params, _ = train_fn(\n",
        "    environment_fn=env_fn, seed=seed,\n",
        "    extra_step_kwargs=False, progress_fn=progress)\n",
        "  print(f'time to jit: {times[1] - times[0]}')\n",
        "  print(f'time to train: {times[-1] - times[1]}')\n",
        "  print(f'Saved logs to {log_path}')\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-0VYySqOEk0"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing a trajectory of the learned inference function\n",
        "eval_seed = 0  # @param {'type': 'integer'}\n",
        "batch_size =  0# @param {type: 'integer'}\n",
        "\n",
        "env, states = evaluators.visualize_env(\n",
        "    env_fn=env_fn, inference_fn=inference_fn,\n",
        "    params=params, batch_size=batch_size,\n",
        "    seed = eval_seed, output_path=output_path,\n",
        "    verbose=True,\n",
        ")\n",
        "HTML(html.render(env.sys, [state.qp for state in states]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-48ybSUcyMJu"
      },
      "outputs": [],
      "source": [
        "#@title Plot information of the trajectory\n",
        "experiments.plot_states(states[1:], max_ncols=5)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "composer.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1BCqjiaBc13bQK1gQiEMUQGrxjPTov2EN",
          "timestamp": 1639370433605
        },
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1630801484981
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
