{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dc9aeda9-0fcd-4114-a24e-450d1edcda85",
   "metadata": {},
   "source": [
    "\n",
    "## Code tested with pytorch 2.3.0, timm 0.9.12\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb00d22-bb4f-4f00-b4c7-32d175bd7bba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import copy\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import json\n",
    "from collections import deque\n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models import *\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "\n",
    "import timm\n",
    "parent_directory = os.path.abspath('..')\n",
    "sys.path.append(parent_directory )\n",
    "from utils import ECE, getTestData\n",
    "# code tested with pytorch 2.3.0, timm 0.9.12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b29c407",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', size=10)\n",
    "plt.rc('axes', titlesize=10)\n",
    "plt.rc('axes', labelsize=10)\n",
    "plt.rc('xtick', labelsize=10)\n",
    "plt.rc('ytick', labelsize=10)\n",
    "plt.rc('legend', fontsize=10)\n",
    "plt.rc('legend', title_fontsize=10)\n",
    "plt.rc('figure', titlesize=10)\n",
    "\n",
    "plt.rc('legend', framealpha=0.0)\n",
    "plt.rc('lines', linewidth=1.5)\n",
    "\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{times}'\n",
    "plt.rcParams['font.family'] = \"Times New Roman\"\n",
    "plt.rcParams['font.serif'] = ['Computer Modern']\n",
    "\n",
    "matplotlib.rcParams['figure.dpi'] = 150\n",
    "\n",
    "colorize = lambda dim_list, colormap : [colormap(i) for i in np.linspace(0.0, 0.92, len(dim_list))]\n",
    "cm = matplotlib.cm.get_cmap('Spectral')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8081e91f",
   "metadata": {},
   "source": [
    "### change folder name for V2\n",
    "import glob\n",
    "for path in glob.glob('D:/data/v2/'):\n",
    "    if os.path.isdir(path):\n",
    "        print(path)\n",
    "        for subpath in glob.glob(f'{path}/*'):\n",
    "            print(subpath)\n",
    "            dirname = os.path.basename(subpath)\n",
    "            os.rename(subpath, os.path.sep.join([os.path.dirname(subpath), dirname.zfill(4)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c54255b",
   "metadata": {},
   "outputs": [],
   "source": [
    "real_json_path = 'real.json'\n",
    "f = open(real_json_path)\n",
    "real = json.load(f)\n",
    "f.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c38ddd5d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os, glob\n",
    "dataset_path = '/mnt/data2/data/in1k/val/'\n",
    "val_indices = []\n",
    "val_labels = []\n",
    "real_labels = []\n",
    "for path in glob.glob(dataset_path):\n",
    "    if os.path.isdir(path):\n",
    "        for class_id, subpath in enumerate(sorted(glob.glob(f'{path}/*'))):\n",
    "            #print(subpath)\n",
    "            for file in sorted(os.listdir(subpath)):\n",
    "                if file.endswith('.JPEG'):\n",
    "                    num = int(file[-13:-5])\n",
    "                    val_indices.append(num)\n",
    "                    val_labels.append(class_id)\n",
    "                    real_labels.append(real[num-1])\n",
    "for i, (idx, v, r) in enumerate(zip(val_indices,val_labels,real_labels)):\n",
    "    if idx < 5:\n",
    "        print(f'Image ID: {idx:5d} class: {v:4d}, real label: {r}')\n",
    "print(len(real_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6556ca30-aabf-4855-8841-894f1d4ae78f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare models\n",
    "\n",
    "models ={\n",
    "    'alexnet':{\n",
    "        'model':'alexnet',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':0.71e9,\n",
    "        'params_official':61.1e6,\n",
    "        'weights':'AlexNet_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet18_v1':{\n",
    "        'model':'resnet18',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':1.81e9,\n",
    "        'params_official':11.7e6,\n",
    "        'weights':'ResNet18_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet34_v1':{\n",
    "        'model':'resnet34',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':3.66e9,\n",
    "        'params_official':21.8e6,\n",
    "        'weights':'ResNet34_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'resnet50_v1':{\n",
    "        'model':'resnet50',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.09e9,\n",
    "        'params_official':25.6e6,\n",
    "        'weights':'ResNet50_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet50_v2':{\n",
    "        'model':'resnet50',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.09e9,\n",
    "        'params_official':25.6e6,\n",
    "        'weights':'ResNet50_Weights.IMAGENET1K_V2',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet101_v1':{\n",
    "        'model':'resnet101',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':7.8e9,\n",
    "        'params_official':44.5e6,\n",
    "        'weights':'ResNet101_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet101_v2':{\n",
    "        'model':'resnet101',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':7.8e9,\n",
    "        'params_official':44.5e6,\n",
    "        'weights':'ResNet101_Weights.IMAGENET1K_V2',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'resnet152_v1':{\n",
    "        'model':'resnet152',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':11.51e9,\n",
    "        'params_official':60.2e6,\n",
    "        'weights':'ResNet152_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'resnet152_v2':{\n",
    "        'model':'resnet152',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':11.51e9,\n",
    "        'params_official':60.2e6,\n",
    "        'weights':'ResNet152_Weights.IMAGENET1K_V2',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    # Pytorch Convnexts\n",
    "    'convnext_t':{\n",
    "        'model':'convnext_tiny',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.46e9,\n",
    "        'params_official':28.6e6,\n",
    "        'weights':'ConvNeXt_Tiny_Weights.IMAGENET1K_V1',\n",
    "        'resize':236,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'convnext_s':{\n",
    "        'model':'convnext_small',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':8.68e9,\n",
    "        'params_official':50.2e6,\n",
    "        'weights':'ConvNeXt_Small_Weights.IMAGENET1K_V1',\n",
    "        'resize':230,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'convnext_b':{\n",
    "        'model':'convnext_base',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':15.36e9,\n",
    "        'params_official':88.6e6,\n",
    "        'weights':'ConvNeXt_Base_Weights.IMAGENET1K_V1',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "\n",
    "    'convnext_l':{\n",
    "        'model':'convnext_large',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':34.36e9,\n",
    "        'params_official':197.8e6,\n",
    "        'weights':'ConvNeXt_Large_Weights.IMAGENET1K_V1',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':8,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    # Pytorch Efficientnets\n",
    "    'efficientnet_b0':{\n",
    "        'model':'efficientnet_b0',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':0.39e9,\n",
    "        'params_official':5.3e6,\n",
    "        'weights':'EfficientNet_B0_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientnet_b1':{\n",
    "        'model':'efficientnet_b1',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':0.69e9,\n",
    "        'params_official':7.8e6,\n",
    "        'weights':'EfficientNet_B1_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':240,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientnet_b2':{\n",
    "        'model':'efficientnet_b2',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':1.09e9,\n",
    "        'params_official':9.1e6,\n",
    "        'weights':'EfficientNet_B2_Weights.IMAGENET1K_V1',\n",
    "        'resize':288,\n",
    "        'crop':288,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientnet_b3':{\n",
    "        'model':'efficientnet_b3',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':1.83e9,\n",
    "        'params_official':12.2e6,\n",
    "        'weights':'EfficientNet_B3_Weights.IMAGENET1K_V1',\n",
    "        'resize':320,\n",
    "        'crop':300,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientnet_b4':{\n",
    "        'model':'efficientnet_b4',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.39e9,\n",
    "        'params_official':19.3e6,\n",
    "        'weights':'EfficientNet_B4_Weights.IMAGENET1K_V1',\n",
    "        'resize':384,\n",
    "        'crop':380,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientnet_b5':{\n",
    "        'model':'efficientnet_b5',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':10.27e9,\n",
    "        'params_official':30.4e6,\n",
    "        'weights':'EfficientNet_B5_Weights.IMAGENET1K_V1',\n",
    "        'resize':456,\n",
    "        'crop':456,\n",
    "        'batch_size':8,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientnet_b6':{\n",
    "        'model':'efficientnet_b6',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':19.07e9,\n",
    "        'params_official':43.0e6,\n",
    "        'weights':'EfficientNet_B6_Weights.IMAGENET1K_V1',\n",
    "        'resize':528,\n",
    "        'crop':528,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientnet_b7_380':{\n",
    "        'model':'efficientnet_b7',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':37.75e9*(380/600)**2,\n",
    "        'params_official':66.3e6,\n",
    "        'weights':'EfficientNet_B7_Weights.IMAGENET1K_V1',\n",
    "        'resize':384,\n",
    "        'crop':380,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientnet_b7_456':{\n",
    "        'model':'efficientnet_b7',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':37.75e9*(456/600)**2,\n",
    "        'params_official':66.3e6,\n",
    "        'weights':'EfficientNet_B7_Weights.IMAGENET1K_V1',\n",
    "        'resize':456,\n",
    "        'crop':456,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientnet_b7':{\n",
    "        'model':'efficientnet_b7',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':37.75e9,\n",
    "        'params_official':66.3e6,\n",
    "        'weights':'EfficientNet_B7_Weights.IMAGENET1K_V1',\n",
    "        'resize':600,\n",
    "        'crop':600,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    # torch swin\n",
    "    'swin_t':{\n",
    "        'model':'swin_t',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.49e9,\n",
    "        'params_official':28.3e6,\n",
    "        'weights':'Swin_T_Weights.IMAGENET1K_V1',\n",
    "        'resize':232,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'swin_s':{\n",
    "        'model':'swin_s',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':8.74e9,\n",
    "        'params_official':49.6e6,\n",
    "        'weights':'Swin_S_Weights.IMAGENET1K_V1',\n",
    "        'resize':246,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'swin_b':{\n",
    "        'model':'swin_b',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':15.43e9,\n",
    "        'params_official':87.8e6,\n",
    "        'weights':'Swin_B_Weights.IMAGENET1K_V1',\n",
    "        'resize':238,\n",
    "        'crop':224,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'swin_v2_t':{\n",
    "        'model':'swin_v2_t',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':5.94e9,\n",
    "        'params_official':28.4e6,\n",
    "        'weights':'Swin_V2_T_Weights.IMAGENET1K_V1',\n",
    "        'resize':260,\n",
    "        'crop':256,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'swin_v2_s':{\n",
    "        'model':'swin_v2_s',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':11.55e9,\n",
    "        'params_official':49.7e6,\n",
    "        'weights':'Swin_V2_S_Weights.IMAGENET1K_V1',\n",
    "        'resize':260,\n",
    "        'crop':256,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'swin_v2_b':{\n",
    "        'model':'swin_v2_b',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':20.32e9,\n",
    "        'params_official':87.9e6,\n",
    "        'weights':'Swin_V2_B_Weights.IMAGENET1K_V1',\n",
    "        'resize':272,\n",
    "        'crop':256,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    # torch ViTs\n",
    "    'vit_b_32':{\n",
    "        'model':'vit_b_32',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':4.41e9,\n",
    "        'params_official':88.2e6,\n",
    "        'weights':'ViT_B_32_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'vit_l_32':{\n",
    "        'model':'vit_l_32',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':15.38e9,\n",
    "        'params_official':306.5e6,\n",
    "        'weights':'ViT_L_32_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'vit_b_16_224':{\n",
    "        'model':'vit_b_16',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':17.56e9,\n",
    "        'params_official':86.6e6,\n",
    "        'weights':'ViT_B_16_Weights.IMAGENET1K_V1',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'vit_b_16_384':{\n",
    "        'model':'vit_b_16',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':55.48e9,\n",
    "        'params_official':86.9e6,\n",
    "        'weights':'ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'vit_l_16_224':{\n",
    "        'model':'vit_l_16',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':61.55e9,\n",
    "        'params_official':304.3e6,\n",
    "        'weights':'ViT_L_16_Weights.IMAGENET1K_V1',\n",
    "        'resize':242,\n",
    "        'crop':224,\n",
    "        'batch_size':8,\n",
    "        'interpolation':InterpolationMode.BILINEAR},\n",
    "    \n",
    "    'vit_l_16_512':{\n",
    "        'model':'vit_l_16',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':361.99e9,\n",
    "        'params_official':305.2e6,\n",
    "        'weights':'ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1',\n",
    "        'resize':512,\n",
    "        'crop':512,\n",
    "        'batch_size':8,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'vit_h_14_518':{\n",
    "        'model':'vit_h_14',\n",
    "        'source':'torchvision',\n",
    "        'macs_official':1016.72e9,\n",
    "        'params_official':633.5e6,\n",
    "        'weights':'ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1',\n",
    "        'resize':518,\n",
    "        'crop':518,\n",
    "        'batch_size':2,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    # timm DEITs\n",
    "    'deit_t':{\n",
    "        'model':'deit_tiny_patch16_224',\n",
    "        'source':'timm',\n",
    "        'macs_official':1.3e9,\n",
    "        'params_official':5.7e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'deit_s':{\n",
    "        'model':'deit_small_patch16_224',\n",
    "        'source':'timm',\n",
    "        'macs_official':4.6e9,\n",
    "        'params_official':22.1e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'deit_b':{\n",
    "        'model':'deit_base_patch16_224',\n",
    "        'source':'timm',\n",
    "        'macs_official':17.6e9,\n",
    "        'params_official':86.6e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    # timm DEIT3s\n",
    "    'deit3_s_224_v2':{\n",
    "        'model':'deit3_small_patch16_224_in21ft1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':4.6e9,\n",
    "        'params_official':22.1e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'deit3_b_224_v2':{\n",
    "        'model':'deit3_base_patch16_224_in21ft1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':17.6e9,\n",
    "        'params_official':86.6e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'deit3_l_224_v2':{\n",
    "        'model':'deit3_large_patch16_224_in21ft1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':61.6e9,\n",
    "        'params_official':304.4e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    \n",
    "    'deit3_l_384_v2':{\n",
    "        'model':'deit3_large_patch16_384_in21ft1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':191.2e9,\n",
    "        'params_official':304.8e6,\n",
    "        'weights':'None',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'deit3_h_224_v2':{\n",
    "        'model':'deit3_huge_patch14_224_in21ft1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':167.4e9,\n",
    "        'params_official':632.1e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    # timm efficientvits\n",
    "    'efficientvit_b0_224':{\n",
    "        'model':'efficientvit_b0.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':0.1e9,\n",
    "        'params_official':3.4e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_b1_224':{\n",
    "        'model':'efficientvit_b1.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':0.52e9,\n",
    "        'params_official':9.1e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':128,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_b1_256':{\n",
    "        'model':'efficientvit_b1.r256_in1k',\n",
    "        'macs_official':0.68e9,\n",
    "        'params_official':9.1e6,\n",
    "        'source':'timm',\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':256,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_b1_288':{\n",
    "        'model':'efficientvit_b1.r288_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':0.86e9,\n",
    "        'params_official':9.1e6,\n",
    "        'weights':'None',\n",
    "        'resize':288,\n",
    "        'crop':288,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_b2_224':{\n",
    "        'model':'efficientvit_b2.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':1.6e9,\n",
    "        'params_official':24e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_b2_256':{\n",
    "        'model':'efficientvit_b2.r256_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':2.1e9,\n",
    "        'params_official':24e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':256,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_b2_288':{\n",
    "        'model':'efficientvit_b2.r288_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':2.6e9,\n",
    "        'params_official':24e6,\n",
    "        'weights':'None',\n",
    "        'resize':288,\n",
    "        'crop':288,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_b3_224':{\n",
    "        'model':'efficientvit_b3.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':4.0e9,\n",
    "        'params_official':49e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_b3_256':{\n",
    "        'model':'efficientvit_b3.r256_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':5.2e9,\n",
    "        'params_official':49e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':256,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_b3_288':{\n",
    "        'model':'efficientvit_b3.r288_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':6.5e9,\n",
    "        'params_official':49e6,\n",
    "        'weights':'None',\n",
    "        'resize':288,\n",
    "        'crop':288,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'efficientvit_l1_224':{\n",
    "        'model':'efficientvit_l1.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':5.3e9,\n",
    "        'params_official':53e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l2_224':{\n",
    "        'model':'efficientvit_l2.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':6.9e9,\n",
    "        'params_official':64e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':64,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l2_256':{\n",
    "        'model':'efficientvit_l2.r256_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':9.19e9,\n",
    "        'params_official':64e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':256,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l2_288':{\n",
    "        'model':'efficientvit_l2.r288_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':11e9,\n",
    "        'params_official':64e6,\n",
    "        'weights':'None',\n",
    "        'resize':288,\n",
    "        'crop':288,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l2_384':{\n",
    "        'model':'efficientvit_l2.r384_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':20e9,\n",
    "        'params_official':64e6,\n",
    "        'weights':'None',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l3_224':{\n",
    "        'model':'efficientvit_l3.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':28e9,\n",
    "        'params_official':246e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l3_256':{\n",
    "        'model':'efficientvit_l3.r256_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':36e9,\n",
    "        'params_official':246e6,\n",
    "        'weights':'None',\n",
    "        'resize':256,\n",
    "        'crop':256,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l3_320':{\n",
    "        'model':'efficientvit_l3.r320_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':56e9,\n",
    "        'params_official':246e6,\n",
    "        'weights':'None',\n",
    "        'resize':320,\n",
    "        'crop':320,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_l3_384':{\n",
    "        'model':'efficientvit_l3.r384_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':81e9,\n",
    "        'params_official':246e6,\n",
    "        'weights':'None',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'efficientvit_m5_224':{\n",
    "        'model':'efficientvit_m5.r224_in1k',\n",
    "        'source':'timm',\n",
    "        'macs_official':0.5e9,\n",
    "        'params_official':12.5e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':16,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'internimage_t':{\n",
    "        'model':'internimage_t_224',\n",
    "        'cfg':'/mnt/data1/GitHub/InternImage/classification/configs/internimage_t_1k_224.yaml',\n",
    "        'ckpt':'/mnt/data1/GitHub/InternImage/classification/weights/internimage_t_1k_224.pth',\n",
    "        'source':'internimage',\n",
    "        'macs_official':5.0e9,\n",
    "        'params_official':30.9e6,\n",
    "        'weights':'None',\n",
    "        'resize':224,\n",
    "        'crop':224,\n",
    "        'batch_size':32,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "\n",
    "    'internimage_l':{\n",
    "        'model':'internimage_l_384',\n",
    "        'cfg':'/mnt/data1/GitHub/InternImage/classification/configs/internimage_l_22kto1k_384.yaml',\n",
    "        'ckpt':'/mnt/data1/GitHub/InternImage/classification/weights/internimage_l_22kto1k_384.pth',\n",
    "        'source':'internimage',\n",
    "        'macs_official':108e9,\n",
    "        'params_official':223e6,\n",
    "        'weights':'None',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'internimage_xl':{\n",
    "        'model':'internimage_xl_384',\n",
    "        'cfg':'/mnt/data1/GitHub/InternImage/classification/configs/internimage_xl_22kto1k_384.yaml',\n",
    "        'ckpt':'/mnt/data1/GitHub/InternImage/classification/weights/internimage_xl_22kto1k_384.pth',\n",
    "        'source':'internimage',\n",
    "        'macs_official':163e9,\n",
    "        'params_official':335e6,\n",
    "        'weights':'None',\n",
    "        'resize':384,\n",
    "        'crop':384,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "    'internimage_g':{\n",
    "        'model':'internimage_g_512',\n",
    "        'cfg':'/mnt/data1/GitHub/InternImage/classification/configs/internimage_g_22kto1k_512.yaml',\n",
    "        'ckpt':'/mnt/data1/GitHub/InternImage/classification/weights/internimage_g_22kto1k_512.pth',\n",
    "        'source':'internimage',\n",
    "        'macs_official':2700e9,\n",
    "        'params_official':3076e6,\n",
    "        'weights':'None',\n",
    "        'resize':512,\n",
    "        'crop':512,\n",
    "        'batch_size':4,\n",
    "        'interpolation':InterpolationMode.BICUBIC},\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e8a785f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_single_save(net, model_dict, data_path='/mnt/data3/data/in1k', dataset='in1k', amp=True):\n",
    "    divider = 1\n",
    "    testset= getTestData(data_path,\n",
    "                                resize_size=model_dict['resize'],\n",
    "                                crop_size=model_dict['crop'],\n",
    "                                interpolation=model_dict['interpolation'])\n",
    "    dataloader  = DataLoader(\n",
    "                    testset,  \n",
    "                    batch_size=max(model_dict['batch_size']//divider,1), \n",
    "                    num_workers=16, \n",
    "                    shuffle=False, \n",
    "                    pin_memory=True\n",
    "            )\n",
    "    \n",
    "    print('{:<30}  {:<8}'.format('Computational complexity: ', model_dict['macs_official']))\n",
    "    print('{:<30}  {:<8}'.format('Number of parameters: ', model_dict['params_official']))\n",
    "    \n",
    "\n",
    "    for key in model_dict.keys():\n",
    "        print(f'{key:20s}: {model_dict[key]}')\n",
    "    # evaluates the model on a single GPU and save\n",
    "    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n",
    "    net.eval()\n",
    "    epoch_loss, epoch_acc = 0, 0\n",
    "    t_misc = time.time()\n",
    "        \n",
    "    runtimes = {k: deque(maxlen=10) for k in ['model', 'misc', 'util']}\n",
    "    preds_all = None\n",
    "    labels_all = None\n",
    "    robustness = {}\n",
    "    net_times = []\n",
    "    batch_sizes = []\n",
    "    t_start = time.time()\n",
    "        \n",
    "    for batch, (data, target) in (pbar := tqdm(enumerate(dataloader), total=len(dataloader))):\n",
    "\n",
    "        data, target = data.cuda(), target.int().cuda()\n",
    "        t_st = time.time()\n",
    "        with torch.autocast(device_type=\"cuda\", dtype=dtype, enabled=amp):\n",
    "            output = net(data)\n",
    "            t_net = time.time()\n",
    "            error = - output[range(target.shape[0]),target] + output.logsumexp(dim=1)\n",
    "            indiv_loss = error\n",
    "            loss = error.mean()\n",
    "\n",
    "        indiv_acc = (output.argmax(dim=1) == target).float()\n",
    "        acc = indiv_acc.mean()\n",
    "\n",
    "        t_en = time.time()\n",
    "            \n",
    "        epoch_loss += indiv_loss.detach().sum()\n",
    "        epoch_acc  += indiv_acc.detach().sum()\n",
    "            \n",
    "        runtimes['model'].append(t_en - t_st)   # model forward and backward\n",
    "        runtimes['misc'].append(t_st - t_misc)  # data loading, logging, etc.\n",
    "        runtimes['util'].append((t_en - t_st) / (t_en - t_misc) * 100)\n",
    "           \n",
    "        pbar.set_description(\n",
    "                f\"loss: {loss.item():.3f} acc: {acc.item():.3f} \" +\\\n",
    "                f'  time: {t_en-t_start}'\n",
    "                )\n",
    "        t_misc = time.time()\n",
    "        if preds_all is None:\n",
    "            preds_all = output.detach().float().cpu().numpy()\n",
    "            labels_all = target.view(target.size(0), -1).detach().float().cpu().numpy()\n",
    "        else:\n",
    "                    \n",
    "            preds_all = np.vstack((preds_all,output.detach().float().cpu().numpy()))\n",
    "            labels_all = np.vstack((labels_all,target.view(target.size(0), -1).detach().float().cpu().numpy()))\n",
    "            net_times.append(t_net-t_st)\n",
    "            batch_sizes.append(len(target))\n",
    "            #break\n",
    "            \n",
    "    ece_results = ECE(labels_all,preds_all,11)  \n",
    "    ece_results['net_times'] = net_times\n",
    "    ece_results['batch_sizes'] = batch_sizes\n",
    "    model_info = copy.deepcopy(model_dict)\n",
    "    model_info.pop('interpolation', None)\n",
    "    ece_results['model'] = model_info\n",
    "    return ece_results, labels_all, preds_all\n",
    "    #print(f'preds_all: {preds_all.shape}, preds: {output.shape}')\n",
    "\n",
    "def load_output(model_name,folder='output', dataset=''):\n",
    "    file = os.path.join(folder, model_name+dataset+'_summary.json')\n",
    "    print(file)\n",
    "    f = open(file)\n",
    "    data = json.load(f)\n",
    "    f.close()\n",
    "    return data\n",
    "\n",
    "def save_output(model, summary, labels=None, preds=None, dataset='in1k'): \n",
    "    print('saving output')\n",
    "    if labels is not None and preds is not None:\n",
    "        results_raw = {'predictions':preds.tolist(),\n",
    "                        'labels':labels.tolist()\n",
    "                        }  \n",
    "        json_object = json.dumps(results_raw)\n",
    "        raw_file = os.path.join('output', model+dataset+'_raw.json')\n",
    "        with open(raw_file, \"w\") as outfile:\n",
    "            outfile.write(json_object)\n",
    "        print(f'raw results saved {raw_file}')\n",
    "    \n",
    "    summary_object = json.dumps(summary)\n",
    "    summary_file = os.path.join('output', model+dataset+'_summary.json')\n",
    "    with open(summary_file, \"w\") as outfile:\n",
    "        outfile.write(summary_object)\n",
    "    print(f'summary saved {summary_file}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07d43987-d5c6-4120-bef8-19a4026dbcc8",
   "metadata": {},
   "source": [
    "## Evaluate all models and save the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0270c1c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_list = list(models.keys())\n",
    "print(models_list)\n",
    "\n",
    "\n",
    "dataset = ''\n",
    "dataset_path = '/mnt/data2/data/in1k/'\n",
    "\n",
    "#dataset = 'inv2'\n",
    "#dataset_path = '/mnt/data2/data/in1k/v2/'\n",
    "\n",
    "force = 0 # this will overwrite previously saved results\n",
    "for model in models_list[::1]:\n",
    "    print(model)\n",
    "    model_dict = models[model]\n",
    "    raw_file = os.path.join('output', model+dataset+'_raw.json')\n",
    "    summary_file = os.path.join('output', model+dataset+'_summary.json')\n",
    "    if os.path.isfile(raw_file) and os.path.isfile(summary_file) and not force:\n",
    "        print(f'existing file found, skipping eval of {model}, MACs: {model_dict['macs_official']}')\n",
    "        summary = load_output(model,dataset=dataset)\n",
    "        continue\n",
    "    \n",
    "    if model_dict['source'] == 'torchvision':\n",
    "        exp = f\"{model_dict['model']}(weights={model_dict['weights']}).cuda()\"\n",
    "        net = eval(exp)\n",
    "    elif model_dict['source'] == 'timm':\n",
    "        net = timm.create_model(model_dict['model'], pretrained=True).cuda()\n",
    "    summary, labels, preds = eval_single_save(net, model_dict,amp=False,data_path=dataset_path, dataset=dataset)\n",
    "    del net\n",
    "    torch.cuda.empty_cache()\n",
    "    acc = summary['accuracy']\n",
    "    print(f'Accuracy: {acc:.5f}')\n",
    "    save_output(model,summary, labels, preds, dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8866dd3a-b250-4a9d-ad10-faaa22261eb5",
   "metadata": {},
   "source": [
    "## Load saved results and evaluate Little-Big pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e631425-c615-4160-8424-fdd2810898e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "vits          = ['vit_b_32','vit_l_32','vit_b_16_224','vit_b_16_384','vit_l_16_512','vit_h_14_518']\n",
    "\n",
    "efficientnets = ['efficientnet_b0','efficientnet_b1','efficientnet_b2',\n",
    "                 'efficientnet_b3','efficientnet_b4','efficientnet_b5',\n",
    "                 'efficientnet_b6','efficientnet_b7']\n",
    "\n",
    "efficientnet_b7 = ['efficientnet_b7_380','efficientnet_b7_456','efficientnet_b7']\n",
    "\n",
    "efficientvit_bs = ['efficientvit_b0_224',\n",
    "                 'efficientvit_b1_224','efficientvit_b1_256','efficientvit_b1_288',\n",
    "                 'efficientvit_b2_224','efficientvit_b2_256','efficientvit_b2_288',\n",
    "                 'efficientvit_b3_224','efficientvit_b3_256','efficientvit_b3_288',\n",
    "                 ]\n",
    "\n",
    "efficientvit_ls = ['efficientvit_l1_224',                 \n",
    "                 'efficientvit_l2_224','efficientvit_l2_256','efficientvit_l2_288',\n",
    "                 'efficientvit_l2_384',\n",
    "                 'efficientvit_l3_224','efficientvit_l3_256',\n",
    "                 'efficientvit_l3_320','efficientvit_l3_384',\n",
    "                 ]\n",
    "\n",
    "deit3s          = ['deit3_s_224_v2', 'deit3_b_224_v2', 'deit3_l_224_v2', \n",
    "                   'deit3_h_224_v2', 'deit3_l_384_v2']\n",
    "swin_v1         = ['swin_t', 'swin_s', 'swin_b']\n",
    "swin_v2         = ['swin_v2_t', 'swin_v2_s', 'swin_v2_b']\n",
    "\n",
    "deits         = ['deit_t', 'deit_s', 'deit_b']\n",
    "convnexts     = ['convnext_t', 'convnext_s', 'convnext_b', 'convnext_l']\n",
    "mix           = ['efficientvit_b0_224','efficientvit_b1_224', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2']\n",
    "\n",
    "mix2          = ['efficientvit_b0_224','efficientvit_b1_224', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'deit_s']\n",
    "\n",
    "mix3          = ['efficientvit_b0_224','efficientvit_b1_224', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'deit3_s_224_v2']\n",
    "\n",
    "group = efficientnets[:]\n",
    "\n",
    "data_v1, data_v2 = [],[]\n",
    "counter = 0\n",
    "for i, m in enumerate(group):\n",
    "    \n",
    "    v1 = load_output(m)\n",
    "    v2 = load_output(m, dataset='inv2')\n",
    "    if i == 0:\n",
    "        base_acc = v1['accuracy']\n",
    "        base_macs = v1['model']['macs_official']\n",
    "        print(base_acc)\n",
    "    data_v1.append(v1)\n",
    "    data_v2.append(v2)\n",
    "    print('='*20)\n",
    "    print(v1['model']['macs_official']/1e9)\n",
    "    print(v1['accuracy'])\n",
    "    print(f'acc improvement: {v1['accuracy']-base_acc:.4f}')\n",
    "    print(f'additional Macs cost: {(v1['model']['macs_official']-base_macs)/base_macs:.2f}')\n",
    "    counter +=1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12e19e28",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def big_little(little, big, threshold=0.9, p=1, real_labels=real_labels):\n",
    "    \n",
    "    little, big = little, big\n",
    "    \n",
    "    little_name = little['model']['model']\n",
    "    big_name = big['model']['model']\n",
    "    \n",
    "    ece1 = little['ECE']\n",
    "    macs1 = float(little['model']['macs_official'])    \n",
    "    params1 = float(little['model']['params'])    \n",
    "    confs1 = np.array(little['confidence_raw']).astype(np.float32)\n",
    "    preds1 = np.array(little['prediction_raw']).astype(np.int32)\n",
    "    labels1 = np.array(little['labels_raw']).squeeze().astype(np.int32)\n",
    "    correct1 = (preds1 == labels1).astype(np.float32)\n",
    "    \n",
    "    ece2 = big['ECE']\n",
    "    macs2 = float(big['model']['macs_official'])\n",
    "    params2 = float(big['model']['params'])  \n",
    "    confs2 = np.array(big['confidence_raw']).astype(np.float32)\n",
    "    preds2 = np.array(big['prediction_raw']).astype(np.int32)\n",
    "    labels2 = np.array(big['labels_raw']).squeeze().astype(np.int32)\n",
    "    correct2 = (preds2 == labels2).astype(np.float32)\n",
    "\n",
    "    assert(np.sum(labels2-labels1) == 0)\n",
    "    total = len(labels2) # total predictions\n",
    "\n",
    "    \n",
    "    both_correct = np.sum(correct1*correct2)\n",
    "    correct1_only = np.sum(correct1 * (1 - correct2))\n",
    "    correct2_only = np.sum(correct2 * (1 - correct1))\n",
    "    both_wrong = np.sum((1 - correct1) * (1 - correct2))\n",
    "    total = len(labels2)\n",
    "\n",
    "    # \n",
    "    solvable1 = (confs1 >= threshold)\n",
    "    solvable2 = (1 - solvable1).astype(bool)\n",
    "    assert np.sum(solvable1) + np.sum(solvable2) == total\n",
    "    num_solvable1 = np.sum(solvable1)\n",
    "    num_solvable2 = np.sum(solvable2)\n",
    "\n",
    "    preds_meta = np.zeros_like(preds2)\n",
    "    preds_meta[solvable1] = preds1[solvable1]\n",
    "    preds_meta[solvable2] = preds2[solvable2]\n",
    "\n",
    "    correct_meta = (preds_meta == labels2).astype(np.float32)\n",
    "    \n",
    "    macs_meta = (total * macs1 + num_solvable2 * macs2) / total\n",
    "    num_correct_meta = np.sum(correct_meta)\n",
    "    meta_acc = num_correct_meta/total\n",
    "    \n",
    "    if real_labels is not None:\n",
    "        correct1_real = [pred in real_labels[i] for i, pred in enumerate(preds1) if real_labels[i]]\n",
    "        correct2_real = [pred in real_labels[i] for i, pred in enumerate(preds2) if real_labels[i]]\n",
    "        correct_meta_real = [pred in real_labels[i] for i, pred in enumerate(preds_meta) if real_labels[i]]\n",
    "        acc1_real = np.mean(correct1_real)\n",
    "        acc2_real = np.mean(correct2_real)\n",
    "        acc_meta_real = np.mean(correct_meta_real)\n",
    "\n",
    "    else:\n",
    "        acc1_real, acc2_real, acc_meta_real = 0, 0, 0\n",
    "        \n",
    "    if p:\n",
    "        print('='*40)\n",
    "        print(f'little : {little_name:28}  || big : {big_name:28}')\n",
    "        print(f'Acc   1: {np.sum(correct1)/50000:.5f}                       || Acc   2: {np.sum(correct2)/50000:.5f}')\n",
    "        print(f'ReaL  1: {acc1_real:.5f}                       || ReaL  2: {acc2_real:.5f}')\n",
    "        print(f'ECE   1: {ece1:.5f}                       || ECE   2: {ece2:.5f}')\n",
    "        print(f'MACS  1: {macs1:20}          || MACS  2: {macs2:20}')\n",
    "        print(f'PARAM 1: {params1:20}          || PARAM 2: {params1:20}')\n",
    "        \n",
    "        print(f'meta accuracy: {meta_acc:.5f}, meta accuracy real: {acc_meta_real:.5f}')\n",
    "        print(f'threshold {threshold}')\n",
    "        print(f'stage 1 solvable: {num_solvable1} stage 2 solvable: {num_solvable2}') \n",
    "        print(f'total solved solvable: {num_solvable1+num_solvable2}') \n",
    "        print(f'stage 2 acc {np.sum(correct2)/total}')\n",
    "        print(f'delta accuracy: {meta_acc-np.sum(correct2)/total:.5f}')\n",
    "        print(f'meta model MACs: {macs_meta} ')\n",
    "        print(f'MACs change from stage 2: {(macs_meta-macs2)/macs2*100:3.0f} %')\n",
    "\n",
    "    results = {\n",
    "        'acc_meta':meta_acc,\n",
    "        'macs_meta':macs_meta,\n",
    "        'threshold':threshold,\n",
    "        'little':little['model'],\n",
    "        'big':big['model'],\n",
    "        'macs_little':macs1,\n",
    "        'macs_big':macs2,\n",
    "        'solvable_little':num_solvable1/total,\n",
    "        'acc_little':np.sum(correct1)/total,\n",
    "        'acc_big':np.sum(correct2)/total,\n",
    "        'acc_meta_real':acc_meta_real,\n",
    "        'acc_little_real':acc1_real,\n",
    "        'acc_big_real':acc2_real,\n",
    "    }\n",
    "    return results\n",
    "\n",
    "def get_curve_multi(little_data, big_data, little_data_v2=None, big_data_v2=None, N=100):\n",
    "    \n",
    "    thresholds = [1/N * n for n in range(0,N+1)]\n",
    "    solvable_little_list = [0 for _ in range(0,N+1)]\n",
    "    solvable_little_v2_list = [0 for _ in range(0,N+1)]\n",
    "    \n",
    "    macs_list = [0 for _ in range(0,N+1)]\n",
    "    macs_v2_list = [0 for _ in range(0,N+1)]\n",
    "    \n",
    "    accs_list = [0 for _ in range(0,N+1)]    \n",
    "    accs_real_list = [0 for _ in range(0,N+1)]\n",
    "    accs_v2_list = [0 for _ in range(0,N+1)]\n",
    "    \n",
    "    for i, thr in enumerate(thresholds):\n",
    "        result = big_little(little_data,big_data, thr, p=0)\n",
    "        if little_data_v2 is not None and big_data_v2 is not None:\n",
    "            result_v2 = big_little(little_data_v2,big_data_v2, thr, p=0, real_labels=None)\n",
    "\n",
    "        solvable_little_list[i] = result['solvable_little']\n",
    "        solvable_little_v2_list[i] = result_v2['solvable_little']\n",
    "        \n",
    "        accs_list[i] = result['acc_meta']\n",
    "        macs_list[i] = result['macs_meta']\n",
    "        macs_v2_list[i] = result_v2['macs_meta']\n",
    "        accs_real_list[i] = result['acc_meta_real']\n",
    "        accs_v2_list[i] = result_v2['acc_meta']\n",
    "        \n",
    "    results = {'big': big_data['model'],\n",
    "               'little': little_data['model'],\n",
    "              'thresholds':thresholds,\n",
    "              'solvable_little':solvable_little_list,\n",
    "              'solvable_little_v2':solvable_little_v2_list,\n",
    "              'macs': macs_list,\n",
    "              'macs_v2': macs_v2_list,\n",
    "              'acc':accs_list,\n",
    "              'acc_real':accs_real_list,\n",
    "              'acc_v2':accs_v2_list,\n",
    "              'acc_big':result['acc_big'],\n",
    "              'acc_big_real':result['acc_big_real'],\n",
    "              'acc_big_v2':result_v2['acc_big'],\n",
    "              'acc_little':result['acc_little'],\n",
    "              'acc_little_real':result['acc_little_real'],\n",
    "              'acc_little_v2':result_v2['acc_little']}\n",
    "    return results\n",
    "\n",
    "N = 50\n",
    "ks_results = []\n",
    "for i, (dv1, dv2) in enumerate(zip(data_v1[:-1],data_v2[:-1])):\n",
    "    print(f\"evaluating {dv1['model']['model']} and {data_v1[-1]['model']['model']}\")\n",
    "    ks_results.append(get_curve_multi(dv1,data_v1[-1],dv2, data_v2[-1],N))\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, constrained_layout=True, figsize=(10,6))\n",
    "\n",
    "for ks in ks_results:    \n",
    "    axes[0].plot(ks['macs'][0],ks['acc'][0],'+',markersize=10,label=f\"{ks['little']['model']}\")\n",
    "    axes[1].plot(ks['macs'][0],ks['acc_real'][0],'+',markersize=10,label=f\"{ks['little']['model']}\")\n",
    "    axes[2].plot(ks['macs'][0],ks['acc_v2'][0],'+',markersize=10,label=f\"{ks['little']['model']}\")\n",
    "    \n",
    "for ks in ks_results:\n",
    "    axes[0].plot(ks['macs'],ks['acc'],label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[1].plot(ks['macs'],ks['acc_real'],label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[2].plot(ks['macs'],ks['acc_v2'],label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "\n",
    "axes[0].plot(ks_results[0]['macs'],np.ones_like(ks_results[0]['macs']) * ks['acc_big'],label=f\"Baseline-{ks['big']['model']}\")\n",
    "axes[1].plot(ks_results[0]['macs'],np.ones_like(ks_results[0]['macs']) * ks['acc_big_real'],label=f\"Baseline-{ks['big']['model']}\")\n",
    "axes[2].plot(ks_results[0]['macs'],np.ones_like(ks_results[0]['macs']) * ks['acc_big_v2'],label=f\"Baseline-{ks['big']['model']}\")\n",
    "\n",
    "axes[0].plot(ks_results[0]['big']['macs_official'],ks['acc_big'],'*',markersize=10)\n",
    "axes[1].plot(ks_results[0]['big']['macs_official'],ks['acc_big_real'],'*',markersize=10)\n",
    "axes[2].plot(ks_results[0]['big']['macs_official'],ks['acc_big_v2'],'*',markersize=10)\n",
    "\n",
    "axes[0].set_xscale('log')\n",
    "axes[1].set_xscale('log')\n",
    "axes[2].set_xscale('log')\n",
    "\n",
    "axes[0].set_xlabel('MACs')\n",
    "axes[1].set_xlabel('MACs')\n",
    "axes[2].set_xlabel('MACs')\n",
    "\n",
    "axes[0].set_ylabel('Top-1 Accuracy')\n",
    "\n",
    "axes[0].set_title('ImageNet-1K')\n",
    "axes[1].set_title('ImageNet-ReaL')\n",
    "axes[2].set_title('ImageNet-V2')\n",
    "\n",
    "axes[2].legend(loc='lower right')\n",
    "j = 0\n",
    "k, s = ks_results[j]['big']['model'], ks_results[j]['little']['model']\n",
    "macs_little, acc_little = ks_results[j]['little']['macs'], ks_results[j]['acc_little']\n",
    "macs_big, acc_big = ks_results[j]['big']['macs'], ks_results[j]['acc_big']\n",
    "A, B = ks_results[j]['acc'], ks_results[j]['macs'], \n",
    "for acc, mac in zip(A,B):\n",
    "    #print(acc, mac)\n",
    "    if acc-acc_big>=0:\n",
    "        print(f'small model {s:20s} accuracy: {acc_little}, MACs: {macs_little}')\n",
    "        print(f'large model {k:20s} accuracy: {acc_big}, MACs: {macs_big}')\n",
    "        print(f'smallest model without losing acc: MACs {mac}, reduction {(macs_big-mac)/macs_big}')\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c6a6ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# colorize(width_list, plt.cm.plasma)\n",
    "a = 0\n",
    "macs_divider = 1e9\n",
    "fig, axes = plt.subplots(1, 3, constrained_layout=True, figsize=(9,4))\n",
    "cm = matplotlib.cm.get_cmap('Spectral')\n",
    "for ks,color in zip(ks_results[a:],colorize(ks_results, cm)):\n",
    "    label = ks['little']['model'][:]\n",
    "    marker = '*'\n",
    "    axes[0].plot(np.array(ks['macs'][0])/macs_divider,ks['acc'][0],'x',markersize=4,color=color,label=label)\n",
    "    axes[1].plot(np.array(ks['macs'][0])/macs_divider,ks['acc_real'][0],'x',markersize=4,color=color,label=label)\n",
    "    axes[2].plot(np.array(ks['macs_v2'][0])/macs_divider,ks['acc_v2'][0],'x',markersize=4,color=color,label=label)\n",
    "    \n",
    "for ks,color in zip(ks_results[a:],colorize(ks_results, cm)):\n",
    "    axes[0].plot(np.array(ks['macs'])/macs_divider,ks['acc'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[1].plot(np.array(ks['macs'])/macs_divider,ks['acc_real'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[2].plot(np.array(ks['macs_v2'])/macs_divider,ks['acc_v2'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "\n",
    "axes[0].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big'],':',color='black')\n",
    "axes[1].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big_real'],':',color='black')\n",
    "axes[2].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big_v2'],':',color='black')\n",
    "\n",
    "label_big = ks['big']['model'][:]\n",
    "axes[0].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big'],'*',color='red',markersize=10,label=label_big)\n",
    "axes[1].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big_real'],'*',color='red',markersize=10,label=label_big)\n",
    "axes[2].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big_v2'],'*',color='red',markersize=10,label=label_big)\n",
    "\n",
    "axes[0].set_xscale('log')\n",
    "axes[1].set_xscale('log')\n",
    "axes[2].set_xscale('log')\n",
    "\n",
    "axes[0].set_xlabel('GMACs')\n",
    "axes[1].set_xlabel('GMACs')\n",
    "axes[2].set_xlabel('GMACs')\n",
    "\n",
    "axes[0].set_ylabel('Top-1 Accuracy')\n",
    "\n",
    "axes[0].set_title('ImageNet-1K')\n",
    "axes[1].set_title('ImageNet-ReaL')\n",
    "axes[2].set_title('ImageNet-V2')\n",
    "\n",
    "\n",
    "axes[2].legend(loc='lower right', ncol=1, fancybox=True, shadow=False)\n",
    "\n",
    "def select_threshold(ks_result,acc_tolerance=0.0, p=1):\n",
    "\n",
    "    k, s = ks_result['big']['model'], ks_result['little']['model']\n",
    "    k_size, s_size = ks_result['big']['params'], ks_result['little']['params']\n",
    "    macs_little, acc_little = ks_result['little']['macs_official'], ks_result['acc_little']\n",
    "    macs_big, acc_big = ks_result['big']['macs_official'], ks_result['acc_big']\n",
    "    acc_big_real, acc_big_v2 = ks_result['acc_big_real'], ks_result['acc_big_v2']\n",
    "    A, B, C, D, E = ks_result['acc'], ks_result['macs'], ks_result['acc_real'], ks_result['macs_v2'],ks_result['acc_v2']\n",
    "\n",
    "    for j, (acc, mac, acc_real, mac_v2, acc_v2) in enumerate(zip(A, B, C, D, E)):\n",
    "        if acc_v2-acc_big_v2>=acc_tolerance:\n",
    "        #if acc_real-acc_big_real>=acc_tolerance:\n",
    "        #if acc-acc_big>=acc_tolerance:\n",
    "            \n",
    "            if p:\n",
    "                print('='*40)\n",
    "                #print(A)\n",
    "                print(f'threshold: {j/N:.2f}')\n",
    "                print(f'Acc: V1 {acc:2.4f} ReaL {acc_real:2.4f} V2 {acc_v2:2.4f}')\n",
    "                print(f'big Acc: V1 {acc_big:2.4f} ReaL {acc_big_real:2.4f} V2 {acc_big_v2:2.4f}')\n",
    "                print(f'Params: small {s_size:15d}, large {k_size:15d}, meta {k_size+s_size:15d}, delta {(s_size)/k_size:.2f}')\n",
    "                print(f'small model {s:20s} accuracy: {acc_little}, MACs: {macs_little}')\n",
    "                print(f'large model {k:20s} accuracy: {acc_big}, MACs: {macs_big}')\n",
    "                print(f'accuracy improvement on ReaL: {acc_real - acc_big_real:.5f}')\n",
    "                print(f'accuracy improvement on V2: {acc_v2 - acc_big_v2:.5f}')\n",
    "                print(f' MACs {mac}, reduction {(macs_big-mac)/macs_big}')\n",
    "                print(f'V2 MACs {mac_v2}, reduction {(macs_big-mac_v2)/macs_big}')\n",
    "                print('-'*40)\n",
    "            return j, j/N, mac, (macs_big-mac)/macs_big\n",
    "            \n",
    "tolerance = -0.000\n",
    "best_threshold_idx = []\n",
    "best_macs = []\n",
    "best_reductions = []\n",
    "names = []\n",
    "\n",
    "for i, r in enumerate(ks_results[a:]):\n",
    "    names.append(r['little']['model'])\n",
    "    idx, thr, mac, reduction = select_threshold(r, tolerance)\n",
    "    best_macs.append(mac)\n",
    "    best_threshold_idx.append(idx)\n",
    "    best_reductions.append(reduction)\n",
    "    \n",
    "best_config_idx = np.argmax(best_reductions)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0589711-245e-4323-b1a7-489e6909b998",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find best config given an accuracy tolerance\n",
    "name_s, name_e = -2,None\n",
    "tolerance = -0.0000\n",
    "best_threshold_idx = []\n",
    "best_macs = []\n",
    "best_reductions = []\n",
    "names = []\n",
    "for i, r in enumerate(ks_results[a:]):\n",
    "    names.append(r['little']['model'])\n",
    "    idx, thr, mac, reduction = select_threshold(r, tolerance,0)\n",
    "    best_macs.append(mac)\n",
    "    best_threshold_idx.append(idx)\n",
    "    best_reductions.append(reduction)\n",
    "    \n",
    "    print(f\"{r['little']['model']}: Best relative reduction: {reduction:.3f}, Acc {r['acc'][idx]:.4f}, thr {idx/N:.2f}\")\n",
    "best_config_idx = np.argmax(best_reductions)\n",
    "print(best_config_idx)\n",
    "fig, axes = plt.subplots(2, 3, constrained_layout=True, figsize=(9,5))\n",
    "\n",
    "for ks,color in zip(ks_results[a:],colorize(ks_results, cm)):\n",
    "\n",
    "    label = f'{ks['little']['model'][name_s:name_e].capitalize()}-{ks['big']['model'][name_s:name_e].capitalize()}'\n",
    "    marker = '*'\n",
    "    axes[1,0].plot(np.array(ks['macs'][0])/macs_divider,ks['acc'][0],'x',markersize=4,color=color)\n",
    "    axes[1,1].plot(np.array(ks['macs'][0])/macs_divider,ks['acc_real'][0],'x',markersize=4,color=color)\n",
    "    axes[1,2].plot(np.array(ks['macs'][0])/macs_divider,ks['acc_v2'][0],'x',markersize=4,color=color)\n",
    "    \n",
    "    axes[0,0].plot(ks['thresholds'],1-np.array(ks['solvable_little']),'-',markersize=4,color=color,label=label)\n",
    "    axes[0,1].plot(ks['thresholds'],np.array(ks['macs'])/ks_results[a]['big']['macs_official'],'-',markersize=4,color=color,label=label)\n",
    "    axes[0,2].plot(ks['thresholds'],ks['acc'],'-',markersize=4,color=color,label=label)\n",
    "    \n",
    "for ks,color in zip(ks_results[a:],colorize(ks_results, cm)):\n",
    "    axes[1,0].plot(np.array(ks['macs'])/macs_divider,ks['acc'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[1,1].plot(np.array(ks['macs'])/macs_divider,ks['acc_real'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "    axes[1,2].plot(np.array(ks['macs_v2'])/macs_divider,ks['acc_v2'],color=color)#,label=f\"{ks['little']['model']}-{ks['big']['model']}\")\n",
    "\n",
    "\n",
    "\n",
    "axes[1,0].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big'],':',color='black')\n",
    "axes[1,1].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big_real'],':',color='black')\n",
    "axes[1,2].plot(np.array(ks_results[a]['macs'])/macs_divider,np.ones_like(ks_results[a]['macs']) * ks['acc_big_v2'],':',color='black')\n",
    "axes[0,1].plot(ks['thresholds'],np.ones_like(ks_results[a]['macs']),':',color='black')\n",
    "axes[0,2].plot(ks['thresholds'],np.ones_like(ks_results[a]['macs'])* ks['acc_big'],':',color='black')\n",
    "\n",
    "label_big = ks['big']['model'][name_s:name_e].capitalize()\n",
    "axes[1,0].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big'],'*',color='red',markersize=10,label=label_big)\n",
    "axes[1,1].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big_real'],'*',color='red',markersize=10,label=label_big)\n",
    "axes[1,2].plot(ks_results[a]['big']['macs_official']/macs_divider,ks['acc_big_v2'],'*',color='red',markersize=10,label=label_big)\n",
    "\n",
    "label_best = f\"{names[a+best_config_idx][name_s:name_e].capitalize()}-{label_big} ($T={best_threshold_idx[best_config_idx]/N:.2f}$)\"\n",
    "axes[1,0].plot(np.array(ks_results[a+best_config_idx]['macs'][best_threshold_idx[best_config_idx]])/macs_divider,\n",
    "               ks_results[a+best_config_idx]['acc'][best_threshold_idx[best_config_idx]],\n",
    "               '*',color='blue',markersize=6,label=label_best)\n",
    "axes[1,1].plot(np.array(ks_results[a+best_config_idx]['macs'][best_threshold_idx[best_config_idx]])/macs_divider,\n",
    "               ks_results[a+best_config_idx]['acc_real'][best_threshold_idx[best_config_idx]],\n",
    "               '*',color='blue',markersize=6,label=label_best)\n",
    "axes[1,2].plot(np.array(ks_results[a+best_config_idx]['macs_v2'][best_threshold_idx[best_config_idx]])/macs_divider,\n",
    "               ks_results[a+best_config_idx]['acc_v2'][best_threshold_idx[best_config_idx]],\n",
    "               '*',color='blue',markersize=6,label=label_best)\n",
    "\n",
    "\n",
    "axes[1,0].set_xscale('log')\n",
    "axes[1,1].set_xscale('log')\n",
    "axes[1,2].set_xscale('log')\n",
    "axes[1,0].set_ylim([0.75,0.85])\n",
    "\n",
    "axes[0,1].set_yscale('log')\n",
    "\n",
    "axes[1,0].set_ylabel('ImageNet-1k Accuracy')\n",
    "axes[1,1].set_ylabel('ReaL Accuracy')\n",
    "axes[1,2].set_ylabel('ImageNet-V2 Accuracy')\n",
    "\n",
    "axes[0,0].set_ylabel('|D*| / |D|')\n",
    "axes[0,1].set_ylabel('Relative GMACs')\n",
    "axes[0,2].set_ylabel('ImageNet-1k Accuracy')\n",
    "\n",
    "axes[0,0].set_xlabel(f'Threshold $(T )$')\n",
    "axes[0,1].set_xlabel(f'Threshold $(T )$')\n",
    "axes[0,2].set_xlabel(f'Threshold $(T )$')\n",
    "\n",
    "axes[1,0].set_xlabel('GMACs')\n",
    "axes[1,1].set_xlabel('GMACs')\n",
    "axes[1,2].set_xlabel('GMACs')\n",
    "\n",
    "axes[0,0].legend(loc='upper left', ncol=1, fancybox=True, shadow=False)\n",
    "axes[1,0].legend(loc='lower right', ncol=1, fancybox=True, shadow=False)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
