{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0f513615-3e75-498d-90ca-41ae0d4a5710",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n",
    "os.environ[\"HF_HOME\"]=\"~/codes/.cache/huggingface\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28311549-8eb0-40f9-9b9f-fe0119761fb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\n",
    "    'cifar10',\n",
    "    # split=\"train\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee6fe21c-6b7f-4bf5-993f-f5aea19b5912",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['img', 'label'],\n",
       "        num_rows: 50000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['img', 'label'],\n",
       "        num_rows: 10000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "656e9fa0-237b-4f46-b8a7-8d6a9294b44f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': Image(decode=True, id=None),\n",
       " 'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'].features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1929a08c-bdd8-42c7-8711-8ba61db617fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2c03a185-352e-4f19-a734-07f3add931d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0853de21-2412-408e-a39c-5dd8812fd74e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "112772bc-5a08-4d83-aebb-1a73909b758d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label\n",
       "0      0\n",
       "1      6\n",
       "2      0\n",
       "3      2\n",
       "4      7"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df_train = pd.DataFrame()\n",
    "df_train['label'] = dataset['train']['label']\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73530660-fd62-4394-8ecb-10a3783c3f56",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0216791",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label\n",
       "0      3\n",
       "1      8\n",
       "2      8\n",
       "3      0\n",
       "4      6"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val = pd.DataFrame()\n",
    "df_val['label'] = dataset['test']['label']\n",
    "df_val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8750cb29-9d40-4aff-ba26-de2150bbf2a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a4afb6bb-327e-40fa-8455-82346198a767",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "1    1000\n",
       "7    1000\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val = df_val[(df_val['label']==1) | (df_val['label']==7)]\n",
    "df_val['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3ec96f0e-7654-4c63-8306-186206801cd9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "7    500\n",
       "1    500\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val, _ = train_test_split(df_val, train_size=1000, random_state=42, stratify=df_val['label'])\n",
    "df_val['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8cdf9477-5258-475b-bb61-cace672929ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>7405</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5226</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1363</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6615</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7612</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label\n",
       "7405      7\n",
       "5226      1\n",
       "1363      1\n",
       "6615      7\n",
       "7612      7"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_val.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dad05109-7a79-4b37-90a1-1f2df645fcbf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Index: 1000 entries, 7405 to 3027\n",
      "Data columns (total 1 columns):\n",
      " #   Column  Non-Null Count  Dtype\n",
      "---  ------  --------------  -----\n",
      " 0   label   1000 non-null   int64\n",
      "dtypes: int64(1)\n",
      "memory usage: 15.6 KB\n"
     ]
    }
   ],
   "source": [
    "df_val.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03653cf2-cfc8-4534-92a9-f77f408257cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a42f3b19-4e99-49fb-b7ec-b0969edd598e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [5000]:\n",
    "    for j in [0.5]:\n",
    "        filename = os.path.join('./data/indices/{}-{}/idx-val.pkl'.format(i, j))\n",
    "        os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "        \n",
    "        with open(filename, 'wb') as handle:\n",
    "            pickle.dump(df_val.index.to_list(), handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e25fb7-5ede-4f14-a8d2-70c9a4a7e95b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "43714f88-9ff9-48c6-8263-497db05bf47b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "7    5000\n",
       "1    5000\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train = df_train[(df_train['label']==1) | (df_train['label']==7)]\n",
    "df_train['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c4db24-32c4-4161-a320-88b960e047b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3d2af2a7-1ef9-41a6-987d-3fdb7072b984",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5000\n",
      "256\n",
      "512\n",
      "1536\n"
     ]
    }
   ],
   "source": [
    "for i in [5000]:\n",
    "    for j in [0.5]:\n",
    "        if i<len(df_train):\n",
    "            df_train_, _ = train_test_split(df_train, train_size=i, \n",
    "                                        random_state=42, \n",
    "                                        stratify=df_train['label'])\n",
    "        else:\n",
    "            df_train_ = df_train.copy()\n",
    "            \n",
    "        print(len(df_train_))\n",
    "        ####\n",
    "        filename = os.path.join('./data/indices/{}-{}/idx-train.pkl'.format(i, j))\n",
    "        os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "        \n",
    "        with open(filename, 'wb') as handle:\n",
    "            pickle.dump(df_train_.index.to_list(), handle)\n",
    "        ####\n",
    "        count = 0\n",
    "        ####\n",
    "        for k in range(256):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/lds-val/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1\n",
    "        print(count)\n",
    "        ####\n",
    "        for k in range(256):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/lds-test/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1 \n",
    "        print(count)\n",
    "        ####\n",
    "        for k in range(1024):\n",
    "            tmp, _ = train_test_split(df_train_, train_size=j, random_state=42+count+k, \n",
    "                                  stratify=df_train_['label']\n",
    "                             )\n",
    "            filename = os.path.join('./data/indices/{}-{}/retrain/sub-idx-{}.pkl'.format(i, j, k))\n",
    "            os.makedirs(os.path.dirname(filename), exist_ok=True)\n",
    "            with open(filename, 'wb') as handle:\n",
    "                pickle.dump(tmp.index.to_list(), handle)\n",
    "            count = count+1   \n",
    "        print(count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c302dd78-4d57-4351-9181-648491c1b9c1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea1a950-69b6-4d24-a4b5-4c89e5b81292",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a27c0e5-14f5-4a16-a334-67dc4c2ba738",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "824b1dd7-131d-44c1-8e84-47c1402e3b1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5000"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('./data/indices/5000-0.5/idx-train.pkl', 'rb')  as handle:\n",
    "    idx_train = pickle.load(handle)\n",
    "len(idx_train)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dff0dfff-3651-47ca-bf34-d82709894e10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('./data/indices/5000-0.5/idx-val.pkl', 'rb')  as handle:\n",
    "    idx_val = pickle.load(handle)\n",
    "len(idx_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bee47cd-8e18-46d7-9cdc-7bc73ebd8ce3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56543a8-9c44-4d97-9a46-388cfbed7920",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c19ef489-d8c3-4750-867e-3e4c2c506925",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJJElEQVR4nC3VSXNc13XA8XPOHd597/WAbhADCUEUR5GWEVdUkcyykkqqksomLn/ULLxwlbNQKSmnotgmTZEwB5kgZjTQaPTcb7r3npOF8wX+y98fR5NV5BiZyzoopZWiEGOQoDggcx2RCYsqRtBAAhGL0h9eXBY+rvd7Wpnh9WS+LNa3e2udfDZZTObVfBmKAKyVSKTo8fdvzxRR4/2iqJTW1lpEVCQKWRhXdRwtqtm8ZhaGyA0bl62aEAQi1wggYJoQBdkq0FoHprKIRRM9A6E4Y/S702sRFpYmMimltSFCq0QRCaiyiatGfKToGVGQPPMSDRlSFlxR1PPFHFA55xAQBULwkT0AcGRCjIQ6mpxjZBJMFJKKAEE4cOTALOAbCbFiBgABiVpJYjBw8LHiYBKktlOLZcESTObEBw3YSsmwUg0rQK2VriIyIwByAJYAACCCwt5HEFYSSKpWmrdbPZEYylgsl0eHx6PrKyJU2ljnUClJUmV6WmurlEcRBY2Cpq7qgLryUWIUEAaIIgCAACAiQsQRQp3lLs86Z+fXZ+eDkw/Hw/MLXxedliurpY+x0+0oo6umIuZ2q7O+uZO2W7q1lqytK5MRgg4hCDOIyP+HWZAECFgg+I6z7Vbn9GJyM6nKRp1fXqx1k81bWxKbo8MKmHZ2P9OJGU/Gg7PT1fVkXkDwsWB8sPe3nX6vqgoSRmAgFoyCEhAYRIkoitWtzGz2+9OiLIB99Kc/vlMEmNtkY20apRZiVpfD0aqpbbsvpu3FdNY37j59kHftxx/3z4+OpjcT7YxiBhCFDIDEAAJGQpMb6vfak6Ke1pERT44PptdDDVRPynm6urlcSDDiZToq0VYe6mLlleByMVu/3d5Yb58cD9su2XnykEZnB4rrdqa6GSdQ91KdqzKpp7fXOlUTxlVY1nh8cHT0/r1SlikX7K4qx9DFxMVEReuWQeomOp057Zpi2SyL3GSpcVdXowhKf//tb3prrbX1tf56CqAfPnqy0W+LTTnUJ4ObD9fFxeX4+NX3aagXvorZmiM3n87IGqtRKZXkDomUdvl6B2O9rG+OT4aJTpCS8Wg2Ox9pLEZb27ajSw32/uOn9+7fd5mtK1wuuT3n6seb9y/3k7pyCew+vlciTYtCW5caHRbJaFzNZzNKkpLKWDaxLippyCSJrowhQt0pb/RXew/v39359MGTpP9oWcPZ6UywnjfF69fvXr88GAzmDKytGwwH19NXrfXdybISAQXMzEIJaos1iqoxLoIEcj0wuVLe6sZymVc/6q1+vnX30UJvvXhxsioXaZ4eHg5/ePt6Nl9IUEprpUwpjfR2AvPMW047SMKxWM7P0pxUYkBYN6WvLmazVb7dszYr5qc3w3fN9ORO1dZEMpjz9z984KJ8dL87XV7+8dWfmkqjZBEjKgAynhyZjuXKl2OpfFUtq8VQpmc1IpOWGEMg3wyV67Y3fqogRL8Yn7zB+up8sK1vGjd48XJRu3/+x1/cu9v7cHZF9l29qLEeczVFnaS37lSz0fjykMtpU06Ea43BoE+AfGBhJkIBAGm0mOr6h8WA+23dzYLoZFxE3URaFavdnU9ub98WoYuLsggacLUan82vDkG7Tj0rJgOcX7UtiBGtdWp1ajPSSVE1vmlIkYiUdTNZrmblG0QZjMjwgmOYzRfaiu+0W4PR9b//+j+syQ7OJyEqaRZcT7Y3202A+fBjrrnVcyjMQiJRQljWTR0WkRlAYhQS6xLXyxLRrtVtj8Y3UIMoypzVVuvtW63T/UOmpaK16LHfNlfnV70W/vTJ48D4/PlzzU1VlquyjBFERESiiCADAvwV8hDqpnCpffrFg88e3Hv/4cPR+x+C95lV+mJw2YQGV/Nb25/l7czcrMbTm7C6+uJnD7udtN3pzabjVy9fiHANRhSLMAOKEIOICCICoNaqjo1z9sHTRzt37ojCweEbFL61vq73X7+yKrbaa7wYng4G+2//UpazTEOv39vdvW1tqt4feCQGitYIBGFGpMQmSZIiKY4REDXhaj7WWSo26d++8/LPb7LU5Yn54slTvb3eGgwGo+tZfzPMy5LriQ61tvmirDbu7J6dXjQxJllGpLJuHxRJCE1VR9/8ddKhaULwLCH62mVb3d6GSlv3Hj9R5eSTjf7XX/9c//Jf/+nXv/1ucTyezavalxwaEkqSdHvnU5XkWbf3+ZPP//Juv66KaeDac/RBCWsQkSawBxGttVJEIFZZrdLZqr6ZLp49e/bV3tMs7+i/+/JvNnd2//t/97/99n+aWpy2PjYxNPPSj8o4XJQHR2epdp1ckbVkM610nqYEMhwe19Xi2bOvvvnFMw7yh+ev/vT246v91/nldpiPH/7Dz1zWapqgT84vbmaL9V6+vd0dvznf3dleX+8fHB9//HgwqWNdLr2v//6bnzuMy6pA42Lkpm5ms2mep58/uvv4wb1eJ++21/prvePzy5vrQavX3+61E62quopB8Ff/9i9Xw+vpdDFfLEhRYp1ziY/Q3338yf2neZo8//7bZnZtUBpmAQQGQCzLYnw9Yu+7nezupztfffnl3t7ei9dvf/fH/d17D7/ee/TNl18YY5VK9HQVypqZAQHK5XLmJ0RojKOktbPz6en5yeDkCH2ptfaABIAAhMgszmU6BWPVzc3su+/+6/3bt61Od3ejvXf/9k/ufxLqUljSLNGt7mYIqMgSWWsy7wNzFAnT6+EffvefKNh2uUpTRAKdaAQEVkopIuBAKKlLWnl2e+vW3k8+397adHlbEWlkrZUy2hij2+1bGpMynVV56X1gZuYI0JDSxjoEZN+IhMiMZDUKoaRZlrpkc3Ntd+fO9uaWcy5vuUQJITRBFKEirY01NrHW6jztWnLOpU2rCiHEGEEEIRCyIpIYvNeIMU2TbrvXbec7d25vbGy2WjmgD74xShMSIDJHQCKliAiVIpNom1ib6NSmJKQUJImJHIIPAKAAnVWtzKap7XbSXi931vS7fZAIIIioDYcQBRk5KKVRaaMdAERQRGRdqq1LsyxJnDbap4lOVBcVK8V5lrRbbQQymlp5SiQggAhNU3sfRZiZtTGBWesky7O6rkQpJGwiG2NckgGRNkneahGpOob/A/Epv6decUL6AAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import DatasetDict, Dataset, load_dataset, Image\n",
    "\n",
    "train_dataset = dataset['train'].select(idx_train)\n",
    "train_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "315b64d9-525f-4350-b91e-9c5b403f6f68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 5000\n",
       "})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b4958a1a-66f2-4c27-a370-6a61038e7821",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJjUlEQVR4nCXHyXNcx30A4N673zIrBoOdEEjQlGkuMi2JkiwpcTlOjjmlKmUf7Cr/Vz7lkEtOuaRSLpcX2VFki0UttEwSJAGI5GCbATBvZt7Sr193/zqHfLcP717fZoxtbKyXZbHIZ8aY4XCYqPRf/+UXP/nHfy4q/W///qvf/Po/m8Zsb2+Px2NOaKfbHawsG2f/7sN/iuXy3Xvvz/L5f/33f8zmr1eGazffvH/vzseUcUChCZ5orQkh3nvvnWkMYwwhlM2m+wfPvfecR/3eUCnFOdda11pzztbWVpM4KYtSl/nNG9c6imTH+zGx/Xa7yGYP/vfP2eWlheCBffHgSwIAWuuzs3Fd1wFASiml5Bydjg8ePPy00gvKCABYZ8fjcUAoBHR6evbs+fPz8/Mvv/lqNH79+z/9ev/ga4JcNx0wyqrycu/5FyZUlLNuJ8Zr64MokkmSRFE0zaZxrJI0IbiJJG+3N75/7/1Hf310uL/nvIuUcs4xSjEhPgRAobM0uH3rLVvpyclh8D5JW2lLtlppwNHmd+5vre0KBnh1vZemSikZgJVlzQRaGnTiiHtTdzrdJElOT8e6AvBACMEEOxQIIRgTqSTGKJJSEDLs9XavvrGzfeXp02dx1JpcZi9GRzJOklbCMEKEUM5ZUWhGkRKcBIwDcQ7KskIIgfeEYISQ95ZimkRRAKi0znXZbbeGqyv333l3bbh6cT5pGt9u97W2t2/defziSW3mRSUYpUxJaZ2jxIcAFDGGqaAy7kXWuVo7KRNCXAjBOauU4oTOs5kk5NqNG+++/f3bt24RIv7y+ZeHhwdLg1671f3tH/74s5/+dGt9/dvXL4XgDAJY5zAOFMPa+vLm5s6773zwg3feSdutk5Pjzz777IuHD2ezS0xw2oqddaFxOxvrf/+jH3304YfBm6dPnpyMZ/svDrPFItfllSs8aXeevTi4ce3m0etTZxDeujLo9lLG8Xt37/zy57/Y3L7qHeZSCClUFFvvX7569eDPnz7f2yvKcjhYvvfWvTu3b/UH/RfPnx0+3/vmb4+pSGe5GR0dUYoEV/v7B42x//Djn+TF4sneEyZ5xLDghH78wcffvXbjcpGfnJ0TQijB21d3+8PVW7ff/t6b31ss5mVRdLrdJEkQuCw7b/c627u7QEVZ2/O//q1qKlNVZVVNs0vK+KuzY0JxgyzbWOtLyZVSmJBnBwezRXnwcsQofe/d+2mrGwIFD4iI9tJKu7/ivXdW27rEAbXSxDf+5s3O7z755Ph4ZF1tcQPUiYRhQh49+SpOolYrZh9/9N35LMeYjkav5rMiEP7w60etJProo/cpCx4MpTIAeOsa0xRFjl1DMCAM1tqqMnt7Lx58/rnzpt9LwUvvU4SWCGEX04WKVKeTsqaigvZbaXf0enJ1t72yvnLjzZvHR4fzxfj8QjAZMyapx+B9Ns0m5xOlIqWEEJRL2ustnZz8zyJfbO1sDnrtUOvB8nK/19O1yUujdSUkY5nW3gXv4Wx8knTSqjaTyfz46OyTP376wQ/f6/VXmhp56/LF4vHjxxcX59ev71y/vtsw6byfzspvXx0stcV31ga9tB0aRyXnCCFKuu2kv7NlbMNEygJC9WKel9NXo5dHowc4iGl2eT49yqtqfXWboGgyHk/OJ5eXl7Yxb+xsmabZ2zs4PHh5PJ5OTo5/fO/6/SubImAuxfl0CsGmnRiKUukFwpidTk4RQtwCVWheLmrnbZMbqPrDjacvnnzz5Gms+tiBcw4AKl0++ubx6OT0xfNvZ9micVjhsLO+MlREYoRFwFFgnPU6CkcUe8AIM6qobRyiSKWi21vR9eXJyWhje/nNO9ta58dHkyw/JpYQjEMIVV09/PqrJGkhxExtrTfJsti7+PY084orKpVvXBpFXWckwQxQyhWrtfbOR0IJRhHBGBPCyMaVJY9zHvmd6ytV7qqsqsrKOhvT4EFxrpwL88U0akF/c+kEipG3wWooKEWYLWbheNRJEuKA+cBaLFGKInDGo/HZNMuyXi8hDJ2enUOAEJBgQrV41OswSr0PqVgZvcoO9l+XxWLz2lrUipz3mHJKCCHAKCUEg2cFtiIWjTFsPi8v6mp1eWm+0JeXRinV68bTi1LGwjnkvAvQYAjD5eUo4pwL79Esm5dFISWNI1WXNec0NI5SjFAAHgLGGGPAoYZaKskWpbW6iePGe2atpsQhRAmKcIh0Oa+NN3WNHOjiQkrZ7XYTCdY6CCFRoswLD3ow6EulvG9CCJgSQgjG2KMAAGVRsrNxJhgOZ9P5vBwMVr23o9fHTDBtNcY4imNwAI0ryynGeDKeXr1yFSPmvRMyEoz7xoMNVFJrQwjgkaeCEkItWISRtZZ1OoJihIKTEet1Y/BNpAiToqqoMU0kWdJrUwIeXLfTwgRxKsbnF1zRuC17S11dFbPZLC9ySmCp1+aUIu+FFISQoiiwB3b3zhZG2DRmsdBtxUITVISBICmlMQGg4ZwpIQLC7Q4hBNXGtPsykEgl1CM9XO0RSgACRRBLzjkPIVS1JoylSjnrmBTBOZcmjHNVzxdba2+MTrKLfEYlqp1njAVCCeGcKd8wKsVSr+Utrqt5v9/udZTgAXwTR4pgzgglhCCEOOM8kpRSjDEjgDmmBBGPXCz5zpWti8tFeVQ1FhVFhRAKMMW4kULGUTpYXr57d8ObKQE27A3TGNW6IJg60yCERNpyzllrPUA1rznnaZoyAhwguMZTLNtp+9XhfrvFfvDWLlicZRlCqLHu+OyiLE02LbOp5kiEQOaXev/pKIp8pxNHUUwppQwbYwCAUkoZq6vm/8uqXHMupFQUh1goHDyGgFDgEU1lKqVkQm6/sVbXzWKeA8DG2iql6uRkdHB4NCuytJ2kadrtdldXusPlxNS1EKLb63DfAEDjGpa2kqqqPJBYSnANxriTtBlnCKEKV7WpccC9SMhukidMCIGCnS8Wb791LRCV5XmhC/C+aczFZOKMRAgRQvKiIJwUZcEoY8ZUSnEuaJLGVhsAqHTFHZdS6lo756SQAnDQTUIEw8x63U15p8Uwl5trqfOuLEspBOG0trX3Hrz3EJwnToeqKVmrFVHGMULGGE5pkiQA0DTNbDYDgCRJlJJOG0IICkhrHWjwGHxwDAUCKDjfTkUcxx7A2QrACUalVDJubawNjTHMgrPGee+98amMPQoYoRAC5cwDBISEkE0IURQTQvI8X1SlRUAFtd5xTAG8Mb7x3lrLGB0sLTlnvYfGFFJJKfD/AenmvEAf3YYuAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset =  dataset['test'].select(idx_val)\n",
    "val_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8955f515-fbcf-484b-b1df-522c1c326f66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 1000\n",
       "})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbad2bc0-babb-4465-a1af-193b08c222fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b657be-e172-42c7-aece-0450a8f9f1e6",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3854a013-b6e0-4fab-9738-93d410329c74",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "30ff1255-784e-4503-96eb-bc01e0826659",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[10392, 44460, 16163, 49891, 13381]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/5000-0.5/lds-val/sub-idx-0.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5c3889f3-49f7-4012-94fc-d17565d2c5ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2500\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "df904154-6c44-4afe-be75-69d7a82825bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4871, 21837, 20107, 553, 20868]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/5000-0.5/lds-val/sub-idx-8.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bfac9520-ad7e-47bb-ac83-4a3548640521",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2500\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "41d6f882-5297-4d8a-bb44-e676b2808d19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7405, 5226, 1363, 6615, 7612]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "with open(\"./data/indices/5000-0.5/idx-val.pkl\", 'rb') as handle:\n",
    "    sub_idx = pickle.load(handle)\n",
    "print(sub_idx[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2b1b3261-ae38-47fa-9612-9456abb60245",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n"
     ]
    }
   ],
   "source": [
    "print(len(sub_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46d3e33-46d8-4529-b3ab-d652752a0051",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160e394d-f981-49e8-833e-a8b6a17e7b73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116f1ac5-6754-45c2-82e3-63ddf51de624",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9987b5a0-58a0-4cf7-b845-12a0bbab7089",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4696e9c3-5d96-49e7-be83-7a47291ddb16",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7bdefa-759d-41bf-b547-318b4d696688",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
