{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# Training Mutual Information Maximization (MI-Max) RL algorithms in Brax\n",
        "\n",
        "In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [variational GCRL](https://arxiv.org/abs/2106.01404) algorithms or MI-maximization (MI-max) algorithms, which include [goal-conditioned RL](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.51.3077) and [DIAYN](https://arxiv.org/abs/1802.06070) as special cases. They are also known as *unsupervised* RL algorithms as they learn without task rewards. Let's try it out!\n",
        "\n",
        "This provides a bare bone implementation based on minimal modifications to the\n",
        "baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),\n",
        "enabling training in a few minutes. More features, examples, and benchmarked results will be added."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYe1kc3a4Oxc"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/mimax.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlVNS8JstMRr"
      },
      "outputs": [],
      "source": [
        "#@title Colab setup and imports\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime \u003e Change Runtime Type, then select **'TPU'** in the dropdown.\n",
        "\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import math\n",
        "import os\n",
        "import pprint\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from IPython.display import HTML, clear_output\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "from brax.io import html\n",
        "from brax.experimental.composer import composer\n",
        "from brax.experimental.braxlines import experiments\n",
        "from brax.experimental.braxlines.common import evaluators\n",
        "from brax.experimental.braxlines.common import logger_utils\n",
        "from brax.experimental.braxlines.envs.obs_indices import OBS_INDICES\n",
        "from brax.experimental.braxlines.training import ppo\n",
        "from brax.experimental.braxlines.vgcrl import evaluators as vgcrl_evaluators\n",
        "from brax.experimental.braxlines.vgcrl import utils as vgcrl_utils\n",
        "\n",
        "import tensorflow_probability as tfp\n",
        "\n",
        "tfp = tfp.substrates.jax\n",
        "tfd = tfp.distributions\n",
        "\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  from jax.tools import colab_tpu\n",
        "  colab_tpu.setup_tpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gh4QsRPnX770"
      },
      "outputs": [],
      "source": [
        "#@title Define task and experiment parameters\n",
        "\n",
        "#@markdown **Task Parameters**\n",
        "#@markdown \n",
        "#@markdown As in [DIAYN](https://arxiv.org/abs/1802.06070)\n",
        "#@markdown and [VGCRL](https://arxiv.org/abs/2106.01404),\n",
        "#@markdown we assume some task knowledge about interesting dimensions\n",
        "#@markdown of the environment `obs_indices` and their range `obs_scale`.\n",
        "#@markdown This is also used for evaluation and visualization.\n",
        "#@markdown\n",
        "#@markdown When the **task parameters** are the same, the metrics computed by\n",
        "#@markdown [vgcrl/evaluators.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/evaluators.py)\n",
        "#@markdown are directly comparable across experiment runs with different\n",
        "#@markdown **experiment parameters**. \n",
        "env_name = 'ant'  # @param ['ant', 'humanoid', 'halfcheetah', 'uni_ant', 'bi_ant']\n",
        "obs_indices = 'vel'  # @param ['vel']\n",
        "obs_scale = 10.0 #@param{'type': 'number'}\n",
        "obs_indices_str = obs_indices\n",
        "obs_indices = OBS_INDICES[obs_indices][env_name]\n",
        "\n",
        "#@markdown **Experiment Parameters**\n",
        "#@markdown See [vgcrl/utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/utils.py)\n",
        "evaluate_mi = False # @param{'type': 'boolean'}\n",
        "evaluate_lgr = False # @param{'type': 'boolean'}\n",
        "algo_name = 'diayn'  # @param ['gcrl', 'cdiayn', 'diayn', 'diayn_full', 'fixed_gcrl']\n",
        "env_reward_multiplier =   0# @param{'type': 'number'}\n",
        "obs_norm_reward_multiplier =   0# @param{'type': 'number'}\n",
        "normalize_obs_for_disc = False  # @param {'type': 'boolean'}\n",
        "seed =   0# @param {type: 'integer'}\n",
        "diayn_num_skills = 8  # @param {type: 'integer'}\n",
        "spectral_norm = True  # @param {'type': 'boolean'}\n",
        "output_path = '' # @param {'type': 'string'}\n",
        "task_name = \"\" # @param {'type': 'string'}\n",
        "exp_name = '' # @param {'type': 'string'}\n",
        "if output_path:\n",
        "  output_path = output_path.format(\n",
        "    date=datetime.now().strftime('%Y%m%d'))\n",
        "  task_name = task_name or f'{env_name}_{obs_indices_str}_{obs_scale}'\n",
        "  exp_name = exp_name or algo_name \n",
        "  output_path = f'{output_path}/{task_name}/{exp_name}'\n",
        "print(f'output_path={output_path}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NaJDZqhCLovU"
      },
      "outputs": [],
      "source": [
        "# @title Initialize Brax environment\n",
        "visualize = False # @param{'type': 'boolean'}\n",
        "\n",
        "# Create baseline environment to get observation specs\n",
        "base_env_fn = composer.create_fn(env_name=env_name)\n",
        "base_env = base_env_fn()\n",
        "\n",
        "# Create discriminator-parameterized environment\n",
        "disc = vgcrl_utils.create_disc_fn(algo_name=algo_name,\n",
        "                   observation_size=base_env.observation_size,\n",
        "                   obs_indices=obs_indices,\n",
        "                   scale=obs_scale,\n",
        "                   diayn_num_skills = diayn_num_skills,\n",
        "                   spectral_norm=spectral_norm,\n",
        "                   env=base_env,\n",
        "                   normalize_obs=normalize_obs_for_disc)()\n",
        "extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=seed))\n",
        "env_fn = vgcrl_utils.create_fn(env_name=env_name, wrapper_params=dict(\n",
        "    disc=disc, env_reward_multiplier=env_reward_multiplier,\n",
        "    obs_norm_reward_multiplier=obs_norm_reward_multiplier, \n",
        "    ))\n",
        "eval_env_fn = functools.partial(env_fn, auto_reset=False)\n",
        "\n",
        "# make inference functions and goals for LGR metric\n",
        "core_env = env_fn()\n",
        "params, inference_fn = ppo.make_params_and_inference_fn(\n",
        "      core_env.observation_size, core_env.action_size,\n",
        "      normalize_observations=True, extra_params=extra_params)\n",
        "inference_fn = jax.jit(inference_fn)\n",
        "goals = tfd.Uniform(low=-disc.obs_scale, high=disc.obs_scale).sample(\n",
        "    seed=jax.random.PRNGKey(0), sample_shape=(10,))\n",
        "\n",
        "# Visualize\n",
        "if visualize:\n",
        "  env = env_fn()\n",
        "  jit_env_reset = jax.jit(env.reset)\n",
        "  state = jit_env_reset(rng=jax.random.PRNGKey(seed=seed))\n",
        "  clear_output()  # clear out jax.lax warning before rendering\n",
        "  HTML(html.render(env.sys, [state.qp]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4vgMSWODfyMC"
      },
      "outputs": [],
      "source": [
        "#@title Training\n",
        "num_timesteps_multiplier =   6# @param {type: 'number'}\n",
        "ncols = 5 # @param{type: 'integer'}\n",
        "\n",
        "tab = logger_utils.Tabulator(\n",
        "    output_path=f'{output_path}/training_curves.csv',\n",
        "    append=False)\n",
        "\n",
        "# We determined some reasonable hyperparameters offline and share them here.\n",
        "n = num_timesteps_multiplier\n",
        "ppo_params = experiments.defaults.get_ppo_params(\n",
        "    env_name, num_timesteps_multiplier, default='ant')\n",
        "train_fn = functools.partial(ppo.train, **ppo_params)\n",
        "\n",
        "times = [datetime.now()]\n",
        "plotpatterns = ['eval/episode_reward', 'losses/disc_loss', 'metrics/lgr',\n",
        "            'metrics/entropy_all_', 'metrics/entropy_z_', 'metrics/mi_']\n",
        "\n",
        "def update_metrics_fn(num_steps, metrics, params):\n",
        "  if evaluate_mi:\n",
        "    metrics.update(vgcrl_evaluators.estimate_empowerment_metric(\n",
        "      env_fn=env_fn, disc=disc,\n",
        "      inference_fn=inference_fn, params=params,\n",
        "      # custom_obs_indices = list(range(core_env.observation_size))[:30],\n",
        "      # custom_obs_scale = obs_scale,\n",
        "    ))\n",
        "  if evaluate_lgr:\n",
        "    metrics.update(vgcrl_evaluators.estimate_latent_goal_reaching_metric( \n",
        "      params=params, env_fn=env_fn, disc=disc, inference_fn=inference_fn,\n",
        "      goals=goals))\n",
        "  \n",
        "progress, plot, _, _ = experiments.get_progress_fn(\n",
        "    plotpatterns, times, tab=tab, max_ncols=5,\n",
        "    xlim=[0, train_fn.keywords['num_timesteps']],\n",
        "    update_metrics_fn = update_metrics_fn,\n",
        "    pre_plot_fn = lambda : clear_output(wait=True),\n",
        "    post_plot_fn = plt.show)\n",
        "\n",
        "extra_loss_fns = dict(disc_loss=disc.disc_loss_fn) if extra_params else None\n",
        "_, params, _ = train_fn(\n",
        "    environment_fn=env_fn, progress_fn=progress, extra_params=extra_params,\n",
        "    extra_loss_fns=extra_loss_fns, seed=seed)\n",
        "clear_output(wait=True)\n",
        "plot(output_path=output_path)\n",
        "\n",
        "print(f'time to jit: {times[1] - times[0]}')\n",
        "print(f'time to train: {times[-1] - times[1]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p5eWOxg7RmQQ"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing skills of the learned inference function in 2D plot\n",
        "num_z = 5  # @param {type: 'integer'}\n",
        "num_samples_per_z = 5  # @param {type: 'integer'}\n",
        "time_subsampling = 10  # @param {type: 'integer'}\n",
        "time_last_n = 500 # @param {type: 'integer'}\n",
        "eval_seed = 0  # @param {type: 'integer'}\n",
        "\n",
        "vgcrl_evaluators.visualize_skills(\n",
        "    env_fn=eval_env_fn,\n",
        "    disc=disc,\n",
        "    inference_fn=inference_fn,\n",
        "    params=params,\n",
        "    output_path=output_path,\n",
        "    verbose=True,\n",
        "    num_z=num_z,\n",
        "    num_samples_per_z=num_samples_per_z,\n",
        "    time_subsampling=time_subsampling,\n",
        "    time_last_n=time_last_n,\n",
        "    save_video=True,\n",
        "    seed=eval_seed)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VpAxzRnRu_ej"
      },
      "outputs": [],
      "source": [
        "# @title Estimate [Latent Goal Reaching metric](https://arxiv.org/abs/2106.01404)\n",
        "num_samples_per_z =   10# @param {type: 'integer'}\n",
        "time_subsampling = 1  # @param {type: 'integer'}\n",
        "time_last_n = 500 # @param {type: 'integer'}\n",
        "eval_seed = 0  # @param {type: 'integer'}\n",
        "\n",
        "\n",
        "metrics = vgcrl_evaluators.estimate_latent_goal_reaching_metric( \n",
        "    params=params,\n",
        "    env_fn = eval_env_fn,\n",
        "    disc=disc,\n",
        "    inference_fn=inference_fn,\n",
        "    goals=goals,\n",
        "    num_samples_per_z=num_samples_per_z,\n",
        "    time_subsampling=time_subsampling,\n",
        "    time_last_n=time_last_n,\n",
        "    seed=eval_seed,\n",
        ")\n",
        "pprint.pprint(metrics)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uf5Jvf11NWUm"
      },
      "outputs": [],
      "source": [
        "#@title Estimate empowerment metrics using 1D/2D binning\n",
        "num_z =   10# @param {type: 'integer'}\n",
        "num_samples_per_z =   10# @param {type: 'integer'}\n",
        "time_subsampling = 1  # @param {type: 'integer'}\n",
        "time_last_n = 500 # @param {type: 'integer'}\n",
        "eval_seed = 0  # @param {type: 'integer'\n",
        "num_1d_bins = 1000  # @param {type: 'integer'}\n",
        "num_2d_bins =   30# @param {type: 'integer'}\n",
        "\n",
        "metrics = vgcrl_evaluators.estimate_empowerment_metric(\n",
        "    env_fn=eval_env_fn,\n",
        "    disc=disc,\n",
        "    inference_fn=inference_fn,\n",
        "    params=params,\n",
        "    num_z=num_z,\n",
        "    num_samples_per_z=num_samples_per_z,\n",
        "    time_subsampling=time_subsampling,\n",
        "    time_last_n=time_last_n,\n",
        "    num_1d_bins = num_1d_bins,\n",
        "    num_2d_bins = num_2d_bins,\n",
        "    custom_obs_indices = tuple(range(base_env.observation_size))[:25],\n",
        "    custom_obs_scale = obs_scale,\n",
        "    verbose = True,\n",
        "    seed=eval_seed)\n",
        "mi = {k.split('/')[-1]: float(v)\n",
        " for k, v in metrics.items() if 'mi_1d' in k}\n",
        "mi = sorted(mi.items(),\n",
        "            key=lambda x: x[1], reverse=True)\n",
        "pprint.pprint(mi)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RNMLEyaTspEM"
      },
      "outputs": [],
      "source": [
        "#@title Visualizing a trajectory of the learned inference function\n",
        "#@markdown If `z_value` is `None`, sample `z`, else fix `z` to `z_value`.\n",
        "z_value =   0# @param {'type': 'raw'}\n",
        "eval_seed = 0  # @param {'type': 'integer'}\n",
        "\n",
        "z = {\n",
        "    'fixed_gcrl': jnp.ones(disc.z_size) * z_value,\n",
        "    'gcrl': jnp.ones(disc.z_size) * z_value,\n",
        "    'cdiayn': jnp.ones(disc.z_size) * z_value,\n",
        "    'diayn': jax.nn.one_hot(jnp.array(int(z_value)), disc.z_size),\n",
        "    'diayn_full': jax.nn.one_hot(jnp.array(int(z_value)), disc.z_size),\n",
        "}[algo_name] if z_value is not None else None\n",
        "\n",
        "env, states = evaluators.visualize_env(\n",
        "    env_fn=eval_env_fn,\n",
        "    inference_fn=inference_fn,\n",
        "    params=params,\n",
        "    batch_size=0,\n",
        "    seed = eval_seed,\n",
        "    reset_args = (z,),\n",
        "    step_args = (params['normalizer'], params['extra']),\n",
        "    output_path=output_path,\n",
        "    output_name=f'video_z_{z_value}',\n",
        ")\n",
        "HTML(html.render(env.sys, [state.qp for state in states]))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "mimax.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1635166389289
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
