{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "mounted-calculator",
   "metadata": {},
   "source": [
    "# CIFAR Results\n",
    "- CIFAR10/CIFAR100\n",
    "- VGG-16/ResNet-56\n",
    "- Results in paper are averaged over multiple runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "overall-deployment",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from rni.cifar import set_dataset_path, train, prune_finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "representative-clause",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_dataset_path(\"cifar10\", \"/data/datasets/CIFAR10\")\n",
    "set_dataset_path(\"cifar100\", \"/data/datasets/CIFAR100\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bridal-administrator",
   "metadata": {},
   "source": [
    "# VGG-16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "sought-profit",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>train_acc</th>\n",
       "      <th>prune_acc_50</th>\n",
       "      <th>prune_acc_90</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">cifar10</th>\n",
       "      <th>l1</th>\n",
       "      <td>0.9387</td>\n",
       "      <td>0.9367</td>\n",
       "      <td>0.8995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rni</th>\n",
       "      <td>0.9363</td>\n",
       "      <td>0.9367</td>\n",
       "      <td>0.9088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ucs</th>\n",
       "      <td>0.9358</td>\n",
       "      <td>0.9293</td>\n",
       "      <td>0.8405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">cifar100</th>\n",
       "      <th>l1</th>\n",
       "      <td>0.7450</td>\n",
       "      <td>0.7134</td>\n",
       "      <td>0.2581</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rni</th>\n",
       "      <td>0.7201</td>\n",
       "      <td>0.7221</td>\n",
       "      <td>0.5935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ucs</th>\n",
       "      <td>0.7371</td>\n",
       "      <td>0.7067</td>\n",
       "      <td>0.5460</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 train_acc  prune_acc_50  prune_acc_90\n",
       "dataset  method                                       \n",
       "cifar10  l1         0.9387        0.9367        0.8995\n",
       "         rni        0.9363        0.9367        0.9088\n",
       "         ucs        0.9358        0.9293        0.8405\n",
       "cifar100 l1         0.7450        0.7134        0.2581\n",
       "         rni        0.7201        0.7221        0.5935\n",
       "         ucs        0.7371        0.7067        0.5460"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hyperparams = [\n",
    "    {\"method\": \"rni\", \"reg_weight\": 1e-3, \"b\": 3},\n",
    "    {\"method\": \"l1\", \"reg_weight\": 1e-4},\n",
    "    {\"method\": \"ucs\"}\n",
    "]\n",
    "\n",
    "res_vgg = []\n",
    "for hp in hyperparams:\n",
    "    for ds in [\"cifar10\", \"cifar100\"]:\n",
    "        trained_model, trained_logs = train(dataset=ds, arch=\"vgg16\", **hp)\n",
    "        pruned_model_50, pruned_logs_50 = prune_finetune(trained_model, dataset=ds, method=hp[\"method\"], prune_pc=0.5)\n",
    "        pruned_model_90, pruned_logs_90 = prune_finetune(trained_model, dataset=ds, method=hp[\"method\"], prune_pc=0.9)\n",
    "\n",
    "        res_vgg.append({\n",
    "            **hp,\n",
    "            \"dataset\": ds,\n",
    "            \"train_acc\": trained_logs[\"test_acc\"][trained_logs[\"best_epoch\"]],\n",
    "            \"prune_acc_50\": pruned_logs_50[\"test_acc\"][pruned_logs_50[\"best_epoch\"]],\n",
    "            \"prune_acc_90\": pruned_logs_90[\"test_acc\"][pruned_logs_90[\"best_epoch\"]],\n",
    "        })\n",
    "\n",
    "\n",
    "res_vgg_df = pd.DataFrame(res_vgg).drop([\"reg_weight\", \"b\"], axis=1).set_index([\"dataset\", \"method\"]).sort_index()\n",
    "res_vgg_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "victorian-spelling",
   "metadata": {},
   "source": [
    "# ResNet-56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "selective-cinema",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>train_acc</th>\n",
       "      <th>prune_acc_50</th>\n",
       "      <th>prune_acc_90</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">cifar10</th>\n",
       "      <th>l1</th>\n",
       "      <td>0.9403</td>\n",
       "      <td>0.9337</td>\n",
       "      <td>0.8991</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rni</th>\n",
       "      <td>0.9322</td>\n",
       "      <td>0.9326</td>\n",
       "      <td>0.9001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ucs</th>\n",
       "      <td>0.9354</td>\n",
       "      <td>0.9351</td>\n",
       "      <td>0.8881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">cifar100</th>\n",
       "      <th>l1</th>\n",
       "      <td>0.7167</td>\n",
       "      <td>0.7029</td>\n",
       "      <td>0.6192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>rni</th>\n",
       "      <td>0.7136</td>\n",
       "      <td>0.7038</td>\n",
       "      <td>0.6401</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ucs</th>\n",
       "      <td>0.7157</td>\n",
       "      <td>0.7114</td>\n",
       "      <td>0.6445</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 train_acc  prune_acc_50  prune_acc_90\n",
       "dataset  method                                       \n",
       "cifar10  l1         0.9403        0.9337        0.8991\n",
       "         rni        0.9322        0.9326        0.9001\n",
       "         ucs        0.9354        0.9351        0.8881\n",
       "cifar100 l1         0.7167        0.7029        0.6192\n",
       "         rni        0.7136        0.7038        0.6401\n",
       "         ucs        0.7157        0.7114        0.6445"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hyperparams = [\n",
    "    {\"method\": \"rni\", \"reg_weight\": 1e-4, \"b\": 0},\n",
    "    {\"method\": \"l1\", \"reg_weight\": 1e-5},\n",
    "    {\"method\": \"ucs\"}\n",
    "]\n",
    "\n",
    "res_resnet = []\n",
    "for hp in hyperparams:\n",
    "    for ds in [\"cifar10\", \"cifar100\"]:\n",
    "        trained_model, trained_logs = train(dataset=ds, arch=\"resnet56\", **hp)\n",
    "        pruned_model_50, pruned_logs_50 = prune_finetune(trained_model, dataset=ds, method=hp[\"method\"], prune_pc=0.5)\n",
    "        pruned_model_90, pruned_logs_90 = prune_finetune(trained_model, dataset=ds, method=hp[\"method\"], prune_pc=0.9)\n",
    "\n",
    "        res_resnet.append({\n",
    "            **hp,\n",
    "            \"dataset\": ds,\n",
    "            \"train_acc\": trained_logs[\"test_acc\"][trained_logs[\"best_epoch\"]],\n",
    "            \"prune_acc_50\": pruned_logs_50[\"test_acc\"][pruned_logs_50[\"best_epoch\"]],\n",
    "            \"prune_acc_90\": pruned_logs_90[\"test_acc\"][pruned_logs_90[\"best_epoch\"]],\n",
    "        })\n",
    "\n",
    "\n",
    "res_resnet_df = pd.DataFrame(res_resnet).drop([\"reg_weight\", \"b\"], axis=1).set_index([\"dataset\", \"method\"]).sort_index()\n",
    "res_resnet_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dev]",
   "language": "python",
   "name": "conda-env-dev-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
