{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-25T22:18:36.450062Z",
     "iopub.status.busy": "2020-09-25T22:18:36.449561Z",
     "iopub.status.idle": "2020-09-25T22:18:36.454212Z",
     "shell.execute_reply": "2020-09-25T22:18:36.453731Z"
    },
    "papermill": {
     "duration": 0.018094,
     "end_time": "2020-09-25T22:18:36.454370",
     "exception": false,
     "start_time": "2020-09-25T22:18:36.436276",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Factors of Clusterability Analysis: small CNNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-25T22:18:36.471587Z",
     "iopub.status.busy": "2020-09-25T22:18:36.471020Z",
     "iopub.status.idle": "2020-09-25T22:18:40.384035Z",
     "shell.execute_reply": "2020-09-25T22:18:40.382624Z"
    },
    "papermill": {
     "duration": 3.925696,
     "end_time": "2020-09-25T22:18:40.384371",
     "exception": false,
     "start_time": "2020-09-25T22:18:36.458675",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from src.visualization import run_spectral_cluster\n",
    "from src.experiment_tagging import get_model_path\n",
    "from src.utils import get_weights_paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-25T22:18:40.442101Z",
     "iopub.status.busy": "2020-09-25T22:18:40.440466Z",
     "iopub.status.idle": "2020-09-25T22:18:40.449595Z",
     "shell.execute_reply": "2020-09-25T22:18:40.450834Z"
    },
    "papermill": {
     "duration": 0.040846,
     "end_time": "2020-09-25T22:18:40.451212",
     "exception": false,
     "start_time": "2020-09-25T22:18:40.410366",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_clust = 12\n",
    "n_samples = 50\n",
    "n_workers = 10\n",
    "n_reps = 5\n",
    "\n",
    "model_tags = ('CNN-MNIST', 'CNN-MNIST+DROPOUT', 'CNN-MNIST+L1REG', 'CNN-MNIST+L2REG',\n",
    "              'CNN-MNIST+MOD-INIT', 'CNN-STACKED-MNIST', 'CNN-STACKED-SAME-MNIST',\n",
    "              'CNN-FASHION', 'CNN-FASHION+DROPOUT', 'CNN-FASHION+L1REG', 'CNN-FASHION+L2REG',\n",
    "              'CNN-FASHION+MOD-INIT', 'CNN-STACKED-FASHION', 'CNN-STACKED-SAME-FASHION')\n",
    "tag_to_net = {'CNN-MNIST': 'Control', 'CNN-MNIST+DROPOUT': 'Dropout', 'CNN-MNIST+L1REG': 'L1 Reg',\n",
    "              'CNN-MNIST+L2REG': 'L2 Reg', 'CNN-MNIST+MOD-INIT': 'Clusterable Init',\n",
    "              'CNN-STACKED-SAME-MNIST': 'Stacked Same', 'CNN-STACKED-MNIST': 'Stacked Diff',\n",
    "              'CNN-FASHION': 'Control', 'CNN-FASHION+DROPOUT': 'Dropout', 'CNN-FASHION+L1REG': 'L1 Reg',\n",
    "              'CNN-FASHION+L2REG': 'L2 Reg', 'CNN-FASHION+MOD-INIT': 'Clusterable Init',\n",
    "              'CNN-STACKED-SAME-FASHION': 'Stacked Same', 'CNN-STACKED-FASHION': 'Stacked Diff'}\n",
    "\n",
    "model_paths = {tag: get_model_path(tag, filter_='all')[-n_reps:] for tag in model_tags}\n",
    "assert all([len(mps)==n_reps for mps in model_paths.values()])\n",
    "\n",
    "clustering_results = {}\n",
    "clustering_results_pruned = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2020-09-25T22:18:40.495096Z",
     "iopub.status.busy": "2020-09-25T22:18:40.493709Z",
     "iopub.status.idle": "2020-09-25T22:22:01.848368Z",
     "shell.execute_reply": "2020-09-25T22:22:01.849675Z"
    },
    "papermill": {
     "duration": 201.387834,
     "end_time": "2020-09-25T22:22:01.850047",
     "exception": false,
     "start_time": "2020-09-25T22:18:40.462213",
     "status": "completed"
    },
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "for tag, paths in tqdm(model_paths.items()):\n",
    "\n",
    "    clustering_results[tag] = {}\n",
    "    clustering_results_pruned[tag] = {}\n",
    "\n",
    "    for rep in range(n_reps):\n",
    "\n",
    "        weight_paths = get_weights_paths(paths[rep])\n",
    "        results = run_spectral_cluster(weight_paths[True], n_clusters=n_clust, n_samples=n_samples,\n",
    "                                       n_workers=n_workers, eigen_solver='arpack')\n",
    "        clustering_results[tag][rep] = results\n",
    "\n",
    "        results_pruned = run_spectral_cluster(weight_paths[False], n_clusters=n_clust, n_samples=n_samples,\n",
    "                                              n_workers=n_workers, eigen_solver='arpack')\n",
    "        clustering_results_pruned[tag][rep] = results_pruned\n",
    "\n",
    "all_results = []\n",
    "for i, res in enumerate([clustering_results, clustering_results_pruned]):\n",
    "    for tag in res:\n",
    "\n",
    "        network = tag_to_net[tag]\n",
    "        if i == 1:\n",
    "            network += ', Pruning'\n",
    "        if 'MNIST' in tag:\n",
    "            dset = 'MNIST'\n",
    "        elif 'CIFAR10' in tag:\n",
    "            dset = 'CIFAR-10'\n",
    "        else:\n",
    "            dset = 'FASHION'\n",
    "\n",
    "        for rep in res[tag]:\n",
    "            result = {'model': tag,\n",
    "                      'network_type': 'cnn' if 'CNN' in tag else 'mlp',\n",
    "                      'Network': network,\n",
    "                      'Dataset': dset}\n",
    "            labels, metrics = res[tag][rep]\n",
    "            result.update(metrics)\n",
    "            all_results.append(pd.Series(result))\n",
    "\n",
    "result_df = pd.DataFrame(all_results)\n",
    "savepath = '../results/clustering_factors_cnn.csv'\n",
    "result_df.to_csv(savepath)\n",
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "papermill": {
   "duration": 207.256063,
   "end_time": "2020-09-25T22:22:02.684723",
   "environment_variables": {},
   "exception": null,
   "input_path": "./notebooks/clustering_factors_cnn.ipynb",
   "output_path": "./notebooks/clustering_factors_cnn.ipynb",
   "parameters": {},
   "start_time": "2020-09-25T22:18:35.428660",
   "version": "1.2.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
