{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compressing a Resnet-18 trained on Imagenet\n",
    "\n",
    "#### To replicate results with pretrained models please download the following model from the CUP repository.\n",
    "\n",
    "1. resnet18_imagenet_pytorch.pth\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpu-data2/kfot/miniconda3/envs/myenv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import sys; sys.argv=['']; \n",
    "sys.path.insert(0, '../')\n",
    "del sys\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.distributed as dist\n",
    "import torch.multiprocessing as mp\n",
    "import torch.utils.data\n",
    "import torch.utils.data.distributed\n",
    "import torch.utils.model_zoo as model_zoo\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "from torchvision import datasets, transforms\n",
    "from tensorboardX import SummaryWriter\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import copy\n",
    "\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
    "\n",
    "from src.imagenet_utils import train,validate,save_checkpoint,AverageMeter,ProgressMeter\n",
    "from src.imagenet_utils import adjust_learning_rate,accuracy,adjust_learning_rate_pytorch_retrain\n",
    "from src.utils import plot_tsne,fancy_dendrogram,save_obj,load_obj,weights_init\n",
    "from src.model import VGG,load_model\n",
    "from src.prune_model import prune_model\n",
    "from src.cluster_model import cluster_model, fuse_model\n",
    "from src.train_test import adjust_learning_rate_nips,adjust_learning_rate_iccv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch version: 1.8.1\n",
      "torchvision version: 0.9.1\n",
      "CUDA available: True\n",
      "CUDA device: NVIDIA GeForce RTX 2080 Ti\n"
     ]
    }
   ],
   "source": [
    "print(f\"PyTorch version: {torch.__version__}\")\n",
    "print(f\"torchvision version: {torchvision.__version__}\")\n",
    "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"CUDA device: {torch.cuda.get_device_name(0)}\")\n",
    "else:\n",
    "    print(\"No CUDA device available\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Specify imagenet data path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--data'], dest='data', nargs=None, const=None, default='/gpu-data2/kfot/data/imagenet/ILSVRC/Data/CLS-LOC/', type=<class 'str'>, choices=None, help='path to dataset', metavar='S')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n",
    "\n",
    "### REPLACE THE PATH WITH YOUR OWN PATH TO IMAGENET ###\n",
    "path_to_imagenet = '/path/to/imagenet/ILSVRC/Data/CLS-LOC/'\n",
    "\n",
    "### add path to dataset here #####\n",
    "parser.add_argument('--data', type=str,default=path_to_imagenet,metavar='S',help='path to dataset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n",
    "                    help='manual epoch number (useful on restarts)')\n",
    "parser.add_argument('-b', '--batch-size', default=256, type=int,\n",
    "                    metavar='N',\n",
    "                    help='mini-batch size (default: 256), this is the total '\n",
    "                         'batch size of all GPUs on the current node when '\n",
    "                         'using Data Parallel or Distributed Data Parallel')\n",
    "parser.add_argument('-j', '--workers', default=6, type=int, metavar='N',\n",
    "                    help='number of data loading workers (default: 4)')\n",
    "parser.add_argument('--epochs', default=90, type=int, metavar='N',\n",
    "                    help='number of total epochs to run')\n",
    "parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n",
    "                    metavar='LR', help='initial learning rate')\n",
    "parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n",
    "                    help='momentum')\n",
    "parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n",
    "                    metavar='W', help='weight decay (default: 1e-4)')\n",
    "parser.add_argument('-p', '--print-freq', default=500, type=int,\n",
    "                    metavar='N', help='print frequency (default: 10)')\n",
    "parser.add_argument('--resume', default='', type=str, metavar='PATH',\n",
    "                    help='path to latest checkpoint (default: none)')\n",
    "parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',\n",
    "                    help='evaluate model on validation set')\n",
    "parser.add_argument('--pretrained', dest='pretrained', action='store_true',\n",
    "                    help='use pre-trained model')\n",
    "parser.add_argument('--world-size', default=-1, type=int,\n",
    "                    help='number of nodes for distributed training')\n",
    "parser.add_argument('--rank', default=-1, type=int,\n",
    "                    help='node rank for distributed training')\n",
    "parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,\n",
    "                    help='url used to set up distributed training')\n",
    "parser.add_argument('--dist-backend', default='nccl', type=str,\n",
    "                    help='distributed backend')\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='disables CUDA training')\n",
    "parser.add_argument('--seed', type=int, default=12346, metavar='S',\n",
    "                    help='random seed (default: 12346)')\n",
    "parser.add_argument('--num_output', type=int, default=10, metavar='S',\n",
    "                    help='number of classes(default: 10)')\n",
    "parser.add_argument('--log-interval', type=int, default=100, metavar='N',\n",
    "                    help='how many batches to wait before logging training status')\n",
    "parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/resnet18_imagenet_pytorch.pth', metavar='S',\n",
    "                    help='path to store model training checkpoints')\n",
    "parser.add_argument('--gpu', type=int, default=0, nargs='+', help='used gpu')\n",
    "parser.add_argument('--multiprocessing-distributed', action='store_true',\n",
    "                        help='Use multi-processing distributed training to launch ')#,\n",
    "#                          'N processes per node, which has N GPUs. This is the ',\n",
    "#                          'fastest way to use PyTorch for either single node or ',\n",
    "#                          'multi node data parallel training')\n",
    "\n",
    "args = parser.parse_args()\n",
    "\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "# if use_cuda:\n",
    "#     print('using gpu',args.gpu)\n",
    "#     os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(str(x) for x in args.gpu)\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "# Data loading code\n",
    "traindir = os.path.join(args.data, 'train')\n",
    "valdir = os.path.join(args.data, 'val')\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "train_dataset = datasets.ImageFolder(\n",
    "    traindir,\n",
    "    transforms.Compose([\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ]))\n",
    "\n",
    "val_dataset = datasets.ImageFolder(\n",
    "    valdir, \n",
    "    transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ]))\n",
    "\n",
    "# Generate random indices for the subset\n",
    "total_size = len(val_dataset)\n",
    "subset_size = 2000\n",
    "indices = np.random.choice(total_size, subset_size, replace=False)\n",
    "\n",
    "# Create the subset\n",
    "val_subset = torch.utils.data.Subset(val_dataset, indices)\n",
    "\n",
    "train_sampler = None\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),\n",
    "    num_workers=args.workers, pin_memory=True, sampler=train_sampler)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_dataset,\n",
    "    batch_size=args.batch_size, shuffle=False,\n",
    "    num_workers=args.workers, pin_memory=True)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_subset,\n",
    "    batch_size=args.batch_size, shuffle=False,\n",
    "    num_workers=args.workers, pin_memory=True)\n",
    "\n",
    "# criterion = nn.CrossEntropyLoss().cuda(args.gpu)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "\n",
    "writer = SummaryWriter('logs/resnet18_imagenet/')\n",
    "\n",
    "#set all seeds for reproducability\n",
    "def set_random_seed(seed):    \n",
    "    random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(args.seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "set_random_seed(args.seed)\n",
    "\n",
    "# print(args.gpu)\n",
    "# torch.cuda.set_device(args.gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train baseline model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "set_random_seed(args.seed)\n",
    "\n",
    "args.lr = 0.1\n",
    "\n",
    "resnet18 = torchvision.models.resnet18(pretrained=True).to(device)\n",
    "# resnet18 = torch.nn.DataParallel(resnet18)\n",
    "# resnet18.cuda(args.gpu)\n",
    "torch.save(resnet18, args.checkpoint_path, pickle_protocol=4)            \n",
    "optimizer = torch.optim.SGD(resnet18.parameters(), args.lr,\n",
    "                            momentum=args.momentum,\n",
    "                            weight_decay=args.weight_decay)\n",
    "\n",
    "best_val_acc = 0\n",
    "\n",
    "if not os.path.isfile(args.checkpoint_path):\n",
    "    # for epoch in range(args.start_epoch, args.epochs):\n",
    "    for epoch in range(5):\n",
    "        adjust_learning_rate_pytorch_retrain(optimizer, epoch, args)        \n",
    "\n",
    "        # train for one epoch\n",
    "        train_loss,train_top1,train_top5 = train(train_loader, resnet18, criterion, optimizer, epoch, args)\n",
    "\n",
    "        # evaluate on validation set\n",
    "        val_loss,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args)\n",
    "        \n",
    "        if val_top1 > best_val_acc:  \n",
    "            torch.save(resnet18, args.checkpoint_path, pickle_protocol=4)            \n",
    "            best_val_acc = val_top1    \n",
    "\n",
    "        writer.add_scalars('resnet18_imagenet_pytorch_schedule/loss',{'train_loss': train_loss,\n",
    "                                        'val_loss' : val_loss}, epoch)\n",
    "        writer.add_scalars('resnet18_imagenet_pytorch_schedule/accuracy',{'train_top1': train_top1,\n",
    "                                                  'val_top1': val_top1,\n",
    "                                                  'train_top5': train_top5,\n",
    "                                                  'val_top5': val_top5}, epoch) \n",
    "else:   \n",
    "    resnet18 = torch.load(args.checkpoint_path)\n",
    "    _,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "    \n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using CUP (T = 0.5)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 66.800 Acc@5 86.600\n",
      "large top-1 69.75, small top-1 66.80, top-1 drop 2.95, large top-5 89.80, small top-5 86.60, top-5 drop 3.20\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.5,\n",
    "    'merge_criterion' : 'max_l2_norm',\n",
    "    'tropnnc_features_and_threshold' : False,\n",
    "    'variant' : False,\n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 11.66M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.58G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3581177856.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using CUP (T = 0.6)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 53.400 Acc@5 78.000\n",
      "large top-1 69.75, small top-1 53.40, top-1 drop 16.35, large top-5 89.80, small top-5 78.00, top-5 drop 11.80\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.6,\n",
    "    'merge_criterion' : 'max_l2_norm',\n",
    "    'tropnnc_features_and_threshold' : False,\n",
    "    'variant' : False,\n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 11.49M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.38G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3382521321.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using CUP (T = 0.65)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 28.450 Acc@5 50.950\n",
      "large top-1 69.75, small top-1 28.45, top-1 drop 41.30, large top-5 89.80, small top-5 50.95, top-5 drop 38.85\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.65,\n",
    "    'merge_criterion' : 'max_l2_norm',\n",
    "    'tropnnc_features_and_threshold' : False,\n",
    "    'variant' : False,\n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 11.17M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.15G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3150679389.0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 1 (T = 0.0187)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 59.150 Acc@5 83.600\n",
      "large top-1 69.75, small top-1 59.15, top-1 drop 10.60, large top-5 89.80, small top-5 83.60, top-5 drop 6.20\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 0.0187,\n",
    "    'merge_criterion' : 'tropnnc',\n",
    "    'tropnnc_features_and_threshold' : True,\n",
    "    'variant' : False, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 9.84M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.46G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3455823704.0"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 2 (T = 1.1)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 68.950 Acc@5 89.050\n",
      "large top-1 69.75, small top-1 68.95, top-1 drop 0.80, large top-5 89.80, small top-5 89.05, top-5 drop 0.75\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 1.1,\n",
    "    'merge_criterion' : 'tropnnc',\n",
    "    'tropnnc_features_and_threshold' : True,\n",
    "    'variant' : True, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 11.66M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.48G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3476479752.0"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 2 (T = 1.2)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 61.300 Acc@5 83.450\n",
      "large top-1 69.75, small top-1 61.30, top-1 drop 8.45, large top-5 89.80, small top-5 83.45, top-5 drop 6.35\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 1.2,\n",
    "    'merge_criterion' : 'tropnnc',\n",
    "    'tropnnc_features_and_threshold' : True,\n",
    "    'variant' : True, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 11.46M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 3.25G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3249418750.0"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compress using TropNNC, variant 2 (T = 1.25)\n",
    "\n",
    "- This section compresses the resnet model that we trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 69.750 Acc@5 89.800\n",
      "loaded model with top1 : 69.75, top5 : 89.80000305175781\n"
     ]
    }
   ],
   "source": [
    "resnet18 = torch.load(args.checkpoint_path).to(device)\n",
    "_,val_top1,val_top5 = validate(val_loader, resnet18, criterion, args, verbose=False)\n",
    "best_val_acc = val_top5\n",
    "print('loaded model with top1 : {}, top5 : {}'.format(val_top1,val_top5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['layer1', '0', 'conv1']\n",
      "['layer1', '1', 'conv1']\n",
      "['layer2', '0', 'conv1']\n",
      "['layer2', '1', 'conv1']\n",
      "['layer3', '0', 'conv1']\n",
      "['layer3', '1', 'conv1']\n",
      "['layer4', '0', 'conv1']\n",
      "['layer4', '1', 'conv1']\n",
      " * Acc@1 41.650 Acc@5 66.150\n",
      "large top-1 69.75, small top-1 41.65, top-1 drop 28.10, large top-5 89.80, small top-5 66.15, top-5 drop 23.65\n"
     ]
    }
   ],
   "source": [
    "cluster_args = {    \n",
    "    'cluster_layers' : {4:0,9:0,14:0,21:0,26:0,33:0,38:0,45:0},\n",
    "    'conv_feature_size' : 1,\n",
    "    'features' : 'both',\n",
    "    'channel_reduction' : 'fro',\n",
    "    'use_bias' : False,\n",
    "    'reshape_exists' : False,\n",
    "    'linkage_method' : 'ward',\n",
    "    'distance_metric' : 'euclidean',\n",
    "    'cluster_criterion' : 'hierarchical',\n",
    "    'distance_threshold' : 1.25,\n",
    "    'merge_criterion' : 'tropnnc',\n",
    "    'tropnnc_features_and_threshold' : True,\n",
    "    'variant' : True, \n",
    "    'verbose' : False\n",
    "}\n",
    "\n",
    "model_modifier = cluster_model(resnet18,cluster_args)\n",
    "resnet18_clustered = model_modifier.cluster_model()\n",
    "\n",
    "# resnet18_clustered = torch.nn.DataParallel(resnet18_clustered)\n",
    "# resnet18_clustered.cuda(args.gpu)\n",
    "resnet18_clustered = resnet18_clustered.to(device)\n",
    "\n",
    "_,top1_acc_no_retrain,top5_acc_no_retrain = validate(val_loader, resnet18_clustered, criterion, args, verbose=False)\n",
    "\n",
    "print('large top-1 {:.2f}, small top-1 {:.2f}, top-1 drop {:.2f}, large top-5 {:.2f}, small top-5 {:.2f}, top-5 drop {:.2f}'.format(val_top1,top1_acc_no_retrain,val_top1-top1_acc_no_retrain,val_top5,top5_acc_no_retrain,val_top5-top5_acc_no_retrain))       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  + Number of params: 11.69M\n",
      "  + Number of params: 10.98M\n",
      "  + Number of FLOPs: 3.64G\n",
      "  + Number of FLOPs: 2.92G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2916881034.0"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from src.compute_flops import print_model_param_nums,print_model_param_flops\n",
    "\n",
    "print_model_param_nums(resnet18)\n",
    "print_model_param_nums(resnet18_clustered)\n",
    "\n",
    "print_model_param_flops(resnet18.cpu(),input_res=224)\n",
    "print_model_param_flops(resnet18_clustered.cpu(),input_res=224)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
