{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from numpy import random\n",
    "\n",
    "# config\n",
    "dataset = 'mnist'\n",
    "bias_ratio = 4\n",
    "bias_factor = 'Y'\n",
    "target_class = 8\n",
    "random_seed = 0\n",
    "\n",
    "# random seed\n",
    "torch.manual_seed(random_seed)\n",
    "random.seed(random_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as tf\n",
    "import torchvision.datasets as datasets\n",
    "from os import path, mkdir, makedirs\n",
    "\n",
    "\n",
    "\n",
    "# load data\n",
    "if not path.exists(path.join(f'original')):\n",
    "    mkdir(path.join(f'original'))\n",
    "\n",
    "if dataset == 'mnist':\n",
    "    train = datasets.MNIST(root='/home/.../nas/PF-GAN/dataset/original', train=True, transform=tf.ToTensor(), download=True)\n",
    "    test = datasets.MNIST(root='/home/.../nas/PF-GAN/dataset/original', train=False, transform=tf.ToTensor(), download=True)\n",
    "\n",
    "elif dataset == 'fmnist':\n",
    "    train = datasets.FashionMNIST(root='./original', train=True, transform=tf.ToTensor(), download=True)\n",
    "    test = datasets.FashionMNIST(root='./original', train=False, transform=tf.ToTensor(), download=True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError(\"Dataset should be mnist or fmnist\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original class count: \n",
      "class 0: 5923\n",
      "class 1: 6742\n",
      "class 2: 5958\n",
      "class 3: 6131\n",
      "class 4: 5842\n",
      "class 5: 5421\n",
      "class 6: 5918\n",
      "class 7: 6265\n",
      "class 8: 5851\n",
      "class 9: 5949\n",
      "\n",
      "Created dataset with 4:1 ratio on Y\n",
      "Modified class count: \n",
      "class 0: 5923\n",
      "class 1: 6742\n",
      "class 2: 5958\n",
      "class 3: 6131\n",
      "class 4: 5842\n",
      "class 5: 5421\n",
      "class 6: 5918\n",
      "class 7: 6265\n",
      "class 8: 1462\n",
      "class 9: 5949\n",
      "Original class count: \n",
      "class 0: 980\n",
      "class 1: 1135\n",
      "class 2: 1032\n",
      "class 3: 1010\n",
      "class 4: 982\n",
      "class 5: 892\n",
      "class 6: 958\n",
      "class 7: 1028\n",
      "class 8: 974\n",
      "class 9: 1009\n",
      "\n",
      "Created dataset with 4:1 ratio on Y\n",
      "Modified class count: \n",
      "class 0: 980\n",
      "class 1: 1135\n",
      "class 2: 1032\n",
      "class 3: 1010\n",
      "class 4: 982\n",
      "class 5: 892\n",
      "class 6: 958\n",
      "class 7: 1028\n",
      "class 8: 243\n",
      "class 9: 1009\n"
     ]
    }
   ],
   "source": [
    "# create bias on Y \n",
    "\n",
    "# ====== Functions =========================================================\n",
    "def print_class_count(targets, out=None):\n",
    "    class_id, counts = torch.unique(targets, sorted=True, return_counts=True)\n",
    "    for cid, c in zip(class_id, counts):\n",
    "        print(f'class {int(cid)}: {c.item()}')\n",
    "        if out:\n",
    "            print(f'class {int(cid)}: {c.item()}', file=out)\n",
    "\n",
    "\n",
    "def filter_class_with_bias(data, targets, target_class, ratio = 4):\n",
    "    '''\n",
    "   Filter class and returns (ratio:1) biased dataset of given class list.\n",
    "    '''\n",
    "\n",
    "    if type(data) != torch.Tensor:\n",
    "        data = torch.Tensor(data)\n",
    "        targets = torch.Tensor(targets)\n",
    "\n",
    "    target_mask = (targets == target_class)\n",
    "    nontarget_mask = (targets != target_class)\n",
    "\n",
    "\n",
    "    print('Original class count: ')\n",
    "    print_class_count(targets)\n",
    "    \n",
    "    target_size = int(target_mask.sum().item() / ratio)\n",
    "    target_data = data[target_mask][:target_size]\n",
    "    target_label = targets[target_mask][:target_size]\n",
    "\n",
    "    data = torch.cat([data[nontarget_mask], target_data])\n",
    "    targets = torch.cat([targets[nontarget_mask], target_label])\n",
    "\n",
    "    print(f'\\nCreated dataset with {ratio}:1 ratio on Y')\n",
    "    print(\"Modified class count: \")\n",
    "    print_class_count(targets)\n",
    "\n",
    "    return data, targets\n",
    "# ============================================================================\n",
    "\n",
    "\n",
    "# filter class and create bias on train/test dataset\n",
    "train_img = train.data\n",
    "train_label = train.targets\n",
    "\n",
    "test_img = test.data\n",
    "test_label = test.targets\n",
    "\n",
    "train_img, train_label  = \\\n",
    "    filter_class_with_bias(train_img, train_label, target_class, bias_ratio)\n",
    "\n",
    "test_img, test_label  = \\\n",
    "    filter_class_with_bias(test_img, test_label, target_class, bias_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With Rotation, Augmented data size: 166833\n",
      "With Rotation, Augmented data size: 27807\n",
      "class 0: 17769\n",
      "class 1: 20226\n",
      "class 2: 17874\n",
      "class 3: 18393\n",
      "class 4: 17526\n",
      "class 5: 16263\n",
      "class 6: 17754\n",
      "class 7: 18795\n",
      "class 8: 4386\n",
      "class 9: 17847\n",
      "class 0: 2940\n",
      "class 1: 3405\n",
      "class 2: 3096\n",
      "class 3: 3030\n",
      "class 4: 2946\n",
      "class 5: 2676\n",
      "class 6: 2874\n",
      "class 7: 3084\n",
      "class 8: 729\n",
      "class 9: 3027\n"
     ]
    }
   ],
   "source": [
    "# augment data to 3x\n",
    "def rot_aug(train, Y):\n",
    "    '''\n",
    "    Rotate clockwise/counterclockwise slightly.\n",
    "    Includes plot test of random data. \n",
    "\n",
    "    Returns:\n",
    "        3x augmented data\n",
    "    '''\n",
    "    transforms_rotate_r = tf.RandomRotation(degrees=(5,10))\n",
    "    transforms_rotate_l = tf.RandomRotation(degrees=(-10, -5))\n",
    "    r_rot_data = transforms_rotate_r(train)\n",
    "    l_rot_data = transforms_rotate_l(train)\n",
    "\n",
    "    data_3x = torch.cat([train, l_rot_data, r_rot_data], dim=0)\n",
    "    data_3x_Y = Y.repeat(3)\n",
    "    print(f'With Rotation, Augmented data size: {data_3x.shape[0]}')\n",
    "\n",
    "    \n",
    "    return data_3x, data_3x_Y\n",
    "\n",
    "\n",
    "train_img, train_label = rot_aug(train_img, train_label)\n",
    "test_img, test_label = rot_aug(test_img, test_label)\n",
    "\n",
    "print_class_count(train_label)\n",
    "print_class_count(test_label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class 0: 17190\n",
      "class 1: 17190\n",
      "class 2: 17190\n",
      "class 3: 17190\n",
      "class 4: 17190\n",
      "class 5: 16260\n",
      "class 6: 17190\n",
      "class 7: 17190\n",
      "class 8: 4290\n",
      "class 9: 17190\n"
     ]
    }
   ],
   "source": [
    "# let 8 = 4290 and others 17190 thus sum is 159000\n",
    "\n",
    "# Separate data by class\n",
    "class_data = [[] for _ in range(10)]\n",
    "for i in range(len(train_img)):\n",
    "    label = train_label[i]\n",
    "    class_data[label].append(i)\n",
    "\n",
    "# Define the desired number of samples for class 8 and the rest of the classes\n",
    "desired_samples_8 = 4290\n",
    "desired_samples_other = 17190\n",
    "\n",
    "# Randomly sample from other classes to balance class 8\n",
    "balanced_indices = []\n",
    "for i, data_indices in enumerate(class_data):\n",
    "    if i == 8:  # Skip class 8\n",
    "        continue\n",
    "\n",
    "    num_samples = min(len(data_indices), desired_samples_other)\n",
    "    if num_samples % 10 != 0:\n",
    "        num_samples = num_samples - num_samples % 10\n",
    "    balanced_indices.extend(torch.tensor(data_indices)[:num_samples])\n",
    "\n",
    "# Sample from class 8\n",
    "num_samples_8 = min(len(class_data[8]), desired_samples_8)\n",
    "balanced_indices.extend(torch.tensor(class_data[8])[:num_samples_8])\n",
    "\n",
    "train_img = train_img[balanced_indices]\n",
    "train_label = train_label[balanced_indices]\n",
    "\n",
    "print_class_count(train_label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhDUlEQVR4nO3de3BU5f3H8c8mkgUxWQyQGzeJqKhcbJFERsQgGUKqDqB4q22h42DVYBVELI6Cth1TsQKjIjrWgoz3C5dqLY4GE6rlIigytEJJDAJCguCwy8UEJM/vD35uXZOAZ9nNNwnv18wzQ855vnu+HM7kw9k9eeJzzjkBANDEEqwbAACcnAggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCDgBG3ZskU+n09//vOfY/aapaWl8vl8Ki0tjdlrAs0NAYST0vz58+Xz+bRmzRrrVuLmyy+/1LXXXqsOHTooJSVFI0eO1Oeff27dFhB2inUDAGJv//79Gjp0qILBoO699161adNGs2bN0qWXXqp169apY8eO1i0CBBDQGj355JPavHmzVq9erYEDB0qSCgsL1adPHz366KN66KGHjDsEeAsOaNShQ4c0bdo0DRgwQIFAQO3bt9cll1yi999/v9GaWbNmqUePHmrXrp0uvfRSbdiwod6cjRs3asyYMUpNTVXbtm114YUX6m9/+9tx+zl48KA2btyo3bt3H3fu66+/roEDB4bDR5J69+6tYcOG6dVXXz1uPdAUCCCgEaFQSH/5y1+Ul5enhx9+WA888IC++uorFRQUaN26dfXmL1iwQI899piKioo0depUbdiwQZdddpmqq6vDc/7973/roosu0meffabf/e53evTRR9W+fXuNGjVKixYtOmY/q1ev1rnnnqsnnnjimPPq6uq0fv16XXjhhfX25eTkqKKiQvv27ftxJwGII96CAxpx+umna8uWLUpKSgpvGz9+vHr37q3HH39czz77bMT88vJybd68WV26dJEkjRgxQrm5uXr44Yc1c+ZMSdIdd9yh7t2766OPPpLf75ck3XbbbRo8eLDuuecejR49+oT7/vrrr1VbW6vMzMx6+77btmPHDp1zzjknfCzgRHAHBDQiMTExHD51dXX6+uuv9e233+rCCy/Uxx9/XG/+qFGjwuEjHb3byM3N1dtvvy3paDAsW7ZM1157rfbt26fdu3dr9+7d2rNnjwoKCrR582Z9+eWXjfaTl5cn55weeOCBY/b9zTffSFI44L6vbdu2EXMASwQQcAzPPfec+vXrp7Zt26pjx47q3Lmz/v73vysYDNabe9ZZZ9XbdvbZZ2vLli2Sjt4hOed0//33q3PnzhFj+vTpkqRdu3adcM/t2rWTJNXW1tbbV1NTEzEHsMRbcEAjnn/+eY0bN06jRo3S3XffrbS0NCUmJqq4uFgVFRWeX6+urk6SNHnyZBUUFDQ4p1evXifUsySlpqbK7/dr586d9fZ9ty0rK+uEjwOcKAIIaMTrr7+u7OxsLVy4UD6fL7z9u7uVH9q8eXO9bf/97391xhlnSJKys7MlSW3atFF+fn7sG/5/CQkJ6tu3b4M/ZLtq1SplZ2crOTk5bscHfizeggMakZiYKElyzoW3rVq1SitWrGhw/uLFiyM+w1m9erVWrVqlwsJCSVJaWpry8vL09NNPN3h38tVXXx2zHy+PYY8ZM0YfffRRRAht2rRJy5Yt0zXXXHPceqApcAeEk9pf//pXLV26tN72O+64Q1dccYUWLlyo0aNH6/LLL1dlZaWeeuopnXfeedq/f3+9ml69emnw4MG69dZbVVtbq9mzZ6tjx46aMmVKeM6cOXM0ePBg9e3bV+PHj1d2draqq6u1YsUKbd++XZ9++mmjva5evVpDhw7V9OnTj/sgwm233aZnnnlGl19+uSZPnqw2bdpo5syZSk9P11133fXjTxAQRwQQTmpz585tcPu4ceM0btw4VVVV6emnn9Y777yj8847T88//7xee+21BhcJ/dWvfqWEhATNnj1bu3btUk5Ojp544omIx6HPO+88rVmzRg8++KDmz5+vPXv2KC0tTT/5yU80bdq0mP29kpOTVVpaqokTJ+qPf/yj6urqlJeXp1mzZqlz584xOw5wInzu++8vAADQRPgMCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYaHY/B1RXV6cdO3YoOTk5YvkTAEDL4JzTvn37lJWVpYSExu9zml0A7dixQ926dbNuAwBwgrZt26auXbs2ur/ZvQXHIokA0Doc7/t53AJozpw5OuOMM9S2bVvl5uZq9erVP6qOt90AoHU43vfzuATQK6+8okmTJmn69On6+OOP1b9/fxUUFMTkl20BAFoJFwc5OTmuqKgo/PWRI0dcVlaWKy4uPm5tMBh0khgMBoPRwkcwGDzm9/uY3wEdOnRIa9eujfiFWwkJCcrPz2/w96jU1tYqFApFDABA6xfzANq9e7eOHDmi9PT0iO3p6emqqqqqN7+4uFiBQCA8eAIOAE4O5k/BTZ06VcFgMDy2bdtm3RIAoAnE/OeAOnXqpMTERFVXV0dsr66uVkZGRr35fr9ffr8/1m0AAJq5mN8BJSUlacCAASopKQlvq6urU0lJiQYNGhTrwwEAWqi4rIQwadIkjR07VhdeeKFycnI0e/ZsHThwQL/+9a/jcTgAQAsUlwC67rrr9NVXX2natGmqqqrSBRdcoKVLl9Z7MAEAcPLyOeecdRPfFwqFFAgErNsAAJygYDColJSURvebPwUHADg5EUAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxCnWDQBofvLy8jzXlJSUeK5JSPD+f+BoeisrK/Ncg/jjDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJFiMFWrFx48ZFVXf77bd7rqmrq4vqWF7NnDnTc82CBQuiOtacOXM813z77bdRHetkxB0QAMAEAQQAMBHzAHrggQfk8/kiRu/evWN9GABACxeXz4DOP/98vffee/87yCl81AQAiBSXZDjllFOUkZERj5cGALQScfkMaPPmzcrKylJ2drZuvPFGbd26tdG5tbW1CoVCEQMA0PrFPIByc3M1f/58LV26VHPnzlVlZaUuueQS7du3r8H5xcXFCgQC4dGtW7dYtwQAaIZiHkCFhYW65ppr1K9fPxUUFOjtt9/W3r179eqrrzY4f+rUqQoGg+Gxbdu2WLcEAGiG4v50QIcOHXT22WervLy8wf1+v19+vz/ebQAAmpm4/xzQ/v37VVFRoczMzHgfCgDQgsQ8gCZPnqyysjJt2bJF//rXvzR69GglJibqhhtuiPWhAAAtWMzfgtu+fbtuuOEG7dmzR507d9bgwYO1cuVKde7cOdaHAgC0YD7nnLNu4vtCoZACgYB1G0CzE83Cor/85S+jOtaQIUOiqvMqIcH7mzBNteipJPXq1ctzzRdffBGHTlqmYDColJSURvezFhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATcf+FdEBL0qFDB881F1xwgeeaefPmea7p1KmT55q2bdt6ronWxo0bPddEsxjp2Wef7bkGzRN3QAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE6yGjVZp1KhRUdWNHz/ec83w4cM910SzCnRdXZ3nmqb0yCOPeK6J5jw888wznmvQPHEHBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwASLkaLZ+8UvfuG55rnnnotDJ7ETzSKczZ3P52uS47TGc3ey4l8SAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACRYjRZOKZmHR2bNne66pq6vzXCNJNTU1nmuqq6s91yQnJ3uuSU1N9VwTrWjOQygU8lwTCAQ810T7b4vmhzsgAIAJAggAYMJzAC1fvlxXXnmlsrKy5PP5tHjx4oj9zjlNmzZNmZmZateunfLz87V58+ZY9QsAaCU8B9CBAwfUv39/zZkzp8H9M2bM0GOPPaannnpKq1atUvv27VVQUBDVe8oAgNbL80MIhYWFKiwsbHCfc06zZ8/Wfffdp5EjR0qSFixYoPT0dC1evFjXX3/9iXULAGg1YvoZUGVlpaqqqpSfnx/eFggElJubqxUrVjRYU1tbq1AoFDEAAK1fTAOoqqpKkpSenh6xPT09Pbzvh4qLixUIBMKjW7dusWwJANBMmT8FN3XqVAWDwfDYtm2bdUsAgCYQ0wDKyMiQVP8H86qrq8P7fsjv9yslJSViAABav5gGUM+ePZWRkaGSkpLwtlAopFWrVmnQoEGxPBQAoIXz/BTc/v37VV5eHv66srJS69atU2pqqrp3764777xTf/zjH3XWWWepZ8+euv/++5WVlaVRo0bFsm8AQAvnOYDWrFmjoUOHhr+eNGmSJGns2LGaP3++pkyZogMHDujmm2/W3r17NXjwYC1dulRt27aNXdcAgBbP55xz1k18XygUimqBQjS9aO5q33jjDc81Tbn4ZFlZmeea7//YwY81btw4zzXPPPOM55poffcfSy8ef/xxzzXN/Tz06tXLc80XX3wRh05apmAweMzP9c2fggMAnJwIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACY8/zoGtD7RrEgsSbNnz45pH42pqanxXLNq1aqojvXb3/42qrqm8Omnn3quee6556I61ty5c6Oq8+r111/3XDN+/HjPNTk5OZ5rEH/cAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBYqTQ/fffH1Vd+/btY9xJwx566CHPNcXFxXHoJHY++OADzzX/+Mc/PNdUV1d7rmlK+/fv91xTW1sbh05ggTsgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJliMtJW54IILPNckJydHdayEBO//f0lMTIzqWK1NeXm5dQstls/n81wTzbWK+ONfBQBgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkWI23G+vTp47nmjTfe8Fxz+umne66RpLq6uqjqgO+cdtppnmuSkpI813CtNk/cAQEATBBAAAATngNo+fLluvLKK5WVlSWfz6fFixdH7B83bpx8Pl/EGDFiRKz6BQC0Ep4D6MCBA+rfv7/mzJnT6JwRI0Zo586d4fHSSy+dUJMAgNbH80MIhYWFKiwsPOYcv9+vjIyMqJsCALR+cfkMqLS0VGlpaTrnnHN06623as+ePY3Ora2tVSgUihgAgNYv5gE0YsQILViwQCUlJXr44YdVVlamwsJCHTlypMH5xcXFCgQC4dGtW7dYtwQAaIZi/nNA119/ffjPffv2Vb9+/XTmmWeqtLRUw4YNqzd/6tSpmjRpUvjrUChECAHASSDuj2FnZ2erU6dOKi8vb3C/3+9XSkpKxAAAtH5xD6Dt27drz549yszMjPehAAAtiOe34Pbv3x9xN1NZWal169YpNTVVqampevDBB3X11VcrIyNDFRUVmjJlinr16qWCgoKYNg4AaNk8B9CaNWs0dOjQ8NfffX4zduxYzZ07V+vXr9dzzz2nvXv3KisrS8OHD9cf/vAH+f3+2HUNAGjxPAdQXl6enHON7n/nnXdOqCH8z2OPPea5pnv37nHoBIiPMWPGeK7JycmJQyewwFpwAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATMf+V3Dh5TJkyxboFNCO9e/f2XDNjxow4dFLfli1boqqrqamJbSOIwB0QAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEyxGiqjt2bPHugXESTQLiy5ZssRzTceOHT3X7Nq1y3PNmDFjPNdIUnV1dVR1+HG4AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGDC55xz1k18XygUUiAQsG6jWXj//fc91wwZMiQOncROYmKidQst1mmnnea5ZsGCBVEda+TIkVHVefX55597rrniiis812zatMlzDU5cMBhUSkpKo/u5AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCxUibsWHDhnmueeWVVzzXNOX5/uCDDzzXRHOJLlmyxHONFN2ilVOmTPFc4/P5PNckJSV5rsnJyfFcI0k1NTWeax566CHPNQsXLvRcw8KiLQeLkQIAmiUCCABgwlMAFRcXa+DAgUpOTlZaWppGjRpV73a4pqZGRUVF6tixo0477TRdffXVqq6ujmnTAICWz1MAlZWVqaioSCtXrtS7776rw4cPa/jw4Tpw4EB4zsSJE/Xmm2/qtddeU1lZmXbs2KGrrroq5o0DAFq2U7xMXrp0acTX8+fPV1pamtauXashQ4YoGAzq2Wef1YsvvqjLLrtMkjRv3jyde+65WrlypS666KLYdQ4AaNFO6DOgYDAoSUpNTZUkrV27VocPH1Z+fn54Tu/evdW9e3etWLGiwdeora1VKBSKGACA1i/qAKqrq9Odd96piy++WH369JEkVVVVKSkpSR06dIiYm56erqqqqgZfp7i4WIFAIDy6desWbUsAgBYk6gAqKirShg0b9PLLL59QA1OnTlUwGAyPbdu2ndDrAQBaBk+fAX1nwoQJeuutt7R8+XJ17do1vD0jI0OHDh3S3r17I+6CqqurlZGR0eBr+f1++f3+aNoAALRgnu6AnHOaMGGCFi1apGXLlqlnz54R+wcMGKA2bdqopKQkvG3Tpk3aunWrBg0aFJuOAQCtgqc7oKKiIr344otasmSJkpOTw5/rBAIBtWvXToFAQDfddJMmTZqk1NRUpaSk6Pbbb9egQYN4Ag4AEMFTAM2dO1eSlJeXF7F93rx5GjdunCRp1qxZSkhI0NVXX63a2loVFBToySefjEmzAIDWg8VIW5lLL73Uc80bb7wR1bGi+XdKSPD+3EtdXZ3nmuauqc5DWVmZ5xpJWrBgQZPUoHVjMVIAQLNEAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBathQly5doqq7+eabPdfcd999nmta42rYu3bt8lzzz3/+03PNb37zG8810tFVjIETxWrYAIBmiQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkWI0WTGjt2rOeayZMne67p3bu35xpJ2rhxo+eaRx55xHNNRUWF55oPP/zQcw1gicVIAQDNEgEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMsRgoAiAsWIwUANEsEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDhKYCKi4s1cOBAJScnKy0tTaNGjdKmTZsi5uTl5cnn80WMW265JaZNAwBaPk8BVFZWpqKiIq1cuVLvvvuuDh8+rOHDh+vAgQMR88aPH6+dO3eGx4wZM2LaNACg5TvFy+SlS5dGfD1//nylpaVp7dq1GjJkSHj7qaeeqoyMjNh0CABolU7oM6BgMChJSk1Njdj+wgsvqFOnTurTp4+mTp2qgwcPNvoatbW1CoVCEQMAcBJwUTpy5Ii7/PLL3cUXXxyx/emnn3ZLly5169evd88//7zr0qWLGz16dKOvM336dCeJwWAwGK1sBIPBY+ZI1AF0yy23uB49erht27Ydc15JSYmT5MrLyxvcX1NT44LBYHhs27bN/KQxGAwG48TH8QLI02dA35kwYYLeeustLV++XF27dj3m3NzcXElSeXm5zjzzzHr7/X6//H5/NG0AAFowTwHknNPtt9+uRYsWqbS0VD179jxuzbp16yRJmZmZUTUIAGidPAVQUVGRXnzxRS1ZskTJycmqqqqSJAUCAbVr104VFRV68cUX9bOf/UwdO3bU+vXrNXHiRA0ZMkT9+vWLy18AANBCefncR428zzdv3jznnHNbt251Q4YMcampqc7v97tevXq5u++++7jvA35fMBg0f9+SwWAwGCc+jve93/f/wdJshEIhBQIB6zYAACcoGAwqJSWl0f2sBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMNHsAsg5Z90CACAGjvf9vNkF0L59+6xbAADEwPG+n/tcM7vlqKur044dO5ScnCyfzxexLxQKqVu3btq2bZtSUlKMOrTHeTiK83AU5+EozsNRzeE8OOe0b98+ZWVlKSGh8fucU5qwpx8lISFBXbt2PeaclJSUk/oC+w7n4SjOw1Gch6M4D0dZn4dAIHDcOc3uLTgAwMmBAAIAmGhRAeT3+zV9+nT5/X7rVkxxHo7iPBzFeTiK83BUSzoPze4hBADAyaFF3QEBAFoPAggAYIIAAgCYIIAAACYIIACAiRYTQHPmzNEZZ5yhtm3bKjc3V6tXr7Zuqck98MAD8vl8EaN3797WbcXd8uXLdeWVVyorK0s+n0+LFy+O2O+c07Rp05SZmal27dopPz9fmzdvtmk2jo53HsaNG1fv+hgxYoRNs3FSXFysgQMHKjk5WWlpaRo1apQ2bdoUMaempkZFRUXq2LGjTjvtNF199dWqrq426jg+fsx5yMvLq3c93HLLLUYdN6xFBNArr7yiSZMmafr06fr444/Vv39/FRQUaNeuXdatNbnzzz9fO3fuDI8PPvjAuqW4O3DggPr37685c+Y0uH/GjBl67LHH9NRTT2nVqlVq3769CgoKVFNT08SdxtfxzoMkjRgxIuL6eOmll5qww/grKytTUVGRVq5cqXfffVeHDx/W8OHDdeDAgfCciRMn6s0339Rrr72msrIy7dixQ1dddZVh17H3Y86DJI0fPz7iepgxY4ZRx41wLUBOTo4rKioKf33kyBGXlZXliouLDbtqetOnT3f9+/e3bsOUJLdo0aLw13V1dS4jI8M98sgj4W179+51fr/fvfTSSwYdNo0fngfnnBs7dqwbOXKkST9Wdu3a5SS5srIy59zRf/s2bdq41157LTzns88+c5LcihUrrNqMux+eB+ecu/TSS90dd9xh19SP0OzvgA4dOqS1a9cqPz8/vC0hIUH5+flasWKFYWc2Nm/erKysLGVnZ+vGG2/U1q1brVsyVVlZqaqqqojrIxAIKDc396S8PkpLS5WWlqZzzjlHt956q/bs2WPdUlwFg0FJUmpqqiRp7dq1Onz4cMT10Lt3b3Xv3r1VXw8/PA/feeGFF9SpUyf16dNHU6dO1cGDBy3aa1SzWw37h3bv3q0jR44oPT09Ynt6ero2btxo1JWN3NxczZ8/X+ecc4527typBx98UJdccok2bNig5ORk6/ZMVFVVSVKD18d3+04WI0aM0FVXXaWePXuqoqJC9957rwoLC7VixQolJiZatxdzdXV1uvPOO3XxxRerT58+ko5eD0lJSerQoUPE3NZ8PTR0HiTp5z//uXr06KGsrCytX79e99xzjzZt2qSFCxcadhup2QcQ/qewsDD85379+ik3N1c9evTQq6++qptuusmwMzQH119/ffjPffv2Vb9+/XTmmWeqtLRUw4YNM+wsPoqKirRhw4aT4nPQY2nsPNx8883hP/ft21eZmZkaNmyYKioqdOaZZzZ1mw1q9m/BderUSYmJifWeYqmurlZGRoZRV81Dhw4ddPbZZ6u8vNy6FTPfXQNcH/VlZ2erU6dOrfL6mDBhgt566y29//77Eb8/LCMjQ4cOHdLevXsj5rfW66Gx89CQ3NxcSWpW10OzD6CkpCQNGDBAJSUl4W11dXUqKSnRoEGDDDuzt3//flVUVCgzM9O6FTM9e/ZURkZGxPURCoW0atWqk/762L59u/bs2dOqrg/nnCZMmKBFixZp2bJl6tmzZ8T+AQMGqE2bNhHXw6ZNm7R169ZWdT0c7zw0ZN26dZLUvK4H66cgfoyXX37Z+f1+N3/+fPef//zH3Xzzza5Dhw6uqqrKurUmddddd7nS0lJXWVnpPvzwQ5efn+86derkdu3aZd1aXO3bt8998skn7pNPPnGS3MyZM90nn3zivvjiC+ecc3/6059chw4d3JIlS9z69evdyJEjXc+ePd0333xj3HlsHes87Nu3z02ePNmtWLHCVVZWuvfee8/99Kc/dWeddZarqamxbj1mbr31VhcIBFxpaanbuXNneBw8eDA855ZbbnHdu3d3y5Ytc2vWrHGDBg1ygwYNMuw69o53HsrLy93vf/97t2bNGldZWemWLFnisrOz3ZAhQ4w7j9QiAsg55x5//HHXvXt3l5SU5HJyctzKlSutW2py1113ncvMzHRJSUmuS5cu7rrrrnPl5eXWbcXd+++/7yTVG2PHjnXOHX0U+/7773fp6enO7/e7YcOGuU2bNtk2HQfHOg8HDx50w4cPd507d3Zt2rRxPXr0cOPHj291/0lr6O8vyc2bNy8855tvvnG33XabO/30092pp57qRo8e7Xbu3GnXdBwc7zxs3brVDRkyxKWmpjq/3+969erl7r77bhcMBm0b/wF+HxAAwESz/wwIANA6EUAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMDE/wEgY5/vUu9oAQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot test\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "def plot_img(data, label, idx):\n",
    "    plt.imshow(data[idx], cmap='gray')\n",
    "    plt.title(f'Label: {label[idx]}')\n",
    "    plt.show()\n",
    "\n",
    "plot_img(train_img, train_label, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle dataset and save\n",
    "def shuffle_with_A(data, Y):\n",
    "    '''\n",
    "    Shuffle whole dataset.\n",
    "    '''\n",
    "    indices = torch.randperm(data.size(0))\n",
    "    return torch.index_select(data, dim=0, index=indices), \\\n",
    "        torch.index_select(Y, dim=0, index=indices)\n",
    "\n",
    "train_img, train_label = shuffle_with_A(train_img, train_label)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class 0: 17190\n",
      "class 1: 17190\n",
      "class 2: 17190\n",
      "class 3: 17190\n",
      "class 4: 17190\n",
      "class 5: 16260\n",
      "class 6: 17190\n",
      "class 7: 17190\n",
      "class 8: 4290\n",
      "class 9: 17190\n"
     ]
    }
   ],
   "source": [
    "# save\n",
    "\n",
    "train_data = train_img\n",
    "train_Y = train_label\n",
    "\n",
    "# save data\n",
    "savedir = '/home/.../pfgan_hub/dataset/mnist/downsized/minor_8_bias_4'\n",
    "torch.save(train_data, path.join(savedir, f'train_data.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(train_Y, path.join(savedir, f'train_Y.pt'), _use_new_zipfile_serialization=False)\n",
    "\n",
    "# save info\n",
    "with open(path.join(savedir, 'class_info.txt'), 'w') as f:\n",
    "    f.write(f'Total dataset size: {train_data.shape[0]}\\n')\n",
    "    f.write(f'===================\\n')\n",
    "    print_class_count(train_Y, out=f)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfgan",
   "language": "python",
   "name": "pfgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
