{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An example: running BuDRO on the German credit data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy as sp\n",
    "import tensorflow as tf\n",
    "import xgboost as xgb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## load the BuDRO scripts\n",
    "import sys\n",
    "sys.path.append('../scripts')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import BuDRO scripts.  `geman/german_proc.py` contains functions for loading and processing the German credit data set.  `scripts/script.py` contains various helper functions for collecting statistics.  `scripts/tf_xgboost_log.py` contains all of the boosting functions for BuDRO (primal vs dual, SGD vs without SGD)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import german_proc\n",
    "import script\n",
    "import tf_xgboost_log as txl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define dtype for a default.  \n",
    "_dtype = tf.float32\n",
    "_verbose = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load seeds and get a train/test split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = np.load('german-seeds.npz')['seeds']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = seeds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_dataset = german_proc.get_german_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3, 4, 5, 6]\n"
     ]
    }
   ],
   "source": [
    "(\n",
    "    X_train,\n",
    "    X_test,\n",
    "    y_train,\n",
    "    y_test,\n",
    "    y_age_train,\n",
    "    y_age_test,\n",
    "    feature_names,\n",
    "    traininds,\n",
    "    testinds\n",
    ") = german_proc.get_german_train_test_age(\n",
    "        orig_dataset,\n",
    "        pct=0.8,    # percent data in training set\n",
    "        removeProt=False,    # remove or keep protected features\n",
    "        seed=seed\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Indices of protected features and status features for running consistency evaluations.  Also create a binary age feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Personal status indices: [37, 38, 39, 40]\n"
     ]
    }
   ],
   "source": [
    "aind = feature_names.index('age')\n",
    "abind = feature_names.index('age_bin')\n",
    "\n",
    "status_inds = []\n",
    "y_sex_bin = np.zeros(X_train.shape[0])\n",
    "yt_sex = np.zeros(X_test.shape[0])\n",
    "\n",
    "for i, nom in enumerate(feature_names):\n",
    "    if \"personal_status\" in nom:\n",
    "        status_inds.append(i)\n",
    "        \n",
    "        # Males are 1, females are 0 in the following \n",
    "        if 'A92' not in nom and 'A95' not in nom:\n",
    "            y_sex_bin  += X_train[:,i]\n",
    "            yt_sex += X_test[:,i]\n",
    "\n",
    "print(\"Personal status indices: {}\".format(status_inds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'personal_status=A91'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_names[37]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the binary age feature and remove it from the data.  See the note about agebin in the `00_README.txt` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_age_bin = np.copy(X_train[:, abind])\n",
    "yt_age = np.copy(X_test[:, abind])\n",
    "\n",
    "X_train[:, abind] = 0\n",
    "X_test[:, abind] = 0 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the BuDRO parameters\n",
    "Distance matrix, etc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of projection matrix: (62, 62)\n",
      "Projection error: 1.1102230246251565e-16\n"
     ]
    }
   ],
   "source": [
    "# first project the data onto the sensitive subspace, then calculate the pairwise distances\n",
    "RCV = german_proc.train_ridge(X_train, y_age_train, aind=aind, abind=None)\n",
    "proj = german_proc.german_proj_mat(RCV, aind)\n",
    "projData = np.matmul(X_train, proj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "Corig = sp.spatial.distance.squareform(\n",
    "        sp.spatial.distance.pdist(\n",
    "            projData,\n",
    "            metric='sqeuclidean'\n",
    "        )\n",
    ")\n",
    "\n",
    "C = tf.constant(Corig, dtype=_dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = Corig.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# XGBoost data objects and tree parameters\n",
    "dtrain, dtest, watchlist, param, _, _ = german_proc.prep_baseline_xgb(X_train, X_test, y_train, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Tweak the values in param to change the input hyperparameters.\n",
    "\n",
    "Generally keep `scale_pos_weight` untouched, and don't touch the objective.  Otherwise, I've found good results by keeping the tree pretty short and regularizing quite a bit (`lambda` on the order of 1 - 1000, fairly large `min_child_weight` 0.1-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': 5,\n",
       " 'eta': 0.05,\n",
       " 'objective': 'binary:logistic',\n",
       " 'min_child_weight': 0.5,\n",
       " 'lambda': 0.0001,\n",
       " 'scale_pos_weight': 2.3333333333333335}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run BuDRO to reproduce one of the hand-tuning runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "param['max_depth'] = 10\n",
    "param['eta']= 0.005 \n",
    "param['min_child_weight'] = 0.1/n\n",
    "param['lambda'] = 1.0\n",
    "\n",
    "n_iter = 500\n",
    "eps = 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Boosting functions:\n",
    "\n",
    "* `txl.boost_dual` uses dual without sgd\n",
    "* `txl.boost_dual_sgd` uses dual with sgd\n",
    "* `txl.boost_sinkhorn_tf` uses entropic regularization without sgd\n",
    "* `txl.boost_sinkhorn_2gpu` uses entropic regularization without sgd explicitly on 2 gpus.  This code is not portable and should be avoided if possible. \n",
    "* `txl.boost_sinkhorn_sgd_2gpu` uses entropric regularization with sgd and expects 2 gpus. \n",
    "\n",
    "Please see the file `scripts/tf_xgboost_log.py` for the function arguments and implementations.\n",
    "\n",
    "In the output for the following, we are minimizing the inner max - thus, that quantity should be decreasing.  If it's not, mess around with the parameters some more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 0: Final value of inner max: 0.6914510862529278\n",
      "Iter 1: Final value of inner max: 0.6913871238729493\n",
      "Iter 2: Final value of inner max: 0.6913456810265779\n",
      "Iter 3: Final value of inner max: 0.6913054555654525\n",
      "Iter 4: Final value of inner max: 0.6912731740868736\n",
      "Iter 5: Final value of inner max: 0.6912468572827626\n",
      "Iter 6: Final value of inner max: 0.6912159740139269\n",
      "Iter 7: Final value of inner max: 0.6911799963616841\n",
      "Iter 8: Final value of inner max: 0.69113916606486\n",
      "Iter 9: Final value of inner max: 0.6911111467331648\n",
      "Iter 10: Final value of inner max: 0.6910489483177662\n",
      "Iter 11: Final value of inner max: 0.6910253007802201\n",
      "Iter 12: Final value of inner max: 0.6909951419532356\n",
      "Iter 13: Final value of inner max: 0.6909613249450922\n",
      "Iter 14: Final value of inner max: 0.6909336142987013\n",
      "Iter 15: Final value of inner max: 0.6908779411017896\n",
      "Iter 16: Final value of inner max: 0.6908568576371931\n",
      "Iter 17: Final value of inner max: 0.6908369352668524\n",
      "Iter 18: Final value of inner max: 0.690806607176782\n",
      "Iter 19: Final value of inner max: 0.6907581660027127\n",
      "Iter 20: Final value of inner max: 0.6907347445189953\n",
      "Iter 21: Final value of inner max: 0.6907067283269777\n",
      "Iter 22: Final value of inner max: 0.6906810290366412\n",
      "Iter 23: Final value of inner max: 0.6906392924533699\n",
      "Iter 24: Final value of inner max: 0.6906182212382554\n",
      "Iter 25: Final value of inner max: 0.6905949907283464\n",
      "Iter 26: Final value of inner max: 0.6905461509226506\n",
      "Iter 27: Final value of inner max: 0.6905327262729407\n",
      "Iter 28: Final value of inner max: 0.6905134259164334\n",
      "Iter 29: Final value of inner max: 0.6904895022044228\n",
      "Iter 30: Final value of inner max: 0.6904310883578688\n",
      "Iter 31: Final value of inner max: 0.6903785175787445\n",
      "Iter 32: Final value of inner max: 0.6903567477227772\n",
      "Iter 33: Final value of inner max: 0.690336673359315\n",
      "Iter 34: Final value of inner max: 0.6903019496057107\n",
      "Iter 35: Final value of inner max: 0.6902885508537293\n",
      "Iter 36: Final value of inner max: 0.6902451783258247\n",
      "Iter 37: Final value of inner max: 0.6901899235695601\n",
      "Iter 38: Final value of inner max: 0.6901661171654221\n",
      "Iter 39: Final value of inner max: 0.6901357239484787\n",
      "Iter 40: Final value of inner max: 0.6901055033619073\n",
      "Iter 41: Final value of inner max: 0.6900557160377503\n",
      "Iter 42: Final value of inner max: 0.6900159494646053\n",
      "Iter 43: Final value of inner max: 0.6899705209848106\n",
      "Iter 44: Final value of inner max: 0.6899284569174051\n",
      "Iter 45: Final value of inner max: 0.6899017240852118\n",
      "Iter 46: Final value of inner max: 0.6898732610411064\n",
      "Iter 47: Final value of inner max: 0.6898559067398309\n",
      "Iter 48: Final value of inner max: 0.6898150096088648\n",
      "Iter 49: Final value of inner max: 0.6897637256046044\n",
      "Iter 50: Final value of inner max: 0.6897312454879284\n",
      "Iter 51: Final value of inner max: 0.6896942621066781\n",
      "Iter 52: Final value of inner max: 0.6896543526757659\n",
      "Iter 53: Final value of inner max: 0.6896468444168569\n",
      "Iter 54: Final value of inner max: 0.6895998660475016\n",
      "Iter 55: Final value of inner max: 0.6895953169465064\n",
      "Iter 56: Final value of inner max: 0.6895818493573642\n",
      "Iter 57: Final value of inner max: 0.6895553481159375\n",
      "Iter 58: Final value of inner max: 0.6895321437002035\n",
      "Iter 59: Final value of inner max: 0.6895133732906241\n",
      "Iter 60: Final value of inner max: 0.6894671309221068\n",
      "Iter 61: Final value of inner max: 0.6894501656963099\n",
      "Iter 62: Final value of inner max: 0.6894353747188837\n",
      "Iter 63: Final value of inner max: 0.6893851976841688\n",
      "Iter 64: Final value of inner max: 0.6893647124810307\n",
      "Iter 65: Final value of inner max: 0.6893357610702515\n",
      "Iter 66: Final value of inner max: 0.6892926111072302\n",
      "Iter 67: Final value of inner max: 0.6892588570596543\n",
      "Iter 68: Final value of inner max: 0.6892370022561559\n",
      "Iter 69: Final value of inner max: 0.6892204059522482\n",
      "Iter 70: Final value of inner max: 0.6891857254505158\n",
      "Iter 71: Final value of inner max: 0.6891635348647833\n",
      "Iter 72: Final value of inner max: 0.6891256213760033\n",
      "Iter 73: Final value of inner max: 0.6891113719778773\n",
      "Iter 74: Final value of inner max: 0.689090060271289\n",
      "Iter 75: Final value of inner max: 0.6890595487505198\n",
      "Iter 76: Warning: Multiple points have more than one non-trivial transport location.\n",
      "Warning: All but first point ignored\n",
      "Warning: Points are [606 701]\n",
      "Warning: Maximum distance between these points: 16.217430114746094\n",
      "Final value of inner max: 0.6890219787646106\n",
      "Iter 77: Final value of inner max: 0.6890075595672052\n",
      "Iter 78: Final value of inner max: 0.6889829142391681\n",
      "Iter 79: Final value of inner max: 0.6889697460830212\n",
      "Iter 80: Final value of inner max: 0.688925602562136\n",
      "Iter 81: Final value of inner max: 0.6889135755598546\n",
      "Iter 82: Final value of inner max: 0.6888800268658828\n",
      "Iter 83: Final value of inner max: 0.6888510210067034\n",
      "Iter 84: Final value of inner max: 0.6888372925807038\n",
      "Iter 85: Warning: Multiple points have more than one non-trivial transport location.\n",
      "Warning: All but first point ignored\n",
      "Warning: Points are [ 57 264]\n",
      "Warning: Maximum distance between these points: 16.307392120361328\n",
      "Final value of inner max: 0.6887899889478213\n",
      "Iter 86: Final value of inner max: 0.6887585151940584\n",
      "Iter 87: Final value of inner max: 0.6887434498038504\n",
      "Iter 88: Final value of inner max: 0.6887071213126182\n",
      "Iter 89: Final value of inner max: 0.6886937272130427\n",
      "Iter 90: Final value of inner max: 0.6886650834568979\n",
      "Iter 91: Final value of inner max: 0.6886410814685412\n",
      "Iter 92: Final value of inner max: 0.6886035730222273\n",
      "Iter 93: Final value of inner max: 0.6885793164204728\n",
      "Iter 94: Final value of inner max: 0.6885519608110189\n",
      "Iter 95: Final value of inner max: 0.6885423645940346\n",
      "Iter 96: Final value of inner max: 0.6884975349903106\n",
      "Iter 97: Final value of inner max: 0.6884594261277901\n",
      "Iter 98: Final value of inner max: 0.6884373930937698\n",
      "Iter 99: Final value of inner max: 0.688416554875779\n",
      "Iter 100: Final value of inner max: 0.6883851503580809\n",
      "Iter 101: Final value of inner max: 0.6883667521272874\n",
      "Iter 102: Final value of inner max: 0.6883409694336637\n",
      "Iter 103: Final value of inner max: 0.6883227635331222\n",
      "Iter 104: Final value of inner max: 0.6883096025949841\n",
      "Iter 105: Final value of inner max: 0.6882823844757298\n",
      "Iter 106: Final value of inner max: 0.6882416411722843\n",
      "Iter 107: Final value of inner max: 0.6882253430038692\n",
      "Iter 108: Final value of inner max: 0.6881867523728485\n",
      "Iter 109: Final value of inner max: 0.6881662512452509\n",
      "Iter 110: Final value of inner max: 0.6881518602638994\n",
      "Iter 111: Final value of inner max: 0.6881050707399845\n",
      "Iter 112: Final value of inner max: 0.6880890866369009\n",
      "Iter 113: Final value of inner max: 0.6880583137117122\n",
      "Iter 114: Final value of inner max: 0.6880233937374833\n",
      "Iter 115: Final value of inner max: 0.6880064830585184\n",
      "Iter 116: Final value of inner max: 0.6879850642650123\n",
      "Iter 117: Warning: Multiple points have more than one non-trivial transport location.\n",
      "Warning: All but first point ignored\n",
      "Warning: Points are [ 53 223]\n",
      "Warning: Maximum distance between these points: 26.30180549621582\n",
      "Final value of inner max: 0.6879580216836904\n",
      "Iter 118: Final value of inner max: 0.6879164901331367\n",
      "Iter 119: Final value of inner max: 0.6878877985372874\n",
      "Iter 120: Final value of inner max: 0.6878747042268515\n",
      "Iter 121: Final value of inner max: 0.6878482659906149\n",
      "Iter 122: Final value of inner max: 0.6878235793302889\n",
      "Iter 123: Final value of inner max: 0.6878097255977152\n",
      "Iter 124: Final value of inner max: 0.6877798634767532\n",
      "Iter 125: Final value of inner max: 0.6877713060379028\n",
      "Iter 126: Final value of inner max: 0.6877488154917956\n",
      "Iter 127: Final value of inner max: 0.6877208723414316\n",
      "Iter 128: Final value of inner max: 0.6876830858333678\n",
      "Iter 129: Final value of inner max: 0.6876760711520911\n",
      "Iter 130: Final value of inner max: 0.6876377320470028\n",
      "Iter 131: Final value of inner max: 0.68759091356446\n",
      "Iter 132: Final value of inner max: 0.6875747812269393\n",
      "Iter 133: Final value of inner max: 0.6875416430087053\n",
      "Iter 134: Final value of inner max: 0.6875291032055659\n",
      "Iter 135: Final value of inner max: 0.687510164603591\n",
      "Iter 136: Final value of inner max: 0.6874698023457068\n",
      "Iter 137: Final value of inner max: 0.6874503611716101\n",
      "Iter 138: Final value of inner max: 0.6874244434284718\n",
      "Iter 139: Final value of inner max: 0.6874312902987003\n",
      "Iter 140: Final value of inner max: 0.6873824698836104\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 141: Final value of inner max: 0.6873617044836282\n",
      "Iter 142: Final value of inner max: 0.6873367621004581\n",
      "Iter 143: Final value of inner max: 0.6873167233335615\n",
      "Iter 144: Final value of inner max: 0.6872963924834541\n",
      "Iter 145: Final value of inner max: 0.6872841471051725\n",
      "Iter 146: Final value of inner max: 0.6872377872945277\n",
      "Iter 147: Final value of inner max: 0.6872268605205842\n",
      "Iter 148: Final value of inner max: 0.6872155416756869\n",
      "Iter 149: Final value of inner max: 0.6872064643353224\n",
      "Iter 150: Final value of inner max: 0.6871736271679402\n",
      "Iter 151: Final value of inner max: 0.6871471609175206\n",
      "Iter 152: Final value of inner max: 0.6870865130424499\n",
      "Iter 153: Final value of inner max: 0.6870686276197586\n",
      "Iter 154: Final value of inner max: 0.6870512844135968\n",
      "Iter 155: Final value of inner max: 0.6870311967478221\n",
      "Iter 156: Final value of inner max: 0.6870069048553705\n",
      "Iter 157: Final value of inner max: 0.6869708538723616\n",
      "Iter 158: Final value of inner max: 0.6869342570841099\n",
      "Iter 159: Final value of inner max: 0.6869104766845704\n",
      "Iter 160: Final value of inner max: 0.686903587862778\n",
      "Iter 161: Final value of inner max: 0.6868485835939646\n",
      "Iter 162: Final value of inner max: 0.6868369350189386\n",
      "Iter 163: Final value of inner max: 0.6868149753659963\n",
      "Iter 164: Final value of inner max: 0.6867797750952036\n",
      "Iter 165: Final value of inner max: 0.6867587479203939\n",
      "Iter 166: Final value of inner max: 0.6867338518986239\n",
      "Iter 167: Final value of inner max: 0.6867052880593068\n",
      "Iter 168: Final value of inner max: 0.6866843630521218\n",
      "Iter 169: Final value of inner max: 0.686660509541597\n",
      "Iter 170: Final value of inner max: 0.6866215083521594\n",
      "Iter 171: Final value of inner max: 0.6865827783197165\n",
      "Iter 172: Final value of inner max: 0.6865891893953086\n",
      "Iter 173: Final value of inner max: 0.6865554463063606\n",
      "Iter 174: Final value of inner max: 0.6865493978559971\n",
      "Iter 175: Final value of inner max: 0.6865167171933428\n",
      "Iter 176: Final value of inner max: 0.6864767058938742\n",
      "Iter 177: Final value of inner max: 0.6864797795563937\n",
      "Iter 178: Final value of inner max: 0.686453520655632\n",
      "Iter 179: Final value of inner max: 0.6864396781654286\n",
      "Iter 180: Final value of inner max: 0.6864136709272861\n",
      "Iter 181: Final value of inner max: 0.6863554190844298\n",
      "Iter 182: Final value of inner max: 0.6863380048834403\n",
      "Iter 183: Final value of inner max: 0.6863258634419072\n",
      "Iter 184: Final value of inner max: 0.6863393727689981\n",
      "Iter 185: Final value of inner max: 0.6862856245322487\n",
      "Iter 186: Final value of inner max: 0.6862412315230684\n",
      "Iter 187: Final value of inner max: 0.6862063279907316\n",
      "Iter 188: Final value of inner max: 0.6861812441773047\n",
      "Iter 189: Final value of inner max: 0.6861653971694881\n",
      "Iter 190: Final value of inner max: 0.6861425330999105\n",
      "Iter 191: Final value of inner max: 0.6861222846158945\n",
      "Iter 192: Final value of inner max: 0.6860714527219534\n",
      "Iter 193: Final value of inner max: 0.6860928135365248\n",
      "Iter 194: Final value of inner max: 0.6860416801273823\n",
      "Iter 195: Final value of inner max: 0.6860275875494312\n",
      "Iter 196: Final value of inner max: 0.6860082271695137\n",
      "Iter 197: Final value of inner max: 0.685983597259942\n",
      "Iter 198: Final value of inner max: 0.685962337255478\n",
      "Iter 199: Final value of inner max: 0.6859286023388025\n",
      "Iter 200: Final value of inner max: 0.6859090071171522\n",
      "Iter 201: Final value of inner max: 0.6859014231711626\n",
      "Iter 202: Final value of inner max: 0.6858619790020934\n",
      "Iter 203: Final value of inner max: 0.6858619245886802\n",
      "Iter 204: Final value of inner max: 0.685814851000905\n",
      "Iter 205: Final value of inner max: 0.6857810530911554\n",
      "Iter 206: Final value of inner max: 0.685749680193479\n",
      "Iter 207: Final value of inner max: 0.6857265438139439\n",
      "Iter 208: Final value of inner max: 0.6857097131968886\n",
      "Iter 209: Final value of inner max: 0.6857064005732536\n",
      "Iter 210: Final value of inner max: 0.6856791763752699\n",
      "Iter 211: Final value of inner max: 0.6856756829470396\n",
      "Iter 212: Final value of inner max: 0.6856480638682843\n",
      "Iter 213: Final value of inner max: 0.6855831895768643\n",
      "Iter 214: Final value of inner max: 0.6855638358634748\n",
      "Iter 215: Final value of inner max: 0.6855499818058225\n",
      "Iter 216: Final value of inner max: 0.6855043783038854\n",
      "Iter 217: Final value of inner max: 0.6854905313141832\n",
      "Iter 218: Final value of inner max: 0.6854594064502452\n",
      "Iter 219: Final value of inner max: 0.685434372456657\n",
      "Iter 220: Final value of inner max: 0.6854146779189504\n",
      "Iter 221: Final value of inner max: 0.6853824562579394\n",
      "Iter 222: Final value of inner max: 0.6853792034089565\n",
      "Iter 223: Final value of inner max: 0.6853510839492083\n",
      "Iter 224: Final value of inner max: 0.6853323638758276\n",
      "Iter 225: Final value of inner max: 0.6852989468723536\n",
      "Iter 226: Final value of inner max: 0.6852663616204837\n",
      "Iter 227: Warning: Multiple points have more than one non-trivial transport location.\n",
      "Warning: All but first point ignored\n",
      "Warning: Points are [163 606]\n",
      "Warning: Maximum distance between these points: 11.536554336547852\n",
      "Final value of inner max: 0.6852467439243045\n",
      "Iter 228: Final value of inner max: 0.6852160226099657\n",
      "Iter 229: Final value of inner max: 0.685187375422833\n",
      "Iter 230: Final value of inner max: 0.6851570908725262\n",
      "Iter 231: Final value of inner max: 0.6851569602638483\n",
      "Iter 232: Final value of inner max: 0.685140810534358\n",
      "Iter 233: Final value of inner max: 0.6851015290617942\n",
      "Iter 234: Final value of inner max: 0.6850901333987713\n",
      "Iter 235: Final value of inner max: 0.685080875530839\n",
      "Iter 236: Final value of inner max: 0.6850446796417237\n",
      "Iter 237: Final value of inner max: 0.6850185760072942\n",
      "Iter 238: Final value of inner max: 0.6850064224749803\n",
      "Iter 239: Final value of inner max: 0.684972790852189\n",
      "Iter 240: Final value of inner max: 0.6849500789493322\n",
      "Iter 241: Final value of inner max: 0.6848936203809746\n",
      "Iter 242: Final value of inner max: 0.6848428215831519\n",
      "Iter 243: Final value of inner max: 0.6848321325428863\n",
      "Iter 244: Final value of inner max: 0.6848043317686597\n",
      "Iter 245: Final value of inner max: 0.6847875987477314\n",
      "Iter 246: Final value of inner max: 0.6847686033280023\n",
      "Iter 247: Final value of inner max: 0.6847462106749522\n",
      "Iter 248: Final value of inner max: 0.6847294891521136\n",
      "Iter 249: Final value of inner max: 0.6846905318409868\n",
      "Iter 250: Final value of inner max: 0.6846799018234014\n",
      "Iter 251: Final value of inner max: 0.6846323899179697\n",
      "Iter 252: Final value of inner max: 0.6845978878511246\n",
      "Iter 253: Final value of inner max: 0.6845785857160147\n",
      "Iter 254: Final value of inner max: 0.6845672226378695\n",
      "Iter 255: Final value of inner max: 0.6845519009743032\n",
      "Iter 256: Final value of inner max: 0.6844901974499226\n",
      "Iter 257: Final value of inner max: 0.6844831844819959\n",
      "Iter 258: Final value of inner max: 0.6844325793534517\n",
      "Iter 259: Final value of inner max: 0.684415811755769\n",
      "Iter 260: Final value of inner max: 0.6844049768644136\n",
      "Iter 261: Final value of inner max: 0.6843729018419982\n",
      "Iter 262: Final value of inner max: 0.6843550123156352\n",
      "Iter 263: Final value of inner max: 0.6843252911418677\n",
      "Iter 264: Final value of inner max: 0.6843138462395805\n",
      "Iter 265: Final value of inner max: 0.6843102751672268\n",
      "Iter 266: Final value of inner max: 0.6842792280018329\n",
      "Iter 267: Final value of inner max: 0.6842542394250631\n",
      "Iter 268: Final value of inner max: 0.6842170082746246\n",
      "Iter 269: Final value of inner max: 0.6841907193511725\n",
      "Iter 270: Final value of inner max: 0.6841824784874916\n",
      "Iter 271: Final value of inner max: 0.6841239117830993\n",
      "Iter 272: Final value of inner max: 0.6841287432450774\n",
      "Iter 273: Final value of inner max: 0.6841130037605763\n",
      "Iter 274: Final value of inner max: 0.6840294692665339\n",
      "Iter 275: Final value of inner max: 0.6840329073369503\n",
      "Iter 276: Final value of inner max: 0.6840103355050087\n",
      "Iter 277: Final value of inner max: 0.6839606419205666\n",
      "Iter 278: Final value of inner max: 0.683943765665699\n",
      "Iter 279: Final value of inner max: 0.6839169112264369\n",
      "Iter 280: Final value of inner max: 0.6838958611339331\n",
      "Iter 281: Final value of inner max: 0.6838817587790436\n",
      "Iter 282: Final value of inner max: 0.6838347703218459\n",
      "Iter 283: Final value of inner max: 0.6838364623486995\n",
      "Iter 284: Final value of inner max: 0.6837760793417693\n",
      "Iter 285: Final value of inner max: 0.6837752707704057\n",
      "Iter 286: Final value of inner max: 0.6837496081739665\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 287: Final value of inner max: 0.683701735822728\n",
      "Iter 288: Final value of inner max: 0.6836833859980107\n",
      "Iter 289: Final value of inner max: 0.683678777590394\n",
      "Iter 290: Final value of inner max: 0.6836590974032879\n",
      "Iter 291: Final value of inner max: 0.6835922434865724\n",
      "Iter 292: Final value of inner max: 0.6835798062666161\n",
      "Iter 293: Final value of inner max: 0.6835450796041775\n",
      "Iter 294: Final value of inner max: 0.6835146425664425\n",
      "Iter 295: Final value of inner max: 0.6834697360545396\n",
      "Iter 296: Final value of inner max: 0.6834826292097569\n",
      "Iter 297: Final value of inner max: 0.6834393246474078\n",
      "Iter 298: Final value of inner max: 0.6834219286069811\n",
      "Iter 299: Final value of inner max: 0.6834126716198564\n",
      "Iter 300: Final value of inner max: 0.6833976822183332\n",
      "Iter 301: Final value of inner max: 0.6833754934370518\n",
      "Iter 302: Final value of inner max: 0.6833506166189909\n",
      "Iter 303: Final value of inner max: 0.683332399433479\n",
      "Iter 304: Final value of inner max: 0.683298641363036\n",
      "Iter 305: Final value of inner max: 0.6832802148500339\n",
      "Iter 306: Final value of inner max: 0.6832477049529553\n",
      "Iter 307: Final value of inner max: 0.6832347201555967\n",
      "Iter 308: Final value of inner max: 0.683206840255441\n",
      "Iter 309: Final value of inner max: 0.6831605143100024\n",
      "Iter 310: Final value of inner max: 0.6831480919439811\n",
      "Iter 311: Final value of inner max: 0.6831395919620991\n",
      "Iter 312: Final value of inner max: 0.6830701650679112\n",
      "Iter 313: Final value of inner max: 0.683069523125887\n",
      "Iter 314: Final value of inner max: 0.6830668982112094\n",
      "Iter 315: Final value of inner max: 0.683048404545831\n",
      "Iter 316: Final value of inner max: 0.6830058641731739\n",
      "Iter 317: Final value of inner max: 0.6830005816122596\n",
      "Iter 318: Final value of inner max: 0.6829478199779988\n",
      "Iter 319: Final value of inner max: 0.6829610707674043\n",
      "Iter 320: Final value of inner max: 0.6829248927533627\n",
      "Iter 321: Final value of inner max: 0.6829291519522667\n",
      "Iter 322: Final value of inner max: 0.6829158298594067\n",
      "Iter 323: Final value of inner max: 0.6829104640334845\n",
      "Iter 324: Final value of inner max: 0.6828701952844858\n",
      "Iter 325: Final value of inner max: 0.6828573288023472\n",
      "Iter 326: Final value of inner max: 0.6828220848400552\n",
      "Iter 327: Final value of inner max: 0.6828095031682258\n",
      "Iter 328: Final value of inner max: 0.6827831854705281\n",
      "Iter 329: Final value of inner max: 0.6827921372652054\n",
      "Iter 330: Final value of inner max: 0.6827459343630333\n",
      "Iter 331: Final value of inner max: 0.6827100143580053\n",
      "Iter 332: Final value of inner max: 0.6826836361333304\n",
      "Iter 333: Final value of inner max: 0.6826610454171896\n",
      "Iter 334: Final value of inner max: 0.6826132574677468\n",
      "Iter 335: Final value of inner max: 0.6826122116528968\n",
      "Iter 336: Final value of inner max: 0.6825946503579925\n",
      "Iter 337: Final value of inner max: 0.6825475710862279\n",
      "Iter 338: Final value of inner max: 0.682519645036605\n",
      "Iter 339: Final value of inner max: 0.6824985603570245\n",
      "Iter 340: Final value of inner max: 0.6824832691925182\n",
      "Iter 341: Final value of inner max: 0.682466310538543\n",
      "Iter 342: Final value of inner max: 0.6824467881621266\n",
      "Iter 343: Final value of inner max: 0.6824224822035041\n",
      "Iter 344: Final value of inner max: 0.6824025498857115\n",
      "Iter 345: Final value of inner max: 0.68234579205513\n",
      "Iter 346: Final value of inner max: 0.6823048110306263\n",
      "Iter 347: Final value of inner max: 0.6823219794510489\n",
      "Iter 348: Final value of inner max: 0.6822726177424192\n",
      "Iter 349: Final value of inner max: 0.6822741337865591\n",
      "Iter 350: Final value of inner max: 0.682262207491942\n",
      "Iter 351: Final value of inner max: 0.6822456572369965\n",
      "Iter 352: Final value of inner max: 0.6822588175535202\n",
      "Iter 353: Final value of inner max: 0.6821813564002513\n",
      "Iter 354: Final value of inner max: 0.6821440849453211\n",
      "Iter 355: Final value of inner max: 0.6821345968544483\n",
      "Iter 356: Final value of inner max: 0.6821221946179867\n",
      "Iter 357: Final value of inner max: 0.6820989096394797\n",
      "Iter 358: Final value of inner max: 0.6820499277859926\n",
      "Iter 359: Final value of inner max: 0.6820612569470791\n",
      "Iter 360: Final value of inner max: 0.6820669415593148\n",
      "Iter 361: Final value of inner max: 0.6820008223503828\n",
      "Iter 362: Final value of inner max: 0.6819850550215887\n",
      "Iter 363: Final value of inner max: 0.6820013776421547\n",
      "Iter 364: Final value of inner max: 0.6819516463522661\n",
      "Iter 365: Final value of inner max: 0.681933486367766\n",
      "Iter 366: Final value of inner max: 0.6819045183062553\n",
      "Iter 367: Final value of inner max: 0.6819049803167582\n",
      "Iter 368: Final value of inner max: 0.6818626472674457\n",
      "Iter 369: Final value of inner max: 0.6818243599683046\n",
      "Iter 370: Final value of inner max: 0.6817877334356308\n",
      "Iter 371: Final value of inner max: 0.681740616261959\n",
      "Iter 372: Final value of inner max: 0.6817360529719818\n",
      "Iter 373: Final value of inner max: 0.6817228033682002\n",
      "Iter 374: Final value of inner max: 0.6817245414853097\n",
      "Iter 375: Final value of inner max: 0.6816857717931271\n",
      "Iter 376: Final value of inner max: 0.6816564547270536\n",
      "Iter 377: Final value of inner max: 0.6816024714708329\n",
      "Iter 378: Final value of inner max: 0.6816141933993582\n",
      "Iter 379: Final value of inner max: 0.6815965289154711\n",
      "Iter 380: Final value of inner max: 0.6815929458290338\n",
      "Iter 381: Final value of inner max: 0.6815531316399575\n",
      "Iter 382: Final value of inner max: 0.6815522227436304\n",
      "Iter 383: Final value of inner max: 0.6815008345246316\n",
      "Iter 384: Final value of inner max: 0.6814898881541522\n",
      "Iter 385: Final value of inner max: 0.6814958003163337\n",
      "Iter 386: Final value of inner max: 0.6814771503955126\n",
      "Iter 387: Final value of inner max: 0.6813935260474682\n",
      "Iter 388: Final value of inner max: 0.681379854427691\n",
      "Iter 389: Final value of inner max: 0.6813166403770446\n",
      "Iter 390: Final value of inner max: 0.6813254785065407\n",
      "Iter 391: Final value of inner max: 0.6813022788328544\n",
      "Iter 392: Final value of inner max: 0.6812812729179859\n",
      "Iter 393: Final value of inner max: 0.6812465240074452\n",
      "Iter 394: Final value of inner max: 0.6812354157418564\n",
      "Iter 395: Final value of inner max: 0.6812216842364258\n",
      "Iter 396: Final value of inner max: 0.6811990464478731\n",
      "Iter 397: Final value of inner max: 0.6811707420435038\n",
      "Iter 398: Final value of inner max: 0.6811257168398523\n",
      "Iter 399: Final value of inner max: 0.6811165450513362\n",
      "Iter 400: Final value of inner max: 0.6810588317364454\n",
      "Iter 401: Final value of inner max: 0.6810540645569563\n",
      "Iter 402: Final value of inner max: 0.6810262033053782\n",
      "Iter 403: Final value of inner max: 0.6810231605917215\n",
      "Iter 404: Final value of inner max: 0.6810294085741044\n",
      "Iter 405: Final value of inner max: 0.6809757553216687\n",
      "Iter 406: Final value of inner max: 0.6809538919478655\n",
      "Iter 407: Final value of inner max: 0.6809046761829289\n",
      "Iter 408: Final value of inner max: 0.6808708769083023\n",
      "Iter 409: Final value of inner max: 0.6808530952490937\n",
      "Iter 410: Final value of inner max: 0.6808339596552944\n",
      "Iter 411: Final value of inner max: 0.6808137151680489\n",
      "Iter 412: Final value of inner max: 0.6808007769286633\n",
      "Iter 413: Final value of inner max: 0.6808102086186409\n",
      "Iter 414: Final value of inner max: 0.680730353370309\n",
      "Iter 415: Final value of inner max: 0.6807350979519562\n",
      "Iter 416: Final value of inner max: 0.6807073487731801\n",
      "Iter 417: Final value of inner max: 0.6806520561128855\n",
      "Iter 418: Final value of inner max: 0.6806599878113795\n",
      "Iter 419: Final value of inner max: 0.6806411802259517\n",
      "Iter 420: Final value of inner max: 0.6806333106918256\n",
      "Iter 421: Final value of inner max: 0.680589192940163\n",
      "Iter 422: Final value of inner max: 0.6805705390011854\n",
      "Iter 423: Final value of inner max: 0.6805023039877415\n",
      "Iter 424: Final value of inner max: 0.6805117439478636\n",
      "Iter 425: Final value of inner max: 0.6804799315333367\n",
      "Iter 426: Final value of inner max: 0.6804685005391495\n",
      "Iter 427: Final value of inner max: 0.6804534419625998\n",
      "Iter 428: Final value of inner max: 0.680439219891144\n",
      "Iter 429: Final value of inner max: 0.6803606385737658\n",
      "Iter 430: Final value of inner max: 0.6803888395428658\n",
      "Iter 431: Final value of inner max: 0.680349140629064\n",
      "Iter 432: Final value of inner max: 0.680368613600731\n",
      "Iter 433: Final value of inner max: 0.6803330185264349\n",
      "Iter 434: Final value of inner max: 0.6802726297080517\n",
      "Iter 435: Final value of inner max: 0.6802352663874627\n",
      "Iter 436: Final value of inner max: 0.6802909424155951\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 437: Final value of inner max: 0.6802034498803455\n",
      "Iter 438: Final value of inner max: 0.680212764069438\n",
      "Iter 439: Final value of inner max: 0.6801758870482445\n",
      "Iter 440: Final value of inner max: 0.6801415503107421\n",
      "Iter 441: Final value of inner max: 0.6801134432493906\n",
      "Iter 442: Final value of inner max: 0.6801000665256396\n",
      "Iter 443: Final value of inner max: 0.680065809735424\n",
      "Iter 444: Final value of inner max: 0.6800542785240148\n",
      "Iter 445: Final value of inner max: 0.6799929788708687\n",
      "Iter 446: Final value of inner max: 0.679985922202468\n",
      "Iter 447: Final value of inner max: 0.6799692156165839\n",
      "Iter 448: Final value of inner max: 0.6799416894612854\n",
      "Iter 449: Final value of inner max: 0.6799261550076536\n",
      "Iter 450: Final value of inner max: 0.6799091942802048\n",
      "Iter 451: Final value of inner max: 0.6798969122767449\n",
      "Iter 452: Final value of inner max: 0.6798791015893222\n",
      "Iter 453: Final value of inner max: 0.6798612185567618\n",
      "Iter 454: Final value of inner max: 0.6798561360687018\n",
      "Iter 455: Final value of inner max: 0.6798064255084202\n",
      "Iter 456: Final value of inner max: 0.6797526107728482\n",
      "Iter 457: Final value of inner max: 0.6797521365153796\n",
      "Iter 458: Final value of inner max: 0.6797196860403918\n",
      "Iter 459: Final value of inner max: 0.6797053042799235\n",
      "Iter 460: Final value of inner max: 0.6796877005540292\n",
      "Iter 461: Final value of inner max: 0.6796771083772183\n",
      "Iter 462: Final value of inner max: 0.6796465963870287\n",
      "Iter 463: Final value of inner max: 0.6796241180278799\n",
      "Iter 464: Final value of inner max: 0.6796388170868157\n",
      "Iter 465: Final value of inner max: 0.6796093708276749\n",
      "Iter 466: Final value of inner max: 0.6795268365740776\n",
      "Iter 467: Final value of inner max: 0.6795281331241131\n",
      "Iter 468: Final value of inner max: 0.6795028756906287\n",
      "Iter 469: Final value of inner max: 0.6794827160951471\n",
      "Iter 470: Final value of inner max: 0.6794504988193513\n",
      "Iter 471: Final value of inner max: 0.6794486918300391\n",
      "Iter 472: Final value of inner max: 0.6794036746025085\n",
      "Iter 473: Final value of inner max: 0.6793888878195578\n",
      "Iter 474: Final value of inner max: 0.679360518856603\n",
      "Iter 475: Final value of inner max: 0.6793058822304011\n",
      "Iter 476: Final value of inner max: 0.679310963600874\n",
      "Iter 477: Final value of inner max: 0.6792792215786452\n",
      "Iter 478: Final value of inner max: 0.6792665118724108\n",
      "Iter 479: Final value of inner max: 0.6792211055513196\n",
      "Iter 480: Final value of inner max: 0.679202571722991\n",
      "Iter 481: Final value of inner max: 0.6791837750142312\n",
      "Iter 482: Final value of inner max: 0.6791802467252923\n",
      "Iter 483: Final value of inner max: 0.6791951147466898\n",
      "Iter 484: Final value of inner max: 0.6791512668253374\n",
      "Iter 485: Final value of inner max: 0.6791252661173893\n",
      "Iter 486: Final value of inner max: 0.6791119163483381\n",
      "Iter 487: Final value of inner max: 0.6790710425376892\n",
      "Iter 488: Final value of inner max: 0.6790544482945868\n",
      "Iter 489: Final value of inner max: 0.6790110080781802\n",
      "Iter 490: Final value of inner max: 0.678993552212358\n",
      "Iter 491: Final value of inner max: 0.678973422101089\n",
      "Iter 492: Final value of inner max: 0.6789539454465845\n",
      "Iter 493: Final value of inner max: 0.6789380879230664\n",
      "Iter 494: Final value of inner max: 0.6789024476742319\n",
      "Iter 495: Final value of inner max: 0.6788473685085774\n",
      "Iter 496: Final value of inner max: 0.6788678987324237\n",
      "Iter 497: Final value of inner max: 0.6788362301346886\n",
      "Iter 498: Final value of inner max: 0.678829535394907\n",
      "Iter 499: Final value of inner max: 0.6788089331186997\n",
      "Boosting took 94.218552\n"
     ]
    }
   ],
   "source": [
    "t_s = time.time()\n",
    "use_dual = True\n",
    "# fst is the trained GBDT object, xport is the transport map in pairs.  stuff is extra information about a\n",
    "# solution that is ambiguous (e.g. a point is spread out over multiple outputs)\n",
    "fst, xport, stuff = txl.boost_dual(\n",
    "    X_train,\n",
    "    y_train,\n",
    "    C,\n",
    "    n_iter,\n",
    "    X_test=X_test,\n",
    "    y_test=y_test,\n",
    "    pred=None,\n",
    "    eps=eps,       # perturbation budget\n",
    "    param=param,   # parameters, above\n",
    "    verbose=_verbose,\n",
    "    verify=True,\n",
    "    lowg=0.0,      # lower and upper guesses for the bisection method.  A smaller range will result in \n",
    "    highg=15.0,    # a shorter runtime, but the optimal eta could be missed.  See Appendix B.3.\n",
    "    method='brentq',\n",
    "    dtype=_dtype,\n",
    "    outfunc=None,   # how to save the data to a tensorboard file, see `german_fair_age.py`\n",
    ")\n",
    "\n",
    "print(\"Boosting took %f\" % (time.time() - t_s))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Pi from the boosting solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constraint value: 0.9999999995500375\n"
     ]
    }
   ],
   "source": [
    "final = 0\n",
    "if use_dual:\n",
    "    dists = Corig[xport, np.arange(Corig.shape[1])]\n",
    "    \n",
    "    # calculate the exact solution to the linear system with two equations and two unknowns.\n",
    "    if stuff is not None:\n",
    "        dists[stuff[0]] = 0\n",
    "        final += Corig[stuff[0], stuff[2]]*stuff[3] + Corig[stuff[0], stuff[1]]*(1/n - stuff[3])\n",
    "    \n",
    "    final += dists.sum()/n\n",
    "                       \n",
    "# should be equal to eps\n",
    "print(\"Constraint value: {}\".format(final))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the quality of the fair GBDT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use the trained model to classify the test set\n",
    "preds = fst.predict(dtest)\n",
    "y_guess = np.array([1 if pval > 0.5 else 0 for pval in preds])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy is 0.735000\n",
      "Protected class:\n",
      "TPR: 0.7222222222222222; TNR: 0.631578947368421\n",
      "Priv class:\n",
      "TPR: 0.7142857142857143; TNR: 0.7603305785123967\n",
      "Gap RMS: 0.09121395364871711, Gap MAX: 0.1287516311439757\n",
      "Average odds difference is 0.068344\n",
      "Equal opportunity difference is 0.007937\n",
      "Statistical parity difference is 0.178577\n"
     ]
    }
   ],
   "source": [
    "# accuracy and group fairness metrics\n",
    "p0, p1, _ = script.balanced_accuracy(fst, dtest, y_test)\n",
    "p0t, p1t, _ = script.balanced_accuracy(fst, dtrain, y_train)\n",
    "age_res = script.group_metrics(\n",
    "    y_test,\n",
    "    y_guess,\n",
    "    yt_age,\n",
    "    label_good=1,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train balanced: 0.8607142857142858\n",
      "test balanced: 0.7297619047619048\n"
     ]
    }
   ],
   "source": [
    "print(\"train balanced: {}\".format(.5 * (p0t + p1t)))\n",
    "print(\"test balanced: {}\".format(.5 * (p0 + p1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "status consistency: 0.96\n"
     ]
    }
   ],
   "source": [
    "# status consistency\n",
    "scons = german_proc.status_cons(fst, X_test, status_inds)\n",
    "\n",
    "print(\"status consistency: {}\".format(scons.sum()/X_test.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37fair",
   "language": "python",
   "name": "fair"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
