{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "99333bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from random import sample\n",
    "from tqdm import tqdm\n",
    "from time import sleep\n",
    "from model import DQN_Agent\n",
    "import numpy as np      \n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from collections import Counter\n",
    "from sklearn.cluster import KMeans, DBSCAN, OPTICS\n",
    "from scipy.spatial.distance import euclidean\n",
    "from copy import deepcopy\n",
    "from sklearn.metrics import silhouette_score\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "b0aa2dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLASSES = 2\n",
    "NUM_PROTOTYPES = 2\n",
    "LATENT_SIZE = 64\n",
    "PROTOTYPE_SIZE = 16\n",
    "BATCH_SIZE = 64\n",
    "NUM_EPOCHS = 20\n",
    "DEVICE = 'cpu'\n",
    "MAX_SAMPLES = 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "e8844f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def viper_sample(obss, acts, qs, xs, is_reweight=True):\n",
    "\t\n",
    "\t\"\"\"\n",
    "\tFunction taken from: https://github.com/obastani/viper\n",
    "\tobservations\n",
    "\tlatent x activations\n",
    "\tq values\n",
    "\tMax num observations\n",
    "\tUniform or sampling\n",
    "\t\"\"\"\n",
    "\t\n",
    "\t# Step 1: Compute probabilities\n",
    "\tps = np.max(qs, axis=1) - np.min(qs, axis=1)\n",
    "\tps = ps / np.sum(ps)\n",
    "\n",
    "\t# Step 2: Sample points\n",
    "\tif is_reweight:\n",
    "\t\t# According to p(s)\n",
    "\t\tidx = np.random.choice(len(obss), size=min(MAX_SAMPLES, np.sum(ps > 0)), p=ps)\n",
    "\telse:\n",
    "\t\t# Uniformly (without replacement)\n",
    "\t\tidx = np.random.choice(len(obss), size=min(MAX_SAMPLES, np.sum(ps > 0)), replace=False)    \n",
    "\n",
    "\t# Step 3: Obtain sampled indices\n",
    "\treturn obss[idx], acts[idx], qs[idx], xs[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "96dd86d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Environment and Agent\n",
    "env = gym.make(\"CartPole-v1\").unwrapped\n",
    "input_dim = env.observation_space.shape[0]\n",
    "output_dim = env.action_space.n\n",
    "exp_replay_size = 256\n",
    "agent = DQN_Agent(seed=1423, layer_sizes=[input_dim, 64, output_dim], lr=1e-3,\n",
    "                  sync_freq=5, exp_replay_size=exp_replay_size)\n",
    "agent.load_pretrained_model(\"weights/cartpole-dqn.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "8d544585",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.load('data/X_train.npy', ).reshape(-1, LATENT_SIZE)\n",
    "a_train = np.load('data/a_train.npy', ).flatten()\n",
    "obs_train = np.load('data/obs_train.npy', ).reshape(-1, 4)\n",
    "q_train = np.load('data/q_train.npy').reshape(-1, 2)\n",
    "\n",
    "obs_train, a_train, q_train, X_train = viper_sample(obs_train,\n",
    "                                                     a_train,\n",
    "                                                     q_train,\n",
    "                                                     X_train,\n",
    "                                                     is_reweight=True)\n",
    "\n",
    "tensor_x = torch.Tensor(X_train)\n",
    "tensor_y = torch.tensor(a_train)\n",
    "tensor_z = torch.tensor(obs_train)\n",
    "\n",
    "train_dataset = TensorDataset(tensor_x, tensor_y, tensor_z)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "0fdc0eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_loader(model, loader, cce_loss, intuition_labels=False):\n",
    "    model.eval()\n",
    "    total_correct = 0\n",
    "    total_loss = 0\n",
    "    total = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(loader):\n",
    "            imgs, labels, observations = data\n",
    "            \n",
    "            if intuition_labels:\n",
    "                labels = intuition_loss(observations)\n",
    "\n",
    "            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)            \n",
    "            logits, x, _ = model(imgs)\n",
    "            loss = cce_loss(logits, labels)\n",
    "            preds = torch.argmax(logits, dim=1)\n",
    "            total_correct += sum(preds == labels).item()\n",
    "            total += len(preds)\n",
    "            total_loss += loss.item()\n",
    "                \n",
    "    return (total_correct / total) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "b404eb46",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ListModule(object):\n",
    "    #Should work with all kind of module\n",
    "    def __init__(self, module, prefix, *args):\n",
    "        self.module = module\n",
    "        self.prefix = prefix\n",
    "        self.num_module = 0\n",
    "        for new_module in args:\n",
    "            self.append(new_module)\n",
    "\n",
    "    def append(self, new_module):\n",
    "        if not isinstance(new_module, nn.Module):\n",
    "            raise ValueError('Not a Module')\n",
    "        else:\n",
    "            self.module.add_module(self.prefix + str(self.num_module), new_module)\n",
    "            self.num_module += 1\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_module\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        if i < 0 or i >= self.num_module:\n",
    "            raise IndexError('Out of bound')\n",
    "        return getattr(self.module, self.prefix + str(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "808afb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wrapper(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Wrapper, self).__init__()\n",
    "        \n",
    "        self.ts = ListModule(self, 'ts_')\n",
    "        for i in range(NUM_PROTOTYPES):\n",
    "            transformation = nn.Sequential(\n",
    "                nn.Linear(LATENT_SIZE, PROTOTYPE_SIZE),\n",
    "                nn.BatchNorm1d(PROTOTYPE_SIZE),\n",
    "                nn.ReLU(),\n",
    "#                 nn.Dropout(0.2),\n",
    "                nn.Linear(PROTOTYPE_SIZE, PROTOTYPE_SIZE),\n",
    "            )\n",
    "            self.ts.append(transformation)  \n",
    "            \n",
    "        prototypes = torch.randn((NUM_PROTOTYPES, PROTOTYPE_SIZE), dtype=torch.float32)\n",
    "        self.protos_post_trans = nn.Parameter(prototypes, requires_grad=True)\n",
    "        self.epsilon = 1e-5\n",
    "        self.linear = nn.Linear(NUM_PROTOTYPES, NUM_CLASSES, bias=False) \n",
    "        self.__make_linear_weights()\n",
    "        self.relu = nn.ReLU()        \n",
    "        \n",
    "    def __make_linear_weights(self):\n",
    "        \n",
    "        prototype_class_identity = torch.zeros(NUM_PROTOTYPES, NUM_CLASSES)\n",
    "        num_prototypes_per_class = NUM_PROTOTYPES // NUM_CLASSES\n",
    "        \n",
    "        for j in range(NUM_PROTOTYPES):\n",
    "            prototype_class_identity[j, j // num_prototypes_per_class] = 1\n",
    "            \n",
    "        positive_one_weights_locations = torch.t(prototype_class_identity)\n",
    "        negative_one_weights_locations = 1 - positive_one_weights_locations\n",
    "\n",
    "        incorrect_strength = 0.0\n",
    "        correct_class_connection = 1\n",
    "        incorrect_class_connection = incorrect_strength\n",
    "        self.linear.weight.data.copy_(\n",
    "            correct_class_connection * positive_one_weights_locations\n",
    "            + incorrect_class_connection * negative_one_weights_locations)\n",
    "        \n",
    "    def __proto_layer_l2(self, x, p):\n",
    "        \n",
    "        output = list()\n",
    "        b_size = x.shape[0]\n",
    "        \n",
    "        p = p.view(1, PROTOTYPE_SIZE).tile(b_size, 1).to(DEVICE) \n",
    "        c = x.view(b_size, PROTOTYPE_SIZE).to(DEVICE)      \n",
    "        l2s = ( (c - p)**2 ).sum(axis=1).to(DEVICE) \n",
    "        act = torch.log( (l2s + 1. ) / (l2s + self.epsilon) ).to(DEVICE)  \n",
    "        return act\n",
    "    \n",
    "    def __output_act_func(self, p_acts):        \n",
    "        return self.relu(p_acts)\n",
    "    \n",
    "    def __transforms(self, x):\n",
    "        p_acts = list()\n",
    "        for i, t in enumerate(self.ts):\n",
    "            action_prototype = self.protos_post_trans[i]\n",
    "            p_acts.append( self.__proto_layer_l2( t(x), action_prototype).view(-1, 1) )\n",
    "        return torch.cat(p_acts, axis=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \n",
    "        p_acts = self.__transforms(x)\n",
    "                          \n",
    "        # Linear Layer\n",
    "        logits = self.linear(p_acts)\n",
    "                                   \n",
    "        # Activation Functions\n",
    "        final_outputs = self.__output_act_func(logits)\n",
    "        \n",
    "        return final_outputs, x, p_acts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "3debf9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def proto_loss(model, nn_human_x, criterion):\n",
    "    model.eval()\n",
    "    p = model.protos_post_trans\n",
    "    target_x = transform(model, nn_human_x)\n",
    "    loss = criterion(p, target_x.clone().detach()) \n",
    "    model.train()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "2e5e3b6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform(wrapper, x):\n",
    "    p_acts = list()\n",
    "    for i, t in enumerate(wrapper.ts):\n",
    "        p_acts.append( t(x[i].view(1, -1)) )\n",
    "    return torch.cat(p_acts, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "35b4bb7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def intuition_loss(observations):\n",
    "    return (observations.T[3] > 0).long()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14b9c581",
   "metadata": {},
   "source": [
    "### Define Human Concept\n",
    "\n",
    "* 0 -- Move left\n",
    "* 1 -- Move right"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f36c001f",
   "metadata": {},
   "source": [
    "* Cart Position: -4.8 => 4.8\n",
    "* Cart Velocity -inf => +inf\n",
    "* Pole Angle -0.418 rad => 0.418 rad\n",
    "* Pole Angular Velocity -inf => +inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "da2a364b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "73d72ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Wrapper()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "35d9b240",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_concepts = {'tilt_left': [0., 0., -0.01, -0.5], 'tilt_right' : [0., 0., 0.01, 0.5],\n",
    "                 }\n",
    "\n",
    "human_concepts_list = np.array([l for l in human_concepts.values()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "7558a662",
   "metadata": {},
   "outputs": [],
   "source": [
    "knn = KNeighborsClassifier(algorithm='brute')\n",
    "knn.fit(obs_train, list(range(len(obs_train))))\n",
    "p_idxs = knn.kneighbors(X=human_concepts_list, n_neighbors=100, return_distance=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "65ba3d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_idxs = np.array([p_idxs[0][0], p_idxs[1][0],\n",
    "                  ])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "42290397",
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_human_images = obs_train[p_idxs.flatten()]\n",
    "nn_human_x = X_train[p_idxs.flatten()]\n",
    "nn_human_actions = a_train[p_idxs.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "f1cf2801",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.01153835,  0.19514857,  0.04402352, -0.42990083],\n",
       "       [ 0.00464073, -0.22610292, -0.0417408 ,  0.3206237 ]])"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn_human_images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "67e134c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1])"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn_human_actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "5135d1d0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Accuracy: 47.837\n",
      "CCE Loss: 0.09568427289152066\n",
      "Proto Loss: 1.8787282017876579\n",
      "Train Accuracy: 96.60300000000001\n",
      "CCE Loss: 0.06583541974151819\n",
      "Proto Loss: 0.4940799786841827\n",
      "Train Accuracy: 99.572\n",
      "CCE Loss: 0.053773889487927454\n",
      "Proto Loss: 0.16339589428683648\n",
      "Train Accuracy: 97.083\n",
      "CCE Loss: 0.0751402957128958\n",
      "Proto Loss: 0.20540894637487933\n",
      "Train Accuracy: 99.824\n",
      "CCE Loss: 0.06763242204860985\n",
      "Proto Loss: 0.268297635748004\n",
      "Train Accuracy: 99.818\n",
      "CCE Loss: 0.05037251029742175\n",
      "Proto Loss: 0.2582383618091853\n",
      "Train Accuracy: 99.787\n",
      "CCE Loss: 0.05062351369531974\n",
      "Proto Loss: 0.16705248243031995\n",
      "Train Accuracy: 96.122\n",
      "CCE Loss: 0.057664051095723284\n",
      "Proto Loss: 0.35178025993384093\n",
      "Train Accuracy: 99.42999999999999\n",
      "CCE Loss: 0.062476050723743314\n",
      "Proto Loss: 0.2963843544735938\n",
      "Train Accuracy: 98.74300000000001\n",
      "CCE Loss: 0.06131919870212357\n",
      "Proto Loss: 0.3234483618308039\n",
      "Train Accuracy: 97.056\n",
      "CCE Loss: 0.05777360919102674\n",
      "Proto Loss: 0.2683967727520077\n",
      "Train Accuracy: 96.12899999999999\n",
      "CCE Loss: 0.05746541862907538\n",
      "Proto Loss: 0.472704507573977\n",
      "Train Accuracy: 99.889\n",
      "CCE Loss: 0.06060399991053175\n",
      "Proto Loss: 0.33317632326895436\n",
      "Train Accuracy: 99.84100000000001\n",
      "CCE Loss: 0.08672486237394465\n",
      "Proto Loss: 1.121649845372741\n",
      "Train Accuracy: 99.591\n",
      "CCE Loss: 0.07755787224063122\n",
      "Proto Loss: 0.2241363715461541\n",
      "Train Accuracy: 99.03999999999999\n",
      "CCE Loss: 0.06746777847274585\n",
      "Proto Loss: 0.3231059055909397\n",
      "Train Accuracy: 99.792\n",
      "CCE Loss: 0.059158454001753076\n",
      "Proto Loss: 0.12864061554389458\n",
      "Train Accuracy: 99.709\n",
      "CCE Loss: 0.06578313933938065\n",
      "Proto Loss: 0.22307187499394593\n",
      "Train Accuracy: 97.03399999999999\n",
      "CCE Loss: 0.08634458223686411\n",
      "Proto Loss: 0.9031059381547051\n",
      "Train Accuracy: 99.28200000000001\n",
      "CCE Loss: 0.05978574632188533\n",
      "Proto Loss: 0.26826443240919134\n"
     ]
    }
   ],
   "source": [
    "cce_loss = nn.CrossEntropyLoss()\n",
    "mse_loss = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n",
    "best_acc = 0\n",
    "\n",
    "# Freeze prototype learning and use full model (inc. transformation)\n",
    "model.train()\n",
    "\n",
    "lambda1 = 1.0\n",
    "lambda2 = 10.0\n",
    "\n",
    "model.linear.weight.requires_grad = False\n",
    "\n",
    "\n",
    "for epoch in range(20):\n",
    "    \n",
    "    running_loss1 = 0\n",
    "    running_loss2 = 0\n",
    "    \n",
    "    model.eval()\n",
    "    train_acc = evaluate_loader(model, train_loader, cce_loss)\n",
    "    model.train()\n",
    "    print(\"Train Accuracy:\", train_acc)\n",
    "    \n",
    "    if train_acc > best_acc:\n",
    "        torch.save(model.state_dict(), 'weights/wrapper_pre_projection.pth')\n",
    "        best_acc = train_acc\n",
    "    \n",
    "    for instances, labels, observations in train_loader:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "                \n",
    "        instances, labels, observations = instances.to(DEVICE), labels.to(DEVICE), observations.to(DEVICE)\n",
    "        logits, x, _ = model(instances)\n",
    "        \n",
    "        loss1 = cce_loss(logits, labels) * lambda1\n",
    "        loss2 = proto_loss(model, torch.tensor(nn_human_x, dtype=torch.float32), mse_loss) * lambda2\n",
    "        loss = loss1 + loss2\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss1 += loss1.item()\n",
    "        running_loss2 += loss2.item()        \n",
    "        \n",
    "    print(\"CCE Loss:\", running_loss1 / len(train_loader))\n",
    "    print(\"Proto Loss:\", running_loss2 / len(train_loader))\n",
    "    print(\" \")\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c04dc163",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "947c4542",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b87db3cd",
   "metadata": {},
   "source": [
    "## Project Prototypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "0aa944ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Wrapper().eval()\n",
    "model.load_state_dict(torch.load('weights/wrapper_pre_projection.pth'))\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "a1c2548c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "99.889"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy before projection\n",
    "evaluate_loader(model, train_loader, cce_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5040174",
   "metadata": {},
   "source": [
    "## Project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "9ffe11a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "learned_prototypes = model.protos_post_trans.clone().detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "501786b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 16])"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learned_prototypes.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "id": "d005c104",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_prototypes = transform(model, torch.tensor(nn_human_x, dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "d3061aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.protos_post_trans = torch.nn.Parameter(human_prototypes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "4e3948cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Distances: [0.24756118655204773, 0.06394386291503906]\n",
      "Average: 0.18079733848571777\n"
     ]
    }
   ],
   "source": [
    "# Distance between learned prototoypes and projected ones\n",
    "print(\"Distances:\",  torch.sqrt(( (learned_prototypes - human_prototypes)**2 ).sum(axis=1)).tolist()  )\n",
    "print(\"Average:\", torch.sqrt(( (learned_prototypes - human_prototypes)**2 ).sum(axis=1).mean()).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "id": "79026bc6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "99.47"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy after projection\n",
    "evaluate_loader(model, train_loader, cce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "3c6bf5a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'weights/wrapper_post_projection.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "483185bc",
   "metadata": {},
   "source": [
    "## Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "4d9be2b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 139,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Wrapper().eval()\n",
    "model.load_state_dict(torch.load('weights/wrapper_post_projection.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "89cb9528",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:09<00:00, 10592.80it/s]\n"
     ]
    }
   ],
   "source": [
    "trans_x = list()\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():    \n",
    "    for i in tqdm(range(len(X_train))):\n",
    "        img = X_train[i]\n",
    "        temp = list()\n",
    "        for t in model.ts:\n",
    "            x = t( torch.tensor(img, dtype=torch.float32).view(1, -1) )\n",
    "            temp.append(x[0].tolist())\n",
    "        trans_x.append(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "d0c8f5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_x = np.array(trans_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "9e5eb0b0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100000, 2, 16)"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trans_x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6388ee83",
   "metadata": {},
   "source": [
    "## Move Left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "d179b5ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_x = trans_x[:, 0, : ]\n",
    "temp_y = [0 for _ in range(len(temp_x))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "c5ca6cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_x = np.append(temp_x, model.protos_post_trans.clone().detach().numpy()[0].reshape(1,-2), axis=0)\n",
    "temp_y = np.append(temp_y, np.array( [1] ), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "d923b1cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(n_components=2)\n",
    "emb_x = pca.fit_transform(temp_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "ce654ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame()\n",
    "df['Legend'] = temp_y\n",
    "df['x'] = emb_x.T[0]\n",
    "df['y'] = emb_x.T[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "f7af1b80",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.Legend = df.Legend.astype('category')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "6a01c81f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.Legend = df.Legend.replace([0, 1], ['Training Data', 'Turn Left'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c95d2a8b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e935d7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91fd9f53",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ff43d0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f9b988",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47b4c47b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d91870",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969ccf07",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lunar_env",
   "language": "python",
   "name": "lunar_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
