{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "90dcd6cb",
   "metadata": {},
   "source": [
    "### Startup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b850eabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import breaching\n",
    "except ModuleNotFoundError:\n",
    "    # You only really need this safety net if you want to run these notebooks directly in the examples directory\n",
    "    # Don't worry about this if you installed the package or moved the notebook to the main directory.\n",
    "    import os; os.chdir(\"..\")\n",
    "    import breaching\n",
    "    \n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import random, json, copy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Redirects logs directly into the jupyter notebook\n",
    "import logging, sys\n",
    "logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')\n",
    "logger = logging.getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1e902f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_path = os.getcwd() + \"/classification/sensitivity/\"\n",
    "print(sensitivity_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d5e214",
   "metadata": {},
   "source": [
    "### Initialize cfg object and system setup:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56bd663b",
   "metadata": {},
   "source": [
    "This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7dc3a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=6_large_batch_cifar\"])\n",
    "\n",
    "device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))\n",
    "setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "203c5fb1",
   "metadata": {},
   "source": [
    "### Modify config options here"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e0764ef",
   "metadata": {},
   "source": [
    "You can use `.attribute` access to modify any of these configurations for the attack, or the case:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac118ea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.case.data.partition=\"unique-class\" # 100 unique CIFAR-100 images\n",
    "cfg.case.user.num_data_points = 1\n",
    "cfg.case.user.user_idx = 0\n",
    "cfg.case.model='lenet100'\n",
    "\n",
    "cfg.case.user.provide_labels = False\n",
    "cfg.attack.label_strategy = \"yin\" # also works here, as labels are unique\n",
    "\n",
    "# Total variation regularization needs to be smaller on CIFAR-10:\n",
    "cfg.attack.regularization.total_variation.scale = 5e-4\n",
    "\n",
    "# Reduce max iteration to save time\n",
    "# cfg.attack.optim.max_iterations = 100\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76f64389",
   "metadata": {},
   "source": [
    "### Instantiate all parties"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8db2272f",
   "metadata": {},
   "source": [
    "The following lines generate \"server, \"user\" and \"attacker\" objects and print an overview of their configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3abd955",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = f\"./classification/model-lenet.pt\"\n",
    "user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup, \n",
    "                                                              model_path=model_path)\n",
    "attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)\n",
    "breaching.utils.overview(server, user, attacker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0dbd868",
   "metadata": {},
   "outputs": [],
   "source": [
    "server_payload = server.distribute_payload()\n",
    "shared_data, true_user_data = user.compute_local_updates(server_payload)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c68628",
   "metadata": {},
   "outputs": [],
   "source": [
    "user.plot(true_user_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d30b926",
   "metadata": {},
   "source": [
    "**Network Info**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40d84231",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layer = len(shared_data['gradients'])\n",
    "print(f\"Number of Layers: {num_layer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb325a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_param = 0\n",
    "num_param_list = []\n",
    "for layer in shared_data['gradients']:\n",
    "    num_param += layer.numel()\n",
    "    num_param_list.append(layer.numel())\n",
    "print(f\"Number of Parameters: {num_param}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f346d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_param_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fabb81bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sewar.full_ref import msssim, uqi, vifp\n",
    "\n",
    "def get_sim_metrics(true_data, reconstructed_data):\n",
    "    def scale_tensor(d):\n",
    "        min_val, max_val = d.amin(dim=[2, 3], keepdim=True), d.amax(dim=[2, 3], keepdim=True)\n",
    "        # print(f'min_val: {min_val} | max_val: {max_val}')\n",
    "        return (d - min_val) / (max_val - min_val)\n",
    "\n",
    "    true_data_scaled = scale_tensor(true_data[\"data\"]).squeeze().permute(1, 2, 0).cpu().numpy()\n",
    "    reconstructed_data_scaled = scale_tensor(reconstructed_data[\"data\"]).squeeze().permute(1, 2, 0).cpu().numpy()\n",
    "\n",
    "    metrics = {\n",
    "        \"msssim\": msssim(true_data_scaled, reconstructed_data_scaled, MAX=1.0),\n",
    "        \"uqi\": uqi(true_data_scaled, reconstructed_data_scaled),\n",
    "        \"vifp\": vifp(true_data_scaled, reconstructed_data_scaled)\n",
    "    }\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15d696e1",
   "metadata": {},
   "source": [
    "## No Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f2685a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_no_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "for _ in tqdm(range(10)):\n",
    "    reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], {}, dryrun=cfg.dryrun)\n",
    "    metrics = breaching.analysis.report(reconstructed_user_data, \n",
    "                                        true_user_data, [server_payload], \n",
    "                                        server.model, order_batch=True, compute_full_iip=False, \n",
    "                                        cfg_case=cfg.case, setup=setup)\n",
    "    sim_score = get_sim_metrics(true_user_data, reconstructed_user_data)\n",
    "    results_dict_temp = {\n",
    "        \"mse\": metrics[\"mse\"],\n",
    "        \"msssim\": sim_score[\"msssim\"],\n",
    "        \"uqi\": sim_score[\"uqi\"],\n",
    "        \"vifp\": sim_score[\"vifp\"]\n",
    "    }\n",
    "    for k, v in results_dict_temp.items():\n",
    "        results_no_dict[k].append(v)\n",
    "        print(f\"{k}: {v}\")\n",
    "    user.plot(reconstructed_user_data)\n",
    "torch.save(results_no_dict, f\"./classification/results/lenet_no_protection.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selective Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7705614b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flat_tensor_list(grads):\n",
    "    flat_grads = torch.tensor([], device=grads[0].device)\n",
    "    for i in grads:\n",
    "        flat_grads = torch.concat((flat_grads, i.view(-1)), dim = 0)\n",
    "    return flat_grads\n",
    "\n",
    "def encrypt_by_index(grad_list: list, idx_list: float, if_plot=False) -> list:\n",
    "    \"\"\"\n",
    "    @grad_list: gradient list to be encrypted\n",
    "    @idx_list: list of encrypted parameter indices\n",
    "    @if_plot: whether to plot the encrypted ratio of each layer\n",
    "    \"\"\"\n",
    "    grad_flat = flat_tensor_list(grad_list)\n",
    "\n",
    "    mask = torch.zeros(grad_flat.shape).to(device=grad_flat.device) == 0\n",
    "    mask[idx_list] = False\n",
    "\n",
    "    # Apply the encryption mask\n",
    "    grad_flat = grad_flat * mask + torch.rand(grad_flat.shape).to(device=grad_flat.device) * ~mask\n",
    "\n",
    "    # Reshape back\n",
    "    new_grad_list = []\n",
    "    start_idx = 0\n",
    "    ratio_enc_per_layer = []\n",
    "    for layer_grad in tqdm(grad_list):\n",
    "        layer_param_num = layer_grad.numel()\n",
    "        grad_slice = grad_flat[start_idx: start_idx + layer_param_num]\n",
    "        new_grad_list.append(grad_slice.reshape(layer_grad.shape))\n",
    "\n",
    "        # Record the number of encrypted parameters of each layer\n",
    "        mask_slice = mask[start_idx: start_idx + layer_param_num]\n",
    "        ratio_param_enc_layer = (layer_param_num - mask_slice.sum()) / layer_param_num\n",
    "        ratio_enc_per_layer.append(ratio_param_enc_layer.item())\n",
    "\n",
    "        start_idx += layer_param_num\n",
    "\n",
    "    if if_plot:\n",
    "        plt.bar(np.arange(1, len(ratio_enc_per_layer)+1), ratio_enc_per_layer)\n",
    "        plt.xlabel(\"Layer Index\")\n",
    "        plt.ylabel(\"Encryption Ratio\")\n",
    "        plt.show()\n",
    "    return new_grad_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b538e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load sensitivity\n",
    "def encrypt_selective(sens_list, ratio_enc, num_repeat=1):\n",
    "    num_enc = (int) (ratio_enc * num_param)\n",
    "    encrypt_list = torch.topk(flat_tensor_list(sens_list).abs(), num_enc, largest=True).indices\n",
    "\n",
    "    results_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}\"):\n",
    "        new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "        new_gradient = encrypt_by_index(new_gradient, encrypt_list)\n",
    "\n",
    "        gradient_with_he = list(new_gradient)\n",
    "        new_shared_data = dict()\n",
    "        for k in shared_data.keys():\n",
    "            if k == \"gradients\":\n",
    "                new_shared_data[k] = gradient_with_he\n",
    "            else:\n",
    "                new_shared_data[k] = shared_data[k]\n",
    "\n",
    "        # print(new_shared_data['gradients'])\n",
    "\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [new_shared_data],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        sim_score = get_sim_metrics(true_user_data, new_reconstructed_user_data)\n",
    "        results_dict_temp = {\n",
    "            \"mse\": metrics[\"mse\"],\n",
    "            \"msssim\": sim_score[\"msssim\"],\n",
    "            \"uqi\": sim_score[\"uqi\"],\n",
    "            \"vifp\": sim_score[\"vifp\"]\n",
    "        }\n",
    "        print(f\"Encrypted the {ratio_enc * 100:.2f}% parameters\")\n",
    "        for k, v in results_dict_temp.items():\n",
    "            results_dict[k].append(v)\n",
    "            print(f\"{k}: {v}\")\n",
    "        user.plot(new_reconstructed_user_data)\n",
    "    return results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a639d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sens_path = sensitivity_path + \"lenet_x_mean_sens.pt\"\n",
    "sens_list = torch.load(sens_path)\n",
    "\n",
    "for ratio in [0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13]:\n",
    "    results_select_dict = encrypt_selective(sens_list, \n",
    "                                            ratio_enc=ratio,\n",
    "                                            num_repeat=10)\n",
    "    torch.save(results_select_dict,\n",
    "               f\"./classification/results/lenet_selective_{ratio}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bd3e885",
   "metadata": {},
   "source": [
    "## Random Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b8bb337",
   "metadata": {},
   "outputs": [],
   "source": [
    "# below are for for random encryption by parameters (modified):\n",
    "\n",
    "def encrypt_random(ratio_enc, num_repeat=1):\n",
    "    # Randomly pick a proportion to encrypt\n",
    "    protected_params = random.sample(range(num_param), (int) (ratio_enc * num_param))\n",
    "\n",
    "    results_dict = {\"mse\": [], \"msssim\": [], \"uqi\": [], \"vifp\": []}\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}\"):\n",
    "        new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "        new_gradient = encrypt_by_index(new_gradient, protected_params)\n",
    "\n",
    "        gradient_with_he = list(new_gradient)\n",
    "        new_shared_data = dict()\n",
    "        for k in shared_data.keys():\n",
    "            if k == \"gradients\":\n",
    "                new_shared_data[k] = gradient_with_he\n",
    "            else:\n",
    "                new_shared_data[k] = shared_data[k]\n",
    "\n",
    "        # print(new_shared_data['gradients'])\n",
    "\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [new_shared_data],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        sim_score = get_sim_metrics(true_user_data, new_reconstructed_user_data)\n",
    "        results_dict_temp = {\n",
    "            \"mse\": metrics[\"mse\"],\n",
    "            \"msssim\": sim_score[\"msssim\"],\n",
    "            \"uqi\": sim_score[\"uqi\"],\n",
    "            \"vifp\": sim_score[\"vifp\"]\n",
    "        }\n",
    "        print(f\"Randomly encrypted {ratio_enc * 100}% parameters\")\n",
    "        for k, v in results_dict_temp.items():\n",
    "            results_dict[k].append(v)\n",
    "            print(f\"{k}: {v}\")\n",
    "        user.plot(new_reconstructed_user_data)\n",
    "    return results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec842839",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ratio in [0.01, 0.03, 0.05, 0.07, 0.09, 0.11, 0.13]:\n",
    "    results_random_list = []\n",
    "    for _ in tqdm(range(1), desc=\"Testing random init...\"):\n",
    "        results_dict = encrypt_random(ratio_enc=ratio, num_repeat=10)\n",
    "        results_random_list.append(results_dict)\n",
    "    torch.save(results_random_list,\n",
    "               f\"./classification/results/lenet_random_{ratio}.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
