{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e2610a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import copy\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import nn_utils\n",
    "from nn_utils import train, test, load_model\n",
    "from fc import FC2, FC1\n",
    "from id_utils import pruneID, compare_prune, getID\n",
    "from data_utils import fashionmnist_loader, cifar10_loader\n",
    "from plt_utils import plt_IDerr\n",
    "import torch.nn.utils.prune as prune\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from train import model_summary\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import importlib\n",
    "summary_input = (3,32,32)\n",
    "# functions to show an image\n",
    "%matplotlib inline\n",
    "trial=2\n",
    "torch.manual_seed(trial)\n",
    "np.random.seed(trial)\n",
    "\n",
    "\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt \n",
    "%matplotlib inline\n",
    "import torch.nn.functional as F\n",
    "class arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size=64\n",
    "        self.test_batch_size=1000\n",
    "        self.lr=.2\n",
    "        self.epochs=200\n",
    "        self.log_interval=10\n",
    "        self.verbose=False\n",
    "        self.prune_batch_size=1000\n",
    "        self.dataset=\"cifar10\"\n",
    "        self.use_valid=True\n",
    "        self.ft_proportion=0\n",
    "        self.fine_tune=False\n",
    "        self.arch=\"VGG16\"\n",
    "        self.load_fname=\"vgg16_epo160_seed2_best.pt\"\n",
    "args=arguments()\n",
    "args.k_args=\"fracSkipIter\"\n",
    "\n",
    "\n",
    "args.skip=[]\n",
    "\n",
    "model_full=load_model(args)\n",
    "seed=4\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "train_loader, test_loader, prune_loader=cifar10_loader(args)\n",
    "\n",
    "args.pruner=\"id\"\n",
    "args.pruner_args=\"Zorig\"\n",
    "args.k_args=\"fracSkipIter\"\n",
    "args.saveR=False\n",
    "X_prune, _ = next(iter(prune_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0496f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(i, t=1.0):\n",
    "    ex = torch.exp(i/t)\n",
    "    s = torch.sum(ex, axis=0)\n",
    "    return ex / s\n",
    "\n",
    "def cross_entropy(distribution, target, t=1.0):\n",
    "    target=softmax(target)\n",
    "    distribution=softmax(distribution)\n",
    "    return -torch.sum(target * torch.log(distribution))/1000\n",
    "    \n",
    "\n",
    "def distill(args, model, device, train_loader, criterion, optimizer, epoch, alpha=.1, T=1):\n",
    "    model.train()\n",
    "    avg_loss = 0\n",
    "    correct_1, correct_5, total = 0, 0, 0\n",
    "    for data, target in train_loader:\n",
    "        target1=model_full(data)\n",
    "        data   = data.to(device, non_blocking=True)\n",
    "        target = target.to(device, non_blocking=True) \n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target) \n",
    "        loss+=alpha*cross_entropy(output, target1, t=T)\n",
    "        avg_loss += loss.item()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "    avg_loss /= len(train_loader) # over num batches\n",
    "\n",
    "    return avg_loss\n",
    "\n",
    "def test(args, model, device, test_loader, criterion, epoch, returnAcc=False, pstatement=\"Test\"):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correlation=0\n",
    "    correct_1, correct_5, total = 0, 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "\n",
    "            data  = data.to(device, non_blocking=True)\n",
    "            target = target.to(device, non_blocking=True)\n",
    "            output = model(data)\n",
    "            with torch.no_grad():\n",
    "                original=model_full(data)\n",
    "                labels=original.argmax(dim=1, keepdim=True)\n",
    "            \n",
    "            if args.dataset == \"fashion-mnist\" \\\n",
    "                    or args.dataset == \"cifar10\":\n",
    "                test_loss += criterion(output, target).item()  \n",
    "                pred = output.argmax(dim=1, keepdim=True)  \n",
    "\n",
    "                correlation+=(np.sum(pred.numpy()==labels.numpy()))\n",
    "\n",
    "                correct_1 += pred.eq(target.view_as(pred)).sum().item()\n",
    "            total += len(target)\n",
    "    test_loss /= len(test_loader) # over num batches\n",
    "\n",
    "    if epoch % args.log_interval == 0 and args.verbose:\n",
    "        if args.dataset == \"patches\":\n",
    "            print('{:s}: \\t\\t\\tAvg Loss: {:f}\\n'.format(pstatement, test_loss), flush=True)\n",
    "        elif args.dataset == \"fashion-mnist\" or args. dataset == \"cifar10\":\n",
    "            print('{:s}: Avg loss: {:f}, Accuracy: {}/{} ({:.1f}%) Top5 Accuracy: {}/{} ({:.1f}%)'.\\\n",
    "                    format(pstatement, \n",
    "                    test_loss, correct_1, total,\n",
    "                    100. * correct_1 / total, \n",
    "                    correct_5, total,\n",
    "                    100. * correct_5 / total),\n",
    "                    flush=True)\n",
    "    if returnAcc:\n",
    "        return test_loss, correct_1 / total, correct_5 / total, correlation/total\n",
    "    return test_loss\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa6abd97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "import id_utils6\n",
    "from id_utils6 import choosePruneMethod\n",
    "importlib.reload(id_utils6)\n",
    "import time \n",
    "start=time.time()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "ms, conv, lin=model_summary(model_full,summary_input= summary_input, input_res=32)\n",
    "pruned=choosePruneMethod(args, model_full, X_prune, .95)\n",
    "criterion=nn.CrossEntropyLoss()\n",
    "lr=.004\n",
    "epoch=1\n",
    "while ms>627880970.0/2:\n",
    "    pruned=choosePruneMethod(args, pruned, X_prune, .90)\n",
    "    ms, conv, lin=model_summary(pruned,summary_input= summary_input, input_res=32)\n",
    "    optimizer = optim.SGD(pruned.parameters(), lr=lr)\n",
    "    distill(args, pruned, device, prune_loader, \n",
    "            criterion, optimizer, epoch, alpha=1)\n",
    "print(time.time()-start)\n",
    "print(test(args, pruned, 'cpu', test_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3655852e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "import id_utils6\n",
    "from id_utils6 import choosePruneMethod\n",
    "importlib.reload(id_utils6)\n",
    "import time \n",
    "start=time.time()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "ms, conv, lin=model_summary(model_full,summary_input= summary_input, input_res=32)\n",
    "pruned=choosePruneMethod(args, model_full, X_prune, .95)\n",
    "criterion=nn.CrossEntropyLoss()\n",
    "lr=.001\n",
    "epoch=1\n",
    "while ms>627880970.0/2:\n",
    "    pruned=choosePruneMethod(args, pruned, X_prune, .90)\n",
    "    ms, conv, lin=model_summary(pruned,summary_input= summary_input, input_res=32)\n",
    "\n",
    "print(time.time()-start)\n",
    "print(test(args, pruned, 'cpu', test_loader, criterion = nn.CrossEntropyLoss(), epoch=160, returnAcc=True))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "323737bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(pruned.parameters(), lr=lr)\n",
    "for epoch in range(0,30):\n",
    "    print(epoch)\n",
    "    distill(args, pruned, device, prune_loader, \n",
    "            criterion, optimizer, epoch, alpha=1)\n",
    "    print(test(args ,pruned, device, test_loader, \n",
    "            criterion, epoch, returnAcc=True))\n",
    "    optimizer = optim.SGD(pruned.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3edb1e98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
