{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from shrinkbench.experiment import PruningClass\n",
    "from shrinkbench.csv_analysis import *\n",
    "import os\n",
    "\n",
    "# Replace with absolute or relative paths to shrinkbench and trainning data\n",
    "os.environ['DATAPATH'] = './shrinkbench/Training_data'\n",
    "os.environ[\"ShrinkPATH\"] = './shrinkbench'\n",
    "\n",
    "\n",
    "# This code will rely on the existence of a folder called 'saved_models' in the shrinkbench directory, you \n",
    "# might need to create it such that `shrinkbench/save_models` exists if it currently does not"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 30 models on mixed pruning\n",
    "\n",
    "model_arch = 'resnet20' # Replace with model architecture wanted\n",
    "\n",
    "exp = PruningClass(dataset='CIFAR10', \n",
    "                    model=f'{model_arch}',  \n",
    "                    train_kwargs={\n",
    "                        'optim': 'SGD',\n",
    "                        'epochs': 10,\n",
    "                        'lr': 1e-2,\n",
    "                        'weight_decay' : 5e-4},\n",
    "                    dl_kwargs={'batch_size':128},\n",
    "                  save_freq=1)\n",
    "\n",
    "exp.run_init()\n",
    "strategies = [\"MixedMagGrad\"] # Strategies wanting to do\n",
    "compressions = [2, 4, 10, 20, 50] # Compression ratio wanted \n",
    "for strategy in strategies:\n",
    "    exp.fix_seed()\n",
    "    for i in range(30): # Change range to number of models wanting to run\n",
    "        exp.round = i\n",
    "#         Change ResultPATH to the desired folder (should be different folder for each strategy)\n",
    "        os.environ[\"ResultPATH\"] = f'/home/username/Developer/Summer/{model_arch}'\n",
    "\n",
    "        exp.state = 'Original'\n",
    "        exp.compression = 0\n",
    "        exp.pruning = False\n",
    "        exp.build_model(f\"{model_arch}\")\n",
    "        exp.update_optim(epochs=10, lr=1e-2)\n",
    "        \n",
    "        exp.run()\n",
    "        exp.update_optim(epochs=20, lr=1e-1)\n",
    "        exp.run()\n",
    "        for x in [1e-2, 1e-3, 1e-4]:\n",
    "            exp.update_optim(epochs=10, lr=x)\n",
    "            exp.run()\n",
    "\n",
    "        exp.load_model(checkpoint=True)\n",
    "        exp.save_model(f\"{i}-{model_arch}-{exp.compression}\")\n",
    "            \n",
    "        cifar10_init(exp, i)\n",
    "        cifar10_log(exp, i)\n",
    "        \n",
    "        exp.strategy = strategy\n",
    "        for compression in compressions:\n",
    "            exp.compression = compression\n",
    "            exp.prune()\n",
    "            exp.state = \"Compressed\"\n",
    "            cifar10_log(exp, i)\n",
    "\n",
    "            exp.state = \"Finetuned\"\n",
    "            exp.update_optim(epochs=10, lr=1e-1)\n",
    "            exp.run()\n",
    "            for x in [1e-2, 1e-3, 1e-4]:\n",
    "                exp.update_optim(epochs=5, lr=x)\n",
    "                exp.run()\n",
    "                \n",
    "            exp.load_model(prune=True)\n",
    "            exp.update_optim('SGD', 15, 1e-2)\n",
    "            exp.state = \"FineTuned\"\n",
    "            exp.save_model(f\"{i}-{model_arch}-{exp.compression}\")\n",
    "\n",
    "\n",
    "            cifar10_log(exp, i)\n",
    "\n",
    "            exp.load_model(f\"{i}-{model_arch}-0\")\n",
    "            exp.update_optim(epochs=10, lr=1e-1)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
