{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import torch\n",
    "assert torch.cuda.is_available()\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch.nn as nn\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "from train import eval\n",
    "from Utils import load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['mica_igq_shuffle']\n"
     ]
    }
   ],
   "source": [
    "pruners = np.array(\n",
    "    [\n",
    "        # 'cs',\n",
    "        # 'grasp',\n",
    "        # 'sf',\n",
    "        # 'mica_cs_shuffle',\n",
    "        # 'mica_grasp_shuffle',\n",
    "        # 'mica_sf_shuffle',\n",
    "        # 'mica_erk_shuffle',\n",
    "        'mica_igq_shuffle',\n",
    "        # 'mica_cs_rand',\n",
    "        # 'mica_grasp_rand',\n",
    "        # 'mica_sf_rand',\n",
    "        # 'mica_erk_rand',\n",
    "        # 'mica_igq_rand'\n",
    "        # 'mica_igq_shuffle',\n",
    "        \n",
    "    ],\n",
    "    dtype=object\n",
    "    )\n",
    "\n",
    "colors = {\n",
    "    'cs'         : plt.cm.tab10(9),\n",
    "    'grasp'      : plt.cm.tab10(8),\n",
    "    'sf'         : plt.cm.tab10(6),\n",
    "    'mica_cs'    : plt.cm.tab10(9),\n",
    "    'mica_grasp' : plt.cm.tab10(8),\n",
    "    'mica_sf'    : plt.cm.tab10(6),\n",
    "    'mica_erk'   : plt.cm.tab10(4),\n",
    "    'mica_igq'   : plt.cm.tab10(3),\n",
    "}\n",
    "\n",
    "markers = {\n",
    "    'cs'         : \"^\", \n",
    "    'grasp'      : \"h\", \n",
    "    'sf'         : \"H\", \n",
    "    'mica_cs'    : \"^\", \n",
    "    'mica_grasp' : \"h\", \n",
    "    'mica_sf'    : \"H\",\n",
    "    'mica_erk'   : \"p\", \n",
    "    'mica_igq'   : \"*\"\n",
    "}\n",
    "\n",
    "pruner_names = {\n",
    "    'cs'                 : 'SNIP (Original)',\n",
    "    'grasp'              : 'GraSP (Original)',\n",
    "    'sf'                 : 'SynFlow (Original)',\n",
    "    'mica_cs_shuffle'    : 'SNIP (Random)',\n",
    "    'mica_grasp_shuffle' : 'GraSP (Random)',\n",
    "    'mica_sf_shuffle'    : 'SynFlow (Random)',\n",
    "    'mica_erk_shuffle'   : 'ERK (Random)',\n",
    "    'mica_igq_shuffle'   : 'IGQ (Random)',\n",
    "    'mica_cs_rand'       : '$\\\\bf{MiCA}$-$\\\\bf{SNIP (Rand)}$',\n",
    "    'mica_grasp_rand'    : '$\\\\bf{MiCA}$-$\\\\bf{GraSP (Rand)}$',\n",
    "    'mica_sf_rand'       : '$\\\\bf{MiCA}$-$\\\\bf{SynFlow (Rand)}$',\n",
    "    'mica_erk_rand'      : '$\\\\bf{MiCA}$-$\\\\bf{ERK (Rand)}$',\n",
    "    'mica_igq_rand'      : '$\\\\bf{MiCA}$-$\\\\bf{IGQ (Rand)}$',\n",
    "    }\n",
    "\n",
    "pepochs = {\n",
    "    'cs'         : '1',\n",
    "    'grasp'      : '1',\n",
    "    'sf'         : '100',\n",
    "    'mica_cs'    : '1',\n",
    "    'mica_grasp' : '1',\n",
    "    'mica_sf'    : '1',\n",
    "    'mica_erk'   : '1',\n",
    "    'mica_igq'   : '1',\n",
    "}\n",
    "\n",
    "for pruner in pruners:\n",
    "    if '_shuffle' in pruner:\n",
    "        colors[pruner]  = colors[pruner.replace('_shuffle', '')]\n",
    "        markers[pruner]  = markers[pruner.replace('_shuffle', '')]\n",
    "        pepochs[pruner] = pepochs[pruner.replace('_shuffle', '')]\n",
    "    elif '_rand' in pruner:\n",
    "        colors[pruner]  = colors[pruner.replace('_rand', '')]\n",
    "        markers[pruner]  = markers[pruner.replace('_rand', '')]\n",
    "        pepochs[pruner] = pepochs[pruner.replace('_rand', '')]\n",
    "\n",
    "ratios = [\n",
    "    # '0.5', '1.0', '1.5', '2.0', \n",
    "    '2.5', \n",
    "    # '3.0', '3.5', '4.0', \n",
    "    ]\n",
    "expid = 'cifar10_lottery_resnet20'\n",
    "runs = [0, 1, 2]\n",
    "\n",
    "max_run_num = len(runs)\n",
    "\n",
    "print(pruners)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_experimentid(pruner, ratio, pepochs, run_num):\n",
    "    if '_shuffle' in pruner:\n",
    "        base_dir = f'../Results/data/singleshot/{expid}_shuffle'\n",
    "        p_name = pruner.replace('_shuffle', '')\n",
    "    elif '_min' in pruner:\n",
    "        base_dir = f'../Results/data/singleshot/{expid}_min'\n",
    "        p_name = pruner.replace('_min', '')\n",
    "    elif '_max' in pruner:\n",
    "        base_dir = f'../Results/data/singleshot/{expid}_max'\n",
    "        p_name = pruner.replace('_max', '')\n",
    "    elif '_rand' in pruner:\n",
    "        base_dir = f'../Results/data/singleshot/{expid}_rand'\n",
    "        p_name = pruner.replace('_rand', '')\n",
    "    else:\n",
    "        base_dir = f'../Results/data/singleshot/{expid}'\n",
    "        p_name = pruner\n",
    "    b_dir  = f'{base_dir}'\n",
    "    experimentid =  f'{b_dir}/' \\\n",
    "                    f'{p_name}-{ratio}-{pepochs}/' \\\n",
    "                    f'run_{run_num}/' \n",
    "    return experimentid\n",
    "\n",
    "def zero_to_nan(values):\n",
    "    return [float('nan') if x==0 else x for x in values]\n",
    "\n",
    "def get_ordered_list(l, order):\n",
    "    out = []\n",
    "    for i in order:\n",
    "        out.append(l[i])\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expid = 'cifar100_lottery_resnet20'\n",
    "ratios = [\n",
    "    '0.5', '1.0', '1.5', '2.0', \n",
    "    '2.5', '3.0', '3.5', '4.0'\n",
    "    ]\n",
    "\n",
    "for l, pruner in enumerate(pruners):\n",
    "    for i, ratio in enumerate(tqdm(ratios, desc=pruner)): \n",
    "        for r, run_num in enumerate(runs):\n",
    "            experimentid      = get_experimentid(pruner, ratio, pepochs[pruner], run_num)\n",
    "            model_file        = experimentid + 'model.pt'\n",
    "            compression_file  = experimentid + 'compression.pkl'\n",
    "            compression       = pd.read_pickle(compression_file)\n",
    "            prunable          = compression['prunable']==True\n",
    "            weight            = compression['param']=='weight'\n",
    "            modules           = compression[np.logical_and(prunable, weight)]['module']\n",
    "            densities         = np.array(compression[np.logical_and(prunable, weight)]['density'])\n",
    "            shape             = compression[np.logical_and(prunable,weight)]['shape']\n",
    "            e                 = np.array([np.prod(s) for s in shape]) * (densities)\n",
    "            with open(experimentid + 'args.json') as f:\n",
    "                args = json.load(f)\n",
    "            input_shape, num_classes = load.dimension(args['dataset']) \n",
    "            loss = nn.CrossEntropyLoss()\n",
    "            if i == 0 and r == 0:\n",
    "                test_loader, _, _ = load.dataloader(\n",
    "                    args['dataset'], args['test_batch_size'], args['test_batch_size'],\n",
    "                    False, args['workers'], 0, 0)\n",
    "                dataset = args['dataset']\n",
    "                test_batch_size = args['test_batch_size']\n",
    "                workers = args['workers']\n",
    "                (data, _) = next(iter(test_loader))\n",
    "                input_dim = list(data[0,:].shape)\n",
    "                input = torch.ones([1] + input_dim).to('cuda')\n",
    "\n",
    "            model = load.model(args['model'], args['model_class'])(\n",
    "                input_shape, num_classes, args['dense_classifier'], args['pretrained']).to('cuda')\n",
    "            model.load_state_dict(torch.load(model_file))\n",
    "            print(f'Forward :')\n",
    "            model.set_effective_sparsity(is_forward=True, x=input)\n",
    "            print(f'Backward :')\n",
    "            model.set_effective_sparsity(is_forward=False, x=None)\n",
    "\n",
    "            torch.save(model.state_dict(), f\"{experimentid + 'model_effective_comp.pt'}\")\n",
    "\n",
    "            effective_densities = []\n",
    "            set_modules = set(modules.unique())\n",
    "            for n, m in model.named_modules():\n",
    "                if hasattr(m, 'weight_mask') and n in set_modules:\n",
    "                    effective_densities.append((m.weight_mask.sum() / m.weight_mask.numel()).to('cpu'))\n",
    "            effective_densities = np.array(effective_densities)\n",
    "            e = np.array([np.prod(s) for s in shape]) * (effective_densities)\n",
    "            if np.all(e == 0):                \n",
    "                actual_comp = float('nan')\n",
    "            else:\n",
    "                actual_comp = np.sum(np.array([np.prod(s) for s in shape])) / np.sum(e)\n",
    "            print(np.log10(actual_comp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expid = 'cifar100_lottery_vgg16_bn'\n",
    "ratios = [\n",
    "    '0.5', '1.0', '1.5', '2.0', \n",
    "    '2.5', '3.0', '3.5', '4.0', \n",
    "    '4.5', '5.0', '5.5', '6.0', \n",
    "    ]\n",
    "\n",
    "for l, pruner in enumerate(pruners):\n",
    "    for i, ratio in enumerate(tqdm(ratios, desc=pruner)): \n",
    "        for r, run_num in enumerate(runs):\n",
    "            experimentid      = get_experimentid(pruner, ratio, pepochs[pruner], run_num)\n",
    "            model_file        = experimentid + 'model.pt'\n",
    "\n",
    "            with open(experimentid + 'args.json') as f:\n",
    "                args = json.load(f)\n",
    "            input_shape, num_classes = load.dimension(args['dataset']) \n",
    "            loss = nn.CrossEntropyLoss()\n",
    "            if i == 0 and r == 0:\n",
    "                test_loader, _, _ = load.dataloader(\n",
    "                    args['dataset'], args['test_batch_size'], args['test_batch_size'],\n",
    "                    False, args['workers'], 0, 0)\n",
    "                dataset = args['dataset']\n",
    "                test_batch_size = args['test_batch_size']\n",
    "                workers = args['workers']\n",
    "                (data, _) = next(iter(test_loader))\n",
    "                input_dim = list(data[0,:].shape)\n",
    "                input = torch.ones([1] + input_dim).to('cuda')\n",
    "\n",
    "            model = load.model(args['model'], args['model_class'])(\n",
    "                input_shape, num_classes, args['dense_classifier'], args['pretrained']).to('cuda')\n",
    "            model.load_state_dict(torch.load(model_file))\n",
    "            print(f'Forward :')\n",
    "            model.set_effective_sparsity(is_forward=True, x=input)\n",
    "            print(f'Backward :')\n",
    "            model.set_effective_sparsity(is_forward=False, x=None)\n",
    "\n",
    "            torch.save(model.state_dict(), f\"{experimentid + 'model_effective_comp.pt'}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connected_edge",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
