{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Tree weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'cdt'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-22640e19339c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"..\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPPO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/research/Explainability/XRL_BorealisAI/src/rl/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mPPO\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPPO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_wrapper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mStateNormWrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/research/Explainability/XRL_BorealisAI/src/rl/PPO.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"..\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcdt\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCDT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msdt\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSDT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cdt'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "import json\n",
    "import torch\n",
    "import gym\n",
    "sys.path.append(\"..\")\n",
    "from src.rl import PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/quantumiracle/.conda/envs/robo/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.\n",
      "  result = entry_point.load(False)\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'PPO' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-b9e7e0a2bfac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0mrl_confs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_file\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# hyperparameters for il training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n\u001b[0m\u001b[1;32m     13\u001b[0m             **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n\u001b[1;32m     14\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'PPO' is not defined"
     ]
    }
   ],
   "source": [
    "EnvName = 'MountainCar-v0'\n",
    "m = 'cdt'\n",
    "\n",
    "env = gym.make(EnvName).unwrapped\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.n  # discrete\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_rl_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    rl_confs = json.load(read_file)  # hyperparameters for il training\n",
    "\n",
    "model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n",
    "            **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n",
    "i=0\n",
    "model_path = rl_confs[EnvName][\"train_confs\"][\"model_path\"]+str(i)\n",
    "model.load_model(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
