{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WZR1tYwA2o8C"
      },
      "source": [
        "# RDL Big Paper Plots\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com).\n",
        "\n",
        "This colab loads raw measurements from disk and analyzes the results.\n",
        "\n",
        "## Choosing optimal hyperparameters\n",
        "We automatically detect hyperparameter sweeps by selecting fields that don't correspond to dataset metrics but that have more than one chosen value. We choose the hyperparameters that achieve the best according a given metric (see `dataset_metric`) after averaging over random seeds. For example, if the model is trained on CIFAR-10, we use CIFAR-10's validation loss.\n",
        "\n",
        "## Plots\n",
        "All plots report the performance of a given model according to its optimal hyperparameters chosen above. When there are runs with multiple seeds, we show the mean and standard deviation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h3lFstNIFvry"
      },
      "outputs": [],
      "source": [
        "from typing import Dict\n",
        "import itertools\n",
        "import os\n",
        "import pickle\n",
        "\n",
        "import colabtools.fileedit\n",
        "from importlib import reload\n",
        "from IPython import display\n",
        "import matplotlib\n",
        "\n",
        "matplotlib.rcParams['font.sans-serif'] = \"Times New Roman\"\n",
        "matplotlib.rcParams['font.family'] = \"sans-serif\"\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as mtick\n",
        "import seaborn as sns\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "\n",
        "colab_utils = None\n",
        "\n",
        "if colab_utils is None:\n",
        "  !rm -rf uncertainty-baselines\n",
        "  !git clone https://github.com/google/uncertainty-baselines.git\n",
        "  !cp uncertainty-baselines/experimental/plex/colab_utils.py .\n",
        "  import colab_utils\n",
        "\n",
        "%matplotlib inline\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "matplotlib.rcParams['figure.dpi'] = 1000\n",
        "matplotlib.rcParams['lines.linewidth'] = 1.25"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-qpT4rtuEN14"
      },
      "source": [
        "## Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T4JuGWk_t6FW"
      },
      "outputs": [],
      "source": [
        "#@title Choosing optimal hyperparameters\n",
        "\n",
        "# The finetuning deterministic jobs use a fixed random seed but different\n",
        "# upstream checkpoints, which themselves correspond to different random seeds.\n",
        "# In this case, we thus marginalize over upstream checkpoints\n",
        "# (`config.model_init`) rather than the random seed.\n",
        "\n",
        "DATASET_METRIC = {\n",
        "    'cifar10': 'val_loss',\n",
        "    'cifar100': 'val_loss',\n",
        "    'imagenet2012': 'val_loss',\n",
        "    'imagenet21k': 'val_loss',\n",
        "    'jft/entity:1.0.0': 'val_loss',\n",
        "    'retina_country': 'in_domain_validation/auroc',\n",
        "    'retina_severity': 'in_domain_validation/auroc',\n",
        "    'imagenet_variants': 'imagenet/nll',\n",
        "}\n",
        "\n",
        "\n",
        "def get_optimal_results(measurements: Dict[str, pd.DataFrame],\n",
        "                        dataset_metric: Dict[str, str] = DATASET_METRIC,\n",
        "                        verbose=True) -\u003e pd.DataFrame:\n",
        "  \"\"\"Returns a dataframe, typically with one result per model type.\n",
        "\n",
        "  A model type may have multiple results that will be averaged over when\n",
        "  plotting (e.g., random seeds).\n",
        "\n",
        "  Args:\n",
        "    measurements: Dictionary of dataframes to obtain best results for.\n",
        "    dataset_metric: Each dataset's metric to tune for, in the format\n",
        "      `{dataset: metric}`.\n",
        "  \"\"\"\n",
        "  results = []\n",
        "\n",
        "  model_to_marginalization_hparams = {\n",
        "      m: 'config.model_init'\n",
        "      for m in ('Det', 'Det I21K', 'DE', 'DE S/32', 'DE B/32', 'DE L/32',\n",
        "                'Det-\u003eDE', '[Det]_4', 'Det-\u003e[Det]_4', 'Det-\u003eBE')\n",
        "  }\n",
        "  model_to_marginalization_hparams.update({\n",
        "      m: 'config.dune_experts.xid_wid'\n",
        "      for m in ('MoE', 'E^3', '[MoE]_4', 'MoE-\u003e[MoE]_4')\n",
        "  })\n",
        "\n",
        "  for k, v in measurements.items():\n",
        "    marginalization_hparams = (colab_utils.random_seed_col(),)\n",
        "    if k in model_to_marginalization_hparams:\n",
        "      marginalization_hparams += (model_to_marginalization_hparams[k],)\n",
        "    for ds in v[colab_utils.dataset_col()].unique():\n",
        "      df = v[v[colab_utils.dataset_col()] == ds]\n",
        "      try:\n",
        "        results.append(\n",
        "            colab_utils.get_tuned_results(\n",
        "                df,\n",
        "                tuning_metric=dataset_metric[ds],\n",
        "                marginalization_hparams=marginalization_hparams,\n",
        "                verbose=verbose))\n",
        "      except KeyError:\n",
        "        print(f'Could not get optimal results for {k}, {ds}.')\n",
        "    print()\n",
        "  return pd.concat(results)\n",
        "\n",
        "\n",
        "def get_optimal_fewshot_results(measurements: Dict[str, pd.DataFrame],\n",
        "                                verbose=True) -\u003e pd.DataFrame:\n",
        "  \"\"\"Returns a dataframe, typically with one result per model type.\n",
        "\n",
        "  A model type may have multiple results that will be averaged over when\n",
        "  plotting (e.g., random seeds).\n",
        "\n",
        "  Args:\n",
        "    measurements: Dictionary of dataframes to obtain best results for.\n",
        "  \"\"\"\n",
        "  results = []\n",
        "  for k, v in measurements.items():\n",
        "    marginalization_hparams = (colab_utils.random_seed_col(),)\n",
        "    marginalization_hparams += ('config.model_init',)\n",
        "    for ds in v[colab_utils.dataset_col()].unique():\n",
        "      df = v[v[colab_utils.dataset_col()] == ds]\n",
        "      try:\n",
        "        # Gets the model and dataset and standard hps.\n",
        "        dataset = colab_utils.get_unique_value(df, colab_utils._DATASET_COL)\n",
        "        model = colab_utils.get_unique_value(df, colab_utils._MODEL_COL)\n",
        "        hps = colab_utils.get_sweeped_hyperparameters(df, marginalization_hparams)\n",
        "\n",
        "        # Finds the best l2 reg for each shot experiment.\n",
        "        best_l2 = {}\n",
        "        non_metric_cols = [c for c in df.columns if '/' not in c]\n",
        "        dfs_optimal = []\n",
        "        for shot in [1, 5, 10, 25]:\n",
        "          tuning_metrics = [c for c in df.columns if c.endswith(f'_{shot}/test_prec@1')]\n",
        "          marginalized_df = df.groupby(hps)[tuning_metrics].agg('mean').reset_index()\n",
        "          reg_ranks = []\n",
        "          for tuning_metric in tuning_metrics:\n",
        "            reg_accus = marginalized_df[tuning_metric].to_numpy()\n",
        "            reg_ranks.append(np.argsort(np.argsort(reg_accus)))\n",
        "          best_l2[shot] = marginalized_df['config.l2_reg'][np.argmax(np.mean(reg_ranks, axis=0))]\n",
        "\n",
        "          for fewshot_ds in colab_utils.default_fewshot_datasets():\n",
        "            ds_shot_specific_metric_cols = [c for c in df.columns if f'{fewshot_ds}_{shot}/' in c]\n",
        "            dfc = df[df['config.l2_reg']==best_l2[shot]][non_metric_cols + ds_shot_specific_metric_cols].copy()\n",
        "            dfc = dfc.rename(columns={m: str(shot) + 'shot_' + m.split('/')[1] for m in ds_shot_specific_metric_cols})\n",
        "            dfc['config.dataset'] = f'few-shot {fewshot_ds}'\n",
        "            dfs_optimal.append(dfc)\n",
        "        results.append(pd.concat(dfs_optimal))\n",
        "      except KeyError:\n",
        "        print(f'Could not get optimal results for {k}, {ds}.')\n",
        "  return pd.concat(results)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMdPlsI12gec"
      },
      "outputs": [],
      "source": [
        "#@title Pretty printing \n",
        "\n",
        "def pprint(df, models=None, exclude_models=None):\n",
        "  \"\"\"Pretty print dataframe.\n",
        "\n",
        "  Args:\n",
        "    df: Dataframe.\n",
        "    models: Optional list of models to only show. Useful for comparing specific\n",
        "      models to see which performs better (highlighted cells).\n",
        "    exclude_models: Optional list of models to exclude.\n",
        "  \"\"\"\n",
        "  def _rename(m):\n",
        "    m = m.replace('cifar_10h', 'cifar10h')\n",
        "    m = m.replace('places365_small', 'places365')\n",
        "    m = m.replace('imagenet_', 'imagenet-')\n",
        "    m = m.replace('/mean', '')\n",
        "    m = m.replace('/', ' ')\n",
        "    m = m.replace('_', ' ')\n",
        "    m = m.replace('cropped ', '')\n",
        "    m = m.replace('ood', '')\n",
        "    m = m.replace('ece', 'ECE')\n",
        "    m = m.replace('auc', 'AUC')\n",
        "    m = m.replace('auroc', 'AUROC')\n",
        "    m = m.replace('loss', 'NLL')\n",
        "    m = m.replace('negative log likelih', 'NLL')\n",
        "    m = m.replace('nll', 'NLL')\n",
        "    m = m.replace('brier', 'Brier')\n",
        "    m = m.replace('mce', 'mCE')\n",
        "    m = m.replace('pmk', 'p-mk')\n",
        "    return m\n",
        "  def _formatter(metric):\n",
        "    if any(x in metric for x in ['AUROC', 'AUC']):\n",
        "      return '{:.2f}'.format\n",
        "    elif any(x in metric for x in ['prec', 'ECE', 'accuracy']):\n",
        "      return lambda x: '{:.1f}%'.format(x * 100)\n",
        "    elif any(x in metric for x in ['score', 'exaflops', 'tpu days', 'gflops', \n",
        "                                   'ms step']):\n",
        "      return lambda x: '{:.1f}'.format(x)\n",
        "    elif any(x in metric for x in ['NLL', 'Brier']):\n",
        "      return '{:.3f}'.format\n",
        "    else:\n",
        "      return lambda x: x\n",
        "  def _highlight(data, color='#90EE90'):\n",
        "    attr = 'background-color: {}'.format(color)\n",
        "    data = data.replace('%','', regex=True).astype(float)\n",
        "    if any(x in data.name[1] for x in ['NLL', 'ECE', 'Brier', 'mCE',\n",
        "                                       'relative mCE', 'accuracy drop',\n",
        "                                       'accuracy pm-k']):\n",
        "      is_best = data == data.min()\n",
        "    elif any(x in data.name[1] for x in ['exaflops', 'tpu days', 'gflops',\n",
        "                                         'ms step']):\n",
        "      is_best = data == 'asdf'\n",
        "    else:\n",
        "      is_best = data == data.max()\n",
        "    return [attr if v else '' for v in is_best]\n",
        "\n",
        "  df = df.copy()\n",
        "  df = df.rename(columns=_rename)\n",
        "  for c in df:\n",
        "    df[c] = df[c].apply(_formatter(c[0]))\n",
        "\n",
        "  # Swap order of column's multiindex to be dataset first.\n",
        "  df.columns = df.columns.swaplevel(0, 1)\n",
        "  df = df.sort_index(axis=1, level=0)\n",
        "\n",
        "  df = df.T\n",
        "  if models is not None:\n",
        "    df = df[[c for c in df.columns if c in models]]\n",
        "  elif exclude_models is not None:\n",
        "    df = df[[c for c in df.columns if c not in exclude_models]]\n",
        "\n",
        "  return display.display(df.style.apply(_highlight, axis=1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TBqSMsgf0U2k",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "#@title RETINA\n",
        "REBUILD_RETINA_RESULTS_CACHE = False\n",
        "\n",
        "if REBUILD_RETINA_RESULTS_CACHE:\n",
        "  import os\n",
        "  os.system('pip install wandb')\n",
        "  import wandb\n",
        "\n",
        "# TODO(nband): add grid search results (currently random search).\n",
        "RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB = {\n",
        "  ('aptos', 'deterministic'): 'vit32-finetune-aptos-deterministic-focused-3',\n",
        "  ('aptos', 'batchensemble'): 'vit32-finetune-aptos-batchensemble',\n",
        "  ('severity', 'deterministic'): 'vit32-finetune-severity-deterministic',\n",
        "  ('severity', 'batchensemble'): 'vit32-finetune-severity-batchensemble-focused-1'\n",
        "}\n",
        "\n",
        "RETINA_SHIFTS = ['aptos', 'severity']\n",
        "RETINA_UQ_METHODS = ['deterministic', 'batchensemble']\n",
        "RETINA_UQ_METHOD_TO_DF_NAME = {\n",
        "    'deterministic': 'Det I21K',\n",
        "    'batchensemble': 'BE L/32 (I21K)'\n",
        "}\n",
        "\n",
        "RETINA_SHIFT_TO_METRICS = {\n",
        "  'aptos': [\n",
        "    # In-Domain\n",
        "    'in_domain_test.in_domain_test/accuracy',\n",
        "    'in_domain_test.in_domain_test/negative_log_likelihood',\n",
        "    'in_domain_test.in_domain_test/ece',\n",
        "    'in_domain_test.in_domain_test/retention_auroc_auc',\n",
        "    # OOD\n",
        "    'ood_test.ood_test/accuracy',\n",
        "    'ood_test.ood_test/negative_log_likelihood',\n",
        "    'ood_test.ood_test/ece',\n",
        "    'ood_test.ood_test/retention_auroc_auc'\n",
        "  ],\n",
        "  'severity': [\n",
        "    # In-Domain\n",
        "    'in_domain_test.in_domain_test/accuracy',\n",
        "    'in_domain_test.in_domain_test/negative_log_likelihood',\n",
        "    'in_domain_test.in_domain_test/ece',\n",
        "    'in_domain_test.in_domain_test/retention_auroc_auc',\n",
        "    # OOD\n",
        "    'ood_test.ood_test/accuracy',\n",
        "    'ood_test.ood_test/negative_log_likelihood',\n",
        "    'ood_test.ood_test/ece',\n",
        "    'ood_test.ood_test/retention_accuracy_auc'\n",
        "  ]\n",
        "}\n",
        "RETINA_MODEL_SELECTION_METRIC = 'in_domain_validation.in_domain_validation/auroc'\n",
        "\n",
        "# Split RETINA results into the two distributional shifts: Country Shift and\n",
        "# Severity Shift.\n",
        "\n",
        "SHIFT_MAP = {'aptos': 'country', 'severity': 'severity'}\n",
        "\n",
        "\n",
        "def select_top_model_from_project(project_name):\n",
        "  api = wandb.Api(timeout=100000000)\n",
        "  runs = api.runs(project_name)\n",
        "  print(f'Retrieved run results from Weights \u0026 Biases project {project_name}.')\n",
        "  sweep_history_df = []\n",
        "\n",
        "  # Get all full histories\n",
        "  for run in runs:\n",
        "    run_history_df = pd.DataFrame(run._full_history())\n",
        "\n",
        "    # Add run name\n",
        "    run_history_df['run_name'] = run.name\n",
        "    sweep_history_df.append(run_history_df)\n",
        "\n",
        "  sweep_history_df = pd.concat(sweep_history_df)\n",
        "  sweep_history_df.reset_index(inplace=True)\n",
        "\n",
        "  # Best performing step of the best performing model\n",
        "  top_idx = sweep_history_df[RETINA_MODEL_SELECTION_METRIC].idxmax()\n",
        "  return sweep_history_df.iloc[top_idx]\n",
        "\n",
        "\n",
        "def get_retina_i21k_results_df():\n",
        "  all_results_df = []\n",
        "  for shift in RETINA_SHIFTS:\n",
        "    for uq_method in RETINA_UQ_METHODS:\n",
        "      print(f'Retrieving results from shift {shift}, '\n",
        "            f'uncertainty quantification method {uq_method}.')\n",
        "      wandb_project = RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB[(shift, uq_method)]\n",
        "      model_results = select_top_model_from_project(wandb_project)\n",
        "      result_df = model_results.to_frame().T\n",
        "      result_df['shift'] = shift\n",
        "      result_df['uq_method'] = uq_method\n",
        "      all_results_df.append(result_df)\n",
        "\n",
        "  return pd.concat(all_results_df)\n",
        "\n",
        "\n",
        "def add_retina_i21k_results(retina_results_df, preprocessed_df, shift_map=SHIFT_MAP):\n",
        "  for shift in RETINA_SHIFTS:\n",
        "    for uq_method in RETINA_UQ_METHODS:\n",
        "      print(f'Adding results from shift {shift}, '\n",
        "            f'uncertainty quantification method {uq_method}.')\n",
        "      model_results = retina_results_df[\n",
        "        (retina_results_df['shift'] == shift) \u0026\n",
        "        (retina_results_df['uq_method'] == uq_method)]\n",
        "      n_results = len(model_results)\n",
        "      assert n_results == 1, f'Found {n_results} model results, expected 1.'\n",
        "      model_results = model_results.iloc[0]\n",
        "      metrics = RETINA_SHIFT_TO_METRICS[shift]\n",
        "      for metric in metrics:\n",
        "        df_metric_name = metric.split('.')[1]\n",
        "        per_metric_result = model_results[metric]\n",
        "        shift_df_name = shift_map[shift]\n",
        "        metric_shift_series = preprocessed_df[(\n",
        "          df_metric_name, f'retina_{shift_df_name}')]\n",
        "        metric_shift_series[\n",
        "          RETINA_UQ_METHOD_TO_DF_NAME[uq_method]] = per_metric_result\n",
        "        preprocessed_df[\n",
        "          (df_metric_name, f'retina_{shift_df_name}')] = metric_shift_series\n",
        "\n",
        "  return preprocessed_df\n",
        "\n",
        "if REBUILD_RETINA_RESULTS_CACHE:\n",
        "  # Retrieve RETINA I21K results from Weights \u0026 Biases\n",
        "  retina_i21k_results_df = get_retina_i21k_results_df()\n",
        "\n",
        "  # Store RETINA results in gs bucket\n",
        "  retina_ub_gs_file_path = 'gs://retina-i21k-results-df/retina-i21k-results.tsv'\n",
        "  with tf.io.gfile.GFile(retina_ub_gs_file_path, 'w') as f:\n",
        "    retina_i21k_results_df.to_csv(f, sep='\\t', index=None)\n",
        "\n",
        "\n",
        "def add_distribution_shift_to_retina_ds_name(row):\n",
        "  dataset = str(row['config.dataset'])\n",
        "  if dataset == 'retina':\n",
        "    shift = SHIFT_MAP[str(row['config.distribution_shift'])]\n",
        "    row['config.dataset'] = f'{dataset}_{shift}'\n",
        "\n",
        "  return row\n",
        "\n",
        "def split_retina_results_by_shifts(raw_dict):\n",
        "  for model in raw_dict.keys():\n",
        "    raw_model_df = raw_dict[model]\n",
        "    if not len(raw_model_df[raw_model_df['config.dataset'] == 'retina']):\n",
        "        continue\n",
        "\n",
        "    print(f'Splitting RETINA results for model {model} by distribution shift.')\n",
        "\n",
        "    raw_model_df = raw_model_df.apply(\n",
        "        add_distribution_shift_to_retina_ds_name, axis='columns')\n",
        "    raw_dict[model] = raw_model_df\n",
        "\n",
        "  return raw_dict"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GygKbZFjwiLV"
      },
      "source": [
        "## Load and preprocess measurements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFG3GNLoQN3a"
      },
      "outputs": [],
      "source": [
        "load_from_cloud = True\n",
        "if load_from_cloud == True:\n",
        "  from google.colab import auth\n",
        "  auth.authenticate_user()\n",
        "\n",
        "  project_id = 'marginalization-external-xgcp'\n",
        "  !gcloud config set project {project_id}\n",
        "\n",
        "  measurements_path = '/tmp/big-paper-raw-measurements.pkl'\n",
        "  !gsutil cp gs://ub-checkpoints/big-paper-raw-measurements.pkl {measurements_path}\n",
        "\n",
        "  retina_path = '/tmp/retina-i21k-results.tsv'\n",
        "  !gsutil cp gs://retina-i21k-results-df/retina-i21k-results.tsv {retina_path}\n",
        "\n",
        "  fewshot_measurements_path = '/tmp/big-paper-raw-measurements-fewshot.pkl'\n",
        "  !gsutil cp gs://ub-checkpoints/big-paper-raw-measurements-fewshot.pkl {fewshot_measurements_path}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JcnAzeyWsHu8"
      },
      "outputs": [],
      "source": [
        "with tf.io.gfile.GFile(measurements_path, 'rb') as f:\n",
        "  raw_measurements = pickle.load(f)\n",
        "\n",
        "with tf.io.gfile.GFile(retina_path, 'r') as f:\n",
        "  retina_i21k_results_df = pd.read_csv(f, sep='\\t')\n",
        "\n",
        "with tf.io.gfile.GFile(fewshot_measurements_path, 'rb') as f:\n",
        "  fewshot_raw_measurements = pickle.load(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6Bc5TFN3ZTk"
      },
      "outputs": [],
      "source": [
        "raw_measurements = split_retina_results_by_shifts(raw_measurements)\n",
        "\n",
        "excluded_keys = [\n",
        "    'DE', 'Det-\u003eDE', 'DE S/32', 'Det-\u003eDE S/32', 'DE B/32', 'Det-\u003eDE B/32',\n",
        "    'DE L/32', 'Det-\u003eDE L/32', 'Det -\u003e BE L/32 (n=2)', 'Det -\u003e BE L/32 (n=4)',\n",
        "    'Det -\u003e BE L/32 (n=8)'\n",
        "]\n",
        "included_measurements = {\n",
        "    k: v for k, v in raw_measurements.items() if k not in excluded_keys\n",
        "}\n",
        "included_measurements['DE'] = raw_measurements['DE L/32'].query(\n",
        "    'ensemble_size == 3')\n",
        "included_measurements['Det-\u003eDE'] = raw_measurements['Det-\u003eDE L/32'].query(\n",
        "    'ensemble_size == 3')\n",
        "# We fetch the deep ensembles of size 4 to compare with MoEs also of size 4.\n",
        "# In that case, we follow the terminology [MoE]_4 and use [Det]_4. We keep DE to\n",
        "# refer to the deep ensemble used everywhere else in the paper (size 3).\n",
        "included_measurements['[Det]_4'] = raw_measurements['DE L/32'].query(\n",
        "    'ensemble_size == 4')\n",
        "included_measurements['[Det]_4'].loc[:, 'model'] = '[Det]_4'\n",
        "\n",
        "included_measurements['Det-\u003e[Det]_4'] = raw_measurements['Det-\u003eDE L/32'].query(\n",
        "    'ensemble_size == 4')\n",
        "included_measurements['Det-\u003e[Det]_4'].loc[:, 'model'] = 'Det-\u003e[Det]_4'\n",
        "\n",
        "measurements = get_optimal_results(included_measurements)\n",
        "\n",
        "df = colab_utils.process_tuned_results(measurements)\n",
        "df = add_retina_i21k_results(\n",
        "    retina_results_df=retina_i21k_results_df, preprocessed_df=df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tTRMetR24LqG"
      },
      "outputs": [],
      "source": [
        "# Gets tuned fewshot measurements.\n",
        "fewshot_measurements = get_optimal_fewshot_results(fewshot_raw_measurements)\n",
        "\n",
        "# Prepares fewshot measurement to inject to df.\n",
        "relevant_metrics = [c for c in fewshot_measurements.columns if 'shot' in c]\n",
        "fewshot_df = colab_utils.process_tuned_results(fewshot_measurements, relevant_metrics)\n",
        "\n",
        "# Add the fewshot results for the comparison with sparse MoE's.\n",
        "moe_fewshot_df = colab_utils.process_fewshot_for_moe_comparison(included_measurements)\n",
        "\n",
        "fewshot_df = pd.concat((fewshot_df, moe_fewshot_df))\n",
        "\n",
        "# Removes upstream fewshot results.\n",
        "fewshot_metrics_to_del = [m for m in df.columns.levels[0] if 'shot' in m]\n",
        "df = df.drop(columns=fewshot_metrics_to_del, level=0)\n",
        "df.columns = df.columns.remove_unused_levels()\n",
        "\n",
        "# Adds fewshot results from fewshot_df.\n",
        "df = pd.concat([df, fewshot_df], axis=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VuV4-5wzEQiE"
      },
      "source": [
        "## Compute reliability score and generate table"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6kRH3CQGwaQg"
      },
      "outputs": [],
      "source": [
        "datasets = [\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'imagenet2012',\n",
        "    # 'imagenet_variants',\n",
        "    # 'retina_country',\n",
        "    # 'retina_severity',\n",
        "]\n",
        "datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]\n",
        "\n",
        "scores = colab_utils.compute_score(\n",
        "    df, datasets=datasets, drop_1shot=True,\n",
        "    drop_incomplete_measurements=False) * 100\n",
        "\n",
        "score_cols = [\n",
        "    'score', 'score_prediction', 'score_uncertainty', 'score_adaptation'\n",
        "]\n",
        "display.display(scores[score_cols])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wQS1I1pRD9rm"
      },
      "outputs": [],
      "source": [
        "df_with_scores = df.copy()\n",
        "for column in score_cols:\n",
        "  df_with_scores[column] = scores[column]\n",
        "\n",
        "pprint(\n",
        "    df_with_scores,\n",
        "    # models=['BE L/32', 'Det'],\n",
        "    # exclude_models=['DE', 'Det-\u003eDE'],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2RySd-77qMBe"
      },
      "outputs": [],
      "source": [
        "# Show a subset of the table's metrics + models\n",
        "metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation',\n",
        "           'exaflops', 'test_loss', 'tpu_days']\n",
        "models = ['BE L/32', 'Det', 'GP', 'Het', 'BE L/32 (I21K)', 'Det I21K',\n",
        "          'BE-\u003eBE+Het', 'E^3', '[Det]_4', '[MoE]_4']\n",
        "pprint(df_with_scores.loc[models][metrics].rename(\n",
        "    columns={'compute': 'z/compute'}))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VBtgUevvEYrh"
      },
      "source": [
        "## Plot reliability score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C5qvJhS_YWwE"
      },
      "outputs": [],
      "source": [
        "def pareto_plot(df, x, y, ax, filename=None, **kwargs):\n",
        "  def is_on_pareto_front(p, points, higher_is_better):\n",
        "    if higher_is_better:\n",
        "      return len([\n",
        "          point for point in points if point[0] \u003c= p[0] and point[1] \u003e p[1]\n",
        "      ]) == 0\n",
        "    else:\n",
        "      return len([\n",
        "          point for point in points if point[0] \u003c= p[0] and point[1] \u003c p[1]\n",
        "      ]) == 0\n",
        "  def get_pareto_points(x, y, higher_is_better=True):\n",
        "    points = list(zip(x, y))\n",
        "    frontier = [\n",
        "        p for p in points if is_on_pareto_front(p, points, higher_is_better)\n",
        "    ]\n",
        "    return sorted(frontier, key=lambda x: x[0])\n",
        "  for model, point in df.iterrows():\n",
        "    ann = ax.annotate(\n",
        "        '  ' + model,\n",
        "        xy=(point[x], point[y]),\n",
        "        ha='left',\n",
        "        va='bottom',\n",
        "  )\n",
        "  sns.scatterplot(x=df[x], y=df[y], ax=ax)\n",
        "  pareto_frontier = get_pareto_points(df[x], df[y])\n",
        "  xx, yy = zip(*pareto_frontier)\n",
        "  sns.lineplot(x=xx, y=yy, linestyle='--', ax=ax)\n",
        "  ax.set(xscale='log', **kwargs)\n",
        "  if filename is not None:\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(filename)\n",
        "    colabtools.fileedit.download_file(filename)\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(10.0, 5.0))\n",
        "pareto_plot(\n",
        "    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],\n",
        "    ax=ax,\n",
        "    y='score',\n",
        "    x=('tpu_days', 'compute'),\n",
        "    xlabel='Compute (TPUv3 core days)',\n",
        "    ylabel='Reliability Score',\n",
        "    filename='reliability.png',\n",
        ")\n",
        "\n",
        "fig, axes = plt.subplots(1, 3, figsize=(3.5 * 3, 3.5))\n",
        "pareto_plot(\n",
        "    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],\n",
        "    ax=axes[0],\n",
        "    y='score_prediction',\n",
        "    x=('tpu_days', 'compute'),\n",
        "    xlabel=None,\n",
        "    ylabel=None,\n",
        "    title='Reliability Score (Prediction)',\n",
        ")\n",
        "pareto_plot(\n",
        "    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],\n",
        "    ax=axes[1],\n",
        "    y='score_uncertainty',\n",
        "    x=('tpu_days', 'compute'),\n",
        "    xlabel=None,\n",
        "    ylabel=None,\n",
        "    title='Reliability Score (Uncertainty)',\n",
        ")\n",
        "pareto_plot(\n",
        "    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],\n",
        "    ax=axes[2],\n",
        "    y='score_adaptation',\n",
        "    x=('tpu_days', 'compute'),\n",
        "    xlabel=None,\n",
        "    ylabel=None,\n",
        "    title='Reliability Score (Adaptation)',\n",
        ")\n",
        "filename = 'reliability_components.png'\n",
        "plt.tight_layout()\n",
        "plt.savefig(filename)\n",
        "colabtools.fileedit.download_file(filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oj2HLvvlEg46"
      },
      "source": [
        "## Analyze correlation of metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TIaSlPrx8JLB"
      },
      "outputs": [],
      "source": [
        "temp_df = colab_utils.process_tuned_results(\n",
        "    measurements,\n",
        "    relevant_metrics=colab_utils.default_selected_metrics() +\n",
        "    ['training_loss', 'training_prec@1'])\n",
        "datasets = [\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'imagenet2012',\n",
        "]\n",
        "datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]\n",
        "temp_scores = colab_utils.compute_score(\n",
        "    temp_df,\n",
        "    datasets=datasets,\n",
        "    drop_1shot=True,\n",
        "    drop_incomplete_measurements=True)\n",
        "for column in score_cols:\n",
        "  temp_df[column] = temp_scores[column]\n",
        "\n",
        "# scores correlation matrix\n",
        "columns = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']\n",
        "corr_matrix = temp_df[columns]\n",
        "corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]\n",
        "corr_matrix = corr_matrix.corr()\n",
        "display.display(corr_matrix)\n",
        "\n",
        "# upstream test metrics\n",
        "metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']\n",
        "corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1']].T.xs(\n",
        "    'jft/entity:1.0.0', level='dataset')\n",
        "corr_matrix = corr_matrix[metrics]\n",
        "corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]\n",
        "display.display(corr_matrix)\n",
        "\n",
        "# imagenet 10-shot. It doesn't correlate well with reliability, mostly due to\n",
        "# it not correlating well surprisingly on other few-shot tasks.\n",
        "corr_matrix = temp_df.corr()[['10shot_prec@1']].T.xs(\n",
        "    'few-shot imagenet', level='dataset')\n",
        "corr_matrix = corr_matrix[metrics]\n",
        "corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]\n",
        "display.display(corr_matrix)\n",
        "\n",
        "# downstream training loss. The correlation is not nearly as tight as on\n",
        "# upstream.\n",
        "corr_matrix = temp_df.corr()[['training_loss']].T\n",
        "corr_matrix = corr_matrix[metrics + ['test_loss']]\n",
        "corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_country'))\n",
        "corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_severity'))\n",
        "corr_matrix = corr_matrix.drop(index=('training_loss', 'imagenet21k'))\n",
        "corr_matrix = corr_matrix.drop(columns=('test_loss', 'imagenet21k'))\n",
        "# Display test loss only for training loss' same downstream dataset. Looking at\n",
        "# cifar10's train loss correlation with I1K's test loss isn't meaningful.\n",
        "test_loss = pd.Series(\n",
        "    np.diag(corr_matrix['test_loss']), index=corr_matrix['test_loss'].index)\n",
        "corr_matrix = corr_matrix.drop(columns='test_loss')\n",
        "corr_matrix['test_loss'] = test_loss\n",
        "corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]\n",
        "display.display(corr_matrix)\n",
        "\n",
        "# Similar to old plot in go/rdl-big-meeting, even generalization gap decreases.\n",
        "# And downstream is not very indicative, but upstream is.\n",
        "temp_df2 = temp_df.copy()\n",
        "for d in temp_df2['test_loss'].columns:\n",
        "  temp_df2['reg_loss',\n",
        "           d] = temp_df2['test_loss', d] - temp_df2['training_loss', d]\n",
        "\n",
        "corr_matrix = temp_df2.corr()[['reg_loss']].T\n",
        "corr_matrix = corr_matrix[metrics + ['training_loss']]\n",
        "corr_matrix = corr_matrix.drop(index=('reg_loss', 'imagenet21k'))\n",
        "display.display(corr_matrix)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-_ShrtSoVQwh"
      },
      "outputs": [],
      "source": [
        "corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1', 'training_loss']].T.xs('jft/entity:1.0.0', level='dataset')\n",
        "\n",
        "# Rename certain task metrics to be under their generic metric name. This way,\n",
        "# we can average values across that metric.\n",
        "corr_matrix.columns = corr_matrix.columns.values\n",
        "corr_matrix.columns = pd.MultiIndex.from_tuples(corr_matrix.rename(columns={\n",
        "    ('imagenet_real_calib_auc', 'imagenet2012'): ('test_calib_auc', 'imagenet_real'),\n",
        "    ('imagenet_real_ece', 'imagenet2012'): ('test_ece', 'imagenet_real'),\n",
        "    ('imagenet_real_loss', 'imagenet2012'): ('test_loss', 'imagenet_real'),\n",
        "    ('imagenet_real_prec@1', 'imagenet2012'): ('test_prec@1', 'imagenet_real'),\n",
        "    ('cifar_10h_calib_auc', 'cifar10'): ('test_calib_auc', 'cifar_10h'),\n",
        "    ('cifar_10h_ece', 'cifar10'): ('test_ece', 'cifar_10h'),\n",
        "    ('cifar_10h_loss', 'cifar10'): ('test_loss', 'cifar_10h'),\n",
        "    ('cifar_10h_prec@1', 'cifar10'): ('test_prec@1', 'cifar_10h'),\n",
        "    ('ood_cifar100_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10-\u003ecifar100'),\n",
        "    ('ood_cifar10_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100-\u003ecifar10'),\n",
        "    ('ood_places365_small_msp_auroc', 'imagenet2012'): ('msp_auroc', 'imagenet2012-\u003eplaces365'),\n",
        "    ('ood_svhn_cropped_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10-\u003esvhn'),\n",
        "    ('ood_svhn_cropped_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100-\u003esvhn'),\n",
        "}))\n",
        "\n",
        "corr_matrix = corr_matrix.sort_index(axis=1)\n",
        "corr_matrix = corr_matrix.mean(level=0, axis='columns')\n",
        "corr_matrix = abs(corr_matrix)\n",
        "corr_matrix = corr_matrix.reindex(\n",
        "    corr_matrix.mean().sort_values().index, axis=1)\n",
        "for metric in corr_matrix.columns:\n",
        "  if metric.startswith('score') or metric in ['exaflops', 'tpu_days', 'gflops', 'ms_step']:\n",
        "    del corr_matrix[metric]\n",
        "corr_matrix = corr_matrix.T.reset_index()\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(20.0, 5.0))\n",
        "sns.barplot(x='index', y='test_loss', data=corr_matrix)\n",
        "ax.set(xlabel=None)\n",
        "ax.set(ylabel=r'$\\rho(\\cdot,$ test_loss)')\n",
        "\n",
        "filename = 'correlation.png'\n",
        "plt.tight_layout()\n",
        "plt.savefig(filename)\n",
        "colabtools.fileedit.download_file(filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_2NIeOSKegBz"
      },
      "source": [
        "## Plot Relative Score and Rankings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OUGL0zAxOd73"
      },
      "outputs": [],
      "source": [
        "datasets = [\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'imagenet2012',\n",
        "    # 'imagenet_variants',\n",
        "]\n",
        "datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]\n",
        "rel_scores = colab_utils.compute_score(\n",
        "    df,\n",
        "    drop_1shot=True,\n",
        "    datasets=datasets,\n",
        "    baseline_model='Det',\n",
        "    drop_incomplete_measurements=True)\n",
        "plt.rc('figure', figsize=(20, 20))\n",
        "\n",
        "print(\"Average relative score and ranks across categories\")\n",
        "display.display(rel_scores)\n",
        "\n",
        "print(\"==\" * 50)\n",
        "display.display(df_with_scores)\n",
        "\n",
        "print(\"Full dataframe\")\n",
        "display.display(df)\n",
        "\n",
        "# Plot rank distribution\n",
        "ranks = colab_utils.rank_models(\n",
        "    df, drop_1shot=True, datasets=datasets, drop_incomplete_measurements=True)\n",
        "ax = sns.violinplot(data=ranks.T)\n",
        "ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)\n",
        "ax.set_ylabel('Ranking')\n",
        "print(\"==\" * 50)\n",
        "print(\"Rankings\")\n",
        "display.display(ranks)\n",
        "\n",
        "ranks_by_category = colab_utils.rank_models_by_category(\n",
        "    df, drop_1shot=True, datasets=datasets, drop_incomplete_measurements=False)\n",
        "for key, rank_df in ranks_by_category.items():\n",
        "  plt.figure()\n",
        "  ax = sns.violinplot(data=rank_df.T)\n",
        "  ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)\n",
        "  ax.set_ylabel('Ranking - %s' % key)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ezzOnPxha4XR"
      },
      "outputs": [],
      "source": [
        "#@title Radar plot comparing to SOTA\n",
        "# Note that the plots in the paper are edited\n",
        "# posthoc in illustrator to improve text placement.\n",
        "from matplotlib.lines import Line2D\n",
        "from matplotlib.patches import Patch\n",
        "from matplotlib.ticker import MaxNLocator\n",
        "\n",
        "# Preliminaries\n",
        "fontsize = 24\n",
        "fontfamily = 'sans-serif'\n",
        "radar_filename = \"vision-plex-radar.pdf\"\n",
        "\n",
        "matplotlib.rcParams['figure.dpi'] = 1000\n",
        "matplotlib.rcParams['lines.linewidth'] = 1.25\n",
        "matplotlib.rcParams[\"mathtext.fontset\"] = \"cm\"\n",
        "matplotlib.rcParams['font.family'] = fontfamily\n",
        "matplotlib.rcParams['font.sans-serif'] = 'Times New Roman'\n",
        "matplotlib.rcParams['font.size'] = fontsize\n",
        "matplotlib.rcParams['ps.fonttype'] = 42\n",
        "matplotlib.rcParams['pdf.fonttype'] = 42\n",
        "\n",
        "methods = ['Plex L']\n",
        "colors = ['royalblue', 'orangered', 'tab:blue', 'cornflowerblue', 'r']\n",
        "\n",
        "# Add these manually for now\n",
        "df_with_scores.loc[\"BE L/32\", (\"AL Accuracy\", \"cifar10\")] = .9640 # Margin JFT\n",
        "df_with_scores.loc[\"BE L/32\", (\"AL Accuracy\", \"cifar100\")] = .8739 # Margin JFT\n",
        "df_with_scores.loc[\"BE L/32\", (\"AL Accuracy\", \"places365_small\")] = .8739 # Margin JFT\n",
        "df_with_scores.loc[\"BE L/32\", (\"AL Accuracy\", \"imagenet\")] = 0.771687 # Margin JFT\n",
        "\n",
        "df_with_scores.loc[\"Det\", (\"AL Accuracy\", \"cifar10\")] = .95 # Margin JFT RDL AL Meeting Notes - eyeballed\n",
        "df_with_scores.loc[\"Det\", (\"AL Accuracy\", \"imagenet\")] = 0.73 # RDL AL Meeting Notes - eyeballed\n",
        "df_with_scores.loc[\"Det\", (\"AL Accuracy\", \"cifar100\")] = .65 # Margin JFT\n",
        "\n",
        "df_with_scores.loc[\"SOTA\", (\"cifar_10h_loss\", \"cifar10\")] = 0.26\n",
        "df_with_scores.loc[\"SOTA\", (\"ood_cifar10_msp_auroc\", \"cifar100\")] = .9208\n",
        "df_with_scores.loc[\"SOTA\", (\"ood_cifar100_msp_auroc\", \"cifar10\")] = .9775\n",
        "df_with_scores.loc[\"SOTA\", (\"in_domain_test/accuracy\", \"retina_country\")] = .916\n",
        "\n",
        "# https://arxiv.org/pdf/1911.11132.pdf\n",
        "df_with_scores.loc[\"SOTA\", ('ood_places365_small_msp_auroc', 'imagenet2012')] = 0.79\n",
        "\n",
        "# https://arxiv.org/pdf/2201.07459.pdf (Figure 3, 1k examples)\n",
        "df_with_scores.loc[\"SOTA\", (\"AL Accuracy\", \"cifar10\")] = .56\n",
        "# https://arxiv.org/pdf/2107.14263.pdf (Figure 3, 10k examples)\n",
        "df_with_scores.loc[\"SOTA\", (\"AL Accuracy\", \"cifar100\")] = .40\n",
        "# # https://arxiv.org/pdf/1911.11132.pdf (Table 5, 10k examples)\n",
        "# df_with_scores.loc[\"SOTA\", (\"AL Accuracy\", \"places365\")] = .40\n",
        "\n",
        "# https://arxiv.org/pdf/2111.12880.pdf (Figure 3, 30k examples)\n",
        "df_with_scores.loc[\"SOTA\", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .797\n",
        "df_with_scores.loc[\"BE L/32\", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .848\n",
        "df_with_scores.loc[\"Det\", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .795\n",
        "\n",
        "# Retina Selective Prediction\n",
        "df_with_scores.loc[\"SOTA\", (\"AL Accuracy\", \"imagenet\")] = .54\n",
        "\n",
        "# Subpopulation shift\n",
        "# https://arxiv.org/pdf/2110.14216.pdf (Table 2, 25th %'ile)\n",
        "# CIFAR-10 Plex: .990\n",
        "# CIFAR-10 SOTA: .815\n",
        "df_with_scores.loc[\"BE-\u003eBE+Het\", (\"subpopulation\", \"cifar10\")] = .990\n",
        "df_with_scores.loc[\"SOTA\", (\"subpopulation\", \"cifar10\")] = .815\n",
        "#CIFAR-100 Plex: .931\n",
        "#CIFAR-100 SOTA: .528\n",
        "df_with_scores.loc[\"BE-\u003eBE+Het\", (\"subpopulation\", \"cifar100\")] = .931\n",
        "df_with_scores.loc[\"SOTA\", (\"subpopulation\", \"cifar100\")] = .528\n",
        "\n",
        "radar_df = pd.DataFrame(index=df_with_scores.index.copy())\n",
        "cols = [(\"cifar_10h_loss\", \"cifar10\"),\n",
        "        (\"ood_cifar10_msp_auroc\", \"cifar100\"),\n",
        "        (\"ood_cifar100_msp_auroc\", \"cifar10\"),\n",
        "        ('ood_places365_small_msp_auroc', 'imagenet2012'),\n",
        "        (\"in_domain_test/accuracy\", \"retina_country\"),\n",
        "        ('ood_test/selpred_accuracy_auc', 'retina_country'),\n",
        "        (\"AL Accuracy\", \"cifar10\"),\n",
        "        (\"AL Accuracy\", \"cifar100\"),\n",
        "        (\"AL Accuracy\", \"imagenet\"),\n",
        "        (\"subpopulation\", \"cifar10\"),\n",
        "        (\"subpopulation\", \"cifar100\")]\n",
        "radar_df = df_with_scores.loc[:, cols].copy()\n",
        "radar_df.rename(index={'SOTA': 'SOTA (specialized)'}, inplace=True)\n",
        "\n",
        "def add_default_model_results(df, model, default):\n",
        "  \"\"\"Given, say, BE-\u003eBE+Het, we'd like to default its adaptation #s to BE.\"\"\"\n",
        "  df_copy = df.copy()\n",
        "  df_copy.loc[[model]] = df_copy.loc[[model]].fillna(df_copy.loc[default], axis=0)\n",
        "  return df_copy\n",
        "\n",
        "radar_df = add_default_model_results(radar_df, 'BE-\u003eBE+Het', 'BE L/32')\n",
        "radar_df = radar_df.rename(index={\n",
        "    'BE-\u003eBE+Het': 'Plex L',\n",
        "    'Det': 'None L',\n",
        "})\n",
        "\n",
        "plt.figure(figsize=(20, 20))\n",
        "plt.tight_layout()\n",
        "plt.rc('figure', figsize=(20, 20))\n",
        "ax = plt.subplot(1, 1, 1, polar=True)\n",
        "\n",
        "max_val = 1.0\n",
        "methods.append('SOTA (specialized)')\n",
        "xticklabels = cols.copy()\n",
        "xticklabels = [\"Negative KL\\nCIFAR10H\",\n",
        "               \"OOD AUROC \\nCIFAR100 vs 10\",\n",
        "               \"OOD AUROC  \\n    CIFAR10 vs 100\",\n",
        "               \"OOD AUROC \\nImageNet vs Places365\",\n",
        "               \"Accuracy\\n    RETINA (Country)\",\n",
        "               \"  Selective Prediction\\nRETINA\\n      (OOD Country Shift)\",\n",
        "               \"Active Learning Acc.  \\nCIFAR10 @1k    \",\n",
        "               \"Active Learning Acc.\\n  CIFAR100 @10k\",\n",
        "               \"Active Learning Acc.\\n  ImageNet @30k\",\n",
        "               \"Subpopulation Acc.   \\nCIFAR10\",\n",
        "               \"Subpopulation Acc.\\nCIFAR100 \"]\n",
        "# Ranges for each y-axis corresponding to each ticklabel above.\n",
        "yranges = [(0.2, 0.55), # CIFAR10h LL\n",
        "           (0.8, 1.), (0.9, 1.0), (0.73, 0.9), # OOD\n",
        "           (0.85, 0.92), (0.7, 0.9), # Retina\n",
        "           (0.4, 1.), (0.3, 1.), (0.4, 0.9), #AL\n",
        "           (0.7, 1.), (0.5, 0.95)] # Subpopulation\n",
        "\n",
        "# Replot for each method\n",
        "for i, m in enumerate(methods):\n",
        "  colab_utils.make_radar_plot(radar_df, m, colors[i], max_val, ax,\n",
        "                              xticklabels, yranges, fontfamily=fontfamily)\n",
        "legend_elements = [Patch(facecolor=colors[i], edgecolor='k',\n",
        "                         label=m) for i, m in enumerate(methods)]\n",
        "\n",
        "# Attempt to get labels on top of grid (zorder seems buggy in polar plots)\n",
        "ax.xaxis.set_zorder(0.1)\n",
        "ax.yaxis.set_zorder(0.1)\n",
        "ax.yaxis.grid(True, zorder=1)\n",
        "ax.xaxis.grid(True, zorder=1)\n",
        "ax.grid(True, zorder=0.1)\n",
        "\n",
        "# Create the legend\n",
        "font = font_manager.FontProperties(family=fontfamily,\n",
        "                                   weight='normal',\n",
        "                                   style='normal', size=fontsize)\n",
        "# Use the legend from the text plot.\n",
        "# plt.legend(handles=legend_elements, loc='lower right', bbox_to_anchor=(1.26, 0.025), prop=font)\n",
        "ax.tick_params(axis='x', which='major', pad=120)\n",
        "\n",
        "if radar_filename is not None:\n",
        "  plt.savefig(radar_filename, bbox_inches='tight', pad_inches=0)\n",
        "  colabtools.fileedit.download_file(radar_filename)\n",
        "\n",
        "# The following comment is potentially a much cleaner way to produce this plot.\n",
        "#\n",
        "# radar_df = radar_df.loc[['Plex L', 'SOTA (specialized)']].copy()\n",
        "# radar_df.columns = ['_'.join(col).strip() for col in radar_df.columns.values]\n",
        "# cols = ['_'.join(col).strip() for col in cols]\n",
        "# xtickdict = {col:xticklabels[i] for i, col in enumerate(cols)}\n",
        "\n",
        "# radar_df['model'] = radar_df.index\n",
        "# radar_df = radar_df.melt(\n",
        "#     id_vars=['model'],\n",
        "#     var_name='metric',\n",
        "#     value_name='value')\n",
        "# radar_df = radar_df.rename(columns={('model',''):'model'})\n",
        "# display.display(radar_df)\n",
        "\n",
        "# fig = px.line_polar(radar_df, r='value', theta='metric', color='model', line_close=True, labels=xtickdict)\n",
        "# fig.update_traces(fill='toself')\n",
        "\n",
        "# fig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JveXH7-t4wx-"
      },
      "source": [
        ""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7MtTqEW-yu7B"
      },
      "source": [
        "# Plotting helpers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5wq6YIte9iIm"
      },
      "outputs": [],
      "source": [
        "#@title Bar plots\n",
        "def plot_metrics(df, train_dataset, metrics):\n",
        "  df = df[df['config.dataset'] == train_dataset].copy()\n",
        "  df = df[['model'] + metrics].melt(\n",
        "      id_vars='model', var_name='metric', value_name='value')\n",
        "  grid = sns.catplot(\n",
        "      col='metric', data=df, y='value', kind='bar', sharey=False,\n",
        "      x='model')\n",
        "  for ax in grid.axes.flat:\n",
        "    ax.set_xticklabels(\n",
        "        ax.get_xticklabels(), rotation=40, horizontalalignment=\"right\"\n",
        "    )\n",
        "\n",
        "\n",
        "def plot_in_distribution(df, train_dataset, split):\n",
        "  metrics = [f'{split}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]\n",
        "  plot_metrics(df, train_dataset, metrics)\n",
        "\n",
        "def pareto_plot_in_distribution_subfigs(df, train_dataset, split, axes, xmetric):\n",
        "  metrics = [f'{split}_{m}' for m in ['prec@1', 'loss']]\n",
        "  pareto_plot_subfigs(df, metrics, train_dataset, axes=axes, xmetric=xmetric)\n",
        "\n",
        "def plot_ood(df, train_dataset):\n",
        "  df = df[df['config.dataset'] == train_dataset].copy()\n",
        "  if train_dataset == 'imagenet2012':\n",
        "    datasets = {'places365_small'}\n",
        "    metrics = ['msp', 'entropy', 'mlogit']\n",
        "  else:\n",
        "    datasets = set(['svhn_cropped', 'cifar100', 'cifar10']) - {train_dataset}\n",
        "    metrics = ['msp', 'entropy', 'mlogit', 'maha', 'rmaha']\n",
        "  cols = [\n",
        "      f'ood_{ds}_{m}_auroc' for (ds, m) in itertools.product(datasets, metrics)\n",
        "  ]\n",
        "  cols = list(set(cols).intersection(df.columns))\n",
        "  df = df[['model'] + cols]\n",
        "  df = df.melt(id_vars='model', var_name='metric', value_name='AUROC')\n",
        "  df['dataset'] = df['metric'].apply(lambda x: x.split('_')[1])\n",
        "  df['metric'] = df['metric'].apply(lambda x: x.split('_')[-2])\n",
        "\n",
        "  sns.catplot(\n",
        "      data=df, x='metric', y='AUROC', hue='model', kind='bar', col='dataset')\n",
        "  plt.ylim((0.5, 1))\n",
        "\n",
        "\n",
        "def plot_reclassified(df, train_dataset):\n",
        "  ds = 'imagenet_real' if train_dataset == 'imagenet2012' else 'cifar_10h'\n",
        "  metrics = [f'{ds}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]\n",
        "  plot_metrics(df, train_dataset, metrics)\n",
        "\n",
        "\n",
        "def _get_imagenet_shifts_metrics(eval_dataset):\n",
        "  base_metrics = ['accuracy', 'ece', 'nll', 'brier']\n",
        "  metrics = [f'{eval_dataset}/{m}' for m in base_metrics]\n",
        "  if eval_dataset == 'imagenet_c':\n",
        "    metrics = [f'{m}/mean' for m in metrics]\n",
        "  return metrics\n",
        "\n",
        "\n",
        "def _get_imagenet_robustness_metrics(eval_dataset):\n",
        "  base_metrics = ['accuracy_pmk', 'anchor_accuracy', 'accuracy_drop']\n",
        "  return [f'{eval_dataset}/{m}' for m in base_metrics]\n",
        "\n",
        "\n",
        "def plot_imagenet_shifts(df, eval_dataset):\n",
        "  metrics = _get_imagenet_shifts_metrics(eval_dataset)\n",
        "  plot_metrics(df, 'imagenet_variants', metrics)\n",
        "\n",
        "\n",
        "def plot_imagenet_robustness(df, eval_dataset):\n",
        "  metrics = _get_imagenet_robustness_metrics(eval_dataset)\n",
        "  plot_metrics(df, 'imagenet_variants', metrics)\n",
        "\n",
        "\n",
        "def pareto_plot_imagenet_shifts(df, eval_dataset):\n",
        "  metrics = _get_imagenet_shifts_metrics(eval_dataset)\n",
        "  pareto_plot(df, train_dataset='imagenet_variants', metrics=metrics)\n",
        "\n",
        "def pareto_plot_imagenet_shift_subplots(df, eval_dataset, axes, xmetric):\n",
        "  metrics = _get_imagenet_shifts_metrics(eval_dataset)\n",
        "  metrics = [m for m in metrics if 'ece' not in m]\n",
        "  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics, axes=axes, xmetric=xmetric)\n",
        "\n",
        "def pareto_plot_imagenet_robustness(df, eval_dataset):\n",
        "  metrics = _get_imagenet_robustness_metrics(eval_dataset)\n",
        "  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics)\n",
        "\n",
        "def pareto_plot_imagenet_robustness_subplots(df, eval_dataset, axes):\n",
        "  metrics = _get_imagenet_robustness_metrics(eval_dataset)\n",
        "  metrics = [m for m in metrics if 'accuracy_drop' not in m]\n",
        "  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics, axes=axes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pB9ohPCduqdy"
      },
      "outputs": [],
      "source": [
        "#@title Pareto plots\n",
        "\n",
        "\n",
        "def is_on_pareto_front(p, points, higher_is_better):\n",
        "  if higher_is_better:\n",
        "    return len([\n",
        "        point for point in points if point[0] \u003c= p[0] and point[1] \u003e p[1]\n",
        "    ]) == 0\n",
        "  else:\n",
        "    return len([\n",
        "        point for point in points if point[0] \u003c= p[0] and point[1] \u003c p[1]\n",
        "    ]) == 0\n",
        "\n",
        "\n",
        "def get_pareto_points(x, y, higher_is_better):\n",
        "  points = list(zip(x, y))\n",
        "  frontier = [\n",
        "      p for p in points if is_on_pareto_front(p, points, higher_is_better)\n",
        "  ]\n",
        "  return sorted(frontier, key=lambda x: x[0])\n",
        "\n",
        "\n",
        "def plot_fn(data, x, y, ax=None, annotate_names=False, **kws):\n",
        "  if ax is None:\n",
        "    ax = plt.gca()\n",
        "  sns.scatterplot(\n",
        "      data=data,\n",
        "      x=x,\n",
        "      y=y,\n",
        "      hue='model',\n",
        "      markers=True,\n",
        "      style='model',\n",
        "      s=300,\n",
        "      ax=ax,\n",
        "      alpha=0.8)\n",
        "  if annotate_names:\n",
        "    for _, point in data.iterrows():\n",
        "      ann = ax.annotate(\n",
        "          '  ' + point['model'],\n",
        "          xy=(point[x], point[y]),\n",
        "          ha='left',\n",
        "          va='bottom',\n",
        "          fontsize=16)\n",
        "\n",
        "  metric = data['metric'].iloc[0]\n",
        "  higher_is_better = colab_utils.is_higher_better(metric)\n",
        "  pareto_frontier = get_pareto_points(\n",
        "      data[x], data[y], higher_is_better=higher_is_better)\n",
        "  xx, yy = zip(*pareto_frontier)\n",
        "  sns.lineplot(x=xx, y=yy, linestyle='--', ax=ax)\n",
        "  ax.set_ylabel(metric)\n",
        "\n",
        "\n",
        "def pareto_plot(df,\n",
        "                metrics,\n",
        "                train_dataset=None,\n",
        "                xmetric='num_params',\n",
        "                xlabel='Log # Params'):\n",
        "  df = df[df['config.dataset'] == train_dataset].copy()\n",
        "  df = df.groupby(['model', 'config.dataset',\n",
        "                   xmetric])[metrics].apply(np.mean).reset_index()\n",
        "  df = df.melt(\n",
        "      id_vars=['model', 'config.dataset', xmetric],\n",
        "      var_name='metric',\n",
        "      value_name='value')\n",
        "\n",
        "  g = sns.FacetGrid(data=df, col='metric', sharey=False, size=5)\n",
        "  g.map_dataframe(plot_fn, x=xmetric, y='value')\n",
        "  g.set_xlabels(xlabel)\n",
        "  g.set(xscale='log')\n",
        "\n",
        "\n",
        "def pareto_plot_subfigs(df,\n",
        "                        metrics,\n",
        "                        train_dataset=None,\n",
        "                        xmetric='num_params',\n",
        "                        xlabel='Log # Params',\n",
        "                        axes=None):\n",
        "  \"\"\"Plot subfigures corresponding to pareto frontier plots for each of metrics\n",
        "\n",
        "  in `metrics` on the y-axis and `xmetric` on the x-axis.\n",
        "\n",
        "  Allows for passing in an array of axes handles in `axes` so that the plots\n",
        "  can fill in subfigures (in which case axes must be the same length as\n",
        "  metrics).\n",
        "  \"\"\"\n",
        "  df = df.groupby(['model', 'config.dataset',\n",
        "                   xmetric])[metrics].mean().reset_index()\n",
        "  df = df.melt(\n",
        "      id_vars=['model', 'config.dataset', xmetric],\n",
        "      var_name='metric',\n",
        "      value_name='value')\n",
        "  for i in range(len(metrics)):\n",
        "    if axes is not None:\n",
        "      ax = axes[i]\n",
        "    else:\n",
        "      ax = plt.subplot(len(metrics), 1, i + 1)\n",
        "    sub_df = df[df['config.dataset'] == train_dataset].copy()\n",
        "    sub_df = sub_df[sub_df['metric'] == metrics[i]].copy()\n",
        "    plot_fn(sub_df, x=xmetric, y='value', ax=ax)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cBld21j5yx4I"
      },
      "source": [
        "# Results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "186VNwRThIhE"
      },
      "outputs": [],
      "source": [
        "#@title Upstream JFT\n",
        "plot_metrics(measurements,\n",
        "             train_dataset='jft/entity:1.0.0',\n",
        "             metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'])\n",
        "pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='jft/entity:1.0.0',\n",
        "    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],\n",
        ")\n",
        "pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='jft/entity:1.0.0',\n",
        "    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],\n",
        "    xmetric='tpu_days',\n",
        "    xlabel='Compute (TPUv3 core days)',\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xl7rOkhsuFm0"
      },
      "source": [
        "## Cifar 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mzz9lwisoKL3"
      },
      "outputs": [],
      "source": [
        "#@title In-distribution\n",
        "plot_in_distribution(measurements, train_dataset='cifar10', split='test')\n",
        "g = pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='cifar10',\n",
        "    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Oote6WoS_QOd"
      },
      "outputs": [],
      "source": [
        "#@title Cifar10h\n",
        "plot_reclassified(measurements, train_dataset='cifar10')\n",
        "g = pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='cifar10',\n",
        "    metrics=['cifar_10h_loss', 'cifar_10h_prec@1', 'cifar_10h_ece', 'cifar_10h_calib_auc'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A0Afa3nr-8ri"
      },
      "outputs": [],
      "source": [
        "#@title OOD\n",
        "plot_ood(measurements, train_dataset='cifar10')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eZ2bY0aPlQ5e"
      },
      "source": [
        "## Cifar100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Q_wtXR-9CU0"
      },
      "outputs": [],
      "source": [
        "#@title In-distribution\n",
        "plot_in_distribution(measurements, train_dataset='cifar100', split='test')\n",
        "g = pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='cifar100',\n",
        "    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j8jKADGx_ZyV"
      },
      "outputs": [],
      "source": [
        "#@title OOD\n",
        "plot_ood(measurements, train_dataset='cifar100')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kd4Ub0YclSmZ"
      },
      "source": [
        "## Imagenet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gokyk90V_nYw"
      },
      "outputs": [],
      "source": [
        "#@title In-distribution\n",
        "plot_in_distribution(measurements, train_dataset='imagenet2012', split='test')\n",
        "g = pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='imagenet2012',\n",
        "    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tcx57p3a_r7d"
      },
      "outputs": [],
      "source": [
        "#@title Imagenet Real\n",
        "plot_reclassified(measurements, train_dataset='imagenet2012')\n",
        "g = pareto_plot(\n",
        "    measurements,\n",
        "    train_dataset='imagenet2012',\n",
        "    metrics=[\n",
        "        'imagenet_real_loss', 'imagenet_real_prec@1', 'imagenet_real_ece',\n",
        "        'imagenet_real_calib_auc'\n",
        "    ])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wRbkpaT3P3iV"
      },
      "outputs": [],
      "source": [
        "#@title ImageNet Shifts \u0026 Robustness (ImageNet-C, etc.)\n",
        "shifts_filename = 'imagenet_shifts.pdf'\n",
        "robustness_filename = 'imagenet_yttb_robustness.pdf'\n",
        "fontsize = 32\n",
        "fontfamily = 'serif'\n",
        "xmetric = 'num_params'\n",
        "\n",
        "sns.reset_orig()\n",
        "sns.set_theme()\n",
        "matplotlib.rcParams['figure.dpi'] = 1000\n",
        "matplotlib.rcParams['lines.linewidth'] = 1.25\n",
        "sns.set_style('white')\n",
        "matplotlib.rcParams['font.family'] = fontfamily\n",
        "matplotlib.rcParams['font.size'] = fontsize\n",
        "ytickfontparams = {'fontsize': fontsize * .8, 'fontweight': 'normal'}\n",
        "\n",
        "# Don't keep E^3 or MoEs\n",
        "sub_df = measurements.copy()\n",
        "sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains(\n",
        "    'MoE', case=False)].index)\n",
        "sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains('E\\^3')].index)\n",
        "\n",
        "sub_df['model'] = sub_df['model'].replace({\n",
        "    'BE-\u003eBE+Het': 'Plex L',\n",
        "    'Det': 'None L',\n",
        "})\n",
        "\n",
        "# We're not currently including TPU days since too many numbers are missing\n",
        "# in the df.  However, the below fills in some of the values, which might be\n",
        "# revisited.\n",
        "# # Fix TPU days to match pretraining time\n",
        "# sub_df.loc[sub_df['model'].str.contains('Det-\u003e', case=False),\n",
        "#            'tpu_days'] = 107.29\n",
        "# sub_df.loc[sub_df['model'] == 'Plex L', 'tpu_days'] = 119.12\n",
        "\n",
        "# # Populate all columns for tpu_days\n",
        "# cols = sub_df[sub_df['tpu_days'].notna()][['model', 'tpu_days']]\n",
        "# for ind, c in cols.iterrows():\n",
        "#   sub_df.loc[sub_df['model'] == c['model'], 'tpu_days'] = c['tpu_days']\n",
        "\n",
        "sub_df.loc[:, 'model'] = sub_df['model'].str.replace(r'Det', 'None', regex=True)\n",
        "models_in_imagenet_shifts_fig = sub_df['model']\n",
        "\n",
        "variants = ['imagenet_c', 'imagenet_a', 'imagenet_r', 'imagenet_v2']\n",
        "fig, axes = plt.subplots(3, 4, figsize=(20, 15))\n",
        "axes = np.array(axes).T\n",
        "for i, ds in enumerate(variants):\n",
        "  pareto_plot_imagenet_shift_subplots(sub_df, ds, axes[i], xmetric)\n",
        "\n",
        "# Set titles along columns\n",
        "axes[0, 0].set_title('ImageNet-C', fontsize=fontsize)\n",
        "axes[1, 0].set_title('ImageNet-A', fontsize=fontsize)\n",
        "axes[2, 0].set_title('ImageNet-R', fontsize=fontsize)\n",
        "axes[3, 0].set_title('ImageNet-V2', fontsize=fontsize)\n",
        "\n",
        "# Set labels along columns\n",
        "for ax in axes.flatten():\n",
        "  ax.set_xlabel('')\n",
        "for ax in axes[:, 2]:\n",
        "  ax.set_xlabel('# Params', fontsize=fontsize)\n",
        "\n",
        "# Set labels along rows\n",
        "for ax in axes.flatten():\n",
        "  ax.set_ylabel('')\n",
        "axes[0, 0].set_ylabel('Accuracy', fontsize=fontsize)\n",
        "axes[0, 1].set_ylabel('NLL', fontsize=fontsize)\n",
        "axes[0, 2].set_ylabel('Brier', fontsize=fontsize)\n",
        "\n",
        "# Remove axes legends and make on for the figure\n",
        "for ax in axes.flatten():\n",
        "  ax.get_legend().remove()\n",
        "\n",
        "handles, labels = ax.get_legend_handles_labels()\n",
        "legend = fig.legend(\n",
        "    handles,\n",
        "    labels,\n",
        "    loc='lower center',\n",
        "    ncol=len(labels) // 2 + 1,\n",
        "    labelspacing=0.3,\n",
        "    handletextpad=0.1,\n",
        "    borderpad=0.3,\n",
        "    fontsize=fontsize * .7,\n",
        "    markerscale=3)\n",
        "legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])\n",
        "legend.get_frame().set_edgecolor('lightgray')\n",
        "\n",
        "if shifts_filename is not None:\n",
        "  plt.savefig(shifts_filename)\n",
        "  colabtools.fileedit.download_file(shifts_filename)\n",
        "\n",
        "fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n",
        "axes = np.array(axes)\n",
        "pareto_plot_imagenet_robustness_subplots(sub_df, 'imagenet_vid_robust', axes[0])\n",
        "pareto_plot_imagenet_robustness_subplots(sub_df, 'ytbb_robust', axes[1])\n",
        "for ax in axes.flatten():\n",
        "  ax.set_ylabel('')\n",
        "  ax.set_xlabel('')\n",
        "  ax.get_legend().remove()\n",
        "for ax in axes[1, :]:\n",
        "  ax.set_xlabel('# Params', fontsize=fontsize)\n",
        "axes[0, 0].set_title('ImageNet Vid Robust', fontsize=fontsize)\n",
        "axes[0, 1].set_title('YTTB Robust', fontsize=fontsize)\n",
        "axes[0, 0].set_ylabel('Accuracy PMK', fontsize=fontsize)\n",
        "axes[1, 0].set_ylabel('Anchor Accuracy', fontsize=fontsize)\n",
        "handles, labels = ax.get_legend_handles_labels()\n",
        "legend = fig.legend(\n",
        "    handles,\n",
        "    labels,\n",
        "    loc='center right',\n",
        "    ncol=1,\n",
        "    labelspacing=0.3,\n",
        "    handletextpad=0.1,\n",
        "    borderpad=0.3,\n",
        "    fontsize=fontsize * .7,\n",
        "    markerscale=3,\n",
        "    bbox_to_anchor=(1.25, 0.5))\n",
        "legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])\n",
        "legend.get_frame().set_edgecolor('lightgray')\n",
        "\n",
        "if robustness_filename is not None:\n",
        "  plt.savefig(robustness_filename, bbox_inches='tight', pad_inches=0)\n",
        "  colabtools.fileedit.download_file(robustness_filename)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q4ABd3HmGtwL"
      },
      "outputs": [],
      "source": [
        "#@title In distribution figures\n",
        "in_dist_filename = 'in_distribution_robust_generalization.pdf'\n",
        "fontsize = 32\n",
        "fontfamily = 'serif'\n",
        "xmetric = 'num_params'\n",
        "\n",
        "pd.set_option('display.max_columns', None)\n",
        "pd.set_option('display.max_rows', None)\n",
        "\n",
        "sns.reset_orig()\n",
        "sns.set_theme()\n",
        "matplotlib.rcParams['figure.dpi'] = 1000\n",
        "matplotlib.rcParams['lines.linewidth'] = 1.25\n",
        "sns.set_style('white')\n",
        "matplotlib.rcParams['font.family'] = fontfamily\n",
        "matplotlib.rcParams['font.size'] = fontsize\n",
        "ytickfontparams = {'fontsize': fontsize * .8, 'fontweight': 'normal'}\n",
        "\n",
        "# Don't keep E^3 or MoEs\n",
        "sub_df = measurements.copy()\n",
        "sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains(\n",
        "    'MoE', case=False)].index).copy()\n",
        "sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains('E\\^3')].index).copy()\n",
        "\n",
        "sub_df['model'] = sub_df['model'].replace({\n",
        "    'BE-\u003eBE+Het': 'Plex L',\n",
        "    'Det': 'None L',\n",
        "})\n",
        "\n",
        "# Make the set of models shown consistent with the imagenet shifts\n",
        "sub_df = sub_df[sub_df['model'].isin(models_in_imagenet_shifts_fig)]\n",
        "\n",
        "variants = ['cifar10', 'cifar100', 'imagenet2012']\n",
        "fig, axes = plt.subplots(2, 3, figsize=(20, 15))\n",
        "axes = np.array(axes).T\n",
        "for i, ds in enumerate(variants):\n",
        "  pareto_plot_in_distribution_subfigs(sub_df, ds, 'test', axes[i], xmetric)\n",
        "\n",
        "# Set titles along columns\n",
        "axes[0, 0].set_title('Cifar-10', fontsize=fontsize)\n",
        "axes[1, 0].set_title('Cifar-100', fontsize=fontsize)\n",
        "axes[2, 0].set_title('ImageNet', fontsize=fontsize)\n",
        "\n",
        "# Set labels along columns\n",
        "for ax in axes.flatten():\n",
        "  ax.set_xlabel('')\n",
        "for ax in axes[:, 1]:\n",
        "  ax.set_xlabel('# Params', fontsize=fontsize)\n",
        "\n",
        "# Set labels along rows\n",
        "for ax in axes.flatten():\n",
        "  ax.set_ylabel('')\n",
        "axes[0, 0].set_ylabel('Accuracy', fontsize=fontsize)\n",
        "axes[0, 1].set_ylabel('NLL', fontsize=fontsize)\n",
        "\n",
        "# Remove axes legends and make on for the figure\n",
        "for ax in axes.flatten():\n",
        "  try:\n",
        "    ax.get_legend().remove()\n",
        "  except AttributeError:\n",
        "    pass\n",
        "\n",
        "handles, labels = ax.get_legend_handles_labels()\n",
        "legend = fig.legend(\n",
        "    handles,\n",
        "    labels,\n",
        "    loc='center right',\n",
        "    ncol=1,\n",
        "    labelspacing=0.3,\n",
        "    handletextpad=0.1,\n",
        "    borderpad=0.3,\n",
        "    fontsize=fontsize * .7,\n",
        "    markerscale=3,\n",
        "    bbox_to_anchor=(1.08, 0.5))\n",
        "legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])\n",
        "legend.get_frame().set_edgecolor('lightgray')\n",
        "\n",
        "if in_dist_filename is not None:\n",
        "  plt.savefig(in_dist_filename, bbox_inches='tight', pad_inches=0.25)\n",
        "  colabtools.fileedit.download_file(in_dist_filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z_2ETz1XVUkt"
      },
      "source": [
        "## Deep ensemble analysis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bVVWFbeJLLGs"
      },
      "outputs": [],
      "source": [
        "matplotlib.rcParams['font.family'] = 'serif'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PAn_FJHHdrmT"
      },
      "outputs": [],
      "source": [
        "def get_ensemble_scaling_measurements():\n",
        "  DE_NAMES = ['DE S/32','DE B/32','DE L/32']\n",
        "  de_measurements = get_optimal_results({\n",
        "      k: v for k, v in raw_measurements.items() if k in DE_NAMES\n",
        "  })\n",
        "\n",
        "  de_measurements = de_measurements[de_measurements['model'].isin(DE_NAMES)]\n",
        "  de_measurements['model'] = de_measurements.apply(\n",
        "      lambda x: f'{x.model}_{int(x.ensemble_size)}', axis=1)\n",
        "  de_measurements = de_measurements.drop(\n",
        "      columns=list(colab_utils.compute_metrics()), errors='ignore')\n",
        "\n",
        "  relevant_metrics = colab_utils.default_selected_metrics() + ['num_params']\n",
        "  return colab_utils.process_tuned_results(\n",
        "      de_measurements, relevant_metrics=relevant_metrics)\n",
        "\n",
        "de_results = get_ensemble_scaling_measurements()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v0QTY89OdQnn"
      },
      "outputs": [],
      "source": [
        "datasets = [\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'imagenet2012',\n",
        "    # 'imagenet_variants',\n",
        "]\n",
        "datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]\n",
        "\n",
        "ensemble_meas = {\n",
        "    'DE': raw_measurements['DE L/32'].query('ensemble_size==3'),\n",
        "    'Det-\u003eDE': raw_measurements['Det-\u003eDE L/32'].query('ensemble_size==3'),\n",
        "    'Det': raw_measurements['Det'],\n",
        "}\n",
        "\n",
        "ensemble_meas = get_optimal_results(ensemble_meas, verbose=False).drop(\n",
        "    columns=list(colab_utils.compute_metrics()), errors='ignore')\n",
        "df = colab_utils.process_tuned_results(ensemble_meas)\n",
        "\n",
        "display.display(\n",
        "    colab_utils.compute_score(\n",
        "        df,\n",
        "        datasets=datasets,\n",
        "        drop_1shot=True,\n",
        "        drop_incomplete_measurements=False).loc[\n",
        "            ['Det'],\n",
        "            ['score_prediction', 'score_uncertainty', 'score_adaptation']])\n",
        "\n",
        "ensemble_scores = colab_utils.compute_score(\n",
        "    df,\n",
        "    datasets=datasets,\n",
        "    drop_1shot=True,\n",
        "    drop_incomplete_measurements=False,\n",
        "    baseline_model='Det')\n",
        "ensemble_scores = ensemble_scores[[\n",
        "    'score_prediction', 'score_uncertainty', 'score_adaptation'\n",
        "]]\n",
        "\n",
        "\n",
        "def get_improvement(value):\n",
        "  improvement = (value - 1)\n",
        "  sign = '+' if improvement \u003e= 0 else '-'\n",
        "  return f'{sign}{improvement * 100:.2f}%'\n",
        "\n",
        "\n",
        "for col in ['prediction', 'uncertainty', 'adaptation']:\n",
        "  ensemble_scores[f'Rel. improvement ({col})'] = ensemble_scores[\n",
        "      f'score_{col}'].apply(get_improvement)\n",
        "\n",
        "display.display(ensemble_scores.loc[['Det-\u003eDE L/32', 'DE L/32'],\n",
        "                                    [c for c in ensemble_scores if 'Rel' in c]])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UMiokBFI4ZuO"
      },
      "outputs": [],
      "source": [
        "datasets = [\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'imagenet2012',\n",
        "    # 'imagenet_variants',\n",
        "]\n",
        "datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]\n",
        "score_cols = [\n",
        "    'score', 'score_prediction', 'score_uncertainty', 'score_adaptation'\n",
        "]\n",
        "\n",
        "\n",
        "def plot_deep_ensemble_heatmap(scores, col_name):\n",
        "  fontsize = 18\n",
        "  tick_fontsize = 16\n",
        "  de_scores = scores[['DE' in x and x != 'DE' for x in scores.index]]\n",
        "  de_scores.loc[:, 'model_type'] = [\n",
        "      x[3:-2].replace('/32', '') for x in de_scores.index\n",
        "  ]\n",
        "  de_scores.loc[:, 'ensemble_size'] = [int(x[-1:]) for x in de_scores.index]\n",
        "\n",
        "  de_table = pd.pivot_table(\n",
        "      de_scores, values='score', index='model_type', columns='ensemble_size')\n",
        "  de_table = de_table.reindex(['L', 'B', 'S'])\n",
        "  p = sns.heatmap(\n",
        "      de_table,\n",
        "      annot=True,\n",
        "      fmt='.2f',\n",
        "      cmap='Blues',\n",
        "      cbar=False,\n",
        "      annot_kws={'size': 16})\n",
        "  p.set_xlabel('Ensemble Size', fontsize=fontsize)\n",
        "  p.set_ylabel('Model Variant', fontsize=fontsize)\n",
        "  _ = plt.xticks(fontsize=tick_fontsize)\n",
        "  _ = plt.yticks(fontsize=tick_fontsize)\n",
        "\n",
        "\n",
        "de_scores = colab_utils.compute_score(\n",
        "    de_results.drop(columns=['num_params']),\n",
        "    datasets=datasets,\n",
        "    drop_1shot=True,\n",
        "    drop_incomplete_measurements=True)\n",
        "de_scores = de_scores[score_cols] * 100\n",
        "\n",
        "plot_deep_ensemble_heatmap(de_scores, 'score')\n",
        "plt.tight_layout()\n",
        "filename = 'ensemble_tradeoff.pdf'\n",
        "plt.savefig(filename, bbox_inches='tight', pad_inches=0)\n",
        "colabtools.fileedit.download_file(filename)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T1eAXKK54bBn"
      },
      "outputs": [],
      "source": [
        "fontsize = 18\n",
        "tick_fontsize = 16\n",
        "de_results['architecture'] = de_results.index.map(\n",
        "    lambda x: x.split(' ')[1].split('_')[0].replace('/32', ''))\n",
        "de_results['ensemble_size'] = de_results.index.map(\n",
        "    lambda x: int(x.split('_')[-1]))\n",
        "\n",
        "_ = sns.scatterplot(\n",
        "    data=de_results,\n",
        "    x=('num_params', 'imagenet2012'),\n",
        "    y=('test_prec@1', 'imagenet2012'),\n",
        "    hue='architecture',\n",
        "    size='ensemble_size',\n",
        "    sizes=(40, 200))\n",
        "_ = plt.ylabel('Accuracy', fontsize=fontsize)\n",
        "_ = plt.xlabel('# Parameters', fontsize=fontsize)\n",
        "\n",
        "\n",
        "def clean_legend(ax):\n",
        "  handles, labels = ax.get_legend_handles_labels()\n",
        "  legend_texts = [l.get_text() for l in ax.legend().get_texts()]\n",
        "  ensemble_idx = legend_texts.index('ensemble_size')\n",
        "  architecture_idx = legend_texts.index('architecture')\n",
        "\n",
        "  # Remove titles in legend ('Architecture', 'Ensemble size')\n",
        "  del handles[ensemble_idx]\n",
        "  del labels[ensemble_idx]\n",
        "  del handles[architecture_idx]\n",
        "  del labels[architecture_idx]\n",
        "\n",
        "  def _annotate_label(label):\n",
        "    return f'n = {label}' if label.isnumeric() else label\n",
        "\n",
        "  labels = [_annotate_label(l) for l in labels]\n",
        "\n",
        "  # Add empty legends so that the two-column environment breaks at the end of\n",
        "  # the \"architecture\" legend.\n",
        "  empty_handle = matplotlib.collections.PathCollection(\n",
        "      paths=handles[0].get_paths(), sizes=[0.])\n",
        "  handles = [empty_handle, empty_handle] + handles\n",
        "  labels = ['', ''] + labels\n",
        "\n",
        "  legend = plt.legend(\n",
        "      handles=handles,\n",
        "      labels=labels,\n",
        "      fontsize=15,\n",
        "      ncol=2,\n",
        "      frameon=True,\n",
        "      framealpha=.1)\n",
        "  legend.get_frame().set_edgecolor('k')\n",
        "\n",
        "\n",
        "clean_legend(plt.gca())\n",
        "\n",
        "_ = plt.xticks(fontsize=tick_fontsize)\n",
        "_ = plt.yticks(fontsize=tick_fontsize)\n",
        "plt.gca().xaxis.get_offset_text().set_size(tick_fontsize)\n",
        "plt.tight_layout()\n",
        "filename = 'imagenet_params_vs_prec.pdf'\n",
        "plt.savefig(filename, bbox_inches='tight', pad_inches=0)\n",
        "colabtools.fileedit.download_file(filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NGafFDjueBSb"
      },
      "source": [
        "## Comparison with sparse MoEs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G8ryM8lweGJ2"
      },
      "outputs": [],
      "source": [
        "score_cols = [\n",
        "    '\\textsc{Score}',\n",
        "    '\\textsc{Score prediction}',\n",
        "    '\\textsc{Score uncertainty}',\n",
        "    '\\textsc{Score adaptation}'\n",
        "]\n",
        "\n",
        "moes_df = df_with_scores.reindex([\n",
        "  'Det', 'MoE', 'E^3', 'Det-\u003e[Det]_4', 'MoE-\u003e[MoE]_4', '[Det]_4', '[MoE]_4'\n",
        "])\n",
        "moes_df = moes_df.rename(index={\n",
        "    'E^3': '\\textsc{E}$^3$',\n",
        "    'Det-\u003e[Det]_4': '$\\textsc{Det}\\rightarrow[\\textsc{Det}]_4$',\n",
        "    'MoE-\u003e[MoE]_4': '$\\textsc{MoE}\\rightarrow[\\textsc{MoE}]_4$',\n",
        "    '[Det]_4': '$[\\textsc{Det}]_4$',\n",
        "    '[MoE]_4': '$[\\textsc{MoE}]_4$'\n",
        "},\n",
        "columns={\n",
        "    'score': '\\textsc{Score}',\n",
        "    'score_prediction': '\\textsc{Score prediction}',\n",
        "    'score_uncertainty': '\\textsc{Score uncertainty}',\n",
        "    'score_adaptation': '\\textsc{Score adaptation}'\n",
        "})\n",
        "\n",
        "# TODO(rjenatton@): regenerate after the adaption score has been updated.\n",
        "# In the current state, most of the DE and MoEs do not have adaption scores.\n",
        "moe_df = moes_df[score_cols].applymap(\"{0:.2f}\".format)\n",
        "moe_df = moe_df.applymap(lambda s: '$-$' if s == 'nan' else s)\n",
        "moe_latex_table = moe_df.to_latex(index=True,\n",
        "                                  index_names=False,\n",
        "                                  column_format='ccccc',\n",
        "                                  escape=False)\n",
        "# Add \\midrule at appropriate positions and remove empty line.\n",
        "moe_latex_table = moe_latex_table.splitlines()\n",
        "moe_latex_table = moe_latex_table[:8] + ['\\\\midrule'] + moe_latex_table[8:]\n",
        "moe_latex_table = moe_latex_table[:11] + ['\\\\midrule'] + moe_latex_table[11:]\n",
        "moe_latex_table = moe_latex_table[:3] + moe_latex_table[4:]\n",
        "# Just need to copy/paste the result of the print statement.\n",
        "print('\\n'.join(moe_latex_table))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fk5klOAR4df4"
      },
      "source": [
        "# Upstream vs downstream"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fV1VWDEmLOmI"
      },
      "outputs": [],
      "source": [
        "matplotlib.rcParams['font.family'] = 'serif'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IMO-26rn4fm1"
      },
      "outputs": [],
      "source": [
        "def get_up_vs_down_df():\n",
        "  RELEVANT_MODELS = [\n",
        "      'Det', 'BE L/32', 'Det-\u003eBE', 'DE', 'Det-\u003eDE', 'GP', 'Det-\u003eGP', 'Het',\n",
        "      'Det-\u003eHet'\n",
        "  ]\n",
        "  relevant_measurements = {\n",
        "      k: v for k, v in raw_measurements.items() if k in RELEVANT_MODELS\n",
        "  }\n",
        "  df = get_optimal_results(\n",
        "      relevant_measurements, verbose=False).drop(\n",
        "          columns=list(colab_utils.compute_metrics()), errors='ignore')\n",
        "  df = colab_utils.process_tuned_results(df)\n",
        "  return df.rename({'BE L/32': 'BE'})\n",
        "\n",
        "\n",
        "def _add_up_vs_down_metadata(df):\n",
        "  df = df.copy()\n",
        "\n",
        "  def _adaptation_type(model_name):\n",
        "    if model_name == 'Det':\n",
        "      return 'baseline'\n",
        "    if '-\u003e' in model_name:\n",
        "      return 'Downstream only'\n",
        "    else:\n",
        "      return 'Upstream \u0026 downstream'\n",
        "\n",
        "  df['base_model'] = df.index.map(lambda x: x.split('-\u003e')[-1])\n",
        "  df['up_vs_down'] = df.index.map(_adaptation_type)\n",
        "\n",
        "  return df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cjQ996nE4kjJ"
      },
      "outputs": [],
      "source": [
        "from matplotlib.lines import Line2D\n",
        "\n",
        "\n",
        "def up_vs_down_plot(df, dataset, metric):\n",
        "  fontsize = 18\n",
        "  tick_fontsize = 16\n",
        "\n",
        "  col = metric if dataset is None else (metric, dataset)\n",
        "  ymin = df[col].min()\n",
        "  ymax = df[col].max()\n",
        "\n",
        "  det_baseline = df.loc['Det', col]\n",
        "  cur_df = df[df.index != 'Det']\n",
        "\n",
        "  base_model_col = 'base_model' if dataset is None else ('base_model', '')\n",
        "  up_vs_down_col = 'up_vs_down' if dataset is None else ('up_vs_down', '')\n",
        "\n",
        "  graph = sns.barplot(data=cur_df, x=base_model_col, y=col, hue=up_vs_down_col)\n",
        "  graph.axhline(det_baseline, c='r', linewidth=2)\n",
        "  plt.ylim(max(0, ymin - .5 * (ymax - ymin)), ymax + .1 * (ymax - ymin))\n",
        "  plt.legend(title='location', fontsize=fontsize, title_fontsize=fontsize)\n",
        "  ylabel = metric.replace('test_', '').replace('loss', 'NLL').replace(\n",
        "      'ece', 'ECE').replace('prec@1', 'Accuracy')\n",
        "  plt.gca().legend().set_visible(False)\n",
        "  plt.ylabel(ylabel, fontsize=fontsize)\n",
        "  plt.xticks(fontsize=tick_fontsize)\n",
        "  plt.yticks(fontsize=tick_fontsize)\n",
        "  plt.xlabel('')\n",
        "\n",
        "\n",
        "plot_datasets = ['cifar10', 'cifar100', 'imagenet2012']\n",
        "plot_metrics = ['test_prec@1', 'test_ece', 'test_loss']\n",
        "\n",
        "df = _add_up_vs_down_metadata(get_up_vs_down_df())\n",
        "\n",
        "for row, dataset in enumerate(plot_datasets):\n",
        "  for col, metric in enumerate(plot_metrics):\n",
        "    up_vs_down_plot(df, dataset, metric)\n",
        "    filename = f'up_vs_down_{dataset}_{metric}.pdf'\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(filename, bbox_inches='tight', pad_inches=0)\n",
        "    colabtools.fileedit.download_file(filename)\n",
        "    plt.show()\n",
        "\n",
        "# Generate arbitrary figure to use its legend data\n",
        "up_vs_down_plot(df, 'cifar10', 'test_prec@1')\n",
        "\n",
        "# Save legend in separate figure\n",
        "ax = plt.gca()\n",
        "fig_leg = plt.figure(figsize=(15, .5))\n",
        "ax_leg = fig_leg.add_subplot(111)\n",
        "# add the legend from the previous axes\n",
        "handles, labels = ax.get_legend_handles_labels()\n",
        "handles.extend([Line2D([0], [0], color='r')])\n",
        "labels.extend(['Deterministic (upstream \u0026 downstream)'])\n",
        "ax_leg.legend(handles, labels, loc='center', ncol=3, fontsize=18, frameon=False)\n",
        "# hide the axes frame and the x/y labels\n",
        "ax_leg.axis('off')\n",
        "fig_leg.savefig('up_vs_down_legend.pdf', bbox_inches='tight', pad_inches=0)\n",
        "colabtools.fileedit.download_file('up_vs_down_legend.pdf')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YIjqBuxLstHQ"
      },
      "source": [
        "# Open-set recognition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TwpM1y0vvMab"
      },
      "outputs": [],
      "source": [
        "ddf_ood = colab_utils.process_tuned_results(measurements, relevant_metrics=colab_utils.ood_related_metrics())\n",
        "df_ood"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eXULbLNj1UxY"
      },
      "outputs": [],
      "source": [
        "df_ood.keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6-W-SSds1uEK"
      },
      "outputs": [],
      "source": [
        "#@title Comparing models by fixing on OOD method = MSP\n",
        "ood_msp_metrics = [x[0] for x in df_ood.keys() if 'msp' in x[0]]\n",
        "df_ood.loc[:, (ood_msp_metrics, slice(None))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FljtzSHM6IMf"
      },
      "outputs": [],
      "source": [
        "#@title Comparing OOD methods on hard near-OOD tasks: (1) ImageNet2012\n",
        "df_ood.loc[:, (slice(None), 'imagenet2012')]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OYzYdkgZ6wI0"
      },
      "outputs": [],
      "source": [
        "#@title Comparing OOD methods on hard near-OOD tasks (2) CIFAR-100 vs CIFAR-10\n",
        "near_ood_metrics_cifar = ['ood_cifar10_msp_auroc', 'ood_cifar10_entropy_auroc',\n",
        "                          'ood_cifar10_mlogit_auroc', \n",
        "                          'ood_cifar10_maha_auroc', 'ood_cifar10_rmaha_auroc'] \n",
        "df_ood.loc[:, (near_ood_metrics_cifar, 'cifar100')]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4O001cQabuU4"
      },
      "source": [
        "## Zero-shot OOD results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M_vuR5bqKnkz"
      },
      "outputs": [],
      "source": [
        "#@title load zero-shot data\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6cYZpYG6QAkl"
      },
      "outputs": [],
      "source": [
        "excluded_keys = [\n",
        "    'DE', 'Det-\u003eDE', 'DE S/32', 'Det-\u003eDE S/32', 'DE B/32', 'Det-\u003eDE B/32',\n",
        "    'DE L/32', 'Det-\u003eDE L/32', 'Det -\u003e BE L/32 (n=2)', 'Det -\u003e BE L/32 (n=4)',\n",
        "    'Det -\u003e BE L/32 (n=8)', \n",
        "    'E^3', 'BE scaling', 'MoE',\n",
        "    'Det-\u003eBE', 'Det-\u003eGP', 'Det-\u003eHet', 'BE-\u003eBE+Het'\n",
        "]\n",
        "included_measurements = {\n",
        "    k: v for k, v in raw_measurements.items() if k not in excluded_keys\n",
        "}\n",
        "measurements = get_optimal_results(included_measurements)\n",
        "\n",
        "df_ood_zero_shot = colab_utils.process_tuned_results(measurements, relevant_metrics=colab_utils.ood_related_metrics())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SJcj8B9vQOr-"
      },
      "outputs": [],
      "source": [
        "ood_maha_metrics = [x[0] for x in df_ood_zero_shot.keys() if 'maha' in x[0]]\n",
        "df_ood_zero_shot.loc[:, (ood_maha_metrics, slice(None))]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IVpm3CKO_85O"
      },
      "source": [
        "# CIFAR subpopulation shift plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5dS9DYPPQRQo"
      },
      "outputs": [],
      "source": [
        "subpopl_metrics_raw = []\n",
        "subpopl_metrics_raw.append((\"\"\".984 / .989 / .992 / .996 / .998\n",
        ".973 / .987 / .993 / 1.0 / 1.0\n",
        ".892 / .909 / .922 / .933 / .944\n",
        ".878 / .904 / .920 / .940 / .961\"\"\", \"None\"))\n",
        "subpopl_metrics_raw.append((\"\"\".986 / .990 / .995 / .997 / 1.0\n",
        ".978 / .990 / 1.0  / 1.0 / 1.0\n",
        ".912 / .931 / .937 / .945 / .960\n",
        ".900 / .920 / .940 / .950 / .971\"\"\", \"Plex\"))\n",
        "# subpopl_metrics_raw.append((\"\"\".982 / .987 / .990 / .993 / .998\n",
        "# .971 / .985 / .991 / 1.0 / 1.0\n",
        "# .901 / .922 / .933 / .943 / .953\n",
        "# .895 / .920 / .930 / .950 / .970\"\"\", \"None I21K\"))\n",
        "subpopl_metrics_raw.append((\"\"\".986 / .990 / .994 / .997 / 1.0\n",
        ".977 / .989 / 1.0 / 1.0 / 1.0\n",
        ".907 / .923 / .933 / .944 / .959\n",
        ".899 / .919 / .933 / .950 / .971\"\"\", \"BE→BE\"))\n",
        "subpopl_metrics_raw.append((\"\"\".985 / .990 / .994 / .996 / 1.0\n",
        ".977 / .987 / .999 / 1.0 / 1.0\n",
        ".905 / .922 / .931 / .940 / .955\n",
        ".896 / .915 / .930 / .949 / .970\"\"\", \"None→BE\"))\n",
        "\n",
        "\n",
        "subpopl_metrics_CIFAR10_30 = {}\n",
        "subpopl_metrics_CIFAR10_100 = {}\n",
        "subpopl_metrics_CIFAR100_30 = {}\n",
        "subpopl_metrics_CIFAR100_100 = {}\n",
        "for raw_metrics, key in subpopl_metrics_raw:\n",
        "  metrics = [float(raw_metric.strip()) for raw_metric in raw_metrics.replace('\\n', '/').split('/')]\n",
        "  # 4 datasets, 5 metrics for each.\n",
        "  assert len(metrics) == 20\n",
        "  subpopl_metrics_CIFAR10_30[key] = metrics[0:5]\n",
        "  subpopl_metrics_CIFAR10_100[key] = metrics[5:10]\n",
        "  subpopl_metrics_CIFAR100_30[key] = metrics[10:15]\n",
        "  subpopl_metrics_CIFAR100_100[key] = metrics[15:20]\n",
        "\n",
        "subpopl_metrics = {}\n",
        "subpopl_metrics['CIFAR10_30'] = subpopl_metrics_CIFAR10_30\n",
        "subpopl_metrics['CIFAR10_100'] = subpopl_metrics_CIFAR10_100\n",
        "subpopl_metrics['CIFAR100_30'] = subpopl_metrics_CIFAR100_30\n",
        "subpopl_metrics['CIFAR100_100'] = subpopl_metrics_CIFAR100_100\n",
        "\n",
        "subpopl_metrics_list = []\n",
        "for dataset in subpopl_metrics:\n",
        "  for task in subpopl_metrics[dataset]:\n",
        "    dataset_base, dataset_tail = dataset.split('_')\n",
        "    subpopl_metrics_list.append([subpopl_metrics[dataset][task], task, dataset_base, dataset_tail])\n",
        "\n",
        "df = pd.DataFrame(subpopl_metrics_list, columns=['values', 'task', 'dataset', 'subpopulations'])\n",
        "\n",
        "def subpopl_plot_fn(data, color):\n",
        "  x = data['task'].tolist()\n",
        "  y = data['values'].tolist()\n",
        "  sns.boxplot(data=pd.DataFrame({key: data for key, data in zip(x, y)}), order=['Plex', 'BE→BE', 'None→BE', 'None'])\n",
        "\n",
        "  dataset = data['dataset'].tolist()[0]\n",
        "  if dataset == 'CIFAR10':\n",
        "    plt.ylim(.97, 1.0)\n",
        "  elif dataset == 'CIFAR100':\n",
        "    plt.ylim(.87, .98)\n",
        "  else:\n",
        "    raise ValueError()\n",
        "\n",
        "  plt.xlabel('Method')\n",
        "  plt.ylabel('Accuracy')\n",
        "\n",
        "def plot_subpopl_metrics():\n",
        "  matplotlib.rcParams['font.family'] = 'serif'\n",
        "  matplotlib.rcParams['axes.titlepad'] = 10\n",
        "  g = sns.FacetGrid(df, row=\"dataset\", col='subpopulations', aspect=1.8, height=3.2, sharey='row')\n",
        "  for ax in g.axes.flatten():\n",
        "    ax.tick_params(labelbottom=True, labelleft=True)\n",
        "  g.map_dataframe(subpopl_plot_fn)\n",
        "  plt.subplots_adjust(hspace=0.4, wspace=0.2)\n",
        "  g.despine(right=False, top=False)\n",
        "  g.axes[0,0].set_ylabel('Accuracy')\n",
        "  g.axes[1,0].set_ylabel('Accuracy')\n",
        "\n",
        "  filename = 'subpopl_ablations.pdf'\n",
        "  plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=1000)\n",
        "  colabtools.fileedit.download_file(filename)\n",
        "\n",
        "  return g\n",
        "\n",
        "plot_subpopl_metrics()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "eZ2bY0aPlQ5e",
        "Fk5klOAR4df4"
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "RDL Big Paper Plots",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/plex/plots.ipynb?workspaceId=jsnoek:in_dist_plots::citc",
          "timestamp": 1658424204324
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/plex/plots.ipynb?workspaceId=jsnoek:in_dist_figures::citc",
          "timestamp": 1658345633757
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb?workspaceId=zmariet:fig-export-rdl_colab-1402-change-18::citc",
          "timestamp": 1655925815511
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1655830288403
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1654788337884
        },
        {
          "file_id": "1ssv7JaVogtAupjSQRDkE4BL1DuAN1ePV",
          "timestamp": 1654621028399
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1654618220486
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1652467886349
        },
        {
          "file_id": "1QxBhHvVLapPVI0iV31WW9b3bNndiqKR5",
          "timestamp": 1652190923733
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1652124724366
        },
        {
          "file_id": "1ysMQuP_2JJNxIvGD2X3HIOyYq8PnDd0x",
          "timestamp": 1652117504457
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1652116911682
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/RDL_Big_Paper_Plots.ipynb?workspaceId=zmariet:rdl_colab::citc",
          "timestamp": 1651094726179
        },
        {
          "file_id": "17EfR-0x8-RZeRD-6_--tA3Qk1FQzriyU",
          "timestamp": 1651094693154
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1651093947882
        },
        {
          "file_id": "16tdnixI5DVYANhrkytxc549-FA5i226p",
          "timestamp": 1651008508678
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1651004739395
        },
        {
          "file_id": "16nO48cMsvHj1Yb2vBI3sym10SF4Stfsf",
          "timestamp": 1650504045971
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1650406391596
        },
        {
          "file_id": "1wVhO3x9rHqzIbCN4jd8j8PbO33iFcEf2",
          "timestamp": 1650382845945
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1650377950326
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1649944249217
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/google/colab/plots.ipynb?workspaceId=zmariet:colab::citc",
          "timestamp": 1648746918580
        },
        {
          "file_id": "1XVIrTYh6R6VpfRMHwkXF4YN6L6LO7WYh",
          "timestamp": 1648746873184
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/big_paper/plots.ipynb",
          "timestamp": 1648740356002
        },
        {
          "file_id": "1Tufx2M784xw4obIzgWXdYAX8lJGtJb3k",
          "timestamp": 1645545840302
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/google/colab/big_paper_experiments.ipynb?workspaceId=trandustin:plots::citc",
          "timestamp": 1645091101660
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/google/colab/big_paper_experiments.ipynb?cl=428611591",
          "timestamp": 1644888762710
        },
        {
          "file_id": "1pql3UgJFiEjGW4igFnWING7A73O_iq04",
          "timestamp": 1644878348078
        },
        {
          "file_id": "1_OgnYgLLR0zpaN2-bBQt1RE3B5JktWNN",
          "timestamp": 1643738065376
        }
      ]
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
