{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-02T22:54:03.624493Z",
     "iopub.status.busy": "2020-09-02T22:54:03.623202Z",
     "iopub.status.idle": "2020-09-02T22:54:03.626775Z",
     "shell.execute_reply": "2020-09-02T22:54:03.627468Z"
    },
    "papermill": {
     "duration": 0.024645,
     "end_time": "2020-09-02T22:54:03.627741",
     "exception": false,
     "start_time": "2020-09-02T22:54:03.603096",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Factors of Clusterability Analysis: CIFAR-10 VGGs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-02T22:54:03.646623Z",
     "iopub.status.busy": "2020-09-02T22:54:03.645770Z",
     "iopub.status.idle": "2020-09-02T22:54:08.108034Z",
     "shell.execute_reply": "2020-09-02T22:54:08.106594Z"
    },
    "papermill": {
     "duration": 4.472527,
     "end_time": "2020-09-02T22:54:08.108390",
     "exception": false,
     "start_time": "2020-09-02T22:54:03.635863",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from src.visualization import run_spectral_cluster\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import get_weights_paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-02T22:54:08.151193Z",
     "iopub.status.busy": "2020-09-02T22:54:08.150530Z",
     "iopub.status.idle": "2020-09-02T22:54:08.164088Z",
     "shell.execute_reply": "2020-09-02T22:54:08.164598Z"
    },
    "papermill": {
     "duration": 0.040641,
     "end_time": "2020-09-02T22:54:08.164758",
     "exception": false,
     "start_time": "2020-09-02T22:54:08.124117",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_clust = 12\n",
    "n_samples = 50\n",
    "n_workers = 5\n",
    "n_reps = 5\n",
    "\n",
    "model_tags = ('CNN-VGG-CIFAR10', 'CNN-VGG-CIFAR10+L1REG', 'CNN-VGG-CIFAR10+L2REG',\n",
    "              'CNN-VGG-CIFAR10+DROPOUT', 'CNN-VGG-CIFAR10+DROPOUT+L2REG', 'CNN-VGG-CIFAR10+MOD-INIT')\n",
    "tag_to_net = {'CNN-VGG-CIFAR10': 'Control', 'CNN-VGG-CIFAR10+L1REG': 'L1 Reg',\n",
    "              'CNN-VGG-CIFAR10+L2REG': 'L2 Reg', 'CNN-VGG-CIFAR10+DROPOUT': 'Dropout',\n",
    "              'CNN-VGG-CIFAR10+DROPOUT+L2REG': 'Dropout, L2 Reg',\n",
    "              'CNN-VGG-CIFAR10+MOD-INIT': 'Clusterable Init'}\n",
    "\n",
    "model_paths = {tag: get_model_path(tag, filter_='all')[-n_reps:] for tag in model_tags}\n",
    "assert all([len(mps)==n_reps for mps in model_paths.values()])\n",
    "\n",
    "clustering_results = {}\n",
    "clustering_results_pruned = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2020-09-02T22:54:08.188739Z",
     "iopub.status.busy": "2020-09-02T22:54:08.188166Z",
     "iopub.status.idle": "2020-09-03T06:25:11.655571Z",
     "shell.execute_reply": "2020-09-03T06:25:11.656963Z"
    },
    "papermill": {
     "duration": 27063.48692,
     "end_time": "2020-09-03T06:25:11.657524",
     "exception": false,
     "start_time": "2020-09-02T22:54:08.170604",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/32 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  3%|▎         | 1/32 [46:57<24:15:35, 2817.26s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  6%|▋         | 2/32 [1:15:55<20:46:48, 2493.61s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  9%|▉         | 3/32 [1:37:18<17:09:43, 2130.47s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 12%|█▎        | 4/32 [2:04:58<15:28:17, 1989.21s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 16%|█▌        | 5/32 [2:25:25<13:12:16, 1760.61s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 19%|█▉        | 6/32 [3:08:56<14:33:26, 2015.63s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 22%|██▏       | 7/32 [3:34:46<13:01:39, 1875.97s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 25%|██▌       | 8/32 [3:54:53<11:10:06, 1675.28s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 28%|██▊       | 9/32 [4:20:03<10:23:14, 1625.85s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 31%|███▏      | 10/32 [4:39:40<9:06:44, 1491.13s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 34%|███▍      | 11/32 [5:25:18<10:52:47, 1865.14s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 38%|███▊      | 12/32 [5:48:50<9:36:20, 1729.04s/it] "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 41%|████      | 13/32 [6:07:44<8:11:00, 1550.53s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 44%|████▍     | 14/32 [6:31:55<7:36:12, 1520.69s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 47%|████▋     | 15/32 [6:54:28<6:56:39, 1470.54s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 50%|█████     | 16/32 [7:14:06<6:08:44, 1382.76s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 53%|█████▎    | 17/32 [7:29:00<5:09:00, 1236.02s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 56%|█████▋    | 18/32 [7:29:13<3:22:50, 869.34s/it] "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 59%|█████▉    | 19/32 [7:29:21<2:12:18, 610.69s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 62%|██████▎   | 20/32 [7:29:27<1:25:51, 429.30s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 66%|██████▌   | 21/32 [7:29:34<55:28, 302.58s/it]  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 69%|██████▉   | 22/32 [7:29:41<35:39, 213.90s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 72%|███████▏  | 23/32 [7:29:54<23:04, 153.84s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 75%|███████▌  | 24/32 [7:30:01<14:38, 109.76s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 78%|███████▊  | 25/32 [7:30:08<09:11, 78.78s/it] "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 81%|████████▏ | 26/32 [7:30:15<05:43, 57.26s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 84%|████████▍ | 27/32 [7:30:22<03:30, 42.18s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 88%|████████▊ | 28/32 [7:30:35<02:14, 33.61s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 91%|█████████ | 29/32 [7:30:42<01:16, 25.63s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 94%|█████████▍| 30/32 [7:30:49<00:40, 20.01s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 97%|█████████▋| 31/32 [7:30:56<00:16, 16.10s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 32/32 [7:31:03<00:00, 13.30s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "100%|██████████| 32/32 [7:31:03<00:00, 845.73s/it]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for tag, paths in tqdm(model_paths.items()):\n",
    "    \n",
    "    clustering_results[tag] = {}\n",
    "    clustering_results_pruned[tag] = {}\n",
    "\n",
    "    for rep in range(n_reps):\n",
    "\n",
    "        weight_paths = get_weights_paths(paths[rep])\n",
    "        results = run_spectral_cluster(weight_paths[True], n_clusters=n_clust, n_samples=n_samples,\n",
    "                                       n_workers=n_workers, eigen_solver='arpack')\n",
    "        clustering_results[tag][rep] = results\n",
    "\n",
    "        results_pruned = run_spectral_cluster(weight_paths[False], n_clusters=n_clust, n_samples=n_samples,\n",
    "                                              n_workers=n_workers, eigen_solver='arpack')\n",
    "        clustering_results_pruned[tag][rep] = results_pruned\n",
    "\n",
    "all_results = []\n",
    "for i, res in enumerate([clustering_results, clustering_results_pruned]):\n",
    "    for tag in res:\n",
    "\n",
    "        network = tag_to_net[tag]\n",
    "        if i == 1:\n",
    "            network += ', Pruning'\n",
    "\n",
    "        for rep in res[tag]:\n",
    "            result = {'model': tag,\n",
    "                      'network_type': 'cnn' if 'CNN' in tag else 'mlp',\n",
    "                      'Network': network}\n",
    "            labels, metrics = res[tag][rep]\n",
    "            result.update(metrics)\n",
    "            all_results.append(pd.Series(result))\n",
    "\n",
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/clustering_factors_cnn_vgg.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 154094.44518,
   "end_time": "2020-09-04T17:42:17.045092",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/clustering_factors.ipynb",
   "output_path": "./notebooks/clustering_factors.ipynb",
   "parameters": {},
   "start_time": "2020-09-02T22:54:02.599912",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}