{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# Training Divergence Minimization (D-min) RL  algorithms in Brax\n",
        "\n",
        "In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [adversarial inverse RL](https://arxiv.org/abs/1911.02256) algorithms, which includes [GAIL](https://papers.nips.cc/paper/2016/hash/cc7e2b878868cbae992d1fb743995d8f-Abstract.html) and [AIRL](https://arxiv.org/abs/1710.11248) as special cases. These algorithms minimize D(p(s,a), p\\*(s,a)) or D(p(s), p\\*(s)), the divergence D between the policy's state(-action) marginal distribution p(s,a) or p(s), and a given target distribution p\\*(s,a) or p\\*(s). As discussed in [f-MAX](https://arxiv.org/abs/1911.02256), these algorithms could also be used for [state-marginal matching](https://arxiv.org/abs/1906.05274) RL besides imitation learning. Let's try them out!\n",
        "\n",
        "This provides a bare bone implementation based on minimal modifications to the\n",
        "baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),\n",
        "enabling training in a few minutes. More features, tunings, and benchmarked results will be added."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYe1kc3a4Oxc"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/dmin.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_sOmCoOrF0F8"
      },
      "outputs": [],
      "source": [
        "#@title Install Brax and some helper modules\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime \u003e Change runtime type, then select 'TPU' in the dropdown.\n",
        "\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from IPython.display import HTML, clear_output \n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "import tensorflow_probability as tfp\n",
        "from brax.io import html\n",
        "from brax.experimental.composer import composer\n",
        "from brax.experimental.braxlines import experiments\n",
        "from brax.experimental.braxlines.common import evaluators\n",
        "from brax.experimental.braxlines.envs.obs_indices import OBS_INDICES\n",
        "from brax.experimental.braxlines.training import ppo\n",
        "from brax.experimental.braxlines.irl_smm import evaluators as irl_evaluators\n",
        "from brax.experimental.braxlines.irl_smm import utils as irl_utils\n",
        "\n",
        "tfp = tfp.substrates.jax\n",
        "tfd = tfp.distributions\n",
        "\n",
        "if 'COLAB_TPU_ADDR' in os.environ:\n",
        "  from jax.tools import colab_tpu\n",
        "  colab_tpu.setup_tpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NaJDZqhCLovU"
      },
      "outputs": [],
      "source": [
        "#@title Define task and experiment parameters\n",
        "\n",
        "#@markdown **Task Parameters**\n",
        "#@markdown \n",
        "#@markdown As in [SMM](https://arxiv.org/abs/1906.05274)\n",
        "#@markdown and [f-MAX](https://arxiv.org/abs/1911.02256),\n",
        "#@markdown we assume some task knowledge about interesting dimensions\n",
        "#@markdown of the environment `obs_indices` and their range `obs_scale`.\n",
        "#@markdown This is also used for evaluation and visualization\n",
        "#@markdown\n",
        "#@markdown When the **task parameters** are the same, the metrics computed by\n",
        "#@markdown [irl_smm/evaluators.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/evaluators.py)\n",
        "#@markdown are directly comparable across experiment runs with different\n",
        "#@markdown **experiment parameters**. \n",
        "env_name = 'ant'  # @param ['ant', 'humanoid', 'halfcheetah', 'uni_ant', 'bi_ant']\n",
        "obs_indices = 'vel'  # @param ['vel']\n",
        "target_num_modes =   2# @param{'type': 'integer'}\n",
        "obs_scale = 8.0 #@param{'type': 'number'}\n",
        "obs_indices_str = obs_indices\n",
        "obs_indices = OBS_INDICES[obs_indices][env_name]\n",
        "\n",
        "#@markdown **Experiment Parameters**\n",
        "#@markdown See [irl_smm/utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/irl_smm/utils.py)\n",
        "reward_type = \"mle\"  # @param ['gail', 'airl', 'gail2', 'fairl', 'mle']\n",
        "logits_clip_range = 10.0# @param {'type': 'number'}\n",
        "normalize_obs_for_disc = False # @param {'type': 'boolean'}\n",
        "normalize_obs_for_rl = True # @param {'type': 'boolean'}\n",
        "spectral_norm = False  # @param {'type': 'boolean'}\n",
        "gradient_penalty_weight = 0.0 #@param {type: 'number'}\n",
        "env_reward_multiplier = 0.0 # @param {'type': 'number'}\n",
        "evaluate_dist = False # @param{type: 'boolean'}\n",
        "seed = 0 # @param{type: 'integer'}\n",
        "\n",
        "output_path = '' # @param {'type': 'string'}\n",
        "task_name = \"\" # @param {'type': 'string'}\n",
        "exp_name = '' # @param {'type': 'string'}\n",
        "if output_path:\n",
        "  output_path = output_path.format(\n",
        "    date=datetime.now().strftime('%Y%m%d'))\n",
        "  task_name = task_name or f'{env_name}_{obs_indices}_{obs_scale}_{target_num_modes}'\n",
        "  exp_name = exp_name or f'{reward_type}'\n",
        "  output_path = f'{output_path}/{task_name}/{exp_name}'\n",
        "print(f'output_path={output_path}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fUqotLSex9Z3"
      },
      "outputs": [],
      "source": [
        "# @title Generate target distribution to match\n",
        "target_num_samples = 250  # @param{type: 'integer'}\n",
        "\n",
        "rng = jax.random.PRNGKey(seed=seed)\n",
        "jit_get_dist = jax.jit(\n",
        "    functools.partial(\n",
        "        irl_utils.get_multimode_dist,\n",
        "        indexed_obs_dim=len(obs_indices),\n",
        "        num_modes=target_num_modes, scale=obs_scale))\n",
        "target_dist = jit_get_dist()\n",
        "target_data = target_dist.sample(\n",
        "    seed=rng, sample_shape=(target_num_samples,))\n",
        "target_data_2d = irl_utils.make_2d(target_data)\n",
        "\n",
        "print(f'target_data={target_data.shape}')\n",
        "plt.scatter(\n",
        "    x=target_data_2d[:, 0], y=target_data_2d[:, 1], c=jnp.array([0, 0, 1]))\n",
        "plt.xlim((-obs_scale, obs_scale))\n",
        "plt.ylim((-obs_scale, obs_scale))\n",
        "plt.title('target')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KyiOtw6LvZ76"
      },
      "outputs": [],
      "source": [
        "target_data.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rM7nNiXJU-4s"
      },
      "outputs": [],
      "source": [
        "# @title Make environment and inference_fn\n",
        "visualize = False # @param {'type': 'boolean'}\n",
        "\n",
        "base_env_fn = composer.create_fn(env_name=env_name)\n",
        "base_env = base_env_fn()\n",
        "disc = irl_utils.IRLDiscriminator(\n",
        "    input_size=len(obs_indices),\n",
        "    obs_indices=obs_indices,\n",
        "    obs_scale=obs_scale,\n",
        "    include_action=False,\n",
        "    logits_clip_range=logits_clip_range,\n",
        "    spectral_norm=spectral_norm,\n",
        "    gradient_penalty_weight=gradient_penalty_weight,\n",
        "    reward_type=reward_type,\n",
        "    normalize_obs=normalize_obs_for_disc,\n",
        "    target_data=target_data,\n",
        "    target_dist_fn=jit_get_dist,\n",
        "    env=base_env)\n",
        "extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=0))\n",
        "env_fn = irl_utils.create_fn(\n",
        "    env_name=env_name,\n",
        "    wrapper_params=dict(\n",
        "        disc=disc,\n",
        "        env_reward_multiplier=env_reward_multiplier,\n",
        "    ))\n",
        "eval_env_fn = functools.partial(env_fn, auto_reset=False)\n",
        "\n",
        "# make inference functions and goals for evaluation\n",
        "core_env = env_fn()\n",
        "params, inference_fn = ppo.make_params_and_inference_fn(\n",
        "    core_env.observation_size,\n",
        "    core_env.action_size,\n",
        "    normalize_observations=normalize_obs_for_rl,\n",
        "    extra_params=extra_params)\n",
        "inference_fn = jax.jit(inference_fn)\n",
        "\n",
        "# Visualize in 3D\n",
        "if visualize:\n",
        "  env = env_fn()\n",
        "  jit_env_reset = jax.jit(env.reset)\n",
        "  state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))\n",
        "  clear_output()  # clear out jax.lax warning before rendering\n",
        "  HTML(html.render(env.sys, [state.qp]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4vgMSWODfyMC"
      },
      "outputs": [],
      "source": [
        "#@title Training\n",
        "num_timesteps_multiplier =   4# @param {type: 'number'}\n",
        "\n",
        "# We determined some reasonable hyperparameters offline and share them here.\n",
        "ppo_params = experiments.defaults.get_ppo_params(\n",
        "    env_name, num_timesteps_multiplier, default='ant')\n",
        "train_fn = functools.partial(ppo.train, **ppo_params)\n",
        "\n",
        "times = [datetime.now()]\n",
        "plotdata = {}\n",
        "plotkeys = [\n",
        "    'eval/episode_reward', 'losses/disc_loss', 'losses/total_loss',\n",
        "    'losses/policy_loss', 'losses/value_loss', 'losses/entropy_loss',\n",
        "    'metrics/energy_dist']\n",
        "\n",
        "def progress(num_steps, metrics, params):\n",
        "  times.append(datetime.now())\n",
        "  if evaluate_dist:\n",
        "    metrics.update(irl_evaluators.estimate_energy_distance_metric(\n",
        "        params=params, disc=disc, target_data=target_data, env_fn=env_fn,\n",
        "        inference_fn=inference_fn))\n",
        "\n",
        "  for key, v in metrics.items():\n",
        "    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))\n",
        "    plotdata[key]['x'] += [num_steps]\n",
        "    plotdata[key]['y'] += [v]\n",
        "  clear_output(wait=True)\n",
        "  num_figs = len(plotkeys) + 1\n",
        "  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5 * num_figs, 3))\n",
        "  # plot learning curves\n",
        "  for i, key in enumerate(plotkeys):\n",
        "    if key in plotdata:\n",
        "      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])\n",
        "    axs[i].set(xlabel='# environment steps', ylabel=key)\n",
        "    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])\n",
        "  irl_evaluators.visualize_disc(\n",
        "      params=params, disc=disc, num_grid=25, fig=fig, axs=axs)\n",
        "  plt.show()\n",
        "\n",
        "extra_loss_fns = dict(disc_loss=disc.disc_loss_fn)\n",
        "inference_fn, params, _ = train_fn(\n",
        "    seed=seed,\n",
        "    environment_fn=env_fn,\n",
        "    progress_fn=progress,\n",
        "    extra_params=extra_params,\n",
        "    extra_loss_fns=extra_loss_fns)\n",
        "\n",
        "print(f'time to jit: {times[1] - times[0]}')\n",
        "print(f'time to train: {times[-1] - times[1]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p5eWOxg7RmQQ"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing skills of the learned inference function in 2D plot\n",
        "num_samples = 10  # @param {type: 'integer'}\n",
        "time_subsampling = 10  # @param {type: 'integer'}\n",
        "time_last_n = 500 # @param {type: 'integer'}\n",
        "eval_seed = 0  # @param {type: 'integer'}\n",
        "\n",
        "metrics = irl_evaluators.estimate_energy_distance_metric(\n",
        "    params=params,\n",
        "    disc=disc,\n",
        "    target_data=target_data,\n",
        "    env_fn=eval_env_fn,\n",
        "    inference_fn=inference_fn,\n",
        "    num_samples=num_samples,\n",
        "    time_subsampling=time_subsampling,\n",
        "    time_last_n=time_last_n,\n",
        "    visualize=True,\n",
        "    figsize=(3.5,3),\n",
        "    seed=eval_seed,\n",
        "    output_path=output_path,\n",
        ")\n",
        "print(metrics)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AAVx6St9u88g"
      },
      "outputs": [],
      "source": [
        "env_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RNMLEyaTspEM"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing a trajectory of the learned inference function\n",
        "eval_seed = 0  # @param {'type': 'integer'}\n",
        "\n",
        "env, states = evaluators.visualize_env(\n",
        "    env_fn=eval_env_fn,\n",
        "    inference_fn=inference_fn,\n",
        "    params=params,\n",
        "    batch_size=0,\n",
        "    seed = eval_seed,\n",
        "    step_args = (params['normalizer'], params['extra']),\n",
        "    output_path=output_path,\n",
        ")\n",
        "HTML(html.render(env.sys, [state.qp for state in states]))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "dmin.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1VwIb25nv6nJT52pSuZn4ldtAWKGvENwl",
          "timestamp": 1635166450696
        },
        {
          "file_id": "1Gu8SgV7reDUv8weq2P6PRq_YQ2f88Ahv",
          "timestamp": 1628752357019
        },
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1628294539853
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
