{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "75f09f61-7920-4579-b115-e179dc59ed6e",
   "metadata": {},
   "source": [
    "## Same with t2.ipynb, but for imbalanced case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "45dae1bc-e492-46c9-9c3c-0fbed305b6c4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision.datasets import CIFAR10\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bf450914-2674-42c9-9f02-06ba3d40f2c6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "data_dir = \"<root_dir>\" # Change as wish\n",
    "c10_train = CIFAR10(root=data_dir, download=True, train=True, transform=None)\n",
    "c10_test = CIFAR10(root=data_dir, download=True, train=False, transform=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9c609fb0-1d95-456e-a4aa-663bddac5c60",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def pad_image(image,pad_size,left=True):\n",
    "    pad_one_end = np.random.randint(int(pad_size//4))\n",
    "    pad_other_end = int(pad_size - pad_one_end)\n",
    "    if left:\n",
    "        new_image = F.pad(image, (pad_other_end, pad_one_end, pad_other_end, pad_one_end),\"constant\", 0)\n",
    "    else: \n",
    "        new_image = F.pad(image, (pad_one_end, pad_other_end, pad_one_end, pad_other_end),\"constant\", 0)\n",
    "    return new_image\n",
    "\n",
    "def create_m1_samples(data, targets, class_size=5000, pad_size=32):\n",
    "    all_data = []\n",
    "    all_targets = []\n",
    "    for cla_idx in range(10):\n",
    "        cla_length = np.sum(targets==cla_idx).item()\n",
    "        # Some randomness\n",
    "        chosen_ones = np.random.permutation(cla_length)[:class_size]\n",
    "        cur_cla_data = data[targets==cla_idx][chosen_ones]\n",
    "        for i in range(class_size):\n",
    "            first_left = torch.rand(1) < 0.5\n",
    "            im = cur_cla_data[i]\n",
    "            cur_data = pad_image(torch.from_numpy(im).float().permute(2,0,1), pad_size, first_left)\n",
    "            all_data.append(cur_data)\n",
    "            \n",
    "            cur_target = torch.zeros(10)\n",
    "            cur_target[cla_idx] = 1.0\n",
    "            all_targets.append(cur_target)\n",
    "    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)\n",
    "\n",
    "def create_m2_samples(data, targets, class_size=5000, pad_size=28):\n",
    "    all_data = []\n",
    "    all_targets = []\n",
    "    for idx_1 in range(10):\n",
    "        for idx_2 in range(idx_1+1, 10):\n",
    "            cla_length_1 = np.sum(targets==idx_1).item()\n",
    "            cla_length_2 = np.sum(targets==idx_2).item()\n",
    "            # Some randomness\n",
    "            chosen_ones_1 = np.random.permutation(cla_length_1)[:class_size]\n",
    "            chosen_ones_2 = np.random.permutation(cla_length_2)[:class_size]\n",
    "            cla_data_1 = data[targets==idx_1][chosen_ones_1]\n",
    "            cla_data_2 = data[targets==idx_2][chosen_ones_2]\n",
    "        \n",
    "            for i in range(class_size):\n",
    "                first_left = torch.rand(1) < 0.5\n",
    "                cur_data_1 = pad_image(torch.from_numpy(cla_data_1[i]).float().permute(2,0,1), pad_size, left=first_left)\n",
    "                cur_data_2 = pad_image(torch.from_numpy(cla_data_2[i]).float().permute(2,0,1), pad_size, left=not first_left)\n",
    "                all_data.append(torch.maximum(cur_data_1, cur_data_2))\n",
    "\n",
    "                cur_target = torch.zeros(10)\n",
    "                cur_target[idx_1] = 0.5\n",
    "                cur_target[idx_2] = 0.5\n",
    "                all_targets.append(cur_target)\n",
    "    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)\n",
    "\n",
    "def create_m1_samples_imbalance(data, targets, total_class_size=[5000,100], pad_size=32):\n",
    "    all_data = []\n",
    "    all_targets = []\n",
    "    large_class_size, small_class_size = total_class_size\n",
    "    for cla_idx in range(10):\n",
    "        # Choose class size accordingly\n",
    "        if cla_idx < 5:\n",
    "            class_size = large_class_size\n",
    "        else:\n",
    "            class_size = small_class_size\n",
    "        cla_length = np.sum(targets==cla_idx).item()\n",
    "        # Some randomness\n",
    "        chosen_ones = np.random.permutation(cla_length)[:class_size]\n",
    "        cur_cla_data = data[targets==cla_idx][chosen_ones]\n",
    "        for i in range(class_size):\n",
    "            first_left = torch.rand(1) < 0.5\n",
    "            im = cur_cla_data[i]\n",
    "            cur_data = pad_image(torch.from_numpy(im).float().permute(2,0,1), pad_size, first_left)\n",
    "            all_data.append(cur_data)\n",
    "            \n",
    "            cur_target = torch.zeros(10)\n",
    "            cur_target[cla_idx] = 1.0\n",
    "            all_targets.append(cur_target)\n",
    "    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)\n",
    "\n",
    "def create_m2_samples_imbalance(data, targets, total_class_size=[500,50,10], pad_size=32):\n",
    "    all_data = []\n",
    "    all_targets = []\n",
    "    large_class_size, coincide_class_size, small_class_size = total_class_size\n",
    "    for idx_1 in range(10):\n",
    "        for idx_2 in range(idx_1+1, 10):\n",
    "            # Choose class size accordingly\n",
    "            if idx_1 < 5 and idx_2 < 5:\n",
    "                # Both from large class\n",
    "                class_size = large_class_size \n",
    "            elif idx_1 < 5 or idx_2 < 5:\n",
    "                # One from large class, one from small class\n",
    "                class_size = coincide_class_size\n",
    "            elif idx_1 >= 5 and idx_2 >= 5:\n",
    "                # Both from small class\n",
    "                class_size = small_class_size\n",
    "            else:\n",
    "                print(\"???\")\n",
    "                \n",
    "            cla_length_1 = np.sum(targets==idx_1).item()\n",
    "            cla_length_2 = np.sum(targets==idx_2).item()\n",
    "            # Some randomness\n",
    "            chosen_ones_1 = np.random.permutation(cla_length_1)[:class_size]\n",
    "            chosen_ones_2 = np.random.permutation(cla_length_2)[:class_size]\n",
    "            cla_data_1 = data[targets==idx_1][chosen_ones_1]\n",
    "            cla_data_2 = data[targets==idx_2][chosen_ones_2]\n",
    "        \n",
    "            for i in range(class_size):\n",
    "                first_left = torch.rand(1) < 0.5\n",
    "                cur_data_1 = pad_image(torch.from_numpy(cla_data_1[i]).float().permute(2,0,1), pad_size, left=first_left)\n",
    "                cur_data_2 = pad_image(torch.from_numpy(cla_data_2[i]).float().permute(2,0,1), pad_size, left=not first_left)\n",
    "                all_data.append(torch.maximum(cur_data_1, cur_data_2))\n",
    "\n",
    "                cur_target = torch.zeros(10)\n",
    "                cur_target[idx_1] = 0.5\n",
    "                cur_target[idx_2] = 0.5\n",
    "                all_targets.append(cur_target)\n",
    "    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)\n",
    "\n",
    "def create_dataset(num_samples, data, targets, pad_size=32):\n",
    "    m1_num, m2_num = num_samples\n",
    "    m1_data, m1_targets = create_m1_samples(data, targets, class_size=m1_num, pad_size=pad_size)\n",
    "    m2_data, m2_targets = create_m2_samples(data, targets, class_size=m2_num, pad_size=pad_size)\n",
    "    \n",
    "    return torch.cat([m1_data, m2_data], dim=0), torch.cat([m1_targets, m2_targets], dim=0)\n",
    "\n",
    "def create_dataset_imbalance(num_samples, data, targets, pad_size=32):\n",
    "    m1_num, m2_num = num_samples\n",
    "    m1_data, m1_targets = create_m1_samples_imbalance(data, targets, total_class_size=m1_num, pad_size=pad_size)\n",
    "    m2_data, m2_targets = create_m2_samples_imbalance(data, targets, total_class_size=m2_num, pad_size=pad_size)\n",
    "    \n",
    "    return torch.cat([m1_data, m2_data], dim=0), torch.cat([m1_targets, m2_targets], dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0eb4c6d2-2a65-4bd2-bcdf-f4aff5f201c4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainset, trainlabels = create_dataset_imbalance([[5000, 5000], [500,50,5]], c10_train.data / np.max(c10_train.data), \n",
    "                                       np.array(c10_train.targets), pad_size=32)\n",
    "testset, testlabels = create_dataset([800, 50], c10_test.data / np.max(c10_train.data), \n",
    "                                       np.array(c10_test.targets), pad_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "96f7d574-4a5a-4e28-9743-0cd7cae79a3b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2250.,  2250.,  2250.,  2250.,  2250.,  9270., 10270., 10270., 10270.,\n",
       "        10270.])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(trainlabels[25500:], dim=0)*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "268ab90f-f3a4-4d0f-8cf3-e1a0e2e32bd0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "to_save = {\n",
    "            \"train_data\": trainset,\n",
    "            \"train_label\": trainlabels,\n",
    "            \"test_data\": testset,\n",
    "            \"test_label\": testlabels\n",
    "            }\n",
    "    \n",
    "with open(\"<root_dir>/c10_imbalance_multip_combine.pkl\", 'wb') as f: \n",
    "    pickle.dump(to_save, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0956db7e-8e22-425d-8eaf-756436651cd9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
