{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9edb6bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../\") # For relative imports\n",
    "\n",
    "from scipy.stats import entropy\n",
    "\n",
    "from utils.experiment_utils import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a477f3d8",
   "metadata": {},
   "source": [
    "In this notebook, we compute various measures of class imbalance for each dataset. The metrics are computed on the data not used for model training. The metric we use in the paper is `Normalized fraction of mass in rarest 0.05 of classes`, since we find that this metric best captures the type of imbalance that is challenging for our problem setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d104af2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Dataset: imagenet ====\n",
      "softmax_scores shape: (1153051, 1000)\n",
      "Min count: 663\n",
      "Max count: 1201\n",
      "Min/max ratio: 0.552\n",
      "Normalized fraction of mass in rarest 0.05 of classes: 0.7905461250196218\n",
      "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.7905461250196218\n",
      "Normalized Shannon entropy: 0.9997548276966274\n",
      "[.25, .5, .75, .9] class count quantiles: [1159. 1168. 1176. 1184.]\n",
      "\n",
      "==== Dataset: cifar-100 ====\n",
      "softmax_scores shape: (30000, 100)\n",
      "Min count: 257\n",
      "Max count: 330\n",
      "Min/max ratio: 0.779\n",
      "Normalized fraction of mass in rarest 0.05 of classes: 0.9039999999999999\n",
      "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.904\n",
      "Normalized Shannon entropy: 0.9997848210317252\n",
      "[.25, .5, .75, .9] class count quantiles: [290.   301.5  310.25 316.1 ]\n",
      "\n",
      "==== Dataset: places365 ====\n",
      "softmax_scores shape: (183996, 365)\n",
      "Min count: 300\n",
      "Max count: 576\n",
      "Min/max ratio: 0.521\n",
      "Normalized fraction of mass in rarest 0.05 of classes: 0.7687123633122458\n",
      "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.7687123633122458\n",
      "Normalized Shannon entropy: 0.9995684899390082\n",
      "[.25, .5, .75, .9] class count quantiles: [493.  508.  523.  538.6]\n",
      "\n",
      "==== Dataset: inaturalist ====\n",
      "softmax_scores shape: (1393421, 1103)\n",
      "Data preprocessing: Keeping 633 of 1103 classes that have >= 250 examples\n",
      "Min count: 250\n",
      "Max count: 68838\n",
      "Min/max ratio: 0.004\n",
      "Normalized fraction of mass in rarest 0.05 of classes: 0.12139784134651672\n",
      "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.12139784134651672\n",
      "Normalized Shannon entropy: 0.8596172827554369\n",
      "[.25, .5, .75, .9] class count quantiles: [ 373.   697.  1822.  5099.4]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['imagenet', 'cifar-100', 'places365', 'inaturalist']\n",
    "\n",
    "for dataset in dataset_list:\n",
    "    print(f'\\n==== Dataset: {dataset} ====')\n",
    "    softmax_scores, labels = load_dataset(dataset)\n",
    "    cts = Counter(labels).values()\n",
    "    cts = sorted(np.array(list(cts)))\n",
    "    num_classes = len(cts)\n",
    "    print('Min count:', min(cts))\n",
    "    print('Max count:', max(cts))\n",
    "    print(f'Min/max ratio: { min(cts)/max(cts):.3f}')\n",
    "    frac = .05\n",
    "    print(f'Normalized fraction of mass in rarest {frac} of classes: {(np.sum(cts[:int(frac*num_classes)])/len(labels)) / .05}')\n",
    "    print(f'# of examples in rarest {frac} of classes divided by expected number if uniform: {np.sum(cts[:int(frac*num_classes)])/(len(labels) * .05)}') # Another view\n",
    "    print('Normalized Shannon entropy:', entropy(cts) / np.log(len(cts))) # See https://stats.stackexchange.com/questions/239973/a-general-measure-of-data-set-imbalance\n",
    "    print('[.25, .5, .75, .9] class count quantiles:', np.quantile(cts, [.25, .5, .75, .9]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d8a4ec",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.9",
   "language": "python",
   "name": "py3.9"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
