{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# Sweep Training\n",
        "\n",
        "We can perform hyperparameter sweep directly on Colab."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYe1kc3a4Oxc"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/experiment_sweep.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlVNS8JstMRr"
      },
      "outputs": [],
      "source": [
        "#@title Colab setup and imports\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime \u003e Change Runtime Type, then select **'TPU'** in the dropdown.\n",
        "\n",
        "#@markdown See [config_utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/common/config_utils.py)\n",
        "#@markdown for the configuration format.\n",
        "#@markdown See [experiments/](https://github.com/google/brax/blob/main/brax/experimental/braxlines/experiments)\n",
        "#@markdown for the example configurations.\n",
        "from datetime import datetime\n",
        "import importlib\n",
        "import os\n",
        "import pprint\n",
        "from IPython.display import HTML, clear_output\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "experiment = 'custom'# @param ['custom', 'mimax_sweep', 'ant_push_sweep', 'dmin_sweep']\n",
        "output_path = '/tmp' #@param{'type': 'string'}\n",
        "start_count = 0 # @param{'type': 'integer'}\n",
        "end_count = 100000000 # @param{'type': 'integer'}\n",
        "experiment_path = '' #@param{'type': 'string'}\n",
        "experiment_path=experiment_path or datetime.now().strftime('%Y%m%d_%H%M%S')\n",
        "output_path = f'{output_path}/{experiment_path}'\n",
        "\n",
        "custom_agent_module = f'brax.experimental.braxlines.vgcrl.train'\n",
        "custom_config = [\n",
        "    dict(\n",
        "        env_name = ['ant'],\n",
        "        obs_indices = 'vel',\n",
        "        algo_name = ['gcrl', 'diayn', 'cdiayn', 'diayn_full'],\n",
        "        obs_scale = [5.0],\n",
        "        seed = [0],\n",
        "        normalize_obs_for_disc = False,\n",
        "        evaluate_mi = False,\n",
        "        evaluate_lgr = False,\n",
        "        env_reward_multiplier = 0.0,\n",
        "        spectral_norm = [True],\n",
        "        ppo_params = dict(\n",
        "          num_timesteps=int(2.5 * 1e8),\n",
        "          reward_scaling=10,\n",
        "          episode_length=1000,\n",
        "          normalize_observations=True,\n",
        "          action_repeat=1,\n",
        "          unroll_length=5,\n",
        "          num_minibatches=32,\n",
        "          num_update_epochs=4,\n",
        "          discounting=0.95,\n",
        "          learning_rate=3e-4,\n",
        "          entropy_cost=1e-2,\n",
        "          num_envs=2048,\n",
        "          batch_size=1024,)\n",
        "    ),\n",
        "  ]\n",
        "\n",
        "from brax.experimental.braxlines.common import config_utils\n",
        "from brax.experimental.braxlines.experiments import load_experiment\n",
        "from brax.experimental.braxlines.experiments import run_experiment\n",
        "if experiment == 'custom':\n",
        "  config = custom_config\n",
        "  agent_module = custom_agent_module\n",
        "else:\n",
        "  agent_module, config = load_experiment(experiment)\n",
        "agent_module = importlib.import_module(agent_module)\n",
        "\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  from jax.tools import colab_tpu\n",
        "  colab_tpu.setup_tpu()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NaJDZqhCLovU"
      },
      "outputs": [],
      "source": [
        "#@title Launch experiments\n",
        "ignore_errors = False # @param{'type': 'boolean'}\n",
        "run_experiment(\n",
        "  experiment_name=experiment, output_path=output_path,\n",
        "  start_count=start_count, end_count=end_count,\n",
        "  ignore_errors=ignore_errors,\n",
        "  agent_module=agent_module, config=config)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "experiment_sweep.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1zvUdazhGU7ZjPl-Vb2GSESCWtEgiw2bJ",
          "timestamp": 1632487390636
        },
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1629608669428
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
