{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "# %load_ext lab_black\n",
    "\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.insert(1, f\"{os.getcwd()}/backend\")\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set()\n",
    "sns.set_context(\"paper\")\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.backend as K\n",
    "\n",
    "from tf_backend.tf_models import BigConvNet, OneLayerNet\n",
    "from tf_backend.tf_losses import pinball_loss_with_scores_keras, dependent_label_quantile_loss_keras, pinball_loss_keras\n",
    "from tf_backend.tf_utils import *\n",
    "from tf_backend.tf_metrics import *\n",
    "from tf_backend.tf_constraints import *\n",
    "\n",
    "import np_backend.conformal_utils as cf_utils\n",
    "\n",
    "from cifar10_experiment import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_dataset = \"mnist\" # \"cifar10\" # \"mnist\"\n",
    "\n",
    "data_folder = \"datasets/CIFAR10\"\n",
    "\n",
    "path_to_scratch = os.getcwd()\n",
    "path_to_data = path_to_scratch + data_folder\n",
    "\n",
    "n_classes = 10\n",
    "\n",
    "class_names = [\n",
    "    \"airplane\",\n",
    "    \"automobile\",\n",
    "    \"bird\",\n",
    "    \"cat\",\n",
    "    \"deer\",\n",
    "    \"dog\",\n",
    "    \"frog\",\n",
    "    \"horse\",\n",
    "    \"ship\",\n",
    "    \"truck\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Standard CIFAR-10 (or MNIST) Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(5, shape=(), dtype=uint8)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADHCAYAAACzzHd1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADsFJREFUeJzt3X2UVPV9x/H37uzyvC7ZxSdECEX8EUPUVNRwoOBDasDWHk0C1cbEEzRCzSltcuyhEqsYQSjx2CBNQ9Rg1cQaMSbnCEoagqvkJGw2npRUxS+K1hAeFJYHed6d3ekfc2fcsr/lN1xmZmdnP69/uPPlzv39ZuGzd+6dO99bkUqlEJGuVXb3BERKnUIiEqCQiAQoJCIBColIgEIiEqCQiAQoJCIBColIgEIiElAV94nOua8AXwZagBlm9nYOT9M1MFJKKnJZKVZInHN1wC3ABOCTwCJgei7PTVQPBaBx/Qtc+qmpcYaPTWNqzIy21m05rxv37dYlQIOZJc2sCXAxtyNS8uK+3aoD9nR4nHPYGte/AMDHxozOLheLxtSYccQNyR7g/A6P23J9Ymb32NN2zxqzvMY8kbdbcUPSCNzlnEsAFwBvxtyOSMmLFRIz2+2cewxYB7QCN+d1ViIlJPYpYDNbBizL41xESpI+TBQJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCQgdiMIyZ9EZcJbP21AbaztVVdWceaguuzjH/cf7V2vZsDRTrVhV/nbNV+8Yqe33nhlDQCDBvej+XMfNvIcuPQR7/rtB/Z46xsmzPfWx+/8jbdeTNqTiAScTFf5g0BT9HCRma3Oz5RESsvJvN16x8wuy9dEREpVRSoV75Yhzrl9wAZgCzDbzJpzeFrqt69sANLNjje+UdzuqKU6ZkUXt8mo6uJYJWS0+xPe7HC7mFGVfb3rJSo7/9v3OcX//2HjnqS3PqYmPcfEWSNo2/putl45fIR/cu3+ttGH39juHzd50L8dTu7fc9xFF0CO9yc5mZAMMbNdzrkZwCVmNiuHp6V0f5LO8n3gvmrtj/iLK/46+7goB+4LlnHgGx/+FyjGgXseGmYX7iY+AGa2K1p8CpgZdzs9xXl1w731/pV9vPUvVaXXH15dw5IzrgBg2jlbvOv2G1Ht3/ai75zoNAGoqh/B2xt/HOu5ba+v89abWp721vvNXwpAZf0IBnz7w6637c1/9K6ffP4H3vqy6nh7zWKIdXbLOTcw6igPMBl1lZcyFndPMgZ4xDn3AXCU9K3hRMpS3FsvvEL6XokiZU8fJooEKCQiAbp26xifOeNCb/3ZVbO99crTRh53e1X1Q7m16Z9Pel4F0dbaqfTwTQ3eVZsrh/i38fN7ALj1ufk8dM092fLrKf/nG1uTH3jrjTvtOBPtXtqTiAQoJCIBColIgEIiEqCQiATo7NYxfn/gD956astG/xMCZ7cKIbliibfe/scdACT+bjEtS+/I1qtnzfWunzq8v1PtazvWxprTNa37mb+tIdZzS532JCIBColIgEIiEqCQiAQoJCIBOrt1jO0HdnvrC25r8tZnn/sTb73h1WEAfHr1fNZMSV/TdO3v7j6huSSf938Fdtjcn3vr+46kr5dqnH6QSx9Yn61P+uEd3vWfGt1yQvPprbQnEQlQSEQCFBKRAIVEJEAhEQnQ2a0cLdzW4K0/tPcUb3334c0ArE/u52+aXwJg6w23etf9yH8+5K1/d56/q2HmLFauXn7vNW996HsntJleKxgS51w10AB8HLjFzJ5xzg0BngBqgDVmNq+QkxTpTrm83UoCnwe+3aE2B1huZhOBi51z5xViciKlIBgSM0uZ2bH7/YnAymh5JTAp3xMTKRVxj0kGmtnhaHkvkPOXKhrXvwCkO4JnloulEGN21fk9GXVP/9iY0az/9fMA1A3qort7vb8D+xdWftNbn9z6j8edU7n8bEtlzLghOeSc62dmR4BawH8th0emC3ipdngv5JgbR53vXeejDf4vUdU810Wn9Xte9NbbU+2dxiyWnjZm1FU+J3FPAa8Dro6Wp0aPRcpSTnsS59zTwDjggHPuEmAx8Lhz7uvAWjPzn2MUKQM5hcTMpnvKxd23inQTfeIuEqCQiATospQi+tTWzd76lu/7z2JV33ynt37rw53vdQiwbOsv401Mjkt7EpEAhUQkQCERCVBIRAIUEpEAnd0qoq6+LDX53/7XW//ldf7ri+5ffpW3/s1/TTfMrj2lL+9POSdbf7FpmHf963e/1KmWSqW86/Zm2pOIBCgkIgEKiUiAQiISoJCIBOjsVgn43S7/NV33TX3YW5/7/Axvvebh9PqV9SOyywB/1cW4/3VB52vDZrW/7V13815/e6PeQHsSkQCFRCRAIREJUEhEAhQSkQCd3SphC7Y1eOu/+swub/2nMz4CQGLmfbR8765svc9t/iZ3Ezd0/kZkw2dnede98i3/79NNe7Z66+UkbsPsecA0YCeww8yuL+QkRbpTLnuSTMPsmcfU7zazZ/I/JZHSErdhNsCdzrl1zrkbCjAvkZJRkev3B6K3WK9Gb7fqzazZOVcLrAWuNbMtOWwm9dtXNgDpZscb33gz5rTjKZcxa6r6e+vnDEk376489Szad354rFBx2lk5bzv51rve+qajbd76kbZWoOf9bMdddAFARS7rxjpwN7Pm6M99zrlfAGOBXELSqxtm58vlp4/11jMH7v1n3sfh783N1rs6cPfZdbP/wH3aWwe89cyBe0/72Z5Iw+xYIXHO1UYBqQLGA4/E2Y7E8+J7r3rrox8cDMDq6w4z5cH/ydbvePRe7/qzfvONTrUhzy7zrtu0bJ63XrtAZ7cAb8PsuujuVgngSTPbVMA5inSrk2mYLdIr6BN3kQCFRCRAIREJ0LVbZeT9g3sBaG1vyy4DfO3gWu/6s1pv71zs4/8Mps+MOd76TY+2ADCkeiA3DR2frT+27dc5zbkn0J5EJEAhEQlQSEQCFBKRAB2490BXnP4Jb/1fKvoC4KoH0HTGuGz93Bur/Rvq4iDdJ9nwI2/9ie2NAHy19WB2udxoTyISoJCIBCgkIgEKiUiAQiISoLNbJeDSU523/vQw/7dL6x/4oreeOOdiAKrqz2Zs07fiTSZ51Ftuf8389VQ7AKkOy+VGexKRAIVEJEAhEQlQSEQCFBKRAJ3dKpCRtWcA0DdRnV1+MHGud93Ll0/w1hMXXlWYyUWSP/33TrV7F77vXXfxtvUFnUspy6Vh9njgAaAFOAB8IXreE0ANsMbM5hVwjiLdKpe3W+8CV5rZZOA54KvAHGC5mU0ELo56cImUpeCexMw69oNsId1lfiKQuQHGSmAS8HreZydSAk6kYXY98DNgCrDWzM6P6jcAI83svhw202saZvdNpL/DMerckWze9A4Aw+nrXbdm5CD/RgbUxhq7ItGHVFtLcL3Uns7HH9t3JL3r7mjdf9xt9bR/z7w3zHbODQBWALPNbJdz7pBzrp+ZHQFqgd25Tq63NMzOHKz/ZM0PuO7TNwLHOXB/LL8H7lX1I0g2+7vDd+Q7cP9OlwfuLx13Wz3t3zOvDbOjpthPAUvN7FdReR1wNfAsMBWY28XTy8ZHa0/31q8YOMpbX3L/hQD0HT6Y33/3swBU/dm0wkwuklyxBIDE9DnZZYBF9+/1rr9w+8udauV6/dXJyGVPcgPpY45TnHN/D6wCFgOPO+e+Tvqt12sFnKNIt8rlwP0J0qd7j1XcfatIN9En7iIBColIgEIiEtBrr906c1Cdt/7adWd6630+d7W3nrjIX8+oqKmLfVar9cn7vfWFSw5660t3pfteNUz4Wy77p59l6wdaDscaX9K0JxEJUEhEAhQSkQCFRCRAIREJKJuzW9eeOc5bf/SqI9nlgfX92XPTWACqb7zeu35ijP9iw3xI7d/lrdufL/TWL2v297rad8R/diujLdWuM1p5pD2JSIBCIhKgkIgEKCQiAQqJSEDZnN26t4//u9n95i/NLlfUD/9/j09E8mX/PQPfucN/n8BkMv37Z/RzD/DmNbcDMGXfZu+67x/0f3NQSoP2JCIBColIgEIiEqCQiASUzYH7J979b/9fnH15drFbekO1HuLSba8UdUzJr7gNs/8BmAbsBHaYmf9CKJEykMueJNMw+5BzbhbphtkAd5vZM4WbmkhpCB6TmNk2MzsUPcw0zAa40zm3LuoFLFK24jbMTplZs3OuFlgLXGtmW3LYTK9pmK0xS3vME2mYnVNIoobZK4E7O/QDzvzdYuBFM3shh/FSieqhQM9rsKwxy2vMqGF2TiEJvt3yNcyO9iCZvxsP+K+3ECkDcRtmu+juVgngSTPbVMA5inSrk2mYLdIr6BN3kQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkYCcrwLOk9TOnfsBGDx4AHv3Hgqsnl8aU2NmnHpqDeTzKuA8KupgIgE5haTY33HPaVIipUTHJCIBColIgEIiEqCQiAQoJCIBColIQNHbnDrnvgJ8mXQPrxlm9naRxj0INEUPF5nZ6gKMUQ00AB8HbjGzZ5xzQ0h//bkGWGNm84ow5jwK2GGzi66eVRT2dXZbJ9GihsQ5VwfcAkwAPgksAqYXafh3zOyyAo+RBD4PzOxQmwMsN7MVzrlVzrnzzOz1Ao8Jhe2w6evqOZjCvs5u6yRa7LdblwANZpY0sybAFXHss51zLzvnfhg12ss7M0uZ2fZjyhNJ9ywj+nNSEcaEAnbY7KKrZ6FfZ7d1Ei12SOqAPd00/igzmwT8AlhQxHEHmtnhaHkv6Z9BoS01swuBvwRud86dXYhBol82twHfp0iv85gxi/I6ix2SPaR3yxltxRrYzHZFi0+RfqtXLIecc/2i5Vpgd6EHNLPm6M99pH8pjM33GFFXzxXA7OhnW/DXeeyYxXidUPyQNAKTnXMJ59yfAkVpHuucG+icS0QPJxdr3Mg64OpoeWr0uKAK3WHT19WTAr/O7uwkWuyrgIkOur4EtAI3m9lbRRjzIuAR4APgKOmzQH8o0FhPA+NIn4FZDXwLeJz0WZ+1ZnZXEcasAzp22FyS5/G+CCwFMndOWgX8BwV8nV2M6Sjg68woekhEehp9mCgSoJCIBCgkIgEKiUiAQiISoJCIBCgkIgEKiUjA/wFKIEqs6eUiEwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=uint8)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADHCAYAAACzzHd1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADs9JREFUeJzt3X2UVPV9x/H3srs8rbC4C/JgoqFZ8zMmgsqTHCnYxNSIaTXH6NF4rFHxITWBlKZiUyLQ+kDJISdq26SnqIk0asVET4RoehQ1WMMWSUJiwG8CEsUqkYcFhAX2afrH3F02u7/hN3N35s7s7uf1D3e+e2d+vxn2s3funTvfW5ZKpRCRzAYUewIipU4hEQlQSEQCFBKRAIVEJEAhEQlQSEQCFBKRAIVEJEAhEQmoiHtH59yNwHVAE3C9mb2Rxd10DoyUkrJsVooVEudcDTAHOA84G1gKXJHNfcsrxwFQv/4Zpp17UZzhY9OYGrNda/M7Wa8b9+3WVOBFM2sxsw2Ai/k4IiUv7tutGqCh0+2sw1a//hkAPnr6aR3LSdGYGjOOuCFpACZ0ut2a7R3bN4+9bfOsMfvWmLm83YobknrgDudcOTAR+F3MxxEpebFCYmZ7nXPfA9YBzcANeZ2VSAmJfQjYzL4DfCePcxEpSfowUSRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJiH2qvJSe2WPOBqC6cmjHMsBDE973rj/8gf/oVjtww43eda/91Qne+rM7f5nrNHsdbUlEAhQSkQCFRCRAIREJUEhEAnR0qxf61JiJ3vqqNV8GYGDdSR3LAANGnep/oFRbt9LwFf/uXfUHf9jurVdNmXO8qfYJ2pKIBPSkq/whYEN0c6mZPZufKYmUlp683dpuZufnayIipaoslYp3yRDn3H5gE7ADmGtme7K4W+rVjZuAdLPjLa8n2x21r4w5vHKIt15XNxqAskFVpI4eOvaDioE9H7SlyVv++ZbfA73vtZ08aSJkeX2SnoRkpJntds5dD0w1s1uyuFtK1yfpuUw77k+tmQfAwLrpNG39WUc94457DtoCO+697bWNGmYX7iI+AGa2O1p8DLg57uNIZpeMneStf3/lZd76gJPGpxcqBx1bBu9RLIC23Tu61VJH/Od5lX/gDG/9srFTADixsqpjGWD1rk3e9Y9m2CKVslhHt5xzVVFHeYBZqKu89GFxtySnAyuccweAo6QvDSfSJ8W99MJG0tdKFOnz9GGiSIBCIhKgc7cSNGzQUG/94tozvfUVD/oPb5Z/ZFpe5tP2yo+71Zbfvcu77oL1/qNb//m/ywCoqP1AxzJA/Vn/4F3//D2v5DrNotOWRCRAIREJUEhEAhQSkQCFRCRAR7cS9Mafn+ytn/Cvdyc8k7SKv+x+yt3If7rLu27zCn+98save+tnfiLDSeGrsptbKdGWRCRAIREJUEhEAhQSkQCFRCRAR7cKxNfhfejiL/tXLsvtb1XzA3d66y/d1wzAeT+5i/+5cFFH/YJNS7zrt/6uvlvtqbLdnjVhx4oab33JLcd+hcrKO/06Dcjqm7G9grYkIgEKiUiAQiISoJCIBCgkIgE6utVDuXR4z6W7O8DRZQu89Q+teN1bv672HADOHABrOzV5rJro/5bgtS3dO0HtOOD/ZuJz3iosbm3pWE51Wh7ydf8RtU+/dI+3XsrXXgyGxDlXCbwIfAyYY2ZPOOdGAiuBYcBzZra4kJMUKaZs3m61AJ8DvtWptgB40MxmAFOcc/4vQIv0AcGQmFnKzN7tUp4BrI6WVwMz8z0xkVIRd5+kyswOR8v7gPHHW7mz+vXPAOmO4O3LSSnEmJk6vA/s1OF9YN30dDHH7u7l85d7689fc8Rbr61Id2MZXXcy83507PsfI/y7PPyI7o/T1Gm/IhsVten9rLLygR3L6UHHedf/1pp7vfX9zY05jQvJ/Q7FDUmjc26wmR0BqoG92d6xvQt4b+tCPnXUR7z1H8/wr99yX/qyahW1p9JyYCcArTte866bev3n3vryhb/31he/88JxZprsa3t4R3ouFbWn0rLnzWM/yHAwYuSdX/TWT3/ytzmPnYeu8lmJewh4HTA7Wr4oui3SJ2W1JXHOPQ5MBg4656YCy4CHnXPzgbVm9psCzlGkqLIKiZld4Skn+15JpEj0ibtIgEIiEqDTUroYUjnIW//vy0/w1gfd/g1vvW1PdKm16rEdy7dd/rh33acObPbWh1X4Dy/3ZgPPPc3/gxhHt5KiLYlIgEIiEqCQiAQoJCIBColIgI5udTF75ARvfdDtS3N6nFtmfxuARU9/iCV/kV5e+c7PejY5KQptSUQCFBKRAIVEJEAhEQlQSEQCdHSri+/+nf+SbZmaWmdqXt1+JOtLzYf69FGtzk2yOy+nMn0NuKz3NdLWlkQkQCERCVBIRAIUEpEAhUQkoN8e3Zo/bpa3Xn7xF/x3yNBH6vn7cmvm1tekMjTMzvR6NTy6tdBTyru4DbMXA5cDu4CdZnZlIScpUkzZbEnaG2bf3KW+yMyeyP+UREpL3IbZAAudc+ucc1cVYF4iJaMslUpltWL0Fuu16O1WrZntcc5VA2uBS81sRxYPk3p14yYg3ex4y+vdLyJTSJ3HHF05zLvOyW60/84ZPnE/sNn39wO2tr7fbcykJDnmORMckG6YnWptOvaDDL9XLdve8tZ/1ehvAn48PXmekydNBMjq4/9YO+5mtif6d79z7nng40A2ISmZhtmZdtzvfOFvvPWyIf5Q/fTCf/TWP7v3pW5jJqWUG2a/d+2t3vq0X/iv3nU8STXMjhUS51x1FJAKYDqwIs7jFFN1hneamcLQuv0X3vpX20q3X1QcmfqO1Y/N7TpNTd9e7K1P3PJ2rlMqurgNs2uiq1uVA4+YWd/6TRHppCcNs0X6BX3iLhKgkIgEKCQiAf323K2cHTrgLW/b5/+cpDfwHcl6eeSZ3nVPe/mfvfXWzekrAVZMrulYBvj7h/1j7jtyMMdZFp+2JCIBColIgEIiEqCQiARoxz1L7922qthTiO1TYyZ6649OP9ytVnWvfwf9yKKveOsnPvRrAOrXT2HahXfEnGFp05ZEJEAhEQlQSEQCFBKRAIVEJKDfHt0akOlbyxm+pnvS8gxf5S+hIzpP18wE4LTyEzqWAT6x5gve9QecNL5b7eAXb/KuO2p1sl9BLiXakogEKCQiAQqJSIBCIhKgkIgE9NujW22Z2pJl6BdVXjfZW98+4aPe+k3vDQRgeOWQjnOn3m7a6133ysF13vq8S/d76xWXXOqf4xl/mv557clcsGlJR7311y94129c2v08rZs3jvCu259l0zB7OvBNoAk4CFwd3W8lMAx4zswWF3COIkWVzdutN4FPmtks4GngVmAB8KCZzQCmRD24RPqk4JbEzDr3g2wi3WV+BtD+KdpqYCawOe+zEykBuTTMrgV+AnwaWGtmE6L6VcB4M7s7i4cpmYbZ4zI0zB5z+lj/nTO8Ts3b/C2Q32xJ7/ScUncKb21NN4lubmv1rnviAH9r0dEj/OuXjciw3xC1aO3WvPqwv4lF2+6G7vM+WO5dt6H5kH/MSLH/P3OV94bZzrmhwCpgrpntds41OucGm9kRoBrw75F6lErD7CXj/sy7zm3rF/rv3HzUW/7DNX/rrc+NdtzvW3Mvcy+eByS54/7Hzasz7bgffei/utVuf3m4d90fvrvBW29X7P/PXOW1YXbUFPsx4H4zeyUqrwNmAz8ELgK+lvs0e5kMjaTHrP4Xb/3JbRsBGDS+licfvRqA1C5/s+iKcy/JwwSh+dHlAJR//msdywD2zZ3e9afstLyM29dlsyW5ivQ+x3Dn3DxgDbAMeNg5N5/0W6/fFHCOIkWVzY77StKHe7tKdtsqUiT6xF0kQCERCVBIRAL67blbjxz2H1+f/wP/0aqKy76U0+OX101JLwyuOrb84Uk5PUbr2/7PZ9+76d+89T95bQsA9TNvZdptq3MaSzLTlkQkQCERCVBIRAIUEpEAhUQkoN8e3bIG/3lUZy3xn3n7yD3/561PePWeHs/lrQvmeetXN/hPqnx1V//tgVUM2pKIBCgkIgEKiUiAQiISoJCIBPTbo1uZbNv3rrc+DX+dUz553McrxtdaJb+0JREJUEhEAhQSkQCFRCRAIREJiNsw+yvA5cAuYKeZXVnISYoUUzaHgNsbZjc6524h3TAbYJGZPVG4qYmUhuDbLTN7x8wao5vtDbMBFjrn1kW9gEX6rLgNs1Nmtsc5Vw2sBS41M3/n6D9WMg2zNWb/HjOXhtlZhSRqmL0aWNipH3D7z5YBL5jZM1mMlyqvHAf0vgbLGrNvjRk1zM4qJMG3W76G2dEWpP1n04FtsWYq0gvEbZjtoqtblQOPmNlvCzhHkaLqScNskX5BHyaKBCgkIgEKiUiAQiISoJCIBCgkIgEKiUiAQiISoJCIBGR9FnCepHbteh+AESOGsm9fY2D1/NKYGrPdqFHDIJ9nAedRooOJBGQVkqSb02U1KZFSon0SkQCFRCRAIREJUEhEAhQSkQCFRCQg8euTOOduBK4j3cPrejN7I6FxDwEboptLzezZAoxRCbwIfAyYY2ZPOOdGkv768zDgOTNbnMCYiylgh80MXT0rKOzzLFon0URD4pyrAeYA5wFnA0uBKxIafruZnV/gMVqAzwE3d6otAB40s1XOuTXOuTPMbHOBx4TCdtj0dfUcQWGfZ9E6iSb9dmsq8KKZtZjZBsAlOPYHnXM/dc59P2q0l3dmljKzrpfEmkG6ZxnRvzMTGBMK2GEzQ1fPQj/PonUSTTokNUBDkcb/sJnNBJ4H7kpw3CozOxwt7yP9GhTa/WZ2FvAZ4KvOuQ8WYpDoj81fAw+Q0PPsMmYizzPpkDSQ3iy3a01qYDPbHS0+RvqtXlIanXODo+VqYG+hBzSzPdG/+0n/Ufh4vseIunquAuZGr23Bn2fXMZN4npB8SOqBWc65cufcOUAizWOdc1XOufLo5qykxo2sA2ZHyxdFtwuq0B02fV09KfDzLGYn0aTPAiba6foroBm4wcy2JjDmJGAFcAA4Svoo0FsFGutxYDLpIzDPAt8AHiZ91Getmd2RwJg1QOcOm/fmebxrgPuBX0alNcB3KeDzzDCmo4DPs13iIRHpbfRhokiAQiISoJCIBCgkIgEKiUiAQiISoJCIBCgkIgH/D3uLNPzX7O7KAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data_images = {}\n",
    "data_labels = {}\n",
    "\n",
    "if my_dataset == \"cifar10\":\n",
    "    (data_images[\"train\"],\n",
    "     data_labels[\"train\"]), (data_images[\"val\"], data_labels[\"val\"]\n",
    "                            ) = tf.keras.datasets.cifar10.load_data()\n",
    "\n",
    "else:\n",
    "    (data_images[\"train\"],\n",
    "     data_labels[\"train\"]), (data_images[\"val\"], data_labels[\"val\"]\n",
    "                            ) = tf.keras.datasets.mnist.load_data()\n",
    "    \n",
    "cifar_tf_ds = {}\n",
    "\n",
    "for dataset in [\"train\", \"val\"]:\n",
    "    cifar_tf_ds[dataset] = tf.data.Dataset.from_tensor_slices(\n",
    "        (data_images[dataset], data_labels[dataset])\n",
    "    )\n",
    "\n",
    "\n",
    "def img_cifar_preprocessing(training=True, my_dataset=\"cifar10\"):\n",
    "    def process(img, label):\n",
    "        if training:\n",
    "            if my_dataset == \"cifar10\":\n",
    "                img = tf.image.random_flip_left_right(img)\n",
    "                img = tf.image.random_crop(img, [28, 28, 3])\n",
    "                img = tf.image.resize_with_crop_or_pad(img, 32, 32)\n",
    "            else:\n",
    "                img = tf.image.random_crop(img, [28, 28])\n",
    "                \n",
    "        if my_dataset == \"mnist\":\n",
    "            img = tf.reshape(img, [28, 28, 1])\n",
    "                \n",
    "        img = img / 255\n",
    "        return img, label\n",
    "\n",
    "    return process\n",
    "\n",
    "\n",
    "for x, y in cifar_tf_ds[\"train\"].take(2):\n",
    "    print(y)\n",
    "    plt.figure(figsize=(5, 3))\n",
    "    plt.imshow(x.numpy())\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "cifar_tf_ds[\"train\"] = cifar_tf_ds[\"train\"].map(\n",
    "    img_cifar_preprocessing(training=True, my_dataset=my_dataset)\n",
    ")\n",
    "cifar_tf_ds[\"train\"] = cifar_tf_ds[\"train\"].shuffle(buffer_size=1000)\n",
    "cifar_tf_ds[\"train\"] = cifar_tf_ds[\"train\"].batch(batch_size=batch_size)\n",
    "cifar_tf_ds[\"train\"] = cifar_tf_ds[\"train\"].prefetch(\n",
    "    buffer_size=tf.data.experimental.AUTOTUNE\n",
    ")\n",
    "\n",
    "cifar_tf_ds[\"val\"] = cifar_tf_ds[\"val\"].map(\n",
    "    img_cifar_preprocessing(training=False, my_dataset=my_dataset)\n",
    ")\n",
    "cifar_tf_ds[\"val\"] = cifar_tf_ds[\"val\"].batch(batch_size=batch_size)\n",
    "\n",
    "\n",
    "def one_hot_cifar10_preprocessing(img, label):\n",
    "    return img, tf.one_hot(tf.reshape(label, [-1]), n_classes, axis=-1)\n",
    "\n",
    "\n",
    "cifar_tf_ds[\"train\"] = cifar_tf_ds[\"train\"].map(one_hot_cifar10_preprocessing)\n",
    "cifar_tf_ds[\"val\"] = cifar_tf_ds[\"val\"].map(one_hot_cifar10_preprocessing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load New CIFAR (or QMNIST) Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'mnist'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>,\n",
       " 'val': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if my_dataset == \"cifar10\":\n",
    "\n",
    "    cifar_testv2_X = np.load(\n",
    "        path_to_data + \"/CIFAR-10.1/datasets/cifar10.1_v6_data.npy\"\n",
    "    )\n",
    "    cifar_testv2_y = np.load(\n",
    "        path_to_data + \"/CIFAR-10.1/datasets/cifar10.1_v6_labels.npy\"\n",
    "    )\n",
    "\n",
    "else:\n",
    "    \n",
    "    images_path = f'{os.getcwd()}/datasets/QMNIST/actual_data/qmnist-test-images-idx3-ubyte'\n",
    "    labels_path = f'{os.getcwd()}/datasets/QMNIST/actual_data/qmnist-test-labels-idx1-ubyte'\n",
    "\n",
    "    import idx2numpy\n",
    "    X = idx2numpy.convert_from_file(images_path)\n",
    "    y = idx2numpy.convert_from_file(labels_path)\n",
    "\n",
    "    cifar_testv2_X = X[10000:,:]\n",
    "    cifar_testv2_y = y[10000:]\n",
    "    \n",
    "cifar_testv2_y_onehot = np.zeros((cifar_testv2_y.shape[0], n_classes))\n",
    "cifar_testv2_y_onehot[np.arange(cifar_testv2_y.shape[0]), cifar_testv2_y] = 1\n",
    "\n",
    "cifar_tf_ds[\"testV2\"] = tf.data.Dataset.from_tensor_slices(\n",
    "    (cifar_testv2_X, cifar_testv2_y)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(3, shape=(), dtype=uint8)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADHCAYAAACzzHd1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADnlJREFUeJzt3X+UVOV9x/H37rD8Wje7/NSyGEuEPGKMqCDUikLjOU2xaPwjGj3WnIZA9JhT0pPKIVUKFDVH9ITGepKiIdZoNZ5oczwVgkmUUsHIBjhKjgl+K8amBjb8XiiswO46/ePeHZB9lufu7Ny7s7uf1z9757vPzPPMwGfvzJ2Z763I5/OISOcqe3oBIuVOIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJGFDsFZ1z84AvASeAOWb22wRX02dgpJxUJBlUVEicc8OBucCVwKXAA8BNSa6bqxoDQMOmtUz7k1nFTF80zak527W17Eo8ttinW1OB9WbWamabAVfk7YiUvWKfbg0HDp5yOXHYGjatBWDiBRMK21nRnJqzGMWG5CBw8SmX25JesX332Nt2z5qzb83ZladbxYakAVjsnMsBk4B3irwdkbJXVEjM7IBz7gfABqAF+HJJVyVSRoo+BGxmK4GVJVyLSFnSm4kiAQqJSIBCIhKgkIgEKCQiAQqJSIBCIhKgkIgEKCQiAQqJSIBCIhKgkIgEKCQiAQqJSEDRH5XvLcbVnlPYHpSrKlxelTvfO/7Ci3d36farZ/m/3n90rQEwrGYQjTPHA3Bo5xDv2FsONnvrb+x7t0trkXRoTyISoJCIBCgkIgEKiUiAQiIS0OePbs06a0Jhu7ZyUOHylZu/meq8A2+Lfubq6hn2zL8CMKyTsRsb/R2Z2tb80Fu/dMVb3vq7TY1dWqMkoz2JSEB3usofBTbHFx8ws5dKsySR8tKdp1vvmdnMUi1EpFxV5PPFnTLEOXcI2Aa8D8w3s/0JrpbfsnUbEDU73v52+t1RR1edVdg+e3w9u3fsBGDsxPrU5waoyA0k33bizINajnvL+UMHvPXf7P7AWz/e1gJk99ieqrfNOWXyJEh4fpLuhGSkme1zzs0BpprZHQmuls/6/CR31l9V2F7wH9/koevvBmBFyi/c2+Xq6mlr2nnGMW0lfuHe25pX98ScccPs9E7iA2Bm++LNZ4Hbi72dtJ2dP3kXB1Dxkcs+rb96xVtv/va/eeuPb/v4GW/vttXLeGr2EgDGtvj/IM0Y5+9wPuqFx7z1Ny/4d2+95qZ/PuNapDhFHd1yzlXHHeUBZqCu8tKHFbsnuQBY5Zw7DBwnOjWcSJ9U7KkXthKdK1Gkz9ObiSIBColIQJ//7Nb9uzcWtme1HClcHnbpEu/4x1r83wZ868DvOplhxxnn/7OWBXyjcd2ZF7nPXz789EPe+sBbF3jrvzz7dQAuGFDNL8++vFCfunuzd7wkoz2JSIBCIhKgkIgEKCQiAQqJSECfP7p1Iv5kLECefOHy/N2BI05lYMaD2731jX/uP9I28Z+iI1qDzh1a2Aaom+O/naZjR7q5wv5BexKRAIVEJEAhEQlQSEQCFBKRgD5/dKs3+0TVcG+9sna0t/7h3rgjfmsr7D3ZHf9oy7GSr60/0Z5EJEAhEQlQSEQCFBKRAIVEJEBHt8rAp0f8sbf+gyeu89YrBlV761uXRc3pLrmyhTeXneww39LW2r0F9nPBkDjnqoD1wKeAuWb2vHNuJPAUUAO8bGZL01ykSE9K8nSrFfg88O1TaguBx81sOnC5c+7CNBYnUg6CITGzvJmdfnaY6cDqeHs1cHWpFyZSLop9TVJtZu2tzZuAcUmv2LBpLRB1BG/fzkq5zjlkwEBvfdC4kZ1cocZbvuSnywEYOqG+sA3Q0Jr+90bK9bEthWJD0uycG2xmx4BawH+OAI/2LuC9rQt5mnMeXnqNt972laXeer7Z/3CvuH4ZAPNevJfvXb+4UF++5zXv+OOtgVNCdEG5PradibvKJ1LsIeANwLXx9qz4skiflGhP4pz7ETAFOOKcmwo8CDzpnPs6sM7Mfp3iGkV6VKKQmNlNnnK2+1aRHqJ33EUCFBKRAH0sJUOjq+u89cpLLuvS7VTk/P9sd2+Jjm7l6sYUtgEuv+ge7/jrDrzapXn7K+1JRAIUEpEAhUQkQCERCVBIRAJ0dCtDe442eeuvzXndW//k2NXe+vBZI7z1yk+cF/28fj4tL64s1K/ZdJd3/KHvdfwy1j8+XeUdu2LXf3nr/YH2JCIBColIgEIiEqCQiAQoJCIBOrpVBj57YKP/F5193/NXnd1SdDsNk/6Kabc/XaiuG/6ed/TkeR3/Ri579E+9Y9vuyHvrD+/s+5//0p5EJEAhEQlQSEQCFBKRAIVEJEBHt/qBzxz4hbc+/tExHWpvTLrYO3b55vu99Zop/wDAH1XVsGjMzEL9vl3ru7TGclZsw+ylwI3AXuAPZnZzmosU6UlJ9iTtDbNvP62+xMyeL/2SRMpLsQ2zARY55zY4525JYV0iZaMin/e/k3q6+CnWW/HTrRFmtt85VwusA24ws/cT3Ex+y9ZtQNTsePvb7xS57OJozo8anOv43ZGJH/d3dKmo8Z8uu3F71FN31Ph69u7YebLe8n9Jltot3Xlsp0yeBFCRZGxRL9zNbH/885Bz7hXgIiBJSNQwu4zmHF/neeH+3Ru8Y6tm+p8wPHZd9ML9Ky/ex2PXLSrUs3jhnlXD7KJC4pyrjQMyALgCWFXM7UjP2tHU8T/KfXf9xjt2ySt7vPVvrJwGwKDzqgvbAGvm+P9mvrHv3a4us8cV2zB7eHx2qxzwjJn9d4prFOlR3WmYLdIv6B13kQCFRCRAIREJ0Ge35COW71rvrVf6T+vIki33RhtDaxkwZXahPi+32Tv+Tnrf0S3tSUQCFBKRAIVEJEAhEQnQC3dJ5JF9Dd76osboA4a56pG0NZ78sOFtCz/mHf+1Bf7/ci1tHZt3lwvtSUQCFBKRAIVEJEAhEQlQSEQCdHRLEvnsiIu89cq6c6KNXNXJbSB/8VT/+IqflHxtadOeRCRAIREJUEhEAhQSkQCFRCRAR7f6gdHV/oZzgzzN6e4a+mnv2Hk/m+etVwypiTYqK09uA033POEdf7z1xBlWWp6SNMy+AlgBnACOALfG13sKqAFeNrOlKa5RpEclebr1O+AaM5sBvAh8FVgIPG5m04HL4x5cIn1ScE9iZqe2+TtB1GV+OrA4rq0Grgb8rf9EernEr0mccyOAO4G/AG41sw/iXzUB45LeTsOmtUDU7Lh9Oyv9dc6qypx3XEVFx37R51QO8Y6tOm+Uf7IB0euaitxAcnX1hfKIVY96hzccPea/nSJk9dgmbXM6FHgOmG9m+5xzzc65wWZ2DKil8zOOd6CG2dnPWZIX7j/3v3CvHDEWgFxdPW1NJ7vK7597+ulsItMaStcRt2waZsdNsZ8FHjGz9vOKbQCuBX4MzALu7voy+4aVoz/jrc8e93sARlQP5vfTPgnAC/8z1ju2uZNXhr/gsLd+Vd7/rb+/vir6T3rWsMHsv9EV6gOvmewdXzHxsg61AROmeUZ27sMD0X+2XM3owjbAF97z7716oyR7kluIXnN8zDn3NWAN8CDwpHPu68A6M/t1imsU6VFJXrg/RXS493TZPocQ6SF6x10kQCERCVBIRAL02a1u+lnOfwLNW78QfQghN2wwtfH23Jv/rku3/TcnPvDWKwb638toV1lXT/XDpT1D34lVy7z1L/7LQQCWr/4WC2c/VKi/tmd7SefvSdqTiAQoJCIBColIgEIiEqCQiATo6FY3/bjRf9qz1xdHpz1bM/2r/OXiVwH43Lf2eccuOX+3tz54gv8o1pB7v9OlNTYv8H/Y8Pj7LR1q9+zwf9p39SH/NyH2HG0C4O9bmnmhcUuX1tVbaE8iEqCQiAQoJCIBColIgEIiEqCjWylpPBJ9o7nlw9bC9sojG71jV+70luHVTurfv+qMc3fvK8PvhIf0M9qTiAQoJCIBColIgEIiEqCQiAQU2zD7b4Ebgb3AH8zs5jQXKdKTkhwCbm+Y3eycu4OoYTbAEjN7Pr2liZSH4NMtM9tlZs3xxfaG2QCLnHMbnHO3pLY6kTJQkc/nEw2MG2b/lKhhdt7M9jvnaoF1wA1m9n6Cm8lv2boNiJodb3872zeuNKfmbDdl8iSAjh3DPRKFJG6YvRpYdEo/4PbfPQj8p5klae+dz1WNAcqjkbTm7L9zxg2zE4Uk+HTL1zA73oO0/+4K4N2iVirSCxTbMNvFZ7fKAc+YWen66YuUme40zBbpF/RmokiAQiISoJCIBCgkIgEKiUiAQiISoJCIBCgkIgEKiUhA4k8Bl0h+797o9Gl1dUNpamoODC8tzak5240aVQOl/BRwCWU6mUhAopBk3Zwu0aJEyolek4gEKCQiAQqJSIBCIhKgkIgEKCQiAZmfn8Q5Nw/4ElEPrzlm9tuM5j0KtJ8q9wEzeymFOaqA9cCngLlm9rxzbiTR159rgJfNbGkGcy4lxQ6bnXT1HEC697PHOolmGhLn3HBgLnAlcCnwAHBTRtO/Z2YzU56jFfg8cOo5oRcCj5vZc865Nc65C83Mf77n0s0J6XbY9HX1rCPd+9ljnUSzfro1FVhvZq1mthlwGc59rnPuVefc03GjvZIzs7yZNZ5Wnk7Us4z459UZzAkpdtjspKtn2vezxzqJZh2S4cDBHpr/fDO7GngFuD/DeavN7IN4u4noMUjbI2Z2CTAbuMs5d24ak8R/bO4Evk9G9/O0OTO5n1mH5CDRbrldW1YTm9m+ePNZoqd6WWl2zg2Ot2uBA2lPaGb745+HiP4oXFTqOeKuns8B8+PHNvX7efqcWdxPyD4kDcAM51zOOXcZGZ3F0jlX7ZzLxRdnZDVvbANwbbw9K76cqrQ7bPq6epLy/ezJTqJZfwqY+EXXF4EW4MtmtiODOScDq4DDwHGio0D/m9JcPwKmEB2BeQl4CHiS6KjPOjNbnMGcw4FTO2w+XOL5bgMeAd6MS2uAJ0jxfnYypyPF+9ku85CI9DZ6M1EkQCERCVBIRAIUEpEAhUQkQCERCVBIRAIUEpGA/wcQ4Pzhyn3MngAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(8, shape=(), dtype=uint8)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADHCAYAAACzzHd1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADwJJREFUeJzt3X+UVOV9x/H37uzyawNscJH4A1IF8igmgQQErSuYYs+pNAnW3zbRExGqyBGjbQ+0x8I2HjzWnkMTyOnRipxWo/EoNZZAUY8gkZqyRXIkCcKX+BMqMcLyq4KwO5vpH3NnWdhneO7Ozp2dXT6vf7jz5Zn73Lvw2Tv3zp3vVGQyGUQkv8ru3gCRcqeQiAQoJCIBColIgEIiEqCQiAQoJCIBColIgEIiEqCQiARUFfpE59ws4DagGZhhZu/GeJrugZFyUhFnUEEhcc4NAWYClwFfAR4Cbojz3FT12QA0blzDpEuuKmT6gmlOzZnT2rI79thCX25NBNabWdrMNgGuwPWIlL1CX24NAfa3exw7bI0b1wBw4QWj25ZLRXNqzkIUGpL9wJfbPW6N+8Tc4bGnHZ41Z++aszMvtwoNSSOwwDmXAsYCvylwPSJlr6CQmNk+59y/ARuAFuD2om6VSBkp+BKwmT0CPFLEbREpS3ozUSRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkQCFRCRAIREJUEhEAhQSkYCCG0FI8gb06eet31k3yVuf1T/bL3BE3/5sG3m8Ldq53/tD7/iK2roOtZ13r/SOndb0O2/9/YP+em+iI4lIQFe6yh8GNkUPHzKzF4uzSSLlpSsvt94zsyuKtSEi5aoikynsK0OccweBLcAuYK6ZNcV4WuaNzVuAbLPjbdtL2x21p81ZWeF/NTy0qsZfr0wD0GfkCJrf2dlWrz77M97xFamOvyObdx3wjv1NOu2tH2ttAXrez3bC+LEQ8/tJuhKSOjPb65ybAUw0sztjPC2j7yeJr+AT9xVL2Hnd3LZ6KU7ce9rPNmqYndyX+ACY2d5o8RngjkLXIzCu7nxv/dXp/iNAv4YGb72iMnvkSX12OOe9trStnvn972Nvy3k/m+qt/2r9j731++dvB2BY9We455zJbfUffPha7DnLXUFXt5xzNVFHeYApqKu89GKFHkkuAJY55w4Bx8h+NZxIr1ToVy9sJvtdiSK9nt5MFAlQSEQCdO9WQvpW9QGgkoq25QeG1nvHznn+Wm+98twxyWxcAaquuNlbX7TkpwD0HVHDoiUT2uqH7/Z/jeay3a8Xf+MSpiOJSIBCIhKgkIgEKCQiAQqJSICubiVk5aCJAIxK1bQtT/6fv0t0zr3XzALgjGWP0jTz+O10C3Z1vJER4M7oDt72Bg066h07YmWDt151yfTsQk3t8WVgTsV67/hl3mp505FEJEAhEQlQSEQCFBKRAIVEJEBXt7roG2eN99YvX3sPANXDh3H5f91T0LrTeT4NuHie/zNuC3fvAKDx8FEmNe5o9zc7vOOXe2r1Gf/9Yi81+6965TPqe1/y/8Xtv+jUesqBjiQiAQqJSIBCIhKgkIgEKCQiAbq61UWzjw3w1itrP5ddSFUdX874+1+lf/4Tb33MHH9916E9ndvIPCYO/UKH2osvzPaOTdWd619J+y6TeTpO9nTBkDjnqoH1wEXATDNb4ZyrA54EBgKvmFlDkhsp0p3iRD8NXAd8v11tHrDczOqBi51z5fNhbJEiC4bEzDJm9tuTyvXAqmh5FTAZkV6q0HOSGjP7NFo+AJwX94mNG9cA2Y7gueVSSWLO0Sl/v95U7TkAVKT6tC3nk5pyq7e+cu2feuvNrf4O7zlx97OmqmND7r6jzvQP7tP/lOs6eT8rp83xjmvceFNwu+Iq1f+hQkNyxDnXz8yOAoOBfXGfmOsC3tO6kOezaoj/IHrlrx4Ass2rW/fvAvI3rl497u+99eua1he0TXH38+0LL+pQS7/8w85Nlj4GQNWZo0h//HZb+dXxi7zDp+3b0Ln1n0IRusrHUujliA3AtGj5quixSK8U60jinHsWmAB84pybCDwMPOGcuw9YZ2ZbE9xGkW4VKyRmdoOnXNrXSiLdpHe++yNSRAqJSIBuSykD+1PF+V3la9INsHX0aO/4s5bf1eU5P5p+LwDDnvxnfnfLvW31afu2dXnd5UJHEpEAhUQkQCERCVBIRAIUEpEAXd3qotlp/1Wct6J2QJVXzqQlWq6afKN37J+//B1v/dE/3uWtv7HH31Jo15UjABg8qE/bMsDAxzp5P5ZHvvZGk9/fC8DK5jTfjJZ7Gx1JRAIUEpEAhUQkQCERCVBIRAJ0dauL8rX3qZ/7EgA/eukGvh0tb3zTf3Wr8ozh3vq674701lt+mfLWaxY/ml3fZ4cz6PHH2+r5PhHpk+8q1kVzXvDWc/vf3JouWqujcqMjiUiAQiISoJCIBCgkIgEKiUiArm4lZEvTewAcSR9rW35swgPesTNfv89b7/Odv/HX802ap3l1vj7WLav+pUNtzF+/7B3bW69cxVFow+wG4HpgD/CRmRWvLZ9ImYlzJMk1zL7jpPpCM1tR/E0SKS+FNswGuN85t8E5d3MC2yVSNioymUysgdFLrF9HL7fOMLMm59xgYB1wtZn5P/xwoswbm7cA2WbH27b7PxeRlO6ec0TVQO+YuguG+Z9c6X9nPaQi1YdMa3NwXOZgx/OMrf97yDs2TpPunvTvOWH8WICKOGMLOnE3s6boz4POubXAF4E4Iel1DbM7M+eSYX/kHZPvxL2yvz9UeUVn6Knac2g98OHxep5v2PKduH+zwBP37v7ZdlZnGmYXFBLn3OAoIFXApcCyQtZzupm//+fe+m0v+L++oeLGe731fE64itUuGOkta73jx81b16F2Ol/FyqfQhtlDom+3SgFPm9mOBLdRpFt1pWG2yGlB77iLBCgkIgEKiUiA7t1KiK/D+zuX+D+BWNXJq1idVXH2KG/9pgGuQ23RAd/7xqc3HUlEAhQSkQCFRCRAIREJ0Il7F9WfOcZbf/6r2RsM2zevHvhYx3ulTqV1839660ce+Ym33v+WK4ETm3RD/kbd8x88v0PtsTlvesd+fPjAKbe1N9ORRCRAIREJUEhEAhQSkQCFRCRAV7e66D++5v/UX83ibMPq9s2rO9O4GmD17Zu89Rub3vbWx25qBU5s0g35G3VXTb2lQ23zmNe9Y4dv0tUtEclDIREJUEhEAhQSkQCFRCRAV7di+qfP+Xtm1Sxe6H+Cr3l1+lPv0JmXNXjrTzVtjLt5gL9JN8Ans//CO37gox07QQ19wX9/2YgLr/fWdx76uFPb2BPFaZh9KbAYaAY+Ab4VPe9JYCDwipk1JLiNIt0qzsutD4CpZjYF+CkwB5gHLDezeuDiqAeXSK8UPJKYWft+kM1ku8zXAwui2ipgMvBW0bdOpAx0pmH2GcBLwJ8A68zsy1H9ZuA8M3swxmp6bMPs4dX+vrxDLzz7lM87oXl1np68H2z396VtajkcfwPbOXk/xw3u6x1X+fk/iL3OX299x1vPNdLuaf+eRW+Y7ZwbADwHzDWzvc65I865fmZ2FBgM7Iu7cT21YXa+E/fZm0994t6+eXXmmP8//cJvNHjrT+3u3Il7zsn7uefro73jfCfu+Uyf+i1vPXfi3tP+PYvaMDtqiv0MsNTMch2fNwDTgOeBq4C/7fxm9iy3Xvaht57vfixf8+rt9f4f01O7f9GlbcsZWXsWAH1T1W3LANUj/EfBzt5LdrqKcyS5mew5xyDn3D3AauBh4Ann3H1kX3ptTXAbRbpVnBP3J8le7j1ZaY+tIt1E77iLBCgkIgEKiUiA7t2KKVXXv8vrOHTU/35FPu2vULX3xvVneuvVfzYdgL7n17Hl6Vlt9dSXvhZ7zlb7b2/9aIwvKu2tdCQRCVBIRAIUEpEAhUQkQCERCdDVrZi+8KP3vfX3Zmz31lPndvyIzcUb7vOO/dnli731kaOavPV+C77vrVdURr/z+g+iauzUtnq+e7R8V7Ku/faz3rHqKi8ieSkkIgEKiUiAQiISoJCIBOjqVkz5ru7Mv+bH3vqD84cBUHnNX9Ly/A8BqLr6Lu/YiW/GaQ8Q9u7ldwPw+X9fwgfXzm2r/8PRGu/4NQe3daidzlex8tGRRCRAIREJUEhEAhQSkQCduHfR0g9f89ez59A0XjyDSXc/l32Q+zNhjcc+ZdK7vyzJXKeDQhtmfxe4HtgDfGRmNyW5kSLdKc6RJNcw+4hz7k6yDbMBFprZiuQ2TaQ8BM9JzGy3mR2JHuYaZgPc75zbEPUCFum1Cm2YnTGzJufcYGAdcLWZ7Yqxmh7bMFtz9q45O9MwO1ZIoobZq4D72/UDzv3dw8CrZrYmxnyZVHW2C3tPa7CsOXvXnFHD7FghCb7c8jXMjo4gub+7FPD35RfpBQptmO2ib7dKAU+b2Y4Et1GkW3WlYbbIaUHvuIsEKCQiAQqJSIBCIhKgkIgEKCQiAQqJSIBCIhKgkIgExL4LuEgye/b8HwC1tQM4cOBIYHhxaU7NmTN06EAo5l3ARVTSyUQCYoWk1J9xj7VRIuVE5yQiAQqJSIBCIhKgkIgEKCQiAQqJSEDJ25w652YBt5Ht4TXDzN4t0byHgU3Rw4fM7MUE5qgG1gMXATPNbIVzro7sx58HAq+YWUMJ5mwgwQ6bebp6VpHsfnZbJ9GShsQ5NwSYCVwGfAV4CLihRNO/Z2ZXJDxHGrgOuKNdbR6w3Myec86tds6NMbO3Ep4Tku2w6evqWUuy+9ltnURL/XJrIrDezNJmtglwJZx7uHPuNefcU1GjvaIzs4yZ/fakcj3ZnmVEf04uwZyQYIfNPF09k97PbuskWuqQDAH2d9P8I81sMrAWWFTCeWvM7NNo+QDZn0HSlprZOODrwF8554YnMUn0y+Yu4HFKtJ8nzVmS/Sx1SPaTPSzntJZqYjPbGy0+Q/alXqkccc71i5YHA/uSntDMmqI/D5L9pfDFYs8RdfV8Dpgb/WwT38+T5yzFfkLpQ9IITHHOpZxzXwVK0jzWOVfjnEtFD6eUat7IBmBatHxV9DhRSXfY9HX1JOH97M5OoqW+C5jopOtWoAW43czeLsGc44FlwCHgGNmrQDsTmutZYALZKzAvAv8IPEH2qs86M1tQgjmHAO07bP6gyPPdAiwF3oxKq4F/JcH9zDOnI8H9zCl5SER6Gr2ZKBKgkIgEKCQiAQqJSIBCIhKgkIgEKCQiAQqJSMD/A404V2z8yYa7AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for x, y in cifar_tf_ds[\"testV2\"].take(2):\n",
    "    print(y)\n",
    "    plt.figure(figsize=(5, 3))\n",
    "    plt.imshow(x.numpy())\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'testV2': <TensorSliceDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>,\n",
       " 'train': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>,\n",
       " 'val': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds[\"testV2\"] = cifar_tf_ds[\"testV2\"].map(\n",
    "    img_cifar_preprocessing(training=False, my_dataset=\"mnist\")\n",
    ")\n",
    "cifar_tf_ds[\"testV2\"] = cifar_tf_ds[\"testV2\"].batch(batch_size=batch_size)\n",
    "cifar_tf_ds[\"testV2\"] = cifar_tf_ds[\"testV2\"].map(one_hot_cifar10_preprocessing)\n",
    "\n",
    "cifar_tf_ds[\"testV2\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'testV2': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>,\n",
       " 'train': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>,\n",
       " 'val': <MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds[\"train\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Prediction Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "if my_dataset == \"cifar10\":\n",
    "    \n",
    "    model = BigConvNet(\n",
    "        name=\"ResNetCIFARV2\",\n",
    "        input_shape=(32, 32, 3),\n",
    "        output_shape=10,\n",
    "        n_blocks=3,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pre-Trained Predictive Network (achieving about 93% accuracy) at the following path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tmp  weights_ResNet50V2CIFAR_pretrained.h5  weights_ResNet50V2_pretrained.h5\n"
     ]
    }
   ],
   "source": [
    "f\"{os.getcwd()}/datasets/CIFAR10/models\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<MapDataset shapes: ((None, 28, 28, 1), (None, 10)), types: (tf.float32, tf.float32)>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_tf_ds[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/6\n",
      "469/469 [==============================] - 3s 6ms/step - loss: 0.3659 - accuracy: 0.8992 - val_accuracy: 0.9411 - val_loss: 0.2040\n",
      "Epoch 2/6\n",
      "469/469 [==============================] - 2s 5ms/step - loss: 0.1683 - accuracy: 0.9523 - val_accuracy: 0.9597 - val_loss: 0.1381\n",
      "Epoch 3/6\n",
      "469/469 [==============================] - 3s 5ms/step - loss: 0.1198 - accuracy: 0.9654 - val_accuracy: 0.9670 - val_loss: 0.1118\n",
      "Epoch 4/6\n",
      "469/469 [==============================] - 2s 5ms/step - loss: 0.0922 - accuracy: 0.9733 - val_accuracy: 0.9690 - val_loss: 0.1011\n",
      "Epoch 5/6\n",
      "469/469 [==============================] - 3s 5ms/step - loss: 0.0745 - accuracy: 0.9783 - val_accuracy: 0.9727 - val_loss: 0.0924\n",
      "Epoch 6/6\n",
      "469/469 [==============================] - 2s 5ms/step - loss: 0.0618 - accuracy: 0.9824 - val_accuracy: 0.9737 - val_loss: 0.0870\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe753701588>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if dataset == \"mnist\":\n",
    "    \n",
    "    ds_train = cifar_tf_ds[\"train\"]\n",
    "    ds_test = cifar_tf_ds[\"val\"]\n",
    "\n",
    "    model = tf.keras.models.Sequential([\n",
    "      tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n",
    "      tf.keras.layers.Dense(128,activation='relu'),\n",
    "      tf.keras.layers.Dense(10, activation='softmax')\n",
    "    ])\n",
    "    model.compile(\n",
    "        loss='categorical_crossentropy',\n",
    "        optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "        metrics=['accuracy'],\n",
    "    )\n",
    "\n",
    "    model.fit(\n",
    "        ds_train,\n",
    "        epochs=6,\n",
    "        validation_data=ds_test,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset == \"mnist\":\n",
    "    \n",
    "    model.save_weights(f'{os.getcwd()}/models/weights_mnist.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset == \"cifar10\":\n",
    "\n",
    "    model_weights_path = path_to_data + \"/models\" + \"/weights_ResNet50V2CIFAR_pretrained.h5\"\n",
    "    model.load_weights(model_weights_path)\n",
    "\n",
    "    optimizer_args = {\"learning_rate\": 0.1, \"momentum\": 0.9, \"nesterov\": True}\n",
    "    optimizer = optimizer_dict[\"SGD\"](**optimizer_args)\n",
    "    metrics_list = [tf.keras.metrics.CategoricalAccuracy(name=\"accuracy\")]\n",
    "\n",
    "    model.compile(\n",
    "        optimizer,\n",
    "        loss=\"categorical_crossentropy\",\n",
    "        metrics=metrics_list,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten_3 (Flatten)          (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_6 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check model performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "469/469 [==============================] - 2s 4ms/step - loss: 0.0549 - accuracy: 0.9835\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.054887041449546814, 0.9835166931152344]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(cifar_tf_ds[\"train\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "79/79 [==============================] - 0s 5ms/step - loss: 0.0870 - accuracy: 0.9737\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.08698935806751251, 0.9736999869346619]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(cifar_tf_ds[\"val\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "391/391 [==============================] - 2s 4ms/step - loss: 0.1023 - accuracy: 0.9692\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.10227405279874802, 0.9691799879074097]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(cifar_tf_ds[\"testV2\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Features to numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset == \"cifar10\":\n",
    "    n_features = 256\n",
    "\n",
    "else:\n",
    "    n_features = 128\n",
    "    \n",
    "model_preprocessing = tf.keras.Model(\n",
    "    inputs=model.input, outputs=model.layers[-2].output\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_split_list = list(cifar_tf_ds.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {}\n",
    "labels = {}\n",
    "features = {}\n",
    "\n",
    "for dataset in dataset_split_list:\n",
    "    features[dataset] = np.empty((0, n_features))\n",
    "    labels[dataset] = np.empty((0, n_classes))\n",
    "    scores[dataset] = np.empty((0, n_classes))\n",
    "    for x, y in cifar_tf_ds[dataset]:\n",
    "        features[dataset] = np.concatenate(\n",
    "            (features[dataset], model_preprocessing(x).numpy()), axis=0\n",
    "        )\n",
    "        scores[dataset] = np.concatenate(\n",
    "            (scores[dataset], np.log(model(x).numpy())), axis=0\n",
    "        )\n",
    "        labels[dataset] = np.concatenate((labels[dataset], y.numpy()), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/britten1/maxcauch/datasets/CIFAR10'"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path_to_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_data_bkup = path_to_data\n",
    "\n",
    "if dataset == \"cifar10\":\n",
    "    path_to_data = f'{os.getcwd()}/datasets/CIFAR10'\n",
    "    \n",
    "else:\n",
    "    path_to_data = f'{os.getcwd()}/datasets/QMNIST'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in dataset_split_list:\n",
    "    np.save(\n",
    "        path_to_data + \"/np_data/\" + dataset + \".npy\", {\n",
    "            \"features\": features[dataset],\n",
    "            \"labels\": labels[dataset],\n",
    "            \"scores\": scores[dataset]\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can load train, val, and testV2 data already preprocessed on Britten (through the NN) directly in numpy format, at the following location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 128)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load(path_to_data + \"/np_data/\" + dataset + \".npy\",\n",
    "        allow_pickle=True).item()[\"features\"].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
