{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fa6c68ed",
   "metadata": {},
   "source": [
    "### Startup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4a17153",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:19.260604Z",
     "iopub.status.busy": "2024-05-02T03:04:19.259641Z",
     "iopub.status.idle": "2024-05-02T03:04:20.251166Z",
     "shell.execute_reply": "2024-05-02T03:04:20.250623Z"
    }
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import breaching\n",
    "except ModuleNotFoundError:\n",
    "    # You only really need this safety net if you want to run these notebooks directly in the examples directory\n",
    "    # Don't worry about this if you installed the package or moved the notebook to the main directory.\n",
    "    import os; os.chdir(\"..\")\n",
    "    import breaching\n",
    "    \n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import random, json, copy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Redirects logs directly into the jupyter notebook\n",
    "import logging, sys\n",
    "logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')\n",
    "logger = logging.getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "182060ad",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:20.253477Z",
     "iopub.status.busy": "2024-05-02T03:04:20.253312Z",
     "iopub.status.idle": "2024-05-02T03:04:20.264881Z",
     "shell.execute_reply": "2024-05-02T03:04:20.264534Z"
    }
   },
   "outputs": [],
   "source": [
    "sensitivity_path = os.getcwd() + \"/classification/sensitivity/\"\n",
    "print(sensitivity_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be173c61",
   "metadata": {},
   "source": [
    "### Initialize cfg object and system setup:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a73d88",
   "metadata": {},
   "source": [
    "This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd7dbeb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:20.266576Z",
     "iopub.status.busy": "2024-05-02T03:04:20.266413Z",
     "iopub.status.idle": "2024-05-02T03:04:21.530884Z",
     "shell.execute_reply": "2024-05-02T03:04:21.530576Z"
    }
   },
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=6_large_batch_cifar\"])\n",
    "\n",
    "device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))\n",
    "setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e7f06c0",
   "metadata": {},
   "source": [
    "### Modify config options here"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc4dd6f",
   "metadata": {},
   "source": [
    "You can use `.attribute` access to modify any of these configurations for the attack, or the case:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c40a59e2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:21.532592Z",
     "iopub.status.busy": "2024-05-02T03:04:21.532467Z",
     "iopub.status.idle": "2024-05-02T03:04:21.546057Z",
     "shell.execute_reply": "2024-05-02T03:04:21.545722Z"
    }
   },
   "outputs": [],
   "source": [
    "# cfg.case.data.partition=\"unique-class\"\n",
    "cfg.case.data.partition=\"balanced\" # 100 unique CIFAR-100 images\n",
    "cfg.case.user.num_data_points = 1\n",
    "cfg.case.user.user_idx = 1\n",
    "\n",
    "cfg.case.model='lenet100'\n",
    "\n",
    "cfg.case.user.provide_labels = False\n",
    "cfg.attack.label_strategy = \"yin\" # also works here, as labels are unique\n",
    "\n",
    "# Total variation regularization needs to be smaller on CIFAR-10:\n",
    "cfg.attack.regularization.total_variation.scale = 5e-4\n",
    "\n",
    "# Reduce max iteration to save time\n",
    "# cfg.attack.optim.max_iterations = 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37ffd4f5",
   "metadata": {},
   "source": [
    "### Instantiate all parties"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df789fe6",
   "metadata": {},
   "source": [
    "The following lines generate \"server, \"user\" and \"attacker\" objects and print an overview of their configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29de024",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:21.547535Z",
     "iopub.status.busy": "2024-05-02T03:04:21.547434Z",
     "iopub.status.idle": "2024-05-02T03:04:24.275830Z",
     "shell.execute_reply": "2024-05-02T03:04:24.275473Z"
    }
   },
   "outputs": [],
   "source": [
    "user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)\n",
    "model_path = \"./classification/model-lenet.pt\"\n",
    "# torch.save(model.state_dict(), model_path)\n",
    "model.load_state_dict(torch.load(model_path))\n",
    "attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)\n",
    "breaching.utils.overview(server, user, attacker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce331dfb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.277740Z",
     "iopub.status.busy": "2024-05-02T03:04:24.277404Z",
     "iopub.status.idle": "2024-05-02T03:04:24.846740Z",
     "shell.execute_reply": "2024-05-02T03:04:24.846280Z"
    }
   },
   "outputs": [],
   "source": [
    "server_payload = server.distribute_payload()\n",
    "shared_data, true_user_data = user.compute_local_updates(server_payload)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be9b55e0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.848753Z",
     "iopub.status.busy": "2024-05-02T03:04:24.848605Z",
     "iopub.status.idle": "2024-05-02T03:04:24.890630Z",
     "shell.execute_reply": "2024-05-02T03:04:24.890267Z"
    }
   },
   "outputs": [],
   "source": [
    "user.plot(true_user_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b42ee985",
   "metadata": {},
   "source": [
    "**Network Info**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "798b73fc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.892255Z",
     "iopub.status.busy": "2024-05-02T03:04:24.892129Z",
     "iopub.status.idle": "2024-05-02T03:04:24.904612Z",
     "shell.execute_reply": "2024-05-02T03:04:24.904267Z"
    }
   },
   "outputs": [],
   "source": [
    "num_layer = len(shared_data['gradients'])\n",
    "print(f\"Number of Layers: {num_layer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "661b0792",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.906157Z",
     "iopub.status.busy": "2024-05-02T03:04:24.905957Z",
     "iopub.status.idle": "2024-05-02T03:04:24.918163Z",
     "shell.execute_reply": "2024-05-02T03:04:24.917811Z"
    }
   },
   "outputs": [],
   "source": [
    "num_param = 0\n",
    "num_param_list = []\n",
    "for layer in shared_data['gradients']:\n",
    "    num_param += layer.numel()\n",
    "    num_param_list.append(layer.numel())\n",
    "print(f\"Number of Parameters: {num_param}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02479c7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.919639Z",
     "iopub.status.busy": "2024-05-02T03:04:24.919508Z",
     "iopub.status.idle": "2024-05-02T03:04:24.929136Z",
     "shell.execute_reply": "2024-05-02T03:04:24.928805Z"
    }
   },
   "outputs": [],
   "source": [
    "num_param_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a00091de",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.930575Z",
     "iopub.status.busy": "2024-05-02T03:04:24.930439Z",
     "iopub.status.idle": "2024-05-02T03:04:24.947109Z",
     "shell.execute_reply": "2024-05-02T03:04:24.946749Z"
    }
   },
   "outputs": [],
   "source": [
    "from sewar.full_ref import msssim, uqi, vifp\n",
    "\n",
    "def get_sim_metrics(true_data, reconstructed_data):\n",
    "    def scale_tensor(d):\n",
    "        min_val, max_val = d.amin(dim=[2, 3], keepdim=True), d.amax(dim=[2, 3], keepdim=True)\n",
    "        # print(f'min_val: {min_val} | max_val: {max_val}')\n",
    "        return (d - min_val) / (max_val - min_val)\n",
    "\n",
    "    true_data_scaled = scale_tensor(true_data[\"data\"]).squeeze().permute(1, 2, 0).cpu().numpy()\n",
    "    reconstructed_data_scaled = scale_tensor(reconstructed_data[\"data\"]).squeeze().permute(1, 2, 0).cpu().numpy()\n",
    "\n",
    "    metrics = {\n",
    "        \"msssim\": msssim(true_data_scaled, reconstructed_data_scaled, MAX=1.0),\n",
    "        \"uqi\": uqi(true_data_scaled, reconstructed_data_scaled),\n",
    "        \"vifp\": vifp(true_data_scaled, reconstructed_data_scaled)\n",
    "    }\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff196fd1",
   "metadata": {},
   "source": [
    "## DP Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047a60f5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.948695Z",
     "iopub.status.busy": "2024-05-02T03:04:24.948534Z",
     "iopub.status.idle": "2024-05-02T03:04:24.959219Z",
     "shell.execute_reply": "2024-05-02T03:04:24.958880Z"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "def add_laplace_noise(grad_list: list, b: float, iter: int):\n",
    "    file_name = f\"./classification/noised_grad/{cfg.case.model}_dp_{b}_user_{cfg.case.user.user_idx}_iter_{iter}.pt\"\n",
    "    file_path = Path(file_name)\n",
    "    if file_path.exists():\n",
    "        print(f\"Found existing noised gradient for b = {b} and iter = {iter}.\")\n",
    "        return torch.load(file_name, map_location=device)\n",
    "    else:\n",
    "        print(f\"Generating new noised gradient for b = {b} and iter = {iter}.\")\n",
    "        loc = torch.as_tensor(0.0, **setup)\n",
    "        scale = torch.as_tensor(b, **setup)\n",
    "        noise_generator = torch.distributions.laplace.Laplace(loc=loc, scale=scale)\n",
    "        \n",
    "        new_grad_list = [grad + noise_generator.sample(grad.shape) for grad in grad_list]\n",
    "        torch.save(new_grad_list, file_name)\n",
    "        return new_grad_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a681327",
   "metadata": {},
   "source": [
    "## No Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac3bf39",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:24.960792Z",
     "iopub.status.busy": "2024-05-02T03:04:24.960587Z",
     "iopub.status.idle": "2024-05-02T03:04:31.646360Z",
     "shell.execute_reply": "2024-05-02T03:04:31.645952Z"
    }
   },
   "outputs": [],
   "source": [
    "for b in [0.001, 0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13, 0.15, 0.17, 0.19, 0.21, 0.23, 0.25, 0.27]:\n",
    "    results_no_list = []\n",
    "    \n",
    "    for i in tqdm(range(3), desc=f\"Adding Laplace Noise with b = {b}\"):\n",
    "        share_data_copy = copy.deepcopy(shared_data)\n",
    "        share_data_copy['gradients'] = add_laplace_noise(share_data_copy['gradients'], b, i)\n",
    "\n",
    "        results_no_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "\n",
    "        for _ in tqdm(range(10), desc=f\"b = {b}\"):\n",
    "            reconstructed_user_data, stats = attacker.reconstruct([server_payload], [share_data_copy], {}, dryrun=cfg.dryrun)\n",
    "            \n",
    "            metrics = breaching.analysis.report(reconstructed_user_data, \n",
    "                                                true_user_data, [server_payload], \n",
    "                                                server.model, order_batch=True, compute_full_iip=False, \n",
    "                                                cfg_case=cfg.case, setup=setup)\n",
    "            sim_score = get_sim_metrics(true_user_data, reconstructed_user_data)\n",
    "            results_dict_temp = {\n",
    "                \"mse\": metrics[\"mse\"],\n",
    "                \"msssim\": sim_score[\"msssim\"],\n",
    "                \"uqi\": sim_score[\"uqi\"],\n",
    "                \"vifp\": sim_score[\"vifp\"]\n",
    "            }\n",
    "            for k, v in results_dict_temp.items():\n",
    "                results_no_dict[k].append(v)\n",
    "                print(f\"{k}: {v}\")\n",
    "            user.plot(reconstructed_user_data)\n",
    "        results_no_list.append(results_no_dict)\n",
    "    torch.save(results_no_list, f\"./classification/results/dp/dp_{b}_lenet_no_protection.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a7c688e",
   "metadata": {},
   "source": [
    "## Selective Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a61e6e05",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:31.648194Z",
     "iopub.status.busy": "2024-05-02T03:04:31.648085Z",
     "iopub.status.idle": "2024-05-02T03:04:31.664805Z",
     "shell.execute_reply": "2024-05-02T03:04:31.664456Z"
    }
   },
   "outputs": [],
   "source": [
    "def flat_tensor_list(grads):\n",
    "    flat_grads = torch.tensor([], device=grads[0].device)\n",
    "    for i in grads:\n",
    "        flat_grads = torch.concat((flat_grads, i.view(-1)), dim = 0)\n",
    "    return flat_grads\n",
    "\n",
    "def encrypt_by_index(grad_list: list, idx_list: float, if_plot=False) -> list:\n",
    "    \"\"\"\n",
    "    @grad_list: gradient list to be encrypted\n",
    "    @idx_list: list of encrypted parameter indices\n",
    "    @if_plot: whether to plot the encrypted ratio of each layer\n",
    "    \"\"\"\n",
    "    grad_flat = flat_tensor_list(grad_list)\n",
    "\n",
    "    mask = torch.zeros(grad_flat.shape).to(device=grad_flat.device) == 0\n",
    "    mask[idx_list] = False\n",
    "\n",
    "    # Apply the encryption mask\n",
    "    grad_flat = grad_flat * mask + torch.rand(grad_flat.shape).to(device=grad_flat.device) * ~mask\n",
    "\n",
    "    # Reshape back\n",
    "    new_grad_list = []\n",
    "    start_idx = 0\n",
    "    ratio_enc_per_layer = []\n",
    "    for layer_grad in tqdm(grad_list):\n",
    "        layer_param_num = layer_grad.numel()\n",
    "        grad_slice = grad_flat[start_idx: start_idx + layer_param_num]\n",
    "        new_grad_list.append(grad_slice.reshape(layer_grad.shape))\n",
    "\n",
    "        # Record the number of encrypted parameters of each layer\n",
    "        mask_slice = mask[start_idx: start_idx + layer_param_num]\n",
    "        ratio_param_enc_layer = (layer_param_num - mask_slice.sum()) / layer_param_num\n",
    "        ratio_enc_per_layer.append(ratio_param_enc_layer.item())\n",
    "\n",
    "        start_idx += layer_param_num\n",
    "\n",
    "    if if_plot:\n",
    "        plt.bar(np.arange(1, len(ratio_enc_per_layer)+1), ratio_enc_per_layer)\n",
    "        plt.xlabel(\"Layer Index\")\n",
    "        plt.ylabel(\"Encryption Ratio\")\n",
    "        plt.show()\n",
    "    return new_grad_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf4f7129",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:31.666512Z",
     "iopub.status.busy": "2024-05-02T03:04:31.666411Z",
     "iopub.status.idle": "2024-05-02T03:04:31.682645Z",
     "shell.execute_reply": "2024-05-02T03:04:31.682271Z"
    }
   },
   "outputs": [],
   "source": [
    "def encrypt_selective(sens_list, ratio_enc, b, iter, num_repeat=1):\n",
    "    num_enc = (int) (ratio_enc * num_param)\n",
    "    encrypt_list = torch.topk(flat_tensor_list(sens_list).abs(), num_enc, largest=True).indices\n",
    "\n",
    "    new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "    new_gradient = add_laplace_noise(new_gradient, b, iter)\n",
    "    new_gradient = encrypt_by_index(new_gradient, encrypt_list)\n",
    "\n",
    "    gradient_with_he = list(new_gradient)\n",
    "    new_shared_data = dict()\n",
    "    for k in shared_data.keys():\n",
    "        if k == \"gradients\":\n",
    "            new_shared_data[k] = gradient_with_he\n",
    "        else:\n",
    "            new_shared_data[k] = shared_data[k]\n",
    "\n",
    "    results_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}\"):\n",
    "        # print(new_shared_data['gradients'])\n",
    "\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [copy.deepcopy(new_shared_data)],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        sim_score = get_sim_metrics(true_user_data, new_reconstructed_user_data)\n",
    "        results_dict_temp = {\n",
    "            \"mse\": metrics[\"mse\"],\n",
    "            \"msssim\": sim_score[\"msssim\"],\n",
    "            \"uqi\": sim_score[\"uqi\"],\n",
    "            \"vifp\": sim_score[\"vifp\"]\n",
    "        }\n",
    "        print(f\"Encrypted the {ratio_enc * 100:.2f}% parameters\")\n",
    "        for k, v in results_dict_temp.items():\n",
    "            results_dict[k].append(v)\n",
    "            print(f\"{k}: {v}\")\n",
    "        user.plot(new_reconstructed_user_data)\n",
    "    return results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bdf20d1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:31.684202Z",
     "iopub.status.busy": "2024-05-02T03:04:31.684067Z",
     "iopub.status.idle": "2024-05-02T03:04:37.152561Z",
     "shell.execute_reply": "2024-05-02T03:04:37.152181Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load sensitivity\n",
    "sens_path = sensitivity_path + \"lenet_mean_sens.pt\"\n",
    "sens_list = torch.load(sens_path)\n",
    "#\n",
    "for b in [0.001, 0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13, 0.15, 0.17, 0.19, 0.21, 0.23, 0.25, 0.27]:\n",
    "    ratio = 0.005\n",
    "    results_select_list = []\n",
    "    for i in tqdm(range(3), desc=f\"Adding Laplace Noise with b = {b}\"):\n",
    "        results_select_dict = encrypt_selective(sens_list, \n",
    "                                                ratio_enc=ratio,\n",
    "                                                b=b,\n",
    "                                                iter=i,\n",
    "                                                num_repeat=10)\n",
    "        results_select_list.append(results_select_dict)\n",
    "    torch.save(results_select_list, f\"./classification/results/dp/dp_{b}_lenet_selective_{ratio}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e31a435e",
   "metadata": {},
   "source": [
    "## Random Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4bbd7b1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:37.154413Z",
     "iopub.status.busy": "2024-05-02T03:04:37.154300Z",
     "iopub.status.idle": "2024-05-02T03:04:37.171419Z",
     "shell.execute_reply": "2024-05-02T03:04:37.171066Z"
    }
   },
   "outputs": [],
   "source": [
    "# below are for for random encryption by parameters (modified):\n",
    "\n",
    "def encrypt_random(ratio_enc, b, iter, num_repeat=1):\n",
    "    # Randomly pick a proportion to encrypt\n",
    "    protected_params = random.sample(range(num_param), (int) (ratio_enc * num_param))\n",
    "\n",
    "    new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "    new_gradient = add_laplace_noise(new_gradient, b, iter)\n",
    "    new_gradient = encrypt_by_index(new_gradient, protected_params)\n",
    "\n",
    "    gradient_with_he = list(new_gradient)\n",
    "    new_shared_data = dict()\n",
    "    for k in shared_data.keys():\n",
    "        if k == \"gradients\":\n",
    "            new_shared_data[k] = gradient_with_he\n",
    "        else:\n",
    "            new_shared_data[k] = shared_data[k]\n",
    "\n",
    "    results_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}, b={b}\"):\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [copy.deepcopy(new_shared_data)],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        sim_score = get_sim_metrics(true_user_data, new_reconstructed_user_data)\n",
    "        results_dict_temp = {\n",
    "            \"mse\": metrics[\"mse\"],\n",
    "            \"msssim\": sim_score[\"msssim\"],\n",
    "            \"uqi\": sim_score[\"uqi\"],\n",
    "            \"vifp\": sim_score[\"vifp\"]\n",
    "        }\n",
    "        print(f\"Randomly encrypted {ratio_enc * 100}% parameters\")\n",
    "        for k, v in results_dict_temp.items():\n",
    "            results_dict[k].append(v)\n",
    "            print(f\"{k}: {v}\")\n",
    "        user.plot(new_reconstructed_user_data)\n",
    "    return results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dee4e38d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-02T03:04:37.173012Z",
     "iopub.status.busy": "2024-05-02T03:04:37.172906Z",
     "iopub.status.idle": "2024-05-02T03:04:42.693530Z",
     "shell.execute_reply": "2024-05-02T03:04:42.693217Z"
    }
   },
   "outputs": [],
   "source": [
    "for b in [0.001, 0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13, 0.15, 0.17, 0.19, 0.21, 0.23, 0.25, 0.27]:\n",
    "    ratio = 0.005\n",
    "    results_random_list = []\n",
    "    for i in tqdm(range(3), desc=f\"Adding Laplace Noise with b = {b}\"):\n",
    "        results_random_list_b = []\n",
    "        for _ in tqdm(range(5), desc=\"Testing random init...\"):\n",
    "            results_dict = encrypt_random(ratio_enc=ratio, b=b, iter=i, num_repeat=10)\n",
    "            results_random_list_b.append(results_dict)\n",
    "        results_random_list.append(results_random_list_b)\n",
    "    torch.save(results_random_list, f\"./classification/results/dp/dp_{b}_lenet_random_{ratio}.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
