{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1f30e34f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from opacus import PrivacyEngine\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import tqdm\n",
    "import torchvision\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "779fb79a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class LinNet(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 300)\n",
    "        self.fc2 = nn.Linear(300, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "\n",
    "        \n",
    "        x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "    \n",
    "def train(delta, model, device, train_loader, train_loader_0, optimizer, privacy_engine, epoch):\n",
    "    \n",
    "    model.train()\n",
    "    losses = []\n",
    "    \n",
    "     # Compute all the individual norms (actually the squared norms are saved here)\n",
    "    grad_norms=torch.zeros(1200,200).to(device)\n",
    "    \n",
    "    ijk=0\n",
    "    \n",
    "    for _batch_idx0, (data0, target0) in enumerate(tqdm(train_loader)):\n",
    "   \n",
    "        data, target = data0.to(device), target0.to(device)\n",
    "        target = F.one_hot(target, num_classes=10)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.mse_loss(output, target.float())\n",
    "    \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.item())\n",
    "\n",
    "\n",
    "       \n",
    "\n",
    "        #Train loader returns also indices (vector idx)\n",
    "        for bi, (data, target) in enumerate(train_loader_0):\n",
    "\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            target = F.one_hot(target, num_classes=10)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = F.mse_loss(output, target.float())\n",
    "\n",
    "            loss.backward()\n",
    "\n",
    "            batch_grad_norms = torch.zeros(len(target)).cuda()\n",
    "\n",
    "            # Clip each parameter's per-sample gradient\n",
    "            for (ii,p) in enumerate(model.parameters()):\n",
    "\n",
    "                per_sample_grad = p.grad_sample\n",
    "\n",
    "                #dimension across which we compute the norms for this gradient part \n",
    "                #(here is difference e.g. between biases and weight matrices)\n",
    "                dims=list(range(1,len(per_sample_grad.shape)))\n",
    "\n",
    "                # compute the clipped norms. Gradients will be clipped in .backward()  \n",
    "                per_sample_grad_norms = per_sample_grad.norm(dim=dims)\n",
    "\n",
    "                batch_grad_norms += per_sample_grad_norms**2\n",
    "            \n",
    "            # compute the clipped norms. Gradients will be then clipped in .backward()  \n",
    "            max_grad_norm=5.0\n",
    "            grad_norms[bi*batch_size:(bi+1)*batch_size,ijk] = (torch.sqrt(batch_grad_norms).clamp(max=max_grad_norm))\n",
    "        ijk+=1\n",
    "        \n",
    "\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent(\n",
    "        delta=delta\n",
    "    )\n",
    "    print(\n",
    "        f\"Train Epoch: {epoch} \\t\"\n",
    "        f\"Loss: {np.mean(losses):.6f} \"\n",
    "        f\"(ε = {epsilon:.2f}, δ = {delta}) for α = {best_alpha}\"\n",
    "    )\n",
    "    return grad_norms\n",
    "\n",
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in tqdm(test_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += criterion(output, target).item()  # sum up batch loss\n",
    "            \n",
    "            pred = output.argmax(\n",
    "                dim=1, keepdim=True\n",
    "            )  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print(\n",
    "        \"\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n\".format(\n",
    "            test_loss,\n",
    "            correct,\n",
    "            len(test_loader.dataset),\n",
    "            100.0 * correct / len(test_loader.dataset),\n",
    "        )\n",
    "    )\n",
    "    return correct / len(test_loader.dataset)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "e5090a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "import opacus\n",
    "from opacus.data_loader import DPDataLoader\n",
    "\n",
    "batch_size=300\n",
    "\n",
    "test_batch_size=250\n",
    "\n",
    "epochs=50\n",
    "\n",
    "lr=0.01\n",
    "\n",
    "sigma=2.0\n",
    "\n",
    "c=5.0\n",
    "max_per_sample_grad_norm=c\n",
    "\n",
    "delta=1e-5\n",
    "\n",
    "device='cuda'\n",
    "\n",
    "data_root= './'\n",
    "\n",
    "secure_rng=False\n",
    "\n",
    "device='cuda'\n",
    "\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root=data_root, train=True,\n",
    "                                        download=True, transform=transform)\n",
    "\n",
    "\n",
    "testset = torchvision.datasets.MNIST(root=data_root, train=False,\n",
    "                                       download=True, transform=transform)\n",
    "\n",
    "# select random 1200 elements for which to compute the gradients during the training \n",
    "X1, X1rest = torch.utils.data.random_split(trainset, [1200, (60000-1200)])\n",
    "\n",
    "\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
    "                                          num_workers=2)\n",
    "\n",
    "train_loader_X1 = torch.utils.data.DataLoader(X1, batch_size=batch_size, shuffle=False,\n",
    "                                          num_workers=2)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "869b693f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running with learning rate 0.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:08<00:00,  2.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 1 \tLoss: 0.090364 (ε = 0.20, δ = 1e-05) for α = 41.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 40.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0076, Accuracy: 4124/10000 (41.24%)\n",
      "\n",
      "Test Accuracy : 41.24%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:08<00:00,  2.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 2 \tLoss: 0.077355 (ε = 0.23, δ = 1e-05) for α = 41.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 39.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0074, Accuracy: 5963/10000 (59.63%)\n",
      "\n",
      "Test Accuracy : 59.63%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:09<00:00,  2.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 3 \tLoss: 0.069562 (ε = 0.26, δ = 1e-05) for α = 41.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 39.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0073, Accuracy: 6761/10000 (67.61%)\n",
      "\n",
      "Test Accuracy : 67.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:09<00:00,  2.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 4 \tLoss: 0.063828 (ε = 0.29, δ = 1e-05) for α = 41.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 40.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0071, Accuracy: 7213/10000 (72.13%)\n",
      "\n",
      "Test Accuracy : 72.13%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:08<00:00,  2.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 5 \tLoss: 0.059405 (ε = 0.33, δ = 1e-05) for α = 41.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:00<00:00, 39.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0070, Accuracy: 7479/10000 (74.79%)\n",
      "\n",
      "Test Accuracy : 74.79%\n",
      "CPU times: user 1min 15s, sys: 2min 45s, total: 4min 1s\n",
      "Wall time: 5min 49s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "%%time\n",
    "\n",
    "\n",
    "model1 = LinNet().to(device)\n",
    "\n",
    "print('running with learning rate ' + str(lr))\n",
    "\n",
    "optimizer = optim.SGD(model1.parameters(), lr=lr, momentum=0)\n",
    "\n",
    "privacy_engine = None\n",
    "\n",
    "privacy_engine = PrivacyEngine(secure_mode=secure_rng)\n",
    "model1, optimizer, train_loader = privacy_engine.make_private(module=model1,optimizer=optimizer, \n",
    "    data_loader=train_loader,noise_multiplier=sigma, max_grad_norm=max_per_sample_grad_norm,\n",
    ")\n",
    "\n",
    "# First phase: Training the model with x % of the data\n",
    "\n",
    "test_acc=0\n",
    "\n",
    "grad_norms_table = torch.zeros(1200,200*epochs)\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    grad_norms = train(delta,model1, device, train_loader, train_loader_X1, optimizer, privacy_engine, epoch)\n",
    "    grad_norms_table[:,(epoch-1)*200:epoch*200]=grad_norms\n",
    "    test_acc = test(model1, device, test_loader)\n",
    "    print( \"Test Accuracy : {:.2f}%\".format( test_acc * 100 ))\n",
    "    \n",
    "    \n",
    "  \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "57d9137f",
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_norms_table_np = grad_norms_table.t().numpy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "db6cf37e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('./grad_table_' + str(lr) + '_'  + str(epochs) + '_' + str(sigma)  + '_' + str(c) + '.npy', grad_norms_table_np)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "94834f9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.89231056 0.7722873  1.054447   ... 1.0343367  0.97184885 0.69236827]\n",
      " [0.89417565 0.7745887  1.0530472  ... 1.0355823  0.96981794 0.6905627 ]\n",
      " [0.89261436 0.77178824 1.0452148  ... 1.0247855  0.9648452  0.69119495]\n",
      " ...\n",
      " [0.8350607  0.7679998  0.8859045  ... 0.9641356  0.9829359  0.6373617 ]\n",
      " [0.8320443  0.77097017 0.8908278  ... 0.964905   0.9821216  0.6402552 ]\n",
      " [0.8282666  0.7731515  0.8841989  ... 0.9589262  0.9944454  0.6337797 ]]\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1647f08",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
