{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "#import seaborn as sns\n",
    "import torch\n",
    "from IPython.display import Image\n",
    "import torch\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sns' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-8fdf3272f71a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_style\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'darkgrid'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'sns' is not defined"
     ]
    }
   ],
   "source": [
    "sns.set_style('darkgrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = '../data/FairFace'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare samples for unbiased FID statistic calculation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single-attribute"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "combine all data across splits to maximize number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = ['val', 'train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "labels = []\n",
    "for split in splits:\n",
    "    d = torch.load(os.path.join(DATA_DIR, '{}_FairFace_64x64.pt'.format(split)))\n",
    "    l = torch.load(os.path.join(DATA_DIR, '{}_labels_FairFace_64x64.pt'.format(split)))\n",
    "    data.append(d)\n",
    "    labels.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.cat(data)\n",
    "labels = torch.cat(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([32401, 3, 64, 64]), torch.Size([32401]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 1.] [18612 13789]\n",
      "13789\n"
     ]
    }
   ],
   "source": [
    "# get minimum value of class\n",
    "val, freq = np.unique(labels.data.numpy(), return_counts=True)\n",
    "min_value = min(freq)\n",
    "print(val, freq)\n",
    "print(min(freq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "ys = []\n",
    "for i in range(len(val)):\n",
    "    idx = np.where(labels.data.numpy() == i)[0][0:min_value]\n",
    "    samples.append(data[idx])\n",
    "    ys.append(labels[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([13789, 3, 64, 64]),\n",
       " torch.Size([13789]),\n",
       " torch.Size([13789, 3, 64, 64]),\n",
       " torch.Size([13789]))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples[0].shape, ys[0].shape, samples[1].shape, ys[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(27578, 3, 64, 64)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples = torch.cat(samples)\n",
    "samples = samples.numpy()\n",
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we should convert to numpy\n",
    "samples = np.array(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(27578, 3, 64, 64)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('../fid_stats/FairFace/unbiased_all_race_samples.npz', **{'x':samples})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
