{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "18378dc0",
   "metadata": {},
   "source": [
    "# Force Breakout To Causally Use Human Understandable Concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "99333bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import numpy as np      \n",
    "import pandas as pd\n",
    "import matplotlib.animation as animation\n",
    "import pickle\n",
    "import toml\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import cv2\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from copy import deepcopy\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from argparse import ArgumentParser\n",
    "from os.path import join\n",
    "from torch.distributions import Beta\n",
    "from IPython.display import HTML\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "from sklearn.metrics import silhouette_score\n",
    "from sklearn.cluster import KMeans, DBSCAN, OPTICS\n",
    "\n",
    "from random import sample\n",
    "from tqdm import tqdm\n",
    "from time import sleep\n",
    "\n",
    "from collections import deque"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "38e85a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os, gym, time, glob, argparse, sys\n",
    "import numpy as np\n",
    "from scipy.signal import lfilter\n",
    "\n",
    "import cv2 # preserves single-pixel info _unlike_ img = img[::2,::2]\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.multiprocessing as mp\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b0aa2dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLASSES = 6\n",
    "LATENT_SIZE = 1536\n",
    "PROTOTYPE_SIZE = 50\n",
    "BATCH_SIZE = 32\n",
    "NUM_EPOCHS = 20\n",
    "DEVICE = 'cpu'\n",
    "delay_ms = 0\n",
    "NUM_PROTOTYPES = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0fdc0eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_loader(model, loader, cce_loss):\n",
    "    model.eval()\n",
    "    total_correct = 0\n",
    "    total_loss = 0\n",
    "    total = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(loader):\n",
    "            imgs, labels = data\n",
    "            \n",
    "            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)            \n",
    "            logits = model(imgs)\n",
    "            loss = cce_loss(logits, labels)\n",
    "            preds = torch.argmax(logits, dim=1)\n",
    "            total_correct += sum(preds == labels).item()\n",
    "            total += len(preds)\n",
    "            total_loss += loss.item()\n",
    "                \n",
    "    return (total_correct / total) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "047d03d6",
   "metadata": {},
   "source": [
    "## Load Pre-Trained Agent & Simulated Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8d544585",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data Size: (39060, 1, 256) (39060,)\n"
     ]
    }
   ],
   "source": [
    "with open('data/X_train.pkl', 'rb') as f:\n",
    "    X_train = pickle.load(f)\n",
    "with open('data/a_train.pkl', 'rb') as f:\n",
    "    a_train = pickle.load(f)\n",
    "\n",
    "X_train = np.array(X_train)\n",
    "a_train = np.array(a_train)\n",
    "print(\"Data Size:\", X_train.shape, a_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2795e178",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(39060, 1, 256)\n",
      "(39060,)\n"
     ]
    }
   ],
   "source": [
    "# print(observations.shape)\n",
    "print(X_train.shape)\n",
    "print(a_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "18bce77c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor_x = torch.Tensor(X_train)\n",
    "tensor_y = torch.tensor(a_train, dtype=torch.long)\n",
    "train_dataset = TensorDataset(tensor_x, tensor_y)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67a19994",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "603ff581",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ListModule(object):\n",
    "    #Should work with all kind of module\n",
    "    def __init__(self, module, prefix, *args):\n",
    "        self.module = module\n",
    "        self.prefix = prefix\n",
    "        self.num_module = 0\n",
    "        for new_module in args:\n",
    "            self.append(new_module)\n",
    "\n",
    "    def append(self, new_module):\n",
    "        if not isinstance(new_module, nn.Module):\n",
    "            raise ValueError('Not a Module')\n",
    "        else:\n",
    "            self.module.add_module(self.prefix + str(self.num_module), new_module)\n",
    "            self.num_module += 1\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_module\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        if i < 0 or i >= self.num_module:\n",
    "            raise IndexError('Out of bound')\n",
    "        return getattr(self.module, self.prefix + str(i))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7f6d547b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PW_Net(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(PW_Net, self).__init__()\n",
    "        \n",
    "        self.ts = ListModule(self, 'ts_')\n",
    "        for i in range(NUM_PROTOTYPES):\n",
    "            transformation = nn.Sequential(\n",
    "                nn.Linear(LATENT_SIZE, PROTOTYPE_SIZE),\n",
    "                nn.BatchNorm1d(PROTOTYPE_SIZE),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(PROTOTYPE_SIZE, PROTOTYPE_SIZE),\n",
    "            )\n",
    "            self.ts.append(transformation)  \n",
    "            \n",
    "        prototypes = torch.randn((NUM_PROTOTYPES, PROTOTYPE_SIZE), dtype=torch.float32)\n",
    "        self.prototypes = nn.Parameter(prototypes, requires_grad=True)\n",
    "        self.epsilon = 1e-5\n",
    "        self.linear = nn.Linear(NUM_PROTOTYPES, NUM_CLASSES, bias=False) \n",
    "        self.__make_linear_weights()\n",
    "        self.tanh = nn.Tanh()\n",
    "        self.relu = nn.ReLU()        \n",
    "        \n",
    "    def __make_linear_weights(self):\n",
    "        prototype_class_identity = torch.zeros(NUM_PROTOTYPES, NUM_CLASSES)\n",
    "        num_prototypes_per_class = NUM_PROTOTYPES // NUM_CLASSES\n",
    "        for j in range(NUM_PROTOTYPES):\n",
    "            prototype_class_identity[j, j // num_prototypes_per_class] = 1\n",
    "        positive_one_weights_locations = torch.t(prototype_class_identity)\n",
    "        negative_one_weights_locations = 1 - positive_one_weights_locations\n",
    "        incorrect_strength = 0.0\n",
    "        correct_class_connection = 1\n",
    "        incorrect_class_connection = incorrect_strength\n",
    "        self.linear.weight.data.copy_(\n",
    "            correct_class_connection * positive_one_weights_locations\n",
    "            + incorrect_class_connection * negative_one_weights_locations)\n",
    "        \n",
    "        \n",
    "    def __proto_layer_l2(self, x, p):\n",
    "        \n",
    "        output = list()\n",
    "        b_size = x.shape[0]\n",
    "        \n",
    "        p = p.view(1, PROTOTYPE_SIZE).tile(b_size, 1).to(DEVICE) \n",
    "        c = x.view(b_size, PROTOTYPE_SIZE).to(DEVICE)      \n",
    "        l2s = ( (c - p)**2 ).sum(axis=1).to(DEVICE) \n",
    "        act = torch.log( (l2s + 1. ) / (l2s + self.epsilon) ).to(DEVICE)  \n",
    "        return act\n",
    "    \n",
    "    def __output_act_func(self, p_acts):        \n",
    "        p_acts.T[0] = self.tanh(p_acts.T[0])  # steering between -1 -> +1\n",
    "        p_acts.T[1] = self.relu(p_acts.T[1])  # acc > 0\n",
    "        p_acts.T[2] = self.relu(p_acts.T[2])  # brake > 0\n",
    "        return p_acts\n",
    "    \n",
    "    def __transforms(self, x):\n",
    "        p_acts = list()\n",
    "        for i, t in enumerate(self.ts):\n",
    "            action_prototype = self.prototypes[i]\n",
    "            p_acts.append( self.__proto_layer_l2( t(x), action_prototype).view(-1, 1) )\n",
    "        return torch.cat(p_acts, axis=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \n",
    "        p_acts = self.__transforms(x)\n",
    "                          \n",
    "        # Linear Layer\n",
    "        logits = self.linear(p_acts)\n",
    "                                   \n",
    "        # Activation Functions\n",
    "        final_outputs = self.__output_act_func(logits)\n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fdf659b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PW_Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3dd0a635",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "aee0f5c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5463d2a2",
   "metadata": {},
   "source": [
    "## Define Concepts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "076947fb",
   "metadata": {},
   "source": [
    "## Human Defined Concepts\n",
    "\n",
    "action 0 and 1 seems useless, as nothing happens to the racket.\n",
    "\n",
    "action 2 & 4 makes the racket go up, and action 3 & 5 makes the racket go down."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3dd82561",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_concepts = {'stay1':    [0.], 'stay2' :      [1.],\n",
    "                  'move_up1': [2.], 'move_down1' : [3.],\n",
    "                  'move_up2': [4.], 'move_down2' : [5.],\n",
    "                 }\n",
    "human_concepts_list = np.array([l for l in human_concepts.values()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9221a8b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# number to consider when searching for prototypes manually\n",
    "n_neighbours = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "1420c1f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "9469c571",
   "metadata": {},
   "outputs": [],
   "source": [
    "knn = KNeighborsClassifier(algorithm='brute')\n",
    "knn.fit(a_train.reshape(-1, 1), list(range(len(a_train))))\n",
    "p_idxs = knn.kneighbors(X=human_concepts_list, n_neighbors=n_neighbours, return_distance=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "06021d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ### Manually explore nn's for good examples\n",
    "\n",
    "# # which concept to check? (0-3, left, right, acc, brake)\n",
    "# concept_idx = 0\n",
    "\n",
    "# for i in range(n_neighbours):\n",
    "#     print(i, a_train[p_idxs[concept_idx][i]])\n",
    "#     plt.imshow(observations[p_idxs[concept_idx][i]])\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "e153183e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# p_idxs = [p_idxs[0][0],\n",
    "#           p_idxs[1][0],\n",
    "#           p_idxs[2][0],\n",
    "#           p_idxs[3][1],\n",
    "#           p_idxs[4][6],\n",
    "#           p_idxs[5][5] ]\n",
    "# p_idxs = np.array(p_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "82662617",
   "metadata": {},
   "outputs": [],
   "source": [
    "# nn_human_images = observations[p_idxs.flatten()]\n",
    "nn_human_x = X_train[p_idxs.flatten()]\n",
    "nn_human_actions = a_train[p_idxs.flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "fb1441e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# nn_human_x *= 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c8fca07",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6dfc5a51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The actual actions of the data's prototypes we'll use\n",
    "nn_human_actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "384f51a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i, img in enumerate(nn_human_images):\n",
    "#     print(list(human_concepts.keys())[i])\n",
    "#     print(nn_human_actions[i])\n",
    "#     plt.imshow(img[0])\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2f32c4cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# nn_human_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "0cd07ce6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def trans_human_concepts(model, nn_human_x):\n",
    "    model.eval()\n",
    "    trans_nn_human_x = list()\n",
    "    for i, t in enumerate(model.ts):\n",
    "        trans_nn_human_x.append( t( torch.tensor(nn_human_x[i], dtype=torch.float32).view(1, -1)) )\n",
    "    model.train()\n",
    "    return torch.cat(trans_nn_human_x, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9e22de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d226fb81",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "35415786",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "7de2ad55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def proto_loss(model, nn_human_x, criterion):\n",
    "    model.eval()\n",
    "    target_x = trans_human_concepts(model, nn_human_x)\n",
    "    loss = criterion(model.prototypes, target_x) \n",
    "    model.train()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "9a33c4bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_DIR = 'data/pwnet.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "5135d1d0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n",
      "Acc: 15.621301775147927\n",
      "Running Error: 1.1967579266668236\n",
      "Prototype Projection Error: 0.007889682399710734\n",
      " \n",
      "Epoch: 1\n",
      "Acc: 57.90689384757967\n",
      "Running Error: 0.9050256500877168\n",
      "Prototype Projection Error: 0.0015342073658664785\n",
      " \n",
      "Epoch: 2\n",
      "Acc: 29.942254223996578\n",
      "Running Error: 0.8533902491993495\n",
      "Prototype Projection Error: 0.0023734341007483153\n",
      " \n",
      "Epoch: 3\n",
      "Acc: 80.89684180509018\n",
      "Running Error: 0.8122385519889802\n",
      "Prototype Projection Error: 0.002234454251484636\n",
      " \n",
      "Epoch: 4\n",
      "Acc: 84.16339915876524\n",
      "Running Error: 0.7716108481161785\n",
      "Prototype Projection Error: 0.0017745829054215425\n",
      " \n",
      "Epoch: 5\n",
      "Acc: 49.37192557211093\n",
      "Running Error: 0.7390895219805249\n",
      "Prototype Projection Error: 0.0031962317538219947\n",
      " \n",
      "Epoch: 6\n",
      "Acc: 79.8773793398446\n",
      "Running Error: 0.7249237658374392\n",
      "Prototype Projection Error: 0.001862474041465155\n",
      " \n",
      "Epoch: 7\n",
      "Acc: 83.05696157410708\n",
      "Running Error: 0.7077626974736578\n",
      "Prototype Projection Error: 0.001904556186533878\n",
      " \n",
      "Epoch: 8\n",
      "Acc: 86.63719968631925\n",
      "Running Error: 0.6971569734278822\n",
      "Prototype Projection Error: 0.0013731491526015856\n",
      " \n",
      "Epoch: 9\n",
      "Acc: 86.75696870321524\n",
      "Running Error: 0.6917600913421951\n",
      "Prototype Projection Error: 0.001562503142876446\n",
      " \n",
      "Epoch: 10\n",
      "Acc: 82.23711413702146\n",
      "Running Error: 0.68466219558197\n",
      "Prototype Projection Error: 0.002291753828302697\n",
      " \n",
      "Epoch: 11\n",
      "Acc: 59.0917516218721\n",
      "Running Error: 0.676240599333533\n",
      "Prototype Projection Error: 0.0014220530811144272\n",
      " \n",
      "Epoch: 12\n",
      "Acc: 87.58109360518999\n",
      "Running Error: 0.664490931252711\n",
      "Prototype Projection Error: 0.0009712586582407936\n",
      " \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [47], line 49\u001b[0m\n\u001b[1;32m     46\u001b[0m loss_data\u001b[38;5;241m.\u001b[39mappend(loss1\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m*\u001b[39m lambda1)\n\u001b[1;32m     47\u001b[0m proto_data\u001b[38;5;241m.\u001b[39mappend(loss2\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m*\u001b[39m lambda2)\n\u001b[0;32m---> 49\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     50\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     52\u001b[0m running_loss1 \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss1\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[0;32m~/Desktop/X-RL/pong/pong_env/lib/python3.9/site-packages/torch/_tensor.py:396\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    387\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    388\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    389\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    390\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    394\u001b[0m         create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m    395\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 396\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/X-RL/pong/pong_env/lib/python3.9/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    168\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    174\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    175\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#### Train Wrapper\n",
    "model = PW_Net().eval()\n",
    "\n",
    "mse_loss = nn.MSELoss()\n",
    "cce_loss = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-8)\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n",
    "best_acc = 0.\n",
    "model.train()\n",
    "\n",
    "# Freeze Linear Layer to make more interpretable and force prototypes to only contribute to a single output action\n",
    "model.linear.weight.requires_grad = False\n",
    "\n",
    "loss_data = list()\n",
    "proto_data = list()\n",
    "\n",
    "# Could tweak these, haven't tried\n",
    "lambda1 = 1.\n",
    "lambda2 = 1.\n",
    "\n",
    "\n",
    "for epoch in range(100):\n",
    "    \n",
    "    running_loss1 = 0\n",
    "    running_loss2 = 0\n",
    "        \n",
    "    model.eval()\n",
    "    current_acc = evaluate_loader(model, train_loader, cce_loss)\n",
    "    model.train()\n",
    "    \n",
    "    if current_acc > best_acc:\n",
    "        torch.save(model.state_dict(), 'weights/wrapper_pre_projection.pth')\n",
    "        best_acc = current_acc\n",
    "    \n",
    "    for instances, labels in train_loader:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "                \n",
    "        instances, labels = instances.to(DEVICE), labels.to(DEVICE)\n",
    "        logits = model(instances)\n",
    "                \n",
    "        loss1 = cce_loss(logits, labels) * lambda1\n",
    "        loss2 = proto_loss(model, nn_human_x, mse_loss) * lambda2\n",
    "        loss = loss1 + loss2\n",
    "        \n",
    "        loss_data.append(loss1.item() * lambda1)\n",
    "        proto_data.append(loss2.item() * lambda2)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss1 += loss1.item()\n",
    "        running_loss2 += loss2.item()\n",
    "        \n",
    "    print(\"Epoch:\", epoch)\n",
    "    print(\"Acc:\", current_acc)\n",
    "    print(\"Running Error:\", running_loss1 / len(train_loader))\n",
    "    print(\"Prototype Projection Error:\", running_loss2 / len(train_loader))\n",
    "    print(\" \")\n",
    "    scheduler.step()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b87db3cd",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "0aa944ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = PW_Net().eval()\n",
    "model.load_state_dict(torch.load('weights/wrapper_pre_projection.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "a1c2548c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "88.0972410351465"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy before projection\n",
    "evaluate_loader(model, train_loader, cce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57bb1b2e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "b9b1ae20",
   "metadata": {},
   "outputs": [],
   "source": [
    "real_trans_x = trans_human_concepts(model, nn_human_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "390f9ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_trans_x = model.prototypes.clone().detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "89b06cd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_74371/2837856146.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  model.prototypes.data.copy_(torch.tensor(real_trans_x, dtype=torch.float32))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.3272,  0.2350, -0.4590,  0.8884, -0.1614,  0.2455,  0.9443,  0.5712,\n",
       "          0.5428,  0.4614,  0.9351, -0.4157, -0.2981,  0.0323, -0.7886,  0.1806,\n",
       "          0.1004,  0.1757,  0.5604,  0.6771, -0.2461, -0.2129, -0.0166,  0.6005,\n",
       "          0.3068,  0.2923,  0.3692, -0.2245,  0.4959, -0.7097,  0.0420,  0.0606,\n",
       "         -0.4929, -0.2373,  1.0154, -0.0411,  0.3987, -0.3174,  0.5170, -0.0652,\n",
       "          0.1006,  0.0375, -0.2662, -0.1975,  0.5350,  0.3197, -0.3700, -0.1069,\n",
       "          0.1515, -0.4960],\n",
       "        [-0.1718, -0.4029,  0.3226, -0.1947, -0.1145,  0.1973, -0.5035,  0.3778,\n",
       "          0.0404, -0.8930,  0.0659,  0.4297, -0.0245, -0.5159,  0.6800,  0.0622,\n",
       "          0.4711,  0.7019,  0.2534, -0.5523, -0.1248, -0.0698,  0.2973,  1.0863,\n",
       "         -0.2253,  0.0174,  0.3767,  0.0851,  0.4194, -0.3688,  0.1865, -0.0589,\n",
       "         -0.1353, -0.6923,  0.0129,  0.1565, -0.0464,  0.6523,  0.5230, -0.3694,\n",
       "          0.1056, -0.6518,  0.1374,  0.6418, -0.0802,  0.3025,  0.1512,  0.1003,\n",
       "          0.3821,  0.4084],\n",
       "        [-0.1274, -0.3291, -0.1433,  0.4283,  0.9860,  0.5722,  0.3211,  0.1918,\n",
       "          0.2281, -0.8137,  0.0231,  0.3854, -0.6830,  0.1476, -0.1677, -0.1527,\n",
       "          0.5892, -0.0512, -0.1151,  0.3548, -0.2324, -0.0571,  0.3374, -0.5236,\n",
       "         -0.3321, -0.0194,  0.0048, -0.0492,  0.9038, -0.2640, -0.2555, -0.4652,\n",
       "         -0.0292,  0.0072, -0.8647, -0.2221, -0.1226, -0.2718,  0.5178, -0.2548,\n",
       "         -0.0427,  0.1600,  1.2643,  0.4778, -0.4212, -0.0422, -0.3463,  0.6464,\n",
       "          0.9139, -0.2760],\n",
       "        [-0.0409, -0.3778, -0.4856,  0.2341,  0.3580,  1.1153,  0.3294, -0.0851,\n",
       "         -0.6427, -0.5557, -0.3196, -0.1944,  0.5471, -0.5779,  0.2130, -0.3020,\n",
       "          0.7285,  0.3924,  0.3753, -0.6417,  0.0725, -0.0953,  0.0710,  0.3430,\n",
       "          0.1703,  0.6404, -0.1702, -0.7841,  0.6691,  0.2382, -0.0274, -0.1555,\n",
       "          0.8308,  0.4261, -0.4564,  1.2519,  0.0537,  0.0956,  0.5645, -0.3281,\n",
       "         -0.0098,  0.3488, -0.0358, -0.7114,  0.1274, -0.6279, -0.2053, -0.1561,\n",
       "          0.0772, -0.2450],\n",
       "        [ 0.1426,  0.2593, -1.0020, -0.3097,  0.5228, -0.3321,  0.0959, -0.9085,\n",
       "          0.1837, -0.1541,  0.5556,  1.0916, -0.0959,  0.1240,  0.7333,  0.2997,\n",
       "          0.2534,  0.2609,  0.0936, -0.6518,  0.4791,  0.3741, -0.9455,  0.7149,\n",
       "          0.1096, -0.4465, -0.0725,  1.0563, -0.7372, -0.0636, -1.0722,  0.0679,\n",
       "         -0.0334,  0.1650,  0.0211, -0.4457, -0.2593, -0.1384,  0.2796,  1.1656,\n",
       "          0.9759, -0.6846, -0.3999, -0.0047,  0.1412,  0.4336,  0.2557, -0.0393,\n",
       "         -0.9777,  0.0217],\n",
       "        [ 0.4240, -0.6370,  0.6323, -0.2372,  0.4399,  0.2056, -0.8004, -0.5333,\n",
       "          0.3700, -0.2782, -0.5560, -0.1807,  0.4836,  0.7267,  0.0935, -0.9065,\n",
       "         -0.1880,  0.2324, -0.3246, -1.0854, -0.2035, -0.3904, -0.2599,  0.3371,\n",
       "          0.7560, -0.1532, -0.8644,  0.2724, -0.5440,  0.0433,  0.4430, -0.3394,\n",
       "          0.0474,  0.3031,  1.1410,  0.3018, -0.7239,  0.4206,  0.5204,  0.0093,\n",
       "          1.1950,  0.3422,  0.4641,  0.6674,  0.2474, -0.3846, -0.3379, -0.0028,\n",
       "         -0.4174,  0.2612]])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.prototypes.data.copy_(torch.tensor(real_trans_x, dtype=torch.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "7c2b53c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "87.60390675126541"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy after projection\n",
    "evaluate_loader(model, train_loader, cce_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "2c398935",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'weights/pw_net.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "684b4277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Distances: [0.1341969519853592, 0.20840822160243988, 0.4398880898952484, 0.13887129724025726, 0.11193636804819107, 0.18724007904529572]\n",
      "Average: 0.2315952479839325\n"
     ]
    }
   ],
   "source": [
    "# Distance between learned prototoypes and projected ones\n",
    "print(\"Distances:\",  torch.sqrt(( (real_trans_x - trained_trans_x)**2 ).sum(axis=1)).tolist()  )\n",
    "print(\"Average:\", torch.sqrt(( (real_trans_x - trained_trans_x)**2 ).sum(axis=1).mean()).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dcec022",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d2d0ca3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pong_env",
   "language": "python",
   "name": "pong_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
