{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "aea037e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn_extra.cluster import KMedoids\n",
    "\n",
    "from joblib import dump, load\n",
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn import tree\n",
    "\n",
    "# from arguments.args import get_args\n",
    "# import argparse\n",
    "\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "11f16faa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load q values\n",
    "file = open('q_values_hopper.pkl', 'rb')\n",
    "data = pickle.load(file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "84918ae0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 3)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[1]['actions'][1].shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "5ba99e9d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1027"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "fb45ea97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000,)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data[1]['q_values'])\n",
    "data[1]['q_values'][1].shape\n",
    "\n",
    "# data[0]['q_values']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "658f2b62",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 11)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[1]['observations'][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d0360b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(q_list):\n",
    "    max_p = max(q_list)\n",
    "    min_p = min(q_list)\n",
    "    alpha = (max_p + min_p) / 2\n",
    "\n",
    "    # Check if max_p is zero\n",
    "    if max_p == 0:\n",
    "        return q_list\n",
    "\n",
    "    for i in range(0, len(q_list)):\n",
    "        if max_p != 0:\n",
    "            q_list[i] = (q_list[i] - alpha) / max_p\n",
    "        else:\n",
    "            # If max_p is zero, set the value to zero\n",
    "            q_list[i] = 0\n",
    "\n",
    "    return q_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "3dad3d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_list = []\n",
    "tanh_q_list = []\n",
    "obs_list = []\n",
    "\n",
    "for a_i in range(0, 10):\n",
    "    q_val = []\n",
    "    obs = []\n",
    "#     for batch in range(len(data)):\n",
    "    for batch in range(40):\n",
    "        q_val.extend(data[batch]['q_values'][a_i])\n",
    "        obs.extend(np.mean(data[batch]['observations'][a_i], axis=1))\n",
    "    q_list.append(q_val)\n",
    "    obs_list.append(obs)\n",
    " \n",
    "    tanh_q_list.append(np.tanh(normalize(q_val)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b4ffeda0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(q_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "10de012b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(obs_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "2a8b9d08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38698"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(obs_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "51cfc2df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38698"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(q_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "f1d0222e",
   "metadata": {},
   "outputs": [],
   "source": [
    "state = KMedoids(n_clusters=10, metric='cosine', random_state=0).fit(np.array(q_list).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "a375f54d",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = state.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "79b0522b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 2, 2, ..., 9, 9, 9], dtype=int64)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "d301408f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {2: 5889,\n",
       "             0: 28109,\n",
       "             4: 1282,\n",
       "             5: 337,\n",
       "             1: 147,\n",
       "             9: 2306,\n",
       "             7: 331,\n",
       "             3: 37,\n",
       "             6: 126,\n",
       "             8: 134})"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "stats = defaultdict(int)\n",
    "for i in labels:\n",
    "    stats[i] += 1\n",
    "\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "1e4c6348",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-6 {color: black;background-color: white;}#sk-container-id-6 pre{padding: 0;}#sk-container-id-6 div.sk-toggleable {background-color: white;}#sk-container-id-6 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-6 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-6 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-6 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-6 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-6 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-6 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-6 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-6 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-6 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-6 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-6 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-6 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-6 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-6 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-6 div.sk-item {position: relative;z-index: 1;}#sk-container-id-6 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-6 div.sk-item::before, #sk-container-id-6 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-6 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-6 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-6 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-6 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-6 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-6 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-6 div.sk-label-container {text-align: center;}#sk-container-id-6 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-6 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-6\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KMedoids(metric=&#x27;cosine&#x27;, n_clusters=10, random_state=0)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" checked><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KMedoids</label><div class=\"sk-toggleable__content\"><pre>KMedoids(metric=&#x27;cosine&#x27;, n_clusters=10, random_state=0)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KMedoids(metric='cosine', n_clusters=10, random_state=0)"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "700d128a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 38698)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(obs_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "d5c7155f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(38698,)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "c15be0c5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# data prep\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "b6d19735",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-7 {color: black;background-color: white;}#sk-container-id-7 pre{padding: 0;}#sk-container-id-7 div.sk-toggleable {background-color: white;}#sk-container-id-7 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-7 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-7 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-7 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-7 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-7 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-7 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-7 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-7 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-7 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-7 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-7 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-7 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-7 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-7 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-7 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-7 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-7 div.sk-item {position: relative;z-index: 1;}#sk-container-id-7 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-7 div.sk-item::before, #sk-container-id-7 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-7 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-7 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-7 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-7 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-7 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-7 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-7 div.sk-label-container {text-align: center;}#sk-container-id-7 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-7 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-7\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" checked><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "GradientBoostingClassifier(max_leaf_nodes=40)"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# train\n",
    "tree_with_svm = GradientBoostingClassifier(max_leaf_nodes=40)\n",
    "tree_with_svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "48130883",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['.\\\\tree_with_svm_hop.joblib']"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save model\n",
    "dump(tree_with_svm, os.path.join('.', 'tree_with_svm_hop.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "7787c007",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "tree_with_svm = load(os.path.join('.', 'tree_with_svm_hop.joblib'))           \n",
    "pred = tree_with_svm.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "4c3e6264",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {0: 7525,\n",
       "             4: 16,\n",
       "             5: 37,\n",
       "             8: 27,\n",
       "             9: 19,\n",
       "             7: 48,\n",
       "             6: 27,\n",
       "             1: 26,\n",
       "             2: 6,\n",
       "             3: 9})"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "action_stats = defaultdict(int)\n",
    "for i in pred:\n",
    "    action_stats[i] += 1\n",
    "\n",
    "action_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a72f0584",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "ec562d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# baseline training\n",
    "# data prep\n",
    "# X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)\n",
    "# For baselines\n",
    "X_train_bs, X_test_bs, y_train_bs, y_test_bs = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "ed4cae2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-8 {color: black;background-color: white;}#sk-container-id-8 pre{padding: 0;}#sk-container-id-8 div.sk-toggleable {background-color: white;}#sk-container-id-8 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-8 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-8 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-8 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-8 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-8 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-8 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-8 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-8 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-8 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-8 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-8 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-8 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-8 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-8 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-8 div.sk-item {position: relative;z-index: 1;}#sk-container-id-8 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-8 div.sk-item::before, #sk-container-id-8 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-8 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-8 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-8 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-8 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-8 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-8 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-8 div.sk-label-container {text-align: center;}#sk-container-id-8 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-8 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-8\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ExtraTreesClassifier(max_leaf_nodes=40)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" checked><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ExtraTreesClassifier</label><div class=\"sk-toggleable__content\"><pre>ExtraTreesClassifier(max_leaf_nodes=40)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "ExtraTreesClassifier(max_leaf_nodes=40)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "tree_with_cart = DecisionTreeClassifier(max_leaf_nodes=40)\n",
    "tree_with_cart.fit(X_train_bs, y_train_bs)\n",
    "\n",
    "tree_with_rf = RandomForestClassifier(max_leaf_nodes=40)\n",
    "tree_with_rf.fit(X_train_bs, y_train_bs)\n",
    "\n",
    "tree_with_et = ExtraTreesClassifier(max_leaf_nodes=40)\n",
    "tree_with_et.fit(X_train_bs, y_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "6c913edc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['.\\\\tree_with_et_hop.joblib']"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save model\n",
    "dump(tree_with_cart, os.path.join('.', 'tree_with_cart_hop.joblib'))\n",
    "dump(tree_with_rf, os.path.join('.', 'tree_with_rf_hop.joblib'))\n",
    "dump(tree_with_et, os.path.join('.', 'tree_with_et_hop.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c095f1f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "1852b8b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open('q_values_halfcheetah.pkl', 'rb')\n",
    "data = pickle.load(file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "288edca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_list_ = []\n",
    "tanh_q_list_ = []\n",
    "obs_list_ = []\n",
    "\n",
    "for a_i in range(0, 10):\n",
    "    q_val = []\n",
    "    obs = []\n",
    "#     for batch in range(len(data)):\n",
    "    for batch in range(40):\n",
    "        q_val.extend(data[batch]['q_values'][a_i])\n",
    "        obs.extend(np.mean(data[batch]['observations'][a_i], axis=1))\n",
    "    q_list_.append(q_val)\n",
    "    obs_list_.append(obs)\n",
    " \n",
    "    tanh_q_list_.append(np.tanh(normalize(q_val)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "a97039fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "state = KMedoids(n_clusters=10, metric='cosine', random_state=0).fit(np.array(q_list).T)\n",
    "labels = state.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "53c4ebad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# data prep\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "1c7061a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-9 {color: black;background-color: white;}#sk-container-id-9 pre{padding: 0;}#sk-container-id-9 div.sk-toggleable {background-color: white;}#sk-container-id-9 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-9 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-9 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-9 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-9 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-9 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-9 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-9 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-9 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-9 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-9 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-9 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-9 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-9 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-9 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-9 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-9 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-9 div.sk-item {position: relative;z-index: 1;}#sk-container-id-9 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-9 div.sk-item::before, #sk-container-id-9 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-9 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-9 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-9 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-9 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-9 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-9 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-9 div.sk-label-container {text-align: center;}#sk-container-id-9 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-9 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-9\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-9\" type=\"checkbox\" checked><label for=\"sk-estimator-id-9\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "GradientBoostingClassifier(max_leaf_nodes=40)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# train\n",
    "tree_with_svm = GradientBoostingClassifier(max_leaf_nodes=40)\n",
    "tree_with_svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "8221a81a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['.\\\\tree_with_svm_half.joblib']"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save model\n",
    "dump(tree_with_svm, os.path.join('.', 'tree_with_svm_half.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "b034b6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "tree_with_svm = load(os.path.join('.', 'tree_with_svm_half.joblib'))           \n",
    "pred = tree_with_svm.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "d0e5b765",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 0, 0, 0], dtype=int64)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9560a0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "f88f4029",
   "metadata": {},
   "outputs": [],
   "source": [
    "# baseline training\n",
    "# data prep\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)\n",
    "# For baselines\n",
    "X_train_bs, X_test_bs, y_train_bs, y_test_bs = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "1aa2511c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-10 {color: black;background-color: white;}#sk-container-id-10 pre{padding: 0;}#sk-container-id-10 div.sk-toggleable {background-color: white;}#sk-container-id-10 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-10 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-10 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-10 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-10 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-10 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-10 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-10 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-10 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-10 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-10 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-10 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-10 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-10 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-10 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-10 div.sk-item {position: relative;z-index: 1;}#sk-container-id-10 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-10 div.sk-item::before, #sk-container-id-10 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-10 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-10 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-10 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-10 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-10 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-10 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-10 div.sk-label-container {text-align: center;}#sk-container-id-10 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-10 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-10\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>ExtraTreesClassifier(max_leaf_nodes=40)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-10\" type=\"checkbox\" checked><label for=\"sk-estimator-id-10\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ExtraTreesClassifier</label><div class=\"sk-toggleable__content\"><pre>ExtraTreesClassifier(max_leaf_nodes=40)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "ExtraTreesClassifier(max_leaf_nodes=40)"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree_with_cart = DecisionTreeClassifier(max_leaf_nodes=40)\n",
    "tree_with_cart.fit(X_train_bs, y_train_bs)\n",
    "\n",
    "tree_with_rf = RandomForestClassifier(max_leaf_nodes=40)\n",
    "tree_with_rf.fit(X_train_bs, y_train_bs)\n",
    "\n",
    "tree_with_et = ExtraTreesClassifier(max_leaf_nodes=40)\n",
    "tree_with_et.fit(X_train_bs, y_train_bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "36be7563",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['.\\\\tree_with_et_half.joblib']"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save model\n",
    "dump(tree_with_cart, os.path.join('.', 'tree_with_cart_half.joblib'))\n",
    "dump(tree_with_rf, os.path.join('.', 'tree_with_rf_half.joblib'))\n",
    "dump(tree_with_et, os.path.join('.', 'tree_with_et_half.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f467b08",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0570bdc9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
