{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zGiRyw_lYkHa"
      },
      "source": [
        "# Zero-shot Prompt Ensembling for Text-Image Models Results\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "\u003ca href=\"https://githubtocolab.com/google/uncertainty-baselines/blob/main/experimental/multimodal/Zero_shot_prompt_ensembling_for_text_image_models_results.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
        "\n",
        "This notebook produces all of the CLIP results for the ICML 2023 paper [\"A Simple Zero-shot Prompt Weighting Technique to Improve Prompt Ensembling in Text-Image Models\"](https://arxiv.org/abs/2302.06235).\n",
        "\n",
        "If you find this notebook or our implementations in Uncertainty Baselines useful, please cite:\n",
        "```\n",
        "@InProceedings{allingham2023simple,\n",
        "  title = \t {A Simple Zero-shot Prompt Weighting Technique to Improve Prompt Ensembling in Text-Image Models},\n",
        "  author =       {Allingham, James Urquhart and Ren, Jie and Dusenberry, Michael W and Gu, Xiuye and Cui, Yin and Tran, Dustin and Liu, Jeremiah Zhe and Lakshminarayanan, Balaji},\n",
        "  booktitle = \t {Proceedings of the 40th International Conference on Machine Learning},\n",
        "  pages = \t {547--568},\n",
        "  year = \t {2023},\n",
        "  editor = \t {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},\n",
        "  volume = \t {202},\n",
        "  series = \t {Proceedings of Machine Learning Research},\n",
        "  month = \t {23--29 Jul},\n",
        "  publisher =    {PMLR},\n",
        "}\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O0-J7yjSr9h3"
      },
      "source": [
        "Note that the code uses a lot of memory, and so if the kernel crashes try either running on a machine with more memory, or try manually freeing memory with del statements."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZOfLAqOeFpSO"
      },
      "outputs": [],
      "source": [
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9iApUpbkd7aF"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "\n",
        "import jax\n",
        "print(jax.local_devices())\n",
        "\n",
        "import tensorflow as tf\n",
        "tf.config.experimental.set_visible_devices([], \"TPU\")\n",
        "print(tf.config.get_visible_devices())\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import ml_collections\n",
        "from importlib import reload\n",
        "\n",
        "import functools\n",
        "import itertools\n",
        "from typing import Sequence\n",
        "import multiprocessing\n",
        "import os\n",
        "from tqdm import tqdm\n",
        "import pickle\n",
        "from scipy import stats\n",
        "import pandas as pd\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import random\n",
        "\n",
        "from absl import app\n",
        "from absl import flags\n",
        "from absl import logging\n",
        "from clu import metric_writers\n",
        "from clu import parameter_overview\n",
        "from clu import periodic_actions\n",
        "from clu import preprocess_spec\n",
        "import flax\n",
        "import flax.linen as nn\n",
        "from flax.training import train_state\n",
        "import optax\n",
        "import ml_collections\n",
        "import ml_collections.config_flags\n",
        "import numpy as np\n",
        "\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.lines import Line2D\n",
        "text_width = 6.75133  # From the ICML LaTeX template.\n",
        "line_width = 3.25063  # From the ICML LaTeX template.\n",
        "matplotlib.rc('font', size=7)  # Controls default text sizes.\n",
        "matplotlib.rc('axes', titlesize=7)\n",
        "matplotlib.rc('axes', labelsize=7)\n",
        "matplotlib.rc('xtick', labelsize=6)\n",
        "matplotlib.rc('ytick', labelsize=6)\n",
        "matplotlib.rc('legend', fontsize=6)\n",
        "matplotlib.rc('figure', titlesize=8)\n",
        "matplotlib.rc('font', **{'family':'serif', 'serif': ['Palatino']})\n",
        "matplotlib.rc('text', usetex=True)\n",
        "\n",
        "\n",
        "from uncertainty_baselines.models import clip\n",
        "import robustness_metrics as rm\n",
        "import uncertainty_baselines as ub\n",
        "# TODO(jallingham): Fork remaining utils once imports below merged into UB API.\n",
        "# import train_utils  # local file import from baselines.jft\n",
        "# NOTE: Usually we do not allow cross-imports between subdirectories. We are\n",
        "# doing so here because this is an experimental directory and the offending\n",
        "# utils are soon to have much of their functionality merged into the UB API.\n",
        "from experimental.multimodal import input_utils\n",
        "from experimental.multimodal import checkpoint_utils\n",
        "from experimental.multimodal import multimodal_utils\n",
        "from experimental.multimodal import preprocess_utils\n",
        "from experimental.multimodal import simple_tokenizer\n",
        "from experimental.multimodal.configs import clip_common\n",
        "\n",
        "preprocess_utils = reload(preprocess_utils)\n",
        "multimodal_utils = reload(multimodal_utils)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T08CMdD_ee7l"
      },
      "outputs": [],
      "source": [
        "#@title Define config\n",
        "MAX_TEXT_LENGTH = 77\n",
        "def get_config(model_name='vit_b16'):\n",
        "  config = ml_collections.ConfigDict()\n",
        "\n",
        "  config.data_dir = '/mnt/disks/persist/data/'\n",
        "\n",
        "  config.model_name = model_name\n",
        "\n",
        "  config.dataset = 'laion400m'\n",
        "  config.train_split = 'all[10000:210000]'\n",
        "\n",
        "\n",
        "  config.batch_size = 5000\n",
        "\n",
        "  config.tokenizer_max_len = MAX_TEXT_LENGTH\n",
        "\n",
        "  INPUT_RES = clip_common.IMAGE_RESOLUTION[config.model_name]  # pylint: disable=invalid-name\n",
        "  train_image_pp = f'decode|inception_crop({INPUT_RES})'\n",
        "  train_image_pp += f'|value_range(0, 1)|normalize({clip_common.CLIP_IMAGE_MEAN}, {clip_common.CLIP_IMAGE_STD})'\n",
        "  text_pp = f'|clip_tokenize({config.tokenizer_max_len}, key=\"caption\", key_result=\"text\", bpe_path=\"uncertainty-baselines/experimental/multimodal/bpe_simple_vocab_16e6.txt.gz\")'\n",
        "  final_pp = '|keep([\"image\", \"text\"])'\n",
        "  config.pp_train = train_image_pp + text_pp + final_pp\n",
        "\n",
        "  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.\n",
        "\n",
        "  config.prefetch_to_device = 2\n",
        "  config.seed = 0\n",
        "\n",
        "  # Model section.\n",
        "  config.model_init = {\n",
        "      'vit_b16': 'ADD_PATH_HERE/clip_vit-b16.npy',\n",
        "      'vit_b32': 'ADD_PATH_HERE/clip_vit-b32.npy'\n",
        "  }[config.model_name]\n",
        "\n",
        "  config.convert_pytorch = True\n",
        "  config.model = ml_collections.config_dict.create(\n",
        "      **clip_common.CONFIGS[config.model_name])\n",
        "\n",
        "  # Optimizer section.\n",
        "  config.optim_name = 'Adam'\n",
        "  config.optim = ml_collections.ConfigDict()\n",
        "  config.grad_clip_norm = 1.0\n",
        "  config.weight_decay = 1e-5\n",
        "\n",
        "  config.lr = ml_collections.ConfigDict()\n",
        "  config.lr.base = 1e-4\n",
        "\n",
        "  # Zero-shot section.\n",
        "  def zeroshot_pp(n_classes, resize_method='bicubic'):\n",
        "    zeroshot_pp = f'decode|resize_small({INPUT_RES}, method=\"{resize_method}\")|central_crop({INPUT_RES})'\n",
        "    zeroshot_pp += f'|value_range(0, 1)|normalize({clip_common.CLIP_IMAGE_MEAN}, {clip_common.CLIP_IMAGE_STD})'\n",
        "    zeroshot_pp += f'|onehot({n_classes}, key=\"label\", key_result=\"label\")'\n",
        "    zeroshot_pp += '|keep([\"image\", \"label\"])'\n",
        "    return zeroshot_pp\n",
        "\n",
        "  config.zeroshot_eval_datasets = {\n",
        "      'imagenet': {\n",
        "          'dataset': 'imagenet2012',\n",
        "          'split': 'validation',\n",
        "          'classnames_key': 'imagenet',\n",
        "          'prompts_key': 'imagenet',\n",
        "          'pp_spec': zeroshot_pp(1000)\n",
        "      },\n",
        "      'imagenet_a': {\n",
        "          'dataset': 'imagenet_a',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'imagenet_a',\n",
        "          'prompts_key': 'imagenet',\n",
        "          'pp_spec': zeroshot_pp(1000)\n",
        "      },\n",
        "      'imagenet_r': {\n",
        "          'dataset': 'imagenet_r',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'imagenet_r',\n",
        "          'prompts_key': 'imagenet',\n",
        "          'pp_spec': zeroshot_pp(1000)\n",
        "      },\n",
        "      'imagenet_sketch': {\n",
        "          'dataset': 'imagenet_sketch',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'imagenet',\n",
        "          'prompts_key': 'imagenet',\n",
        "          'pp_spec': zeroshot_pp(1000)\n",
        "      },\n",
        "      'imagenet_v2': {\n",
        "          'dataset': 'imagenet_v2',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'imagenet',\n",
        "          'prompts_key': 'imagenet',\n",
        "          'pp_spec': zeroshot_pp(1000)\n",
        "      },\n",
        "      'caltech101': {\n",
        "          'dataset': 'caltech101',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'caltech101',\n",
        "          'prompts_key': 'caltech101',\n",
        "          'pp_spec': zeroshot_pp(102)\n",
        "      },\n",
        "      'cars196': {\n",
        "          'dataset': 'cars196',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'cars196',\n",
        "          'prompts_key': 'cars196',\n",
        "          'pp_spec': zeroshot_pp(196)\n",
        "      },\n",
        "      'cifar10': {\n",
        "          'dataset': 'cifar10',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'cifar10',\n",
        "          'prompts_key': 'cifar10',\n",
        "          'pp_spec': zeroshot_pp(10)\n",
        "      },\n",
        "      'cifar100': {\n",
        "          'dataset': 'cifar100',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'cifar100',\n",
        "          'prompts_key': 'cifar100',\n",
        "          'pp_spec': zeroshot_pp(100)\n",
        "      },\n",
        "      'dtd': {\n",
        "          'dataset': 'dtd',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'dtd',\n",
        "          'prompts_key': 'dtd',\n",
        "          'pp_spec': zeroshot_pp(47)\n",
        "      },\n",
        "      'eurosat': {\n",
        "          'dataset': 'eurosat',\n",
        "          'split': 'train',\n",
        "          'classnames_key': 'eurosat',\n",
        "          'prompts_key': 'eurosat',\n",
        "          'pp_spec': zeroshot_pp(10)\n",
        "      },\n",
        "      'food101': {\n",
        "          'dataset': 'food101',\n",
        "          'split': 'validation',\n",
        "          'classnames_key': 'food101',\n",
        "          'prompts_key': 'food101',\n",
        "          'pp_spec': zeroshot_pp(101)\n",
        "      },\n",
        "      'oxford_flowers102': {\n",
        "          'dataset': 'oxford_flowers102',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'oxford_flowers102',\n",
        "          'prompts_key': 'oxford_flowers102',\n",
        "          'pp_spec': zeroshot_pp(102)\n",
        "      },\n",
        "      'oxford_iiit_pet': {\n",
        "          'dataset': 'oxford_iiit_pet',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'oxford_iiit_pet',\n",
        "          'prompts_key': 'oxford_iiit_pet',\n",
        "          'pp_spec': zeroshot_pp(37)\n",
        "      },\n",
        "      'resisc45': {\n",
        "          'dataset': 'resisc45',\n",
        "          'split': 'train',\n",
        "          'classnames_key': 'resisc45',\n",
        "          'prompts_key': 'resisc45',\n",
        "          'pp_spec': zeroshot_pp(45)\n",
        "      },\n",
        "      'sun397': {\n",
        "          'dataset': 'sun397',\n",
        "          'split': 'test',\n",
        "          'classnames_key': 'sun397',\n",
        "          'prompts_key': 'sun397',\n",
        "          'pp_spec': zeroshot_pp(397)\n",
        "      },\n",
        "  }\n",
        "\n",
        "  return config, INPUT_RES\n",
        "\n",
        "config, image_resolution = get_config('vit_b16')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qa82kkFYcR8G"
      },
      "outputs": [],
      "source": [
        "#@title Create model\n",
        "seed = config.get('seed', 0)\n",
        "rng = jax.random.PRNGKey(seed)\n",
        "\n",
        "# create model, initialize model parameters\n",
        "clip_model = ub.models.clip(**config.model)\n",
        "@functools.partial(jax.jit, backend='cpu')\n",
        "def init(rng):\n",
        "    image_size = (image_resolution, image_resolution, 3)\n",
        "    text_size = (MAX_TEXT_LENGTH, )\n",
        "    dummy_image = jnp.zeros((1,) + image_size, jnp.float32)\n",
        "    dummy_text = jnp.zeros((1,) + text_size, jnp.int32)\n",
        "    variables = clip_model.init(rng, dummy_image, dummy_text)\n",
        "    states, params = variables.pop('params')\n",
        "    params = flax.core.unfreeze(params)\n",
        "    return params, states\n",
        "\n",
        "rng, rng_init = jax.random.split(rng)\n",
        "params_cpu, states_cpu = init(rng_init)\n",
        "# Load the optimizer from flax. We need to create an optimizer because our\n",
        "# checkpoint loader assumes that the optimizer is storing the params.\n",
        "opt_name = config.get('optim_name')\n",
        "opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))\n",
        "# We jit this, such that the arrays that are created are created on the same\n",
        "# device as the input is, in this case the CPU. Else they'd be on device[0].\n",
        "opt_cpu = jax.jit(opt_def.create)(params_cpu)\n",
        "\n",
        "# Load the checkpoint.\n",
        "checkpoint_data = checkpoint_utils.maybe_load_checkpoint(\n",
        "    train_loop_rngs=rng,\n",
        "    save_checkpoint_path=None,\n",
        "    init_optimizer=opt_cpu,\n",
        "    init_params=params_cpu,\n",
        "    init_fixed_model_states=states_cpu,\n",
        "    default_reinit_params=[],\n",
        "    config=config)\n",
        "loaded_params = checkpoint_data.optimizer.target\n",
        "loaded_states = checkpoint_data.fixed_model_states\n",
        "# Sanity check to make sure we loaded params:\n",
        "print('params_cpu[logit_scale]=%s, loaded_params[logit_scale]=%s' % (params_cpu['logit_scale'], loaded_params['logit_scale']))\n",
        "\n",
        "del opt_cpu\n",
        "del params_cpu\n",
        "del states_cpu\n",
        "\n",
        "clip_vars = {'params': flax.core.freeze(loaded_params), **loaded_states}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HMeOdMjj3Wf-"
      },
      "outputs": [],
      "source": [
        "#@title Create tokenizer\n",
        "bpe_path='uncertainty-baselines/experimental/multimodal/bpe_simple_vocab_16e6.txt.gz'\n",
        "\n",
        "tokenizer = simple_tokenizer.SimpleTokenizer(bpe_path=bpe_path)\n",
        "tokenize_fn = simple_tokenizer.make_tokenize_fn(tokenizer, config.tokenizer_max_len)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "slo5g08MpCYv"
      },
      "source": [
        "## Create helper functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "779boNm0aezH"
      },
      "outputs": [],
      "source": [
        "#@title encode_texts \u0026 encode_images\n",
        "\n",
        "# For batches of texts/images.\n",
        "@jax.jit\n",
        "def encode_texts(texts):\n",
        "    return clip_model.apply(\n",
        "        clip_vars,\n",
        "        texts,\n",
        "        normalize=False,\n",
        "        scale_logits=False,\n",
        "        method=clip_model.encode_text\n",
        "    )\n",
        "\n",
        "@jax.jit\n",
        "def encode_images(images):\n",
        "    return clip_model.apply(clip_vars, images, method=clip_model.encode_image)\n",
        "\n",
        "\n",
        "# For a single text/image.\n",
        "def encode_text(text):\n",
        "    return encode_texts(jnp.expand_dims(text, axis=0))\n",
        "\n",
        "def encode_image(image):\n",
        "    return encode_images(jnp.expand_dims(image, axis=0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2FVS7s4dpHQM"
      },
      "outputs": [],
      "source": [
        "#@title load_xxx_dataset\n",
        "\n",
        "def _get_split(dataset, split, pp, rng, data_dir, batch_size=None, drop_remainder=False):\n",
        "\n",
        "    if isinstance(pp, str):\n",
        "        pp = preprocess_spec.parse(spec=pp, available_ops=preprocess_utils.all_ops())\n",
        "\n",
        "    batch_size = BATCH_SIZE if batch_size is None else batch_size\n",
        "\n",
        "    rng = jax.random.fold_in(rng, jax.process_index())\n",
        "\n",
        "    val_ds = input_utils.get_data(\n",
        "        dataset=dataset,\n",
        "        split=split,\n",
        "        rng=rng,\n",
        "        process_batch_size=batch_size,\n",
        "        preprocess_fn=pp,\n",
        "        cache=config.get('val_cache', 'batched'),\n",
        "        num_epochs=1,\n",
        "        repeat_after_batching=True,\n",
        "        shuffle=False,\n",
        "        prefetch_size=config.get('prefetch_to_host', 2),\n",
        "        drop_remainder=drop_remainder,\n",
        "        data_dir=data_dir)\n",
        "\n",
        "    return val_ds\n",
        "\n",
        "def load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=5000):\n",
        "    rng, zeroshot_ds_rng = jax.random.split(rng)\n",
        "    preprocess_fn = preprocess_spec.parse(\n",
        "        spec=config.zeroshot_eval_datasets[dataset_name]['pp_spec'],\n",
        "        available_ops=preprocess_utils.all_ops())\n",
        "\n",
        "    data_dir = config.get('data_dir')\n",
        "\n",
        "    zs_split = _get_split(\n",
        "        config.zeroshot_eval_datasets[dataset_name]['dataset'],\n",
        "        split=config.zeroshot_eval_datasets[dataset_name]['split'],\n",
        "        pp=preprocess_fn,\n",
        "        rng=zeroshot_ds_rng,\n",
        "        data_dir=data_dir,\n",
        "        batch_size=zs_batch_size\n",
        "    )\n",
        "    return zs_split\n",
        "\n",
        "def load_train_dataset(config, train_ds_rng):\n",
        "    batch_size = config.batch_size\n",
        "\n",
        "    preprocess_fn = preprocess_spec.parse(\n",
        "        spec=config.pp_train, available_ops=preprocess_utils.all_ops()\n",
        "    )\n",
        "    train_split = _get_split(\n",
        "        config.dataset,\n",
        "        split=config.train_split,\n",
        "        pp=preprocess_fn,\n",
        "        rng=train_ds_rng,\n",
        "        data_dir=config.get('data_dir'),\n",
        "        batch_size=batch_size\n",
        "    )\n",
        "\n",
        "    return train_split"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "55vfGNDyGTPt"
      },
      "outputs": [],
      "source": [
        "#@title compute_text_embeddings\n",
        "def compute_text_embeddings(templates, tokenize_fn, logit_scale, dataset, use_l2_norm=True, classnames=None):\n",
        "    ztxts = []\n",
        "    if classnames is None:\n",
        "        classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset]['classnames_key']]\n",
        "    for clsname in tqdm(classnames):\n",
        "        token_fn = lambda text: tokenize_fn(tf.constant(text, dtype=tf.string))\n",
        "        texts = jnp.array(\n",
        "            [token_fn(template.format(clsname)) for template in templates])\n",
        "        class_embeddings = encode_texts(texts)\n",
        "        class_embedding = class_embeddings  # [n_prompts, emb_dim]\n",
        "        if use_l2_norm:\n",
        "            class_embedding *= jax.lax.rsqrt(jnp.sum(class_embedding**2, axis=-1, keepdims=True))\n",
        "        class_embedding *= jnp.sqrt(jnp.exp(logit_scale))\n",
        "        ztxts.append(class_embedding)\n",
        "    ztxts = np.stack(ztxts, axis=1)  # [n_prompts, n_classes, emb_dim]\n",
        "    return ztxts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "1old4iYbL-hZ"
      },
      "outputs": [],
      "source": [
        "#@title compute_image_embeddings\n",
        "def compute_image_embeddings(ds_iter, image_resolution):\n",
        "    zimgs = []\n",
        "    labels = []\n",
        "    for batch in tqdm(ds_iter):\n",
        "        image_embedding = jax.pmap(encode_images)(batch['image'])\n",
        "        image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1])\n",
        "        labels_ = batch['label'].reshape(-1, batch['label'].shape[-1]).argmax(-1)\n",
        "\n",
        "        mask = batch['mask'].reshape(-1).astype(np.int32)\n",
        "        mask = np.where(mask)\n",
        "        zimgs.append(image_embedding[mask])\n",
        "        labels.append(labels_[mask])\n",
        "\n",
        "    return np.vstack(zimgs), np.hstack(labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "hvG8dT3i6is9"
      },
      "outputs": [],
      "source": [
        "#@title compute_metrics\n",
        "\n",
        "def compute_accuracy(logits, labels):\n",
        "    top_probs, top_labels = jax.lax.top_k(logits, 5)\n",
        "    top1 = 100 * jnp.mean(top_labels[:, 0] == labels)\n",
        "    top5 = 100 * jnp.sum(top_labels == labels[:, None]) / labels.shape[0]\n",
        "    return top1, top5\n",
        "\n",
        "def compute_metrics(labels, logits, print_out=True):\n",
        "    probs = jax.nn.softmax(logits)\n",
        "    preds = np.argmax(logits, axis=-1)\n",
        "    confidences = np.max(probs, axis=-1)\n",
        "\n",
        "    acc, acc5 = compute_accuracy(logits, labels)\n",
        "\n",
        "    if print_out:\n",
        "        print(f'top1_acc: {acc:5.2f}')\n",
        "\n",
        "    return acc, acc5"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KCo812aHpfiD"
      },
      "outputs": [],
      "source": [
        "#@title get_logits\n",
        "def get_logits(\n",
        "    ztxts,  # [n_prompts, n_classes, emb_dim]\n",
        "    zimgs   # [n_imgs, emb_dim]\n",
        "):\n",
        "    \"\"\"Calculate the zero-shot classifier's logits.\"\"\"\n",
        "    all_logits = jax.vmap(jax.jit(\n",
        "        lambda x, y: jnp.dot(x, y.T), backend='cpu'),\n",
        "        in_axes=(None, 0), out_axes=1)(zimgs, ztxts)\n",
        "\n",
        "    return all_logits\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2HKbJ1zF6avN"
      },
      "outputs": [],
      "source": [
        "#@title agg_logits\n",
        "@functools.partial(jax.jit, backend='cpu')\n",
        "def agg_logits(\n",
        "    logits, # [n_imgs, n_prompts, n_classes]\n",
        "    weights=None\n",
        "):\n",
        "    \"\"\"Calculate (optionally weighted) average of logits.\"\"\"\n",
        "    _, n_prompts, _ = logits.shape\n",
        "\n",
        "    if weights is None:\n",
        "        weights = jnp.ones((n_prompts,))[jnp.newaxis, :, jnp.newaxis]\n",
        "\n",
        "    logits = (logits * weights).mean(axis=1) # [n_imgs, n_classes]\n",
        "\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "k2bvQv10-LY0"
      },
      "outputs": [],
      "source": [
        "#@title get_weights\n",
        "@functools.partial(jax.jit, static_argnums=(2, 3, 4), backend='cpu')\n",
        "def get_weights(\n",
        "    logits,  # [n_img, n_prompt, n_cls]\n",
        "    random_logits=None,  # [n_rand, n_prompt, n_cls]\n",
        "    debias_mode='both',\n",
        "    img_mean=True,\n",
        "    frac_test=1.,\n",
        "):\n",
        "    n_img = logits.shape[0]\n",
        "\n",
        "    if debias_mode in ['both', 'pretrain', 'pretrain_star']:\n",
        "        assert random_logits is not None\n",
        "        axes = (0,) if debias_mode != 'pretrain_star' else (0, 2)\n",
        "        img_mean_rand_logits = random_logits.mean(axes, keepdims=True)  # [1, n_prompt, n_cls or 1]\n",
        "\n",
        "    if debias_mode in ['both', 'test']:\n",
        "        n_test = round(n_img * frac_test)\n",
        "        img_mean_ds_logits = logits[:n_test].mean(0, keepdims=True)  # [1, n_prompt, n_cls]\n",
        "\n",
        "    if debias_mode == 'both':\n",
        "        normalised = logits - 0.5*(img_mean_rand_logits + img_mean_ds_logits)\n",
        "    elif debias_mode == 'test':\n",
        "        normalised = logits - img_mean_ds_logits\n",
        "    elif debias_mode in ['pretrain_star', 'pretrain']:\n",
        "        normalised = logits - img_mean_rand_logits\n",
        "    elif debias_mode == 'none':\n",
        "        normalised = logits\n",
        "    else:\n",
        "        raise RuntimeError(f'Unknown \"debias_mode\" type {debias_mode}')\n",
        "\n",
        "    conf_scores = normalised.max(-1, keepdims=True)  # [n_img, n_prompt, 1]\n",
        "\n",
        "    if img_mean:\n",
        "        conf_scores = conf_scores.mean(0, keepdims=True)  # [1, n_prompt, 1]\n",
        "\n",
        "    return conf_scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "-RmNr7RJjn0e"
      },
      "outputs": [],
      "source": [
        "#@title mad_method\n",
        "def mad_method(data, threshold=3):\n",
        "    \"\"\"Median Absolute Deviation outlier detection.\"\"\"\n",
        "    if threshold == 'NA':\n",
        "        return np.arange(len(data))\n",
        "\n",
        "    med = np.median(data)\n",
        "    mad = np.abs(stats.median_abs_deviation(data))\n",
        "    outliers = []\n",
        "    for i, v in enumerate(data):\n",
        "        t = (v - med)/mad\n",
        "        if t \u003e threshold:\n",
        "            outliers.append(i)\n",
        "        else:\n",
        "            continue\n",
        "    return np.array(outliers, dtype=np.int32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8T7BZLXeHiue"
      },
      "source": [
        "## Collect tabular results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "evmA61nhTrgE"
      },
      "outputs": [],
      "source": [
        "rng, train_ds_rng = jax.random.split(rng, 2)\n",
        "train_ds = load_train_dataset(config, train_ds_rng)\n",
        "train_iter = input_utils.start_input_pipeline(\n",
        "      train_ds, config.get('prefetch_to_device', 1))\n",
        "\n",
        "zimgs_laion = []\n",
        "\n",
        "for i, batch in enumerate(tqdm(train_iter)):\n",
        "  zimg = jax.pmap(encode_images)(batch['image'])\n",
        "  zimg = np.array(zimg).reshape(-1, *zimg.shape[2:])\n",
        "  zimgs_laion.append(zimg)\n",
        "\n",
        "zimgs_laion = np.concatenate(zimgs_laion)\n",
        "\n",
        "zimgs_laion.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "q5BwR5C5lZ5R"
      },
      "outputs": [],
      "source": [
        "#@title Extra templates generated by GPT\n",
        "EXTRA_TEMPLATES = {\n",
        "    'A photo of a {}, a type of insect.',\n",
        "    'A photo of a {}, a type of fish.',\n",
        "    'A photo of a {}, a type of tree.',\n",
        "    'A photo of a {}, a type of fruit.',\n",
        "    'A photo of a {}, a type of car.',\n",
        "    'A photo of a {}, a type of dog.',\n",
        "    'A photo of a {}, a type of mammal.',\n",
        "    'A photo of a {}, a type of reptile.',\n",
        "    'A photo of a {}, a type of food.',\n",
        "    'A photo of a {}, a type of vegetable.',\n",
        "    'A photo of a {}, a type of landscape.',\n",
        "    'A photo of a {}, a type of cityscape.',\n",
        "    'A photo of a {}, a type of seascape.',\n",
        "    'A photo of a {}, a type of architecture.',\n",
        "    'A photo of a {}, a type of monument.',\n",
        "    'A photo of a {}, a type of painting.',\n",
        "    'A photo of a {}, a type of sculpture.',\n",
        "    'A photo of a {}, a type of musical instrument.',\n",
        "    'A photo of a {}, a type of weapon.',\n",
        "    'A photo of a {}, a type of clothing.',\n",
        "    'A photo of a {}, a type of jewelry.',\n",
        "    'A photo of a {}, a type of household item.',\n",
        "    'A photo of a {}, a type of electronic device.',\n",
        "    'A photo of a {}, a type of tool.',\n",
        "    'A photo of a {}, a type of transportation.',\n",
        "    'A photo of a {}, a type of recreational activity.',\n",
        "    'A photo of a {}, a type of game.',\n",
        "    'A photo of a {}, a type of sport.',\n",
        "    'A photo of a {}, a type of musical genre.',\n",
        "    'A photo of a {}, a type of movie genre.',\n",
        "    'A photo of a {}, a type of book genre.',\n",
        "    'A photo of a {}, a type of historical event.',\n",
        "    'A photo of a {}, a type of mythological creature.',\n",
        "    'A photo of a {}, a type of fantasy creature.',\n",
        "    'A photo of a {}, a type of planet.',\n",
        "    'A photo of a {}, a type of constellation.',\n",
        "    'A photo of a {}, a type of comet.',\n",
        "    'A photo of a {}, a type of galaxy.',\n",
        "    'A photo of a {}, a type of meteor.',\n",
        "    'A photo of a {}, a type of asteroid.',\n",
        "    'A photo of a {}, a type of planet.',\n",
        "    'A photo of a {}, a type of star.',\n",
        "    'A photo of a {}, a type of black hole.',\n",
        "    'A photo of a {}, a type of neutron star.',\n",
        "    'A photo of a {}, a type of quasar.',\n",
        "    'A photo of a {}, a type of pulsar.',\n",
        "    'A photo of a {}, a type of supernova.',\n",
        "    'A photo of a {}, a type of brown dwarf.',\n",
        "    'A photo of a {}, a type of white dwarf.',\n",
        "    'A photo of a {}, a type of red giant.',\n",
        "    'A photo of a {}, a type of butterfly.',\n",
        "    'A photo of a {}, a type of amphibian.',\n",
        "    'A photo of a {}, a type of berry.',\n",
        "    'A photo of a {}, a type of motorcycle.',\n",
        "    'A photo of a {}, a type of cat.',\n",
        "    'A photo of a {}, a type of rodent.',\n",
        "    'A photo of a {}, a type of fish.',\n",
        "    'A photo of a {}, a type of dinosaur.',\n",
        "    'A photo of a {}, a type of pasta.',\n",
        "    'A photo of a {}, a type of grain.',\n",
        "    'A photo of a {}, a type of mountain range.',\n",
        "    'A photo of a {}, a type of waterfall.',\n",
        "    'A photo of a {}, a type of lake.',\n",
        "    'A photo of a {}, a type of bridge.',\n",
        "    'A photo of a {}, a type of lighthouse.',\n",
        "    'A photo of a {}, a type of pottery.',\n",
        "    'A photo of a {}, a type of tapestry.',\n",
        "    'A photo of a {}, a type of drum.',\n",
        "    'A photo of a {}, a type of sword.',\n",
        "    'A photo of a {}, a type of hat.',\n",
        "    'A photo of a {}, a type of watch.',\n",
        "    'A photo of a {}, a type of kitchen appliance.',\n",
        "    'A photo of a {}, a type of camera.',\n",
        "    'A photo of a {}, a type of power tool.',\n",
        "    'A photo of a {}, a type of boat.',\n",
        "    'A photo of a {}, a type of adventure sport.',\n",
        "    'A photo of a {}, a type of board game.',\n",
        "    'A photo of a {}, a type of ball sport.',\n",
        "    'A photo of a {}, a type of folk music.',\n",
        "    'A photo of a {}, a type of action movie.',\n",
        "    'A photo of a {}, a type of mystery novel.',\n",
        "    'A photo of a {}, a type of war.',\n",
        "    'A photo of a {}, a type of mythical king.',\n",
        "    'A photo of a {}, a type of fantasy race.',\n",
        "    'A photo of a {}, a type of planet.',\n",
        "    'A photo of a {}, a type of constellation.',\n",
        "    'A photo of a {}, a type of comet.',\n",
        "    'A photo of a {}, a type of galaxy.',\n",
        "    'A photo of a {}, a type of meteor.',\n",
        "    'A photo of a {}, a type of asteroid.',\n",
        "    'A photo of a {}, a type of planet.',\n",
        "    'A photo of a {}, a type of star.',\n",
        "    'A photo of a {}, a type of black hole.',\n",
        "    'A photo of a {}, a type of neutron star.',\n",
        "    'A photo of a {}, a type of quasar.',\n",
        "    'A photo of a {}, a type of pulsar.',\n",
        "    'A photo of a {}, a type of supernova.',\n",
        "    'A photo of a {}, a type of brown dwarf.',\n",
        "    'A photo of a {}, a type of white dwarf.',\n",
        "    'A photo of a {}, a type of red giant.',\n",
        "    'A panoramic photo of a {}.',\n",
        "    'A close-up photo of a {}.',\n",
        "    'A wide-angle photo of a {}.',\n",
        "    'A high-resolution photo of a {}.',\n",
        "    'A low-light photo of a {}.',\n",
        "    'A time-lapse photo of a {}.',\n",
        "    'A long-exposure photo of a {}.',\n",
        "    'A night photo of a {}.',\n",
        "    'A sunset photo of a {}.',\n",
        "    'A sunrise photo of a {}.',\n",
        "    'A silhouette photo of a {}.',\n",
        "    'A sepia-toned photo of a {}.',\n",
        "    'A colored photo of a {}.',\n",
        "    'A watercolor photo of a {}.',\n",
        "    'A sketch photo of a {}.',\n",
        "    'A hyperlapse photo of a {}.',\n",
        "    'A tilt-shift photo of a {}.',\n",
        "    'A motion-blurred photo of a {}.',\n",
        "    'A double-exposure photo of a {}.',\n",
        "    'A HDR photo of a {}.',\n",
        "    'A 360-degree photo of a {}.',\n",
        "    'A black-and-white negative photo of a {}.',\n",
        "    'A split-tone photo of a {}.',\n",
        "    'A film-grain photo of a {}.',\n",
        "    'A thermal photo of a {}.',\n",
        "    'A infrared photo of a {}.',\n",
        "    'A ultraviolet photo of a {}.',\n",
        "    'A x-ray photo of a {}.',\n",
        "    'A 3D photo of a {}.',\n",
        "    'A stop-motion photo of a {}.',\n",
        "    'A bokeh photo of a {}.',\n",
        "    'A miniature photo of a {}.',\n",
        "    'A light-painted photo of a {}.',\n",
        "    'A composite photo of a {}.',\n",
        "    'A polarized photo of a {}.',\n",
        "    'A photomontage photo of a {}.',\n",
        "    'A digital-art photo of a {}.',\n",
        "    'A abstract photo of a {}.',\n",
        "    'A selective-focus photo of a {}.',\n",
        "    'A black-and-white film photo of a {}.',\n",
        "    'A cross-processed photo of a {}.',\n",
        "    'A cyanotype photo of a {}.',\n",
        "    'A lomography photo of a {}.',\n",
        "    'A pinhole photo of a {}.',\n",
        "    'A cyanotype photo of a {}.',\n",
        "    'A high-dynamic-range photo of a {}.',\n",
        "    'A low-dynamic-range photo of a {}.',\n",
        "    'A multiexposure photo of a {}.',\n",
        "    'A high-speed photo of a {}.',\n",
        "    'A underwater photo of a {}.',\n",
        "    'A sculpture of a {}.',\n",
        "    'A print of a {}.',\n",
        "    'A sketch of a {}.',\n",
        "    'A engraving of a {}.',\n",
        "    'A etching of a {}.',\n",
        "    'A lithograph of a {}.',\n",
        "    'A watercolor of a {}.',\n",
        "    'A pastel of a {}.',\n",
        "    'A charcoal of a {}.',\n",
        "    'A oil painting of a {}.',\n",
        "    'A acrylic painting of a {}.',\n",
        "    'A digital painting of a {}.',\n",
        "    'A fresco of a {}.',\n",
        "    'A mosaic of a {}.',\n",
        "    'A collage of a {}.',\n",
        "    'A graffiti of a {}.',\n",
        "    'A stained glass of a {}.',\n",
        "    'A quilt of a {}.',\n",
        "    'A tapestry of a {}.',\n",
        "    'A batik of a {}.',\n",
        "    'A calligraphy of a {}.',\n",
        "    'A wood carving of a {}.',\n",
        "    'A metal sculpture of a {}.',\n",
        "    'A glass sculpture of a {}.',\n",
        "    'A clay sculpture of a {}.',\n",
        "    'A ice sculpture of a {}.',\n",
        "    'A sand sculpture of a {}.',\n",
        "    'A paper mache of a {}.',\n",
        "    'A sculptural installation of a {}.',\n",
        "    'A mural of a {}.',\n",
        "    'A fresco of a {}.',\n",
        "    'A graffiti of a {}.',\n",
        "    'A street art of a {}.',\n",
        "    'A digital art of a {}.',\n",
        "    'A film of a {}.',\n",
        "    'A animation of a {}.',\n",
        "    'A stop motion animation of a {}.',\n",
        "    'A motion graphics of a {}.',\n",
        "    'A 3D animation of a {}.',\n",
        "    'A VR of a {}.',\n",
        "    'A AR of a {}.',\n",
        "    'A hologram of a {}.',\n",
        "    'A laser show of a {}.',\n",
        "    'A light show of a {}.',\n",
        "    'A pyrotechnics of a {}.',\n",
        "    'A performance of a {}.',\n",
        "    'A sound sculpture of a {}.',\n",
        "    'A kinetic sculpture of a {}.',\n",
        "    'A land art of a {}.',\n",
        "    'A environmental art of a {}.',\n",
        "}\n",
        "len(EXTRA_TEMPLATES)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K-00FLI_Ds23"
      },
      "outputs": [],
      "source": [
        "#@title Make shared ztxt\n",
        "\n",
        "pool_templates = list(set(sum(multimodal_utils._ZEROSHOT_TEMPLATES.values(), [])))\n",
        "pool_n_prompts = len(pool_templates)\n",
        "print(\"pool_n_prompts\", pool_n_prompts)\n",
        "\n",
        "USE_EXTRA_TEMPLATE = True #@param\n",
        "if USE_EXTRA_TEMPLATE:\n",
        "  extras_templates = EXTRA_TEMPLATES - set(pool_templates)\n",
        "  all_templates = pool_templates + list(extras_templates)\n",
        "else:\n",
        "  all_templates = pool_templates\n",
        "\n",
        "print(\"len(all_templates)\", len(all_templates))\n",
        "\n",
        "all_classnames = list(set(sum(multimodal_utils._ZEROSHOT_CLASS_NAMES.values(), [])))\n",
        "print(\"len(all_classnames)\", len(all_classnames))\n",
        "\n",
        "ztxts_all_prompts_all_class = compute_text_embeddings(\n",
        "    all_templates, tokenize_fn, loaded_params['logit_scale'], '',\n",
        "    classnames=all_classnames\n",
        ")\n",
        "\n",
        "ztxts_all_prompts_all_class.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RSfEcwnbDos9"
      },
      "outputs": [],
      "source": [
        "OVERWRITE_RESULTS = False #@param\n",
        "LOAD_RESULTS = False #@param\n",
        "\n",
        "base_path = ''\n",
        "df_name = 'final_results_dataframe.pkl'\n",
        "df_path = os.path.join(base_path, df_name)\n",
        "\n",
        "results_df = pd.DataFrame(columns = [\n",
        "    'dataset_name', 'debias_mode', 'img_mean', 'prompt_set', 'weighting',\n",
        "    'select_threshold', 'num_pretrain', 'frac_test', 'top1_acc', 'top5_acc',\n",
        "])\n",
        "\n",
        "if OVERWRITE_RESULTS:\n",
        "    with tf.io.gfile.GFile(df_path, 'w') as f:\n",
        "        f.write(pickle.dumps(results_df, protocol=4))\n",
        "\n",
        "if LOAD_RESULTS:\n",
        "    with tf.io.gfile.GFile(df_path, 'rb') as f:\n",
        "      results_df = pickle.load(f)\n",
        "\n",
        "results_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q6NSEOZQJiUA"
      },
      "outputs": [],
      "source": [
        "#@title Specify datasets and batch sizes\n",
        "ds_inet = {\n",
        "    'imagenet': 5000, 'imagenet_a': 5000, 'imagenet_r': 5000, 'imagenet_sketch': 5000, 'imagenet_v2': 5000,\n",
        "}\n",
        "ds_fine = {\n",
        "    'caltech101': 5000, 'cars196': 5000, 'cifar10': 5000, 'cifar100': 5000, 'dtd': 1880, 'eurosat': 5000,\n",
        "    'food101': 5000, 'oxford_iiit_pet': 3669, 'oxford_flowers102': 5000, 'resisc45': 5000, 'sun397': 5000,\n",
        "}\n",
        "ds_list = ds_inet | ds_fine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q9GYhLKKuDbD"
      },
      "outputs": [],
      "source": [
        "#@title Main loop\n",
        "\n",
        "RUN_ABLATIONS = False #@param\n",
        "# ^ Whether to run all of the ablations, or the key results only.\n",
        "RUN_FULL_THRESHOLD_SWEEP = False #@param\n",
        "# ^ Whether or not to run the full sweep over threshold values, or just the values we found to be best.\n",
        "IMG_MEANS = [True] # [True, False] #@param\n",
        "# ^ Whether or not to get per-example (False) or per-dataset (True) prompt scores.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFjxveSsVYDu"
      },
      "outputs": [],
      "source": [
        "for dataset_name, batch_size in ds_list.items():\n",
        "    print(\"dataset_name\", dataset_name)\n",
        "\n",
        "    # Collect prompts.\n",
        "    # 'ds' stands for 'dataset specific'. I.e., the 'hand-crafted' prompts.\n",
        "    ds_templates = multimodal_utils._ZEROSHOT_TEMPLATES[config.zeroshot_eval_datasets[dataset_name]['prompts_key']]\n",
        "    ds_n_prompts = len(ds_templates)\n",
        "\n",
        "    pool_idxs = np.array([all_templates.index(p) for p in pool_templates])\n",
        "    ds_idxs = np.array([all_templates.index(p) for p in ds_templates])\n",
        "    inet_idxs = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['imagenet']])\n",
        "    none_idx = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['none']])\n",
        "    photo_idx = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['photo']])\n",
        "\n",
        "    # Get prompt embeddings.\n",
        "    classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]\n",
        "    classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])\n",
        "    ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]\n",
        "\n",
        "    ztxts_classname = compute_text_embeddings(multimodal_utils._ZEROSHOT_TEMPLATES['none'],\n",
        "                                                tokenize_fn, loaded_params['logit_scale'], dataset_name)\n",
        "\n",
        "    ztxts_photo = compute_text_embeddings(multimodal_utils._ZEROSHOT_TEMPLATES['photo'],\n",
        "                                                tokenize_fn, loaded_params['logit_scale'], dataset_name)\n",
        "\n",
        "    # Get image embeddings.\n",
        "    zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=batch_size)\n",
        "    ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))\n",
        "    zimgs, labels = compute_image_embeddings(ds_iter, image_resolution)\n",
        "\n",
        "    if dataset_name == 'imagenet_r':\n",
        "        labels = np.array([multimodal_utils._IMAGENET_R_LABELSET.index(l) for l in labels])\n",
        "    elif dataset_name == 'imagenet_a':\n",
        "        labels = np.array([multimodal_utils._IMAGENET_A_LABELSET.index(l) for l in labels])\n",
        "\n",
        "    # Get logits.\n",
        "    classname_logits = get_logits(ztxts_classname, zimgs)\n",
        "    photo_logits = get_logits(ztxts_photo, zimgs)\n",
        "    all_prompts_logits = get_logits(ztxts_all_prompts, zimgs)\n",
        "    pool_logits = all_prompts_logits[:, pool_idxs, :]\n",
        "    ds_logits = all_prompts_logits[:, ds_idxs, :]\n",
        "    inet_logits = all_prompts_logits[:, inet_idxs, :]\n",
        "    del ztxts_classname, ztxts_photo\n",
        "\n",
        "    def add_row_(\n",
        "        df, dataset_name, debias_mode, img_mean, prompt_set, weighting, select_threshold,\n",
        "        num_pretrain, frac_test, top1_acc, top5_acc\n",
        "    ):\n",
        "        return pd.concat([df, pd.DataFrame.from_dict({\n",
        "            'dataset_name': [dataset_name], 'debias_mode': [debias_mode], 'img_mean': [img_mean],\n",
        "            'prompt_set': [prompt_set], 'weighting': [weighting], 'select_threshold': [select_threshold],\n",
        "            'num_pretrain': [num_pretrain],  'frac_test': [frac_test],\n",
        "            'top1_acc': [top1_acc], 'top5_acc': [top5_acc],\n",
        "        })], ignore_index=True)\n",
        "\n",
        "    # Class name.\n",
        "    top1_acc, top5_acc = compute_metrics(labels, agg_logits(classname_logits));\n",
        "    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'classname', 'NA', 'NA',\n",
        "                          'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "    # 'A photo of {}'.\n",
        "    top1_acc, top5_acc = compute_metrics(labels, agg_logits(photo_logits));\n",
        "    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'photo', 'NA', 'NA',\n",
        "                          'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "    # Dataset specific prompts with equal weighting.\n",
        "    top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits));\n",
        "    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'dataset', 'equal', 'NA',\n",
        "                          'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "    # Pool prompts with equal weighting.\n",
        "    top1_acc, top5_acc = compute_metrics(labels, agg_logits(pool_logits));\n",
        "    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'pool', 'equal', 'NA',\n",
        "                          'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "    # INet prompts with equal weighting.\n",
        "    if RUN_ABLATIONS:\n",
        "        top1_acc, top5_acc = compute_metrics(labels, agg_logits(inet_logits));\n",
        "        results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'inet', 'equal', 'NA',\n",
        "                              'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "        # All prompts with equal weighting.\n",
        "        top1_acc, top5_acc = compute_metrics(labels, agg_logits(all_prompts_logits));\n",
        "        results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'all', 'equal', 'NA',\n",
        "                              'NA', 'NA', top1_acc, top5_acc)\n",
        "\n",
        "    for img_mean in IMG_MEANS:\n",
        "        debias_modes = ['both', 'test', 'pretrain_star', 'pretrain', 'none'] if RUN_ABLATIONS else ['both']\n",
        "        for debias_mode in debias_modes:\n",
        "            print(debias_mode)\n",
        "            if debias_mode == 'both':\n",
        "                if RUN_ABLATIONS:\n",
        "                    num_pretrains_frac_test = [(5_000, 1.), (10_000, 1.), (20_000, 1.), (20_000, .5), (20_000, .2), (20_000, .1)]\n",
        "                else:\n",
        "                    num_pretrains_frac_test = [(20_000, 1.)]\n",
        "            elif debias_mode in ['pretrain_star', 'pretrain']:\n",
        "                num_pretrains_frac_test = [(20_000, 'NA')]\n",
        "            else:\n",
        "                num_pretrains_frac_test = [('NA', 1.)]\n",
        "\n",
        "            for num_pretrain, frac_test in num_pretrains_frac_test:\n",
        "                if debias_mode == 'both':\n",
        "                    random_logits = get_logits(ztxts_all_prompts, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_ds]\n",
        "                elif debias_mode == 'test':\n",
        "                    random_logits = None\n",
        "                elif debias_mode == 'pretrain_star':\n",
        "                    random_logits = get_logits(ztxts_all_prompts_all_class, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_all]\n",
        "                elif debias_mode == 'pretrain':\n",
        "                    random_logits = get_logits(ztxts_all_prompts, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_ds]\n",
        "                else:\n",
        "                    random_logits = None\n",
        "                del ztxts_all_prompts\n",
        "\n",
        "                all_weights = get_weights(all_prompts_logits, random_logits, debias_mode=debias_mode, img_mean=img_mean, frac_test=frac_test)\n",
        "\n",
        "                # ds prompts with score weighting.\n",
        "                top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits, all_weights[:, ds_idxs, :]));\n",
        "                results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'dataset', 'scores', 'NA',\n",
        "                                      num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits, jax.nn.softmax(all_weights, axis=1)[:, ds_idxs, :]));\n",
        "                results_df = add_row_(results_df, dataset_name, debias_mode, img_mean,'dataset', 'softmax_scores', 'NA',\n",
        "                                      num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                if img_mean:\n",
        "                    if RUN_FULL_THRESHOLD_SWEEP:\n",
        "                        thresholds = ['NA', 0.1, 0.2, 0.3, 0.4, 0.5, 1.0, 1.5, 1.8, 2.0, 2.5]\n",
        "                    else:\n",
        "                        thresholds = ['NA', 0.5, 2.0]\n",
        "                else:\n",
        "                    thresholds = ['NA']\n",
        "\n",
        "                for select_threshold in thresholds:\n",
        "\n",
        "                    # Pool prompts with score weighting / thresholding.\n",
        "                    selected_prompt_idxs = mad_method(all_weights[0, pool_idxs, 0], select_threshold)\n",
        "                    # Note: ^ select_threshold == 'NA' is a special case for mad_method equivalent to select_threshold == 0.\n",
        "\n",
        "                    top1_acc, top5_acc = compute_metrics(\n",
        "                        labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], all_weights[:, pool_idxs[selected_prompt_idxs], :])\n",
        "                    ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                    results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'scores', select_threshold,\n",
        "                                          num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                    top1_acc, top5_acc = compute_metrics(\n",
        "                        labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, pool_idxs[selected_prompt_idxs], :])\n",
        "                    ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                    results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'softmax_scores', select_threshold,\n",
        "                                          num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                    if RUN_ABLATIONS and debias_mode == 'both' and num_pretrain == 20_000 and frac_test == 1.:\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], (all_weights**10)[:, pool_idxs[selected_prompt_idxs], :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'scores^10', select_threshold,\n",
        "                                              num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                    if select_threshold != 'NA':\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(pool_logits[:, selected_prompt_idxs, :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'equal', select_threshold,\n",
        "                                              num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                    if RUN_ABLATIONS and debias_mode == 'both' and num_pretrain == 20_000 and frac_test == 1.:\n",
        "                        # INet prompts with score weighting / thresholding.\n",
        "                        selected_prompt_idxs = mad_method(all_weights[0, inet_idxs, 0], select_threshold)\n",
        "\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(inet_logits[:, selected_prompt_idxs, :], all_weights[:, inet_idxs[selected_prompt_idxs], :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'scores', select_threshold,\n",
        "                                              num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(inet_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, inet_idxs[selected_prompt_idxs], :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'softmax_scores', select_threshold,\n",
        "                                              num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                        if select_threshold != 'NA':\n",
        "                            top1_acc, top5_acc = compute_metrics(\n",
        "                                labels, agg_logits(inet_logits[:, selected_prompt_idxs, :])\n",
        "                            ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                            results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'equal', select_threshold,\n",
        "                                                  num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                        # All prompts with score weighting / thresholding.\n",
        "                        selected_prompt_idxs = mad_method(all_weights[0, :, 0], select_threshold)\n",
        "\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :], all_weights[:, selected_prompt_idxs, :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'scores', select_threshold,\n",
        "                                            num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                        top1_acc, top5_acc = compute_metrics(\n",
        "                            labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, selected_prompt_idxs, :])\n",
        "                        ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'softmax_scores', select_threshold,\n",
        "                                            num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                        if select_threshold != 'NA':\n",
        "                            top1_acc, top5_acc = compute_metrics(\n",
        "                                labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :])\n",
        "                            ) if len(selected_prompt_idxs) \u003e 0 else (0, 0, 0, 0)\n",
        "                            results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'equal', select_threshold,\n",
        "                                                num_pretrain, frac_test, top1_acc, top5_acc)\n",
        "\n",
        "                del all_weights\n",
        "                del random_logits\n",
        "\n",
        "                with tf.io.gfile.GFile(df_path, 'w') as f:\n",
        "                    f.write(pickle.dumps(results_df, protocol=4))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gfVcNPpYDnKE"
      },
      "outputs": [],
      "source": [
        "results_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xzq1pUX6YvYi"
      },
      "source": [
        "## Make tables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "butoXNNwYxJp"
      },
      "outputs": [],
      "source": [
        "with tf.io.gfile.GFile(\"final_results_dataframe.pkl\", 'rb') as f:\n",
        "  results_df = pickle.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c5qRipvcZKqG"
      },
      "outputs": [],
      "source": [
        "results_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yFltEDCnTMCJ"
      },
      "outputs": [],
      "source": [
        "def df_to_latex(df, apply_formatting=True):\n",
        "    df_rounded = df.round(2)\n",
        "\n",
        "    # Convert dataframe to LaTeX with formatting for largest and second largest values.\n",
        "    latex_string = df_rounded.to_latex(escape=False, header=True)\n",
        "\n",
        "    # Iterate over the floating-point columns.\n",
        "    if apply_formatting:\n",
        "        for column in df_rounded.select_dtypes(include=['float64']).columns:\n",
        "            # Find the largest and second largest values in the column.\n",
        "            largest = df_rounded[column].max()\n",
        "            second_largest = df_rounded[column].nlargest(2).min()\n",
        "\n",
        "            # Apply formatting to the largest value (bold) and second largest value (underline).\n",
        "            latex_string = latex_string.replace(f'{largest:.2f}', r'\\textbf{' + f'{largest:.2f}' + '}')\n",
        "            latex_string = latex_string.replace(f'{second_largest:.2f}', r'\\ul{' + f'{second_largest:.2f}' + '}')\n",
        "\n",
        "    print(latex_string)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gj9rPrEDZZs8"
      },
      "outputs": [],
      "source": [
        "#@title Table 1\n",
        "\n",
        "# Take only ImageNet datasets.\n",
        "table1_df = results_df[results_df.dataset_name.isin(ds_inet.keys())]\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table1_df = table1_df[\n",
        "    table1_df.num_pretrain.isin(['NA', 20_000]) \u0026\n",
        "    table1_df.frac_test.isin(['NA', 1.]) \u0026\n",
        "    table1_df.img_mean.isin(['NA', True]) \u0026\n",
        "    table1_df.debias_mode.isin(['NA', 'both', 'none'])\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table1_df = pd.concat([\n",
        "    # 'class name'\n",
        "    table1_df[\n",
        "        table1_df.prompt_set == 'classname'\n",
        "    ].assign(Name='class name'),\n",
        "    # 'A photo of {}'\n",
        "    table1_df[\n",
        "        table1_df.prompt_set == 'photo'\n",
        "    ].assign(Name=\"`\\emph{A photo of \\{\\}.}'\"),\n",
        "    # hand-crafted, equal average\n",
        "    table1_df[\n",
        "        (table1_df.prompt_set == 'dataset') \u0026 (table1_df.weighting == 'equal')\n",
        "    ].assign(Name='hand-crafted, equal average'),\n",
        "    # pool set, equal average\n",
        "    table1_df[\n",
        "        (table1_df.prompt_set == 'pool') \u0026 (table1_df.weighting == 'equal') \u0026 (table1_df.select_threshold == 'NA')\n",
        "    ].assign(Name='pool set, equal average'),\n",
        "    # max-logit scoring\n",
        "    table1_df[\n",
        "        (table1_df.prompt_set == 'pool') \u0026 (table1_df.weighting == 'scores') \u0026 (table1_df.debias_mode == 'none') \u0026 (table1_df.select_threshold == 'NA')\n",
        "    ].assign(Name='max-logit scoring'),\n",
        "    # ZPE (weighted average)\n",
        "    table1_df[\n",
        "        (table1_df.prompt_set == 'pool') \u0026 (table1_df.weighting == 'softmax_scores') \u0026 (table1_df.debias_mode == 'both') \u0026 (table1_df.select_threshold == 'NA')\n",
        "    ].assign(Name='ZPE (weighted average)'),\n",
        "    # ZPE (prompt selection, ours)\n",
        "    table1_df[\n",
        "        (table1_df.prompt_set == 'pool') \u0026 (table1_df.weighting == 'softmax_scores') \u0026 (table1_df.debias_mode == 'both') \u0026 (table1_df.select_threshold == .5)\n",
        "    ].assign(Name='ZPE (prompt selection, ours)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table1_df = table1_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table1_df = table1_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table1_df.columns.name = None\n",
        "table1_df.index.name = None\n",
        "\n",
        "# Sort columns.\n",
        "table1_df = table1_df.sort_index(axis=1)\n",
        "\n",
        "# Add the averaged column.\n",
        "table1_df['avg'] = table1_df.mean(axis=1)\n",
        "\n",
        "table1_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "15vpZBIbPX32"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table1_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "F1rBAd66qeqf"
      },
      "outputs": [],
      "source": [
        "#@title Table 2\n",
        "\n",
        "# Take only fine datasets.\n",
        "table2_df = results_df[results_df.dataset_name.isin(ds_fine.keys())]\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table2_df = table2_df[\n",
        "    table2_df.num_pretrain.isin(['NA', 20_000]) \u0026\n",
        "    table2_df.frac_test.isin(['NA', 1.]) \u0026\n",
        "    table2_df.img_mean.isin(['NA', True]) \u0026\n",
        "    table2_df.debias_mode.isin(['NA', 'both', 'none'])\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table2_df = pd.concat([\n",
        "    # 'class name'\n",
        "    table2_df[\n",
        "        table2_df.prompt_set == 'classname'\n",
        "    ].assign(Name='class name'),\n",
        "    # 'A photo of {}'\n",
        "    table2_df[\n",
        "        table2_df.prompt_set == 'photo'\n",
        "    ].assign(Name=\"`\\emph{A photo of \\{\\}.}'\"),\n",
        "    # hand-crafted, equal average\n",
        "    table2_df[\n",
        "        (table2_df.prompt_set == 'dataset') \u0026 (table2_df.weighting == 'equal')\n",
        "    ].assign(Name='hand-crafted, equal average'),\n",
        "    # pool set, equal average\n",
        "    table2_df[\n",
        "        (table2_df.prompt_set == 'pool') \u0026 (table2_df.weighting == 'equal') \u0026 (table2_df.select_threshold == 'NA')\n",
        "    ].assign(Name='pool set, equal average'),\n",
        "    # max-logit scoring\n",
        "    table2_df[\n",
        "        (table2_df.prompt_set == 'pool') \u0026 (table2_df.weighting == 'scores') \u0026 (table2_df.debias_mode == 'none') \u0026 (table2_df.select_threshold == 'NA')\n",
        "    ].assign(Name='max-logit scoring'),\n",
        "    # ZPE (weighted average)\n",
        "    table2_df[\n",
        "        (table2_df.prompt_set == 'pool') \u0026 (table2_df.weighting == 'softmax_scores') \u0026 (table2_df.debias_mode == 'both') \u0026 (table2_df.select_threshold == 'NA')\n",
        "    ].assign(Name='ZPE (weighted average)'),\n",
        "    # ZPE (prompt selection, ours)\n",
        "    table2_df[\n",
        "        (table2_df.prompt_set == 'pool') \u0026 (table2_df.weighting == 'softmax_scores') \u0026 (table2_df.debias_mode == 'both') \u0026 (table2_df.select_threshold == 2.)\n",
        "    ].assign(Name='ZPE (prompt selection, ours)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table2_df = table2_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table2_df = table2_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table2_df.columns.name = None\n",
        "table2_df.index.name = None\n",
        "\n",
        "# Sort columns.\n",
        "table2_df = table2_df.sort_index(axis=1)\n",
        "\n",
        "# Add the averaged column.\n",
        "table2_df['avg'] = table2_df.mean(axis=1)\n",
        "\n",
        "table2_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VIJclKTbT-Pz"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table2_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "B1AZhqzmregs"
      },
      "outputs": [],
      "source": [
        "#@title Table 3\n",
        "\n",
        "table3_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table3_df = table3_df[\n",
        "    table3_df.num_pretrain.isin([20_000, 'NA']) \u0026\n",
        "    table3_df.frac_test.isin([1., 'NA']) \u0026\n",
        "    table3_df.img_mean.isin([True]) \u0026\n",
        "    table3_df.weighting.isin(['softmax_scores']) \u0026\n",
        "    (\n",
        "        (table3_df.dataset_name.isin(ds_inet.keys()) \u0026 (table3_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table3_df.dataset_name.isin(ds_fine.keys()) \u0026 (table3_df.select_threshold.isin(['NA', 2.])))\n",
        "    ) \u0026\n",
        "    table3_df.prompt_set.isin(['pool'])\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table3_df = pd.concat([\n",
        "    # none (weighted average)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'none') \u0026 (table3_df.select_threshold == 'NA')\n",
        "    ].assign(Name='none (weighted average)'),\n",
        "    # pretrain (weighted average)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'pretrain') \u0026 (table3_df.select_threshold == 'NA')\n",
        "    ].assign(Name='pretrain (weighted average)'),\n",
        "    # pretrain_star (weighted average)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'pretrain_star') \u0026 (table3_df.select_threshold == 'NA')\n",
        "    ].assign(Name='pretrain_star (weighted average)'),\n",
        "    # test (weighted average)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'test') \u0026 (table3_df.select_threshold == 'NA')\n",
        "    ].assign(Name='test (weighted average)'),\n",
        "    # both (weighted average)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'both') \u0026 (table3_df.select_threshold == 'NA')\n",
        "    ].assign(Name='both (weighted average)'),\n",
        "    # none (prompt selection, ours)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'none') \u0026 (table3_df.select_threshold != 'NA')\n",
        "    ].assign(Name='none (prompt selection, ours)'),\n",
        "    # pretrain (prompt selection)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'pretrain') \u0026 (table3_df.select_threshold != 'NA')\n",
        "    ].assign(Name='pretrain (prompt selection)'),\n",
        "    # pretrain_star (prompt selection)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'pretrain_star') \u0026 (table3_df.select_threshold != 'NA')\n",
        "    ].assign(Name='pretrain_star (prompt selection)'),\n",
        "    # test (prompt selection)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'test') \u0026 (table3_df.select_threshold != 'NA')\n",
        "    ].assign(Name='test (prompt selection)'),\n",
        "    # both (prompt selection)\n",
        "    table3_df[\n",
        "        (table3_df.debias_mode == 'both') \u0026 (table3_df.select_threshold != 'NA')\n",
        "    ].assign(Name='both (prompt selection)'),\n",
        "])\n",
        "\n",
        "# # Order the rows.\n",
        "# desired_order = ['none', 'pretrain', 'pretrain_star', 'test', 'both']\n",
        "# table3_df = table3_df.sort_values(by='debias_mode', key=lambda x: pd.Categorical(x, categories=desired_order, ordered=True))\n",
        "\n",
        "# Pivot the table.\n",
        "table3_df = table3_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table3_df.columns.name = None\n",
        "table3_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table3_df['variants'] = table3_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table3_df['fine'] = table3_df[ds_fine.keys()].mean(axis=1)\n",
        "table3_df['all'] = table3_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table3_df = table3_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table3_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L0gApPvCUE9k"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table3_df[:5])\n",
        "df_to_latex(table3_df[5:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "r5Ows45IE1VF"
      },
      "outputs": [],
      "source": [
        "#@title Table 4\n",
        "\n",
        "table4_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table4_df = table4_df[\n",
        "    table4_df.num_pretrain.isin([20_000]) \u0026\n",
        "    table4_df.frac_test.isin([1.]) \u0026\n",
        "    table4_df.img_mean.isin([True]) \u0026\n",
        "    table4_df.debias_mode.isin(['both']) \u0026\n",
        "    table4_df.prompt_set.isin(['pool']) \u0026\n",
        "    (\n",
        "        (table4_df.dataset_name.isin(ds_inet.keys()) \u0026 (table4_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table4_df.dataset_name.isin(ds_fine.keys()) \u0026 (table4_df.select_threshold.isin(['NA', 2.])))\n",
        "    )\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table4_df = pd.concat([\n",
        "    # scores (weighted average)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'scores') \u0026 (table4_df.select_threshold == 'NA')\n",
        "    ].assign(Name='scores (weighted average)'),\n",
        "    # scores^10 (weighted average)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'scores^10') \u0026 (table4_df.select_threshold == 'NA')\n",
        "    ].assign(Name='scores^10 (weighted average)'),\n",
        "    # softmax_scores (weighted average)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'softmax_scores') \u0026 (table4_df.select_threshold == 'NA')\n",
        "    ].assign(Name='softmax_scores (weighted average)'),\n",
        "    # scores (prompt selection)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'scores') \u0026 (table4_df.select_threshold != 'NA')\n",
        "    ].assign(Name='scores (prompt selection)'),\n",
        "    # scores^10 (prompt selection)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'scores^10') \u0026 (table4_df.select_threshold != 'NA')\n",
        "    ].assign(Name='scores^10 (prompt selection)'),\n",
        "    # softmax_scores (prompt selection)\n",
        "    table4_df[\n",
        "        (table4_df.weighting == 'softmax_scores') \u0026 (table4_df.select_threshold != 'NA')\n",
        "    ].assign(Name='softmax_scores (prompt selection)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table4_df = table4_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table4_df = table4_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table4_df.columns.name = None\n",
        "table4_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table4_df['variants'] = table4_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table4_df['fine'] = table4_df[ds_fine.keys()].mean(axis=1)\n",
        "table4_df['all'] = table4_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table4_df = table4_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table4_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a1zt5WShVAtX"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table4_df[:3], apply_formatting=True)\n",
        "df_to_latex(table4_df[3:], apply_formatting=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "h9LpEpAX6bHv"
      },
      "outputs": [],
      "source": [
        "#@title Table 5\n",
        "\n",
        "table5_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table5_df = table5_df[\n",
        "    table5_df.num_pretrain.isin(['NA', 20_000]) \u0026\n",
        "    table5_df.frac_test.isin(['NA', 1.]) \u0026\n",
        "    table5_df.img_mean.isin(['NA', True]) \u0026\n",
        "    table5_df.debias_mode.isin(['NA', 'both'])  \u0026\n",
        "    (\n",
        "        (table5_df.dataset_name.isin(ds_inet.keys()) \u0026 (table5_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table5_df.dataset_name.isin(ds_fine.keys()) \u0026 (table5_df.select_threshold.isin(['NA', 2.])))\n",
        "    )\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table5_df = pd.concat([\n",
        "    # hand-crafted, equal average\n",
        "    table5_df[\n",
        "        (table5_df.prompt_set == 'dataset') \u0026 (table5_df.weighting == 'equal')\n",
        "    ].assign(Name='hand-crafted, equal average'),\n",
        "    # pool set, equal average\n",
        "    table5_df[\n",
        "        (table5_df.prompt_set == 'pool') \u0026 (table5_df.weighting == 'equal') \u0026 (table5_df.select_threshold == 'NA')\n",
        "    ].assign(Name='pool set, equal average'),\n",
        "    # ZPE (weighted average)\n",
        "    table5_df[\n",
        "        (table5_df.prompt_set == 'pool') \u0026 (table5_df.weighting == 'softmax_scores') \u0026 (table5_df.select_threshold == 'NA')\n",
        "    ].assign(Name='ZPE (weighted average)'),\n",
        "    # ZPE (prompt selection, ours)\n",
        "    table5_df[\n",
        "        (table5_df.prompt_set == 'pool') \u0026 (table5_df.weighting == 'softmax_scores') \u0026 (table5_df.select_threshold != 'NA')\n",
        "    ].assign(Name='ZPE (prompt selection, ours)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table5_df = table5_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table5_df = table5_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table5_df.columns.name = None\n",
        "table5_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table5_df['variants'] = table5_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table5_df['fine'] = table5_df[ds_fine.keys()].mean(axis=1)\n",
        "table5_df['all'] = table5_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table5_df = table5_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table5_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ExUWS53xUSka"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table5_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oluDxJRh8DfN"
      },
      "outputs": [],
      "source": [
        "#@title Table 6\n",
        "\n",
        "table6_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table6_df = table6_df[\n",
        "    table6_df.num_pretrain.isin(['NA', 20_000]) \u0026\n",
        "    table6_df.frac_test.isin(['NA', 1.]) \u0026\n",
        "    table6_df.img_mean.isin(['NA', True]) \u0026\n",
        "    table6_df.debias_mode.isin(['NA', 'both']) \u0026\n",
        "    (\n",
        "        (table6_df.dataset_name.isin(ds_inet.keys()) \u0026 (table6_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table6_df.dataset_name.isin(ds_fine.keys()) \u0026 (table6_df.select_threshold.isin(['NA', 2.])))\n",
        "    )\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table6_df = pd.concat([\n",
        "    # hand-crafted, equal average\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'dataset') \u0026 (table6_df.weighting == 'equal')\n",
        "    ].assign(Name='hand-crafted, equal average'),\n",
        "    # hand-crafted, ZPE weights\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'dataset') \u0026 (table6_df.weighting == 'softmax_scores')\n",
        "    ].assign(Name='hand-crafted, ZPE weights'),\n",
        "    # ZPE (weighted average, 80 prompts)\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'inet') \u0026 (table6_df.weighting == 'softmax_scores') \u0026 (table6_df.select_threshold == 'NA')\n",
        "    ].assign(Name='ZPE (weighted average, 80 prompts)'),\n",
        "    # ZPE (weighted average, 247 prompts)\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'pool') \u0026 (table6_df.weighting == 'softmax_scores') \u0026 (table6_df.select_threshold == 'NA')\n",
        "    ].assign(Name='ZPE (weighted average, 247 prompts)'),\n",
        "    # ZPE (prompt selection, 80 prompts)\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'inet') \u0026 (table6_df.weighting == 'softmax_scores') \u0026 (table6_df.select_threshold != 'NA')\n",
        "    ].assign(Name='ZPE (prompt selection, 80 prompts)'),\n",
        "    # ZPE (prompt selection, 247 prompts)\n",
        "    table6_df[\n",
        "        (table6_df.prompt_set == 'pool') \u0026 (table6_df.weighting == 'softmax_scores') \u0026 (table6_df.select_threshold != 'NA')\n",
        "    ].assign(Name='ZPE (prompt selection, 247 prompts)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table6_df = table6_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table6_df = table6_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table6_df.columns.name = None\n",
        "table6_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table6_df['variants'] = table6_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table6_df['fine'] = table6_df[ds_fine.keys()].mean(axis=1)\n",
        "table6_df['all'] = table6_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table6_df = table6_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table6_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sDXljulaUWUi"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table6_df[:2], apply_formatting=False)\n",
        "df_to_latex(table6_df[2:5], apply_formatting=True)\n",
        "df_to_latex(table6_df[5:], apply_formatting=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "cYhdzyBF969A"
      },
      "outputs": [],
      "source": [
        "#@title Table 7\n",
        "\n",
        "table7_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table7_df = table7_df[\n",
        "    table7_df.num_pretrain.isin([20_000, 10_000, 5_000]) \u0026\n",
        "    table7_df.frac_test.isin([1.]) \u0026\n",
        "    table7_df.img_mean.isin([True]) \u0026\n",
        "    table7_df.weighting.isin(['softmax_scores']) \u0026\n",
        "    table7_df.debias_mode.isin(['both']) \u0026\n",
        "    table7_df.prompt_set.isin(['pool']) \u0026\n",
        "    (\n",
        "        (table7_df.dataset_name.isin(ds_inet.keys()) \u0026 (table7_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table7_df.dataset_name.isin(ds_fine.keys()) \u0026 (table7_df.select_threshold.isin(['NA', 2.])))\n",
        "    )\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table7_df = pd.concat([\n",
        "    # 5k (weighted average)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 5_000) \u0026 (table7_df.select_threshold == 'NA')\n",
        "    ].assign(Name='5k (weighted average)'),\n",
        "    # 10k (weighted average)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 10_000) \u0026 (table7_df.select_threshold == 'NA')\n",
        "    ].assign(Name='10k (weighted average)'),\n",
        "    # 20k (weighted average)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 20_000) \u0026 (table7_df.select_threshold == 'NA')\n",
        "    ].assign(Name='20k (weighted average)'),\n",
        "    # 5k (prompt selection)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 5_000) \u0026 (table7_df.select_threshold != 'NA')\n",
        "    ].assign(Name='5k (prompt selection)'),\n",
        "    # 10k (prompt selection)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 10_000) \u0026 (table7_df.select_threshold != 'NA')\n",
        "    ].assign(Name='10k (prompt selection)'),\n",
        "    # 20k (prompt selection)\n",
        "    table7_df[\n",
        "        (table7_df.num_pretrain == 20_000) \u0026 (table7_df.select_threshold != 'NA')\n",
        "    ].assign(Name='20k (prompt selection)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table7_df = table7_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table7_df = table7_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table7_df.columns.name = None\n",
        "table7_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table7_df['variants'] = table7_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table7_df['fine'] = table7_df[ds_fine.keys()].mean(axis=1)\n",
        "table7_df['all'] = table7_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table7_df = table7_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table7_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q9XWpCkaUZCt"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table7_df, apply_formatting=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "1EK_cAdHCAOT"
      },
      "outputs": [],
      "source": [
        "#@title Table 8\n",
        "\n",
        "table8_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table8_df = table8_df[\n",
        "    table8_df.num_pretrain.isin([20_000]) \u0026\n",
        "    table8_df.frac_test.isin([1., .5, .2, .1]) \u0026\n",
        "    table8_df.img_mean.isin([True]) \u0026\n",
        "    table8_df.weighting.isin(['softmax_scores']) \u0026\n",
        "    table8_df.debias_mode.isin(['both']) \u0026\n",
        "    table8_df.prompt_set.isin(['pool']) \u0026\n",
        "    (\n",
        "        (table8_df.dataset_name.isin(ds_inet.keys()) \u0026 (table8_df.select_threshold.isin(['NA', .5]))) |\n",
        "        (table8_df.dataset_name.isin(ds_fine.keys()) \u0026 (table8_df.select_threshold.isin(['NA', 2.])))\n",
        "    )\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table8_df = pd.concat([\n",
        "    # 10% (weighted average)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .1) \u0026 (table8_df.select_threshold == 'NA')\n",
        "    ].assign(Name='10% (weighted average)'),\n",
        "    # 20% (weighted average)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .2) \u0026 (table8_df.select_threshold == 'NA')\n",
        "    ].assign(Name='20% (weighted average)'),\n",
        "    # 50% (weighted average)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .5) \u0026 (table8_df.select_threshold == 'NA')\n",
        "    ].assign(Name='50% (weighted average)'),\n",
        "    # 100% (weighted average)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == 1.) \u0026 (table8_df.select_threshold == 'NA')\n",
        "    ].assign(Name='100% (weighted average)'),\n",
        "    # 10% (prompt selection)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .1) \u0026 (table8_df.select_threshold != 'NA')\n",
        "    ].assign(Name='10% (prompt selection)'),\n",
        "    # 20% (prompt selection)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .2) \u0026 (table8_df.select_threshold != 'NA')\n",
        "    ].assign(Name='20% (prompt selection)'),\n",
        "    # 50% (prompt selection)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == .5) \u0026 (table8_df.select_threshold != 'NA')\n",
        "    ].assign(Name='50% (prompt selection)'),\n",
        "    # 100% (prompt selection)\n",
        "    table8_df[\n",
        "        (table8_df.frac_test == 1.) \u0026 (table8_df.select_threshold != 'NA')\n",
        "    ].assign(Name='100% (prompt selection)'),\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table8_df = table8_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table8_df = table8_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table8_df.columns.name = None\n",
        "table8_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table8_df['variants'] = table8_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table8_df['fine'] = table8_df[ds_fine.keys()].mean(axis=1)\n",
        "table8_df['all'] = table8_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table8_df = table8_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table8_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gyobh0TzUe_o"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table8_df, apply_formatting=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "3w_hhd8Qlwzf"
      },
      "outputs": [],
      "source": [
        "#@title Table 9\n",
        "\n",
        "table9_df = results_df\n",
        "\n",
        "# Remove most of the ablation rows.\n",
        "table9_df = table9_df[\n",
        "    table9_df.num_pretrain.isin([20_000, 'NA']) \u0026\n",
        "    table9_df.frac_test.isin([1., 'NA']) \u0026\n",
        "    table9_df.debias_mode.isin(['both', 'none', 'NA']) \u0026\n",
        "    table9_df.prompt_set.isin(['dataset', 'pool']) \u0026\n",
        "    table9_df.weighting.isin(['softmax_scores', 'scores', 'equal']) \u0026\n",
        "    table9_df.select_threshold.isin(['NA'])\n",
        "]\n",
        "\n",
        "# Construct table rows.\n",
        "table9_df = pd.concat([\n",
        "    # hand crafted, equal average\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'dataset') \u0026 (table9_df.weighting == 'equal')\n",
        "    ].assign(Name='hand crafted, equal average'),\n",
        "    # hand crafted, ZPE weights, per-dataset\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'dataset') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == True)\n",
        "    ].assign(Name='hand crafted, ZPE weights, per-dataset'),\n",
        "    # hand crafted, ZPE weights, per-example\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'dataset') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == False)\n",
        "    ].assign(Name='hand crafted, ZPE weights, per-example'),\n",
        "\n",
        "    # pool set, equal average\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'equal')\n",
        "    ].assign(Name='pool set, equal average'),\n",
        "    # pool set, ZPE weights, per-dataset\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == True)\n",
        "    ].assign(Name='pool set, ZPE weights, per-dataset'),\n",
        "    # pool set, ZPE weights, per-example\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == False)\n",
        "    ].assign(Name='pool set, ZPE weights, per-example'),\n",
        "\n",
        "    # pool set, ZPE weights, per-dataset, no softmax\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == True)\n",
        "    ].assign(Name='pool set, ZPE weights, per-dataset, no softmax'),\n",
        "    # pool set, ZPE weights, per-example, no softmax\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'scores') \u0026 (table9_df.debias_mode == 'both') \u0026 (table9_df.img_mean == False)\n",
        "    ].assign(Name='pool set, ZPE weights, per-example, no softmax'),\n",
        "\n",
        "    # pool set, ZPE weights, per-dataset, no norm\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'none') \u0026 (table9_df.img_mean == True)\n",
        "    ].assign(Name='pool set, ZPE weights, per-dataset, no norm'),\n",
        "    # pool set, ZPE weights, per-example, no norm\n",
        "    table9_df[\n",
        "        (table9_df.prompt_set == 'pool') \u0026 (table9_df.weighting == 'softmax_scores') \u0026 (table9_df.debias_mode == 'none') \u0026 (table9_df.img_mean == False)\n",
        "    ].assign(Name='pool set, ZPE weights, per-example, no norm'),\n",
        "\n",
        "])\n",
        "\n",
        "# Drop columns.\n",
        "table9_df = table9_df[['Name', 'dataset_name', 'top1_acc']]\n",
        "\n",
        "# Pivot the table.\n",
        "table9_df = table9_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)\n",
        "\n",
        "# Drop extra levels.\n",
        "table9_df.columns.name = None\n",
        "table9_df.index.name = None\n",
        "\n",
        "# Create the new columns.\n",
        "table9_df['variants'] = table9_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)\n",
        "table9_df['fine'] = table9_df[ds_fine.keys()].mean(axis=1)\n",
        "table9_df['all'] = table9_df.mean(axis=1)\n",
        "\n",
        "\n",
        "# Drop columns.\n",
        "table9_df = table9_df[['imagenet', 'variants', 'fine', 'all']]\n",
        "\n",
        "table9_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Przv8q4uYNu"
      },
      "outputs": [],
      "source": [
        "df_to_latex(table9_df[:3], apply_formatting=True)\n",
        "df_to_latex(table9_df[3:6], apply_formatting=True)\n",
        "df_to_latex(table9_df[6:8], apply_formatting=True)\n",
        "df_to_latex(table9_df[8:], apply_formatting=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sRJ7Q85TA1iC"
      },
      "source": [
        "## Collect per-prompt per-dataset scores (Appendix C)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "De1nSaosK_HX"
      },
      "outputs": [],
      "source": [
        "OVERWRITE_RESULTS = False\n",
        "LOAD_RESULTS = False\n",
        "\n",
        "per_dataset_per_prompt_scores_df_name = 'per_dataset_per_prompt_scores.pkl'\n",
        "per_dataset_per_prompt_scores_df_path = os.path.join(base_path, per_dataset_per_prompt_scores_df_name)\n",
        "\n",
        "per_dataset_per_prompt_scores_df = pd.DataFrame(columns = [\n",
        "    'dataset_name', 'prompt', 'score'\n",
        "])\n",
        "\n",
        "if OVERWRITE_RESULTS:\n",
        "    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'w') as f:\n",
        "        f.write(pickle.dumps(per_dataset_per_prompt_scores_df, protocol=4))\n",
        "\n",
        "if LOAD_RESULTS:\n",
        "    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'rb') as f:\n",
        "        per_dataset_per_prompt_scores_df = pickle.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xo9mEAUwA-UH"
      },
      "outputs": [],
      "source": [
        "for dataset_name, batch_size in ds_list.items():\n",
        "    print(\"dataset_name\", dataset_name)\n",
        "\n",
        "    # Collect prompts.\n",
        "    pool_idxs = np.array([all_templates.index(p) for p in pool_templates])\n",
        "\n",
        "    # Get prompt embeddings.\n",
        "    classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]\n",
        "    classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])\n",
        "    del classnames\n",
        "    ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]\n",
        "    del ztxts_all_prompts_all_class\n",
        "    del classname_idxs\n",
        "    ztxts_pool = ztxts_all_prompts[pool_idxs, :, :]\n",
        "    del ztxts_all_prompts\n",
        "    del pool_idxs\n",
        "\n",
        "    # Get image embeddings.\n",
        "    zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=batch_size)\n",
        "    ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))\n",
        "    zimgs, _ = compute_image_embeddings(ds_iter, image_resolution)\n",
        "\n",
        "    # Get logits.\n",
        "    pool_logits = get_logits(ztxts_pool, zimgs)\n",
        "    random_logits = get_logits(ztxts_pool, zimgs_laion)  # [n_pretrain, n_prompts, n_classes_ds]\n",
        "    del zimgs\n",
        "    del ztxts_pool\n",
        "\n",
        "    pool_weights = get_weights(pool_logits, random_logits, debias_mode='both', img_mean=True, frac_test=1.)\n",
        "    del pool_logits\n",
        "    del random_logits\n",
        "    pool_weights = jax.nn.softmax(pool_weights, axis=1)[0, :, 0]\n",
        "\n",
        "\n",
        "    idx = np.argsort(np.abs(pool_weights))\n",
        "    prompts_ordered = np.array(pool_templates)[idx]\n",
        "    weights_list = pool_weights[idx][::-1]\n",
        "    prompts_list = prompts_ordered[::-1]\n",
        "\n",
        "    triplets = zip(itertools.repeat(dataset_name), prompts_list, weights_list)\n",
        "    per_dataset_per_prompt_scores_df = pd.concat([\n",
        "        per_dataset_per_prompt_scores_df,\n",
        "        pd.DataFrame(triplets, columns=per_dataset_per_prompt_scores_df.columns)\n",
        "    ], ignore_index=True)\n",
        "\n",
        "    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'w') as f:\n",
        "        f.write(pickle.dumps(per_dataset_per_prompt_scores_df, protocol=4))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cd8FcZDED2mj"
      },
      "outputs": [],
      "source": [
        "per_dataset_per_prompt_scores_df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a5jF9bGEi6J7"
      },
      "outputs": [],
      "source": [
        "per_dataset_per_prompt_scores_df[per_dataset_per_prompt_scores_df.dataset_name == 'caltech101']['prompt'].to_csv(\"pool_set.csv\", index=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jWObbJ8DR5_h"
      },
      "outputs": [],
      "source": [
        "for dataset_name, group in per_dataset_per_prompt_scores_df.groupby(\"dataset_name\"):\n",
        "    print(dataset_name)\n",
        "\n",
        "    for i, row in group.reset_index().head(10).iterrows():\n",
        "        prompt = row[\"prompt\"].replace(\"{\", \"\\{\").replace(\"}\", \"\\}\")\n",
        "        print(f'{i+1} \u0026 `\\\\emph⁍{prompt}⁌\\' \u0026 {row[\"score\"]:0.4f} \\\\\\\\'.replace(\"⁍\", \"{\").replace(\"⁌\", \"}\"))\n",
        "    print('\\\\multicolumn{3}{c}{\\\\vdots} \\\\\\\\')\n",
        "    for i, row in group.reset_index().tail(10).iterrows():\n",
        "        prompt = row[\"prompt\"].replace(\"{\", \"\\{\").replace(\"}\", \"\\}\")\n",
        "        print(f'{i+1} \u0026 `\\\\emph⁍{prompt}⁌\\' \u0026 {row[\"score\"]:0.4f} \\\\\\\\'.replace(\"⁍\", \"{\").replace(\"⁌\", \"}\"))\n",
        "\n",
        "    print(\"\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MTEYFUAimCfo"
      },
      "source": [
        "## Make figure 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PWiBAWFetBx_"
      },
      "outputs": [],
      "source": [
        "dataset_name = 'imagenet'\n",
        "\n",
        "# Collect prompts.\n",
        "pool_idxs = np.array([all_templates.index(p) for p in pool_templates])\n",
        "\n",
        "# Get prompt embeddings.\n",
        "classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]\n",
        "classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])\n",
        "del classnames\n",
        "ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]\n",
        "del ztxts_all_prompts_all_class\n",
        "del classname_idxs\n",
        "ztxts_pool = ztxts_all_prompts[pool_idxs, :, :]\n",
        "del ztxts_all_prompts\n",
        "del pool_idxs\n",
        "\n",
        "# Get image embeddings.\n",
        "zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=5000)\n",
        "ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))\n",
        "zimgs, labels = compute_image_embeddings(ds_iter, image_resolution)\n",
        "\n",
        "# Get logits.\n",
        "pool_logits = get_logits(ztxts_pool, zimgs)\n",
        "random_logits = get_logits(ztxts_pool, zimgs_laion)  # [n_pretrain, n_prompts, n_classes_ds]\n",
        "del zimgs\n",
        "del ztxts_pool\n",
        "\n",
        "pool_weights = get_weights(pool_logits, random_logits, debias_mode='both', img_mean=True, frac_test=1.)\n",
        "del random_logits\n",
        "\n",
        "pool_weights_softmax = jax.nn.softmax(pool_weights, axis=1)[0, :, 0]\n",
        "pool_weights = pool_weights[0, :, 0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4A48X8zFt7TU"
      },
      "outputs": [],
      "source": [
        "pcts = 1 - np.arange(0.02, 1.01, 0.02)\n",
        "n_prompts = pool_weights.shape[0]\n",
        "n_prompts_selected = set(np.floor((1 - pcts) * n_prompts).astype(np.int32)) | set([1])\n",
        "n_prompts_selected = np.sort(np.array(list(n_prompts_selected)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M4AOufZatUg_"
      },
      "outputs": [],
      "source": [
        "accs_no_softmax = []\n",
        "for n in n_prompts_selected:\n",
        "  idxs = np.argsort(1/np.abs(pool_weights))[:n]\n",
        "  mask = np.zeros((1, pool_n_prompts, 1))\n",
        "  mask[0, idxs, 0] = 1\n",
        "  masked_weights = mask * pool_weights[jnp.newaxis, :, jnp.newaxis]\n",
        "  logits_pool_weighted = agg_logits(pool_logits, weights=masked_weights)\n",
        "  acc, _, _, _ = compute_metrics(labels, logits_pool_weighted);\n",
        "  accs_no_softmax.append(acc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k3x_Dkq-25mQ"
      },
      "outputs": [],
      "source": [
        "accs_softmax = []\n",
        "for n in n_prompts_selected:\n",
        "  idxs = np.argsort(1/np.abs(pool_weights_softmax))[:n]\n",
        "  mask = np.zeros((1, pool_n_prompts, 1))\n",
        "  mask[0, idxs, 0] = 1\n",
        "  masked_weights = mask * pool_weights_softmax[jnp.newaxis, :, jnp.newaxis]\n",
        "  logits_pool_weighted = agg_logits(pool_logits, weights=masked_weights)\n",
        "  acc, _, _, _ = compute_metrics(labels, logits_pool_weighted);\n",
        "  accs_softmax.append(acc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yj7JkqURfwwu"
      },
      "outputs": [],
      "source": [
        "fig, (ax1, ax2) = plt.subplots(1, 2, dpi=400, figsize=(line_width, text_width/4.5), tight_layout=True, sharey=True)\n",
        "\n",
        "\n",
        "ax1.plot(n_prompts_selected, accs_no_softmax, '-', alpha=0.8, color='C0', lw=0.8)\n",
        "idx = np.argmax(accs_no_softmax)\n",
        "ax1.plot(n_prompts_selected[idx], accs_no_softmax[idx], marker='*', color='C0', lw=0.8, ms=3)\n",
        "\n",
        "ax12 = ax1.twinx()\n",
        "idx = np.argsort(1/np.abs(pool_weights))\n",
        "ax12.plot(np.abs(pool_weights)[idx], alpha=0.8, lw=0.8, color='C2')\n",
        "\n",
        "\n",
        "ax2.plot(n_prompts_selected, accs_softmax, '-', alpha=0.8, color='C0', lw=0.8)\n",
        "idx = np.argmax(accs_softmax)\n",
        "ax2.plot(n_prompts_selected[idx], accs_softmax[idx], marker='*', color='C0', lw=0.8, ms=3)\n",
        "\n",
        "ax22 = ax2.twinx()\n",
        "idx = np.argsort(1/np.abs(pool_weights_softmax))\n",
        "ax22.plot(np.abs(pool_weights_softmax)[idx], alpha=0.8, lw=0.8, color='C2')\n",
        "\n",
        "\n",
        "ax1.set_ylabel('acc')\n",
        "ax22.set_ylabel('score')\n",
        "fig.text(0.5, 0.0, 'prompt index', ha='center')\n",
        "ax1.set_title('no softmax')\n",
        "ax2.set_title('softmax')\n",
        "\n",
        "legend_elements = [\n",
        "    Line2D([0], [0], alpha=0.8, color='C0', lw=0.8, label='acc'),\n",
        "    Line2D([0], [0], alpha=0.8, color='C2', lw=0.8, label='score'),\n",
        "]\n",
        "ax2.legend(handles=legend_elements, loc='center right')\n",
        "\n",
        "ax1.grid(alpha=0.3)\n",
        "ax2.grid(alpha=0.3)\n",
        "\n",
        "plt.savefig(\"acc_score_curves.pdf\", dpi=400, format=\"pdf\", bbox_inches='tight', pad_inches=0.01)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpHgRIbr3imT"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/multimodal/Zero_shot_prompt_ensembling_for_text_image_models_results.ipynb",
          "timestamp": 1691470422212
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/ood_clm/Zero_shot_prompt_ensembling_for_text_image_models_results.ipynb?workspaceId=jjren:james_open::citc",
          "timestamp": 1689358685529
        },
        {
          "file_id": "18tnwm_f-1nY7opx22khKibJTRyWexuVv",
          "timestamp": 1689358670966
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
