{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "2xiTG7wKekFU",
        "cHr4FT_nbRlU"
      ],
      "private_outputs": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2024 REDACTED FOR ANONYMITY\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ],
      "metadata": {
        "id": "H_pQkFTcHRwB"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Plots and analyses for \"Towards flexible perception with visual memory\""
      ],
      "metadata": {
        "id": "2C1oBHD5k8FO"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as mpl\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import json\n",
        "import time\n",
        "import math\n",
        "import copy\n",
        "from tqdm import tqdm\n",
        "from collections import Counter"
      ],
      "metadata": {
        "id": "_XMAJ31ZlHqZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def file_opener_default(path, mode):\n",
        "  return open(path, mode)\n",
        "\n",
        "file_opener = file_opener_default"
      ],
      "metadata": {
        "id": "4a71INC29qp6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def print_viewing_path_default(save_fig_path):\n",
        "  print(f'View at: {save_fig_path}')\n",
        "\n",
        "print_viewing_path = print_viewing_path_default"
      ],
      "metadata": {
        "id": "kI8kPr749rwq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "FIGURE_DIR = '/path/to/figures/'\n",
        "DATA_PARENT_DIR = '/path/to/data_parent_dir/'\n",
        "DATA_DIR = f'{DATA_PARENT_DIR}/data/'\n",
        "PRUNING_METRICS_DIR = f'{DATA_PARENT_DIR}/dataset-pruning-metrics'"
      ],
      "metadata": {
        "id": "anZ0_mVQ9umd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Defining color scheme & naming"
      ],
      "metadata": {
        "id": "yJJO6e_T2UUR"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# REDACTED FOR ANONYMITY\n",
        "featurizer_to_color = {'dinov2_vits14': '#669DF6',\n",
        "                       'dinov2_vitb14': '#1A73E8',\n",
        "                       'dinov2_vitl14': '#185ABC',\n",
        "                       'clip-vit_b16': '#EA4335',\n",
        "                       'clip-vit_l14': '#B31412',\n",
        "                       }\n",
        "featurizer_to_name = {'dinov2_vits14': 'DinoV2 ViT-S/14',\n",
        "                      'dinov2_vitb14': 'DinoV2 ViT-B/14',\n",
        "                      'dinov2_vitl14': 'DinoV2 ViT-L/14',\n",
        "                      'clip-vit_b16': 'CLIP ViT-B/16',\n",
        "                      'clip-vit_l14': 'CLIP ViT-L/14',\n",
        "                       }\n",
        "aggregator_to_color = {'PluralityVoting': '#e7298a',\n",
        "                       'DistanceVoting': '#1b9e77',\n",
        "                       'SoftmaxVoting': '#7570b3',\n",
        "                       'RankVoting': '#d95f02',\n",
        "                       }"
      ],
      "metadata": {
        "id": "DOQPPbJQ2TFR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Loading data"
      ],
      "metadata": {
        "id": "6rBvDd-NKokl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def read_neighbors_info_JSON(paths: list[str]) -> pd.DataFrame:\n",
        "  \"\"\"Read information about nearest neighbors from JSON.\"\"\"\n",
        "\n",
        "  t1 = time.time()\n",
        "  assert len(paths) == 1, 'Only one JSON file is supported'\n",
        "\n",
        "  path = paths[0]\n",
        "  assert path.endswith('.json')\n",
        "\n",
        "  df = pd.read_json(path, orient='index')\n",
        "  df.set_index('image_id', inplace=True)\n",
        "  df = df.sort_index()\n",
        "  df = df.reset_index(drop=False)\n",
        "\n",
        "  # reorder columns\n",
        "  df = df[['featurizer', 'image_id', 'image_class', 'neighbor_image_ids',\t'neighbor_classes', 'neighbor_distances']]\n",
        "\n",
        "  t2 = time.time()\n",
        "  print(f'Loading time: {round(t2 - t1)} seconds')\n",
        "\n",
        "  return df"
      ],
      "metadata": {
        "id": "PAyjyAvNUjkK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def remove_neighbors_identical_to_query(df: pd.DataFrame) -> pd.DataFrame:\n",
        "  \"\"\"Remove neighbors identical to queries from the dataframe.\n",
        "\n",
        "  If neighbors are derived from the training set, the first neighbor\n",
        "  is usually identical to the query image and needs to be removed.\n",
        "  \"\"\"\n",
        "\n",
        "  def _get_index(x):\n",
        "    if x.image_id in x.neighbor_image_ids:\n",
        "      return x.neighbor_image_ids.index(x.image_id)\n",
        "    else:\n",
        "      return 'QueryNotFound'\n",
        "\n",
        "  def _remove_query_index(x, c):\n",
        "    if x.query_index == 'QueryNotFound':\n",
        "      # in this case, remove last element of neighbor list\n",
        "      # to keep length consistent\n",
        "      return x[c][:-1]\n",
        "    else:\n",
        "      del x[c][x.query_index]\n",
        "      return x[c]\n",
        "\n",
        "  df_tmp = copy.deepcopy(df)\n",
        "  df_tmp['query_index'] = df.apply(lambda x: _get_index(x), axis=1)\n",
        "\n",
        "  for c in ['neighbor_image_ids', 'neighbor_classes', 'neighbor_distances']:\n",
        "    df_tmp[c] = df_tmp.apply(lambda x: _remove_query_index(x, c), axis=1)\n",
        "\n",
        "  del df_tmp['query_index']\n",
        "\n",
        "  return df_tmp"
      ],
      "metadata": {
        "id": "1QuiDLpbTx0B"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_accuracy(df: pd.DataFrame, k: int, prediction_column: str) -> float:\n",
        "  if type(df['image_class'][0]) is list:\n",
        "    accuracy = df.apply(lambda x: x[prediction_column][k] in x['image_class'], axis=1).mean()\n",
        "  else:\n",
        "    matches = df[prediction_column].apply(lambda x: x[k]) == df['image_class']\n",
        "    accuracy = matches.mean()\n",
        "  return accuracy"
      ],
      "metadata": {
        "id": "O00Yg0q4nrd2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def print_df_stats(df: pd.DataFrame) -> None:\n",
        "  \"\"\"Print some basic statistics about the data.\"\"\"\n",
        "\n",
        "  assert len(df['featurizer'].unique()) == 1\n",
        "  print(f'Dataframe stats for model {df[\"featurizer\"].unique()[0]}:')\n",
        "\n",
        "  num_neighbors = len(df.loc[0]['neighbor_image_ids'])\n",
        "  print(f'Found {len(df)} samples, {num_neighbors} neighbors available.')\n",
        "\n",
        "  min_dist = np.min(df['neighbor_distances'].apply(lambda y: np.min(y)))\n",
        "  max_dist = np.max(df['neighbor_distances'].apply(lambda y: np.max(y)))\n",
        "  print(f'min_dist: {min_dist}, max_dist: {max_dist}')\n",
        "\n",
        "  acc_0 = calculate_accuracy(df, 0, 'neighbor_classes') # accuracy of k=0\n",
        "  acc_1 = calculate_accuracy(df, 1, 'neighbor_classes') # accuracy of k=1\n",
        "\n",
        "  print(f'acc of neighbor 0: {acc_0}')\n",
        "  print(f'acc of neighbor 1: {acc_1}')"
      ],
      "metadata": {
        "id": "Cad2HIlRRu__"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def distance_normalization(df: pd.DataFrame) -> pd.DataFrame:\n",
        "  \"\"\"Normalize distances to [0, 1].\"\"\"\n",
        "\n",
        "  print('Normalizing distances to [0, 1]')\n",
        "\n",
        "  df_tmp = copy.deepcopy(df)\n",
        "  min_dist = np.min(df['neighbor_distances'].apply(lambda y: np.min(y)))\n",
        "  max_dist = np.max(df['neighbor_distances'].apply(lambda y: np.max(y)))\n",
        "\n",
        "  def normalize_distances(x):\n",
        "    for i in range(len(x)):\n",
        "      x[i] = (x[i] - min_dist) / (max_dist - min_dist)\n",
        "    return x\n",
        "\n",
        "  df_tmp['neighbor_distances'] = df_tmp.apply(lambda x: normalize_distances(x['neighbor_distances']), axis=1)\n",
        "\n",
        "  new_min_dist = np.min(df_tmp['neighbor_distances'].apply(lambda y: np.min(y)))\n",
        "  new_max_dist = np.max(df_tmp['neighbor_distances'].apply(lambda y: np.max(y)))\n",
        "  assert np.isclose(new_min_dist, 0.0)\n",
        "  assert np.isclose(new_max_dist, 1.0)\n",
        "\n",
        "  return df_tmp"
      ],
      "metadata": {
        "id": "BHl6usiGottE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def read_scaling_df(model,\n",
        "                    query_dataset='imagenet2012', query_split='validation',\n",
        "                    memory_dataset='imagenet2012', memory_split='train',\n",
        "                    size='full',\n",
        "                    verbose=True,\n",
        "                    remove_identical_neighbors=False,\n",
        "                    normalize_distances=False):\n",
        "\n",
        "  assert memory_split in ['train', 'validation', 'test', 'train-and-test']\n",
        "  assert query_split in ['train', 'validation', 'test']\n",
        "  memory_dataset_list = ['imagenet2012',\n",
        "                         'imagenet-jft-extension-20-classes',\n",
        "                         'imagenet2012-and-ninco',\n",
        "                         'ninco',\n",
        "                         'jft-with-vit22b-labels',\n",
        "                         'inaturalist']\n",
        "  assert memory_dataset in memory_dataset_list\n",
        "  query_dataset_list = ['imagenet2012',\n",
        "                        'imagenet-v2',\n",
        "                        'imagenet-r',\n",
        "                        'imagenet-a',\n",
        "                        'imagenet-sketch',\n",
        "                        'ninco',\n",
        "                        'imagenet-real',\n",
        "                        'inaturalist']\n",
        "  assert query_dataset in query_dataset_list\n",
        "\n",
        "  if not size == 'full':\n",
        "    size = f\"subsampled_{size}\"\n",
        "\n",
        "  path = f'{DATA_DIR}/memory-{memory_dataset}_msplit-{memory_split}_query-{query_dataset}_qsplit-{query_split}_{model}_{size}_neighbor_info.json'\n",
        "\n",
        "  df = read_neighbors_info_JSON(paths=[path])\n",
        "\n",
        "  if (memory_split == \"train\" and query_split == 'train') or remove_identical_neighbors:\n",
        "    print(f'Removing neighbors identical to query')\n",
        "    df = remove_neighbors_identical_to_query(df=df)\n",
        "\n",
        "  if normalize_distances:\n",
        "    df = distance_normalization(df=df)\n",
        "\n",
        "  if verbose:\n",
        "    print_df_stats(df=df)\n",
        "  return df"
      ],
      "metadata": {
        "id": "UpRjsSY0ot4l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def read_multiple_scaling_dfs(models, query_split, query_dataset, memory_dataset, memory_split, verbose=True):\n",
        "  dfs = []\n",
        "  for model in models:\n",
        "    df = read_scaling_df(model=model, query_split=query_split, query_dataset=query_dataset, memory_dataset=memory_dataset, memory_split=memory_split, verbose=verbose)\n",
        "    dfs.append(df)\n",
        "  return pd.concat(dfs)"
      ],
      "metadata": {
        "id": "olsgtAe3FO6_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def read_pruning_data(metric_name, directory=PRUNING_METRICS_DIR):\n",
        "  \"\"\"Read pruning metrics and return as a dataframe.\"\"\"\n",
        "\n",
        "  path = f'{directory}/ImageNet-1K_{metric_name}.csv'\n",
        "  with file_opener(path, 'r') as f:\n",
        "    df = pd.read_csv(f)\n",
        "  return df"
      ],
      "metadata": {
        "id": "BSe6Br3-l7nH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Accuracy based on class"
      ],
      "metadata": {
        "id": "9XVC6X9zIbQ8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "synset_to_index = {'n01440764': 0, 'n01443537': 1, 'n01484850': 2, 'n01491361': 3, 'n01494475': 4, 'n01496331': 5, 'n01498041': 6, 'n01514668': 7, 'n01514859': 8, 'n01518878': 9, 'n01530575': 10, 'n01531178': 11, 'n01532829': 12, 'n01534433': 13, 'n01537544': 14, 'n01558993': 15, 'n01560419': 16, 'n01580077': 17, 'n01582220': 18, 'n01592084': 19, 'n01601694': 20, 'n01608432': 21, 'n01614925': 22, 'n01616318': 23, 'n01622779': 24, 'n01629819': 25, 'n01630670': 26, 'n01631663': 27, 'n01632458': 28, 'n01632777': 29, 'n01641577': 30, 'n01644373': 31, 'n01644900': 32, 'n01664065': 33, 'n01665541': 34, 'n01667114': 35, 'n01667778': 36, 'n01669191': 37, 'n01675722': 38, 'n01677366': 39, 'n01682714': 40, 'n01685808': 41, 'n01687978': 42, 'n01688243': 43, 'n01689811': 44, 'n01692333': 45, 'n01693334': 46, 'n01694178': 47, 'n01695060': 48, 'n01697457': 49, 'n01698640': 50, 'n01704323': 51, 'n01728572': 52, 'n01728920': 53, 'n01729322': 54, 'n01729977': 55, 'n01734418': 56, 'n01735189': 57, 'n01737021': 58, 'n01739381': 59, 'n01740131': 60, 'n01742172': 61, 'n01744401': 62, 'n01748264': 63, 'n01749939': 64, 'n01751748': 65, 'n01753488': 66, 'n01755581': 67, 'n01756291': 68, 'n01768244': 69, 'n01770081': 70, 'n01770393': 71, 'n01773157': 72, 'n01773549': 73, 'n01773797': 74, 'n01774384': 75, 'n01774750': 76, 'n01775062': 77, 'n01776313': 78, 'n01784675': 79, 'n01795545': 80, 'n01796340': 81, 'n01797886': 82, 'n01798484': 83, 'n01806143': 84, 'n01806567': 85, 'n01807496': 86, 'n01817953': 87, 'n01818515': 88, 'n01819313': 89, 'n01820546': 90, 'n01824575': 91, 'n01828970': 92, 'n01829413': 93, 'n01833805': 94, 'n01843065': 95, 'n01843383': 96, 'n01847000': 97, 'n01855032': 98, 'n01855672': 99, 'n01860187': 100, 'n01871265': 101, 'n01872401': 102, 'n01873310': 103, 'n01877812': 104, 'n01882714': 105, 'n01883070': 106, 'n01910747': 107, 'n01914609': 108, 'n01917289': 109, 'n01924916': 110, 'n01930112': 111, 'n01943899': 112, 'n01944390': 113, 'n01945685': 114, 'n01950731': 115, 'n01955084': 116, 'n01968897': 117, 'n01978287': 118, 'n01978455': 119, 'n01980166': 120, 'n01981276': 121, 'n01983481': 122, 'n01984695': 123, 'n01985128': 124, 'n01986214': 125, 'n01990800': 126, 'n02002556': 127, 'n02002724': 128, 'n02006656': 129, 'n02007558': 130, 'n02009229': 131, 'n02009912': 132, 'n02011460': 133, 'n02012849': 134, 'n02013706': 135, 'n02017213': 136, 'n02018207': 137, 'n02018795': 138, 'n02025239': 139, 'n02027492': 140, 'n02028035': 141, 'n02033041': 142, 'n02037110': 143, 'n02051845': 144, 'n02056570': 145, 'n02058221': 146, 'n02066245': 147, 'n02071294': 148, 'n02074367': 149, 'n02077923': 150, 'n02085620': 151, 'n02085782': 152, 'n02085936': 153, 'n02086079': 154, 'n02086240': 155, 'n02086646': 156, 'n02086910': 157, 'n02087046': 158, 'n02087394': 159, 'n02088094': 160, 'n02088238': 161, 'n02088364': 162, 'n02088466': 163, 'n02088632': 164, 'n02089078': 165, 'n02089867': 166, 'n02089973': 167, 'n02090379': 168, 'n02090622': 169, 'n02090721': 170, 'n02091032': 171, 'n02091134': 172, 'n02091244': 173, 'n02091467': 174, 'n02091635': 175, 'n02091831': 176, 'n02092002': 177, 'n02092339': 178, 'n02093256': 179, 'n02093428': 180, 'n02093647': 181, 'n02093754': 182, 'n02093859': 183, 'n02093991': 184, 'n02094114': 185, 'n02094258': 186, 'n02094433': 187, 'n02095314': 188, 'n02095570': 189, 'n02095889': 190, 'n02096051': 191, 'n02096177': 192, 'n02096294': 193, 'n02096437': 194, 'n02096585': 195, 'n02097047': 196, 'n02097130': 197, 'n02097209': 198, 'n02097298': 199, 'n02097474': 200, 'n02097658': 201, 'n02098105': 202, 'n02098286': 203, 'n02098413': 204, 'n02099267': 205, 'n02099429': 206, 'n02099601': 207, 'n02099712': 208, 'n02099849': 209, 'n02100236': 210, 'n02100583': 211, 'n02100735': 212, 'n02100877': 213, 'n02101006': 214, 'n02101388': 215, 'n02101556': 216, 'n02102040': 217, 'n02102177': 218, 'n02102318': 219, 'n02102480': 220, 'n02102973': 221, 'n02104029': 222, 'n02104365': 223, 'n02105056': 224, 'n02105162': 225, 'n02105251': 226, 'n02105412': 227, 'n02105505': 228, 'n02105641': 229, 'n02105855': 230, 'n02106030': 231, 'n02106166': 232, 'n02106382': 233, 'n02106550': 234, 'n02106662': 235, 'n02107142': 236, 'n02107312': 237, 'n02107574': 238, 'n02107683': 239, 'n02107908': 240, 'n02108000': 241, 'n02108089': 242, 'n02108422': 243, 'n02108551': 244, 'n02108915': 245, 'n02109047': 246, 'n02109525': 247, 'n02109961': 248, 'n02110063': 249, 'n02110185': 250, 'n02110341': 251, 'n02110627': 252, 'n02110806': 253, 'n02110958': 254, 'n02111129': 255, 'n02111277': 256, 'n02111500': 257, 'n02111889': 258, 'n02112018': 259, 'n02112137': 260, 'n02112350': 261, 'n02112706': 262, 'n02113023': 263, 'n02113186': 264, 'n02113624': 265, 'n02113712': 266, 'n02113799': 267, 'n02113978': 268, 'n02114367': 269, 'n02114548': 270, 'n02114712': 271, 'n02114855': 272, 'n02115641': 273, 'n02115913': 274, 'n02116738': 275, 'n02117135': 276, 'n02119022': 277, 'n02119789': 278, 'n02120079': 279, 'n02120505': 280, 'n02123045': 281, 'n02123159': 282, 'n02123394': 283, 'n02123597': 284, 'n02124075': 285, 'n02125311': 286, 'n02127052': 287, 'n02128385': 288, 'n02128757': 289, 'n02128925': 290, 'n02129165': 291, 'n02129604': 292, 'n02130308': 293, 'n02132136': 294, 'n02133161': 295, 'n02134084': 296, 'n02134418': 297, 'n02137549': 298, 'n02138441': 299, 'n02165105': 300, 'n02165456': 301, 'n02167151': 302, 'n02168699': 303, 'n02169497': 304, 'n02172182': 305, 'n02174001': 306, 'n02177972': 307, 'n02190166': 308, 'n02206856': 309, 'n02219486': 310, 'n02226429': 311, 'n02229544': 312, 'n02231487': 313, 'n02233338': 314, 'n02236044': 315, 'n02256656': 316, 'n02259212': 317, 'n02264363': 318, 'n02268443': 319, 'n02268853': 320, 'n02276258': 321, 'n02277742': 322, 'n02279972': 323, 'n02280649': 324, 'n02281406': 325, 'n02281787': 326, 'n02317335': 327, 'n02319095': 328, 'n02321529': 329, 'n02325366': 330, 'n02326432': 331, 'n02328150': 332, 'n02342885': 333, 'n02346627': 334, 'n02356798': 335, 'n02361337': 336, 'n02363005': 337, 'n02364673': 338, 'n02389026': 339, 'n02391049': 340, 'n02395406': 341, 'n02396427': 342, 'n02397096': 343, 'n02398521': 344, 'n02403003': 345, 'n02408429': 346, 'n02410509': 347, 'n02412080': 348, 'n02415577': 349, 'n02417914': 350, 'n02422106': 351, 'n02422699': 352, 'n02423022': 353, 'n02437312': 354, 'n02437616': 355, 'n02441942': 356, 'n02442845': 357, 'n02443114': 358, 'n02443484': 359, 'n02444819': 360, 'n02445715': 361, 'n02447366': 362, 'n02454379': 363, 'n02457408': 364, 'n02480495': 365, 'n02480855': 366, 'n02481823': 367, 'n02483362': 368, 'n02483708': 369, 'n02484975': 370, 'n02486261': 371, 'n02486410': 372, 'n02487347': 373, 'n02488291': 374, 'n02488702': 375, 'n02489166': 376, 'n02490219': 377, 'n02492035': 378, 'n02492660': 379, 'n02493509': 380, 'n02493793': 381, 'n02494079': 382, 'n02497673': 383, 'n02500267': 384, 'n02504013': 385, 'n02504458': 386, 'n02509815': 387, 'n02510455': 388, 'n02514041': 389, 'n02526121': 390, 'n02536864': 391, 'n02606052': 392, 'n02607072': 393, 'n02640242': 394, 'n02641379': 395, 'n02643566': 396, 'n02655020': 397, 'n02666196': 398, 'n02667093': 399, 'n02669723': 400, 'n02672831': 401, 'n02676566': 402, 'n02687172': 403, 'n02690373': 404, 'n02692877': 405, 'n02699494': 406, 'n02701002': 407, 'n02704792': 408, 'n02708093': 409, 'n02727426': 410, 'n02730930': 411, 'n02747177': 412, 'n02749479': 413, 'n02769748': 414, 'n02776631': 415, 'n02777292': 416, 'n02782093': 417, 'n02783161': 418, 'n02786058': 419, 'n02787622': 420, 'n02788148': 421, 'n02790996': 422, 'n02791124': 423, 'n02791270': 424, 'n02793495': 425, 'n02794156': 426, 'n02795169': 427, 'n02797295': 428, 'n02799071': 429, 'n02802426': 430, 'n02804414': 431, 'n02804610': 432, 'n02807133': 433, 'n02808304': 434, 'n02808440': 435, 'n02814533': 436, 'n02814860': 437, 'n02815834': 438, 'n02817516': 439, 'n02823428': 440, 'n02823750': 441, 'n02825657': 442, 'n02834397': 443, 'n02835271': 444, 'n02837789': 445, 'n02840245': 446, 'n02841315': 447, 'n02843684': 448, 'n02859443': 449, 'n02860847': 450, 'n02865351': 451, 'n02869837': 452, 'n02870880': 453, 'n02871525': 454, 'n02877765': 455, 'n02879718': 456, 'n02883205': 457, 'n02892201': 458, 'n02892767': 459, 'n02894605': 460, 'n02895154': 461, 'n02906734': 462, 'n02909870': 463, 'n02910353': 464, 'n02916936': 465, 'n02917067': 466, 'n02927161': 467, 'n02930766': 468, 'n02939185': 469, 'n02948072': 470, 'n02950826': 471, 'n02951358': 472, 'n02951585': 473, 'n02963159': 474, 'n02965783': 475, 'n02966193': 476, 'n02966687': 477, 'n02971356': 478, 'n02974003': 479, 'n02977058': 480, 'n02978881': 481, 'n02979186': 482, 'n02980441': 483, 'n02981792': 484, 'n02988304': 485, 'n02992211': 486, 'n02992529': 487, 'n02999410': 488, 'n03000134': 489, 'n03000247': 490, 'n03000684': 491, 'n03014705': 492, 'n03016953': 493, 'n03017168': 494, 'n03018349': 495, 'n03026506': 496, 'n03028079': 497, 'n03032252': 498, 'n03041632': 499, 'n03042490': 500, 'n03045698': 501, 'n03047690': 502, 'n03062245': 503, 'n03063599': 504, 'n03063689': 505, 'n03065424': 506, 'n03075370': 507, 'n03085013': 508, 'n03089624': 509, 'n03095699': 510, 'n03100240': 511, 'n03109150': 512, 'n03110669': 513, 'n03124043': 514, 'n03124170': 515, 'n03125729': 516, 'n03126707': 517, 'n03127747': 518, 'n03127925': 519, 'n03131574': 520, 'n03133878': 521, 'n03134739': 522, 'n03141823': 523, 'n03146219': 524, 'n03160309': 525, 'n03179701': 526, 'n03180011': 527, 'n03187595': 528, 'n03188531': 529, 'n03196217': 530, 'n03197337': 531, 'n03201208': 532, 'n03207743': 533, 'n03207941': 534, 'n03208938': 535, 'n03216828': 536, 'n03218198': 537, 'n03220513': 538, 'n03223299': 539, 'n03240683': 540, 'n03249569': 541, 'n03250847': 542, 'n03255030': 543, 'n03259280': 544, 'n03271574': 545, 'n03272010': 546, 'n03272562': 547, 'n03290653': 548, 'n03291819': 549, 'n03297495': 550, 'n03314780': 551, 'n03325584': 552, 'n03337140': 553, 'n03344393': 554, 'n03345487': 555, 'n03347037': 556, 'n03355925': 557, 'n03372029': 558, 'n03376595': 559, 'n03379051': 560, 'n03384352': 561, 'n03388043': 562, 'n03388183': 563, 'n03388549': 564, 'n03393912': 565, 'n03394916': 566, 'n03400231': 567, 'n03404251': 568, 'n03417042': 569, 'n03424325': 570, 'n03425413': 571, 'n03443371': 572, 'n03444034': 573, 'n03445777': 574, 'n03445924': 575, 'n03447447': 576, 'n03447721': 577, 'n03450230': 578, 'n03452741': 579, 'n03457902': 580, 'n03459775': 581, 'n03461385': 582, 'n03467068': 583, 'n03476684': 584, 'n03476991': 585, 'n03478589': 586, 'n03481172': 587, 'n03482405': 588, 'n03483316': 589, 'n03485407': 590, 'n03485794': 591, 'n03492542': 592, 'n03494278': 593, 'n03495258': 594, 'n03496892': 595, 'n03498962': 596, 'n03527444': 597, 'n03529860': 598, 'n03530642': 599, 'n03532672': 600, 'n03534580': 601, 'n03535780': 602, 'n03538406': 603, 'n03544143': 604, 'n03584254': 605, 'n03584829': 606, 'n03590841': 607, 'n03594734': 608, 'n03594945': 609, 'n03595614': 610, 'n03598930': 611, 'n03599486': 612, 'n03602883': 613, 'n03617480': 614, 'n03623198': 615, 'n03627232': 616, 'n03630383': 617, 'n03633091': 618, 'n03637318': 619, 'n03642806': 620, 'n03649909': 621, 'n03657121': 622, 'n03658185': 623, 'n03661043': 624, 'n03662601': 625, 'n03666591': 626, 'n03670208': 627, 'n03673027': 628, 'n03676483': 629, 'n03680355': 630, 'n03690938': 631, 'n03691459': 632, 'n03692522': 633, 'n03697007': 634, 'n03706229': 635, 'n03709823': 636, 'n03710193': 637, 'n03710637': 638, 'n03710721': 639, 'n03717622': 640, 'n03720891': 641, 'n03721384': 642, 'n03724870': 643, 'n03729826': 644, 'n03733131': 645, 'n03733281': 646, 'n03733805': 647, 'n03742115': 648, 'n03743016': 649, 'n03759954': 650, 'n03761084': 651, 'n03763968': 652, 'n03764736': 653, 'n03769881': 654, 'n03770439': 655, 'n03770679': 656, 'n03773504': 657, 'n03775071': 658, 'n03775546': 659, 'n03776460': 660, 'n03777568': 661, 'n03777754': 662, 'n03781244': 663, 'n03782006': 664, 'n03785016': 665, 'n03786901': 666, 'n03787032': 667, 'n03788195': 668, 'n03788365': 669, 'n03791053': 670, 'n03792782': 671, 'n03792972': 672, 'n03793489': 673, 'n03794056': 674, 'n03796401': 675, 'n03803284': 676, 'n03804744': 677, 'n03814639': 678, 'n03814906': 679, 'n03825788': 680, 'n03832673': 681, 'n03837869': 682, 'n03838899': 683, 'n03840681': 684, 'n03841143': 685, 'n03843555': 686, 'n03854065': 687, 'n03857828': 688, 'n03866082': 689, 'n03868242': 690, 'n03868863': 691, 'n03871628': 692, 'n03873416': 693, 'n03874293': 694, 'n03874599': 695, 'n03876231': 696, 'n03877472': 697, 'n03877845': 698, 'n03884397': 699, 'n03887697': 700, 'n03888257': 701, 'n03888605': 702, 'n03891251': 703, 'n03891332': 704, 'n03895866': 705, 'n03899768': 706, 'n03902125': 707, 'n03903868': 708, 'n03908618': 709, 'n03908714': 710, 'n03916031': 711, 'n03920288': 712, 'n03924679': 713, 'n03929660': 714, 'n03929855': 715, 'n03930313': 716, 'n03930630': 717, 'n03933933': 718, 'n03935335': 719, 'n03937543': 720, 'n03938244': 721, 'n03942813': 722, 'n03944341': 723, 'n03947888': 724, 'n03950228': 725, 'n03954731': 726, 'n03956157': 727, 'n03958227': 728, 'n03961711': 729, 'n03967562': 730, 'n03970156': 731, 'n03976467': 732, 'n03976657': 733, 'n03977966': 734, 'n03980874': 735, 'n03982430': 736, 'n03983396': 737, 'n03991062': 738, 'n03992509': 739, 'n03995372': 740, 'n03998194': 741, 'n04004767': 742, 'n04005630': 743, 'n04008634': 744, 'n04009552': 745, 'n04019541': 746, 'n04023962': 747, 'n04026417': 748, 'n04033901': 749, 'n04033995': 750, 'n04037443': 751, 'n04039381': 752, 'n04040759': 753, 'n04041544': 754, 'n04044716': 755, 'n04049303': 756, 'n04065272': 757, 'n04067472': 758, 'n04069434': 759, 'n04070727': 760, 'n04074963': 761, 'n04081281': 762, 'n04086273': 763, 'n04090263': 764, 'n04099969': 765, 'n04111531': 766, 'n04116512': 767, 'n04118538': 768, 'n04118776': 769, 'n04120489': 770, 'n04125021': 771, 'n04127249': 772, 'n04131690': 773, 'n04133789': 774, 'n04136333': 775, 'n04141076': 776, 'n04141327': 777, 'n04141975': 778, 'n04146614': 779, 'n04147183': 780, 'n04149813': 781, 'n04152593': 782, 'n04153751': 783, 'n04154565': 784, 'n04162706': 785, 'n04179913': 786, 'n04192698': 787, 'n04200800': 788, 'n04201297': 789, 'n04204238': 790, 'n04204347': 791, 'n04208210': 792, 'n04209133': 793, 'n04209239': 794, 'n04228054': 795, 'n04229816': 796, 'n04235860': 797, 'n04238763': 798, 'n04239074': 799, 'n04243546': 800, 'n04251144': 801, 'n04252077': 802, 'n04252225': 803, 'n04254120': 804, 'n04254680': 805, 'n04254777': 806, 'n04258138': 807, 'n04259630': 808, 'n04263257': 809, 'n04264628': 810, 'n04265275': 811, 'n04266014': 812, 'n04270147': 813, 'n04273569': 814, 'n04275548': 815, 'n04277352': 816, 'n04285008': 817, 'n04286575': 818, 'n04296562': 819, 'n04310018': 820, 'n04311004': 821, 'n04311174': 822, 'n04317175': 823, 'n04325704': 824, 'n04326547': 825, 'n04328186': 826, 'n04330267': 827, 'n04332243': 828, 'n04335435': 829, 'n04336792': 830, 'n04344873': 831, 'n04346328': 832, 'n04347754': 833, 'n04350905': 834, 'n04355338': 835, 'n04355933': 836, 'n04356056': 837, 'n04357314': 838, 'n04366367': 839, 'n04367480': 840, 'n04370456': 841, 'n04371430': 842, 'n04371774': 843, 'n04372370': 844, 'n04376876': 845, 'n04380533': 846, 'n04389033': 847, 'n04392985': 848, 'n04398044': 849, 'n04399382': 850, 'n04404412': 851, 'n04409515': 852, 'n04417672': 853, 'n04418357': 854, 'n04423845': 855, 'n04428191': 856, 'n04429376': 857, 'n04435653': 858, 'n04442312': 859, 'n04443257': 860, 'n04447861': 861, 'n04456115': 862, 'n04458633': 863, 'n04461696': 864, 'n04462240': 865, 'n04465501': 866, 'n04467665': 867, 'n04476259': 868, 'n04479046': 869, 'n04482393': 870, 'n04483307': 871, 'n04485082': 872, 'n04486054': 873, 'n04487081': 874, 'n04487394': 875, 'n04493381': 876, 'n04501370': 877, 'n04505470': 878, 'n04507155': 879, 'n04509417': 880, 'n04515003': 881, 'n04517823': 882, 'n04522168': 883, 'n04523525': 884, 'n04525038': 885, 'n04525305': 886, 'n04532106': 887, 'n04532670': 888, 'n04536866': 889, 'n04540053': 890, 'n04542943': 891, 'n04548280': 892, 'n04548362': 893, 'n04550184': 894, 'n04552348': 895, 'n04553703': 896, 'n04554684': 897, 'n04557648': 898, 'n04560804': 899, 'n04562935': 900, 'n04579145': 901, 'n04579432': 902, 'n04584207': 903, 'n04589890': 904, 'n04590129': 905, 'n04591157': 906, 'n04591713': 907, 'n04592741': 908, 'n04596742': 909, 'n04597913': 910, 'n04599235': 911, 'n04604644': 912, 'n04606251': 913, 'n04612504': 914, 'n04613696': 915, 'n06359193': 916, 'n06596364': 917, 'n06785654': 918, 'n06794110': 919, 'n06874185': 920, 'n07248320': 921, 'n07565083': 922, 'n07579787': 923, 'n07583066': 924, 'n07584110': 925, 'n07590611': 926, 'n07613480': 927, 'n07614500': 928, 'n07615774': 929, 'n07684084': 930, 'n07693725': 931, 'n07695742': 932, 'n07697313': 933, 'n07697537': 934, 'n07711569': 935, 'n07714571': 936, 'n07714990': 937, 'n07715103': 938, 'n07716358': 939, 'n07716906': 940, 'n07717410': 941, 'n07717556': 942, 'n07718472': 943, 'n07718747': 944, 'n07720875': 945, 'n07730033': 946, 'n07734744': 947, 'n07742313': 948, 'n07745940': 949, 'n07747607': 950, 'n07749582': 951, 'n07753113': 952, 'n07753275': 953, 'n07753592': 954, 'n07754684': 955, 'n07760859': 956, 'n07768694': 957, 'n07802026': 958, 'n07831146': 959, 'n07836838': 960, 'n07860988': 961, 'n07871810': 962, 'n07873807': 963, 'n07875152': 964, 'n07880968': 965, 'n07892512': 966, 'n07920052': 967, 'n07930864': 968, 'n07932039': 969, 'n09193705': 970, 'n09229709': 971, 'n09246464': 972, 'n09256479': 973, 'n09288635': 974, 'n09332890': 975, 'n09399592': 976, 'n09421951': 977, 'n09428293': 978, 'n09468604': 979, 'n09472597': 980, 'n09835506': 981, 'n10148035': 982, 'n10565667': 983, 'n11879895': 984, 'n11939491': 985, 'n12057211': 986, 'n12144580': 987, 'n12267677': 988, 'n12620546': 989, 'n12768682': 990, 'n12985857': 991, 'n12998815': 992, 'n13037406': 993, 'n13040303': 994, 'n13044778': 995, 'n13052670': 996, 'n13054560': 997, 'n13133613': 998, 'n15075141': 999, 'test.py': 1000}"
      ],
      "metadata": {
        "id": "h-BspqhoIckh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "index_to_synset = {value: key for key, value in synset_to_index.items()}"
      ],
      "metadata": {
        "id": "-yRVVJ_4Ic7G"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plurality_element(lst):\n",
        "    counts = Counter(lst)\n",
        "    plurality_element, _ = counts.most_common(1)[0]\n",
        "    return plurality_element"
      ],
      "metadata": {
        "id": "aHn3ESaAIgKV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_accuracy_based_on_class(df, max_num_neighbors = 1):\n",
        "  class_to_acc_list = {}\n",
        "\n",
        "  def update_dict(image_class, neighbor_class):\n",
        "\n",
        "    correctness_score = 0\n",
        "    if image_class == neighbor_class:\n",
        "      correctness_score = 1\n",
        "\n",
        "    if not image_class in class_to_acc_list:\n",
        "      class_to_acc_list[image_class] = [correctness_score]\n",
        "    else:\n",
        "      class_to_acc_list[image_class].append(correctness_score)\n",
        "\n",
        "  df.apply(lambda x: update_dict(image_class=x['image_class'], neighbor_class=plurality_element(x['neighbor_classes'][:max_num_neighbors])), axis=1)\n",
        "\n",
        "  class_to_acc = {}\n",
        "  for k, v in class_to_acc_list.items():\n",
        "    class_to_acc[k] = np.mean(v)\n",
        "\n",
        "  list_of_tuples = sorted(class_to_acc.items(), key=lambda x: x[1])\n",
        "  return [(x, y) for x, y in list_of_tuples]"
      ],
      "metadata": {
        "id": "pZvrLO_gIgMh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Reliability"
      ],
      "metadata": {
        "id": "2HIiGwuCYLgl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_reliability_at_k(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:\n",
        "  \"\"\"Calculate the reliability at k (neighbor index) for each model.\"\"\"\n",
        "\n",
        "  result_df = pd.DataFrame(columns=['featurizer', 'k_index', 'accuracy_at_k'])\n",
        "  counter = 0\n",
        "\n",
        "  for featurizer in df['featurizer'].unique():\n",
        "    featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])\n",
        "    for k in range(max_num_neighbors):\n",
        "      acc = (featurizer_df.apply(lambda x: x.image_class == x.neighbor_classes[k], axis=1).sum()) / len(featurizer_df)\n",
        "      # Note: this is using k+1 to avoid division by zero\n",
        "      # when fitting a log function to the data in plotting\n",
        "      result_df.loc[counter] = [featurizer, k+1, acc]\n",
        "      counter += 1\n",
        "  return result_df\n",
        "\n",
        "\n",
        "def get_reliability_data_for_pruning(models,\n",
        "                                     query_split='validation', query_dataset='imagenet2012',\n",
        "                                     memory_dataset='imagenet2012', memory_split='train',\n",
        "                                     max_num_neighbors=100):\n",
        "  assert len(models) == 1, 'only 1 model supported at this time'\n",
        "  df_combined = read_multiple_scaling_dfs(models=models, query_split=query_split, query_dataset=query_dataset, memory_dataset=memory_dataset, memory_split=memory_split)\n",
        "  reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=max_num_neighbors)\n",
        "  index_to_reliability = reliability_data.set_index('k_index')['accuracy_at_k'].to_dict()\n",
        "  return index_to_reliability\n",
        "\n",
        "\n",
        "def plot_reliability_at_k(df: pd.DataFrame, ylim: float = 0.0, fit: bool = True, save_fig_path: str = None) -> None:\n",
        "  \"\"\"Plot reliability at k (neighbor index) for each model.\"\"\"\n",
        "\n",
        "  plt.figure(figsize=(8, 5))\n",
        "\n",
        "  for _, featurizer in enumerate(df['featurizer'].unique()):\n",
        "    featurizer_df = df[df['featurizer'] == featurizer]\n",
        "\n",
        "    plt.plot(featurizer_df['k_index'], featurizer_df['accuracy_at_k']*100.0,\n",
        "             marker='o', linestyle='-',\n",
        "             linewidth=2, markersize=8,\n",
        "             color=featurizer_to_color[featurizer],\n",
        "             label=featurizer_to_name[featurizer])\n",
        "\n",
        "  if fit:\n",
        "    for _, featurizer in enumerate(df['featurizer'].unique()):\n",
        "      featurizer_df = df[df['featurizer'] == featurizer]\n",
        "      a, b = np.polyfit(np.log(featurizer_df['k_index']), featurizer_df['accuracy_at_k'], 1)\n",
        "      fit = []\n",
        "      for k in featurizer_df['k_index']:\n",
        "        value = a * np.log(k) + b\n",
        "        fit.append(value * 100.0)\n",
        "      plt.plot(featurizer_df['k_index'], fit, linestyle='-', linewidth=1.5, markersize=8, color='black')\n",
        "\n",
        "  plt.ylim(ylim)\n",
        "  plt.gca().spines['top'].set_visible(False)\n",
        "  plt.gca().spines['right'].set_visible(False)\n",
        "  plt.xticks(fontsize=12)\n",
        "  plt.yticks(fontsize=12)\n",
        "  plt.legend(fontsize=12)\n",
        "\n",
        "  plt.xlabel('Neighbor index', fontsize=14)\n",
        "  plt.ylabel('Neighbor accuracy (%)', fontsize=14)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)"
      ],
      "metadata": {
        "id": "UGrSTxSqfFeN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "df_combined = read_multiple_scaling_dfs(models=models, query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "IX9nD0GuFscN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)"
      ],
      "metadata": {
        "id": "tnaOmwx6Fsi-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "reliability_data[:10]"
      ],
      "metadata": {
        "id": "ChZu81jTrhEm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_reliability_at_k(df=reliability_data,\n",
        "                      ylim=(43, 83),\n",
        "                      save_fig_path=f'{FIGURE_DIR}/imagenet_reliability_at_k.pdf')"
      ],
      "metadata": {
        "id": "7p61dhirGHlM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "df_combined = read_multiple_scaling_dfs(models=models, query_dataset='inaturalist', query_split='validation', memory_dataset='inaturalist', memory_split='train')"
      ],
      "metadata": {
        "id": "NWp3rroU6enQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)"
      ],
      "metadata": {
        "id": "cTFGK-mY6jrm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "reliability_data[:10]"
      ],
      "metadata": {
        "id": "8_0Udk1b6lnW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_reliability_at_k(df=reliability_data,\n",
        "                      ylim=(0, 62),\n",
        "                      save_fig_path=f'{FIGURE_DIR}/inaturalist_reliability_at_k.pdf')"
      ],
      "metadata": {
        "id": "x48_T-SU6nQF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14']\n",
        "df_combined = read_multiple_scaling_dfs(models=models,\n",
        "                                        query_split='validation', query_dataset='imagenet2012',\n",
        "                                        memory_dataset='jft-with-vit22b-labels', memory_split='train')\n",
        "reliability_data = get_reliability_at_k(df=df_combined, max_num_neighbors=100)\n",
        "plot_reliability_at_k(df=reliability_data,\n",
        "                      ylim=(70.0, 84.5))"
      ],
      "metadata": {
        "id": "Tx2AoD86WW6k"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Calibration"
      ],
      "metadata": {
        "id": "J655eo6oYDL5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_majority_element(l):\n",
        "  c = Counter(l)\n",
        "  elem, count = c.most_common()[0]\n",
        "  return elem\n",
        "\n",
        "def get_majority_count(l):\n",
        "  c = Counter(l)\n",
        "  elem, count = c.most_common()[0]\n",
        "  return count\n",
        "\n",
        "from scipy.stats import entropy\n",
        "\n",
        "def get_entropy(labels, base=2):\n",
        "  value, counts = np.unique(labels, return_counts=True)\n",
        "  return entropy(counts, base=base)\n",
        "\n",
        "def get_entropy_wrong(l):\n",
        "    # Count the occurrences of each value\n",
        "    counter = Counter(l)\n",
        "\n",
        "    # Calculate the total number of values\n",
        "    total_count = len(values)\n",
        "\n",
        "    # Calculate the probability of each value\n",
        "    probabilities = [count / total_count for count in counter.values()]\n",
        "\n",
        "    # Calculate entropy using the formula: -sum(p * log(p))\n",
        "    entropy_value = -sum(p * math.log2(p) for p in probabilities)\n",
        "\n",
        "    return entropy_value\n",
        "\n",
        "def get_fraction_of_majority_class(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:\n",
        "  \"\"\"Get fraction of majority class.\"\"\"\n",
        "\n",
        "  result_df = pd.DataFrame(columns=['featurizer', 'count', 'accuracy'])\n",
        "\n",
        "  for featurizer in tqdm(df['featurizer'].unique()):\n",
        "    featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])\n",
        "    featurizer_df['majority-element'] = featurizer_df.apply(lambda x: get_majority_element(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)\n",
        "    featurizer_df['majority-count'] = featurizer_df.apply(lambda x: get_majority_count(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)\n",
        "    featurizer_df['entropy'] = featurizer_df.apply(lambda x: get_entropy(first_k(x.neighbor_classes, k=max_num_neighbors)), axis=1)\n",
        "\n",
        "    count_to_acc = {}\n",
        "\n",
        "    for i, row in featurizer_df.iterrows():\n",
        "      elem = row['majority-element']\n",
        "      count = row['majority-count']\n",
        "      ground_truth = row['image_class']\n",
        "      is_correct = ground_truth == elem\n",
        "\n",
        "      if not count_to_acc.get(count):\n",
        "        count_to_acc[count] = []\n",
        "\n",
        "      if is_correct:\n",
        "        count_to_acc[count].append(1)\n",
        "      else:\n",
        "        count_to_acc[count].append(0)\n",
        "\n",
        "    for c in range(max_num_neighbors):\n",
        "      if c in count_to_acc:\n",
        "        row = {'featurizer': featurizer,'count': c, 'accuracy': np.mean(count_to_acc[c])}\n",
        "        result_df = pd.concat([result_df, pd.DataFrame([row])], ignore_index=True)\n",
        "\n",
        "  return result_df\n",
        "\n",
        "def plot_accuracy_from_count(df, ylim=(0, 100), save_fig_path = None):\n",
        "  \"\"\"Plot accuracy .\"\"\"\n",
        "\n",
        "  plt.figure(figsize=(8, 5))\n",
        "\n",
        "  for _, featurizer in enumerate(df['featurizer'].unique()):\n",
        "    featurizer_df = df[df['featurizer'] == featurizer]\n",
        "\n",
        "    plt.plot(featurizer_df['count'], featurizer_df['accuracy']*100.0,\n",
        "             marker='o', linestyle='-',\n",
        "             linewidth=2, markersize=8,\n",
        "             color=featurizer_to_color[featurizer], label=featurizer)\n",
        "\n",
        "  plt.ylim(ylim)\n",
        "  plt.gca().spines['top'].set_visible(False)\n",
        "  plt.gca().spines['right'].set_visible(False)\n",
        "  plt.xticks(fontsize=12)\n",
        "  plt.yticks(fontsize=12)\n",
        "  plt.legend(fontsize=12)\n",
        "\n",
        "  # Plot diagonal line\n",
        "  plt.plot([0, 100], [0, 100], color='black')\n",
        "\n",
        "  plt.xlabel('Count of plurality class', fontsize=14)\n",
        "  plt.ylabel('Plurality voting accuracy (%)', fontsize=14)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight')\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)"
      ],
      "metadata": {
        "id": "Ryw_Zr1ELCsF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# read data\n",
        "df = read_scaling_df(model='dinov2_vitl14',\n",
        "                     query_split='validation', query_dataset='imagenet2012',\n",
        "                     memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "UXj8pdK1Ye8F"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Example usage\n",
        "values = [795, 795, 795, 970, 795, 795, 795, 795, 795]\n",
        "print(\"Entropy:\", get_entropy(values))"
      ],
      "metadata": {
        "id": "JexHaVerR5ya"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df['entropy'] = df.apply(lambda x: get_entropy(x.neighbor_classes), axis=1)"
      ],
      "metadata": {
        "id": "kCZZtHP47ygh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "max(df['entropy'])"
      ],
      "metadata": {
        "id": "tm-S4kmV-UkU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "min(df['entropy'])"
      ],
      "metadata": {
        "id": "6P3GTnNqAcTf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_entropy_vs_accuracy(df, featurizer, save_fig_path=None):\n",
        "    # Bin the data into 0.05 intervals\n",
        "    bins = np.arange(0, 1.05, 0.05)  # 0, 0.05, 0.1, ..., 1.0\n",
        "\n",
        "    max_entropy = max(df['entropy'])\n",
        "    df['entropy-normalized'] = df.apply(lambda x: 1 - (x.entropy / max_entropy), axis=1)\n",
        "\n",
        "    # Group the data by confidence bins and calculate the mean accuracy for each bin\n",
        "    df['is_correct'] = df.apply(lambda x: get_majority_element(x.neighbor_classes) == x.image_class, axis=1)\n",
        "    binned_data = df.groupby(pd.cut(df['entropy-normalized'], bins, right=False))['is_correct'].mean()\n",
        "\n",
        "    # Convert interval index to midpoint for plotting\n",
        "    midpoints = [interval.left + 0.025 for interval in binned_data.index]\n",
        "\n",
        "    # Plot accuracy vs. confidence\n",
        "    plt.plot(midpoints, binned_data.values * 100, marker='o',\n",
        "             linewidth=2, markersize=8,\n",
        "             color=featurizer_to_color[featurizer], label=featurizer)\n",
        "\n",
        "    # Plot diagonal line from (0, 0) to (100, 100)\n",
        "    plt.plot([0, 1], [0, 100], color='black')\n",
        "\n",
        "    # Set labels and title\n",
        "    plt.xlabel('Reverse normalized entropy', fontsize=14)\n",
        "    plt.ylabel('Accuracy (%)', fontsize=14)\n",
        "\n",
        "    # Remove top and right frame\n",
        "    plt.gca().spines['top'].set_visible(False)\n",
        "    plt.gca().spines['right'].set_visible(False)\n",
        "\n",
        "    # Remove gridlines\n",
        "    plt.grid(False)\n",
        "\n",
        "    # Set axis limits\n",
        "    plt.xlim(0, 1)\n",
        "    plt.ylim(0, 100)\n",
        "    plt.xticks(fontsize=12)\n",
        "    plt.yticks(fontsize=12)\n",
        "    plt.legend(fontsize=12)\n",
        "\n",
        "    #plt.savefig(save_fig_path, format='pdf', bbox_inches='tight')\n",
        "    #plt.close()\n",
        "    plt.show()"
      ],
      "metadata": {
        "id": "is77JmEDR522"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_combined['entropy'] = df_combined.apply(lambda x: get_entropy(x.neighbor_classes), axis=1)"
      ],
      "metadata": {
        "id": "eXjldLgbAxx5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vitl14'],\n",
        "                         featurizer='dinov2_vitl14')\n",
        "                         #save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')"
      ],
      "metadata": {
        "id": "TW45SW2a8INC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vitb14'],\n",
        "                         featurizer='dinov2_vitl14')\n",
        "                         #save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')"
      ],
      "metadata": {
        "id": "5XxVze2jBVL0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_entropy_vs_accuracy(df_combined.loc[df_combined['featurizer'] == 'dinov2_vits14'],\n",
        "                         featurizer='dinov2_vitl14')\n",
        "                         #save_fig_path=f'{FIGURE_DIR}/calibration_entropy_vs_accuracy.pdf')"
      ],
      "metadata": {
        "id": "6Sq1atWnBVpn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "result_df = get_fraction_of_majority_class(df=df_combined, max_num_neighbors=100)"
      ],
      "metadata": {
        "id": "iCFF_HUMNl5U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "result_df"
      ],
      "metadata": {
        "id": "cF0YBpM6SVGU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for featurizer in result_df['featurizer'].unique():\n",
        "  plot_accuracy_from_count(df=result_df.loc[result_df['featurizer'] == featurizer],\n",
        "                          ylim=(0, 100),\n",
        "                          save_fig_path=f'{FIGURE_DIR}/calibration_accuracy_from_count_{featurizer}.pdf')"
      ],
      "metadata": {
        "id": "or60mPHUMhOj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Defining different aggregation methods"
      ],
      "metadata": {
        "id": "WZIiDuP2JBE4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def first_k(x: pd.Series, k: int) -> pd.Series:\n",
        "    \"\"\"Return first k elements of x; raise error if k is too large.\n",
        "\n",
        "    This makes sure that we notice if x is too short (for whatever reason)\n",
        "    since [1,2,3][:42] would simply return [1,2,3] instead of raising an error.\n",
        "    \"\"\"\n",
        "    assert k > 0, print(k)\n",
        "    assert len(x) >= k, print(len(x), k)\n",
        "    return x[:k]"
      ],
      "metadata": {
        "id": "BmU4OE9TzlSy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_class_accuracy(df, max_num_neighbors=100):\n",
        "  \"\"\"For a given dataframe and max_num_neighbors, plot accuracy for each class.\n",
        "\n",
        "  Here, max_num_neighbors determines the maximum number of neighbors that were\n",
        "  considered when aggregating predictions.\n",
        "  \"\"\"\n",
        "\n",
        "  assert max_num_neighbors >= 1\n",
        "\n",
        "  accuracies_per_class = {}\n",
        "  for image_class, class_df in df.groupby('image_class'):\n",
        "\n",
        "    correct_predictions = class_df['prediction_at_k'].apply(lambda x: x[max_num_neighbors-1] == image_class)\n",
        "    accuracy = correct_predictions.mean() * 100\n",
        "    accuracies_per_class[image_class] = accuracy\n",
        "\n",
        "  plt.figure(figsize=(10, 6))\n",
        "  plt.plot([i for i in range(1000)], sorted(accuracies_per_class.values()))\n",
        "\n",
        "  plt.xlabel('Sorted class index')\n",
        "  plt.ylabel('Class-conditional accuracy (%)')\n",
        "  plt.legend()\n",
        "  plt.show()"
      ],
      "metadata": {
        "id": "fojW71OrClE8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_acc_at_k_table(df, aggregators, k_list):\n",
        "  \"\"\"Print LaTeX table of accuracy @k nearest neighbors for aggregators.\"\"\"\n",
        "\n",
        "  # Define columns\n",
        "  cols = ['Aggregation']\n",
        "  for k in k_list:\n",
        "    cols.append(f'@{k}')\n",
        "\n",
        "  # Create empty DataFrame\n",
        "  df_at_k = pd.DataFrame(columns=cols)\n",
        "\n",
        "  for aggregator in tqdm(aggregators):\n",
        "\n",
        "    accuracies = get_acc(df=df, aggregator=aggregator, max_num_neighbors=max(k_list))\n",
        "    row = {'Aggregation': aggregator.get_name()}\n",
        "    for k in k_list:\n",
        "      row[f'@{k}'] = accuracies[k-1]\n",
        "    df_at_k = pd.concat([df_at_k, pd.DataFrame([row])], ignore_index=True)\n",
        "\n",
        "  # Convert DataFrame to LaTeX table\n",
        "  formatters = dict()\n",
        "  cols_bold_mapping = {}\n",
        "  for c in cols[1:]:\n",
        "    cols_bold_mapping[c] = max\n",
        "\n",
        "  def format_numbers(y, num_digits=1):\n",
        "    return (\"{:.\" + str(num_digits) + \"f}\").format(y)\n",
        "\n",
        "  for c, func in cols_bold_mapping.items():\n",
        "    m = func(df_at_k[c])\n",
        "    formatters[c] = lambda y, m=m: \"\\\\textbf{\" + format_numbers(y) + \"}\" if y == m else format_numbers(y)\n",
        "\n",
        "  latex_table = df_at_k.to_latex(escape=False, formatters=formatters,\n",
        "                                 float_format=\"%.1f\", index=False)\n",
        "  print(latex_table)"
      ],
      "metadata": {
        "id": "J6MrU6i01Pt3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_neighbor_scaling(df: pd.DataFrame,\n",
        "                          max_num_neighbors: int,\n",
        "                          aggregators,\n",
        "                          ylim = (71.0, 77.0),\n",
        "                          label_aggregator=True,\n",
        "                          save_fig_path: str = None) -> None:\n",
        "  \"\"\"Plot accuracy as a function of number of nearest neighbors.\"\"\"\n",
        "\n",
        "  plt.figure(figsize=(8, 5))\n",
        "  colors  = mpl.colormaps['tab10'].colors\n",
        "  counter = 0\n",
        "\n",
        "  for aggregator in aggregators:\n",
        "\n",
        "    # aggregate results for the first k neighbors and store in list\n",
        "    df['prediction_at_k'] = df.apply(\n",
        "          lambda x: aggregator.predict(predictions=first_k(x.neighbor_classes, max_num_neighbors),\n",
        "                                       distances=first_k(x.neighbor_distances, max_num_neighbors),\n",
        "                                       neighbor_image_ids=first_k(x.neighbor_image_ids, max_num_neighbors),\n",
        "                                       featurizer=x.featurizer),\n",
        "          axis=1)\n",
        "\n",
        "    for featurizer in df['featurizer'].unique():\n",
        "      featurizer_df = copy.deepcopy(df.loc[df['featurizer'] == featurizer])\n",
        "\n",
        "      accuracy_at_k = []\n",
        "      for k in range(max_num_neighbors):\n",
        "        accuracy = calculate_accuracy(featurizer_df, k, 'prediction_at_k')\n",
        "        accuracy_at_k.append(100.0 * accuracy)\n",
        "        print(f'\\rCalculated {aggregator.get_name()} accuracy for {featurizer} at k={k}: {100.0 * accuracy}', end='', flush=True)\n",
        "\n",
        "      k_list = [x for x in range(1, max_num_neighbors+1)]\n",
        "      print('\\r', end='', flush=True)\n",
        "      print(f'Max {aggregator.get_name()} accuracy for {featurizer}: {np.round(np.max(accuracy_at_k), 3)} for k={k_list[np.argmax(accuracy_at_k)]}')\n",
        "\n",
        "      if label_aggregator:\n",
        "        if aggregator.get_name() in aggregator_to_color.keys():\n",
        "          color = aggregator_to_color[aggregator.get_name()]\n",
        "        else:\n",
        "          color = colors[counter]\n",
        "        label = aggregator.get_name()\n",
        "      else:\n",
        "        color = featurizer_to_color[featurizer]\n",
        "        label = featurizer_to_name[featurizer]\n",
        "\n",
        "      plt.plot(k_list, accuracy_at_k, marker='o', linestyle='-',\n",
        "              linewidth=2, markersize=8, color=color, label=label)\n",
        "      counter += 1\n",
        "\n",
        "  plt.gca().set_ylim(ylim)\n",
        "\n",
        "  plt.gca().spines['top'].set_visible(False)\n",
        "  plt.gca().spines['right'].set_visible(False)\n",
        "  plt.xticks(fontsize=12)\n",
        "  plt.yticks(fontsize=12)\n",
        "  plt.legend(fontsize=12)\n",
        "\n",
        "  plt.xlabel('Number of neighbors (k)', fontsize=14)\n",
        "  plt.ylabel('Top-1 accuracy (%)', fontsize=14)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)\n",
        "\n",
        "  return df"
      ],
      "metadata": {
        "id": "OgFP2kmYYdDt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class PredictionAggregation():\n",
        "  \"\"\"Abstract base class for aggregating predictions from nearest neighbors.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    self.exclude_sets = dict()\n",
        "    self.hyperparam = None\n",
        "\n",
        "  def index_to_weight(self, index: int, *args, **kwargs) -> float:\n",
        "    raise NotImplementedError()\n",
        "\n",
        "  def add_exclude_sets(self, exclude_sets):\n",
        "    self.exclude_sets = exclude_sets\n",
        "\n",
        "  def predict(self,\n",
        "              predictions: list,\n",
        "              neighbor_image_ids: list,\n",
        "              featurizer: str,\n",
        "              *args, **kwargs) -> list[str]:\n",
        "    raise NotImplementedError()\n",
        "\n",
        "  def get_name(self, state_hyperparam=False) -> str:\n",
        "\n",
        "    if state_hyperparam and self.hyperparam:\n",
        "      return f'{self.__class__.__name__} ({self.hyperparam})'\n",
        "    else:\n",
        "      return self.__class__.__name__\n",
        "\n",
        "  def predict(self,\n",
        "              predictions: list,\n",
        "              neighbor_image_ids: list,\n",
        "              featurizer: str,\n",
        "              *args, **kwargs) -> list[str]:\n",
        "\n",
        "    predictions_at_k = []\n",
        "    counts = {}\n",
        "    max_count = -np.inf\n",
        "    highest_weight_class = None\n",
        "    i_count = 0\n",
        "\n",
        "    # In case add_exclude_sets() was called, exclude 'bad' neighbors\n",
        "    exclude_neighbors = featurizer in self.exclude_sets.keys()\n",
        "    exclude_set = set()\n",
        "    if exclude_neighbors:\n",
        "      exclude_set = self.exclude_sets[featurizer]\n",
        "      num_intersecting_bad_neighbors = len(set(neighbor_image_ids).intersection(exclude_set))\n",
        "\n",
        "      # If there's no good neighbor, proceed as usual without excluding any.\n",
        "      # Note that this determination is done based on all neighbors,\n",
        "      # not just on the information for k=1 etc.\n",
        "      if len(neighbor_image_ids) - num_intersecting_bad_neighbors <= 0:\n",
        "        exclude_neighbors = False\n",
        "\n",
        "    for i, p in enumerate(predictions):\n",
        "\n",
        "      if not (exclude_neighbors and neighbor_image_ids[i] in exclude_set):\n",
        "\n",
        "        kwargs['neighbor_image_ids'] = neighbor_image_ids\n",
        "        weight = self.index_to_weight(index=i_count, *args, **kwargs)\n",
        "\n",
        "        if p in counts:\n",
        "          counts[p] += weight\n",
        "        else:\n",
        "          counts[p] = weight\n",
        "\n",
        "        if counts[p] > max_count:\n",
        "          max_count = counts[p]\n",
        "          highest_weight_class = p\n",
        "\n",
        "        # Note: index_counter won't be increased in the event of\n",
        "        # a 'bad neighbor' that needs to be excluded\n",
        "        i_count += 1\n",
        "\n",
        "      if not highest_weight_class:\n",
        "        predictions_at_k.append(predictions[0])\n",
        "      else:\n",
        "        predictions_at_k.append(highest_weight_class)\n",
        "\n",
        "    if len(predictions_at_k) != len(predictions):\n",
        "      raise ValueError(len(predictions_at_k), len(predictions))\n",
        "\n",
        "    return predictions_at_k"
      ],
      "metadata": {
        "id": "dabRX6rjYd1L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class PluralityVoting(PredictionAggregation):\n",
        "  \"\"\"Simple plurality voting, ignoring distances.\n",
        "\n",
        "  Note that if two classes are tied for first place,\n",
        "  the one with the lowest index of first occurrence\n",
        "  is returned.\n",
        "  Examples:\n",
        "  [0, 1, 1] -> return 1 # returning majority class(1)\n",
        "  [0, 1] -> return 0 # tie; returning first tied class (0)\n",
        "  [2, 3, 3, 1, 1] -> return 3 # tie; returning first tied class (3)\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "\n",
        "  def index_to_weight(self, index: int, *args, **kwargs) -> float:\n",
        "    return 1\n",
        "\n",
        "  def get_name(self) -> str:\n",
        "    return 'PluralityVoting'"
      ],
      "metadata": {
        "id": "i00qK7tXYd3v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class DistanceWeightedVoting(PredictionAggregation):\n",
        "  \"\"\"Weight predictions by math.exp(-distance).\n",
        "\n",
        "  Note that if exponent=0.0, this includes MajorityVoting as a special case.\n",
        "  If exponent=1.0, this is identical to the voting done by\n",
        "  Khandelwal et al. (2020), Generalization through memorization: nearest\n",
        "  neighbor language models.\n",
        "  If exponent > 1.0, this 'sharpens the distribution' by giving more weight\n",
        "  to low-distance neighbors, and less weight to high-distance neighbors.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, exponent: float = 1.0):\n",
        "    super().__init__()\n",
        "    self.exponent = exponent\n",
        "    self.hyperparam = self.exponent\n",
        "\n",
        "  def index_to_weight(self, index: int, distances: list, *args, **kwargs) -> float:\n",
        "    return math.exp(-distances[index])**self.exponent\n",
        "\n",
        "  def get_name(self) -> str:\n",
        "    return 'DistanceVoting'"
      ],
      "metadata": {
        "id": "zUvfI-SzYnNa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class RankWeightedVoting(PredictionAggregation):\n",
        "  \"\"\"Weight predictions by the rank of their neighbors.\"\"\"\n",
        "\n",
        "  def __init__(self, offset: float = 2.0):\n",
        "    super().__init__()\n",
        "    assert offset >= 0.1 # avoid division by zero\n",
        "    self.offset = offset\n",
        "    self.hyperparam = self.offset\n",
        "\n",
        "  def index_to_weight(self, index: int, *args, **kwargs) -> float:\n",
        "    return 1/(index + self.offset)\n",
        "\n",
        "  def get_name(self) -> str:\n",
        "    return 'RankVoting'"
      ],
      "metadata": {
        "id": "571n2zTyYnWR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SoftmaxWeightedVoting(PredictionAggregation):\n",
        "  \"\"\"Weight predictions by the softmax of distances.\n",
        "\n",
        "  This is the default for kNN evaluation of SSL models such as DinoV2.\n",
        "  High temperature will make weights more similar,\n",
        "  low temperature will make weights closer to a one-hot distribution.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, temperature: float = 0.07):\n",
        "    # the default of temperature = 0.07 comes from:\n",
        "    # https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/knn.py#L91\n",
        "    # Exemplary implementations:\n",
        "    # https://github.com/facebookresearch/dino/blob/main/eval_knn.py#L143\n",
        "    # https://github.com/facebookresearch/dinov2/blob/main/dinov2/eval/knn.py#L99\n",
        "    super().__init__()\n",
        "    self.temperature = temperature\n",
        "    self.hyperparam = self.temperature\n",
        "\n",
        "  def index_to_weight(self, index: int, distances: list, *args, **kwargs) -> float:\n",
        "    \"\"\"Get softmax weight.\n",
        "\n",
        "    Note that this implementation assumes distances to be normalized to [0, 1],\n",
        "    and it doesn't perform a full softmax since we don't need probabilities\n",
        "    we're only interested in relative differences thus dividing by the sum\n",
        "    as in the softmax isn't necessary.\n",
        "    \"\"\"\n",
        "\n",
        "    if len(distances) > 1:\n",
        "      assert distances[0] <= distances[1], 'Distances must be sorted'\n",
        "\n",
        "    # this code assumes distances to be normalized to [0, 1]\n",
        "    assert max(distances) <= 1.01\n",
        "    assert min(distances) >= -0.01\n",
        "\n",
        "    # new implementation\n",
        "    distance = distances[index]\n",
        "    similarity = 1 - distance\n",
        "    return np.exp(similarity / self.temperature)\n",
        "\n",
        "  def get_name(self) -> str:\n",
        "    return 'SoftmaxVoting'"
      ],
      "metadata": {
        "id": "BhlRTIajijP9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_acc(df, aggregator, max_num_neighbors):\n",
        "\n",
        "  assert len(df['featurizer'].unique()) == 1\n",
        "\n",
        "  df['prediction_at_k'] = df.apply(\n",
        "      lambda x: aggregator.predict(predictions=first_k(x.neighbor_classes, max_num_neighbors),\n",
        "                                   distances=first_k(x.neighbor_distances, max_num_neighbors),\n",
        "                                   neighbor_image_ids=first_k(x.neighbor_image_ids, max_num_neighbors),\n",
        "                                   featurizer=x.featurizer),\n",
        "      axis=1)\n",
        "\n",
        "  accuracy_at_k = {}\n",
        "  for k in range(max_num_neighbors):\n",
        "    accuracy = calculate_accuracy(df, k, 'prediction_at_k')\n",
        "    accuracy_at_k[k] = 100.0 * accuracy\n",
        "\n",
        "  return accuracy_at_k\n",
        "\n",
        "def get_max_acc(df, aggregator, max_num_neighbors):\n",
        "  return max(get_acc(df=df, aggregator=aggregator, max_num_neighbors=max_num_neighbors).values())"
      ],
      "metadata": {
        "id": "9G0eY18uqH_Z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Aggregation method results"
      ],
      "metadata": {
        "id": "F0rZgr3yqXmW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "\n",
        "model_to_ylim = {'dinov2_vitl14': (80.5, 84.0),\n",
        "                 'dinov2_vitb14': (78.5, 82.5),\n",
        "                 'dinov2_vits14': (75, 79.5),\n",
        "                 'clip-vit_l14': (75, 80.5),\n",
        "                 'clip-vit_b16': (68, 74.5)}\n",
        "\n",
        "for FEATURIZER in models:\n",
        "  df = read_scaling_df(model=FEATURIZER,\n",
        "                       query_dataset='imagenet2012', query_split='validation',\n",
        "                       memory_dataset='imagenet2012', memory_split='train')\n",
        "  aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(offset=2.0),\n",
        "    ]\n",
        "  plot_neighbor_scaling(df=df,\n",
        "                        max_num_neighbors=100,\n",
        "                        aggregators=aggregators,\n",
        "                        ylim = model_to_ylim[FEATURIZER],\n",
        "                        save_fig_path=f'{FIGURE_DIR}/imagenet2012_aggregators_{FEATURIZER}_no_pruning.pdf');"
      ],
      "metadata": {
        "id": "NerwjlaDkwcv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "df_combined = read_multiple_scaling_dfs(models=models, query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "J58SjoP6kwl6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "aggregators = [\n",
        "    RankWeightedVoting(offset=2.0),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=df_combined,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (67, 84.0),\n",
        "                      label_aggregator=False,\n",
        "                      save_fig_path=f'{FIGURE_DIR}/imagenet2012_all_models_no_pruning.pdf');"
      ],
      "metadata": {
        "id": "odnL2hJkkwv9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "\n",
        "model_to_ylim = {'dinov2_vitl14': (48, 65),\n",
        "                 'dinov2_vitb14': (48, 63),\n",
        "                 'dinov2_vits14': (42, 58),\n",
        "                 'clip-vit_l14': (30, 42),\n",
        "                 'clip-vit_b16': (21, 24)}\n",
        "\n",
        "for FEATURIZER in models:\n",
        "  df = read_scaling_df(model=FEATURIZER,\n",
        "                       query_dataset='inaturalist', query_split='validation',\n",
        "                       memory_dataset='inaturalist', memory_split='train')\n",
        "  aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(offset=2.0),\n",
        "    ]\n",
        "  plot_neighbor_scaling(df=df,\n",
        "                        max_num_neighbors=100,\n",
        "                        aggregators=aggregators,\n",
        "                        ylim = model_to_ylim[FEATURIZER],\n",
        "                        save_fig_path=f'{FIGURE_DIR}/inaturalist_aggregators_{FEATURIZER}_no_pruning.pdf');"
      ],
      "metadata": {
        "id": "pbG_1a5y4pKD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "df_combined = read_multiple_scaling_dfs(models=models, query_dataset='inaturalist', query_split='validation', memory_dataset='inaturalist', memory_split='train')"
      ],
      "metadata": {
        "id": "rBwtuY3r4V9G"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "aggregators = [\n",
        "    RankWeightedVoting(offset=2.0),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=df_combined,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (20, 65),\n",
        "                      label_aggregator=False,\n",
        "                      save_fig_path=f'{FIGURE_DIR}/inaturalist_all_models_no_pruning.pdf');"
      ],
      "metadata": {
        "id": "efwqkUIH4TmB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "\n",
        "for model in models:\n",
        "  df_val = read_scaling_df(model=model,\n",
        "                           query_dataset='imagenet2012', query_split='validation',\n",
        "                           memory_dataset='imagenet2012', memory_split='train')\n",
        "  print()\n",
        "  print(model)\n",
        "  get_acc_at_k_table(df=df_val,\n",
        "                     aggregators=[PluralityVoting(),\n",
        "                                  DistanceWeightedVoting(exponent=1.0),\n",
        "                                  SoftmaxWeightedVoting(),\n",
        "                                  RankWeightedVoting(offset=2.0)],\n",
        "                     k_list=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])"
      ],
      "metadata": {
        "id": "oR-F0fo_1cU6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Hyperparameter sensitivity"
      ],
      "metadata": {
        "id": "AyDozisYt5pu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_hyperparameter_accuracy(df, aggregator_fn, hyperparams, max_num_neighbors, verbose=False, save_fig_path=None):\n",
        "  \"\"\"Plot accuracy as a function of aggregation hyperparameter.\"\"\"\n",
        "\n",
        "  plt.figure(figsize=(8, 5))\n",
        "\n",
        "  for _, featurizer in enumerate(df['featurizer'].unique()):\n",
        "    featurizer_df = copy.deepcopy(df[df['featurizer'] == featurizer])\n",
        "\n",
        "    acc_list = []\n",
        "    for hyperparam in hyperparams:\n",
        "      acc = get_max_acc(df=featurizer_df, aggregator=aggregator_fn(hyperparam), max_num_neighbors=max_num_neighbors)\n",
        "      acc_list.append(acc)\n",
        "      if verbose:\n",
        "        print(featurizer, hyperparam, acc)\n",
        "\n",
        "    plt.plot(hyperparams, acc_list, marker='o', linestyle='-', linewidth=1.5, markersize=8,\n",
        "             color=featurizer_to_color[featurizer],\n",
        "             label=featurizer_to_name[featurizer])\n",
        "\n",
        "  ax = plt.gca()\n",
        "  ax.spines['top'].set_visible(False)\n",
        "  ax.spines['right'].set_visible(False)\n",
        "  plt.xticks(fontsize=12)\n",
        "  plt.yticks(fontsize=12)\n",
        "  plt.legend(fontsize=12)\n",
        "\n",
        "  plt.xlabel('Hyperparameter', fontsize=14)\n",
        "  plt.ylabel('Top-1 accuracy (%)', fontsize=14)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0.0)\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)"
      ],
      "metadata": {
        "id": "ic2lQB4mt9jJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "df_combined = read_multiple_scaling_dfs(models=models,\n",
        "                                        query_dataset='imagenet2012', query_split='validation',\n",
        "                                        memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "RcRYfKIzt9dd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_hyperparameter_accuracy(df=df_combined,\n",
        "                             aggregator_fn=RankWeightedVoting,\n",
        "                             hyperparams = [1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],\n",
        "                             max_num_neighbors=100,\n",
        "                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_RankVoting.pdf');"
      ],
      "metadata": {
        "id": "3gAeB69ut9l0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_hyperparameter_accuracy(df=df_combined,\n",
        "                             aggregator_fn=DistanceWeightedVoting,\n",
        "                             hyperparams = [0.0, 2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5, 25.0, 27.5, 30, 32.5, 35, 37.5, 40],\n",
        "                             max_num_neighbors=10,\n",
        "                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_DistanceVoting.pdf');"
      ],
      "metadata": {
        "id": "J9mQLFewzPpI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_hyperparameter_accuracy(df=df_combined,\n",
        "                             aggregator_fn=SoftmaxWeightedVoting,\n",
        "                             hyperparams = [0.005, 0.01, 0.025, 0.05, 0.07, 0.09, 0.12, 0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],\n",
        "                             max_num_neighbors=10,\n",
        "                             save_fig_path=f'{FIGURE_DIR}/hyperparameters_SoftmaxVoting.pdf');"
      ],
      "metadata": {
        "id": "H7CFpNTzzWea"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Memory pruning plots"
      ],
      "metadata": {
        "id": "Xwx81g6svdS9"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### General pruning functionality"
      ],
      "metadata": {
        "id": "2xiTG7wKekFU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_pruning_exclude_set(featurizer_list, metric_name, frac_data_excluded):\n",
        "  df = read_pruning_data(metric_name)\n",
        "\n",
        "  # Scores are sorted such that high scores mean keeping example is beneficial\n",
        "  # See REDACTED FOR ANONYMITY\n",
        "  df_sorted = df.sort_values(by=metric_name, ascending=True)\n",
        "\n",
        "  # Calculate the number of rows to keep\n",
        "  num_rows_to_keep = int(len(df) * frac_data_excluded)\n",
        "\n",
        "  # Select the top rows based on the calculated number\n",
        "  top_rows = df_sorted.head(num_rows_to_keep)\n",
        "\n",
        "  # Extract the img_name column and put it in a set\n",
        "  img_names_set = set(top_rows['img_name'])\n",
        "\n",
        "  exclude_set = {}\n",
        "  for featurizer in featurizer_list:\n",
        "    exclude_set[featurizer] = img_names_set\n",
        "\n",
        "  return exclude_set"
      ],
      "metadata": {
        "id": "56tFsP07ejeE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_class_from_imagename(image_id, dataset='imagenet'):\n",
        "  if dataset == 'imagenet':\n",
        "    return image_id.split('_')[0]\n",
        "  else:\n",
        "    raise ValueError(f'Unknown dataset: {dataset}')"
      ],
      "metadata": {
        "id": "9C3hjlkCemOd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_wrong_and_correct_neighbors(df: pd.DataFrame, max_num_neighbors: int = 100) -> None:\n",
        "  \"\"\"Return dicts of {img_id: count} for wrong and correct neighbors.\n",
        "\n",
        "  If max_num_neighbors is set, only use the first max_num_neighbors neighbors.\n",
        "  \"\"\"\n",
        "\n",
        "  assert len(df) > 0, print(len(df))\n",
        "\n",
        "  # convert to np.arrays for speed\n",
        "  neighbor_classes = np.array(df['neighbor_classes'].tolist())\n",
        "  neighbor_classes = neighbor_classes[:, :max_num_neighbors]\n",
        "\n",
        "  neighbor_image_ids = np.array(df['neighbor_image_ids'].tolist())\n",
        "  neighbor_image_ids = neighbor_image_ids[:, :max_num_neighbors]\n",
        "\n",
        "  image_class = np.array(df['image_class'].tolist())\n",
        "\n",
        "  # broadcast from 1D to 2D\n",
        "  image_class_matrix = np.repeat(image_class[:, np.newaxis],\n",
        "                                 repeats=neighbor_image_ids.shape[1],\n",
        "                                 axis=1)\n",
        "\n",
        "  def _get_neighbor_count_dict(comparison_operator):\n",
        "\n",
        "    # get neighbor image ids\n",
        "    match_indices = np.where(comparison_operator(image_class_matrix, neighbor_classes))\n",
        "    neighbors = neighbor_image_ids[match_indices]\n",
        "\n",
        "    unique_elements, counts = np.unique(neighbors, return_counts=True)\n",
        "    return dict(zip(unique_elements, counts))\n",
        "\n",
        "  def _get_count(d):\n",
        "    count = 0\n",
        "    for _, v in d.items():\n",
        "      count += v\n",
        "    return count\n",
        "\n",
        "  wrong_neighbors = _get_neighbor_count_dict(lambda x, y: x != y)\n",
        "  num_wrong_neighbors = _get_count(wrong_neighbors)\n",
        "  print(f'Found {len(wrong_neighbors.keys())} wrong neighbors, occurring a total of {num_wrong_neighbors} times.')\n",
        "\n",
        "  correct_neighbors = _get_neighbor_count_dict(lambda x, y: x == y)\n",
        "  num_correct_neighbors = _get_count(correct_neighbors)\n",
        "  print(f'Found {len(correct_neighbors.keys())} correct neighbors, occurring a total of {num_correct_neighbors} times.')\n",
        "  print(f'Total number of unique neighbors found: {len(set(correct_neighbors.keys()).union(set(wrong_neighbors.keys())))}, occuring a total of {num_wrong_neighbors + num_correct_neighbors} times.')\n",
        "\n",
        "  assert int(len(df) * max_num_neighbors) == num_correct_neighbors + num_wrong_neighbors\n",
        "\n",
        "  return wrong_neighbors, correct_neighbors"
      ],
      "metadata": {
        "id": "keygCUeVem0B"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_exclude_set(neighbor_dict, threshold):\n",
        "  \"\"\"Return neighbors with a count of at least threshold.\"\"\"\n",
        "\n",
        "  exclude_set = set()\n",
        "  for k, v in neighbor_dict.items():\n",
        "    if v >= threshold:\n",
        "      exclude_set.add(k)\n",
        "  return exclude_set"
      ],
      "metadata": {
        "id": "rCPQ6NnIem2h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_exclude_set_wrong_minus_correct_difference(wrong_neighbor_dict, correct_neighbor_dict, threshold):\n",
        "  \"\"\"Return neighbors with a count difference of at least threshold between wrong and correct neighbors.\"\"\"\n",
        "\n",
        "  exclude_set = set()\n",
        "\n",
        "  candidates = set(wrong_neighbor_dict.keys()).union(set(correct_neighbor_dict.keys()))\n",
        "  for k in candidates:\n",
        "\n",
        "    wrong_count = wrong_neighbor_dict.get(k)\n",
        "    if not wrong_count:\n",
        "      wrong_count = 0\n",
        "\n",
        "    correct_count = correct_neighbor_dict.get(k)\n",
        "    if not correct_count:\n",
        "      correct_count = 0\n",
        "\n",
        "    if wrong_count - correct_count >= threshold:\n",
        "      exclude_set.add(k)\n",
        "\n",
        "  return exclude_set"
      ],
      "metadata": {
        "id": "IsO4No1aem45"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Hard memory pruning"
      ],
      "metadata": {
        "id": "cHr4FT_nbRlU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "df_val = read_scaling_df(model='dinov2_vitl14',\n",
        "                         query_dataset='imagenet2012', query_split='validation',\n",
        "                         memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "OCohPm8Pv6iP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_train = read_scaling_df(model='dinov2_vitl14',\n",
        "                           query_dataset='imagenet2012', query_split='train',\n",
        "                           memory_dataset='imagenet2012', memory_split='train',\n",
        "                           remove_identical_neighbors=True)"
      ],
      "metadata": {
        "id": "ewEFg_qpv-pN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "wrong_neighbors_train, correct_neighbors_train = get_wrong_and_correct_neighbors(df=df_train, max_num_neighbors=100)"
      ],
      "metadata": {
        "id": "hXdf_OFHvlO1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# PluralityVoting\n",
        "# no pruning: 83.22\n",
        "# hard pruning: 83.33 for k=10 and excluding 26257 neighbors based on threshold 128\n",
        "aggregators = [PluralityVoting()]\n",
        "for i in [32, 64, 128, 192, 256, 320, 384, 448, 512]:\n",
        "  aggregator = PluralityVoting()\n",
        "  exclude_set = get_exclude_set(neighbor_dict=wrong_neighbors_train, threshold=i)\n",
        "  print(i, len(exclude_set))\n",
        "  aggregator.add_exclude_sets({'dinov2_vitl14': exclude_set})\n",
        "  aggregators.append(aggregator)\n",
        "\n",
        "plot_neighbor_scaling(df=df_val,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (80.0, 84.5));"
      ],
      "metadata": {
        "id": "6d_JT52NvlRb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# RankWeightedVoting(2.0)\n",
        "# no pruning: 83.62\n",
        "# hard pruning: 83.73 for k=20 and excluding 26257 neighbors based on threshold 128\n",
        "\n",
        "aggregators = [RankWeightedVoting(offset=2.0)]\n",
        "for i in [32, 64, 128, 192, 256, 320, 384, 448, 512]:\n",
        "  aggregator = RankWeightedVoting(offset=2.0)\n",
        "  exclude_set = get_exclude_set(neighbor_dict=wrong_neighbors_train, threshold=i)\n",
        "  print(i, len(exclude_set))\n",
        "  aggregator.add_exclude_sets({'dinov2_vitl14': exclude_set})\n",
        "  aggregators.append(aggregator)\n",
        "\n",
        "plot_neighbor_scaling(df=df_val,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (80.0, 84.5));"
      ],
      "metadata": {
        "id": "GVB9a-dgvlTW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Soft memory pruning"
      ],
      "metadata": {
        "id": "N7afcezV-X7K"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def swarm_plot(data_as_list, sample_size=100):\n",
        "\n",
        "  assert type(data_as_list) is list\n",
        "\n",
        "  if len(data_as_list) > sample_size:\n",
        "    # plotting is slow for many data points, therefore subsample if too large\n",
        "    sampled_x = np.random.choice(data_as_list, size=sample_size, replace=False)\n",
        "  else:\n",
        "    sampled_x = data_as_list\n",
        "\n",
        "  data = pd.DataFrame({'Value': sampled_x})\n",
        "\n",
        "  plt.figure(figsize=(10, 6))\n",
        "  sns.swarmplot(x='Value', data=data)\n",
        "  plt.show()"
      ],
      "metadata": {
        "id": "Ophsh3iqm55G"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SoftPluralityVoting(PredictionAggregation):\n",
        "  \"\"\"Soft-weight predictions by the reliability of their neighbors.\"\"\"\n",
        "\n",
        "  def __init__(self, image_id_to_weight: dict[str, float]):\n",
        "    super().__init__()\n",
        "    self.image_id_to_weight = image_id_to_weight\n",
        "\n",
        "  def index_to_weight(self, index: int, neighbor_image_ids, *args, **kwargs) -> float:\n",
        "\n",
        "    neighbor_img_id = neighbor_image_ids[index]\n",
        "\n",
        "    if neighbor_img_id in self.image_id_to_weight:\n",
        "      soft_weight = self.image_id_to_weight[neighbor_img_id]\n",
        "    else:\n",
        "      soft_weight = 1.0\n",
        "\n",
        "    return soft_weight"
      ],
      "metadata": {
        "id": "0KPJwaMg0_OC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SoftRankVoting(PredictionAggregation):\n",
        "  \"\"\"Soft-weight predictions by the reliability & rank of their neighbors.\"\"\"\n",
        "\n",
        "  def __init__(self, image_id_to_weight: dict[str, float], offset: float = 2.0):\n",
        "    super().__init__()\n",
        "    assert offset >= 0.1 # avoid division by zero\n",
        "    self.offset = offset\n",
        "    self.image_id_to_weight = image_id_to_weight\n",
        "    self.hyperparam = self.offset\n",
        "\n",
        "  def index_to_weight(self, index: int, neighbor_image_ids, *args, **kwargs) -> float:\n",
        "\n",
        "    neighbor_img_id = neighbor_image_ids[index]\n",
        "\n",
        "    if neighbor_img_id in self.image_id_to_weight:\n",
        "      soft_weight = self.image_id_to_weight[neighbor_img_id]\n",
        "    else:\n",
        "      soft_weight = 1.0\n",
        "\n",
        "    return (1/(index + self.offset)) * soft_weight\n",
        "\n",
        "  def get_name(self):\n",
        "    return 'SoftRankVoting (ours)'"
      ],
      "metadata": {
        "id": "uVKfbOLfjJA1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_soft_weights(wrong_neighbors, constant = 1.0, dividend=1.75):\n",
        "  soft_weights = {}\n",
        "  for k, v in wrong_neighbors.items():\n",
        "    soft_weights[k] = dividend / (constant + v)\n",
        "  return soft_weights"
      ],
      "metadata": {
        "id": "sGV6dKHOi9CC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_train = read_scaling_df(model='dinov2_vitl14',\n",
        "                           query_dataset='imagenet2012', query_split='train',\n",
        "                           memory_dataset='imagenet2012', memory_split='train',\n",
        "                           remove_identical_neighbors=True)"
      ],
      "metadata": {
        "id": "3V37tiyuAI0I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_val = read_scaling_df(model='dinov2_vitl14',\n",
        "                         query_dataset='imagenet2012', query_split='validation',\n",
        "                         memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "RdSg0Bv3AIV6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "wrong_neighbors_train, correct_neighbors_train = get_wrong_and_correct_neighbors(df=df_train, max_num_neighbors=10)"
      ],
      "metadata": {
        "id": "JJPtAs3yAEDt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "soft_weights = get_soft_weights(wrong_neighbors_train)"
      ],
      "metadata": {
        "id": "6qugyAnYjixi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x = list(wrong_neighbors_train.values())\n",
        "swarm_plot(x, sample_size=100)"
      ],
      "metadata": {
        "id": "WWP7jf5unDEx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x = list(soft_weights.values())\n",
        "swarm_plot(x, sample_size=150)"
      ],
      "metadata": {
        "id": "tRKuzdKBm9OK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "aggregators = [SoftRankVoting(image_id_to_weight=soft_weights), RankWeightedVoting(offset=2.0)]\n",
        "_ = plot_neighbor_scaling(df=df_val,\n",
        "                          max_num_neighbors=100,\n",
        "                          aggregators=aggregators,\n",
        "                          ylim = (81.0, 84.5));"
      ],
      "metadata": {
        "id": "IAEdNW5CmjiE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "aggregators = [SoftPluralityVoting(image_id_to_weight=soft_weights), PluralityVoting()]\n",
        "_ = plot_neighbor_scaling(df=df_val,\n",
        "                          max_num_neighbors=100,\n",
        "                          aggregators=aggregators,\n",
        "                          ylim = (81.0, 84.5));"
      ],
      "metadata": {
        "id": "5j4TlaLSmjky"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    SoftRankVoting(image_id_to_weight=soft_weights, offset=2.0)\n",
        "]\n",
        "\n",
        "_ = plot_neighbor_scaling(df=df_val,\n",
        "                          max_num_neighbors=100,\n",
        "                          aggregators=aggregators,\n",
        "                          ylim = (80.0, 84.5),\n",
        "                          save_fig_path=f'{FIGURE_DIR}/soft_pruning_vs_baselines_dinov2_vitl14.pdf');"
      ],
      "metadata": {
        "id": "hZwva2EZmmqz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_hyperparameter_accuracy(df=df_val,\n",
        "                             aggregator_fn=lambda y: SoftRankVoting(image_id_to_weight=get_soft_weights(wrong_neighbors=wrong_neighbors_train, constant=1.0, dividend=1.75)),\n",
        "                             hyperparams = np.linspace(1, 1.2, 3),\n",
        "                             max_num_neighbors=100,\n",
        "                             verbose=True);"
      ],
      "metadata": {
        "id": "9_SdLWZfkF2r"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## OOD robustness analysis"
      ],
      "metadata": {
        "id": "yzY_Ow-hfgMk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_accuracy_df(featurizer_list, datasets, aggregators, normalize_distances=False, memory_dataset='imagenet2012', max_num_neighbors=100):\n",
        "\n",
        "  rows = []\n",
        "\n",
        "  for featurizer in featurizer_list:\n",
        "    for qdataset, qsplit in datasets.items():\n",
        "      data_df = read_scaling_df(model=featurizer,\n",
        "                                query_split=qsplit, query_dataset=qdataset,\n",
        "                                memory_dataset=memory_dataset,\n",
        "                                normalize_distances=normalize_distances)\n",
        "      for aggregator in aggregators:\n",
        "        acc = get_max_acc(df=data_df,\n",
        "                          aggregator=aggregator,\n",
        "                          max_num_neighbors=max_num_neighbors)\n",
        "        row = {'featurizer': featurizer,\n",
        "               'qdataset': qdataset,\n",
        "               'qsplit': qsplit,\n",
        "               'aggregator': aggregator.get_name(),\n",
        "               'accuracy': acc}\n",
        "        rows.append(row)\n",
        "\n",
        "  return pd.DataFrame(rows)"
      ],
      "metadata": {
        "id": "qB7swLK7gjti"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "featurizer_list = ['dinov2_vitl14']\n",
        "datasets = {'imagenet2012': 'validation',\n",
        "            'imagenet-v2': 'test',\n",
        "            'imagenet-r': 'test',\n",
        "            'imagenet-a': 'test',\n",
        "            'imagenet-sketch': 'test'}\n",
        "            #'imagenet-real': 'test'}\n",
        "aggregators = [PluralityVoting(), DistanceWeightedVoting(), SoftmaxWeightedVoting(), RankWeightedVoting()]"
      ],
      "metadata": {
        "id": "5Y9nMBRAfhdM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_JFT = get_accuracy_df(featurizer_list, datasets, aggregators, memory_dataset='jft-with-vit22b-labels', max_num_neighbors=10, normalize_distances=True)"
      ],
      "metadata": {
        "id": "8RT2UkfY0wkv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_JFT"
      ],
      "metadata": {
        "id": "R8ykiPDxpXvI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_IN = get_accuracy_df(featurizer_list, datasets, aggregators, memory_dataset='imagenet2012', max_num_neighbors=100)"
      ],
      "metadata": {
        "id": "G8TiEc4g93Zr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_IN = get_accuracy_df(featurizer_list=featurizer_list, datasets={'imagenet-real': 'test'}, aggregators=aggregators, memory_dataset='imagenet2012', max_num_neighbors=10, normalize_distances=True)"
      ],
      "metadata": {
        "id": "ZWjFSLpyTgBk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_IN"
      ],
      "metadata": {
        "id": "TcfkMdIKp_zp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "OOD_df_IN # TODO delete after comparing to top plot"
      ],
      "metadata": {
        "id": "ny3O92q8j6u6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## NINCO out-of-distribution classes"
      ],
      "metadata": {
        "id": "k62CGdDZ_rSL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "df_IN_baseline = read_scaling_df(model='dinov2_vitl14', query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "WCsdy3vzIbJ-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_IN_val = read_scaling_df(model='dinov2_vitl14', query_split='validation', query_dataset='imagenet2012', memory_dataset='imagenet2012-and-ninco', memory_split='train-and-test')"
      ],
      "metadata": {
        "id": "g4rDtBVz_vuO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_NINCO_test = read_scaling_df(model='dinov2_vitl14', query_split='test', query_dataset='ninco', memory_dataset='imagenet2012-and-ninco', memory_split='train-and-test', remove_identical_neighbors=True)"
      ],
      "metadata": {
        "id": "BzweGsaoABKq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_NINCO_OOD = read_scaling_df(model='dinov2_vitl14', query_split='test', query_dataset='ninco', memory_dataset='imagenet2012', memory_split='train')"
      ],
      "metadata": {
        "id": "5zluLL7MgItD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "NINCO_classes = ['Caracal caracal caracal', 'amphiuma_means', 'aphanizomenon_flosaquae', 'araneus_gemma',\n",
        "                 'arctocephalus_galapagoensis', 'bagpipe', 'batrachoseps_attenuatus', 'cable', 'chicken_quesadilla',\n",
        "                 'cirsium_pitcheri', 'creme_brulee', 'ctenolepisma_longicaudata', 'cup_cakes',\n",
        "                 'darlingtonia_californica', 'dendrolagus_lumholtzi', 'donuts', 'door', 'empty_water_dispencer',\n",
        "                 'epithelantha_micromeris', 'erysimum_franciscanum', 'f_field_road', 'f_forest_path',\n",
        "                 'ferocactus_pilosus', 'fire_extinguisher', 'fireworks', 'french_fries', 'glass_of_milk',\n",
        "                 'gramophone', 'haemulon_sciurus', 'high heels', 'hindu_temple', 'hippopus_hippopus',\n",
        "                 'lasionycteris_noctivagans', 'lathyrus_odoratus', 'lepomis_auritus', 'leptoglossus_phyllopus',\n",
        "                 'mbira', 'microcystis_wesenbergii', 'octopus_bimaculoides', 'octopus_rubescens',\n",
        "                 'ozotoceros_bezoarticus', 'platycephalus_fuscus', 'polistes_dominula', 'pseudorca_crassidens',\n",
        "                 'pyramid', 's_sky', 'sarpa_salpa', 'sarracenia_alata', 'scissors', 'sepia_apama',\n",
        "                 'sepia_officinalis', 'sepioteuthis_australis', 'shuttlecock', 'skipper_caterpillar',\n",
        "                 'spaghetti_bolognese', 'stapler', 'streptopus_lanceolatus', 'tapirus_bairdii', 'triturus_marmoratus',\n",
        "                 'tursiops_aduncus', 'vaccinium_reticulatum', 'waffles', 'walker', 'windsor_chair']\n",
        "print(len(NINCO_classes))"
      ],
      "metadata": {
        "id": "r_3cXnIiALMq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Baseline: memory=IN-train, query=IN-val\n",
        "aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=df_IN_baseline,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (80.0, 84.0));"
      ],
      "metadata": {
        "id": "OtYzEgPFIlZ6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# memory=IN-train-and-NINCO, query=IN-val\n",
        "aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=df_IN_val,\n",
        "                      max_num_neighbors=100,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (80.0, 84.0));"
      ],
      "metadata": {
        "id": "ePlItYOIBkci"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# memory=IN-train-and-NINCO, query=NINCO\n",
        "aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=df_NINCO_test,\n",
        "                      max_num_neighbors=99,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (70.0, 88.0));"
      ],
      "metadata": {
        "id": "f3x_m2C8GmAY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# memory=IN-train-and-NINCO, query=IN-val-and-NINCO\n",
        "aggregators = [\n",
        "    PluralityVoting(),\n",
        "    DistanceWeightedVoting(exponent=1.0),\n",
        "    SoftmaxWeightedVoting(),\n",
        "    RankWeightedVoting(),\n",
        "    ]\n",
        "plot_neighbor_scaling(df=pd.concat([df_IN_val, df_NINCO_test], ignore_index=True),\n",
        "                      max_num_neighbors=99,\n",
        "                      aggregators=aggregators,\n",
        "                      ylim = (70.0, 88.0));"
      ],
      "metadata": {
        "id": "de2DrxKMJBxy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### OOD detection analysis"
      ],
      "metadata": {
        "id": "K7rS3sU93Kd6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_distance_histogram(df):\n",
        "  fig, axs = plt.subplots(1, len(df['featurizer'].unique())+1, figsize=(15, 4))\n",
        "  for i, featurizer in enumerate(df['featurizer'].unique()):\n",
        "    featurizer_df = df[df['featurizer'] == featurizer]\n",
        "    min_distance = np.min(featurizer_df['neighbor_distances'].apply(lambda y: np.min(y)))\n",
        "    max_distance = np.max(featurizer_df['neighbor_distances'].apply(lambda y: np.max(y)))\n",
        "    mean_distance = np.mean(featurizer_df['neighbor_distances'].apply(lambda y: np.mean(y)))\n",
        "    median_distance = np.median(featurizer_df['neighbor_distances'].apply(lambda y: np.median(y)))\n",
        "    std_distance = np.std(featurizer_df['neighbor_distances'].apply(lambda y: np.std(y)))\n",
        "    print(f'{featurizer}, min_distance: {min_distance}, max_distance: {max_distance}, mean_distance: {mean_distance}, median_distance: {median_distance}, std_distance: {std_distance}')\n",
        "\n",
        "    fig.suptitle('Histogram of distances of first neighbors')\n",
        "    #axs[i].hist(featurizer_df['neighbor_distances'].apply(lambda x: x[0]), bins=20);\n",
        "    axs[i].hist(featurizer_df['neighbor_distances'].apply(lambda x: np.mean(x)), bins=20);\n",
        "    axs[i].set_title(featurizer)"
      ],
      "metadata": {
        "id": "LuRrfMRagbz2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_distance_histogram(df_NINCO_OOD)"
      ],
      "metadata": {
        "id": "5A9MXIBbghmz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_distance_histogram(df_IN_baseline)"
      ],
      "metadata": {
        "id": "pRis05QPgmng"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### OOD distance boxplot"
      ],
      "metadata": {
        "id": "3FsTZTNSx-Ms"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_distance_boxplot(df_baseline, df_OOD, statistic='median', save_fig_path=None):\n",
        "\n",
        "  assert statistic in ['median', 'mean']\n",
        "  if statistic == 'median':\n",
        "    ylabel = 'Median distance to first 100 neighbors'\n",
        "  else:\n",
        "    ylabel = 'Mean distance to first 100 neighbors'\n",
        "\n",
        "  df_baseline = copy.deepcopy(df_baseline)\n",
        "  df_OOD = copy.deepcopy(df_OOD)\n",
        "\n",
        "  df_baseline['group'] = 'in-distribution (ImageNet)'\n",
        "  df_OOD['group'] = 'OOD (NINCO)'\n",
        "\n",
        "  df_baseline['median'] = df_baseline['neighbor_distances'].apply(lambda x: np.median(x[:100]))\n",
        "  df_OOD['median'] = df_OOD['neighbor_distances'].apply(lambda x: np.median(x[:100]))\n",
        "  df_baseline['mean'] = df_baseline['neighbor_distances'].apply(lambda x: np.mean(x[:100]))\n",
        "  df_OOD['mean'] = df_OOD['neighbor_distances'].apply(lambda x: np.mean(x[:100]))\n",
        "\n",
        "  combined_df = pd.concat([df_baseline, df_OOD])\n",
        "\n",
        "  # Create the boxplot\n",
        "  plt.figure(figsize=(10, 6))\n",
        "\n",
        "  sns.boxplot(data=combined_df, x='group', y=statistic)\n",
        "  plt.gca().patch.set_facecolor('None')\n",
        "\n",
        "  ax = plt.gca()\n",
        "  ax.spines['top'].set_visible(False)\n",
        "  ax.spines['right'].set_visible(False)\n",
        "\n",
        "  sns.set(font_scale=1.6)  # Adjust the scaling factor as needed\n",
        "  sns.set_style(\"white\")\n",
        "\n",
        "  plt.xlabel('', fontsize=20)\n",
        "  plt.ylabel(ylabel, fontsize=17)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight', pad_inches=0)\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)"
      ],
      "metadata": {
        "id": "pLFU9c-ckUB4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_distance_boxplot(df_baseline=df_IN_baseline,\n",
        "                      df_OOD=df_NINCO_OOD,\n",
        "                      statistic='median',\n",
        "                      save_fig_path=f'{FIGURE_DIR}/boxplot_OOD_detection_NINCO_median.pdf')"
      ],
      "metadata": {
        "id": "BuLnhaaw0YYh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_distance_boxplot(df_baseline=df_IN_baseline,\n",
        "                      df_OOD=df_NINCO_OOD,\n",
        "                      statistic='mean',\n",
        "                      save_fig_path=f'{FIGURE_DIR}/boxplot_OOD_detection_NINCO_mean.pdf')"
      ],
      "metadata": {
        "id": "r_kyHB0o1G37"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Scaling dataset size"
      ],
      "metadata": {
        "id": "EFi68iKtKkxt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def read_imagenet_scaling_results_for_dataset_scaling_experiment(featurizers, sample_sizes, aggregator, max_num_neighbors=100):\n",
        "\n",
        "  df = pd.DataFrame(columns=['featurizer', 'sample_size', 'accuracy'])\n",
        "\n",
        "  counter = 0\n",
        "  for featurizer in featurizers:\n",
        "    for sample_size in sample_sizes:\n",
        "      scaling_df = read_scaling_df(model=featurizer,\n",
        "                                   query_dataset='imagenet2012',\n",
        "                                   query_split='validation',\n",
        "                                   memory_dataset='imagenet2012',\n",
        "                                   memory_split='train',\n",
        "                                   size=sample_size,\n",
        "                                   verbose=False)\n",
        "\n",
        "      acc = get_max_acc(df=scaling_df, aggregator=aggregator, max_num_neighbors=max_num_neighbors)\n",
        "      if sample_size == 'full':\n",
        "        sample_size = 1_281_167 # full IN train set\n",
        "      df.loc[counter] = [featurizer, sample_size, acc]\n",
        "      counter += 1\n",
        "\n",
        "  return df\n",
        "\n",
        "def read_jft_scaling_results_for_dataset_scaling_experiment(data_dir=DATA_DIR):\n",
        "  df1 = pd.read_csv(f'{data_dir}/dinov2_vitl14_JFT_scaling.csv')\n",
        "  df2 = pd.read_csv(f'{data_dir}/dinov2_vits14_JFT_scaling.csv')\n",
        "  combined_df = pd.concat([df1, df2], ignore_index=True)\n",
        "  return combined_df"
      ],
      "metadata": {
        "id": "2kvFPbQMqtix"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_data_scaling(df: pd.DataFrame,\n",
        "                      x_range = [0, 1, 2, 3, 4, 5, 6, 7],\n",
        "                      x_ticks = ['1', '10', '100', '1K', '10K', '100K', '1M'],\n",
        "                      linestyle='-',\n",
        "                      multiply_accuracy_by_100 = False,\n",
        "                      save_fig_path = None,\n",
        "                      plot_error_rate = True,\n",
        "                      fit = False,\n",
        "                      plot_yaxis_log_scale = True) -> None:\n",
        "  \"\"\"Plot accuracy as a function of memory dataset size.\"\"\"\n",
        "\n",
        "  plt.figure(figsize=(8, 5))\n",
        "\n",
        "  for _, featurizer in enumerate(df['featurizer'].unique()):\n",
        "    featurizer_df = df[df['featurizer'] == featurizer]\n",
        "\n",
        "    x = np.log10(featurizer_df['sample_size'])\n",
        "    y = featurizer_df['accuracy'].to_numpy()\n",
        "    if multiply_accuracy_by_100:\n",
        "      y = y * 100\n",
        "\n",
        "    if plot_error_rate:\n",
        "      y = 100 - y # convert accuracy to error rate\n",
        "\n",
        "    if fit:\n",
        "      assert plot_yaxis_log_scale and plot_error_rate and (not multiply_accuracy_by_100)\n",
        "      a, b = np.polyfit(np.log10(x), np.log10(y), deg=1)\n",
        "      print(f'{featurizer}: a = {a}, b = {b}')\n",
        "      x_range_for_fit = np.linspace(x_range[0], x_range[-1], 100).tolist()\n",
        "      fit = [10**(a * np.log10(x_val) + b) for x_val in x_range_for_fit]\n",
        "      plt.plot(x_range_for_fit, fit, linestyle='-', linewidth=1.75, color=featurizer_to_color[featurizer])\n",
        "\n",
        "    plt.plot(x, y, marker='o', linestyle=linestyle, linewidth=2, markersize=10,\n",
        "             color=featurizer_to_color[featurizer],\n",
        "             label=featurizer_to_name[featurizer])\n",
        "\n",
        "  # Use log y scale with human-readable accuracies\n",
        "  if plot_yaxis_log_scale:\n",
        "    plt.yscale('log', base=10)\n",
        "  ax = plt.gca()\n",
        "  ax.yaxis.set_minor_formatter('{x:.0f}')\n",
        "  ax.tick_params(axis='y', which='minor', labelsize=12)\n",
        "\n",
        "\n",
        "  plt.xticks(x_range, x_ticks)\n",
        "\n",
        "  ax.spines['top'].set_visible(False)\n",
        "  ax.spines['right'].set_visible(False)\n",
        "  plt.xticks(fontsize=12)\n",
        "  plt.yticks(fontsize=12)\n",
        "  plt.legend(fontsize=12)\n",
        "\n",
        "  plt.xlabel('Number of images in memory', fontsize=14)\n",
        "  plt.ylabel('Top-1 error rate (%)', fontsize=14)\n",
        "\n",
        "  if save_fig_path:\n",
        "    plt.savefig(file_opener(save_fig_path, 'wb'), format='pdf', bbox_inches='tight')\n",
        "    print(f'Saved figure to {save_fig_path}')\n",
        "    print_viewing_path(save_fig_path)"
      ],
      "metadata": {
        "id": "VFvSyIgvsNh8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "featurizer_list = ['dinov2_vitl14', 'dinov2_vitb14', 'dinov2_vits14', 'clip-vit_l14', 'clip-vit_b16']\n",
        "\n",
        "data_IN = read_imagenet_scaling_results_for_dataset_scaling_experiment(\n",
        "    featurizers=featurizer_list,\n",
        "    sample_sizes=[1_000, 10_000, 100_000, 'full'],\n",
        "    aggregator=PluralityVoting())"
      ],
      "metadata": {
        "id": "E1FTZuj5sa3j"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_IN"
      ],
      "metadata": {
        "id": "Tq0u6WEAx01d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_data_scaling(df=data_IN,\n",
        "                  x_range=[np.log10(i) for i in [1_000, 10_000, 100_000, 1_281_167]],\n",
        "                  x_ticks=['1K', '10K', '100K', '1.28M'],\n",
        "                  save_fig_path=f'{FIGURE_DIR}/memory_size_scaling_imagenet.pdf')"
      ],
      "metadata": {
        "id": "2I_WNURRUlwq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_JFT = read_jft_scaling_results_for_dataset_scaling_experiment()"
      ],
      "metadata": {
        "id": "LfRBdwE3_gCu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_JFT"
      ],
      "metadata": {
        "id": "5g4HGtek_hqK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# log fit in log log space\n",
        "plot_data_scaling(df=data_JFT,\n",
        "                  x_range=[np.log10(i) for i in [1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]],\n",
        "                  x_ticks=['1K', '10K', '100K', '1M', '10M', '100M', '1B'],\n",
        "                  linestyle='',\n",
        "                  fit = True,\n",
        "                  save_fig_path=f'{FIGURE_DIR}/memory_size_scaling_jft.pdf')"
      ],
      "metadata": {
        "id": "ZPHelNKpObOo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Raw JFT data"
      ],
      "metadata": {
        "id": "J53MYJXYThzb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "np.log10(data_JFT.loc[data_JFT['featurizer'] == 'dinov2_vitl14']['sample_size']).tolist()"
      ],
      "metadata": {
        "id": "HlU1QoSQSBYR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "[100 - x for x in data_JFT.loc[data_JFT['featurizer'] == 'dinov2_vitl14']['accuracy'].tolist()]"
      ],
      "metadata": {
        "id": "shE06t15SQec"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}