{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0f513615-3e75-498d-90ca-41ae0d4a5710",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n",
    "os.environ[\"HF_HOME\"]=\"~/codes/.cache/huggingface\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28311549-8eb0-40f9-9b9f-fe0119761fb6",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'load_dataset' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# from datasets import load_dataset\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m(\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcifar10\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;66;03m# split=\"train\",\u001b[39;00m\n\u001b[1;32m      6\u001b[0m )\n",
      "\u001b[0;31mNameError\u001b[0m: name 'load_dataset' is not defined"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\n",
    "    'cifar10',\n",
    "    # split=\"train\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee6fe21c-6b7f-4bf5-993f-f5aea19b5912",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['img', 'label'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['img', 'label'],\n",
       "        num_rows: 10000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "656e9fa0-237b-4f46-b8a7-8d6a9294b44f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': Image(decode=True, id=None),\n",
       " 'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'].features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1929a08c-bdd8-42c7-8711-8ba61db617fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2c03a185-352e-4f19-a734-07f3add931d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0853de21-2412-408e-a39c-5dd8812fd74e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "112772bc-5a08-4d83-aebb-1a73909b758d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label\n",
       "0      0\n",
       "1      6\n",
       "2      0\n",
       "3      2\n",
       "4      7"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df_train = pd.DataFrame()\n",
    "df_train['label'] = dataset['train']['label']\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73530660-fd62-4394-8ecb-10a3783c3f56",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0216791",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label\n",
       "0      3\n",
       "1      8\n",
       "2      8\n",
       "3      0\n",
       "4      6"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val = pd.DataFrame()\n",
    "df_val['label'] = dataset['test']['label']\n",
    "df_val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8750cb29-9d40-4aff-ba26-de2150bbf2a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a4afb6bb-327e-40fa-8455-82346198a767",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "3    1000\n",
       "8    1000\n",
       "0    1000\n",
       "6    1000\n",
       "1    1000\n",
       "9    1000\n",
       "5    1000\n",
       "7    1000\n",
       "4    1000\n",
       "2    1000\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# df_val = df_val[(df_val['label']==1) | (df_val['label']==7)]\n",
    "df_val['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3ec96f0e-7654-4c63-8306-186206801cd9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "4    100\n",
       "5    100\n",
       "6    100\n",
       "9    100\n",
       "2    100\n",
       "7    100\n",
       "8    100\n",
       "0    100\n",
       "1    100\n",
       "3    100\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val, _ = train_test_split(df_val, train_size=1000, random_state=42, stratify=df_val['label'])\n",
    "df_val['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8cdf9477-5258-475b-bb61-cace672929ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>9128</th>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5468</th>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>62</th>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2465</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2810</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label\n",
       "9128      4\n",
       "5468      5\n",
       "62        6\n",
       "2465      9\n",
       "2810      9"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dad05109-7a79-4b37-90a1-1f2df645fcbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Index: 1000 entries, 9128 to 1330\n",
      "Data columns (total 1 columns):\n",
      " #   Column  Non-Null Count  Dtype\n",
      "---  ------  --------------  -----\n",
      " 0   label   1000 non-null   int64\n",
      "dtypes: int64(1)\n",
      "memory usage: 15.6 KB\n"
     ]
    }
   ],
   "source": [
    "df_val.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03653cf2-cfc8-4534-92a9-f77f408257cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a42f3b19-4e99-49fb-b7ec-b0969edd598e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [ \n",
    "    50000\n",
    "]:\n",
    "    for j in [0.5]:\n",
    "        filename = os.path.join('./data/indices/{}-{}/idx-val.pkl'.format(i, j))\n",
    "        os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "        \n",
    "        with open(filename, 'wb') as handle:\n",
    "            pickle.dump(df_val.index.to_list(), handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03a4cb2-ccdd-426b-8952-059366f79cf2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfa5e0e-2927-4787-8e03-17aaa46ecdfb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e25fb7-5ede-4f14-a8d2-70c9a4a7e95b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "43714f88-9ff9-48c6-8263-497db05bf47b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "0    5000\n",
       "6    5000\n",
       "2    5000\n",
       "7    5000\n",
       "1    5000\n",
       "4    5000\n",
       "5    5000\n",
       "3    5000\n",
       "8    5000\n",
       "9    5000\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# df_train = df_train[(df_train['label']==1) | (df_train['label']==7)]\n",
    "df_train['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c4db24-32c4-4161-a320-88b960e047b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3d2af2a7-1ef9-41a6-987d-3fdb7072b984",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5000\n",
      "256\n",
      "512\n",
      "1536\n",
      "5000\n",
      "256\n",
      "512\n",
      "1536\n",
      "5000\n",
      "256\n",
      "512\n",
      "1536\n",
      "10000\n",
      "256\n",
      "512\n",
      "1536\n",
      "10000\n",
      "256\n",
      "512\n",
      "1536\n",
      "10000\n",
      "256\n",
      "512\n",
      "1536\n",
      "20000\n",
      "256\n",
      "512\n",
      "1536\n",
      "20000\n",
      "256\n",
      "512\n",
      "1536\n",
      "20000\n",
      "256\n",
      "512\n",
      "1536\n",
      "50000\n",
      "256\n",
      "512\n",
      "1536\n",
      "50000\n",
      "256\n",
      "512\n",
      "1536\n",
      "50000\n",
      "256\n",
      "512\n",
      "1536\n"
     ]
    }
   ],
   "source": [
    "for i in [\n",
    "    50000\n",
    "         ]:\n",
    "    for j in [0.5]:\n",
    "        if i<len(df_train):\n",
    "            df_train_, _ = train_test_split(df_train, train_size=i, \n",
    "                                        random_state=42, \n",
    "                                        stratify=df_train['label'])\n",
    "        else:\n",
    "            df_train_ = df_train.copy()\n",
    "            \n",
    "        print(len(df_train_))\n",
    "        ####\n",
    "        filename = os.path.join('./data/indices/{}-{}/idx-train.pkl'.format(i, j))\n",
    "        os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "        \n",
    "        with open(filename, 'wb') as handle:\n",
    "            pickle.dump(df_train_.index.to_list(), handle)\n",
    "        ####\n",
    "        count = 0\n",
    "        ####\n",
    "        for k in range(256):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/lds-val/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1\n",
    "        print(count)\n",
    "        ####\n",
    "        for k in range(256):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/lds-test/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1 \n",
    "        print(count)\n",
    "        ####\n",
    "        for k in range(1024):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/retrain/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1   \n",
    "        print(count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c302dd78-4d57-4351-9181-648491c1b9c1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea1a950-69b6-4d24-a4b5-4c89e5b81292",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a27c0e5-14f5-4a16-a334-67dc4c2ba738",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "824b1dd7-131d-44c1-8e84-47c1402e3b1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('./data/indices/50000-0.5/idx-train.pkl', 'rb')  as handle:\n",
    "    idx_train = pickle.load(handle)\n",
    "len(idx_train)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dff0dfff-3651-47ca-bf34-d82709894e10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('./data/indices/50000-0.5/idx-val.pkl', 'rb')  as handle:\n",
    "    idx_val = pickle.load(handle)\n",
    "len(idx_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bee47cd-8e18-46d7-9cdc-7bc73ebd8ce3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56543a8-9c44-4d97-9a46-388cfbed7920",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c19ef489-d8c3-4750-867e-3e4c2c506925",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAH8klEQVR4nHVWS68dRxGuqn7MnDkz53UfwdzEQgSEEsdCihSJLRKIn8KGHX8MwYIVGwRS2CQECSlRYsy149j3Xp/HnDNnpqe7q4rFwcaJRKvVqq6u/r6vululxt//7k/wqiEigAIoIsL/b6r6pvFqqqpw6iIiwiLMzJaIVF8hIiIogL4J9x2y18H/2/UGkyqoIiIi6mnRMvO3UfRNQEQUEURQhdMIAIggJ0u/m8fJLaIiIiKqSq/jAABRTzpO0CmlrutOK6/okEWGIfR9fzz2J1nfEvffE1YiOvntKdMxhO1uwyqXF5feOVXNnF/c3kwn1XQ6PVGKyEnp3fpldzwu54tpVTHzawIBUOHNZt227Wp5Nq1rALAIGFP86vGjJ8+eKMBPHzy8uvd9MrTbt4+vr3/8w3dPClSViG5f3qWURCWEYX7/fozx5ubGWnt5cWG9e3l3e3397+1m3ff9Yr56+MHDqqrss2+ePXn29G5zF9MIgI8ePzoej9O62m7XLze3rDzGIAqL+Xw5n68367v1unT22Pfdfj+m+OLutq4bAQlxfPny7vrJ9RizAEyqJKoIiL/+zW9DGIwhBCAkYS58ASCJx8RSWF8V1aEfp5NyNqu37f4YRkIForPlqu9aREC0232bU64mhUEKfRTRh+89WJyvYorWoS4vz/f7w6HrSu+dIeaUUET1rK6naMcshTUied/tvXf3Lt8qCndo2/bQjmFoCgekVenZGlZpJsVbswYFQYbHT764ubuznty0nPSHHoEuzlZ1WW13u8PYHzkNx76oKlHxjqp66r1HkTT2OeiExTuno2HRUUZQEElZ+NALAiDQYRO33X4YBouZU9dbxKauJmVpLBmLZ8V0lotDe2QW620a03DsIBVGdDf0ibn07rKq31mdDSkdlLPw0/bGGgPKN7t2GOOibqbFxBtrl7OaCSbolmXFKQxJQxxZdV6WYkzLXIlNLKEP63RABTfxdVWy5IOIMzQIl0XZxyAMGYA5xZwzS4rReduNo51eXmx3e0WDzk9KNco5hpggsk6c89aiqvXGkt9qQLJNM00xiUIGUZVuHIuiWDT1vXuXSaTdbEi5tCQ577qui9EaY2LOOStLsiSQQ8o8qcpuyDnngjCmXE/L89XCHw5A5nJ19vxuPYY4KSfOmv44kDVIVLnST4rxeDx2R+9NyGlIkJnttmuFREkF9DAMm/WdClRVrIpi1jQiHFPeh3TMrXG2tHa72XLOInI4dMYSEj57/mI2nzEz9zmEITMXgCpCCARgl1VtWLPJRVl4v1yv1zEmhNEZo4UW5SRkHmMOQ7g4XyHQs5vn/RhndV14e2yPfRj6IYjwbFYfDp0hOlvMC+cRAQhE1BpnACFlNlmaabGom+BDymmIQVl8EVXUO9dMJ6W1YewvLs/b/SGEYQhgyNRVUU8nhMQpzZv5crEahmNOUQFzZs7Jqoh3zhAVRQHAhlg4obJB5wtrDQiCSs4x70MvqsZQ6cx2108KX1YVC6eUAYCZc46SYBxH4QxIqmiMtaoQhhhjGsc0nzUqyHlUFbQ29IO8quWqAISoqqDCGQSHkBIfVDTnbIz1zvX9LrMoKPOpPgpntjlxHHMfRj7I7c1huwlZRzTEmoA5s6ScrbOgYJHcGMFiKux0MgVDIYyl9yIqkp2xgFAUJmURlbqqrDEibFVVQdbr7aNHT1WtpXo2nyyWxMIpR0O2JJwVE6eqqnUEX03h6gKNsc6klKqy6IdRBLyzaMBbG8YYwricz+p6IiLWWHP51vl0Wt3erNfb4/lZUcc8bmFx1VTTlSGSIUuI2veJcMCMFs8Wc1Xw3lRlgYgiqgAGEQgRUVQ5Ze+stZaFrYgwc92UH3308JNPPgcXcRh3Ldl68fDBT5ar2X7b/uOzTwOBEIIntrBgRkVni8nEqwIiiQqRQ0RFAQF1zCoCqog25xxjOgaeNtXVvfNP/v6VFW9NefN19/GfH//8lx++98GDt++/k7MoIGv2zs2bGSBaa5ylnKUfxrqaxLE3gFhWzDnHFMZhv9/t91vbdV3f94eud95Yn5u6GHpzdtFcXV3ePh/++Ie/3X93+Ytf/ezsrAYwqgKoCJiFVSAxHA7Dx3/9bHm+yk+ufzRzzQfvm/NzY33tCrLFGNnu2w0irRZNCOH5dlcvidx40/7r6/UX/ZHHMX/6uT6/u/7ww/dnzXTWzJz1dHqCrIf+sNt3L7cvHl8/cuMYbCjbjf/BfWOMsS6zGGOsMdYYMs6eNZdFUTlvRXiz3m7b9ng8Hocxx/zN86f7v2zPV8vFYsGi3jkEUBVCFGGgCGYYS/hnVr/dVqFXAmadNzNrjFUwMcvtNy8AoPKOiJpqagnfefteXTfdoYt53Ky3Ctq1uy+/fLTZ7VarZemcMfTO1dVsVjOPbbsh8kRgEdViVlnv2nldA4Jtuzal1IeBVClZJZI0JNCwH7rjXjgjQDVx1lrvyFpaLabL+dx7r6DWqGgmgu9dniERCDhUIEKlWVO17ZoQbe0Ii/KsqQjAEACSijIqqKoIWUuAqiCqflLOJwXiOQICABIioKg4QPBeVdAaRAhx5BS9AUkRkSykqIgAyKCKqK+/kCcYUAZFBAKE0+8URQFVFE9RqioMCKJCaBANCZOoApAKiNghDsxiyGaNpKAIgOjIEKpotmhz5owyxOTJO2MVGBRUQFAJIQsjIYugmtPNK0hmQQEkEOH/AIgrr0LkBku7AAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import DatasetDict, Dataset, load_dataset, Image\n",
    "\n",
    "train_dataset = dataset['train'].select(idx_train)\n",
    "train_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "315b64d9-525f-4350-b91e-9c5b403f6f68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 50000\n",
       "})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b4958a1a-66f2-4c27-a370-6a61038e7821",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAFOElEQVR4nN1Wy3IbyRHMrO6ZwWMAEFrKJCWSoq1Y+7COsP//J3TUwRE+6MHHhkgQIAbz6O4qHwaAKJKr215cgUAgujGZnZVVNc3z83P8mSF/Kvr/BYEnCcDMnm1tV0iagdwvGwCKmBnMKPLCo49CvPciP+gg4RxFCBhgJEgDDKZCkCBsOCh2vyH8WUiMUVX5/YQA4H2WZZmIDAYD7z3JXqKZOefGZZnnOYAsy5yTnwqA2C6ebDjnnHM9UJ8oca7nmM/neZ6LSJ7nWZbbo3iB4PmSGULozCzLspQSgD6NPYSqhhCccyR7lT3ri+gvEwBQNVLyLBcRJ85U+48I1XSzqcx0NBwOBoVq2pfDi+F7K54oABmjZh6OEmKwGBxxcHgwGA2ur65F1IsNB9lmveqamiDlD8vdP0EHABqpIpK088LhMOva5L07OT6czmaZk8Xi/nZx13VBk4IO/Fkz+edLpAGatMvdoCyHuXchFE5kMh6enhyt7hefP62SEnBGTwrwQyc9OfHLJgOIMa3XVYp6fHRyePhakz4s7+tqVXg3Ho0AGsTEK53tcPfQJEWkr0NeXFw8Jdj+CTD1wmk5Ho9HoWsPpvlsOsoH48Vq85//fmoCxBUAHfoG5BMR/bd3oqQYARCg7bdBwAirmjaoluORiG/qNkRMysn56enl9beYQBHZpaEv1r2Ons9PxznIBAJipIFmZikZDCAMIJVQQkEFU5d8lubT2d3tso4tjHiUmRdMno6HyTRENYrtRKhKPz8oYqoxpRi6mPygKOqm+/Tlaxe0aaP1Dv/YZfuhslMwKkKMtXUhRaNQBJREJjPAvHM+L0LokqYQQ0h5VTff7u4V4lwOCmBAwg4UzwazLwdZTOIF1aZOpnmegy4k7WApJWjKity5oq7rpmkUXD5URnEuNzg1EqDZ42Z+OjezzJlpkecxpaYN3jmf5ZISiLZV0yREkecaQ93F1aqqm0BxBtq25Ix/AN2Pey/ifS6hDQpROnMZnGdSmsLUCTPvvBPvnDjp2g6SsW9dMzLt6vGZz7tEeRFmlFqjhs5CQMyMFkIbQuB24jtVte0TBNhP1j0UH+X8e40CAIT0dduGLtR1PSlHx2VpautNHboupVQUxXQ6NbOmaVJMqkwpPR2OBtD2Cr5zbLtZ/P0mEMhH09PTN5Oy3FTrcP07H6q+M2OMZVnGGNu2a9uuPz6AXgQACnsFj18JJCEU70ScPzx7fzCdeoGluKrbTd0Vo/Jf/z6LMX78+PHr5eXff/11Pp8DfGiWzknf5iKiqj0WbTsntpQkSRMaaTD/9vSv08kkyzyBZlPd3Fw39frw6OjVfJ4Nig8fPiyWy3cX7xS4r1prgplRaKBzHhQaTBMMlN567vw1qBrgp5NpCMEMo9FoflhODn55WC2qh9tvi8Vs/mo6m19dXR2s1vNXrybLTd3ekr3TgDhSCNLkyYggtzcSAP7i4mK1Wi2Xy8Vi4Zw7ODg4PHw9m5X3i7urqyujH4wn17/f/vL6+Pj4ZLWuzMxABXsCGMW0d7S33+xHgrquAZRlOR6PN5vNarUSIUzFFX97/4+3Zxe3325vbm7qJr59e7Zcre6XS4pLoNq2halKQkT29yuSZgoYSV8UhZn1dxAhq6pyzheDIUmI/OXo9dn5+7quLy+/Fl5+++2fn79+qeumS5rUADE1aLK94dyWWUoRMBHxJ2/eqOrnL5/XVdW27WA43GokN3UdQpzNZqPR6Pz83eZhMRrlIlTg9u6+DVGc06imqin1N5r9BSmlBJj37n+v2zva5lpxsAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset =  dataset['test'].select(idx_val)\n",
    "val_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8955f515-fbcf-484b-b1df-522c1c326f66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 1000\n",
       "})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbad2bc0-babb-4465-a1af-193b08c222fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b657be-e172-42c7-aece-0450a8f9f1e6",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3854a013-b6e0-4fab-9738-93d410329c74",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "30ff1255-784e-4503-96eb-bc01e0826659",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[10936, 11178, 26489, 22034, 14307]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/50000-0.5/lds-val/sub-idx-0.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5c3889f3-49f7-4012-94fc-d17565d2c5ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "df904154-6c44-4afe-be75-69d7a82825bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6591, 35601, 1087, 17631, 28857]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/50000-0.5/lds-val/sub-idx-8.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bfac9520-ad7e-47bb-ac83-4a3548640521",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "41d6f882-5297-4d8a-bb44-e676b2808d19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9128, 5468, 62, 2465, 2810]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/50000-0.5/idx-val.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2b1b3261-ae38-47fa-9612-9456abb60245",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46d3e33-46d8-4529-b3ab-d652752a0051",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160e394d-f981-49e8-833e-a8b6a17e7b73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116f1ac5-6754-45c2-82e3-63ddf51de624",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9987b5a0-58a0-4cf7-b845-12a0bbab7089",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4696e9c3-5d96-49e7-be83-7a47291ddb16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7bdefa-759d-41bf-b547-318b4d696688",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
