{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d66ff280-78f7-4f88-9b84-4d1496bc3ad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision import transforms\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision.transforms import Compose, Normalize, RandomVerticalFlip\n",
    "from torchvision import models\n",
    "class Animal10NDataset(Dataset):\n",
    "    def __init__(self, root_dir, transform=None):\n",
    "        self.root_dir = root_dir\n",
    "        self.transform = transform\n",
    "        self.image_paths = []\n",
    "        self.labels = []\n",
    "\n",
    "        for file_name in os.listdir(root_dir):\n",
    "            if file_name.endswith(('.png', '.jpg', '.jpeg')):\n",
    "                self.image_paths.append(os.path.join(root_dir, file_name))\n",
    "                label = int(file_name.split('_')[0])\n",
    "                self.labels.append(label)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_path = self.image_paths[idx]\n",
    "        image = Image.open(img_path).convert('RGB')\n",
    "        label = self.labels[idx]\n",
    "\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "\n",
    "        return image, label\n",
    "\n",
    "# Get the current working directory\n",
    "current_directory = os.getcwd()\n",
    "\n",
    "# Construct the paths for training and testing directories\n",
    "train_dir = os.path.join(current_directory, 'training')\n",
    "test_dir = os.path.join(current_directory, 'testing')\n",
    "\n",
    "# Define transformations\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.Resize(64),\n",
    "    transforms.RandomCrop(64),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),\n",
    "])\n",
    "\n",
    "# Load the datasets\n",
    "train_dataset = Animal10NDataset(root_dir=train_dir, transform=transform_train)\n",
    "test_dataset = Animal10NDataset(root_dir=test_dir, transform=transform_test)\n",
    "\n",
    "# Create data loaders\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "01f6abd5-b965-488c-9c6d-3736a8e08a79",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data1, train_targets1 = [], []\n",
    "for batch in train_loader:\n",
    "    images, labels = batch\n",
    "    train_data1.append(images)\n",
    "    train_targets1.append(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1de1a8c2-3d31-4792-a15e-7253c00fc713",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = torch.cat(train_data1, dim=0)\n",
    "train_targets = torch.cat(train_targets1, dim=0)\n",
    "\n",
    "test_data1, test_targets1 = [], []\n",
    "for batch in test_loader:\n",
    "    images, labels = batch\n",
    "    test_data1.append(images)\n",
    "    test_targets1.append(labels)\n",
    "\n",
    "test_data = torch.cat(test_data1, dim=0)\n",
    "test_targets = torch.cat(test_targets1, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86684cc6-8700-46c6-8c06-5f67b406c6fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90fa7d25-c1ab-4da2-bb25-32c2a00b31e3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cbc0f3af-dc18-4add-b277-0a76dea4aaf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(train_data, 'train_data.pt')\n",
    "torch.save(train_targets, 'train_targets.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5ad11d93-0b26-464f-aac4-a9f865765963",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(test_data, 'test_data.pt')\n",
    "torch.save(test_targets, 'test_targets.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c164861-0ba4-421b-8760-efbd6cac5b02",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0f0b89c7-783c-4457-aec2-b3a5aacd6071",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the current working directory\n",
    "current_directory = os.getcwd()\n",
    "\n",
    "# Construct the paths for training and testing directories\n",
    "train_dir = os.path.join(current_directory, 'training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd58e9ca-dce9-440f-ae1b-4b9926cf2668",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub data rate exceeded.\n",
      "The Jupyter server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--ServerApp.iopub_data_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
      "ServerApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# Set the directory you want to check\n",
    "directory = train_dir\n",
    "\n",
    "# List all files in the directory\n",
    "for file_name in os.listdir(directory):\n",
    "    file_path = os.path.join(directory, file_name)\n",
    "    print(file_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6ac2e3-e996-405b-a2ca-77842e356f90",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
