{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PqKCJSkmvsJT"
      },
      "outputs": [],
      "source": [
        "# !pip install scikit-learn==1.1.0 #diffprivlib contractions\n",
        "!pip install optuna\n",
        "!pip install numpy==1.24.4  # Stable with torch 2.x and opacus\n",
        "!pip install --upgrade opacus\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from opacus import PrivacyEngine\n",
        "\n",
        "from tqdm.notebook import tqdm\n",
        "from scipy.stats import gamma\n",
        "import os as os\n",
        "import torch\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "import torch.nn as nn\n",
        "import time\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import math\n",
        "import torch.optim as optim\n",
        "\n",
        "# For confidence intervals\n",
        "import scipy\n",
        "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vqN7PrJpv3Gj"
      },
      "source": [
        "# Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T9C4FyDPv4hl"
      },
      "outputs": [],
      "source": [
        "# Chebyshev approximation polynomial\n",
        "def phi_logit(x):\n",
        "    return -math.log(1 + math.exp(-x))\n",
        "\n",
        "class Chebyshev:\n",
        "    \"\"\"\n",
        "    Chebyshev(a, b, n, func)\n",
        "    Given a function func, lower and upper limits of the interval [a,b],\n",
        "    and maximum degree n, this class computes a Chebyshev approximation\n",
        "    of the function.\n",
        "    Method eval(x) yields the approximated function value.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, a, b, n, func):\n",
        "        self.a = a\n",
        "        self.b = b\n",
        "        self.func = func\n",
        "        self.n = n\n",
        "        bma = 0.5 * (b - a)\n",
        "        bpa = 0.5 * (b + a)\n",
        "        f = [func(math.cos(math.pi * (k + 0.5) / n) * bma + bpa) for k in range(n)]\n",
        "        fac = 2.0 / n\n",
        "        self.c = [fac * sum([f[k] * math.cos(math.pi * j * (k + 0.5) / n)\n",
        "                             for k in range(n)]) for j in range(n)]\n",
        "\n",
        "    def eval(self, x):\n",
        "        a, b = self.a, self.b\n",
        "        y = (2.0 * x - a - b) * (1.0 / (b - a))\n",
        "        y2 = 2.0 * y\n",
        "        (d, dd) = (self.c[-1], 0)\n",
        "        for cj in self.c[-2:0:-1]:\n",
        "            (d, dd) = (y2 * d - dd + cj, d)\n",
        "        return y * d - dd + 0.5 * self.c[0]\n",
        "\n",
        "    def monomial_coeffs(self):\n",
        "        \"\"\"\n",
        "        Converts the Chebyshev approximation (assumed n >= 3) into a\n",
        "        standard polynomial form: f(x) ≈ a0 + a1*x + a2*x^2\n",
        "        \"\"\"\n",
        "        if self.n < 3:\n",
        "            raise ValueError(\"monomial_coeffs() requires at least a second-order (n >= 3) approximation.\")\n",
        "\n",
        "        c0, c1, c2 = self.c[0], self.c[1], self.c[2]\n",
        "        a, b = self.a, self.b\n",
        "        alpha = 2 / (b - a)\n",
        "        beta = -(a + b) / (b - a)\n",
        "\n",
        "        a0 = c0 - c2 + c1 * beta + 2 * c2 * beta**2\n",
        "        a1 = c1 * alpha + 4 * c2 * alpha * beta\n",
        "        a2 = 2 * c2 * alpha**2\n",
        "\n",
        "        return [a0, a1, a2]\n",
        "\n",
        "# ----------------------------------------------------------\n",
        "# 2. Simple CNN\n",
        "# ----------------------------------------------------------\n",
        "class SmallCNN(nn.Module):\n",
        "    def __init__(self, output_dim = 128, num_classes = 2):\n",
        "        super().__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),\n",
        "            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64 * 8 * 8, 128), nn.ReLU(),\n",
        "            nn.Linear(128, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        return self.classifier(x)\n",
        "\n",
        "def extract_cnn_features(model, loader, device=torch.device('cpu')):\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "        feats  : torch.Tensor  [N, 128]\n",
        "        labels : torch.Tensor  [N]\n",
        "    \"\"\"\n",
        "    model.eval()                       # inference mode\n",
        "    feats, labels = [], []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for x, y in tqdm(loader, desc=\"Extracting\"):\n",
        "            x = x.to(device)\n",
        "            # --- forward until the 128‑dim layer ---\n",
        "            h = model.features(x)                # shape: [b, 64, 8, 8]\n",
        "            h = h.view(h.size(0), -1)            # flatten\n",
        "            h = model.classifier[0](h)           # Linear(64*8*8 → 128)\n",
        "            h = model.classifier[1](h)           # ReLU\n",
        "\n",
        "            feats.append(h.cpu())\n",
        "            labels.append(y)\n",
        "\n",
        "    feats  = torch.cat(feats)   # [N, 128]\n",
        "    labels = torch.cat(labels)  # [N]\n",
        "    return feats, labels\n",
        "\n",
        "##################### Load Datasets #####################\n",
        "def filter_and_relabel(dataset, digits):\n",
        "    idx = [i for i, y in enumerate(dataset.targets) if int(y) in digits]\n",
        "    imgs = torch.stack([dataset[i][0] for i in idx])\n",
        "    labels = torch.tensor([0 if dataset.targets[i] == digits[0] else 1 for i in idx])\n",
        "    return TensorDataset(imgs, labels)\n",
        "\n",
        "def load_CIFAR10(digits, batch_size):\n",
        "    transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5,), (1.0,))\n",
        "    ])\n",
        "\n",
        "    train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
        "    test_ds  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
        "\n",
        "    train_set = filter_and_relabel(train_ds, digits)\n",
        "    test_set  = filter_and_relabel(test_ds,  digits)\n",
        "\n",
        "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
        "    return train_loader, test_loader\n",
        "\n",
        "def load_CIFAR100(digits, batch_size):\n",
        "    transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5,), (1.0,))\n",
        "    ])\n",
        "\n",
        "    train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
        "    test_ds  = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)\n",
        "\n",
        "    train_set = filter_and_relabel(train_ds, digits)\n",
        "    test_set  = filter_and_relabel(test_ds,  digits)\n",
        "\n",
        "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
        "    return train_loader, test_loader\n",
        "\n",
        "def load_MNIST(digits, batch_size):\n",
        "    transform = transforms.Compose([\n",
        "        transforms.Resize((32, 32)),\n",
        "        transforms.Grayscale(num_output_channels=3),  # no-op for RGB images\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5,), (1.0,))\n",
        "    ])\n",
        "\n",
        "    train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
        "    test_ds  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
        "\n",
        "    train_set = filter_and_relabel(train_ds, digits)\n",
        "    test_set  = filter_and_relabel(test_ds,  digits)\n",
        "\n",
        "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
        "    return train_loader, test_loader\n",
        "\n",
        "def load_FMNIST(digits, batch_size):\n",
        "    transform = transforms.Compose([\n",
        "        transforms.Resize((32, 32)),\n",
        "        transforms.Grayscale(num_output_channels=3),  # no-op for RGB images\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5,), (1.0,))\n",
        "    ])\n",
        "\n",
        "    train_ds = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
        "    test_ds  = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)\n",
        "\n",
        "    train_set = filter_and_relabel(train_ds, digits)\n",
        "    test_set  = filter_and_relabel(test_ds,  digits)\n",
        "\n",
        "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
        "    return train_loader, test_loader\n",
        "\n",
        "##################### Logistic Regression Utilities #####################\n",
        "\n",
        "# Compute the derivatives of the logistic loss function\n",
        "def logistic_derivatives(X, y, theta, n=1):\n",
        "    # Compute θ^T x_i for each i\n",
        "    t = X @ theta  # shape (n,)\n",
        "    # σ(θ^T x_i)\n",
        "    p = 1.0 / (1.0 + np.exp(-t))\n",
        "    # d_i = p - y_i\n",
        "    d = p - y\n",
        "    return d/n\n",
        "\n",
        "def sample_from_density(d, c):\n",
        "    \"\"\"\n",
        "    Samples a random vector b from the density h(b) ∝ exp(-c * ||b||).\n",
        "\n",
        "    Parameters:\n",
        "    - d: Dimension of the vector b\n",
        "    - c: Parameter controlling the density\n",
        "\n",
        "    Returns:\n",
        "    - b: A sampled vector from the desired density\n",
        "    \"\"\"\n",
        "    # Step 1: Sample the norm of b from Gamma(d, 1/c)\n",
        "    norm_b = np.random.gamma(shape=d, scale=1/c)\n",
        "\n",
        "    # Step 2: Sample a random direction uniformly on the unit sphere\n",
        "    direction = np.random.normal(size=d)\n",
        "    direction /= np.linalg.norm(direction)\n",
        "\n",
        "    # Step 3: Combine norm and direction\n",
        "    b = norm_b * direction\n",
        "\n",
        "    return b\n",
        "\n",
        "# calculate values of \\sigma_1 and \\sigma_2 for (\\alpha,\\eps)-Renyi-DP\n",
        "def objective_func(alpha, k, d, sigma1, sigma2, delta,C_max):\n",
        "    if alpha <= 1 or alpha >= np.min((sigma1,sigma2))/(1+C_max):\n",
        "        return np.inf  # Penalize out-of-bound values\n",
        "\n",
        "    term1 = (k * alpha) / (2 * (alpha - 1)) * np.log(1.0 - (1.0 + C_max) / (np.max((sigma1, sigma2))))\n",
        "    term2 = - (k / (2 * (alpha - 1))) * np.log(1 - (alpha*(1 + C_max)) / np.min((sigma1,sigma2)))\n",
        "    term3 = (np.log(1.0 / delta) + (alpha - 1)*np.log(1-1/alpha) - np.log(alpha)) / (alpha - 1)\n",
        "\n",
        "    return term1 + term2 + term3\n",
        "\n",
        "def solve_sigma_renyi(sigma_DP, n_prime, d, delta, target_epsilon, C_max):\n",
        "    # Define binary search bounds\n",
        "    left, right = sigma_DP / 30000.0, 30000.0*sigma_DP\n",
        "    best_sigma = right  # Default to upper bound in case no solution is found\n",
        "    while right - left > 1e-6:  # Precision threshold\n",
        "        mid_sigma = (left + right) / 2\n",
        "        # Solve for optimal alpha given the current sigma\n",
        "        result = scipy.optimize.minimize_scalar(objective_func,\n",
        "                                 bounds=(1 + 1e-5, mid_sigma - 1e-5),\n",
        "                                 args=(n_prime, d, mid_sigma, mid_sigma, delta,C_max),\n",
        "                                 method='bounded')\n",
        "        if result.success and result.fun < target_epsilon:\n",
        "            best_sigma = mid_sigma  # Update best found sigma\n",
        "            right = mid_sigma  # Search for a smaller sigma\n",
        "        else:\n",
        "            left = mid_sigma  # Increase sigma to meet target_epsilon\n",
        "    return best_sigma"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "41Z8dgOpttO7"
      },
      "source": [
        "## Logistic Regression DP-Training Procedure"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i8B0JbRQtvm5"
      },
      "outputs": [],
      "source": [
        "# ---------- helper : logistic loss with ℓ2‑regularisation ----------\n",
        "def lr_loss(w, X, y, lam):\n",
        "    \"\"\"Binary‑logistic loss + (lam/2)||w||²  (expects y∈{0,1}).\"\"\"\n",
        "    logits = X @ w\n",
        "    data_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, y, reduction='mean')\n",
        "    reg_loss  = 0.5 * lam * w.pow(2).sum()\n",
        "    return data_loss + reg_loss\n",
        "\n",
        "# ---------- Objective‑Perturbation with L‑BFGS ----------\n",
        "def Objective_perturb_LBFGS(X_train, y_train, X_test, y_test,\\\n",
        "                            lambda_param, epsilon, delta,\\\n",
        "                            num_steps=500, tol=1e-6, verbose=False,\\\n",
        "                            device=torch.device('cpu')):\n",
        "\n",
        "    # --- move data to torch tensors ---\n",
        "    Xtr = torch.tensor(X_train, dtype=torch.float32, device=DEVICE)\n",
        "    ytr = torch.tensor(y_train, dtype=torch.float32, device=DEVICE)\n",
        "    Xte = torch.tensor(X_test,  dtype=torch.float32, device=DEVICE)\n",
        "    yte = torch.tensor(y_test,  dtype=torch.float32, device=DEVICE)\n",
        "\n",
        "    n, d = Xtr.shape\n",
        "\n",
        "    # --- sample perturbation vector b ~ N(0, I_d) ---\n",
        "    b_noise     = torch.randn(d, device=DEVICE)\n",
        "    sigma_scale = np.sqrt((4.0*epsilon + 8.0*np.log(2.0/delta)) / (epsilon**2))\n",
        "    b_vec       = sigma_scale * b_noise\n",
        "    Delta       = 1.0/(2.0 * epsilon)\n",
        "\n",
        "    # --- initialise parameter vector ---\n",
        "    w = torch.zeros(d, device=DEVICE, requires_grad=True)\n",
        "\n",
        "    # --- L‑BFGS solver ---\n",
        "    optimizer = optim.LBFGS([w], max_iter=num_steps, tolerance_grad=tol, tolerance_change=1e-20)\n",
        "\n",
        "    # closure with perturbation term  (b·w)/n\n",
        "    def closure():\n",
        "        optimizer.zero_grad()\n",
        "        loss = lr_loss(w, Xtr, ytr, lambda_param) + b_vec.dot(w) / n + w.dot(w) * Delta/(4.0 * n)\n",
        "        loss.backward()\n",
        "        return loss\n",
        "\n",
        "    t0 = time.time()\n",
        "    optimizer.step(closure)\n",
        "    t1 = time.time()\n",
        "\n",
        "    # --- evaluate on test set ---\n",
        "    with torch.no_grad():\n",
        "        probs  = torch.sigmoid(Xte @ w)\n",
        "        preds  = (probs >= 0.5).long()\n",
        "        acc    = (preds == yte.long()).float().mean().item()\n",
        "\n",
        "    return acc, t1 - t0\n",
        "\n",
        "# Solve with chebyshev approximation\n",
        "def second_order_cheb(X_train, y_train, k, n, d, tau, sigma_matrix, Q = 6):\n",
        "\n",
        "  cheb = Chebyshev(-Q, Q, 3, phi_logit)\n",
        "  c1 = cheb.c[1]\n",
        "  c2 = cheb.c[2]\n",
        "\n",
        "  # Build noisy dataset\n",
        "  N = np.random.randn(k, d)\n",
        "  N_y = np.random.randn(k)\n",
        "\n",
        "  # Calculate lambda_min\n",
        "  y_col = y_train.reshape(-1, 1) if y_train.ndim == 1 else y_train\n",
        "  XY = np.hstack((X_train, y_col))\n",
        "  lambda_min = np.real(np.min(np.linalg.eigvals(XY.T @ XY)))\n",
        "  sigma_eigenval = sigma_matrix/np.sqrt(k)\n",
        "\n",
        "  # Compute projected data\n",
        "  time_start = time.time()\n",
        "  S = np.random.randn(k, n)\n",
        "  if sigma_matrix <= tau:\n",
        "      X_PR = S @ X_train + np.sqrt(sigma_matrix) * N  # (n_prime,d)\n",
        "      y_PR = (S @ (2.0 * y_train.reshape(-1, 1) - 1.0)).ravel() + np.sqrt(sigma_matrix) * N_y\n",
        "  else:\n",
        "      gamma_tilde = np.max((0, lambda_min - np.sqrt(sigma_eigenval) * (tau - np.random.randn())))\n",
        "      sigma_tilde = np.sqrt(np.max((0, sigma_matrix - gamma_tilde)))\n",
        "      X_PR = S @ X_train + sigma_tilde * N  # (n_prime,d)\n",
        "      y_PR = (S @ (2.0 * y_train.reshape(-1, 1) - 1.0)).ravel() + sigma_tilde * N_y\n",
        "\n",
        "  # Solve over the noised sufficient statistics\n",
        "  theta_acc = -(c1/(2.0 * c2)) * np.linalg.inv(X_PR.T @ X_PR) @ (X_PR.T @ y_PR)\n",
        "  time_end = time.time()\n",
        "  total_time = time_end - time_start\n",
        "\n",
        "  # Calculate accuracy\n",
        "  t_test = X_test @ theta_acc\n",
        "  p_test = 1.0 / (1.0 + np.exp(-t_test))\n",
        "  preds = (p_test >= 0.5).astype(int)\n",
        "\n",
        "  # return final accuracy and time\n",
        "  return np.mean(preds == y_test), total_time, theta_acc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4uTrtI5DEdB"
      },
      "source": [
        "### Training Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FeZdQdNsDGXm"
      },
      "outputs": [],
      "source": [
        "# Train with DP-SGD\n",
        "def train_with_dp(model, optimizer, train_loader, test_loader, epochs, lr, max_grad_norm,\n",
        "                    delta, eps, device=\"cpu\"):\n",
        "    \"\"\"\n",
        "    Trains the given model with DP-SGD using Opacus.\n",
        "    Returns:\n",
        "        final_epsilon, test_error, training_time\n",
        "    \"\"\"\n",
        "\n",
        "    model.train()\n",
        "    # Move model to device\n",
        "    model = model.to(device)\n",
        "\n",
        "    # Attach privacy engine\n",
        "    privacy_engine = PrivacyEngine()\n",
        "    model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(\n",
        "        module=model,\n",
        "        optimizer=optimizer,\n",
        "        data_loader=train_loader,\n",
        "        target_epsilon=eps,\n",
        "        target_delta=delta,\n",
        "        epochs=epochs,\n",
        "        max_grad_norm=max_grad_norm\n",
        "    )\n",
        "\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    start_time = time.time()\n",
        "    for epoch in tqdm(range(epochs)):\n",
        "        for images, labels in train_loader:\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            optimizer.zero_grad()\n",
        "            outputs = model(images)\n",
        "            loss = criterion(outputs, labels)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "    training_time = time.time() - start_time\n",
        "\n",
        "    # Compute final epsilon\n",
        "    epsilon = privacy_engine.get_epsilon(delta=delta)\n",
        "\n",
        "    # Evaluate model\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for images, labels in test_loader:\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            outputs = model(images)\n",
        "            _, predicted = torch.max(outputs, 1)\n",
        "            total += labels.size(0)\n",
        "            correct += (predicted == labels).sum().item()\n",
        "\n",
        "    test_accuracy = 100.0 * correct / total\n",
        "    test_error = 100.0 - test_accuracy\n",
        "\n",
        "    return epsilon, test_accuracy, training_time, model\n",
        "\n",
        "def train_logistic_regression_with_dp_sgd(X_train, y_train, X_test, y_test,\n",
        "                      epsilon, delta, max_grad_norm,\n",
        "                      num_epochs=10,\n",
        "                      batch_size=256,\n",
        "                      learning_rate=1e-2):\n",
        "    \"\"\"\n",
        "    Trains a logistic regression model with DP-SGD using Opacus.\n",
        "\n",
        "    Args:\n",
        "        X_train, y_train: training data (numpy arrays)\n",
        "        X_test, y_test: test data (numpy arrays)\n",
        "        epsilon, delta: DP parameters\n",
        "        num_epochs: number of training epochs\n",
        "        batch_size: training batch size\n",
        "        learning_rate: SGD learning rate\n",
        "\n",
        "    Returns:\n",
        "        test_acc: accuracy on the test set\n",
        "        train_time: total training time in seconds\n",
        "    \"\"\"\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Convert data to PyTorch tensors\n",
        "    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
        "    y_train_tensor = torch.tensor(y_train, dtype=torch.long)\n",
        "    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
        "    y_test_tensor = torch.tensor(y_test, dtype=torch.long)\n",
        "\n",
        "    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)\n",
        "    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "    # Define logistic regression model\n",
        "    class LogisticRegression(nn.Module):\n",
        "        def __init__(self, input_dim, output_dim):\n",
        "            super().__init__()\n",
        "            self.linear = nn.Linear(input_dim, output_dim)\n",
        "\n",
        "        def forward(self, x):\n",
        "            return self.linear(x)\n",
        "\n",
        "    model = LogisticRegression(X_train.shape[1], len(np.unique(y_train)))\n",
        "    model = model.to(device)\n",
        "\n",
        "    # Loss and optimizer\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
        "\n",
        "    # Attach Privacy Engine\n",
        "    privacy_engine = PrivacyEngine()\n",
        "    model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(\n",
        "        module=model,\n",
        "        optimizer=optimizer,\n",
        "        data_loader=train_loader,\n",
        "        epochs=num_epochs,\n",
        "        target_epsilon=epsilon,\n",
        "        target_delta=delta,\n",
        "        max_grad_norm=max_grad_norm\n",
        "    )\n",
        "\n",
        "    # Training\n",
        "    model.train()\n",
        "    start_time = time.time()\n",
        "    for epoch in range(num_epochs):\n",
        "        for x_batch, y_batch in train_loader:\n",
        "            x_batch, y_batch = x_batch.to(device), y_batch.to(device)\n",
        "            optimizer.zero_grad()\n",
        "            outputs = model(x_batch)\n",
        "            loss = criterion(outputs, y_batch)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "    end_time = time.time()\n",
        "\n",
        "    # Evaluation\n",
        "    model.eval()\n",
        "    from sklearn.metrics import accuracy_score\n",
        "    with torch.no_grad():\n",
        "        y_pred = model(X_test_tensor.to(device))\n",
        "        y_pred_labels = torch.argmax(y_pred, dim=1).cpu().numpy()\n",
        "        test_acc = accuracy_score(y_test, y_pred_labels)\n",
        "\n",
        "    train_time = end_time - start_time\n",
        "    return test_acc, train_time\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W4AekOBuv5gI"
      },
      "source": [
        "# Load Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yEimHIRTv6-C"
      },
      "outputs": [],
      "source": [
        "# Choose dataset from the next options\n",
        "# dataset_type = 'CIFAR10'\n",
        "# dataset_type = 'CIFAR100'\n",
        "dataset_type = 'MNIST'\n",
        "# dataset_type = 'FMNIST'\n",
        "# dataset_type = 'FMNIST'\n",
        "\n",
        "digits = [3, 8]\n",
        "batch_size = 500\n",
        "\n",
        "if dataset_type == 'CIFAR10':\n",
        "    train_loader, test_loader = load_CIFAR10(digits, batch_size)\n",
        "elif dataset_type == 'CIFAR100':\n",
        "    digits = [3, 8]\n",
        "    batch_size = 500\n",
        "    train_loader, test_loader = load_CIFAR100(digits, batch_size)\n",
        "elif dataset_type == 'MNIST':\n",
        "    digits = [3, 8]\n",
        "    batch_size = 500\n",
        "    train_loader, test_loader = load_MNIST(digits, batch_size)\n",
        "elif dataset_type == 'FMNIST':\n",
        "    digits = [3, 8]\n",
        "    batch_size = 500\n",
        "    train_loader, test_loader = load_FMNIST(digits, batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tppKiIGG47Ps"
      },
      "source": [
        "# Baseline Model (Non-Private)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mB8NARUD4822"
      },
      "outputs": [],
      "source": [
        "lr_full = 0.001\n",
        "epochs = 20\n",
        "output_dim = 128\n",
        "\n",
        "model_non_private = SmallCNN(output_dim=output_dim, num_classes=2)\n",
        "optimizer = optim.Adam(model_non_private.parameters(), lr=lr_full, weight_decay=0.0)\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "# Baseline Training\n",
        "model_non_private = model_non_private.to(DEVICE)\n",
        "\n",
        "# Training loop\n",
        "model_non_private.train()\n",
        "start_time = time.time()\n",
        "for epoch in tqdm(range(epochs)):\n",
        "    total_loss = 0.0\n",
        "    num_batches = 0\n",
        "    for images, labels in train_loader:\n",
        "        images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
        "        optimizer.zero_grad()\n",
        "        outputs = model_non_private(images)\n",
        "        loss = criterion(outputs, labels)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Accumulate loss\n",
        "        total_loss += loss.item()\n",
        "        num_batches += 1\n",
        "\n",
        "    # Print the average loss for the epoch\n",
        "    avg_loss = total_loss / num_batches\n",
        "    print(f\"Epoch [{epoch+1}/{epochs}] - Average Train Loss: {avg_loss:.4f}\")\n",
        "training_time = time.time() - start_time\n",
        "\n",
        "# Evaluate model on the test set\n",
        "model_non_private.eval()\n",
        "correct = 0\n",
        "total = 0\n",
        "with torch.no_grad():\n",
        "    for images, labels in test_loader:\n",
        "        images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
        "        outputs = model_non_private(images)\n",
        "        _, predicted = torch.max(outputs, 1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "baseline_test_acc = correct / total\n",
        "print('Baseline (Non-Private) Test accuracy: ' + str(100.0 * baseline_test_acc))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PHOphOHXEId0"
      },
      "source": [
        "# Set hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7vTmNrdwEeme"
      },
      "outputs": [],
      "source": [
        "# Privacy parameters\n",
        "epsilon_values = np.logspace(-1.0, 1.5, 5)\n",
        "delta_DP = np.min((1e-6, 0.01/len(train_loader.dataset)))\n",
        "clip_norm_dpsgd = 4.0\n",
        "removal_size = 10\n",
        "\n",
        "# Number of Monte-Carlo iteraitons\n",
        "iters = 100\n",
        "\n",
        "# Confidence Intervals\n",
        "confidence = 0.95  # Change to your desired confidence level\n",
        "t_value = scipy.stats.t.ppf((1 + confidence) / 2.0, df=iters - 1)\n",
        "\n",
        "# Calculate the variance of the noise for ensuring privacy of our algorithm\n",
        "delta_DP_adjusted       = ((removal_size-1)/removal_size) * delta_DP\n",
        "tau = np.sqrt(2.0 * np.log(3.0/delta_DP_adjusted))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EhoxksfVAZri"
      },
      "source": [
        "## Use optimal hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gENzSbV_9Cue"
      },
      "outputs": [],
      "source": [
        "if dataset_type == 'MNIST':\n",
        "  k_opt_second_order = int(4.5 * output_dim)\n",
        "\n",
        "elif dataset_type == 'CIFAR10':\n",
        "  k_opt_second_order = int(4.5 * output_dim)\n",
        "\n",
        "elif dataset_type == 'CIFAR100':\n",
        "  k_opt_second_order = int(4.5 * output_dim)\n",
        "\n",
        "elif dataset_type == 'FMNIST':\n",
        "  k_opt_second_order = int(4.5 * output_dim)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6LMvjziQZQPR"
      },
      "source": [
        "### Baseline train with DP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wIfAwMNrZRqs"
      },
      "outputs": [],
      "source": [
        "model_hyperparameters = SmallCNN(output_dim=output_dim, num_classes=2)\n",
        "optimizer = optim.Adam(model_hyperparameters.parameters(), lr=lr_full, weight_decay=0.0)\n",
        "_, _, _, model_dp = train_with_dp(model_hyperparameters, optimizer, train_loader, test_loader, epochs, lr_full, clip_norm_dpsgd,\n",
        "                    delta_DP/removal_size, epsilon_values[-1]/removal_size, device=DEVICE)\n",
        "# Extract Features\n",
        "train_feats, train_labels = extract_cnn_features(model_dp, train_loader, DEVICE)\n",
        "test_feats,  test_labels  = extract_cnn_features(model_dp, test_loader,  DEVICE)\n",
        "# Convert to NumPy for scikit‑learn\n",
        "X_train = train_feats.numpy()\n",
        "y_train = train_labels.numpy()\n",
        "X_test  = test_feats.numpy()\n",
        "y_test  = test_labels.numpy()\n",
        "norm_fact = np.sqrt(np.max(np.sum(X_train**2, 1)))\n",
        "X_train = X_train/norm_fact\n",
        "X_test = X_test/norm_fact\n",
        "n, d = X_train.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k1tCa9bJwJEU"
      },
      "source": [
        "# Main For Loop: Simulate optimal hyperparameter setting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qyzKA1s9aFn_"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "# Suppress specific Opacus warnings\n",
        "warnings.filterwarnings(\"ignore\", message=\"Secure RNG turned off.*\")\n",
        "warnings.filterwarnings(\"ignore\", message=\"Optimal order is the largest alpha.*\")\n",
        "warnings.filterwarnings(\"ignore\", message=\"Full backward hook is firing when gradients are computed with respect to module outputs since no inputs require gradients.*\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "Q3LY-gw2wKWa"
      },
      "outputs": [],
      "source": [
        "opt_acc_ours = []\n",
        "opt_acc_ours_std = []\n",
        "\n",
        "opt_acc_objective = []\n",
        "opt_acc_objective_std = []\n",
        "\n",
        "opt_acc_dp_sgd = []\n",
        "opt_acc_dp_sgd_std = []\n",
        "\n",
        "total_time_ours = 0.0\n",
        "total_time_obj_perturb = 0.0\n",
        "total_time_dp_sgd = 0.0\n",
        "\n",
        "n, d = X_train.shape\n",
        "lambda_param = 0.0\n",
        "###############################################################################\n",
        "# Outer loops: iterate over learning_rate_list and iters_GD_list\n",
        "###############################################################################\n",
        "\n",
        "# Loop over different epsilon values\n",
        "for eps_idx, eps in tqdm(enumerate(epsilon_values)):\n",
        "    print('Started eps = ' + str(eps))\n",
        "\n",
        "    curr_test_acc = []\n",
        "    curr_test_acc_loss_perturb = []\n",
        "    curr_test_acc_dp_sgd = []\n",
        "\n",
        "    # Zero the time counters only for n_prime == 0\n",
        "    curr_time_sketching = 0.0\n",
        "    curr_time_obj_perturb = 0.0\n",
        "    curr_time_dp_sgd = 0.0\n",
        "\n",
        "    # Adjust epsilon and delta\n",
        "    delta_DP_adjusted = ((removal_size-1)/removal_size) * delta_DP\n",
        "    epsilon_adjusted  = ((removal_size-1)/removal_size) * eps\n",
        "\n",
        "    sigma_DP = 2.0 * np.log(1.25/delta_DP) / (eps**2)\n",
        "    sigma_matrix = solve_sigma_renyi(sigma_DP, k_opt_second_order, d, delta_DP_adjusted, epsilon_adjusted, 1.0)\n",
        "\n",
        "    # Train model with privacy\n",
        "    model_hyperparameters = SmallCNN(output_dim=output_dim, num_classes=2)\n",
        "    optimizer = optim.Adam(model_hyperparameters.parameters(), lr=lr_full, weight_decay=0.0)\n",
        "    _, _, _, model_dp = train_with_dp(model_hyperparameters, optimizer, train_loader, test_loader, epochs, lr_full, clip_norm_dpsgd,\n",
        "                        delta_DP/removal_size, eps/removal_size, device=DEVICE)\n",
        "\n",
        "    # Extract Features\n",
        "    train_feats, train_labels = extract_cnn_features(model_dp, train_loader, DEVICE)\n",
        "    test_feats,  test_labels  = extract_cnn_features(model_dp, test_loader,  DEVICE)\n",
        "\n",
        "    # Convert to NumPy for scikit‑learn\n",
        "    X_train = train_feats.numpy()\n",
        "    y_train = train_labels.numpy()\n",
        "    X_test  = test_feats.numpy()\n",
        "    y_test  = test_labels.numpy()\n",
        "\n",
        "    norm_fact = np.sqrt(np.max(np.sum(X_train**2, 1)))\n",
        "    X_train = X_train/norm_fact\n",
        "    X_test = X_test/norm_fact\n",
        "\n",
        "    print('Data norm is:' + str(np.max(np.sum(X_train**2, 1))))\n",
        "\n",
        "    # Repeat the experiment `iters` times for averaging\n",
        "    for _ in tqdm(range(iters)):\n",
        "\n",
        "        #######################################################################\n",
        "        # (1) Objective Perturbation\n",
        "        #######################################################################\n",
        "        test_acc_obj_perturb, time_obj_perturb = Objective_perturb_LBFGS(X_train, y_train, X_test, y_test,\\\n",
        "                        lambda_param, epsilon_adjusted, delta_DP_adjusted,\\\n",
        "                        num_steps=500,  tol=1e-6, verbose=False,\\\n",
        "                        device=torch.device('cpu'))\n",
        "\n",
        "        curr_time_obj_perturb += time_obj_perturb\n",
        "        curr_test_acc_loss_perturb.append(test_acc_obj_perturb)\n",
        "\n",
        "        #######################################################################\n",
        "        # (2) Second order approximation\n",
        "        #######################################################################\n",
        "        test_acc_second_order, \\\n",
        "          time_second_order, final_theta_cheb = \\\n",
        "                    second_order_cheb(X_train, y_train, k_opt_second_order, n, d, tau, sigma_matrix)\n",
        "        curr_time_sketching += time_second_order\n",
        "        curr_test_acc.append(test_acc_second_order)\n",
        "\n",
        "        #######################################################################\n",
        "        # (3) DP SGD\n",
        "        #######################################################################\n",
        "        test_acc_dp_sgd, time_dp_sgd = \\\n",
        "                    train_logistic_regression_with_dp_sgd(X_train, y_train, X_test, y_test, epsilon_adjusted, delta_DP_adjusted, max_grad_norm = 1.0, num_epochs = 10, batch_size = 1024, learning_rate = 0.5)\n",
        "        curr_time_dp_sgd += time_dp_sgd\n",
        "        curr_test_acc_dp_sgd.append(test_acc_dp_sgd)\n",
        "\n",
        "\n",
        "    # End of for _ in tqdm(range(iters))\n",
        "    opt_acc_ours.append(np.mean(curr_test_acc))\n",
        "    opt_acc_ours_std.append(t_value * np.std(curr_test_acc)/np.sqrt(iters))\n",
        "    opt_acc_objective.append(np.mean(curr_test_acc_loss_perturb))\n",
        "    opt_acc_objective_std.append(t_value * np.std(curr_test_acc_loss_perturb)/np.sqrt(iters))\n",
        "    opt_acc_dp_sgd.append(np.mean(curr_test_acc_dp_sgd))\n",
        "    opt_acc_dp_sgd_std.append(t_value * np.std(curr_test_acc_dp_sgd)/np.sqrt(iters))\n",
        "\n",
        "    if eps_idx == (len(epsilon_values) - 1):\n",
        "      total_time_ours += curr_time_sketching / iters\n",
        "      total_time_obj_perturb += curr_time_obj_perturb / iters\n",
        "      total_time_dp_sgd += curr_time_dp_sgd / iters\n",
        "\n",
        "    # Print some diagnostics\n",
        "    print('Test Acc Sketching: ' + str(np.mean(curr_test_acc)))\n",
        "    print('Test Acc STD Sketching: ' + str( np.std(curr_test_acc)))\n",
        "\n",
        "    print('Test Acc Objective Pertubration: ' + str(np.mean(curr_test_acc_loss_perturb)))\n",
        "    print('Test Acc STD Objective Pertubration: ' + str(np.std(curr_test_acc_loss_perturb)))\n",
        "\n",
        "    print('Test Acc DP SGD: ' + str(np.mean(curr_test_acc_dp_sgd)))\n",
        "    print('Test Acc STD DP SGD: ' + str(np.std(curr_test_acc_dp_sgd)))\n",
        "\n",
        "    print('Total time ours: ' + str(curr_time_sketching / iters))\n",
        "    print('Total time Objective Perturb: ' + str(curr_time_obj_perturb / iters))\n",
        "    print('Total time DP SGD: ' + str(curr_time_dp_sgd / iters))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PhpNK4mKS11a"
      },
      "source": [
        "## Plot results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dyCzzTW1-5o1"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(9,8), constrained_layout=True)\n",
        "line_handles = []\n",
        "\n",
        "# Baseline (non-private) horizontal line\n",
        "baseline_line = plt.axhline(\n",
        "    y=1.0 - baseline_test_acc,\n",
        "    color='black',\n",
        "    linestyle='--',\n",
        "    linewidth=4.0,\n",
        "    label='Baseline (Non-private)'\n",
        ")\n",
        "line_handles.append(baseline_line)\n",
        "# (A) Plot Ours\n",
        "line1, = plt.plot(\n",
        "    epsilon_values,\n",
        "    [1.0 - element for element in opt_acc_ours],\n",
        "    marker='o',\n",
        "    markersize=12,\n",
        "    color='orangered',\n",
        "    label=fr\"Gaussian mixing (ours) $\\frac{{k}}{{d}}$: {k_opt_second_order/X_train.shape[1]:.3f}\"\n",
        ")\n",
        "# Shading for standard errors\n",
        "y_error_upper = [\n",
        "    max(0.0, 1.0 - (m - s/2.0)) for (m, s) in zip(opt_acc_ours, opt_acc_ours_std)\n",
        "]\n",
        "y_error_lower = [\n",
        "    max(0.0, 1.0 - (m + s/2.0)) for (m, s) in zip(opt_acc_ours, opt_acc_ours_std)\n",
        "]\n",
        "plt.fill_between(\n",
        "    epsilon_values, y_error_lower, y_error_upper,\n",
        "    alpha=0.2, color='orangered'\n",
        ")\n",
        "line_handles.append(line1)\n",
        "\n",
        "# (B) Plot Objective Perturbation (once for i=0)\n",
        "line2, = plt.plot(\n",
        "    epsilon_values,\n",
        "    [1.0 - elem for elem in opt_acc_objective],\n",
        "    marker='*',\n",
        "    markersize=12,\n",
        "    color='blue',\n",
        "    label='Objective Perturbation (Guo et al., 2020)'\n",
        ")\n",
        "y_error_upper = [\n",
        "    max(0.0, 1.0 - (m - s/2.0)) for (m, s) in zip(opt_acc_objective, opt_acc_objective_std)\n",
        "]\n",
        "y_error_lower = [\n",
        "    max(0.0, 1.0 - (m + s/2.0)) for (m, s) in zip(opt_acc_objective, opt_acc_objective_std)\n",
        "]\n",
        "plt.fill_between(\n",
        "    epsilon_values, y_error_lower, y_error_upper,\n",
        "    alpha=0.2, color='blue'\n",
        ")\n",
        "line_handles.append(line2)\n",
        "\n",
        "\n",
        "# (C) Plot DP SGD (once for i=0)\n",
        "line3, = plt.plot(\n",
        "    epsilon_values,\n",
        "    [1.0 - elem for elem in opt_acc_dp_sgd],\n",
        "    marker='p',\n",
        "    markersize=12,\n",
        "    color='magenta',\n",
        "    label='DP-SGD'\n",
        ")\n",
        "y_error_upper = [\n",
        "    max(0.0, 1.0 - (m - s/2.0)) for (m, s) in zip(opt_acc_dp_sgd, opt_acc_dp_sgd_std)\n",
        "]\n",
        "y_error_lower = [\n",
        "    max(0.0, 1.0 - (m + s/2.0)) for (m, s) in zip(opt_acc_dp_sgd, opt_acc_dp_sgd_std)\n",
        "]\n",
        "plt.fill_between(\n",
        "    epsilon_values, y_error_lower, y_error_upper,\n",
        "    alpha=0.2, color='magenta'\n",
        ")\n",
        "line_handles.append(line3)\n",
        "\n",
        "# Axis labels\n",
        "plt.xlabel(r\"$\\epsilon_{\\text{DP}}$\", fontsize=24)\n",
        "plt.ylabel(r\"$Test\\ Error\\ Rate$\", fontsize=24)\n",
        "plt.grid(True)\n",
        "\n",
        "# Build a multi-row legend\n",
        "handles, labels = plt.gca().get_legend_handles_labels()\n",
        "split1 = len(handles) // 4\n",
        "split2 = 2 * len(handles) // 4\n",
        "split3 = 3 * len(handles) // 4\n",
        "\n",
        "legend1 = plt.legend(handles[:split1], labels[:split1], loc=\"lower center\",\n",
        "                         bbox_to_anchor=(0.5, -0.35), ncol=2, frameon=False,\n",
        "                         fontsize=24, labelspacing=1.0)\n",
        "\n",
        "legend2 = plt.legend(handles[split1:split2], labels[split1:split2], loc=\"lower center\",\n",
        "                     bbox_to_anchor=(0.5, -0.45), ncol=2, frameon=False,\n",
        "                     fontsize=24, labelspacing=1.0)\n",
        "\n",
        "legend3 = plt.legend(handles[split2:split3], labels[split2:split3], loc=\"lower center\",\n",
        "                     bbox_to_anchor=(0.5, -0.55), ncol=2, frameon=False,\n",
        "                     fontsize=24, labelspacing=1.0)\n",
        "\n",
        "legend4 = plt.legend(handles[split3:], labels[split3:], loc=\"lower center\",\n",
        "                     bbox_to_anchor=(0.5, -0.65), ncol=2, frameon=False,\n",
        "                     fontsize=24, labelspacing=1.0)\n",
        "\n",
        "# Add the first two legends to the current plot\n",
        "plt.gca().add_artist(legend1)\n",
        "plt.gca().add_artist(legend2)\n",
        "plt.gca().add_artist(legend3)\n",
        "\n",
        "plt.title(dataset_type + ' Classification', fontsize = 28)\n",
        "\n",
        "line1 = f\"Obj. Pert.: {total_time_obj_perturb/total_time_ours:.2f}× Faster\"\n",
        "line2 = f\"DP-SGD: {total_time_dp_sgd/total_time_ours:.2f}× Faster\"\n",
        "\n",
        "\n",
        "# Combine all lines with the DP delta string\n",
        "textstr = (\n",
        "    fr\"$\\delta_{{\\mathrm{{DP}}}} = \\min(10^{{-6}}, \\frac{{0.01}}{{n}})$\" + \"\\n\"\n",
        "    + f\"{line1}\\n\"\n",
        "    + f\"{line2}\"\n",
        ")\n",
        "\n",
        "\n",
        "plt.xticks(fontsize=18)\n",
        "plt.yticks(fontsize=18)\n",
        "plt.xscale('log')\n",
        "plt.text(\n",
        "    0.95, 0.95, textstr, transform=plt.gca().transAxes, fontsize=24,\n",
        "    verticalalignment='top', horizontalalignment='right',\n",
        "    bbox=dict(boxstyle=\"round,pad=0.3\", edgecolor='black', facecolor='white'),\n",
        "    family='monospace'\n",
        ")\n",
        "\n",
        "# Save figures\n",
        "base_name = f\"Logistic_oneshot_\" + dataset_type + \"_optimal_hyperparam\"\n",
        "plt.savefig(base_name + \".pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHTz_WnT_BoR"
      },
      "source": [
        "## Download all files"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pDWPzfHZ1jyg"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import files\n",
        "\n",
        "for f in os.listdir('.'):\n",
        "    if os.path.isfile(f) and f.endswith('.pdf'):\n",
        "        files.download(f)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
