{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Tree weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "import json\n",
    "import torch\n",
    "import gym\n",
    "from rl import PPO\n",
    "\n",
    "def softmax(x):\n",
    "    \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n",
    "    e_x = np.exp(x - np.max(x))\n",
    "    return e_x / e_x.sum()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RL Data Norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
      "CDT parameters:  {'num_intermediate_variables': 1, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 2, 'output_dim': 3, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 128, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'episodes': 40, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0}\n",
      "../data/cdt/model/mountaincar/rl_ppo0\n",
      "OrderedDict([('policy.fl_leaf_weights', tensor([[ 1.2682, -8.0616],\n",
      "        [ 4.4988,  1.1245],\n",
      "        [ 0.7746, -8.9235],\n",
      "        [-1.1001, -2.5657]], device='cuda:0')), ('policy.dc_leaves', tensor([[-1.5820, -1.6395,  2.2960],\n",
      "        [-0.4989,  2.3197, -0.8824],\n",
      "        [ 1.1209, -0.5438, -2.6799],\n",
      "        [ 1.1128,  2.1277,  0.7750]], device='cuda:0')), ('policy.fl_inner_nodes.weight', tensor([[  1.6196,   3.5666,  10.0695],\n",
      "        [ -2.9400,  -4.6343, -11.3407],\n",
      "        [  0.9527,   1.9745,   4.4633]], device='cuda:0')), ('policy.dc_inner_nodes.weight', tensor([[-0.3107, -3.2153],\n",
      "        [ 0.1307, -4.4850],\n",
      "        [ 0.5900,  3.3418]], device='cuda:0')), ('fc1.weight', tensor([[-4.4777, -8.2504],\n",
      "        [ 2.0882,  0.3353],\n",
      "        [ 0.8424,  2.1979],\n",
      "        [ 0.3132,  0.6510],\n",
      "        [-0.3978, -2.0470],\n",
      "        [ 0.3904, -0.1952],\n",
      "        [ 6.3024, 12.6843],\n",
      "        [ 5.9705, -4.7382],\n",
      "        [ 0.0522, -1.2685],\n",
      "        [ 0.3746, -0.4909],\n",
      "        [ 0.3401,  2.5167],\n",
      "        [ 0.2090,  7.5001],\n",
      "        [ 0.6811,  0.2857],\n",
      "        [ 0.0606,  0.0221],\n",
      "        [ 2.5765,  1.8656],\n",
      "        [ 0.8402, 11.6315],\n",
      "        [-0.4691,  0.0974],\n",
      "        [ 0.3507, -0.3691],\n",
      "        [ 2.2504,  1.2020],\n",
      "        [ 0.0874, -0.2258],\n",
      "        [-0.2958,  0.4226],\n",
      "        [ 1.4203, -0.0475],\n",
      "        [-0.4518,  0.5194],\n",
      "        [ 2.1986,  0.8678],\n",
      "        [-0.6884, -1.6875],\n",
      "        [ 1.3956, 13.2145],\n",
      "        [ 2.1777,  0.5697],\n",
      "        [-4.6204, -7.3229],\n",
      "        [ 0.5169,  0.3833],\n",
      "        [-2.9755, -5.6317],\n",
      "        [ 0.1996,  7.4095],\n",
      "        [ 2.3578,  1.8206]], device='cuda:0')), ('fc1.bias', tensor([-1.0527,  2.0826,  1.1687, -0.5016,  0.7829, -0.2217, -1.4266,  1.1556,\n",
      "        -0.0946, -0.4790,  1.6529, -0.5404,  1.7411, -0.1335,  2.2626, -0.3360,\n",
      "        -0.6811, -0.4346,  2.2304, -0.3120, -0.5181,  1.7910, -0.5590,  2.0594,\n",
      "         0.7365, -0.4162,  2.1807, -0.8551, -0.5428, -0.8447, -0.5413,  1.8485],\n",
      "       device='cuda:0')), ('fc_v.weight', tensor([[ 5.0479,  1.0550,  1.8517, -0.1374, -0.5934, -0.1005, 11.7922, -6.0217,\n",
      "          0.0245,  0.1176, -0.9430,  4.7065, -0.5778,  0.0307,  1.4730,  9.6253,\n",
      "         -0.0499, -0.1212,  1.7044,  0.0676,  0.0753, -0.0617,  0.1454,  1.5756,\n",
      "         -0.5214, 12.0224,  1.1129,  4.9488, -0.0204,  2.6058,  4.4900,  1.4710]],\n",
      "       device='cuda:0')), ('fc_v.bias', tensor([0.7431], device='cuda:0'))])\n",
      "[0.01989198 0.01878024 0.9613278 ]\n",
      "[0.05424495 0.9087906  0.03696447]\n",
      "[0.8253573  0.15619408 0.01844862]\n",
      "[0.22359638 0.6169009  0.1595027 ]\n"
     ]
    }
   ],
   "source": [
    "EnvName = 'MountainCar-v0'\n",
    "m = 'cdt'\n",
    "\n",
    "env = gym.make(EnvName).unwrapped\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.n  # discrete\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_rl_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    rl_confs = json.load(read_file)  # hyperparameters for il training\n",
    "\n",
    "model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n",
    "            **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n",
    "i=0\n",
    "model_path = rl_confs[EnvName][\"train_confs\"][\"model_path\"]+str(i)\n",
    "print(model_path)\n",
    "model.load_model(model_path)\n",
    "print(model.state_dict())\n",
    "\n",
    "for w in model.state_dict()['policy.dc_leaves'].detach().cpu().numpy():\n",
    "    print(softmax(w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'episodes': 40, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0}\n",
      "../data/cdt_compare_depth/model/cartpole/rl_ppo_12_1\n",
      "[0.00731772 0.9926823 ]\n",
      "[0.7414794  0.25852057]\n",
      "[0.9659206  0.03407942]\n",
      "[0.7439119  0.25608805]\n",
      "OrderedDict([('policy.fl_leaf_weights', tensor([[ 0.5395,  0.1492,  2.6079,  2.6178],\n",
      "        [ 0.4740,  0.0534,  3.5718,  0.7234],\n",
      "        [-0.2066,  0.5147,  3.6235,  3.6063],\n",
      "        [ 0.2708,  2.5149,  3.1295,  1.0171]], device='cuda:0')), ('policy.dc_leaves', tensor([[-2.7513,  2.1589],\n",
      "        [ 0.1815, -0.8722],\n",
      "        [ 1.3746, -1.9698],\n",
      "        [ 1.3319,  0.2655]], device='cuda:0')), ('policy.fl_inner_nodes.weight', tensor([[-1.2752,  0.2580,  0.0931,  0.0452,  0.3635]], device='cuda:0')), ('policy.dc_inner_nodes.weight', tensor([[ 0.2481,  2.3432,  1.6599],\n",
      "        [ 1.1299,  1.9963,  1.3973],\n",
      "        [ 0.9243, -1.7002, -0.7125]], device='cuda:0')), ('fc1.weight', tensor([[ 0.1701,  0.3506,  0.7501,  0.2310],\n",
      "        [ 0.2816,  0.2305,  0.0112, -0.1559],\n",
      "        [-0.1952,  0.3126, -0.4049, -0.3921],\n",
      "        [ 0.2524, -0.0509, -0.3769, -0.0882],\n",
      "        [-0.4032,  0.0787,  0.4864, -0.2674],\n",
      "        [-0.2842,  0.4338,  0.4017, -0.2596],\n",
      "        [-0.3031, -0.1364, -0.3194,  0.1112],\n",
      "        [-0.4517,  0.0658,  0.3630, -0.0548],\n",
      "        [ 0.6397, -0.2563, -0.3042, -0.2061],\n",
      "        [ 0.0874,  0.2903,  0.4961, -0.1798],\n",
      "        [ 0.0335, -0.2052,  0.3994, -0.3340],\n",
      "        [ 0.3501,  0.2419,  0.7207,  0.2007],\n",
      "        [-0.4308, -0.4150, -0.4409, -0.4072],\n",
      "        [-0.3375, -0.3676, -0.0264, -0.1924],\n",
      "        [ 0.0256, -0.2909, -0.4180,  0.3518],\n",
      "        [-0.2088, -0.4678, -0.3034, -0.0503],\n",
      "        [-0.1336,  0.1257, -0.3175,  0.1502],\n",
      "        [-0.1981,  0.3473,  0.4018, -0.1190],\n",
      "        [-0.1237, -0.2484, -0.3284, -0.2381],\n",
      "        [ 0.1575, -0.2604,  0.0651, -0.0396],\n",
      "        [-0.2581, -0.2529, -0.3187,  0.0639],\n",
      "        [ 0.0274,  0.4458, -0.4588,  0.2915],\n",
      "        [-0.1787,  0.3036, -0.0640,  0.2442],\n",
      "        [-0.0392,  0.3565, -0.1600,  0.2868],\n",
      "        [-0.3125,  0.0928,  0.4402, -0.4276],\n",
      "        [ 0.2399, -0.1450, -0.4408,  0.1615],\n",
      "        [-0.2335, -0.1637, -0.0515,  0.1973],\n",
      "        [ 0.0224,  0.2929,  0.1676, -0.4284],\n",
      "        [-0.0072, -0.2639,  0.2564,  0.1993],\n",
      "        [ 0.0253, -0.3916,  0.2145, -0.2133],\n",
      "        [-0.3304, -0.3844, -0.1701, -0.2983],\n",
      "        [-0.3757, -0.3640,  0.4809, -0.1197],\n",
      "        [ 0.2436,  0.4491,  0.2426, -0.1501],\n",
      "        [-0.3130,  0.4223, -0.0913, -0.3265],\n",
      "        [-0.0800,  0.1984,  0.3321, -0.3760],\n",
      "        [-0.0604,  0.2231,  0.3347, -0.0662],\n",
      "        [ 0.3106,  0.1898,  0.0787, -0.4053],\n",
      "        [ 0.1288,  0.1084, -0.2789, -0.1490],\n",
      "        [-0.3366, -0.2422,  0.1718, -0.1405],\n",
      "        [-0.3989, -0.2642,  0.3693,  0.3786],\n",
      "        [ 0.3157, -0.1871, -0.3169, -0.3130],\n",
      "        [ 0.0447,  0.2781,  0.0482, -0.2327],\n",
      "        [-0.0984, -0.2288,  0.3463, -0.0489],\n",
      "        [ 0.4157, -0.4146,  0.5275, -0.1221],\n",
      "        [-0.1694, -0.4257,  0.2579, -0.1577],\n",
      "        [ 0.2570,  0.4116, -0.4864,  0.2940],\n",
      "        [-0.2551,  0.3359,  0.3267, -0.0686],\n",
      "        [-0.5144,  0.2599,  0.4441,  0.1169],\n",
      "        [ 0.4988,  0.0685, -0.2549, -0.1782],\n",
      "        [-0.1873,  0.0685, -0.9365,  0.0346],\n",
      "        [ 0.4326, -0.4119, -0.5929, -0.2788],\n",
      "        [-0.3721, -0.4470, -0.3640,  0.2235],\n",
      "        [-0.0603, -0.1171, -0.3238,  0.0436],\n",
      "        [ 0.3319,  0.0038,  0.1052,  0.1072],\n",
      "        [ 0.2968, -0.4278,  0.1511, -0.2196],\n",
      "        [ 0.3141, -0.4167, -0.0353, -0.4084],\n",
      "        [-0.0515,  0.2537,  0.5439,  0.4385],\n",
      "        [ 0.2601,  0.0910, -0.3613,  0.0493],\n",
      "        [ 0.3813,  0.3372, -0.2163,  0.1010],\n",
      "        [ 0.3752,  0.1222,  0.4600,  0.0887],\n",
      "        [-0.2694,  0.1550, -0.3204,  0.1995],\n",
      "        [ 0.3384, -0.2629,  0.2917,  0.3213],\n",
      "        [ 0.3878,  0.4003,  0.6189,  0.2807],\n",
      "        [-0.0525,  0.1555,  0.1949, -0.0905],\n",
      "        [-0.5592,  0.4493,  0.1591,  0.3570],\n",
      "        [ 0.3051,  0.0641,  0.4374,  0.3211],\n",
      "        [-0.1022, -0.3521,  0.4704,  0.0209],\n",
      "        [-0.1136, -0.3096,  0.1692, -0.1632],\n",
      "        [-0.2168,  0.2068, -0.7841, -0.3658],\n",
      "        [-0.3591,  0.4159,  0.4503,  0.2402],\n",
      "        [ 0.1412,  0.3838,  0.4739,  0.3964],\n",
      "        [ 0.1097, -0.3981,  0.3219,  0.2770],\n",
      "        [ 0.2540,  0.3550, -0.2674,  0.2779],\n",
      "        [-0.2947, -0.3613,  0.0065, -0.1210],\n",
      "        [-0.1798,  0.4989,  0.5653,  0.3789],\n",
      "        [ 0.2510, -0.3040, -0.0211, -0.3819],\n",
      "        [-0.1404,  0.0338, -0.0479,  0.3557],\n",
      "        [ 0.2191, -0.3362,  0.1931,  0.2853],\n",
      "        [-0.2598,  0.3563,  0.1274, -0.3546],\n",
      "        [-0.0589, -0.2862, -0.1856, -0.0153],\n",
      "        [-0.3807, -0.0412, -0.0817,  0.3627],\n",
      "        [ 0.3615,  0.4385,  0.2661, -0.3062],\n",
      "        [ 0.3122, -0.3682, -0.1418,  0.4236],\n",
      "        [ 0.1254,  0.2874,  0.0149,  0.2917],\n",
      "        [ 0.5055,  0.2698,  0.4261,  0.3242],\n",
      "        [ 0.1523, -0.1694, -0.3527,  0.4551],\n",
      "        [-0.2849,  0.1694, -0.1982, -0.2810],\n",
      "        [-0.4422, -0.1147,  0.4069, -0.2726],\n",
      "        [ 0.2976,  0.4060,  0.1853, -0.1360],\n",
      "        [ 0.2515, -0.3956,  0.0654,  0.4206],\n",
      "        [-0.1380, -0.3701,  0.3548, -0.1998],\n",
      "        [ 0.1514,  0.0378, -0.3996,  0.2506],\n",
      "        [-0.2372, -0.2047,  0.5520,  0.1533],\n",
      "        [-0.3529,  0.1703,  0.2409, -0.0867],\n",
      "        [ 0.1005, -0.3687,  0.1520, -0.2527],\n",
      "        [ 0.2091,  0.3161, -0.0192,  0.1321],\n",
      "        [ 0.1923,  0.3066,  0.1255, -0.1501],\n",
      "        [ 0.2753, -0.1281, -0.1687,  0.1971],\n",
      "        [-0.2483,  0.0616, -0.1527,  0.4190],\n",
      "        [-0.0882,  0.2721,  0.3519, -0.1743],\n",
      "        [ 0.2641, -0.0620, -0.6809, -0.3494],\n",
      "        [-0.3483, -0.2013,  0.1587,  0.4265],\n",
      "        [ 0.5133, -0.2154,  0.1154, -0.1476],\n",
      "        [-0.0319,  0.2117, -0.4405,  0.4296],\n",
      "        [ 0.2077,  0.2912,  0.3001,  0.1206],\n",
      "        [-0.4873,  0.0077,  0.2479,  0.2782],\n",
      "        [ 0.0117,  0.0181, -0.3077, -0.0205],\n",
      "        [ 0.0989, -0.1827,  0.4554,  0.3935],\n",
      "        [-0.4567,  0.1205,  0.6141,  0.1367],\n",
      "        [-0.3141, -0.2242,  0.2074, -0.3686],\n",
      "        [ 0.2959,  0.0989, -0.4390, -0.0380],\n",
      "        [ 0.0872,  0.3045,  0.3351, -0.1921],\n",
      "        [ 0.4170,  0.1091, -0.3593,  0.2646],\n",
      "        [-0.6040,  0.0683, -0.0171, -0.2032],\n",
      "        [ 0.3311,  0.0298,  0.1399, -0.0294],\n",
      "        [-0.2551, -0.3529, -0.3985, -0.0615],\n",
      "        [-0.4059,  0.2991, -0.0719,  0.2268],\n",
      "        [ 0.0666, -0.2651,  0.3114, -0.1491],\n",
      "        [-0.2418, -0.3530, -0.5751, -0.2926],\n",
      "        [ 0.5571, -0.2162, -0.3675, -0.2702],\n",
      "        [ 0.2715,  0.0517, -0.5230, -0.4113],\n",
      "        [ 0.0064, -0.3831, -0.0778, -0.2490],\n",
      "        [ 0.1341,  0.4649, -0.1436,  0.2838],\n",
      "        [-0.0957, -0.4000, -0.3819, -0.2409],\n",
      "        [-0.0119, -0.3050, -0.1573, -0.0399],\n",
      "        [ 0.1760,  0.1965,  0.4153, -0.1745],\n",
      "        [-0.2709, -0.2370,  0.0104,  0.3885],\n",
      "        [-0.2364,  0.2236,  0.3657,  0.3981]], device='cuda:0')), ('fc1.bias', tensor([-0.3394,  0.0044, -0.3285, -0.1163, -0.0411, -0.4824, -0.5538,  0.0252,\n",
      "         0.0277,  0.4844,  0.2934, -0.0615, -0.3490,  0.0471,  0.3781, -0.4223,\n",
      "        -0.2258, -0.1901, -0.2485, -0.4559,  0.2481, -0.2215,  0.2132, -0.0543,\n",
      "         0.1858,  0.5079, -0.1232,  0.3250,  0.4806,  0.2990, -0.0255, -0.4912,\n",
      "        -0.4955, -0.3594, -0.3035, -0.4770, -0.2914,  0.4048, -0.3858,  0.3932,\n",
      "        -0.0203,  0.3135, -0.3557, -0.3594, -0.1764,  0.4031, -0.3906, -0.1450,\n",
      "        -0.0520, -0.0069, -0.1595, -0.0460,  0.4795,  0.2623,  0.2423, -0.3686,\n",
      "         0.2411, -0.2455, -0.0986,  0.1823,  0.2401,  0.1340, -0.0152,  0.3128,\n",
      "        -0.1814,  0.3435, -0.4725,  0.0813, -0.4955,  0.3085, -0.2968, -0.1863,\n",
      "        -0.6889,  0.3569,  0.1336, -0.3718,  0.2326, -0.0394, -0.0710,  0.1217,\n",
      "        -0.2503,  0.4278, -0.2493,  0.4184, -0.4930, -0.3436,  0.3300,  0.2279,\n",
      "         0.0361,  0.3760,  0.0483,  0.4223, -0.2281, -0.1148, -0.0427, -0.4899,\n",
      "         0.0460,  0.2512,  0.2604,  0.0875, -0.5194,  0.3765, -0.0487,  0.2897,\n",
      "         0.0690, -0.3824,  0.4007,  0.4040, -0.1793,  0.1709,  0.4540,  0.2187,\n",
      "        -0.0128,  0.0970, -0.1375,  0.3923,  0.0222, -0.2353, -0.0709, -0.2302,\n",
      "        -0.3751, -0.0829,  0.4747, -0.3825,  0.4542, -0.1527,  0.2862,  0.3334],\n",
      "       device='cuda:0')), ('fc_v.weight', tensor([[-0.1965,  0.0233, -0.0577, -0.0897,  0.0147, -0.0252, -0.0535, -0.0037,\n",
      "         -0.0790,  0.0906,  0.0137, -0.1507, -0.0275, -0.0116,  0.0325, -0.0133,\n",
      "         -0.0395, -0.0077, -0.0940, -0.1005,  0.0254,  0.0284,  0.0453,  0.0060,\n",
      "          0.0169,  0.1042, -0.0159, -0.0144,  0.0468,  0.0265, -0.0423, -0.0262,\n",
      "          0.0117, -0.0561, -0.0039,  0.0172, -0.0558,  0.0808, -0.0965,  0.0342,\n",
      "          0.0566,  0.0025, -0.0483, -0.1075,  0.0513, -0.0308,  0.0302, -0.0162,\n",
      "         -0.0619, -0.2828, -0.1003, -0.0802,  0.0561, -0.0089,  0.0254,  0.0076,\n",
      "         -0.0754,  0.0138,  0.0332,  0.0282,  0.0026, -0.0132, -0.0992,  0.0809,\n",
      "         -0.0077, -0.0506, -0.0797,  0.0445, -0.0203,  0.0591, -0.0650,  0.0077,\n",
      "         -0.1111, -0.0312, -0.0694,  0.0231,  0.0140, -0.0215,  0.0099, -0.0383,\n",
      "          0.0400, -0.0265, -0.0290, -0.0259, -0.0760, -0.0127,  0.0323,  0.0598,\n",
      "          0.0351,  0.0901,  0.0433,  0.0799, -0.0453,  0.0440,  0.0644, -0.0220,\n",
      "          0.0613,  0.0627,  0.0278,  0.0242, -0.0488, -0.0270, -0.1071,  0.0264,\n",
      "         -0.0629,  0.0255,  0.0248, -0.0293, -0.0299, -0.0192,  0.0674,  0.0368,\n",
      "          0.0116,  0.0255, -0.0637, -0.0199,  0.1081, -0.0333, -0.1373,  0.1011,\n",
      "         -0.0482,  0.0111,  0.0844, -0.1024,  0.0767,  0.0140,  0.0481, -0.0280]],\n",
      "       device='cuda:0')), ('fc_v.bias', tensor([0.0399], device='cuda:0'))])\n"
     ]
    }
   ],
   "source": [
    "EnvName = 'CartPole-v1'\n",
    "m = 'cdt'\n",
    "\n",
    "env = gym.make(EnvName).unwrapped\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.n  # discrete\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_rl_train_compare.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    rl_confs = json.load(read_file)  # hyperparameters for il training\n",
    "    \n",
    "rl_confs[EnvName]['learner_args']['feature_learning_depth']=1\n",
    "rl_confs[EnvName]['learner_args']['decision_depth']=2\n",
    "\n",
    "model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n",
    "            **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n",
    "i='_12_1'\n",
    "model_path = rl_confs[EnvName][\"train_confs\"][\"model_path\"]+i\n",
    "print(model_path)\n",
    "model.load_model(model_path)\n",
    "for w in model.state_dict()['policy.dc_leaves'].detach().cpu().numpy():\n",
    "    print(softmax(w))\n",
    "print(model.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/quantumiracle/.conda/envs/robo/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.\n",
      "  result = entry_point.load(False)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'episodes': 40, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0}\n",
      "../data/cdt_compare_depth/model/lunarlander/rl_ppo_22_1\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Error(s) in loading state_dict for PPO:\n\tsize mismatch for policy.fl_leaf_weights: copying a param with shape torch.Size([8, 8]) from checkpoint, the shape in current model is torch.Size([4, 8]).\n\tsize mismatch for policy.fl_inner_nodes.weight: copying a param with shape torch.Size([3, 9]) from checkpoint, the shape in current model is torch.Size([1, 9]).",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-f3a7bf321afd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0mmodel_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrl_confs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mEnvName\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_confs\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"model_path\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'policy.dc_leaves'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/research/Explainability/XRL_BorealisAI/src/rl/PPO.py\u001b[0m in \u001b[0;36mload_model\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m    122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/robo/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m    775\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    776\u001b[0m             raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[0;32m--> 777\u001b[0;31m                                self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[1;32m    778\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for PPO:\n\tsize mismatch for policy.fl_leaf_weights: copying a param with shape torch.Size([8, 8]) from checkpoint, the shape in current model is torch.Size([4, 8]).\n\tsize mismatch for policy.fl_inner_nodes.weight: copying a param with shape torch.Size([3, 9]) from checkpoint, the shape in current model is torch.Size([1, 9])."
     ]
    }
   ],
   "source": [
    "EnvName = 'LunarLander-v2'\n",
    "m = 'cdt'\n",
    "\n",
    "env = gym.make(EnvName).unwrapped\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.n  # discrete\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_rl_train_compare.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    rl_confs = json.load(read_file)  # hyperparameters for il training\n",
    "    \n",
    "rl_confs[EnvName]['learner_args']['feature_learning_depth']=2\n",
    "rl_confs[EnvName]['learner_args']['decision_depth']=2\n",
    "\n",
    "model = PPO(state_dim, action_dim, rl_confs[\"General\"][\"policy_approx\"], rl_confs[EnvName][\"learner_args\"],\\\n",
    "            **rl_confs[EnvName][\"alg_confs\"]).to(torch.device(rl_confs[EnvName][\"learner_args\"][\"device\"]))\n",
    "i='_22_1'\n",
    "model_path = rl_confs[EnvName][\"train_confs\"][\"model_path\"]+i\n",
    "print(model_path)\n",
    "model.load_model(model_path)\n",
    "for w in model.state_dict()['policy.dc_leaves'].detach().cpu().numpy():\n",
    "    print(softmax(w))\n",
    "print(model.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
