{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from IPython.display import Image\n",
    "import torch\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style('darkgrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = '../data/UTKFace'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare samples for unbiased FID statistic calculation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single-attribute"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "combine all data across splits to maximize number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = ['test', 'val', 'train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "labels = []\n",
    "for split in splits:\n",
    "    d = torch.load(os.path.join(DATA_DIR, '{}_UTK_64x64.pt'.format(split)))\n",
    "    l = torch.load(os.path.join(DATA_DIR, '{}_labels_UTK_64x64.pt'.format(split)))\n",
    "    data.append(d)\n",
    "    labels.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.cat(data)\n",
    "labels = torch.cat(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([23705, 3, 64, 64]), torch.Size([23705]))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 1.] [10078 13627]\n",
      "10078\n"
     ]
    }
   ],
   "source": [
    "# get minimum value of class\n",
    "val, freq = np.unique(labels.data.numpy(), return_counts=True)\n",
    "min_value = min(freq)\n",
    "print(val, freq)\n",
    "print(min(freq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "ys = []\n",
    "for i in range(len(val)):\n",
    "    idx = np.where(labels.data.numpy() == i)[0][0:min_value]\n",
    "    samples.append(data[idx])\n",
    "    ys.append(labels[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([10078, 3, 64, 64]),\n",
       " torch.Size([10078]),\n",
       " torch.Size([10078, 3, 64, 64]),\n",
       " torch.Size([10078]))"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples[0].shape, ys[0].shape, samples[1].shape, ys[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20156, 3, 64, 64)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples = torch.cat(samples)\n",
    "samples = samples.numpy()\n",
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we should convert to numpy\n",
    "samples = np.array(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20156, 3, 64, 64)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('../fid_stats/UTKFace/unbiased_all_race_samples.npz', **{'x':samples})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
