{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b6f35b06",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/eoinkenny/opt/anaconda3/envs/rl_env/lib/python3.9/site-packages/seaborn/rcmod.py:82: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
      "  if LooseVersion(mpl.__version__) >= \"3.0\":\n",
      "/Users/eoinkenny/opt/anaconda3/envs/rl_env/lib/python3.9/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
      "  other = LooseVersion(other)\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from random import sample\n",
    "from tqdm import tqdm\n",
    "from time import sleep\n",
    "from model import DQN_Agent\n",
    "import numpy as np      \n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from collections import Counter\n",
    "from sklearn.cluster import KMeans, DBSCAN, OPTICS\n",
    "from scipy.spatial.distance import euclidean\n",
    "from copy import deepcopy\n",
    "from sklearn.metrics import silhouette_score\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import pandas as pd\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "faab21e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLASSES = 2\n",
    "NUM_PROTOTYPES = 2\n",
    "LATENT_SIZE = 64\n",
    "BATCH_SIZE = 64\n",
    "NUM_EPOCHS = 20\n",
    "DEVICE = 'cpu'\n",
    "MAX_SAMPLES = 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3e84d0df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom for each domain == Load in environment and model\n",
    "env = gym.make(\"CartPole-v1\").unwrapped\n",
    "input_dim = env.observation_space.shape[0]\n",
    "output_dim = env.action_space.n\n",
    "exp_replay_size = 256\n",
    "agent = DQN_Agent(seed=1423, layer_sizes=[input_dim, 64, output_dim], lr=1e-3, sync_freq=5,\n",
    "                  exp_replay_size=exp_replay_size)\n",
    "agent.load_pretrained_model(\"weights/cartpole-dqn.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "07213efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.load('data/X_train.npy', )\n",
    "a_train = np.load('data/a_train.npy', )\n",
    "obs_train = np.load('data/obs_train.npy', )\n",
    "q_train = np.load('data/q_train.npy')\n",
    "\n",
    "\n",
    "tensor_x = torch.Tensor(X_train)\n",
    "tensor_y = torch.tensor(a_train)\n",
    "tensor_z = torch.tensor(obs_train)\n",
    "\n",
    "train_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "91079220",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 200, 64)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4c00d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "tsne = TSNE(n_components=2)\n",
    "emb_x = tsne.fit_transform(temp_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "254f5a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "tsne = PCA(n_components=2)\n",
    "emb_x = tsne.fit_transform(temp_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd88577",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame()\n",
    "df['a'] = temp_y\n",
    "df['x'] = emb_x.T[0]\n",
    "df['y'] = emb_x.T[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd8a680",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.a = df.a.astype('category')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7aa1b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.scatterplot(x='x', y='y', hue='a', data=df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl_env",
   "language": "python",
   "name": "rl_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
