{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree Stability Analysis\n",
    "\n",
    "For imitation learning only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.cdt import CDT \n",
    "from src.sdt import SDT\n",
    "\n",
    "def normalize(list_v):\n",
    "    normalized_list = []\n",
    "    for v in list_v:\n",
    "        if np.sum(np.abs(v)) == 0:\n",
    "            continue\n",
    "        else:\n",
    "            v =np.array(v)/np.sum(np.abs(v))\n",
    "        normalized_list.append(v)\n",
    "    return normalized_list\n",
    "\n",
    "def l1_norm(a,b):\n",
    "    '''\n",
    "    Return the L1-norm distance of two vectors\n",
    "    '''\n",
    "    return np.linalg.norm(np.array(a)-np.array(b), ord=1)\n",
    "\n",
    "def difference_metric(list1, list2, norm=True, symmetric=True):\n",
    "    '''\n",
    "    Calculate minimal difference of list1 and list2\n",
    "    '''\n",
    "    if norm:\n",
    "        list1 = normalize(list1)\n",
    "        list2 = normalize(list2)\n",
    "    def similarity_measure(l1, l2):\n",
    "        score = []\n",
    "        for v1 in l1:\n",
    "            sim_list = []\n",
    "            for v2 in l2:\n",
    "                sim = np.min([l1_norm(v1, v2),  l1_norm(v1, -1.*np.array(v2))])\n",
    "                sim_list.append(sim)\n",
    "            score.append(np.min(sim_list)) \n",
    "        return np.mean(score)\n",
    "    \n",
    "    if symmetric:\n",
    "        final_score = 0.5*similarity_measure(list1, list2) + 0.5*similarity_measure(list2, list1)\n",
    "#         print(similarity_measure(list1, list2), similarity_measure(list2, list1))\n",
    "    else:\n",
    "        final_score = similarity_measure(list1, list2)\n",
    "        \n",
    "    return final_score\n",
    "\n",
    "# test\n",
    "# a=[[1,2], [1.25, 2.25]]\n",
    "# b=[[1.5,2.5], [4, 6]]\n",
    "# difference_metric(a,b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CartPole-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3 4 5]\n",
      "SDT parameters:  {'input_dim': 4, 'output_dim': 2, 'depth': 3, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/cartpole/il_model', 'log_path': '../data/sdt/log/cartpole/il_log'}\n",
      "SDT parameters:  {'input_dim': 4, 'output_dim': 2, 'depth': 3, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/cartpole/il_model', 'log_path': '../data/sdt/log/cartpole/il_log'}\n",
      "SDT parameters:  {'input_dim': 4, 'output_dim': 2, 'depth': 3, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/cartpole/il_model', 'log_path': '../data/sdt/log/cartpole/il_log'}\n",
      "SDT parameters:  {'input_dim': 4, 'output_dim': 2, 'depth': 3, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/cartpole/il_model', 'log_path': '../data/sdt/log/cartpole/il_log'}\n",
      "SDT parameters:  {'input_dim': 4, 'output_dim': 2, 'depth': 3, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/cartpole/il_model', 'log_path': '../data/sdt/log/cartpole/il_log'}\n",
      "[0.18239866942167282, 0.22911859303712845, 0.2876540273427963, 0.19550787657499313, 0.236889086663723, 0.20481963455677032, 0.09203776717185974, 0.19447296112775803, 0.26314952224493027, 0.20961599797010422]\n",
      "0.20956641361117362\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'CartPole-v1'\n",
    "m = 'sdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "weights_list = []\n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = SDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    weights = tree.get_tree_weights(Bias=True)    \n",
    "    weights_list.append(weights)\n",
    "\n",
    "similarity_score=[]\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        similarity_score.append(difference_metric(weights_list[i-1], weights_list[j-1]))\n",
    "\n",
    "print(similarity_score)\n",
    "print(np.mean(similarity_score))      "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LunarLander-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "0.49667860716581347 0.9173517346382141\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'LunarLander-v2'\n",
    "m = 'sdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "weights_list = []\n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = SDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    weights = tree.get_tree_weights(Bias=True)    \n",
    "    weights_list.append(weights)\n",
    "\n",
    "tree = SDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "random_weights = tree.get_tree_weights(Bias=True) \n",
    "\n",
    "similarity_score=[]\n",
    "random_similarity_score=[]\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        similarity_score.append(difference_metric(weights_list[i-1], weights_list[j-1]))\n",
    "        \n",
    "for i in range(1, n+1):       \n",
    "    random_similarity_score.append(difference_metric(weights_list[i-1], random_weights))\n",
    "\n",
    "print(np.mean(similarity_score), np.mean(random_similarity_score))      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LunarLander Heuristic Decision Tree Agent Nodes\n",
    "nodes_in_heuristic_tree = [  # first dim is bias, the rest are weights\n",
    "    [0, 0,0,0,0,0,0,1,1],\n",
    "\n",
    "    [-0.4, 0.5, 0,1,0,0,0,0,0],\n",
    "    [-0.4, -0.5, 0,-1,0,0,0,0,0],\n",
    "    [0, 1,0,0,0,0,0,0,0],\n",
    "\n",
    "    # at\n",
    "    [0.2, 0,0,0,0,-0.5,-1,0,0],\n",
    "    [0.15, 0,0,0,0,-0.5,-1,0,0],\n",
    "    [-0.25, 0,0,0,0,0.5,1,0,0],\n",
    "\n",
    "    [-0.2, 0,0,0,0,-0.5,-1,0,0],\n",
    "    [-0.25, 0,0,0,0,-0.5,-1,0,0],\n",
    "    [0.15, 0,0,0,0,0.5,1,0,0],\n",
    "\n",
    "\n",
    "    [0, 0.25, 0, 0.5, 0, -0.5, -1, 0, 0 ],\n",
    "    [-0.05, 0.25, 0, 0.5, 0, -0.5, -1, 0, 0 ],\n",
    "    [-0.05, -0.25, 0, -0.5, 0, 0.5, 1, 0, 0 ],\n",
    "\n",
    "\n",
    "    # ht\n",
    "    [-0.05, 0.275, -0.5, 0, -0.5, 0,0,0,0],\n",
    "\n",
    "    [-0.05, -0.275, -0.5, 0, -0.5, 0,0,0,0],\n",
    "\n",
    "    [-0.05, 0, 0, 0, -0.5, 0, 0, 0, 0],\n",
    "\n",
    "    # at, ht\n",
    "    [-0.2, 0.275, -0.5, 0,-0.5, 0.5,1,0,0],\n",
    "    [0.2, 0.275, -0.5, 0,-0.5, -0.5, -1, 0,0],\n",
    "\n",
    "    [-0.2, -0.275, -0.5, 0,-0.5, 0.5,1,0,0],\n",
    "    [0.2, -0.275, -0.5, 0,-0.5, -0.5, -1, 0,0],\n",
    "\n",
    "    [0.2, 0.275, -0.5, 0,-0.5, 0.5,1,0,0],\n",
    "    [-0.2, 0.275, -0.5, 0,-0.5, -0.5, -1, 0,0],\n",
    "\n",
    "    [0.2, -0.275, -0.5, 0,-0.5, 0.5,1,0,0],\n",
    "    [-0.2, -0.275, -0.5, 0,-0.5, -0.5, -1, 0,0],\n",
    "\n",
    "    [0, 0.025, -0.5, -0.5, -0.5, 0.5, 1, 0, 0],\n",
    "    [0, 0.525, -0.5, 0.5, -0.5, -0.5, -1, 0, 0],\n",
    "\n",
    "    [0, -0.525, -0.5, -0.5, -0.5, 0.5, 1, 0, 0],\n",
    "    [0, -0.025, -0.5, 0.5, -0.5, -0.5, -1, 0, 0],\n",
    "\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3 4 5]\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "SDT parameters:  {'input_dim': 8, 'output_dim': 4, 'depth': 4, 'lamda': 0.001, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'epochs': 80, 'device': 'cuda', 'log_interval': 100, 'exp_scheduler_gamma': 1.0, 'beta': 0, 'l1_regularization': 0, 'greatest_path_probability': 1, 'model_path': '../data/sdt/model/lunarlander/il_model', 'log_path': '../data/sdt/log/lunarlander/il_log'}\n",
      "0.8404764748671589\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'LunarLander-v2'\n",
    "m = 'sdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "weights_list = []\n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = SDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    weights = tree.get_tree_weights(Bias=True)    \n",
    "    weights_list.append(weights)\n",
    "\n",
    "similarity_score=[]\n",
    "heuristic_similarity_score=[]\n",
    "        \n",
    "for i in range(1, n+1):\n",
    "    heuristic_similarity_score.append(difference_metric(weights_list[i-1], nodes_in_heuristic_tree))\n",
    "\n",
    "print(np.mean(heuristic_similarity_score))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CDT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CartPole-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "0.38250247836112977 0.5391488242894411\n",
      "1.1868854403495788 0.7651692867279053\n",
      "0.4608256513252854\n",
      "0.9760273635387421\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'CartPole-v1'\n",
    "m = 'cdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "fl_weights_list = []\n",
    "dm_weights_list = []\n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    fl_weights, dm_weights = tree.get_tree_weights(Bias=True)    \n",
    "    fl_weights_list.append(fl_weights)\n",
    "    dm_weights_list.append(dm_weights)\n",
    "    \n",
    "tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "random_fl_weights, random_dm_weights = tree.get_tree_weights(Bias=True) \n",
    "\n",
    "fl_similarity_score=[]\n",
    "dm_similarity_score=[]\n",
    "random_fl_similarity_score=[]\n",
    "random_dm_similarity_score=[]\n",
    "\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        fl_similarity_score.append(difference_metric(fl_weights_list[i-1], fl_weights_list[j-1]))\n",
    "        dm_similarity_score.append(difference_metric(dm_weights_list[i-1], dm_weights_list[j-1]))\n",
    "        \n",
    "for i in range(1, n+1):       \n",
    "    random_fl_similarity_score.append(difference_metric(fl_weights_list[i-1], random_fl_weights))\n",
    "    random_dm_similarity_score.append(difference_metric(dm_weights_list[i-1], random_dm_weights))\n",
    "\n",
    "print(np.mean(fl_similarity_score), np.mean(dm_similarity_score))      \n",
    "print(np.mean(random_fl_similarity_score), np.mean(random_dm_similarity_score))   \n",
    "\n",
    "print(np.mean(fl_similarity_score+dm_similarity_score))\n",
    "print(np.mean(random_fl_similarity_score+random_dm_similarity_score))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LunarLander-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 2, 'decision_depth': 2, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "0.7915976077318192 0.6702044695615769\n",
      "1.0348077774047852 0.559660267829895\n",
      "0.730901038646698\n",
      "0.7972340226173401\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'LunarLander-v2'\n",
    "m = 'cdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "fl_weights_list = []\n",
    "dm_weights_list = []\n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    fl_weights, dm_weights = tree.get_tree_weights(Bias=True)    \n",
    "    fl_weights_list.append(fl_weights)\n",
    "    dm_weights_list.append(dm_weights)\n",
    "\n",
    "tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "random_fl_weights, random_dm_weights = tree.get_tree_weights(Bias=True) \n",
    "\n",
    "fl_similarity_score=[]\n",
    "dm_similarity_score=[]\n",
    "random_fl_similarity_score=[]\n",
    "random_dm_similarity_score=[]\n",
    "\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        fl_similarity_score.append(difference_metric(fl_weights_list[i-1], fl_weights_list[j-1]))\n",
    "        dm_similarity_score.append(difference_metric(dm_weights_list[i-1], dm_weights_list[j-1]))\n",
    "        \n",
    "for i in range(1, n+1):       \n",
    "    random_fl_similarity_score.append(difference_metric(fl_weights_list[i-1], random_fl_weights))\n",
    "    random_dm_similarity_score.append(difference_metric(dm_weights_list[i-1], random_dm_weights))\n",
    "\n",
    "print(np.mean(fl_similarity_score), np.mean(dm_similarity_score))      \n",
    "print(np.mean(random_fl_similarity_score), np.mean(random_dm_similarity_score))   \n",
    "\n",
    "print(np.mean(fl_similarity_score+dm_similarity_score))\n",
    "print(np.mean(random_fl_similarity_score+random_dm_similarity_score))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Another way of calculating DM weights (based on raw input $x$)\n",
    "**Adopted in the paper**.\n",
    "\n",
    "Since the DM weighs are represented in $\\{f\\}$-space (the features), we want to transform them back to $\\{x\\}$-space (the raw inputs), so as to compare with normal SDT.\n",
    "\n",
    "params:\n",
    "    * R: raw feature dimension\n",
    "    * K: intermediate feature dimension\n",
    "    * O: output dimension\n",
    "    * N: number of leaf nodes in feature learning tree\n",
    "    * M: number of inner nodes in decision making tree\n",
    "\n",
    "$\\tilde{W}_{N\\times K\\times R}\\cdot W'_{M\\times K} \\rightarrow W^{raw}_{N\\times M \\times R}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CartPole-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "(7, 5)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "(7, 5)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "(7, 5)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "(7, 5)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "(7, 5)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 1, 'decision_depth': 2, 'input_dim': 4, 'output_dim': 2, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/cartpole/il_model', 'log_path': '../data/cdt/log/cartpole/il_log'}\n",
      "0.0764705200213939\n",
      "1.0942245364189147\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'CartPole-v1'\n",
    "m = 'cdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "weights_list = []\n",
    "\n",
    "def get_all_weights(tree):\n",
    "    f_weights = tree.get_feature_weights()\n",
    "#     print(f_weights.shape)\n",
    "    f_weights = np.moveaxis(f_weights, 1, -1)  # (N, K, R) -> (N, R, K)\n",
    "#     print(f_weights.shape)\n",
    "    fl_weights, dm_weights = tree.get_tree_weights(Bias=True)    \n",
    "    dm_bias = dm_weights[:, 0]\n",
    "    dm_weights_no_bias = np.moveaxis(dm_weights[:, 1:], 0, 1)  # (M, K) -> (K, M)\n",
    "#     print(dm_weights_no_bias.shape, f_weights.shape)\n",
    "    dm_in_x = np.moveaxis(f_weights@dm_weights_no_bias, -1, 0)  # (N, R, K) \\times (K, M) -> (N, R, M) -> (M, N, R)\n",
    "    dm_bias = np.repeat(dm_bias, tree.num_fl_leaves, axis=0).reshape(-1, tree.num_fl_leaves, 1)  # (M,) -> (M, N, 1)\n",
    "    dm_in_x = np.concatenate((dm_bias, dm_in_x), axis=-1)  # (M, N, R)+(M, N, 1) -> (M, N, R+1)\n",
    "    dm_in_x = dm_in_x.reshape(-1, dm_in_x.shape[-1])  # (M, N, R+1) -> (M \\times N, R+1)\n",
    "#     print(dm_in_x.shape)\n",
    "#     print(fl_weights.shape)\n",
    "    return np.concatenate((fl_weights, dm_in_x), axis=0)\n",
    "    \n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    all_weights = get_all_weights(tree)\n",
    "    print(all_weights.shape)\n",
    "\n",
    "    weights_list.append(all_weights)\n",
    "    \n",
    "tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "random_weights = get_all_weights(tree)\n",
    "\n",
    "similarity_score=[]\n",
    "random_similarity_score=[]\n",
    "\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        similarity_score.append(difference_metric(weights_list[i-1], weights_list[j-1]))\n",
    "        \n",
    "for i in range(1, n+1):       \n",
    "    random_similarity_score.append(difference_metric(weights_list[i-1], random_weights))\n",
    "\n",
    "print(np.mean(similarity_score))      \n",
    "print(np.mean(random_similarity_score))   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LunarLander-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "(63, 9)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "(63, 9)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "(63, 9)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "(63, 9)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "(63, 9)\n",
      "CDT parameters:  {'num_intermediate_variables': 2, 'feature_learning_depth': 3, 'decision_depth': 3, 'input_dim': 8, 'output_dim': 4, 'lr': 0.001, 'weight_decay': 0.0, 'batch_size': 1280, 'exp_scheduler_gamma': 1.0, 'device': 'cuda', 'epochs': 80, 'log_interval': 100, 'greatest_path_probability': 1, 'beta_fl': 0, 'beta_dc': 0, 'model_path': '../data/cdt/model/lunarlander/il_model', 'log_path': '../data/cdt/log/lunarlander/il_log'}\n",
      "0.5340688943862915\n",
      "0.7743854939937591\n",
      "0.8673055873743681\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "EnvName = 'LunarLander-v2'\n",
    "m = 'cdt'\n",
    "n = 5 # number of runs\n",
    "\n",
    "conf_path = '../src/'+m+'/'+m+'_il_train.json'\n",
    "with open(conf_path, \"r\") as read_file:\n",
    "    il_confs = json.load(read_file)  # hyperparameters for il training\n",
    "#         print(il_confs)\n",
    "\n",
    "weights_list = []\n",
    "\n",
    "def get_all_weights(tree):\n",
    "    f_weights = tree.get_feature_weights()\n",
    "#     print(f_weights.shape)\n",
    "    f_weights = np.moveaxis(f_weights, 1, -1)  # (N, K, R) -> (N, R, K)\n",
    "#     print(f_weights.shape)\n",
    "    fl_weights, dm_weights = tree.get_tree_weights(Bias=True)    \n",
    "    dm_bias = dm_weights[:, 0]\n",
    "    dm_weights_no_bias = np.moveaxis(dm_weights[:, 1:], 0, 1)  # (M, K) -> (K, M)\n",
    "#     print(dm_weights_no_bias.shape, f_weights.shape)\n",
    "    dm_in_x = np.moveaxis(f_weights@dm_weights_no_bias, -1, 0)  # (N, R, K) \\times (K, M) -> (N, R, M) -> (M, N, R)\n",
    "    dm_bias = np.repeat(dm_bias, tree.num_fl_leaves, axis=0).reshape(-1, tree.num_fl_leaves, 1)  # (M,) -> (M, N, 1)\n",
    "    dm_in_x = np.concatenate((dm_bias, dm_in_x), axis=-1)  # (M, N, R)+(M, N, 1) -> (M, N, R+1)\n",
    "    dm_in_x = dm_in_x.reshape(-1, dm_in_x.shape[-1])  # (M, N, R+1) -> (M \\times N, R+1)\n",
    "#     print(dm_in_x.shape)\n",
    "#     print(fl_weights.shape)\n",
    "    return np.concatenate((fl_weights, dm_in_x), axis=0)\n",
    "    \n",
    "for i in range(1, n+1):\n",
    "#         print(il_confs[EnvName][\"learner_args\"])\n",
    "    model_path = il_confs[EnvName][\"learner_args\"][\"model_path\"]+str(i)\n",
    "    device = torch.device(il_confs[EnvName][\"learner_args\"][\"device\"])\n",
    "    tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "    tree.load_model(model_path)\n",
    "    all_weights = get_all_weights(tree)\n",
    "    print(all_weights.shape)\n",
    "\n",
    "    weights_list.append(all_weights)\n",
    "    \n",
    "tree = CDT(il_confs[EnvName][\"learner_args\"]).to(device)\n",
    "random_weights = get_all_weights(tree)\n",
    "\n",
    "similarity_score=[]\n",
    "random_similarity_score=[]\n",
    "heuristic_similarity_score=[]\n",
    "\n",
    "# loop through all possible pairs\n",
    "for i in range(1, n):\n",
    "    for j in range(i+1, n+1):\n",
    "        similarity_score.append(difference_metric(weights_list[i-1], weights_list[j-1]))\n",
    "        \n",
    "for i in range(1, n+1):       \n",
    "    random_similarity_score.append(difference_metric(weights_list[i-1], random_weights))\n",
    "    heuristic_similarity_score.append(difference_metric(weights_list[i-1], nodes_in_heuristic_tree))\n",
    "\n",
    "print(np.mean(similarity_score))      \n",
    "print(np.mean(random_similarity_score))   \n",
    "print(np.mean(heuristic_similarity_score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
