{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook replicates the synthetic data experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import os\n",
    "from utils import generate_synthetic_LTR_data\n",
    "from fair_training_ranking import train_fair_nn\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior()\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to delete these directories if they exist so plotting works properly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm -r heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm -r tensorboard_simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm -r results_simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create synthetic train/test data\n",
    "num_docs_per_query = 10\n",
    "num_queries = 100\n",
    "X_queries, relevances, majority_status = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)\n",
    "X_queries_test, relevances_test, majority_status_test = generate_synthetic_LTR_data(num_queries = num_queries, num_docs_per_query = num_docs_per_query)\n",
    "\n",
    "if not os.path.exists('data'):\n",
    "    os.makedirs('data')\n",
    "\n",
    "np.save('data/X.npy', X_queries)\n",
    "np.save('data/relevance.npy', relevances)\n",
    "np.save('data/majority_status.npy', majority_status)\n",
    "np.save('data/X_test.npy', X_queries_test)\n",
    "np.save('data/relevance_test.npy', relevances_test)\n",
    "np.save('data/majority_status_test.npy', majority_status_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#baseline\n",
    "LR = LogisticRegression(C = 100).fit(X_queries, majority_status)\n",
    "sens_directions = LR.coef_\n",
    "print('sensitive directions', sens_directions)\n",
    "\n",
    "_  = train_fair_nn(X_queries,\n",
    "                                                        relevances,\n",
    "                                                        majority_status,\n",
    "                                                        X_test = X_queries_test,\n",
    "                                                        relevance_test = relevances_test,\n",
    "                                                        group_membership_test = majority_status_test,\n",
    "                                                        num_items_per_query = num_docs_per_query,\n",
    "                                                        tf_prefix='baseline',\n",
    "                                                        weights=None,\n",
    "                                                        n_units = [],\n",
    "                                                        lr=0.04,\n",
    "                                                        fair_start=1.,\n",
    "                                                        batch_size=1,\n",
    "                                                        epoch=20*num_queries,\n",
    "                                                        verbose=True,\n",
    "                                                        activ_f = tf.nn.relu,\n",
    "                                                        l2_reg=0.0,\n",
    "                                                        plot=True,\n",
    "                                                        sens_directions=sens_directions,\n",
    "                                                        seed=None,\n",
    "                                                        simul=True, # need to make this true if you want to make plots\n",
    "                                                        num_monte_carlo_samples = 10,\n",
    "                                                        bias = False,\n",
    "                                                        init_range = .0001,\n",
    "                                                        entropy_regularizer = .0,\n",
    "                                                        baseline_ndcg = True,\n",
    "                                                        load = False\n",
    "                                                        )\n",
    "\n",
    "#Train SenSeI with different fair regularization strength\n",
    "for fair_reg in [.0003, .001]:\n",
    "    tf.reset_default_graph()\n",
    "    print('fair_reg',fair_reg)\n",
    "    _  = train_fair_nn(X_queries,\n",
    "                                                        relevances,\n",
    "                                                        majority_status,\n",
    "                                                        num_items_per_query = num_docs_per_query,\n",
    "                                                        tf_prefix='sensei',\n",
    "                                                        X_test = X_queries_test,\n",
    "                                                        relevance_test = relevances_test,\n",
    "                                                        group_membership_test = majority_status_test,\n",
    "                                                        weights=None,\n",
    "                                                        n_units = [],\n",
    "                                                        lr=0.04,\n",
    "                                                        batch_size=1,\n",
    "                                                        epoch=20*num_queries,\n",
    "                                                        verbose=True,\n",
    "                                                        activ_f = tf.nn.relu,\n",
    "                                                        l2_reg=0.,\n",
    "                                                        plot=True,\n",
    "                                                        lamb_init=2.,\n",
    "                                                        adv_epoch=20,\n",
    "                                                        adv_step=.001,\n",
    "                                                        epsilon=0.001,\n",
    "                                                        sens_directions=sens_directions,\n",
    "                                                        l2_attack=0.001,\n",
    "                                                        adv_epoch_full=20,\n",
    "                                                        fair_reg=fair_reg,\n",
    "                                                        fair_start=0.,\n",
    "                                                        seed=None,\n",
    "                                                        simul=True,\n",
    "                                                        num_monte_carlo_samples = 10,\n",
    "                                                        bias = False,\n",
    "                                                        init_range = .0001,\n",
    "                                                        entropy_regularizer = .0,\n",
    "                                                        baseline_ndcg = True,\n",
    "                                                        load = True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# run fair-pg-rank\n",
    "!python PG.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot results\n",
    "X = np.load('data/X.npy')\n",
    "relevances = np.load('data/relevance.npy')\n",
    "xx = np.load('data/xx.npy')\n",
    "yy = np.load('data/yy.npy')\n",
    "Z = np.load('data/Z_0.0.npy')\n",
    "Z = Z.reshape(xx.shape)\n",
    "Z_2 = np.load('data/Z_0.0003.npy')\n",
    "Z_2 = Z_2.reshape(xx.shape)\n",
    "Z_3 = np.load('data/Z_0.001.npy')\n",
    "Z_3 = Z_3.reshape(xx.shape)\n",
    "\n",
    "# find the minority with the highest relevance but smallest x-value\n",
    "min_x_idx = 0\n",
    "min_x_value = 100\n",
    "max_relevance = 0\n",
    "max_relevance_idx = 0\n",
    "for i in range(X.shape[0]):\n",
    "    if X[i,1] == 0 and X[i,0] < 1.6:\n",
    "        if max_relevance < relevances[i]:\n",
    "            max_relevance_idx = i\n",
    "            max_relevance = relevances[i]\n",
    "print(max_relevance)\n",
    "#find majority\n",
    "\n",
    "max_relevance = 0\n",
    "max_relevance_idx_majority = 0\n",
    "for i in range(X.shape[0]):\n",
    "    if X[i,1] > 2.9 and X[i,0] < 1.7 and X[i,0]>1.2:\n",
    "        if np.abs(relevances[i] - max_relevance) < 5:\n",
    "            max_relevance_idx_majority = i\n",
    "            break\n",
    "print(relevances[i])\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "\n",
    "grid = ImageGrid(fig, 111,          # as in plt.subplot(111)\n",
    "                 nrows_ncols=(1,3),\n",
    "                 axes_pad=0.15,\n",
    "                 share_all=True,\n",
    "                 cbar_location=\"right\",\n",
    "                 cbar_mode=\"single\",\n",
    "                 cbar_size=\"7%\",\n",
    "                 cbar_pad=0.15,\n",
    "                 )\n",
    "\n",
    "cm = 'PiYG_r'\n",
    "\n",
    "plt.rc('xtick', labelsize=20) \n",
    "plt.rc('ytick', labelsize=20)\n",
    "\n",
    "contour = grid[0].contourf(xx, yy, Z, cmap='bwr', alpha=.8)\n",
    "grid[0].set_title('Baseline: $\\\\rho=0$', fontsize = 25)\n",
    "\n",
    "contour2 = grid[0].scatter(X[:,0], X[:,1], cmap=cm, c=relevances)\n",
    "grid[0].scatter(X[max_relevance_idx, 0], 0, marker = '*', s=250, color = 'blue')\n",
    "grid[0].scatter(X[max_relevance_idx_majority, 0], X[max_relevance_idx_majority, 1], marker = '*', s=300, color = 'black')\n",
    "#plt.colorbar(contour2)\n",
    "#############################\n",
    "contour = grid[1].contourf(xx, yy, Z_2, cmap='bwr', alpha=.8)\n",
    "grid[1].set_title('SenSTIR: $\\\\rho=.0003$', fontsize = 25)\n",
    "\n",
    "contour2= grid[1].scatter(X[:,0], X[:,1], cmap=cm, c=relevances)\n",
    "grid[1].scatter(X[max_relevance_idx, 0], 0, marker = '*', s=250, color = 'blue')\n",
    "grid[1].scatter(X[max_relevance_idx_majority, 0], X[max_relevance_idx_majority, 1], marker = '*', s=300, color = 'black')\n",
    "\n",
    "#plt.colorbar(contour2)\n",
    "#Final group_exposure_test_stochastic 0.0103\n",
    "\n",
    "#############################\n",
    "contour = grid[2].contourf(xx, yy, Z_3, cmap='bwr', alpha=.8)\n",
    "grid[2].set_title('SenSTIR: $\\\\rho=.001$', fontsize = 25)\n",
    "\n",
    "contour2 = grid[2].scatter(X[:,0], X[:,1], cmap=cm, c=relevances)\n",
    "grid[2].scatter(X[max_relevance_idx, 0], 0, marker = '*', s=250, color = 'blue')\n",
    "grid[2].scatter(X[max_relevance_idx_majority, 0], X[max_relevance_idx_majority, 1], marker = '*', s=300, color = 'black')\n",
    "\n",
    "\n",
    "grid[2].cax.colorbar(contour2)\n",
    "grid[2].cax.toggle_label(True)\n",
    "#Final group_exposure_test_stochastic 0.008\n",
    "\n",
    "# plt.tight_layout()\n",
    "plt.savefig('synthetic.pdf')\n",
    "\n",
    "fig = plt.figure(figsize=(15, 4))\n",
    "\n",
    "grid = ImageGrid(fig, 111,          # as in plt.subplot(111)\n",
    "                 nrows_ncols=(1,4),\n",
    "                 axes_pad=0.15,\n",
    "                 share_all=True,\n",
    "                 cbar_location=\"right\",\n",
    "                 cbar_mode=\"single\",\n",
    "                 cbar_size=\"7%\",\n",
    "                 cbar_pad=0.15,\n",
    "                 )\n",
    "\n",
    "plt.rc('xtick', labelsize=15) \n",
    "plt.rc('ytick', labelsize=15)\n",
    "\n",
    "baseline_heatmap = np.load('heatmaps/baseline_adv-epoch:100_batch_size:1_adv-step:1.0_l2_attack:0.01_adv_epoch_full:10_epsilon:None_lr:0.04_MC:10_reg:0.0_epoch:2000_l2reg:0.0_init_range:0.0001_arch:_heatmap_test_stochastic_0.npy')\n",
    "sensei_heatmap_1 = np.load('heatmaps/sensei_adv-epoch:20_batch_size:1_adv-step:0.001_l2_attack:0.001_adv_epoch_full:20_epsilon:0.001_lr:0.04_MC:10_reg:0.0003_epoch:2000_l2reg:0.0_init_range:0.0001_arch:_heatmap_test_stochastic_0.npy')\n",
    "sensei_heatmap_2 = np.load('heatmaps/sensei_adv-epoch:20_batch_size:1_adv-step:0.001_l2_attack:0.001_adv_epoch_full:20_epsilon:0.001_lr:0.04_MC:10_reg:0.001_epoch:2000_l2reg:0.0_init_range:0.0001_arch:_heatmap_test_stochastic_0.npy')\n",
    "PG_heatmap = np.load('heatmaps/sensei_adv-epoch:100_batch_size:1_adv-step:1.0_l2_attack:0.01_adv_epoch_full:10_epsilon:None_lr:0.04_MC:10_reg:0.0_epoch:0_l2reg:0.0_init_range:0.0001_arch:_heatmap_test_stochastic_0.npy')\n",
    "\n",
    "min = np.min(np.concatenate((baseline_heatmap, sensei_heatmap_1, sensei_heatmap_2, PG_heatmap)))\n",
    "max = np.max(np.concatenate((baseline_heatmap, sensei_heatmap_1, sensei_heatmap_2, PG_heatmap)))\n",
    "\n",
    "grid[0].set_title('Baseline: $\\\\rho=0$', fontsize = 17)\n",
    "grid[0].set_xticks(np.arange(10))\n",
    "grid[0].set_xticklabels(np.arange(11)[1:])\n",
    "grid[0].set_yticks(np.arange(10))\n",
    "grid[0].set_yticklabels(np.arange(11)[1:])\n",
    "\n",
    "grid[0].imshow(baseline_heatmap,vmin=min, vmax=max, aspect='auto')\n",
    "\n",
    "#plt.colorbar(contour2)\n",
    "#############################\n",
    "\n",
    "grid[2].set_title('SenSTIR: $\\\\rho=.0003$', fontsize = 17)\n",
    "grid[2].set_xticks(np.arange(10))\n",
    "grid[2].set_xticklabels(np.arange(11)[1:])\n",
    "grid[2].set_yticks(np.arange(10))\n",
    "grid[2].set_yticklabels(np.arange(11)[1:])\n",
    "\n",
    "grid[2].imshow(sensei_heatmap_1,vmin=min, vmax=max, aspect='auto')\n",
    "#cbar = plt.colorbar()\n",
    "#cbar.ax.tick_params(labelsize=14) \n",
    "#plt.colorbar(contour2)\n",
    "\n",
    "#############################\n",
    "\n",
    "grid[3].set_title('SenSTIR: $\\\\rho=.001$', fontsize = 17)\n",
    "grid[3].set_xticks(np.arange(10))\n",
    "grid[3].set_xticklabels(np.arange(11)[1:])\n",
    "grid[3].set_yticks(np.arange(10))\n",
    "grid[3].set_yticklabels(np.arange(11)[1:])\n",
    "\n",
    "im = grid[3].imshow(sensei_heatmap_2,vmin=min, vmax=max, aspect='auto')\n",
    "# #############################\n",
    "\n",
    "grid[1].set_title('Fair-PG-Rank: $\\\\lambda=25$', fontsize = 17)\n",
    "grid[1].set_xticks(np.arange(10))\n",
    "grid[1].set_xticklabels(np.arange(11)[1:])\n",
    "grid[1].set_yticks(np.arange(10))\n",
    "grid[1].set_yticklabels(np.arange(11)[1:])\n",
    "grid[1].imshow(PG_heatmap,vmin=min, vmax=max, aspect='auto')\n",
    "# #############################\n",
    "# plt.subplot(1, 5, 5)\n",
    "# plt.title('Project', fontsize = 20)\n",
    "# plt.xticks(fontsize=16)\n",
    "# plt.yticks(fontsize=16)\n",
    "# plt.imshow(project_heatmap, vmin=min, vmax=max, aspect='auto')\n",
    "# cbar = plt.colorbar()\n",
    "# cbar.ax.tick_params(labelsize=14) \n",
    "\n",
    "grid[3].cax.colorbar(im)\n",
    "grid[3].cax.toggle_label(True)\n",
    "\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig('heatmap.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
