{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FairBatch on the Synthetic Data\n",
    "\n",
    "#### This Jupyter Notebook simulates FairBatch on the synthetic data.\n",
    "#### It includes three fairness metrics: equal opportunity, equalized odds, and demographic parity."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import numpy as np\n",
    "import math\n",
    "import random\n",
    "import itertools\n",
    "import copy\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.utils.data.sampler import Sampler\n",
    "import torch\n",
    "\n",
    "from models import LogisticRegression, weights_init_normal, test_model\n",
    "from FairBatchSampler import FairBatch, CustomDataset\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and process the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "xz_train = np.load('./synthetic_data/xz_train.npy')\n",
    "y_train = np.load('./synthetic_data/y_train.npy') \n",
    "z_train = np.load('./synthetic_data/z_train.npy')\n",
    "\n",
    "xz_test = np.load('./synthetic_data/xz_test.npy')\n",
    "y_test = np.load('./synthetic_data/y_test.npy') \n",
    "z_test = np.load('./synthetic_data/z_test.npy')\n",
    "\n",
    "xz_train = torch.FloatTensor(xz_train)\n",
    "y_train = torch.FloatTensor(y_train)\n",
    "z_train = torch.FloatTensor(z_train)\n",
    "\n",
    "xz_test = torch.FloatTensor(xz_test)\n",
    "y_test = torch.FloatTensor(y_test)\n",
    "z_test = torch.FloatTensor(z_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---------- Number of Data ----------\n",
      "Train data : 2000, Test data : 1000 \n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "print(\"---------- Number of Data ----------\" )\n",
    "print(\n",
    "    \"Train data : %d, Test data : %d \"\n",
    "    % (len(y_train), len(y_test))\n",
    ")       \n",
    "print(\"------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_epoch(model, train_features, labels, optimizer, criterion):\n",
    "    \"\"\"Trains the model with the given train data.\n",
    "\n",
    "    Args:\n",
    "        model: A torch model to train.\n",
    "        train_features: A torch tensor indicating the train features.\n",
    "        labels: A torch tensor indicating the true labels.\n",
    "        optimizer: A torch optimizer.\n",
    "        criterion: A torch criterion.\n",
    "\n",
    "    Returns:\n",
    "        loss value.\n",
    "    \"\"\"\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    label_predicted = model.forward(train_features)\n",
    "    loss  = criterion((F.tanh(label_predicted.squeeze())+1)/2, (labels.squeeze()+1)/2)\n",
    "    loss.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "    \n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. FairBatch w.r.t. Equal Opportunity\n",
    "### The results are in Section 4.1 of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 5 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 6 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 7 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 8 >\n",
      "  Test accuracy: 0.8569999933242798, EO disparity: 0.014029686002265618\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 9 >\n",
      "  Test accuracy: 0.8550000190734863, EO disparity: 0.012102903728662517\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "full_tests = []\n",
    "\n",
    "# Set the train data\n",
    "train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "\n",
    "seeds = [0,1,2,3,4,5,6,7,8,9]\n",
    "for seed in seeds:\n",
    "    \n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "    \n",
    "    model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Define FairBatch and DataLoader\n",
    "    # ---------------------\n",
    "\n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = 0.005, target_fairness = 'eqopp', replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "    for epoch in range(300):\n",
    "\n",
    "        tmp_loss = []\n",
    "        \n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "            \n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "        \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, EO disparity: {}\".format(tmp_test['Acc'], tmp_test['EO_Y1_diff']))\n",
    "    print(\"----------------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.8552000164985657\n",
      "EO disparity  (avg): 0.012295581956022827\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_eo = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_eo.append(full_tests[i]['EO_Y1_diff'])\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(sum(tmp_acc)/len(tmp_acc)))\n",
    "print(\"EO disparity  (avg): {}\".format(sum(tmp_eo)/len(tmp_eo)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. FairBatch w.r.t. Equalized Odds \n",
    "### The results are in the supplementary of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8579999804496765, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8550000190734863, ED disparity: 0.04270697728641655\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8560000061988831, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8560000061988831, ED disparity: 0.04270697728641655\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8579999804496765, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 5 >\n",
      "  Test accuracy: 0.8560000061988831, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 6 >\n",
      "  Test accuracy: 0.8569999933242798, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 7 >\n",
      "  Test accuracy: 0.8569999933242798, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 8 >\n",
      "  Test accuracy: 0.8579999804496765, ED disparity: 0.035440184972895264\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 9 >\n",
      "  Test accuracy: 0.8519999980926514, ED disparity: 0.04270697728641655\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "full_tests = []\n",
    "\n",
    "# Set the train data\n",
    "train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "\n",
    "seeds = [0,1,2,3,4,5,6,7,8,9]\n",
    "for seed in seeds:\n",
    "    \n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "    \n",
    "    model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Define FairBatch and DataLoader\n",
    "    # ---------------------\n",
    "    \n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = 0.005, target_fairness = 'eqodds', replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "    for epoch in range(400):\n",
    "\n",
    "        tmp_loss = []\n",
    "        \n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "            \n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "        \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, ED disparity: {}\".format(tmp_test['Acc'], tmp_test['EqOdds_diff']))\n",
    "    print(\"----------------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.8562999963760376\n",
      "ED disparity  (avg): 0.03762022266695164\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_ed = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_ed.append(full_tests[i]['EqOdds_diff'])\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(sum(tmp_acc)/len(tmp_acc)))\n",
    "print(\"ED disparity  (avg): {}\".format(sum(tmp_ed)/len(tmp_ed)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. FairBatch w.r.t. Demographic parity\n",
    "### The results are in Section 4.1 of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.7940000295639038, DP disparity: 0.040395784543325486\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.7950000166893005, DP disparity: 0.04242154566744727\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.7950000166893005, DP disparity: 0.039395784543325485\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.7940000295639038, DP disparity: 0.04307962529274001\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.7950000166893005, DP disparity: 0.039737704918032746\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 5 >\n",
      "  Test accuracy: 0.7929999828338623, DP disparity: 0.038711943793911074\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 6 >\n",
      "  Test accuracy: 0.7900000214576721, DP disparity: 0.03571194379391107\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 7 >\n",
      "  Test accuracy: 0.7940000295639038, DP disparity: 0.038737704918032745\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 8 >\n",
      "  Test accuracy: 0.7960000038146973, DP disparity: 0.041370023419203816\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 9 >\n",
      "  Test accuracy: 0.7960000038146973, DP disparity: 0.04073770491803275\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "full_tests = []\n",
    "\n",
    "# Set the train data\n",
    "train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "\n",
    "seeds = [0,1,2,3,4,5,6,7,8,9]\n",
    "for seed in seeds:\n",
    "    \n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "    \n",
    "    model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Define FairBatch and DataLoader\n",
    "    # ---------------------\n",
    "    \n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = 0.005, target_fairness = 'dp', replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "    for epoch in range(450):\n",
    "\n",
    "        tmp_loss = []\n",
    "        \n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "            \n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "        \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, DP disparity: {}\".format(tmp_test['Acc'], tmp_test['DP_diff']))\n",
    "    print(\"----------------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.7942000150680542\n",
      "DP disparity  (avg): 0.040029976580796246\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_dp = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_dp.append(full_tests[i]['DP_diff'])\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(sum(tmp_acc)/len(tmp_acc)))\n",
    "print(\"DP disparity  (avg): {}\".format(sum(tmp_dp)/len(tmp_dp)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyTorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
