{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "112016f2-a889-414a-bd44-35535330b875",
   "metadata": {},
   "source": [
    "# CIFAR-10 Dataset\n",
    "\n",
    "Each of 10 data owners owns a classes of the CIFAR-10 dataset. They have imbalanced privacy budgets: $\\epsilon = {.5, 1, 1.5, ..., 5}, \\delta = 1e-5$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e231f95a-1745-43be-8c86-58b201565a4b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "cuda = False\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\", index=0)\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d26cb442-29d2-4b4d-b8a4-4c80e71010a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f3330756150>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(122)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "386f73d1-4bdb-48a3-a07a-ed65000d06c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from utils.datasets.cifar import load_cifar\n",
    "\n",
    "train_set = load_cifar('train')\n",
    "test_set = load_cifar('test')\n",
    "n_data = len(train_set)\n",
    "n_owners = 10\n",
    "class_to_owner_dict = {i:i for i in range(10)}\n",
    "dataset_sizes = torch.full((n_owners,), 5000)\n",
    "epsilons = torch.arange(1, n_owners + 1) / 2\n",
    "deltas = torch.full((n_owners,), 1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea6a7a27-d3d0-4f49-a57d-608e669b2ede",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6, 9, 9,  ..., 9, 1, 1], dtype=torch.int32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from utils.colab import assign_data_owners_by_class\n",
    "\n",
    "assignment = assign_data_owners_by_class(train_set.targets[:, 0], class_to_owner_dict)\n",
    "assignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "47c47dfa-19a3-411a-834a-81fd11452178",
   "metadata": {},
   "outputs": [],
   "source": [
    "#from ipp.planner import IPPPlanner\n",
    "#from utils.file import save\n",
    "\n",
    "#ipp_planner = IPPPlanner(3000, dataset_sizes, epsilons, deltas, assignment)\n",
    "#ipp = ipp_planner.plan_sampling(batch_size=1024, clipping_threshold=2)\n",
    "#save(ipp, \"saved_data/cifar_sample_ipp_3000_iterations_2_clipping\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12a74f08-2546-4f40-9a43-85e2c1585842",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.file import load\n",
    "\n",
    "ipp = load(\"saved_data/cifar_sample_ipp_3000_iterations_2_clipping\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "711913b3-d5cc-4f99-9c95-ad4675a2990e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1,\tloss: 2.0627,\taccuracy: 23.95\n",
      "epoch: 2,\tloss: 1.8195,\taccuracy: 26.42\n",
      "epoch: 3,\tloss: 1.7528,\taccuracy: 28.31\n",
      "epoch: 4,\tloss: 1.6982,\taccuracy: 30.78\n",
      "epoch: 5,\tloss: 1.6553,\taccuracy: 32.55\n",
      "epoch: 6,\tloss: 1.6053,\taccuracy: 34.01\n",
      "epoch: 7,\tloss: 1.5324,\taccuracy: 35.13\n",
      "epoch: 8,\tloss: 1.486,\taccuracy: 36.49\n",
      "epoch: 9,\tloss: 1.4237,\taccuracy: 38.06\n",
      "epoch: 10,\tloss: 1.3969,\taccuracy: 38.51\n",
      "epoch: 11,\tloss: 1.3638,\taccuracy: 39.16\n",
      "epoch: 12,\tloss: 1.352,\taccuracy: 39.98\n",
      "epoch: 13,\tloss: 1.3327,\taccuracy: 40.62\n",
      "epoch: 14,\tloss: 1.3146,\taccuracy: 41.53\n",
      "epoch: 15,\tloss: 1.3082,\taccuracy: 41.84\n",
      "epoch: 16,\tloss: 1.299,\taccuracy: 42.36\n",
      "epoch: 17,\tloss: 1.2997,\taccuracy: 43.09\n",
      "epoch: 18,\tloss: 1.2999,\taccuracy: 42.75\n",
      "epoch: 19,\tloss: 1.2994,\taccuracy: 43.49\n",
      "epoch: 20,\tloss: 1.3119,\taccuracy: 43.83\n",
      "epoch: 21,\tloss: 1.3163,\taccuracy: 44.25\n",
      "epoch: 22,\tloss: 1.3211,\taccuracy: 43.88\n",
      "epoch: 23,\tloss: 1.3328,\taccuracy: 44.29\n",
      "epoch: 24,\tloss: 1.3337,\taccuracy: 45.0\n",
      "epoch: 25,\tloss: 1.3439,\taccuracy: 44.83\n",
      "epoch: 26,\tloss: 1.3316,\taccuracy: 45.02\n",
      "epoch: 27,\tloss: 1.343,\taccuracy: 45.72\n",
      "epoch: 28,\tloss: 1.3421,\taccuracy: 46.02\n",
      "epoch: 29,\tloss: 1.3462,\taccuracy: 45.98\n",
      "epoch: 30,\tloss: 1.3408,\taccuracy: 45.82\n",
      "epoch: 31,\tloss: 1.344,\taccuracy: 45.83\n",
      "epoch: 32,\tloss: 1.3232,\taccuracy: 46.88\n",
      "epoch: 33,\tloss: 1.3253,\taccuracy: 47.15\n",
      "epoch: 34,\tloss: 1.3179,\taccuracy: 46.93\n",
      "epoch: 35,\tloss: 1.3184,\taccuracy: 47.0\n",
      "epoch: 36,\tloss: 1.3033,\taccuracy: 46.95\n",
      "epoch: 37,\tloss: 1.3215,\taccuracy: 47.18\n",
      "epoch: 38,\tloss: 1.3085,\taccuracy: 47.72\n",
      "epoch: 39,\tloss: 1.3086,\taccuracy: 48.12\n",
      "epoch: 40,\tloss: 1.2877,\taccuracy: 47.86\n",
      "epoch: 41,\tloss: 1.2918,\taccuracy: 48.33\n",
      "epoch: 42,\tloss: 1.2974,\taccuracy: 48.84\n",
      "epoch: 43,\tloss: 1.2991,\taccuracy: 48.35\n",
      "epoch: 44,\tloss: 1.2955,\taccuracy: 48.04\n",
      "epoch: 45,\tloss: 1.2823,\taccuracy: 47.9\n",
      "epoch: 46,\tloss: 1.281,\taccuracy: 48.73\n",
      "epoch: 47,\tloss: 1.2702,\taccuracy: 48.75\n",
      "epoch: 48,\tloss: 1.2712,\taccuracy: 48.64\n",
      "epoch: 49,\tloss: 1.2616,\taccuracy: 49.47\n",
      "epoch: 50,\tloss: 1.2511,\taccuracy: 48.93\n",
      "epoch: 51,\tloss: 1.2532,\taccuracy: 49.39\n",
      "epoch: 52,\tloss: 1.2458,\taccuracy: 49.78\n",
      "epoch: 53,\tloss: 1.2466,\taccuracy: 50.15\n",
      "epoch: 54,\tloss: 1.2353,\taccuracy: 50.0\n",
      "epoch: 55,\tloss: 1.2673,\taccuracy: 49.48\n",
      "epoch: 56,\tloss: 1.2315,\taccuracy: 49.99\n",
      "epoch: 57,\tloss: 1.2565,\taccuracy: 50.08\n",
      "epoch: 58,\tloss: 1.2597,\taccuracy: 50.64\n",
      "epoch: 59,\tloss: 1.2653,\taccuracy: 50.48\n",
      "epoch: 60,\tloss: 1.2507,\taccuracy: 50.29\n",
      "epoch: 61,\tloss: 1.2408,\taccuracy: 49.96\n",
      "epoch: 62,\tloss: 1.2484,\taccuracy: 50.8\n",
      "epoch: 63,\tloss: 1.2336,\taccuracy: 50.52\n",
      "epoch: 1,\tloss: 2.1111,\taccuracy: 24.08\n",
      "epoch: 2,\tloss: 1.8534,\taccuracy: 27.24\n",
      "epoch: 3,\tloss: 1.7845,\taccuracy: 28.77\n",
      "epoch: 4,\tloss: 1.7304,\taccuracy: 31.25\n",
      "epoch: 5,\tloss: 1.6884,\taccuracy: 32.78\n",
      "epoch: 6,\tloss: 1.642,\taccuracy: 34.53\n",
      "epoch: 7,\tloss: 1.5749,\taccuracy: 35.71\n",
      "epoch: 8,\tloss: 1.5284,\taccuracy: 36.78\n",
      "epoch: 9,\tloss: 1.4658,\taccuracy: 38.46\n",
      "epoch: 10,\tloss: 1.435,\taccuracy: 38.97\n",
      "epoch: 11,\tloss: 1.3925,\taccuracy: 39.86\n",
      "epoch: 12,\tloss: 1.3657,\taccuracy: 40.76\n",
      "epoch: 13,\tloss: 1.3377,\taccuracy: 41.38\n",
      "epoch: 14,\tloss: 1.3082,\taccuracy: 42.27\n",
      "epoch: 15,\tloss: 1.2883,\taccuracy: 42.98\n",
      "epoch: 16,\tloss: 1.2706,\taccuracy: 43.45\n",
      "epoch: 17,\tloss: 1.2571,\taccuracy: 44.16\n",
      "epoch: 18,\tloss: 1.2389,\taccuracy: 44.19\n",
      "epoch: 19,\tloss: 1.2299,\taccuracy: 44.76\n",
      "epoch: 20,\tloss: 1.2202,\taccuracy: 45.23\n",
      "epoch: 21,\tloss: 1.2026,\taccuracy: 45.65\n",
      "epoch: 22,\tloss: 1.1941,\taccuracy: 46.15\n",
      "epoch: 23,\tloss: 1.1924,\taccuracy: 45.94\n",
      "epoch: 24,\tloss: 1.182,\taccuracy: 46.78\n",
      "epoch: 25,\tloss: 1.1747,\taccuracy: 47.01\n",
      "epoch: 26,\tloss: 1.1594,\taccuracy: 47.23\n",
      "epoch: 27,\tloss: 1.1601,\taccuracy: 47.22\n",
      "epoch: 28,\tloss: 1.157,\taccuracy: 47.75\n",
      "epoch: 29,\tloss: 1.1594,\taccuracy: 48.2\n",
      "epoch: 30,\tloss: 1.1497,\taccuracy: 47.62\n",
      "epoch: 31,\tloss: 1.1522,\taccuracy: 47.9\n",
      "epoch: 32,\tloss: 1.1375,\taccuracy: 48.6\n",
      "epoch: 33,\tloss: 1.14,\taccuracy: 48.85\n",
      "epoch: 34,\tloss: 1.1314,\taccuracy: 48.53\n",
      "epoch: 35,\tloss: 1.1384,\taccuracy: 48.7\n",
      "epoch: 36,\tloss: 1.1254,\taccuracy: 48.51\n",
      "epoch: 37,\tloss: 1.148,\taccuracy: 48.63\n",
      "epoch: 38,\tloss: 1.1367,\taccuracy: 49.45\n",
      "epoch: 39,\tloss: 1.1362,\taccuracy: 49.3\n",
      "epoch: 40,\tloss: 1.1254,\taccuracy: 49.07\n",
      "epoch: 41,\tloss: 1.1397,\taccuracy: 49.24\n",
      "epoch: 42,\tloss: 1.1449,\taccuracy: 49.99\n",
      "epoch: 43,\tloss: 1.1556,\taccuracy: 49.54\n",
      "epoch: 44,\tloss: 1.1559,\taccuracy: 48.94\n",
      "epoch: 45,\tloss: 1.1506,\taccuracy: 48.53\n",
      "epoch: 46,\tloss: 1.1544,\taccuracy: 49.17\n",
      "epoch: 47,\tloss: 1.1485,\taccuracy: 49.35\n",
      "epoch: 48,\tloss: 1.1569,\taccuracy: 49.71\n",
      "epoch: 49,\tloss: 1.1562,\taccuracy: 50.12\n",
      "epoch: 50,\tloss: 1.1498,\taccuracy: 49.8\n",
      "epoch: 51,\tloss: 1.1626,\taccuracy: 50.0\n",
      "epoch: 52,\tloss: 1.1596,\taccuracy: 50.09\n",
      "epoch: 53,\tloss: 1.1631,\taccuracy: 50.25\n",
      "epoch: 54,\tloss: 1.1575,\taccuracy: 50.48\n",
      "epoch: 55,\tloss: 1.1894,\taccuracy: 49.91\n",
      "epoch: 56,\tloss: 1.1502,\taccuracy: 50.33\n",
      "epoch: 57,\tloss: 1.1852,\taccuracy: 50.37\n",
      "epoch: 58,\tloss: 1.1903,\taccuracy: 51.53\n",
      "epoch: 59,\tloss: 1.1966,\taccuracy: 50.66\n",
      "epoch: 60,\tloss: 1.1855,\taccuracy: 50.89\n",
      "epoch: 61,\tloss: 1.1692,\taccuracy: 50.98\n",
      "epoch: 62,\tloss: 1.182,\taccuracy: 51.03\n",
      "epoch: 63,\tloss: 1.176,\taccuracy: 51.4\n",
      "epoch: 1,\tloss: 2.1997,\taccuracy: 24.98\n",
      "epoch: 2,\tloss: 1.992,\taccuracy: 27.47\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "import gc\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader\n",
    "from ipp import IPrivacyEngine\n",
    "from ipp.train import ipp_train\n",
    "from utils.models import CifarCNN\n",
    "from utils.file import save\n",
    "\n",
    "LEARNING_RATE = 0.15\n",
    "MOMENTUM = 0.5\n",
    "\n",
    "n_iteration=3000\n",
    "clipping_threshold=0.5\n",
    "batch_size=1024\n",
    "\n",
    "SEEDS = [122]\n",
    "seed = 122\n",
    "\n",
    "#muss = [torch.Tensor([100.0, 100.0, 100.0, 100.0]).repeat(n_iteration, 1),\n",
    "#        torch.Tensor([150.0, 150.0, 150.0, 150.0]).repeat(n_iteration, 1),\n",
    "#        torch.Tensor([250.0, 250.0, 250.0, 250.0]).repeat(n_iteration, 1),\n",
    "#        torch.Tensor([300.0, 300.0, 300.0, 300.0]).repeat(n_iteration, 1)]\n",
    "\n",
    "betas = [(4, 4, 800), (6, 4, 800)]\n",
    "\n",
    "#names = [\"100_100_100_100\", \"150_150_150_150\", \"250_250_250_250\", \"300_300_300_300\"]\n",
    "\n",
    "for alpha, beta, scale in enumerate(betas):\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    data_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)\n",
    "    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
    "    \n",
    "    model = CifarCNN()\n",
    "    model = model.to(device)\n",
    "    \n",
    "    sgd_optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)\n",
    "    \n",
    "    i_privacy_engine = IPrivacyEngine()\n",
    "    ipp_data_loader, ipp_model, ipp_optimizer = i_privacy_engine.make_private(data_loader=data_loader, model=model, optimizer=sgd_optimizer, ipp=ipp)\n",
    "    \n",
    "    result = ipp_train(ipp_data_loader, ipp_model, ipp_optimizer, test_loader, ipp, device=device, adaptive_threshold=0, alpha=alpha, beta=beta, scale=scale)\n",
    "    \n",
    "    save(result, f\"saved_data/cifar_sample_ino_3000_iterations_2_clipping_0.15_lr_{names[i]}_mu_{seed}_seed.pkl\")\n",
    "\n",
    "    del data_loader, test_loader, model, sgd_optimizer, i_privacy_engine, ipp_data_loader, ipp_model, ipp_optimizer\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b1b6a6a-3076-4a6c-bf8f-f1c61c21aef0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b784334-4b48-4e13-8bc8-f2ebd2d5a159",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "py310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
