{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/rlworkgroup/garage/blob/master/examples/jupyter/custom_env.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Za0nLGHy5jyP"
   },
   "source": [
    "Demonstrate usage of a custom openai/gym environment with rlworkgroup/garage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "acXH8kwHsAYT"
   },
   "source": [
    "Demonstrate usage of [garage](https://github.com/rlworkgroup/garage) with a custom `openai/gym` environment in a jupyter notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GmPLYXrXH7A5"
   },
   "source": [
    "## Install pre-requisites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 901
    },
    "colab_type": "code",
    "id": "Aj3bL9HaG5HL",
    "outputId": "21508810-1b19-4f63-c1da-fdde3cb17e88"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fatal: destination path 'garage' already exists and is not an empty directory.\n",
      "bash: scripts/setup_colab.sh: No such file or directory\n"
     ]
    }
   ],
   "source": [
    "!echo \"abcd\" > mujoco_fake_key\n",
    "\n",
    "\n",
    "!git clone --depth 1 https://github.com/rlworkgroup/garage/\n",
    "\n",
    "!cd garage\n",
    "!bash scripts/setup_colab.sh --mjkey ../mujoco_fake_key --no-modify-bashrc > /dev/null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_FWFcZWCgjrm"
   },
   "outputs": [
    {
     "ename": "Exception",
     "evalue": "Please restart your runtime so that the installed dependencies for 'garage' can be loaded, and then resume running the notebook",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mException\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-0a7126b9b224>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please restart your runtime so that the installed dependencies for 'garage' can be loaded, and then resume running the notebook\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mException\u001b[0m: Please restart your runtime so that the installed dependencies for 'garage' can be loaded, and then resume running the notebook"
     ]
    }
   ],
   "source": [
    "raise Exception(\"Please restart your runtime so that the installed dependencies for 'garage' can be loaded, and then resume running the notebook\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wCuJ9d-Jgk1Y"
   },
   "source": [
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "\n",
    "---\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "34oLOKVb5t8l"
   },
   "source": [
    "# custom gym environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oBv4mojWRT6v"
   },
   "outputs": [],
   "source": [
    "# Create a gym env that simulates the current water treatment plant\n",
    "# Based on https://github.com/openai/gym/blob/master/gym/envs/toy_text/nchain.py\n",
    "import os\n",
    "import gym\n",
    "from gym import spaces\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "# Gym env\n",
    "class MyEnv(gym.Env):\n",
    "    \"\"\"Custom gym environment\n",
    "    \n",
    "    Observation: Coin flip (Discrete binary: 0/1)\n",
    "      \n",
    "    Actions: Guess of coin flip outcome (Discrete binary: 0/1)\n",
    "      \n",
    "    Reward: Guess the coin flip correctly\n",
    "      \n",
    "    Episode termination: Make 5 correct guesses within 20 attempts\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        # set action/observation spaces\n",
    "        self.action_space = spaces.Discrete(2)\n",
    "        self.observation_space = spaces.Discrete(2)\n",
    "        self.reset()\n",
    "\n",
    "    def step(self, action):\n",
    "        assert self.action_space.contains(action), \"action not in action space!\"\n",
    "        \n",
    "        # flip a coin\n",
    "        self.state = np.random.rand() < 0.5\n",
    "\n",
    "        # increment number of attempts\n",
    "        self.attempt += 1\n",
    "        \n",
    "        # calculate reward of this element\n",
    "        reward = (action == self.state)\n",
    "        self.score += reward\n",
    "          \n",
    "        # allow a maximum number of attempts or reach max score\n",
    "        done = (self.attempt >= 20) | (self.score >= 5)\n",
    "          \n",
    "        return self.state, reward, done, {}\n",
    "      \n",
    "    def reset(self):\n",
    "      # accumulate score\n",
    "      self.score = 0\n",
    "      # count number of attempts\n",
    "      self.attempt = 0\n",
    "      \n",
    "      return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 272
    },
    "colab_type": "code",
    "id": "Xi3SwHSJO3xm",
    "outputId": "d8557925-1e7c-4173-d2f2-ce035e5811be"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: action=0, observation=1 => reward = 0, done = False\n",
      "step 1: action=1, observation=0 => reward = 0, done = False\n",
      "step 2: action=0, observation=1 => reward = 0, done = False\n",
      "step 3: action=0, observation=1 => reward = 0, done = False\n",
      "step 4: action=1, observation=0 => reward = 0, done = False\n",
      "step 5: action=1, observation=0 => reward = 0, done = False\n",
      "step 6: action=1, observation=1 => reward = 1, done = False\n",
      "step 7: action=0, observation=0 => reward = 1, done = False\n",
      "step 8: action=0, observation=0 => reward = 1, done = False\n",
      "step 9: action=0, observation=1 => reward = 0, done = False\n",
      "step 10: action=1, observation=0 => reward = 0, done = False\n",
      "step 11: action=0, observation=0 => reward = 1, done = False\n",
      "step 12: action=1, observation=0 => reward = 0, done = False\n",
      "step 13: action=0, observation=1 => reward = 0, done = False\n",
      "step 14: action=0, observation=1 => reward = 0, done = False\n",
      "step 15: action=1, observation=1 => reward = 1, done = True\n"
     ]
    }
   ],
   "source": [
    "# some smoke testing\n",
    "env_test = MyEnv()\n",
    "observation = env_test.reset()\n",
    "\n",
    "for step in range(40):\n",
    "  action = np.random.rand() < 0.5\n",
    "  observation, reward, done, _ = env_test.step(action)\n",
    "  print(\"step %i: action=%i, observation=%i => reward = %i, done = %s\" % (step, action, observation, reward, done))\n",
    "  if done: break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Nd-RbAhH6Kx8"
   },
   "source": [
    "## Import dependencies "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9j8M1L9S6rvU"
   },
   "outputs": [],
   "source": [
    "# The contents of this cell are mostly copied from garage/examples/...\n",
    "\n",
    "from garage.np.baselines import LinearFeatureBaseline # <<<<<< requires restarting the runtime in colab after the 1st dependency installation above\n",
    "from garage.envs import GarageEnv\n",
    "from garage.envs import normalize\n",
    "from garage.tf.algos import TRPO\n",
    "from garage.tf.policies import GaussianMLPPolicy\n",
    "from garage.tf.policies import CategoricalMLPPolicy\n",
    "\n",
    "import gym # already imported before\n",
    "\n",
    "\n",
    "from garage.experiment import LocalTFRunner\n",
    "from garage.experiment.deterministic import set_seed\n",
    "from dowel import logger, StdOutput"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "v7LQO4zBp8h4"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import garage\n",
    "# set up the dowel logger\n",
    "log_dir = os.path.join(os.getcwd(), 'data')\n",
    "ctxt=garage.experiment.SnapshotConfig(snapshot_dir=log_dir,\n",
    "                                      snapshot_mode='last',\n",
    "                                      snapshot_gap=1)\n",
    "\n",
    "# log to stdout\n",
    "logger.add_output(StdOutput())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Register Env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UNoyVv0tA23n"
   },
   "outputs": [],
   "source": [
    "# register the env with gym\n",
    "# https://github.com/openai/gym/tree/master/gym/envs#how-to-create-new-environments-for-gym\n",
    "from gym.envs.registration import register\n",
    "\n",
    "register(\n",
    "    id='MyEnv-v0',\n",
    "    entry_point=MyEnv,\n",
    ")\n",
    "\n",
    "# test registration was successful\n",
    "env = gym.make(\"MyEnv-v0\")\n",
    "# env = GarageEnv(normalize(gym.make(\"MyEnv-v0\")))\n",
    "# env = GarageEnv(env_name='MyEnv-v0') \n",
    "# env = GarageEnv(gym.make('MyEnv-v0'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tx8_cmOe63QK"
   },
   "outputs": [],
   "source": [
    "# Wrap the environment to convert the observation to numpy array\n",
    "# Not sure why this is necessary ATM\n",
    "# Based on https://github.com/openai/gym/blob/5404b39d06f72012f562ec41f60734bd4b5ceb4b/gym/wrappers/dict.py\n",
    "\n",
    "      \n",
    "from gym import wrappers\n",
    "\n",
    "class NpWrapper(gym.ObservationWrapper):\n",
    "    def observation(self, observation):\n",
    "        obs = np.array(observation).astype('int')\n",
    "        return obs\n",
    "      \n",
    "env = NpWrapper(env)\n",
    "env = GarageEnv(normalize(env))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define result analytic function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_test_result(env, policy, obs_var):\n",
    "\n",
    "    # test results\n",
    "    n_experiments = 10\n",
    "    row_all = []\n",
    "\n",
    "    for i in range(n_experiments):\n",
    "      #print(\"experiment \", i+1)\n",
    "\n",
    "      policy.build(obs_var)\n",
    "      # reset\n",
    "      obs_initial = env.reset()\n",
    "\n",
    "      # start\n",
    "      done = False\n",
    "      obs_i = obs_initial\n",
    "      while not done:\n",
    "        row_i = {}\n",
    "        row_i['exp'] = i + 1\n",
    "        row_i['obs'] = obs_i\n",
    "        act_i, _ = policy.get_action(obs_i.flatten())\n",
    "        row_i['act'] = act_i\n",
    "        obs_i, rew_i, done, _ = env.step(act_i)\n",
    "        row_i['obs'] = obs_i\n",
    "        row_i['rew'] = rew_i\n",
    "        row_all.append(row_i)\n",
    "\n",
    "        if done: break\n",
    "\n",
    "    env.close()\n",
    "    \n",
    "    return row_all\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hFgHLyD7oRaH"
   },
   "source": [
    "## Define and train the algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "hyper_parameters = {\n",
    "    'hidden_sizes': [32, 32],\n",
    "    'max_kl': 0.01,\n",
    "    'gae_lambda': 0.97,\n",
    "    'discount': 0.99,\n",
    "    'max_path_length': 100,\n",
    "    'n_epochs': 2,\n",
    "    'batch_size': 10000,\n",
    "}\n",
    "\n",
    "\n",
    "def test_gym_environment(env, ctxt=None, seed=1):\n",
    "    set_seed(seed)\n",
    "    with LocalTFRunner(snapshot_config=ctxt) as runner:\n",
    "        policy = CategoricalMLPPolicy(\n",
    "            name=\"policy\", env_spec=env.spec, hidden_sizes=(32, 32))\n",
    "\n",
    "        obs_var = tf.compat.v1.placeholder(\n",
    "                    tf.float32,\n",
    "                    shape=[None, None, env.observation_space.flat_dim],\n",
    "                    name='obs')\n",
    "\n",
    "\n",
    "        baseline = LinearFeatureBaseline(env_spec=env.spec)\n",
    "\n",
    "        algo = TRPO(env_spec=env.spec,\n",
    "                    policy=policy,\n",
    "                    baseline=baseline,\n",
    "                    max_path_length=hyper_parameters['max_path_length'],\n",
    "                    discount=hyper_parameters['discount'],\n",
    "                    gae_lambda=hyper_parameters['gae_lambda'],\n",
    "                    max_kl_step=hyper_parameters['max_kl'])\n",
    "\n",
    "        # train the algorithm\n",
    "        runner.setup(algo, env)\n",
    "        runner.train(n_epochs=hyper_parameters['n_epochs'],\n",
    "                     batch_size=hyper_parameters['batch_size'])\n",
    "        \n",
    "        # display test results\n",
    "#         result = display_test_result(env, policy, obs_var)\n",
    "        \n",
    "       \n",
    "#     return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-06-17 23:21:19 | Setting seed to 1\n",
      "2020-06-17 23:21:19 | Setting seed to 1\n",
      "WARNING:tensorflow:From /Users/irisliu/src/resl/garage/src/garage/tf/models/mlp.py:84: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /Users/irisliu/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From /Users/irisliu/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:1666: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "WARNING:tensorflow:From /Users/irisliu/src/resl/garage/src/garage/tf/models/model.py:345: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Prefer Variable.assign which has equivalent behavior in 2.X.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/irisliu/src/resl/garage/src/garage/experiment/local_tf_runner.py:182: LoggerWarning: \u001b[33mLog data of type Graph was not accepted by any output\u001b[0m\n",
      "  logger.log(self.sess.graph)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-06-17 23:21:19 | Obtaining samples...\n",
      "2020-06-17 23:21:19 | Obtaining samples...\n",
      "2020-06-17 23:21:19 | epoch #0 | Obtaining samples for iteration 0...\n",
      "2020-06-17 23:21:19 | epoch #0 | Obtaining samples for iteration 0...\n",
      "2020-06-17 23:21:21 | epoch #0 | Logging diagnostics...\n",
      "2020-06-17 23:21:21 | epoch #0 | Logging diagnostics...\n",
      "2020-06-17 23:21:21 | epoch #0 | Optimizing policy...\n",
      "2020-06-17 23:21:21 | epoch #0 | Optimizing policy...\n",
      "2020-06-17 23:21:21 | epoch #0 | Computing loss before\n",
      "2020-06-17 23:21:21 | epoch #0 | Computing loss before\n",
      "2020-06-17 23:21:21 | epoch #0 | Computing KL before\n",
      "2020-06-17 23:21:21 | epoch #0 | Computing KL before\n",
      "2020-06-17 23:21:21 | epoch #0 | Optimizing\n",
      "2020-06-17 23:21:21 | epoch #0 | Optimizing\n",
      "2020-06-17 23:21:21 | epoch #0 | Start CG optimization: #parameters: 1218, #inputs: 992, #subsample_inputs: 992\n",
      "2020-06-17 23:21:21 | epoch #0 | Start CG optimization: #parameters: 1218, #inputs: 992, #subsample_inputs: 992\n",
      "2020-06-17 23:21:21 | epoch #0 | computing loss before\n",
      "2020-06-17 23:21:21 | epoch #0 | computing loss before\n",
      "2020-06-17 23:21:21 | epoch #0 | computing gradient\n",
      "2020-06-17 23:21:21 | epoch #0 | computing gradient\n",
      "2020-06-17 23:21:22 | epoch #0 | gradient computed\n",
      "2020-06-17 23:21:22 | epoch #0 | gradient computed\n",
      "2020-06-17 23:21:22 | epoch #0 | computing descent direction\n",
      "2020-06-17 23:21:22 | epoch #0 | computing descent direction\n",
      "2020-06-17 23:21:22 | epoch #0 | descent direction computed\n",
      "2020-06-17 23:21:22 | epoch #0 | descent direction computed\n",
      "2020-06-17 23:21:23 | epoch #0 | backtrack iters: 0\n",
      "2020-06-17 23:21:23 | epoch #0 | backtrack iters: 0\n",
      "2020-06-17 23:21:23 | epoch #0 | optimization finished\n",
      "2020-06-17 23:21:23 | epoch #0 | optimization finished\n",
      "2020-06-17 23:21:23 | epoch #0 | Computing KL after\n",
      "2020-06-17 23:21:23 | epoch #0 | Computing KL after\n",
      "2020-06-17 23:21:23 | epoch #0 | Computing loss after\n",
      "2020-06-17 23:21:23 | epoch #0 | Computing loss after\n",
      "2020-06-17 23:21:23 | epoch #0 | Fitting baseline...\n",
      "2020-06-17 23:21:23 | epoch #0 | Fitting baseline...\n",
      "2020-06-17 23:21:23 | epoch #0 | Saving snapshot...\n",
      "2020-06-17 23:21:23 | epoch #0 | Saving snapshot...\n",
      "2020-06-17 23:21:23 | epoch #0 | Saved\n",
      "2020-06-17 23:21:23 | epoch #0 | Saved\n",
      "2020-06-17 23:21:23 | epoch #0 | Time 3.63 s\n",
      "2020-06-17 23:21:23 | epoch #0 | Time 3.63 s\n",
      "2020-06-17 23:21:23 | epoch #0 | EpochTime 3.63 s\n",
      "2020-06-17 23:21:23 | epoch #0 | EpochTime 3.63 s\n",
      "---------------------------------------  ---------------\n",
      "EnvExecTime                                  0.382032\n",
      "Evaluation/AverageDiscountedReturn           4.74729\n",
      "Evaluation/AverageReturn                     4.98992\n",
      "Evaluation/CompletionRate                    1\n",
      "Evaluation/Iteration                         0\n",
      "Evaluation/MaxReturn                         5\n",
      "Evaluation/MinReturn                         2\n",
      "Evaluation/NumTrajs                        992\n",
      "Evaluation/StdReturn                         0.134326\n",
      "Extras/EpisodeRewardMean                     5\n",
      "LinearFeatureBaseline/ExplainedVariance     -1.60825e-09\n",
      "PolicyExecTime                               1.18518\n",
      "ProcessExecTime                              0.0848358\n",
      "TotalEnvSteps                            10009\n",
      "policy/Entropy                               6.80431\n",
      "policy/KL                                    0.00966321\n",
      "policy/KLBefore                              0\n",
      "policy/LossAfter                            -0.392955\n",
      "policy/LossBefore                           -0.383091\n",
      "policy/Perplexity                          901.721\n",
      "policy/dLoss                                 0.00986421\n",
      "---------------------------------------  ---------------\n",
      "---------------------------------------  ---------------\n",
      "EnvExecTime                                  0.382032\n",
      "Evaluation/AverageDiscountedReturn           4.74729\n",
      "Evaluation/AverageReturn                     4.98992\n",
      "Evaluation/CompletionRate                    1\n",
      "Evaluation/Iteration                         0\n",
      "Evaluation/MaxReturn                         5\n",
      "Evaluation/MinReturn                         2\n",
      "Evaluation/NumTrajs                        992\n",
      "Evaluation/StdReturn                         0.134326\n",
      "Extras/EpisodeRewardMean                     5\n",
      "LinearFeatureBaseline/ExplainedVariance     -1.60825e-09\n",
      "PolicyExecTime                               1.18518\n",
      "ProcessExecTime                              0.0848358\n",
      "TotalEnvSteps                            10009\n",
      "policy/Entropy                               6.80431\n",
      "policy/KL                                    0.00966321\n",
      "policy/KLBefore                              0\n",
      "policy/LossAfter                            -0.392955\n",
      "policy/LossBefore                           -0.383091\n",
      "policy/Perplexity                          901.721\n",
      "policy/dLoss                                 0.00986421\n",
      "---------------------------------------  ---------------\n",
      "2020-06-17 23:21:23 | epoch #1 | Obtaining samples for iteration 1...\n",
      "2020-06-17 23:21:23 | epoch #1 | Obtaining samples for iteration 1...\n",
      "2020-06-17 23:21:25 | epoch #1 | Logging diagnostics...\n",
      "2020-06-17 23:21:25 | epoch #1 | Logging diagnostics...\n",
      "2020-06-17 23:21:25 | epoch #1 | Optimizing policy...\n",
      "2020-06-17 23:21:25 | epoch #1 | Optimizing policy...\n",
      "2020-06-17 23:21:25 | epoch #1 | Computing loss before\n",
      "2020-06-17 23:21:25 | epoch #1 | Computing loss before\n",
      "2020-06-17 23:21:25 | epoch #1 | Computing KL before\n",
      "2020-06-17 23:21:25 | epoch #1 | Computing KL before\n",
      "2020-06-17 23:21:25 | epoch #1 | Optimizing\n",
      "2020-06-17 23:21:25 | epoch #1 | Optimizing\n",
      "2020-06-17 23:21:25 | epoch #1 | Start CG optimization: #parameters: 1218, #inputs: 995, #subsample_inputs: 995\n",
      "2020-06-17 23:21:25 | epoch #1 | Start CG optimization: #parameters: 1218, #inputs: 995, #subsample_inputs: 995\n",
      "2020-06-17 23:21:25 | epoch #1 | computing loss before\n",
      "2020-06-17 23:21:25 | epoch #1 | computing loss before\n",
      "2020-06-17 23:21:25 | epoch #1 | computing gradient\n",
      "2020-06-17 23:21:25 | epoch #1 | computing gradient\n",
      "2020-06-17 23:21:25 | epoch #1 | gradient computed\n",
      "2020-06-17 23:21:25 | epoch #1 | gradient computed\n",
      "2020-06-17 23:21:25 | epoch #1 | computing descent direction\n",
      "2020-06-17 23:21:25 | epoch #1 | computing descent direction\n",
      "2020-06-17 23:21:26 | epoch #1 | descent direction computed\n",
      "2020-06-17 23:21:26 | epoch #1 | descent direction computed\n",
      "2020-06-17 23:21:26 | epoch #1 | backtrack iters: 0\n",
      "2020-06-17 23:21:26 | epoch #1 | backtrack iters: 0\n",
      "2020-06-17 23:21:26 | epoch #1 | optimization finished\n",
      "2020-06-17 23:21:26 | epoch #1 | optimization finished\n",
      "2020-06-17 23:21:26 | epoch #1 | Computing KL after\n",
      "2020-06-17 23:21:26 | epoch #1 | Computing KL after\n",
      "2020-06-17 23:21:26 | epoch #1 | Computing loss after\n",
      "2020-06-17 23:21:26 | epoch #1 | Computing loss after\n",
      "2020-06-17 23:21:26 | epoch #1 | Fitting baseline...\n",
      "2020-06-17 23:21:26 | epoch #1 | Fitting baseline...\n",
      "2020-06-17 23:21:26 | epoch #1 | Saving snapshot...\n",
      "2020-06-17 23:21:26 | epoch #1 | Saving snapshot...\n",
      "2020-06-17 23:21:26 | epoch #1 | Saved\n",
      "2020-06-17 23:21:26 | epoch #1 | Saved\n",
      "2020-06-17 23:21:26 | epoch #1 | Time 6.44 s\n",
      "2020-06-17 23:21:26 | epoch #1 | Time 6.44 s\n",
      "2020-06-17 23:21:26 | epoch #1 | EpochTime 2.80 s\n",
      "2020-06-17 23:21:26 | epoch #1 | EpochTime 2.80 s\n",
      "---------------------------------------  --------------\n",
      "EnvExecTime                                  0.358832\n",
      "Evaluation/AverageDiscountedReturn           4.75296\n",
      "Evaluation/AverageReturn                     4.99497\n",
      "Evaluation/CompletionRate                    1\n",
      "Evaluation/Iteration                         1\n",
      "Evaluation/MaxReturn                         5\n",
      "Evaluation/MinReturn                         3\n",
      "Evaluation/NumTrajs                        995\n",
      "Evaluation/StdReturn                         0.0837253\n",
      "Extras/EpisodeRewardMean                     4.99\n",
      "LinearFeatureBaseline/ExplainedVariance      0.656556\n",
      "PolicyExecTime                               1.04336\n",
      "ProcessExecTime                              0.0791256\n",
      "TotalEnvSteps                            20013\n",
      "policy/Entropy                               6.81075\n",
      "policy/KL                                    0.00970041\n",
      "policy/KLBefore                              0\n",
      "policy/LossAfter                            -0.0241581\n",
      "policy/LossBefore                           -0.0102119\n",
      "policy/Perplexity                          907.55\n",
      "policy/dLoss                                 0.0139461\n",
      "---------------------------------------  --------------\n",
      "---------------------------------------  --------------\n",
      "EnvExecTime                                  0.358832\n",
      "Evaluation/AverageDiscountedReturn           4.75296\n",
      "Evaluation/AverageReturn                     4.99497\n",
      "Evaluation/CompletionRate                    1\n",
      "Evaluation/Iteration                         1\n",
      "Evaluation/MaxReturn                         5\n",
      "Evaluation/MinReturn                         3\n",
      "Evaluation/NumTrajs                        995\n",
      "Evaluation/StdReturn                         0.0837253\n",
      "Extras/EpisodeRewardMean                     4.99\n",
      "LinearFeatureBaseline/ExplainedVariance      0.656556\n",
      "PolicyExecTime                               1.04336\n",
      "ProcessExecTime                              0.0791256\n",
      "TotalEnvSteps                            20013\n",
      "policy/Entropy                               6.81075\n",
      "policy/KL                                    0.00970041\n",
      "policy/KLBefore                              0\n",
      "policy/LossAfter                            -0.0241581\n",
      "policy/LossBefore                           -0.0102119\n",
      "policy/Perplexity                          907.55\n",
      "policy/dLoss                                 0.0139461\n",
      "---------------------------------------  --------------\n"
     ]
    }
   ],
   "source": [
    "# result = test_gym_environment(env, ctxt)\n",
    "test_gym_environment(env, ctxt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display test result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 378
    },
    "colab_type": "code",
    "id": "LfG7tMG4LbtF",
    "outputId": "c6356fd6-0e95-46dd-e335-8021bd9df25e"
   },
   "outputs": [],
   "source": [
    "# # pandas test results\n",
    "# ! pip install pandas\n",
    "# import pandas as pd\n",
    "# df = pd.DataFrame(result)\n",
    "# pd.DataFrame({\n",
    "#     'score': df.groupby('exp')['rew'].sum(),\n",
    "#     'nstep': df.groupby('exp')['rew'].count()\n",
    "# })"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "custom_env.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
