{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "When does dough become a bagel? ImageNet-M Evaluation",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/imagenet-mistakes/blob/main/notebooks/When_does_dough_become_a_bagel__ImageNet_M_Evaluation.ipynb)"
      ],
      "metadata": {
        "id": "B2nhG2wmQZbG"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This notebook demonstrate both ImageNet-M evaluation as well as \n",
        "Multi-Label validation and standard imagenet validation using pre-computed\n",
        "logits for the ViT3B and Greedy Soups models.\n"
      ],
      "metadata": {
        "id": "fdHHDfKeK4Ey"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k1hgxSQi-aR7"
      },
      "outputs": [],
      "source": [
        "#@title Run some imports.\n",
        "import json\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Accuracy and Loader functions.\n",
        "def imagenet_val_accuracy(imagenet_val_ds, logits_dict):\n",
        "  correct = 0\n",
        "  for example in tqdm.tqdm(imagenet_val_ds):\n",
        "    gt_label = example['label']\n",
        "    file_name = example['file_name'].numpy()\n",
        "    logits = logits_dict[file_name]\n",
        "    scores = tf.nn.softmax(logits).numpy()\n",
        "    if tf.nn.in_top_k(np.reshape(gt_label, [1]), scores[np.newaxis], k=1).numpy()[0]:\n",
        "      correct += 1\n",
        "  return correct / float(len(imagenet_val_ds))\n",
        "\n",
        "def imagenet_multilabel_accuracy(imagenet_multilabel_ds, logits_dict):\n",
        "  num_correct_per_class = {}\n",
        "  num_images_per_class = {}\n",
        "  for example in imagenet_multilabel_ds:\n",
        "    # We ignore all problematic images\n",
        "    if example['is_problematic'].numpy():\n",
        "        continue\n",
        "\n",
        "    correct_labels = (list(example['correct_multi_labels']) +\n",
        "                      list(example['unclear_multi_labels']))\n",
        "\n",
        "    # The label of the image in ImageNet\n",
        "    cur_class = example['original_label'].numpy()\n",
        "\n",
        "    # If we haven't processed this class yet, set the counters to 0\n",
        "    if cur_class not in num_correct_per_class:\n",
        "      num_correct_per_class[cur_class] = 0\n",
        "      assert cur_class not in num_images_per_class\n",
        "      num_images_per_class[cur_class] = 0\n",
        "\n",
        "    num_images_per_class[cur_class] += 1\n",
        "\n",
        "    # Get the predictions for this image\n",
        "    file_name = example['file_name'].numpy()\n",
        "    logits = logits_dict[file_name]\n",
        "    scores = tf.nn.softmax(logits).numpy()\n",
        "    top_pred = np.argmax(scores)\n",
        "\n",
        "    # We count a prediction as correct if it is marked as correct or unclear\n",
        "    # (i.e., we are lenient with the unclear labels)\n",
        "    if top_pred in correct_labels:\n",
        "      num_correct_per_class[cur_class] += 1\n",
        "\n",
        "  # Check that we have collected accuracy data for each of the 1,000 classes\n",
        "  num_classes = 1000\n",
        "  assert len(num_correct_per_class) == num_classes\n",
        "  assert len(num_images_per_class) == num_classes\n",
        "\n",
        "  # Compute the per-class accuracies and then average them\n",
        "  final_avg = 0\n",
        "  for cid in range(num_classes):\n",
        "    assert cid in num_correct_per_class\n",
        "    assert cid in num_images_per_class\n",
        "    final_avg += num_correct_per_class[cid] / num_images_per_class[cid]\n",
        "  final_avg /= num_classes\n",
        "  return final_avg\n",
        "\n",
        "def imagenet_m_correct(imagenet_m_ds, logits_dict):\n",
        "  correct = 0\n",
        "  for example in imagenet_m_ds:\n",
        "    correct_labels = (list(example['correct_multi_labels']) +\n",
        "                      list(example['unclear_multi_labels']))\n",
        "    file_name = example['file_name'].numpy()\n",
        "    logits = logits_dict[file_name]\n",
        "    scores = tf.nn.softmax(logits).numpy()\n",
        "    top_pred = np.argmax(scores)\n",
        "    if top_pred in correct_labels:\n",
        "      correct += 1\n",
        "  return correct\n",
        "\n",
        "def load_npy_as_dict(path):\n",
        "  logits_array = np.load(tf.io.gfile.GFile(path, 'rb'))\n",
        "  ret = {}\n",
        "  for i in range(len(logits_array)):\n",
        "    filename = f'ILSVRC2012_val_{i+1:08}.JPEG'.encode()\n",
        "    ret[filename] = logits_array[i]\n",
        "  return ret"
      ],
      "metadata": {
        "cellView": "form",
        "id": "drRqzhqC-ecV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Download all relevant assets.\n",
        "!wget https://raw.githubusercontent.com/modestyachts/ImageNetV2/master/data/metadata/class_info.json\n",
        "!wget https://raw.githubusercontent.com/tensorflow/datasets/master/tensorflow_datasets/image_classification/imagenet2012_validation_labels.txt\n",
        "!wget https://storage.googleapis.com/brain-car-datasets/imagenet-mistakes/human_accuracy_v3.0.0.json\n",
        "!wget https://storage.googleapis.com/brain-car-datasets/imagenet-mistakes/logits/vit3b.npz\n",
        "!wget https://storage.googleapis.com/brain-car-datasets/imagenet-mistakes/logits/greedysoups.npz\n"
      ],
      "metadata": {
        "id": "TjBItwj7CT15"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now we load and preprocess the files downloaded.  It should take less than a minute to run."
      ],
      "metadata": {
        "id": "OnV4Q6iAKOwK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# WNID to CID mapping.\n",
        "imagenet_class_info = json.load(open('class_info.json'))\n",
        "wnid_to_cid = {}\n",
        "for entry in imagenet_class_info:\n",
        "  wnid_to_cid[entry['wnid']] = entry['cid']\n",
        "\n",
        "# Validation set list, in wnid form.\n",
        "labels_list = open('imagenet2012_validation_labels.txt').read().strip().splitlines()\n",
        "\n",
        "# Human accuracy json\n",
        "human_accuracy = json.load(open('human_accuracy_v3.0.0.json'))\n",
        "imagenetv1_prefix = 'ILSVRC2012_val_'\n",
        "prefix_len = len(imagenetv1_prefix)\n",
        "annotated_images = {}\n",
        "for image_name, data in human_accuracy['initial_annots'].items():\n",
        "  if image_name.startswith(imagenetv1_prefix):\n",
        "    annotated_images[image_name] = data\n",
        "\n",
        "problematic_images = set(human_accuracy['problematic_images'].keys())\n",
        "imagenet_m_2022_files = set(human_accuracy['imagenet_m'])\n",
        "\n",
        "# Set up data structure of inputs.\n",
        "imagenet_multilabel_v3 = {}\n",
        "imagenet_m_2022 = {}\n",
        "for image_name, data in annotated_images.items():\n",
        "  is_problematic = image_name in problematic_images\n",
        "  correct_multi_labels = data.get('correct', [])\n",
        "  wrong_multi_labels = data.get('wrong', [])\n",
        "  unclear_multi_labels = data.get('unclear', [])\n",
        "\n",
        "  correct_cids = [wnid_to_cid[x] for x in correct_multi_labels]\n",
        "  wrong_cids = [wnid_to_cid[x] for x in wrong_multi_labels]\n",
        "  unclear_cids = [wnid_to_cid[x] for x in unclear_multi_labels]\n",
        "\n",
        "  image_id = int(image_name.split('_')[-1][:-5])\n",
        "  original_label = wnid_to_cid[labels_list[image_id-1]]\n",
        "\n",
        "  entry = {\n",
        "      'file_name': tf.constant(image_name),\n",
        "      'is_problematic': tf.constant(is_problematic),\n",
        "      'correct_multi_labels': correct_cids,\n",
        "      'wrong_multi_labels': wrong_cids,\n",
        "      'unclear_multi_labels': unclear_cids,\n",
        "      'label': original_label,\n",
        "      'original_label': tf.constant(original_label),\n",
        "  }\n",
        "  imagenet_multilabel_v3[image_name] = entry\n",
        "\n",
        "  if image_name in imagenet_m_2022_files:\n",
        "    imagenet_m_2022[image_name] = entry\n",
        "\n",
        "\n",
        "# Full imagenet validation labels.\n",
        "imagenet_val = {}\n",
        "for i, label in enumerate(labels_list):\n",
        "  index = i + 1\n",
        "  file_name = f'ILSVRC2012_val_{index:08}.JPEG'\n",
        "  imagenet_val[file_name] = {\n",
        "      'file_name': tf.constant(file_name),\n",
        "      'label': wnid_to_cid[label],\n",
        "  }\n",
        "\n",
        "# Load vit3b and greedysoups logits into dictionaries.\n",
        "vit3b = load_npy_as_dict('vit3b.npz')\n",
        "greedysoups = load_npy_as_dict('greedysoups.npz')"
      ],
      "metadata": {
        "id": "MBojairUCiil"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluate on ImageNet-M\n",
        "\n",
        "# We expect ViT3B to get 0 on ImageNet-M by construction.\n",
        "print('ImageNet-M ViT3B: ', imagenet_m_correct(imagenet_m_2022.values(), vit3b))\n",
        "\n",
        "# Greedy Soups gets 19 correct.\n",
        "print('ImageNet-M greedysoups: ', imagenet_m_correct(imagenet_m_2022.values(), greedysoups))"
      ],
      "metadata": {
        "id": "0Z-5A5WQCyWg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluate ImageNet Multi-Label Accuracy\n",
        "print('ImageNet-MLA ViT3B: ', imagenet_multilabel_accuracy(imagenet_multilabel_v3.values(), vit3b))\n",
        "print('ImageNet-MLA greedysoups: ', imagenet_multilabel_accuracy(imagenet_multilabel_v3.values(), greedysoups))"
      ],
      "metadata": {
        "id": "LNUu3htFGzPn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluate ImageNet Top-1 Accuracy\n",
        "print('ImageNet-Top-1 ViT3B: ', imagenet_val_accuracy(imagenet_val.values(), vit3b))\n",
        "print('ImageNet-Top-1 greedysoups: ', imagenet_val_accuracy(imagenet_val.values(), greedysoups))"
      ],
      "metadata": {
        "id": "giBjlajnIqte"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "vyPNoIuvJlER"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}