{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from IPython.display import Image\n",
    "import torch\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style('darkgrid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-attribute labels: gender and hair color"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "class indices:\n",
    "black hair = 8\n",
    "gender = 20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build new labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = '../data/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure to have pre-processed the celebA dataset before running this code!\n",
    "split = 'val' # repeat for train and val\n",
    "data = torch.load(os.path.join(DATA_DIR, '{}_celeba_64x64.pt'.format(split)))\n",
    "labels = torch.load(os.path.join(DATA_DIR, '{}_labels_celeba_64x64.pt'.format(split)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 0.] 0\n",
      "[0. 1.] 1\n",
      "[1. 0.] 2\n",
      "[1. 1.] 3\n"
     ]
    }
   ],
   "source": [
    "new_labels = np.zeros(len(labels))\n",
    "unique_items = np.unique(labels[:,(8, 20)], axis=0)\n",
    "for i, unique in enumerate(unique_items):\n",
    "    yes = np.ravel([np.array_equal(x,unique) for x in labels[:,(8,20)]])\n",
    "    new_labels[yes] = i\n",
    "    print(unique, i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_labels = torch.from_numpy(new_labels)\n",
    "torch.save(new_labels, os.path.join(DATA_DIR, '{}_multi_labels_celeba_64x64.pt'.format(split)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### breakdown:\n",
    "- black hair = 0, male = 0 -> 0\n",
    "- black hair = 0, male = 1 -> 1\n",
    "- black hair = 1, male = 0 -> 2\n",
    "- black hair = 1, male = 1 -> 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test label and image consistency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = 'train'\n",
    "data = torch.load(os.path.join(DATA_DIR, '{}_celeba_64x64.pt'.format(split)))\n",
    "labels = torch.load(os.path.join(DATA_DIR, '{}_multi_labels_celeba_64x64.pt'.format(split)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    print(i, np.where(labels.data.cpu().numpy()==i)[0][0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 5\n",
    "plt.imshow(np.transpose(data[i].data.cpu().numpy(), (1,2,0)))\n",
    "plt.show()\n",
    "print(labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 20\n",
    "plt.imshow(np.transpose(data[i].data.cpu().numpy(), (1,2,0)))\n",
    "plt.show()\n",
    "print(labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 16\n",
    "plt.imshow(np.transpose(data[i].data.cpu().numpy(), (1,2,0)))\n",
    "plt.show()\n",
    "print(labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 7\n",
    "plt.imshow(np.transpose(data[i].data.cpu().numpy(), (1,2,0)))\n",
    "plt.show()\n",
    "print(labels[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "train set\n",
    "- balanced dataset ratio: (array([0., 1., 2., 3.]), array([15000, 15000, 15000, 15000]))\n",
    "- unbalanced dataset ratio: (array([0., 1., 2., 3.]), array([26216, 24878,  3784,  5122]))\n",
    "\n",
    "validation set\n",
    "\n",
    "- balanced dataset ratio: (array([0., 1., 2., 3.]), array([1315, 1315, 1315, 1315]))\n",
    "-unbalanced dataset ratio: (array([0., 1., 2., 3.]), array([ 973, 1091,  342,  224]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "new labels:\n",
    "- 0) (black hair = 0, male = 0): 26216\n",
    "- 1) (black hair = 0, male = 1): 24878\n",
    "- 2) (black hair = 1, male = 0): 3784\n",
    "- 3) (black hair = 1, male = 1): 5122"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(26216 + 24878 + 3784 + 5122)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "26216/ 60000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "24878/ 60000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "3784 / 60000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "5122 / 60000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare samples for unbiased FID statistic calculation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single-attribute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels[:,20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "combine all data across splits to maximize number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = ['test', 'val', 'train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "labels = []\n",
    "for split in splits:\n",
    "    d = torch.load(os.path.join(DATA_DIR, '{}_celeba_64x64.pt'.format(split)))\n",
    "    l = torch.load(os.path.join(DATA_DIR, '{}_labels_celeba_64x64.pt'.format(split)))\n",
    "    data.append(d)\n",
    "    labels.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.cat(data)\n",
    "labels = torch.cat(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get minimum value of class\n",
    "val, freq = np.unique(labels[:,20].data.numpy(), return_counts=True)\n",
    "min_value = min(freq)\n",
    "print(val, freq)\n",
    "print(min(freq))\n",
    "\n",
    "samples = []\n",
    "ys = []\n",
    "for i in range(len(val)):\n",
    "    idx = np.where(labels[:,20].data.numpy() == i)[0][0:min_value]\n",
    "    samples.append(data[idx])\n",
    "    ys.append(labels[:,20][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "84434 * 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.cat(samples)\n",
    "samples = samples.numpy()\n",
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we should convert to numpy\n",
    "samples = np.transpose(samples, (0, 2, 3, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('../fid_stats/unbiased_all_gender_samples.npz', **{'x':samples})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-attribute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = ['test', 'val', 'train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# multi\n",
    "data = []\n",
    "labels = []\n",
    "for split in splits:\n",
    "    d = torch.load(os.path.join(DATA_DIR, '{}_celeba_64x64.pt'.format(split)))\n",
    "    l = torch.load(os.path.join(DATA_DIR, '{}_multi_labels_celeba_64x64.pt'.format(split)))\n",
    "    data.append(d)\n",
    "    labels.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.cat(data)\n",
    "labels = torch.cat(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get minimum value of class\n",
    "val, freq = np.unique(labels.data.numpy(), return_counts=True)\n",
    "min_value = min(freq)\n",
    "print(val, freq)\n",
    "print(min(freq))\n",
    "\n",
    "samples = []\n",
    "ys = []\n",
    "for i in range(len(val)):\n",
    "    idx = np.where(labels.data.numpy() == i)[0][0:min_value]\n",
    "    samples.append(data[idx])\n",
    "    ys.append(labels[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "23316*4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.cat(samples)\n",
    "samples = samples.numpy()\n",
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we should convert to numpy\n",
    "samples = np.transpose(samples, (0, 2, 3, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('../fid_stats/unbiased_all_multi_samples.npz', **{'x':samples})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
