{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# [BIG-Gym](https://github.com/google/brax/blob/main/brax/experimental/biggym) RL training\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/biggym/biggym_rl.ipynb)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlVNS8JstMRr"
      },
      "outputs": [],
      "source": [
        "#@title Colab setup and imports\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime \u003e Change Runtime Type, then select **'TPU'** in the dropdown.\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "import pprint\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "# from jax.config import config\n",
        "# config.update(\"jax_debug_nans\", True)\n",
        "from IPython.display import HTML, clear_output\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "from brax.io import html\n",
        "from brax.experimental import biggym\n",
        "from brax.experimental.composer import composer\n",
        "from brax.experimental.composer.training import mappo\n",
        "from brax.experimental.braxlines import experiments\n",
        "from brax.experimental.braxlines.common import evaluators\n",
        "from brax.experimental.braxlines.common import logger_utils\n",
        "from brax.experimental.braxlines.training import ppo\n",
        "\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  from jax.tools import colab_tpu\n",
        "  colab_tpu.setup_tpu()\n",
        "\n",
        "def show_env(env):\n",
        "  jit_env_reset = jax.jit(env.reset)\n",
        "  state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))\n",
        "  clear_output(wait=True)\n",
        "  return HTML(html.render(env.sys, [state.qp]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j-HZevAii3z-"
      },
      "outputs": [],
      "source": [
        "# @title Register a BIG-Gym registry\n",
        "registry_name = 'proant' # @param {type: 'string'}\n",
        "register_all = True # @param {type: 'boolean'}\n",
        "\n",
        "if register_all:\n",
        "  biggym.register_all(verbose=True)\n",
        "  pprint.pprint(biggym.ENVS_BY_TRACKS)\n",
        "env_names, comp_names, task_env_names, _ = biggym.register(registry_name)\n",
        "print(f'env_names: {env_names}')\n",
        "print(f'comp_names: {comp_names}')\n",
        "print(f'task_env_names: {task_env_names}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "riA5oBKFK5B7"
      },
      "outputs": [],
      "source": [
        "#@title Specify an environment\n",
        "env_name = 'sumo__proant__ant' # @param {type: 'string'}\n",
        "output_path = '' # @param {type: 'string'}\n",
        "show_params = True # @param {'type':'boolean'}\n",
        "\n",
        "if output_path:\n",
        "  output_path = f'{output_path}/{datetime.now().strftime(\"%Y%m%d\")}' \n",
        "  output_path = f'{output_path}/{env_name}'\n",
        "  print(f'Saving outputs to {output_path}')\n",
        "\n",
        "if show_params:\n",
        "  supported_params, support_kwargs = biggym.inspect_env(env_name=env_name)\n",
        "  print(f'supported_params for \"{env_name}\" =')\n",
        "  pprint.pprint(supported_params)\n",
        "  print(f'support variable-length kwargs? (i.e. **kwargs): {support_kwargs}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T1ZJ2jZDKH8Y"
      },
      "outputs": [],
      "source": [
        "#@title Create a custom env\n",
        "env_params =  {'num_legs': 2}# @param{'type': 'raw'}\n",
        "mode = 'viewer'# @param ['print_step', 'print_obs', 'print_sys', 'viewer']\n",
        "ignore_kwargs = True # @param {'type':'boolean'}\n",
        "\n",
        "# check supported params\n",
        "env_params = env_params or {}\n",
        "biggym.assert_env_params(env_name, env_params, ignore_kwargs)\n",
        "\n",
        "# create env\n",
        "env_fn = composer.create_fn(env_name=env_name, **env_params)\n",
        "# env_fn = biggym.create_fn(env_name=env_name, **env_params)\n",
        "env = env_fn()\n",
        "show_env(env)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WGRizNxK3MtF"
      },
      "outputs": [],
      "source": [
        "#@title Training the custom env\n",
        "num_timesteps_multiplier =   2# @param {type: 'number'}\n",
        "seed = 0 # @param{type: 'integer'}\n",
        "skip_training = False # @param {type: 'boolean'}\n",
        "\n",
        "log_path = output_path\n",
        "if log_path:\n",
        "  log_path = f'{log_path}/training_curves.csv'\n",
        "tab = logger_utils.Tabulator(output_path=log_path,\n",
        "    append=False)\n",
        "\n",
        "ppo_lib = mappo if biggym.is_multiagent(env) else ppo\n",
        "ppo_params = experiments.defaults.get_ppo_params(\n",
        "    'ant', num_timesteps_multiplier)\n",
        "train_fn = functools.partial(ppo_lib.train, **ppo_params)\n",
        "\n",
        "times = [datetime.now()]\n",
        "plotpatterns = ['eval/episode_reward', 'eval/episode_score']\n",
        "\n",
        "progress, _, _, _ = experiments.get_progress_fn(\n",
        "    plotpatterns, times, tab=tab, max_ncols=5,\n",
        "    xlim=[0, train_fn.keywords['num_timesteps']],\n",
        "    pre_plot_fn = lambda : clear_output(wait=True),\n",
        "    post_plot_fn = plt.show)\n",
        "\n",
        "if skip_training:\n",
        "  action_size = (env.group_action_shapes if \n",
        "    biggym.is_multiagent(env) else env.action_size)\n",
        "  params, inference_fn = ppo_lib.make_params_and_inference_fn(\n",
        "    env.observation_size, action_size,\n",
        "    normalize_observations=True)\n",
        "  inference_fn = jax.jit(inference_fn)\n",
        "else:\n",
        "  inference_fn, params, _ = train_fn(\n",
        "    environment_fn=env_fn, seed=seed,\n",
        "    extra_step_kwargs=False, progress_fn=progress)\n",
        "  print(f'time to jit: {times[1] - times[0]}')\n",
        "  print(f'time to train: {times[-1] - times[1]}')\n",
        "  print(f'Saved logs to {log_path}')\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-0VYySqOEk0"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing a trajectory of the learned inference function\n",
        "eval_seed = 0  # @param {'type': 'integer'}\n",
        "batch_size =  0# @param {type: 'integer'}\n",
        "\n",
        "env, states = evaluators.visualize_env(\n",
        "    env_fn=env_fn, inference_fn=inference_fn,\n",
        "    params=params, batch_size=batch_size,\n",
        "    seed = eval_seed, output_path=output_path,\n",
        "    verbose=True,\n",
        ")\n",
        "HTML(html.render(env.sys, [state.qp for state in states]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-48ybSUcyMJu"
      },
      "outputs": [],
      "source": [
        "#@title Plot information of the trajectory\n",
        "experiments.plot_states(states[1:], max_ncols=5)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "biggym_rl.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1PWXVD5BforifYfej0R-PQaW8T6hlFuJ6",
          "timestamp": 1639457023559
        },
        {
          "file_id": "1BCqjiaBc13bQK1gQiEMUQGrxjPTov2EN",
          "timestamp": 1639058480819
        },
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1630801484981
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
