{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "90dcd6cb",
   "metadata": {},
   "source": [
    "### Startup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b850eabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import breaching\n",
    "except ModuleNotFoundError:\n",
    "    # You only really need this safety net if you want to run these notebooks directly in the examples directory\n",
    "    # Don't worry about this if you installed the package or moved the notebook to the main directory.\n",
    "    import os; os.chdir(\"../..\")\n",
    "    import breaching\n",
    "    \n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import random, json, copy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Redirects logs directly into the jupyter notebook\n",
    "import logging, sys\n",
    "logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')\n",
    "logger = logging.getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7408be55",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_path = os.getcwd() + \"/lm/sensitivity/\"\n",
    "print(sensitivity_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d5e214",
   "metadata": {},
   "source": [
    "### Initialize cfg object and system setup:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56bd663b",
   "metadata": {},
   "source": [
    "This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7dc3a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))\n",
    "setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "203c5fb1",
   "metadata": {},
   "source": [
    "### Modify config options here"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e0764ef",
   "metadata": {},
   "source": [
    "You can use `.attribute` access to modify any of these configurations for the attack, or the case:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac118ea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.user.user_idx = 1 # From which user?\n",
    "cfg.case.data.shape = [16] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"transformerS\"\n",
    "\n",
    "# cfg.attack.optim.step_size = 0.02\n",
    "# cfg.attack.optim.step_size_decay = \"step-lr\"\n",
    "cfg.attack.optim.max_iterations = 5000 # Increasing the number of iterations can help this attack"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76f64389",
   "metadata": {},
   "source": [
    "### Instantiate all parties"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8db2272f",
   "metadata": {},
   "source": [
    "The following lines generate \"server, \"user\" and \"attacker\" objects and print an overview of their configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3abd955",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = f\"./lm/causal-lm/model-tag-{cfg.case.model}.pt\"\n",
    "user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup, model_path=model_path)\n",
    "\n",
    "attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)\n",
    "breaching.utils.overview(server, user, attacker)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "548c0ad6",
   "metadata": {},
   "source": [
    "### Simulate an attacked FL protocol"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2058bcc2",
   "metadata": {},
   "source": [
    "This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0dbd868",
   "metadata": {},
   "outputs": [],
   "source": [
    "server_payload = server.distribute_payload()\n",
    "shared_data, true_user_data = user.compute_local_updates(server_payload)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c68628",
   "metadata": {},
   "outputs": [],
   "source": [
    "user.print(true_user_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1ecdc38",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layer = len(model.state_dict())\n",
    "print(f\"Number of Layers: {num_layer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c1657f",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_param = 0\n",
    "num_param_list = []\n",
    "for layer in shared_data['gradients']:\n",
    "    num_param += layer.numel()\n",
    "    num_param_list.append(layer.numel())\n",
    "print(f\"Number of Parameters: {num_param}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9c3130d",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_param_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17255c5a",
   "metadata": {},
   "source": [
    "### Reconstruct user data:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82360c14",
   "metadata": {},
   "source": [
    "Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. \n",
    "\n",
    "You can interrupt the computation early to see a partial solution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a32fd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_no_dict = {\n",
    "    \"accuracy\": [], \n",
    "    \"sacrebleu\": [], \n",
    "    \"feat_mse\": [], \n",
    "    \"google_bleu\": [],\n",
    "    \"rouge1\": [],\n",
    "    \"rouge2\": [],\n",
    "    \"rougeL\": [],\n",
    "    \"token_acc\": [],\n",
    "    \"token_avg_accuracy\": []\n",
    "}\n",
    "for i in tqdm(range(10)):\n",
    "    reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                        [copy.deepcopy(shared_data)],\n",
    "                                                        server.secrets, \n",
    "                                                        dryrun=cfg.dryrun)\n",
    "\n",
    "    metrics = breaching.analysis.report(reconstructed_user_data,\n",
    "                                        true_user_data,\n",
    "                                        [server_payload], \n",
    "                                        server.model,\n",
    "                                        order_batch=True,\n",
    "                                        compute_full_iip=False, \n",
    "                                        cfg_case=cfg.case,\n",
    "                                        setup=setup)\n",
    "    for k in results_no_dict.keys():\n",
    "        results_no_dict[k].append(metrics[k])\n",
    "    user.print_and_mark_correct(reconstructed_user_data, true_user_data)\n",
    "torch.save(results_no_dict, f\"./lm/results/tag/{cfg.case.user.user_idx}/{cfg.case.model}_no_protection.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dca8837",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flat_tensor_list(grads):\n",
    "    flat_grads = torch.tensor([], device=grads[0].device)\n",
    "    for i in grads:\n",
    "        flat_grads = torch.concat((flat_grads, i.view(-1)), dim = 0)\n",
    "    return flat_grads\n",
    "\n",
    "def encrypt_by_index(grad_list: list, idx_list: float, if_plot=False) -> list:\n",
    "    \"\"\"\n",
    "    @grad_list: gradient list to be encrypted\n",
    "    @idx_list: list of encrypted parameter indices\n",
    "    @if_plot: whether to plot the encrypted ratio of each layer\n",
    "    \"\"\"\n",
    "    grad_flat = flat_tensor_list(grad_list)\n",
    "\n",
    "    mask = torch.zeros(grad_flat.shape).to(device=grad_flat.device) == 0\n",
    "    mask[idx_list] = False\n",
    "\n",
    "    # Apply the encryption mask\n",
    "    grad_flat = grad_flat * mask + torch.rand(grad_flat.shape).to(device=grad_flat.device) * ~mask\n",
    "\n",
    "    # Reshape back\n",
    "    new_grad_list = []\n",
    "    start_idx = 0\n",
    "    ratio_enc_per_layer = []\n",
    "    for layer_grad in tqdm(grad_list):\n",
    "        layer_param_num = layer_grad.numel()\n",
    "        grad_slice = grad_flat[start_idx: start_idx + layer_param_num]\n",
    "        new_grad_list.append(grad_slice.reshape(layer_grad.shape))\n",
    "\n",
    "        # Record the number of encrypted parameters of each layer\n",
    "        mask_slice = mask[start_idx: start_idx + layer_param_num]\n",
    "        ratio_param_enc_layer = (layer_param_num - mask_slice.sum()) / layer_param_num\n",
    "        ratio_enc_per_layer.append(ratio_param_enc_layer.item())\n",
    "\n",
    "        start_idx += layer_param_num\n",
    "\n",
    "    if if_plot:\n",
    "        plt.bar(np.arange(1, len(ratio_enc_per_layer)+1), ratio_enc_per_layer)\n",
    "        plt.xlabel(\"Layer Index\")\n",
    "        plt.ylabel(\"Encryption Ratio\")\n",
    "        plt.show()\n",
    "    return new_grad_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a4bea91",
   "metadata": {},
   "source": [
    "## Selective Encryption by Sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee47ca1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encrypt_selective(sens_list, ratio_enc, num_repeat=1):\n",
    "    num_enc = (int) (ratio_enc * num_param)\n",
    "    encrypt_list = torch.topk(flat_tensor_list(sens_list).abs(), num_enc, largest=True).indices\n",
    "\n",
    "    results_dict = {\n",
    "        \"accuracy\": [], \n",
    "        \"sacrebleu\": [], \n",
    "        \"feat_mse\": [], \n",
    "        \"google_bleu\": [],\n",
    "        \"rouge1\": [],\n",
    "        \"rouge2\": [],\n",
    "        \"rougeL\": [],\n",
    "        \"token_acc\": [],\n",
    "        \"token_avg_accuracy\": []\n",
    "    }\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}\"):\n",
    "        new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "        new_gradient = encrypt_by_index(new_gradient, encrypt_list)\n",
    "\n",
    "        gradient_with_he = list(new_gradient)\n",
    "        new_shared_data = dict()\n",
    "        for k in shared_data.keys():\n",
    "            if k == \"gradients\":\n",
    "                new_shared_data[k] = gradient_with_he\n",
    "            else:\n",
    "                new_shared_data[k] = shared_data[k]\n",
    "\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [new_shared_data],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        print(f\"Encrypted the {ratio_enc * 100:.2f}% parameters\")\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        for k in results_dict.keys():\n",
    "            results_dict[k].append(metrics[k])\n",
    "        user.print_and_mark_correct(new_reconstructed_user_data, true_user_data)\n",
    "    return results_dict\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90d463bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sens_path = sensitivity_path + f\"{cfg.case.model}_tag_mean_sens.pt\"\n",
    "sens_list = torch.load(sens_path)\n",
    "\n",
    "for ratio in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:\n",
    "    results_select_dict = encrypt_selective(sens_list, \n",
    "                                            ratio_enc=ratio,\n",
    "                                            num_repeat=10)\n",
    "    torch.save(results_select_dict, f\"./lm/results/tag/{cfg.case.user.user_idx}/{cfg.case.model}_selective_{ratio}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "288e95dd",
   "metadata": {},
   "source": [
    "## Random Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6722bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# below are for for random encryption by parameters (modified):\n",
    "\n",
    "def encrypt_random(ratio_enc, num_repeat=1):\n",
    "    # Randomly pick a proportion to encrypt\n",
    "    protected_params = random.sample(range(num_param), (int) (ratio_enc * num_param))\n",
    "\n",
    "    results_dict = {\n",
    "        \"accuracy\": [], \n",
    "        \"sacrebleu\": [], \n",
    "        \"feat_mse\": [], \n",
    "        \"google_bleu\": [],\n",
    "        \"rouge1\": [],\n",
    "        \"rouge2\": [],\n",
    "        \"rougeL\": [],\n",
    "        \"token_acc\": [],\n",
    "        \"token_avg_accuracy\": []\n",
    "    }\n",
    "\n",
    "    for _ in tqdm(range(num_repeat), desc=f\"p={ratio_enc}\"):\n",
    "        new_gradient = copy.deepcopy(shared_data['gradients'])\n",
    "        new_gradient = encrypt_by_index(new_gradient, protected_params)\n",
    "\n",
    "        gradient_with_he = list(new_gradient)\n",
    "        new_shared_data = dict()\n",
    "        for k in shared_data.keys():\n",
    "            if k == \"gradients\":\n",
    "                new_shared_data[k] = gradient_with_he\n",
    "            else:\n",
    "                new_shared_data[k] = shared_data[k]\n",
    "\n",
    "        # print(new_shared_data['gradients'])\n",
    "\n",
    "        new_reconstructed_user_data, stats = attacker.reconstruct([server_payload],\n",
    "                                                                [new_shared_data],\n",
    "                                                                server.secrets,\n",
    "                                                                dryrun=cfg.dryrun)\n",
    "        print(f\"Randomly encrypted {ratio_enc * 100}% parameters\")\n",
    "        metrics = breaching.analysis.report(new_reconstructed_user_data,\n",
    "                                            true_user_data,\n",
    "                                            [server_payload],\n",
    "                                            server.model,\n",
    "                                            order_batch=True,\n",
    "                                            compute_full_iip=False,\n",
    "                                            cfg_case=cfg.case,\n",
    "                                            setup=setup)\n",
    "        \n",
    "        for k in results_dict.keys():\n",
    "            results_dict[k].append(metrics[k])\n",
    "        user.print_and_mark_correct(new_reconstructed_user_data, true_user_data)\n",
    "    return results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78350fdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ratio in [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:\n",
    "    results_random_list = []\n",
    "    for _ in tqdm(range(10), desc=\"Testing random init...\"):\n",
    "        results_dict = encrypt_random(ratio_enc=ratio, num_repeat=10)\n",
    "        results_random_list.append(results_dict)\n",
    "    torch.save(results_random_list, f\"./lm/results/tag/{cfg.case.user.user_idx}/{cfg.case.model}_random_{ratio}.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
