{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "from hydra.utils import instantiate\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../../conf/dataset/\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"tiny_imagenet\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref = \"Maysee/tiny-imagenet\"\n",
    "\n",
    "train_split = \"train\"\n",
    "test_split = \"valid\"\n",
    "\n",
    "\n",
    "label_key = \"label\"\n",
    "image_key = \"image\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_hf_dataset = load_dataset(\n",
    "    ref,\n",
    "    split=train_split,\n",
    "    use_auth_token=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_hf_dataset = load_dataset(\n",
    "    ref,\n",
    "    split=test_split,\n",
    "    use_auth_token=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "from datasets import load_dataset\n",
    "\n",
    "dataset_name = \"tiny_imagenet\"\n",
    "splits = [\"train\", \"test\"]\n",
    "\n",
    "base_dir = \"../../data/\" + dataset_name\n",
    "\n",
    "datasets = {\"train\": train_hf_dataset, \"test\": test_hf_dataset}\n",
    "\n",
    "for split in splits:\n",
    "    classes = datasets[split].features[\"label\"].names\n",
    "\n",
    "    for class_name in classes:\n",
    "        os.makedirs(os.path.join(base_dir, split, class_name), exist_ok=True)\n",
    "\n",
    "    for i, example in enumerate(datasets[split]):\n",
    "        image = example[\"image\"]\n",
    "        label = example[\"label\"]\n",
    "        class_name = classes[label]\n",
    "\n",
    "        image_path = os.path.join(base_dir, split, class_name, f\"{i}.jpg\")\n",
    "        image.save(image_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
