{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "99333bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import numpy as np  \n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.cluster import KMeans, DBSCAN, OPTICS\n",
    "\n",
    "from random import sample\n",
    "from tqdm import tqdm\n",
    "from time import sleep\n",
    "from model import DQN_Agent\n",
    "from collections import Counter\n",
    "from scipy.spatial.distance import euclidean\n",
    "from copy import deepcopy\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b0aa2dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLASSES = 2\n",
    "NUM_PROTOTYPES = 2\n",
    "LATENT_SIZE = 64\n",
    "PROTOTYPE_SIZE = 16\n",
    "BATCH_SIZE = 64\n",
    "NUM_EPOCHS = 50\n",
    "DEVICE = 'cpu'\n",
    "MAX_SAMPLES = 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e8844f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def viper_sample(obss, acts, qs, xs, is_reweight=True):\n",
    "\t\n",
    "# \t\"\"\"\n",
    "# \tFunction taken from: https://github.com/obastani/viper\n",
    "# \tobservations\n",
    "# \tlatent x activations\n",
    "# \tq values\n",
    "# \tMax num observations\n",
    "# \tUniform or sampling\n",
    "# \t\"\"\"\n",
    "\t\n",
    "# \t# Step 1: Compute probabilities\n",
    "# \tps = np.max(qs, axis=1) - np.min(qs, axis=1)\n",
    "# \tps = ps / np.sum(ps)\n",
    "\n",
    "# \t# Step 2: Sample points\n",
    "# \tif is_reweight:\n",
    "# \t\t# According to p(s)\n",
    "# \t\tidx = np.random.choice(len(obss), size=min(MAX_SAMPLES, np.sum(ps > 0)), p=ps)\n",
    "# \telse:\n",
    "# \t\t# Uniformly (without replacement)\n",
    "# \t\tidx = np.random.choice(len(obss), size=min(MAX_SAMPLES, np.sum(ps > 0)), replace=False)    \n",
    "\n",
    "# \t# Step 3: Obtain sampled indices\n",
    "# \treturn obss[idx], acts[idx], qs[idx], xs[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "96dd86d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Environment and Agent\n",
    "env = gym.make(\"CartPole-v1\").unwrapped\n",
    "input_dim = env.observation_space.shape[0]\n",
    "output_dim = env.action_space.n\n",
    "exp_replay_size = 256\n",
    "agent = DQN_Agent(seed=1423, layer_sizes=[input_dim, 64, output_dim], lr=1e-3,\n",
    "                  sync_freq=5, exp_replay_size=exp_replay_size)\n",
    "agent.load_pretrained_model(\"weights/cartpole-dqn.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8d544585",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.load('data/X_train.npy', ).reshape(-1, LATENT_SIZE)\n",
    "a_train = np.load('data/a_train.npy', ).flatten()\n",
    "obs_train = np.load('data/obs_train.npy', ).reshape(-1, 4)\n",
    "# q_train = np.load('data/q_train.npy').reshape(-1, 2)\n",
    "\n",
    "\n",
    "\n",
    "# obs_train, a_train, q_train, X_train = viper_sample(obs_train,\n",
    "#                                                      a_train,\n",
    "#                                                      q_train,\n",
    "#                                                      X_train,\n",
    "#                                                      is_reweight=True)\n",
    "\n",
    "\n",
    "tensor_x = torch.Tensor(X_train)\n",
    "tensor_y = torch.tensor(a_train)\n",
    "# tensor_z = torch.tensor(obs_train)\n",
    "\n",
    "train_dataset = TensorDataset(tensor_x, tensor_y)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0fdc0eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_loader(model, loader, cce_loss, intuition_labels=False):\n",
    "    model.eval()\n",
    "    total_correct = 0\n",
    "    total_loss = 0\n",
    "    total = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(loader):\n",
    "            imgs, labels = data\n",
    "            \n",
    "            if intuition_labels:\n",
    "                labels = intuition_loss(observations)\n",
    "\n",
    "            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)            \n",
    "            logits = model(imgs)\n",
    "            loss = cce_loss(logits, labels)\n",
    "            preds = torch.argmax(logits, dim=1)\n",
    "            total_correct += sum(preds == labels).item()\n",
    "            total += len(preds)\n",
    "            total_loss += loss.item()\n",
    "                \n",
    "    return (total_correct / total) * 100, (total_loss / len(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "40a1c504",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ListModule(object):\n",
    "    #Should work with all kind of module\n",
    "    def __init__(self, module, prefix, *args):\n",
    "        self.module = module\n",
    "        self.prefix = prefix\n",
    "        self.num_module = 0\n",
    "        for new_module in args:\n",
    "            self.append(new_module)\n",
    "\n",
    "    def append(self, new_module):\n",
    "        if not isinstance(new_module, nn.Module):\n",
    "            raise ValueError('Not a Module')\n",
    "        else:\n",
    "            self.module.add_module(self.prefix + str(self.num_module), new_module)\n",
    "            self.num_module += 1\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_module\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        if i < 0 or i >= self.num_module:\n",
    "            raise IndexError('Out of bound')\n",
    "        return getattr(self.module, self.prefix + str(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "35811975",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PWNet(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(PWNet, self).__init__()\n",
    "        self.ts = ListModule(self, 'ts_')\n",
    "        for i in range(NUM_PROTOTYPES):\n",
    "            transformation = nn.Sequential(\n",
    "                nn.Linear(LATENT_SIZE, PROTOTYPE_SIZE),\n",
    "                nn.InstanceNorm1d(PROTOTYPE_SIZE),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(PROTOTYPE_SIZE, PROTOTYPE_SIZE),\n",
    "            )\n",
    "            self.ts.append(transformation)  \n",
    "        self.prototypes = None\n",
    "        self.epsilon = 1e-5\n",
    "        self.linear = nn.Linear(NUM_PROTOTYPES, NUM_CLASSES, bias=False) \n",
    "        self.__make_linear_weights()\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        self.nn_human_x = nn.Parameter( torch.randn(NUM_PROTOTYPES, LATENT_SIZE), requires_grad=False)\n",
    "        \n",
    "    def __make_linear_weights(self):\n",
    "        prototype_class_identity = torch.zeros(NUM_PROTOTYPES, NUM_CLASSES)\n",
    "        num_prototypes_per_class = NUM_PROTOTYPES // NUM_CLASSES\n",
    "        for j in range(NUM_PROTOTYPES):\n",
    "            prototype_class_identity[j, j // num_prototypes_per_class] = 1\n",
    "        positive_one_weights_locations = torch.t(prototype_class_identity)\n",
    "        negative_one_weights_locations = 1 - positive_one_weights_locations\n",
    "        incorrect_strength = 0.0\n",
    "        correct_class_connection = 1\n",
    "        incorrect_class_connection = incorrect_strength\n",
    "        self.linear.weight.data.copy_(\n",
    "            correct_class_connection * positive_one_weights_locations\n",
    "            + incorrect_class_connection * negative_one_weights_locations)\n",
    "        \n",
    "    def __proto_layer_l2(self, x, p):\n",
    "        output = list()\n",
    "        b_size = x.shape[0]\n",
    "        p = p.view(1, PROTOTYPE_SIZE).tile(b_size, 1).to(DEVICE) \n",
    "        c = x.view(b_size, PROTOTYPE_SIZE).to(DEVICE)      \n",
    "        l2s = ( (c - p)**2 ).sum(axis=1).to(DEVICE) \n",
    "        act = torch.log( (l2s + 1. ) / (l2s + self.epsilon) ).to(DEVICE)  \n",
    "        return act\n",
    "    \n",
    "    def __output_act_func(self, p_acts):        \n",
    "        return self.softmax(p_acts)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \n",
    "        # Let's pretend there's a variable like self.precomputed_protos that can be used for caching after training.\n",
    "        latent_protos = None\n",
    "        if self.prototypes is None:\n",
    "            trans_nn_human_x = list()\n",
    "            for i, t in enumerate(self.ts):\n",
    "                trans_nn_human_x.append( t( torch.tensor(self.nn_human_x[i], dtype=torch.float32).view(1, -1)) )\n",
    "            latent_protos = torch.cat(trans_nn_human_x, dim=0)   \n",
    "        else:\n",
    "            latent_protos = self.prototypes\n",
    "            \n",
    "        # Now we redo the logic that was in self.__transforms()\n",
    "        p_acts = list()\n",
    "        for i, t in enumerate(self.ts):\n",
    "            action_prototype = latent_protos[i]\n",
    "            p_acts.append( self.__proto_layer_l2( t(x), action_prototype).view(-1, 1) )\n",
    "        p_acts = torch.cat(p_acts, axis=1)\n",
    "        \n",
    "        # And the final transformations:\n",
    "        logits = self.linear(p_acts)                     \n",
    "        final_outputs = self.__output_act_func(logits)   \n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bf2441ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Human defined Prototypes for interpretable model (these were gotten manually earlier)\n",
    "human_concepts = {'move_left':    [0.], 'move_right': [1.],\n",
    "                 }\n",
    "human_concepts_list = np.array([l for l in human_concepts.values()])\n",
    "\n",
    "# Get prototypes with means centres\n",
    "p_idxs = list()\n",
    "nn_human_x = list()\n",
    "\n",
    "for i in range(NUM_CLASSES):\n",
    "    idxs = a_train == i\n",
    "    temp_x = X_train[idxs]\n",
    "    mean = temp_x.mean(axis=0)\n",
    "    knn = KNeighborsClassifier().fit(temp_x, list(range(len(temp_x))))\n",
    "    idx = knn.kneighbors(X=mean.reshape(1,-1), n_neighbors=1, return_distance=False)\n",
    "    p_idxs.append(idx.item())\n",
    "    nn_human_x.append( temp_x[idx.item()].tolist() )\n",
    "nn_human_x = np.array(nn_human_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44f92be5",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "5d24f6e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_5104/2813415531.py:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  trans_nn_human_x.append( t( torch.tensor(self.nn_human_x[i], dtype=torch.float32).view(1, -1)) )\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n",
      "Running Loss: 0.40092846301078794\n",
      "Acc.: 67.6705\n",
      " \n",
      "Epoch: 1\n",
      "Running Loss: 0.38817532694816587\n",
      "Acc.: 95.446\n",
      " \n",
      "Epoch: 2\n",
      "Running Loss: 0.3791076073551178\n",
      "Acc.: 97.7035\n",
      " \n",
      "Epoch: 3\n",
      "Running Loss: 0.37539956309318545\n",
      "Acc.: 96.8325\n",
      " \n",
      "Epoch: 4\n",
      "Running Loss: 0.3745260102939606\n",
      "Acc.: 97.758\n",
      " \n",
      "Epoch: 5\n",
      "Running Loss: 0.37292750781059264\n",
      "Acc.: 97.65650000000001\n",
      " \n",
      "Epoch: 6\n",
      "Running Loss: 0.37180168519973755\n",
      "Acc.: 98.476\n",
      " \n",
      "Epoch: 7\n",
      "Running Loss: 0.3702807523727417\n",
      "Acc.: 95.4785\n",
      " \n",
      "Epoch: 8\n",
      "Running Loss: 0.3677257373332977\n",
      "Acc.: 97.005\n",
      " \n",
      "Epoch: 9\n",
      "Running Loss: 0.3705681381702423\n",
      "Acc.: 97.5385\n",
      " \n",
      "Epoch: 10\n",
      "Running Loss: 0.37084455242156983\n",
      "Acc.: 97.9675\n",
      " \n",
      "Epoch: 11\n",
      "Running Loss: 0.36524662796974183\n",
      "Acc.: 96.718\n",
      " \n",
      "Epoch: 12\n",
      "Running Loss: 0.36488855337142945\n",
      "Acc.: 97.813\n",
      " \n",
      "Epoch: 13\n",
      "Running Loss: 0.36430219604492187\n",
      "Acc.: 96.6195\n",
      " \n",
      "Epoch: 14\n",
      "Running Loss: 0.3636164101600647\n",
      "Acc.: 98.854\n",
      " \n",
      "Epoch: 15\n",
      "Running Loss: 0.3627794969367981\n",
      "Acc.: 98.181\n",
      " \n",
      "Epoch: 16\n",
      "Running Loss: 0.3628472620010376\n",
      "Acc.: 98.07249999999999\n",
      " \n",
      "Epoch: 17\n",
      "Running Loss: 0.36213002696990965\n",
      "Acc.: 98.3765\n",
      " \n",
      "Epoch: 18\n",
      "Running Loss: 0.3621215799999237\n",
      "Acc.: 98.341\n",
      " \n",
      "Epoch: 19\n",
      "Running Loss: 0.3610743521785736\n",
      "Acc.: 98.5665\n",
      " \n"
     ]
    }
   ],
   "source": [
    "#### Training\n",
    "model = PWNet().eval()\n",
    "model.nn_human_x.data.copy_( torch.tensor(nn_human_x) )\n",
    "\n",
    "cce_loss = nn.CrossEntropyLoss()\n",
    "mse_loss = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, )\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)\n",
    "best_acc = 0.\n",
    "model.train()\n",
    "\n",
    "loss_data = list()\n",
    "\n",
    "# Freeze Linear Layer to make more interpretable\n",
    "model.linear.weight.requires_grad = False\n",
    "\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "\n",
    "    running_loss = 0\n",
    "\n",
    "    model.eval()\n",
    "    current_acc = evaluate_loader(model, train_loader, cce_loss)[0]\n",
    "    model.train()\n",
    "\n",
    "    if current_acc > best_acc:\n",
    "        torch.save(  model.state_dict(), 'weights/pw_net.pth'  )\n",
    "        best_acc = current_acc\n",
    "\n",
    "    for instances, labels in train_loader:\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        instances, labels = instances.to(DEVICE), labels.to(DEVICE)\n",
    "\n",
    "        logits = model(instances)    \n",
    "        loss = cce_loss(logits, labels)\n",
    "        loss_data.append(loss.item())\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "    print(\"Epoch:\", epoch)\n",
    "    print(\"Running Loss:\", running_loss / len(train_loader))\n",
    "    print(\"Acc.:\", current_acc)\n",
    "    print(\" \")\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df06ec2c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "8923e932",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                    | 0/30 [00:00<?, ?it/s]/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_5104/2813415531.py:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  trans_nn_human_x.append( t( torch.tensor(self.nn_human_x[i], dtype=torch.float32).view(1, -1)) )\n",
      "100%|███████████████████████████████████████████| 30/30 [00:02<00:00, 10.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average reward per episode : 200.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = PWNet().eval()\n",
    "model.load_state_dict(torch.load('weights/pw_net.pth'))\n",
    "\n",
    "accuracy = list()\n",
    "\n",
    "reward_arr = []\n",
    "for i in tqdm(range(30)):\n",
    "    obs, done, rew = env.reset(), False, 0\n",
    "    count = 0\n",
    "    while not done and count < 200:\n",
    "        AgentAction, latent_x, _ = agent.get_action(obs, env.action_space.n, epsilon=0)\n",
    "        A = model( latent_x.view(1,-1) )\n",
    "        A = torch.argmax(A).item()\n",
    "        \n",
    "        obs, reward, done, info = env.step(A)\n",
    "        rew += reward\n",
    "        count += 1\n",
    "        \n",
    "        accuracy.append( AgentAction.item() == A )\n",
    "        \n",
    "        \n",
    "#         env.render()\n",
    "    reward_arr.append(count)\n",
    "\n",
    "\n",
    "\n",
    "print(\"average reward per episode :\", sum(reward_arr) / len(reward_arr))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "7f5336f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9916666666666667"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(accuracy) / len(accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ff875a0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969ccf07",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pong_env",
   "language": "python",
   "name": "pong_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
