{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yHRzgIPrf9Vh"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from torch import nn\n",
        "from torch.utils import data\n",
        "from tqdm import tqdm\n",
        "import copy\n",
        "import time\n",
        "import abc\n",
        "from typing import Any, List, Optional, Callable\n",
        "import torch.optim as optim\n",
        "import os\n",
        "import torch.utils.data as data_utils\n",
        "import urllib\n",
        "import pandas as pd\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from collections import namedtuple\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from sklearn.model_selection import train_test_split\n",
        "from math import pi, sqrt\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ENoYJ5K_iRow"
      },
      "source": [
        "# Load Datasets\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W8qdFcjViQdP"
      },
      "outputs": [],
      "source": [
        "\n",
        "class NetRegression(nn.Module):\n",
        "    def __init__(self, input_size, num_classes):\n",
        "        super(NetRegression, self).__init__()\n",
        "        size = 1000\n",
        "        self.first = nn.Linear(input_size, size)\n",
        "        self.mid = nn.Linear(size, size)\n",
        "        self.last = nn.Linear(size, num_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mid = F.selu(self.first(x))\n",
        "        mid = F.selu(self.mid(mid))\n",
        "        out = self.last(mid)\n",
        "        return out\n",
        "\n",
        "\n",
        "def load_adult(nTrain=None, scaler=True, shuffle=False):\n",
        "    if shuffle:\n",
        "        print('Warning: I wont shuffle because adult has fixed test set')\n",
        "    '''\n",
        "    :param smaller: selecting this flag it is possible to generate a smaller version of the training and test sets.\n",
        "    :param scaler: if True it applies a StandardScaler() (from sklearn.preprocessing) to the data.\n",
        "    :return: train and test data.\n",
        "\n",
        "    Features of the Adult dataset:\n",
        "    0. age: continuous.\n",
        "    1. workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.\n",
        "    2. fnlwgt: continuous.\n",
        "    3. education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th,\n",
        "    Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.\n",
        "    4. education-num: continuous.\n",
        "    5. marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed,\n",
        "    Married-spouse-absent, Married-AF-spouse.\n",
        "    6. occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty,\n",
        "    Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv,\n",
        "    Protective-serv, Armed-Forces.\n",
        "    7. relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.\n",
        "    8. race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.\n",
        "    9. sex: Female, Male.\n",
        "    10. capital-gain: continuous.\n",
        "    11. capital-loss: continuous.\n",
        "    12. hours-per-week: continuous.\n",
        "    13. native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc),\n",
        "    India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico,\n",
        "    Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala,\n",
        "    Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.\n",
        "    (14. label: <=50K, >50K)\n",
        "    '''\n",
        "    if not os.path.isfile('adult.data'):\n",
        "        urllib.request.urlretrieve(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\", \"adult.data\")\n",
        "        urllib.request.urlretrieve(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\", \"adult.test\")\n",
        "    data = pd.read_csv(\n",
        "        \"adult.data\",\n",
        "        names=[\n",
        "            \"Age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\", \"marital-status\",\n",
        "            \"occupation\", \"relationship\", \"race\", \"gender\", \"capital gain\", \"capital loss\",\n",
        "            \"hours per week\", \"native-country\", \"income\"]\n",
        "    )\n",
        "    len_train = len(data.values[:, -1])\n",
        "    data_test = pd.read_csv(\n",
        "        \"adult.test\",\n",
        "        names=[\n",
        "            \"Age\", \"workclass\", \"fnlwgt\", \"education\", \"education-num\", \"marital-status\",\n",
        "            \"occupation\", \"relationship\", \"race\", \"gender\", \"capital gain\", \"capital loss\",\n",
        "            \"hours per week\", \"native-country\", \"income\"],\n",
        "        skiprows=1, header=None\n",
        "    )\n",
        "    data = pd.concat([data, data_test])\n",
        "    # Considering the relative low portion of missing data, we discard rows with missing data\n",
        "    domanda = data[\"workclass\"][4].values[1]\n",
        "    data = data[data[\"workclass\"] != domanda]\n",
        "    data = data[data[\"occupation\"] != domanda]\n",
        "    data = data[data[\"native-country\"] != domanda]\n",
        "    # Here we apply discretisation on column marital_status\n",
        "    data.replace(['Divorced', 'Married-AF-spouse',\n",
        "                  'Married-civ-spouse', 'Married-spouse-absent',\n",
        "                  'Never-married', 'Separated', 'Widowed'],\n",
        "                 ['not married', 'married', 'married', 'married',\n",
        "                  'not married', 'not married', 'not married'], inplace=True)\n",
        "    # categorical fields\n",
        "    category_col = ['workclass', 'race', 'education', 'marital-status', 'occupation',\n",
        "                    'relationship', 'gender', 'native-country', 'income']\n",
        "    for col in category_col:\n",
        "        b, c = np.unique(data[col], return_inverse=True)\n",
        "        data[col] = c\n",
        "    datamat = data.values\n",
        "    # Care there is a final dot in the class only in test set which creates 4 different classes\n",
        "    target = np.array([-1.0 if (val == 0 or val == 1) else 1.0 for val in np.array(datamat)[:, -1]])\n",
        "    datamat = datamat[:, :-1]\n",
        "    if scaler:\n",
        "        scaler = StandardScaler()\n",
        "        scaler.fit(datamat)\n",
        "        datamat = scaler.transform(datamat)\n",
        "    if nTrain is None:\n",
        "        nTrain = len_train\n",
        "    data = namedtuple('_', 'data, target')(datamat[:nTrain, :], target[:nTrain])\n",
        "    data_test = namedtuple('_', 'data, target')(datamat[len_train:, :], target[len_train:])\n",
        "\n",
        "    encoded_data = pd.DataFrame(data.data)\n",
        "    encoded_data['Target'] = (data.target + 1) / 2\n",
        "    to_protect = 1. * (data.data[:, 9] != data.data[:, 9][0])\n",
        "\n",
        "    encoded_data_test = pd.DataFrame(data_test.data)\n",
        "    encoded_data_test['Target'] = (data_test.target + 1) / 2\n",
        "    to_protect_test = 1. * (data_test.data[:, 9] != data_test.data[:, 9][0])\n",
        "\n",
        "    # Variable to protect (9:Sex) is removed from dataset\n",
        "    return encoded_data.drop(columns=9), to_protect, encoded_data_test.drop(columns=9), to_protect_test\n",
        "\n",
        "# Load insurance dataset\n",
        "def load_insurance():\n",
        "    insurance = pd.read_csv('insurance.csv')\n",
        "    insurance['sex'] = insurance['sex'].map({'male': 1, 'female': 0})\n",
        "    insurance['smoker'] = insurance['smoker'].map({'yes': 1, 'no': 0})\n",
        "    insurance['region'] = insurance['region'].map({'northeast': 3, 'northwest': 2, 'southeast': 1, 'southwest': 0})\n",
        "    cols_to_norm = ['charges','age','bmi']\n",
        "    insurance[cols_to_norm] = insurance[cols_to_norm].apply(lambda x: (x - x.min()) / (x.max() - x.min()))\n",
        "\n",
        "    E = insurance['sex']\n",
        "    E = np.array(E)\n",
        "    y = insurance['charges']\n",
        "    y = np.array(y)\n",
        "    insurance = insurance.drop(['sex','charges'], axis=1)\n",
        "    x = np.array(insurance)\n",
        "\n",
        "    return train_test_split(x, E, y, test_size=0.2, random_state = 0)\n",
        "\n",
        "# Load insurance dataset\n",
        "def read_crimes(label='ViolentCrimesPerPop', sensitive_attribute='racepctblack', env_partition=0.05):\n",
        "  if not os.path.isfile('communities.data'):\n",
        "      urllib.request.urlretrieve(\n",
        "          \"http://archive.ics.uci.edu/ml/machine-learning-databases/communities/communities.data\", \"communities.data\")\n",
        "      urllib.request.urlretrieve(\n",
        "          \"http://archive.ics.uci.edu/ml/machine-learning-databases/communities/communities.names\",\n",
        "          \"communities.names\")\n",
        "  # create names\n",
        "  names = []\n",
        "  with open('communities.names', 'r') as file:\n",
        "      for line in file:\n",
        "          if line.startswith('@attribute'):\n",
        "              names.append(line.split(' ')[1])\n",
        "  # load data\n",
        "  data = pd.read_csv('communities.data', names=names, na_values=['?'])\n",
        "  data.drop(['state', 'county', 'community', 'fold', 'communityname'], axis=1, inplace=True)\n",
        "  data = data.replace('?', np.nan)\n",
        "  data['OtherPerCap'] = data['OtherPerCap'].fillna(data['OtherPerCap'].astype(float).mean())\n",
        "  data = data.dropna(axis=1)\n",
        "  data['OtherPerCap'] = data['OtherPerCap'].astype(float)\n",
        "  # shuffle\n",
        "  data = data.sample(frac=1, replace=False).reset_index(drop=True)\n",
        "  to_drop = []\n",
        "  y = data[label].values\n",
        "  to_drop += [label]\n",
        "  z = data[sensitive_attribute].values\n",
        "  to_drop += [sensitive_attribute]\n",
        "  data.drop(to_drop + [label], axis=1, inplace=True)\n",
        "  for n in data.columns:\n",
        "      data[n] = (data[n] - data[n].mean()) / data[n].std()\n",
        "  x = np.array(data.values)\n",
        "  x = x[z >= env_partition]\n",
        "  y = y[z >= env_partition]\n",
        "  z = z[z >= env_partition]\n",
        "  return train_test_split(x, z, y, test_size=0.2, random_state = 0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3HksaKqGivtC"
      },
      "source": [
        "# Influence Calculation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QBqE5RySiw74"
      },
      "outputs": [],
      "source": [
        "def _set_attr(obj, names, val):\n",
        "    if len(names) == 1:\n",
        "        setattr(obj, names[0], val)\n",
        "    else:\n",
        "        _set_attr(getattr(obj, names[0]), names[1:], val)\n",
        "\n",
        "\n",
        "def _del_attr(obj, names):\n",
        "    if len(names) == 1:\n",
        "        delattr(obj, names[0])\n",
        "    else:\n",
        "        _del_attr(getattr(obj, names[0]), names[1:])\n",
        "\n",
        "class BaseObjective(abc.ABC):\n",
        "    \"\"\"An abstract adapter that provides torch-influence with project-specific information\n",
        "    about how training and test objectives are computed.\n",
        "\n",
        "    In order to use torch-influence in your project, a subclass of this module should be\n",
        "    created that implements this module's four abstract methods.\n",
        "    \"\"\"\n",
        "\n",
        "    @abc.abstractmethod\n",
        "    def train_outputs(self, model: nn.Module, batch: Any) -> torch.Tensor:\n",
        "        \"\"\"Returns a batch of model outputs (e.g., logits, probabilities) from a batch of data.\n",
        "\n",
        "        Args:\n",
        "            model: the model.\n",
        "            batch: a batch of training data.\n",
        "\n",
        "        Returns:\n",
        "            the model outputs produced from the batch.\n",
        "        \"\"\"\n",
        "\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    @abc.abstractmethod\n",
        "    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Any) -> torch.Tensor:\n",
        "        \"\"\"Returns the **mean**-reduced loss of the model outputs produced from a batch of data.\n",
        "\n",
        "        Args:\n",
        "            outputs: a batch of model outputs.\n",
        "            batch: a batch of training data.\n",
        "\n",
        "        Returns:\n",
        "            the loss of the outputs over the batch.\n",
        "\n",
        "        Note:\n",
        "            There may be some ambiguity in how to define :meth:`train_outputs()` and\n",
        "            :meth:`train_loss_on_outputs()`: what point in the forward pass deliniates\n",
        "            outputs from loss function? For example, in binary classification, the\n",
        "            outputs can reasonably be taken to be the model logits or normalized probabilities.\n",
        "\n",
        "            For standard use of influence functions, both choices produce the same behaviour.\n",
        "            However, if using the Gauss-Newton Hessian approximation for influence functions,\n",
        "            we require that :meth:`train_loss_on_outputs()` be convex in the model\n",
        "            outputs.\n",
        "\n",
        "        See also:\n",
        "            :class:`CGInfluenceModule`\n",
        "            :class:`LiSSAInfluenceModule`\n",
        "        \"\"\"\n",
        "\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    @abc.abstractmethod\n",
        "    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"Returns the regularization loss at a set of model parameters.\n",
        "\n",
        "        Args:\n",
        "            params: a flattened vector of model parameters.\n",
        "\n",
        "        Returns:\n",
        "            the regularization loss.\n",
        "        \"\"\"\n",
        "\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    def train_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:\n",
        "        \"\"\"Returns the **mean**-reduced regularized loss of a model over a batch of data.\n",
        "\n",
        "        This method should not be overridden for most use cases. By default, torch-influence\n",
        "        takes and expects the overall training loss to be::\n",
        "\n",
        "            outputs = train_outputs(model, batch)\n",
        "            loss = train_loss_on_outputs(outputs, batch) + train_regularization(params)\n",
        "\n",
        "        Args:\n",
        "            model: the model.\n",
        "            params: a flattened vector of the model's parameters.\n",
        "            batch: a batch of training data.\n",
        "\n",
        "        Returns:\n",
        "            the training loss over the batch.\n",
        "        \"\"\"\n",
        "\n",
        "        outputs = self.train_outputs(model, batch)\n",
        "        return self.train_loss_on_outputs(outputs, batch) + self.train_regularization(params)\n",
        "\n",
        "    @abc.abstractmethod\n",
        "    def test_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:\n",
        "        \"\"\"Returns the **mean**-reduced loss of a model over a batch of data.\n",
        "\n",
        "        Args:\n",
        "            model: the model.\n",
        "            params: a flattened vector of the model's parameters.\n",
        "            batch: a batch of test data.\n",
        "\n",
        "        Returns:\n",
        "            the test loss over the batch.\n",
        "        \"\"\"\n",
        "\n",
        "        raise NotImplementedError()\n",
        "\n",
        "\n",
        "\n",
        "class BaseInfluenceModule(abc.ABC):\n",
        "    \"\"\"The core module that contains convenience methods for computing influence functions.\n",
        "\n",
        "    Args:\n",
        "        model: the model of interest.\n",
        "        objective: an implementation of :class:`BaseObjective`.\n",
        "        train_loader: a training dataset loader.\n",
        "        test_loader: a test dataset loader.\n",
        "        device: the device on which operations are performed.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "            self,\n",
        "            model: nn.Module,\n",
        "            objective: BaseObjective,\n",
        "            train_loader: data.DataLoader,\n",
        "            test_loader: data.DataLoader,\n",
        "            device: torch.device\n",
        "    ):\n",
        "        model.eval()\n",
        "        self.model = model.to(device)\n",
        "        self.device = device\n",
        "\n",
        "        self.is_model_functional = False\n",
        "        self.params_names = tuple(name for name, _ in self._model_params())\n",
        "        self.params_shape = tuple(p.shape for _, p in self._model_params())\n",
        "\n",
        "        self.objective = objective\n",
        "        self.train_loader = train_loader\n",
        "        self.test_loader = test_loader\n",
        "\n",
        "    @abc.abstractmethod\n",
        "    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:\n",
        "        \"\"\"Computes an inverse-Hessian vector product, where the Hessian is specifically\n",
        "        that of the (mean) empirical risk over the training dataset.\n",
        "\n",
        "        Args:\n",
        "            vec: a vector.\n",
        "\n",
        "        Returns:\n",
        "            the inverse-Hessian vector product.\n",
        "        \"\"\"\n",
        "\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    # ====================================================\n",
        "    # Interface functions\n",
        "    # ====================================================\n",
        "\n",
        "    def train_loss_grad(self, train_idxs: List[int]) -> torch.Tensor:\n",
        "        \"\"\"Returns the gradient of the (mean) training loss over a set of training\n",
        "        data points with respect to the model's flattened parameters.\n",
        "\n",
        "        Args:\n",
        "            train_idxs: the indices of the training points.\n",
        "\n",
        "        Returns:\n",
        "            the loss gradient at the training points.\n",
        "        \"\"\"\n",
        "\n",
        "        return self._loss_grad(train_idxs, train=True)\n",
        "\n",
        "    def test_loss_grad(self, test_idxs: List[int]) -> torch.Tensor:\n",
        "        \"\"\"Returns the gradient of the (mean) test loss over a set of test\n",
        "        data points with respect to the model's flattened parameters.\n",
        "\n",
        "        Args:\n",
        "           test_idxs: the indices of the test points.\n",
        "\n",
        "        Returns:\n",
        "           the loss gradient at the test points.\n",
        "        \"\"\"\n",
        "\n",
        "        return self._loss_grad(test_idxs, train=False)\n",
        "\n",
        "    def stest(self, test_idxs: List[int]) -> torch.Tensor:\n",
        "        \"\"\"This function simply composes :func:`inverse_hvp` with :func:`test_loss_grad`.\n",
        "\n",
        "        In the original influence function paper, the resulting vector was called\n",
        "        :math:`\\mathbf{s}_{\\mathrm{test}}`.\n",
        "\n",
        "        Args:\n",
        "            test_idxs: the indices of the test points.\n",
        "\n",
        "        Returns:\n",
        "            the :math:`\\mathbf{s}_{\\mathrm{test}}` vector.\n",
        "        \"\"\"\n",
        "\n",
        "        return self.inverse_hvp(self.test_loss_grad(test_idxs))\n",
        "\n",
        "    def influences(\n",
        "            self,\n",
        "            train_idxs: List[int],\n",
        "            test_idxs: List[int],\n",
        "            stest: Optional[torch.Tensor] = None,\n",
        "            target_grad: Optional[torch.Tensor] = None,\n",
        "            influence_objective: Optional[str] = 'Taylor'\n",
        "    ) -> torch.Tensor:\n",
        "        \"\"\"Returns the influence scores of a set of training data points with respect to\n",
        "        the (mean) test loss over a set of test data points.\n",
        "\n",
        "        Specifically, this method returns a 1D tensor of ``len(train_idxs)`` influence scores.\n",
        "        These scores estimate the following quantities:\n",
        "\n",
        "            Let :math:`\\mathcal{L}_0` be the (mean) test loss of the current model\n",
        "            over the input test points. Suppose we produce a new model by (1) removing\n",
        "            the ``train_idxs[i]``-th example from the training dataset and (2) retraining\n",
        "            the model on this one-smaller dataset. Let :math:`\\mathcal{L}` be the (mean)\n",
        "            test loss of the **new** model over the input test points. Then the ``i``-th\n",
        "            influence score estimates :math:`\\mathcal{L} - \\mathcal{L}_0`.\n",
        "\n",
        "        Args:\n",
        "            train_idxs: the indices of the training points.\n",
        "            test_idxs: the indices of the test points.\n",
        "            stest: this method requires the :math:`\\mathbf{s}_{\\mathrm{test}}` vector of\n",
        "                the input test points. If not ``None``, this argument will be used taken as\n",
        "                :math:`\\mathbf{s}_{\\mathrm{test}}`. Otherwise, :math:`\\mathbf{s}_{\\mathrm{test}}`\n",
        "                will be computed internally with :meth:`stest`.\n",
        "\n",
        "        Returns:\n",
        "            the influence scores.\n",
        "        \"\"\"\n",
        "\n",
        "        if len(test_idxs) == 0:\n",
        "            time_start = time.time()\n",
        "            stest = self.inverse_hvp(self._flatten_params_like(target_grad), train_len = len(self.train_loader.dataset), unlearning=False)\n",
        "            time_end = time.time()\n",
        "        else:\n",
        "            stest = self.stest(test_idxs) if (stest is None) else stest.to(self.device)\n",
        "\n",
        "        scores = []\n",
        "        for grad_z, _ in self._loss_grad_loader_wrapper(batch_size=1, subset=train_idxs, train=True):\n",
        "            s = grad_z @ stest\n",
        "            scores.append(s)\n",
        "        return torch.tensor(scores) / len(self.train_loader.dataset), time_end-time_start\n",
        "\n",
        "    # Unlearning\n",
        "    def unlearning(\n",
        "            self,\n",
        "            train_idxs: List[int]\n",
        "    ) -> torch.Tensor:\n",
        "        \"\"\"Unlearns pre-specified training samples from a trained model .\n",
        "\n",
        "        Returns:\n",
        "            the unlearned model\n",
        "        \"\"\"\n",
        "        time_start = time.time()\n",
        "        curr_vec = self.inverse_hvp(self.train_loss_grad(train_idxs), train_len = len(self.train_loader.dataset), unlearning=True)\n",
        "        time_end = time.time()\n",
        "        return self.model, time_end-time_start\n",
        "\n",
        "    # ACV\n",
        "    def ACV(\n",
        "            self,\n",
        "            train_idx: List[int],\n",
        "            target_grad: Optional[torch.Tensor] = None\n",
        "    ) -> torch.Tensor:\n",
        "        \"\"\"Unlearns pre-specified training samples from a trained model .\n",
        "\n",
        "        Returns:\n",
        "            the unlearned model\n",
        "        \"\"\"\n",
        "        time_start = time.time()\n",
        "        curr_vec = self.inverse_hvp(self._flatten_params_like(target_grad), train_len = len(self.train_loader.dataset), unlearning=True)\n",
        "        time_end = time.time()\n",
        "\n",
        "        # # calculate loss\n",
        "        # for z_i, _ in self._loader_wrapper(batch_size=1, subset=train_idx, train=True):\n",
        "        #     params = self._model_params(with_names=False)\n",
        "        #     flat_params = self._flatten_params_like(params)\n",
        "        #     loss_est = self.objective.train_loss(self.model, flat_params, z_i)\n",
        "        return self.model, time_end-time_start\n",
        "\n",
        "    # ====================================================\n",
        "    # Private helper functions\n",
        "    # ====================================================\n",
        "\n",
        "    # Model and parameter helpers\n",
        "\n",
        "    def _model_params(self, with_names=True):\n",
        "        assert not self.is_model_functional\n",
        "        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)\n",
        "\n",
        "    def _model_make_functional(self):\n",
        "        assert not self.is_model_functional\n",
        "        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))\n",
        "\n",
        "        for name in self.params_names:\n",
        "            _del_attr(self.model, name.split(\".\"))\n",
        "        self.is_model_functional = True\n",
        "\n",
        "        return params\n",
        "\n",
        "    def _model_reinsert_params(self, params, register=False):\n",
        "        for name, p in zip(self.params_names, params):\n",
        "            _set_attr(self.model, name.split(\".\"), torch.nn.Parameter(p) if register else p)\n",
        "        self.is_model_functional = not register\n",
        "\n",
        "    def _flatten_params_like(self, params_like):\n",
        "        vec = []\n",
        "        for p in params_like:\n",
        "            vec.append(p.view(-1))\n",
        "        return torch.cat(vec)\n",
        "\n",
        "    def _reshape_like_params(self, vec):\n",
        "        pointer = 0\n",
        "        split_tensors = []\n",
        "        for dim in self.params_shape:\n",
        "            num_param = dim.numel()\n",
        "            split_tensors.append(vec[pointer: pointer + num_param].view(dim))\n",
        "            pointer += num_param\n",
        "        return tuple(split_tensors)\n",
        "\n",
        "    # Data helpers\n",
        "\n",
        "    def _transfer_to_device(self, batch):\n",
        "        if isinstance(batch, torch.Tensor):\n",
        "            return batch.to(self.device)\n",
        "        elif isinstance(batch, (tuple, list)):\n",
        "            return type(batch)(self._transfer_to_device(x) for x in batch)\n",
        "        elif isinstance(batch, dict):\n",
        "            return {k: self._transfer_to_device(x) for k, x in batch.items()}\n",
        "        else:\n",
        "            raise NotImplementedError()\n",
        "\n",
        "    def _loader_wrapper(self, train, batch_size=None, subset=None, sample_n_batches=-1):\n",
        "        loader = self.train_loader if train else self.test_loader\n",
        "        batch_size = loader.batch_size if (batch_size is None) else batch_size\n",
        "\n",
        "        if subset is None:\n",
        "            dataset = loader.dataset\n",
        "        else:\n",
        "            subset = np.array(subset)\n",
        "            if len(subset.shape) != 1 or len(np.unique(subset)) != len(subset):\n",
        "                raise ValueError()\n",
        "            if np.any((subset < 0) | (subset >= len(loader.dataset))):\n",
        "                raise IndexError()\n",
        "            dataset = data.Subset(loader.dataset, indices=subset)\n",
        "\n",
        "        if sample_n_batches > 0:\n",
        "            num_samples = sample_n_batches * batch_size\n",
        "            sampler = data.RandomSampler(data_source=dataset, replacement=True, num_samples=num_samples)\n",
        "        else:\n",
        "            sampler = None\n",
        "\n",
        "        new_loader = data.DataLoader(\n",
        "            dataset=dataset,\n",
        "            batch_size=batch_size,\n",
        "            shuffle=False,\n",
        "            sampler=sampler,\n",
        "            collate_fn=loader.collate_fn,\n",
        "            num_workers=loader.num_workers,\n",
        "            worker_init_fn=loader.worker_init_fn,\n",
        "        )\n",
        "\n",
        "        data_left = len(dataset)\n",
        "        for batch in new_loader:\n",
        "            batch = self._transfer_to_device(batch)\n",
        "            size = min(batch_size, data_left)  # deduce batch size\n",
        "            yield batch, size\n",
        "            data_left -= size\n",
        "\n",
        "    # Loss and autograd helpers\n",
        "\n",
        "    def _loss_grad_loader_wrapper(self, train, **kwargs):\n",
        "        params = self._model_params(with_names=False)\n",
        "        flat_params = self._flatten_params_like(params)\n",
        "\n",
        "        for batch, batch_size in self._loader_wrapper(train=train, **kwargs):\n",
        "            loss_fn = self.objective.train_loss if train else self.objective.test_loss\n",
        "            loss = loss_fn(model=self.model, params=flat_params, batch=batch)\n",
        "            yield self._flatten_params_like(torch.autograd.grad(loss, params)), batch_size\n",
        "\n",
        "    def _loss_grad(self, idxs, train):\n",
        "        grad = 0.0\n",
        "        for grad_batch, batch_size in self._loss_grad_loader_wrapper(subset=idxs, train=train):\n",
        "            grad = grad + grad_batch * batch_size\n",
        "        return grad / len(idxs)\n",
        "\n",
        "    def _hvp_at_batch(self, batch, flat_params, vec, gnh):\n",
        "\n",
        "        def f(theta_):\n",
        "            self._model_reinsert_params(self._reshape_like_params(theta_))\n",
        "            return self.objective.train_loss(self.model, theta_, batch)\n",
        "\n",
        "        def out_f(theta_):\n",
        "            self._model_reinsert_params(self._reshape_like_params(theta_))\n",
        "            return self.objective.train_outputs(self.model, batch)\n",
        "\n",
        "        def loss_f(out_):\n",
        "            return self.objective.train_loss_on_outputs(out_, batch)\n",
        "\n",
        "        def reg_f(theta_):\n",
        "            return self.objective.train_regularization(theta_)\n",
        "\n",
        "        def train_loss_unregularized(theta_):\n",
        "            self._model_reinsert_params(self._reshape_like_params(theta_))\n",
        "            return self.objective.test_loss(self.model, theta_, batch)\n",
        "\n",
        "        if gnh=='gnh':\n",
        "            y, jvp = torch.autograd.functional.jvp(out_f, flat_params, v=vec)\n",
        "            hjvp = torch.autograd.functional.hvp(loss_f, y, v=jvp)[1]\n",
        "            gnhvp_batch = torch.autograd.functional.vjp(out_f, flat_params, v=hjvp)[1]\n",
        "            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]\n",
        "        elif gnh=='emp':\n",
        "            y, jvp = torch.autograd.functional.jvp(train_loss_unregularized, flat_params, v=vec)\n",
        "            gnhvp_batch = torch.autograd.functional.vjp(train_loss_unregularized, flat_params, v=jvp)[1]\n",
        "            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]\n",
        "        else:\n",
        "            return torch.autograd.functional.hvp(f, flat_params, v=vec)[1]\n",
        "\n",
        "\n",
        "class LiSSAInfluenceModule(BaseInfluenceModule):\n",
        "    r\"\"\"An influence module that computes inverse-Hessian vector products\n",
        "    using the Linear time Stochastic Second-Order Algorithm (LiSSA).\n",
        "\n",
        "    At a high level, LiSSA estimates an inverse-Hessian vector product\n",
        "    by using truncated Neumann iterations:\n",
        "\n",
        "    .. math::\n",
        "        \\mathbf{H}^{-1}\\mathbf{v} \\approx \\frac{1}{R}\\sum\\limits_{r = 1}^R\n",
        "        \\left(\\sigma^{-1}\\sum_{t = 1}^{T}(\\mathbf{I} - \\sigma^{-1}\\mathbf{H}_{r, t})^t\\mathbf{v}\\right)\n",
        "\n",
        "    Here, :math:`\\mathbf{H}` is the risk Hessian matrix and :math:`\\mathbf{H}_{r, t}` are\n",
        "    loss Hessian matrices over batches of training data drawn randomly with replacement (we\n",
        "    also use a batch size in ``train_loader``). In addition, :math:`\\sigma > 0` is a scaling\n",
        "    factor chosen sufficiently large such that :math:`\\sigma^{-1} \\mathbf{H} \\preceq \\mathbf{I}`.\n",
        "\n",
        "    In practice, we can compute each inner sum recursively. Starting with\n",
        "    :math:`\\mathbf{h}_{r, 0} = \\mathbf{v}`, we can iteratively update for :math:`T` steps:\n",
        "\n",
        "    .. math::\n",
        "        \\mathbf{h}_{r, t} = \\mathbf{v} + \\mathbf{h}_{r, t - 1} - \\sigma^{-1}\\mathbf{H}_{r, t}\\mathbf{h}_{r, t - 1}\n",
        "\n",
        "    where :math:`\\mathbf{h}_{r, T}` will be equal to the :math:`r`-th inner sum.\n",
        "\n",
        "    Args:\n",
        "        model: the model of interest.\n",
        "        objective: an implementation of :class:`BaseObjective`.\n",
        "        train_loader: a training dataset loader.\n",
        "        test_loader: a test dataset loader.\n",
        "        device: the device on which operations are performed.\n",
        "        damp: the damping strength :math:`\\lambda`. Influence functions assume that the\n",
        "            risk Hessian :math:`\\mathbf{H}` is positive-definite, which often fails to\n",
        "            hold for neural networks. Hence, a damped risk Hessian :math:`\\mathbf{H} + \\lambda\\mathbf{I}`\n",
        "            is used instead, for some sufficiently large :math:`\\lambda > 0` and\n",
        "            identity matrix :math:`\\mathbf{I}`.\n",
        "        repeat: the number of trials :math:`R`.\n",
        "        depth: the recurrence depth :math:`T`.\n",
        "        scale: the scaling factor :math:`\\sigma`.\n",
        "        gnh: if ``True``, the risk Hessian :math:`\\mathbf{H}` is approximated with\n",
        "            the Gauss-Newton Hessian, which is positive semi-definite.\n",
        "            Otherwise, the risk Hessian is used.\n",
        "        debug_callback: a callback function which is passed in :math:`(r, t, \\mathbf{h}_{r, t})`\n",
        "            at each recurrence step.\n",
        "     \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "            self,\n",
        "            model: nn.Module,\n",
        "            objective: BaseObjective,\n",
        "            train_loader: data.DataLoader,\n",
        "            test_loader: data.DataLoader,\n",
        "            device: torch.device,\n",
        "            damp: float,\n",
        "            repeat: int,\n",
        "            depth: int,\n",
        "            scale: float,\n",
        "            gnh: bool = False,\n",
        "            debug_callback: Optional[Callable[[int, int, torch.Tensor], None]] = None\n",
        "    ):\n",
        "\n",
        "        super().__init__(\n",
        "            model=model,\n",
        "            objective=objective,\n",
        "            train_loader=train_loader,\n",
        "            test_loader=test_loader,\n",
        "            device=device,\n",
        "        )\n",
        "\n",
        "        self.damp = damp\n",
        "        self.gnh = gnh\n",
        "        self.repeat = repeat\n",
        "        self.depth = depth\n",
        "        self.scale = scale\n",
        "        self.debug_callback = debug_callback\n",
        "\n",
        "    def inverse_hvp(self, vec, train_len, unlearning=False):\n",
        "\n",
        "        params = self._model_make_functional()\n",
        "        flat_params = self._flatten_params_like(params)\n",
        "\n",
        "        ihvp = 0.0\n",
        "\n",
        "        for r in range(self.repeat):\n",
        "\n",
        "            h_est = vec.clone()\n",
        "\n",
        "            for t, (batch, _) in enumerate(self._loader_wrapper(sample_n_batches=self.depth, train=True)):\n",
        "\n",
        "                hvp_batch = self._hvp_at_batch(batch, flat_params, vec=h_est, gnh=self.gnh)\n",
        "\n",
        "                with torch.no_grad():\n",
        "                    hvp_batch = hvp_batch + self.damp * h_est\n",
        "                    h_est = vec + h_est - hvp_batch / self.scale\n",
        "\n",
        "                if self.debug_callback is not None:\n",
        "                    self.debug_callback(r, t, h_est)\n",
        "\n",
        "            ihvp = ihvp + h_est / self.scale\n",
        "\n",
        "        with torch.no_grad():\n",
        "            if unlearning:\n",
        "                unlearned_params = flat_params + ihvp / self.repeat\n",
        "                self._model_reinsert_params(self._reshape_like_params(unlearned_params), register=True)\n",
        "            else:\n",
        "                self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)\n",
        "\n",
        "        return ihvp / self.repeat\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAgQ0xEnhJ8R"
      },
      "source": [
        "# Common Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zhu2n0bfj8io"
      },
      "outputs": [],
      "source": [
        "# train model\n",
        "def train_model(model, train_loader, test_loader, DEVICE, lr = 0.001, num_epochs = 10, l2_weight= 1e-6, dataset_type = 'adult'):\n",
        "    if dataset_type == 'adult':\n",
        "      criterion = nn.CrossEntropyLoss()\n",
        "    else:\n",
        "      criterion = nn.MSELoss()\n",
        "\n",
        "    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=l2_weight)\n",
        "\n",
        "    # Fine-tuning the model\n",
        "    for epoch in range(num_epochs):\n",
        "        model.train()\n",
        "        running_loss = 0.0\n",
        "\n",
        "        for data, labels in train_loader:\n",
        "            data, labels = data.to(DEVICE), labels.to(DEVICE)\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(data)\n",
        "            if dataset_type == 'adult':\n",
        "              loss = criterion(outputs, labels)\n",
        "            else:\n",
        "              loss = criterion(outputs, labels.float().view(-1, 1))\n",
        "\n",
        "            # Backward pass and optimization\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            running_loss += loss.item()\n",
        "\n",
        "        # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.6f}')\n",
        "\n",
        "        # Evaluate the model on the test set\n",
        "        model.eval()\n",
        "        correct = 0\n",
        "        total = 0\n",
        "\n",
        "        with torch.no_grad():\n",
        "            for data, labels in test_loader:\n",
        "                data, labels = data.to(DEVICE), labels.to(DEVICE)\n",
        "                outputs = model(data)\n",
        "                total += labels.size(0)\n",
        "                if dataset_type == 'adult':\n",
        "                  _, predicted = torch.max(outputs.data, 1)\n",
        "                  correct += (predicted == labels).sum().item()\n",
        "                else:\n",
        "                  correct += criterion(outputs, labels.float().view(-1, 1)).item()\n",
        "\n",
        "    if dataset_type == 'adult':\n",
        "      print(f'Accuracy of the model on the test data: {100 * correct / total:.2f}%')\n",
        "    else:\n",
        "      print(f'MSE of the model on the test data: {correct / total:.6f}')\n",
        "\n",
        "    print(f'Final Loss: {running_loss/len(train_loader):.6f}')\n",
        "    return model, correct / total\n",
        "\n",
        "def calculate_perf(model, test_loader, DEVICE, dataset_type = 'adult'):\n",
        "    # Evaluate the model on the test set\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    if dataset_type == 'adult':\n",
        "      criterion = nn.CrossEntropyLoss()\n",
        "    else:\n",
        "      criterion = nn.MSELoss()\n",
        "    with torch.no_grad():\n",
        "        for data, labels in test_loader:\n",
        "            data, labels = data.to(DEVICE), labels.to(DEVICE)\n",
        "            outputs = model(data)\n",
        "\n",
        "            total += labels.size(0)\n",
        "            if dataset_type == 'adult':\n",
        "              _, predicted = torch.max(outputs.data, 1)\n",
        "              correct += (predicted == labels).sum().item()\n",
        "            else:\n",
        "              correct += criterion(outputs, labels.float().view(-1, 1)).item()\n",
        "    if dataset_type == 'adult':\n",
        "      print(f'Accuracy of the model on the test data: {100 * correct / total:.2f}%')\n",
        "    else:\n",
        "      print(f'MSE of the model on the test data: {correct / total:.6f}')\n",
        "\n",
        "    return correct / total\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mY8pfBiGyffw"
      },
      "source": [
        "# Common Funcitons for Fairness Metrics Analysis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CgN2Xgz9yh4X"
      },
      "outputs": [],
      "source": [
        "# DP Loss\n",
        "def entropy_to_prob(entropy):  # Only for X Tensor of dimension 2\n",
        "    return entropy[:, 1].exp() / entropy.exp().sum(dim=1)\n",
        "\n",
        "def DemographicParity(model, encoded_data, to_protect, device='cpu'):\n",
        "\n",
        "    outputs = model(encoded_data.to(device))\n",
        "    preds = entropy_to_prob(outputs)\n",
        "    dp = torch.abs(torch.mean(preds[to_protect == 0]) - torch.mean(preds[to_protect == 1]))\n",
        "\n",
        "    return dp\n",
        "\n",
        "def calc_grad_dp(model, encoded_data, to_protect, device='cpu'):\n",
        "\n",
        "    dp = DemographicParity(model, encoded_data, to_protect, device=device)\n",
        "    dp_gradients = torch.autograd.grad(dp, model.parameters(), retain_graph=False, create_graph=False)\n",
        "\n",
        "    return dp.item(), dp_gradients\n",
        "\n",
        "def dp_calculations(model, encoded_data, to_protect, device='cpu'):\n",
        "\n",
        "    dp = DemographicParity(model, encoded_data, to_protect, device=device)\n",
        "    dp_gradients = torch.autograd.grad(dp, model.parameters(), retain_graph=False, create_graph=False)\n",
        "\n",
        "    return dp.item(), dp_gradients\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WGP6XAyxhGaR"
      },
      "outputs": [],
      "source": [
        "def _unsqueeze_multiple_times(input, axis, times):\n",
        "    \"\"\"\n",
        "    Utils function to unsqueeze tensor to avoid cumbersome code\n",
        "    :param input: A pytorch Tensor of dimensions (D_1,..., D_k)\n",
        "    :param axis: the axis to unsqueeze repeatedly\n",
        "    :param times: the number of repetitions of the unsqueeze\n",
        "    :return: the unsqueezed tensor. ex: dimensions (D_1,... D_i, 0,0,0, D_{i+1}, ... D_k) for unsqueezing 3x axis i.\n",
        "    \"\"\"\n",
        "    output = input\n",
        "    for i in range(times):\n",
        "        output = output.unsqueeze(axis)\n",
        "    return output\n",
        "\n",
        "class kde:\n",
        "    \"\"\"\n",
        "    A Gaussian KDE implemented in pytorch for the gradients to flow in pytorch optimization.\n",
        "    Keep in mind that KDE are not scaling well with the number of dimensions and this implementation is not really\n",
        "    optimized...\n",
        "    \"\"\"\n",
        "    def __init__(self, x_train):\n",
        "        n, d = x_train.shape\n",
        "        self.n = n\n",
        "        self.d = d\n",
        "        self.bandwidth = (n * (d + 2) / 4.) ** (-1. / (d + 4))\n",
        "        self.std = self.bandwidth\n",
        "        self.train_x = x_train\n",
        "\n",
        "    def pdf(self, x):\n",
        "        s = x.shape\n",
        "        d = s[-1]\n",
        "        s = s[:-1]\n",
        "        assert d == self.d\n",
        "        data = x.unsqueeze(-2).cuda()\n",
        "        train_x = _unsqueeze_multiple_times(self.train_x, 0, len(s))\n",
        "        pdf_values = (torch.exp(-((data - train_x).norm(dim=-1) ** 2 / (self.bandwidth ** 2) / 2))).mean(dim=-1) / sqrt(2 * pi) / self.bandwidth\n",
        "        return pdf_values\n",
        "\n",
        "def _joint_2(X, Y, density, damping=1e-10):\n",
        "    X = (X - X.mean()) / X.std()\n",
        "    Y = (Y - Y.mean()) / Y.std()\n",
        "    data = torch.cat([X.unsqueeze(-1), Y.unsqueeze(-1)], -1)\n",
        "    joint_density = density(data)\n",
        "\n",
        "    nbins = int(min(50, 5. / joint_density.std))\n",
        "    # nbins = np.sqrt( Y.size/5 )\n",
        "    x_centers = torch.linspace(-2.5, 2.5, nbins)\n",
        "    y_centers = torch.linspace(-2.5, 2.5, nbins)\n",
        "\n",
        "    xx, yy = torch.meshgrid([x_centers, y_centers])\n",
        "    grid = torch.cat([xx.unsqueeze(-1), yy.unsqueeze(-1)], -1)\n",
        "    h2d = joint_density.pdf(grid) + damping\n",
        "    h2d /= h2d.sum()\n",
        "    return h2d\n",
        "\n",
        "\n",
        "def chi_2(X, Y, density, damping=0):\n",
        "    \"\"\"\n",
        "    The \\chi^2 divergence between the joint distribution on (x,y) and the product of marginals. This is know to be the\n",
        "    square of an upper-bound on the Hirschfeld-Gebelein-Renyi maximum correlation coefficient. We compute it here on\n",
        "    an empirical and discretized density estimated from the input data.\n",
        "    :param X: A torch 1-D Tensor\n",
        "    :param Y: A torch 1-D Tensor\n",
        "    :param density: so far only kde is supported\n",
        "    :return: numerical value between 0 and \\infty (0: independent)\n",
        "    \"\"\"\n",
        "    h2d = _joint_2(X, Y, density, damping=damping)\n",
        "    marginal_x = h2d.sum(dim=1).unsqueeze(1)\n",
        "    marginal_y = h2d.sum(dim=0).unsqueeze(0)\n",
        "    Q = h2d / (torch.sqrt(marginal_x) * torch.sqrt(marginal_y))\n",
        "    return ((Q ** 2).sum(dim=[0, 1]) - 1.)\n",
        "\n",
        "def chi_squared_l1_kde(X, Y):\n",
        "    return chi_2(X, Y, kde)\n",
        "\n",
        "def chi2_calculations(model, train_data, train_protect, train_target, device='cpu'):\n",
        "\n",
        "    frac = 1\n",
        "    foo = torch.bernoulli(frac * torch.ones(train_data.shape[0])).byte().bool()\n",
        "\n",
        "    br = train_data[foo, :].to(device)\n",
        "    outputs = model(br).flatten()\n",
        "\n",
        "    z = train_protect[foo].to(device)\n",
        "    y = train_target[foo].to(device)\n",
        "\n",
        "    fairness_loss = chi_squared_l1_kde(outputs, z)\n",
        "\n",
        "    chi2_gradients = torch.autograd.grad(fairness_loss, model.parameters(), retain_graph=False, create_graph=False)\n",
        "\n",
        "    return fairness_loss.item(), chi2_gradients\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gyrawSYZj97v"
      },
      "source": [
        "# Main Code"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xrQOU33OiAFK"
      },
      "source": [
        "## Global hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HY4IPdiXiBle"
      },
      "outputs": [],
      "source": [
        "# Pick dataset\n",
        "dataset_type = 'adult' #'crime', 'adult', 'insurance'\n",
        "\n",
        "num_epochs = 200\n",
        "# Hyper Parameters\n",
        "if dataset_type == 'crime':\n",
        "  lr = 1e-4\n",
        "  batch_size = 100\n",
        "  num_classes = 1\n",
        "elif dataset_type == 'insurance':\n",
        "  lr = 1e-4\n",
        "  batch_size = 100\n",
        "  num_classes = 1\n",
        "else: # adults\n",
        "  lr = 1e-4\n",
        "  batch_size = 256\n",
        "  num_classes = 2\n",
        "\n",
        "l2_weight = 1e-6 # [0.0, 1e-8, 1e-6, 1e-4, 1e-2, 1e-1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_mqGuFcakXB7"
      },
      "source": [
        "Load data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h4n4Ao-JkV-x"
      },
      "outputs": [],
      "source": [
        "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using {DEVICE}\")\n",
        "\n",
        "if dataset_type == 'adult':\n",
        "  # load adult dataset, standardize features, and split into train and test\n",
        "  encoded_data, to_protect, encoded_data_test, to_protect_test = load_adult()\n",
        "  # prepare dataset for training\n",
        "  train_target = torch.tensor(encoded_data['Target']).long()\n",
        "  train_data = torch.tensor(encoded_data.drop('Target', axis=1).values.astype(np.float32))\n",
        "  train_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(train_data, train_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  train_protect = torch.tensor(to_protect).long().to(DEVICE)\n",
        "\n",
        "  test_target = torch.tensor(encoded_data_test['Target']).long()\n",
        "  test_data = torch.tensor(encoded_data_test.drop('Target', axis=1).values.astype(np.float32))\n",
        "  test_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(test_data, test_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  test_protect = torch.tensor(to_protect_test).long().to(DEVICE)\n",
        "  input_size = encoded_data.shape[1] - 1\n",
        "  influence_train_indices = list(range(len(encoded_data)))\n",
        "\n",
        "elif dataset_type == 'insurance':\n",
        "  x_train, x_test, z_train, z_test, y_train, y_test = load_insurance()\n",
        "  # prepare dataset for training\n",
        "  train_target = torch.tensor(y_train.astype(np.float32)).to(DEVICE)\n",
        "  train_data = torch.tensor(x_train.astype(np.float32)).to(DEVICE)\n",
        "  train_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(train_data, train_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  train_protect = torch.tensor(z_train.astype(np.float32)).to(DEVICE)\n",
        "\n",
        "  test_target = torch.tensor(y_test.astype(np.float32)).to(DEVICE)\n",
        "  test_data = torch.tensor(x_test.astype(np.float32)).to(DEVICE)\n",
        "  test_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(test_data, test_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  test_protect = torch.tensor(z_test.astype(np.float32)).to(DEVICE)\n",
        "  input_size = x_train.shape[1]\n",
        "  influence_train_indices = list(range(len(y_train)))\n",
        "\n",
        "elif dataset_type == 'crime':\n",
        "  x_train, x_test, z_train, z_test, y_train, y_test = read_crimes(env_partition=0.05)\n",
        "  train_target = torch.tensor(y_train.astype(np.float32)).to(DEVICE)\n",
        "  train_data = torch.tensor(x_train.astype(np.float32)).to(DEVICE)\n",
        "  train_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(train_data, train_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  train_protect = torch.tensor(z_train.astype(np.float32)).to(DEVICE)\n",
        "\n",
        "  test_target = torch.tensor(y_test.astype(np.float32)).to(DEVICE)\n",
        "  test_data = torch.tensor(x_test.astype(np.float32)).to(DEVICE)\n",
        "  test_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(test_data, test_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  test_protect = torch.tensor(z_test.astype(np.float32)).to(DEVICE)\n",
        "  input_size = x_train.shape[1]\n",
        "  influence_train_indices = list(range(len(y_train)))\n",
        "\n",
        "else:\n",
        "  # load adult dataset, standardize features, and split into train and test\n",
        "  encoded_data, to_protect, encoded_data_test, to_protect_test = load_adult()\n",
        "  # prepare dataset for training\n",
        "  train_target = torch.tensor(encoded_data['Target']).long()\n",
        "  train_data = torch.tensor(encoded_data.drop('Target', axis=1).values.astype(np.float32))\n",
        "  train_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(train_data, train_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  train_protect = torch.tensor(to_protect).long().to(DEVICE)\n",
        "\n",
        "  test_target = torch.tensor(encoded_data_test['Target']).long()\n",
        "  test_data = torch.tensor(encoded_data_test.drop('Target', axis=1).values.astype(np.float32))\n",
        "  test_loader = data_utils.DataLoader(dataset=data_utils.TensorDataset(test_data, test_target),\n",
        "                                       batch_size=batch_size, shuffle=True)\n",
        "  test_protect = torch.tensor(to_protect_test).long().to(DEVICE)\n",
        "  input_size = encoded_data.shape[1] - 1\n",
        "  influence_train_indices = list(range(len(encoded_data)))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U0Rb1OX0keSq"
      },
      "source": [
        "Lissa hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r1p8ztrakf0d"
      },
      "outputs": [],
      "source": [
        "# LiSSA parameters\n",
        "repeat_lissa = 3\n",
        "depth_lissa = 2000\n",
        "scale_lissa = 250 # Full with Adult: 500, crime: 2000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "79LYeA7-klbx"
      },
      "source": [
        "Main loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z51t-lB-jpoV"
      },
      "outputs": [],
      "source": [
        "\n",
        "class BinClassObjective_CE(BaseObjective):\n",
        "    def train_outputs(self, model, batch):\n",
        "        return model(batch[0])\n",
        "    def train_loss_on_outputs(self, outputs, batch):\n",
        "          return torch.nn.CrossEntropyLoss()(outputs, batch[1])\n",
        "    def train_regularization(self, params):\n",
        "        return l2_weight * torch.square(params.norm())\n",
        "    def test_loss(self, model, params, batch):\n",
        "        outputs = model(batch[0])\n",
        "        return torch.nn.CrossEntropyLoss()(outputs, batch[1])\n",
        "\n",
        "class BinClassObjective_MSE(BaseObjective):\n",
        "    def train_outputs(self, model, batch):\n",
        "        return model(batch[0])\n",
        "    def train_loss_on_outputs(self, outputs, batch):\n",
        "          return torch.nn.MSELoss()(outputs, batch[1].float().view(-1, 1))\n",
        "    def train_regularization(self, params):\n",
        "        return l2_weight * torch.square(params.norm())\n",
        "    def test_loss(self, model, params, batch):\n",
        "        outputs = model(batch[0])\n",
        "        return torch.nn.MSELoss()(outputs, batch[1].float().view(-1, 1))\n",
        "\n",
        "# Run the experiments\n",
        "num_tests = 10\n",
        "fairness_orig = []\n",
        "fairness_remove_Fisher = []\n",
        "fairness_remove_Hessian = []\n",
        "perf_ERM = []\n",
        "perf_Fisher = []\n",
        "perf_Hessian = []\n",
        "\n",
        "time_Newton = []\n",
        "time_gnh = []\n",
        "\n",
        "for exp_idx in tqdm(range(num_tests)):\n",
        "    print(f\"Experiment {exp_idx + 1}/{num_tests}\")\n",
        "\n",
        "    # ===========\n",
        "    # train the model\n",
        "    model = NetRegression(input_size, num_classes).to(DEVICE)\n",
        "    trained_model, orig_perf = train_model(model, train_loader, test_loader, DEVICE, lr = lr, \\\n",
        "        num_epochs = num_epochs, l2_weight= l2_weight, dataset_type=dataset_type)\n",
        "    if dataset_type == 'adult':\n",
        "      # error rate\n",
        "      perf_ERM.append(1.0 - orig_perf)\n",
        "    else:\n",
        "      # MSE\n",
        "      perf_ERM.append(orig_perf)\n",
        "\n",
        "    # ===========\n",
        "    # Calculate fairness score and its gradient\n",
        "    # ===========\n",
        "    if dataset_type == 'adult':\n",
        "      fairness_score, fairness_grad = dp_calculations(trained_model, train_data, train_protect, device=DEVICE)\n",
        "      fairness_orig.append(fairness_score)\n",
        "      print('Original fairness score: ' + str(fairness_score))\n",
        "    else:\n",
        "      fairness_score, fairness_grad = chi2_calculations(model, train_data, train_protect, train_target, device=DEVICE)\n",
        "      fairness_orig.append(fairness_score)\n",
        "      print('Original fairness score: ' + str(fairness_score))\n",
        "\n",
        "    # ===========\n",
        "    # Initialize influence module using custom objective\n",
        "    # ===========\n",
        "    # GNH\n",
        "    if dataset_type == 'adult':\n",
        "      curr_net_gnh = copy.deepcopy(trained_model)\n",
        "      lissa_gnh = LiSSAInfluenceModule(model=curr_net_gnh,objective=BinClassObjective_CE(),train_loader=train_loader,\n",
        "                                      test_loader=test_loader,device=DEVICE,damp=0.001,repeat=repeat_lissa,depth=depth_lissa,\n",
        "                                      scale=scale_lissa,gnh='gnh')\n",
        "      # Newton\n",
        "      curr_net_Newton = copy.deepcopy(trained_model)\n",
        "      lissa_Newton = LiSSAInfluenceModule(model=curr_net_Newton,objective=BinClassObjective_CE(),train_loader=train_loader,\n",
        "                                      test_loader=test_loader,device=DEVICE,damp=0.001,repeat=repeat_lissa,depth=depth_lissa,\n",
        "                                          scale=scale_lissa,gnh='Hessian')\n",
        "    else:\n",
        "      curr_net_gnh = copy.deepcopy(trained_model)\n",
        "      lissa_gnh = LiSSAInfluenceModule(model=curr_net_gnh,objective=BinClassObjective_MSE(),train_loader=train_loader,\n",
        "                                      test_loader=test_loader,device=DEVICE,damp=0.001,repeat=repeat_lissa,depth=depth_lissa,\n",
        "                                      scale=scale_lissa,gnh='gnh')\n",
        "      # Newton\n",
        "      curr_net_Newton = copy.deepcopy(trained_model)\n",
        "      lissa_Newton = LiSSAInfluenceModule(model=curr_net_Newton,objective=BinClassObjective_MSE(),train_loader=train_loader,\n",
        "                                      test_loader=test_loader,device=DEVICE,damp=0.001,repeat=repeat_lissa,depth=depth_lissa,\n",
        "                                          scale=scale_lissa,gnh='Hessian')\n",
        "\n",
        "    # ===========\n",
        "    # Calculate the influence score of each training point, where the functional is the DP fairness metric\n",
        "    # We repeat the experiment several times and report the running time and the error between the Hessian-based and the FIM-based\n",
        "    # calculations\n",
        "    # ===========\n",
        "\n",
        "    # Calculate the influence, and find the training points for which we achieve a negative influence\n",
        "    # Then, remove the samples and re-calculate the DP and ERM Loss\n",
        "\n",
        "    # Fisher\n",
        "    influences_gnh, curr_time_gnh = lissa_gnh.influences(train_idxs=influence_train_indices, test_idxs=[], target_grad = fairness_grad)\n",
        "    indices_gnh = np.where(influences_gnh.detach().numpy() < 0.0)[0]\n",
        "    new_model_gnh, curr_time_gnh2 = lissa_gnh.unlearning(indices_gnh)\n",
        "\n",
        "    if dataset_type == 'adult':\n",
        "      fairness_score, _ = dp_calculations(new_model_gnh, train_data, train_protect, device=DEVICE)\n",
        "      curr_error_rate = calculate_perf(new_model_gnh, test_loader, DEVICE, dataset_type = 'adult')\n",
        "      perf_Fisher.append(1.0 - curr_error_rate)\n",
        "      fairness_remove_Fisher.append(fairness_score)\n",
        "    else:\n",
        "      fairness_score, _ = chi2_calculations(new_model_gnh, train_data, train_protect, train_target, device=DEVICE)\n",
        "      curr_mse = calculate_perf(new_model_gnh, test_loader, DEVICE, dataset_type = dataset_type)\n",
        "      perf_Fisher.append(curr_mse)\n",
        "      fairness_remove_Fisher.append(fairness_score)\n",
        "\n",
        "    time_gnh.append(curr_time_gnh + curr_time_gnh2)\n",
        "    print('fairness score Fisher: ' + str(fairness_score))\n",
        "\n",
        "    # Hessian\n",
        "    influences_hessian, curr_time_hessian = lissa_Newton.influences(train_idxs=influence_train_indices, test_idxs=[], target_grad = fairness_grad)\n",
        "    indices_Newton = np.where(influences_hessian.detach().numpy() < 0.0)[0]\n",
        "    new_model_Newton, curr_time_hessian2 = lissa_Newton.unlearning(indices_Newton)\n",
        "    if dataset_type == 'adult':\n",
        "      fairness_score, _ = dp_calculations(new_model_Newton, train_data, train_protect, device=DEVICE)\n",
        "      curr_error_rate = calculate_perf(new_model_Newton, test_loader, DEVICE, dataset_type = 'adult')\n",
        "      perf_Hessian.append(1.0 - curr_error_rate)\n",
        "      fairness_remove_Hessian.append(fairness_score)\n",
        "    else:\n",
        "      fairness_score, _ = chi2_calculations(new_model_Newton, train_data, train_protect, train_target, device=DEVICE)\n",
        "      curr_mse = calculate_perf(new_model_Newton, test_loader, DEVICE, dataset_type = dataset_type)\n",
        "      perf_Hessian.append(curr_mse)\n",
        "      fairness_remove_Hessian.append(fairness_score)\n",
        "\n",
        "    time_Newton.append(curr_time_hessian + curr_time_hessian2)\n",
        "\n",
        "    print('fairness score Hessian: ' + str(fairness_score))\n",
        "    print('Finished experiment: ' + str(exp_idx))\n",
        "    print('=====================================')\n",
        "    print('time gnh: ' + str(curr_time_gnh + curr_time_gnh2))\n",
        "    print('time hessian: ' + str(curr_time_hessian + curr_time_hessian2))\n",
        "    print('=====================================')\n",
        "\n",
        "print('=======Finished: start saving figures=======')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gf0H2d72kNOt"
      },
      "source": [
        "Plots"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RAd-kxoCkN_O"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "plt.rcParams.update({\n",
        "    \"axes.titlesize\": 32,\n",
        "    \"axes.labelsize\": 26,\n",
        "})\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(10.4, 7.0))\n",
        "\n",
        "col_hess  = \"#14b8a6\"  # teal\n",
        "col_fish  = \"#3b82f6\"  # blue\n",
        "col_erm   = \"#8b5cf6\"  # violet\n",
        "\n",
        "ax.scatter(\n",
        "    fairness_remove_Hessian, perf_Hessian,\n",
        "    label=\"Hessian\", marker=\"*\", s=600,\n",
        "    c=col_hess, edgecolors=\"white\", linewidths=1.4, alpha=0.95, zorder=3\n",
        ")\n",
        "\n",
        "ax.scatter(\n",
        "    fairness_remove_Fisher, perf_Fisher,\n",
        "    label=\"Fisher (ours)\", marker=\"o\", s=260,\n",
        "    c=col_fish, edgecolors=\"white\", linewidths=1.2, alpha=0.95, zorder=3\n",
        ")\n",
        "ax.scatter(\n",
        "    fairness_orig, perf_ERM,\n",
        "    label=\"ERM\", marker=\"s\", s=260,\n",
        "    c=col_erm, edgecolors=\"white\", linewidths=1.2, alpha=0.95, zorder=3\n",
        ")\n",
        "\n",
        "if dataset_type == \"adult\":\n",
        "    ax.set_xlabel(r\"fairness metric\", fontweight=\"bold\", fontsize=30)\n",
        "    ax.set_ylabel(\"Error rate\", fontweight=\"bold\", fontsize=30)\n",
        "else:\n",
        "    ax.set_xlabel(r\"fairness metric\", fontweight=\"bold\", fontsize=30)\n",
        "    ax.set_ylabel(\"MSE\", fontweight=\"bold\", fontsize=30)\n",
        "\n",
        "ax.set_title(rf\"Speedup: {np.mean(time_Newton)/np.mean(time_gnh):.2f}\",\n",
        "             fontweight=\"bold\")\n",
        "\n",
        "ax.set_xscale(\"log\")\n",
        "ax.set_yscale(\"log\")\n",
        "\n",
        "ax.tick_params(axis=\"both\", which=\"both\", labelsize=24)\n",
        "\n",
        "ax.grid(True, which=\"both\", linestyle=\"--\", linewidth=0.6, alpha=0.5)\n",
        "\n",
        "leg = ax.legend(\n",
        "    loc=\"upper center\",\n",
        "    bbox_to_anchor=(0.5, -0.22),\n",
        "    ncol=3,\n",
        "    frameon=False,\n",
        "    handletextpad=0.8,\n",
        "    columnspacing=0.25,\n",
        "    borderpad=0.6,\n",
        "    labelspacing=0.6,\n",
        "    prop=dict(weight=\"bold\", size=26),\n",
        "    markerscale=1.4,\n",
        "    handlelength=2.0\n",
        ")\n",
        "\n",
        "fig.tight_layout()\n",
        "outname = {\n",
        "    \"adult\":     f\"Error_rate_DP_adult_repeat_{repeat_lissa}_depth_{depth_lissa}_scale_{scale_lissa}.pdf\",\n",
        "    \"insurance\": f\"MSE_chi2_insurance_repeat_{repeat_lissa}_depth_{depth_lissa}_scale_{scale_lissa}.pdf\",\n",
        "}.get(dataset_type, f\"MSE_chi2_crime_repeat_{repeat_lissa}_depth_{depth_lissa}_scale_{scale_lissa}.pdf\")\n",
        "\n",
        "fig.savefig(outname, bbox_inches=\"tight\")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}