{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "realworldrl_notebook.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      }
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OF9qRa7D8643"
      },
      "source": [
        "**To use the Jupyter notebook, ensure that you have installed the following packages (in the pre-defined order):**\n",
        "1. Python3\n",
        "2. Matplotlib\n",
        "3. Numpy\n",
        "4. Scipy\n",
        "5. tqdm\n",
        "6. Mujoco (Make sure Mujoco is installed before installing our Realworld RL suite)\n",
        "7. The Realworld RL Suite\n",
        "\n",
        "**It is recommended to use the realworldrl_venv virtual environment that you used when installing the realworldrl_suite package. To do so, you may need to run the following commands:**  \n",
        "\n",
        "```\n",
        "pip3 install --user ipykernel\n",
        "python3 -m ipykernel install --user --name=realworldrl_venv\n",
        "```\n",
        "\n",
        "Then in this notebook, click 'Kernel' in the menu, then click 'Change Kernel' and select `realworldrl_venv`\n",
        "\n",
        "**Note**: You may need to restart the Jupyter kernel to see the updated virtual environment in the Kernel list.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yLMLidYq8646"
      },
      "source": [
        "**Import the necessary libraries**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "both",
        "colab_type": "code",
        "id": "icAkv-zLplqP",
        "colab": {}
      },
      "source": [
        "#@title \n",
        "# Copyright 2020 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tqdm\n",
        "\n",
        "import collections\n",
        "\n",
        "import realworldrl_suite.environments as rwrl\n",
        "from realworldrl_suite.utils import evaluators"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DDcZhPBB865C"
      },
      "source": [
        "**Define the environment and the policy**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7gDwzTFz0Tak",
        "colab": {}
      },
      "source": [
        "total_episodes = 1000  # The analysis tools require at least 100 episodes.\n",
        "domain_name = 'cartpole'\n",
        "task_name = 'realworld_swingup'\n",
        "\n",
        "# Define the challenge dictionaries\n",
        "safety_spec_dict = dict(enable=True, safety_coeff=0.5)\n",
        "delay_spec_dict = dict(enable=True, actions=20)\n",
        "\n",
        "log_safety_violations = True\n",
        "\n",
        "def random_policy(action_spec):\n",
        "\n",
        "  def _act(timestep):\n",
        "    del timestep\n",
        "    return np.random.uniform(low=action_spec.minimum,\n",
        "                             high=action_spec.maximum,\n",
        "                             size=action_spec.shape)\n",
        "  return _act\n",
        "\n",
        "\n",
        "env = rwrl.load(\n",
        "    domain_name=domain_name,\n",
        "    task_name=task_name,\n",
        "    safety_spec=safety_spec_dict,\n",
        "    delay_spec=delay_spec_dict,\n",
        "    log_output=os.path.join('/tmp/', 'log.npz'),\n",
        "    environment_kwargs=dict(log_safety_vars=log_safety_violations, \n",
        "                            flat_observation=True,\n",
        "                            log_every=10))\n",
        "\n",
        "policy = random_policy(action_spec=env.action_spec())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lj0CFq1k865I"
      },
      "source": [
        "**Run the main loop**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ot9xUSM8z8Ib",
        "colab": {}
      },
      "source": [
        "rewards = []\n",
        "episode_counter = 0\n",
        "for _ in tqdm.tqdm(range(total_episodes)):\n",
        "    timestep = env.reset()\n",
        "    total_reward = 0.\n",
        "    while not timestep.last():\n",
        "        action = policy(timestep)\n",
        "        timestep = env.step(action)\n",
        "        total_reward += timestep.reward\n",
        "    rewards.append(total_reward)\n",
        "    episode_counter+=1\n",
        "print('Random policy total reward per episode: {:.2f} +- {:.2f}'.format(\n",
        "np.mean(rewards), np.std(rewards)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ne2i3ZFRsnVw",
        "colab": {}
      },
      "source": [
        "f = open(env.logs_path, \"rb\")   \n",
        "stats = np.load(f, allow_pickle=True)\n",
        "evaluator = evaluators.Evaluators(stats)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xU-jhT4W865T"
      },
      "source": [
        "**Load the average return plot as a function of the number of episodes**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "AEHttr9JyttU",
        "colab": {}
      },
      "source": [
        "fig = evaluator.get_return_plot()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "327LSohpug29",
        "colab_type": "text"
      },
      "source": [
        "**Compute regret**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nuyf6byUufCP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig = evaluator.get_convergence_plot()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xk8vY0b6u2um",
        "colab_type": "text"
      },
      "source": [
        "**Compute instability**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QoqFSJDZu6Eo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig = evaluator.get_stability_plot()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RkPTOgzA865Z"
      },
      "source": [
        "**Safety violations plot (left figure) and the mean evolution of safety constraint violations during an episode (right figure)**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "i7BRSJsWzJEB",
        "colab": {}
      },
      "source": [
        "fig = evaluator.get_safety_plot()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TSxbfXDOyV2q",
        "colab_type": "text"
      },
      "source": [
        "**Multiple training seeds can be aggregated.**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dmi6k1vyxsLn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# We emulate multiple runs by copying the same logs with added noise.\n",
        "\n",
        "all_evaluators = []\n",
        "for _ in range(10):\n",
        "  another_evaluator = evaluators.Evaluators(stats)\n",
        "  v = another_evaluator.stats['return_stats']['episode_totals']\n",
        "  v += np.random.randn(*v.shape) * 100.\n",
        "  all_evaluators.append(another_evaluator)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36t1n8k17oIp",
        "colab_type": "text"
      },
      "source": [
        "**Normalized regret across all runs.**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J2XujjsfykR5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "evaluators.get_regret_plot(all_evaluators)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BNgepDFv7nna",
        "colab_type": "text"
      },
      "source": [
        "**Return across all runs.**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "khHpm_Gy5nOP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "evaluators.get_return_plot(all_evaluators, stride=500)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HY3NkWxNRceA",
        "colab_type": "text"
      },
      "source": [
        "**Additional useful functions**\n",
        "\n",
        "Multi-objective runs can be analyzed using:\n",
        "\n",
        "```\n",
        "evaluator.get_multiobjective_plot()  # For a single run.\n",
        "evaluators.get_multiobjective_plot(all_evaluators)  # For multiple runs.\n",
        "```"
      ]
    }
  ]
}
