{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pickle\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open(\"data/VGG-Face2/train_list.txt\", \"r\")\n",
    "train_txt = file.read().splitlines()\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3141890"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = []\n",
    "images = []\n",
    "for i in range(len(train_txt)):\n",
    "    entry = train_txt[i].split('/')\n",
    "    labels.append(entry[0])\n",
    "    #images.append(entry[1])\n",
    "\n",
    "len(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "792"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "label_counts = Counter(labels)\n",
    "labels = np.array(list(label_counts.keys()))\n",
    "counts = np.array(list(label_counts.values()))\n",
    "large_classes = labels[counts > 500]\n",
    "len(large_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_root = \"data/VGG-Face2/vggface2_train/train/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_pickle(\"MAAD_Face_1.0.pkl\")\n",
    "#df = pd.read_csv(\"MAAD_Face.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#data preprocessing\n",
    "df.at[250634, 'Identity'] = 713\n",
    "df[[\"Identity\", \"Male\"]] = df[[\"Identity\", \"Male\"]].apply(pd.to_numeric)\n",
    "\n",
    "folder_list = []\n",
    "for i in range(len(df)):\n",
    "    folder_list.append(df['Filename'][i].split('/')[0])\n",
    "\n",
    "df['Folder'] = folder_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Filename</th>\n",
       "      <th>Identity</th>\n",
       "      <th>Male</th>\n",
       "      <th>Young</th>\n",
       "      <th>Middle_Aged</th>\n",
       "      <th>Senior</th>\n",
       "      <th>Asian</th>\n",
       "      <th>White</th>\n",
       "      <th>Black</th>\n",
       "      <th>Rosy_Cheeks</th>\n",
       "      <th>...</th>\n",
       "      <th>Pointy_Nose</th>\n",
       "      <th>Heavy_Makeup</th>\n",
       "      <th>Wearing_Hat</th>\n",
       "      <th>Wearing_Earrings</th>\n",
       "      <th>Wearing_Necktie</th>\n",
       "      <th>Wearing_Lipstick</th>\n",
       "      <th>No_Eyewear</th>\n",
       "      <th>Eyeglasses</th>\n",
       "      <th>Attractive</th>\n",
       "      <th>Folder</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>n000002/0001_01.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>n000002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>n000002/0002_01.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>n000002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>n000002/0003_01.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>n000002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>n000002/0004_01.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>n000002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>n000002/0005_01.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>n000002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3308035</th>\n",
       "      <td>n009294/0284_01.jpg</td>\n",
       "      <td>9294</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>...</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>n009294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3308036</th>\n",
       "      <td>n009294/0287_01.jpg</td>\n",
       "      <td>9294</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>...</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>n009294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3308037</th>\n",
       "      <td>n009294/0289_01.jpg</td>\n",
       "      <td>9294</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>...</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>n009294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3308038</th>\n",
       "      <td>n009294/0290_02.jpg</td>\n",
       "      <td>9294</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>...</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>n009294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3308039</th>\n",
       "      <td>n009294/0291_01.jpg</td>\n",
       "      <td>9294</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>...</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>-1</td>\n",
       "      <td>n009294</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3308040 rows × 50 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                    Filename  Identity  Male  Young  Middle_Aged  Senior  \\\n",
       "0        n000002/0001_01.jpg         2    -1      1           -1      -1   \n",
       "1        n000002/0002_01.jpg         2    -1      1           -1      -1   \n",
       "2        n000002/0003_01.jpg         2    -1      1           -1      -1   \n",
       "3        n000002/0004_01.jpg         2    -1      1           -1      -1   \n",
       "4        n000002/0005_01.jpg         2    -1      1           -1      -1   \n",
       "...                      ...       ...   ...    ...          ...     ...   \n",
       "3308035  n009294/0284_01.jpg      9294    -1      0           -1      -1   \n",
       "3308036  n009294/0287_01.jpg      9294    -1     -1           -1      -1   \n",
       "3308037  n009294/0289_01.jpg      9294    -1     -1           -1      -1   \n",
       "3308038  n009294/0290_02.jpg      9294    -1     -1           -1      -1   \n",
       "3308039  n009294/0291_01.jpg      9294    -1     -1           -1      -1   \n",
       "\n",
       "         Asian  White  Black  Rosy_Cheeks  ...  Pointy_Nose  Heavy_Makeup  \\\n",
       "0           -1      1     -1            0  ...            1             1   \n",
       "1           -1      1     -1            0  ...            1             1   \n",
       "2           -1      1     -1            0  ...            1             1   \n",
       "3           -1      1     -1            0  ...            1             1   \n",
       "4           -1      1     -1            0  ...            1             1   \n",
       "...        ...    ...    ...          ...  ...          ...           ...   \n",
       "3308035      1     -1     -1           -1  ...           -1            -1   \n",
       "3308036      1     -1     -1           -1  ...           -1            -1   \n",
       "3308037      1     -1     -1           -1  ...           -1            -1   \n",
       "3308038      1     -1     -1           -1  ...           -1            -1   \n",
       "3308039      1     -1     -1           -1  ...           -1            -1   \n",
       "\n",
       "         Wearing_Hat  Wearing_Earrings  Wearing_Necktie  Wearing_Lipstick  \\\n",
       "0                 -1                 1               -1                 1   \n",
       "1                 -1                 1               -1                 1   \n",
       "2                 -1                 1               -1                 1   \n",
       "3                 -1                 1               -1                 1   \n",
       "4                 -1                 1               -1                 1   \n",
       "...              ...               ...              ...               ...   \n",
       "3308035           -1                 1               -1                 1   \n",
       "3308036           -1                 1               -1                 1   \n",
       "3308037           -1                 0               -1                 1   \n",
       "3308038           -1                 0               -1                 1   \n",
       "3308039           -1                 1               -1                 1   \n",
       "\n",
       "         No_Eyewear  Eyeglasses  Attractive   Folder  \n",
       "0                 1          -1           1  n000002  \n",
       "1                 1          -1           1  n000002  \n",
       "2                 1          -1           1  n000002  \n",
       "3                 1          -1           1  n000002  \n",
       "4                 1          -1           1  n000002  \n",
       "...             ...         ...         ...      ...  \n",
       "3308035           1          -1          -1  n009294  \n",
       "3308036           1          -1          -1  n009294  \n",
       "3308037           0          -1          -1  n009294  \n",
       "3308038           1          -1          -1  n009294  \n",
       "3308039           1          -1          -1  n009294  \n",
       "\n",
       "[3308040 rows x 50 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_classes_ids = []\n",
    "for i in range(len(large_classes)):\n",
    "    large_classes_ids.append(int(large_classes[i][3:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_classes_ids = np.array(large_classes_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_balanced_classes(Nclasses, large_classes_ids, attribute='Male'):\n",
    "    #theoretically should throw error if replace=False and the sample size is greater than the population size\n",
    "    labels = []\n",
    "\n",
    "    for i in tqdm(range(len(large_classes_ids))):\n",
    "        label = df[df['Identity'] == large_classes_ids[i]][attribute].mode() #gets the most common label for this identity\n",
    "        label = label.to_list()[0]\n",
    "        labels.append(label)\n",
    "    \n",
    "    labels=np.array(labels)\n",
    "\n",
    "    class_1 = large_classes_ids[labels == -1]\n",
    "    class_2 = large_classes_ids[labels == 1]\n",
    "\n",
    "    class_1_subsampled = np.random.choice(class_1, Nclasses, replace=False)\n",
    "    class_2_subsampled = np.random.choice(class_2, Nclasses, replace=False) #NEED TO RERUN THIS CODE NOW THAT I REPLACE THINGS AHHHHHHH\n",
    "\n",
    "    #filtered_df = df[df['Identity'].isin(large_classes_ids)]\n",
    "    #class_1_df = filtered_df[filtered_df[attribute] == -1]\n",
    "    #class_2_df = filtered_df[filtered_df[attribute] == 1]\n",
    "    #class_1_subsampled = class_1_df['Identity'].sample(n=Nclasses).to_list()\n",
    "    #class_2_subsampled = class_2_df['Identity'].sample(n=Nclasses).to_list()\n",
    "    \n",
    "    train_identities = np.concatenate((class_1_subsampled[0:int(Nclasses/2)], class_2_subsampled[0:int(Nclasses/2)]))\n",
    "    test_identities = np.concatenate((class_1_subsampled[int(Nclasses/2):], class_2_subsampled[int(Nclasses/2):]))\n",
    "    train_labels = [0] * int(Nclasses/2) + [1] * int(Nclasses/2) #output 0s and 1s instead of -1 and 1\n",
    "    return(train_identities, test_identities, train_labels)\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 792/792 [00:01<00:00, 445.57it/s]\n"
     ]
    }
   ],
   "source": [
    "#lacuna10_train_ids, lacuna10_ood_ids, lacuna10_train_labels = select_balanced_classes(10, large_classes_ids, attribute='Male')\n",
    "lacuna100_train_ids, lacuna100_ood_ids, lacuna100_train_labels = select_balanced_classes(100, large_classes_ids, attribute='Male')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#selected_classes = np.random.choice(large_classes, 110)\n",
    "\n",
    "#lacuna100 = selected_classes[:100]\n",
    "#lacuna10 = selected_classes[:10]\n",
    "\n",
    "\n",
    "\n",
    "def get_image(image_path, resize_to=None):\n",
    "    image = Image.open(image_path)\n",
    "    if resize_to is not None:\n",
    "        image = image.resize(resize_to)\n",
    "\n",
    "    #image = np.expand_dims(image, axis=0)\n",
    "    return np.array(image)\n",
    "\n",
    "def make_train_test_dataset(data_root, classes, labels, resize_to=None, num_samples=500, dest=\"data/lacuna100binary128/in_distribution\"): #classes is list of identities, labels is binary label\n",
    "\n",
    "    try:\n",
    "        os.makedirs(os.path.join(dest,'train'))\n",
    "        os.makedirs(os.path.join(dest,'test'))\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "    dataset_train = None\n",
    "    targets_train = None\n",
    "    identities_train = None\n",
    "    dataset_test = None\n",
    "    targets_test = None\n",
    "    identities_test = None\n",
    "\n",
    "    N_train = 0\n",
    "    N_test = 0\n",
    "\n",
    "    for idx in tqdm(range(len(classes))):\n",
    "        identity = classes[idx]\n",
    "        \n",
    "        folder = df[df['Identity'] == identity]['Folder'].iloc[0]\n",
    "\n",
    "        images = []\n",
    "        for fil in os.listdir(os.path.join(train_data_root, folder)):\n",
    "            if fil.endswith('.jpg'):\n",
    "                images.append(fil) #filenames\n",
    "        selected_images = np.random.choice(images, num_samples)\n",
    "\n",
    "        #train test split\n",
    "        selected_images_train = selected_images[:400]\n",
    "        selected_images_test = selected_images[400:]\n",
    "\n",
    "\n",
    "        identity = np.array([identity])\n",
    "        identity = np.expand_dims(identity, axis=0)\n",
    "\n",
    "        for j in range(len(selected_images_train)):\n",
    "            img = selected_images_train[j]\n",
    "            image_path = os.path.join(data_root, folder, img)\n",
    "            image = get_image(image_path, resize_to)\n",
    "\n",
    "            N = idx*len(selected_images_train) + j\n",
    "            np.save(os.path.join(dest, \"train\", f'image{N}.npy'), image)\n",
    "            \n",
    "            N_train += 1\n",
    "            \n",
    "\n",
    "            label = np.array([labels[idx]])\n",
    "            label = np.expand_dims(label, axis=0)\n",
    "\n",
    "\n",
    "            if targets_train is None:\n",
    "                targets_train = label\n",
    "            else:\n",
    "                targets_train = np.concatenate((targets_train,label), axis=0)\n",
    "\n",
    "            if identities_train is None:\n",
    "                identities_train = identity\n",
    "            else:\n",
    "                identities_train = np.concatenate((identities_train,identity), axis=0)\n",
    "\n",
    "\n",
    "        for j in range(len(selected_images_test)):\n",
    "            img = selected_images_test[j]\n",
    "            image_path = os.path.join(data_root, folder, img)\n",
    "            image = get_image(image_path, resize_to)\n",
    "\n",
    "            N = idx*len(selected_images_test) + j\n",
    "            np.save(os.path.join(dest, \"test\", f'image{N}.npy'), image)\n",
    "            N_test += 1\n",
    "\n",
    "\n",
    "            if targets_test is None:\n",
    "                targets_test = label\n",
    "            else:\n",
    "                targets_test = np.concatenate((targets_test,label), axis=0)\n",
    "\n",
    "            if identities_test is None:\n",
    "                identities_test = identity\n",
    "            else:\n",
    "                identities_test = np.concatenate((identities_test,identity), axis=0)\n",
    "\n",
    "\n",
    "    if targets_train is not None and identities_train is not None:\n",
    "        \n",
    "        np.save(os.path.join(dest, \"train\", 'label.npy'), targets_train.reshape(-1))\n",
    "        np.save(os.path.join(dest, \"train\", 'identities.npy'), identities_train.reshape(-1))\n",
    "        print (\"OK! train set saved\")\n",
    "        print (\"dataset size: {}\\tlabels size: {}\\tidentities size: {}\".format(N_train, targets_train.shape, identities_train.shape))\n",
    "    else:\n",
    "        print (\"Error! train set did not saved as the sizes are zero\")\n",
    "\n",
    "    if targets_test is not None and identities_test is not None:\n",
    "        np.save(os.path.join(dest, \"test\", 'label.npy'), targets_test.reshape(-1))\n",
    "        np.save(os.path.join(dest, \"test\", 'identities.npy'), identities_test.reshape(-1))\n",
    "        print (\"OK! test set saved\")\n",
    "        print (\"dataset size: {}\\tlabels size: {}\\tidentities size: {}\".format(N_test, targets_test.shape, identities_test.shape))\n",
    "    else:\n",
    "        print (\"Error! test set did not saved as the sizes are zero\")\n",
    "    \n",
    "    np.save(os.path.join(dest, 'identities.npy'), classes)\n",
    "    print(image.shape)\n",
    "    print(\"Generated dataset with IDs :\" + str(classes.tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_dataset(data_root, classes, labels, resize_to=None, num_samples=100, dest=\"data/lacuna100binary128/oo_distribution\"): #classes is list of identities, labels is binary label\n",
    "\n",
    "    try:\n",
    "        os.makedirs(dest)\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "    dataset = None\n",
    "    targets = None\n",
    "    identities = None\n",
    "\n",
    "    N_samples = 0\n",
    "\n",
    "    for idx in tqdm(range(len(classes))):\n",
    "        identity = classes[idx]\n",
    "        \n",
    "        folder = df[df['Identity'] == identity]['Folder'].iloc[0]\n",
    "\n",
    "        images = []\n",
    "        for fil in os.listdir(os.path.join(train_data_root, folder)):\n",
    "            if fil.endswith('.jpg'):\n",
    "                images.append(fil) #filenames\n",
    "        selected_images = np.random.choice(images, num_samples)\n",
    "\n",
    "        identity = np.array([identity])\n",
    "        identity = np.expand_dims(identity, axis=0)\n",
    "\n",
    "        for j in range(len(selected_images)):\n",
    "            img = selected_images[j]\n",
    "            image_path = os.path.join(data_root, folder, img)\n",
    "            image = get_image(image_path, resize_to)\n",
    "\n",
    "            N = idx*len(selected_images) + j\n",
    "            np.save(os.path.join(dest, f'image{N}.npy'), image)\n",
    "            \n",
    "            N_samples += 1\n",
    "\n",
    "            label = np.array([labels[idx]])\n",
    "            label = np.expand_dims(label, axis=0)\n",
    "\n",
    "\n",
    "            if targets is None:\n",
    "                targets = label\n",
    "            else:\n",
    "                targets = np.concatenate((targets,label), axis=0)\n",
    "\n",
    "            if identities is None:\n",
    "                identities = identity\n",
    "            else:\n",
    "                identities = np.concatenate((identities,identity), axis=0)\n",
    "\n",
    "    if targets is not None and identities is not None:\n",
    "        \n",
    "        np.save(os.path.join(dest, 'label.npy'), targets.reshape(-1))\n",
    "        np.save(os.path.join(dest, 'identities.npy'), identities.reshape(-1))\n",
    "        print (\"OK! dataset saved\")\n",
    "        print (\"dataset size: {}\\tlabels size: {}\\tidentities size: {}\".format(N_samples, targets.shape, identities.shape))\n",
    "    else:\n",
    "        print (\"Error! dataset did not save as the sizes are zero\")\n",
    "    \n",
    "    np.save(os.path.join(dest, 'total_identities.npy'), classes)\n",
    "    print(image.shape)\n",
    "    print(\"Generated dataset with IDs :\" + str(classes.tolist()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3344, 1091, 2066, 7590, 1397, 7564,  326, 7785, 4038, 5898, 2397,\n",
       "       1237, 4844, 2297, 8930, 8545,  563,  628, 3478, 3181, 3210, 7981,\n",
       "       7980, 8316, 8186, 2360, 2366, 2115, 1727, 8879, 2618, 2502,  472,\n",
       "       2447, 3615, 4373, 7129, 7120,  105, 3973, 7688, 5779, 8502, 7610,\n",
       "       5249, 4150, 5014, 5901, 3392, 3107, 5335, 3680, 2350, 7327, 3831,\n",
       "       4178, 4940, 1132, 8901, 7392, 1496, 7938, 5292, 3776, 4598, 7540,\n",
       "       3769, 5483, 7106, 8683, 1642, 4001, 2060,  560, 1369, 3651, 4686,\n",
       "       7353, 6352,  956, 9036, 1052, 3170, 7658, 5343, 3520, 6671, 5828,\n",
       "       3485,  598, 5814, 3735, 2142, 5738, 5491, 1200, 2141, 6011, 8658,\n",
       "       3239])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lacuna100_train_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(lacuna100_ood_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:17<00:00,  1.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OK! train set saved\n",
      "dataset size: 40000\tlabels size: (40000, 1)\tidentities size: (40000, 1)\n",
      "OK! test set saved\n",
      "dataset size: 10000\tlabels size: (10000, 1)\tidentities size: (10000, 1)\n",
      "(128, 128, 3)\n",
      "Generated dataset with IDs :[3344, 1091, 2066, 7590, 1397, 7564, 326, 7785, 4038, 5898, 2397, 1237, 4844, 2297, 8930, 8545, 563, 628, 3478, 3181, 3210, 7981, 7980, 8316, 8186, 2360, 2366, 2115, 1727, 8879, 2618, 2502, 472, 2447, 3615, 4373, 7129, 7120, 105, 3973, 7688, 5779, 8502, 7610, 5249, 4150, 5014, 5901, 3392, 3107, 5335, 3680, 2350, 7327, 3831, 4178, 4940, 1132, 8901, 7392, 1496, 7938, 5292, 3776, 4598, 7540, 3769, 5483, 7106, 8683, 1642, 4001, 2060, 560, 1369, 3651, 4686, 7353, 6352, 956, 9036, 1052, 3170, 7658, 5343, 3520, 6671, 5828, 3485, 598, 5814, 3735, 2142, 5738, 5491, 1200, 2141, 6011, 8658, 3239]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "make_train_test_dataset(train_data_root, lacuna100_train_ids, lacuna100_train_labels, resize_to=(128, 128))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:16<00:00,  6.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OK! dataset saved\n",
      "dataset size: 10000\tlabels size: (10000, 1)\tidentities size: (10000, 1)\n",
      "(128, 128, 3)\n",
      "Generated dataset with IDs :[5408, 1329, 3121, 6391, 419, 3374, 1178, 7559, 1221, 1659, 2724, 248, 3701, 7280, 9161, 4617, 7001, 1333, 4949, 3716, 2454, 2507, 7971, 8638, 5176, 6239, 5338, 2648, 4151, 4651, 6573, 3746, 6623, 6400, 771, 8023, 5782, 2909, 3815, 6860, 5496, 6379, 5318, 111, 228, 4439, 8048, 2947, 6394, 1322, 6138, 4257, 6481, 8972, 1592, 7445, 257, 5674, 7174, 2269, 4270, 6363, 4275, 9052, 4984, 176, 241, 7009, 2262, 554, 103, 2789, 3938, 2885, 7616, 5384, 3076, 4124, 857, 2278, 5829, 8242, 3965, 3350, 817, 3047, 7139, 1367, 5677, 7493, 1493, 8633, 8524, 1775, 8529, 547, 185, 9194, 8078, 4550]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "make_dataset(train_data_root, lacuna100_ood_ids, lacuna100_train_labels, resize_to=(128, 128))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mul-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
