{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.utils.data\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "from datasets.lawschool import LawschoolDataset\n",
    "from generative.gmm import GMM\n",
    "from real_nvp_encoder import FlowEncoder"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "sns.set_theme()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "PROJECT_ROOT = Path('.').absolute().parent"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset = 'lawschool'\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "gamma = 1.0\n",
    "alpha = 0.05\n",
    "# lr = 1e-2\n",
    "# weight_decay = 1e-4\n",
    "# kl_start = 0\n",
    "# kl_end = 50\n",
    "# protected_att = None\n",
    "n_blocks = 4\n",
    "batch_size = 128\n",
    "# dec_epochs = 100\n",
    "# prior_epochs = 150\n",
    "# n_epochs = 100\n",
    "# adv_epochs = 100\n",
    "# prior = 'gmm'\n",
    "gmm_comps1 = 10\n",
    "gmm_comps2 = 10\n",
    "# out_file = None\n",
    "# n_flows = 1\n",
    "seed = 100\n",
    "# train_dec = True\n",
    "# log_epochs = 10\n",
    "quantiles = False\n",
    "p_test = 0.2\n",
    "p_val = 0.2\n",
    "# with_test = False\n",
    "# fair_criterion = 'stat_parity'\n",
    "# no_early_stop = False\n",
    "# load_enc = False"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model_dir = PROJECT_ROOT / 'code' / dataset / f'gamma_{gamma}'\n",
    "plots_dir = PROJECT_ROOT / 'plots' / dataset\n",
    "plots_dir.mkdir(parents=True, exist_ok=True)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Args:\n",
    "    def __init__(self, quantiles):\n",
    "        self.quantiles = quantiles"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "args = Args(quantiles=quantiles)\n",
    "\n",
    "train_dataset = LawschoolDataset('train', args, p_test=p_test, p_val=p_val)\n",
    "valid_dataset = LawschoolDataset('validation', args, p_test=p_test, p_val=p_val)\n",
    "test_dataset = LawschoolDataset('test', args, p_test=p_test, p_val=p_val)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_all = train_dataset.features\n",
    "valid_all = valid_dataset.features\n",
    "test_all = test_dataset.features\n",
    "\n",
    "train_prot = train_dataset.protected\n",
    "valid_prot = valid_dataset.protected\n",
    "test_prot = test_dataset.protected\n",
    "\n",
    "train_targets = train_dataset.labels\n",
    "valid_targets = valid_dataset.labels\n",
    "test_targets = test_dataset.labels\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_college = train_all[:, 6:30].max(dim=1)[1]\n",
    "valid_college = valid_all[:, 6:30].max(dim=1)[1]\n",
    "test_college = test_all[:, 6:30].max(dim=1)[1]\n",
    "\n",
    "c1_cnt = np.bincount(train_college[train_targets == 0].detach().cpu().numpy())\n",
    "c1_cnt = c1_cnt / np.sum(c1_cnt)\n",
    "\n",
    "college_rnk = list(range(c1_cnt.shape[0]))\n",
    "college_rnk.sort(key=lambda i: c1_cnt[i])\n",
    "\n",
    "new_train_college = train_college.detach().clone()\n",
    "new_valid_college = valid_college.detach().clone()\n",
    "new_test_college = test_college.detach().clone()\n",
    "\n",
    "for i, college in enumerate(college_rnk):\n",
    "    new_train_college = torch.where(train_college == college, i, new_train_college)\n",
    "    new_valid_college = torch.where(valid_college == college, i, new_valid_college)\n",
    "    new_test_college = torch.where(test_college == college, i, new_test_college)\n",
    "\n",
    "train_all = torch.cat([train_all[:, :2], new_train_college.unsqueeze(1)], dim=1).float()\n",
    "valid_all = torch.cat([valid_all[:, :2], new_valid_college.unsqueeze(1)], dim=1).float()\n",
    "test_all = torch.cat([test_all[:, :2], new_test_college.unsqueeze(1)], dim=1).float()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def compute_quants(train_all):\n",
    "    quants = []\n",
    "    for i in range(train_all.shape[1]):\n",
    "        x = np.sort(train_all[:, i].detach().cpu().numpy())\n",
    "        min_quant = 1000.0\n",
    "        for j in range(x.shape[0] - 1):\n",
    "            if x[j+1] - x[j] < 1e-4:\n",
    "                continue\n",
    "            min_quant = min(min_quant, x[j+1] - x[j])\n",
    "        quants += [min_quant]\n",
    "    return quants\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "quants = compute_quants(train_all)\n",
    "quants[1] = 0"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "column_ids = ['lsat', 'gpa', 'college']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "lb, ub = dict(), dict()\n",
    "\n",
    "for idx in range(train_all.shape[1]):\n",
    "    lb[idx], ub[idx] = train_all[:, idx].min(), train_all[:, idx].max()\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def normalize(x, a, b):\n",
    "    return 0.5 + (1 - alpha) * ((x - a) / (b - a) - 0.5)\n",
    "\n",
    "def denormalize(z, a, b):\n",
    "    return ((z - 0.5) / (1 - alpha) + 0.5) * (b - a) + a\n",
    "\n",
    "def normalize_data(data):\n",
    "    for idx in range(data.shape[1]):\n",
    "        # clamping has no effect on training data\n",
    "        data[:, idx] = torch.clamp(data[:, idx], lb[idx], ub[idx] + quants[idx])\n",
    "        data[:, idx] = normalize(data[:, idx], lb[idx], ub[idx] + quants[idx])\n",
    "    return data\n",
    "\n",
    "def denormalize_data(data):\n",
    "    for idx in range(data.shape[1]):\n",
    "        data[:, idx] = denormalize(data[:, idx], lb[idx], ub[idx] + quants[idx])\n",
    "    return data"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_all = normalize_data(train_all)\n",
    "valid_all = normalize_data(valid_all)\n",
    "test_all = normalize_data(test_all)\n",
    "\n",
    "q = torch.tensor(compute_quants(train_all)).float().unsqueeze(0).to(device)\n",
    "q[0, 1] = 0\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train1, train2 = train_all[train_prot == 1], train_all[train_prot == 0]\n",
    "train1 = torch.clamp(train1 + q * torch.rand(train1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "train2 = torch.clamp(train2 + q * torch.rand(train2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "targets1, targets2 = train_targets[train_prot == 1].long(), train_targets[train_prot == 0].long()\n",
    "train1_loader = torch.utils.data.DataLoader(TensorDataset(train1, targets1), batch_size=batch_size, shuffle=True)\n",
    "train2_loader = torch.utils.data.DataLoader(TensorDataset(train2, targets2), batch_size=batch_size, shuffle=True)\n",
    "\n",
    "valid1, valid2 = valid_all[valid_prot == 1], valid_all[valid_prot == 0]\n",
    "valid1 = torch.clamp(valid1 + q * torch.rand(valid1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "valid2 = torch.clamp(valid2 + q * torch.rand(valid2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "v_targets1, v_targets2 = valid_targets[valid_prot == 1].long(), valid_targets[valid_prot == 0].long()\n",
    "valid1_loader = torch.utils.data.DataLoader(TensorDataset(valid1, v_targets1), batch_size=batch_size)\n",
    "valid2_loader = torch.utils.data.DataLoader(TensorDataset(valid2, v_targets2), batch_size=batch_size)\n",
    "\n",
    "test1, test2 = test_all[test_prot == 1], test_all[test_prot == 0]\n",
    "test1 = torch.clamp(test1 + q * torch.rand(test1.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "test2 = torch.clamp(test2 + q * torch.rand(test2.shape).to(device), alpha / 2, 1 - alpha / 2).logit()\n",
    "t_targets1, t_targets2 = test_targets[test_prot == 1].long(), test_targets[test_prot == 0].long()\n",
    "test1_loader = torch.utils.data.DataLoader(TensorDataset(test1, t_targets1), batch_size=batch_size)\n",
    "test2_loader = torch.utils.data.DataLoader(TensorDataset(test2, t_targets2), batch_size=batch_size)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "gaussian_mixture1 = GaussianMixture(\n",
    "    n_components=gmm_comps1, n_init=1, covariance_type='full'\n",
    ")\n",
    "gaussian_mixture2 = GaussianMixture(\n",
    "    n_components=gmm_comps2, n_init=1, covariance_type='full'\n",
    ")\n",
    "\n",
    "gaussian_mixture1.weights_ = np.load(model_dir / 'prior1_weights.npy')\n",
    "gaussian_mixture1.means_ = np.load(model_dir / 'prior1_means.npy')\n",
    "gaussian_mixture1.covariances_ = np.load(model_dir / 'prior1_covs.npy')\n",
    "\n",
    "gaussian_mixture2.weights_ = np.load(model_dir / 'prior2_weights.npy')\n",
    "gaussian_mixture2.means_ = np.load(model_dir / 'prior2_means.npy')\n",
    "gaussian_mixture2.covariances_ = np.load(model_dir / 'prior2_covs.npy')\n",
    "\n",
    "prior1 = GMM(gaussian_mixture1, device=device)\n",
    "prior2 = GMM(gaussian_mixture2, device=device)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "in_dim = train_all.shape[1]\n",
    "\n",
    "masks = []\n",
    "for i in range(20):\n",
    "    t = np.array([j % 2 for j in range(in_dim)])\n",
    "    np.random.shuffle(t)\n",
    "    masks += [t, 1 - t]\n",
    "\n",
    "flow1 = FlowEncoder(None, in_dim, [50, 50], n_blocks, masks).to(device)\n",
    "flow2 = FlowEncoder(None, in_dim, [50, 50], n_blocks, masks).to(device)\n",
    "\n",
    "flow1.load_state_dict(torch.load(model_dir / 'flow1.pt'))\n",
    "flow2.load_state_dict(torch.load(model_dir / 'flow2.pt'))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mappings = defaultdict(list)\n",
    "\n",
    "for (x1, y1), (x2, y2) in zip(train1_loader, train2_loader):\n",
    "\n",
    "    # clamp has no effect on train data\n",
    "    x1 = torch.clamp(x1, alpha / 2, 1 - alpha).logit()\n",
    "    x2 = torch.clamp(x2, alpha / 2, 1 - alpha).logit()\n",
    "\n",
    "    x1_z1, _ = flow1.inverse(x1)\n",
    "    x1_x2, _ = flow2.forward(x1_z1)\n",
    "\n",
    "    mappings['x1_real'].append(x1.sigmoid())\n",
    "    mappings['x2_fake'].append(x1_x2.sigmoid())\n",
    "    mappings['z1'].append(x1_z1)\n",
    "    mappings['y1'].append(y1)\n",
    "\n",
    "    x2_z2, _ = flow2.inverse(x2)\n",
    "    x2_x1, _ = flow1.forward(x2_z2)\n",
    "\n",
    "    mappings['x1_fake'].append(x2_x1.sigmoid())\n",
    "    mappings['x2_real'].append(x2.sigmoid())\n",
    "    mappings['z2'].append(x2_z2)\n",
    "    mappings['y2'].append(y2)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1_real = denormalize_data(torch.vstack(mappings['x1_real'])).cpu().detach()\n",
    "x2_real = denormalize_data(torch.vstack(mappings['x2_real'])).cpu().detach()\n",
    "x2_fake = denormalize_data(torch.vstack(mappings['x2_fake'])).cpu().detach()\n",
    "x1_fake = denormalize_data(torch.vstack(mappings['x1_fake'])).cpu().detach()\n",
    "\n",
    "z1 = torch.vstack(mappings['z1']).cpu().detach()\n",
    "z2 = torch.vstack(mappings['z2']).cpu().detach()\n",
    "y1 = torch.cat(mappings['y1']).cpu().detach()\n",
    "y2 = torch.cat(mappings['y2']).cpu().detach()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# undo dequantization\n",
    "x1_real[:, 0] = torch.floor(x1_real[:, 0])\n",
    "x1_real[:, 2] = torch.floor(x1_real[:, 2])\n",
    "\n",
    "x2_real[:, 0] = torch.floor(x2_real[:, 0])\n",
    "x2_real[:, 2] = torch.floor(x2_real[:, 2])\n",
    "\n",
    "x1_fake[:, 0] = torch.floor(x1_fake[:, 0])\n",
    "x1_fake[:, 2] = torch.floor(x1_fake[:, 2])\n",
    "\n",
    "x2_fake[:, 0] = torch.floor(x2_fake[:, 0])\n",
    "x2_fake[:, 2] = torch.floor(x2_fake[:, 2])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1 = torch.cat((x1_real, x1_fake))\n",
    "x2 = torch.cat((x2_fake, x2_real))\n",
    "z = torch.cat((z1, z2))\n",
    "y = torch.cat((y1, y2))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for n_clusters in [4, 6, 8]:\n",
    "    kmeans_x1 = KMeans(n_clusters=n_clusters, random_state=0).fit(x1)\n",
    "    kmeans_x2 = KMeans(n_clusters=n_clusters, random_state=0).fit(x2)\n",
    "\n",
    "    perplexities = [5, 15, 25, 35, 45]\n",
    "    fig, ax = plt.subplots(\n",
    "        nrows=len(perplexities), ncols=2, figsize=(10, 5 * len(perplexities))\n",
    "    )\n",
    "\n",
    "    for idx, perplexity in enumerate(perplexities):\n",
    "\n",
    "        x1_t_sne = TSNE(perplexity=perplexity).fit_transform(x1)\n",
    "        x2_t_sne = TSNE(perplexity=perplexity).fit_transform(x2)\n",
    "\n",
    "        ax[idx, 0].scatter(x1_t_sne[:, 0], x1_t_sne[:, 1], c=kmeans_x1.labels_, cmap='tab10')\n",
    "        ax[idx, 1].scatter(x2_t_sne[:, 0], x2_t_sne[:, 1], c=kmeans_x1.labels_, cmap='tab10')\n",
    "\n",
    "    ax[0, 0].set_title('Non-White')\n",
    "    ax[0, 1].set_title('White')\n",
    "\n",
    "    fig.suptitle(f't-SNE with {n_clusters} Clusters')\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}.eps')\n",
    "\n",
    "    clusters_x1 = pd.DataFrame(columns=column_ids)\n",
    "    clusters_x2 = pd.DataFrame(columns=column_ids)\n",
    "\n",
    "    for cluster in range(n_clusters):\n",
    "        clusters_x1.loc[cluster] = x1[kmeans_x1.labels_ == cluster].mean(axis=0).numpy()\n",
    "        clusters_x2.loc[cluster] = x2[kmeans_x1.labels_ == cluster].mean(axis=0).numpy()\n",
    "\n",
    "    clusters_x1.to_csv(\n",
    "        plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x1.csv', index=False\n",
    "    )\n",
    "    clusters_x2.to_csv(\n",
    "        plots_dir / f'gamma_{gamma}_n_clusters_{n_clusters}_x2.csv', index=False\n",
    "    )"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train1 = denormalize_data(train1.sigmoid())\n",
    "train2 = denormalize_data(train2.sigmoid())\n",
    "\n",
    "# undo dequantization\n",
    "train1[:, 0] = torch.floor(train1[:, 0])\n",
    "train1[:, 2] = torch.floor(train1[:, 2])\n",
    "train2[:, 0] = torch.floor(train2[:, 0])\n",
    "train2[:, 2] = torch.floor(train2[:, 2])\n",
    "\n",
    "print(train1.min(0)[0])\n",
    "print(train2.min(0)[0])\n",
    "\n",
    "print(train1.mean(0))\n",
    "print(train2.mean(0))\n",
    "\n",
    "print(train1.max(0)[0])\n",
    "print(train2.max(0)[0])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "clf = LogisticRegression(random_state=0).fit(z, y)\n",
    "y_hat = clf.predict(z)\n",
    "clf.score(z, y)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mappings = defaultdict(list)\n",
    "\n",
    "for x1_i, z_i in zip(x1[y_hat == 0], z[y_hat == 0]):\n",
    "    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]\n",
    "\n",
    "    for beta in np.linspace(0, 1, 11):\n",
    "        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)\n",
    "\n",
    "        if clf.predict(z_new):\n",
    "            break\n",
    "\n",
    "    x1_i_new, _ = flow1.forward(z_new.to(device))\n",
    "\n",
    "    mappings['x1_old'].append(x1_i.to(device))\n",
    "    mappings['x1_new'].append(x1_i_new.sigmoid().to(device))\n",
    "\n",
    "for x2_i, z_i in zip(x2[y_hat == 0], z[y_hat == 0]):\n",
    "    z_i_nn = z[y_hat == 1][torch.norm(z_i - z[y_hat == 1], dim=1).argmin()]\n",
    "\n",
    "    for beta in np.linspace(0, 1, 11):\n",
    "        z_new = torch.unsqueeze(z_i + beta * (z_i_nn - z_i), 0)\n",
    "\n",
    "        if clf.predict(z_new):\n",
    "            break\n",
    "\n",
    "    x2_i_new, _ = flow2.forward(z_new.to(device))\n",
    "\n",
    "    mappings['x2_old'].append(x2_i.to(device))\n",
    "    mappings['x2_new'].append(x2_i_new.sigmoid().to(device))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1_old = torch.vstack(mappings['x1_old'])\n",
    "x2_old = torch.vstack(mappings['x2_old'])\n",
    "\n",
    "# undo flow normalization\n",
    "x1_new = denormalize_data(torch.vstack(mappings['x1_new']))\n",
    "x2_new = denormalize_data(torch.vstack(mappings['x2_new']))\n",
    "# undo dequantization\n",
    "x1_new[:, 0] = torch.floor(x1_new[:, 0])\n",
    "x2_new[:, 0] = torch.floor(x2_new[:, 0])\n",
    "x1_new[:, 2] = torch.floor(x1_new[:, 2])\n",
    "x2_new[:, 2] = torch.floor(x2_new[:, 2])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "x1_diff = x1_new - x1_old\n",
    "x2_diff = x2_new - x2_old\n",
    "\n",
    "x1_diff_without_college = x1_diff[x1_diff[:, 2] == 0]\n",
    "x2_diff_without_college = x2_diff[x2_diff[:, 2] == 0]\n",
    "\n",
    "avg_recourse = pd.DataFrame(\n",
    "    [x1_diff_without_college.mean(0).cpu().detach().numpy(),\n",
    "     x2_diff_without_college.mean(0).cpu().detach().numpy()],\n",
    "    columns=column_ids, index=['Non-White', 'White']\n",
    ")\n",
    "avg_recourse.to_csv(plots_dir / f'recourse_gamma_{gamma}.csv')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}