{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8dca029-e956-44c6-bd8c-ac02733f2bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as T\n",
    "import torch.optim as optim\n",
    "from torch.optim import Adam, AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.metrics import (accuracy_score, precision_score, recall_score,\n",
    "                             f1_score, roc_auc_score, matthews_corrcoef,\n",
    "                             confusion_matrix)\n",
    "from sklearn.preprocessing import label_binarize\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import multiprocessing\n",
    "from diffprivlib.mechanisms import GaussianAnalytic\n",
    "import torchattacks\n",
    "from torchattacks import FGSM, PGD\n",
    "import pprint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae8aa11a-f9ec-4a72-8059-4a512e217436",
   "metadata": {},
   "source": [
    "# Use GPU or CPU depending on availability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346f861f-5163-4c6f-8f31-6eef6ca83050",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a1bd49f-d853-43f8-833b-6960666546e9",
   "metadata": {},
   "source": [
    "## ResNet-9 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf5ef2c-c67e-4fd4-92ea-5bd2c5bb6ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspired by \"Deep Residual Learning for Image Recognition\": https://github.com/Moddy2024/ResNet-9\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.dropout = nn.Dropout(dropout_rate)\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_channels != out_channels:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(out_channels)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.dropout(out)\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return F.relu(out)\n",
    "\n",
    "class ResNet9(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(ResNet9, self).__init__()\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False),  # For grayscale inputs\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(inplace=True)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(2, 2)\n",
    "        )\n",
    "        self.res1 = BasicBlock(64, 64)\n",
    "\n",
    "        self.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(2, 2)\n",
    "        )\n",
    "        self.res2 = BasicBlock(128, 128)\n",
    "\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.MaxPool2d(4),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(128, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.res1(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.res2(x)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62c82944-3891-4df4-97c0-934a759e365e",
   "metadata": {},
   "source": [
    "## ViT Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4404fd9-18da-48c9-9b9c-8c1964473a84",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "    Inspired by \"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale\"\n",
    "    https://github.com/google-research/vision_transformer\n",
    "    https://colab.research.google.com/drive/1rabTm93y39FNbu-21tDhlvYh2gp8edVR?usp=sharing\n",
    "'''\n",
    "\n",
    "# Patch Embeddings\n",
    "class PatchEmbedding(nn.Module):\n",
    "    def __init__(self, d_model, img_size, patch_size, n_channels):\n",
    "        super().__init__()\n",
    "        self.d_model = d_model # Dimensionality of Model\n",
    "        self.img_size = img_size # Image Size\n",
    "        self.patch_size = patch_size # Patch Size\n",
    "        self.n_channels = n_channels # Number of Channels\n",
    "        self.linear_project = nn.Conv2d(self.n_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size)\n",
    "\n",
    "    # B: Batch Size\n",
    "    # C: Image Channels\n",
    "    # H: Image Height\n",
    "    # W: Image Width\n",
    "    # P_col: Patch Column\n",
    "    # P_row: Patch Row\n",
    "    def forward(self, x):\n",
    "        x = self.linear_project(x) # (B, C, H, W) -> (B, d_model, P_col, P_row)\n",
    "        x = x.flatten(2) # (B, d_model, P_col, P_row) -> (B, d_model, P)\n",
    "        x = x.transpose(-2, -1) # (B, d_model, P) -> (B, P, d_model)\n",
    "        return x\n",
    "\n",
    "\n",
    "# Class Token and Positional Encoding\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, max_seq_length):\n",
    "        super().__init__()\n",
    "        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))  # Classification Token\n",
    "\n",
    "        # Fix: add +1 for CLS token\n",
    "        pe = torch.zeros(max_seq_length + 1, d_model)\n",
    "        for pos in range(max_seq_length + 1):\n",
    "            for i in range(0, d_model, 2):\n",
    "                pe[pos, i] = np.sin(pos / (10000 ** (i / d_model)))\n",
    "                if i + 1 < d_model:\n",
    "                    pe[pos, i + 1] = np.cos(pos / (10000 ** ((i + 1) / d_model)))\n",
    "\n",
    "        self.register_buffer('pe', pe.unsqueeze(0))  # Shape: (1, max_seq_length+1, d_model)\n",
    "\n",
    "    def forward(self, x):\n",
    "        tokens_batch = self.cls_token.expand(x.size(0), -1, -1)  # Shape: (B, 1, D)\n",
    "        x = torch.cat((tokens_batch, x), dim=1)  # Now (B, S+1, D)\n",
    "        x = x + self.pe[:, :x.size(1)]  # Match positional encoding length\n",
    "        return x\n",
    "\n",
    "\n",
    "# Multi-Head Attention\n",
    "class AttentionHead(nn.Module):\n",
    "    def __init__(self, d_model, head_size):\n",
    "        super().__init__()\n",
    "        self.head_size = head_size\n",
    "        self.query = nn.Linear(d_model, head_size)\n",
    "        self.key = nn.Linear(d_model, head_size)\n",
    "        self.value = nn.Linear(d_model, head_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Obtaining Queries, Keys, and Values\n",
    "        Q = self.query(x)\n",
    "        K = self.key(x)\n",
    "        V = self.value(x)\n",
    "\n",
    "        # Dot Product of Queries and Keys\n",
    "        attention = Q @ K.transpose(-2,-1)\n",
    "\n",
    "        # Scaling\n",
    "        attention = attention / (self.head_size ** 0.5)\n",
    "        attention = torch.softmax(attention, dim=-1)\n",
    "        attention = attention @ V\n",
    "        return attention\n",
    "\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_model, n_heads):\n",
    "        super().__init__()\n",
    "        self.head_size = d_model // n_heads\n",
    "        self.W_o = nn.Linear(d_model, d_model)\n",
    "        self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(n_heads)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Combine attention heads\n",
    "        out = torch.cat([head(x) for head in self.heads], dim=-1)\n",
    "        out = self.W_o(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "# Transformer Encoder\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, r_mlp=4):\n",
    "        super().__init__()\n",
    "        self.d_model = d_model\n",
    "        self.n_heads = n_heads\n",
    "\n",
    "        # Sub-Layer 1 Normalization\n",
    "        self.ln1 = nn.LayerNorm(d_model)\n",
    "\n",
    "        # Multi-Head Attention\n",
    "        self.mha = MultiHeadAttention(d_model, n_heads)\n",
    "\n",
    "        # Sub-Layer 2 Normalization\n",
    "        self.ln2 = nn.LayerNorm(d_model)\n",
    "\n",
    "        # Multilayer Perception\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(d_model, d_model*r_mlp),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(d_model*r_mlp, d_model)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Residual Connection After Sub-Layer 1\n",
    "        out = x + self.mha(self.ln1(x))\n",
    "        # Residual Connection After Sub-Layer 2\n",
    "        out = out + self.mlp(self.ln2(out))\n",
    "        return out\n",
    "\n",
    "\n",
    "# Vision Transformer Model\n",
    "class VisionTransformer(nn.Module):\n",
    "    def __init__(self, d_model, n_classes, img_size, patch_size, n_channels, n_heads, n_layers):\n",
    "        super().__init__()\n",
    "\n",
    "        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \"img_size dimensions must be divisible by patch_size dimensions\"\n",
    "        assert d_model % n_heads == 0, \"d_model must be divisible by n_heads\"\n",
    "\n",
    "        self.d_model = d_model # Dimensionality of model\n",
    "        self.n_classes = n_classes # Number of classes\n",
    "        self.img_size = img_size # Image size\n",
    "        self.patch_size = patch_size # Patch size\n",
    "        self.n_channels = n_channels # Number of channels\n",
    "        self.n_heads = n_heads # Number of attention heads\n",
    "\n",
    "        self.n_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])\n",
    "        self.max_seq_length = self.n_patches + 1\n",
    "\n",
    "        self.patch_embedding = PatchEmbedding(self.d_model, self.img_size, self.patch_size, self.n_channels)\n",
    "        self.positional_encoding = PositionalEncoding( self.d_model, self.max_seq_length)\n",
    "        self.transformer_encoder = nn.Sequential(*[TransformerEncoder( self.d_model, self.n_heads) for _ in range(n_layers)])\n",
    "\n",
    "        # Classification MLP\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(self.d_model, self.n_classes),\n",
    "            nn.Softmax(dim=-1)\n",
    "        )\n",
    "\n",
    "    def forward(self, images):\n",
    "        x = self.patch_embedding(images)\n",
    "        x = self.positional_encoding(x)\n",
    "        x = self.transformer_encoder(x)\n",
    "        x = self.classifier(x[:,0])\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "521ef36f-75cb-496f-a898-69b9bd70ad3d",
   "metadata": {},
   "source": [
    "## MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f520411-e1bf-439d-9652-5af636efc44b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspired by the default MLP from the Sci-Kit Learn library: https://scikit-learn.org/stable/ and coded from scratch for GPU parallelization capabilities\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_layer_sizes=(100,), num_classes=None):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        in_dim = input_dim\n",
    "        for hidden_dim in hidden_layer_sizes:\n",
    "            layers.append(nn.Linear(in_dim, hidden_dim))\n",
    "            layers.append(nn.ReLU())\n",
    "            in_dim = hidden_dim\n",
    "        layers.append(nn.Linear(in_dim, num_classes))\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(x.size(0), -1)  # flatten batch of images\n",
    "        return self.network(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51a86794-1e54-4e40-a633-68e7832575f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(dataset_name, num_workers):\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "    ])\n",
    "\n",
    "    if dataset_name == \"MNIST\":\n",
    "        train_dataset = datasets.MNIST(root='mnist_train', train=True, download=True, transform=transform)\n",
    "        test_dataset = datasets.MNIST(root='mnist_test', train=False, download=True, transform=transform)\n",
    "    elif dataset_name == \"FashionMNIST\":\n",
    "        train_dataset = datasets.FashionMNIST(root='fmnist_train', train=True, download=True, transform=transform)\n",
    "        test_dataset = datasets.FashionMNIST(root='fmnist_test', train=False, download=True, transform=transform)\n",
    "    elif dataset_name == 'USPS':\n",
    "        train_dataset = datasets.USPS(root='usps_train', train=True, download=True, transform=transform)\n",
    "        test_dataset = datasets.USPS(root='usps_test', train=False, download=True, transform=transform)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported dataset!\")\n",
    "        \n",
    "    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=num_workers, pin_memory=True)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=num_workers, pin_memory=True)\n",
    "    return train_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6d46bf-4719-4b96-83e1-b69c3a1b5781",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters used for training\n",
    "num_classes = 10\n",
    "batch_size = 256\n",
    "epochs = 40\n",
    "L_cons = 0.1\n",
    "epsilon = 1\n",
    "total_delta = 1e-5\n",
    "nums_loops = 10\n",
    "classical_mech = 'Analytic'\n",
    "baseline_model = 'MLP'\n",
    "dataset_name = 'MNIST'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f3df52d-0659-4b01-a653-5832a09ddd6e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# === Load Dataset ===\n",
    "train_loader, test_loader = load_dataset(dataset_name, num_workers)\n",
    "\n",
    "# === Loop Over Multiple Runs to Average Results ===\n",
    "for run in range(nums_loops):\n",
    "    print(f\"\\n[RUN {run}] on {dataset_name} with epsilon = {epsilon}, and L_cons = {L_cons}\")\n",
    "\n",
    "    # === Choose Classical Differential Privacy Mechanism ===\n",
    "    if classical_mech == 'Analytic':\n",
    "        mech = GaussianAnalytic(epsilon = epsilon, sensitivity = L_cons, delta = total_delta)\n",
    "    else:\n",
    "        mech = Gaussian(epsilon = epsilon, sensitivity = attk_bound, delta = total_delta)\n",
    "    std = mech._scale\n",
    "\n",
    "    # === Select Model Architecture Based on Baseline Setting ===\n",
    "    if baseline_model == 'ResNet9':\n",
    "        model = ResNet9().to(device)\n",
    "    elif baseline_model == 'ViT':\n",
    "        if dataset_name in [\"MNIST\", \"FashionMNIST\"]:\n",
    "            img_size = (28,28)\n",
    "            patch_size = (14,14)\n",
    "        elif dataset_name in [\"USPS\"]:\n",
    "            img_size = (16,16)\n",
    "            patch_size = (8,8)\n",
    "        n_heads = 3\n",
    "        d_model = 9\n",
    "        n_layers = 3\n",
    "        in_channels = 1 \n",
    "        model = VisionTransformer(d_model, num_classes, img_size, patch_size, in_channels, n_heads, n_layers).to(device)\n",
    "    elif baseline_model == 'MLP':\n",
    "        if dataset_name in [\"MNIST\", \"FashionMNIST\"]:\n",
    "            input_dim = 28*28\n",
    "        elif dataset_name in [\"USPS\"]:\n",
    "            input_dim = 16*16\n",
    "        hl = (100,)\n",
    "        model = MLP(input_dim, hl, num_classes).to(device)\n",
    "    else:\n",
    "        print(\"Error: Model is not valid\")\n",
    "        break\n",
    "    \n",
    "    model.train()\n",
    "    # === Optimizer and Loss Setup ===\n",
    "    lrate = 0.001\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = AdamW(model.parameters(), lr=lrate)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)\n",
    "    train_losses, train_accs = [], []\n",
    "\n",
    "    # === Epoch-wise Training ===\n",
    "    '''\n",
    "        Each time load_dataset() is called, random noise is added to the training data\n",
    "        standard deviation --> std\n",
    "    '''\n",
    "    for epoch in range(epochs):\n",
    "        total_loss, correct, total = 0, 0, 0\n",
    "        for batch_idx, (data, target) in enumerate(train_loader):\n",
    "            data_noise = data + torch.randn_like(data) * std # Add classical DP noise to input features\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(data_noise.to(device))\n",
    "            loss = criterion(outputs, target.to(device))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += loss.item()\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            correct += (predicted == target.to(device)).sum().item()\n",
    "            total += target.size(0)\n",
    "\n",
    "        total_loss = total_loss / len(train_loader)\n",
    "        accuracy = correct / total\n",
    "        train_losses.append(total_loss)\n",
    "        train_accs.append(accuracy)\n",
    "\n",
    "        scheduler.step(total_loss)\n",
    "        print(f\"Epoch {epoch+1}/{epochs} | Train Accuracy: {accuracy:.4f} | Train Loss: {total_loss:.4f}\")\n",
    "    \n",
    "    # === Save model ===\n",
    "    model_save_path = f\"models/{baseline_model}_epsilon{eps}_Lcons{L_cons}_loop{run}_{dataset_name}.pt\"\n",
    "    torch.save(model.state_dict(), model_save_path) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
