{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d58ed68e-186c-4630-b336-3c1f913b1553",
   "metadata": {
    "id": "d58ed68e-186c-4630-b336-3c1f913b1553"
   },
   "source": [
    "# Librerie e Funzioni"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3291b4b8-c7a5-4f58-8431-66a341882771",
   "metadata": {
    "id": "3291b4b8-c7a5-4f58-8431-66a341882771"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dtraini/.local/lib/python3.10/site-packages/pytorch_wavelets/dtcwt/coeffs.py:7: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
      "  from pkg_resources import resource_stream\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "from torch.utils.data import DataLoader, TensorDataset, Dataset\n",
    "from torch import optim, nn\n",
    "from torch.nn import functional\n",
    "import os\n",
    "import time\n",
    "import csv\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import shutil\n",
    "from torchvision.utils import make_grid\n",
    "from random import randint\n",
    "from PIL import Image\n",
    "import random\n",
    "\n",
    "from einops import rearrange\n",
    "from einops.layers.torch import Rearrange\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from vit_pytorch.ats_vit import ViT as ATSViT\n",
    "from vit_pytorch import ViT\n",
    "\n",
    "from MultiViT import MultiViT\n",
    "from TopK import TopKViT as ViT_topk\n",
    "from PatchMergerViT import PatchMergerViT\n",
    "from ToMe import ToMe\n",
    "from EViT import EViT\n",
    "from DWTViT import DWTViT\n",
    "from DWTViT_pruning import DWTViT_pruning\n",
    "from DWTViT_quantile import DWTViT_quantile\n",
    "from DWTViT_gini import DWTViT_gini\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from einops import rearrange, repeat\n",
    "from einops.layers.torch import Rearrange\n",
    "\n",
    "from thop import profile\n",
    "\n",
    "import torch\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75a86dcf-4ef7-426f-bde1-f7636152ebb3",
   "metadata": {
    "id": "75a86dcf-4ef7-426f-bde1-f7636152ebb3"
   },
   "source": [
    "## Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "baf78064-db6d-4a48-b5ae-867ce9ab21f7",
   "metadata": {
    "id": "baf78064-db6d-4a48-b5ae-867ce9ab21f7"
   },
   "outputs": [],
   "source": [
    "def benchmark_time(\n",
    "    model,\n",
    "    device,\n",
    "    input_size,\n",
    "    batch_size,\n",
    "    runs,\n",
    "    throw_out,\n",
    "    verbose,\n",
    "):\n",
    "    \"\"\"\n",
    "    Benchmark the given model with random inputs at the given batch size.\n",
    "\n",
    "    Args:\n",
    "     - model: the module to benchmark\n",
    "     - device: the device to use for benchmarking\n",
    "     - input_size: the input size to pass to the model (channels, h, w)\n",
    "     - batch_size: the batch size to use for evaluation\n",
    "     - runs: the number of total runs to do\n",
    "     - throw_out: the percentage of runs to throw out at the start of testing\n",
    "     - verbose: whether or not to use tqdm to print progress / print throughput at end\n",
    "\n",
    "    Returns:\n",
    "     - the throughput measured in images / second\n",
    "    \"\"\"\n",
    "    if not isinstance(device, torch.device):\n",
    "        device = torch.device(device)\n",
    "    is_cuda = torch.device(device).type == \"cuda\"\n",
    "\n",
    "    model = model.eval().to(device)\n",
    "    print(device)\n",
    "    input = torch.rand(batch_size, *input_size, device=device)\n",
    "\n",
    "    warm_up = int(runs * throw_out)\n",
    "    total = 0\n",
    "    start = time.time()\n",
    "\n",
    "    with torch.autocast(device.type):\n",
    "        with torch.no_grad():\n",
    "            for i in tqdm(range(runs), disable=not verbose, desc=\"Benchmarking\"):\n",
    "                if i == warm_up:\n",
    "                    if is_cuda:\n",
    "                        torch.cuda.synchronize()\n",
    "                    total = 0\n",
    "                    start = time.time()\n",
    "\n",
    "                model(input)\n",
    "                total += batch_size\n",
    "\n",
    "    if is_cuda:\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "    end = time.time()\n",
    "    elapsed = end - start\n",
    "\n",
    "    throughput = total / elapsed\n",
    "    throughput = f'{throughput:.2f}'\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Throughput: {throughput:.2f} im/s\")\n",
    "\n",
    "    return [throughput]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def benchmark_flops(\n",
    "    model,\n",
    "    device,\n",
    "    input_size,\n",
    "    batch_size,\n",
    "    runs,\n",
    "    verbose,\n",
    "):\n",
    "    \"\"\"\n",
    "    Benchmark the given model with random inputs at the given batch size.\n",
    "\n",
    "    Args:\n",
    "     - model: the module to benchmark\n",
    "     - device: the device to use for benchmarking\n",
    "     - input_size: the input size to pass to the model (channels, h, w)\n",
    "     - batch_size: the batch size to use for evaluation\n",
    "     - runs: the number of total runs to do\n",
    "     - verbose: whether or not to use tqdm to print progress / print throughput at end\n",
    "\n",
    "    Returns:\n",
    "     - the throughput measured in images / second\n",
    "    \"\"\"\n",
    "    if not isinstance(device, torch.device):\n",
    "        device = torch.device(device)\n",
    "    is_cuda = torch.device(device).type == \"cuda\"\n",
    "\n",
    "    model = model.eval().to(device)\n",
    "    print(device)\n",
    "    input = torch.rand(batch_size, *input_size, device=device)\n",
    "\n",
    "    total = 0\n",
    "    total_flops = 0\n",
    "    total_params = 0\n",
    "\n",
    "    with torch.autocast(device.type):\n",
    "        with torch.no_grad():\n",
    "            for i in tqdm(range(runs), disable=not verbose, desc=\"Benchmarking\"):\n",
    "                flops, params = profile(model, inputs=(input, ), verbose=False)\n",
    "                total += batch_size\n",
    "                total_flops += flops\n",
    "\n",
    "    flops = total_flops/total\n",
    "    flops = f'{flops/1e9:.4f}'\n",
    "    if verbose:\n",
    "        print(f\"Total FLOPs: {flops / 1e9:.4f} GFLOPs\")\n",
    "        print(f\"Total Parameters: {params / 1e6:.4f} M\")\n",
    "\n",
    "    return [flops], [params]\n",
    "\n",
    "\n",
    "\n",
    "def create_patch_list(total_patches, cut):\n",
    "    # Calcola l'importo scontato per ogni blocco di 3 elementi\n",
    "    discounted_patches = total_patches\n",
    "    patch_list = []\n",
    "    for i in range(12):\n",
    "        if i % 3 == 0 and i != 0:\n",
    "            discounted_patches = int(discounted_patches * (cut / 100))\n",
    "        patch_list.append(discounted_patches)\n",
    "    return patch_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c55da31a-fa55-464e-8c80-cdc0f33d357e",
   "metadata": {
    "id": "c55da31a-fa55-464e-8c80-cdc0f33d357e"
   },
   "outputs": [],
   "source": [
    "def get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, tome_patches, DW_patches, type, df, device):\n",
    "\n",
    "    # ViT\n",
    "    print('##### ViT ####')\n",
    "    ViTnet = ViT(\n",
    "        image_size = img_size,\n",
    "        patch_size = patch_size,\n",
    "        num_classes = 10,\n",
    "        dim = att_dim,\n",
    "        depth = depth,\n",
    "        heads = heads,\n",
    "        mlp_dim = mlp_dim,\n",
    "    )\n",
    "\n",
    "    ViTnet.to(device)\n",
    "\n",
    "    baseline_throughput = benchmark_time(\n",
    "        ViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    baseline_flops, baseline_params = benchmark_flops(\n",
    "        ViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'ViT_net_{type}', 'Throughput': baseline_throughput, 'Flops': baseline_flops, 'Params': baseline_params})], ignore_index=True)\n",
    "\n",
    "    # ATS\n",
    "    print('##### ATSViT ####')\n",
    "    ATSViTnet = ATSViT(\n",
    "        image_size = img_size,\n",
    "        patch_size = patch_size,\n",
    "        num_classes = 10,\n",
    "        dim = att_dim,\n",
    "        depth = depth,\n",
    "        heads = heads,\n",
    "        mlp_dim = mlp_dim,\n",
    "        max_tokens_per_depth = max_tokens_per_depth\n",
    "    )\n",
    "\n",
    "    ATSViTnet.to(device)\n",
    "\n",
    "    ATS_throughput = benchmark_time(\n",
    "        ATSViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    ATS_flops, ATS_params = benchmark_flops(\n",
    "        ATSViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'ATS_{type}', 'Throughput': ATS_throughput, 'Flops': ATS_flops, 'Params': ATS_params})], ignore_index=True)\n",
    "\n",
    "    # MultiViT\n",
    "    print('##### MultiViT ####')\n",
    "    MultiViTnet = MultiViT(\n",
    "        image_size = img_size,\n",
    "        patch_size = patch_size,\n",
    "        num_classes = 10,\n",
    "        dim = att_dim,\n",
    "        depth = depth,\n",
    "        heads = heads,\n",
    "        mlp_dim = mlp_dim,\n",
    "        n_patch = max_tokens_per_depth\n",
    "    )\n",
    "\n",
    "    MultiViTnet.to(device)\n",
    "\n",
    "\n",
    "    MultiViT_throughput = benchmark_time(\n",
    "        MultiViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    MultiViT_flops, MultiViT_params = benchmark_flops(\n",
    "        MultiViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'MultiViT_{type}', 'Throughput': MultiViT_throughput, 'Flops': MultiViT_flops, 'Params': MultiViT_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "    print('##### Transformer_topk ####')\n",
    "    # Transformer_topk\n",
    "    Transformer_topknet = ViT_topk(\n",
    "        image_size = img_size,\n",
    "        patch_size = patch_size,\n",
    "        num_classes = 10,\n",
    "        dim = att_dim,\n",
    "        depth = depth,\n",
    "        heads = heads,\n",
    "        mlp_dim = mlp_dim,\n",
    "        n_patch = max_tokens_per_depth\n",
    "    )\n",
    "\n",
    "    Transformer_topknet.to(device)\n",
    "\n",
    "    Transformer_topk_throughput = benchmark_time(\n",
    "        Transformer_topknet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_topk_flops, Transformer_topk_params = benchmark_flops(\n",
    "        Transformer_topknet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_topk_{type}', 'Throughput': Transformer_topk_throughput, 'Flops': Transformer_topk_flops, 'Params': Transformer_topk_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "\n",
    "    print('##### Transformer_evit ####')\n",
    "    Transformer_evitnet = EViT(\n",
    "        image_size=img_size,\n",
    "        patch_size=patch_size,\n",
    "        num_classes=10,\n",
    "        dim=att_dim,\n",
    "        depth=depth,\n",
    "        heads=heads,\n",
    "        mlp_dim=mlp_dim,\n",
    "        n_patch = max_tokens_per_depth\n",
    "    )\n",
    "\n",
    "    Transformer_evitnet.to(device)\n",
    "\n",
    "    Transformer_evit_throughput = benchmark_time(\n",
    "        Transformer_evitnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_evit_flops, Transformer_evit_params = benchmark_flops(\n",
    "        Transformer_evitnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_evit_{type}', 'Throughput': Transformer_evit_throughput, 'Flops': Transformer_evit_flops, 'Params': Transformer_evit_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "    print('##### Transformer_tome ####')\n",
    "    Transformer_tomenet = ToMe(\n",
    "        image_size=img_size,\n",
    "        patch_size=patch_size,\n",
    "        num_classes=10,\n",
    "        dim=att_dim,\n",
    "        depth=depth,\n",
    "        heads=heads,\n",
    "        mlp_dim=mlp_dim,\n",
    "        n_patch = tome_patches\n",
    "    )\n",
    "\n",
    "    Transformer_tomenet.to(device)\n",
    "\n",
    "    Transformer_tome_throughput = benchmark_time(\n",
    "        Transformer_tomenet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_tome_flops, Transformer_tome_params = benchmark_flops(\n",
    "        Transformer_tomenet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_tome_{type}', 'Throughput': Transformer_tome_throughput, 'Flops': Transformer_tome_flops, 'Params': Transformer_tome_params})], ignore_index=True)\n",
    "\n",
    "    for wavelet in ['haar', 'sym2', 'db4', 'db2']:\n",
    "            print('##### Transformer_DWTViT ####')\n",
    "            Transformer_DWTViT = DWTViT_pruning(\n",
    "                image_size=img_size,\n",
    "                patch_size=patch_size,\n",
    "                num_classes=10,\n",
    "                dim=att_dim,\n",
    "                depth=depth,\n",
    "                heads=heads,\n",
    "                mlp_dim=mlp_dim,\n",
    "                wavelet=wavelet,\n",
    "                pruning_locations = DW_patches\n",
    "            )\n",
    "        \n",
    "            Transformer_DWTViT.to(device)\n",
    "        \n",
    "            Transformer_DWTViT_throughput = benchmark_time(\n",
    "                Transformer_DWTViT,\n",
    "                device=device,\n",
    "                verbose=False,\n",
    "                runs=runs_time,\n",
    "                batch_size=batch_size,\n",
    "                input_size=input_size,\n",
    "                throw_out = 0.25\n",
    "            )\n",
    "        \n",
    "            Transformer_DWTViT_flops, Transformer_DWTViT_params = benchmark_flops(\n",
    "                Transformer_DWTViT,\n",
    "                device=device,\n",
    "                verbose=False,\n",
    "                runs=runs_time,\n",
    "                batch_size=batch_size,\n",
    "                input_size=input_size\n",
    "            )\n",
    "        \n",
    "        \n",
    "            df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_DWTViT_{wavelet}_{type}', 'Throughput': Transformer_DWTViT_throughput, 'Flops': Transformer_DWTViT_flops, 'Params': Transformer_DWTViT_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "    for strategy in ['cA', 'cD']:\n",
    "            print('##### Transformer_DWTViT ####')\n",
    "            Transformer_DWTViT = DWTViT_pruning(\n",
    "                image_size=img_size,\n",
    "                patch_size=patch_size,\n",
    "                num_classes=10,\n",
    "                dim=att_dim,\n",
    "                depth=depth,\n",
    "                heads=heads,\n",
    "                mlp_dim=mlp_dim,\n",
    "                wavelet='haar',\n",
    "                pruning_locations = DW_patches,\n",
    "                strategy = strategy\n",
    "            )\n",
    "        \n",
    "            Transformer_DWTViT.to(device)\n",
    "        \n",
    "            Transformer_DWTViT_throughput = benchmark_time(\n",
    "                Transformer_DWTViT,\n",
    "                device=device,\n",
    "                verbose=False,\n",
    "                runs=runs_time,\n",
    "                batch_size=batch_size,\n",
    "                input_size=input_size,\n",
    "                throw_out = 0.25\n",
    "            )\n",
    "        \n",
    "            Transformer_DWTViT_flops, Transformer_DWTViT_params = benchmark_flops(\n",
    "                Transformer_DWTViT,\n",
    "                device=device,\n",
    "                verbose=False,\n",
    "                runs=runs_time,\n",
    "                batch_size=batch_size,\n",
    "                input_size=input_size\n",
    "            )\n",
    "        \n",
    "        \n",
    "            df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_DWTViT_{strategy}_{type}', 'Throughput': Transformer_DWTViT_throughput, 'Flops': Transformer_DWTViT_flops, 'Params': Transformer_DWTViT_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "    print('##### Transformer_DWTViT_pruning ####')\n",
    "    Transformer_DWTViT_pruning = DWTViT_pruning(\n",
    "        image_size=img_size,\n",
    "        patch_size=patch_size,\n",
    "        num_classes=10,\n",
    "        dim=att_dim,\n",
    "        depth=depth,\n",
    "        heads=heads,\n",
    "        mlp_dim=mlp_dim,\n",
    "        wavelet='db4',\n",
    "        pruning_locations=DW_patches,\n",
    "        # pruning_aggressiveness=1,\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_pruning.to(device)\n",
    "\n",
    "    Transformer_DWTViT_pruning_throughput = benchmark_time(\n",
    "        Transformer_DWTViT_pruning,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_pruning_flops, Transformer_DWTViT_pruning_params = benchmark_flops(\n",
    "        Transformer_DWTViT_pruning,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_DWTViT_pruning_{type}', 'Throughput': Transformer_DWTViT_pruning_throughput, 'Flops': Transformer_DWTViT_pruning_flops, 'Params': Transformer_DWTViT_pruning_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "\n",
    "    print('##### Transformer_DWTViT_gini ####')\n",
    "    Transformer_DWTViT_gini = DWTViT_gini(\n",
    "        image_size=img_size,\n",
    "        patch_size=patch_size,\n",
    "        num_classes=10,\n",
    "        dim=att_dim,\n",
    "        depth=depth,\n",
    "        heads=heads,\n",
    "        mlp_dim=mlp_dim,\n",
    "        wavelet='haar',\n",
    "        pruning_locations=DW_patches,\n",
    "        soglia = 0.5\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_gini.to(device)\n",
    "\n",
    "    Transformer_DWTViT_gini_throughput = benchmark_time(\n",
    "        Transformer_DWTViT_gini,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_gini_flops, Transformer_DWTViT_gini_params = benchmark_flops(\n",
    "        Transformer_DWTViT_gini,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_DWTViT_gini_{type}', 'Throughput': Transformer_DWTViT_gini_throughput, 'Flops': Transformer_DWTViT_gini_flops, 'Params': Transformer_DWTViT_gini_params})], ignore_index=True)\n",
    "\n",
    "\n",
    "\n",
    "    print('##### Transformer_DWTViT_quantile ####')\n",
    "    Transformer_DWTViT_quantile = DWTViT_quantile(\n",
    "        image_size=img_size,\n",
    "        patch_size=patch_size,\n",
    "        num_classes=10,\n",
    "        dim=att_dim,\n",
    "        depth=depth,\n",
    "        heads=heads,\n",
    "        mlp_dim=mlp_dim,\n",
    "        wavelet='haar',\n",
    "        pruning_locations=DW_patches,\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_quantile.to(device)\n",
    "\n",
    "    Transformer_DWTViT_quantile_throughput = benchmark_time(\n",
    "        Transformer_DWTViT_quantile,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "    Transformer_DWTViT_quantile_flops, Transformer_DWTViT_quantile_params = benchmark_flops(\n",
    "        Transformer_DWTViT_quantile,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'Transformer_DWTViT_quantile_{type}', 'Throughput': Transformer_DWTViT_quantile_throughput, 'Flops': Transformer_DWTViT_quantile_flops, 'Params': Transformer_DWTViT_quantile_params})], ignore_index=True)\n",
    "\n",
    "    \n",
    "    # Patch merger\n",
    "    print('##### PatchMerger ####')\n",
    "    PatchMergerViTnet = PatchMergerViT(\n",
    "        image_size = img_size,\n",
    "        patch_size = patch_size,\n",
    "        num_classes = 10,\n",
    "        dim = att_dim,\n",
    "        depth = depth,\n",
    "        heads = heads,\n",
    "        mlp_dim = mlp_dim,\n",
    "        patch_merge_layers = patch_merge_layers\n",
    "    )\n",
    "\n",
    "    PatchMergerViTnet.to(device)\n",
    "\n",
    "\n",
    "    PatchMerger_throughput = benchmark_time(\n",
    "        PatchMergerViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size,\n",
    "        throw_out = 0.25\n",
    "    )\n",
    "\n",
    "\n",
    "    PatchMerger_flops, PatchMerger_params = benchmark_flops(\n",
    "        PatchMergerViTnet,\n",
    "        device=device,\n",
    "        verbose=False,\n",
    "        runs=runs_time,\n",
    "        batch_size=batch_size,\n",
    "        input_size=input_size\n",
    "    )\n",
    "\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame({'Modello': f'PatchMerger_{type}', 'Throughput': PatchMerger_throughput, 'Flops': PatchMerger_flops, 'Params': PatchMerger_params})], ignore_index=True)\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "676bf11a-905f-460f-88c0-ac758d06ff52",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_patch_list_reversed(total_patches, cut):\n",
    "    remaining_patches = total_patches\n",
    "    removal_list = []\n",
    "    \n",
    "    for i in range(12):\n",
    "        if i % 3 == 0 and i != 0:\n",
    "            removed_patches = remaining_patches - int(remaining_patches * (cut / 100))\n",
    "            remaining_patches -= removed_patches\n",
    "        else:\n",
    "            removed_patches = 0\n",
    "        \n",
    "        removal_list.append(removed_patches)\n",
    "    \n",
    "    return removal_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82477279-f999-4426-9a98-f5e62c9e4e41",
   "metadata": {
    "id": "82477279-f999-4426-9a98-f5e62c9e4e41"
   },
   "source": [
    "# Creazione tabella"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "zF0yxkTalOAc",
   "metadata": {
    "id": "zF0yxkTalOAc"
   },
   "source": [
    "## GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "h3h1nW2llOAo",
   "metadata": {
    "id": "h3h1nW2llOAo"
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "# device = torch.device('cpu')\n",
    "\n",
    "img_size = 160\n",
    "batch_size = 64\n",
    "patch_size = 16\n",
    "runs_time = 10\n",
    "runs_flops = 10\n",
    "input_size = (3, img_size, img_size)\n",
    "\n",
    "DW_patches = [4, 6, 8]\n",
    "\n",
    "att_dim = 768\n",
    "depth = 12\n",
    "heads = 12\n",
    "mlp_dim = att_dim * 4\n",
    "\n",
    "total_patches = int((img_size/patch_size)**2)\n",
    "cut = 75\n",
    "max_tokens_per_depth = create_patch_list(total_patches, cut)\n",
    "\n",
    "patch_merge_layers = [(2, max_tokens_per_depth[3]),(5, max_tokens_per_depth[6]),(8, max_tokens_per_depth[9])] \n",
    "\n",
    "tome_patches = create_patch_list_reversed(total_patches, cut)\n",
    "\n",
    "type = 'Base'\n",
    "\n",
    "df_gpu = pd.DataFrame(columns = ['Modello', 'Throughput', 'Flops', 'Params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "01UR_jedlOAo",
   "metadata": {
    "id": "01UR_jedlOAo",
    "outputId": "9a8df396-f3be-4b67-c122-83e3fb7ebf98"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##### ViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### ATSViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### MultiViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_topk ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_evit ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_tome ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_pruning ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_gini ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_quantile ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### PatchMerger ####\n",
      "cuda:0\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "df_gpu = get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, tome_patches, DW_patches, type, df_gpu, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5ZmL9RyNlOAp",
   "metadata": {
    "id": "5ZmL9RyNlOAp"
   },
   "outputs": [],
   "source": [
    "att_dim = 384\n",
    "depth = 12\n",
    "heads = 6\n",
    "mlp_dim = att_dim * 4\n",
    "\n",
    "type = 'Small'\n",
    "\n",
    "df_small_gpu = pd.DataFrame(columns = ['Modello', 'Throughput', 'Flops', 'Params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bwRvlOwrlOAp",
   "metadata": {
    "id": "bwRvlOwrlOAp",
    "outputId": "9cf70fa4-892e-434b-b934-736747f150f3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##### ViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### ATSViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### MultiViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_topk ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_evit ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_tome ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_pruning ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_gini ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_quantile ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### PatchMerger ####\n",
      "cuda:0\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "df_small_gpu = get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, tome_patches, DW_patches, type, df_small_gpu, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "oQ1D4iiclOAp",
   "metadata": {
    "id": "oQ1D4iiclOAp"
   },
   "outputs": [],
   "source": [
    "att_dim = 192\n",
    "depth = 12\n",
    "heads = 3\n",
    "mlp_dim = att_dim * 4\n",
    "\n",
    "type = 'Tiny'\n",
    "\n",
    "df_tiny_gpu = pd.DataFrame(columns = ['Modello', 'Throughput', 'Flops', 'Params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4RaWtHY0lOAp",
   "metadata": {
    "id": "4RaWtHY0lOAp",
    "outputId": "8e41d0a8-f917-4765-f511-d5759279c954"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##### ViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### ATSViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### MultiViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_topk ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_evit ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_tome ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_pruning ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_gini ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### Transformer_DWTViT_quantile ####\n",
      "cuda:0\n",
      "cuda:0\n",
      "##### PatchMerger ####\n",
      "cuda:0\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "df_tiny_gpu = get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, tome_patches, DW_patches, type, df_tiny_gpu, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "874b15d4-9a35-4938-bd53-d76df4c6956e",
   "metadata": {},
   "source": [
    "### Display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3796e3b8-b49d-45c2-aafc-baa0e9881458",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Modello</th>\n",
       "      <th>Throughput</th>\n",
       "      <th>Flops</th>\n",
       "      <th>Params</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ViT_net_Base</td>\n",
       "      <td>793.98</td>\n",
       "      <td>8.6502</td>\n",
       "      <td>85629706.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ATS_Base</td>\n",
       "      <td>931.37</td>\n",
       "      <td>5.3217</td>\n",
       "      <td>85629706.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>MultiViT_Base</td>\n",
       "      <td>1027.11</td>\n",
       "      <td>6.0855</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Transformer_topk_Base</td>\n",
       "      <td>1015.16</td>\n",
       "      <td>6.0855</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Transformer_evit_Base</td>\n",
       "      <td>1029.55</td>\n",
       "      <td>6.1682</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Transformer_tome_Base</td>\n",
       "      <td>1024.95</td>\n",
       "      <td>6.0855</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Transformer_DWTViT_haar_Base</td>\n",
       "      <td>1144.23</td>\n",
       "      <td>4.6162</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Transformer_DWTViT_sym2_Base</td>\n",
       "      <td>1117.55</td>\n",
       "      <td>4.7107</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Transformer_DWTViT_db4_Base</td>\n",
       "      <td>1070.08</td>\n",
       "      <td>4.8595</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Transformer_DWTViT_db2_Base</td>\n",
       "      <td>1119.48</td>\n",
       "      <td>4.7107</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Transformer_DWTViT_cA_Base</td>\n",
       "      <td>1138.14</td>\n",
       "      <td>4.6162</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>Transformer_DWTViT_cD_Base</td>\n",
       "      <td>1141.86</td>\n",
       "      <td>4.6162</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Transformer_DWTViT_pruning_Base</td>\n",
       "      <td>1069.05</td>\n",
       "      <td>4.8595</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Transformer_DWTViT_gini_Base</td>\n",
       "      <td>1219.03</td>\n",
       "      <td>4.6162</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Transformer_DWTViT_quantile_Base</td>\n",
       "      <td>1052.55</td>\n",
       "      <td>4.6162</td>\n",
       "      <td>85626634.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>PatchMerger_Base</td>\n",
       "      <td>1104.02</td>\n",
       "      <td>5.8645</td>\n",
       "      <td>85634314.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             Modello Throughput   Flops      Params\n",
       "0                       ViT_net_Base     793.98  8.6502  85629706.0\n",
       "1                           ATS_Base     931.37  5.3217  85629706.0\n",
       "2                      MultiViT_Base    1027.11  6.0855  85626634.0\n",
       "3              Transformer_topk_Base    1015.16  6.0855  85626634.0\n",
       "4              Transformer_evit_Base    1029.55  6.1682  85626634.0\n",
       "5              Transformer_tome_Base    1024.95  6.0855  85626634.0\n",
       "6       Transformer_DWTViT_haar_Base    1144.23  4.6162  85626634.0\n",
       "7       Transformer_DWTViT_sym2_Base    1117.55  4.7107  85626634.0\n",
       "8        Transformer_DWTViT_db4_Base    1070.08  4.8595  85626634.0\n",
       "9        Transformer_DWTViT_db2_Base    1119.48  4.7107  85626634.0\n",
       "10        Transformer_DWTViT_cA_Base    1138.14  4.6162  85626634.0\n",
       "11        Transformer_DWTViT_cD_Base    1141.86  4.6162  85626634.0\n",
       "12   Transformer_DWTViT_pruning_Base    1069.05  4.8595  85626634.0\n",
       "13      Transformer_DWTViT_gini_Base    1219.03  4.6162  85626634.0\n",
       "14  Transformer_DWTViT_quantile_Base    1052.55  4.6162  85626634.0\n",
       "15                  PatchMerger_Base    1104.02  5.8645  85634314.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Modello</th>\n",
       "      <th>Throughput</th>\n",
       "      <th>Flops</th>\n",
       "      <th>Params</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ViT_net_Small</td>\n",
       "      <td>2026.76</td>\n",
       "      <td>2.1806</td>\n",
       "      <td>21581962.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ATS_Small</td>\n",
       "      <td>1718.16</td>\n",
       "      <td>1.3545</td>\n",
       "      <td>21581962.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>MultiViT_Small</td>\n",
       "      <td>2472.32</td>\n",
       "      <td>1.5381</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Transformer_topk_Small</td>\n",
       "      <td>2472.74</td>\n",
       "      <td>1.5381</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Transformer_evit_Small</td>\n",
       "      <td>2496.32</td>\n",
       "      <td>1.5588</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Transformer_tome_Small</td>\n",
       "      <td>2496.37</td>\n",
       "      <td>1.5381</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Transformer_DWTViT_haar_Small</td>\n",
       "      <td>2412.40</td>\n",
       "      <td>1.1703</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Transformer_DWTViT_sym2_Small</td>\n",
       "      <td>2334.14</td>\n",
       "      <td>1.1940</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Transformer_DWTViT_db4_Small</td>\n",
       "      <td>2308.58</td>\n",
       "      <td>1.2312</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Transformer_DWTViT_db2_Small</td>\n",
       "      <td>2391.03</td>\n",
       "      <td>1.1940</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Transformer_DWTViT_cA_Small</td>\n",
       "      <td>2370.81</td>\n",
       "      <td>1.1703</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>Transformer_DWTViT_cD_Small</td>\n",
       "      <td>2424.17</td>\n",
       "      <td>1.1703</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Transformer_DWTViT_pruning_Small</td>\n",
       "      <td>2321.16</td>\n",
       "      <td>1.2312</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Transformer_DWTViT_gini_Small</td>\n",
       "      <td>2691.85</td>\n",
       "      <td>1.1703</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Transformer_DWTViT_quantile_Small</td>\n",
       "      <td>2059.96</td>\n",
       "      <td>1.1703</td>\n",
       "      <td>21579658.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>PatchMerger_Small</td>\n",
       "      <td>2650.52</td>\n",
       "      <td>1.4832</td>\n",
       "      <td>21584266.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              Modello Throughput   Flops      Params\n",
       "0                       ViT_net_Small    2026.76  2.1806  21581962.0\n",
       "1                           ATS_Small    1718.16  1.3545  21581962.0\n",
       "2                      MultiViT_Small    2472.32  1.5381  21579658.0\n",
       "3              Transformer_topk_Small    2472.74  1.5381  21579658.0\n",
       "4              Transformer_evit_Small    2496.32  1.5588  21579658.0\n",
       "5              Transformer_tome_Small    2496.37  1.5381  21579658.0\n",
       "6       Transformer_DWTViT_haar_Small    2412.40  1.1703  21579658.0\n",
       "7       Transformer_DWTViT_sym2_Small    2334.14  1.1940  21579658.0\n",
       "8        Transformer_DWTViT_db4_Small    2308.58  1.2312  21579658.0\n",
       "9        Transformer_DWTViT_db2_Small    2391.03  1.1940  21579658.0\n",
       "10        Transformer_DWTViT_cA_Small    2370.81  1.1703  21579658.0\n",
       "11        Transformer_DWTViT_cD_Small    2424.17  1.1703  21579658.0\n",
       "12   Transformer_DWTViT_pruning_Small    2321.16  1.2312  21579658.0\n",
       "13      Transformer_DWTViT_gini_Small    2691.85  1.1703  21579658.0\n",
       "14  Transformer_DWTViT_quantile_Small    2059.96  1.1703  21579658.0\n",
       "15                  PatchMerger_Small    2650.52  1.4832  21584266.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Modello</th>\n",
       "      <th>Throughput</th>\n",
       "      <th>Flops</th>\n",
       "      <th>Params</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ViT_net_Tiny</td>\n",
       "      <td>4091.76</td>\n",
       "      <td>0.5543</td>\n",
       "      <td>5483338.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ATS_Tiny</td>\n",
       "      <td>2252.83</td>\n",
       "      <td>0.3444</td>\n",
       "      <td>5483338.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>MultiViT_Tiny</td>\n",
       "      <td>5053.16</td>\n",
       "      <td>0.3929</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Transformer_topk_Tiny</td>\n",
       "      <td>5100.24</td>\n",
       "      <td>0.3929</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Transformer_evit_Tiny</td>\n",
       "      <td>5414.91</td>\n",
       "      <td>0.3981</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Transformer_tome_Tiny</td>\n",
       "      <td>5193.41</td>\n",
       "      <td>0.3929</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Transformer_DWTViT_haar_Tiny</td>\n",
       "      <td>3530.08</td>\n",
       "      <td>0.3007</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Transformer_DWTViT_sym2_Tiny</td>\n",
       "      <td>3400.96</td>\n",
       "      <td>0.3067</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Transformer_DWTViT_db4_Tiny</td>\n",
       "      <td>3308.14</td>\n",
       "      <td>0.3160</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Transformer_DWTViT_db2_Tiny</td>\n",
       "      <td>3206.66</td>\n",
       "      <td>0.3067</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Transformer_DWTViT_cA_Tiny</td>\n",
       "      <td>3540.06</td>\n",
       "      <td>0.3007</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>Transformer_DWTViT_cD_Tiny</td>\n",
       "      <td>3176.66</td>\n",
       "      <td>0.3007</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>Transformer_DWTViT_pruning_Tiny</td>\n",
       "      <td>3411.64</td>\n",
       "      <td>0.3160</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Transformer_DWTViT_gini_Tiny</td>\n",
       "      <td>4110.73</td>\n",
       "      <td>0.3007</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Transformer_DWTViT_quantile_Tiny</td>\n",
       "      <td>2858.91</td>\n",
       "      <td>0.3007</td>\n",
       "      <td>5481418.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>PatchMerger_Tiny</td>\n",
       "      <td>4612.93</td>\n",
       "      <td>0.3795</td>\n",
       "      <td>5484490.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             Modello Throughput   Flops     Params\n",
       "0                       ViT_net_Tiny    4091.76  0.5543  5483338.0\n",
       "1                           ATS_Tiny    2252.83  0.3444  5483338.0\n",
       "2                      MultiViT_Tiny    5053.16  0.3929  5481418.0\n",
       "3              Transformer_topk_Tiny    5100.24  0.3929  5481418.0\n",
       "4              Transformer_evit_Tiny    5414.91  0.3981  5481418.0\n",
       "5              Transformer_tome_Tiny    5193.41  0.3929  5481418.0\n",
       "6       Transformer_DWTViT_haar_Tiny    3530.08  0.3007  5481418.0\n",
       "7       Transformer_DWTViT_sym2_Tiny    3400.96  0.3067  5481418.0\n",
       "8        Transformer_DWTViT_db4_Tiny    3308.14  0.3160  5481418.0\n",
       "9        Transformer_DWTViT_db2_Tiny    3206.66  0.3067  5481418.0\n",
       "10        Transformer_DWTViT_cA_Tiny    3540.06  0.3007  5481418.0\n",
       "11        Transformer_DWTViT_cD_Tiny    3176.66  0.3007  5481418.0\n",
       "12   Transformer_DWTViT_pruning_Tiny    3411.64  0.3160  5481418.0\n",
       "13      Transformer_DWTViT_gini_Tiny    4110.73  0.3007  5481418.0\n",
       "14  Transformer_DWTViT_quantile_Tiny    2858.91  0.3007  5481418.0\n",
       "15                  PatchMerger_Tiny    4612.93  0.3795  5484490.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(None, None, None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "display(df_gpu), display(df_small_gpu), display(df_tiny_gpu)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "d58ed68e-186c-4630-b336-3c1f913b1553"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
