{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f04439e8-3684-4e71-a58d-1b222008c74d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n",
    "os.environ[\"HF_HOME\"]=\"~/codes/.cache/huggingface\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "379a0758-ffa9-4c0f-8828-6976b5752918",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a7c672bc-d87e-450e-bf89-86799b5ba6ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate.utils import ProjectConfiguration, set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8820a3ed-9b69-423b-868b-79c11c67762e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "def set_seeds(seed):\n",
    "    set_seed(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "set_seeds(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9bc9238-5dc8-4d43-8551-a2ce62bfd120",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "from diffusers import DDPMPipeline, DDIMPipeline, DDPMScheduler, DDIMScheduler, UNet2DModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af017ec5-da0e-43ed-b9d1-8544336a4736",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1c7aa6f4-4adf-494a-bfd2-dd414ef7d8e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'config.json'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Args():\n",
    "    \n",
    "    model_config_name_or_path=\"config.json\"\n",
    "    \n",
    "    dataset_name=\"cifar10\"\n",
    "\n",
    "    resolution=32\n",
    "    center_crop=True\n",
    "    random_flip=False\n",
    "\n",
    "    dataloader_num_workers=8\n",
    "    \n",
    "    seed=42\n",
    "    gen_seed=0    \n",
    "\n",
    "    train_batch_size=256\n",
    "\n",
    "    ddpm_num_steps=1000\n",
    "    ddpm_num_inference_steps=50\n",
    "    ddpm_beta_schedule='linear'\n",
    "\n",
    "args=Args()\n",
    "args.model_config_name_or_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11f8d9d3-a6f0-4113-802d-964b459666de",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./data/indices/5000-0.5/counter/idx-gen-sampled.pkl', 'rb') as handle:\n",
    "    test_index = pickle.load(handle)\n",
    "# test_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae2ce82-e7e2-41bc-a208-8c63c4262b38",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4b8133-2eda-41cb-9ec9-85b8e964ca55",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df86fcb8-af99-499e-83f9-b60d74b7acc0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "04f73514-a407-49c6-90e5-22262cec9def",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./counter/saved.pkl\", 'rb') as handle:\n",
    "    index_image_list = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b685379f-65f1-4c98-990b-5bbc39b9cd15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI3klEQVR4nCXWW48c13EA4Ko6p093T89Mz+wsd5e7S1K8SaQo2U5iRzEcwEYcJEDe8/f8C/wcBMhTAD8koRIDkhCS4p0il3vj7Fy7z62q8pDvV3z4u9/c3/hkDDrCyhlLPG2rB/duN3VFhKAaU1alwlbbzdbHKvrym9/8ers+Wy3PO7+5ml+dnX0E1MLZsrSZs4ICgXBOMRGAZU4GckHkrJlNmuPD/fGoqpxFskQkKiwZlBRo3XeL5WZ5FW+dXAPpmPvR0A7r2d7eMOcYcwox9d6nnI1BayjlFGO2O8NyMJrUdXW4f3D/zr3ZdCac+75jzgggDN6nrk/vP749uZgrlzs7zWLxkzVQVgZNUThqxkOW7H0MIU+gZWFAJAQWFhX7y589PDy+PhwO23bajieGChFQ4ZhCjJGzbLvYfbwQgNl0Z9Q0o6a0Rg0ZVxagkFlS9iGl4CMh2sIaQ4iIiGRQQexnd44HTVW6yhpC5CxZRAgNYFZMgpLYh7QpSxoNR1Vlcup9lOFgBAA5ZyJSVRAwZMiAggAgAooyC+ecbMB19p311cBFpJgkCUvpqpRz1/chxE3XkQnWirUkkkNMqhBSyiKiYskQgopYY4xFRUBAMtT7XkRAwTbDAZHZXnHQsN2qCAOCCHe+X2yWwXexjyFoCMGagsgWBapKYhZVYwwLswKAWouIag0BAholAmYFUNv3GUHn8467T3VtBVSUGaQLft2tWRKnImWXgvZBDWWAiJABuK7Kui5BNaVgrQG0SBRTzMymMN4nQHCFsf0mAlgC7H0X+sQCWSSpBubMoFD6zvjghE3yFiACeFckMhx89D5ag8YCIKQuN4M6Z71abclYY0xZWkRj66IqbL3NW6kNAeaEm56jFtt1KpsdV44zC+XknJnttDFtor+0pitstCaXDl1ZNE3V936z6euyBjUEpbIqUgiKoJY5i2xDXpsKDBRUVCcfzl/8dCZ2Z/+46Rd+s+6qqh6PBpjIbwVioYwG42xKzmlK0XvofZ9SiiEJo++TdcaVhog4i02cswS1DFqEbH2ugtu/+1e/qid756fnob9Y+xBENn4rKafgS0uVQ2tITGaAgpKPHlRVxYc+RfU+1FQB6mQ6VMmWRSMzM+ZYeF8L7rW7g9OrxYvXfzaoIeWcA4kPOeSUEU1ihLIaDpvLcJlXcTYCJDVojKW6dk3jrCuWm03vUXSUU7Q+0GoL1pV1uQc4ZWo378++//bxdKe5fffuyemlpBA5cvJkrKuKmFIMYbkGMk1h6MHRDb88NcDNoBgPm7oatJNsLjCLdl1nAex87q426fe//Tui0eP/efbu/bOz8/PV/Pzg4MHph5Pz00vJOXcbzZ4Aue7IusR5Me+TmmuPvv6LX/5zTeHHH/5kcWFNAFQFadsRkhHhwpBNsf3bv/6H63s3WfSL29lm5u3VN7/4xXRv993J6aMHX87aqTM4qiplzknBFGqIHBpXPHzwVTs+HlfF44s/WQjDaVaKSGY0HrmiBFBCxD/+4V9u3LybUkea+vVys7wKnG3TZKTOp8Pjm23bls4VhZUkr159+PP3T3yOhct1aY1oZt6Zzfz84qeX3w3b7c61ggoiIkNISIbAHu0fWUBDNnbrfPVpc3qSENqD66PJzrgsZ01TDSrrnLXWDt3hMT7+/vl8vbY2Qc79dmsL49r96dGXH779zs5XVT27djgmRGPAEKqy7dcLCQFRrs5PF+/f+cV8Pp+//OG7cjw+vnFzfziYHUzUGjQGiabT5uBg/3yxWW+9sAoMrGrUYtmjG+1OCiIwTdWikRi3rAKglmMKDGVtF+vVD8+fDkFrayvNvF5cvM3/+vbl9MatR3/5q6Mbt6rBICx99n23XPsU2rYdjQbK67A6L4f50YPbj27/3OLS508+rTPHkD0zW0tGVYbt+M6DBz/++GR7dYmSjXOuLtfef/j4Yf38x58+nhwdHI+KUqlWHM0mo8zN7qydta5brS7eP7n16GeH9/dGtWzWebtZJ/aiimAA0RbOMKuozPb2/+a3v/vfx//hP52iK1ecs3P1tPn88ODdydnTJ09uza7Pdo+K4/vtuAy+J97U1BjLy7S2YV7B4Opi8eLVc4FuMhkgWVUEZasqZEhFAPDO/c8t8H/9+7/V43ays9fHmF8vED0YWal6V11uu+7tKyodpF5rA+3Bzrje+fKLQTPw2/7p05evXr+5vj/ZbXcBFDFnYptSts6ignOOLH3+1VedXz979uL6/pEYuDx9ttpuPq2XWDmpyovF+vL8sqiKtjINDZXb8WgHB9V62z998+Y///vbojDHBwcpoCkMWjLKloVJSUARqXAlof35N79mU3ZeRvWwGe58/LCYLzeBKiQhCxx71AR2sF1v37x5d3FxulwuTy8/PX39crVZ37txw5ZlzNkggoICW7IGDSlojLGoHJCpB6OHX3/9/t0ZKBRlu+okq/14clqgPdo/vrbbFq4k4bNPn16+W3b9ettvfQrbsB001fWjfTcoBZnQICHo/z/DWASIIRTBWWfJwmjcHt+wq+XKlPWbj2dX6z4zPH/z9upqNRo2KcXVarlaLzklJCoKqmvz2fHewf5uO65y9q5wZNAYK6IWERFARaL31hmFEhkVxBgtC7rz2a2jw6P502c5+Zjjh7POnCsoi4ghGI+b/b296aSZjMvD67OqqizZSTuqSscqCggqFhBBgXPqQ6+oCKCgrFklp9hNx80//ePfP/zi3qs3r88/Xc4Xq+22F8mlMzcPD+7cvLE7HVelNcRFaYiMMWVhHSAKKzOLilWFzMyQY4qIZMggAquo5hRiCmE6HA2/+Pzhvdvrzerpi5cfzy8RdDKs792+ORrUwkkkIQESAJAi+hBizkk4c2Zmy5JBAVQBIOXUbTsiUlJClawsSmisddYYN6FH9+/eOj4sjK0cla5Q4RBEFJAQAEQ0i/QxdCGElGNOImI5MwIQACpoZi9dURggRNKUE2cREVAEBVAYNoOmrg0RgogyC1K2hAZAWURBVDVy7rz3OTMLM/8fV/LYP1y7WlEAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index_image_list[0][0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "da93aa94-9e34-463f-aacb-ec3305b1ae83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAIy0lEQVR4nC2WSY9e13GGq+rUOXf4xp67OTTJpiJLciQ5UahNfkFgZ5f8hPwuBwGyC6wEWRhJENiCkcGQgoAaSJEUyf449dz9jfeee4YqL9q1K1QV8KA2z4t/97d/ZZktARu0ltigEiiAc4UxBhFR1RBZ5wBtUgSyhogkoWZAFc2akvgQQlwsFnW/1x+MfOMBFIiEmB0TM1pjCsfGELMBFQF1zhljAFAFCBGURRGJjbEGEUBBAREIKGcNqimnqiwMIAEmzVkyAwMiE6ohMAQiWUGuoQEghICEhi0BKeA1EbEltoiAYAyqIUICLQMZoyCQRQCa0CgqGUQCQGVrDaIAAhAKQkJABRFFIkIDOV3PlMuiNxbFToSQKufquiwKp5o6v4yYMa1S6gwikBpDIKqiKomtIwA1jK4w7Awg5C5JEEOMhgAxpxxjLOs1a2slw2T6dW93a/v2rb3RcAiYQ2iOjp89ffZ/Ka9S9snnbgaxzTmLiHJRs2EoatsbVLbgtlk1i05RQLIKAQEoeh8wSWUL60rrbNUfusFwtLmzvbnJjGTUDfR88ZzLnqi/OplddZ1vQpeSpMxU+f6oqgboqmyYbL8wdVzNfDMNsc2ELFFCTBWZqqrEGGTjgz85PR0N+sbCsK5DCo++e7achwFi03VAXPa4ucqgMYtwPSRbJ+EQADQQCIOTzCFpVrXedzEIFuX27o319Y3L5Wy5mvnGW+TxaHDj5tbp2dmjR98+fPgQcXX3YNgbj1xZNhczkZVBF1Lk3JBPIBSVJHTR+4QEqmK4MIW2vnO9wd69+3fv37956+7Zxcm33z1sF5dY1GXpVm37m3/79//4z1+/mrxZGw/evtvcP7jx3vufrW+Or97NKSYgYr9IdT30Xou6HBS1X52enh3F3PX7w/Xxxp98+N77H37aG22iG2tO2fut0XhnvFHX/YP9/ZjTF1/88+TwsO5Xou7wxfybR8efPSj/5q9/fvbmaCG57Tzf2N/dXr9V2tHuzXuD/vj09OjN68myWY43Nu/cO9jeu1X1h5lI0FqwB7fvVFVpGBeLZb83+Pa7b+fT+ebW1ucP/nw0Gn/z+Em9trdcdJPDV87ZqqqaxuHj///HUW9c2e2i2gYiBUEAZmtsgcRClBHBgAqiKCKAqKIYNvPp9It/+tX/fvX7j/70p/2qQiJX18iuXfnl1Zm007Sctr5ltJUYmwhJAYGR0SAJEQgRAgARMioAKaDknFOOKnm5jD88/n68NvjFL34+HI7ni1mIUUAlZ03+9PS4x1ISWmsZeAiuj3bIRQ1ogBCRAFGBFBEVSBRRBTV0PsVOYgySposFIu3duLFctaenp/PF3F9X6yeHk8nk5QcHt25vj31WbpNSJ8GvjBlWRSEqgGiIVBEUAVQ0ikoMvvNd1u7pkydCJGDm01no2ul0Pp/P2rYNIYSuW66ao+OTyeR1QTruV00W/vp/vro4vzp6d/7g8798/+DD8WBgnV22y9FobJkBoaxKRJAYkqRH3z/6+1/+w/69uwfv/2S1XHS+nV5ddd4DgKpKzsH745PjuqpCjPPFigrHZ+/Of3z28r/++/ePv3+2ubG5tbG5tjaeTCaf/tnPPv7448nrV3/x4LPd3c3p+cXDh9/867/8+svf/u69D47q/iCmmGNIMYpkzaIiqsKGmqY5uH27X9ByuRjZETPYFEWVXr6YfP3V18zm/sHdjbWNV5PDO3fuHL489L77/MGnrw9ffPnb3yjoRx999Pbo7fnJSX84iDGqCIhITillyflqOssxbq6vjXr28vI8p8St72bTWQrBumJnZ1tFQghEcHl+9rsvv7yaz88vLwe94s3k+Wq53NvZu3nj1tX88uTkuKoqzZJiapu2Wa6y5LKoXr9+61ufUyIsc1bNwAAwGPQP7t3OKs7yarlSla5dnZ+l1jdVr7+3velXyxA6BSmcY1vu7uzMZtOUkoqkGCVnYlKht0enT58+Lwqbc2p9x8Yu5g2vr28NhheLla9Kt7e3hSpkjLPOsEHExrf9XpElFUVljd3d2Vn5bntra/L6VdM0XedD18UQlqvVZPLm+Oi4LN3m5joC+K4DRVFgNvaTT38Wga5m89ky9CpbOlfUvbouy7IgMsycRS8uZ83KT+eLwXBgjHny5OmLlxMVYea6LpvGh667f39/PB5bQ6rJt5mNQUA+PTna2b3x0w9+8vjHlyfHb1aNGfYH3gdnTd2vy7Iqi9IW7uz8ar5cTedzW1artu26UFVVPew761QkuVRVRa8uDQoBSM7MxjKHGPnq4ip0oazrYekuAGfTeWELY5nRsisNW2Jm67zvkK0P6fnkcL6Y79++uXtjm5AWs+XZ2RkRWTY5JVOVzjnLzIZVhIj46PiYCF1RxJSYdDmbP/7hh7oq9/f3P/nkE7texJT9bDZfrES1DZGZ5os2RHn54lUMKadcFK5flXVd9uteVZTWWSKTU04pZ1FeNW0IXZaccsqSY4jBd4vZ7OL8YvLycDgYsjPedxdnl2DsdLFYX1t78fzHHDvnnCuK0XgwGvSLoiicddYCYAgpxDalrFkAgeuqFBHfhBBCTFFE19fWYogxxavp7N3b4xiTKJCCLZ1Cml6cjgfVsL/Rq2tAJENEZIwB0JRikJizxCygAKBEyONBX0UMm1Wz6kKXUzaG2BiXbelcKOsUBACqsixrV/SKQb/XL11h2RiKKceUREGyZBEAUAVRUEBVJUBVYHauLIrReDxdFJcXF2IkZWOMAVUiNIiGkMk4wz5nW9i6KkvLzEZVjUER0JxVr4kBAAGvUxxdt7y2uenbNoTWEJZFIZKTGJOMihIhGbTGWAZHpseVomFDxhgFEAC5FpIIIDAwIgJgVlEFVcgpqyrPF8umbbrQARIRItD1i0QEEBCBDRGhElhmQFZQBQCALKqqAIDX6oM/WiQnBVDJOecMiPz27TuNIasgXiNfg1zfqoISIv3RxYqoqCBZrwPy9Q4iqiqAppRUIOUcUwQFVc0EfwCX/45z0VRJogAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.DataFrame()\n",
    "df['path'] = ['{}/{}.png'.format('./saved/5000-0.5/gen', i) for i in range(1000)]\n",
    "\n",
    "from datasets import DatasetDict, Dataset, load_dataset, Image\n",
    "dataset = DatasetDict({\n",
    "\"train\": Dataset.from_dict({\n",
    "    \"img\": df['path'].tolist(),\n",
    "}).cast_column(\"img\", Image()),})\n",
    "gen_dataset = dataset[\"train\"]\n",
    "gen_dataset[0]['img']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "41122a48-e6b2-4fe5-a0a5-349ef6d2d824",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(index_image_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc1816e-129e-437e-b6cb-839b9374cd9d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cb69ca39-2130-42fd-8f25-e07420fc4c36",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI10lEQVR4nCXUx5Kc13UA4HPOPfeP3T0dpicAAwwiCZsEHCRtZJXKKpW90GvqBbT1xhtXqVgs0SUSEAFSyGEGEzv+4aZzvPD3EB/+4T++bvvAjExYWDJGhoP8+OhGkWcIgKA+xJTQmrxpWuez6Mtf/OqXbXO5Xl90rlkslhcX50iaFzbPbYhBQZFAJcaYDCKn6BEiARHioBrt7+2WpUXEENQgiWrXRQAGC6umXa+3zToe3doF6JO0wwEPqun+3iBKCDGGENuuCzESAXORUowxcZmZ6aQqinxvtn//zsPZZBa8b9pNDAEAQLTvU9v5D6fvz66XqPl4Wq/WJ8ZoXhgwajOqR3VKse2893EwrGNMgIAISZKC8uNHD/cP5/VgONmZT8dz5lwEVJIPvXcuRWka9+HTKbHZm++OR4PhIDcEhMiWkTQptL3rne86R4DWMhIQIhIigarwrePDqi7yvMoso0lRW1E1xoAkMBFAhHyEtihwMBjWtY2xD0mroiaDMSUCVAUCytgiqYKAACCKSJLog+dONrFvue99H4F8lCAieVaEFLqu88F3Xc/WZ5lYCyKpd0EEiALFJJIMGQIAUTbGMCqCKhhDTduICChwXReGuV2KA9e2GlNEVEmudd2mWXvfehdc0BA9QEVk8wxEVERExZBJIqKqomQQUY0hACRGthS8ECG3XSBI11e9uKuqYkBVFEXpQ7/ttklC9BxCFnpJKVoWgAAak6Y8y4oiVxHvnTFk0YBQ711MiS07FwCUreFu60AZFULfblwSUCAQgqgpRlAtfMuuLzSaVnJFr+qZI2DsyXfWIYFhsMi+DXVZpgjLVYuGjDF5YYmIq6zIbNlTn0NnSFMEF9Sp3W5TXk2KYpKxaTGwod3ZpHPrpjkH3NrMG/KZNWy5qoq27Vbr7aAcEBKbXFRByTshIBZJPnZROy7JoKGUvTtbvvy4EJ7OD3f6i369boq8HA7rZu03y21yIhIzbvbnxjBq9M5D57oYo/dBErg+GktZbhAppcRJxUePrJKMC9a5YoPzo69/mQ+nl+eX2+3lddNVCmvfxdPgmqbIuCiYA9HaJ7CMsfO9pCQgveujF+dcjhkijMe1amBAFk0xQvCZ66sE+6N5/fl6df7qe0MQUxIJoZMUeue9IdO4gHmWDeprdwFNnA6JQJHQWlOU1lTWZLzabHtPinWKkV3Imi6QzYpsjrgrtHN28vnpX74djspbt48vrpbJu15i6Dskyqraedd3MYEwlxnRvTuH4lekklkzqMoyK0ch8qVJqs55A8DNpl43+tvf/DvR6Jtvn7//9NPZxdny8vNs+vDs5OTqagWivt1G3yOAb1rD7JPdLHu0+c2vvnr8+D8rG04//CBpAdoDgAYZjCpEEhE2wAZn//ar393cv52SfvUglkjoNsN/+ped2fTtuw8Pj+/NZ7uZoUFVgaSUgG1umG3Btiju3rm7M9wrWV5dP5eUuIxRgwDW9YANAwCh4p/++F/7BzdT8iTRbZq+3UYEW1VOZLXa7h3cmEwmeZFbyxLk3buTFz+/6r0nTvb/NwetqiJuFqvLN1yui6EkSgBoCIiQEHg23UUFVkq9c9cXi7PPydD08MZ4Opvuz3em43xYc2bJMBN7Lz/87eX1coscIabgPVs+uHWrGh2/ffp3SOs79wfTwwmiGqPGoErgrt0mDqiyOj+7fvfWr5aL5fLl06eD8fjwxs0vvn48Hj3gzLC1SGYyGe3Od08vFl3rAEDAGMApZJKsFjtDowhmOjxQck23SBoBgVOIQSAv7PVm9dfnz2rEYVFwCu3l+af15uzd251bR/cfPzm+d7+qa+08xbhZLDvf7+3Pd8bDFLfaXRWD6ZMn97+8PSltaMPVqvkcou98KyJMiJLSYGd258sv/vbs+3a1oOSNIWO4Df3q5PLH969/fvtmOBrP6sHOYGbyycFsJ4Tq+Gh/d1Kul+3Jp1e7s+rgeD4oyPf9drlqXaMCSBZVmNmAQkxxtj//9e9//+J/v+uXlwqwbLaqijV/cXD04v3H//7z/9yf3zg+uH3w6J/Ho6JvE/pNaVgyypLDfsWxXFxsX79+2fSL0SRHY1FURVgkGsOSkircffggM/j9X74pi2IEcH11tV2cu9B68d6yVIOzbXPx8/PRdEKx41Ttzbiuy4cPHhRF0Tf9i59e/vjTjzs71Xh8bCAH8oYi+xBzRBW1bImze4++9OLfvHo9nUzzwfD1dnW1uL5er7K6MnW5XnXrkw+DZrVbZ7UV17ej4TTnbLNtfnrz4s/ffhNSqAeDZhuykpEzA4ZFUlIjKqrAbNHyPzx+AmSvL6+LcjAYTz99XDddj8BoABlD03cNBbuzWm5evnr76ex0sVienJ+9ePX3xXo1n04eGtM5LySUAEnZMBORSOr7jjNmNkVZPXj05cmHk81y+bke9QltUZ+enn/kT3vj+XBYF0UZJb0/u37+9uW23WzaTe/7PvZFmd+4eTgY1YqiqGQMIDCzYWZECiG4vpc8Y6CyLG8e3VwPh5/PPl+s1uum56w4ubhwLgyqarW5XK1W6/XSB6cAxmBR8OH+dL47nYxrUc+2zvPMWhZVRkRABFCJMQQPqEkICQClLLNHj7548+7JN99955aLznV91xiEFH2MAVGrqpjNpqNhvTPMbxzu1nWV22Jvd29QlaIiKikmVkBVTSn54JSUiFKMAglAg+uLjH/329/cu3v8w7On796/O7+8bpo2JbK2Ojrcu3vr1u5sXBaWWcsqszbLbJlnBQCG4GNMKQmDgiRRlZACIDFZRFAUUAk+uL5nojtHN44O5ovF4q/Pnn08OQWV8ah+cO/2zqBGEIVkDJAxRAbQdL3rnXMxhBhFhEWSKqEKKKQY+64nQiAwBiVqjBEECCgzvDeb/uvjf7x/fIvZFLnJLauK9z2AkCFCEoWQpHNu03Yu+BCjqHKMYkgNIAJITH1qbcZoUBKE4GNIIpEAAFQljQZ1XZZEgCCiMSVRSIgAACoKqqoQYuxc70KMIiml/wNkiuAfCYLZxAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen_images = [gen_dataset[i]['img'] for i in test_index]\n",
    "gen_images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "371b1704-2c31-460f-91dd-f8d620d577e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random_images = [index_image_list[i][0][0] for i in range(60)]\n",
    "len(random_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1253cd76-d79d-4f47-8ad6-697496b94182",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI3klEQVR4nCXWW48c13EA4Ko6p093T89Mz+wsd5e7S1K8SaQo2U5iRzEcwEYcJEDe8/f8C/wcBMhTAD8koRIDkhCS4p0il3vj7Fy7z62q8pDvV3z4u9/c3/hkDDrCyhlLPG2rB/duN3VFhKAaU1alwlbbzdbHKvrym9/8ers+Wy3PO7+5ml+dnX0E1MLZsrSZs4ICgXBOMRGAZU4GckHkrJlNmuPD/fGoqpxFskQkKiwZlBRo3XeL5WZ5FW+dXAPpmPvR0A7r2d7eMOcYcwox9d6nnI1BayjlFGO2O8NyMJrUdXW4f3D/zr3ZdCac+75jzgggDN6nrk/vP749uZgrlzs7zWLxkzVQVgZNUThqxkOW7H0MIU+gZWFAJAQWFhX7y589PDy+PhwO23bajieGChFQ4ZhCjJGzbLvYfbwQgNl0Z9Q0o6a0Rg0ZVxagkFlS9iGl4CMh2sIaQ4iIiGRQQexnd44HTVW6yhpC5CxZRAgNYFZMgpLYh7QpSxoNR1Vlcup9lOFgBAA5ZyJSVRAwZMiAggAgAooyC+ecbMB19p311cBFpJgkCUvpqpRz1/chxE3XkQnWirUkkkNMqhBSyiKiYskQgopYY4xFRUBAMtT7XkRAwTbDAZHZXnHQsN2qCAOCCHe+X2yWwXexjyFoCMGagsgWBapKYhZVYwwLswKAWouIag0BAholAmYFUNv3GUHn8467T3VtBVSUGaQLft2tWRKnImWXgvZBDWWAiJABuK7Kui5BNaVgrQG0SBRTzMymMN4nQHCFsf0mAlgC7H0X+sQCWSSpBubMoFD6zvjghE3yFiACeFckMhx89D5ag8YCIKQuN4M6Z71abclYY0xZWkRj66IqbL3NW6kNAeaEm56jFtt1KpsdV44zC+XknJnttDFtor+0pitstCaXDl1ZNE3V936z6euyBjUEpbIqUgiKoJY5i2xDXpsKDBRUVCcfzl/8dCZ2Z/+46Rd+s+6qqh6PBpjIbwVioYwG42xKzmlK0XvofZ9SiiEJo++TdcaVhog4i02cswS1DFqEbH2ugtu/+1e/qid756fnob9Y+xBENn4rKafgS0uVQ2tITGaAgpKPHlRVxYc+RfU+1FQB6mQ6VMmWRSMzM+ZYeF8L7rW7g9OrxYvXfzaoIeWcA4kPOeSUEU1ihLIaDpvLcJlXcTYCJDVojKW6dk3jrCuWm03vUXSUU7Q+0GoL1pV1uQc4ZWo378++//bxdKe5fffuyemlpBA5cvJkrKuKmFIMYbkGMk1h6MHRDb88NcDNoBgPm7oatJNsLjCLdl1nAex87q426fe//Tui0eP/efbu/bOz8/PV/Pzg4MHph5Pz00vJOXcbzZ4Aue7IusR5Me+TmmuPvv6LX/5zTeHHH/5kcWFNAFQFadsRkhHhwpBNsf3bv/6H63s3WfSL29lm5u3VN7/4xXRv993J6aMHX87aqTM4qiplzknBFGqIHBpXPHzwVTs+HlfF44s/WQjDaVaKSGY0HrmiBFBCxD/+4V9u3LybUkea+vVys7wKnG3TZKTOp8Pjm23bls4VhZUkr159+PP3T3yOhct1aY1oZt6Zzfz84qeX3w3b7c61ggoiIkNISIbAHu0fWUBDNnbrfPVpc3qSENqD66PJzrgsZ01TDSrrnLXWDt3hMT7+/vl8vbY2Qc79dmsL49r96dGXH779zs5XVT27djgmRGPAEKqy7dcLCQFRrs5PF+/f+cV8Pp+//OG7cjw+vnFzfziYHUzUGjQGiabT5uBg/3yxWW+9sAoMrGrUYtmjG+1OCiIwTdWikRi3rAKglmMKDGVtF+vVD8+fDkFrayvNvF5cvM3/+vbl9MatR3/5q6Mbt6rBICx99n23XPsU2rYdjQbK67A6L4f50YPbj27/3OLS508+rTPHkD0zW0tGVYbt+M6DBz/++GR7dYmSjXOuLtfef/j4Yf38x58+nhwdHI+KUqlWHM0mo8zN7qydta5brS7eP7n16GeH9/dGtWzWebtZJ/aiimAA0RbOMKuozPb2/+a3v/vfx//hP52iK1ecs3P1tPn88ODdydnTJ09uza7Pdo+K4/vtuAy+J97U1BjLy7S2YV7B4Opi8eLVc4FuMhkgWVUEZasqZEhFAPDO/c8t8H/9+7/V43ays9fHmF8vED0YWal6V11uu+7tKyodpF5rA+3Bzrje+fKLQTPw2/7p05evXr+5vj/ZbXcBFDFnYptSts6ignOOLH3+1VedXz979uL6/pEYuDx9ttpuPq2XWDmpyovF+vL8sqiKtjINDZXb8WgHB9V62z998+Y///vbojDHBwcpoCkMWjLKloVJSUARqXAlof35N79mU3ZeRvWwGe58/LCYLzeBKiQhCxx71AR2sF1v37x5d3FxulwuTy8/PX39crVZ37txw5ZlzNkggoICW7IGDSlojLGoHJCpB6OHX3/9/t0ZKBRlu+okq/14clqgPdo/vrbbFq4k4bNPn16+W3b9ettvfQrbsB001fWjfTcoBZnQICHo/z/DWASIIRTBWWfJwmjcHt+wq+XKlPWbj2dX6z4zPH/z9upqNRo2KcXVarlaLzklJCoKqmvz2fHewf5uO65y9q5wZNAYK6IWERFARaL31hmFEhkVxBgtC7rz2a2jw6P502c5+Zjjh7POnCsoi4ghGI+b/b296aSZjMvD67OqqizZSTuqSscqCggqFhBBgXPqQ6+oCKCgrFklp9hNx80//ePfP/zi3qs3r88/Xc4Xq+22F8mlMzcPD+7cvLE7HVelNcRFaYiMMWVhHSAKKzOLilWFzMyQY4qIZMggAquo5hRiCmE6HA2/+Pzhvdvrzerpi5cfzy8RdDKs792+ORrUwkkkIQESAJAi+hBizkk4c2Zmy5JBAVQBIOXUbTsiUlJClawsSmisddYYN6FH9+/eOj4sjK0cla5Q4RBEFJAQAEQ0i/QxdCGElGNOImI5MwIQACpoZi9dURggRNKUE2cREVAEBVAYNoOmrg0RgogyC1K2hAZAWURBVDVy7rz3OTMLM/8fV/LYP1y7WlEAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random_images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ffd8ab02-5fc4-4db4-ad92-b1f5052d79d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trak_images = [index_image_list[i][0][1] for i in range(60)]\n",
    "len(trak_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e5bc4262-e856-4dc0-bc4b-0f3a930f0bbc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJF0lEQVR4nCXQx5Jkx3UA0Js3b77M58t0V1eb6THAYDQCDUhEkDIRCipC4l47/aIW+gCusGCIXIASgBAw4GBgxva0Kfde1TPptdD5hMP+/d8+7UdLggvOlOAh2CylDx5cqiQBAIhhNAY8KVluN7vB5M5W//yv/3Jor7brN12/Xa3u3r19jQRpJvNCDeMYIDIOIRgzWAJG3lmOXiAQgyLNj47OpBKMcecRGYbgxxEAEBBum916vd3cmYePTjkbtG7KkufZ0XKZO2+Mddq4EEE7K4jJRGlpjLaUCqaKXAg6np48vHwymxx5H7QeQgiMYQzQD37Xji9++ObV1V0MajLNr959oyTPS2I8EYLVs4mxZrftvR9Pl6fGOeTAkVlvIkR6cHHv+GQmVTafnS4X9xKRRYAYwTljrQmBtXu7aV+7CPP5fFIWdakEd1mR5IUKMYQQm/3Q7vv1+iCJpAqIwBE5R0AKEGh5vlCZTGWlkowS5mAIPnLkHrQD4yMYP46mESLOppPpRLb7tj242azinPfjIIhiDMGC4JTm3EfnfEAunHcumGEcyIgeohkbH32SVWScDsErmRrn9n2vR71ve0qMJCjylDjrB2NN7AZtfdBGSxGIsehcKhOZ8IjACZUS22YbwbOIVE9L4uLmpZVg+rZxwQBC8HY/dOtm2/et6c3Y837UcxCMiaqYOu8HrY3TUibGWBvBOi8V5zwiYYwgJAiB1njiSEYHT7BrdPPy+/07ihwcB8uhs0Oz34VovRXapP2elS0eeme0C1Eba6oyT4TwMe6bVggSkQEk/aHvjclsdugGTphnCZnBM4RUSd+91wfrAXXCesV19AEY52rUrOso+LTryfvB6I7QumCtdmNvOI+IUTDcNXuZzL3F9+9bkYxC8GqSkBCYcjnN60qJuqRZpSapgoDBy6bB4OfIzyM7sTZFyqbHZzKrARTDRCmVJEwpVhTy4uKYE2v3XXAMQSQ8BY/IaBz80EcKMWpzcNDQcYJRUMiuf7j68d1dpOpITkxn2r0hUrKodwbag3U2GbVlMEynFGMw1ozj2A+9tc4Y6yyMnROKSymIyFlHPsZh2GIKHtNGk/V5ny2WTz6W+WS73nW2GRxPGNvuu/V2Px5aiSERKIgLHSNzDKy2o3cWMIx60EM0xjKCGMPxyST4kQC4diGC8D7tR4psmk1mq2Z/df2ShWB9CCE43ZveWuMA0LKQJyiybGt2g4tVBugAAiPOVZqUhfKM3axX7YEto9daU9/joeMyz1I1Cz4NkJvN6qfnL8pSzebzpu3BOx+s1QNRorJ0GMYQsdOB0zQhsbg4GttbsLYssMiLSVXLNOv0gTFsd3uBQG2jNq3+3c9+C1H9zxff3a5ere62fbs7Oro87Pf75hBDcENnxo7FGCsDXNiRmsPoAn/65NHPPvl9Suanv37O/S4hG7zrh66qs7IqAUIiiYwtf/PpPy6Pz2IMTz/0KsL+7ubJg4tqWt+uN5fnZ7PJJFdykuftdn19uwlceca4YlzIj//2l4zXdZ03d3/mdrCF3sRGB3Z8spzVhQ+eI2P/+R+fXZxfQuiViKDNYbfZ7tuAXHu/7/qze/cXx4uqLrNUWe0/++xPf/jsj564ULEu8lQo5PzexanZbF5//20i2+k0cYRCilxGIuQIdLE8l5yjZ7wb3KGNu10GMS1LLpWd2PniWE1KmaUkRF1nf/8Pf/enL57d7HYssF0z3AzbNJWLi/vl2UdXf/ma+zFJ1OJyySnm0osEnR/pxdefT5Mi4TH0bWh2pml27a7RA1Py/Px8TlhOc84jchajraf5ow8frb98NvQdAkamtEcbqNMosqrijKO8WF4i2bZ5a7wJ4KnfrBzsqmm1uX6/f/Mq9U4AoBnt0K+0+a83r6uHDx7+6tenl/ezItd7TSHqbnTBT6flpC7Bd8P2fTZZ/uLjB08fTKd5Eknv2lttu/3YWu9pNp0NfX/2wcP5xfKr/hCbbXQ2yVKSotGmub1bv/qpfv7i6GQ5UUqqklhxNKutcxcnx8ujXB/s5vr72aI6+/iszrkdD2+uXreHjY0OGAFDque1TJOIMD87/ejTX99++3XYbQLCymqjSGD6+Hz+7OXV519+eVzUy+XF8dNPFkdV3x0SpidZ4ZnUK5vYbQpJs9r99fnzze52MstRcogUQyDnLUmK0QPCyf176PWbb74qqjoty6brbl4+s65P0iQUpS+n1/v+5ttn5XQCdtRMYlB1lU2fPM5SqTv9/Lsf//fZd0WuThYLZJGhI+5IG01CRAAp0ySRyZOPHLi7m7vl2b3cu+3tq83Nu7vdvphMi+l8dX27ffe6aNezXBZUmWGfVnNGed+PP7x+9+f//mK0pq5rPXhSHEnwiORCYNGHGBlDmSqR8A8++QS//9E6SFVW1UdXb99u244grUqpu2y1a/pDl3O2Xe9egF3dXPf9sNo23718udm3s7piHAejEyQWWARPIkk48RhhHAeSCackFfzeo0e7VcMYUlr3XvCkePv2PSJfHi8XR9MQIAS4Xu3eXV97b7TRozW9GSjhJ8vjrEwjBsDIOfqIJJUkEkKQ0drokYRAztK8QEbjoHmav9/s9joIVb56e9v3vszkYd82bdt3nfcekSGClLg4quazalrLGA0JJQQJIXwIRJyICJGN40AqUQzAxwgBWIBoT08X5+fnzYuX1ozG2pvb6zWD4LR1NoLPMjmfz8tMljm/OJ+nqUTEo+ksS5X3PgJjzhNyQuTOudV244JfLBbWWecNi1EPB5Xg7/7pt48fP3j5+s3t3Xq9PXS9do7nGC7Pjx+cn07rQknOwKa54EjIhRASGXMhOu9DjITIvffGmaE/IKeqrLxzLlgGXutRD32Rqg/un394eXbour989ezN1S0ymNXpL58+LjLpnXZOcxKcEBiGiF0/aOu0c8455z1FiCFEH1yWZizGtm05YxGCSDgDHgJwSjgSRD+f8J//zcOzxSxVqipkkcngXdcZEolICJH5EJ3zh2HYHoZeW21sjJGci8hAcELJQ4z73TbNFHKMEa21IcTgAzJgACH46aSsyzwRgrMYIVjHhbWcEBF8CCzECGG0bvf/kT567/4PcQTMQleFHicAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trak_images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cc57274b-d591-44f0-b9f8-53fda6c34fa5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ours_images = [index_image_list[i][0][2] for i in range(60)]\n",
    "len(ours_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "db088da1-30f2-429a-be78-74dd2ea06099",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI6ElEQVR4nAXBSY+c61UA4HPe6Rvrq+quco9uu+22E4VwbxJEiIQEOwQLVkj5A/y1sGTJCiQiJCKUHSCm2PfeXNvtdtvVXdM3vPM5PA/+zV99ZX0WEiViqaVAbhr97OKiNAYBADhlAlBGlW7yIesUzasf/MDbwzjunJt2h/1ms2bI2khTqJQiARFwzin6KABUSgkhS5QScd4dXV6eF0ZyTgRCIBJRiEFJwSAH5w77wY1puZhJkQSn+azsWnN+Oo8p+Bh8DJO1MSVAVlKEwoeYlEboFo0x5vzk6U+//uPz80tnp+1mbScLDMwQIzmbbj++u1/vgXTXNtOwURJNqVAorURbzFKO4+StC7N2HlIEZESIOQKQurm+vrw6b5ru5MnlyxevjSlzzufnF3YavXNE2Pfh7Te/tz5XZTtvmqbUEkkppaTkDBE4JGedHycHANpIBhaIKFChZED17ObZrG3qoit1JRVYf4ghNnWtNaaMnHHy46HfKolt1zaV8W5yKc9VxwwxJCGRmTmxRImSM2diAhScOXMK3qmk3UR53ATyevlktu+3OSUlT0JMj7vNdtevPz/m7JBZKw0gXEgpJqV9IGLKSkmBmDNJIaRGEkDMUotxGpgzAKrlybFR5u5300TusH6wblBa+Wlc7w7fvX+/3ayTTTEI72NdtZkkCg0CJhdUSlrrlKMAYGBdSCFAKGQEaVAlGQNLKVRKQrCynu3de9rdMwJp+Z36dL/d3j98FpgFFCmqHLEKGGJ2PqeUck7G6AoAgVOKSgmQWoFx1rkcTGms80JAWWkVLYCSdV2Hj5+GTzESWCF6IQ7OEQtldAwYIgpQMWNO0dmQc6CcQogxRClRKmCQYYhlVXCWm8dRKK+0bGeF0loYUTRm1pambXU3q5qynGwYbTz0gbkkqkKUIRARoFQh5xBzTDkzE2cGUgq7WQPAwzBxAoVaiwpZSlTRc/CsgDAlx8LL41qyrEhv1vvb7VaoUld6GqO1QUmlhXAxWutCzEzEnIzBmBNG1t5Zb0MMISRijC7LQhTGaCOBSTHj5EZZSGqaYUooTC7qdllVVROcc5F8JET2MYbt1tmRKSESIlOizOQjh+iYmZlC9DlB8EELBQjL1YI5KAAVYjS6AC0CTEiqrOYPw6MdH4XATEDMOQUbXU6ZiQmyECwkOR8lUllIBiEAhBBNU9ZVS0reP34ZRslwlFJUIajgdVm0ppbZ6egiZXF/dz/v6tWT5WGwOaeQM6cIKJQ2OUQGiEyJsK3ar3729f7hy2H7UGjZ1M352WnRNoMfALE/jEaCcq7wgc9PbyDTuHu7P2x3262dxtPT1eHQb7Z7YKRgKXkBWNcNCEFZBEpjSD/9yS9++cu/7febf/qHv3f9o8iUvLfWLhazelYLpMIoFaJ+8fx1XR1JpOXxvn/c+Kl/9vTy8vzs7bffXl+/urw4L43qmnrYbT68u93u+ykEgWa5mv35n/4Zgzw9eSqhdIcwiD6lELQ5u7ycNQVRkgLV2fn1cnkBmJWi1fGyX8xvrp9dvHxZVJVW6i/+8q9vXr+u6rIsC8Hwm3/5t1/96u9ShLKCH7262X+++69+d3Vz8+z6hx/f3W57f9Y0omhdROonKUAJELO2E0Io1AKEQhQARquyKKWUR7P2fLVq2kYXBQhVNt0f/fxPjp5cEKuibDKJ3/3fm++++TZmPrt+PSaxG2JiNesWbdPN2nnXLYq6VofdF5lC1zUEYffw6MfJ7g//+ut/VkXxo5fXhn0lGSUKJQCobIrnL65v79feJ+toPluMwXufhazaxaIr2qqdr46WhDSNm5gTASly00DUNsYG+9277w/rdaFUKURblpVS//Hb39x+eHfx4mb55Kysq8f7hxQCZwrEylRXV8/H/ecvd+9Pzq6+/skPz45LrTBCHu3k/TSGMREpo2T0tqjKumtU1RZllXO4OFnNVysfwsfbD99/vP3+/bvZrIMU+975caqqIqdUGX3UFYbN/fs3x/P6+molMQY/Tf1hmkbKGYVCJmVKyZxi8rPjk+cvX68xPn4JoEtPjChMac5XR4+bzZv/+c95MyurBZGp6yoFB+SlSHWtN+swHh4WR+3kxoeHh37sjdGAkjMxg8opCMkxekDx5OwMws6Ffdkem25lx+Hh9s04jlLLcjbrjk+8pe3hwAhaQU6OyZdV8fTyQgsMNq7Xuw+398x8cXmqBRFIiaS8d0JJyqkoSnl0jPkqkLMeVqeXzk73H95+uLtLgNV8YaryYbvZHQ4kZKOBUvDeVkVT1Y0P6dP60/++edv341HX1fVY1Aa1QEaVKSMIAgCAqm5pcXyhXj1ueiF1Xc/qZvH5y/3jvl9mcTRbJM4xBNRaaj3047t3HzdtOYzTerN/+/v368et0frHr+vDMFaYVJaAqHRRKKMBYBz6brHQpm46IlEkLwTKqj1KbAJX//7f7z8/jqujVVUWKGXItNls7tYPRNF6Z30YnZVavn7x/MnpSsmstFBaMbMqq8oURVEU1o7aaARCoau6iZIpU9l2j/108JjN4puP/XbA1aKOftgd+vXDQ0pZCiEl1JV8/vR0tTo6XnRSUVmasiyM0TmTUkorZYQU0zjoQpdFkTPlnFEAeb9czs/OTofbNfgMEiZnN9vAeRomC5iPj9unFxezpmgqvDhfaqNz5sV8XpUmpkjECZMSUiOKGP3Hu7uYwuXTq+BDSF4gOndg8n/4By+fPbv4st4chn6zm/YHy2yahn/21avXL667WS0F5+yMkYBCSK2UZgBKOeZIDEoImVIKwQ7DHoVcLlcxhJSTEOydG/p9XRaLbvbjV8+FgH/89W+nYWpns7PV+S9+/nVVaO9siE5JhUIwIjH3o7XeuxhCDCllxUBEOVNaLOaZaLPZKCUzpbI0gCLE2NaVKQwgFEY9v1oy5KvLy5PVrKmKlFNIkQCkkIDIBJHyYRx3wzSF6EMkIkWZENFoo5WKKT98uT86XqAU3pNzngkGayuCpNDa8ei4m3fzRdcpBcQZEAFBSQ3IREzEOZMNYT9NNuaUKOX0/+AK6q2lo87IAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ours_images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb24da0-787a-4e79-a7fe-5780ea4763ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ff7d9e6-bd65-4656-a9a2-a2aca804d652",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2634fa58-babe-452b-97f4-8c9c38fceaa7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI10lEQVR4nCXUx5Kc13UA4HPOPfeP3T0dpicAAwwiCZsEHCRtZJXKKpW90GvqBbT1xhtXqVgs0SUSEAFSyGEGEzv+4aZzvPD3EB/+4T++bvvAjExYWDJGhoP8+OhGkWcIgKA+xJTQmrxpWuez6Mtf/OqXbXO5Xl90rlkslhcX50iaFzbPbYhBQZFAJcaYDCKn6BEiARHioBrt7+2WpUXEENQgiWrXRQAGC6umXa+3zToe3doF6JO0wwEPqun+3iBKCDGGENuuCzESAXORUowxcZmZ6aQqinxvtn//zsPZZBa8b9pNDAEAQLTvU9v5D6fvz66XqPl4Wq/WJ8ZoXhgwajOqR3VKse2893EwrGNMgIAISZKC8uNHD/cP5/VgONmZT8dz5lwEVJIPvXcuRWka9+HTKbHZm++OR4PhIDcEhMiWkTQptL3rne86R4DWMhIQIhIigarwrePDqi7yvMoso0lRW1E1xoAkMBFAhHyEtihwMBjWtY2xD0mroiaDMSUCVAUCytgiqYKAACCKSJLog+dONrFvue99H4F8lCAieVaEFLqu88F3Xc/WZ5lYCyKpd0EEiALFJJIMGQIAUTbGMCqCKhhDTduICChwXReGuV2KA9e2GlNEVEmudd2mWXvfehdc0BA9QEVk8wxEVERExZBJIqKqomQQUY0hACRGthS8ECG3XSBI11e9uKuqYkBVFEXpQ7/ttklC9BxCFnpJKVoWgAAak6Y8y4oiVxHvnTFk0YBQ711MiS07FwCUreFu60AZFULfblwSUCAQgqgpRlAtfMuuLzSaVnJFr+qZI2DsyXfWIYFhsMi+DXVZpgjLVYuGjDF5YYmIq6zIbNlTn0NnSFMEF9Sp3W5TXk2KYpKxaTGwod3ZpHPrpjkH3NrMG/KZNWy5qoq27Vbr7aAcEBKbXFRByTshIBZJPnZROy7JoKGUvTtbvvy4EJ7OD3f6i369boq8HA7rZu03y21yIhIzbvbnxjBq9M5D57oYo/dBErg+GktZbhAppcRJxUePrJKMC9a5YoPzo69/mQ+nl+eX2+3lddNVCmvfxdPgmqbIuCiYA9HaJ7CMsfO9pCQgveujF+dcjhkijMe1amBAFk0xQvCZ66sE+6N5/fl6df7qe0MQUxIJoZMUeue9IdO4gHmWDeprdwFNnA6JQJHQWlOU1lTWZLzabHtPinWKkV3Imi6QzYpsjrgrtHN28vnpX74djspbt48vrpbJu15i6Dskyqraedd3MYEwlxnRvTuH4lekklkzqMoyK0ch8qVJqs55A8DNpl43+tvf/DvR6Jtvn7//9NPZxdny8vNs+vDs5OTqagWivt1G3yOAb1rD7JPdLHu0+c2vvnr8+D8rG04//CBpAdoDgAYZjCpEEhE2wAZn//ar393cv52SfvUglkjoNsN/+ped2fTtuw8Pj+/NZ7uZoUFVgaSUgG1umG3Btiju3rm7M9wrWV5dP5eUuIxRgwDW9YANAwCh4p/++F/7BzdT8iTRbZq+3UYEW1VOZLXa7h3cmEwmeZFbyxLk3buTFz+/6r0nTvb/NwetqiJuFqvLN1yui6EkSgBoCIiQEHg23UUFVkq9c9cXi7PPydD08MZ4Opvuz3em43xYc2bJMBN7Lz/87eX1coscIabgPVs+uHWrGh2/ffp3SOs79wfTwwmiGqPGoErgrt0mDqiyOj+7fvfWr5aL5fLl06eD8fjwxs0vvn48Hj3gzLC1SGYyGe3Od08vFl3rAEDAGMApZJKsFjtDowhmOjxQck23SBoBgVOIQSAv7PVm9dfnz2rEYVFwCu3l+af15uzd251bR/cfPzm+d7+qa+08xbhZLDvf7+3Pd8bDFLfaXRWD6ZMn97+8PSltaMPVqvkcou98KyJMiJLSYGd258sv/vbs+3a1oOSNIWO4Df3q5PLH969/fvtmOBrP6sHOYGbyycFsJ4Tq+Gh/d1Kul+3Jp1e7s+rgeD4oyPf9drlqXaMCSBZVmNmAQkxxtj//9e9//+J/v+uXlwqwbLaqijV/cXD04v3H//7z/9yf3zg+uH3w6J/Ho6JvE/pNaVgyypLDfsWxXFxsX79+2fSL0SRHY1FURVgkGsOSkircffggM/j9X74pi2IEcH11tV2cu9B68d6yVIOzbXPx8/PRdEKx41Ttzbiuy4cPHhRF0Tf9i59e/vjTjzs71Xh8bCAH8oYi+xBzRBW1bImze4++9OLfvHo9nUzzwfD1dnW1uL5er7K6MnW5XnXrkw+DZrVbZ7UV17ej4TTnbLNtfnrz4s/ffhNSqAeDZhuykpEzA4ZFUlIjKqrAbNHyPzx+AmSvL6+LcjAYTz99XDddj8BoABlD03cNBbuzWm5evnr76ex0sVienJ+9ePX3xXo1n04eGtM5LySUAEnZMBORSOr7jjNmNkVZPXj05cmHk81y+bke9QltUZ+enn/kT3vj+XBYF0UZJb0/u37+9uW23WzaTe/7PvZFmd+4eTgY1YqiqGQMIDCzYWZECiG4vpc8Y6CyLG8e3VwPh5/PPl+s1uum56w4ubhwLgyqarW5XK1W6/XSB6cAxmBR8OH+dL47nYxrUc+2zvPMWhZVRkRABFCJMQQPqEkICQClLLNHj7548+7JN99955aLznV91xiEFH2MAVGrqpjNpqNhvTPMbxzu1nWV22Jvd29QlaIiKikmVkBVTSn54JSUiFKMAglAg+uLjH/329/cu3v8w7On796/O7+8bpo2JbK2Ojrcu3vr1u5sXBaWWcsqszbLbJlnBQCG4GNMKQmDgiRRlZACIDFZRFAUUAk+uL5nojtHN44O5ovF4q/Pnn08OQWV8ah+cO/2zqBGEIVkDJAxRAbQdL3rnXMxhBhFhEWSKqEKKKQY+64nQiAwBiVqjBEECCgzvDeb/uvjf7x/fIvZFLnJLauK9z2AkCFCEoWQpHNu03Yu+BCjqHKMYkgNIAJITH1qbcZoUBKE4GNIIpEAAFQljQZ1XZZEgCCiMSVRSIgAACoKqqoQYuxc70KMIiml/wNkiuAfCYLZxAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from datasets import DatasetDict, Dataset, load_dataset, Image\n",
    "\n",
    "dataset = DatasetDict({\n",
    "\"train\": Dataset.from_dict({\n",
    "    \"img\": gen_images,\n",
    "})}\n",
    "                     )\n",
    "gen_dataset = dataset[\"train\"]\n",
    "gen_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5e01b83a-e639-4b2d-afeb-a886411a3676",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI3klEQVR4nCXWW48c13EA4Ko6p093T89Mz+wsd5e7S1K8SaQo2U5iRzEcwEYcJEDe8/f8C/wcBMhTAD8koRIDkhCS4p0il3vj7Fy7z62q8pDvV3z4u9/c3/hkDDrCyhlLPG2rB/duN3VFhKAaU1alwlbbzdbHKvrym9/8ers+Wy3PO7+5ml+dnX0E1MLZsrSZs4ICgXBOMRGAZU4GckHkrJlNmuPD/fGoqpxFskQkKiwZlBRo3XeL5WZ5FW+dXAPpmPvR0A7r2d7eMOcYcwox9d6nnI1BayjlFGO2O8NyMJrUdXW4f3D/zr3ZdCac+75jzgggDN6nrk/vP749uZgrlzs7zWLxkzVQVgZNUThqxkOW7H0MIU+gZWFAJAQWFhX7y589PDy+PhwO23bajieGChFQ4ZhCjJGzbLvYfbwQgNl0Z9Q0o6a0Rg0ZVxagkFlS9iGl4CMh2sIaQ4iIiGRQQexnd44HTVW6yhpC5CxZRAgNYFZMgpLYh7QpSxoNR1Vlcup9lOFgBAA5ZyJSVRAwZMiAggAgAooyC+ecbMB19p311cBFpJgkCUvpqpRz1/chxE3XkQnWirUkkkNMqhBSyiKiYskQgopYY4xFRUBAMtT7XkRAwTbDAZHZXnHQsN2qCAOCCHe+X2yWwXexjyFoCMGagsgWBapKYhZVYwwLswKAWouIag0BAholAmYFUNv3GUHn8467T3VtBVSUGaQLft2tWRKnImWXgvZBDWWAiJABuK7Kui5BNaVgrQG0SBRTzMymMN4nQHCFsf0mAlgC7H0X+sQCWSSpBubMoFD6zvjghE3yFiACeFckMhx89D5ag8YCIKQuN4M6Z71abclYY0xZWkRj66IqbL3NW6kNAeaEm56jFtt1KpsdV44zC+XknJnttDFtor+0pitstCaXDl1ZNE3V936z6euyBjUEpbIqUgiKoJY5i2xDXpsKDBRUVCcfzl/8dCZ2Z/+46Rd+s+6qqh6PBpjIbwVioYwG42xKzmlK0XvofZ9SiiEJo++TdcaVhog4i02cswS1DFqEbH2ugtu/+1e/qid756fnob9Y+xBENn4rKafgS0uVQ2tITGaAgpKPHlRVxYc+RfU+1FQB6mQ6VMmWRSMzM+ZYeF8L7rW7g9OrxYvXfzaoIeWcA4kPOeSUEU1ihLIaDpvLcJlXcTYCJDVojKW6dk3jrCuWm03vUXSUU7Q+0GoL1pV1uQc4ZWo378++//bxdKe5fffuyemlpBA5cvJkrKuKmFIMYbkGMk1h6MHRDb88NcDNoBgPm7oatJNsLjCLdl1nAex87q426fe//Tui0eP/efbu/bOz8/PV/Pzg4MHph5Pz00vJOXcbzZ4Aue7IusR5Me+TmmuPvv6LX/5zTeHHH/5kcWFNAFQFadsRkhHhwpBNsf3bv/6H63s3WfSL29lm5u3VN7/4xXRv993J6aMHX87aqTM4qiplzknBFGqIHBpXPHzwVTs+HlfF44s/WQjDaVaKSGY0HrmiBFBCxD/+4V9u3LybUkea+vVys7wKnG3TZKTOp8Pjm23bls4VhZUkr159+PP3T3yOhct1aY1oZt6Zzfz84qeX3w3b7c61ggoiIkNISIbAHu0fWUBDNnbrfPVpc3qSENqD66PJzrgsZ01TDSrrnLXWDt3hMT7+/vl8vbY2Qc79dmsL49r96dGXH779zs5XVT27djgmRGPAEKqy7dcLCQFRrs5PF+/f+cV8Pp+//OG7cjw+vnFzfziYHUzUGjQGiabT5uBg/3yxWW+9sAoMrGrUYtmjG+1OCiIwTdWikRi3rAKglmMKDGVtF+vVD8+fDkFrayvNvF5cvM3/+vbl9MatR3/5q6Mbt6rBICx99n23XPsU2rYdjQbK67A6L4f50YPbj27/3OLS508+rTPHkD0zW0tGVYbt+M6DBz/++GR7dYmSjXOuLtfef/j4Yf38x58+nhwdHI+KUqlWHM0mo8zN7qydta5brS7eP7n16GeH9/dGtWzWebtZJ/aiimAA0RbOMKuozPb2/+a3v/vfx//hP52iK1ecs3P1tPn88ODdydnTJ09uza7Pdo+K4/vtuAy+J97U1BjLy7S2YV7B4Opi8eLVc4FuMhkgWVUEZasqZEhFAPDO/c8t8H/9+7/V43ays9fHmF8vED0YWal6V11uu+7tKyodpF5rA+3Bzrje+fKLQTPw2/7p05evXr+5vj/ZbXcBFDFnYptSts6ignOOLH3+1VedXz979uL6/pEYuDx9ttpuPq2XWDmpyovF+vL8sqiKtjINDZXb8WgHB9V62z998+Y///vbojDHBwcpoCkMWjLKloVJSUARqXAlof35N79mU3ZeRvWwGe58/LCYLzeBKiQhCxx71AR2sF1v37x5d3FxulwuTy8/PX39crVZ37txw5ZlzNkggoICW7IGDSlojLGoHJCpB6OHX3/9/t0ZKBRlu+okq/14clqgPdo/vrbbFq4k4bNPn16+W3b9ettvfQrbsB001fWjfTcoBZnQICHo/z/DWASIIRTBWWfJwmjcHt+wq+XKlPWbj2dX6z4zPH/z9upqNRo2KcXVarlaLzklJCoKqmvz2fHewf5uO65y9q5wZNAYK6IWERFARaL31hmFEhkVxBgtC7rz2a2jw6P502c5+Zjjh7POnCsoi4ghGI+b/b296aSZjMvD67OqqizZSTuqSscqCggqFhBBgXPqQ6+oCKCgrFklp9hNx80//ePfP/zi3qs3r88/Xc4Xq+22F8mlMzcPD+7cvLE7HVelNcRFaYiMMWVhHSAKKzOLilWFzMyQY4qIZMggAquo5hRiCmE6HA2/+Pzhvdvrzerpi5cfzy8RdDKs792+ORrUwkkkIQESAJAi+hBizkk4c2Zmy5JBAVQBIOXUbTsiUlJClawsSmisddYYN6FH9+/eOj4sjK0cla5Q4RBEFJAQAEQ0i/QxdCGElGNOImI5MwIQACpoZi9dURggRNKUE2cREVAEBVAYNoOmrg0RgogyC1K2hAZAWURBVDVy7rz3OTMLM/8fV/LYP1y7WlEAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = DatasetDict({\n",
    "\"train\": Dataset.from_dict({\n",
    "    \"img\": random_images,\n",
    "})}\n",
    "                     )\n",
    "random_dataset = dataset[\"train\"]\n",
    "random_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0ef5867c-7855-4ff1-a526-4c3ccec15cb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJF0lEQVR4nCXQx5Jkx3UA0Js3b77M58t0V1eb6THAYDQCDUhEkDIRCipC4l47/aIW+gCusGCIXIASgBAw4GBgxva0Kfde1TPptdD5hMP+/d8+7UdLggvOlOAh2CylDx5cqiQBAIhhNAY8KVluN7vB5M5W//yv/3Jor7brN12/Xa3u3r19jQRpJvNCDeMYIDIOIRgzWAJG3lmOXiAQgyLNj47OpBKMcecRGYbgxxEAEBBum916vd3cmYePTjkbtG7KkufZ0XKZO2+Mddq4EEE7K4jJRGlpjLaUCqaKXAg6np48vHwymxx5H7QeQgiMYQzQD37Xji9++ObV1V0MajLNr959oyTPS2I8EYLVs4mxZrftvR9Pl6fGOeTAkVlvIkR6cHHv+GQmVTafnS4X9xKRRYAYwTljrQmBtXu7aV+7CPP5fFIWdakEd1mR5IUKMYQQm/3Q7vv1+iCJpAqIwBE5R0AKEGh5vlCZTGWlkowS5mAIPnLkHrQD4yMYP46mESLOppPpRLb7tj242azinPfjIIhiDMGC4JTm3EfnfEAunHcumGEcyIgeohkbH32SVWScDsErmRrn9n2vR71ve0qMJCjylDjrB2NN7AZtfdBGSxGIsehcKhOZ8IjACZUS22YbwbOIVE9L4uLmpZVg+rZxwQBC8HY/dOtm2/et6c3Y837UcxCMiaqYOu8HrY3TUibGWBvBOi8V5zwiYYwgJAiB1njiSEYHT7BrdPPy+/07ihwcB8uhs0Oz34VovRXapP2elS0eeme0C1Eba6oyT4TwMe6bVggSkQEk/aHvjclsdugGTphnCZnBM4RUSd+91wfrAXXCesV19AEY52rUrOso+LTryfvB6I7QumCtdmNvOI+IUTDcNXuZzL3F9+9bkYxC8GqSkBCYcjnN60qJuqRZpSapgoDBy6bB4OfIzyM7sTZFyqbHZzKrARTDRCmVJEwpVhTy4uKYE2v3XXAMQSQ8BY/IaBz80EcKMWpzcNDQcYJRUMiuf7j68d1dpOpITkxn2r0hUrKodwbag3U2GbVlMEynFGMw1ozj2A+9tc4Y6yyMnROKSymIyFlHPsZh2GIKHtNGk/V5ny2WTz6W+WS73nW2GRxPGNvuu/V2Px5aiSERKIgLHSNzDKy2o3cWMIx60EM0xjKCGMPxyST4kQC4diGC8D7tR4psmk1mq2Z/df2ShWB9CCE43ZveWuMA0LKQJyiybGt2g4tVBugAAiPOVZqUhfKM3axX7YEto9daU9/joeMyz1I1Cz4NkJvN6qfnL8pSzebzpu3BOx+s1QNRorJ0GMYQsdOB0zQhsbg4GttbsLYssMiLSVXLNOv0gTFsd3uBQG2jNq3+3c9+C1H9zxff3a5ere62fbs7Oro87Pf75hBDcENnxo7FGCsDXNiRmsPoAn/65NHPPvl9Suanv37O/S4hG7zrh66qs7IqAUIiiYwtf/PpPy6Pz2IMTz/0KsL+7ubJg4tqWt+uN5fnZ7PJJFdykuftdn19uwlceca4YlzIj//2l4zXdZ03d3/mdrCF3sRGB3Z8spzVhQ+eI2P/+R+fXZxfQuiViKDNYbfZ7tuAXHu/7/qze/cXx4uqLrNUWe0/++xPf/jsj564ULEu8lQo5PzexanZbF5//20i2+k0cYRCilxGIuQIdLE8l5yjZ7wb3KGNu10GMS1LLpWd2PniWE1KmaUkRF1nf/8Pf/enL57d7HYssF0z3AzbNJWLi/vl2UdXf/ma+zFJ1OJyySnm0osEnR/pxdefT5Mi4TH0bWh2pml27a7RA1Py/Px8TlhOc84jchajraf5ow8frb98NvQdAkamtEcbqNMosqrijKO8WF4i2bZ5a7wJ4KnfrBzsqmm1uX6/f/Mq9U4AoBnt0K+0+a83r6uHDx7+6tenl/ezItd7TSHqbnTBT6flpC7Bd8P2fTZZ/uLjB08fTKd5Eknv2lttu/3YWu9pNp0NfX/2wcP5xfKr/hCbbXQ2yVKSotGmub1bv/qpfv7i6GQ5UUqqklhxNKutcxcnx8ujXB/s5vr72aI6+/iszrkdD2+uXreHjY0OGAFDque1TJOIMD87/ejTX99++3XYbQLCymqjSGD6+Hz+7OXV519+eVzUy+XF8dNPFkdV3x0SpidZ4ZnUK5vYbQpJs9r99fnzze52MstRcogUQyDnLUmK0QPCyf176PWbb74qqjoty6brbl4+s65P0iQUpS+n1/v+5ttn5XQCdtRMYlB1lU2fPM5SqTv9/Lsf//fZd0WuThYLZJGhI+5IG01CRAAp0ySRyZOPHLi7m7vl2b3cu+3tq83Nu7vdvphMi+l8dX27ffe6aNezXBZUmWGfVnNGed+PP7x+9+f//mK0pq5rPXhSHEnwiORCYNGHGBlDmSqR8A8++QS//9E6SFVW1UdXb99u244grUqpu2y1a/pDl3O2Xe9egF3dXPf9sNo23718udm3s7piHAejEyQWWARPIkk48RhhHAeSCackFfzeo0e7VcMYUlr3XvCkePv2PSJfHi8XR9MQIAS4Xu3eXV97b7TRozW9GSjhJ8vjrEwjBsDIOfqIJJUkEkKQ0drokYRAztK8QEbjoHmav9/s9joIVb56e9v3vszkYd82bdt3nfcekSGClLg4quazalrLGA0JJQQJIXwIRJyICJGN40AqUQzAxwgBWIBoT08X5+fnzYuX1ozG2pvb6zWD4LR1NoLPMjmfz8tMljm/OJ+nqUTEo+ksS5X3PgJjzhNyQuTOudV244JfLBbWWecNi1EPB5Xg7/7pt48fP3j5+s3t3Xq9PXS9do7nGC7Pjx+cn07rQknOwKa54EjIhRASGXMhOu9DjITIvffGmaE/IKeqrLxzLlgGXutRD32Rqg/un394eXbour989ezN1S0ymNXpL58+LjLpnXZOcxKcEBiGiF0/aOu0c8455z1FiCFEH1yWZizGtm05YxGCSDgDHgJwSjgSRD+f8J//zcOzxSxVqipkkcngXdcZEolICJH5EJ3zh2HYHoZeW21sjJGci8hAcELJQ4z73TbNFHKMEa21IcTgAzJgACH46aSsyzwRgrMYIVjHhbWcEBF8CCzECGG0bvf/kT567/4PcQTMQleFHicAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = DatasetDict({\n",
    "\"train\": Dataset.from_dict({\n",
    "    \"img\": trak_images,\n",
    "})}\n",
    "                     )\n",
    "trak_dataset = dataset[\"train\"]\n",
    "trak_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7af27bc1-97ed-4fa1-9bb6-a3b0e7b66d9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI6ElEQVR4nAXBSY+c61UA4HPe6Rvrq+quco9uu+22E4VwbxJEiIQEOwQLVkj5A/y1sGTJCiQiJCKUHSCm2PfeXNvtdtvVXdM3vPM5PA/+zV99ZX0WEiViqaVAbhr97OKiNAYBADhlAlBGlW7yIesUzasf/MDbwzjunJt2h/1ms2bI2khTqJQiARFwzin6KABUSgkhS5QScd4dXV6eF0ZyTgRCIBJRiEFJwSAH5w77wY1puZhJkQSn+azsWnN+Oo8p+Bh8DJO1MSVAVlKEwoeYlEboFo0x5vzk6U+//uPz80tnp+1mbScLDMwQIzmbbj++u1/vgXTXNtOwURJNqVAorURbzFKO4+StC7N2HlIEZESIOQKQurm+vrw6b5ru5MnlyxevjSlzzufnF3YavXNE2Pfh7Te/tz5XZTtvmqbUEkkppaTkDBE4JGedHycHANpIBhaIKFChZED17ObZrG3qoit1JRVYf4ghNnWtNaaMnHHy46HfKolt1zaV8W5yKc9VxwwxJCGRmTmxRImSM2diAhScOXMK3qmk3UR53ATyevlktu+3OSUlT0JMj7vNdtevPz/m7JBZKw0gXEgpJqV9IGLKSkmBmDNJIaRGEkDMUotxGpgzAKrlybFR5u5300TusH6wblBa+Wlc7w7fvX+/3ayTTTEI72NdtZkkCg0CJhdUSlrrlKMAYGBdSCFAKGQEaVAlGQNLKVRKQrCynu3de9rdMwJp+Z36dL/d3j98FpgFFCmqHLEKGGJ2PqeUck7G6AoAgVOKSgmQWoFx1rkcTGms80JAWWkVLYCSdV2Hj5+GTzESWCF6IQ7OEQtldAwYIgpQMWNO0dmQc6CcQogxRClRKmCQYYhlVXCWm8dRKK+0bGeF0loYUTRm1pambXU3q5qynGwYbTz0gbkkqkKUIRARoFQh5xBzTDkzE2cGUgq7WQPAwzBxAoVaiwpZSlTRc/CsgDAlx8LL41qyrEhv1vvb7VaoUld6GqO1QUmlhXAxWutCzEzEnIzBmBNG1t5Zb0MMISRijC7LQhTGaCOBSTHj5EZZSGqaYUooTC7qdllVVROcc5F8JET2MYbt1tmRKSESIlOizOQjh+iYmZlC9DlB8EELBQjL1YI5KAAVYjS6AC0CTEiqrOYPw6MdH4XATEDMOQUbXU6ZiQmyECwkOR8lUllIBiEAhBBNU9ZVS0reP34ZRslwlFJUIajgdVm0ppbZ6egiZXF/dz/v6tWT5WGwOaeQM6cIKJQ2OUQGiEyJsK3ar3729f7hy2H7UGjZ1M352WnRNoMfALE/jEaCcq7wgc9PbyDTuHu7P2x3262dxtPT1eHQb7Z7YKRgKXkBWNcNCEFZBEpjSD/9yS9++cu/7febf/qHv3f9o8iUvLfWLhazelYLpMIoFaJ+8fx1XR1JpOXxvn/c+Kl/9vTy8vzs7bffXl+/urw4L43qmnrYbT68u93u+ykEgWa5mv35n/4Zgzw9eSqhdIcwiD6lELQ5u7ycNQVRkgLV2fn1cnkBmJWi1fGyX8xvrp9dvHxZVJVW6i/+8q9vXr+u6rIsC8Hwm3/5t1/96u9ShLKCH7262X+++69+d3Vz8+z6hx/f3W57f9Y0omhdROonKUAJELO2E0Io1AKEQhQARquyKKWUR7P2fLVq2kYXBQhVNt0f/fxPjp5cEKuibDKJ3/3fm++++TZmPrt+PSaxG2JiNesWbdPN2nnXLYq6VofdF5lC1zUEYffw6MfJ7g//+ut/VkXxo5fXhn0lGSUKJQCobIrnL65v79feJ+toPluMwXufhazaxaIr2qqdr46WhDSNm5gTASly00DUNsYG+9277w/rdaFUKURblpVS//Hb39x+eHfx4mb55Kysq8f7hxQCZwrEylRXV8/H/ecvd+9Pzq6+/skPz45LrTBCHu3k/TSGMREpo2T0tqjKumtU1RZllXO4OFnNVysfwsfbD99/vP3+/bvZrIMU+975caqqIqdUGX3UFYbN/fs3x/P6+molMQY/Tf1hmkbKGYVCJmVKyZxi8rPjk+cvX68xPn4JoEtPjChMac5XR4+bzZv/+c95MyurBZGp6yoFB+SlSHWtN+swHh4WR+3kxoeHh37sjdGAkjMxg8opCMkxekDx5OwMws6Ffdkem25lx+Hh9s04jlLLcjbrjk+8pe3hwAhaQU6OyZdV8fTyQgsMNq7Xuw+398x8cXmqBRFIiaS8d0JJyqkoSnl0jPkqkLMeVqeXzk73H95+uLtLgNV8YaryYbvZHQ4kZKOBUvDeVkVT1Y0P6dP60/++edv341HX1fVY1Aa1QEaVKSMIAgCAqm5pcXyhXj1ueiF1Xc/qZvH5y/3jvl9mcTRbJM4xBNRaaj3047t3HzdtOYzTerN/+/v368et0frHr+vDMFaYVJaAqHRRKKMBYBz6brHQpm46IlEkLwTKqj1KbAJX//7f7z8/jqujVVUWKGXItNls7tYPRNF6Z30YnZVavn7x/MnpSsmstFBaMbMqq8oURVEU1o7aaARCoau6iZIpU9l2j/108JjN4puP/XbA1aKOftgd+vXDQ0pZCiEl1JV8/vR0tTo6XnRSUVmasiyM0TmTUkorZYQU0zjoQpdFkTPlnFEAeb9czs/OTofbNfgMEiZnN9vAeRomC5iPj9unFxezpmgqvDhfaqNz5sV8XpUmpkjECZMSUiOKGP3Hu7uYwuXTq+BDSF4gOndg8n/4By+fPbv4st4chn6zm/YHy2yahn/21avXL667WS0F5+yMkYBCSK2UZgBKOeZIDEoImVIKwQ7DHoVcLlcxhJSTEOydG/p9XRaLbvbjV8+FgH/89W+nYWpns7PV+S9+/nVVaO9siE5JhUIwIjH3o7XeuxhCDCllxUBEOVNaLOaZaLPZKCUzpbI0gCLE2NaVKQwgFEY9v1oy5KvLy5PVrKmKlFNIkQCkkIDIBJHyYRx3wzSF6EMkIkWZENFoo5WKKT98uT86XqAU3pNzngkGayuCpNDa8ei4m3fzRdcpBcQZEAFBSQ3IREzEOZMNYT9NNuaUKOX0/+AK6q2lo87IAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = DatasetDict({\n",
    "\"train\": Dataset.from_dict({\n",
    "    \"img\": ours_images,\n",
    "})}\n",
    "                     )\n",
    "ours_dataset = dataset[\"train\"]\n",
    "ours_dataset[0][\"img\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cbbba78-5c51-45ed-a044-4001a18949bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3ac1bf0e-897b-4ca7-a287-6f470dc956a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args():\n",
    "    resolution=32\n",
    "    center_crop=True\n",
    "    random_flip=False\n",
    "args=Args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "fb0345d7-02df-4b53-a8bf-1f95d148ee12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "\n",
    "if True:\n",
    "    augmentations = transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n",
    "            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize([0.5], [0.5]),\n",
    "        ]\n",
    "    )\n",
    "    def transform_images(examples):\n",
    "        # images = [augmentations(image.convert(\"RGB\")) for image in examples[\"image\"]]\n",
    "        images = [augmentations(image.convert(\"RGB\")) for image in examples[\"img\"]]\n",
    "        return {\"input\": images}\n",
    "\n",
    "    gen_dataset.set_transform(transform_images)\n",
    "    \n",
    "    random_dataset.set_transform(transform_images)\n",
    "    \n",
    "    trak_dataset.set_transform(transform_images)\n",
    "\n",
    "    ours_dataset.set_transform(transform_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b4631d7f-ad40-431e-abb5-ebf32b38e708",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_features = torch.stack([gen_dataset[i]['input'] for i in range(60)]).reshape(-1, 3*32*32)\n",
    "random_features = torch.stack([random_dataset[i]['input'] for i in range(60)]).reshape(-1, 3*32*32)\n",
    "trak_features = torch.stack([trak_dataset[i]['input'] for i in range(60)]).reshape(-1, 3*32*32)\n",
    "ours_features = torch.stack([ours_dataset[i]['input'] for i in range(60)]).reshape(-1, 3*32*32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "dd5ab3bf-1943-40dd-b657-6183d2d1315e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_features_array = gen_features.numpy()\n",
    "random_features_array = random_features.numpy()\n",
    "trak_features_array = trak_features.numpy()\n",
    "ours_features_array = ours_features.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80ea5261-7b8c-4477-ab11-1900691f4fd1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "6d291f63-d573-4d27-86ca-a8583455b7ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "0b046bc3-28b7-47be-94e1-d4b6693daa18",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_scores = np.linalg.norm(gen_features_array-gen_features_array, axis=1)\n",
    "random_scores = np.linalg.norm(random_features_array-gen_features_array, axis=1)\n",
    "trak_scores = np.linalg.norm(trak_features_array-gen_features_array, axis=1)\n",
    "ours_scores = np.linalg.norm(ours_features_array-gen_features_array, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6148e65a-eedd-4e25-b827-5eededa0315f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.514583"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random_scores.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "085325fe-b4a7-4c3b-9023-30d9c9b78968",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.712385"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.median(random_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "84e96291-00f1-42e9-bb8c-89caca63b351",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.4391108"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trak_scores.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "97805b62-f894-47ae-bdb1-a7f74bf091fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.901608"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.median(trak_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "1c6d7429-34d7-4d28-a40b-771f47d3f245",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.462168"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ours_scores.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "de78d5dc-d975-49c4-b4e3-81ef2aee933b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8.965666"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.median(ours_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "10e850ed-2b9a-4270-ab6b-4766ff667bb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Random</th>\n",
       "      <th>TRAK</th>\n",
       "      <th>D-TRAK</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.788872</td>\n",
       "      <td>4.483785</td>\n",
       "      <td>7.861606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6.189478</td>\n",
       "      <td>6.249662</td>\n",
       "      <td>19.911118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.807986</td>\n",
       "      <td>5.988188</td>\n",
       "      <td>7.086569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6.362907</td>\n",
       "      <td>5.454236</td>\n",
       "      <td>7.952570</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6.149230</td>\n",
       "      <td>7.665555</td>\n",
       "      <td>6.885323</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     Random      TRAK     D-TRAK\n",
       "0  2.788872  4.483785   7.861606\n",
       "1  6.189478  6.249662  19.911118\n",
       "2  3.807986  5.988188   7.086569\n",
       "3  6.362907  5.454236   7.952570\n",
       "4  6.149230  7.665555   6.885323"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame()\n",
    "df.head()\n",
    "df['Random'] = random_scores\n",
    "df['TRAK'] = trak_scores\n",
    "df['D-TRAK'] = ours_scores\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cdc057-f2be-4b1d-9659-d419b303ff05",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a70b172-b850-486d-9295-f65db8c72ae6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e768c0-cee6-404c-9fe0-c8e75072101c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
