{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4068b83",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import torch\n",
    "import torchvision\n",
    "import torchprune as tp\n",
    "import torch.nn as nn\n",
    "import argparse\n",
    "import nn_utils\n",
    "class arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size=64\n",
    "        self.test_batch_size=1000\n",
    "        self.lr=.2\n",
    "        self.epochs=200\n",
    "        self.log_interval=10\n",
    "        self.verbose=False\n",
    "        self.prune_batch_size=1000\n",
    "        self.dataset=\"cifar10\"\n",
    "        self.use_valid=True\n",
    "        self.ft_proportion=0\n",
    "        self.fine_tune=False\n",
    "        self.arch=\"VGG16\"\n",
    "        self.load_fname=\"vgg16_epo160_seed2_best.pt\"\n",
    "args=arguments()\n",
    "\n",
    "from nn_utils import train, test, load_model\n",
    "model_full=load_model(args)\n",
    "model=model_full\n",
    "def test(args, model, device, test_loader, criterion, epoch, returnAcc=False, pstatement=\"Test\"):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correlation=0\n",
    "    correct_1, correct_5, total = 0, 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "\n",
    "            data  = data.to(device, non_blocking=True)\n",
    "            target = target.to(device, non_blocking=True)\n",
    "            output = model(data)\n",
    "            with torch.no_grad():\n",
    "                original=model_full(data)\n",
    "                labels=original.argmax(dim=1, keepdim=True)\n",
    "            \n",
    "            if args.dataset == \"fashion-mnist\" \\\n",
    "                    or args.dataset == \"cifar10\":\n",
    "                test_loss += criterion(output, target).item()  \n",
    "                pred = output.argmax(dim=1, keepdim=True)  \n",
    "\n",
    "                correlation+=(np.sum(pred.numpy()==labels.numpy()))\n",
    "\n",
    "                correct_1 += pred.eq(target.view_as(pred)).sum().item()\n",
    "            total += len(target)\n",
    "    test_loss /= len(test_loader) # over num batches\n",
    "\n",
    "    if epoch % args.log_interval == 0 and args.verbose:\n",
    "        if args.dataset == \"patches\":\n",
    "            print('{:s}: \\t\\t\\tAvg Loss: {:f}\\n'.format(pstatement, test_loss), flush=True)\n",
    "        elif args.dataset == \"fashion-mnist\" or args. dataset == \"cifar10\":\n",
    "            print('{:s}: Avg loss: {:f}, Accuracy: {}/{} ({:.1f}%) Top5 Accuracy: {}/{} ({:.1f}%)'.\\\n",
    "                    format(pstatement, \n",
    "                    test_loss, correct_1, total,\n",
    "                    100. * correct_1 / total, \n",
    "                    correct_5, total,\n",
    "                    100. * correct_5 / total),\n",
    "                    flush=True)\n",
    "    if returnAcc:\n",
    "        return test_loss, correct_1 / total, correct_5 / total, correlation/total\n",
    "    return test_loss\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "from data_utils import fashionmnist_loader, cifar10_loader\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "\n",
    "args.k_args=\"fracSkipIter\"\n",
    "args.skip=[]\n",
    "trial=4\n",
    "torch.manual_seed(trial)\n",
    "np.random.seed(trial)\n",
    "train_loader, test_loader, prune_loader=cifar10_loader(args)\n",
    "import pickle\n",
    "file=open(\"prundModel.p\", 'rb')\n",
    "model=pickle.load(file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3efd865d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_params = {\n",
    "    # any loss and corresponding kwargs for __init__ from tp.util.nn_loss\n",
    "    \"loss\": \"CrossEntropyLoss\",\n",
    "    \"lossKwargs\": {\"reduction\": \"mean\"},\n",
    "    # exactly two metrics with __init__ kwargs from tp.util.metrics\n",
    "    \"metricsTest\": [\n",
    "        {\"type\": \"TopK\", \"kwargs\": {\"topk\": 1}},\n",
    "        {\"type\": \"TopK\", \"kwargs\": {\"topk\": 5}},\n",
    "    ],\n",
    "    # any optimizer from torch.optim with corresponding __init__ kwargs\n",
    "    \"optimizer\": \"SGD\",\n",
    "    \"optimizerKwargs\": {\n",
    "        \"lr\": 0.1,\n",
    "        \"weight_decay\": 1.0e-4,\n",
    "        \"nesterov\": False,\n",
    "        \"momentum\": 0.9,\n",
    "    },\n",
    "    # batch size\n",
    "    \"batchSize\": batch_size,\n",
    "    # desired number of epochs\n",
    "    \"startEpoch\": 0,\n",
    "    \"retrainStartEpoch\": -1,\n",
    "    \"numEpochs\": 5,  # 182\n",
    "    # any desired combination of lr schedulers from tp.util.lr_scheduler\n",
    "    \"lrSchedulers\": [\n",
    "        {\n",
    "            \"type\": \"MultiStepLR\",\n",
    "            \"stepKwargs\": {\"milestones\": [91, 136]},\n",
    "            \"kwargs\": {\"gamma\": 0.1},\n",
    "        },\n",
    "        {\"type\": \"WarmupLR\", \"stepKwargs\": {\"warmup_epoch\": 5}, \"kwargs\": {}},\n",
    "    ],\n",
    "    # output size of the network\n",
    "    \"outputSize\": 10,\n",
    "    # directory to store checkpoints\n",
    "    \"dir\": os.path.realpath(\"./checkpoints\"),\n",
    "}\n",
    "retrain_params = copy.deepcopy(train_params)\n",
    "trainer = tp.util.train.NetTrainer(\n",
    "    train_params=train_params,\n",
    "    retrain_params=retrain_params,\n",
    "    train_loader=train_loader,\n",
    "    test_loader=test_loader,\n",
    "    valid_loader=prune_loader,\n",
    "    num_gpus=0,\n",
    ")\n",
    "loss_handle = trainer.get_loss_handle()\n",
    "net_name = \"vg16_CIFAR10\"\n",
    "model.eval()\n",
    "net = tp.util.net.NetHandle(model, net_name)\n",
    "from nn_utils import print_model_param_flops\n",
    "sizes=[.6, .5, .45, .4, .35, .3,.25, .2,.15, .1]\n",
    "params=[]\n",
    "accs=[]\n",
    "flops=[]\n",
    "for keep_ratio in sizes:\n",
    "    print(keep_ratio)\n",
    "    \n",
    "    net_weight_pruned = tp.LearnedRankNet(net, prune_loader, loss_handle)\n",
    "    \n",
    "    test(args, net_weight_pruned, 'cpu', prune_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True)\n",
    "    net_weight_pruned.compress(keep_ratio=keep_ratio)\n",
    "    flops.append(print_model_param_flops(net_weight_pruned.compressed_net.torchnet, input_res=32)[0])\n",
    "    params.append(net_weight_pruned.size())\n",
    "    accs.append(test(args, net_weight_pruned, 'cpu', test_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True))\n",
    "    print(accs)\n",
    "    \n",
    "accs=np.array(accs)\n",
    "params=np.array(params)\n",
    "file=open('LRank+IDNet.p', 'wb')\n",
    "import pickle\n",
    "\n",
    "pickle.dump([params, accs, flops],file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb8ea7a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from nn_utils import print_model_param_flops\n",
    "import importlib\n",
    "importlib.reload(nn_utils)\n",
    "from nn_utils import print_model_param_flops\n",
    "#print(net_weight_pruned.compressed_net.torchnet)\n",
    "print(print_model_param_flops(net_weight_pruned.compressed_net.torchnet, input_res=32)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "370ced11",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_model_param_flops(model_full, input_res=32)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03cd82d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "31282/62788"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77dc3f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "test(args, model, 'cpu', test_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa092cee",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_full=load_model(args)\n",
    "test(args, model_full, 'cpu', test_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58af665",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
