{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.utils.data\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "from datasets.crime import CrimeDataset\n",
    "from generative.gmm import GMM\n",
    "from real_nvp_encoder import FlowEncoder"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "sns.set_theme()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "PROJECT_ROOT = Path('.').absolute().parent"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset = 'crime'\n",
    "gamma = 1.0\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "alpha = 0.05\n",
    "# lr = 1e-2\n",
    "# weight_decay = 1e-4\n",
    "# kl_start = 0\n",
    "# kl_end = 50\n",
    "# protected_att = None\n",
    "n_blocks = 4\n",
    "batch_size = 128\n",
    "# dec_epochs = 100\n",
    "# prior_epochs = 150\n",
    "# n_epochs = 60\n",
    "# adv_epochs = 60\n",
    "# prior = 'gmm'\n",
    "gmm_comps1 = 4\n",
    "gmm_comps2 = 2\n",
    "# gamma = 1.0\n",
    "# n_flows = 1\n",
    "seed = 100\n",
    "# train_dec = False\n",
    "# log_epochs = 10\n",
    "p_test = 0.2\n",
    "p_val = 0.2\n",
    "# with_test = False\n",
    "# fair_criterion = 'stat_parity'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plots_dir = PROJECT_ROOT / 'plots' / dataset\n",
    "plots_dir.mkdir(parents=True, exist_ok=True)\n",
    "model_dir = PROJECT_ROOT / 'code' / dataset / f'gamma_{gamma}'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_dataset = CrimeDataset('train', p_test=p_test, p_val=p_val)\n",
    "valid_dataset = CrimeDataset('validation', p_test=p_test, p_val=p_val)\n",
    "test_dataset = CrimeDataset('test', p_test=p_test, p_val=p_val)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "feats = np.array([3, 15, 42, 43, 44, 49])\n",
    "column_ids = [\n",
    "    col for col, idx in train_dataset.column_ids.items() if idx in feats\n",
    "]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_all = train_dataset.features[:, feats]\n",
    "valid_all = valid_dataset.features[:, feats]\n",
    "test_all = test_dataset.features[:, feats]\n",
    "\n",
    "train_prot = train_dataset.protected\n",
    "valid_prot = valid_dataset.protected\n",
    "test_prot = test_dataset.protected\n",
    "\n",
    "train_targets = train_dataset.labels\n",
    "valid_targets = valid_dataset.labels\n",
    "test_targets = test_dataset.labels"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "lb, ub = dict(), dict()\n",
    "\n",
    "for idx, feature_idx in enumerate(feats):\n",
    "    lb[feature_idx] = torch.min(train_all[:, idx])\n",
    "    ub[feature_idx] = torch.max(train_all[:, idx])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def normalize(x, a, b):\n",
    "    return 0.5 + (1 - alpha) * ((x - a) / (b - a) - 0.5)\n",
    "\n",
    "def denormalize(z, a, b):\n",
    "    return ((z - 0.5) / (1 - alpha) + 0.5) * (b - a) + a\n",
    "\n",
    "def normalize_data(data):\n",
    "    for idx, feature_idx in enumerate(feats):\n",
    "        data[:, idx] = normalize(data[:, idx], lb[feature_idx], ub[feature_idx])\n",
    "    return data\n",
    "\n",
    "def denormalize_data(data):\n",
    "    for idx, feature_idx in enumerate(feats):\n",
    "        data[:, idx] = denormalize(data[:, idx], lb[feature_idx], ub[feature_idx])\n",
    "    return data"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_all = normalize_data(train_all)\n",
    "valid_all = normalize_data(valid_all)\n",
    "test_all = normalize_data(test_all)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train1, train2 = train_all[train_prot == 1], train_all[train_prot == 0]\n",
    "targets1, targets2 = train_targets[train_prot == 1].long(), train_targets[train_prot == 0].long()\n",
    "train1_loader = torch.utils.data.DataLoader(TensorDataset(train1, targets1), batch_size=batch_size, shuffle=True)\n",
    "train2_loader = torch.utils.data.DataLoader(TensorDataset(train2, targets2), batch_size=batch_size, shuffle=True)\n",
    "\n",
    "valid1, valid2 = valid_all[valid_prot == 1], valid_all[valid_prot == 0]\n",
    "v_targets1, v_targets2 = valid_targets[valid_prot == 1].long(), valid_targets[valid_prot == 0].long()\n",
    "valid1_loader = torch.utils.data.DataLoader(TensorDataset(valid1, v_targets1), batch_size=8, shuffle=True)\n",
    "valid2_loader = torch.utils.data.DataLoader(TensorDataset(valid2, v_targets2), batch_size=8, shuffle=True)\n",
    "\n",
    "test1, test2 = test_all[test_prot == 1], test_all[test_prot == 0]\n",
    "t_targets1, t_targets2 = test_targets[test_prot == 1].long(), test_targets[test_prot == 0].long()\n",
    "test1_loader = torch.utils.data.DataLoader(TensorDataset(test1, t_targets1), batch_size=8, shuffle=True)\n",
    "test2_loader = torch.utils.data.DataLoader(TensorDataset(test2, t_targets2), batch_size=8, shuffle=True)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "gaussian_mixture1 = GaussianMixture(n_components=gmm_comps1, n_init=1, covariance_type='full')\n",
    "gaussian_mixture2 = GaussianMixture(n_components=gmm_comps2, n_init=1, covariance_type='full')\n",
    "\n",
    "gaussian_mixture1.weights_ = np.load(model_dir / 'prior1_weights.npy')\n",
    "gaussian_mixture1.means_ = np.load(model_dir / 'prior1_means.npy')\n",
    "gaussian_mixture1.covariances_ = np.load(model_dir / 'prior1_covs.npy')\n",
    "\n",
    "gaussian_mixture2.weights_ = np.load(model_dir / 'prior2_weights.npy')\n",
    "gaussian_mixture2.means_ = np.load(model_dir / 'prior2_means.npy')\n",
    "gaussian_mixture2.covariances_ = np.load(model_dir / 'prior2_covs.npy')\n",
    "\n",
    "prior1 = GMM(gaussian_mixture1, device=device)\n",
    "prior2 = GMM(gaussian_mixture2, device=device)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "masks = []\n",
    "for i in range(20):\n",
    "    t = np.array([j % 2 for j in range(feats.shape[0])])\n",
    "    np.random.shuffle(t)\n",
    "    masks += [t, 1 - t]\n",
    "\n",
    "flow1 = FlowEncoder(None, feats.shape[0], [50, 50], n_blocks, masks).to(device)\n",
    "flow2 = FlowEncoder(None, feats.shape[0], [50, 50], n_blocks, masks).to(device)\n",
    "\n",
    "flow1.load_state_dict(torch.load(model_dir / 'flow1.pt'))\n",
    "flow2.load_state_dict(torch.load(model_dir / 'flow2.pt'))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mappings = defaultdict(list)\n",
    "\n",
    "for (x1, y1), (x2, y2) in zip(train1_loader, train2_loader):\n",
    "\n",
    "    # clamp has no effect on train data\n",
    "    x1 = torch.clamp(x1, alpha / 2, 1 - alpha).logit()\n",
    "    x2 = torch.clamp(x2, alpha / 2, 1 - alpha).logit()\n",
    "\n",
    "\n",
    "    x1_z1, _ = flow1.inverse(x1)\n",
    "    x1_x2, _ = flow2.forward(x1_z1)\n",
    "\n",
    "    mappings['x1_real'].append(x1.sigmoid())\n",
    "    mappings['x2_fake'].append(x1_x2.sigmoid())\n",
    "    mappings['x1_real_logp'].append(prior1.log_prob(x1))\n",
    "    mappings['x2_fake_logp'].append(prior2.log_prob(x1_x2))\n",
    "    mappings['z1'].append(x1_z1)\n",
    "    mappings['y1'].append(y1)\n",
    "\n",
    "    x2_z2, _ = flow2.inverse(x2)\n",
    "    x2_x1, _ = flow1.forward(x2_z2)\n",
    "\n",
    "    mappings['x1_fake'].append(x2_x1.sigmoid())\n",
    "    mappings['x2_real'].append(x2.sigmoid())\n",
    "    mappings['x1_fake_logp'].append(prior1.log_prob(x2_x1))\n",
    "    mappings['x2_real_logp'].append(prior2.log_prob(x2))\n",
    "    mappings['z2'].append(x2_z2)\n",
    "    mappings['y2'].append(y2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# undo flow normalization\n",
    "x1_real = denormalize_data(torch.vstack(mappings['x1_real']))\n",
    "x2_real = denormalize_data(torch.vstack(mappings['x2_real']))\n",
    "x1_fake = denormalize_data(torch.vstack(mappings['x1_fake']))\n",
    "x2_fake = denormalize_data(torch.vstack(mappings['x2_fake']))\n",
    "\n",
    "# undo preprocessing normalization\n",
    "x1_real_denormalized = train_dataset.std[feats] * x1_real + train_dataset.mean[feats]\n",
    "x2_real_denormalized = train_dataset.std[feats] * x2_real + train_dataset.mean[feats]\n",
    "x1_fake_denormalized = train_dataset.std[feats] * x1_fake + train_dataset.mean[feats]\n",
    "x2_fake_denormalized = train_dataset.std[feats] * x2_fake + train_dataset.mean[feats]\n",
    "\n",
    "x1_real = x1_real.cpu().detach()\n",
    "x2_real = x2_real.cpu().detach()\n",
    "x1_fake = x1_fake.cpu().detach()\n",
    "x2_fake = x2_fake.cpu().detach()\n",
    "\n",
    "x1_real_denormalized = x1_real_denormalized.cpu().detach()\n",
    "x2_real_denormalized = x2_real_denormalized.cpu().detach()\n",
    "x1_fake_denormalized = x1_fake_denormalized.cpu().detach()\n",
    "x2_fake_denormalized = x2_fake_denormalized.cpu().detach()\n",
    "\n",
    "x1_fake_logp = torch.cat(mappings['x1_fake_logp']).cpu().detach()\n",
    "x2_fake_logp = torch.cat(mappings['x2_fake_logp']).cpu().detach()\n",
    "x1_real_logp = torch.cat(mappings['x1_real_logp']).cpu().detach()\n",
    "x2_real_logp = torch.cat(mappings['x2_real_logp']).cpu().detach()\n",
    "\n",
    "z1 = torch.vstack(mappings['z1']).cpu().detach()\n",
    "z2 = torch.vstack(mappings['z2']).cpu().detach()\n",
    "y1 = torch.cat(mappings['y1']).cpu().detach()\n",
    "y2 = torch.cat(mappings['y2']).cpu().detach()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1 = torch.cat((x1_real, x1_fake))\n",
    "x2 = torch.cat((x2_fake, x2_real))\n",
    "x1_denormalized = torch.cat((x1_real_denormalized, x1_fake_denormalized))\n",
    "x2_denormalized = torch.cat((x2_fake_denormalized, x2_real_denormalized))\n",
    "\n",
    "x1_logp = torch.cat((x1_real_logp, x1_fake_logp))\n",
    "x2_logp = torch.cat((x2_fake_logp, x2_real_logp))\n",
    "\n",
    "z = torch.cat((z1, z2))\n",
    "y = torch.cat((y1, y2))\n",
    "\n",
    "n_x1 = x1_real.shape[0]\n",
    "n_x2 = x2_real.shape[0]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "n_clusters = 4\n",
    "kmeans_x1 = KMeans(n_clusters=n_clusters, random_state=0).fit(x1)\n",
    "kmeans_x2 = KMeans(n_clusters=n_clusters, random_state=0).fit(x2)\n",
    "\n",
    "perplexity = 35\n",
    "x1_t_sne = TSNE(perplexity=perplexity, random_state=75 if gamma == 0 else 56).fit_transform(x1)\n",
    "x2_t_sne = TSNE(perplexity=perplexity, random_state=75 if gamma == 0 else 56).fit_transform(x2)\n",
    "\n",
    "# rescale to [-1, 1]\n",
    "l1, u1 = x1_t_sne.min(0), x1_t_sne.max(0)\n",
    "l2, u2 = x2_t_sne.min(0), x2_t_sne.max(0)\n",
    "\n",
    "x1_t_sne = (2 * x1_t_sne - (u1 + l1)) / (u1 - l1)\n",
    "x2_t_sne = (2 * x2_t_sne - (u2 + l2)) / (u2 - l2)\n",
    "\n",
    "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))\n",
    "\n",
    "ax[0].scatter(x1_t_sne[:n_x1, 0], x1_t_sne[:n_x1, 1], c=kmeans_x1.labels_[:n_x1], marker='.', cmap='tab10', label='real')\n",
    "ax[0].scatter(x1_t_sne[n_x1:, 0], x1_t_sne[n_x1:, 1], c=kmeans_x1 .labels_[n_x1:], marker='x', cmap='tab10', label='matched', s=20)\n",
    "ax[1].scatter(x2_t_sne[n_x1:, 0], x2_t_sne[n_x1:, 1], c=kmeans_x1.labels_[n_x1:], marker='.', cmap='tab10', label='real')\n",
    "ax[1].scatter(x2_t_sne[:n_x1, 0], x2_t_sne[:n_x1, 1], c=kmeans_x1 .labels_[:n_x1], marker='x', cmap='tab10', label='matched', s=20)\n",
    "\n",
    "for idx in range(2):\n",
    "    ax[idx].set_xticklabels(list())\n",
    "    ax[idx].set_yticklabels(list())\n",
    "    ax[idx].legend()\n",
    "    ax[idx].get_legend().legendHandles[0].set_color('k')\n",
    "    ax[idx].get_legend().legendHandles[1].set_color('k')\n",
    "    ax[idx].set_xlim(left=-1.1, right=1.1)\n",
    "    ax[idx].set_ylim(bottom=-1.1, top=1.1)\n",
    "\n",
    "ax[0].set_title(r'Non-White (a = 1)')\n",
    "ax[1].set_title(r'White $(a = 0)$')\n",
    "\n",
    "fig.transFigure.inverted()\n",
    "ax0tr = ax[0].transData\n",
    "ax1tr = ax[1].transData\n",
    "figtr = fig.transFigure.inverted()\n",
    "\n",
    "ptA = figtr.transform(ax0tr.transform((1.15, -0.1)))\n",
    "ptB = figtr.transform(ax1tr.transform((-1.15, -0.1)))\n",
    "arrow = patches.FancyArrowPatch(\n",
    "    ptA, ptB, transform=fig.transFigure, connectionstyle=\"arc3,rad=0.5\",\n",
    "    arrowstyle='simple, head_width=5, head_length=5', color='k'\n",
    ")\n",
    "fig.patches.append(arrow)\n",
    "\n",
    "fig.text(0.495, 0.495, 'FNF')\n",
    "\n",
    "ptA = figtr.transform(ax1tr.transform((-1.15, 0.15)))\n",
    "ptB = figtr.transform(ax0tr.transform((1.15, 0.15)))\n",
    "arrow = patches.FancyArrowPatch(\n",
    "    ptA, ptB, transform=fig.transFigure, connectionstyle=\"arc3,rad=0.5\",\n",
    "    arrowstyle='simple, head_width=5, head_length=5', color='k'\n",
    ")\n",
    "fig.patches.append(arrow)\n",
    "\n",
    "plt.savefig(\n",
    "    plots_dir / f'gamma_{gamma}_perplexity_{perplexity}_n_clusters_{n_clusters}.eps',\n",
    "    bbox_inches='tight'\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "clusters_x1 = pd.DataFrame(columns=['logp'] + column_ids)\n",
    "clusters_x2 = pd.DataFrame(columns=['logp'] + column_ids)\n",
    "\n",
    "for cluster in range(n_clusters):\n",
    "    clusters_x1.loc[cluster] = torch.cat((\n",
    "        x1_logp[kmeans_x1.labels_ == cluster].mean().unsqueeze(0),\n",
    "        x1_denormalized[kmeans_x1.labels_ == cluster].mean(axis=0)\n",
    "    )).numpy()\n",
    "    clusters_x2.loc[cluster] = torch.cat((\n",
    "        x2_logp[kmeans_x1.labels_ == cluster].mean().unsqueeze(0),\n",
    "        x2_denormalized[kmeans_x1.labels_ == cluster].mean(axis=0)\n",
    "    )).numpy()\n",
    "\n",
    "clusters_x1.to_csv(\n",
    "    plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x1.csv', index=False\n",
    ")\n",
    "clusters_x2.to_csv(\n",
    "    plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x2.csv', index=False\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train1 = denormalize_data(train1)\n",
    "train2 = denormalize_data(train2)\n",
    "\n",
    "print(train_dataset.std[feats] * train1.min(0)[0] + train_dataset.mean[feats])\n",
    "print(train_dataset.std[feats] * train2.min(0)[0]+ train_dataset.mean[feats])\n",
    "print()\n",
    "print(train_dataset.std[feats] * train1.mean(0)+ train_dataset.mean[feats])\n",
    "print(train_dataset.std[feats] * train2.mean(0)+ train_dataset.mean[feats])\n",
    "print()\n",
    "print(train_dataset.std[feats] * train1.max(0)[0]+ train_dataset.mean[feats])\n",
    "print(train_dataset.std[feats] * train2.max(0)[0]+ train_dataset.mean[feats])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "clf = LogisticRegression(random_state=0).fit(z, y)\n",
    "y_hat = clf.predict(z)\n",
    "clf.score(z, y)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mappings = defaultdict(list)\n",
    "\n",
    "for x1_i, z_i in zip(x1[y_hat == 0], z[y_hat == 0]):\n",
    "    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]\n",
    "\n",
    "    for beta in np.linspace(0, 1, 11):\n",
    "        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)\n",
    "\n",
    "        if clf.predict(z_new):\n",
    "            break\n",
    "\n",
    "    x1_i_new, _ = flow1.forward(z_new.to(device))\n",
    "\n",
    "    mappings['x1_old'].append(x1_i.to(device))\n",
    "    mappings['x1_new'].append(x1_i_new.sigmoid().to(device))\n",
    "\n",
    "for x2_i, z_i in zip(x2[y_hat == 0], z[y_hat == 0]):\n",
    "    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]\n",
    "\n",
    "    for beta in np.linspace(0, 1, 11):\n",
    "        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)\n",
    "\n",
    "        if clf.predict(z_new):\n",
    "            break\n",
    "\n",
    "    x2_i_new, _ = flow2.forward(z_new.to(device))\n",
    "\n",
    "    mappings['x2_old'].append(x2_i.to(device))\n",
    "    mappings['x2_new'].append(x2_i_new.sigmoid().to(device))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1_old = torch.vstack(mappings['x1_old'])\n",
    "x2_old = torch.vstack(mappings['x2_old'])\n",
    "\n",
    "# undo flow normalization\n",
    "x1_new = denormalize_data(torch.vstack(mappings['x1_new']))\n",
    "x2_new = denormalize_data(torch.vstack(mappings['x2_new']))\n",
    "\n",
    "# undo preprocessing normalization\n",
    "x1_old = train_dataset.std[feats] * x1_old + train_dataset.mean[feats]\n",
    "x1_new = train_dataset.std[feats] * x1_new + train_dataset.mean[feats]\n",
    "x2_old = train_dataset.std[feats] * x2_old + train_dataset.mean[feats]\n",
    "x2_new = train_dataset.std[feats] * x2_new + train_dataset.mean[feats]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "avg_recourse = pd.DataFrame(\n",
    "    [torch.mean(x1_new - x1_old, 0).cpu().detach().numpy(),\n",
    "     torch.mean(x2_new - x2_old, 0).cpu().detach().numpy()],\n",
    "    columns=column_ids, index=['Non-White', 'White']\n",
    ")\n",
    "avg_recourse.to_csv(plots_dir / f'recourse_gamma_{gamma}.csv')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}