{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of \"Procedural Fairness Through Decoupling Objectionable Data Generating Components\"\n",
    "\n",
    "Experiment on [UCI Adult data set](https://archive.ics.uci.edu/dataset/2/adult).\n",
    "\n",
    "Following Nabi \\& Shpitser (2018) and Chiappa (2019), the potential locations for objectionable components include:\n",
    "- $A \\rightarrow Y$\n",
    "- $A \\rightarrow M \\rightarrow \\cdots \\rightarrow Y$\n",
    "  - $A \\rightarrow M \\rightarrow Y$\n",
    "  - $A \\rightarrow M \\rightarrow L \\rightarrow Y$\n",
    "  - $A \\rightarrow M \\rightarrow R \\rightarrow Y$\n",
    "  - $A \\rightarrow M \\rightarrow L \\rightarrow R \\rightarrow Y$\n",
    "\n",
    "1. We define the causal graph `CausalGraph` by specifying the list of nodes and their parents.\n",
    "1. We construct the prediction model `CustomNetwork` using `SubNetworks` for each local causal module, i.e., for nodes that have parent nodes in the `CausalGraph`.\n",
    "1. We implement training and testing afterwards, followed by the optimization for reference points."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the nonlinear model without fairness constraints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Python package for simulated annealing\n",
    "%pip install -e git+https://github.com/perrygeo/simanneal.git#egg=simanneal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define causal model and load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import libraries\n",
    "import os\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "from utils import (\n",
    "    CustomDataset,\n",
    "    CausalGraph,\n",
    "    CustomMinMaxScaler,\n",
    "    CustomNetwork,\n",
    "    CustomStandardScaler,\n",
    ")\n",
    "from utils import (\n",
    "    train_hierarchical,\n",
    "    evaluate_hierarchical,\n",
    "    load_uci_adult_preprocessed,\n",
    "    TailRefPtConfigAnnealer,\n",
    ")\n",
    "\n",
    "if_train = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define causal graph and prepare data\n",
    "experiment_prefix = \"uci_adult\"\n",
    "relation_type = \"nonlinear\"\n",
    "experiment_prefix = experiment_prefix + \".\" + relation_type\n",
    "is_linear = True if \"linear\" == relation_type else False\n",
    "\n",
    "nodes = [\n",
    "    \"male\",  # sex\n",
    "    \"married\",  # marital-status\n",
    "    \"higher_edu\",  # education-num\n",
    "    \"managerial_occ\",  # occupation\n",
    "    \"high_hours\",  # hours-per-week\n",
    "    \"gov_jobs\",  # work-class\n",
    "    # \"race\",  # not included in prediction process\n",
    "    \"age\",\n",
    "    \"native_country\",  # native-country\n",
    "    \"high_income\",  # income\n",
    "]\n",
    "parent_dict = {\n",
    "    \"male\": [],  # A\n",
    "    \"age\": [],  # C (part)\n",
    "    \"native_country\": [],  # C (part)\n",
    "    \"married\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "    ],  # M\n",
    "    \"higher_edu\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "        \"married\",\n",
    "    ],  # L\n",
    "    \"managerial_occ\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "        \"married\",\n",
    "        \"higher_edu\",\n",
    "    ],  # R1\n",
    "    \"high_hours\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "        \"married\",\n",
    "        \"higher_edu\",\n",
    "    ],  # R2\n",
    "    \"gov_jobs\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "        \"married\",\n",
    "        \"higher_edu\",\n",
    "    ],  # R3\n",
    "    \"high_income\": [\n",
    "        \"male\",\n",
    "        \"age\",\n",
    "        \"native_country\",\n",
    "        \"married\",\n",
    "        \"higher_edu\",\n",
    "        \"managerial_occ\",\n",
    "        \"high_hours\",\n",
    "        \"gov_jobs\",\n",
    "    ],  # Y\n",
    "}\n",
    "node_types = {node: \"continuous\" for node in nodes}\n",
    "node_types[\"male\"] = \"binary\"\n",
    "node_types[\"high_income\"] = \"binary\"\n",
    "causal_graph = CausalGraph(nodes, parent_dict, node_types)\n",
    "final_output_node = causal_graph.nodes[-1]\n",
    "\n",
    "data_dict, n_samples = load_uci_adult_preprocessed(\n",
    "    nodes, data_file_path=\"data/adult_processed.csv\"\n",
    ")\n",
    "print(f\"Number of samples: {n_samples}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify experimental setting\n",
    "learning_rate = 1e-3\n",
    "min_hidden_neurons = 0 if \"linear\" == relation_type else 5\n",
    "if 0 != min_hidden_neurons:\n",
    "    experiment_prefix += f\"_min_hidden_neurons_{min_hidden_neurons}\"\n",
    "\n",
    "if_linear_model = is_linear\n",
    "\n",
    "# GPU index (if use cuda)\n",
    "gpu_idx = 0\n",
    "\n",
    "n_epochs = 100\n",
    "batch_size = 2048\n",
    "checkpoint_interval = 10\n",
    "n_repeats = 1 if if_train is True else 0  # number of runs\n",
    "\n",
    "# Used only in linear models, will be ignored for nonlinear models\n",
    "lambda_linear_L1 = 0.0\n",
    "# No explicit edge coef constraint in our approach\n",
    "edge_linear_coef_constraint_config = None  # set to {...} for linear\n",
    "\n",
    "# Set device\n",
    "device_type = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "device_id = torch.device(f\"cuda:{gpu_idx}\") if \"cuda\" == device_type else \"cpu\"\n",
    "device = torch.device(device_id)\n",
    "experiment_prefix = experiment_prefix + \".\" + device_type\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "# Set checkpoint directory and file name\n",
    "checkpoint_dir = os.path.join(os.getcwd(), \"checkpoints\")\n",
    "if not os.path.exists(checkpoint_dir):\n",
    "    os.makedirs(checkpoint_dir)\n",
    "checkpoint_file = os.path.join(checkpoint_dir, f\"{experiment_prefix}.pt\")\n",
    "\n",
    "print(f\"Experiment prefix: {experiment_prefix}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data set processing\n",
    "\n",
    "- For features: shift and scale to 0 mean and 1 variance\n",
    "- For targets: shift and scale to [0, 1] without clipping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data preprocessing and data set preparation\n",
    "# Split the indices into train and test indices (fix testing indices)\n",
    "train_indices, test_indices = train_test_split(\n",
    "    np.arange(n_samples), test_size=0.25, random_state=42\n",
    ")\n",
    "\n",
    "# Create StandardScaler for features and fit on train_indices (later into train_ and val_indices)\n",
    "feature_standard_scaler_dict = CustomStandardScaler()\n",
    "feature_standard_scaler_dict.fit(\n",
    "    data_dict,\n",
    "    [  # excluded nodes that are not features\n",
    "        final_output_node,\n",
    "    ],\n",
    "    train_indices,\n",
    ")\n",
    "\n",
    "# Create MinMaxScaler and scale final output to [0, 1] on train_indices\n",
    "target_minmax_scaler_dict = CustomMinMaxScaler((0, 1), clip=False)\n",
    "target_minmax_scaler_dict.fit(\n",
    "    data_dict,\n",
    "    [  # only include target nodes\n",
    "        final_output_node,\n",
    "    ],\n",
    "    train_indices,\n",
    ")\n",
    "\n",
    "# Transform features and targets sequentially\n",
    "# NOTE use args instead of kwargs to avoid `wrapped()` error of sklearn\n",
    "_training_data_dict_scaled = feature_standard_scaler_dict.transform(\n",
    "    data_dict,\n",
    "    [  # excluded keys\n",
    "        final_output_node,\n",
    "    ],\n",
    "    train_indices,\n",
    ")\n",
    "training_data_dict_scaled = target_minmax_scaler_dict.transform(\n",
    "    _training_data_dict_scaled,\n",
    "    [  # target keys\n",
    "        final_output_node,\n",
    "    ],\n",
    "    np.arange(len(train_indices)),\n",
    ")\n",
    "\n",
    "_testing_data_dict_scaled = feature_standard_scaler_dict.transform(\n",
    "    data_dict,\n",
    "    [  # excluded keys\n",
    "        final_output_node,\n",
    "    ],\n",
    "    test_indices,\n",
    ")\n",
    "testing_data_dict_scaled = target_minmax_scaler_dict.transform(\n",
    "    _testing_data_dict_scaled,\n",
    "    [  # target keys\n",
    "        final_output_node,\n",
    "    ],\n",
    "    np.arange(len(test_indices)),\n",
    ")\n",
    "\n",
    "# Create CustomDataset instances\n",
    "training_dataset = CustomDataset(training_data_dict_scaled)\n",
    "testing_dataset = CustomDataset(testing_data_dict_scaled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the neural network with cross validation\n",
    "\n",
    "Indicative running time: ~ 2 min\n",
    "\n",
    "If the current experimental setting is already trained previously, skip this section and directly load the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the neural network with cross validation\n",
    "best_checkpoint_val_loss = float(\"inf\")\n",
    "\n",
    "for _ in range(n_repeats):\n",
    "    # Split training_dataset into training, validating\n",
    "    _train_indices, _val_indices = train_test_split(\n",
    "        np.arange(len(training_dataset)), test_size=0.25\n",
    "    )\n",
    "\n",
    "    # Create DataLoader on training_dataset\n",
    "    train_loader = DataLoader(\n",
    "        training_dataset,\n",
    "        batch_size=batch_size,\n",
    "        sampler=SubsetRandomSampler(_train_indices),\n",
    "    )\n",
    "    val_loader = DataLoader(\n",
    "        training_dataset,\n",
    "        batch_size=batch_size,\n",
    "        sampler=SubsetRandomSampler(_val_indices),\n",
    "    )\n",
    "\n",
    "    # Initialize model, criterion, and optimizer\n",
    "    model = CustomNetwork(\n",
    "        causal_graph, min_hidden_neurons=min_hidden_neurons, is_linear=if_linear_model\n",
    "    ).to(device)\n",
    "\n",
    "    criterion_dict = {\n",
    "        node: nn.BCELoss() if submodel.output_type == \"binary\" else nn.MSELoss()\n",
    "        for node, submodel in model.submodels.items()\n",
    "    }\n",
    "\n",
    "    optimizer_dict = {\n",
    "        node: torch.optim.Adam(submodel.parameters(), lr=learning_rate)\n",
    "        for node, submodel in model.submodels.items()\n",
    "    }\n",
    "\n",
    "    model.set_lambda_linear_model_coef_penalty_L1(lambda_linear_L1)\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        train_loss_dict = train_hierarchical(\n",
    "            model,\n",
    "            train_loader,\n",
    "            criterion_dict,\n",
    "            optimizer_dict,\n",
    "            edge_linear_coef_constraint_config,\n",
    "        )\n",
    "        val_loss_dict = evaluate_hierarchical(\n",
    "            model, val_loader, criterion_dict, feature_standard_scaler_dict\n",
    "        )\n",
    "\n",
    "        # Save the top-1 best performance checkpoints\n",
    "        if epoch % checkpoint_interval == 0:\n",
    "            # Sum over all output nodes of local modules\n",
    "            val_loss = sum(val_loss_dict.values())\n",
    "            if val_loss < best_checkpoint_val_loss:\n",
    "                best_checkpoint_val_loss = val_loss\n",
    "\n",
    "                # Save the checkpoint\n",
    "                torch.save(\n",
    "                    {\n",
    "                        \"epoch\": epoch,\n",
    "                        \"model_state_dict\": model.state_dict(),\n",
    "                        # 'optimizer_state_dict': optimizer.state_dict(),\n",
    "                        # 'loss': val_loss,\n",
    "                        \"hyperparams\": {\n",
    "                            \"learning_rate\": learning_rate,\n",
    "                            \"min_hidden_neurons\": min_hidden_neurons,\n",
    "                        },\n",
    "                    },\n",
    "                    checkpoint_file,\n",
    "                )\n",
    "\n",
    "pprint(f\"Best model state (checkpoint_file): {checkpoint_file}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Derive $\\widehat{Y}$ with reference point\n",
    "\n",
    "For better interpretability, the reference point configuration is always in the raw data scale.\n",
    "The corresponding scaling for prediction is included in `CustomNetwork`.\n",
    "\n",
    "In this implementation, disadvantaged individuals are female individuals, but there are different reference point configurations:\n",
    "1. $\\mathcal{E}_{\\mathrm{obj}} = \\{ A \\rightarrow Y \\}$\n",
    "1. $\\mathcal{E}_{\\mathrm{obj}} = \\{ A \\rightarrow Y, M \\rightarrow Y \\}$\n",
    "1. $\\mathcal{E}_{\\mathrm{obj}} = \\{ A \\rightarrow Y, M \\rightarrow R_1, M \\rightarrow Y \\}$\n",
    "1. $\\mathcal{E}_{\\mathrm{obj}} = \\{ A \\rightarrow Y, M \\rightarrow R_1, M \\rightarrow L, M \\rightarrow Y \\}$\n",
    "\n",
    "Note:\n",
    "- There is no fairness constraint enforced when fitting model parameters\n",
    "- The reference points are introduced to decouple objectionable components"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model\n",
    "def load_best_model(checkpoint_dir, experiment_prefix):\n",
    "    files = os.listdir(checkpoint_dir)\n",
    "    matching_files = [f for f in files if f.startswith(experiment_prefix)]\n",
    "    # matching_files.sort(key=lambda x: int(x.split('.')[1].split('_')[1]))\n",
    "    best_checkpoint_file = matching_files[0]\n",
    "    print(best_checkpoint_file)\n",
    "    checkpoint = torch.load(os.path.join(checkpoint_dir, best_checkpoint_file))\n",
    "    model = CustomNetwork(\n",
    "        causal_graph,\n",
    "        min_hidden_neurons=checkpoint[\"hyperparams\"][\"min_hidden_neurons\"],\n",
    "        is_linear=if_linear_model,\n",
    "    ).to(device)\n",
    "    model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "    return model, checkpoint[\"hyperparams\"]\n",
    "\n",
    "\n",
    "best_model, best_hyperparams = load_best_model(checkpoint_dir, experiment_prefix)\n",
    "\n",
    "criterion_dict = {\n",
    "    node: nn.BCELoss() if submodel.output_type == \"binary\" else nn.MSELoss()\n",
    "    for node, submodel in best_model.submodels.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define most disadvantaged individuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DataLoader for data points where REFERENCE_POINT need to be set\n",
    "mask_most_disadv_individuals = testing_dataset.torch_tensor_dict[\"male\"] < 0\n",
    "mask_length = mask_most_disadv_individuals.shape[0]\n",
    "idx_most_disadv_individuals = np.arange(mask_length)[\n",
    "    mask_most_disadv_individuals.reshape(\n",
    "        -1,\n",
    "    )\n",
    "]\n",
    "most_disadv_individuals = Subset(testing_dataset, idx_most_disadv_individuals)\n",
    "\n",
    "record_loader = DataLoader(\n",
    "    most_disadv_individuals, batch_size=len(most_disadv_individuals)\n",
    ")  # only one batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform simulated annealing to derive the best reference point values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Simulated annealing configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find best reference points by Simulated Annealing\n",
    "# Create MinMaxScaler for features are in its raw scale\n",
    "feature_minmax_scaler_dict = CustomMinMaxScaler((0, 1), clip=False)\n",
    "feature_nodes = nodes.copy()\n",
    "feature_nodes.remove(final_output_node)\n",
    "feature_minmax_scaler_dict.fit(\n",
    "    data_dict, feature_nodes, np.arange(len(train_indices))  # included nodes\n",
    ")\n",
    "\n",
    "# Different number of tail reference points initialization with raw-scale variable values\n",
    "\n",
    "initial_tail_ref_pt_config = {\n",
    "    \"male->high_income\": 1,\n",
    "    # \"married->higher_edu\": 1,\n",
    "    # \"married->managerial_occ\": 1,\n",
    "    \"married->high_income\": 1,\n",
    "}\n",
    "\n",
    "best_overall_tail_ref_pt_config = None\n",
    "best_overall_avg_outcome = float(\"-inf\")\n",
    "\n",
    "# Create an instance of CustomNetworkAnnealer\n",
    "annealer = TailRefPtConfigAnnealer(\n",
    "    state=initial_tail_ref_pt_config,\n",
    "    model=best_model,\n",
    "    dataloader=record_loader,\n",
    "    node_types=node_types,\n",
    "    output_node=final_output_node,\n",
    "    feature_standard_scaler_dict=feature_standard_scaler_dict,\n",
    "    feature_minmax_scaler_dict=feature_minmax_scaler_dict,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multiple run to find the best reference point values\n",
    "\n",
    "Indicative running time: ~ 1 min per repeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of times to run the simulated annealing process\n",
    "n_repeats_SA = 1\n",
    "\n",
    "for i in range(n_repeats_SA):\n",
    "    # Create an instance of CustomNetworkAnnealer\n",
    "    annealer = TailRefPtConfigAnnealer(\n",
    "        state=initial_tail_ref_pt_config,\n",
    "        model=best_model,\n",
    "        dataloader=record_loader,\n",
    "        node_types=node_types,\n",
    "        output_node=final_output_node,\n",
    "        feature_standard_scaler_dict=feature_standard_scaler_dict,\n",
    "        feature_minmax_scaler_dict=feature_minmax_scaler_dict,\n",
    "    )\n",
    "\n",
    "    annealer.set_schedule({\"tmax\": 5.0, \"tmin\": 0.001, \"steps\": 1000, \"updates\": 100})\n",
    "\n",
    "    # Run the annealing optimization\n",
    "    best_tail_ref_pt_config, best_avg_outcome = annealer.anneal()\n",
    "\n",
    "    if -best_avg_outcome > best_overall_avg_outcome:\n",
    "        best_overall_tail_ref_pt_config = best_tail_ref_pt_config\n",
    "        best_overall_avg_outcome = -best_avg_outcome\n",
    "\n",
    "    print(\n",
    "        f\"Step {i+1}/{n_repeats_SA}: Best tail reference point configuration: {best_tail_ref_pt_config}\"\n",
    "    )\n",
    "    print(\n",
    "        f\"Step {i+1}/{n_repeats_SA}: Best average outcome: {-best_avg_outcome}\"\n",
    "    )  # Negate to get the maximized outcome\n",
    "\n",
    "print(\"Best overall tail reference point configuration:\")\n",
    "pprint(best_overall_tail_ref_pt_config)\n",
    "print(\"Best overall average outcome:\", best_overall_avg_outcome)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction with objectionable components decoupled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"married\": married is coded 6 or 7\n",
    "tail_ref_pt_config = {\n",
    "    \"male->high_income\": 0,\n",
    "    # \"married->higher_edu\": 7,\n",
    "    # \"married->managerial_occ\": 1,\n",
    "    \"married->high_income\": 7,\n",
    "}\n",
    "\n",
    "test_whole = DataLoader(testing_dataset, batch_size=len(test_indices), shuffle=False)\n",
    "\n",
    "batch_test_whole = next(iter(test_whole))\n",
    "batch_test_whole = {\n",
    "    node: node_tensor.to(device) for node, node_tensor in batch_test_whole.items()\n",
    "}\n",
    "\n",
    "batch_test_whole_output_inference = best_model.inference(\n",
    "    batch_test_whole,\n",
    "    feature_standard_scaler_dict,\n",
    "    tail_ref_pt_config=tail_ref_pt_config,\n",
    "    tail_ref_pt_config_as_dummy=False,\n",
    ")\n",
    "\n",
    "Yhat_ref_pt = batch_test_whole_output_inference[\"high_income\"].numpy()\n",
    "Yhat_ref_pt = (\n",
    "    (Yhat_ref_pt > 0.5)\n",
    "    .astype(int)\n",
    "    .reshape(\n",
    "        -1,\n",
    "    )\n",
    ")\n",
    "\n",
    "print(\"Done. Outputs with CustomNetwork.inference() with reference points.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "workspace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
