{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reproduction of Adult dataset experiments\n",
    "\n",
    "In this notebook we reproduce the results from Table 2 of the DECAF paper. We compare various methods for generating debiased data using the DECAF model against synthetic data generated using benchmark models GAN, WGAN-GP and FairGAN. As described in the paper we run all experiments (as implemented in this notebook) 10 times and avarage the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "from data import load_adult, preprocess_adult\n",
    "from metrics import DP, FTU\n",
    "from train import train_decaf, train_fairgan, train_vanilla_gan, train_wgan_gp\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646.0</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721.0</td>\n",
       "      <td>11th</td>\n",
       "      <td>7.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409.0</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>Cuba</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  age         workclass    fnlwgt  education  education-num  \\\n",
       "0  39         State-gov   77516.0  Bachelors           13.0   \n",
       "1  50  Self-emp-not-inc   83311.0  Bachelors           13.0   \n",
       "2  38           Private  215646.0    HS-grad            9.0   \n",
       "3  53           Private  234721.0       11th            7.0   \n",
       "4  28           Private  338409.0  Bachelors           13.0   \n",
       "\n",
       "       marital-status         occupation   relationship   race     sex  \\\n",
       "0       Never-married       Adm-clerical  Not-in-family  White    Male   \n",
       "1  Married-civ-spouse    Exec-managerial        Husband  White    Male   \n",
       "2            Divorced  Handlers-cleaners  Not-in-family  White    Male   \n",
       "3  Married-civ-spouse  Handlers-cleaners        Husband  Black    Male   \n",
       "4  Married-civ-spouse     Prof-specialty           Wife  Black  Female   \n",
       "\n",
       "   capital-gain  capital-loss  hours-per-week native-country income  \n",
       "0        2174.0           0.0            40.0  United-States  <=50K  \n",
       "1           0.0           0.0            13.0  United-States  <=50K  \n",
       "2           0.0           0.0            40.0  United-States  <=50K  \n",
       "3           0.0           0.0            40.0  United-States  <=50K  \n",
       "4           0.0           0.0            40.0           Cuba  <=50K  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = load_adult()\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocess the data next in order to make it suitable for training models on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.301370</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.043350</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.615385</td>\n",
       "      <td>0.6</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.02174</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.397959</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.452055</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.047274</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.307692</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.122449</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.287671</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.136877</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.461538</td>\n",
       "      <td>0.6</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.397959</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.493151</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.149792</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.461538</td>\n",
       "      <td>0.4</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.397959</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.150685</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.219998</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.384615</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.397959</td>\n",
       "      <td>0.3</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        age  workclass    fnlwgt  education  education-num  marital-status  \\\n",
       "0  0.301370   0.833333  0.043350   0.000000       0.800000        0.333333   \n",
       "1  0.452055   0.166667  0.047274   0.000000       0.800000        0.000000   \n",
       "2  0.287671   0.000000  0.136877   0.200000       0.533333        0.166667   \n",
       "3  0.493151   0.000000  0.149792   0.133333       0.400000        0.000000   \n",
       "4  0.150685   0.000000  0.219998   0.000000       0.800000        0.000000   \n",
       "\n",
       "   occupation  relationship  race  sex  capital-gain  capital-loss  \\\n",
       "0    0.615385           0.6   0.0  1.0       0.02174           0.0   \n",
       "1    0.307692           0.4   0.0  1.0       0.00000           0.0   \n",
       "2    0.461538           0.6   0.0  1.0       0.00000           0.0   \n",
       "3    0.461538           0.4   1.0  1.0       0.00000           0.0   \n",
       "4    0.384615           0.0   1.0  0.0       0.00000           0.0   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0        0.397959             0.0     1.0  \n",
       "1        0.122449             0.0     1.0  \n",
       "2        0.397959             0.0     1.0  \n",
       "3        0.397959             0.0     1.0  \n",
       "4        0.397959             0.3     1.0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = preprocess_adult(dataset)\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the dataset into train and test folds. Test fold size is 2000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of train set: 43222\n",
      "Size of test set: 2000\n"
     ]
    }
   ],
   "source": [
    "# Split data into train and testing sets\n",
    "dataset_train, dataset_test = train_test_split(dataset, test_size=2000,\n",
    "                                               stratify=dataset['income'])\n",
    "\n",
    "print('Size of train set:', len(dataset_train))\n",
    "print('Size of test set:', len(dataset_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the DAG\n",
    "\n",
    "We need to define a DAG which captures the biases of the dataset. As described in the DECAF paper normally a causal discovery algorithm is used. In this notebook we simply copy the DAG which as described in the Zhang et al. paper which is the one also used in the DECAF paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8, 6], [8, 14], [8, 12], [8, 3], [8, 5], [0, 6], [0, 12], [0, 14], [0, 1], [0, 5], [0, 3], [0, 7], [9, 6], [9, 5], [9, 14], [9, 1], [9, 3], [9, 7], [13, 5], [13, 12], [13, 3], [13, 1], [13, 14], [13, 7], [5, 6], [5, 12], [5, 14], [5, 1], [5, 7], [5, 3], [3, 6], [3, 12], [3, 14], [3, 1], [3, 7], [6, 14], [12, 14], [1, 14], [7, 14]]\n"
     ]
    }
   ],
   "source": [
    "# Define DAG for Adult dataset\n",
    "dag = [\n",
    "    # Edges from race\n",
    "    ['race', 'occupation'],\n",
    "    ['race', 'income'],\n",
    "    ['race', 'hours-per-week'],\n",
    "    ['race', 'education'],\n",
    "    ['race', 'marital-status'],\n",
    "\n",
    "    # Edges from age\n",
    "    ['age', 'occupation'],\n",
    "    ['age', 'hours-per-week'],\n",
    "    ['age', 'income'],\n",
    "    ['age', 'workclass'],\n",
    "    ['age', 'marital-status'],\n",
    "    ['age', 'education'],\n",
    "    ['age', 'relationship'],\n",
    "    \n",
    "    # Edges from sex\n",
    "    ['sex', 'occupation'],\n",
    "    ['sex', 'marital-status'],\n",
    "    ['sex', 'income'],\n",
    "    ['sex', 'workclass'],\n",
    "    ['sex', 'education'],\n",
    "    ['sex', 'relationship'],\n",
    "    \n",
    "    # Edges from native country\n",
    "    ['native-country', 'marital-status'],\n",
    "    ['native-country', 'hours-per-week'],\n",
    "    ['native-country', 'education'],\n",
    "    ['native-country', 'workclass'],\n",
    "    ['native-country', 'income'],\n",
    "    ['native-country', 'relationship'],\n",
    "    \n",
    "    # Edges from marital status\n",
    "    ['marital-status', 'occupation'],\n",
    "    ['marital-status', 'hours-per-week'],\n",
    "    ['marital-status', 'income'],\n",
    "    ['marital-status', 'workclass'],\n",
    "    ['marital-status', 'relationship'],\n",
    "    ['marital-status', 'education'],\n",
    "    \n",
    "    # Edges from education\n",
    "    ['education', 'occupation'],\n",
    "    ['education', 'hours-per-week'],\n",
    "    ['education', 'income'],\n",
    "    ['education', 'workclass'],\n",
    "    ['education', 'relationship'],\n",
    "    \n",
    "    # All remaining edges\n",
    "    ['occupation', 'income'],\n",
    "    ['hours-per-week', 'income'],\n",
    "    ['workclass', 'income'],\n",
    "    ['relationship', 'income'],\n",
    "]\n",
    "\n",
    "def dag_to_idx(df, dag):\n",
    "    \"\"\"Convert columns in a DAG to the corresponding indices.\"\"\"\n",
    "\n",
    "    dag_idx = []\n",
    "    for edge in dag:\n",
    "        dag_idx.append([df.columns.get_loc(edge[0]), df.columns.get_loc(edge[1])])\n",
    "\n",
    "    return dag_idx\n",
    "\n",
    "# Convert the DAG to one that can be provided to the DECAF model\n",
    "dag_seed = dag_to_idx(dataset, dag)\n",
    "print(dag_seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's also necessary to define edges we want to remove from the DAG in order to meet the various fairness criteria described in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bias dict FTU: {14: [9]}\n",
      "Bias dict DP: {14: [6, 12, 5, 3, 9, 1, 7]}\n",
      "Bias dict CF: {14: [5, 9]}\n"
     ]
    }
   ],
   "source": [
    "def create_bias_dict(df, edge_map):\n",
    "    \"\"\"\n",
    "    Convert the given edge tuples to a bias dict used for generating\n",
    "    debiased synthetic data.\n",
    "    \"\"\"\n",
    "    bias_dict = {}\n",
    "    for key, val in edge_map.items():\n",
    "        bias_dict[df.columns.get_loc(key)] = [df.columns.get_loc(f) for f in val]\n",
    "    \n",
    "    return bias_dict\n",
    "\n",
    "# Bias dictionary to satisfy FTU\n",
    "bias_dict_ftu = create_bias_dict(dataset, {'income': ['sex']})\n",
    "print('Bias dict FTU:', bias_dict_ftu)\n",
    "\n",
    "# Bias dictionary to satisfy DP\n",
    "bias_dict_dp = create_bias_dict(dataset, {'income': [\n",
    "    'occupation', 'hours-per-week', 'marital-status', 'education', 'sex',\n",
    "    'workclass', 'relationship']})\n",
    "print('Bias dict DP:', bias_dict_dp)\n",
    "\n",
    "# Bias dictionary to satisfy CF\n",
    "bias_dict_cf = create_bias_dict(dataset, {'income': [\n",
    "    'marital-status', 'sex']})\n",
    "print('Bias dict CF:', bias_dict_cf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiments\n",
    "\n",
    "We have loaded and preprocessed the data and we are ready to run the experiments. For each experiment we train a generative model, sample synthetic data from the trained model and then obtain metrics by training and evaluating a downstream multi-layer perceptron using the test fold we generated in the previous section. We use the MLP model from `sklearn` with default parameters which matches the settings described in Appendix D of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_model(dataset_train, dataset_test):\n",
    "    \"\"\"Helper function that prints evaluation metrics.\"\"\"\n",
    "\n",
    "    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']\n",
    "    X_test, y_test = dataset_test.drop(columns=['income']), dataset_test['income']\n",
    "\n",
    "    clf = MLPClassifier()\n",
    "    clf.fit(X_train, y_train)\n",
    "    y_pred = clf.predict(X_test)\n",
    "\n",
    "    precision = precision_score(y_test, y_pred)\n",
    "    recall = recall_score(y_test, y_pred)\n",
    "    auroc = roc_auc_score(y_test, y_pred)\n",
    "    dp = DP(clf, X_test)\n",
    "    ftu = FTU(clf, X_test)\n",
    "\n",
    "    return {'precision': precision, 'recall': recall, 'auroc': auroc,\n",
    "            'dp': dp, 'ftu': ftu}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original dataset\n",
    "\n",
    "As a benchmark we want to first train the downstream model on the original dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/sklearn/neural_network/_multilayer_perceptron.py:692: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'precision': 0.8748451053283767,\n",
       " 'recall': 0.9388297872340425,\n",
       " 'auroc': 0.7657858613589568,\n",
       " 'dp': 0.15769354186140594,\n",
       " 'ftu': 0.008499999999999952}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(dataset_train, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following sections we train various models in order to reproduce the results from Table 2 of the DECAF paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-02-04 16:55:26.745579: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
      "2022-02-04 16:55:26.871962: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]2022-02-04 16:55:27.720874: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)\n",
      "  2%|▏         | 1/50 [00:14<12:09, 14.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [D loss: 0.000034, acc.: 100.00%] [G loss: 43.320641]\n",
      "generated_data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 2/50 [00:28<11:25, 14.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 [D loss: 0.000001, acc.: 100.00%] [G loss: 71.679199]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 3/50 [00:42<11:08, 14.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 [D loss: 0.000004, acc.: 100.00%] [G loss: 101.990181]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 4/50 [00:56<10:48, 14.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 [D loss: 0.000000, acc.: 100.00%] [G loss: 118.899742]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 5/50 [01:10<10:21, 13.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 [D loss: 0.000000, acc.: 100.00%] [G loss: 125.251663]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 6/50 [01:24<10:14, 13.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 [D loss: 0.000000, acc.: 100.00%] [G loss: 135.860107]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 7/50 [01:38<10:04, 14.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 [D loss: 0.000000, acc.: 100.00%] [G loss: 153.603333]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 8/50 [01:52<09:49, 14.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7 [D loss: 0.000000, acc.: 100.00%] [G loss: 155.913116]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 9/50 [02:07<09:44, 14.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 [D loss: 0.000000, acc.: 100.00%] [G loss: 175.700729]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 10/50 [02:21<09:31, 14.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9 [D loss: 0.000000, acc.: 100.00%] [G loss: 191.106064]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 11/50 [02:35<09:05, 13.98s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 [D loss: 0.000000, acc.: 100.00%] [G loss: 200.216995]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 12/50 [02:49<08:58, 14.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11 [D loss: 0.000000, acc.: 100.00%] [G loss: 221.028473]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 13/50 [03:03<08:43, 14.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12 [D loss: 0.000000, acc.: 100.00%] [G loss: 215.032562]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 14/50 [03:19<08:49, 14.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13 [D loss: 0.000000, acc.: 100.00%] [G loss: 220.074982]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 15/50 [03:32<08:18, 14.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14 [D loss: 0.000007, acc.: 100.00%] [G loss: 233.851868]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 16/50 [03:46<08:03, 14.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15 [D loss: 0.000000, acc.: 100.00%] [G loss: 237.539993]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 17/50 [04:00<07:45, 14.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16 [D loss: 0.000000, acc.: 100.00%] [G loss: 226.648666]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 18/50 [04:14<07:31, 14.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17 [D loss: 0.000000, acc.: 100.00%] [G loss: 238.086273]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 19/50 [04:28<07:10, 13.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18 [D loss: 0.000000, acc.: 100.00%] [G loss: 241.048676]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 20/50 [04:41<06:49, 13.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19 [D loss: 0.000000, acc.: 100.00%] [G loss: 249.809219]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 21/50 [04:54<06:33, 13.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20 [D loss: 0.000000, acc.: 100.00%] [G loss: 262.289795]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 22/50 [05:11<06:45, 14.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21 [D loss: 0.000000, acc.: 100.00%] [G loss: 268.167114]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 23/50 [05:24<06:22, 14.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22 [D loss: 0.000000, acc.: 100.00%] [G loss: 274.003418]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 24/50 [05:39<06:11, 14.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23 [D loss: 0.000000, acc.: 100.00%] [G loss: 270.760864]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 25/50 [05:53<05:58, 14.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24 [D loss: 0.000000, acc.: 100.00%] [G loss: 284.079407]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 26/50 [06:08<05:43, 14.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25 [D loss: 0.000000, acc.: 100.00%] [G loss: 289.136902]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 27/50 [06:22<05:30, 14.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26 [D loss: 0.000000, acc.: 100.00%] [G loss: 277.091644]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 28/50 [06:35<05:08, 14.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27 [D loss: 0.000000, acc.: 100.00%] [G loss: 297.106140]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 29/50 [06:49<04:51, 13.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28 [D loss: 0.000000, acc.: 100.00%] [G loss: 292.717499]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 30/50 [07:02<04:35, 13.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29 [D loss: 0.000000, acc.: 100.00%] [G loss: 297.127472]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 31/50 [07:16<04:19, 13.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30 [D loss: 0.000000, acc.: 100.00%] [G loss: 322.954773]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 32/50 [07:31<04:14, 14.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31 [D loss: 0.000000, acc.: 100.00%] [G loss: 310.573669]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 33/50 [07:45<04:01, 14.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32 [D loss: 0.000000, acc.: 100.00%] [G loss: 288.561218]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 34/50 [08:00<03:47, 14.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33 [D loss: 0.000000, acc.: 100.00%] [G loss: 308.145325]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 35/50 [08:13<03:31, 14.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34 [D loss: 0.000000, acc.: 100.00%] [G loss: 321.515198]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 36/50 [08:27<03:13, 13.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "35 [D loss: 0.000000, acc.: 100.00%] [G loss: 333.250549]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 37/50 [08:40<02:56, 13.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36 [D loss: 0.000000, acc.: 100.00%] [G loss: 322.087280]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 38/50 [08:53<02:41, 13.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37 [D loss: 0.000000, acc.: 100.00%] [G loss: 336.577698]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 39/50 [09:06<02:26, 13.31s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38 [D loss: 0.000028, acc.: 100.00%] [G loss: 355.909821]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 40/50 [09:19<02:12, 13.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39 [D loss: 0.000000, acc.: 100.00%] [G loss: 321.536743]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 41/50 [09:32<01:58, 13.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "40 [D loss: 0.000001, acc.: 100.00%] [G loss: 345.735809]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 42/50 [09:45<01:44, 13.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "41 [D loss: 0.000000, acc.: 100.00%] [G loss: 332.080841]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 43/50 [09:58<01:31, 13.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42 [D loss: 0.000000, acc.: 100.00%] [G loss: 360.490906]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 44/50 [10:11<01:18, 13.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43 [D loss: 0.000000, acc.: 100.00%] [G loss: 346.502136]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 45/50 [10:24<01:05, 13.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "44 [D loss: 0.000000, acc.: 100.00%] [G loss: 350.296021]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 46/50 [10:37<00:52, 13.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45 [D loss: 0.000000, acc.: 100.00%] [G loss: 368.256470]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 47/50 [10:50<00:39, 13.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46 [D loss: 0.000000, acc.: 100.00%] [G loss: 340.930298]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 48/50 [11:03<00:26, 13.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47 [D loss: 0.000000, acc.: 100.00%] [G loss: 360.207825]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 49/50 [11:16<00:13, 13.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48 [D loss: 0.000000, acc.: 100.00%] [G loss: 359.248108]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [11:29<00:00, 13.80s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49 [D loss: 0.000000, acc.: 100.00%] [G loss: 385.927917]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Synthetic data generation: 100%|██████████| 338/338 [00:03<00:00, 106.60it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.080582</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.035735</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.6</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.017997</td>\n",
       "      <td>0.026124</td>\n",
       "      <td>0.096906</td>\n",
       "      <td>0.225</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.080582</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.035735</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.6</td>\n",
       "      <td>0.50</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.017997</td>\n",
       "      <td>0.026124</td>\n",
       "      <td>0.096906</td>\n",
       "      <td>0.275</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.080582</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.035735</td>\n",
       "      <td>0.066667</td>\n",
       "      <td>0.933333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.25</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.017997</td>\n",
       "      <td>0.026124</td>\n",
       "      <td>0.096906</td>\n",
       "      <td>0.150</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.080582</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.035735</td>\n",
       "      <td>0.466667</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.230769</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.017997</td>\n",
       "      <td>0.026124</td>\n",
       "      <td>0.096906</td>\n",
       "      <td>0.325</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.080582</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.035735</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.923077</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.50</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.017997</td>\n",
       "      <td>0.026124</td>\n",
       "      <td>0.096906</td>\n",
       "      <td>0.475</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        age  workclass    fnlwgt  education  education-num  marital-status  \\\n",
       "0  0.080582   0.666667  0.035735   0.333333       0.133333        0.333333   \n",
       "1  0.080582   0.333333  0.035735   0.333333       0.533333        0.333333   \n",
       "2  0.080582   0.833333  0.035735   0.066667       0.933333        0.666667   \n",
       "3  0.080582   0.166667  0.035735   0.466667       0.800000        0.166667   \n",
       "4  0.080582   0.833333  0.035735   0.400000       0.533333        0.500000   \n",
       "\n",
       "   occupation  relationship  race  sex  capital-gain  capital-loss  \\\n",
       "0    0.846154           0.6  0.00  0.0      0.017997      0.026124   \n",
       "1    0.846154           0.6  0.50  1.0      0.017997      0.026124   \n",
       "2    0.000000           0.8  0.25  1.0      0.017997      0.026124   \n",
       "3    0.230769           0.4  0.00  0.0      0.017997      0.026124   \n",
       "4    0.923077           0.2  0.50  0.0      0.017997      0.026124   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0        0.096906           0.225     0.0  \n",
       "1        0.096906           0.275     1.0  \n",
       "2        0.096906           0.150     1.0  \n",
       "3        0.096906           0.325     1.0  \n",
       "4        0.096906           0.475     1.0  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_vanilla_gan(dataset_train)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.6971736204576043,\n",
       " 'recall': 0.6888297872340425,\n",
       " 'auroc': 0.39078586135895677,\n",
       " 'dp': 0.7117996939976097,\n",
       " 'ftu': 0.47}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### WGAN-GP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 1/50 [00:05<04:48,  5.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 | disc_loss: 0.15245389938354492 | gen_loss: -0.06989498436450958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 2/50 [00:10<03:54,  4.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | disc_loss: 0.0009180586785078049 | gen_loss: -0.0033318146597594023\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 3/50 [00:14<03:36,  4.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2 | disc_loss: -0.004699191078543663 | gen_loss: 0.013895398937165737\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 4/50 [00:18<03:23,  4.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3 | disc_loss: 0.4139225482940674 | gen_loss: 0.014603140763938427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 5/50 [00:22<03:14,  4.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 4 | disc_loss: 1.4078881740570068 | gen_loss: 0.008926557376980782\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 6/50 [00:27<03:11,  4.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 5 | disc_loss: -0.012374036014080048 | gen_loss: 0.04716317355632782\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 7/50 [00:32<03:16,  4.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 6 | disc_loss: 1.080539345741272 | gen_loss: -0.0200481116771698\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 8/50 [00:36<03:09,  4.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 7 | disc_loss: 0.46246033906936646 | gen_loss: 0.014931818470358849\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 9/50 [00:40<03:02,  4.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 8 | disc_loss: 1.791371464729309 | gen_loss: 0.03856757655739784\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 10/50 [00:45<03:01,  4.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 9 | disc_loss: 0.133454367518425 | gen_loss: 0.07898350059986115\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 11/50 [00:49<02:54,  4.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 10 | disc_loss: -0.056419532746076584 | gen_loss: 0.029874488711357117\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 12/50 [00:54<02:48,  4.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 11 | disc_loss: 3.5388360023498535 | gen_loss: 0.047961827367544174\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 13/50 [00:59<02:49,  4.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 12 | disc_loss: -0.036514509469270706 | gen_loss: 0.0739879384636879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 14/50 [01:03<02:44,  4.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 13 | disc_loss: -0.043887440115213394 | gen_loss: 0.05410666763782501\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 15/50 [01:10<02:59,  5.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 14 | disc_loss: -0.060908351093530655 | gen_loss: 0.06259451061487198\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 16/50 [01:16<03:09,  5.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 15 | disc_loss: -0.008532661944627762 | gen_loss: 0.06303340196609497\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 17/50 [01:21<02:53,  5.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 16 | disc_loss: -0.05893365666270256 | gen_loss: 0.05481185391545296\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 18/50 [01:25<02:37,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 17 | disc_loss: -0.06201693415641785 | gen_loss: -0.008117406629025936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 19/50 [01:29<02:25,  4.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 18 | disc_loss: -0.07819495350122452 | gen_loss: 0.014502013102173805\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 20/50 [01:33<02:15,  4.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 19 | disc_loss: -0.06925803422927856 | gen_loss: 0.02424549125134945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 21/50 [01:38<02:12,  4.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 20 | disc_loss: 0.05352470278739929 | gen_loss: 0.04082176834344864\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 22/50 [01:42<02:08,  4.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 21 | disc_loss: 0.19452141225337982 | gen_loss: 0.017852380871772766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 23/50 [01:47<02:04,  4.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 22 | disc_loss: 0.7697275876998901 | gen_loss: -0.018885279074311256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 24/50 [01:52<02:01,  4.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 23 | disc_loss: -0.07032063603401184 | gen_loss: 0.07426878809928894\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 25/50 [01:56<01:54,  4.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 24 | disc_loss: -0.07268654555082321 | gen_loss: 0.036055222153663635\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 26/50 [02:00<01:47,  4.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 25 | disc_loss: 4.07672643661499 | gen_loss: 0.02717735432088375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 27/50 [02:05<01:42,  4.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 26 | disc_loss: -0.08391407877206802 | gen_loss: 0.046227939426898956\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 28/50 [02:10<01:39,  4.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 27 | disc_loss: -0.02380479872226715 | gen_loss: 0.03740321844816208\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 29/50 [02:15<01:40,  4.81s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 28 | disc_loss: 0.025824055075645447 | gen_loss: -0.0008147454354912043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 30/50 [02:21<01:40,  5.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 29 | disc_loss: 0.019960414618253708 | gen_loss: 0.0661730095744133\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 31/50 [02:26<01:36,  5.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 30 | disc_loss: -0.06524480134248734 | gen_loss: -0.016044048592448235\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 32/50 [02:31<01:29,  5.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 31 | disc_loss: -0.08846060186624527 | gen_loss: 0.03986186906695366\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 33/50 [02:35<01:23,  4.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 32 | disc_loss: -0.08835487812757492 | gen_loss: 0.03278104215860367\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 34/50 [02:40<01:17,  4.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 33 | disc_loss: -0.060576777905225754 | gen_loss: 0.03870237246155739\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 35/50 [02:44<01:10,  4.69s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 34 | disc_loss: 0.2938733696937561 | gen_loss: 0.04040674865245819\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 36/50 [02:49<01:07,  4.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 35 | disc_loss: -0.07837940007448196 | gen_loss: 0.04813133180141449\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 37/50 [02:54<01:02,  4.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 36 | disc_loss: 0.9505723714828491 | gen_loss: -0.008810448460280895\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 38/50 [02:59<00:57,  4.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 37 | disc_loss: -0.0900762602686882 | gen_loss: 0.04091276228427887\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 39/50 [03:03<00:51,  4.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 38 | disc_loss: 0.31794506311416626 | gen_loss: 0.0437326580286026\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 40/50 [03:07<00:44,  4.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 39 | disc_loss: -0.09149685502052307 | gen_loss: 0.02371656522154808\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 41/50 [03:11<00:39,  4.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 40 | disc_loss: 0.10043340921401978 | gen_loss: 0.1226792186498642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 42/50 [03:15<00:34,  4.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 41 | disc_loss: -0.05715509504079819 | gen_loss: 0.04945666715502739\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 43/50 [03:20<00:29,  4.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 42 | disc_loss: -0.08194859325885773 | gen_loss: 0.015959346666932106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 44/50 [03:24<00:25,  4.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 43 | disc_loss: -0.05691428855061531 | gen_loss: 0.04214879125356674\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 45/50 [03:29<00:22,  4.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 44 | disc_loss: -0.05352303385734558 | gen_loss: 0.04443015158176422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 46/50 [03:34<00:17,  4.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 45 | disc_loss: -0.07117447257041931 | gen_loss: 0.08183664828538895\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 47/50 [03:38<00:13,  4.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 46 | disc_loss: -0.0995122417807579 | gen_loss: 0.05079025402665138\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 48/50 [03:42<00:08,  4.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 47 | disc_loss: 0.03659405559301376 | gen_loss: 0.04000601917505264\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 49/50 [03:46<00:04,  4.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 48 | disc_loss: -0.08710119128227234 | gen_loss: 0.041870132088661194\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:51<00:00,  4.62s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 49 | disc_loss: -0.10412894934415817 | gen_loss: 0.05354977026581764\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Synthetic data generation: 100%|██████████| 87/87 [00:01<00:00, 69.38it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.415160</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.136200</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.307692</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.000021</td>\n",
       "      <td>0.041940</td>\n",
       "      <td>0.407321</td>\n",
       "      <td>0.375</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.452994</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.128935</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.846154</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.026776</td>\n",
       "      <td>0.029268</td>\n",
       "      <td>0.403623</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.500516</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.204948</td>\n",
       "      <td>0.066667</td>\n",
       "      <td>0.933333</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-0.028636</td>\n",
       "      <td>0.001469</td>\n",
       "      <td>0.480637</td>\n",
       "      <td>0.400</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.208567</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.118594</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.307692</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.027901</td>\n",
       "      <td>0.015777</td>\n",
       "      <td>0.280495</td>\n",
       "      <td>0.475</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.430072</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.126853</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.466667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.384615</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.030149</td>\n",
       "      <td>0.020552</td>\n",
       "      <td>0.380429</td>\n",
       "      <td>0.650</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        age  workclass    fnlwgt  education  education-num  marital-status  \\\n",
       "0  0.415160   0.333333  0.136200   0.333333       0.600000        0.166667   \n",
       "1  0.452994   0.666667  0.128935   0.400000       0.800000        0.500000   \n",
       "2  0.500516   0.166667  0.204948   0.066667       0.933333        0.666667   \n",
       "3  0.208567   0.500000  0.118594   0.600000       0.800000        0.833333   \n",
       "4  0.430072   0.166667  0.126853   0.666667       0.466667        0.000000   \n",
       "\n",
       "   occupation  relationship  race  sex  capital-gain  capital-loss  \\\n",
       "0    0.307692           1.0  1.00  1.0      0.000021      0.041940   \n",
       "1    0.846154           0.2  0.75  0.0     -0.026776      0.029268   \n",
       "2    1.000000           0.0  0.00  1.0     -0.028636      0.001469   \n",
       "3    0.307692           0.0  1.00  0.0     -0.027901      0.015777   \n",
       "4    0.384615           0.0  0.00  0.0     -0.030149      0.020552   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0        0.407321           0.375     0.0  \n",
       "1        0.403623           0.100     0.0  \n",
       "2        0.480637           0.400     0.0  \n",
       "3        0.280495           0.475     0.0  \n",
       "4        0.380429           0.650     1.0  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_wgan_gp(dataset_train)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.7207890743550834,\n",
       " 'recall': 0.3158244680851064,\n",
       " 'auroc': 0.47242836307481123,\n",
       " 'dp': 0.10060642419000054,\n",
       " 'ftu': 0.14}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FairGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "cache/adult.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-02-04 17:11:43.219249: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pretrain_Epoch:0, trainLoss:0.018377, validLoss:0.007400, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:1, trainLoss:0.007770, validLoss:0.007489, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:2, trainLoss:0.007857, validLoss:0.007542, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:3, trainLoss:0.007895, validLoss:0.007547, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:4, trainLoss:0.007891, validLoss:0.006571, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:5, trainLoss:0.002940, validLoss:0.002490, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:6, trainLoss:0.002588, validLoss:0.002484, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:7, trainLoss:0.002579, validLoss:0.002462, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:8, trainLoss:0.002581, validLoss:0.002461, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:9, trainLoss:0.002567, validLoss:0.002443, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:10, trainLoss:0.002573, validLoss:0.002476, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:11, trainLoss:0.002570, validLoss:0.002426, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:12, trainLoss:0.002569, validLoss:0.002483, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:13, trainLoss:0.002560, validLoss:0.002377, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:14, trainLoss:0.002559, validLoss:0.002425, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:15, trainLoss:0.002530, validLoss:0.002426, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:16, trainLoss:0.002519, validLoss:0.002468, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:17, trainLoss:0.002530, validLoss:0.002368, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:18, trainLoss:0.002519, validLoss:0.002354, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:19, trainLoss:0.002518, validLoss:0.002378, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:20, trainLoss:0.002515, validLoss:0.002413, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:21, trainLoss:0.002518, validLoss:0.002381, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:22, trainLoss:0.002514, validLoss:0.002395, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:23, trainLoss:0.002518, validLoss:0.002355, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:24, trainLoss:0.002512, validLoss:0.002376, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:25, trainLoss:0.002508, validLoss:0.002399, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:26, trainLoss:0.002509, validLoss:0.002375, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:27, trainLoss:0.002510, validLoss:0.002362, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:28, trainLoss:0.002503, validLoss:0.002424, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:29, trainLoss:0.002510, validLoss:0.002372, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:30, trainLoss:0.002491, validLoss:0.002377, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:31, trainLoss:0.002478, validLoss:0.002377, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:32, trainLoss:0.002478, validLoss:0.002295, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:33, trainLoss:0.002462, validLoss:0.002328, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:34, trainLoss:0.002460, validLoss:0.002335, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:35, trainLoss:0.002391, validLoss:0.002229, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:36, trainLoss:0.002308, validLoss:0.002176, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:37, trainLoss:0.002276, validLoss:0.002142, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:38, trainLoss:0.002283, validLoss:0.002074, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:39, trainLoss:0.002275, validLoss:0.002157, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:40, trainLoss:0.002270, validLoss:0.002135, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:41, trainLoss:0.002265, validLoss:0.002140, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:42, trainLoss:0.002273, validLoss:0.002175, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:43, trainLoss:0.002271, validLoss:0.002114, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:44, trainLoss:0.002269, validLoss:0.002204, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:45, trainLoss:0.002263, validLoss:0.002157, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:46, trainLoss:0.002267, validLoss:0.002175, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:47, trainLoss:0.002270, validLoss:0.002125, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:48, trainLoss:0.002250, validLoss:0.002216, validReverseLoss:0.000000\n",
      "Pretrain_Epoch:49, trainLoss:0.002275, validLoss:0.002140, validReverseLoss:0.000000\n",
      "Epoch:0, d_loss:0.357153, g_loss:3.251941, d accuracy:1.000000, d AUC:1.000000, g accuracy:1.000000, rdf 0.000000\n",
      "Epoch:1, d_loss:0.085009, g_loss:4.590782, d accuracy:0.998977, d AUC:1.000000, g accuracy:1.000000, rdf 0.000000\n",
      "Epoch:2, d_loss:0.918341, g_loss:2.717331, d accuracy:0.854886, d AUC:0.942447, g accuracy:0.748864, rdf 0.000000\n",
      "Epoch:3, d_loss:0.727511, g_loss:2.209599, d accuracy:0.810568, d AUC:0.905650, g accuracy:0.829545, rdf 0.000000\n",
      "Epoch:4, d_loss:0.803367, g_loss:2.215892, d accuracy:0.811705, d AUC:0.913673, g accuracy:0.732500, rdf 0.000000\n",
      "Epoch:5, d_loss:0.956857, g_loss:1.956211, d accuracy:0.790682, d AUC:0.896345, g accuracy:0.703409, rdf 0.000000\n",
      "Epoch:6, d_loss:1.016619, g_loss:1.816614, d accuracy:0.613636, d AUC:0.670148, g accuracy:0.563864, rdf 0.000000\n",
      "Epoch:7, d_loss:0.944658, g_loss:1.895438, d accuracy:0.801364, d AUC:0.893936, g accuracy:0.759773, rdf 0.000000\n",
      "Epoch:8, d_loss:0.945125, g_loss:1.866466, d accuracy:0.704545, d AUC:0.791189, g accuracy:0.646591, rdf 0.000000\n",
      "Epoch:9, d_loss:0.955225, g_loss:1.820898, d accuracy:0.745114, d AUC:0.833359, g accuracy:0.703409, rdf 0.000000\n",
      "INFO:tensorflow:cache/fairgan_unfair is not in all_model_checkpoint_paths. Manually adding it.\n",
      "INFO:tensorflow:cache/fairgan is not in all_model_checkpoint_paths. Manually adding it.\n",
      "cache/fairgan\n",
      "INFO:tensorflow:Restoring parameters from cache/fairgan\n",
      "burning in\n",
      "generating\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/sklearn/neural_network/_multilayer_perceptron.py:692: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.\n",
      "  warnings.warn(\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/sklearn/base.py:450: UserWarning: X does not have valid feature names, but MLPClassifier was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.313098</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.447853</td>\n",
       "      <td>0.307014</td>\n",
       "      <td>0.633629</td>\n",
       "      <td>0.370685</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.414991</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.369372</td>\n",
       "      <td>0.526167</td>\n",
       "      <td>1.048100</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.162797</td>\n",
       "      <td>0.153582</td>\n",
       "      <td>0.032848</td>\n",
       "      <td>0.780304</td>\n",
       "      <td>0.193229</td>\n",
       "      <td>0.461950</td>\n",
       "      <td>0.398713</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.813397</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.385241</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.102639</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.208462</td>\n",
       "      <td>0.239274</td>\n",
       "      <td>0.839121</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.545892</td>\n",
       "      <td>0.406260</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.180958</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.410443</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.149550</td>\n",
       "      <td>0.366827</td>\n",
       "      <td>0.022638</td>\n",
       "      <td>0.792689</td>\n",
       "      <td>0.087440</td>\n",
       "      <td>0.244400</td>\n",
       "      <td>0.490716</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.860048</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.332419</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.363599</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.063530</td>\n",
       "      <td>0.319478</td>\n",
       "      <td>0.718697</td>\n",
       "      <td>0.598897</td>\n",
       "      <td>0.166194</td>\n",
       "      <td>0.573069</td>\n",
       "      <td>0.628330</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.016383</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.070529</td>\n",
       "      <td>0.514503</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.772208</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        age  workclass    fnlwgt  education  education-num  marital-status  \\\n",
       "0  0.000000   0.313098  0.000000   0.447853       0.307014        0.633629   \n",
       "1  0.162797   0.153582  0.032848   0.780304       0.193229        0.461950   \n",
       "2  0.000000   0.208462  0.239274   0.839121       0.000000        0.545892   \n",
       "3  0.149550   0.366827  0.022638   0.792689       0.087440        0.244400   \n",
       "4  0.063530   0.319478  0.718697   0.598897       0.166194        0.573069   \n",
       "\n",
       "   occupation  relationship      race  sex  capital-gain  capital-loss  \\\n",
       "0    0.370685           0.0  0.414991  0.0      0.000000      0.369372   \n",
       "1    0.398713           0.0  0.813397  0.0      0.000000      0.385241   \n",
       "2    0.406260           0.0  1.180958  0.0      0.000000      0.410443   \n",
       "3    0.490716           0.0  0.860048  0.0      0.000000      0.332419   \n",
       "4    0.628330           0.0  1.016383  0.0      0.070529      0.514503   \n",
       "\n",
       "   hours-per-week  native-country  income  \n",
       "0        0.526167        1.048100     1.0  \n",
       "1        0.000000        1.102639     1.0  \n",
       "2        0.000000        0.000000     1.0  \n",
       "3        0.000000        0.363599     1.0  \n",
       "4        0.000000        0.772208     1.0  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_fairgan(dataset_train)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.773604590505999,\n",
       " 'recall': 0.9860372340425532,\n",
       " 'auroc': 0.5555186170212766,\n",
       " 'dp': 0.01400030693248877,\n",
       " 'ftu': 0.030999999999999917}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DECAF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DECAF-ND"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:170: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n",
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:120: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.\n",
      "  rank_zero_warn(\"You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.\")\n",
      "\n",
      "  | Name          | Type             | Params\n",
      "---------------------------------------------------\n",
      "0 | generator     | Generator_causal | 134 K \n",
      "1 | discriminator | Discriminator    | 43.6 K\n",
      "---------------------------------------------------\n",
      "178 K     Trainable params\n",
      "225       Non-trainable params\n",
      "178 K     Total params\n",
      "0.713     Total estimated model params size (MB)\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:631: UserWarning: Checkpoint directory /Users/*****/Projects/UvA/UvA_FACT2022/checkpoints exists and is not empty.\n",
      "  rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialised adjacency matrix as parsed:\n",
      " Parameter containing:\n",
      "tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n",
      "Epoch 49: 100%|██████████| 676/676 [00:11<00:00, 59.96it/s, loss=-0.113] \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3675</th>\n",
       "      <td>0.380142</td>\n",
       "      <td>0.000388</td>\n",
       "      <td>0.069295</td>\n",
       "      <td>0.193782</td>\n",
       "      <td>0.602936</td>\n",
       "      <td>2.099937e-01</td>\n",
       "      <td>0.196843</td>\n",
       "      <td>0.239727</td>\n",
       "      <td>6.095866e-35</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.093916e-05</td>\n",
       "      <td>1.270744e-06</td>\n",
       "      <td>0.189316</td>\n",
       "      <td>9.965676e-11</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5946</th>\n",
       "      <td>0.607431</td>\n",
       "      <td>0.615952</td>\n",
       "      <td>0.023988</td>\n",
       "      <td>0.110090</td>\n",
       "      <td>0.606911</td>\n",
       "      <td>3.790743e-14</td>\n",
       "      <td>0.860790</td>\n",
       "      <td>0.708142</td>\n",
       "      <td>3.790271e-26</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.319775e-06</td>\n",
       "      <td>1.160979e-05</td>\n",
       "      <td>0.345060</td>\n",
       "      <td>2.871152e-24</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16075</th>\n",
       "      <td>0.572381</td>\n",
       "      <td>0.079426</td>\n",
       "      <td>0.030153</td>\n",
       "      <td>0.208819</td>\n",
       "      <td>0.960146</td>\n",
       "      <td>4.513823e-08</td>\n",
       "      <td>0.590635</td>\n",
       "      <td>0.044224</td>\n",
       "      <td>9.959477e-27</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.349094e-08</td>\n",
       "      <td>2.085748e-02</td>\n",
       "      <td>0.393091</td>\n",
       "      <td>2.918873e-11</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23510</th>\n",
       "      <td>0.084389</td>\n",
       "      <td>0.000023</td>\n",
       "      <td>0.104752</td>\n",
       "      <td>0.065256</td>\n",
       "      <td>0.816472</td>\n",
       "      <td>4.859577e-01</td>\n",
       "      <td>0.744997</td>\n",
       "      <td>0.953279</td>\n",
       "      <td>2.737193e-27</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.918518e-07</td>\n",
       "      <td>3.674373e-18</td>\n",
       "      <td>0.205421</td>\n",
       "      <td>2.281011e-12</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24885</th>\n",
       "      <td>0.234341</td>\n",
       "      <td>0.000942</td>\n",
       "      <td>0.112664</td>\n",
       "      <td>0.035219</td>\n",
       "      <td>0.603397</td>\n",
       "      <td>4.412857e-03</td>\n",
       "      <td>0.835228</td>\n",
       "      <td>0.391127</td>\n",
       "      <td>7.269054e-32</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.768562e-07</td>\n",
       "      <td>8.484753e-16</td>\n",
       "      <td>0.448975</td>\n",
       "      <td>5.465361e-09</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            age  workclass    fnlwgt  education  education-num  \\\n",
       "3675   0.380142   0.000388  0.069295   0.193782       0.602936   \n",
       "5946   0.607431   0.615952  0.023988   0.110090       0.606911   \n",
       "16075  0.572381   0.079426  0.030153   0.208819       0.960146   \n",
       "23510  0.084389   0.000023  0.104752   0.065256       0.816472   \n",
       "24885  0.234341   0.000942  0.112664   0.035219       0.603397   \n",
       "\n",
       "       marital-status  occupation  relationship          race  sex  \\\n",
       "3675     2.099937e-01    0.196843      0.239727  6.095866e-35  1.0   \n",
       "5946     3.790743e-14    0.860790      0.708142  3.790271e-26  1.0   \n",
       "16075    4.513823e-08    0.590635      0.044224  9.959477e-27  0.0   \n",
       "23510    4.859577e-01    0.744997      0.953279  2.737193e-27  0.0   \n",
       "24885    4.412857e-03    0.835228      0.391127  7.269054e-32  1.0   \n",
       "\n",
       "       capital-gain  capital-loss  hours-per-week  native-country  income  \n",
       "3675   2.093916e-05  1.270744e-06        0.189316    9.965676e-11     0.0  \n",
       "5946   1.319775e-06  1.160979e-05        0.345060    2.871152e-24     0.0  \n",
       "16075  7.349094e-08  2.085748e-02        0.393091    2.918873e-11     0.0  \n",
       "23510  1.918518e-07  3.674373e-18        0.205421    2.281011e-12     1.0  \n",
       "24885  3.768562e-07  8.484753e-16        0.448975    5.465361e-09     1.0  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_decaf(dataset_train, dag_seed)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.8732394366197183,\n",
       " 'recall': 0.7832446808510638,\n",
       " 'auroc': 0.7192433081674674,\n",
       " 'dp': 0.35986203,\n",
       " 'ftu': 0.19800001}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DECAF-FTU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:170: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialised adjacency matrix as parsed:\n",
      " Parameter containing:\n",
      "tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3675</th>\n",
       "      <td>0.569362</td>\n",
       "      <td>2.306030e-02</td>\n",
       "      <td>0.040449</td>\n",
       "      <td>0.031290</td>\n",
       "      <td>0.460504</td>\n",
       "      <td>4.049787e-01</td>\n",
       "      <td>0.687541</td>\n",
       "      <td>0.959119</td>\n",
       "      <td>1.518811e-13</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.004490e-06</td>\n",
       "      <td>3.206931e-10</td>\n",
       "      <td>0.434427</td>\n",
       "      <td>3.572187e-23</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5946</th>\n",
       "      <td>0.377764</td>\n",
       "      <td>2.334096e-05</td>\n",
       "      <td>0.122391</td>\n",
       "      <td>0.728067</td>\n",
       "      <td>0.882204</td>\n",
       "      <td>1.850907e-04</td>\n",
       "      <td>0.202852</td>\n",
       "      <td>0.380964</td>\n",
       "      <td>1.583383e-16</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.184416e-06</td>\n",
       "      <td>7.724157e-09</td>\n",
       "      <td>0.418983</td>\n",
       "      <td>5.500095e-15</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16075</th>\n",
       "      <td>0.350978</td>\n",
       "      <td>6.441578e-08</td>\n",
       "      <td>0.089029</td>\n",
       "      <td>0.050665</td>\n",
       "      <td>0.487395</td>\n",
       "      <td>5.067925e-01</td>\n",
       "      <td>0.511657</td>\n",
       "      <td>0.591335</td>\n",
       "      <td>6.622320e-15</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.145748e-07</td>\n",
       "      <td>1.514143e-08</td>\n",
       "      <td>0.405432</td>\n",
       "      <td>2.919201e-19</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23510</th>\n",
       "      <td>0.471133</td>\n",
       "      <td>8.222158e-06</td>\n",
       "      <td>0.122726</td>\n",
       "      <td>0.043518</td>\n",
       "      <td>0.468221</td>\n",
       "      <td>3.292640e-09</td>\n",
       "      <td>0.566924</td>\n",
       "      <td>0.369867</td>\n",
       "      <td>7.818704e-37</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9.189102e-10</td>\n",
       "      <td>1.149758e-10</td>\n",
       "      <td>0.435682</td>\n",
       "      <td>8.755472e-09</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24885</th>\n",
       "      <td>0.608282</td>\n",
       "      <td>1.187612e-02</td>\n",
       "      <td>0.076645</td>\n",
       "      <td>0.269258</td>\n",
       "      <td>0.461104</td>\n",
       "      <td>2.710335e-07</td>\n",
       "      <td>0.516129</td>\n",
       "      <td>0.372954</td>\n",
       "      <td>5.737523e-24</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.757823e-09</td>\n",
       "      <td>8.570197e-11</td>\n",
       "      <td>0.759713</td>\n",
       "      <td>1.737677e-05</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            age     workclass    fnlwgt  education  education-num  \\\n",
       "3675   0.569362  2.306030e-02  0.040449   0.031290       0.460504   \n",
       "5946   0.377764  2.334096e-05  0.122391   0.728067       0.882204   \n",
       "16075  0.350978  6.441578e-08  0.089029   0.050665       0.487395   \n",
       "23510  0.471133  8.222158e-06  0.122726   0.043518       0.468221   \n",
       "24885  0.608282  1.187612e-02  0.076645   0.269258       0.461104   \n",
       "\n",
       "       marital-status  occupation  relationship          race  sex  \\\n",
       "3675     4.049787e-01    0.687541      0.959119  1.518811e-13  0.0   \n",
       "5946     1.850907e-04    0.202852      0.380964  1.583383e-16  1.0   \n",
       "16075    5.067925e-01    0.511657      0.591335  6.622320e-15  1.0   \n",
       "23510    3.292640e-09    0.566924      0.369867  7.818704e-37  1.0   \n",
       "24885    2.710335e-07    0.516129      0.372954  5.737523e-24  1.0   \n",
       "\n",
       "       capital-gain  capital-loss  hours-per-week  native-country  income  \n",
       "3675   1.004490e-06  3.206931e-10        0.434427    3.572187e-23     1.0  \n",
       "5946   1.184416e-06  7.724157e-09        0.418983    5.500095e-15     1.0  \n",
       "16075  1.145748e-07  1.514143e-08        0.405432    2.919201e-19     1.0  \n",
       "23510  9.189102e-10  1.149758e-10        0.435682    8.755472e-09     0.0  \n",
       "24885  2.757823e-09  8.570197e-11        0.759713    1.737677e-05     1.0  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_ftu)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.8654259126700071,\n",
       " 'recall': 0.8038563829787234,\n",
       " 'auroc': 0.7124120624571036,\n",
       " 'dp': 0.26447117,\n",
       " 'ftu': 0.009500027}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DECAF-CF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:170: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialised adjacency matrix as parsed:\n",
      " Parameter containing:\n",
      "tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3675</th>\n",
       "      <td>0.126801</td>\n",
       "      <td>2.034548e-07</td>\n",
       "      <td>0.097748</td>\n",
       "      <td>0.445107</td>\n",
       "      <td>0.504234</td>\n",
       "      <td>5.187284e-02</td>\n",
       "      <td>0.555084</td>\n",
       "      <td>0.207682</td>\n",
       "      <td>1.940717e-26</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.969487e-07</td>\n",
       "      <td>8.796071e-22</td>\n",
       "      <td>0.438005</td>\n",
       "      <td>9.395826e-04</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5946</th>\n",
       "      <td>0.481099</td>\n",
       "      <td>2.218284e-02</td>\n",
       "      <td>0.096830</td>\n",
       "      <td>0.036081</td>\n",
       "      <td>0.593112</td>\n",
       "      <td>2.401668e-04</td>\n",
       "      <td>0.822327</td>\n",
       "      <td>0.372022</td>\n",
       "      <td>3.589954e-33</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.816202e-05</td>\n",
       "      <td>3.667631e-02</td>\n",
       "      <td>0.358581</td>\n",
       "      <td>3.940624e-20</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16075</th>\n",
       "      <td>0.066505</td>\n",
       "      <td>7.122482e-01</td>\n",
       "      <td>0.024837</td>\n",
       "      <td>0.168736</td>\n",
       "      <td>0.601697</td>\n",
       "      <td>3.587708e-01</td>\n",
       "      <td>0.080743</td>\n",
       "      <td>0.418794</td>\n",
       "      <td>2.489512e-37</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.710960e-05</td>\n",
       "      <td>1.951058e-18</td>\n",
       "      <td>0.418606</td>\n",
       "      <td>1.317396e-09</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23510</th>\n",
       "      <td>0.091973</td>\n",
       "      <td>1.355053e-01</td>\n",
       "      <td>0.029914</td>\n",
       "      <td>0.085077</td>\n",
       "      <td>0.956460</td>\n",
       "      <td>5.039743e-07</td>\n",
       "      <td>0.500809</td>\n",
       "      <td>0.375713</td>\n",
       "      <td>1.503318e-26</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4.706238e-07</td>\n",
       "      <td>1.705442e-03</td>\n",
       "      <td>0.390647</td>\n",
       "      <td>9.432316e-13</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24885</th>\n",
       "      <td>0.294165</td>\n",
       "      <td>9.693924e-08</td>\n",
       "      <td>0.049272</td>\n",
       "      <td>0.744953</td>\n",
       "      <td>0.583081</td>\n",
       "      <td>7.366261e-11</td>\n",
       "      <td>0.326614</td>\n",
       "      <td>0.381308</td>\n",
       "      <td>2.713189e-27</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.116689e-09</td>\n",
       "      <td>1.143300e-22</td>\n",
       "      <td>0.427482</td>\n",
       "      <td>1.598419e-06</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            age     workclass    fnlwgt  education  education-num  \\\n",
       "3675   0.126801  2.034548e-07  0.097748   0.445107       0.504234   \n",
       "5946   0.481099  2.218284e-02  0.096830   0.036081       0.593112   \n",
       "16075  0.066505  7.122482e-01  0.024837   0.168736       0.601697   \n",
       "23510  0.091973  1.355053e-01  0.029914   0.085077       0.956460   \n",
       "24885  0.294165  9.693924e-08  0.049272   0.744953       0.583081   \n",
       "\n",
       "       marital-status  occupation  relationship          race  sex  \\\n",
       "3675     5.187284e-02    0.555084      0.207682  1.940717e-26  0.0   \n",
       "5946     2.401668e-04    0.822327      0.372022  3.589954e-33  1.0   \n",
       "16075    3.587708e-01    0.080743      0.418794  2.489512e-37  1.0   \n",
       "23510    5.039743e-07    0.500809      0.375713  1.503318e-26  1.0   \n",
       "24885    7.366261e-11    0.326614      0.381308  2.713189e-27  1.0   \n",
       "\n",
       "       capital-gain  capital-loss  hours-per-week  native-country  income  \n",
       "3675   7.969487e-07  8.796071e-22        0.438005    9.395826e-04     1.0  \n",
       "5946   3.816202e-05  3.667631e-02        0.358581    3.940624e-20     1.0  \n",
       "16075  2.710960e-05  1.951058e-18        0.418606    1.317396e-09     1.0  \n",
       "23510  4.706238e-07  1.705442e-03        0.390647    9.432316e-13     0.0  \n",
       "24885  1.116689e-09  1.143300e-22        0.427482    1.598419e-06     1.0  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_cf)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.7692307692307693,\n",
       " 'recall': 0.9441489361702128,\n",
       " 'auroc': 0.5426389842141386,\n",
       " 'dp': 0.0018555522,\n",
       " 'ftu': 0.07699996}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### DECAF-DP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n",
      "/Users/*****/Projects/UvA/UvA_FACT2022/env/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:170: LightningDeprecationWarning: DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\n",
      "  rank_zero_deprecation(\"DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialised adjacency matrix as parsed:\n",
      " Parameter containing:\n",
      "tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
      "        [0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3675</th>\n",
       "      <td>0.602783</td>\n",
       "      <td>1.143769e-07</td>\n",
       "      <td>0.040871</td>\n",
       "      <td>0.034705</td>\n",
       "      <td>0.498522</td>\n",
       "      <td>3.042221e-07</td>\n",
       "      <td>0.265838</td>\n",
       "      <td>0.692303</td>\n",
       "      <td>1.167078e-14</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.071873e-06</td>\n",
       "      <td>1.343675e-04</td>\n",
       "      <td>0.617038</td>\n",
       "      <td>2.140054e-13</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5946</th>\n",
       "      <td>0.596992</td>\n",
       "      <td>2.013029e-03</td>\n",
       "      <td>0.143913</td>\n",
       "      <td>0.794298</td>\n",
       "      <td>0.455123</td>\n",
       "      <td>1.978627e-02</td>\n",
       "      <td>0.727317</td>\n",
       "      <td>0.614474</td>\n",
       "      <td>1.705038e-13</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.786253e-07</td>\n",
       "      <td>3.505098e-16</td>\n",
       "      <td>0.353130</td>\n",
       "      <td>8.698998e-23</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16075</th>\n",
       "      <td>0.112791</td>\n",
       "      <td>5.396063e-07</td>\n",
       "      <td>0.082294</td>\n",
       "      <td>0.040499</td>\n",
       "      <td>0.596541</td>\n",
       "      <td>1.945614e-06</td>\n",
       "      <td>0.540141</td>\n",
       "      <td>0.381207</td>\n",
       "      <td>8.063686e-32</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.033979e-08</td>\n",
       "      <td>7.078151e-10</td>\n",
       "      <td>0.426859</td>\n",
       "      <td>2.057964e-24</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23510</th>\n",
       "      <td>0.313090</td>\n",
       "      <td>4.725884e-05</td>\n",
       "      <td>0.115046</td>\n",
       "      <td>0.111139</td>\n",
       "      <td>0.596356</td>\n",
       "      <td>8.643508e-03</td>\n",
       "      <td>0.142311</td>\n",
       "      <td>0.789508</td>\n",
       "      <td>2.858072e-05</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.492114e-10</td>\n",
       "      <td>7.012393e-07</td>\n",
       "      <td>0.428549</td>\n",
       "      <td>1.228493e-14</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24885</th>\n",
       "      <td>0.495689</td>\n",
       "      <td>2.521220e-06</td>\n",
       "      <td>0.031824</td>\n",
       "      <td>0.286899</td>\n",
       "      <td>0.453205</td>\n",
       "      <td>3.095134e-01</td>\n",
       "      <td>0.571484</td>\n",
       "      <td>0.978397</td>\n",
       "      <td>1.969127e-21</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.708674e-09</td>\n",
       "      <td>2.719935e-20</td>\n",
       "      <td>0.409587</td>\n",
       "      <td>2.226772e-21</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            age     workclass    fnlwgt  education  education-num  \\\n",
       "3675   0.602783  1.143769e-07  0.040871   0.034705       0.498522   \n",
       "5946   0.596992  2.013029e-03  0.143913   0.794298       0.455123   \n",
       "16075  0.112791  5.396063e-07  0.082294   0.040499       0.596541   \n",
       "23510  0.313090  4.725884e-05  0.115046   0.111139       0.596356   \n",
       "24885  0.495689  2.521220e-06  0.031824   0.286899       0.453205   \n",
       "\n",
       "       marital-status  occupation  relationship          race  sex  \\\n",
       "3675     3.042221e-07    0.265838      0.692303  1.167078e-14  1.0   \n",
       "5946     1.978627e-02    0.727317      0.614474  1.705038e-13  1.0   \n",
       "16075    1.945614e-06    0.540141      0.381207  8.063686e-32  1.0   \n",
       "23510    8.643508e-03    0.142311      0.789508  2.858072e-05  0.0   \n",
       "24885    3.095134e-01    0.571484      0.978397  1.969127e-21  0.0   \n",
       "\n",
       "       capital-gain  capital-loss  hours-per-week  native-country  income  \n",
       "3675   2.071873e-06  1.343675e-04        0.617038    2.140054e-13     1.0  \n",
       "5946   3.786253e-07  3.505098e-16        0.353130    8.698998e-23     1.0  \n",
       "16075  1.033979e-08  7.078151e-10        0.426859    2.057964e-24     0.0  \n",
       "23510  8.492114e-10  7.012393e-07        0.428549    1.228493e-14     0.0  \n",
       "24885  2.708674e-09  2.719935e-20        0.409587    2.226772e-21     0.0  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_dp)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'precision': 0.7696105320899616,\n",
       " 'recall': 0.932845744680851,\n",
       " 'auroc': 0.543035775566232,\n",
       " 'dp': 0.045112073,\n",
       " 'ftu': 0.021500051}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_model(synth_data, dataset_test)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "13e45f1d2a71a5202f039caac1b9c752d2af996637e618c185f277d12109b1e8"
  },
  "kernelspec": {
   "display_name": "Python 3.9.8 64-bit ('env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
