{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "os.chdir(\"../..\")\n",
    "sys.path.append(os.getcwd())\n",
    "\n",
    "import breaching\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import scipy.stats as stats\n",
    "import scipy.integrate as integrate\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_path = os.getcwd() + \"/lm/sensitivity/\"\n",
    "print(sensitivity_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sensitivity Calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def min_max_normalize(l: list) -> list:\n",
    "    l_min = min(l)\n",
    "    l_max = max(l)\n",
    "    l_norm = [(x - l_min) / (l_max - l_min) for x in l]\n",
    "    return l_norm\n",
    "\n",
    "def scale_to_100(l: list) -> list:\n",
    "    l = min_max_normalize(l)\n",
    "    original_sum = sum(l)\n",
    "    l_scaled = [x / original_sum * 100 for x in l]\n",
    "    return l_scaled\n",
    "\n",
    "def plot_layer_sens_mean(sens: list, model_name: str):\n",
    "    sens_layer_mean = [layer_sens.mean().item() for layer_sens in sens]\n",
    "    sens_layer_mean_scale = scale_to_100(sens_layer_mean)\n",
    "\n",
    "    plt.bar(np.arange(1, len(sens_layer_mean_scale)+1), sens_layer_mean_scale)\n",
    "    plt.xlabel(\"Layer Index\")\n",
    "    plt.ylabel(\"Mean Sensitivity Ratio (%)\")\n",
    "    plt.title(model_name)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_point(user, setup):\n",
    "    data_point = dict()\n",
    "    for data_block in user.dataloader:\n",
    "        data = dict()\n",
    "        for key in data_block:\n",
    "            data[key] = data_block[key].to(device=setup[\"device\"])\n",
    "        data_key = \"input_ids\" if \"input_ids\" in data.keys() else \"inputs\"\n",
    "        data_point = {key: val[0 : 1] for key, val in data.items()}\n",
    "        data_point[data_key] = (\n",
    "            data_point[data_key] + user.generator_input.sample(data_point[data_key].shape)\n",
    "            if user.generator_input is not None\n",
    "            else data_point[data_key]\n",
    "        )\n",
    "        break\n",
    "    return data_point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from tqdm import tqdm\n",
    "\n",
    "def get_sensitivity(model, loss_fn, data_point, device, discrete_grad=False, grad_on_x=False):\n",
    "    model.eval()\n",
    "    outputs = model(**data_point)\n",
    "    labels = data_point[\"labels\"].to(torch.float32)\n",
    "    gt_label = torch.Tensor([labels[0, 1]]).long().to(device)\n",
    "\n",
    "    one_hot_labels = torch.zeros_like(outputs).to(device)\n",
    "    for i in range(labels.shape[-1]):\n",
    "        if i > 0:\n",
    "            one_hot_labels[0, i, int(labels[0, i].item())] = 1\n",
    "\n",
    "    if grad_on_x:\n",
    "        def get_grad_x(input):\n",
    "            outputs = model(**input)\n",
    "            l = loss_fn(outputs, one_hot_labels)\n",
    "            l.backward(create_graph=True)\n",
    "            dl_dw = [param.grad.clone().detach() for param in model.parameters()]\n",
    "            model.zero_grad()\n",
    "            return dl_dw\n",
    "\n",
    "        data_point_plus = copy.deepcopy(data_point)\n",
    "        data_point_plus[\"input_ids\"][0] += 1\n",
    "        data_point_minus = copy.deepcopy(data_point)\n",
    "        data_point_minus[\"input_ids\"][0] -= 1\n",
    "        \n",
    "        dl_dw = get_grad_x(data_point)\n",
    "        dl_dw_plus = get_grad_x(data_point_plus)\n",
    "        dl_dw_minus = get_grad_x(data_point_minus)\n",
    "\n",
    "        d2l_dwdx = []\n",
    "        for i in range(len(dl_dw)):\n",
    "            grad_minus = dl_dw_minus[i]\n",
    "            grad = dl_dw[i]\n",
    "            grad_plus = dl_dw_plus[i]\n",
    "            d2l_dwdx.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))\n",
    "        \n",
    "        return d2l_dwdx\n",
    "    else:\n",
    "        one_hot_labels_minus = torch.zeros_like(outputs).to(device)\n",
    "        one_hot_labels_plus = torch.zeros_like(outputs).to(device)\n",
    "        for i in range(labels.shape[-1]):\n",
    "            if i > 0:\n",
    "                one_hot_labels_minus[0, i, int(labels[0, i].item()) - 1] = 1\n",
    "                one_hot_labels_plus[0, i, int(labels[0, i].item()) + 1] = 1\n",
    "        \n",
    "        def get_grad(one_hot_label):\n",
    "            l = loss_fn(outputs, one_hot_label)\n",
    "            l.backward(create_graph=True)\n",
    "            # print(\"Loss:\", l)\n",
    "            grad_list = [param.grad.clone().detach() for param in model.parameters()]\n",
    "            model.zero_grad()\n",
    "            return grad_list\n",
    "        \n",
    "        dl_dw = get_grad(one_hot_labels)\n",
    "        dl_dw_minus = get_grad(one_hot_labels_minus)\n",
    "        dl_dw_plus = get_grad(one_hot_labels_plus)\n",
    "\n",
    "        assert len(dl_dw) == len(dl_dw_minus) == len(dl_dw_plus)\n",
    "\n",
    "        num_layer = len(dl_dw)\n",
    "        d2l_dwdy = []\n",
    "\n",
    "        if discrete_grad:\n",
    "            for i in range(num_layer):\n",
    "                grad_minus = dl_dw_minus[i]\n",
    "                grad = dl_dw[i]\n",
    "                grad_plus = dl_dw_plus[i]\n",
    "                d2l_dwdy.append(torch.max(torch.abs(grad_minus - grad), torch.abs(grad - grad_plus)))\n",
    "        else:\n",
    "            d2l_dwdy = dl_dw\n",
    "        \n",
    "        return d2l_dwdy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_sens(cfg_config, model_name, device, num_user=5, discrete_grad=False, grad_on_x=False):\n",
    "    sens_mean = []\n",
    "    for i in tqdm(range(num_user)):\n",
    "        cfg_config.case.user.user_idx = i+1 # From which user?\n",
    "        setup = dict(device=device, dtype=getattr(torch, cfg_config.case.impl.dtype))\n",
    "        user, server, model, loss_fn = breaching.cases.construct_case(cfg_config.case, setup)\n",
    "        model.to(device=setup[\"device\"])\n",
    "\n",
    "        data_point = get_data_point(user, setup)\n",
    "        sens_single = get_sensitivity(model, loss_fn,\n",
    "                                      copy.deepcopy(data_point),\n",
    "                                      setup[\"device\"], \n",
    "                                      discrete_grad,\n",
    "                                      grad_on_x)\n",
    "    \n",
    "        if i == 0:\n",
    "            sens_mean = sens_single\n",
    "        else:\n",
    "            sens_mean = [sens_mean[j] + sens_single[j] / num_user for j in range(len(sens_mean))]\n",
    "        sens_mean_path = sensitivity_path + model_name + \"_mean_sens\"\n",
    "        if discrete_grad:\n",
    "            sens_mean_path += \"_discrete\"\n",
    "        sens_mean_path += \".pt\"\n",
    "    torch.save(sens_mean, sens_mean_path)\n",
    "    return sens_mean, sens_mean_path\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.data.shape = [2] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"transformer3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer3_sens_mean, transformer3_sens_mean_path = get_mean_sens(cfg, \"200_transformer3_tag\", \n",
    "                                                    torch.device('cpu'), \n",
    "                                                    num_user=5, discrete_grad=False)\n",
    "plot_layer_sens_mean(transformer3_sens_mean, \"Transformer3 (TAG)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer3f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.user.user_idx = 1 # From which user?\n",
    "cfg.case.data.shape = [2] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"transformer3f\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer3f_sens_mean, transformer3f_sens_mean_path = get_mean_sens(cfg, \"transformer3f_tag\", \n",
    "                                                    torch.device('cpu'), \n",
    "                                                    num_user=5, discrete_grad=False)\n",
    "plot_layer_sens_mean(transformer3f_sens_mean, \"Transformer3f (TAG)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformer3t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.user.user_idx = 1 # From which user?\n",
    "cfg.case.data.shape = [2] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"transformer3t\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer3t_sens_mean, transformer3t_sens_mean_path = get_mean_sens(cfg, \"transformer3t_tag\", \n",
    "                                                    torch.device('cpu'), \n",
    "                                                    num_user=5, discrete_grad=False)\n",
    "plot_layer_sens_mean(transformer3t_sens_mean, \"Transformer3t (TAG)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TransformerS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.user.user_idx = 1 # From which user?\n",
    "cfg.case.data.shape = [2] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"transformerS\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformerS_sens_mean, transformerS_sens_mean_path = get_mean_sens(cfg, \"transformerS_tag\", \n",
    "                                                    torch.device('cpu'), \n",
    "                                                    num_user=5, discrete_grad=False)\n",
    "plot_layer_sens_mean(transformerS_sens_mean, \"TransformerS (TAG)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPT-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = breaching.get_config(overrides=[\"case=10_causal_lang_training\",  \"attack=tag\"])\n",
    "          \n",
    "device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "torch.backends.cudnn.benchmark = cfg.case.impl.benchmark\n",
    "\n",
    "cfg.case.user.num_data_points = 1 # How many sentences?\n",
    "cfg.case.user.user_idx = 1 # From which user?\n",
    "cfg.case.data.shape = [2] # This is the sequence length\n",
    "\n",
    "cfg.case.model = \"gpt2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt2_sens_mean, gpt2_sens_mean_path = get_mean_sens(cfg, f\"{cfg.case.model}_tag\", \n",
    "                                                    torch.device('cpu'), \n",
    "                                                    num_user=5, discrete_grad=False)\n",
    "plot_layer_sens_mean(gpt2_sens_mean, \"GPT-2 (TAG)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "breaching",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
