{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 03 - Continual Learning\n",
    "\n",
    "It is common that a vanilla learning from scratch is interrupted and to be carried on. For such, we can start a continual learning process based on the pretrained model.  \n",
    "To demonstrate, we will start from where we left off in the previous tutorial `Tutorial 02 - Vanilla Learning`. You will see in the following:\n",
    "* how to load a pretrained model and a saved environment wrapper\n",
    "* how to carry out a continual learning\n",
    "\n",
    "Note that all funcionalities/code snippets from this tutorial are provided in `./commonroad_rl/train_model.py`. Please use the full functionality of `train_model.py` to perform your other experiments."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Preparation\n",
    "\n",
    "Please make sure `Tutorial 02 - Vanilla Learning` has been followed correctly and the followings:\n",
    "* current path is at `commonroad-rl/commonroad_rl`, i.e. one upper layer to the `tutorials` folder\n",
    "* interactive python kernel is triggered from the correct environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check current path\n",
    "%cd ..\n",
    "%pwd\n",
    "\n",
    "# Check interactive python kernel\n",
    "import sys\n",
    "sys.executable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Reuse configurations and hyperparameters\n",
    "\n",
    "Again, before the learning begins, both the RL environment and the RL model have to be set. We simply reuse the environment configurations and model hyperparameters which are saved previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "import copy\n",
    "\n",
    "log_path = \"tutorials/logs/\"\n",
    "\n",
    "# Read in environment configurations\n",
    "env_configs = {}\n",
    "with open(os.path.join(log_path, \"environment_configurations.yml\"), \"r\") as config_file:\n",
    "    env_configs = yaml.safe_load(config_file)\n",
    "\n",
    "# Read in model hyperparameters\n",
    "hyperparams = {}\n",
    "with open(os.path.join(log_path, \"model_hyperparameters.yml\"), \"r\") as hyperparam_file:\n",
    "    hyperparams = yaml.safe_load(hyperparam_file)\n",
    "\n",
    "# Deduce `policy` from the pretrained model\n",
    "if \"policy\" in hyperparams:\n",
    "    del hyperparams[\"policy\"]\n",
    "    \n",
    "# Remove `normalize` as it will be handled explicitly later\n",
    "if \"normalize\" in hyperparams:\n",
    "    del hyperparams[\"normalize\"]\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create a training and testing environment\n",
    "\n",
    "Similar to the previous tutorial, we create a training and testing environment for the continual learning process. However, note that we need not to explicitly wrap the vectorized training environment with the `VecNormalize` wrapper because we will later load it from the saved file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "from stable_baselines.bench import Monitor\n",
    "from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize\n",
    "from stable_baselines.common.callbacks import BaseCallback, EvalCallback\n",
    "\n",
    "import commonroad_rl.gym_commonroad\n",
    "\n",
    "# Create a Gym-based RL environment with specified data paths and environment configurations\n",
    "meta_scenario_path = \"tutorials/data/highD/pickles/meta_scenario\"\n",
    "training_data_path = \"tutorials/data/highD/pickles/problem_train\"\n",
    "training_env = gym.make(\"commonroad-v1\", \n",
    "                        meta_scenario_path=meta_scenario_path,\n",
    "                        train_reset_config_path= training_data_path,\n",
    "                        **env_configs)\n",
    "\n",
    "# Wrap the environment with a monitor to keep an record of the learning process\n",
    "info_keywords=tuple([\"is_collision\", \\\n",
    "                     \"is_time_out\", \\\n",
    "                     \"is_off_road\", \\\n",
    "                     \"is_friction_violation\", \\\n",
    "                     \"is_goal_reached\"])\n",
    "training_env = Monitor(training_env, log_path + \"0\", info_keywords=info_keywords)\n",
    "\n",
    "# Vectorize the environment with a callable argument\n",
    "def make_training_env():\n",
    "    return training_env\n",
    "training_env = DummyVecEnv([make_training_env])\n",
    "\n",
    "# Append the additional key for the testing environment\n",
    "env_configs_test = copy.deepcopy(env_configs)\n",
    "env_configs_test[\"test_env\"] = True\n",
    "\n",
    "# Create the testing environment\n",
    "testing_data_path = \"tutorials/data/highD/pickles/problem_test\"\n",
    "testing_env = gym.make(\"commonroad-v1\", \n",
    "                        meta_scenario_path=meta_scenario_path,\n",
    "                        test_reset_config_path= testing_data_path,\n",
    "                        **env_configs_test)\n",
    "\n",
    "# Wrap the environment with a monitor to keep an record of the testing episodes \n",
    "log_path_test = \"tutorials/logs/test\"\n",
    "os.makedirs(log_path_test, exist_ok=True)\n",
    "\n",
    "testing_env = Monitor(testing_env, log_path_test + \"/0\", info_keywords=info_keywords)\n",
    "\n",
    "# Vectorize the environment with a callable argument\n",
    "def make_testing_env():\n",
    "    return testing_env\n",
    "testing_env = DummyVecEnv([make_testing_env])\n",
    "\n",
    "# Normalize only observations during testing\n",
    "testing_env = VecNormalize(testing_env, norm_obs=True, norm_reward=False)\n",
    "\n",
    "# Define a customized callback function to save the vectorized and normalized training environment\n",
    "class SaveVecNormalizeCallback(BaseCallback):\n",
    "    def __init__(self, save_path: str, verbose=1):\n",
    "        super(SaveVecNormalizeCallback, self).__init__(verbose)\n",
    "        self.save_path = save_path\n",
    "        \n",
    "    def _init_callback(self) -> None:\n",
    "        if self.save_path is not None:\n",
    "            os.makedirs(self.save_path, exist_ok=True)\n",
    "    \n",
    "    def _on_step(self) -> bool:\n",
    "        save_path_name = os.path.join(self.save_path, \"vecnormalize.pkl\")\n",
    "        self.model.get_vec_normalize_env().save(save_path_name)\n",
    "        print(\"Saved vectorized and normalized environment to {}\".format(save_path_name))\n",
    "    \n",
    "# Pass the testing environment and customized saving callback to an evaluation callback\n",
    "# Note that the evaluation callback will triggers three evaluating episodes after every 500 training steps\n",
    "save_vec_normalize_callback = SaveVecNormalizeCallback(save_path=log_path)\n",
    "eval_callback = EvalCallback(testing_env, \n",
    "                             log_path=log_path, \n",
    "                             eval_freq=500, \n",
    "                             n_eval_episodes=3, \n",
    "                             callback_on_new_best=save_vec_normalize_callback)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Load the saved environment wrapper and pretrained agent\n",
    "\n",
    "Now we load the saved environment wrapper and the pretrained agent for the continual learning process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines.common.vec_env import VecNormalize\n",
    "from stable_baselines import PPO2\n",
    "\n",
    "# Load saved environment\n",
    "training_env = VecNormalize.load(\"tutorials/logs/vecnormalize.pkl\", training_env)\n",
    "\n",
    "# Load pretrained model\n",
    "model_continual = PPO2.load(\"tutorials/logs/intermediate_model\", env=training_env, **hyperparams)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Conduct learning and save results\n",
    "Finally, we trigger the continual learning process in a similar fashion as before, and save the final model as well as the environment wrapper for later inspection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set learning steps and trigger learning with the evaluation callback\n",
    "n_timesteps=5000\n",
    "model_continual.learn(n_timesteps, eval_callback)\n",
    "\n",
    "# Save the continual-learning model\n",
    "# Note that we use the name \"best_model\" here as it will be fetched in the next tutorials\n",
    "model_continual.save(\"tutorials/logs/best_model\")\n",
    "model_continual.get_vec_normalize_env().save(\"tutorials/logs/vecnormalize.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python (cr37_release)",
   "language": "python",
   "name": "cr37-release"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
