{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Create MNIST/FMNIST-90'Rotation dataset with two class.\n",
    "You can create bias in either class label(denoted by Y) or rotation(denoted by A), or both.\n",
    "\n",
    "Config:\n",
    "    dataset : 'fmnist', 'mnist'\n",
    "    bias_ratio : Major/Minor Ratio.\n",
    "    bias_factor: 'None', 'Y', 'A', 'Both'\n",
    "    class_list : two class labels(int) to filter (ex. [2, 8] for MNIST)\n",
    "                Note that order matters, which is [major, minor]\n",
    "    random_seed : random seed (default 0)\n",
    "'''\n",
    "\n",
    "\n",
    "import torch\n",
    "from numpy import random\n",
    "\n",
    "# config\n",
    "dataset = 'mnist'\n",
    "bias_ratio = 4\n",
    "bias_factor = 'Y'\n",
    "class_list = [1, 6]\n",
    "random_seed = 0\n",
    "\n",
    "# random seed\n",
    "torch.manual_seed(random_seed)\n",
    "random.seed(random_seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as tf\n",
    "import torchvision.datasets as datasets\n",
    "from os import path, mkdir, makedirs\n",
    "\n",
    "\n",
    "\n",
    "# load data\n",
    "if not path.exists(path.join(f'original')):\n",
    "    mkdir(path.join(f'original'))\n",
    "\n",
    "if dataset == 'mnist':\n",
    "    train = datasets.MNIST(root='./original', train=True, transform=tf.ToTensor(), download=True)\n",
    "    test = datasets.MNIST(root='./original', train=False, transform=tf.ToTensor(), download=True)\n",
    "\n",
    "elif dataset == 'fmnist':\n",
    "    train = datasets.FashionMNIST(root='./original', train=True, transform=tf.ToTensor(), download=True)\n",
    "    test = datasets.FashionMNIST(root='./original', train=False, transform=tf.ToTensor(), download=True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError(\"Dataset should be mnist or fmnist\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def get_savename(bias_factor, bias_ratio, class_list, dataset):\n",
    "    '''\n",
    "    Return dataset name along with bias factor, bias ratio, and class list.\n",
    "    '''\n",
    "    if bias_factor == 'None':\n",
    "        setting = 'Unbiased'\n",
    "    elif bias_factor == 'Y':\n",
    "        setting = f'Y_{bias_ratio}1'\n",
    "    elif bias_factor == 'A':\n",
    "        setting = f'A_{bias_ratio}1'\n",
    "    elif bias_factor == 'Both':\n",
    "        setting = f'Both_{bias_ratio}1'\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "    digit_name = ''.join(str(digit) for digit in class_list)\n",
    "    savename = f'{dataset}_{digit_name}_{setting}'\n",
    "    print(f'Savename of the dataset: {savename}')\n",
    "\n",
    "    return savename\n",
    "\n",
    "savename = get_savename(bias_factor, bias_ratio, class_list, dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create bias on Y \n",
    "\n",
    "# ====== Functions =========================================================\n",
    "def filter_class_with_bias(data, targets, class_list, ratio = 4):\n",
    "    '''\n",
    "   Filter class and returns (ratio:1) biased dataset of given class list.\n",
    "    '''\n",
    "\n",
    "    if type(data) != torch.Tensor:\n",
    "        data = torch.Tensor(data)\n",
    "        targets = torch.Tensor(targets)\n",
    "        \n",
    "    mask_minor = (targets == class_list[1])\n",
    "    mask_major = (targets == class_list[0])\n",
    "\n",
    "    print(f'\\nCreated dataset with {ratio}:1 ratio on Y')\n",
    "    print('original count of major: ', mask_major.sum().item())\n",
    "    print('original count of minor: ', mask_minor.sum().item())\n",
    "\n",
    "    num_minor = int(mask_major.sum() / ratio)\n",
    "    minor_class = data[mask_minor][:num_minor]\n",
    "    minor_class_label = targets[mask_minor][:num_minor]\n",
    "\n",
    "    major_class = data[mask_major]\n",
    "    major_class_label = targets[mask_major]\n",
    "\n",
    "    print('count of major: ', major_class_label.shape[0])\n",
    "    print('count of minor: ', minor_class_label.shape[0])\n",
    "    print('\\n')\n",
    "\n",
    "    return major_class, major_class_label, minor_class, minor_class_label\n",
    "# ============================================================================\n",
    "\n",
    "\n",
    "\n",
    "# filter class and create bias on train/test dataset\n",
    "train_img = train.data\n",
    "train_label = train.targets\n",
    "\n",
    "test_img = test.data\n",
    "test_label = test.targets\n",
    "\n",
    "Y_bias = bias_ratio if bias_factor == 'Y' or bias_factor=='Both' else 1\n",
    "\n",
    "train_data_major, train_Y_major, train_data_minor, train_Y_minor  = \\\n",
    "    filter_class_with_bias(train_img, train_label, class_list, Y_bias)\n",
    "\n",
    "test_data_major, test_Y_major, test_data_minor, test_Y_minor  = \\\n",
    "    filter_class_with_bias(test_img, test_label, class_list, Y_bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# just for now\n",
    "def rot_aug(train, Y):\n",
    "    transforms_rotate_r = tf.RandomRotation(degrees=(5,10))\n",
    "    transforms_rotate_l = tf.RandomRotation(degrees=(-10, -5))\n",
    "\n",
    "    r_rot_data = transforms_rotate_r(train)\n",
    "    l_rot_data = transforms_rotate_l(train)\n",
    "\n",
    "    data_3x = torch.cat([train, l_rot_data, r_rot_data], dim=0)\n",
    "    data_3x_Y = Y.repeat(3)\n",
    "\n",
    "    return data_3x, data_3x_Y\n",
    "\n",
    "train_data = torch.cat([train_data_minor, train_data_major], dim=0)\n",
    "train_Y = torch.cat([train_Y_minor, train_Y_major], dim=0)\n",
    "train_data, train_Y = rot_aug(train_data, train_Y)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shift_aug(train, Y):\n",
    "\n",
    "    transforms_shift = tf.RandomAffine(degrees=0, translate=(0.1,0.1))\n",
    "    shift_data = transforms_shift(train)\n",
    "\n",
    "    data_2x = torch.cat([train, shift_data], dim=0)\n",
    "    data_2x_Y = Y.repeat(2)\n",
    "\n",
    "    return data_2x, data_2x_Y\n",
    "\n",
    "train_data, train_Y = shift_aug(train_data, train_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rand_sample_with_A(data, Y, size):\n",
    "    indices = torch.randperm(data.size(0))[:size]\n",
    "    return torch.index_select(data, dim=0, index=indices), \\\n",
    "        torch.index_select(Y, dim=0, index=indices)\n",
    "\n",
    "if train_data.size(0) > 60000:\n",
    "    train_data, train_Y = rand_sample_with_A(train_data, train_Y, 60000)\n",
    "\n",
    "savedir = path.join('rotated', savename)\n",
    "if not path.exists(savedir):\n",
    "    makedirs(savedir)\n",
    "\n",
    "torch.save(train_data, path.join(savedir, f'train_data.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(train_Y, path.join(savedir, f'train_Y.pt'), _use_new_zipfile_serialization=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create bias on Z\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "# ====== Functions =========================================================\n",
    "def plot_img(data, Y, A, num):\n",
    "    '''\n",
    "    Plot test for given dataset. \n",
    "\n",
    "    Args:\n",
    "        num: number of plots\n",
    "    '''\n",
    "    for idx in torch.randint(0, data.shape[0], (num,)):\n",
    "        plt.imshow(data[idx], cmap='gray')\n",
    "        plt.title(f'Y={Y[idx]} A={A[idx]}')\n",
    "        plt.show()\n",
    "     \n",
    "\n",
    "def rotate_90(data, bias):\n",
    "    '''\n",
    "    Rotate 90 degree counterclockwise.\n",
    "\n",
    "    Returns:\n",
    "        A = 1: clean\n",
    "        A = 0: rotated\n",
    "    '''\n",
    "    minor_size = data.shape[0] // (bias + 1)\n",
    "    rot_data = tf.functional.rotate(data[:minor_size], 90)\n",
    "    A = torch.zeros((data.shape[0],))\n",
    "    A[minor_size:] = 1\n",
    "    print(f'Created dataset with {bias}:1 ratio on A')\n",
    "    print(f'Clean: {data.shape[0] - minor_size}\\nRotated: {minor_size}\\n')\n",
    "    return torch.cat([rot_data, data[minor_size:]], dim=0), A\n",
    "\n",
    "# ============================================================================\n",
    "\n",
    "# create rotation bias on train/test dataset\n",
    "A_bias = bias_ratio if bias_factor == 'A' or bias_factor == 'Both' else 1\n",
    "\n",
    "train_data_major, train_A_major = rotate_90(train_data_major, A_bias)\n",
    "train_data_minor, train_A_minor = rotate_90(train_data_minor, A_bias)\n",
    "\n",
    "test_data_major, test_A_major = rotate_90(test_data_major, A_bias)\n",
    "test_data_minor, test_A_minor = rotate_90(test_data_minor, A_bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# augment 6x to mimic original MNIST/FMNIST size \n",
    "\n",
    "# ====== Functions =========================================================\n",
    "def rot_aug(train, Y, A):\n",
    "    '''\n",
    "    Rotate clockwise/counterclockwise slightly.\n",
    "    Includes plot test of random data. \n",
    "\n",
    "    Returns:\n",
    "        3x augmented data\n",
    "    '''\n",
    "    transforms_rotate_r = tf.RandomRotation(degrees=(5,10))\n",
    "    transforms_rotate_l = tf.RandomRotation(degrees=(-10, -5))\n",
    "    r_rot_data = transforms_rotate_r(train)\n",
    "    l_rot_data = transforms_rotate_l(train)\n",
    "\n",
    "    data_3x = torch.cat([train, l_rot_data, r_rot_data], dim=0)\n",
    "    data_3x_Y = Y.repeat(3)\n",
    "    data_3x_A = A.repeat(3)\n",
    "    print(f'With Rotation, Augmented data size: {data_3x.shape[0]}')\n",
    "    \n",
    "\n",
    "    def plot_rot_aug(i):\n",
    "        '''\n",
    "        Plot test for Rotation Augmentation.\n",
    "        '''\n",
    "        fig = plt.figure()\n",
    "        rows, cols = 1 , 3\n",
    "\n",
    "        ax1 = fig.add_subplot(rows, cols, 1)\n",
    "        ax1.set_title(f'original')\n",
    "        ax1.imshow(train[i].squeeze(), cmap='gray')\n",
    "\n",
    "        ax2 = fig.add_subplot(rows, cols, 2)\n",
    "        ax2.set_title('rotate_right')\n",
    "        ax2.imshow(l_rot_data[i].squeeze(), cmap='gray')\n",
    "\n",
    "        ax3 = fig.add_subplot(rows, cols, 3)\n",
    "        ax3.set_title('rotate_left')\n",
    "        ax3.imshow(r_rot_data[i].squeeze(), cmap='gray')\n",
    "\n",
    "        fig.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    plot_rot_aug(torch.randint(0, train.shape[0], (1,)))\n",
    "    \n",
    "    return data_3x, data_3x_Y, data_3x_A\n",
    "\n",
    "\n",
    "def shift_aug(train, Y, A):\n",
    "    '''\n",
    "    Randomly shift horizontally/vertically data a little bit.\n",
    "    Includes plot test of random data. \n",
    "    \n",
    "    Returns:\n",
    "        2x augmented data\n",
    "    '''\n",
    "    transforms_shift = tf.RandomAffine(degrees=0, translate=(0.1,0.1))\n",
    "    shift_data = transforms_shift(train)\n",
    "\n",
    "    data_2x = torch.cat([train, shift_data], dim=0)\n",
    "    data_2x_Y = Y.repeat(2)\n",
    "    data_2x_A = A.repeat(2)\n",
    "    print(f'With Random Shift, Augmented data size: {data_2x.shape[0]}')\n",
    "\n",
    "    def plot_shift_aug(i):\n",
    "        '''\n",
    "        Plot test for Shift Augmentation\n",
    "        '''\n",
    "        fig = plt.figure()\n",
    "        rows, cols = 1 , 2\n",
    "\n",
    "        ax1 = fig.add_subplot(rows, cols, 1)\n",
    "        ax1.set_title(f'original')\n",
    "        ax1.imshow(train[i].squeeze(), cmap='gray')\n",
    "\n",
    "        ax2 = fig.add_subplot(rows, cols, 2)\n",
    "        ax2.set_title('shifted')\n",
    "        ax2.imshow(shift_data[i].squeeze(), cmap='gray')\n",
    "\n",
    "        fig.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    plot_shift_aug(torch.randint(0, train.shape[0], (1,)))\n",
    "\n",
    "    return data_2x, data_2x_Y, data_2x_A\n",
    "\n",
    "def shuffle_with_A(data, Y, A):\n",
    "    '''\n",
    "    Shuffle whole dataset.\n",
    "    '''\n",
    "    indices = torch.randperm(data.size(0))\n",
    "    return torch.index_select(data, dim=0, index=indices), \\\n",
    "        torch.index_select(Y, dim=0, index=indices), torch.index_select(A, dim=0, index=indices)\n",
    "# ============================================================================\n",
    "\n",
    "\n",
    "# putting everything together \n",
    "train_data = torch.cat([train_data_minor, train_data_major], dim=0)\n",
    "train_Y = torch.cat([train_Y_minor, train_Y_major], dim=0)\n",
    "train_A = torch.cat([train_A_minor, train_A_major], dim=0)\n",
    "train_data, train_Y, train_A  = shuffle_with_A(train_data, train_Y, train_A)\n",
    "\n",
    "test_data = torch.cat([test_data_minor, test_data_major], dim=0)\n",
    "test_Y = torch.cat([test_Y_minor, test_Y_major], dim=0)\n",
    "test_A = torch.cat([test_A_minor, test_A_major], dim=0)\n",
    "test_data, test_Y, test_A  = shuffle_with_A(test_data, test_Y, test_A)\n",
    "\n",
    "\n",
    "# augment 6x\n",
    "# if bias apperas both Y and A, augment 2x more\n",
    "train_data, train_Y, train_A= rot_aug(train_data, train_Y, train_A)\n",
    "train_data, train_Y, train_A = shift_aug(train_data, train_Y, train_A)\n",
    "if bias_factor == 'Both':\n",
    "    train_data, train_Y, train_A = shift_aug(train_data, train_Y, train_A)\n",
    "\n",
    "test_data, test_Y, test_A= rot_aug(test_data, test_Y, test_A)\n",
    "test_data, test_Y, test_A = shift_aug(test_data, test_Y, test_A)\n",
    "if bias_factor == 'Both':\n",
    "    test_data, test_Y, test_A = shift_aug(test_data, test_Y, test_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Dataset Size: 70000\n",
      "\tMajor Class 3: 52516\n",
      "\t\tClean: 39477\n",
      "\t\tRotated: 13039\n",
      "\tMinor Class 1: 17484\n",
      "\t\tClean: 13093\n",
      "\t\tRotated: 4391\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'test_Y' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 31\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[38;5;66;03m# check final group size with bias is correct\u001b[39;00m\n\u001b[1;32m     30\u001b[0m print_data_info(train_Y, train_A, class_list\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m---> 31\u001b[0m print_data_info(\u001b[43mtest_Y\u001b[49m, test_A, class_list\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m1\u001b[39m])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'test_Y' is not defined"
     ]
    }
   ],
   "source": [
    "# final test. Check datasets are well-made.\n",
    "\n",
    "# ====== Functions =========================================================\n",
    "def print_data_info(Y, A, class_list):\n",
    "    major, minor = class_list\n",
    "    cln_major = ((Y == major) & (A == 1)).sum()\n",
    "    rot_major = ((Y == major) & (A == 0)).sum()\n",
    "    cln_minor = ((Y == minor) & (A == 1)).sum()\n",
    "    rot_minor = ((Y == minor) & (A == 0)).sum()\n",
    "\n",
    "    print(f'Total Dataset Size: {Y.shape[0]}')\n",
    "    print(f'\\tMajor Class {major}: {cln_major+rot_major}')\n",
    "    print(f'\\t\\tClean: {cln_major}')\n",
    "    print(f'\\t\\tRotated: {rot_major}')\n",
    "    print(f'\\tMinor Class {minor}: {cln_minor+rot_minor}')\n",
    "    print(f'\\t\\tClean: {cln_minor}')\n",
    "    print(f'\\t\\tRotated: {rot_minor}')\n",
    "# ============================================================================\n",
    "\n",
    "# # check Y and A is correct\n",
    "# plot_img(train_data, train_Y, train_A, 5)\n",
    "# plot_img(test_data, test_Y, test_A, 5)\n",
    "\n",
    "import torch\n",
    "train_Y = torch.load(\"/home/.../nas/PF-GAN/dataset/rotated/mnist_31_Both_31_70k/train_Y.pt\")\n",
    "train_A = torch.load(\"/home/.../nas/PF-GAN/dataset/rotated/mnist_31_Both_31_70k/train_A.pt\")\n",
    "\n",
    "\n",
    "# check final group size with bias is correct\n",
    "print_data_info(train_Y, train_A, class_list=[3,1])\n",
    "print_data_info(test_Y, test_A, class_list=[3,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# random sample to mimic original dataset size of MNIST/FashionMNIST.\n",
    "\n",
    "def print_data_info(Y, A, class_list):\n",
    "    major, minor = class_list\n",
    "    cln_major = ((Y == major) & (A == 1)).sum()\n",
    "    rot_major = ((Y == major) & (A == 0)).sum()\n",
    "    cln_minor = ((Y == minor) & (A == 1)).sum()\n",
    "    rot_minor = ((Y == minor) & (A == 0)).sum()\n",
    "\n",
    "    print(f'Total Dataset Size: {Y.shape[0]}')\n",
    "    print(f'\\tMajor Class {major}: {cln_major+rot_major}')\n",
    "    print(f'\\t\\tClean: {cln_major}')\n",
    "    print(f'\\t\\tRotated: {rot_major}')\n",
    "    print(f'\\tMinor Class {minor}: {cln_minor+rot_minor}')\n",
    "    print(f'\\t\\tClean: {cln_minor}')\n",
    "    print(f'\\t\\tRotated: {rot_minor}')\n",
    "\n",
    "def rand_sample_with_A(data, Y, A, size):\n",
    "    '''\n",
    "    Random sampling of given size from the dataset.\n",
    "    '''\n",
    "\n",
    "    if data.size(0) < size:\n",
    "        print(\"Target size is bigger than current dataset size\")\n",
    "        return data, Y, A\n",
    "\n",
    "    indices = torch.randperm(data.size(0))[:size]\n",
    "    return torch.index_select(data, dim=0, index=indices), \\\n",
    "        torch.index_select(Y, dim=0, index=indices), torch.index_select(A, dim=0, index=indices)\n",
    "\n",
    "\n",
    "# 60k for train data, 10k for test data\n",
    "if train_data.size(0) > 60000 and test_data.size(0) > 10000:\n",
    "    train_data, train_Y, train_A  = rand_sample_with_A(train_data, train_Y, train_A, 60000)\n",
    "    test_data, test_Y, test_A  = rand_sample_with_A(test_data, test_Y, test_A, 10000)\n",
    "\n",
    "# check final group size with bias is correct\n",
    "print_data_info(train_Y, train_A, class_list)\n",
    "print_data_info(test_Y, test_A, class_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save data\n",
    "# data: (data size x 28 x 28)\n",
    "# Y : (data size,)\n",
    "# A : (data size,)\n",
    "\n",
    "\n",
    "# save directory\n",
    "savedir = path.join('rotated', savename)\n",
    "if not path.exists(savedir):\n",
    "    makedirs(savedir)\n",
    "\n",
    "# save train data\n",
    "torch.save(train_data, path.join(savedir, f'train_data.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(train_Y, path.join(savedir, f'train_Y.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(train_A, path.join(savedir, f'train_A.pt'), _use_new_zipfile_serialization=False)\n",
    "\n",
    "# save test data\n",
    "torch.save(test_data, path.join(savedir, f'test_data.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(test_Y, path.join(savedir, f'test_Y.pt'), _use_new_zipfile_serialization=False)\n",
    "torch.save(test_A, path.join(savedir, f'test_A.pt'), _use_new_zipfile_serialization=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfgan",
   "language": "python",
   "name": "pfgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
