{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compressing a VGG-16 trained on CIFAR-10\n",
    "\n",
    "#### To replicate results with pretrained models please download the following models from CUP repository\n",
    "\n",
    "1. vgg16_cifar10.pth\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpu-data2/kfot/miniconda3/envs/myenv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import sys; sys.argv=['']; del sys\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.tight_layout()\n",
    "\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from tensorboardX import SummaryWriter\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "import time\n",
    "import copy\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "\n",
    "from src.utils import plot_tsne,fancy_dendrogram,save_obj,load_obj\n",
    "from src.model import VGG,load_model\n",
    "from src.prune_model import prune_model\n",
    "from src.cluster_model import cluster_model, fuse_model\n",
    "from src.train_test import train,test,adjust_learning_rate_nips,adjust_learning_rate_iccv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train baseline VGG-16 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n",
    "parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n",
    "                    help='input batch size for training (default: 128)')\n",
    "parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n",
    "                    help='number of data loading workers (default: 4)')\n",
    "parser.add_argument('--epochs', default=160, type=int, metavar='N',\n",
    "                    help='number of total epochs to run')\n",
    "parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n",
    "                    metavar='LR', help='initial learning rate')\n",
    "parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n",
    "                    help='momentum')\n",
    "parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n",
    "                    metavar='W', help='weight decay (default: 1e-4)')\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='disables CUDA training')\n",
    "parser.add_argument('--seed', type=int, default=12346, metavar='S',\n",
    "                    help='random seed (default: 12346)')\n",
    "parser.add_argument('--num_output', type=int, default=10, metavar='S',\n",
    "                    help='number of classes(default: 10)')\n",
    "parser.add_argument('--log-interval', type=int, default=100, metavar='N',\n",
    "                    help='how many batches to wait before logging training status')\n",
    "parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/vgg16_cifar10.pth', metavar='S',\n",
    "                    help='path to store model training checkpoints')\n",
    "parser.add_argument('--gpu', type=int, default=[5], nargs='+', help='used gpu')\n",
    "\n",
    "args = parser.parse_args()\n",
    "\n",
    "\n",
    "#set device to CPU or GPU\n",
    "args = parser.parse_args()\n",
    "\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "writer = SummaryWriter('logs/vgg_16_cifar10/base_iccv_param/')\n",
    "\n",
    "#set all seeds for reproducability\n",
    "def set_random_seed(seed):    \n",
    "    random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(args.seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "set_random_seed(args.seed)\n",
    "\n",
    "\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225])\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.CIFAR10(root='data/', train=True, transform=transforms.Compose([\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.RandomCrop(32, 4),\n",
    "            transforms.ToTensor(),\n",
    "            normalize,\n",
    "        ]), download=True),\n",
    "        batch_size=args.batch_size, shuffle=True,\n",
    "        num_workers=args.workers, pin_memory=True)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "        datasets.CIFAR10(root='data/', train=False, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            normalize,\n",
    "        ])),\n",
    "        batch_size=args.batch_size, shuffle=False,\n",
    "        num_workers=args.workers, pin_memory=True)\n",
    "\n",
    "vgg16_bn = VGG('VGG16',num_output=args.num_output)\n",
    "vgg16_bn = vgg16_bn.to(device)\n",
    "        \n",
    "optimizer = optim.SGD(vgg16_bn.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay,nesterov=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading state from epoch 152 and test loss 0.3201139727592468\n",
      "original model accuracy is 93.64\n"
     ]
    }
   ],
   "source": [
    "test_loss = 0\n",
    "best_val_acc = 0\n",
    "\n",
    "if not os.path.isfile(args.checkpoint_path):\n",
    "    for epoch in range(1, args.epochs + 1):\n",
    "        start = time.time()\n",
    "        args.lr = adjust_learning_rate_iccv(args,optimizer,epoch)\n",
    "        train_loss,train_acc = train(args,vgg16_bn,device,train_loader,optimizer,epoch)\n",
    "        val_loss,val_acc = test(args,vgg16_bn,device,val_loader)\n",
    "\n",
    "        writer.add_scalars('base_model/loss',{'train_loss': train_loss,\n",
    "                                        'val_loss' : val_loss}, epoch)\n",
    "        writer.add_scalars('base_model/accuracy',{'train_acc': train_acc,\n",
    "                                            'val_acc' : val_acc}, epoch)    \n",
    "\n",
    "        print('Time taken for epoch : {}\\n'.format(time.time()-start))\n",
    "\n",
    "        if val_acc > best_val_acc: \n",
    "            best_val_acc = val_acc\n",
    "\n",
    "            torch.save({\n",
    "                        'epoch': epoch,\n",
    "                        'model_state_dict': vgg16_bn.state_dict(),\n",
    "                        'optimizer_state_dict': optimizer.state_dict(),\n",
    "                        'loss': val_loss,\n",
    "                        }, args.checkpoint_path, pickle_protocol=4)\n",
    "else:\n",
    "    vgg16_bn,optimizer = load_model('vgg16_bn','sgd',args)\n",
    "    orig_loss, orig_acc = test(args, vgg16_bn, device, val_loader,verbose=False)\n",
    "    best_val_acc = orig_acc\n",
    "    \n",
    "print('original model accuracy is {}'.format(best_val_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 1 (T = 0.02), no retrain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading state from epoch 152 and test loss 0.3201139727592468\n",
      "\n",
      "Test set: Average loss: 0.3201, Accuracy: 9364/10000 (94%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_random_seed(args.seed)\n",
    "vgg16_bn,optimizer = load_model('vgg16_bn','sgd',args)\n",
    "orig_loss, orig_acc = test(args, vgg16_bn, device, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prune using T=0.02"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cnn_layers', '0']\n",
      "['cnn_layers', '3']\n",
      "['cnn_layers', '7']\n",
      "['cnn_layers', '10']\n",
      "['cnn_layers', '14']\n",
      "['cnn_layers', '17']\n",
      "['cnn_layers', '20']\n",
      "['cnn_layers', '24']\n",
      "['cnn_layers', '27']\n",
      "['cnn_layers', '30']\n",
      "['cnn_layers', '34']\n",
      "['cnn_layers', '37']\n",
      "\n",
      "Test set: Average loss: 0.5065, Accuracy: 8620/10000 (86%)\n",
      "\n",
      "Original accuracy 93.64, compressed model accuracy 86.2, accuracy drop 7.439999999999998\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {\n",
    "    'cluster_layers' : {0:0,3:0,7:0,10:0,14:0,17:0,20:0,24:0,27:0,30:0,34:0,37:0,40:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'reshape_exists' : True,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.02,\n",
    "    'merge_criterion' : 'tropnnc',   \n",
    "    'tropnnc_features_and_threshold' : True, \n",
    "    'variant' : False,\n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(vgg16_bn,cluster_args)\n",
    "vgg16_bn_clustered = model_modifier.cluster_model()#[int(nodes*drop_percentage) for nodes in [500,300]])\n",
    "vgg16_bn_clustered.cuda()\n",
    "\n",
    "val_loss_no_retrain, val_accuracy_no_retrain = test(args, vgg16_bn_clustered, device, val_loader)\n",
    "\n",
    "print('Original accuracy {}, compressed model accuracy {}, accuracy drop {}'.format(orig_acc,val_accuracy_no_retrain,orig_acc-val_accuracy_no_retrain))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Calculate #params and #flops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- stats for original model ---\n",
      "  + Number of params: 14.73M\n",
      "  + Number of FLOPs: 0.63G\n",
      "--- stats for compressed model ---\n",
      "  + Number of params: 2.75M\n",
      "  + Number of FLOPs: 0.39G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "389956402.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print('--- stats for original model ---')\n",
    "print_model_param_nums(vgg16_bn)\n",
    "print_model_param_flops(vgg16_bn.cpu(),input_res=32)\n",
    "\n",
    "print('--- stats for compressed model ---')\n",
    "print_model_param_nums(vgg16_bn_clustered)\n",
    "print_model_param_flops(vgg16_bn_clustered.cpu(),input_res=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 1 (T = 0.025), no retrain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading state from epoch 152 and test loss 0.3201139727592468\n",
      "\n",
      "Test set: Average loss: 0.3201, Accuracy: 9364/10000 (94%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_random_seed(args.seed)\n",
    "vgg16_bn,optimizer = load_model('vgg16_bn','sgd',args)\n",
    "orig_loss, orig_acc = test(args, vgg16_bn, device, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prune using T=0.025"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cnn_layers', '0']\n",
      "['cnn_layers', '3']\n",
      "['cnn_layers', '7']\n",
      "['cnn_layers', '10']\n",
      "['cnn_layers', '14']\n",
      "['cnn_layers', '17']\n",
      "['cnn_layers', '20']\n",
      "['cnn_layers', '24']\n",
      "['cnn_layers', '27']\n",
      "['cnn_layers', '30']\n",
      "['cnn_layers', '34']\n",
      "['cnn_layers', '37']\n",
      "\n",
      "Test set: Average loss: 1.6493, Accuracy: 3489/10000 (35%)\n",
      "\n",
      "Original accuracy 93.64, compressed model accuracy 34.89, accuracy drop 58.75\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {\n",
    "    'cluster_layers' : {0:0,3:0,7:0,10:0,14:0,17:0,20:0,24:0,27:0,30:0,34:0,37:0,40:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'reshape_exists' : True,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.025,\n",
    "    'merge_criterion' : 'tropnnc',   \n",
    "    'tropnnc_features_and_threshold' : True, \n",
    "    'variant' : False, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(vgg16_bn,cluster_args)\n",
    "vgg16_bn_clustered = model_modifier.cluster_model()#[int(nodes*drop_percentage) for nodes in [500,300]])\n",
    "vgg16_bn_clustered.cuda()\n",
    "\n",
    "val_loss_no_retrain, val_accuracy_no_retrain = test(args, vgg16_bn_clustered, device, val_loader)\n",
    "\n",
    "print('Original accuracy {}, compressed model accuracy {}, accuracy drop {}'.format(orig_acc,val_accuracy_no_retrain,orig_acc-val_accuracy_no_retrain))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Calculate #params and #flops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- stats for original model ---\n",
      "  + Number of params: 14.73M\n",
      "  + Number of FLOPs: 0.63G\n",
      "--- stats for compressed model ---\n",
      "  + Number of params: 1.80M\n",
      "  + Number of FLOPs: 0.33G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "328996762.0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print('--- stats for original model ---')\n",
    "print_model_param_nums(vgg16_bn)\n",
    "print_model_param_flops(vgg16_bn.cpu(),input_res=32)\n",
    "\n",
    "print('--- stats for compressed model ---')\n",
    "print_model_param_nums(vgg16_bn_clustered)\n",
    "print_model_param_flops(vgg16_bn_clustered.cpu(),input_res=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using iterative TropNNC, variant 1 (T = 0.021), no retrain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading state from epoch 152 and test loss 0.3201139727592468\n",
      "\n",
      "Test set: Average loss: 0.3201, Accuracy: 9364/10000 (94%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_random_seed(args.seed)\n",
    "vgg16_bn,optimizer = load_model('vgg16_bn','sgd',args)\n",
    "orig_loss, orig_acc = test(args, vgg16_bn, device, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prune using T=0.021"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cnn_layers', '0']\n",
      "['cnn_layers', '3']\n",
      "['cnn_layers', '7']\n",
      "['cnn_layers', '10']\n",
      "['cnn_layers', '14']\n",
      "['cnn_layers', '17']\n",
      "['cnn_layers', '20']\n",
      "['cnn_layers', '24']\n",
      "['cnn_layers', '27']\n",
      "['cnn_layers', '30']\n",
      "['cnn_layers', '34']\n",
      "['cnn_layers', '37']\n",
      "\n",
      "Test set: Average loss: 0.2930, Accuracy: 9120/10000 (91%)\n",
      "\n",
      "Original accuracy 93.64, compressed model accuracy 91.2, accuracy drop 2.4399999999999977\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {\n",
    "    'cluster_layers' : {0:0,3:0,7:0,10:0,14:0,17:0,20:0,24:0,27:0,30:0,34:0,37:0,40:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'reshape_exists' : True,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.021,\n",
    "    'merge_criterion' : 'tropnnc3iters',   \n",
    "    'tropnnc_features_and_threshold' : True, \n",
    "    'variant' : False,\n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(vgg16_bn,cluster_args)\n",
    "vgg16_bn_clustered = model_modifier.cluster_model()#[int(nodes*drop_percentage) for nodes in [500,300]])\n",
    "vgg16_bn_clustered.cuda()\n",
    "\n",
    "val_loss_no_retrain, val_accuracy_no_retrain = test(args, vgg16_bn_clustered, device, val_loader)\n",
    "\n",
    "print('Original accuracy {}, compressed model accuracy {}, accuracy drop {}'.format(orig_acc,val_accuracy_no_retrain,orig_acc-val_accuracy_no_retrain))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Calculate #params and #flops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- stats for original model ---\n",
      "  + Number of params: 14.73M\n",
      "  + Number of FLOPs: 0.63G\n",
      "--- stats for compressed model ---\n",
      "  + Number of params: 2.52M\n",
      "  + Number of FLOPs: 0.38G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "378712426.0"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print('--- stats for original model ---')\n",
    "print_model_param_nums(vgg16_bn)\n",
    "print_model_param_flops(vgg16_bn.cpu(),input_res=32)\n",
    "\n",
    "print('--- stats for compressed model ---')\n",
    "print_model_param_nums(vgg16_bn_clustered)\n",
    "print_model_param_flops(vgg16_bn_clustered.cpu(),input_res=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 1 (T = 0.025), no retrain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading state from epoch 152 and test loss 0.3201139727592468\n",
      "\n",
      "Test set: Average loss: 0.3201, Accuracy: 9364/10000 (94%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_random_seed(args.seed)\n",
    "vgg16_bn,optimizer = load_model('vgg16_bn','sgd',args)\n",
    "orig_loss, orig_acc = test(args, vgg16_bn, device, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prune using T=0.025"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cnn_layers', '0']\n",
      "['cnn_layers', '3']\n",
      "['cnn_layers', '7']\n",
      "['cnn_layers', '10']\n",
      "['cnn_layers', '14']\n",
      "['cnn_layers', '17']\n",
      "['cnn_layers', '20']\n",
      "['cnn_layers', '24']\n",
      "['cnn_layers', '27']\n",
      "['cnn_layers', '30']\n",
      "['cnn_layers', '34']\n",
      "['cnn_layers', '37']\n",
      "\n",
      "Test set: Average loss: 0.9190, Accuracy: 6995/10000 (70%)\n",
      "\n",
      "Original accuracy 93.64, compressed model accuracy 69.95, accuracy drop 23.689999999999998\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {\n",
    "    'cluster_layers' : {0:0,3:0,7:0,10:0,14:0,17:0,20:0,24:0,27:0,30:0,34:0,37:0,40:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'reshape_exists' : True,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.0255,\n",
    "    'merge_criterion' : 'tropnnc3iters',   \n",
    "    'tropnnc_features_and_threshold' : True, \n",
    "    'variant' : False, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(vgg16_bn,cluster_args)\n",
    "vgg16_bn_clustered = model_modifier.cluster_model()#[int(nodes*drop_percentage) for nodes in [500,300]])\n",
    "vgg16_bn_clustered.cuda()\n",
    "\n",
    "val_loss_no_retrain, val_accuracy_no_retrain = test(args, vgg16_bn_clustered, device, val_loader)\n",
    "\n",
    "print('Original accuracy {}, compressed model accuracy {}, accuracy drop {}'.format(orig_acc,val_accuracy_no_retrain,orig_acc-val_accuracy_no_retrain))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Calculate #params and #flops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- stats for original model ---\n",
      "  + Number of params: 14.73M\n",
      "  + Number of FLOPs: 0.63G\n",
      "--- stats for compressed model ---\n",
      "  + Number of params: 1.76M\n",
      "  + Number of FLOPs: 0.33G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "326360898.0"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print('--- stats for original model ---')\n",
    "print_model_param_nums(vgg16_bn)\n",
    "print_model_param_flops(vgg16_bn.cpu(),input_res=32)\n",
    "\n",
    "print('--- stats for compressed model ---')\n",
    "print_model_param_nums(vgg16_bn_clustered)\n",
    "print_model_param_flops(vgg16_bn_clustered.cpu(),input_res=32)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
