{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a45c719",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os, json, re\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import PIL\n",
    "from PIL import Image\n",
    "import scipy.io as sio\n",
    "from matplotlib import cm\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "# import mne"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0002494a",
   "metadata": {},
   "source": [
    "# COIL-20"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6279b913",
   "metadata": {},
   "source": [
    "COIL-20 is a set of 1440 greyscale images consisting of 20 objects under 72 different rotations spanning 360 degrees. Each image is a 128x128\n",
    "image which we treat as a single 16384 dimensional vector for the purposes\n",
    "of computing distance between images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe1ca1e1-608a-4c97-9a7e-674c8884e55c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9470a79-18d6-4747-a2e2-628fa9bc6901",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p data/coil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d22427-5d22-415e-86ba-cd13ef0d6ef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "\n",
    "# URL of the Coil-20 dataset\n",
    "dataset_url = \"http://www.cs.columbia.edu/CAVE/databases/SLAM_coil-20_coil-100/coil-20/coil-20-proc.zip\"\n",
    "\n",
    "# Directory to save the downloaded dataset\n",
    "download_dir = \"data\"\n",
    "\n",
    "# Create the download directory if it doesn't exist\n",
    "if not os.path.exists(download_dir):\n",
    "    os.makedirs(download_dir)\n",
    "\n",
    "# Path to save the downloaded ZIP file\n",
    "zip_file_path = os.path.join(download_dir, \"coil-20.zip\")\n",
    "\n",
    "# Download the ZIP file\n",
    "urllib.request.urlretrieve(dataset_url, zip_file_path)\n",
    "\n",
    "print(\"Dataset downloaded successfully!\")\n",
    "\n",
    "# You can now extract the contents of the ZIP file using a library like zipfile.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eff5cfc8-0a84-425e-a2e2-674a341c8c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import zipfile\n",
    "import os\n",
    "\n",
    "# Path to the downloaded ZIP file\n",
    "zip_file_path = \"data/coil-20.zip\"\n",
    "\n",
    "# Directory where you want to extract the contents\n",
    "!mkdir data/COIL-20\n",
    "extracted_dir = \"data/COIL-20\"\n",
    "\n",
    "# Create the extraction directory if it doesn't exist\n",
    "os.makedirs(extracted_dir, exist_ok=True)\n",
    "\n",
    "# Open the ZIP file\n",
    "with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
    "    # Extract all contents to the specified directory\n",
    "    zip_ref.extractall(extracted_dir)\n",
    "\n",
    "print(\"ZIP file extracted successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a9c44b",
   "metadata": {},
   "outputs": [],
   "source": [
    "filenames = os.listdir(\"data/COIL-20/coil-20-proc/\")\n",
    "dirname = \"data/COIL-20/coil-20-proc/\"\n",
    "\n",
    "labels = []\n",
    "data = []\n",
    "for file in tqdm(filenames):\n",
    "    img = Image.open(dirname + file)\n",
    "    objId, imgId = file.split('__')\n",
    "    imgId = int(imgId[:-4])\n",
    "    objId = int(objId[3:])\n",
    "    data.append(np.array(img))\n",
    "    labels.append(objId)\n",
    "data = np.asarray(data)\n",
    "labels = np.asarray(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9068351d",
   "metadata": {},
   "outputs": [],
   "source": [
    "objId = 1\n",
    "fig, axes = plt.subplots(9, 8, figsize=(2*9, 2*8))\n",
    "\n",
    "for i, ax in enumerate(axes.flatten()):\n",
    "    ax.imshow(data[labels==objId][i])\n",
    "    ax.get_xaxis().set_visible(False)\n",
    "    ax.get_yaxis().set_visible(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67518aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# demonstration\n",
    "img.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "562d8553-422c-46c1-8bd1-0b753640951c",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d45f9168",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data/COIL-20/prepared\n",
    "np.save('data/COIL-20/prepared/data.npy', data)\n",
    "np.save('data/COIL-20/prepared/labels.npy', labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78369b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbfc7273",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = (labels == 1) | (labels == 2) | (labels == 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf1d26b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clabels = labels[ids]\n",
    "cdata = data[ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca698c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "cdata.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76c9f982-cd5d-49c4-b78e-b457605ae1dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d8274c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('data/COIL-20/prepared/data_3obj.npy', cdata)\n",
    "np.save('data/COIL-20/prepared/labels_3obj.npy', clabels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da311f69",
   "metadata": {},
   "source": [
    "# COIL-100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "916b5464",
   "metadata": {},
   "source": [
    "COIL-100 is a set of 7200 colour images consisting of 100 objects under 72 different rotations spanning 360 degrees. Each image consists of 3 128x128 intensity matrices (one for each color channel). We treat this as a single 49152 dimensional vector for the purposes of computing distance between images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17dab81b-8154-4d6c-a464-ee6141daf6c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import zipfile\n",
    "\n",
    "# URL of the Coil-100 dataset\n",
    "dataset_url = \"http://www.cs.columbia.edu/CAVE/databases/SLAM_coil-20_coil-100/coil-100/coil-100.zip\"\n",
    "\n",
    "# Directory to save the downloaded dataset\n",
    "!mkdir data/COIL-100\n",
    "download_dir = \"data/COIL-100\"\n",
    "\n",
    "# Create the download directory if it doesn't exist\n",
    "if not os.path.exists(download_dir):\n",
    "    os.makedirs(download_dir)\n",
    "\n",
    "# Path to save the downloaded ZIP file\n",
    "zip_file_path = os.path.join(download_dir, \"coil-100.zip\")\n",
    "\n",
    "# Download the ZIP file\n",
    "urllib.request.urlretrieve(dataset_url, zip_file_path)\n",
    "\n",
    "print(\"Dataset downloaded successfully!\")\n",
    "\n",
    "# Unzip the downloaded ZIP file\n",
    "with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n",
    "    zip_ref.extractall(download_dir)\n",
    "\n",
    "print(\"Dataset unzipped successfully!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29cd5074-0cd3-46bf-a0fc-37cd0999206b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/COIL-100/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c11d5172-0ba1-4611-9e82-b84ccdac8fff",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mv data/COIL-100/coil-100 data/COIL-100/images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2523cabd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dirname = \"data/COIL-100/images/\"\n",
    "filenames = os.listdir(dirname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f35886f",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = []\n",
    "data = []\n",
    "for file in tqdm(filenames):\n",
    "    if re.match(\"obj[]+__[0-9]+\\.png\", file) is None:\n",
    "        continue\n",
    "    img = Image.open(dirname + file)\n",
    "    objId, imgId = file.split('__')\n",
    "    imgId = int(imgId[:-4])\n",
    "    objId = int(objId[3:])\n",
    "    data.append(np.array(img))\n",
    "    labels.append([objId, imgId])\n",
    "data = np.asarray(data)\n",
    "labels = np.asarray(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a0165e",
   "metadata": {},
   "outputs": [],
   "source": [
    "img.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a0710a",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data/COIL-100/prepared\n",
    "np.save('data/COIL-100/prepared/data.npy', data)\n",
    "np.save('data/COIL-100/prepared/labels.npy', labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf5cc8eb",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8059aab-69e9-4467-9414-97a2b8d9aa9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import gzip\n",
    "import shutil\n",
    "\n",
    "# URLs for the MNIST dataset files\n",
    "images_url = \"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\"\n",
    "labels_url = \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"\n",
    "\n",
    "# Directory to save the downloaded dataset\n",
    "#!mkdir data/MNIST\n",
    "download_dir = \"data/MNIST\"\n",
    "\n",
    "# Create the download directory if it doesn't exist\n",
    "if not os.path.exists(download_dir):\n",
    "    os.makedirs(download_dir)\n",
    "\n",
    "# Function to download and extract a gzip file\n",
    "def download_and_extract(url, file_path):\n",
    "    urllib.request.urlretrieve(url, file_path + '.gz')\n",
    "    with gzip.open(file_path + '.gz', 'rb') as f_in:\n",
    "        with open(file_path, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "    #os.remove(file_path + '.gz')\n",
    "\n",
    "# Download and extract the images file\n",
    "images_file_path = os.path.join(download_dir, \"train-images-idx3-ubyte\")\n",
    "download_and_extract(images_url, images_file_path)\n",
    "\n",
    "# Download and extract the labels file\n",
    "labels_file_path = os.path.join(download_dir, \"train-labels-idx1-ubyte\")\n",
    "download_and_extract(labels_url, labels_file_path)\n",
    "\n",
    "print(\"MNIST dataset downloaded and extracted successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1419495e-b9f5-4c09-87a0-02a6eaa9eebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "031dcc55",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip\n",
    "f = gzip.open('data/MNIST/train-images-idx3-ubyte.gz','r')\n",
    "\n",
    "image_size = 28\n",
    "num_images = 60000\n",
    "\n",
    "f.read(16)\n",
    "buf = f.read(image_size * image_size * num_images)\n",
    "data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n",
    "data = data.reshape(num_images, image_size, image_size)\n",
    "\n",
    "f = gzip.open('data/MNIST/train-labels-idx1-ubyte.gz','r')\n",
    "f.read(8)\n",
    "labels = []\n",
    "for i in range(num_images):   \n",
    "    buf = f.read(1)\n",
    "    labels.append(np.frombuffer(buf, dtype=np.uint8).astype(np.int64)[0])\n",
    "labels = np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d25c857",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data/MNIST/prepared/\n",
    "np.save('data/MNIST/prepared/train_data.npy', data)\n",
    "np.save('data/MNIST/prepared/train_labels.npy', labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c3ae59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# demonstration\n",
    "Image.fromarray(data[0].astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166c2c62-c102-44b4-bcc9-ad4ef4878308",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import gzip\n",
    "import shutil\n",
    "\n",
    "# URLs for the MNIST dataset files\n",
    "test_images_url = \"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\"\n",
    "test_labels_url = \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"\n",
    "\n",
    "# Directory to save the downloaded dataset\n",
    "#!mkdir data/MNIST\n",
    "download_dir = \"data/MNIST\"\n",
    "\n",
    "# Create the download directory if it doesn't exist\n",
    "if not os.path.exists(download_dir):\n",
    "    os.makedirs(download_dir)\n",
    "\n",
    "# Function to download and extract a gzip file\n",
    "def download_and_extract(url, file_path):\n",
    "    urllib.request.urlretrieve(url, file_path + '.gz')\n",
    "    with gzip.open(file_path + '.gz', 'rb') as f_in:\n",
    "        with open(file_path, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "    #os.remove(file_path + '.gz')\n",
    "\n",
    "# Download and extract the images file\n",
    "images_file_path = os.path.join(download_dir, \"test-images-idx3-ubyte\")\n",
    "download_and_extract(test_images_url, images_file_path)\n",
    "\n",
    "# Download and extract the labels file\n",
    "labels_file_path = os.path.join(download_dir, \"test-labels-idx1-ubyte\")\n",
    "download_and_extract(test_labels_url, labels_file_path)\n",
    "\n",
    "print(\"MNIST dataset downloaded and extracted successfully!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d69d231-d7b1-4c7d-bbab-76ea7fb46fcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e71a0c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import gzip\n",
    "f = gzip.open('data/MNIST/test-images-idx3-ubyte.gz','r')\n",
    "\n",
    "image_size = 28\n",
    "num_images = 10000\n",
    "\n",
    "f.read(16)\n",
    "buf = f.read(image_size * image_size * num_images)\n",
    "data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n",
    "data = data.reshape(num_images, image_size, image_size)\n",
    "\n",
    "f = gzip.open('data/MNIST/test-labels-idx1-ubyte.gz','r')\n",
    "f.read(8)\n",
    "labels = []\n",
    "for i in range(num_images):   \n",
    "    buf = f.read(1)\n",
    "    labels.append(np.frombuffer(buf, dtype=np.uint8).astype(np.int64)[0])\n",
    "labels = np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6a4c16c",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('data/MNIST/prepared/test_data.npy', data)\n",
    "np.save('data/MNIST/prepared/test_labels.npy', labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "320099f0",
   "metadata": {},
   "source": [
    "# F-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b42691f-a9da-4e89-9526-ea1af47a9d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import gzip\n",
    "import shutil\n",
    "\n",
    "\n",
    "download_dir = \"data/F-MNIST\"\n",
    "\n",
    "def download_and_extract(url, download_dir):\n",
    "    if not os.path.exists(download_dir):\n",
    "        os.makedirs(download_dir)\n",
    "\n",
    "    filename = os.path.basename(url)\n",
    "    gzip_file_path = os.path.join(download_dir, filename)\n",
    "    binary_file_path = os.path.splitext(gzip_file_path)[0]\n",
    "\n",
    "    urllib.request.urlretrieve(url, gzip_file_path)\n",
    "\n",
    "    with gzip.open(gzip_file_path, 'rb') as f_in:\n",
    "        with open(binary_file_path, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "\n",
    "    #os.remove(gzip_file_path)\n",
    "\n",
    "    print(f\"{filename} downloaded and extracted successfully!\")\n",
    "\n",
    "# URLs for F-MNIST test images and labels\n",
    "test_images_url = \"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\"\n",
    "test_labels_url = \"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\"\n",
    "# URLs for F-MNIST train images and labels\n",
    "train_images_url = \"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\"\n",
    "train_labels_url = \"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\"\n",
    "\n",
    "# Download and extract F-MNIST test images and labels\n",
    "download_and_extract(test_images_url, download_dir)\n",
    "download_and_extract(test_labels_url, download_dir)\n",
    "\n",
    "# Download and extract F-MNIST train images and labels\n",
    "download_and_extract(train_images_url, download_dir)\n",
    "download_and_extract(train_labels_url, download_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "327a1eea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip\n",
    "f = gzip.open('data/F-MNIST/train-images-idx3-ubyte.gz','r')\n",
    "\n",
    "image_size = 28\n",
    "num_images = 60000\n",
    "\n",
    "f.read(16)\n",
    "buf = f.read(image_size * image_size * num_images)\n",
    "data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n",
    "data = data.reshape(num_images, image_size, image_size)\n",
    "\n",
    "f = gzip.open('data/F-MNIST/train-labels-idx1-ubyte.gz','r')\n",
    "f.read(8)\n",
    "labels = []\n",
    "for i in range(num_images):   \n",
    "    buf = f.read(1)\n",
    "    labels.append(np.frombuffer(buf, dtype=np.uint8).astype(np.int64)[0])\n",
    "labels = np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52ad57cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# demonstration\n",
    "Image.fromarray(data[0].astype(np.uint8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7537b170",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data/F-MNIST/prepared\n",
    "np.save('data/F-MNIST/prepared/train_data.npy', data)\n",
    "np.save('data/F-MNIST/prepared/train_labels.npy', labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b40fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = gzip.open('data/F-MNIST/t10k-images-idx3-ubyte.gz','r')\n",
    "\n",
    "image_size = 28\n",
    "num_images = 10000\n",
    "\n",
    "f.read(16)\n",
    "buf = f.read(image_size * image_size * num_images)\n",
    "data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n",
    "data = data.reshape(num_images, image_size, image_size)\n",
    "\n",
    "f = gzip.open('data/F-MNIST/t10k-labels-idx1-ubyte.gz','r')\n",
    "f.read(8)\n",
    "labels = []\n",
    "for i in range(num_images):   \n",
    "    buf = f.read(1)\n",
    "    labels.append(np.frombuffer(buf, dtype=np.uint8).astype(np.int64)[0])\n",
    "labels = np.array(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16efb030",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('data/F-MNIST/prepared/test_data.npy', data)\n",
    "np.save('data/F-MNIST/prepared/test_labels.npy', labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fa58e3e",
   "metadata": {},
   "source": [
    "# CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cd49d60",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9025ed5f-bd37-4331-b298-d40d1b7aed7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import tarfile\n",
    "\n",
    "# URL for the CIFAR-10 dataset\n",
    "cifar10_url = \"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\"\n",
    "\n",
    "# Directory to save the downloaded dataset\n",
    "download_dir = \"data/CIFAR-10\"\n",
    "\n",
    "# Create the download directory if it doesn't exist\n",
    "if not os.path.exists(download_dir):\n",
    "    os.makedirs(download_dir)\n",
    "\n",
    "# Path to save the downloaded tar.gz file\n",
    "tar_file_path = os.path.join(download_dir, \"cifar-10-python.tar.gz\")\n",
    "\n",
    "# Download the tar.gz file\n",
    "urllib.request.urlretrieve(cifar10_url, tar_file_path)\n",
    "\n",
    "# Extract the tar.gz file\n",
    "with tarfile.open(tar_file_path, 'r:gz') as tar:\n",
    "    tar.extractall(download_dir)\n",
    "\n",
    "# Remove the downloaded tar.gz file\n",
    "#os.remove(tar_file_path)\n",
    "\n",
    "print(\"CIFAR-10 dataset downloaded and extracted successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b31aeef-5c14-4514-9a9c-66ccbae79f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/CIFAR-10/cifar-10-batches-py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de20548",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unpickle(file):\n",
    "    with open(file, 'rb') as fo:\n",
    "        d = pickle.load(fo, encoding='bytes')\n",
    "    return d\n",
    "\n",
    "dirname = 'data/CIFAR-10/cifar-10-batches-py/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d9afe31",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(dirname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "952228cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_files = sorted([file for file in os.listdir(dirname) if 'data_batch' in file])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47ee4d60",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = []\n",
    "data = []\n",
    "for file in train_files:\n",
    "    loaded = unpickle(dirname+file)\n",
    "    data.append(loaded[b'data'])\n",
    "    labels.extend(loaded[b'labels'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46aeb9e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.concatenate(data).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c715c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data/CIFAR-10/prepared\n",
    "np.save('data/CIFAR-10/prepared/train_labels.npy', np.array(labels))\n",
    "np.save('data/CIFAR-10/prepared/train_data.npy', np.concatenate(data, axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fd60f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded = unpickle(dirname+'test_batch')\n",
    "test_data = loaded[b'data']\n",
    "test_labels = np.array(loaded[b'labels'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26bbb84",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('data/CIFAR-10/prepared/test_labels.npy', test_labels)\n",
    "np.save('data/CIFAR-10/prepared/test_data.npy', test_data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
