{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "wa3UrNHqPG8V"
      },
      "outputs": [],
      "source": [
        "from numpy.core.fromnumeric import compress\n",
        "from numpy.lib.function_base import average\n",
        "import numpy as np\n",
        "import torch\n",
        "from torchvision import datasets\n",
        "from torchvision.transforms import ToTensor\n",
        "from torch.nn.utils import parameters_to_vector, vector_to_parameters\n",
        "from matplotlib import pyplot as plt\n",
        "import math\n",
        "import sys\n",
        "import pickle\n",
        "from datetime import datetime\n",
        "import os\n",
        "import random\n",
        "import copy\n",
        "from torch import optim\n",
        "from torch.autograd import Variable\n",
        "import torch.nn as nn\n",
        "import itertools\n",
        "\n",
        "def get_avg(x):\n",
        "  avgs = []\n",
        "  for i in range(1,len(x)):\n",
        "    avgs.append(sum(x[:i])/i)\n",
        "  return avgs\n",
        "\n",
        "#compute KL divergence\n",
        "def KL(p,q):\n",
        "  if q == 0:\n",
        "    return math.log(1/(1-p), 2)\n",
        "  if q == 1:\n",
        "    return math.log(1/p, 2)\n",
        "  if p == 0 or p == 1:\n",
        "    assert(q==p) \n",
        "    return 0\n",
        "  return p*math.log(p/q,2)  + (1-p)*math.log(((1-p)/(1-q)), 2)\n",
        "\n",
        "#Get accuracy of model on data\n",
        "def getAcc(net, x, labels, get_preds = False):\n",
        "    net.eval()\n",
        "    with torch.no_grad():\n",
        "      output = net(x)    \n",
        "      pred_y = torch.max(output, 1)[1].data.squeeze()\n",
        "      accuracy = (pred_y == labels).sum().item() / float(labels.size(0))\n",
        "    net.train()\n",
        "    if get_preds is True:\n",
        "      return accuracy,  (pred_y== labels)\n",
        "    return accuracy\n",
        "\n",
        "\n",
        "#plot accuracy discrepancy vs accuracy\n",
        "def plot_two(x,y, batch_size, num_hidden):\n",
        "  y = [z*100 for z in y]\n",
        "  fig, ax_left = plt.subplots()\n",
        "  ax_right = ax_left.twinx()\n",
        "\n",
        "  fontsize = 13\n",
        "  ax_left.title.set_text(\"batch size: %d, num hidden: %s\" % (batch_size, num_hidden))\n",
        "  ax_left.set_xlabel('Epoch', fontsize=fontsize)\n",
        "  ax_left.set_ylabel('Average accuracy discrepancy', color='black', fontsize=fontsize)\n",
        "  ax_right.set_ylabel('Accuracy (%)', color='red', fontsize=fontsize)\n",
        "\n",
        "  ax_left.plot(x, color='black')\n",
        "  ax_right.plot(y, color='red')\n",
        "\n",
        "#Fully connected network with 1 hidden layer, GELU activation\n",
        "class FNN(nn.Module):\n",
        "    def __init__(self, hidden_size = 10):\n",
        "        input_size = 784\n",
        "        num_classes = 10\n",
        "        super(FNN, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_size, hidden_size) \n",
        "        self.gelu = nn.GELU()\n",
        "        self.fc2 = nn.Linear(hidden_size, num_classes)\n",
        "        torch.nn.init.uniform_(self.fc1.weight)\n",
        "        torch.nn.init.uniform_(self.fc2.weight)\n",
        "    \n",
        "    def forward(self, x):\n",
        "        x = x.reshape(-1, 28*28)\n",
        "        out = self.fc1(x)\n",
        "        out = self.gelu(out)\n",
        "        out = self.fc2(out)\n",
        "        return out\n",
        "\n",
        "device = torch.device(\"cuda\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "#experiment params\n",
        "batch_size = 100 # in the paper we use 50, 100, 200\n",
        "num_hidden = 5 # in the paper we use 2,5,10\n",
        "num_epochs = 300 # in the paper we use 300\n",
        "rand_labels = True # random label flag\n",
        "\n",
        "train_data = datasets.MNIST(\n",
        "    root = 'data',\n",
        "    train = True,                         \n",
        "    transform = ToTensor(), \n",
        "    download = True           \n",
        ")\n",
        "n = 60000 # input size\n",
        "\n",
        "if rand_labels:\n",
        "  train_data.targets = torch.randint(10, train_data.targets.shape)\n",
        "from torch.utils.data import DataLoader\n",
        "loaders = {\n",
        "    'train' : torch.utils.data.DataLoader(train_data, \n",
        "                                          batch_size=batch_size, \n",
        "                                          shuffle=True,\n",
        "                                          num_workers=0)\n",
        "}\n",
        "\n",
        "net = FNN(num_hidden)\n",
        "net = net.to(device)\n",
        "net.train()\n",
        "print(net)\n",
        "loss_func = nn.CrossEntropyLoss()   \n",
        "\n",
        "learning_rate = 0.01\n",
        "optimizer = optim.SGD(net.parameters(), lr = learning_rate, momentum=0)   \n",
        "\n",
        "epoch_accs =[]\n",
        "batch_accs = []\n",
        "pref_accs = []\n",
        "epoch_discs = []\n",
        "X = None\n",
        "Y = None\n",
        "for epoch in range(num_epochs):\n",
        "    Xf = None\n",
        "    Yf = None\n",
        "\n",
        "    batch_accs.append([])\n",
        "    pref_accs.append([])\n",
        "    prog_disc = 0\n",
        "\n",
        "    for i, (images, labels) in enumerate(loaders['train']):\n",
        "      images = images.to(device)\n",
        "      labels = labels.to(device)\n",
        "      if Xf is None:\n",
        "          Xf = images.to(device)\n",
        "          Yf = labels.to(device)\n",
        "      else:\n",
        "          Xf = torch.cat((Xf, images))\n",
        "          Yf = torch.cat((Yf, labels))\n",
        "      \n",
        "      b_x = Variable(images).to(device)   # batch x\n",
        "      b_y = Variable(labels).to(device)   # batch y\n",
        "      optimizer.zero_grad()  \n",
        "      output = net(b_x)               \n",
        "      loss = loss_func(output, b_y)\n",
        "  \n",
        "      loss.backward()             \n",
        "      optimizer.step() \n",
        "  \n",
        "      batch_acc = getAcc(net, b_x, b_y)\n",
        "      batch_accs[-1].append(batch_acc)\n",
        "      \n",
        "      pref_accs[-1].append(getAcc(net, Xf, Yf))\n",
        "      prog_disc += KL(pref_accs[-1][-1], batch_accs[-1][-1])\n",
        "      \n",
        "    if X is None:\n",
        "      X = Xf.clone().detach()\n",
        "      Y = Yf.clone().detach()\n",
        "    train_acc = getAcc( net, X,Y)\n",
        "    \n",
        "    epoch_accs.append(train_acc)\n",
        "    epoch_discs.append(prog_disc)\n",
        "    print('epoch', epoch, 'disc', epoch_discs[-1], 'train acc', epoch_accs[-1])\n",
        "\n",
        "\n",
        "plot_two(get_avg(epoch_discs), get_avg(epoch_accs), batch_size, num_hidden)\n",
        "plt.show() \n",
        "        \n",
        "        \n"
      ],
      "metadata": {
        "id": "IyK4I0DPQahE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "OiINygDwZ_fo"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}