{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd264dbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a3ae41f",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_state = 0\n",
    "flipper = np.random.RandomState(random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "db705060",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 0, 1])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "P = np.array([0.1, 0.8, 0., 0.9])\n",
    "a = np.random.binomial(1, P, 4)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "72cf1504",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4cb82bcc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n",
       "        0.1000])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = 0.1*torch.ones(10)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "63622e1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_state = 0\n",
    "flipper = np.random.RandomState(random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "efab4900",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.76405235 0.40015721 0.97873798 2.2408932  1.86755799]\n",
      "[-0.97727788  0.95008842 -0.15135721 -0.10321885  0.4105985 ]\n",
      "[0.14404357 1.45427351 0.76103773 0.12167502 0.44386323]\n",
      "[ 0.33367433  1.49407907 -0.20515826  0.3130677  -0.85409574]\n",
      "[-2.55298982  0.6536186   0.8644362  -0.74216502  2.26975462]\n"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    a = flipper.randn(5)\n",
    "    print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e4fa68dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b16e288c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.76405235 0.40015721 0.97873798 2.2408932  1.86755799]\n",
      "[-0.97727788  0.95008842 -0.15135721 -0.10321885  0.4105985 ]\n",
      "[0.14404357 1.45427351 0.76103773 0.12167502 0.44386323]\n",
      "[ 0.33367433  1.49407907 -0.20515826  0.3130677  -0.85409574]\n",
      "[-2.55298982  0.6536186   0.8644362  -0.74216502  2.26975462]\n"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    a = np.random.randn(5)\n",
    "    print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "a6b1b6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_scar_candidate_labels(train_labels, scar_type):\n",
    "    K = int(torch.max(train_labels) - torch.min(train_labels) + 1)\n",
    "    n = train_labels.shape[0]\n",
    "    if scar_type == 'set1':\n",
    "        cl_class_prior = 0.1 * torch.ones(K)\n",
    "    elif scar_type == 'set2':\n",
    "        cl_class_prior = torch.tensor([0.1, 0.1, 0.9, 0.05, 0.05, 0.1, 0.1, 0.2, 0.05, 0.05])\n",
    "    elif scar_type == 'set3':\n",
    "        cl_class_prior = 0.2 * torch.ones(K)\n",
    "    trainsition_mat = 10 / 9 * torch.ones(K, K)\n",
    "    for i in range(K):\n",
    "        trainsition_mat[:, i] = trainsition_mat[:, i] * cl_class_prior[i]\n",
    "        trainsition_mat[i, i] = 0\n",
    "    partialY = torch.zeros(n, K)\n",
    "    for idx in range(n):\n",
    "        true_label = train_labels[idx]\n",
    "        partial_label = np.random.binomial(1, trainsition_mat[true_label], K)\n",
    "        partial_label = torch.from_numpy(partial_label)\n",
    "        partialY[idx] = partial_label\n",
    "    return partialY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "d20adf48",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = torch.arange(10)\n",
    "scar_type = 'set2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "f856e705",
   "metadata": {},
   "outputs": [],
   "source": [
    "    K = int(torch.max(train_labels) - torch.min(train_labels) + 1)\n",
    "    n = train_labels.shape[0]\n",
    "    if scar_type == 'set1':\n",
    "        cl_class_prior = 0.1 * torch.ones(K)\n",
    "    elif scar_type == 'set2':\n",
    "        cl_class_prior = torch.tensor([0.1, 0.1, 0.9, 0.05, 0.05, 0.1, 0.1, 0.2, 0.05, 0.05])\n",
    "    elif scar_type == 'set3':\n",
    "        cl_class_prior = 0.2 * torch.ones(K)\n",
    "    trainsition_mat = 10 / 9 * torch.ones(K, K)\n",
    "    for i in range(K):\n",
    "        trainsition_mat[:, i] = trainsition_mat[:, i] * cl_class_prior[i]\n",
    "        trainsition_mat[i, i] = 0\n",
    "    partialY = torch.zeros(n, K)\n",
    "    for idx in range(n):\n",
    "        true_label = train_labels[idx]\n",
    "        partial_label = np.random.binomial(1, trainsition_mat[true_label], K)\n",
    "        partial_label = torch.from_numpy(partial_label)\n",
    "        partialY[idx] = partial_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "30b97341",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.1111, 1.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.0000, 1.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 0.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0000, 0.0556, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0000, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0556, 0.0000, 0.1111, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0556, 0.1111, 0.0000, 0.2222, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.0000, 0.0556,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.2222, 0.0000,\n",
       "         0.0556],\n",
       "        [0.1111, 0.1111, 1.0000, 0.0556, 0.0556, 0.1111, 0.1111, 0.2222, 0.0556,\n",
       "         0.0000]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainsition_mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d3856883",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "import pickle\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as dsets\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a098b3a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CLCIFAR20(Dataset):\n",
    "    \"\"\"CLCIFAR10 training set\n",
    "\n",
    "    The training set of CIFAR10 with human annotated complementary labels.\n",
    "    Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels\n",
    "\n",
    "    Args:\n",
    "        root: the path to store the dataset\n",
    "        transform: feature transformation function\n",
    "    \"\"\"\n",
    "    def __init__(self, root=\"../dataset\", transform=None):\n",
    "\n",
    "        #os.makedirs(os.path.join(root, 'clcifar20'), exist_ok=True)\n",
    "        dataset_path = os.path.join(root, 'clcifar20', f\"clcifar20.pkl\")\n",
    "\n",
    "        data = pickle.load(open(dataset_path, \"rb\"))\n",
    "\n",
    "        self.transform = transform\n",
    "        self.input_dim = 32 * 32 * 3\n",
    "        self.num_classes = 20\n",
    "\n",
    "        #self.targets = [labels[0] for labels in data[\"cl_labels\"]]\n",
    "        self.targets = np.zeros((len(data), self.num_classes))\n",
    "        for i in range(len(data)):\n",
    "            for j in range(3):\n",
    "                self.targets[i, data[\"cl_labels\"][i][j]] = 1\n",
    "        self.data = data[\"images\"]\n",
    "        self.ord_labels = data[\"ord_labels\"]\n",
    "        self.ord_labels = np.array(self.ord_labels)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        image = self.data[index]\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        return image, self.targets[index]\n",
    "\n",
    "class CLCIFAR10(Dataset):\n",
    "    \"\"\"CLCIFAR10 training set\n",
    "\n",
    "    The training set of CIFAR10 with human annotated complementary labels.\n",
    "    Containing 50000 samples, each with one ordinary label and the first one of the three complementary labels\n",
    "\n",
    "    Args:\n",
    "        root: the path to store the dataset\n",
    "        transform: feature transformation function\n",
    "    \"\"\"\n",
    "    def __init__(self, root=\"../dataset\", transform=None):\n",
    "\n",
    "        #os.makedirs(os.path.join(root, 'clcifar10'), exist_ok=True)\n",
    "        dataset_path = os.path.join(root, 'clcifar10', f\"clcifar10.pkl\")\n",
    "\n",
    "        data = pickle.load(open(dataset_path, \"rb\"))\n",
    "\n",
    "        self.transform = transform\n",
    "        self.input_dim = 32 * 32 * 3\n",
    "        self.num_classes = 10\n",
    "\n",
    "        self.targets = np.zeros((len(data), self.num_classes))\n",
    "        for i in range(len(data)):\n",
    "            for j in range(3):\n",
    "                self.targets[i, data[\"cl_labels\"][i][j]] = 1\n",
    "        self.data = data[\"images\"]\n",
    "        self.ord_labels = data[\"ord_labels\"]\n",
    "        self.ord_labels = np.array(self.ord_labels)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        image = self.data[index]\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        return image, self.targets[index], self.ord_labels[index], index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "52a04101",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_cv_datasets(dataname, batch_size):\n",
    "    if dataname == 'clcifar10':\n",
    "        train_transform = transforms.Compose(\n",
    "            [transforms.ToTensor(),  \n",
    "            transforms.RandomHorizontalFlip(), \n",
    "            transforms.RandomCrop(32,4),\n",
    "            transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))])\n",
    "        test_transform = transforms.Compose(\n",
    "            [transforms.ToTensor(),\n",
    "            transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))])\n",
    "        train_dataset = CLCIFAR10(transform=train_transform)\n",
    "        ordinary_train_dataset = dsets.CIFAR10(root='/home/wangw/dataset', train=True, transform=train_transform, download=True)\n",
    "        test_dataset = dsets.CIFAR10(root='/home/wangw/dataset', train=False, transform=test_transform)\n",
    "        num_classes = 10\n",
    "    if dataname == 'clcifar20':\n",
    "        train_transform = transforms.Compose(\n",
    "            [transforms.ToTensor(),  \n",
    "            transforms.RandomHorizontalFlip(), \n",
    "            transforms.RandomCrop(32,4),\n",
    "            transforms.Normalize((0.5068, 0.4854, 0.4402), (0.2672, 0.2563, 0.2760))])\n",
    "        test_transform = transforms.Compose(\n",
    "            [transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5068, 0.4854, 0.4402), (0.2672, 0.2563, 0.2760))])\n",
    "        train_dataset = CLCIFAR20(transform=train_transform)\n",
    "        ordinary_train_dataset = dsets.CIFAR100(root='/home/wangw/dataset', train=True, transform=test_transform, download=True)\n",
    "        test_dataset = dsets.CIFAR100(root='/home/wangw/dataset', train=False, transform=test_transform)\n",
    "        def _cifar100_to_cifar20(target):\n",
    "            _dict = {0: 4, 1: 1, 2: 14, 3: 8, 4: 0, 5: 6, 6: 7, 7: 7, 8: 18, 9: 3, 10: 3, 11: 14, 12: 9, 13: 18, 14: 7, 15: 11, 16: 3, 17: 9, 18: 7, 19: 11, 20: 6, 21: 11, 22: 5, 23: 10, 24: 7, 25: 6, 26: 13, 27: 15, 28: 3, 29: 15, 30: 0, 31: 11, 32: 1, 33: 10, 34: 12, 35: 14, 36: 16, 37: 9, 38: 11, 39: 5, 40: 5, 41: 19, 42: 8, 43: 8, 44: 15, 45: 13, 46: 14, 47: 17, 48: 18, 49: 10, 50: 16, 51: 4, 52: 17, 53: 4, 54: 2, 55: 0, 56: 17, 57: 4, 58: 18, 59: 17, 60: 10, 61: 3, 62: 2, 63: 12, 64: 12, 65: 16, 66: 12, 67: 1, 68: 9, 69: 19, 70: 2, 71: 10, 72: 0, 73: 1, 74: 16, 75: 12, 76: 9, 77: 13, 78: 15, 79: 13, 80: 16, 81: 18, 82: 2, 83: 4, 84: 6, 85: 19, 86: 5, 87: 5, 88: 8, 89: 19, 90: 18, 91: 1, 92: 2, 93: 15, 94: 6, 95: 0, 96: 17, 97: 8, 98: 14, 99: 13}\n",
    "            return _dict[target]  \n",
    "        test_dataset.targets = [_cifar100_to_cifar20(i) for i in test_dataset.targets]\n",
    "        num_classes = 20     \n",
    "    complementary_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "    full_train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset.data), shuffle=True, num_workers=0)\n",
    "    \n",
    "    return full_train_loader, complementary_train_loader, train_loader, test_loader, train_dataset, test_dataset, num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ba6b8753",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/home/wangw/dataset/clcifar10/clcifar10.pkl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-fd6a48666b43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mdataname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'clcifar10'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m256\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mprepare_cv_datasets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-5-851f1e333331>\u001b[0m in \u001b[0;36mprepare_cv_datasets\u001b[0;34m(dataname, batch_size)\u001b[0m\n\u001b[1;32m      9\u001b[0m             [transforms.ToTensor(),\n\u001b[1;32m     10\u001b[0m             transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))])\n\u001b[0;32m---> 11\u001b[0;31m         \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCLCIFAR10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_transform\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m         \u001b[0mordinary_train_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdsets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCIFAR10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/home/wangw/dataset'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_transform\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0mtest_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdsets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCIFAR10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/home/wangw/dataset'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest_transform\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-2-f37fd159b643>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, transform)\u001b[0m\n\u001b[1;32m     53\u001b[0m         \u001b[0mdataset_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'clcifar10'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"clcifar10.pkl\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/wangw/dataset/clcifar10/clcifar10.pkl'"
     ]
    }
   ],
   "source": [
    "dataname = 'clcifar10'\n",
    "batch_size = 256\n",
    "prepare_cv_datasets(dataname, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a89820b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a1f049",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conu",
   "language": "python",
   "name": "conu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
