{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Evaluate performance of clean LinearProbe on noisy data \n",
    "\n",
    "Prereqs:\n",
    "    Models: \n",
    "        - clean CLIP model (A)\n",
    "        - finetuned CLIP model on noisy data (B)\n",
    "    Linear Heads: \n",
    "        - linear fit on cleanCLIP embeddings (1)\n",
    "        - linear fit on noisyCLIP embeddings (2)\n",
    "And then we evaluate:  {A1(clean), B1(noisy), A1(noisy))\n",
    "     A   B \n",
    "    -------\n",
    "  1|   |   |\n",
    "   |---|----\n",
    "  2|   |   |\n",
    "    -------\n",
    "\"\"\"\n",
    "import torch \n",
    "import torch.nn as nn \n",
    "import numpy as np \n",
    "import os \n",
    "from torch.utils.data import DataLoader\n",
    "from utils import * \n",
    "# imagenet \n",
    "from pytorch_lightning import Trainer, LightningDataModule,LightningModule, seed_everything\n",
    "from pytorch_lightning.metrics import Accuracy\n",
    "import glob\n",
    "from noisy_clip_dataparallel import NoisyCLIP\n",
    "from baselines import Baseline\n",
    "from tabulate import tabulate\n",
    "\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Jupyter evals should be quick \n",
    "class HParam:\n",
    "    def __init__(self, **kwargs):\n",
    "        for k, v in kwargs.items():\n",
    "            setattr(self, k, v) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = os.path.expanduser('~/datasets/ImageNet100/')\n",
    "\n",
    "def get_xform(distort_type, param=None):\n",
    "    if distort_type == None:\n",
    "        return ImageNetBaseTransformVal(HParam(encoder='clip'))\n",
    "    \n",
    "    elif distort_type == 'randommask':\n",
    "        hparam = HParam(encoder='clip', distortion=distort_type, \n",
    "                        percent_missing=param, fixed_mask=False)\n",
    "    elif distort_type == \"gaussiannoise\":\n",
    "        hparam = HParam(encoder='clip', distortion=distort_type, std=param, fixed_mask=False)\n",
    "    elif distort_type == \"gaussianblur\":\n",
    "        hparam = HParam(encoder='clip', distortion=distort_type, kernel_size=param[0], sigma=param[1])\n",
    "    return ImageNetDistortVal(hparam)\n",
    "\n",
    "\n",
    "def get_valset(distort_type, param=None):\n",
    "    \n",
    "    return DataLoader(ImageNet100(root=DATA_DIR, split='val', \n",
    "                                  transform=get_xform(distort_type, param)), \n",
    "                      batch_size=128, num_workers=4, \n",
    "                      pin_memory=True, shuffle=False)\n",
    "\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch = next(iter(get_valset(None)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matt/venv/lib/python3.6/site-packages/torchvision/transforms/transforms.py:258: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  \"Argument interpolation should be of type InterpolationMode instead of int. \"\n"
     ]
    }
   ],
   "source": [
    "baseline_ckpt = glob.glob('/tmp/Logs/Contrastive-Inversion/RN101_CLEAN_CLIP_LIN/checkpoints/*')[0]\n",
    "\n",
    "baseline = Baseline.load_from_checkpoint(baseline_ckpt)\n",
    "clean_backbone = baseline.encoder.feature_extractor\n",
    "clean_classifier = baseline.encoder.classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=512, out_features=100, bias=True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clean_classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_classifier = clean_classifier.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['models_table1/blur21.ckpt',\n",
       " 'models_table1/blur37.ckpt',\n",
       " 'models_table1/noise01.ckpt',\n",
       " 'models_table1/noise03.ckpt',\n",
       " 'models_table1/noise05.ckpt',\n",
       " 'models_table1/rand50.ckpt',\n",
       " 'models_table1/rand75.ckpt',\n",
       " 'models_table1/rand90.ckpt']"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_ckpt = sorted(glob.glob('models_table1/rand*'))[0]\n",
    "sorted(glob.glob('models_table1/*'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_model = NoisyCLIP.load_from_checkpoint(test_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestComp(LightningModule):\n",
    "    def __init__(self, backbone, classifier):\n",
    "        super().__init__()\n",
    "        self.backbone = backbone \n",
    "        self.classifier = classifier \n",
    "        \n",
    "        self.test_top_1 = Accuracy(top_k=1)\n",
    "        self.test_top_5 = Accuracy(top_k=5)\n",
    "        self.output_dict = {}\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.classifier(self.backbone(x))\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y = batch \n",
    "        logits = self.forward(x)\n",
    "        pred_probs = logits.softmax(dim=-1) \n",
    "        self.test_top_1(pred_probs, y)\n",
    "        self.test_top_5(pred_probs, y) \n",
    "        \n",
    "    def test_epoch_end(self, outputs):\n",
    "        top1 = self.test_top_1.compute()\n",
    "        top5 = self.test_top_5.compute() \n",
    "        \n",
    "        self.log('top1', top1)\n",
    "        self.log('top5', top5)\n",
    "        self.test_top_1.reset()\n",
    "        self.test_top_5.reset()\n",
    "        self.output_dict=  {'top1': top1.item(), 'top5': top5.item()}\n",
    "        \n",
    "def eval_combo(backbone, classifier, data):\n",
    "    trainer = Trainer(gpus=[0])\n",
    "    comp = TestComp(backbone, classifier)\n",
    "    trainer.test(model=comp, test_dataloaders=data)\n",
    "    return comp.output_dict\n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_DISTORT_PAIRS = \\\n",
    "[('models_table1/blur21.ckpt', ('gaussianblur', (21, 5))),\n",
    " ('models_table1/blur37.ckpt', ('gaussianblur', (37, 9))), \n",
    " ('models_table1/noise01.ckpt',('gaussiannoise', 0.1)),\n",
    " ('models_table1/noise03.ckpt',('gaussiannoise', 0.3)),\n",
    " ('models_table1/noise05.ckpt',('gaussiannoise', 0.5)),\n",
    " ('models_table1/rand50.ckpt', ('randommask', 0.50)),\n",
    " ('models_table1/rand75.ckpt', ('randommask', 0.75)),\n",
    " ('models_table1/rand90.ckpt', ('randommask', 0.90))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d6de528dee84112aa175eaf629e443d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7832000255584717, 'top5': 0.9462000131607056}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90de5374784f402bb2df427e48df7441",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.010200000368058681, 'top5': 0.04879999905824661}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "464ccb83977c4e739e6935759ea878ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7307999730110168, 'top5': 0.9064000248908997}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f14ce484fca4d3d827ad406d0978752",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.010400000028312206, 'top5': 0.054999999701976776}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1046f975780e4e82b3fdff3cfd34996b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7724000215530396, 'top5': 0.9422000050544739}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b3817018ab2468e9b35493752db30cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.41440001130104065, 'top5': 0.6615999937057495}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b3e30ffd15841478c8229f7778aa4ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7570000290870667, 'top5': 0.9318000078201294}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a015d9e293db4f718e961da1f2f292e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.39100000262260437, 'top5': 0.6398000121116638}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c18ae9f9f434e8cbd2f29c1f000a1f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7052000164985657, 'top5': 0.8985999822616577}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecbf696ddc984ab1abbdee85e067a3a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.33000001311302185, 'top5': 0.5640000104904175}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "429c93ec6fb3430286ddd8ac00e06474",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.8014000058174133, 'top5': 0.9495999813079834}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03514e551a7e495da4229a80ba6e9141",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.6033999919891357, 'top5': 0.8348000049591064}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57c3e61e05fe4d5a9cd48bcce99f9043",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.76419997215271, 'top5': 0.9314000010490417}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2515ae78c383456893ec38de5077d32e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.36880001425743103, 'top5': 0.5983999967575073}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14d2a12c39a349e6b7cc2004c8a589e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: None, using: 0 TPU cores\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.7811999917030334, 'top5': 0.9430000185966492}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e0c486b5a104126a07c2d1be7eabe7d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "DATALOADER:0 TEST RESULTS\n",
      "{'top1': 0.03720000013709068, 'top5': 0.10939999669790268}\n",
      "--------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "OUTPUTS = []\n",
    "for model_name, distort_param in MODEL_DISTORT_PAIRS: \n",
    "    encoder = NoisyCLIP.load_from_checkpoint(model_name).noisy_visual_encoder\n",
    "    \n",
    "    eval_distorted = eval_combo(encoder, clean_classifier, get_valset(*distort_param))\n",
    "    eval_clean = eval_combo(encoder, clean_classifier, get_valset(None))\n",
    "    OUTPUTS.append((model_name, eval_distorted, eval_clean))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('models_table1/blur21.ckpt',\n",
       "  {'top1': 0.7832000255584717, 'top5': 0.9462000131607056},\n",
       "  {'top1': 0.010200000368058681, 'top5': 0.04879999905824661}),\n",
       " ('models_table1/blur37.ckpt',\n",
       "  {'top1': 0.7307999730110168, 'top5': 0.9064000248908997},\n",
       "  {'top1': 0.010400000028312206, 'top5': 0.054999999701976776}),\n",
       " ('models_table1/noise01.ckpt',\n",
       "  {'top1': 0.7724000215530396, 'top5': 0.9422000050544739},\n",
       "  {'top1': 0.41440001130104065, 'top5': 0.6615999937057495}),\n",
       " ('models_table1/noise03.ckpt',\n",
       "  {'top1': 0.7570000290870667, 'top5': 0.9318000078201294},\n",
       "  {'top1': 0.39100000262260437, 'top5': 0.6398000121116638}),\n",
       " ('models_table1/noise05.ckpt',\n",
       "  {'top1': 0.7052000164985657, 'top5': 0.8985999822616577},\n",
       "  {'top1': 0.33000001311302185, 'top5': 0.5640000104904175}),\n",
       " ('models_table1/rand50.ckpt',\n",
       "  {'top1': 0.8014000058174133, 'top5': 0.9495999813079834},\n",
       "  {'top1': 0.6033999919891357, 'top5': 0.8348000049591064}),\n",
       " ('models_table1/rand75.ckpt',\n",
       "  {'top1': 0.76419997215271, 'top5': 0.9314000010490417},\n",
       "  {'top1': 0.36880001425743103, 'top5': 0.5983999967575073}),\n",
       " ('models_table1/rand90.ckpt',\n",
       "  {'top1': 0.7811999917030334, 'top5': 0.9430000185966492},\n",
       "  {'top1': 0.03720000013709068, 'top5': 0.10939999669790268})]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "OUTPUTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model      distort_acc    clean_acc\n",
      "-------  -------------  -----------\n",
      "blur21           0.783        0.010\n",
      "blur37           0.731        0.010\n",
      "noise01          0.772        0.414\n",
      "noise03          0.757        0.391\n",
      "noise05          0.705        0.330\n",
      "rand50           0.801        0.603\n",
      "rand75           0.764        0.369\n",
      "rand90           0.781        0.037\n"
     ]
    }
   ],
   "source": [
    "top1s = [[_[0].split('/')[-1].split('.')[0], _[1]['top1'], _[2]['top1']] for _ in OUTPUTS]\n",
    "print(tabulate(top1s, headers=['model', 'distort_acc', 'clean_acc'], floatfmt=\".3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model      distort_acc    clean_acc\n",
      "-------  -------------  -----------\n",
      "blur21           0.946        0.049\n",
      "blur37           0.906        0.055\n",
      "noise01          0.942        0.662\n",
      "noise03          0.932        0.640\n",
      "noise05          0.899        0.564\n",
      "rand50           0.950        0.835\n",
      "rand75           0.931        0.598\n",
      "rand90           0.943        0.109\n"
     ]
    }
   ],
   "source": [
    "top5s = [[_[0].split('/')[-1].split('.')[0], _[1]['top5'], _[2]['top5']] for _ in OUTPUTS]\n",
    "print(tabulate(top5s, headers=['model', 'distort_acc', 'clean_acc'], floatfmt=\".3f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
