{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from TinyImageNet import TinyImageNet\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100000/100000 [00:18<00:00, 5528.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.4800912  0.44808727 0.39818029] [0.2285629  0.22575072 0.22453852]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "dataset = TinyImageNet(root='../data/')\n",
    "\n",
    "# find mean and std of each channel\n",
    "mean = 0.\n",
    "std = 0.\n",
    "\n",
    "for i in tqdm(range(len(dataset))):\n",
    "    img = dataset[i][0]\n",
    "    img = np.array(img) / 255.\n",
    "    mean += np.mean(img, axis=(0, 1))\n",
    "    std += np.std(img, axis=(0, 1))\n",
    "\n",
    "mean /= len(dataset)\n",
    "std /= len(dataset)\n",
    "\n",
    "print(mean, std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: opencv-python in /home/ejahns/miniconda3/envs/Norse/lib/python3.12/site-packages (4.10.0.84)\n",
      "Requirement already satisfied: numpy>=1.21.2 in /home/ejahns/miniconda3/envs/Norse/lib/python3.12/site-packages (from opencv-python) (1.26.4)\n",
      "Note: you may need to restart the kernel to use updated packages.\n",
      "Downloading the dataset, this may take a while...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tiny-imagenet-200.zip: 237MB [00:04, 54.9MB/s]                              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting the dataset, this may take a while...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Extracting: 100%|██████████| 120609/120609 [00:02<00:00, 41539.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating the dataset directory...\n",
      "Moving train images...\n",
      "Splitting original dataset images...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Splitting images: 100%|██████████| 200/200 [00:00<00:00, 1231.17class/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Copying processed dataset to tiny-64...\n",
      "Resizing images...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Resizing images: 100%|██████████| 110000/110000 [00:05<00:00, 20610.41file/s]\n"
     ]
    }
   ],
   "source": [
    "%pip install opencv-python\n",
    "import random\n",
    "import shutil\n",
    "from hashlib import md5\n",
    "from pathlib import Path\n",
    "from urllib.request import urlretrieve\n",
    "from zipfile import ZipFile\n",
    "import cv2\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "DATASET_URL = \"http://cs231n.stanford.edu/tiny-imagenet-200.zip\"\n",
    "# DATASET_URL = (\n",
    "#     \"https://github.com/tjmoon0104/pytorch-tiny-imagenet/releases/download/tiny-imagenet-dataset/tiny-imagenet-200.zip\"\n",
    "# )\n",
    "DATASET_ZIP = Path(\"../data/tiny-imagenet-200.zip\")\n",
    "DATASET_MD5_HASH = \"90528d7ca1a48142e341f4ef8d21d0de\"\n",
    "\n",
    "# Download Dataset if needed\n",
    "if not DATASET_ZIP.exists():\n",
    "    print(\"Downloading the dataset, this may take a while...\")\n",
    "\n",
    "    with tqdm(unit=\"B\", unit_scale=True, unit_divisor=1024, miniters=1, desc=DATASET_URL.split(\"/\")[-1]) as t:\n",
    "\n",
    "        def show_progress(block_num, block_size, total_size):\n",
    "            t.total = total_size\n",
    "            t.update(block_num * block_size - t.n)\n",
    "\n",
    "        urlretrieve(url=DATASET_URL, filename=DATASET_ZIP, reporthook=show_progress)\n",
    "\n",
    "# Check MD5 Hash\n",
    "with DATASET_ZIP.open(\"rb\") as f:\n",
    "    assert (\n",
    "        md5(f.read()).hexdigest() == DATASET_MD5_HASH\n",
    "    ), \"The dataset zip file seems corrupted. Try to download it again.\"\n",
    "\n",
    "\n",
    "# Remove existing data set\n",
    "ORIGINAL_DATASET_DIR = Path(\"../data/original\")\n",
    "if ORIGINAL_DATASET_DIR.exists():\n",
    "    shutil.rmtree(ORIGINAL_DATASET_DIR)\n",
    "\n",
    "if not ORIGINAL_DATASET_DIR.exists():\n",
    "    print(\"Extracting the dataset, this may take a while...\")\n",
    "\n",
    "    # Unzip the dataset\n",
    "    with ZipFile(DATASET_ZIP, \"r\") as zip_ref:\n",
    "        for member in tqdm(zip_ref.infolist(), desc=\"Extracting\"):\n",
    "            zip_ref.extract(member, ORIGINAL_DATASET_DIR)\n",
    "\n",
    "# Remove existing data set\n",
    "DATASET_DIR = Path(\"../data/tiny-imagenet-200\")\n",
    "if DATASET_DIR.exists():\n",
    "    shutil.rmtree(DATASET_DIR)\n",
    "\n",
    "# Create the dataset directory\n",
    "if not DATASET_DIR.exists():\n",
    "    print(\"Creating the dataset directory...\")\n",
    "    DATASET_DIR.mkdir()\n",
    "\n",
    "# Move train images to dataset directory\n",
    "ORIGINAL_TRAIN_DIR = ORIGINAL_DATASET_DIR / \"tiny-imagenet-200\" / \"train\"\n",
    "if ORIGINAL_TRAIN_DIR.exists():\n",
    "    print(\"Moving train images...\")\n",
    "    ORIGINAL_TRAIN_DIR.replace(DATASET_DIR / \"train\")\n",
    "\n",
    "# Get validation images and annotations\n",
    "val_dict = {}\n",
    "ORIGINAL_VAL_DIR = ORIGINAL_DATASET_DIR / \"tiny-imagenet-200\" / \"val\"\n",
    "with (ORIGINAL_VAL_DIR / \"val_annotations.txt\").open(\"r\") as f:\n",
    "    for line in f.readlines():\n",
    "        split_line = line.split(\"\\t\")\n",
    "        if split_line[1] not in val_dict.keys():\n",
    "            val_dict[split_line[1]] = [split_line[0]]\n",
    "        else:\n",
    "            val_dict[split_line[1]].append(split_line[0])\n",
    "\n",
    "\n",
    "def split_list_randomly(input_list: list[str], split_ratio=0.5) -> dict[str, list[str]]:\n",
    "    # Shuffle the input list in-place\n",
    "    random.shuffle(input_list)\n",
    "\n",
    "    # Calculate the index to split the list\n",
    "    split_index = int(len(input_list) * split_ratio)\n",
    "\n",
    "    # Split the list into two parts\n",
    "    return {\"val\": input_list[:split_index], \"test\": input_list[split_index:]}\n",
    "\n",
    "\n",
    "# Sample from validation images randomly into validation and test sets (50/50)\n",
    "print(\"Splitting original dataset images...\")\n",
    "with tqdm(val_dict.items(), desc=\"Splitting images\", unit=\"class\") as t:\n",
    "    for image_label, images in t:\n",
    "        for split_type, split_images in split_list_randomly(images, split_ratio=0.5).items():\n",
    "            for image in split_images:\n",
    "                src = ORIGINAL_VAL_DIR / \"images\" / image\n",
    "                dest_folder = DATASET_DIR / split_type / image_label / \"images\"\n",
    "                dest_folder.mkdir(parents=True, exist_ok=True)\n",
    "                src.replace(dest_folder / image)\n",
    "        t.update()\n",
    "\n",
    "# Remove original directory\n",
    "shutil.rmtree(ORIGINAL_DATASET_DIR)\n",
    "\n",
    "# Remove resized data set directory\n",
    "RESIZED_DIR = Path(\"../data/tiny-64\")\n",
    "if RESIZED_DIR.exists():\n",
    "    shutil.rmtree(RESIZED_DIR)\n",
    "\n",
    "# Copy processed dataset to tiny-224\n",
    "print(\"Copying processed dataset to tiny-64...\")\n",
    "shutil.copytree(DATASET_DIR, RESIZED_DIR)\n",
    "\n",
    "\n",
    "# Resize images to 224x224\n",
    "def resize_img(image_path: Path, size: int = 224) -> None:\n",
    "    img = cv2.imread(image_path.as_posix())\n",
    "    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)\n",
    "    cv2.imwrite(image_path.as_posix(), img)\n",
    "\n",
    "\n",
    "all_images = [*Path(\"../data/tiny-64\").glob(\"**/*.JPEG\")]\n",
    "print(\"Resizing images...\")\n",
    "with tqdm(all_images, desc=\"Resizing images\", unit=\"file\") as t:\n",
    "    for image in t:\n",
    "        resize_img(image, 64)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Norse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
