{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "aea037e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn_extra.cluster import KMedoids\n",
    "\n",
    "from joblib import dump, load\n",
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn import tree\n",
    "\n",
    "# from arguments.args import get_args\n",
    "# import argparse\n",
    "\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "11f16faa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load q values\n",
    "file = open('q_values_hopper.pkl', 'rb')\n",
    "data = pickle.load(file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84918ae0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 3)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[1]['actions'][1].shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5ba99e9d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1027"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fb45ea97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data[1]['q_values'])\n",
    "data[1]['q_values'][1].shape\n",
    "\n",
    "# data[0]['q_values']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "658f2b62",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 11)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[1]['observations'][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d0360b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize(q_list):\n",
    "    max_p = max(q_list)\n",
    "    min_p = min(q_list)\n",
    "    alpha = (max_p + min_p) / 2\n",
    "\n",
    "    # Check if max_p is zero\n",
    "    if max_p == 0:\n",
    "        return q_list\n",
    "\n",
    "    for i in range(0, len(q_list)):\n",
    "        if max_p != 0:\n",
    "            q_list[i] = (q_list[i] - alpha) / max_p\n",
    "        else:\n",
    "            # If max_p is zero, set the value to zero\n",
    "            q_list[i] = 0\n",
    "\n",
    "    return q_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3dad3d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_list = []\n",
    "tanh_q_list = []\n",
    "obs_list = []\n",
    "\n",
    "for a_i in range(0, 10):\n",
    "    q_val = []\n",
    "    obs = []\n",
    "#     for batch in range(len(data)):\n",
    "    for batch in range(5):\n",
    "        q_val.extend(data[batch]['q_values'][a_i])\n",
    "        obs.extend(np.mean(data[batch]['observations'][a_i], axis=1))\n",
    "    q_list.append(q_val)\n",
    "    obs_list.append(obs)\n",
    " \n",
    "    tanh_q_list.append(np.tanh(normalize(q_val)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b4ffeda0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(q_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "10de012b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(obs_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2a8b9d08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5000"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(obs_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "51cfc2df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5000"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(q_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "f1d0222e",
   "metadata": {},
   "outputs": [],
   "source": [
    "state = KMedoids(n_clusters=10, metric='cosine', random_state=0).fit(np.array(q_list).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a375f54d",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = state.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "79b0522b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([8, 3, 3, ..., 3, 3, 3], dtype=int64)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "d301408f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int,\n",
       "            {8: 337,\n",
       "             3: 872,\n",
       "             6: 3279,\n",
       "             1: 47,\n",
       "             7: 107,\n",
       "             0: 58,\n",
       "             2: 61,\n",
       "             5: 211,\n",
       "             4: 12,\n",
       "             9: 16})"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "stats = defaultdict(int)\n",
    "for i in labels:\n",
    "    stats[i] += 1\n",
    "\n",
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "1e4c6348",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KMedoids(metric=&#x27;cosine&#x27;, n_clusters=10, random_state=0)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KMedoids</label><div class=\"sk-toggleable__content\"><pre>KMedoids(metric=&#x27;cosine&#x27;, n_clusters=10, random_state=0)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KMedoids(metric='cosine', n_clusters=10, random_state=0)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "700d128a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 5000)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(obs_list).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d5c7155f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5000,)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "c15be0c5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# data prep\n",
    "X_train, X_test, y_train, y_test = train_test_split(np.array(obs_list).T, labels, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "b6d19735",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-3 {color: black;background-color: white;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier(max_leaf_nodes=40)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "GradientBoostingClassifier(max_leaf_nodes=40)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# train\n",
    "tree_with_svm = GradientBoostingClassifier(max_leaf_nodes=40)\n",
    "tree_with_svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "48130883",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['.\\\\tree_with_svm.joblib']"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save model\n",
    "dump(tree_with_svm, os.path.join('.', 'tree_with_svm.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "7787c007",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "tree_with_svm = load(os.path.join('.', 'tree_with_svm.joblib'))           \n",
    "pred = tree_with_svm.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "4c3e6264",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(int, {6: 918, 9: 3, 8: 8, 2: 15, 3: 13, 5: 12, 0: 8, 1: 7, 7: 16})"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "action_stats = defaultdict(int)\n",
    "for i in pred:\n",
    "    action_stats[i] += 1\n",
    "\n",
    "action_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c913edc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
