{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hyyN-2qyK_T2",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# PPO for Cart Pole and Mountain Car\n",
    "\n",
    "# Introduction\n",
    "This notebook contains the PPO agent used in the paper \"Go-Explore with a guide: Speeding up search in sparse reward domains with goal-directed intrinsic rewards\"\n",
    "\n",
    "The PPO agent was implemented using Stable-Baselines 3, and this notebook was modified from one of Stable-Baselines 3's tutorial notebooks.\n",
    "\n",
    "Github repo: https://github.com/araffin/rl-tutorial-jnrr19\n",
    "\n",
    "Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3\n",
    "\n",
    "It also contains the 2 continuous state environments used:\n",
    "- Cart Pole\n",
    "- Mountain Car"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gWskDE2c9WoN",
    "outputId": "b566c4e7-7986-4ca4-e5af-e5e552a6d800"
   },
   "outputs": [],
   "source": [
    "!apt-get install ffmpeg freeglut3-dev xvfb  # For visualization\n",
    "!pip install stable-baselines3[extra]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "U29X1-B-AIKE",
    "outputId": "8071276b-11c4-4317-ba55-7f375f30615f"
   },
   "outputs": [],
   "source": [
    "import stable_baselines3\n",
    "stable_baselines3.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FtY8FhliLsGm"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gcX8hEcaUpR0"
   },
   "source": [
    "Stable-Baselines works on environments that follow the [gym interface](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html).\n",
    "You can find a list of available environment [here](https://gym.openai.com/envs/#classic_control).\n",
    "\n",
    "It is also recommended to check the [source code](https://github.com/openai/gym) to learn more about the observation and action space of each env, as gym does not have a proper documentation.\n",
    "Not all algorithms can work with all action spaces, you can find more in this [recap table](https://stable-baselines.readthedocs.io/en/master/guide/algos.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BIedd7Pz9sOs"
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ae32CtgzTG3R"
   },
   "source": [
    "The first thing you need to import is the RL model, check the documentation to know what you can use on which problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R7tKaBFrTR0a"
   },
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-0_8OQbOTTNT"
   },
   "source": [
    "The next thing you need to import is the policy class that will be used to create the networks (for the policy/value functions).\n",
    "This step is optional as you can directly use strings in the constructor: \n",
    "\n",
    "```PPO('MlpPolicy', env)``` instead of ```PPO(MlpPolicy, env)```\n",
    "\n",
    "Note that some algorithms like `SAC` have their own `MlpPolicy`, that's why using string for the policy is the recommened option."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ROUJr675TT01"
   },
   "outputs": [],
   "source": [
    "# from stable_baselines3.ppo import MlpPolicy\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xENwTP_fNcNS",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Helper functions\n",
    "- FirstSolve: Determine first solve run of an environment after running model\n",
    "- ManyRuns: Compute the first solve run across a few evaluation trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oku1b7kgAZgr"
   },
   "outputs": [],
   "source": [
    "def FirstSolve(env, seed = None):\n",
    "  model = None\n",
    "  model = PPO('MlpPolicy', copy.deepcopy(env), verbose=0)\n",
    "  \n",
    "  # set the random seed\n",
    "  if seed is not None:\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "  # Keep training the agent for 200 steps, then evaluate\n",
    "  solved = False\n",
    "  for i in range(100):\n",
    "    model.learn(total_timesteps=200)\n",
    "    mean_reward, std_reward = evaluate_policy(model, copy.deepcopy(env), n_eval_episodes=1)\n",
    "    # print(f\"Run {i+1}: mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")\n",
    "    if mean_reward == 1: \n",
    "      solved = True\n",
    "      break\n",
    "\n",
    "  if solved:\n",
    "    print(f'Solved at Run {i+1}')\n",
    "    return i+1\n",
    "  else:\n",
    "    print('Unsolved')\n",
    "    return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8C9iDEdRYnF4"
   },
   "outputs": [],
   "source": [
    "def ManyRuns(env, trials=5):\n",
    "  solvedlist = []\n",
    "  for i in range(trials):\n",
    "    trial_num = FirstSolve(copy.deepcopy(env), i)\n",
    "    if trial_num > 0:\n",
    "      solvedlist.append(run_num)\n",
    "\n",
    "  if len(solvedlist) > 0:\n",
    "    avgsolve = sum(solvedlist)/len(solvedlist)\n",
    "    print(f'Average solved run number is {avgsolve}')\n",
    "    print(*solvedlist, sep = ' & ', end = ' & '+str(avgsolve))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B8KlklgHNmcw",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Cart Pole"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VGJGnUPj1i7y"
   },
   "outputs": [],
   "source": [
    "class CartPoleEnv:\n",
    "    metadata = {'render.modes': ['human']}\n",
    "    def __init__(self, env_name = 'CartPole-v0', goal_steps = 50, start_state = None, normalizer = np.array([0.01, 0.001, 0.01, 0.001]), cap = np.array([42])):\n",
    "        self.name = env_name\n",
    "        self.numsteps = 0\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.start_state = start_state\n",
    "        self.goal_steps = goal_steps\n",
    "        self.normalizer = normalizer\n",
    "        self.cap = cap\n",
    "        self.env = gym.make(env_name)\n",
    "        self.env.reset()\n",
    "        self.action_space = self.env.action_space\n",
    "        self.observation_space = self.env.observation_space\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "      \n",
    "    def render(self):\n",
    "        return self.env.render()\n",
    "            \n",
    "    def reset(self):\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.numsteps = 0\n",
    "        return self.staterep()\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        if self.normalizer is not None:\n",
    "            return ((self.env.state)//self.normalizer).clip(-self.cap, self.cap)\n",
    "        else:\n",
    "            return self.env.state\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        return list(range(self.env.action_space.n))\n",
    "    \n",
    "    def step(self, move):\n",
    "        if self.done:\n",
    "            return self.staterep(), self.reward, self.done, {}\n",
    "    \n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        \n",
    "        # do your move\n",
    "        state, reward, done, _ = self.env.step(move)\n",
    "        \n",
    "        # only if done at step 175 is considered success\n",
    "        if done:\n",
    "            self.done = True\n",
    "            if self.numsteps > self.goal_steps:\n",
    "                self.reward = 1\n",
    "            else:\n",
    "                self.reward = 0\n",
    "            self.env.close()\n",
    "\n",
    "        return self.staterep(), self.reward, done, {}\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.env.state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K78eFR0xBef8",
    "tags": []
   },
   "source": [
    "## Discrete states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "llMFRHSVYcTw"
   },
   "outputs": [],
   "source": [
    "TRIAL_NUM = 10 # this tells us how many times to evaluate FirstSolve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "pUWGZp3i9wyf",
    "outputId": "31a7e1fe-f091-4af0-ccb9-393a33d25523"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 50, start_state = np.array([0, 0, 0, 0]))\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OdBEMHjI9tl9",
    "outputId": "57e9445b-ee9f-4417-87fa-e7df24325565"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 100, start_state = np.array([0, 0, 0, 0]))\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ynQslLQG-QrC",
    "outputId": "86e0d013-52f4-470b-dadb-7be5da68b3e4"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 175, start_state = np.array([0, 0, 0, 0]))\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fMFYkgPsXt94"
   },
   "source": [
    "## Continuous States"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q4-1Z-5EXtYA"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 50, start_state = np.array([0, 0, 0, 0]), normalizer = None)\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZQb7R7uwXxUK"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 100, start_state = np.array([0, 0, 0, 0]), normalizer = None)\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HsMKTKEkX5kZ"
   },
   "outputs": [],
   "source": [
    "env = CartPoleEnv(goal_steps = 175, start_state = np.array([0, 0, 0, 0]), normalizer = None)\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mJYpJ17dNpvX",
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Mountain Car"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VPKNhoSC-WEM"
   },
   "outputs": [],
   "source": [
    "class MountainCarEnv:\n",
    "    metadata = {'render.modes': ['human']}\n",
    "    def __init__(self, env_name = 'MountainCar-v0', start_state = None, normalizer = np.array([0.01, 0.001]), cap = np.array([50])):\n",
    "        self.env_name = env_name\n",
    "        self.numsteps = 0\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.start_state = start_state\n",
    "        self.normalizer = normalizer\n",
    "        self.cap = cap\n",
    "        self.env = gym.make(env_name)\n",
    "        self.env.reset()\n",
    "        self.action_space = self.env.action_space\n",
    "        self.observation_space = self.env.observation_space\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        \n",
    "    def render(self):\n",
    "        return self.env.render()\n",
    "\n",
    "    def reset(self):\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.numsteps = 0\n",
    "        return self.staterep()\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        if self.normalizer is not None:\n",
    "            return ((self.env.state)//self.normalizer).clip(-self.cap, self.cap)\n",
    "        else:\n",
    "            return self.env.state\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        return list(range(self.env.action_space.n))\n",
    "    \n",
    "    def step(self, move):\n",
    "        if self.done:\n",
    "            return self.staterep(), self.reward, self.done, {}\n",
    "    \n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        \n",
    "        # do your move\n",
    "        state, reward, done, _ = self.env.step(move)\n",
    "        \n",
    "        # only if past x=0.5, then it is success\n",
    "        if done:\n",
    "            self.done = True\n",
    "            if self.env.state[0] >= 0.5:\n",
    "                self.reward = 1\n",
    "            else:\n",
    "                self.reward = 0\n",
    "            self.env.close()\n",
    "\n",
    "        return self.staterep(), self.reward, done, {}\n",
    "        \n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.env.state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qD3PIzwXYK_P"
   },
   "source": [
    "## Discrete State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LTpM-ZNxOsDg"
   },
   "outputs": [],
   "source": [
    "env = MountainCarEnv(start_state = np.array([-0.5, 0]))\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EbrTR-tCYMzo"
   },
   "source": [
    "## Continuous State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kapf4ap0OxDj"
   },
   "outputs": [],
   "source": [
    "env = MountainCarEnv(start_state = np.array([-0.5, 0]), normalizer = None)\n",
    "ManyRuns(env, TRIAL_NUM)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
