{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "os.chdir(\"../..\")\n",
    "sys.path.append(os.getcwd())\n",
    "\n",
    "import breaching\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import grad\n",
    "import torchvision\n",
    "from torchvision import models, datasets, transforms\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy.stats as stats\n",
    "import scipy.integrate as integrate\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "print(torch.__version__, torchvision.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_path = os.getcwd() + \"/classification/sensitivity/\"\n",
    "print(sensitivity_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cpu\"\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "print(\"Running on %s\" % device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flat_tensor_list(grads):\n",
    "    Shapes_orgianl = []\n",
    "    flat_tensor_lists = torch.tensor([], device=grads[0].device)\n",
    "    for i in grads:\n",
    "        Shapes_orgianl.append(i.shape)\n",
    "        flat_tensor_lists = torch.concat((flat_tensor_lists, i.view(-1)), dim = 0)\n",
    "    return flat_tensor_lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def min_max_normalize(l: list) -> list:\n",
    "    l_min = min(l)\n",
    "    l_max = max(l)\n",
    "    l_norm = [(x - l_min) / (l_max - l_min) for x in l]\n",
    "    return l_norm\n",
    "\n",
    "def scale_to_100(l: list) -> list:\n",
    "    l = min_max_normalize(l)\n",
    "    original_sum = sum(l)\n",
    "    l_scaled = [x / original_sum * 100 for x in l]\n",
    "    return l_scaled\n",
    "\n",
    "def plot_layer_sens_mean(sens: list, model_name: str):\n",
    "    sens_layer_mean = [layer_sens.mean().item() for layer_sens in sens]\n",
    "    sens_layer_mean_scale = scale_to_100(sens_layer_mean)\n",
    "\n",
    "    plt.figure()\n",
    "    plt.bar(np.arange(1, len(sens_layer_mean_scale)+1), sens_layer_mean_scale)\n",
    "    plt.xlabel(\"Layer Index\")\n",
    "    plt.ylabel(\"Mean Sensitivity Ratio (%)\")\n",
    "    plt.title(model_name)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_point(user, setup):\n",
    "    data_point = dict()\n",
    "    for data_block in user.dataloader:\n",
    "        data = dict()\n",
    "        for key in data_block:\n",
    "            data[key] = data_block[key].to(device=setup[\"device\"])\n",
    "        data_key = \"input_ids\" if \"input_ids\" in data.keys() else \"inputs\"\n",
    "        data_point = {key: val[0 : 1] for key, val in data.items()}\n",
    "        data_point[data_key] = (\n",
    "            data_point[data_key] + user.generator_input.sample(data_point[data_key].shape)\n",
    "            if user.generator_input is not None\n",
    "            else data_point[data_key]\n",
    "        )\n",
    "        break\n",
    "    return data_point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sensitivity(model, loss_fn, data_point, device, discrete_grad=False, grad_on_x=False, use_jacobian=False):\n",
    "    model.eval()\n",
    "    labels = data_point[\"labels\"].to(torch.float32)\n",
    "    gt_label = torch.Tensor([labels[0]]).long().to(device)\n",
    "    print(\"gt_label:\", gt_label.item())\n",
    "    \n",
    "    if grad_on_x:\n",
    "        data_point[\"inputs\"].requires_grad = True\n",
    "        outputs = model(**data_point)\n",
    "\n",
    "        gt_one_hot_label = torch.zeros_like(outputs).to(device)\n",
    "        gt_one_hot_label[0, gt_label.item()] = 1\n",
    "        print(\"One hot label:\", torch.argmax(gt_one_hot_label, dim=-1).item())\n",
    "        if use_jacobian:\n",
    "            def grad_w(data_in):\n",
    "                data_point_input = {\n",
    "                    \"inputs\": data_in, \n",
    "                    \"labels\": data_point[\"labels\"]\n",
    "                }\n",
    "                out = model(**data_point_input)\n",
    "                l = loss_fn(out, gt_one_hot_label)\n",
    "                l.backward(create_graph=True)\n",
    "                dl_dw = [param.grad for param in model.parameters()]\n",
    "                return tuple(dl_dw)\n",
    "            \n",
    "            d2l_dwdx = torch.autograd.functional.jacobian(\n",
    "                grad_w,\n",
    "                data_point[\"inputs\"]\n",
    "            )\n",
    "            return d2l_dwdx\n",
    "        else:\n",
    "            loss = loss_fn(outputs, gt_one_hot_label)\n",
    "            loss.backward(create_graph=True)\n",
    "\n",
    "            dl_dw = [param.grad for param in model.parameters()]\n",
    "\n",
    "            d2l_dwdx = []\n",
    "            for idx, layer_grad in enumerate(dl_dw):\n",
    "                print(f\"Layer {idx+1}\")\n",
    "                sens_layer = torch.zeros_like(layer_grad.view(-1))\n",
    "                cnt = 0\n",
    "\n",
    "                layer_grad_flatten = layer_grad.view(-1)\n",
    "                for j in tqdm(range(len(layer_grad_flatten))):\n",
    "                    data_point[\"inputs\"].grad.data.zero_()\n",
    "                    layer_grad_flatten[j].backward(retain_graph=True)\n",
    "                    sens_layer[cnt] = data_point[\"inputs\"].grad.mean().clone().detach()\n",
    "                    cnt += 1\n",
    "                d2l_dwdx.append(sens_layer)\n",
    "            return d2l_dwdx\n",
    "    else:\n",
    "        outputs = model(**data_point)\n",
    "        \n",
    "        gt_one_hot_label = torch.zeros_like(outputs).to(device)\n",
    "        gt_one_hot_label_minus = torch.zeros_like(outputs).to(device)\n",
    "        gt_one_hot_label_plus = torch.zeros_like(outputs).to(device)\n",
    "\n",
    "        gt_one_hot_label[0, gt_label.item()] = 1\n",
    "        gt_one_hot_label_minus[0, gt_label.item() - 1] = 1\n",
    "        gt_one_hot_label_plus[0, gt_label.item() + 1] = 1\n",
    "        print(\"One hot label:\", torch.argmax(gt_one_hot_label, dim=-1).item())\n",
    "        \n",
    "        def get_grad(labels):\n",
    "            l = loss_fn(outputs, labels)\n",
    "            l.backward(create_graph=True)\n",
    "            print(\"Loss:\", l)\n",
    "            grad_list = [param.grad.clone().detach() for param in model.parameters()]\n",
    "            model.zero_grad()\n",
    "            return grad_list\n",
    "        \n",
    "        dl_dw = get_grad(gt_one_hot_label)\n",
    "        print(\"len(dl_dw):\", len(flat_tensor_list(dl_dw)))\n",
    "        dl_dw_minus = get_grad(gt_one_hot_label_minus)\n",
    "        dl_dw_plus = get_grad(gt_one_hot_label_plus)\n",
    "\n",
    "        assert len(dl_dw) == len(dl_dw_minus) == len(dl_dw_plus)\n",
    "\n",
    "        num_layer = len(dl_dw)\n",
    "        d2l_dwdy = []\n",
    "\n",
    "        if discrete_grad:\n",
    "            for i in range(num_layer):\n",
    "                grad_minus = dl_dw_minus[i]\n",
    "                grad = dl_dw[i]\n",
    "                grad_plus = dl_dw_plus[i]\n",
    "                d2l_dwdy.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))\n",
    "        else:\n",
    "            d2l_dwdy = dl_dw\n",
    "        \n",
    "        return d2l_dwdy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_sens(cfg_config, model_name, device, num_user=5, discrete_grad=False, grad_on_x=False, use_jacobian=False):\n",
    "    sens_mean = []\n",
    "    for i in tqdm(range(num_user)):\n",
    "        cfg_config.case.user.user_idx = i+1 # From which user?\n",
    "        setup = dict(device=device, dtype=getattr(torch, cfg_config.case.impl.dtype))\n",
    "        user, server, model, loss_fn = breaching.cases.construct_case(cfg_config.case, setup)\n",
    "        model.to(device=setup[\"device\"])\n",
    "\n",
    "        data_point = get_data_point(user, setup)\n",
    "\n",
    "        # plt.imshow(data_point[\"inputs\"][0].permute(1, 2, 0).cpu())\n",
    "        # plt.title(\"Ground truth image\")\n",
    "        # plt.show()\n",
    "\n",
    "        sens_single = get_sensitivity(model, loss_fn,\n",
    "                                      copy.deepcopy(data_point),\n",
    "                                      setup[\"device\"], \n",
    "                                      discrete_grad,\n",
    "                                      grad_on_x,\n",
    "                                      use_jacobian)\n",
    "    \n",
    "        if i == 0:\n",
    "            sens_mean = sens_single\n",
    "        else:\n",
    "            sens_mean = [sens_mean[j] + sens_single[j] / num_user for j in range(len(sens_mean))]\n",
    "        sens_mean_path = sensitivity_path + model_name + \"_mean_sens\"\n",
    "        if discrete_grad:\n",
    "            sens_mean_path += \"_discrete\"\n",
    "        sens_mean_path += \".pt\"\n",
    "    torch.save(sens_mean, sens_mean_path)\n",
    "    print(sens_mean_path)\n",
    "    return sens_mean, sens_mean_path\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LeNet (CIFAR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=6_large_batch_cifar\"])\n",
    "\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.data.partition=\"balanced\" # 100 unique CIFAR-100 images\n",
    "cfg.case.user.user_idx = 0\n",
    "cfg.case.model='lenet100'\n",
    "\n",
    "cfg.case.user.provide_labels = False\n",
    "cfg.attack.label_strategy = \"yin\" # also works here, as labels are unique\n",
    "\n",
    "# Total variation regularization needs to be smaller on CIFAR-10:\n",
    "cfg.attack.regularization.total_variation.scale = 5e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lenet_sens_mean, lenet_mean_path = get_mean_sens(cfg, \"lenet\", \n",
    "                                                       torch.device('cpu'), \n",
    "                                                       num_user=5, \n",
    "                                                       discrete_grad=False)\n",
    "plot_layer_sens_mean(lenet_sens_mean, \"LeNet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CNN (CIFAR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=6_large_batch_cifar\"])\n",
    "\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.data.partition=\"balanced\" # 100 unique CIFAR-100 images\n",
    "cfg.case.user.user_idx = 0\n",
    "cfg.case.model='CNN_FedAvg'\n",
    "\n",
    "cfg.case.user.provide_labels = False\n",
    "cfg.attack.label_strategy = \"yin\" # also works here, as labels are unique\n",
    "\n",
    "# Total variation regularization needs to be smaller on CIFAR-10:\n",
    "cfg.attack.regularization.total_variation.scale = 5e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CNN_FedAvg_sens_mean, CNN_FedAvg_mean_path = get_mean_sens(cfg, \"CNN_FedAvg\", \n",
    "                                                       torch.device('cpu'), \n",
    "                                                       num_user=5, \n",
    "                                                       discrete_grad=False)\n",
    "plot_layer_sens_mean(CNN_FedAvg_sens_mean, \"CNN_FedAvg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResNet18 (CIFAR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=6_large_batch_cifar\"])\n",
    "\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.data.partition=\"unique-class\"\n",
    "cfg.case.model='resnet18'\n",
    "\n",
    "cfg.case.user.provide_labels = False\n",
    "cfg.attack.label_strategy = \"yin\" # also works here, as labels are unique\n",
    "\n",
    "# Total variation regularization needs to be smaller on CIFAR-10:\n",
    "cfg.attack.regularization.total_variation.scale = 5e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet18_cifar_sens_mean, resnet18_cifar_mean_path = get_mean_sens(cfg, \"resnet18_cifar\", \n",
    "                                                       torch.device('cpu'), \n",
    "                                                       num_user=5, \n",
    "                                                       discrete_grad=False)\n",
    "plot_layer_sens_mean(resnet18_cifar_sens_mean, \"ResNet18\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "breaching_vision",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
