{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fairness under Correlation Shifts\n",
    "\n",
    "### This Jupyter Notebook simulates the proposed pre-processing approach on the synthetic data.\n",
    "\n",
    "We consider two scenarios: supporting (1) a single metric (DP) and (2) multiple metrics (DP & EO).\n",
    "\n",
    "We use FairBatch [Roh et al., ICLR 2021] as an in-processing approach."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "import numpy as np\n",
    "import math\n",
    "import random \n",
    "import itertools\n",
    "\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.utils.data.sampler import Sampler\n",
    "import torch\n",
    "\n",
    "from models import LogisticRegression, weights_init_normal\n",
    "from FairBatchSampler_Multiple import FairBatch, CustomDataset\n",
    "from utils import correlation_reweighting, datasampling, test_model\n",
    "\n",
    "import cvxopt\n",
    "import cvxpy as cp\n",
    "from cvxpy import OPTIMAL, Minimize, Problem, Variable, quad_form # Work in YJ kernel\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and process the data\n",
    "\n",
    "In the synthetic_data directory, there are a total of 6 numpy files including training data and test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "xz_train = np.load('./synthetic_data/xz_train.npy')\n",
    "y_train = np.load('./synthetic_data/y_train.npy') \n",
    "z_train = np.load('./synthetic_data/z_train.npy')\n",
    "\n",
    "xz_test = np.load('./synthetic_data/xz_test.npy')\n",
    "y_test = np.load('./synthetic_data/y_test.npy') \n",
    "z_test = np.load('./synthetic_data/z_test.npy')\n",
    "\n",
    "xz_train = torch.FloatTensor(xz_train)\n",
    "y_train = torch.FloatTensor(y_train)\n",
    "z_train = torch.FloatTensor(z_train)\n",
    "\n",
    "xz_test = torch.FloatTensor(xz_test)\n",
    "y_test = torch.FloatTensor(y_test)\n",
    "z_test = torch.FloatTensor(z_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters and functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [0,1,2,3,4]\n",
    "\n",
    "w = np.array([sum((z_train==1)&(y_train==1))/len(y_train), sum((z_train==0)&(y_train==1))/len(y_train), sum((z_train==1)&(y_train==-1))/len(y_train), sum((z_train==0)&(y_train==-1))/len(y_train)])\n",
    "corr = 0.18\n",
    "alpha = 0.005 # Used in FairBatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_epoch(model, train_features, labels, optimizer, criterion):\n",
    "    \"\"\"Trains the model with the given train data.\n",
    "\n",
    "    Args:\n",
    "        model: A torch model to train.\n",
    "        train_features: A torch tensor indicating the train features.\n",
    "        labels: A torch tensor indicating the true labels.\n",
    "        optimizer: A torch optimizer.\n",
    "        criterion: A torch criterion.\n",
    "\n",
    "    Returns:\n",
    "        loss value.\n",
    "    \"\"\"\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    label_predicted = model.forward(train_features)\n",
    "    loss  = criterion((F.tanh(label_predicted.squeeze())+1)/2, (labels.squeeze()+1)/2)\n",
    "    loss.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "    \n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_w_cvxpy(w, corr, gamma1, gamma2):\n",
    "    \n",
    "    \"\"\"Solves the SDP relaxation problem.\n",
    "\n",
    "    Args:\n",
    "        w: A list indicating the original data ratio for each (y, z)-class.\n",
    "        corr: A real number indicating the target correlation.\n",
    "        gamma1: A real number indicating the range of Pr(y) change\n",
    "        gamma2: A real number indicating the range of Pr(z) change\n",
    "\n",
    "    Returns:\n",
    "        solution for the optimization problem.\n",
    "    \"\"\"\n",
    "    \n",
    "    n = len(w)\n",
    "    a = w[0]\n",
    "    b = w[1]\n",
    "    c = w[2]\n",
    "    d = w[3]\n",
    "    orig_corr = w[0]/(w[0]+w[2]) - w[1]/(w[1]+w[3])\n",
    "\n",
    "    P0 = np.array([[1,0,0,0,-a],[0,1,0,0,-b],[0,0,1,0,-c],[0,0,0,1,-d],[-a,-b,-c,-d,0]])\n",
    "    \n",
    "    P1 = np.array([[0,-corr/2,0,(1-corr)/2,0],[-corr/2,0,(-1-corr)/2,0,0],[0,(-1-corr)/2,0,-corr/2,0],[(1-corr)/2,0,-corr/2,0,0],[0,0,0,0,0]])\n",
    "\n",
    "    P2 = np.array([[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,0],[0,0,0,0,0],[1,1,0,0,0]])\n",
    "    r2 = -2*(a+b)\n",
    "\n",
    "    P3 = np.array([[0,0,0,0,1],[0,0,0,0,0],[0,0,0,0,1],[0,0,0,0,0],[1,0,1,0,0]])\n",
    "    r3 = -2*(a+c)\n",
    "\n",
    "    P4 = np.array([[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[0,0,0,0,1],[1,1,1,1,0]])\n",
    "    r4 = -2*1\n",
    "\n",
    "    P5 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,1]])\n",
    "\n",
    "    X = cp.Variable((n+1,n+1), symmetric=True)\n",
    "\n",
    "    constraints = [X >> 0]\n",
    "    constraints = [\n",
    "        cp.trace(P1 @ X) == 0,\n",
    "        cp.trace(P2 @ X) + r2 <= gamma1,\n",
    "        cp.trace(P2 @ X) + r2 >= -gamma1,\n",
    "        cp.trace(P3 @ X) + r3 <= gamma2,\n",
    "        cp.trace(P3 @ X) + r3 >= -gamma2,\n",
    "        cp.trace(P4 @ X) + r4 == 0,\n",
    "        cp.trace(P5 @ X) == 1,\n",
    "        X >> 0\n",
    "    ]\n",
    "    prob = cp.Problem(cp.Minimize(cp.trace(P0 @ X)),constraints)\n",
    "\n",
    "    result = prob.solve()\n",
    "\n",
    "    x = X.value\n",
    "    x = x[:, -1][:-1]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Supporting a single metric (DP)\n",
    "### The results are in the experiments of the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1 In-processing-only (FairBatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8209999799728394, Unfairness (DP): 0.048428571428571376\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8209999799728394, Unfairness (DP): 0.048428571428571376\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8209999799728394, Unfairness (DP): 0.048428571428571376\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8199999928474426, Unfairness (DP): 0.047428571428571376\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8209999799728394, Unfairness (DP): 0.048428571428571376\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "train_type = 'in-processing-only'\n",
    "\n",
    "full_tests = []\n",
    "full_trains = []\n",
    "\n",
    "\"\"\" Find new data ratio for each (y, z)-class \"\"\"  \n",
    "w_new = find_w_cvxpy(w, corr, 0.1, 0.1)\n",
    "\n",
    "\"\"\" Find example weights according to the new weight \"\"\"  \n",
    "our_weights = correlation_reweighting(xz_train, y_train, z_train, w, w_new)\n",
    "\n",
    "\"\"\" Train models \"\"\"\n",
    "for seed in seeds:\n",
    "\n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "\n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "\n",
    "    useCuda = False\n",
    "    if useCuda:\n",
    "        model = LogisticRegression(3,1).cuda()\n",
    "    else:\n",
    "        model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    \n",
    "    # ---------------------\n",
    "    #  Set data and batch sampler\n",
    "    # ---------------------\n",
    "    \n",
    "    if train_type == 'in-processing-only':\n",
    "        train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "    else:\n",
    "        new_index = datasampling(xz_train, y_train, z_train, our_weights, seed = seed)\n",
    "        train_data = CustomDataset(xz_train[new_index], y_train[new_index], z_train[new_index])\n",
    "\n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = alpha, target_fairness = 'dp', replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "\n",
    "    for epoch in range(500):\n",
    "\n",
    "        tmp_loss = []\n",
    "\n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "\n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "    \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, Unfairness (DP): {}\".format(tmp_test['Acc'], tmp_test['DP_diff']))\n",
    "    print(\"----------------------------------------------------------------------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.82079998254776\n",
      "Unfairness (DP) (avg): 0.04822857142857138\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_unfair = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_unfair.append(full_tests[i]['DP_diff'])\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(np.mean(tmp_acc)))\n",
    "print(\"Unfairness (DP) (avg): {}\".format(np.mean(tmp_unfair)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2 Ours + In-processing (FairBatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8360000252723694, Unfairness (DP): 0.0034285714285714475\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8349999785423279, Unfairness (DP): 0.0024285714285714466\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8370000123977661, Unfairness (DP): 0.0008095238095238155\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8360000252723694, Unfairness (DP): 0.0034285714285714475\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8360000252723694, Unfairness (DP): 0.004190476190476189\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "train_type = 'ours'\n",
    "\n",
    "full_tests = []\n",
    "full_trains = []\n",
    "\n",
    "\"\"\" Find new data ratio for each (y, z)-class \"\"\"  \n",
    "w_new = find_w_cvxpy(w, corr, 0.1, 0.1)\n",
    "\n",
    "\"\"\" Find example weights according to the new weight \"\"\"  \n",
    "our_weights = correlation_reweighting(xz_train, y_train, z_train, w, w_new)\n",
    "\n",
    "\"\"\" Train models \"\"\"\n",
    "for seed in seeds:\n",
    "\n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "\n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "\n",
    "    useCuda = False\n",
    "    if useCuda:\n",
    "        model = LogisticRegression(3,1).cuda()\n",
    "    else:\n",
    "        model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    \n",
    "    # ---------------------\n",
    "    #  Set data and batch sampler\n",
    "    # ---------------------\n",
    "    \n",
    "    if train_type == 'in-processing-only':\n",
    "        train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "    else:\n",
    "        new_index = datasampling(xz_train, y_train, z_train, our_weights, seed = seed)\n",
    "        train_data = CustomDataset(xz_train[new_index], y_train[new_index], z_train[new_index])\n",
    "\n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = alpha, target_fairness = 'dp', replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "\n",
    "    for epoch in range(500):\n",
    "\n",
    "        tmp_loss = []\n",
    "\n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "\n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "    \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, Unfairness (DP): {}\".format(tmp_test['Acc'], tmp_test['DP_diff']))\n",
    "    print(\"----------------------------------------------------------------------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.8360000133514405\n",
      "Unfairness (DP) (avg): 0.0028571428571428693\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_unfair = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_unfair.append(full_tests[i]['DP_diff'])\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(np.mean(tmp_acc)))\n",
    "print(\"Unfairness (DP) (avg): {}\".format(np.mean(tmp_unfair)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Supporting multiple metrics (DP & EO)\n",
    "### The results are in the experiments of the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 In-processing-only (FairBatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8479999899864197, Unfairness (DP & EO): 0.09999999999999999\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8500000238418579, Unfairness (DP & EO): 0.08823529411764706\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8489999771118164, Unfairness (DP & EO): 0.09215686274509803\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8479999899864197, Unfairness (DP & EO): 0.08823529411764706\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8479999899864197, Unfairness (DP & EO): 0.08823529411764706\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "train_type = 'in-processing-only'\n",
    "\n",
    "full_tests = []\n",
    "full_trains = []\n",
    "\n",
    "\"\"\" Find new data ratio for each (y, z)-class \"\"\"  \n",
    "w_new = find_w_cvxpy(w, corr, 0.1, 0.1)\n",
    "\n",
    "\"\"\" Find example weights according to the new weight \"\"\"  \n",
    "our_weights = correlation_reweighting(xz_train, y_train, z_train, w, w_new)\n",
    "\n",
    "\"\"\" Train models \"\"\"\n",
    "for seed in seeds:\n",
    "\n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "\n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "\n",
    "    useCuda = False\n",
    "    if useCuda:\n",
    "        model = LogisticRegression(3,1).cuda()\n",
    "    else:\n",
    "        model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    \n",
    "    # ---------------------\n",
    "    #  Set data and batch sampler\n",
    "    # ---------------------\n",
    "\n",
    "    if train_type == 'in-processing-only':\n",
    "        train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "    else:\n",
    "        new_index = datasampling(xz_train, y_train, z_train, our_weights, seed = seed)\n",
    "        train_data = CustomDataset(xz_train[new_index], y_train[new_index], z_train[new_index])\n",
    "    \n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = alpha, target_fairness = 'dpeo', knob = 3, replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "\n",
    "    for epoch in range(500):\n",
    "\n",
    "        tmp_loss = []\n",
    "\n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "\n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "    \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, Unfairness (DP & EO): {}\".format(tmp_test['Acc'], max(tmp_test['DP_diff'], tmp_test['EqOdds_diff'])))\n",
    "    print(\"----------------------------------------------------------------------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.8485999941825867\n",
      "Unfairness (DP & EO) (avg): 0.09137254901960785\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_unfair = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_unfair.append(max(full_tests[i]['DP_diff'], full_tests[i]['EqOdds_diff']))\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(np.mean(tmp_acc)))\n",
    "print(\"Unfairness (DP & EO) (avg): {}\".format(np.mean(tmp_unfair)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Ours + In-processing-only (FairBatch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "< Seed: 0 >\n",
      "  Test accuracy: 0.8600000143051147, Unfairness (DP & EO): 0.05686274509803921\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 1 >\n",
      "  Test accuracy: 0.8500000238418579, Unfairness (DP & EO): 0.0586666666666667\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 2 >\n",
      "  Test accuracy: 0.8489999771118164, Unfairness (DP & EO): 0.05728571428571422\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 3 >\n",
      "  Test accuracy: 0.8500000238418579, Unfairness (DP & EO): 0.0586666666666667\n",
      "----------------------------------------------------------------------\n",
      "< Seed: 4 >\n",
      "  Test accuracy: 0.8489999771118164, Unfairness (DP & EO): 0.05728571428571422\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "train_type = 'ours'\n",
    "\n",
    "full_tests = []\n",
    "full_trains = []\n",
    "\n",
    "\"\"\" Find new data ratio for each (y, z)-class \"\"\"  \n",
    "w_new = find_w_cvxpy(w, corr, 0.1, 0.1)\n",
    "\n",
    "\"\"\" Find example weights according to the new weight \"\"\"  \n",
    "our_weights = correlation_reweighting(xz_train, y_train, z_train, w, w_new)\n",
    "\n",
    "\"\"\" Train models \"\"\"\n",
    "for seed in seeds:\n",
    "\n",
    "    print(\"< Seed: {} >\".format(seed))\n",
    "\n",
    "    # ---------------------\n",
    "    #  Initialize model, optimizer, and criterion\n",
    "    # ---------------------\n",
    "\n",
    "    useCuda = False\n",
    "    if useCuda:\n",
    "        model = LogisticRegression(3,1).cuda()\n",
    "    else:\n",
    "        model = LogisticRegression(3,1)\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "    model.apply(weights_init_normal)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999))\n",
    "    criterion = torch.nn.BCELoss()\n",
    "\n",
    "    losses = []\n",
    "    \n",
    "    \n",
    "    # ---------------------\n",
    "    #  Set data and batch sampler\n",
    "    # ---------------------\n",
    "\n",
    "    if train_type == 'in-processing-only':\n",
    "        train_data = CustomDataset(xz_train, y_train, z_train)\n",
    "    else:\n",
    "        new_index = datasampling(xz_train, y_train, z_train, our_weights, seed = seed)\n",
    "        train_data = CustomDataset(xz_train[new_index], y_train[new_index], z_train[new_index])\n",
    "    \n",
    "    sampler = FairBatch (model, train_data.x, train_data.y, train_data.z, batch_size = 100, alpha = alpha, target_fairness = 'dpeo', knob = 10, replacement = False, seed = seed)\n",
    "    train_loader = torch.utils.data.DataLoader (train_data, sampler=sampler, num_workers=0)\n",
    "\n",
    "    \n",
    "    # ---------------------\n",
    "    #  Model training\n",
    "    # ---------------------\n",
    "\n",
    "    for epoch in range(500):\n",
    "\n",
    "        tmp_loss = []\n",
    "\n",
    "        for batch_idx, (data, target, z) in enumerate (train_loader):\n",
    "            loss = run_epoch (model, data, target, optimizer, criterion)\n",
    "            tmp_loss.append(loss)\n",
    "\n",
    "        losses.append(sum(tmp_loss)/len(tmp_loss))\n",
    "    \n",
    "    tmp_test = test_model(model, xz_test, y_test, z_test)\n",
    "    full_tests.append(tmp_test)\n",
    "    \n",
    "    print(\"  Test accuracy: {}, Unfairness (DP & EO): {}\".format(tmp_test['Acc'], max(tmp_test['DP_diff'], tmp_test['EqOdds_diff'])))\n",
    "    print(\"----------------------------------------------------------------------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy (avg): 0.8516000032424926\n",
      "Unfairness (DP & EO) (avg): 0.057753501400560216\n"
     ]
    }
   ],
   "source": [
    "tmp_acc = []\n",
    "tmp_unfair = []\n",
    "for i in range(len(seeds)):\n",
    "    tmp_acc.append(full_tests[i]['Acc'])\n",
    "    tmp_unfair.append(max(full_tests[i]['DP_diff'], full_tests[i]['EqOdds_diff']))\n",
    "\n",
    "print(\"Test accuracy (avg): {}\".format(np.mean(tmp_acc)))\n",
    "print(\"Unfairness (DP & EO) (avg): {}\".format(np.mean(tmp_unfair)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
