{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97e2b486-fbda-42bd-ab14-f060bd788311",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader, random_split, TensorDataset\n",
    "import torch.nn.utils.prune as prune\n",
    "from torchvision import datasets, transforms, models\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import copy\n",
    "import time\n",
    "\n",
    "SEED = 0\n",
    "torch.manual_seed(SEED)\n",
    "os.environ['PYTHONHASHSEED'] = str(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "g = torch.Generator()\n",
    "g.manual_seed(0)\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a4833364-f27c-4c3b-88eb-4bb6436b3c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prune_model(model, pruning_amount=0.3, pruning_method=prune.L1Unstructured):\n",
    "    model = copy.deepcopy(model)\n",
    "    parameters_to_prune = []\n",
    "    \n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, nn.Conv2d):\n",
    "            parameters_to_prune.append((module, 'weight'))\n",
    "        elif isinstance(module, nn.Linear):\n",
    "            parameters_to_prune.append((module, 'weight'))\n",
    "    prune.global_unstructured(\n",
    "        parameters_to_prune,\n",
    "        pruning_method=pruning_method,\n",
    "        amount=pruning_amount\n",
    "    )\n",
    "    \n",
    "    for module, _ in parameters_to_prune:\n",
    "        prune.remove(module, 'weight')\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "993d794a-0983-483f-ac1c-50cb70b803fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, val_loader, optimizer, criterion, num_epoch):\n",
    "    print('Training....')\n",
    "    start = time.time()\n",
    "    \n",
    "    for epoch in range(num_epoch):\n",
    "        total = 0\n",
    "        correct = 0\n",
    "        for i, data in enumerate(train_loader, 1):\n",
    "            images, labels = data\n",
    "    \n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "    \n",
    "            optimizer.zero_grad()    \n",
    "            outputs = model(images)\n",
    "    \n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    \n",
    "            loss = criterion(outputs, labels)\n",
    "            # if (i % 100 == 0):\n",
    "            #     print('Epoch: {} Batch: {} loss: {}'.format(epoch, i, loss.item()))\n",
    "    \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        evaluate(model, val_loader)\n",
    "    \n",
    "    print('Training Completed in: {} secs'.format(time.time()-start))\n",
    "    print('Training accuracy: {} %'.format((correct/total)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "44c41264-91a3-4ed4-af48-ba1ee6531ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_loader):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    for i, data in enumerate(test_loader, 1):\n",
    "        images, labels = data\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "    \n",
    "    print('Accuracy: {} %'.format((correct/total)*100))\n",
    "    return (correct/total)*100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecf7ff12-0341-47b9-ab61-a7d0d3597e37",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "96d1e792-68c9-4a14-b270-032a59d16e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"batch_size\": 256,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "472dc1e0-1745-4071-930d-f47250b8070c",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize((32, 32)),\n",
    "    transforms.Grayscale(num_output_channels=3),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "])\n",
    "\n",
    "dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=g)\n",
    "\n",
    "mnist_train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True, generator=g)\n",
    "mnist_val_loader = DataLoader(val_dataset, batch_size=hyperparams['batch_size'], shuffle=False, generator=g)\n",
    "mnist_test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'], shuffle=False, generator=g)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6680e40b-28cf-4ed4-8101-c3fc1671d21f",
   "metadata": {},
   "source": [
    "## Resnet-18 pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c46a222-2c72-4f4e-92d1-2fe6aa271b1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 10,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 4e-5,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0e8e3ffd-0436-4ef9-8df2-1bea47a462fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(weights='DEFAULT')\n",
    "model.fc = nn.Linear(512, 10, bias=True)\n",
    "\n",
    "# for param in model.parameters():\n",
    "#     param.requires_grad = False\n",
    "\n",
    "# for param in model.fc.parameters():\n",
    "#     param.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe95be4-cec4-468f-9c9c-6499bd5d33d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d5221d5e-22e1-4a9a-a391-b546daa2110a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28e4f925-2ca1-4ea5-b792-2834c77f0391",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, mnist_train_loader, mnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1fc379f-9c8d-45b4-aeac-1c2f65023c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "3ec6b961-fe59-488e-97ed-40d11d5fedb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecc35781-96e3-4e3d-8ecb-40d325c0d17e",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1d1255-1141-43b3-8704-e1951414906a",
   "metadata": {},
   "outputs": [],
   "source": [
    "11181642 * (1 - 0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50627342-f5ba-40ee-a6a0-96c4f938c089",
   "metadata": {},
   "source": [
    "## Resnext50 pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c4ed3d3a-a4a7-4e60-8203-457eb9c66838",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 10,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 1e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8e9e1138-c3a8-40c0-b0a2-425f88397bff",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnext50_32x4d(weights='DEFAULT')\n",
    "model.fc = nn.Linear(2048, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaf03af-f486-497d-820b-a1592f7571cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8b0036c3-848f-4841-b0c5-0cd26aa1e9f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"], weight_decay=hyperparams[\"weight_decay\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c93939-2948-4a1c-b087-75784cf4c0d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, mnist_train_loader, mnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51a132dd-bef2-4ac0-bc96-dd649b941f13",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "3f8e5835-0ffb-479b-b27c-067279e810c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25711245-ce04-4a29-bc97-bcbc741bb23e",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a9ea13e-c848-46a8-b424-df53f425ae31",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e92a3c6-5b29-4893-bd08-19451b5b423b",
   "metadata": {},
   "source": [
    "## MNASNet-1 pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "4bb80d8c-ed63-4ae8-9cef-b6d647e872dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 10,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 4e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c25afe95-5702-4b9e-89e7-52501a7de3cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.mnasnet1_0(weights='DEFAULT')\n",
    "model.classifier[1] = nn.Linear(1280, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74128a2-32d0-4397-8583-26d863987e56",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "0516a159-3294-46ce-a315-87e915f3b95d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"], weight_decay=hyperparams[\"weight_decay\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eab8349e-01a1-4d82-9079-78cd1354bbc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, mnist_train_loader, mnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a1be1be-a735-41de-8250-0c2f9756abf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "57880ce8-a08f-4e30-a84f-394ebb43d7c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.825)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "004a52a1-a0ea-421a-9e69-4b43373b558e",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "921f67cc-ae08-4a49-af07-8cb5b67a2c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.825)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "329de23c-bed1-4e7e-b1f4-03d564575c85",
   "metadata": {},
   "source": [
    "## EfficientNet-V2-Small pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "33fe6ea7-404d-485c-8c94-82662c7c7771",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 16,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 2e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "f1784969-08a1-407c-bbc5-616c28a9a8a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.efficientnet_v2_s(weights='DEFAULT')\n",
    "model.classifier[1] = nn.Linear(1280, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af5bfc4c-34eb-472c-b6f6-77e54e791307",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "52ce5e62-2658-48ca-ab4c-946fbbbed141",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"], weight_decay=hyperparams[\"weight_decay\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c6ae8aa-edab-4b97-96da-5c36c1a60c8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, mnist_train_loader, mnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "496075ed-a478-41f0-89c7-93005bcd5117",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "26f1f432-45f8-4011-9c78-932c0f6b0782",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.91)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "772335d4-b3b9-4c8f-8c9d-db39f821fbc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a09bd5c-c324-4758-a696-7c40dcf856f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.91)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab771775-0644-4a88-9757-5beb9e165f75",
   "metadata": {},
   "source": [
    "## MobileNet-V3-Small pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "id": "9b0feab2-bf64-420d-bc22-566b7fb1e41a",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 16,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 8e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "id": "1ef5f8b1-e803-4c81-ac7f-0777934aed32",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.mobilenet_v3_small(weights='DEFAULT')\n",
    "model.classifier[3] = nn.Linear(1024, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc24e79-e32d-4f2a-bffb-4d3a6bd9d445",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "id": "284893b5-f5d4-447d-aa12-4bb76f34305a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"], weight_decay=hyperparams[\"weight_decay\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f30ce742-c0bd-42ec-99d0-3b2165f74542",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, mnist_train_loader, mnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d45a171-a2fd-439a-82dd-a3f79351bf1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "id": "a7eee784-b0f5-4a9e-9041-957017f178f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.937)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f9654ba-e52d-45ef-b645-a2fa404e4cfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, mnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d78d3532-dd3c-4eef-90b1-4eed4405d273",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.937)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bbe62e0-3439-46bd-bbf7-541cb3b596e4",
   "metadata": {},
   "source": [
    "# FashionMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "3dd8ad75-811b-45ea-97cf-0fea1bea8db1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize((32, 32)),\n",
    "    transforms.Grayscale(num_output_channels=3),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.2860,), (0.3530,))\n",
    "])\n",
    "\n",
    "dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
    "\n",
    "test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=g)\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "fmnist_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)\n",
    "fmnist_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, generator=g)\n",
    "fmnist_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, generator=g)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13a5be23-3ca9-430c-a6f0-4727164d88e9",
   "metadata": {},
   "source": [
    "## Resnet-18 pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "d2f1907f-81d1-4fe0-8166-39860933fd09",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 10,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 4e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "f58a8351-bda1-4e08-9e18-840a82e3fefd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(weights='DEFAULT')\n",
    "model.fc = nn.Linear(512, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b6bb47f-ca8f-4d40-b9c9-33fce9c2bc67",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "500e4b25-3519-4cba-ab63-cb6953ac6599",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92d5d0e3-4ac9-4075-86cc-f1046fab5c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, fmnist_train_loader, fmnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f215196-ab0b-4b0b-9fc4-78ad68ffcc21",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "0c21e071-f259-49af-a2c9-8dd5c30cd16e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.86)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65974b83-c23e-44cc-bf89-e33cbfc19bda",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7238bf-e6ac-49be-a710-29cc37cc673b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.86)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e9d71b5-3c75-412f-b205-2487926efc57",
   "metadata": {},
   "source": [
    "## Resnext50 pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9638d97c-fbbe-4377-b57d-a4bb48537de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 20,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 5e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-4\n",
    "}\n",
    "\n",
    "name = \", \".join(\n",
    "    f\"{key}: {value.__class__.__name__ if key == 'metric' else value}\"\n",
    "    for key, value in hyperparams.items()\n",
    ")\n",
    "\n",
    "name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "dd5cc608-0730-409b-9310-da9bbc8004d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnext50_32x4d()\n",
    "model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
    "model.fc = nn.Linear(2048, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5388762-1111-4464-b9d5-3d965dbcb151",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "582dee0c-5722-4776-be5c-d6577895d9fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff9ecc13-30ca-4a89-8648-6b71a76d3592",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, fmnist_train_loader, fmnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c31c6085-d800-4980-bccc-9df5b24fa37c",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "8e242bda-34e3-4146-b09d-58343b7b0b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "104c2258-9dab-4019-8b9b-2ac577e29b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5750579-7dd1-4f72-a4fd-3eaa890aa6c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.86)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67e9ef11-be72-40e8-867c-21733eee47d4",
   "metadata": {},
   "source": [
    "## MobileNet-V3-Small pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "id": "06324f4c-67ef-492a-9b9b-3deb3ebd676e",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 10,\n",
    "    \"batch_size\": 256,\n",
    "    \"lr\": 5e-4,    \n",
    "    \"task_type\": \"classification\",\n",
    "    \"weight_decay\": 1e-3\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "id": "5d96ca2d-6ad3-4521-8194-fae3a2a2e75c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.mobilenet_v3_small(weights='DEFAULT')\n",
    "model.classifier[3] = nn.Linear(1024, 10, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7668c0e-eb70-447e-b499-7b4150020020",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
    "pytorch_total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "id": "d5ae1e19-970c-40d9-b7d3-e8b0e6f4a178",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=hyperparams[\"lr\"], weight_decay=hyperparams[\"weight_decay\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ec9f3e3-8ab1-4589-92b7-a0e872d28833",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(model, fmnist_train_loader, fmnist_val_loader, optimizer, criterion, hyperparams[\"num_epochs\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25458f2-faf1-455c-9d71-c74637f0cb99",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "id": "3f751d1b-746f-4ba1-b9a8-a1b43c4c1a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_model = prune_model(model, pruning_amount=0.925)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f68f17d-af33-429f-bb7a-b0429c53ad99",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "evaluate(pruned_model, fmnist_test_loader)\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1855f6e6-85f9-4400-8043-1a77dbb742fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_total_params * (1 - 0.925)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8722ec3c-8453-48a3-b53a-5739be6f85f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
