{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# smaller font size for all plots\n",
    "plt.rcParams.update({'font.size': 6}) #, 'font.family': 'serif'})\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from solo.data.classification_dataloader import prepare_data\n",
    "from solo.methods import METHODS\n",
    "from scripts.utils.get_images_and_feats import get_images_and_feats\n",
    "\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from cmcrameri import cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up plot dir\n",
    "plot_dir = Path(\"plots/mnist/\")\n",
    "plot_dir.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dataset = \"mnist\"\n",
    "val_data_path = \"../datasets/mnist/val\"\n",
    "train_data_path = \"../datasets/\"\n",
    "data_format = \"image_folder\"\n",
    "batch_size = 256\n",
    "num_workers = 4\n",
    "\n",
    "# prepare data\n",
    "train_loader, val_loader = prepare_data(\n",
    "    dataset,\n",
    "    train_data_path=train_data_path,\n",
    "    val_data_path=val_data_path,\n",
    "    data_format=data_format,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    "    auto_augment=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(model_name, run_name):\n",
    "    folder_name = \"../trained_models/\" + model_name + \"/\"\n",
    "    # get name of the most recent model by folder creation time\n",
    "    names = sorted(os.listdir(folder_name), key=lambda x: os.path.getctime(os.path.join(folder_name, x)))\n",
    "    for n in names:\n",
    "        run = folder_name + n\n",
    "        args_file = os.path.join(run, \"args.json\")\n",
    "        args = json.load(open(args_file, \"r\"))\n",
    "        if args[\"name\"] == run_name:\n",
    "            pretrained_checkpoint_dir = run\n",
    "            name = n\n",
    "    cfg = OmegaConf.create(args)\n",
    "\n",
    "    # build paths\n",
    "    ckpt_dir = Path(pretrained_checkpoint_dir)\n",
    "    args_path = ckpt_dir / \"args.json\"\n",
    "    ckpt_path = [ckpt_dir / ckpt for ckpt in os.listdir(ckpt_dir) if ckpt.endswith(\".ckpt\")][0]\n",
    "\n",
    "    # load arguments\n",
    "    with open(args_path) as f:\n",
    "        method_args = json.load(f)\n",
    "    cfg = OmegaConf.create(method_args)\n",
    "\n",
    "    # build the model\n",
    "    model = (\n",
    "        METHODS[method_args[\"method\"]]\n",
    "        .load_from_checkpoint(ckpt_path, strict=False, cfg=cfg)\n",
    "    )\n",
    "\n",
    "    # move model to the gpu\n",
    "    device = \"cuda:0\"\n",
    "    model = model.to(device)\n",
    "\n",
    "    # get images and features\n",
    "    data, labels, z = get_images_and_feats(device, model, val_loader)\n",
    "    # clip data for imshow\n",
    "    data = np.clip(data, 0, 1)\n",
    "\n",
    "    return data, labels, z, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Using scheduler_interval=step might generate issues when resuming a checkpoint.\n",
      "Collecting features: 100%|██████████| 40/40 [00:00<00:00, 45.75it/s]\n",
      "WARNING:root:Using scheduler_interval=step might generate issues when resuming a checkpoint.\n",
      "Collecting features: 100%|██████████| 40/40 [00:01<00:00, 38.89it/s]\n",
      "WARNING:root:Using scheduler_interval=step might generate issues when resuming a checkpoint.\n",
      "Collecting features: 100%|██████████| 40/40 [00:00<00:00, 46.10it/s]\n",
      "WARNING:root:Using scheduler_interval=step might generate issues when resuming a checkpoint.\n",
      "Collecting features: 100%|██████████| 40/40 [00:00<00:00, 44.00it/s]\n"
     ]
    }
   ],
   "source": [
    "# model name is \n",
    "model_name = \"catprob\"\n",
    "run_name = \"catprob_mnist_pairs_exact_p=0.992\"\n",
    "data_exact, labels_exact, z_exact, model_exact = load_model(model_name, run_name)\n",
    "run_name = \"catprob_mnist_pairs_approx2_p=0.992\"\n",
    "data_approx_p09, labels_approx_p09, z_approx_p09, model_approx09 = load_model(model_name, run_name)\n",
    "run_name = \"catprob_mnist_pairs_approx2_p=0.8\"\n",
    "data_approx_p08, labels_approx_p08, z_approx_p08, model_approx08 = load_model(model_name, run_name)\n",
    "run_name = \"catprob_mnist_pairs_MI\"\n",
    "data_MI, labels_MI, z_MI, model_MI = load_model(model_name, run_name)\n",
    "\n",
    "model_dict = {\n",
    "    r\"DM, $p_\\theta=0.99$\": (data_exact, labels_exact, z_exact, model_exact), # make sure this is last, because we will plot it in the later plots\n",
    "    r\"DM+MI, $p_\\theta=0.99$\": (data_approx_p09, labels_approx_p09, z_approx_p09, model_approx09), \n",
    "    r\"DM+MI, $p_\\theta=0.8$\": (data_approx_p08, labels_approx_p08, z_approx_p08, model_approx08),\n",
    "    r\"MI\": (data_MI, labels_MI, z_MI, model_MI),\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "predictions_dict = {}\n",
    "entropies_dict = {}\n",
    "for key, (data, labels, feats, model) in model_dict.items():\n",
    "    predictions = np.argmax(feats, axis=1)\n",
    "    entropies = -np.nansum(feats * np.log2(feats), axis=1)\n",
    "    predictions_dict[key] = predictions\n",
    "    entropies_dict[key] = entropies\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DM, $p_\\theta=0.99$\n",
      "Accuracy: 0.989\n",
      "DM+MI, $p_\\theta=0.99$\n",
      "Accuracy: 0.892\n",
      "DM+MI, $p_\\theta=0.8$\n",
      "Accuracy: 0.989\n",
      "MI\n",
      "Accuracy: 0.699\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAACKCAYAAABrR3e7AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGn1JREFUeJzt3X1UVHX+B/D3QGiZgCio6w/MXE10H1LDctWBQVKIBzVEzefEPGGpv1+rrhm5SGpIqXtKW2wNtQxKtx8qSSSyODhrZ1PTXHd9qHQ5QCzFKk+iTcnc3x/+mOPIPN6Zud8ZeL/O4Ry533vnfud95/rhzv3Od1SSJEkgIiKSwUd0B4iIyHuxiBARkWwsIkREJBuLCBERycYiQkREsrGIEBGRbCwiREQkG4sIERHJxiJCRESysYgQEZFsLCJERCSb1xaRiooKhISEYPz48YiMjMTy5ctx48YNVFRUQKVS4ejRowCAH3/8EUFBQdi2bZtifVu1ahXUajXmzp2Ln376yaSttbUVc+bMQXR0NFJTU3Hr1i2zy0Rxda7WtquoqEBKSorbn9OdvOHYdNZjYDAY8PTTT0OtVmPcuHG4ePGiov0SxdrxiYiIENw727y2iABAVFQUysrKUF5ejm7duiEjIwMAEBERgYKCAgBAaWkpBg8erFifzp49i2+//RY6nQ7h4eH46KOPTNr379+PBx98EEePHkV4eDgKCgrMLhPJmVyffvrpdstEHo87edOx6YzH4Msvv4Rer4dOp0NWVha2bNkipI8ieMrxkcOri0gblUqFNWvWoLCwEADwwAMPoLKyEpIkYf/+/UhOTra6/a9//WtMmjQJI0aMwGuvveZUXz777DNMnDgRABAXF4fjx4+btF++fBnDhw8HAIwcORLHjh0zu8wTOJtrG7nbATw2nekYhIaGQpIkSJKE+vp6BAcHO7Uvb+LM8RGtQxQRAOjSpQt+/PFH4++/+c1vcOzYMdTV1aFv374Wt2toaMC1a9ewc+dOfP7558jLy7O4bl1dHTQaTbuf2tpa4zr19fUICAgAAAQGBuLatWsmjzFs2DCUlZUBuP0XR319vdllnkJurneTsx2PzW2d5RgEBwfDz88P4eHhWLp0KZ577jm7n1tHIPe4inaP6A64il6vR9euXY2/T506FTNmzMC8efOsbnfu3Dk89dRTCA4OhiRJuO+++wAATU1NyM/PR21tLebMmYNBgwYhJCQEWq3W6uP16NEDTU1NAIDGxkb07NnTpD0xMRFarRbjx4/HL37xC/Tt29fsMk9hT66VlZXG3y9evAiNRgMAKCkpsbqdLTw2t3WWY1BSUoJ77rkHly5dwqlTp7B8+XLs3bvX7r56OznHxxN0mCuRrKwsTJkyxfj74MGDMW7cuHY3Daurq01+P3fuHHx8bsfwwQcfIDExEQCwY8cOXL9+HTU1NTh79iwA+/7SGjNmDEpLSwEAhw8fxtixY032p1KpsHnzZpSVlaFXr16YPHmy2WWewp5c+/fvD61WC61Wi7i4OOO/u3TpYnW7u/HYmNdZjoEkSejVqxeA21cljY2NDuXk7ew5Pp7Iq69EysvLER0djdbWVjz22GN45ZVX8N133xnb33zzTZP1b926hZkzZ0Kn0xmXnTt3Dn5+foiJiUHfvn2xc+dOALdPiI0bN2LlypV47LHHAMCuv7SGDx+OPn36QK1Wo3///lixYgVqa2uRk5ODzMxM1NbWYubMmfDx8UFMTAwiIyPNLhPJ0VztZW07HhtTnfEYrFmzBrt370ZUVBT0en2nurHeRu5xFUnVmb4e98SJEzh79iwWLVpkXDZhwgQcPnzY+NdWm82bN8PPzw/3338/Fi5cqHRXOx0eG/F4DEiOTlVEzImKikJ5ebnobpAZPDbi8RiQLZ2+iBARkXwd5sY6EREpj0WEiIhkYxEhIiLZWESIiEg2j/2ciMFgQE1NDfz9/aFSqUR3x6NIkoTm5mb069ev3dBLZzF3y5i7GMxdDHtz99giUlNTg7CwMNHd8GhVVVUIDQ116WMyd9uYuxjMXQxbuStSRFpaWvDcc8+hS5cu0Gg0mD17ts1t/P39AQCnPsxE9273tmsfOv9tl/fTW0iSAWioMGZkiTO5f7Fvvdncw+f+UV6nOwBFct/7ivnc522X1+kOgK93MezNXZEiUlBQgJSUFCQlJWHGjBlmD65er4derzf+3tzcDADo3u1e+N/f/uCqVL7u67AXkACbl9/O535fu/WZO3MXgbmLYU/uitxYr66uNl4y+vqaPyhZWVkIDAw0/vAS03nMXQzmLgZzF0ORIhIaGmqcHdRgMJhdZ/Xq1WhsbDT+VFVVKdG1Do25i8HcxWDuYijydlZycjKWLFmCoqIiJCUlmV2na9euJt+ZQM5j7mIwdzGYuxiKFJH7778fu3btUmJXdAfmLgZzF4O5i+GxQ3zbDJ3/ttmbWz9VFlncxq9/gju71CmEz/0jcxcgfN525i4AX+/yKXJP5MqVK1i4cKHXfWOXt2Pu4jB7MZi78hQpIgMHDkRubq7VdfR6PZqamkx+yDnMXRxb2TN392DuyvOYubM49E4M5i4GcxeDubuexxQRDr0Tg7mLwdzFYO6up0gRuXr1KtLS0nDmzBlkZWWZXadr164ICAgw+SHnMHdxbGXP3N2DuStPkdFZvXr1wvbtnXfuH1GYuzjMXgzmrjyPH+JribXhdbVHMiy29Z2Q6Y7udBrWcv+ubJ3Ftj7j17ijO52G1dd7qeXXdN/HLZ8LZJs7cvfxtfzfrqH1ln0d8yAec0+EiIi8j11XIufPn7fYNmzYMJvbHzhwAEVFRWhqasLChQsxceJE+3tIsjF3MZi7GMxdDLuKyOuvv252uUqlws6dO21uP2XKFEyZMgX19fVYsWKF2YN79xTNHL/tPOYuBnMXg7mLYVcRuXM+GkmSUFdXh969ezu8s/Xr1+P5558325aVlYXMTN6vcAfmLgZzF4O5K8uheyJ79+6FWq3G448/jtbWVjz11FN2bSdJElatWoUnnngCI0eONLsOx2+7HnMXg7mLwdzFcGh01tatW3Hs2DHExMTA19cX33//vd3blZaWorGxEd988w3S0tLarcMpml2PuYvB3MVg7mKoJEmS7F05MjISZWVlmDhxIkpKSvD4449Dq9W6pWNNTU0IDAyEKmigS7+i8tuPf2ex7b+SXrPYpvKxftEmWfgSHHeQpFZI9VfQ2Njo8g9LeVrunoS5i+GNudd88pLFtn7xr7psP+5kb+4OXYmkp6dDo9Hgq6++QkxMDNLT053uKBEReS+HikhsbCxiY2NRV1eH4OBgm1/g3ubChQt444038J///AcxMTFYvHixrM6SY5i7GMxdDOYuhkM31v/5z39i8uTJSExMxJNPPol//OMfdm03dOhQbN++Hfv27cPx48fNrsMpml2PuYvB3MVg7mI4VERSU1ORnZ2Nzz//HNnZ2Vi4cKHd2xYWFiIhIQHx8fFm2zlFs3swdzGYuxjMXXkOFZE+ffogPDwcADBkyBCHPisyadIkFBcXIy8vz2w7h965B3MXg7mLwdyVZ9c9kZUrV0KlUkGv1yMyMhIjRozAmTNn0KNHD7t2otVqUVBQAL1eb/EvBA69cz3mLgZzF4O5i2HXEN/y8nKLbVFRUS7tUBt3Db2zpurPSyy2hU3bpkgf7mRpWLFkaIXh2jdeNeTRm1T/7/+YXd7cchNDEld0mNz5endf7tX7X7DYFvrkH1y2H3dy6RDfOwvF5cuXUVNTAwc+XkJERB2UQ/dEli1bhuXLl2POnDnYvHkzNm3aZPe2LS0tiIiIwKFDhxzuJMnH3MVg7mIwd+U59DmRL774AsePH4dGo8HBgweRkpJi97bZ2dmYPn26xXbOrukezF0M5i4Gc1eeQ1cifn5+AIBu3bqhrKwMFy9etGu7I0eOYNiwYVZHc3HonesxdzGYuxjMXQyHrkS2bdsGvV6PzZs3IycnB3/4g303iLRaLVpaWnD+/Hncd999iI+Ph89dN9FWr16N3/72t8bfm5qaeICdxNzFYO5iMHcx7CoiN27cAAAMHDgQra2teOCBB7Bx40a7d7JhwwYAwO7duxEcHNzuwAIceucOzF0M5i4GcxfDriG+0dHRUKlUxhFZbf9WqVQoKytzS8c8bahp9RvTrLaH/vefFeqJd85qam0WZGszIDN35zB369w2xPdNy/dlQpfts9jmjbOF23UlcvToUZd1jIiIOg6HbqzLpdVqoVarkZaW5rbvH6H2mLsYzF0cZq88h26sy6VSqdC9e3f88MMPCA0NNbsOh965HnMXg7mLYyt75u56Dl+JnD59GoWFhTAYDKiurrZrG7VajeLiYmRnZyMjI8PsOhx653rMXQzmLo6t7Jm76zlURJYvX4533nkH69atg4+PD1JTU+3byf/fLAoKCjL5K+BOnF3T9Zi7GMxdHFvZM3fXc+jtrDNnzqCsrAzR0dEAgFu3btm1XUFBAQ4fPoyGhgYsWWJ+0jcOvXM95i4GcxfHVvbM3fUcKiJdunTBv/71L6hUKlRVVeHee++1a7vk5GQkJyfL6iDJx9zFYO7iMHvlOVREcnJy8OKLL+Lq1atYsWIF3nrrLXf1y0il8jE7dlrJ8dKA7XHxVe9a/pbHsPm5ru6O15F7vGzm/v4ii21hc3bI2qdIrn69uy33956x2BY27x1Z++xIrH0WxBlVec9abAub/bZb9mmLQ0XkwQcfxN69ex3eicFgwJo1a9DU1ISIiAjMnz/f4ccgxzF3MZi7GMxdDIeKyKhRo4yfVq+vr0ePHj1w6tQpm9sdPHgQ1dXV6NWrF4c8Koi5i8HcxWDuYjg0OuvkyZM4ceIETp48aXKD3ZZLly5hzJgx2LJlC3Jycsyuw6F3rsfcxWDuYjB3MWR/Yj0sLAw6nc6udUNDQxEUFAQA8PU1Pz8Nh965HnMXg7mLwdzFkP12ll6vx6xZs+zaLjk5GUuXLoVOp0NkZKTZdTj0zvWYuxjMXQzmLobdRUSSJLzzzjt4+OGHHd5Jt27dkJvLEUpKY+5iMHcxmLsYdhWRhoYG9OjRA+np6Yp/d7EkGQCDyqFt5E5/7Qxrw3j//enLFtt+FrfeHd0RQkjuVobx/vvwGottP4td547uOM1rXu9WhvHWHjE/1QsA9J2Q6Y7ueBXfbv4W21pvNFvd1tow3uoDyy22hU7ZbLtjMtlVRJKTk1FWVoaQkBBkZGRg1KhRxukF4uPjbW6v0+mQl5eHW7du4fz58/jss8+c6zXZhbmLwdzFYO5iOHRPZMCAAVCpVDh9+rRxmT1FRK1WQ61W48CBAxg1apTZdTj0zvWYuxjMXQzmLoZdReTUqVN49NFHcfeXIKpUKvz+97+3e2f5+fkW37PMyspCZiYvdd2BuYvB3MVg7sqya4jvI488Yvx8yJ0/J06csHtHlZWVCAwMhL+/+fcDOfTOPZi7GMxdDOauPEW+lAoAcnNzsWDBAovtHHrnHsxdDOYuBnNXnl1FpLi42Okd8RJSDOYuBnMXg7krz64iYu+U755C6Rl+bbE2jNfarJz951oYvmqQIJlvEcrjcrcyjLcqP81iW9is7e7ojtt4Wu7WhvFWfficxbawp/7oju4IYW3YtbVhvD6+1v9LNrRa/g4na8N43Tn7ryJvZ1VWVmLZsmXo2bMnHnroIbz44otK7LbTY+5iMHcxmLsYihSRc+fOISUlBXPmzMGMGTPMrsOhd67H3MVg7mIwdzFkT8DoiNGjRyM3Nxfjx49HXFyc2XU4u6brMXcxmLsYzF0MRYrIrl27kJmZibKyMhQVFZldh0PvXI+5i8HcxWDuYijydlZcXBzWrl2L/Px8DBgwwOw6HHrnesxdDOYuBnMXQ5Ei8stf/hIfffSREruiOzB3MZi7GMxdDMU+bOgNrA3Ls8XaMEtrj2t1Vs79L5hd3txyE0PiLc/YSbdZzd3KMF5Ls6E2t9zEkCfMHxOyj7VhvN9+/Duzy5tbbuKhicvc1SW3+PaQ5ZFh/eJftdhmbQivM5wdxmuNIvdEzp8/j+nTp2Px4sX8S0FBzF0M5i4GcxdDkSuR4uJiLF26FGq1GpMmTUJKSkq7dTj0zvWYuxjMXQzmLoYiVyJz587Fhx9+iJUrV+Lq1atm1+HQO9dj7mIwdzGYuxgq6e753d2otbUVycnJOHjwYLs2c38hhIWFQRU0ECqVryL9E3FPxNp2tu6JNDY2IiAgwGbfPD13d5Gdu417IszdPWzdE/Gm3Gs+eclim7V7Ip5Ekloh1V+xmbsib2dVVFTg1VdfRUtLC1auXGl2HQ69cz3mLgZzF4O5i6FIERkwYAD+9Kc/ObRN2wWSJCk4uZxB/kWZ1X5aeVxr2zW33DS7/PqNH/5/W+v99Zrc3cXVubcwd3eynPvt5d6Uu6XncntfrS7dl7u0ZWIrd0XfznJEdXU136+0oaqqCqGhoS59TOZuG3MXg7mLYSt3jy0iBoMBNTU18Pf3R3NzM8LCwlBVVdXuvbm29zQdbXNmW9FtkiShubkZ/fr1g48T93HMsTd3pZ6rJ7UplbtKpRL+XD2pjbl7du4e+2FDHx8fY/VTqVQAgICAAIs3eOS2uetx3d0WGBhodh1nOZq7rfaO1qZE7nL61dHbmLvn5q7IEF8iIuqYWESIiEg2rygiXbt2RUZGhtmheXLb3PW4Sre5U2fIzxNzd6ZfHb3N3TzpuXpSmzUee2OdiIg8n1dciRARkWdiESEiItlYRIiISDavKCItLS2YP38+Fi1ahLy8PJO2K1euYOHChWanfT5w4AAWLVqEGTNmoKSkxKTtwoULSEtLQ0pKCnJycszuMyIiAocOHTJZrtVqoVarkZaWBq1Wa9JmMBiQnp6OpUuX4t133zVp0+l0SEtLwzPPPIMxY8aYtFVWVmLKlClITU3Fxo0bTdpEfkcCc2fucnMHvC975i4zd8kLvPfee1JhYaEkSZI0ffp0s+tMnTrV4vbXrl2TUlNTzba1trZKs2fPbrd8zZo1UnZ2tvTxxx+bLNdqtVJcXJw0f/586euvvzZpKygokObNmye98MILUmlpqdn97d+/X9q+fbvJskOHDkl79uyRJKn989u0aZN07NgxSZIkKSkpyeJzdAfmztydzV2SvCd75i4vd6+4ErlzfhtfX8ena16/fj2ef/75dssLCwuRkJCA+Ph4k+VHjhzBsGHD0Lt373bbqNVqFBcXIzs7GxkZGSZtly5dwpgxY7Blyxazf3UAQH5+PmbNmmWybPTo0cjNzcX48eMRFxdn0mbPdyS4C3Nn7s7mDnhP9sxdXu5eUURCQ0NRXV0N4PalnL0kScKqVavwxBNPYOTIke3aJ02ahOLi4naXrlqtFn/729+Qn5+PHTt2mOyzbQ6ZoKAgk+8laOtnUFAQAPMvwsrKSgQGBsLf399k+a5du5CZmYmysjIUFRWZtPXu3RtvvfUWNm7ciODgYLufuyswd+buTO6Ad2XP3OXl7rFzZ90pOTkZS5YsQVFREZKSkkzarl69ivT0dJw5cwZZWVlYvXq1sW3r1q0oLS1FY2MjvvnmG6SlpRnbtFotCgoKoNfr2/2FsGHDBgDA7t27ERwcbDL5WEFBAQ4fPoyGhgYsWbKkXT+XLl0KnU6HyMjIds8jNzcXCxYsaLc8Li4Oa9euRX5+PgYMGGDSZs93JLgLc2fuzuQOeFf2zF1e7vywIRERyeYVb2cREZFnYhEhIiLZWESIiEg2FhEiIpKNRYSIiGRjESHqACoqKhASEgKNRoNHH30UJ0+edGj7FStWQKvV4ssvv7T4AbaKigqTaT2effZZp/pMHUOHLSI8qcRg7uJERUVBq9Vi69atSE9PNy535INzw4cPx+LFi8223Z3722+/Lb+zHQRf7x24iAA8qURh7mINHz4cVVVVSEpKwpNPPondu3fj008/hVqtxpgxY/DBBx8AAM6ePYtRo0YhMTERf//73wHc/nDcihUrAACffPIJRo8eDY1Ggz179iAnJwd79+6FRqPBtWvXEBERAQA4d+4cxo0bh7FjxyIrKwsAsHbtWsydOxfx8fGIiorCzZs3BSShjM7+eu/QRaQNTyoxmLsY5eXlqK2tRWNjIwoKCrBgwQKsW7cOf/nLX6DT6bBt2za0trbi5Zdfxvvvv4/CwkJcv37d5DEMBgNWr16NkpISaLVazJ49G4sXL8aMGTOg1WrRs2dP47ovvfQSduzYgb/+9a84evQoKioqAACDBw82HrsjR44oGYEQnfX17hXTnjir7aQKCQlBeXk5AGDcuHE4evQofH19ERkZienTpxtPqsGDB2PcuHEmj9F2Uul0OgQEBMBgMCAsLAxhYWHYtGmTybptJ1V4eDhiY2Mxc+ZMALdPqj179mDVqlU4cuQIJk2apEwAgjB3ZZWXl0Oj0aB79+54/fXXcf78eahUKnz//ff46quvMHHiRABAQ0MD6urqUFtbiyFDhgAAHnnkEZPHqqurQ1hYGAICAgDAZEqOu9XW1mLo0KEAgJEjR+Ly5csAgBEjRgAAwsLCUF9f79on64E66+u9QxcRnlRiMHcxoqKijN8DodVqcfHiRQBAcHAwwsPDUVJSgi5duuCnn36Cn58f+vTpg6+//hqDBg3C6dOnMXXqVONjhYSEoLq6GtevX0f37t1hMBjg5+eH1tbWdvvt06cPLly4gPDwcJw+fRppaWnQ6XRQqVTGdTry7Eqd/fXeod/Oanuv8tChQxg0aJDxgNx5UrXd1Orbt6/xpJIkCadPnzZ5rDtPKgB2nVRtj/Pzn/8cADrNScXcPYuPjw9efvllTJgwAdHR0Zg9ezYAYN26dZg1axYSExONs8Leuc2GDRsQExOD6Oho5OXl4Ve/+hW++OILTJs2DQ0NDcZ1N2zYgGeeeQZjx45FVFRUu4n9OrrO/nrv0Fciltx5Uvn4+CAkJAT79u0znlS9e/e2elJ169YNqampmDx5MlavXo1p06Zhx44dxnXbTipJkpCQkNDpTipLmLv7DBgwwOTb6DQaDTQajfH32NhYxMbGmmwzYsQIs6OJ2rZLSEhAQkKCSduxY8eM/z516hQA4OGHH8bx48dN1lu7dq3x33fPQttZdJbXO2fxJSIi2Tr021lEROReLCJERCQbiwgREcnGIkJERLKxiBARkWwsIkREJBuLCBERycYiQkREsrGIEBGRbCwiREQk2/8BCT1pV8mxjxoAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 400x120 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "fig, axs = plt.subplots(1, 4, figsize=(4, 1.2))\n",
    "i = 2\n",
    "keys = list(entropies_dict.keys())\n",
    "for i, key in enumerate(keys):\n",
    "    data, labels, feats, model = model_dict[key]\n",
    "    predictions = predictions_dict[key]\n",
    "    entropies = entropies_dict[key]\n",
    "    # make confusion matrix\n",
    "\n",
    "    confusion = confusion_matrix(labels, predictions)\n",
    "\n",
    "    # rearrange predictions to match labels\n",
    "    labels_inds = -np.ones(np.unique(labels).shape[0], dtype=int)\n",
    "    predictions_inds = -np.ones(np.unique(labels).shape[0], dtype=int)\n",
    "    # get best fitting labels\n",
    "    best_fits = np.dstack(np.unravel_index(np.argsort(confusion.ravel()), confusion.shape))[0][::-1]\n",
    "\n",
    "    for j, best_fit in enumerate(best_fits): \n",
    "        l_ind, p_ind = best_fit[0], best_fit[1]\n",
    "        if p_ind not in predictions_inds and l_ind not in labels_inds:\n",
    "            predictions_inds[l_ind] = p_ind # replace l_ind with k for sorting\n",
    "            labels_inds[l_ind] = l_ind # replace l_ind with k for sorting\n",
    "\n",
    "    # resort feats\n",
    "    predictions_inds = predictions_inds[np.argsort(labels_inds)]\n",
    "    feats = feats[:, predictions_inds]\n",
    "    predictions = np.argmax(feats, axis=1)\n",
    "\n",
    "    confusion = confusion_matrix(labels, predictions)\n",
    "    \n",
    "    im = axs[i].imshow(confusion, cmap=cm.lipari)\n",
    "    axs[i].set_xlabel(\"Prediction\")\n",
    "    axs[i].set_xticks(np.arange(10))\n",
    "    axs[i].set_yticks(np.arange(10))\n",
    "    #smaller font size for ticks\n",
    "    axs[i].tick_params(axis='x', labelsize=5)\n",
    "    axs[i].tick_params(axis='y', labelsize=5)\n",
    "    print(key)\n",
    "    print(\"Accuracy: {:.3f}\".format(np.trace(confusion) / np.sum(confusion)))\n",
    "    axs[i].set_title(key, fontsize=6)\n",
    "    i -= 1\n",
    "axs[0].set_ylabel(\"True label\")\n",
    "# plt.colorbar(im, fraction=0.03) \n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_dir / \"mnist_confusion_matrices.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHoAAABmCAYAAAAETYUEAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAACslJREFUeJztnX9QVOUaxz/LVCCxQBGSDQI5ZvCHsIn0g2WVnCBSExyj0muR/ZCrd8qmazk4022GJv2DpGEqvTdzmByyW6lhhuVEKr90DAHJUFwwkU1GZ9ZgVwWN4L1/OO5clB/n7C9gz/uZOTOHc97nOc/Ol/c9Z9/3Oc/qhBACic/jN9oBSLyDFFojSKE1ghRaI0ihNYIUWiNIoTWCFFoj+JzQQgjsdjtyHmggPif0xYsXCQkJ4eLFi6MdypjC54SWDI4UegzT3d1NTU0Nn3/+OVar1SVfUugxTHNzMykpKSxdupS9e/e65EsKPYaJjY2lpKQEgJiYGJd8SaHHMIGBgcTFxQEwYcIEl3x5Teg1a9ZgMpl47rnn6O3tdRw3m80YDAYCAgK4dOmS43hRURFGo5EFCxZgt9u9FabP4hWhGxsbOXv2LFVVVcTGxrJ9+3bHucjISCoqKnj44Ycdx6xWK99++y3V1dU888wzfPzxx94I06fxitAHDx4kPT0dgIyMDGpqahznAgMDCQkJGdC+traW2bNno9Ppbmp/I1evXsVutw/YJDfjFaE7OzsJDg4GICQkhD/++MNt7devX09ISIhjmzx5svsC9yG8InRoaKijp9lsNu688063tc/Ly8Nmszk2i8XivsB9CK8InZycTHl5OQB79+7FaDQO2z4pKYnKykpF7f39/QkODh6wSW7GK0IbDAYiIiIwmUw0NTWxaNEicnNzgWvD9GOPPUZjYyNPPvkk33//PeHh4cybNw+j0ci2bdtYuXKlN8L0aXS+lu5rt9sJCQnBZrP5RO+ur68nMTGRuro6ZsyY4bQfOWGiEaTQGkEKrRFUC/3TTz95Ig6Jh1Et9L59+0hJSeGtt97i2LFjnohJ4gFuUWvw3nvvAVBdXU1+fj5ms5mcnBxeeumlm6YyJWMH1T36zz//ZPv27XzwwQf09fXx9ttvExMTw4IFCzwRn8RNqO7RaWlpLFy4kE2bNjFx4kTH8QsXLrg1MIl7US10RUUFVquV06dPo9PpCA8PB+CVV15xe3AS96Fa6A0bNrB7927i4+P55ZdfmD9/PqtXr/ZEbBI3olroHTt2UFNTg06nQwiB0WiUQo8DVD+MTZ8+nVOnTgHw22+/kZiY6PagJO5HdY8+cuQI6enpBAUFcenSJe644w6SkpLQ6XT8/PPPnohR4gZUC11XV+eJOCSD0NLSwokTJwBob293afUKoZKGhgaRlpYmHnzwQZGeni4aGhrUuvAoNptNAMJms412KC5hNpsFMGAzm81O+1MttMlkEm1tbUIIIU6fPi2MRqPTF/cEviJ0XV2dAMSWLVvEu+++KwBRV1fntD/VD2N//fUX0dHRAERHR9PX1+f8cCIZEYPBwNy5c132o1ro7OxsUlNTWbVqFampqWRnZyuyGyqBv6+vjxdffBGTycTrr7/uOK7X60lNTSU1NVUunrgBVUILIYiLi+PLL79kyZIlfPXVV7zxxhsj2g2XwP/dd99xzz33UFVVxeXLlzl06BAA999/PwcOHODAgQNMnz5d5ceS3IgqoXU6HZs2bSIiIoKHHnqIiIgIRXbDJfAPde7UqVPMmjWLFStWcOXKlSF9ywR+ZageugMDA8nJyeGjjz5i48aNbNy4cUSb4RLyhzrX2tpKZWUlkyZNGvaVHJnArwzVQmdkZDBnzhz0ej233347QUFBI9oMl5A/1LmwsDAAnnrqKRobG4f0LRP4laFa6BMnTpCTk+PYzp49O6LNcAn8g527fPmy42m+qqqKqVOnDunblxP47w7SMaHLzIQuM3cH6VzypVhoq9VKU1MT+/fv5/jx4xw/fpxjx445RBqO4RL458+fT3t7OyaTiYCAAB555BFaWlpISkpi1qxZ7Nmzh1WrVjn/CccxuYm3EVeZS1xlLrmJt7nkS3EC/65duygtLeWHH34gIyMDgFtvvZUnnniChQsXuhSEO/GVBP76+nrmzZ7Jvl3bAJiTuYSyiiNOT4MqnuvOzMwkMzOTtrY2l8ssSJRx7pKgJ3SaY98VVC9qHDp0iGeffZb+/n7HMblqNfZRLXRBQQEVFRXo9XpPxCPxEKqfuhMSEvDzky94jDdU9+jDhw8TFRXFlClTAGTCgYfo7u4Grt0qOzs7XXeodJlr7dq1jv2ioiLH/sqVK51eOvMEvrJMuXnzZreuRyseg68vNgCUlpY69q9nQEjcS1ZWFps3b+bTTz8F4JtvvuG+++5z2p+82Y5R7rrrLl5++WUeeOABAKKiolzyp/ge/euvv/L0008jhBiw39TU5FIAEu+gWOja2lpPxiHxMIqFvp4+JBmfyHu0RpBCawQptEaQQmsEKbRGkEJrhFGvwD9UAr+swO9eRr0C/2AJ/LIC/zW6u7sdawk9PT0u+Rr1CvyDnfNEBX6r1conG/LJX7GI9WtXqfodqZMnT/LP5X8jf8Uiyv67RbGdq7bNzc0sXboUgLa2NlW2NzLqFfgHO+eJCvylpaV07F7HvyLKuXrwPxQXFyuOv6CgAL15B/+KKKf23/+gubnZK7axsbFUV1dTUlLC448/rthuMFQnHjiD2gT+0NBQWltbB21/I3l5eQPe/7Lb7YOKnZWVxU5bB/mtjfgnR7Js2TLF8b/55pt8sqGH/PNXSPr7XGJjY71iGxgYiNFoHLGQvRK8Uq/76NGjFBYWsnXrVtatW8e9997L4sWLgWs9rb6+nvz8fJYvX86yZcuYOnUqixcvpry8nG3btnHmzBny8vIUXctmsxEaGorFYhnX6b5q0Ov16HQjJPi7KSFiRFavXi1SUlLEkiVLxNWrV8Xy5cuFEEL09vaKnJwckZKSIl599VVH+8LCQpGcnCzmzZsnurq6FF/HYrHclJnh65uSbBqfq8Df399PR0fHoP/l14d1Z3r7WLZV0qO9co/2Jn5+fkRGRg7bxpV3tMajLciZMc0ghdYImhLa39+fd955B39/f03Y/j8+9zAmGRxN9WgtI4XWCFJojaAZoVtaWkhOTmbatGkkJSWpevHgtddeIyYmBp1Ox9GjRxXbXblyhaysLKZNm0ZCQgJpaWmOOXwlpKenEx8fj8FgwGQy0dDQoNj2JtRNZI5fHn30UVFcXCyEEOLrr78WM2fOVGxbUVEhLBaLiI6OVlXktqenR5SVlYn+/n4hhBAffvihmD17tmL7zs5Ox/7OnTtFfHy8Ytsb0YTQ58+fF3q9XvT29gohhOjv7xcRERGipaVFlR+1Qt9IbW2tiI6Odsq2uLhYJCQkOH1tn5sCHQyLxcKkSZO45ZZrH1en0xEVFUV7e/uwpa3cTVFREZmZmapsnn/+efbv3w/Anj17nL62JoQeC6xbt47W1lbVP/m4detWAD777DPWrFnjvNhOjwXjiNEeugsKCkRiYuKAe64zBAQECKvV6pStJp66J06cyIwZMygpKQGu/dJPZGSkV4btwsJCvvjiC3788UdCQ0MV23V1ddHR0eH4u7S0lLCwsGGzbYZDM1OgJ0+e5IUXXuDChQsEBwdTXFysuDx0bm4uZWVlnDt3jrCwMPR6vaKvSb///juTJ09mypQpjipO/v7+HD58eETbM2fOkJ2dTU9PD35+foSHh/P+++9jMBgUxXwjmhFa62hi6JZIoTWDFFojSKE1ghRaI0ihNYIUWiNIoTWCFFojSKE1ghRaI/wPv3Ed1dVD6rAAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 120x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot entropy distributions\n",
    "plt.figure(figsize=(1.2,1))\n",
    "n = len(entropies_dict.keys())\n",
    "keys = list(entropies_dict.keys())\n",
    "for key in keys:\n",
    "    entropies = entropies_dict[key]\n",
    "    # entropies = np.log(entropies + 1e-6)  # log scale for better visibility\n",
    "    plt.boxplot(entropies, positions=[list(entropies_dict.keys())[::-1].index(key)], showfliers=False)\n",
    "plt.xticks(ticks=range(n), labels=range(n), rotation=0, fontsize=8)\n",
    "plt.ylabel(\"Entropy\")\n",
    "# remove top and right spines\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "plt.tight_layout() \n",
    "plt.savefig(plot_dir / \"mnist_entropy_distributions.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfYAAAB6CAYAAABN7nGrAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAFHxJREFUeJzt3XlMVOf+BvBnqCMVFFChgFF7q1KxqHUBQbbBDakKVfGXYkWLbVzStMmtG5cmRumGttXbxZSGpmJN3W6tNiaYYgiLVFJXoBgrEStRQZQGOqMxHZZ5f39YJ44gzAxzZnnn+SQkzMuced85zznznbNwjkoIIUBERERS8HD0AIiIiMh2WNiJiIgkwsJOREQkERZ2IiIiibCwExERSYSFnYiISCIs7ERERBJhYSciIpIICzsREZFEWNiJiIgkwsJOREQkEbcu7PX19QgICMDMmTMRHx+P9evX4/79+6ivr4dKpUJJSQkAoK2tDYMHD8auXbvsNrbMzEzExcVh+fLlaG9vN/lbZ2cn0tPTMWPGDLz++uvo6Ojotk1mts6up+nq6+uxZMkSxd/To5j/A+6as8FgQEZGBuLi4hAbG4vLly/bdVwy62kZCA8Pd/DobMOtCzsAaDQaFBcXo6ysDF5eXtiyZQsAIDw8HEeOHAEAFBUVISQkxG5jqq6uRkNDA8rLyxEaGorDhw+b/P3o0aN47rnnUFJSgtDQUBw5cqTbNtn1JbuMjIwubY7M/FHM35Q75lxVVQW9Xo/y8nLk5ORg586dDhmjrJxlGVCK2xf2h1QqFTZv3oxjx44BAJ599llcv34dQggcPXoUixcv7nH6iRMnIiUlBZMnT8bHH3/cp7FUVFQgMTERAJCUlIRTp06Z/P3q1auYNGkSAGDKlCk4efJkt23uoq/ZPWTtdADztwd3ynn48OEQQkAIgdbWVvj7+/epLzLVl2XAFbCwP6J///5oa2szPp4+fTpOnjyJ5uZmBAUFPXG6v/76Cy0tLdi9ezdOnz6Nffv2PfG5zc3NSEhI6PLT1NRkfE5rayt8fHwAAL6+vmhpaTF5jRdeeAHFxcUAHnzbbG1t7bbNnVib3eOsmY7524+75Ozv7w+1Wo3Q0FC8/fbbePPNN81+b2Qea5cdV9DP0QNwJnq9Hp6ensbHqampeOWVV7BixYoep6upqUFaWhr8/f0hhMCAAQMAADqdDvv370dTUxPS09MxZswYBAQEoLS0tMfX8/Pzg06nAwBotVoMGTLE5O8LFixAaWkpZs6cibCwMAQFBXXb5k7Mye769evGx5cvX0ZCQgIA4MSJEz1O1xvmbz/ukvOJEyfQr18/1NbW4ty5c1i/fj0OHTpk9lipd9YsA66CW+yPyMnJwcKFC42PQ0JCEBsb2+WEmps3b5o8rqmpgYfHg1l54MABLFiwAADwzTff4N69e2hsbER1dTUA877JR0dHo6ioCABQWFiImJgYk/5UKhV27NiB4uJiDB06FC+//HK3be7EnOxGjhyJ0tJSlJaWIikpyfh7//79e5zucczfcdwlZyEEhg4dCuDB1rtWq7VoPlHvzFkGXJXbb7GXlZVhxowZ6OzsRGRkJN577z3cvn3b+PcvvvjC5PkdHR1YunQpysvLjW01NTVQq9WYNWsWgoKCsHv3bgAPVu5t27Zh48aNiIyMBACzvslPmjQJgYGBiIuLw8iRI7FhwwY0NTUhNzcX2dnZaGpqwtKlS+Hh4YFZs2YhPj6+2zbZWZqduXqajvnbnzvmvHnzZuzZswcajQZ6vZ4nzynE2mXH2amEEMLRg3AlZ86cQXV1NVatWmVsmzNnDgoLC43f5h/asWMH1Go1vL298cYbb9h7qKQA5u8emDO5MhZ2G9BoNCgrK3P0MMhBmL97YM7kKljYiYiIJMKT54iIiCTCwk5ERCQRpyzsP//8M1JSUhw9DFIAs5UXs5UXs3UtdivsWq0W06ZNw8CBA3Hx4kVj++M3Qbh06RJ0Oh1GjRplr6FRHzFbeTFbeTFbedmtsHt5eaGgoMDkYgDd3QShsLAQjY2NqKysNF7s4XF6vR46nc74o9Vq0dzcDJ4H6Bi2zBZgvs6E2cqL2crLbheoUavVCAgIMGl7/CYI+fn5xtsu1tfX48UXX+z2tXJycpCdnd2lXavVGq+9TPZjy2wB5utMmK28mK28HHqMvaebIHz22WdPnC4rKwtardb4c+PGDaWHShayNluA+To7JbIdv6VQsfGS+bjeysGhl5Tt7WYXT+Lp6WlyIwhyPtZmCzBfZ8ds5cVs5eDQLfbebnZBrovZyovZyovZysGuW+zz5s1DVVUVamtrsWbNGmRkZHS5CQK5JmYrL2YrL2YrJykuKavT6eDr68uTNCT1MN8R//4frv/3/xw9HLIhZisvfi47jtvfttXV/es/Bcbf67fNd+BIiIjIGTjlleeIiIjIOtxilwi33omIiIXdRT1axImIiB5iYSeXwr0SREQ94zF2IiJS3L/+U8A9jXbCLXYiIlLM+C2F8PD0Mj7mXjflsbATEZFDsMgrg4WdiIgcjkXedljYXQiPT5niBwGRnLhu943FJ89lZGSYPH7nnXdsNRayoYcnqvDLgPNiRkS94zpiObO32Ovq6lBbW4vKykocP34cANDR0YELFy4oNjgi2fADisg63Io3n9mFvaGhAefOnYNOp8PZs2cBAGq1Gtu3b1dscETm4kpP5D64vvfM7MKu0Wig0WiQmZmJp59+WskxEbkdflBxHpB1uNx0ZfHJc19++SX27t2LAQMGQAgBlUqFM2fOKDE2IiIispDFhf3HH39EdXU1PDx40TpyTs72DZ7H1Ynsw9nWfUexuLBHRESgsbERw4cPV2I8OHHiBCorK/HHH39g165dUKvVivRD9sds5cVs5eWq2bpzkbd4s7uiogIajQbh4eGIiIjAtGnTzJpOq9Vi2rRpGDhwIC5evGhsz8zMRFxcHJYvX4729nYkJiYiMzMT3t7eaGtrs3R40nGFf4ly5mydfd45O2fOlvqG2crL4sJ+/vx5XL16FefOncPZs2fNPr7u5eWFgoICLFmyxNhWXV2NhoYGlJeXIzQ0FIcPHwYAfP3110hMTIS3t7elwyMHYLbyYrbycqdsH91Acocv+xbvil+5ciVUKpVJ2+7du3udTq1WIyAgwKStoqICiYmJAICkpCTk5+dDq9WiqKgIsbGxiIyMxODBg7u8ll6vh16vNz7W6XSWvg23ovQuKVtmCyiTrzvvlusLR2XLvJTnCustWcfiwr5hwwYAgBAC1dXVfbpATWtrK4KDgwEAvr6+aGlpwdq1a7F27doep8vJyUF2drbV/ZLyrM0WkCdfa7cKHk7nrAWN2crLnbJ90vrprOudJSzeFR8WFoawsDCMHz8ey5YtMzk2Yyk/Pz/jtzqtVoshQ4aYNV1WVha0Wq3x58aNG1aPwd3YazeUtdkCyufrLrvjlGLvbJmX/Tjzekvms3iLfePGjcZd8Tdu3OjTxWqio6Oxc+dOrFixAoWFhYiJiTFrOk9PT3h6elrdLym/q9PabAH75stdvpZzlWzJcsxWDhYX9gULFgAAVCoVBg8ejAkTJpg97bx581BVVYXa2lqsWbMGGRkZCAwMRFxcHEaOHGnczU/2ZYvixmwfkHGr0pmylXn3qSM4U7bOord12BWWNZUQQlgyQUdHBw4cOIC6ujqEhIQgLS0N/fo59u6vOp0Ovr6+0Gq18PHxcehYbM3ehcIZF9qH+Y749//g4elllz4fnQ+WHPO2ZV7OmIWt2Spbd5hXrsYR6629Kbnc9eVcG4sr8ooVKxAWFobo6GicPXsW6enpOHjwoMUdExHZSm8fgjIecpHxPZFtWFzYb926hf379wMA5s6di4SEBFuPye3JuDvX1XSXwZM+SJ05L3f78DcnC3ebJ6Sc3g4NmbOsKfH5YXFh9/HxQV5eHiIiInD69GkMGjTI5oMiIvO+SLAwETmf3jYMlGZ2Yb9w4QI8PDzw/fffIy8vD3l5eRg1ahQ2bdqk5PiInJK9t9Kdea+AbPjFiVyd2YV906ZNOHr0KAYNGoT169cDAO7evYtFixahqKhIsQESESmlrxcDsvTwDL8okD2YXdg7Ozu77HYfNGgQOjo6bD4oIjIPty7tp7d5bYu9Kq54Hgc5H4uOsd+5cwfPPPOM8XFTU1OX68YTkWPww996T5p3lrZb0k93J1gR2YLZhf3DDz9EYmIiFi1ahGHDhuHmzZs4duwYcnNzlRyf2+DKTeReelvn+ZlA1jL7WvHR0dEoKSnB6NGj0draipCQEBQXFyMqKkrJ8REREZEFLNoVP3jwYKSnpys1FiIiIuoji+/uRkRERM6LhZ2IiEgiLOxEREQSYWEnIiKSCAs7ERGRRBx7I3U3x/9TJSIiW+MWOxERkUScbov9/Pnz+Omnn3Dv3j1s374d/fv3d/SQyEaYrbyYrbyYreux2xa7VqvFtGnTMHDgQFy8eNHYnpmZibi4OCxfvhzt7e04dOgQtm7dihkzZuDUqVP2Gh71AbOVF7OVF7OVl90Ku5eXFwoKCrBkyRJjW3V1NRoaGlBeXo7Q0FAcPnzYXsMhG2K28mK28mK28rLbrni1Wo2AgACTtoqKCiQmJgIAkpKSkJ+fj5UrVyI7Oxv37t3Dtm3bun0tvV4PvV5vfKzVagEAOp1OodErw6C/7+ghdKHT6TBo0CCL7tpny2yBJ+frjPPLlTBbeTFbeVmTrUOPsbe2tiI4OBgA4Ovri5aWFkydOhVTp07tcbqcnBxkZ2d3aR8xYoQi43Qnvp89uD3v4yu8pazNFnhyvg25GX0ak7tjtvJitvKyJluHFnY/Pz/jVrZWq8WQIUPMmi4rKwvr1q0zPjYYDGhpacHQoUMVuT+8TqfDiBEjcOPGDfj4+Nj89Z2tT1ucHGNttoB982W2lmO2ztkns1WWK2Xr0MIeHR2NnTt3YsWKFSgsLERMTIxZ03l6esLT09Okzc/PT4ERmvLx8bFboI7s0xYrobXZAo7Jl9maj9k6Z5/M1j5cIVu7FvZ58+ahqqoKtbW1WLNmDTIyMhAYGIi4uDiMHDkSGzZssOdwyIaYrbyYrbyYrZxUQgjh6EE4O51OB19fX2i1WrvugnGHPh3NXeYzs5V3PjNbeeeztX3yynNm8PT0xJYtW7rsZmKfrs9d5jOzZZ8ycZf5bG2f3GInIiKSCLfYiYiIJMLCTkREJBEW9l486XrKP/zwA6KjozFr1izcvHlTsf7r6+sREBCAhIQEJCQkoLm5WbG+Hr9GtOyYrbyYrbyYrRkE9aitrU3cuXNHvPbaa6KmpkYIIUR7e7uIiooSer1e/PLLL2L16tWK9X/t2jWRmpqq2Os/VFVVJZYtWyaEEOKDDz4Q+/fvV7xPR2O28mK28mK2veMWey+6u57ylStXMG7cOPTv3x8xMTH47bffFB3DqVOnEBcXh3fffRdCoXMdH79GtDvcxYnZyovZyovZ9o6F3Qqtra0m/1PY2dmpWF/BwcGoq6vDyZMncefOHRw5ckSRfh59Tw+vEe2OmK28mK28mK0ph15S1pk0NTUhLS2tS/vBgwcRFBRk0vbo9ZQB4KmnnrJL/4sXL8avv/6K1NTUPvf3uL5cI9rZMVtmCzBbV8Nsrc+Whf0fQUFBKC0tNeu5ISEh+P3339HW1oZz585h4sSJivV/9+5d4+/l5eUYN25cn/vqTl+uEe3smC2zBZitq2G2fchWiYP+snnppZdEcHCwiIqKEvn5+UIIIQ4ePCimT58uZsyYIa5fv65Y38ePHxdTpkwRsbGxYvny5aK9vV2xvjZs2CBiY2PFq6++KvR6vWL9OBNmKy9mKy9m2zNeeY6IiEgiPHmOiIhIIizsREREEmFhJyIikggLOxERkURY2ImIiCTCwk5ERCQRFnYiIiKJsLD/o76+HkuWLOnz62zduhUTJkww3tKvrKzsic/Ny8vrc3/UO2YrL2YrL2ZrPV5SVgE5OTlYsGBBr8/Ly8vD6tWrTdoMBgM8PPh9y1kxW3kxW3m5W7auNVo7KykpQVRUFKKiorB3714AQGVlJcLDw5GSkoLk5GSzrmW8Z88epKamIjk5GREREbh16xZyc3NRW1uLhIQEFBcXIyEhAZs2bcLcuXOh0+mQkpICjUaDtLQ0tLW1obS0FImJicbXqKmpQWVlpXEhNBgMmD59OgwGg5KzRBrMVl7MVl7M1kyKXeTWxVy7dk2kpqaatEVGRorm5mbR1tYmpk6dKu7fvy/mz58vamtrhcFgEDExMaKkpMRkmi1btojx48cLjUYjNBqNuHz5ssjPzxcrV64UQgjx1Vdfic8//1wIIcTUqVON02k0GlFUVCSEEOKTTz4Rubm5Qggh3nvvPfHdd9+JkpISERMTIwwGg7h06ZJITk4WQggRHx8v/v77b1FcXCwyMzMVmTeujtnKi9nKi9laj1vsPejs7IS/vz/UajXGjBmDxsZG3L59G88//zxUKhUmT57c7XQ5OTkoLS1FaWkpxo4dCwDG544YMQKtra3dThcREQEAqKurM/4eERGBK1euGF9DpVJh3LhxuHXrFgAgJSUFBQUF2LdvH9LT02335iXHbOXFbOXFbM3Dwt4DDw8P/Pnnn2hvb8eVK1cwbNgwBAYG4sqVKxBCoKqqyuzXUqlUxt/FP/fdebTtYX8AMGbMGJw5cwYAcPbsWYSEhAAAqqqqIIRAbW0tgoODAQDLli1Dfn4+rl69ivHjx1v9Xt0Ns5UXs5UXszUPT557RHl5OWbPng0AmD17Nj766CPMnz8fKpUKb731FgYMGID3338fS5cuRVBQELy9vaFWq7u8TlZWFj799FMAwLp1657Y39ixY5GamtrlOatWrcKyZctw8OBBBAYGIjMzExUVFfD19UVycjJu376Nb7/9FsCDewYbDAazTgxxZ8xWXsxWXszWSg45AODC2trahBBCdHZ2ivj4eNHY2GiXfktKSsT69eu7/dvChQvtNg6ZMVt5MVt5MduuuCveQqdPn0Z8fDwiIyMxZ84c4+4XR1m0aBFGjx7t8HHIgNnKi9nKi9l2pRLin4MLRERE5PK4xU5ERCQRFnYiIiKJsLATERFJhIWdiIhIIizsREREEmFhJyIikggLOxERkURY2ImIiCTCwk5ERCSR/wdxUdtt7wH16gAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 500x120 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot entropy distributions\n",
    "fig, ax = plt.subplots(1, 4, figsize=(5, 1.2))\n",
    "\n",
    "\n",
    "n = len(entropies_dict.keys())\n",
    "keys = list(entropies_dict.keys())\n",
    "for i, key in enumerate(keys):\n",
    "    entropies = entropies_dict[key]\n",
    "    entropies = np.log(entropies + 1e-6)  # log scale for better visibility\n",
    "    bins = np.linspace(-14, 0, 30)\n",
    "    ax[i].hist(entropies, bins=bins)\n",
    "    # remove top and right spines\n",
    "    ax[i].spines['top'].set_visible(False)\n",
    "    ax[i].spines['right'].set_visible(False)\n",
    "    ax[i].set_title(key, fontsize=6)\n",
    "    ax[i].set_xlabel(\"Log Entropy\")\n",
    "    ax[i].set_yscale('log')\n",
    "    ax[i].set_ylim(1,10000)\n",
    "    ax[i].set_xlim(-14,0)\n",
    "ax[0].set_ylabel(\"Count\")\n",
    "# ax[1].spines['left'].set_visible(False)\n",
    "# ax[1].set_yticks([])\n",
    "# ax[2].spines['left'].set_visible(False)\n",
    "# ax[2].set_yticks([])\n",
    "plt.tight_layout() \n",
    "plt.savefig(plot_dir / \"mnist_entropy_distributions.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = r\"DM, $p_\\theta=0.99$\"\n",
    "data, labels, feats, model = model_dict[key]\n",
    "predictions = predictions_dict[key]\n",
    "entropies = entropies_dict[key]\n",
    "\n",
    "confusion = confusion_matrix(labels, predictions)\n",
    "\n",
    "# rearrange predictions to match labels\n",
    "labels_inds = -np.ones(np.unique(labels).shape[0], dtype=int)\n",
    "predictions_inds = -np.ones(np.unique(labels).shape[0], dtype=int)\n",
    "# get best fitting labels\n",
    "best_fits = np.dstack(np.unravel_index(np.argsort(confusion.ravel()), confusion.shape))[0][::-1]\n",
    "\n",
    "for j, best_fit in enumerate(best_fits): \n",
    "    l_ind, p_ind = best_fit[0], best_fit[1]\n",
    "    if p_ind not in predictions_inds and l_ind not in labels_inds:\n",
    "        predictions_inds[l_ind] = p_ind # replace l_ind with k for sorting\n",
    "        labels_inds[l_ind] = l_ind # replace l_ind with k for sorting\n",
    "\n",
    "# resort feats\n",
    "predictions_inds = predictions_inds[np.argsort(labels_inds)]\n",
    "feats = feats[:, predictions_inds]\n",
    "predictions = np.argmax(feats, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAB6CAYAAAABKL4WAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHxFJREFUeJzt3XtYVHX+B/D3DJfhfg0B5S4IohiwgIoKFikWiqRlZKSxbbsrXnJT0/aXYq71bKvmqiVurZs95qplhRW1aQLqSuYFUAGRuCmIJHIZYLjN5fP7g/UUcnEGmIGpz+t55nmYM2fOvOfwnc98zznfc0ZERATGGNMj4qEOwBhjmuLCxRjTO1y4GGN6hwsXY0zvcOFijOkdLlyMMb3DhYsxpne4cDHG9A4XLsaY3jHU1QuJRCJdvVSfND1RgHMPTH9OzNDX7Jx7YDTJzT0uxpje4cLFGNM7elG4xGIxli5dips3b6KiogI2NjZDHYkxNoREuro6xEC2o42MjFBZWQkHBwcAwMmTJzF//nzU1dVpvKxfw/Y/oL+5Af3NzrkH5he3j8vKygojRoyASCSCSCTCpEmTMGXKFJiamg51NMbYENCLwtXW1tblvomJCaKjo2FhYTFEidRna2sLNzc3eHh4wMPDAwYGBkMdiTG9p7PhEAPR0dHRbVpoaCjMzMyGIE3PAgICkJiYKGRSqVQQi8UwMzODiYmJ0B2vr6/Hhg0bUF1dPZRxfzEcHBwwd+5chISEAOjc3Dhz5gzEYjH279/fr01VNnBisRgBAQGwtrYGEaG8vBz19fUwMTHBnTt3Brz8YV+4xGIxJkyY0G26vb39sOi9uLm5ITw8HLGxsYiLi4OpqSmICEQEsbjnDu1bb72lV4VLIpHAx8cH5ubmMDIygoeHB0aNGgWRSITCwkKcPXt2SN6Pi4sLnnrqKSQmJmLcuHEAOr8wwsLCIBaL4e3tLRSub7/9FtnZ2ZDJZDrPCQDu7u6YOnUqxowZ0+PjRIScnBwcPXpUx8n6TyQSwdHRETNnzsTo0aO7PCYWi+Hj4wMLCwsQEaqqqtDY2AiJRIKCggKkpqYOqM0M+8JlamqKF198cahj9MrIyAjh4eF4+umnhWl398XpM2dnZ8TFxQHo/B+MGzcOlpaWMDIygo+PDzw8PCASiXD58mUcO3YMZ86cwbFjx3Sa0dfXF7///e/h6emJsrIyXLp0CaNHj4a1tTWcnJzw6quvCv+HBx98EFlZWfj4449RVlams4xWVlaYPn06IiMjMWPGDAQEBPQ4n0qlwtmzZzFy5EiUlJTgxIkTUCqVOsupCUNDQ7i4uCAmJgYjR47EvHnz4Ovr263N37x5E21tbSAiuLm5wcrKCh4eHqioqIC5uTm2bdvW/wwDfRPaZGBgAE9PTzz77LNDHaVXJSUlOH78OKZMmQJ/f39IJBLI5XIUFxeDiCCXy+Hn5weJRDLsi5mhoSH8/f3h4+ODsWPHYtOmTULmtrY2yOVyoQejVCohkUgwadIkBAUFITs7GyKRCGfOnEFzc7NO8trY2GDMmDFoaGjAN998g127diE8PBzW1tbw9fXFqFGjMHLkSIwbNw5z587F3LlzUVRUpNPC5e3tjTVr1mDq1Klob29HfX092tvboVAo0NLSArFYDHt7e9ja2iI8PBzh4eHIysqCWCxGenp6j7tJhopYLIaTkxOmTZuGwMBArFu3rtd5L168iG+++QZ1dXVQqVSQSqVwcnJCXFwcQkNDsXz5cmRnZyMjI6NfWYZ14TIzM8PChQuHOsZ9ZWZmYuPGjUhMTIS9vT2amprwySefQKlUQiaTYe3atcI+mOHIwMAA9vb2CA4ORmJiIhYsWAAiQk1NDZqamnDr1i3cuXMHzc3NUKlUADp7lTY2NvD09IS7uzsmT56Md999F7/73e9w/PhxneTu6OhAc3MzOjo6UFJSgoKCAhQUFAiPOzo6YsqUKXj33XdhZ2c3JF8cNjY2wpdZfn4+Lly4AKlUitbWVty6dQtGRkYYP348/Pz84OzsDG9vb4SHh2PHjh1YuHAhLl++DLlcrvPc9zIxMYGHhwfmz5+PDRs2wNjYuNd5KyoqsHnzZnz77bddvsQMDQ1RXl6OAwcOYNSoUdi+fTsCAwP7F4h0BIBGN2NjY4qIiOh1ecXFxeTl5aXxcrWZWyQSkYGBQbfpr7/+OikUCiIiUiqV5OfnN2xyGxkZkZeXF61bt054rkqlourqatqyZQstWrSITE1Ne31+REQE/ec//yEioqamJjpw4MCAcmuSfcqUKXT27FmSSqX0/vvv9zrf9evXSaVSUVtbG82ZM0en69zU1JTeeecdOnPmDCUkJPTZduLj46mwsJBUKhURER08eJCcnZ2HtK0YGxvTyJEjKTo6mv71r3/1+HyVSkV1dXVUWlpKpaWltGrVKrKxselxefPnzxeeV1lZ2e/cw7JwGRoa0sSJE6m+vr7X5V29epU8PT2HTQHo63a3IRIR3b59m3x8fIZFbhMTEwoKCqIdO3Z0ea5cLqfly5erne1uY2xqaqKDBw8OKLcm63zChAl08OBBKisro5dffrnXD15VVRUREZ07d47Cw8OHdVtxdnamjo4OYZn+/v5D1lZEIhGFhITQ4cOHSS6XC/OrVCpSKpXU3t5O7e3tVFFRQdu2baOJEyfSpEmTyMjIqNfP9TPPPENERAqFggoLC/ude1huKvr5+WHLli3CqT30v31FRkZGQnc/Nze32/gufXD06FE0NjYOdQwYGxtj9uzZWL9+fZejtkSEhoYG7Nq1S63liEQi4ehpW1sb8vLytJK3J01NTcjKysLp06exe/fuHueZPn06JBIJAGDjxo24cOGCzvL1x61btyCXy2FoaChsjhsaGkKhUOg8ywMPPICVK1diwYIFADrbhkKhQHNzMxobG1FWVgaFQoEZM2aotTx/f3888sgjADqHBe3fv7/f2YblAFQbGxtMnjwZwE8r69ChQ13mef/991FfXz8U8fSeiYkJkpKSsH37dqFoERE6OjpQW1srnFqlDmdnZ4wdO1ZbUftUVlaGXbt29Vq0AODIkSOws7MDAEil0mG1s7s3X3zxhVCo3N3dezxDRCwWw8DAAAYGBlrbd7dy5Uo8/PDDADrbR1tbG7788kvY2dnBx8cHL730ktpFC+j8shyss12GZeHKzc3FwoULhZX15z//GYsXL+4yj0ql0svBhcbGxkN6dNHS0hJ79uzB9u3b4eLiAqCzUebn50MikWhUtABgypQpeO6554TlDKdD+O7u7sP+SG5PXFxchF6sm5tbtw/72LFjkZSUhK1bt+LLL79EfHz8oGdwdXVFaGgonJ2dAXT2kP7xj39g3rx5AAC5XI6cnByNlunu7o7g4GAAg9BWNN4g7icMwvb/z/cVPfTQQySRSIbFviJNcqtUKtq6dSulp6eTVCollUpFKpWKjh8/ThEREVrN7eDgQF9//XWXPEREJ0+e7Nf7sre3p+TkZGF55eXl9Mgjjwwod3/WubW1NcXHx1NsbKwwLSEhgeRyebd1f/d2/fp1WrduHZmYmOi8rdjb29PSpUu75Onpdm/m3m4VFRWDnjs+Pp6uXLkizFNQUEBJSUkD+hwsWbJk0NoKFy4t5x4/fny3QqFQKEipVHaZrlQq6fDhw1rNnZiYSAUFBcI81dXV9Oabb/Z4JFSdW3JysnC0lIgoLy+PRCKRTgvX6NGjaf/+/SSXy0kul1NHRwd1dHSQQqHott5/TqVS0aefftrnzm9t5F68eDHV1dV1WW/91dDQQB9++CFZWVkNeu4VK1ZQcXGxMM/BgwfJ1dW1358DS0tLeuWVV4TlDbStDMud8/ro7mbWqFGjMGfOHKxYsQIAejztp6dTlU6fPo29e/dqNePdMUV3nTx5Env27NG4yy4SifD3v/8df/jDH4T3cvHiRSQkJOh8810mk6GoqAiGhr03ZZlMhkWLFgmbNgYGBmhra4NUKtX5KUBmZmawsbERNmEVCkW3AbHe3t7C46dOncKJEyfQ0NAAoPMAyM2bN1FUVITq6mrI5XKtHKRKSUlBVFSUcCrP/PnzYWNjg9WrVyM/P1/j5a1YsQLr168H0HlQJS8vb2BtRePy3E/4hfa4IiMj6dChQ1RVVUXV1dVUU1NDTU1Nai07JyeHPv74Y9q8eTNNnz6djI2NtZr76aef7tL9b2lpoU8//ZR8fX3VXn+LFy+mM2fOUHNzs7Cc3NxcevLJJ3vsufWHJv9PsVhM9vb29Oijj9LatWtp165d9MEHH3RpK8HBwSSRSEgsFgu3e7/tddFWgM7NJSKi/Px8mjVrFjk5OZGDg0OX28+HQ0ycOJHMzc3JzMyMzMzMyNTUlIyNjUksFms99yOPPEKZmZnCfG1tbZSXl0erVq3S6H+UlJREhYWFwnLOnz9Po0ePHlBuLlwDzL127douH+L7+fDDDyk6OprGjx9PHh4e5OzsTHZ2dn3uaxms3La2tpSWltZlvubmZvrhhx8oIyODnn322W6bHZ6envToo4/SmjVr6LvvvqNbt25Re3u78PyzZ89SfHw8mZubD0ru3rL3dROJRGRiYkI2Njbk4OBATk5OQlv57LPPyNbWtl/tTRu5J0+eTKtXr6bJkyf3+j+/W7gqKyt7/IDrKrepqSnt3bu3y7xyuZxu375N33//PS1YsKDX50ZGRtLOnTvpu+++ox9//FF4T5mZmRQUFDTgLzm93VS8efPmsDgVwszMDObm5l2mtbe3IysrC4WFhViyZIkwfdeuXdizZw9KS0uHZAxafX09Vq9eDblcjrlz5wIAzM3NMXr0aLi6usLNzQ1JSUlQKBSoqKgAAIwZMwYWFhawsLDAiBEjYGRkBKDzqO7evXuxc+dOVFRUDNlVFwAIR597WqdZWVlob28fglQ9u3z5MoqLi9UampGXlzek2VtbW/HGG29AJBIhMTERQOdpOw4ODrC1tcXmzZsxd+5cyOVyXLp0Cb6+vrC0tMSoUaPg4OAAR0dHWFlZCW0mJycH77//Pq5cuTLgo896VbjKysrg6emJ06dPCydvDrUbN24gMzMT5eXlyMrKQk1NDZRKJaqrqyGVSiGRSJCYmAiRSITs7OwhK1p3Xb16Fbt374aZmZkwBkckEkEikcDLywteXl4gIuEcMwsLi25DCtLS0vDVV1/hxIkTuHbtms7fQ1/EYrHwIQM63+9QDN7sjUwmU7vI3x2MOpRKSkqwZ88emJubCwNRgc4C5u3tDScnJ6hUKkybNg02NjYwMjKCubl5t327Bw8exO7du3Ht2rVB+X8M+8JlYGCAkJAQxMfHCztgnZycsGnTJtTU1ODq1avIycmBvb096urqUFhYqNN8GRkZKC0tRW1tLUpLS7ucVCoWi7t8sO9e4mOofffdd9i8eTOOHDkCc3NzBAQEIDo6GiNHjgTQWcgsLS27PGffvn3Izs5Ge3s78vLykJ+fD6lUOhTx+yQWi4XLIKWlpeHSpUvDqnDdz+zZs4UvipycnGFxdsjly5exZcsWZGZmYuzYsVi+fDmAru3E2tq62/PKy8vx0UcfoaSkBBcuXEB2dvagZRpWhSs0NBSzZ8/uMk0sFmP8+PGYNWuWMM3b2xtubm6QyWQoLS3F1atXYWNjA6lUisOHDyMtLU1nmUtLS1FaWtrjY2KxGDNnztRZFnU1NTXh1KlTOHXqFExNTeHt7Y1z5871Ofj06NGjKCwsHNYjzw0NDREbGytc82r37t2oqqoaFj1zdb344otCb0WpVA6LL7q2tjah8Hh5eaGmpgZAZ6fiN7/5Tbf5a2pqUFlZiRs3buCbb74RdjsMpmFVuEJCQrBhw4b7zicSiWBiYgITExPY29sjNDRUeKypqUmnhasvIpEIDz300FDH6FNrayuuXLmCK1euDHWUAbOyssKqVatARJBKpTh9+vSwGsmvjp8PlTAxMen1KrpDQaVSobi4GH/5y18AdBauKVOmdJuvuroaZWVlWt3MHVaF68aNGzh58iQcHBy6jDfqiUwmAxHB1NQUBgYGUKlUaGpqQnFxsY7S3t/PT0Bm2mVoaAhfX1+Eh4eDiLB///4hPWDQX4WFhQgODoZIJEJRUdGwOrBwL6VSiVOnTg3Jaw+rwpWRkYHKykoEBwf3uM38c7du3YJKpcIDDzwAiUQCpVKJ69eva3z+lDbp43ly+srMzKzL5bNXrlypV5uId4WEhAjtpqSkZFhvmg+lYVW4WlpacOnSJVy6dGmoo2iFPu0k1icikQgODg5Yvny5MDRCH4sWgC7XbpfJZHr7PrSNt2O0SKVS4fbt28KleaqqqvRun4s+uPsrRB0dHcLPk+mrhoYGYYe8t7d3n5dI/jXjwqVFCoUC0dHRaGxsxNWrV5Gbm8u9Li1oa2vD+fPnkZWVhbq6Oo2uETXcvPTSS2htbUVbWxt8fX3vu6/310pEOjreOlz292j6djn3wPSneelr9sHK/frrr+P27dt477330NLSovHzfw1thQvXfXDugeHCpXu/htwaF64TJ04gKipK41CMMTZYNN7HlZ6ejqlTp+Lll1/+RQxaZIzpn35vKv73v//Fjh07UFRUhMWLF+P555+/79grxhgbDBr3uDo6OnDkyBFs374dSqUS69evh4eHB2JjY7WRjzHGutF4AOqMGTPw+OOPIyUlBSNGjBCm19bWarQcj3U/nU9Y/tcYTWP8YmhzPdxd9q95/d5LX9udLtqJNpatLRr3uOLi4rBy5UqhaP3zn/8EALzwwguDm4wxxnqhduFSKBSQyWRITU1Fa2srWlpa0NjYiE8++USb+RhjrBu1NxUPHDiAffv24fLly4iJiQERwdjYGHPmzNFmPsYY60btwrV48WIsXrwYZ8+exaRJk7SZiTHG+qR24XrttdeQnJyMbdu2dRtp+9FHHw16MMYY643aheuPf/wjAGDr1q1aC8MYY+pQu3BdvHix18fc3d0HJQxjjKlD7cJ1/vz5Xh977LHHBiUMY4ypQ+3ClZycrM0cjDGmNrUL15IlS5CSkoLQ0FBh5zwRQSQS4dy5c1oLyBhj91K7cKWkpADoe5ORMcZ0QeNTfvLz8zF37lxMnDgRcXFxyM/P10YuxhjrlcYnWf/2t7/FBx98AD8/P1y7dg2LFi3C999/r41sjDHWI417XI6OjvDz8wPQ+VNKP79CBGOM6YLaPa41a9ZAJBKhvb0dERERCAoKQk5ODmxsbLQYjzHGulO7cM2ePRsAEBPz0/V65s2bN/iJGGPsPtQuXJGRkcLfJSUlqKqq6tcvuDDG2EBpvHN+xYoVuHHjBnJychAYGAgiQkREhDayMcZYjzTeOX/x4kWkpqbC09MTR48e5Z8IZ4zpnMaFy8jICABgZmaG9PR0FBYWDnooxhjri8aF6+2330Z7ezu2bduG1NRUbN++XRu5GGOsVxrv4xo/fjxyc3Pxww8/IDExEUFBQdrIxRhjvdK4cP3pT3/Cjz/+iJCQEHz22WdwcHDAjh07tJGNMcZ6pHHhys7OxsmTJ4X7Px8mwRhjuqD2Pq6Wlha0tLTgwQcfxIkTJ1BfX4/09HSEhYVpMx9jjHWjdo8rJiYGIpEIRIQrV64I0+/94QzGGNM2tQtXRkZGl/tKpRIGBgaDHogxxu5H4+EQx48fR2hoKCIiIhAaGopjx45pIxdjjPVK453zycnJSE9Ph6WlJRobGzFr1izMnDlTG9kYY6xHGve4VCoVTExMAAAmJiZQKpWDHooxxvrSr3FcoaGhcHd3x/Xr17Fu3Tpt5GKMsV5pVLiICK2trcjOzkZNTQ0cHBwgFmvcaWOMsQHRqHCJRCJ88cUXeO655+Do6KjWc4gITU1N3aar2luEvxsbGzWJMWgsLS17Hc7RW+7B1p/1oG7uu8seqvV7r1/6+tYmXbQTTZatTX3lvktEGl4NMDY2Fo2NjQgJCRF6W3/72996nb+xsRHW1taavITOSKVSWFlZ9fgY5x58nFu3fom571J7O+/ixYuIjo5GcXExLl++DH9/f8TExHS5lHNPLC0tIZVKIZVKUVFRAQCoqKjo8ndfj6k7nybLuHuztLTk3Fp4T4OZWxt5fmm5e3rtoczd32Wok/sutTcVly5dip07dyIwMBCZmZl4++238fnnn9/3eSKRqFv1/Pl9Kyurbvf7M19/lsG5tfOeBjO3NvP8UnL3J7s2c/d3Gequc0CDHpeZmRnCwsJgbGyMmTNnorm5We0XYYyxwaR2jysvLw8LFiwA0Llj7+f3P/roI+2kY4yxHqhduM6fPz/gF5NIJEhOToZEIgGALn/39Zi686m7DM6t/fc0mLkHO88vMfe9jw117v4uQ10aH1VkjLGhxqNHGWN6hwsXY0zvaHyuYn/JZDIkJSXB2NgY06dPx+TJk/H6669DKpUiISEBaWlpaGxsxIwZM3DhwgXcuXMHUVFRWLJkCWQyGSIjIzF//nx89dVXGDduHBYsWIATJ04Ig2G9vLxw4MABKBQK5OTkwNXVFXZ2dhgzZgxiY2OxceNG2NvbIyoqCk888QTn/l/u559/Hq6urtixY4eQfdGiRd1yx8fHIyIiAuvXr0djYyMsLS1RV1fHufU8973Z/f39kZeXp5XcBQUFOHToEFasWAE7OztYW1vj5s2b/cqts31c+/fvh42NDebMmYOnnnoKhw8fBgA88cQTOHLkCACgvr4eq1evxt69e6FSqbBo0SJ8+OGH2LBhAywsLKBUKnHq1Ck4OjoiLCwM33//Pezt7RETE4OoqCgAQGpqKr799ltMmjQJCQkJeOqppxAWFoawsDBMmzYNsbGxao0/+7XlBiBk9/Ly6pb71VdfxZUrV5CamtolO+fW79y9ZddG7h9//BEuLi6or69HQkICAgMDsWvXrn7l1tmmYmVlJVxdXQGg1yunbt68GUuXLsXnn3+OmJgYPPbYYzh+/Dj8/f0xYsQIjBs3Dl9//TXefPNNvPvuuwgPD8dbb72FlJQUYRn//ve/hZX88MMPY9asWXj22Wdx6NAhrFmzBrW1tZz7ntwAhOyurq495k5OTsa1a9e6Zefc+p1bneyDlXvhwoWYNGmSkD0xMbHfuXW2qeji4oLKykoEBgZCpVJ1eYyIsG7dOjz66KMIDg5GcHAwYmNjERMTg8DAQMhkMhQUFMDU1BSzZ8+Gra0tjI2NYWtrC+CnlX3jxg1YW1vjyJEjeO211xAREYEnnngCiYmJeOedd6BUKjFv3jzOfU9uoPMc1NjYWHh7e6O1tbVb7vb2dri4uMDY2FjIzrn1P3df2Qc7t6WlJbZu3dol+5EjR/qVW2ebijKZDMuWLYOJiQmmTp2KWbNm4f/+7/+EnklVVRVCQ0NhZmYGhUKB9vZ2TJgwQaj0+/btQ1FREWpra9HQ0IDExER8/PHHMDMzg5+fH5YuXYrk5GRER0fDysoKGzduxAMPPAALCwssW7YMb7zxBmQyGZYsWYKpU6dy7v/lDgwMhJ+fHz799NMu2e/NvWTJEoSFhWH58uVC9tu3b3NuPc99b/YHH3wQubm5WskdHh6OvLw8IbtSqYRIJOpXbh7HxRjTOzwcgjGmd7hwMcb0Dhcuxpje4cLFGNM7XLgYY3qHCxdjTO/opHBlZWVh+vTpiIyMxMMPP4wLFy50m6e8vBzHjh3TRRy1cW7d09fsnFvHSMtqa2tpwoQJVFVVRUREDQ0NdPHixW7zZWRk0KpVqwb99ZVKZb+ex7n7p7+5ifQ3O+fun4G0Fa33uNLS0hAXFwdnZ2cAgLW1NYyMjBAZGYnJkydj2bJlAICUlBQcPnwY06dPR11dHfbt24dp06YhPDwc6enpAIBjx44hKCgITz75JCIiIlBeXo7GxkbExsYiMjIS8fHx6OjoQGZmJubMmYPHH38cW7Zs6fJLRFFRUWr9dhzn1m1ufc7OuXXfVrTe4/rrX/9Ke/bs6TKtpaWFVCoVERHFxsZSUVFRl6p+584dio6OJpVKRc3NzRQZGUlERBMnTqTa2lpqa2sjDw8PKisroy1btlBKSgoREW3atIk++OADysjIoGnTpgmvERcXR1VVVVRSUkLx8fGcexjm1ufsnFv3bUXrJ1mPHDkSP/zwQ5dpZWVlWLVqFVpaWlBaWoqqqqouj5eUlCA/Px8PPfQQAKCmpgYAoFQqYWdnBwAYP348AKC4uBgvvPACACA0NBRnzpyBm5sbQkJChF/DTUhIwMGDByGTyfDMM89w7mGYW5+zc27dtxWtbyrGxMTg6NGjuHXrFoDOX9Bdu3YtVq1ahZMnTyIoKAhEBCMjIyiVSgCAl5cXJkyYgIyMDGRmZiI3NxdA55nm9fX16OjoQH5+PgDA29sb586dA9D5gx4+Pj6db0z801ubM2cO0tLScPz4ccyaNYtzD8Pc+pydc+u+rWi9x2VnZ4eUlBQ8/fTTICIYGBggOjoaL774Ivz8/ITLaAQEBOCVV17Bk08+iffeew/x8fGIjIyEgYEBAgICsHPnTmzatAlRUVHw9PSEk5MTjIyM8MILL+CZZ57BoUOH4OjoiLVr1yIrK6tLBmNjY/j5+UEsFsPQUL23zLl1m1ufs3Nu3bcVre/jGkwdHR1ERNTW1kYBAQGkUCjUfu6yZcvo/Pnz2orWJ86te/qanXOrR2cXEhwMqampeOedd9DY2IiVK1f2eqXJeyUlJUEqlSIkJETLCXvGuXVPX7NzbvXw9bgYY3qHT/lhjOkdLlyMMb3DhYsxpne4cDHG9A4XLsaY3uHCxRjTO1y4GGN6hwsXY0zvcOFijOmd/wdmCdnzLmq/FwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 300x120 with 12 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAB6CAYAAAABKL4WAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIjRJREFUeJztnXlcVFX/xz/DMMOAIAoJiqAQKogbKIuiLIkJiiAmpaVZPGY9bmU/zUczc0krKzU1pXyqR1vcS9NoUVlMsRSXBCTBXRQcFYGBGZhh5n5/f/AwD8M6M8zC2Hm/Xuf1Yubec+77Hu587zn3nnMvj4gIDAaDYUFYmVuAwWAwdIUFLgaDYXGwwMVgMCwOFrgYDIbFwQIXg8GwOFjgYjAYFgcLXAwGw+JggYvBYFgcLHAxGAyLw9pUG+LxeKbaVIvoOlGAebcNfSZmWKo7824bunizFheDwbA4WOBiMBgWBwtcDAbD4mCBi8FgWBwmuzivLS4uLggODoa9vT1SU1Nx//59cyu1mYiICPj5+UEkEqGiogJ37tzB3bt3IRaLUVJSArlcbm7FRw4vLy9cv37d3Bo64+zsDFtbWwwbNgxubm6wsvpf24KIUFZWhvPnz+PChQtmtDQ/PFM9j0ubOxeurq6YMGECXn75ZQgEAmRnZ2PKlCkG9TDlHZdOnTrhyy+/hJubG5ycnMDn86FQKCCTydQpLS0N+/fvR0FBgdm8+Xw+OnToAFtbW4jFYp220xrmuqt49OhR3L59G1u3bkVWVhZqamp0LsNUxwqfz8eTTz6JadOmoVOnTrC2toarqys6dOigUSYRQaFQ4MGDB/jpp5+wbt06KBQKs3m3hLOzM6ytrVFRUYGqqiqtnHTxblctrj59+iA+Ph4BAQEAgM6dO5vZSD/4fD569eqFN998ExMmTGhxXW9vb4SEhGDv3r3YuXOniQz/R0BAAGbOnAkbGxsIBAKIxWK8/vrrJvcwFDweD6GhoRg5ciSqq6vh5uaGTz75BEePHoVMJjO3ngbW1tbw9vbG/Pnz4evri5CQEAgEAq0CiYODAwoKCvD9998b3EsgEOCll15C3759cebMGaSmpuLOnTta51+zZg26d+8OKysrVFdXo6amBgcPHkRKSorhJMlEAGg1xcfHU05OjjrP3bt3qX///lrl1TYZw7t+srW1pVGjRtH27dubLbOkpIQUCoXGdydOnKC4uDiTevfu3Zs+/vhjjXxlZWVmrW9t3ZtLzs7OdPDgQeI4jlasWEFLly6lp556iuzt7dvVseLq6kqzZs1q8Thpidu3b9M777xjFO++ffvSiRMnSKFQ0Pnz52njxo00atQorfbLysqKioqKGm0jPT2d4uLiyMrKyiD13a4C14gRIyglJUWdp7KykpKTk7XKa21tTXw+36wHo0gkomHDhtFPP/3UbHkXLlygNWvW0Pr16ykzM5PKysqIiKiqqop2795NXbp0MZn3pEmT6ObNm+o8FRUV9NFHH7UpcLTVW1v35o6B2NhY4jiOTp06RdbW1ur9jImJIUdHR6O6a1tu9+7daeHChRp1ryvGDFyxsbF06dIl9fLq6mpau3atVvv2yiuvqI/phhw6dIj8/f0NUt/tqqt4+fJlpKWlISwsDA4ODujQoQOmTZuGw4cPY//+/c3ms7e3R2BgIEpLS8160bJbt26YMWMGxowZ0+Ty6upqJCcnY8eOHZBIJBg/fjwWLFiAYcOGQSQSoWfPnhgwYADS0tJM4uvk5IQePXqoP5eXl2PBggXqz+7u7vDy8kJpaSmkUinEYrHW1yvMgZWVFXx8fAAAqamp4DgOADBlyhR06tQJs2fPRk5OjjkVIRKJMHz4cMybNw/dunVrcp2CggIUFhaivLwcHMfBxsYG/v7+cHNzA5/PN7pjp06dIBAI1J+trKw0bhI0R9++fbF27Vp06NChyeUBAQGYPHkyrl+/jvLy8jY5tqvAJRaLkZ6ejpCQEISGhqJ79+6ws7PD7t27ERYWhlOnTjWZz8vLCy+++CLOnTtntsAlEAjg5+eHpKSkZtc5efIk9uzZA4lEAgD44Ycf0KNHD3Tt2hW9evVCly5dEBAQYLLApVKpoFAoIBQKAQBCoRBhYWFQKBRQqVSIjY1FXFwcCgoKcOfOHZw+fRpFRUXIy8tDaWmpSRx1QaVS4cyZM1CpVPDz84O3tzeuXr2KwMBAdOvWDR07djS3Inr37o3o6OhGQYuIcO/ePeTn52P37t1ITU3FlStXoFKp4OLigi1btmDcuHEmCVyPPfaYOnA9ePAABQUFyMvLazXP/Pnz1cdSU3Tv3h2RkZE4cuQIUlNT2yapc7tST6BDE71r1660ePFiun79OhERcRxHYrG40Xp8Pp88PT3phRdeoHnz5lFgYKDZmv8eHh60cuXKZst58OABBQYGqrsvdcna2ppWrVpFRLVd4127dpnMOyEhgS5evKiRT6VSkVgspmvXrhHHcY3KvXfvHk2cOJGEQmG76yrW1eepU6eI4zj66quvKCgoiO7du0dERImJiWRnZ2e2rqKtrS0tWLCgUT6VSkU3btyg999/v8lrQIMHD6bjx49r5CksLKQVK1YYxTs6OpqOHTtGly9fpo0bNzZ53NZPQqGQ5s2bR0qlUqNcmUxG165do5ycHBKLxURU28VdtWpVm73bZeACQDY2NhQeHq7eYY7jGv1Te/fuTb///jsplUp6+eWXzXIw1qXg4GBKTU1tlJ/jOKqqqqJ33nmnyR87j8dTBzyO4ygjI8Nk3q6urjR//nwqLy/XqewvvviCBgwY0C4DV10qLCyk8vJyjeB7+vRpCg0NNVvgioyMpB9//FEjD8dxVFZWRjExMc3mS05OpgcPHmjku3z5Ms2dO9do3sOHDycfHx8SiUQt7pOVlRWFhoY2OslVVFTQ4cOHKSgoiADQihUriKj25s+XX37ZZu92O3JeLpfj5MmTGDNmDJRKJVQqFTw8PGBtXdu7FQgEWLx4MXx9fcFxHFQqlVl9ZTIZiouLNb4jIsjlchw6dAjLly9vcsxN586d1V2YmpoaVFZWmsQXqO2a79y5E++99x7Ky8tRUVEBuVwOuVyO6upqSCQSlJeXq6+11DF06FB4eXmZzFMfPDw8sGbNGpSUlKjd3d3dm73+YgqeeOIJxMbGanxXU1ODs2fP4pdffmkyD5/Px4ABA+Ds7Kz+TqlU4tq1a8jIyDCaa2ZmJvLz81FdXd3sOjweD926dUNmZqbGEI66a7lJSUnIysoCACgUCvVAa4OMG9M5POsJ9Dxz2tnZ0eHDhykjI4N++eUXio6OJpFIRM8//7z6tuu+ffsoODjYrC0uOzs7mjhxIikUClIqlVRTU0MymYyWLl3aYr5PP/1Uva1bt241u76x69va2pq8vb1p1apVtGrVKnrzzTepQ4cOxOPxyNraWuMW988//0wjRowwSn3r495S6tixIxUXFxPHcTR8+HCd8hrau67VUZ/S0lJKSkpqNs/o0aPpr7/+0sjzyy+/0LBhw0x+jDdMDg4O9Pvvv2uUVVNTQ0uWLKHOnTtrrOvp6Unr168nmUxG33//fZu9233gqp/69+9P3t7edPv2bVKpVERU29TmOM7sXUWgttvXrVs3WrJkCc2cObPV9YVCIX3++efqbWVlZTXbZTBHfQO13fGbN29q1HdCQoLRfvyGdAdAnTt3puLiYpLL5S3+2I1d525ubpScnNwoz+3bt1sc25STk9OoG7Z582azHeN1yc7OjubMmdOorPHjxzd5SaQucBERHT9+nGxsbNrk3a7uKrZGdnZ2o++++uorrFmzBn/99ZcZjDQhIhQXF2P16tVarf/ee+9h/Pjx6s83btxotstgLnJzczVujWdkZOg0itrcuLi4wMrKCiUlJXpN+zEUIpGoyTtubm5ukMvlyMzMRGRkpMayzz//HD179jSRofY4OTlh1qxZWLlyZaNl2dnZUCqVreYPDw/HkSNH9JfQOTzrCQxw9mx45pHL5RQdHU0CgcAsZ9H6ycXFhZ599ln64IMP1BckW0oxMTGUlZWl3qf9+/dTv379zHoWbZjkcrlGnZ84cULri/L6ehvKvS7dvXuXOI6jKVOmUMeOHY3q3lp5TXUViWpbsSqVihQKhUZSqVRN3tk1d4vLzs6OEhMTNfzPnz9Pzs7OxOPxmszz1ltvkVwuJyKiY8eONbmeTs4676WetOXgEwgEJJFIGv0TOY6j7OxsCgsLM8vB6OnpSZs3byaJREIVFRUkk8lILpdTZWUlSSQSkkgkVF5eTv/5z380ugPW1taUlpamcft4y5YtBpsO0Zb6FggENHfu3Eb1ffbsWfLx8WnR0RDebXFvmMaOHUtSqZQ4jmvxpGCqOh8yZAh98803etVJfcwZuJydnWnu3Lkkk8mIiEihUFBUVFSrw0yWL19OHMfR8ePHKSQkpM3e7T5w8Xg86tGjR6PyCgoKKC8vj9LS0nS6dmEobz6fT+vWrSOpVNpifo7jSCaTUVFREXXo0IFsbW3p+PHj6rMPEdHWrVvJ09PTbAdj/Xq+efNmk8Mjrl+/Tt7e3kb/8evj3lwSi8XEcRxJJBLq06eP0d1bK6/+DZC2YM7ANXjwYMrLy1PnlcvlrfZ4fHx8aPv27ZSTk0OvvPJKs2PCdHLWeS/1RN+DTyAQUFJSkkZZ06dPJx8fH/Ly8qKePXuSra2tyQ/GTZs20cOHD3Uq69KlS/TXX39RdXW1+rt3332XPDw8Wp1naSjv5pJIJKLo6OgmuyZEtWfWlJQU8vHxsZjAVTeR/eOPP252Dqip65zP55OzszNFR0e3OKe1OXJzc2n69OlmOVasra1p9OjRGg8IkMvlreZzc3OjGTNm0IwZM1rsruvkrPNe6ok+Bx6fz6eBAwfShQsX1OV88cUX1KVLl2b70qY4GCdNmkTZ2dlEVPsEi02bNtHw4cM1Unh4OM2fP5/u37/fbDDYu3cvhYSEaLUvxq5vKysr8vDwoP379zdbZmVlJf3f//0fde3a1Wj1rY97U2nXrl2kUqmourqaBg0a1OLIb3PUuUgkokGDBtGSJUu0Kvvhw4e0atUqGjx4cKtB2Bje3bt3p7lz51Jubq46X01NDf38888t5hs2bBiFh4eTj48POTk5Gcy7XQcuBwcHevfdd6mqqoq2bt1KCQkJ5Ovrq3fQMtQ/dffu3VRZWUlERF999RX5+fk1WofH45GrqyvFxMRQbm5uo+CVmZlJcXFxWj9uxRT1bW1tTYMGDaKEhASNdOnSJaqpqSEiovz8fIqMjDRafevrXpf4fD7t37+fKioqiOM4evXVV3Wa5mOqOre3t6dx48bRqVOnWiyzqqqKFi5cSGPGjCEPDw+zneQmTpxIly9fVuepqKigDRs20JAhQzTWc3NzIxcXF+Lz+WRlZUVJSUm0fft2euaZZwzq3a6HQygUCmRkZEAoFCIkJAQvv/xyk+s5OzujuroaUqnU6E7e3t7o1auXegT29evXkZ+f32g9+u+kWZVKBScnp0bLv/76a2RmZpp0pHxrKJVKXLhwocmJ6l9++SU6d+6MPn36YNKkSRCLxe1iCEp9eDweunbtioSEBADAa6+9hm+//RZVVVXmFauHp6cnoqKiEBQUhH79+iE4OLjZdU+cOIHvvvsOO3bswL1790xoqUlAQAASEhLQq1cvAIBUKkV6ejq2bNmicezb29vjvffeg7W1NaRSKX788UdcvnwZZWVluHr1qkGd2nXgksvlOHLkCPLz85t9DIZIJMKzzz6LHj16ICUlBceOHTOqk5eXF2xtbQEAxcXFuHv3bpPTjezt7bFw4UIMHToUXbt21ZjmUFhYiKysLDx8+NCorobiwIEDGD16NJ577jk4Ojpi6NChOHToULsLXDY2NliyZAmICDt27MBnn31m9uf5DxkyBKGhoXjssccA1D4hISQkBH369GnxSQrr169HRkYGUlNTTXJCbokhQ4Zg+PDhAACJRIKMjAxs2rRJI2gtX74cHTp0wOTJkyEUCpGeng6VSoULFy6goqLC4E7tOnABtc8C8vDwQFZWFgYPHoy8vDz1/CmRSITIyEi88MILCAwMREVFhdEDl1gsVs85VCqVEAqFsLOzg0wmA5/PR/fu3REbGwtHR0csXbpUI+9vv/2GgoIC5Ofno6ioyKiehmbPnj3q/erSpQvs7e3NraSmrt4TExPxz3/+E/n5+VizZo3Zg5azszOmTZuGSZMmwdXVtcV1iQi5ubm4ceMGCgsLsXr1ajx8+LBdPPvMxcVF/Riee/fuYe/evUhNTUWnTp2QmJgIoVCIt99+GzweD7du3cL9+/exZcsWHD9+3ChBC7CQwNWnTx/07dsXbm5uOHLkiLoyOnbsqI7wZ8+eNUkwyMnJQXFxMfz8/ODh4YGYmBhIpVKUlpZCKBSiX79+WLx4sfrBa3VdxlOnTiE5ORnHjh1rV10XbZHJZOrJyl26dIGDg4OZjWpxcHBAYGAgRo4cibfeegtKpRKbN282+wMDgdrLCk888USrQUulUuHcuXP47LPPcObMmXb3Bh8bGxuIRCIAtQ83cHd3R2JiIrp164b3339f3QP59ddfceLECVy7dg1Hjx5VP3fOKOh8JU9P0IYLrjwej4YMGdLk3bk7d+7Q8uXLaezYseTl5WWSC5fLly+nwsLCVvNyHEeXLl2itWvXkoODQ5vqwBDebUnBwcHq56MRkdHmhuri7uDgQOPGjaM//viDiIiUSiX98ccfBtvntnqHhoZqvEOhIQqFgq5du0aHDx+mMWPGtBvvhuntt99u9KytOjiOI6lUSidPntT65oEhvC0icNWlK1eu0L1790gmkxHHcVRSUkIrV66knj17mvSf6u7uTtu3b2/0bG2O40ihUFBZWRndvXuXsrOzafz48e3yYNQ1hYWFqZ+RXl1dTf/4xz+M4q2Le1BQEP36669UU1NDYrGYTp8+TX379m03AaBr1650+PBhjXF71dXVVFxcTHl5eXT8+HF66aWXDP6/MnR9z5kzp8kXYHAcR5WVlXTo0KFWhzoY2rvddxXrk5iYiFGjRmHUqFHw9fXFzp078dlnn5l80u/t27exaNEiyGQyxMXFqb9XqVR48OABfv/9d+Tn52Pr1q1mv85iKNzd3dWTrS9dumTw9y/qQ3V1NYqLi3Hp0iVs27YNa9euNbeSBnfv3sW2bdtgY2MDb29vAMC1a9ewY8cOfPPNN+3qjnJL/PbbbxgyZAhefPFFEBGUSiUePnwIuVyOc+fOtfoKPqOgc3jWExj4rKJvYt76pfDwcLp48SKVlZXRunXrtB5Brw/mruv2UuftyXv69OlUVlZGJSUllJmZSb6+vmb1bldvsjYFuu4u824b+hxelurOvNuGLt7t9tHNDAaD0Rw6t7hSU1MRFRVlLB8Gg8FoFZ1bXGlpaRgxYgQWLlzYLsbKMBiMvx96X+M6ceIENmzYgIKCArzwwguYPn06HB0dDe3HYDAYjdC5xaVQKLBv3z6sX78eKpUKS5cuhaenJ+Lj443hx2AwGI3QeRzXk08+iQkTJiA5ORkuLi7q70tKSgwqxmAwGM2hc4srISEB8+bNUwetzz//HAAwY8YMw5oxGAxGM2gduJRKJaRSKQ4cOICqqirIZDJIJBJ89913xvRjMBiMRmjdVfz222+xbds2ZGdnIzY2FkQEoVCoMeWFwWAwTIHOdxX/+OMPDB061Fg+DAaD0SpaB64VK1Zg2bJlePrppxtNEdizZ49R5BgMBqMptA5cYrEYrq6uuHnzZqNl7fE14QwG49FF62tcZ8+ebXYZC1wMBsOUaB24srKyml02duxYg8gwGAyGNpjssTYMBoNhKLRucc2cORPJyckICgpSX5wnIvB4PJw+fdpoggwGg9EQ1uJiMBgWh85Tfi5evIjx48cjJCQECQkJuHjxojG8GAwGo1l0bnGFhIRg+/bt8PX1RX5+PqZNm4ZTp04Zy4/BYDAaoXOLy9XVFb6+vgAAHx8fjSdEMBgMhinQ+uL8G2+8AR6PB7lcjvDwcAQEBOD8+fPo1KmTEfUYDAajMVp3FY8dO9bssoiICIMJMRgMRmto3eKqH5yuXr2KoqIivV49xWAw2heei1LUf994P9aMJtqj8xNQX331Vdy6dQvnz5+Hv78/iAjh4eHGcGMwGIwm0fni/NmzZ3HgwAF4eXnhhx9+gFAoNIYXg/G3xHNRikYLiNE0OgcugUAAALCzs0NaWhouXbpkcCkGg8FoCZ27ip988gnkcjnWrl2L5ORkrF+/3hhejHZC3dnfXNc+LPH6C8P46By4+vfvjz///BOXL19GUlISAgICjOHFYDAYzaJz4Hr99dchFosRGBiI/fv3o0uXLtiwYYMx3B4pzN1yeRRhrbG/LzoHrnPnzmmM6fo7j+FiPxwGwzxoHbhkMhkAYNCgQUhNTcXgwYNx/vx5BAcHG00OYMGBwWgvtKffotaBKzY2FjweD0SEnJwc9fcNX5zBYDDa14/8UUTrwJWenq7xWaVSgc/n671h9o9lMBj6ovM1riNHjuDNN9+EUCiEQqHA6tWrMXr06DZJsCBmeTT8n7GbDwxTonPgWrZsGdLS0uDg4ACJRIKYmJg2By7GowU7EZmWv2N96xy4OI6DSCQCAIhEIqhUKoNL6YOpWgBsOgaDYX70GscVFBSEnj174ubNm1i0aJExvCwSQwfLv+OZFNCvHrWtq0elTo11AjVG/RijEaFT4CIiVFVV4dy5c7h//z66dOkCKyudpzu2yqNwveRR+YHUx1Jbm6byttT6sUR0Clw8Hg+HDh3Ciy++CFdXV63yEBEqKioAAP2X/QoAyF0RDU4ua3J9iUSiXlb/bwDo8fpe9d+5K6KbLa9hGXXr1a1bh4ODQ7PDOQzp3XBZfRq6tbRPbfGuv536tLasuX2qT2v7rqu3NmXVX68pH23/Z/Vp63Giz3bq07C+W3LTZluG9tb2t9ica8PfYn20re86dH5ZRnx8PCQSCQIDA9WtrQ8++KDZ9SUSCRwdHXXZhMkoLy9Hx44dm1zGvA0P8zYtj6J3HVr3886ePYvo6GhcuXIF2dnZ8PPzQ2xsLGJjW+4GOTg4oLy8HOXl5SgsLAQAFBYWavzd0jJt19OljLrk4ODAvI2wT4b0NobPo+bd1LbN6a1vGdp416F1V3H27NnYuHEj/P39kZGRgU8++QQHDx5sNR+Px2sUPet/7tixY6PP+qynTxnM2zj7ZEhvY/o8Kt76uBvTW98ytK1zQIcWl52dHYKDgyEUCjF69GhUVlZqvREGg8EwJFq3uHJzc/HMM88AqL2wV//znj17jGPHYDAYTaB14MrKymrzxmxsbLBs2TLY2NgAgMbfLS3Tdj1ty2Dext8nQ3ob2udR9G64zNze+pahLTrfVWQwGAxzY/jRowwGg2FkWOBiMBgWh85zFfVFKpVi1qxZEAqFiIyMxLBhw7B69WqUl5dj6tSpSElJgUQiwZNPPokzZ87gwYMHiIqKwsyZMyGVShEREYGJEyfip59+Qr9+/fDMM88gNTVVPRj28ccfx7fffgulUonz58/Dw8MDTk5O6NOnD+Lj47F8+XI4OzsjKioKiYmJzPu/3tOnT4eHhwc2bNigdp82bVoj78mTJyM8PBxLly6FRCKBg4MDHj58yLwt3Luhu5+fH3Jzc43inZeXh127duHVV1+Fk5MTHB0dcefOHb28TXaN6+uvv0anTp0QFxeHSZMmYffu3QCAxMRE7Nu3DwBQWlqKBQsW4IsvvgDHcZg2bRq++eYbvP3227C3t4dKpcJvv/0GV1dXBAcH49SpU3B2dkZsbCyioqIAAAcOHMDRo0cxdOhQTJ06FZMmTUJwcDCCg4MRFhaG+Ph4rcaf/d28AajdH3/88Ubeb731FnJycnDgwAENd+Zt2d7NuRvDWywWw93dHaWlpZg6dSr8/f2xadMmvbxN1lW8ffs2PDw8AKDZJ6euWrUKs2fPxsGDBxEbG4uxY8fiyJEj8PPzg4uLC/r164eff/4Za9aswdatWxEaGop169YhOTlZXcaOHTvUlTxy5EjExMTg+eefx65du/DGG2+gpKSEeTfwBqB29/DwaNJ72bJlyM/Pb+TOvC3bWxt3Q3k/99xzGDp0qNo9KSlJb2+TdRXd3d1x+/Zt+Pv7g+M4jWVEhEWLFmHMmDEYPHgwBg8ejPj4eMTGxsLf3x9SqRR5eXmwtbXFuHHj0LlzZwiFQnTu3BnA/yr71q1bcHR0xL59+7BixQqEh4cjMTERSUlJ2Lx5M1QqFZ566inm3cAbqJ2DGh8fj169eqGqqqqRt1wuh7u7O4RCodqdeVu+d0vuhvZ2cHDARx99pOG+b98+vbxN1lWUSqWYM2cORCIRRowYgZiYGCxZskTdMikqKkJQUBDs7OygVCohl8sxcOBAdaTftm0bCgoKUFJSgrKyMiQlJWHv3r2ws7ODr68vZs+ejWXLliE6OhodO3bE8uXL8dhjj8He3h5z5szBu+++C6lUipkzZ2LEiBHM+7/e/v7+8PX1xffff6/h3tB75syZCA4Oxty5c9Xu9+7dY94W7t3QfdCgQfjzzz+N4h0aGorc3Fy1u0qlAo/H08ubjeNiMBgWBxsOwWAwLA4WuBgMhsXBAheDwbA4WOBiMBgWBwtcDAbD4mCBi8FgWBwmCVwnT55EZGQkIiIiMHLkSJw5c6bROjdu3MDhw4dNoaM1zNv0WKo78zYxZGRKSkpo4MCBVFRUREREZWVldPbs2Ubrpaen0/z58w2+fZVKpVc+5q0f+noTWa4789aPthwrRm9xpaSkICEhAd26dQMAODo6QiAQICIiAsOGDcOcOXMAAMnJydi9ezciIyPx8OFDbNu2DWFhYQgNDUVaWhoA4PDhwwgICMDTTz+N8PBw3LhxAxKJBPHx8YiIiMDkyZOhUCiQkZGBuLg4TJgwAR9++KHGm4iioqIavS+OeZvf25LdmbfpjxWjt7jef/99+vTTTzW+k8lkxHEcERHFx8dTQUGBRlR/8OABRUdHE8dxVFlZSREREUREFBISQiUlJVRdXU2enp50/fp1+vDDDyk5OZmIiFauXEnbt2+n9PR0CgsLU28jISGBioqK6OrVqzR58mTm3Q69LdmdeZv+WDH6JGs3NzdcvnxZ47vr169j/vz5kMlkuHbtGoqKijSWX716FRcvXsQTTzwBALh//z4AQKVSwcnJCQDQv39/AMCVK1cwY8YMAEBQUBAyMzPRo0cPBAYGqt+GO3XqVOzcuRNSqRRTpkxh3u3Q25LdmbfpjxWjdxVjY2Pxww8/oLi4GEDtG3T/9a9/Yf78+Th27BgCAgJARBAIBFCpVACAxx9/HAMHDkR6ejoyMjLw559/AqidaV5aWgqFQoGLFy8CAHr16oXTp08DqH2hR+/evWt3zOp/uxYXF4eUlBQcOXIEMTExzLsdeluyO/M2/bFi9BaXk5MTkpOT8eyzz4KIwOfzER0djddeew2+vr7qx2gMGDAAixcvxtNPP41///vfmDx5MiIiIsDn8zFgwABs3LgRK1euRFRUFLy8vNC1a1cIBALMmDEDU6ZMwa5du+Dq6op//etfOHnypIaDUCiEr68vrKysYG2t3S4zb9N6W7I78zb9sWL0a1yGRKFQEBFRdXU1DRgwgJRKpdZ558yZQ1lZWcZSaxHmbXos1Z15a4fJHiRoCA4cOIDNmzdDIpFg3rx5zT5psiGzZs1CeXk5AgMDjWzYNMzb9FiqO/PWDvY8LgaDYXGwKT8MBsPiYIGLwWBYHCxwMRgMi4MFLgaDYXGwwMVgMCwOFrgYDIbFwQIXg8GwOFjgYjAYFgcLXAwGw+L4f4lAjWIVBj0yAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 300x120 with 12 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot maximum entropy images\n",
    "fig, axs = plt.subplots(2, 6, figsize=(3,1.2))\n",
    "# get predictions with index i\n",
    "imgs = data[np.argsort(entropies)]\n",
    "sort_feats = feats[np.argsort(entropies)]\n",
    "for j in range(6):\n",
    "    axs[0,j].imshow(imgs[j].transpose(1, 2, 0))\n",
    "    axs[0,j].axis(\"off\")\n",
    "    axs[1,j].bar(range(feats.shape[1]), sort_feats[j])\n",
    "    axs[1,j].set_ylim(0, 1)\n",
    "    axs[1,j].set_xlim(-1, feats.shape[1])\n",
    "    entropy = -np.nansum(sort_feats[j] * np.log2(sort_feats[j]))\n",
    "    # axs[1,j].set_title(f\"H[z]={entropy:.2f}\")\n",
    "    # remove y ticks\n",
    "    axs[1,j].set_yticks([])\n",
    "    # remove left top and right spines\n",
    "    axs[1,j].spines['top'].set_visible(False)\n",
    "    axs[1,j].spines['right'].set_visible(False)\n",
    "    axs[1,j].spines['left'].set_visible(False)\n",
    "    axs[1,j].set_xticks(range(feats.shape[1]))\n",
    "    axs[1,j].set_xticklabels(range(feats.shape[1]))\n",
    "    # smaller font size for x labels\n",
    "    axs[1,j].tick_params(axis='x', labelsize=5)\n",
    "    axs[1,j].set_xlabel(\"Category\")\n",
    "axs[1,0].set_ylabel(\"Probability\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_dir / \"mnist_max_entropy_examples.pdf\")\n",
    "\n",
    "# plot minimum entropy images\n",
    "fig, axs = plt.subplots(2, 6, figsize=(3,1.2))\n",
    "# get predictions with index i\n",
    "imgs = data[np.argsort(entropies)[::-1]]\n",
    "sort_feats = feats[np.argsort(entropies)[::-1]]\n",
    "for j in range(6):\n",
    "    axs[0,j].imshow(imgs[j].transpose(1, 2, 0))\n",
    "    axs[0,j].axis(\"off\")\n",
    "    axs[1,j].bar(range(feats.shape[1]), sort_feats[j])\n",
    "    axs[1,j].set_ylim(0, 1)\n",
    "    axs[1,j].set_xlim(-1, feats.shape[1])\n",
    "    entropy = -np.nansum(sort_feats[j] * np.log2(sort_feats[j]))\n",
    "    # axs[1,j].set_title(f\"H[z]={entropy:.2f}\")\n",
    "    # remove y ticks\n",
    "    axs[1,j].set_yticks([])\n",
    "    # remove left top and right spines\n",
    "    axs[1,j].spines['top'].set_visible(False)\n",
    "    axs[1,j].spines['right'].set_visible(False)\n",
    "    axs[1,j].spines['left'].set_visible(False)\n",
    "    axs[1,j].set_xticks(range(feats.shape[1]))\n",
    "    axs[1,j].set_xticklabels(range(feats.shape[1]))\n",
    "    # smaller font size for x labels\n",
    "    axs[1,j].tick_params(axis='x', labelsize=5)\n",
    "    axs[1,j].set_xlabel(\"Category\")\n",
    "axs[1,0].set_ylabel(\"Probability\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_dir / \"mnist_min_entropy_examples.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAB6CAYAAACm2cqoAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGrBJREFUeJzt3X1wFPX9B/B3ggmpDEaZAmkHBGmZQidAYEIUBojKFIIhaXgsDw0IaRzkocXJUDAMDTpQoIigxUnHmhZaEcvPhgRhGImGRNFB0oAUY9FAhCAUB4HkaiJkEj6/P+xd955393bv9m7fr5mdyT3s9z7v+17uu7u3D3EiIiAiIqKoFx/pAoiIiMgYHNSJiIhiBAd1IiKiGMFBnYiIKEZwUCciIooRHNSJiIhiBAd1IiKiGMFBnYiIKEZwUCciIooRd0W6AKe4uLiIvn64T6xnp7x2ygrYK6+dsgL2ymunrEDs5OWaOhERUYyIuUG9vLwcbW1tmDJlSqRLISIiCqs4q1zQxahNH8o4Wtrkph7z2CkrYK+8dsoK2CuvnbICsZM35tbUiYjIOgoKClBVVWWLraebNm3C1atXUV5ejuTk5IjUwDV1H/OFQ6wsFaphp6yAvfKanTUjIwOPPPIItmzZ4vPxaOjb5ORktLa2GvL60da3mZmZqKmp0dVmNPStUnJyMlpaWnS3yTV1H1JTU11//+Y3v4lgJUTmSU5ORnx8TP3r+rRixQp8+OGH2Lx5M7Zt2xbpcnSZOnUqWlpa0NTUFOlSIkI5oAPAypUrI1JHOHgO6MC3C6VhJxYBIOQplPaiMa9R7xWzRk/e5ORk1+tkZmZGPK9ZOUeNGuX2Or179454Vj15X331VRERqa6uDvi8goICy32WQ+3Dw4cPh9RuuJnxPRPq/HrE5OL+V199FekSIqaoqAgiEpklRIOJiNvk+RvVjh073B7v3bt3hCoNn7Vr17r+Pn/+fAQrMVZZWRkyMzMBACNGjEB9fb3rsZkzZ+LatWuRKi0kffv2BQAsWrTI73MuX76MV155JVwlhUVKSgqysrIiXUbYnD592uu+999/PwKVIAKLQ34gxKWk0tJSV1v5+fmGLGVZOa+v6eWXX3a1n5SUZJm8WjPU1tbK8ePHvdrZvXu363kHDx5U/VrhFmo/jhs3ztXWvn37XPdPnDjRdX+gNbtoyZqQkODVXm1trTQ2NrpuX7582TKfYxH9WwB79erl8/H169eLiPtn2yp5Q+lbI9oNN71ZCwoKDGnPsByGtRSiUL8InTo7Ow3/EIY7r4jI0qVLNdWfmZnpanvw4MGWyqs2w4YNG4K2lZ+f73PAD/Ra4WbUZ1nkf5ttPQfAaOtbz6lfv34+2+vs7HS7HR8fb5msIvq/qPU+Ho19G+z/2IpZtdSlnDx/JnJasWJFSP/3IeUwrKUQhfIlqPwQrVy5UlcbVsk7d+5cERG5efOmpvpPnjwpIiLHjx+3XF419WzatCloO9euXQs4oK9bty7iWUWMXcNx/pbc3Nzsum/cuHFR1bee05gxY3y2tW3bNs3th5uWnFOmTAk436BBgzS3a9Wszmn27NmGtRtuWrMqt6gp1dXVuZ6zevVqWb9+fVjzxsSgrpScnCypqalSWVmpaYC3St69e/cGfY7ySyMhIUF69OjhmiclJcVyeYPVMmLEiIDzT5kyRXr06OG2k5ins2fPWiKriL7P8hNPPOHVzoABAwSA7N69W1P74aQ154wZM7zaqK2tlaSkJLl69arb/f369bNUVhFteZ395nA4fD6+c+dOERHZu3evJb+ntPatv60vSps2bbJkVhFteQcPHhy0nZs3b7rua2pqClveqB/UlUv3ziUiPe2Gm68alJtYKysrA9abl5cnIiLl5eWyb98+S+cNpRbnwAZAhg4d6vd5gQaAcNP6Gfa3lWLw4MGSmprqdl+wfSWs1rfKKSMjw2v+SZMmCQCvnKNGjbLc51hEW97Tp0+LiMjp06cD1t6jR4+o/L9VTklJSaranDFjhiWziqjPGx8fH7SNhoYGze0blsOwlkKk9YsQ8N6Ml5CQICtXrnS7L9hvclb6AC1dujTg477qraqqCtpupPPq+TI4duyY2/M8v/SV/B3aZaW+9TUNHjzYbQdPTw0NDeJwOFy3Gxsbo6pvg9WlHOyUrLrmKmLOoB7Ke2iVrGqp3RRt5b4NNr/nGKS2fcNyGNZSiLR8gJxTW1uba/66ujqfe9NG0wfIuWkyOTk5YK3K4z+Vg+Lhw4ctmTdQHco9nf3Nc+vWLd3tW6VvPSflkQq+vPzyy15r8NHWt8pJuTVJxP1nIr1tWrVvnVOgQV25qdqqedXUo/zpTy01W5us2rf+ZGRkCKBuLd7svFE5qA8YMMDn/Mo1Xa3thpuWvMpJuQkzJSXFbWebQYMGWTJvOGqxStZgtQSqR3n4GuD+O3u0b6I1o00r9q1yCjSoR0NeM+uxWlY1NflbA1ceXhpIVVVVWPJG3aAebK3Nn2BnbAo3f3VMnDjRkDpPnz7ttiWjubnZsl8OocrLy4uKvvVXz+rVq72eo/xp6datW1H9xa+lpoaGBstmFTF+UNd66KqVsoYiNTXVUllF9OVV/pxQW1ur6nX8nbPAsByGtRQiNR/o6upq3e0HO3493HzV4Dzm0V+NymPRtejs7JQxY8ZY9stBuce/Fs3NzW470lm5b5WTc9NrY2OjDB061Otx52GNTsEOYbPyF79zSkhIkLKysoBtBds5NNJZRfQN6p7z6W3Pan0bihEjRlgqq0jgvL72dleueT/++OOqX+fxxx83NW/UXKWtsrISubm5XvdPnjwZ/fv3D3iaxXPnzmHZsmU4cuSI3+eE+23wlbe4uBgbN270+17Ex8ejq6sLALB161YAwKpVqwAA7e3t2Lx5s9c8Z86cQUVFhdf94cyr5kpFWuvZv38/pk+fruq5VuhbT7169cKNGzd8Pqas909/+hMKCgo0vb7V+taXSZMm4a233nLdvnTpEu6//37N7Vixb51Onz6N4cOHu82n/B5raWnBfffdp+n1rdS3odRixndCqALVNHXqVLz55puGvE5aWprP08oaltewxYMQQcdSm/PMac7Du5TOnj0rTzzxhCWXgEV859V63KpyJzm1OSORV009gY5B93T16lXLZhUJ7ZwLytPftrW16WojGrJ67v/iuSXJilm15lXuDHn58mXXkQ7OI1asvmXCrFp27NhhuawigfPq2SHQl0B7/xslKgZ1X1f7KSoqcj3ueWpJ57GvVv1nEfGd1/lPr/ZUt+Xl5SLiffiX1fJqqWvcuHEyd+5cyc/P97tZfu7cuZbNKqJ/oPP8eUnLjo/R0LfOyfMwRq0LadHSt8qdWJ2U5+/3txnWKnmD1eK8Ap2Tr2OzldScTMjKfRuKzs5Ov1caNDqv5Qd1X4epnTx50u+boWYHDKt+gJx7V6odtJy07mxjtS8Hf5O/M1SpPTGJlfo22OT5hahmXwEr5NVTn+fhfGoPcYp0Vj15z54965p3w4YNrpqDXajGCnnV1FNcXOy1EtWrVy/d7Vm9b7W4du2az31mzM5r+UHdc9O6r02SToFOFRotHyAtCyVO/vamtEpePf3ha2FOzakWI51VT17PIzr0foajoW89z5et5dwKkc6qJ6/zENTa2lq3mrWcLjVa+tY5+dpCoae9cDOqtmBHMJmd9y5Y3L333uv6u6urCz169PB6TmJiIqZNm4Z9+/aFsTJzfPzxx6qe57y2+J49e/zucBVrtm/fHukSDJeXl4fu3bu7buvd8SxaLFu2zO32lClTIlRJeJw4cSLm+9RTU1NTpEswXUtLi9vY5DRs2DDV3+GmMWzxIEQIsuSnd7O62slKebXUHKu/u/q6wEms9u26deukra0tpM3Q0dC3KSkpbvN6njvB6lm15vU1Odfcg503wwp5Q8npeQbBYCdeiXRWrXl9UXsaZ7PzRsWgHo6Jea2VtaioyK0Nrb+jR0vfOo9l3rlzZ8z37bFjx3TPa4WsRtTs3LlMz09mVu5bz8nz57NgO4lFOqvWvOvXr3ebt7q62jKfZQ7qUfABiva8eurLz893ze/vghhWzKolr/JiPGpOxmHFvHrrOn78eNRl1ZrXX71lZWVRkTfUrEOHDpWGhoao3nk50LR06VKpqqqS/Px8S32WOahHyQcomvPqqU95jXW1V9qzQla1eZXXFS8uLrZF365evVpEvt1pTHlBl2jJqjWvr+nVV1/VfGhXNPStGVO4xUreqDmjnNnC/TbYKa+erAkJCejo6NA9v5IV+7a4uBjTpk3DvHnz0NjYaOjrW71vjWTFvjUT+9Y8sZKXg/p/8QNkHr1Zy8vL8e6772LHjh0hvT771jx2ygrYK6+dsgKxk5eD+n/xA2QeO2UF7JXXTlkBe+W1U1YgdvJaZlAnIiKi0MRHugAiIiIyBgd1IiKiGKF5UH/nnXfMqIOIiIhCpHlQr66uxrhx4/DrX/8aZ86cMaMmIiIi0kH3jnLHjh3DCy+8gM8++wwLFy5EQUGB6yIjREREFH6a19Q7OjrwxhtvYPv27ejq6sK6deswcOBA5ObmmlEfERERqaT50qs/+clPMG3aNJSWlqJPnz6u+69fv25oYURERKSN5jX1vLw8rFy50jWgv/LKKwCAwsJCYysjIiIiTVQP6p2dnWhra0NFRQW++eYbtLe3w+Fw4O9//7uZ9REREZFKqje/79mzB7t27cI///lPZGdnQ0SQmJiInJwcM+sjIiIilTTv/X78+HE89NBDZtVDREREOqke1J955hmUlJRg1qxZXie+37dvnynFERERkXqqB/Uvv/wSffv2xcWLF70eGzBggOGFERERkTaqf1Ovr6/3+xgHdSIidwPXHHK7fWFzdoQqITtRPajX1dX5feyxxx4zpBgiIopOXIixBtWDeklJiZl1EBERUYhUD+pPPvkkSktLMXr0aNeOciKCuLg4nDhxwrQCiYiISB3Vg3ppaSmAwJvhiYiIKHI0nya2oaEBP/3pT/Hggw8iLy8PDQ0NZtRFRERkWQPXHHJNVqL5gi6LFy/G7t27MWTIEHz66adYsGABPvzwQzNqoxBxxxVvyvckmt8P9i0R+aJ5UO/bty+GDBkCAPjRj37kdqU2IqJI4YIOkYZBfdWqVYiLi8Pt27cxYcIEjBw5EqdOncK9995rYnlERESklupBferUqQCA7Oz/Lf1Onz7d+IqIiIhIF9WDemZmpuvv8+fP48qVK9B4LRgiIiIykebf1H/5y1+iubkZp06dQlpaGkQEEyZMMKO2gGJlhyciIiKjaD6krb6+HhUVFXjggQdQWVmJxMREM+oiIiIijTQP6gkJCQCAu+++G9XV1Th79qzhRREREZF2mje/79y5E7dv38a2bdtQWlqK7du3m1EXEcUgHnZGZC7Ng3pqaio++ugjNDY2YtGiRRg5cqQZdZGN2Wl/CQ5yRGQkzYP6U089hS+//BLp6enYv38/evfujRdeeMGM2oiIiKJWJFZQNA/qJ0+eRG1treu28lA3q+FaEBER2YnqHeXa29vR3t6OESNG4J133sHNmzdRXV2NjIwMM+sjIiIilVSvqWdnZyMuLg4igjNnzrjud15b3Uh2+k2ViIjIKKoH9aNHj7rd7urqQrdu3QwviLzxZwQiIlJD82/qVVVVKC4uRmJiIjo6OrBx40ZMmjTJjNqIiCjMuBIR3TQP6iUlJaiurkbPnj3hcDiQlZXFQZ2IiMgCNA/qd+7cQVJSEgAgKSkJXV1dhhdld9yngIiI9NB1nPro0aMxYMAAXLx4EWvWrDGjLtKIm8yIjOHrf8nMBW0uxJORNA3qIoJvvvkGJ0+exLVr19C7d2/Exwc+Kk5E8J///AcAkFrylttjHz8z2ec8d263u/52OBw+5/N8TrB2AOD+p/5P1esH0rNnT797/CuzAt55PanN75nDV15fz1EzX7A+0ZJXa9v+qOnbYK+n9r1VCqVv1WbzrFFNH/midz6lUPpWL711a+1bz7a1ZA32v+Sv5mD/72q+p/x93wV7rVD+b/V8JoN9J2uZL1SR+BwD6r6n9H6XBRIoLwBANJo+fbqm57e2tgqAmJlaW1ttk9Vuee2U1W557ZTVbnntlDVYXhGROBERaJCbmwuHw4H09HTXWvrvfvc7v88XjyUlh8OB/v3749KlS7jnnnt83qfmOXrn09u2k9alwnDWaGTbevJGS7ZI9q0VPhNm9a0VspnVt1bMxr5l3/qievN7fX09iouLcenSJVy9ehULFizAD37wg6DzxcXFeRUFAPfcc4/X/Z73qXmO3vn0th2Iv6zhrtHItgNh3xr7+uF+3wIJpW/11m1m24Fo6Vu9rx/u9y0Q9q2xr2+lvgU0DOrLli3Diy++iLS0NNTU1GDnzp04cOCAphcjIiIi86g+9/vdd9+NjIwMJCYmYtKkSfj666/NrIuIiIg0Ur2m/vHHH2P27NkAvv2dQnl73759ql+we/fuKCkpQffu3f3ep+Y5eufT27Ze4azRyLaNyGrVbJHs23BnC2ffWiGb1T7L4X7fzMpqxfxm5o10tlCyqt5R7uLFi34fGzBggOYXJiIiImNp3vudiIiIrEn1b+pERERkbWEd1Nva2rBw4UIUFhZiz549AICmpiYUFBRg5syZrudVVFSgsLAQP/vZz3DkyBH861//wpIlSzBz5kyUlpa6tZeeno6DBw+ipqYG48ePx5IlS1BTUwPg2/PUr127FitWrMDu3bsBAO+99x6WLFmCX/ziFxg7diyam5uRl5eHxYsXY/PmzQCATz75BLNnz8aTTz6JN954w9S8nlkB+MyrzApAVV7PrAC88hqV1Vdeo/rWV172bWT7Vk1e9q15fRutedX0ra+84czqK29UjUEBT01jsL/85S9y4MABERGZPXu222MzZszwev6NGzdk8eLFrttdXV0yf/581+1169bJli1b5M0335SamhrJysqShQsXSmNjo4iIlJeXy4IFC+Spp56St99+263t/fv3yx/+8Ac5ePCg/PWvf3Wr6bnnnpN3331XRERycnLCktczq2deZVYR0ZTXmVVEvPIalTVQ3lD71lde9q17VpHI9K2avOxb4/s2WvOq6VtfeSOR1VfeaBiDwrqm/sUXX6B///4AgG7dugV9/oYNG7Bs2TIAwIEDB5CdnY3HHnsMwLfXdf/xj3+MPn36AADGjx+Pw4cPY8uWLSgpKQEAfPrppxg7diyef/55t6UrAHjttdcwb948PPTQQygrK8Ojjz6KrKwsAEB+fj5ef/11rFq1CtevXw9LXmVWz7yeWbXmdWYF4JXXqKyh5A3Wt77ysm/dswLsWzv1bbTmVdO3vvJGIqtn3mgZgzRfpS0U/fr1wxdffIG0tDTcuXPH7/NEBGvWrMGUKVMwatQoAN+enjY3NxfZ2dmYN28eampq0NbWhk8++QTf+c53XG/0fffdh9u3b7teLzExEYB7BzY3NyM5ORk9e/bEc889h2eeeQYTJkzAzJkzsWjRIvTp0wcvvfQSurq6MH36dFPz+srqmTctLc0rq/MUvcHyKrMCwJ///GevvEZkDSWvmr71zMu+Zd/avW+jMa+avvWXN1xZ/eWNljEorIP69OnTsXz5chw6dAg5OTkAgOvXr2Pt2rU4deoUNm3ahKeffhq///3v8fbbb6O1tRXnzp3DkCFDUF5ejtu3b7veuI0bNwIAdu3ahe9+97uoqKjAW2+9hZaWFixfvtz1eitWrMB7772HCRMmuOooKyvDokWLAABZWVlYv349XnvtNQwcOBAAcOHCBfz2t79FW1sbVq1aZWreHj16uGV1/h6jzOtcUnRmjY+PR3l5uaq8yqy+8hqV1Vdeo/rWV96MjAz2bQT7Vk1e9q15fRutedX0ra+84cwa7WMQD2kjIiKKETykjYiIKEZwUCciIooRHNSJiIhiBAd1IiKiGMFBnYiIKEZwUCciIooRER/UP/jgAzz88MPIzMzEo48+in/84x9ez7lw4YLrfMPRzk557ZQVsFdeO2UF7JXXTlmBGMyr6aSyBrt+/boMHz5crly5IiIiLS0tUl9f7/W8o0ePSlFRkeGv39XVZXibgdgpr52yitgrr52yitgrr52yisRm3oiuqR86dAh5eXn43ve+BwBITk5GQkICMjMzMWbMGNdZeUpLS/G3v/0NDz/8MG7cuIFdu3Zh/PjxGDt2LKqrqwEAR44cwciRIzFr1ixMmDABFy5cgMPhQG5uLjIzMzFnzhx0dHSgpqYGOTk5mDZtGrZu3Yrs7GxXPRMnToTD4WBeZmVeZrVlXjtljdm8hi8maLB582bXVXmc2tvb5c6dOyIikpubK5999pnbUtJXX30lkydPljt37sjXX38tmZmZIiLy4IMPyvXr1+XWrVsycOBA+fzzz2Xr1q1SWloqIiLPPvus7N69W44ePSrjx493vUZeXp5cuXJFzp8/L3PmzGFeZmVeZrVtXjtljdW8YT33u6fvf//7aGxsdLvv888/R1FREdrb29HU1IQrV664PX7+/Hk0NDTgkUceAQBcu3YNANDV1YVevXoBAFJTUwEA586dQ2FhIQBg9OjReP/993H//fcjPT0dcXFxAICf//zn2Lt3L9ra2jB//nzzwsJeee2UFbBXXjtlBeyV105ZgdjMG9HN79nZ2aisrMS///1vAIDD4cDq1atRVFSE2tpajBw5EiKChIQEdHV1AQAGDRqE4cOH4+jRo6ipqcFHH30E4Nsr4Ny8eRMdHR1oaGgAAPzwhz/EiRMnAAB1dXUYPHgwALiuHAQAOTk5OHToEKqqqlyXvWNeZmVeZrVjXjtljdW8EV1T79WrF0pLSzF37lyICLp164bJkyfjV7/6FYYMGeK6NN6wYcPw9NNPY9asWfjjH/+IOXPmIDMzE926dcOwYcPw4osv4tlnn8XEiRPxwAMPICUlBQkJCSgsLMT8+fPx+uuvo2/fvli9ejU++OADtxoSExMxZMgQxMfH4667zH077JTXTlntltdOWe2W105ZYzZvyBvwLaKjo0NERG7duiXDhg2Tzs5O1fMuX75c6urqzCrNFHbKa6esIvbKa6esIvbKa6esItbJG9E1dSNVVFTgpZdegsPhwMqVK90uSB/I0qVL0draivT0dJMrNJad8topK2CvvHbKCtgrr52yAtbJy+upExERxYiIn1GOiIiIjMFBnYiIKEZwUCciIooRHNSJiIhiBAd1IiKiGMFBnYiIKEZwUCciIooRHNSJiIhiBAd1IiKiGPH/H7eLvb01kmkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 500x120 with 16 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot minimum entropy images of number 4s\n",
    "n = 8\n",
    "fig, axs = plt.subplots(2, n, figsize=(5,1.2))\n",
    "# get predictions with index i\n",
    "inds = np.argsort(entropies)[::-1]\n",
    "inds = inds[labels[inds] == 4]\n",
    "imgs = data[inds]\n",
    "sort_feats = feats[inds]\n",
    "for j in range(8):\n",
    "    axs[0,j].imshow(imgs[j].transpose(1, 2, 0))\n",
    "    axs[0,j].axis(\"off\")\n",
    "    axs[1,j].bar(range(feats.shape[1]), sort_feats[j])\n",
    "    axs[1,j].set_ylim(0, 1)\n",
    "    axs[1,j].set_xlim(-1, feats.shape[1])\n",
    "    entropy = -np.nansum(sort_feats[j] * np.log2(sort_feats[j]))\n",
    "    # axs[1,j].set_title(f\"H[z]={entropy:.2f}\")\n",
    "    # remove y ticks\n",
    "    axs[1,j].set_yticks([])\n",
    "    # remove left top and right spines\n",
    "    axs[1,j].spines['top'].set_visible(False)\n",
    "    axs[1,j].spines['right'].set_visible(False)\n",
    "    axs[1,j].spines['left'].set_visible(False)\n",
    "    axs[1,j].set_xticks(range(feats.shape[1]))\n",
    "    axs[1,j].set_xticklabels(range(feats.shape[1]))\n",
    "    # smaller font size for x labels\n",
    "    axs[1,j].tick_params(axis='x', labelsize=5)\n",
    "    axs[1,j].set_xlabel(\"Category\")\n",
    "axs[1,0].set_ylabel(\"Probability\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(plot_dir / \"min_entropy_4s.pdf\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "solo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
