{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "edbe7f63",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenyu/miniconda3/envs/DLcourse/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "import numpy as np\n",
    "from time import time\n",
    "import os\n",
    "from collections import OrderedDict\n",
    "import FrEIA.framework as Ff\n",
    "import FrEIA.modules as Fm\n",
    "from tqdm import tqdm, trange\n",
    "import random\n",
    "\n",
    "torch.set_num_threads(5)\n",
    "\n",
    "def seed_everything(seed):\n",
    "    \"\"\"\n",
    "    Changes the seed for reproducibility. \n",
    "    \"\"\"\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "seed_everything(0)\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "path = r\"adult.json\"\n",
    "test_path = r\"adult_test.json\"\n",
    "train_df = pd.read_json(path)\n",
    "test_df = pd.read_json(test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f6f313a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def subnet_fc(c_in, c_out, init_identity):\n",
    "    subnet = nn.Sequential(nn.Linear(c_in, 32), nn.ReLU(),\n",
    "                            nn.Linear(32, 32), nn.ReLU(),\n",
    "                            nn.Linear(32,  c_out))\n",
    "    if init_identity:\n",
    "        subnet[-1].weight.data.fill_(0.)\n",
    "        subnet[-1].bias.data.fill_(0.)\n",
    "    return subnet\n",
    "\n",
    "\n",
    "def construct_net(init_identity=True):\n",
    "    block = Fm.GINCouplingBlock\n",
    "    nodes = [Ff.InputNode(32, name='input')]\n",
    "    \n",
    "    for k in range(18):\n",
    "        nodes.append(Ff.Node(nodes[-1], block,\n",
    "                             {'subnet_constructor':lambda c_in,c_out: subnet_fc(c_in, c_out, init_identity), 'clamp':2.0},\n",
    "                             name=F'coupling_{k}'))\n",
    "        nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom,\n",
    "                        {'seed':np.random.randint(2**31)},\n",
    "                        name=F'permute_{k+1}'))\n",
    "\n",
    "    nodes.append(Ff.OutputNode(nodes[-1], name='output'))\n",
    "    return Ff.ReversibleGraphNet(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6945f027",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ReversibleGraphNet(\n",
      "  (module_list): ModuleList(\n",
      "    (0): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (1): PermuteRandom()\n",
      "    (2): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (3): PermuteRandom()\n",
      "    (4): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (5): PermuteRandom()\n",
      "    (6): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (7): PermuteRandom()\n",
      "    (8): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (9): PermuteRandom()\n",
      "    (10): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (11): PermuteRandom()\n",
      "    (12): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (13): PermuteRandom()\n",
      "    (14): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (15): PermuteRandom()\n",
      "    (16): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (17): PermuteRandom()\n",
      "    (18): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (19): PermuteRandom()\n",
      "    (20): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (21): PermuteRandom()\n",
      "    (22): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (23): PermuteRandom()\n",
      "    (24): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (25): PermuteRandom()\n",
      "    (26): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (27): PermuteRandom()\n",
      "    (28): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (29): PermuteRandom()\n",
      "    (30): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (31): PermuteRandom()\n",
      "    (32): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (33): PermuteRandom()\n",
      "    (34): GINCouplingBlock(\n",
      "      (subnet1): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "      (subnet2): Sequential(\n",
      "        (0): Linear(in_features=16, out_features=32, bias=True)\n",
      "        (1): ReLU()\n",
      "        (2): Linear(in_features=32, out_features=32, bias=True)\n",
      "        (3): ReLU()\n",
      "        (4): Linear(in_features=32, out_features=32, bias=True)\n",
      "      )\n",
      "    )\n",
      "    (35): PermuteRandom()\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = construct_net()\n",
    "weight_site = r\"Adult/model_121.pt\"\n",
    "net.load_state_dict(torch.load(weight_site, map_location=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ac355d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def get_naive_dataset(dataset):\n",
    "    X = dataset.drop(['income'], axis=1)\n",
    "    y = dataset['income']\n",
    "    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state =2048)# 2**(0))\n",
    "    a_train = x_train['sex']\n",
    "    a_test = x_test['sex']\n",
    "    x_train = x_train.drop(['race','sex'], axis=1)\n",
    "    x_test = x_test.drop(['race','sex'], axis=1)\n",
    "    return (x_train, y_train, a_train), (x_test, y_test, a_test)\n",
    "\n",
    "(x_train, y_train, a_train), (x_valid, y_valid, a_valid) = get_naive_dataset(train_df)\n",
    "x_test, y_test, a_test = test_df.drop(['income','race','sex'], axis=1), test_df['income'], test_df['sex']\n",
    "\n",
    "\n",
    "batch_size = 256\n",
    "def make_dataloader(data, y, a,batch_size):\n",
    "    dataset = BasicDataset(data, y, a)\n",
    "    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "class BasicDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data, target, sensitive):\n",
    "        super().__init__()\n",
    "        self.data = torch.tensor(data.values)\n",
    "        self.target = torch.tensor(target.values)\n",
    "        self.sensitive = torch.tensor(sensitive.values)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data.size(0)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data.float()[idx], self.target[idx], self.sensitive[idx]\n",
    "    \n",
    "training_loader = make_dataloader(x_train, y_train, a_train, batch_size)\n",
    "valid_loader = make_dataloader(x_valid, y_valid, a_valid, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "041a3ef0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 22.48it/s]\n"
     ]
    }
   ],
   "source": [
    "net = net.to(device)\n",
    "valid_data = []\n",
    "valid_sensitive_list = []\n",
    "valid_target_list = []\n",
    "examples = iter(valid_loader)\n",
    "n_batches = len(examples)\n",
    "for i in trange(n_batches):\n",
    "    data, label, sensitive = next(examples)\n",
    "    data = data.to(device)\n",
    "    z, logdet = net(data)\n",
    "    valid_data.extend(z.detach().cpu().numpy())\n",
    "    valid_sensitive_list.extend(sensitive.detach().cpu().numpy())\n",
    "    valid_target_list.extend(label.detach().cpu().numpy())\n",
    "\n",
    "x_valid = np.array(valid_data)\n",
    "valid_sensitive = np.array(valid_sensitive_list)\n",
    "valid_target = np.array(valid_target_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f07e88c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#MMD correlation\n",
    "#https://github.com/torchdrift/torchdrift/blob/master/notebooks/note_on_mmd.ipynb\n",
    "def mmd(x, y, sigma):\n",
    "    n, d = x.shape\n",
    "    m, d2 = y.shape\n",
    "    assert d == d2\n",
    "    xy = torch.cat([x.detach(), y.detach()], dim=0)\n",
    "    dists = torch.cdist(xy, xy, p=2.0)\n",
    "    k = torch.exp((-1/(2*sigma**2)) * dists**2) + torch.eye(n+m)*1e-5\n",
    "    k_x = k[:n, :n]\n",
    "    k_y = k[n:, n:]\n",
    "    k_xy = k[:n, n:]\n",
    "    mmd = k_x.sum() / (n * (n - 1)) + k_y.sum() / (m * (m - 1)) - 2 * k_xy.sum() / (n * m)\n",
    "    return mmd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ec0dd5a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:08<00:00,  3.86it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "label_score = []\n",
    "percentage_to_use = 100\n",
    "total_indices = np.arange(x_valid.shape[0])\n",
    "selected_indices = np.random.choice(total_indices, size=int(percentage_to_use / 100* len(total_indices)), replace=False)\n",
    "\n",
    "selected_data = x_valid[selected_indices]\n",
    "selected_labels = valid_target[selected_indices]\n",
    "\n",
    "selected_data_0 = torch.tensor(selected_data[selected_labels== 0])\n",
    "selected_data_1 = torch.tensor(selected_data[selected_labels== 1])\n",
    "\n",
    "\n",
    "for i in tqdm(range(selected_data_0.shape[1])):\n",
    "    label0_subsample = selected_data_0[:,i]\n",
    "    label1_subsample = selected_data_1[:,i]\n",
    "    score_for_dim = mmd(label0_subsample[:, None],label1_subsample[:, None],7)\n",
    "    label_score.append(score_for_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "23285da2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:08<00:00,  3.87it/s]\n"
     ]
    }
   ],
   "source": [
    "score = []\n",
    "percentage_to_use = 100\n",
    "total_indices = np.arange(x_valid.shape[0])\n",
    "selected_indices = np.random.choice(total_indices, size=int(percentage_to_use / 100* len(total_indices)), replace=False)\n",
    "\n",
    "\n",
    "selected_data = x_valid[selected_indices]\n",
    "selected_labels = valid_sensitive[selected_indices]\n",
    "\n",
    "selected_data_0 = torch.tensor(selected_data[selected_labels == 0])\n",
    "selected_data_1 = torch.tensor(selected_data[selected_labels == 1])\n",
    "\n",
    "\n",
    "for i in tqdm(range(selected_data_0.shape[1])):\n",
    "    sensitive0_subsample = selected_data_0[:,i]\n",
    "    sensitive1_subsample = selected_data_1[:,i]\n",
    "    score_for_dim = mmd(sensitive0_subsample[:, None],sensitive1_subsample[:, None],7)\n",
    "    score.append(score_for_dim)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b2c43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(21,)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dim_sen = 6\n",
    "score_np = np.array(score)\n",
    "large_index_sensitive = score_np.argsort()[-1*dim_sen:][::-1]\n",
    "dim_label = 26\n",
    "label_score_np = np.array(label_score)\n",
    "sorted_indices = np.argsort(label_score_np)\n",
    "largest_two_indices = sorted_indices[-1*dim_label:]\n",
    "dim_to_delete = large_index_sensitive\n",
    "common_elements = np.intersect1d(largest_two_indices, dim_to_delete)\n",
    "common_elements.shape\n",
    "label_score_without_common = np.setdiff1d(largest_two_indices, common_elements)\n",
    "label_score_without_common.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fc203253",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_np =np.array(x_train)\n",
    "training_data =train_data_np[:, label_score_without_common]\n",
    "test_data_np = np.array(x_test)\n",
    "test_data =test_data_np[:, label_score_without_common]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f5298010",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_target = np.array(y_train)\n",
    "test_target = np.array(y_test)\n",
    "test_sensitive = np.array(a_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6fed5db3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Basic_dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data, target, sensitive):\n",
    "        super().__init__()\n",
    "        self.data = torch.tensor(data)\n",
    "        self.target = torch.tensor(target)\n",
    "        self.sensitive = torch.tensor(sensitive)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data.size(0)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data.float()[idx], self.target[idx], self.sensitive[idx]\n",
    "    \n",
    "train_set = Basic_dataset(training_data,train_target,  np.array(a_train))\n",
    "test_set = Basic_dataset(test_data, test_target, test_sensitive)\n",
    "training_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f7148a5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tablur_Model(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(Tablur_Model, self).__init__()\n",
    "        nodes = 32\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.dropout = nn.Dropout()\n",
    "        self.fc1 = nn.Linear(21, nodes)\n",
    "        self.fc2 = nn.Linear(nodes, 2*nodes)\n",
    "        self.fc = nn.Linear(2*nodes, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2c47448b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 0.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 18.30batch/s, ul=0.369]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10, Loss: 0.5218\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 18.35batch/s, ul=0.368]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/10, Loss: 0.3594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 2.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 18.39batch/s, ul=0.362]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/10, Loss: 0.3458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 3.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 18.21batch/s, ul=0.408]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/10, Loss: 0.3429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 4.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 17.84batch/s, ul=0.311]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/10, Loss: 0.3406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 5.000000 : 100%|█████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 18.01batch/s, ul=0.35]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/10, Loss: 0.3394\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 6.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.72batch/s, ul=0.336]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/10, Loss: 0.3387\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 7.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.73batch/s, ul=0.308]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/10, Loss: 0.3375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 8.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:04<00:00, 17.80batch/s, ul=0.358]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/10, Loss: 0.3372\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 9.000000 : 100%|████████████████████████████████████████████████████████████| 89/89 [00:05<00:00, 17.79batch/s, ul=0.421]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/10, Loss: 0.3374\n",
      "***********************************************\n",
      "test procedure\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 59/59 [00:02<00:00, 26.20batch/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Female TPR 0.573055028462998\n",
      "male TPR 0.5880829015544041\n",
      "DP 0.00648671270140197\n",
      "EOP 0.015027873091406074\n",
      "EoD 0.01131386727054838\n",
      "acc 0.8416998671978752\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix, accuracy_score\n",
    "import torch.optim as optim\n",
    "import random\n",
    "\n",
    "\n",
    "def test(net, dataloader, print_fairness=True):\n",
    "    net.eval()\n",
    "    test_pred = []\n",
    "    test_gt = []\n",
    "    sense_gt = []\n",
    "    female_predic = []\n",
    "    female_gt = []\n",
    "    male_predic = []\n",
    "    male_gt = []\n",
    "    epoch_loss = 0\n",
    "    criteria= nn.BCEWithLogitsLoss()\n",
    "    with torch.no_grad():\n",
    "        with tqdm(dataloader, unit=\"batch\") as tepoch:\n",
    "            for content in tepoch:\n",
    "                test, label, _ = content\n",
    "                test = test.to(device)\n",
    "                label = label.unsqueeze(1)\n",
    "                prediction = net(test)\n",
    "                label = label.to(torch.float).to(device)\n",
    "                loss = criteria(prediction, label)\n",
    "                epoch_loss += loss.item()\n",
    "                prediction = torch.sigmoid(prediction)\n",
    "                gt = label.detach().cpu().numpy()\n",
    "                sen = sensitive.detach().cpu().numpy()\n",
    "                test_pred.extend(torch.round(prediction.squeeze(1)).detach().cpu().numpy())\n",
    "                test_gt.extend(gt)\n",
    "                sense_gt.extend(sen)\n",
    "        epoch_loss = epoch_loss/len(dataloader)\n",
    "\n",
    "        for i in range(len(sense_gt)):\n",
    "            if sense_gt[i] == 0:\n",
    "                female_predic.append(test_pred[i])\n",
    "                female_gt.append(test_gt[i])\n",
    "            else:\n",
    "                male_predic.append(test_pred[i])\n",
    "                male_gt.append(test_gt[i])\n",
    "\n",
    "        female_CM = confusion_matrix(female_gt, female_predic)    \n",
    "        male_CM = confusion_matrix(male_gt, male_predic) \n",
    "        female_dp = (female_CM[1][1]+female_CM[0][1])/(female_CM[0][0]+female_CM[0][1]+female_CM[1][0]+female_CM[1][1])\n",
    "        male_dp = (male_CM[1][1]+male_CM[0][1])/(male_CM[0][0]+male_CM[0][1]+male_CM[1][0]+male_CM[1][1])\n",
    "        female_TPR = female_CM[1][1]/(female_CM[1][1]+female_CM[1][0])\n",
    "        male_TPR = male_CM[1][1]/(male_CM[1][1]+male_CM[1][0])\n",
    "        female_FPR = female_CM[0][1]/(female_CM[0][1]+female_CM[0][0])\n",
    "        male_FPR = male_CM[0][1]/(male_CM[0][1]+male_CM[0][0])\n",
    "        if print_fairness == True:\n",
    "            EOd = 0.5*(abs(female_FPR-male_FPR)+abs(female_TPR-male_TPR))\n",
    "            print('Female TPR', female_TPR)\n",
    "            print('male TPR', male_TPR)\n",
    "            print('DP',abs(female_dp - male_dp))\n",
    "            print('EOP', abs(female_TPR - male_TPR))\n",
    "            print('EoD', EOd)\n",
    "            print('acc', accuracy_score(test_gt, test_pred))\n",
    "        else:\n",
    "            EOd = 0.5*(abs(female_FPR-male_FPR)+abs(female_TPR-male_TPR))\n",
    "        return accuracy_score(test_gt, test_pred), epoch_loss, EOd\n",
    "\n",
    "\n",
    "def train_model():\n",
    "    model = Tablur_Model()\n",
    "    model = model.to(device)\n",
    "    epoch = 10\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)\n",
    "    train_loss = []\n",
    "    valid_loss =[]\n",
    "    \n",
    "    for epoches in range(epoch):\n",
    "        with tqdm(training_loader, unit=\"batch\") as tepoch:\n",
    "            model.train()\n",
    "            running_loss = 0.0\n",
    "            for train_input, train_target, _ in tepoch:\n",
    "                train_input = train_input.float().to(device)\n",
    "                train_target = train_target.float().to(device)\n",
    "                train_target = train_target.unsqueeze(1)\n",
    "                optimizer.zero_grad()\n",
    "                outputs = model(train_input)\n",
    "                loss = criterion(outputs, train_target)\n",
    "                tepoch.set_postfix(ul = loss.item())  \n",
    "                loss.backward()\n",
    "                running_loss += loss.item()\n",
    "                optimizer.step()\n",
    "                tepoch.set_description(f\"epoch %2f \" % epoches)\n",
    "        print(f'Epoch {epoches+1}/{epoch}, Loss: {running_loss/len(training_loader):.4f}')\n",
    "        run_loss = running_loss/len(training_loader)\n",
    "        train_loss.append(run_loss)\n",
    "        \n",
    "        #valid_acc, _, valid_EOd = test(model, valid_loader, False)\n",
    "        #print('Valid ACC:',valid_acc, 'Valid EOd', valid_EOd)\n",
    "    return train_loss, valid_loss, model\n",
    "        \n",
    "\n",
    "tl,vl, model = train_model()\n",
    "\n",
    "print(\"***********************************************\")\n",
    "print(\"test procedure\")\n",
    "_, _, _ = test(model, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DLcourse]",
   "language": "python",
   "name": "conda-env-DLcourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
