{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssCOanHc8JH_"
      },
      "source": [
        "# Experiment Viewer\n",
        "\n",
        "We can visualize hyperparameter sweep result directly on Colab."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYe1kc3a4Oxc"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/experiment_viewer.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rlVNS8JstMRr"
      },
      "outputs": [],
      "source": [
        "# @title Colab setup and imports\n",
        "import numpy as np\n",
        "import os\n",
        "from typing import Tuple\n",
        "from IPython.display import clear_output\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "# add more if more loading output_path's\n",
        "output_path1 = '' #@param{'type': 'string'}\n",
        "output_path2 = '' #@param{'type': 'string'}\n",
        "output_path3 = '' #@param{'type': 'string'}\n",
        "# output_pathN ... add more if needed\n",
        "output_paths = [output_path1, output_path2, output_path3]\n",
        "output_structure = '**/**/**/training_curves.csv' #@param{type: 'string'}\n",
        "\n",
        "from brax.io import file\n",
        "from brax.experimental.braxlines import experiments\n",
        "\n",
        "csv_files = []\n",
        "for output_path in output_paths:\n",
        "  if not output_path:\n",
        "    continue\n",
        "  pattern = f'{output_path}/{output_structure}'\n",
        "  csv_files_ = file.Glob(pattern)\n",
        "  csv_files += csv_files_\n",
        "  print(f'Found {len(csv_files_)} files matching {pattern}')\n",
        "print(f'Total: Found {len(csv_files)} files')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Mn_Iml-w71b"
      },
      "outputs": [],
      "source": [
        "# @title Load data\n",
        "data = experiments.load_data(csv_files)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HRXt97Qj_mLh"
      },
      "outputs": [],
      "source": [
        "# @title Compute data statistics\n",
        "statistics, filepaths = experiments.compute_statistics(data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yEwJnB4F3sYk"
      },
      "outputs": [],
      "source": [
        "# @title Plot data\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "ncols =  4# @param{type: 'integer'}\n",
        "xmax = 3e8# @param{type: 'number'}\n",
        "xlabel = 'num_steps' # @param['num_steps']\n",
        "ylabel_re = '(episode_reward|energy_dist)' # @param['eval/episode_reward', 'metrics/entropy_z', 'metrics/entropy_all', 'losses/disc_loss', 'metrics/lgr', 'metrics/mi', '(lgr|episode_reward|mi|entropy_all)', '(episode_reward|energy_dist)']\n",
        "key_include_re= 'ant' # @param{'type': 'string'}\n",
        "key_exclude_re= '' # @param{'type': 'string'}\n",
        "legend_tags =  None# @param{type: 'raw'}\n",
        "\n",
        "summaries = experiments.plot_statistics(\n",
        "    statistics, ncols=ncols, xmax=xmax, xlabel=xlabel,\n",
        "    ylabel_re=ylabel_re,\n",
        "    legend_tags=legend_tags,\n",
        "    key_include_re=key_include_re,\n",
        "    key_exclude_re=key_exclude_re)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "experiment_viewer.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1-fplVbCqf5xQaosZXC9UcaiKfNpsNKjx",
          "timestamp": 1632951144835
        },
        {
          "file_id": "1zvUdazhGU7ZjPl-Vb2GSESCWtEgiw2bJ",
          "timestamp": 1629749582973
        },
        {
          "file_id": "1ZaAO4BS2tJ_03CIXdBCFibZR2yLl6dtv",
          "timestamp": 1629608669428
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
