{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torchvision import datasets \n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from PIL import Image\n",
    "import os\n",
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "import copy\n",
    "import sys\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import tqdm\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torchvision import datasets \n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from PIL import Image\n",
    "import os\n",
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "\n",
    "\n",
    "class DatasetZoo(Dataset):\n",
    "    def __init__(self, root, dataset='cifar10', dataidxs=None, train=True, transform=None, target_transform=None,\n",
    "                 download=False, p_data=1.0, seed=2023):\n",
    "        \n",
    "        self.root = root\n",
    "        self.dataset = dataset\n",
    "        self.dataidxs = dataidxs\n",
    "        self.train = train\n",
    "        self.transform = transform\n",
    "        self.target_transform = target_transform\n",
    "        self.download = download\n",
    "        self.p_data = p_data\n",
    "        self.seed = seed\n",
    "        \n",
    "        self.data, self.target, self.dataobj, self.mode = self.__init_dataset__()\n",
    "\n",
    "    def load_tinyimagenet_data(self, datadir):\n",
    "        transform_train = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.4807, 0.4574, 0.4083), (0.2056, 0.2035, 0.2041)),\n",
    "        ])\n",
    "        \n",
    "        transform_val = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.4807, 0.4574, 0.4083), (0.2056, 0.2035, 0.2041)),\n",
    "        ])\n",
    "        \n",
    "        # Load TinyImageNet datasets\n",
    "        train_ds = datasets.ImageNet(root=datadir , split='train', transform=transform_train)\n",
    "        val_ds = datasets.ImageNet(root=datadir , split='val', transform=transform_val)\n",
    "        \n",
    "        # Define the 100 labels to select\n",
    "        labels = [847, 874, 471, 476, 764, 138,  49, 226, 100, 426, 815, 836, 338,\n",
    "                669, 743, 912, 320, 843, 796, 322, 261, 136, 841, 460, 699, 935,\n",
    "                949, 877,  61, 332, 416,  35, 227, 493,  32, 478, 660,  13, 451,\n",
    "                438, 323, 867, 168,  40, 156, 455, 691, 223, 354, 495, 799, 432,\n",
    "                158, 866, 657, 768, 183, 852, 727, 249, 402, 507,  12, 880, 995,\n",
    "                233, 176, 776, 830, 586, 865, 475, 610, 534, 953, 490, 160, 386,\n",
    "                117, 942, 675,  24, 538, 494, 266, 295, 272,  11,   9, 154, 967,\n",
    "                901, 123, 649, 737, 121,  20, 439, 641, 398]\n",
    "        \n",
    "        subset_train = np.array([], dtype='int')\n",
    "        train_targets = np.array(train_ds.targets)\n",
    "        for label in labels:\n",
    "            subset_train = np.hstack([subset_train, np.where(train_targets == label)[0][0:500]])\n",
    "        public_ds_train = torch.utils.data.Subset(train_ds, subset_train)\n",
    "        \n",
    "        subset_val = np.array([], dtype='int')\n",
    "        val_targets = np.array(val_ds.targets)\n",
    "        for label in labels:\n",
    "            subset_val = np.hstack([subset_val, np.where(val_targets == label)[0]])\n",
    "        public_ds_val = torch.utils.data.Subset(val_ds, subset_val)\n",
    "        \n",
    "        return public_ds_train, public_ds_val\n",
    "\n",
    "    def __init_dataset__(self):\n",
    "        \n",
    "        if self.dataset == 'mnist':\n",
    "            dataobj = datasets.MNIST(self.root, self.train, self.transform, self.target_transform, self.download)\n",
    "            mode = 'L'\n",
    "        elif self.dataset == 'usps':\n",
    "            dataobj = datasets.USPS(self.root, self.train, self.transform, self.target_transform, self.download)\n",
    "            mode = 'L'\n",
    "        elif self.dataset == 'fmnist':\n",
    "            dataobj = datasets.FashionMNIST(self.root, self.train, self.transform, self.target_transform, self.download)\n",
    "            mode = 'L'\n",
    "        elif self.dataset == 'cifar10':\n",
    "            dataobj = datasets.CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)\n",
    "            mode = 'RGB'\n",
    "        elif self.dataset == 'cifar100':\n",
    "            dataobj = datasets.CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download)\n",
    "            mode = 'RGB'\n",
    "        elif self.dataset == 'svhn':\n",
    "            if self.train:\n",
    "                dataobj = datasets.SVHN(self.root, 'train', self.transform, self.target_transform, self.download)\n",
    "            else:\n",
    "                dataobj = datasets.SVHN(self.root, 'test', self.transform, self.target_transform, self.download)\n",
    "            mode = 'RGB'\n",
    "        elif self.dataset == 'stl10':\n",
    "            if self.train:\n",
    "                dataobj = datasets.STL10(self.root, 'train', self.transform, self.target_transform, self.download)\n",
    "            else:\n",
    "                dataobj = datasets.STL10(self.root, 'test', self.transform, self.target_transform, self.download)\n",
    "            mode = 'RGB'\n",
    "        elif self.dataset == 'celeba':\n",
    "            X_train, y_train, X_test, y_test = load_celeba_data(datadir)\n",
    "            mode = 'RGB'\n",
    "        elif self.dataset == 'femnist':\n",
    "            X_train, y_train, u_train, X_test, y_test, u_test = load_femnist_data(datadir)\n",
    "            mode = 'L'\n",
    "        elif self.dataset == 'tinyimagenet':\n",
    "            train_ds, val_ds = self.load_tinyimagenet_data(\"/home/mahdi/codes/data/imagenet_resized/\")\n",
    "            mode = 'RGB'\n",
    "            dataobj = train_ds if self.train else val_ds\n",
    "            data = np.array([np.array(img) for img, _ in dataobj])\n",
    "            print(data.dtype)\n",
    "            print(\"length of the data\",len(data))\n",
    "            # data = np.array([np.transpose(img, (1, 2, 0)) for img in data])\n",
    "            \n",
    "            print(\"imagenet\")\n",
    "            print(data.dtype)\n",
    "            print(\"length of the data\",len(data))\n",
    "            print(data.shape)\n",
    "            target = np.array([label for _, label in dataobj])\n",
    "            print(\"length of the target\",len(target))\n",
    "\n",
    "            if self.dataidxs is not None:\n",
    "                data = data[self.dataidxs]\n",
    "                target = target[self.dataidxs]\n",
    "            \n",
    "            if self.dataidxs is None: \n",
    "                idxs_data = np.arange(len(data))\n",
    "                idxs_target = np.arange(len(target))\n",
    "                \n",
    "                perm_data = np.random.RandomState(seed=self.seed).permutation(len(target))\n",
    "                #perm_data = idxs_data  #np.random.permutation(idxs_data) \n",
    "\n",
    "                p_data1 = int(len(idxs_data)*self.p_data)\n",
    "                perm_data = perm_data[0:p_data1]\n",
    "\n",
    "                data = data[perm_data] \n",
    "                target = target[perm_data]\n",
    "           \n",
    "\n",
    "            return data, target, dataobj, mode\n",
    "        elif self.dataset == 'cinic10':\n",
    "            cinic_mean = [0.47889522, 0.47227842, 0.43047404]\n",
    "            cinic_std = [0.24205776, 0.23828046, 0.25874835]\n",
    "            transform = transforms.Compose([\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(mean=cinic_mean, std=cinic_std)\n",
    "            ])\n",
    "            mode=\"RGB\"\n",
    "            \n",
    "            mode = \"RGB\"\n",
    "    \n",
    "            train_ds = datasets.ImageFolder(root='/home/mahdi/codes/data/cinic-10/train', transform=transform)\n",
    "            val_ds = datasets.ImageFolder(root='/home/mahdi/codes/data/cinic-10/valid', transform=transform)\n",
    "            test_ds = datasets.ImageFolder(root='/home/mahdi/codes/data/cinic-10/test', transform=transform)\n",
    "            whole_ds = torch.utils.data.ConcatDataset([train_ds, val_ds, test_ds])\n",
    "\n",
    "            train_size = 100000\n",
    "            length = len(whole_ds)\n",
    "            indices = list(range(length))\n",
    "            \n",
    "            # Shuffle indices with fixed seed\n",
    "            np.random.seed(42)\n",
    "            np.random.shuffle(indices)\n",
    "\n",
    "            train_indices = indices[:train_size]\n",
    "            test_indices = indices[train_size:train_size + 30000]\n",
    "\n",
    "\n",
    "            train_ds_global = torch.utils.data.Subset(whole_ds, train_indices)\n",
    "            test_ds_global = torch.utils.data.Subset(whole_ds, test_indices)\n",
    "            \n",
    "            # Assign dataobj for train and test\n",
    "            dataobj = train_ds_global if self.train else test_ds_global\n",
    "            \n",
    "            # Extract data and targets\n",
    "            x_data, y_data = zip(*[(img, label) for img, label in dataobj])\n",
    "            x_data = torch.stack(x_data)\n",
    "            y_data = torch.tensor(y_data)\n",
    "            \n",
    "            print(f\"x_data: {len(x_data)}\")\n",
    "            print(f\"y_data: {len(y_data)}\")\n",
    "            print(f\"x_data: {x_data.shape}\")\n",
    "            \n",
    "            # Apply permutation and filtering\n",
    "            idxs_data = np.arange(len(x_data))\n",
    "            perm_data = np.random.RandomState(seed=self.seed).permutation(len(idxs_data))\n",
    "            p_data1 = min(int(len(idxs_data) * self.p_data), len(idxs_data))\n",
    "            perm_data = perm_data[:p_data1]\n",
    "            \n",
    "            x_data = x_data[perm_data]\n",
    "            y_data = y_data[perm_data]\n",
    "            \n",
    "            return x_data, y_data, dataobj, mode\n",
    "\n",
    "        data = np.array(dataobj.data)\n",
    "        try:\n",
    "            target = np.array(dataobj.targets)\n",
    "        except:\n",
    "            target = np.array(dataobj.labels)\n",
    "        \n",
    "        if data.shape[2]==data.shape[3]:\n",
    "            data = data.transpose(0,2,3,1) ## STL-10\n",
    "\n",
    "        if self.dataidxs is not None:\n",
    "            data = data[self.dataidxs]\n",
    "            target = target[self.dataidxs]\n",
    "            \n",
    "        if self.dataidxs is None: \n",
    "            idxs_data = np.arange(len(data))\n",
    "            idxs_target = np.arange(len(target))\n",
    "            \n",
    "            perm_data = np.random.RandomState(seed=self.seed).permutation(len(target))\n",
    "            #perm_data = idxs_data  #np.random.permutation(idxs_data) \n",
    "\n",
    "            p_data1 = int(len(idxs_data)*self.p_data)\n",
    "            perm_data = perm_data[0:p_data1]\n",
    "\n",
    "            data = data[perm_data] \n",
    "            target = target[perm_data]\n",
    "\n",
    "        return data, target, dataobj, mode\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            index (int): Index\n",
    "\n",
    "        Returns:\n",
    "            tuple: (image, target) where target is index of the target class.\n",
    "        \"\"\"\n",
    "        img, target = self.data[index], self.target[index]\n",
    "        if(self.dataset==\"tinyimagenet\"):\n",
    "            if img.shape[0] == 3:  \n",
    "                img = np.transpose(img, (1, 2, 0))\n",
    "        \n",
    "        if(self.dataset==\"tinyimagenet\"):\n",
    "            img = Image.fromarray((img).astype(np.uint8), mode=self.mode)\n",
    "        elif(self.dataset==\"cinic10\"):\n",
    "            img = Image.fromarray(np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8))\n",
    "        else:\n",
    "            img = Image.fromarray(img, mode=self.mode)\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        if self.target_transform is not None:\n",
    "            target = self.target_transform(target)\n",
    "            target.type(torch.LongTensor)\n",
    "\n",
    "        return img, target\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install --upgrade tqdm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ../../data/cifar-100-python.tar.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Widget Javascript not detected.  It may not be installed or enabled properly.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df46a7de0f2d4c1e851b8603fc87b4c9"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "AttributeError",
     "evalue": "'FloatProgress' object has no attribute 'style'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-f41bcb75b399>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0mdatadir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'../../data/'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m dataobj_train = DatasetZoo(datadir, dataset='cifar100', train=True, transform=transform_train, download=True, p_data=1,\n\u001b[0m\u001b[1;32m     34\u001b[0m                         seed=2023)\n\u001b[1;32m     35\u001b[0m dataobj_test = DatasetZoo(datadir, dataset='cifar100', train=False, transform=transform_train, download=True, p_data=1,\n",
      "\u001b[0;32m<ipython-input-2-a4e475e13416>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, dataset, dataidxs, train, transform, target_transform, download, p_data, seed)\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init_dataset__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mload_tinyimagenet_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatadir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-2-a4e475e13416>\u001b[0m in \u001b[0;36m__init_dataset__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     78\u001b[0m             \u001b[0mmode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'RGB'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'cifar100'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m             \u001b[0mdataobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCIFAR100\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     81\u001b[0m             \u001b[0mmode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'RGB'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     82\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svhn'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     64\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     67\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_integrity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36mdownload\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    141\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Files already downloaded and verified'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m             \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m         \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtgz_md5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[0;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[1;32m    314\u001b[0m         \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 316\u001b[0;31m     \u001b[0mdownload_url\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    317\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    318\u001b[0m     \u001b[0marchive\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[0;34m(url, root, filename, md5, max_redirect_hops)\u001b[0m\n\u001b[1;32m    132\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Downloading '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0murl\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m         \u001b[0m_urlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'https'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchvision/datasets/utils.py\u001b[0m in \u001b[0;36m_urlretrieve\u001b[0;34m(url, filename, chunk_size)\u001b[0m\n\u001b[1;32m     28\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murlopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"User-Agent\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUSER_AGENT\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m             \u001b[0;32mwith\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlength\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0mchunk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mchunk\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/tqdm/notebook.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    244\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplayed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    245\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 246\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolour\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcolour\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    247\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    248\u001b[0m         \u001b[0;31m# Print initial bar state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/tqdm/notebook.py\u001b[0m in \u001b[0;36mcolour\u001b[0;34m(self, bar_color)\u001b[0m\n\u001b[1;32m    203\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcolour\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbar_color\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    204\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'container'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstyle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbar_color\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbar_color\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    207\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'FloatProgress' object has no attribute 'style'"
     ]
    }
   ],
   "source": [
    "class AddGaussianNoise(object):\n",
    "    def __init__(self, mean=0., std=1., net_id=None, total=0):\n",
    "        self.std = std\n",
    "        self.mean = mean\n",
    "        self.net_id = net_id\n",
    "        self.num = int(np.sqrt(total))\n",
    "        if self.num * self.num < total:\n",
    "            self.num = self.num + 1\n",
    "\n",
    "    def __call__(self, tensor):\n",
    "        if self.net_id is None:\n",
    "            return tensor + torch.randn(tensor.size()) * self.std + self.mean\n",
    "        else:\n",
    "            tmp = torch.randn(tensor.size())\n",
    "            filt = torch.zeros(tensor.size())\n",
    "            size = int(28 / self.num)\n",
    "            row = int(self.net_id / size)\n",
    "            col = self.net_id % size\n",
    "            for i in range(size):\n",
    "                for j in range(size):\n",
    "                    filt[:,row*size+i,col*size+j] = 1\n",
    "            tmp = tmp * filt\n",
    "            return tensor + tmp * self.std + self.mean\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            AddGaussianNoise(), \n",
    "            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])\n",
    "        ])\n",
    "\n",
    "\n",
    "datadir='../../data/'\n",
    "dataobj_train = DatasetZoo(datadir, dataset='cifar100', train=True, transform=transform_train, download=True, p_data=1,\n",
    "                        seed=2023)\n",
    "dataobj_test = DatasetZoo(datadir, dataset='cifar100', train=False, transform=transform_train, download=True, p_data=1,\n",
    "                        seed=2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
