{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/Trusted-AI/AIF360/blob/master/examples/sklearn/demo_grid_search_reduction_classification_sklearn.ipynb)\n"
      ],
      "metadata": {
        "id": "Zny_LW9qx_vx"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wXvOUWF3e1_5"
      },
      "source": [
        "# Sklearn compatible Grid Search for classification\n",
        "\n",
        "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n",
        "the candidates searched. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#Install aif360\n",
        "#Install Reductions from Fairlearn\n",
        "!pip install aif360[Reductions]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YbWByiJFfPP5",
        "outputId": "d7f5cc84-4c95-4271-a7de-3cc4e4035d97"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: aif360[Reductions] in /usr/local/lib/python3.7/dist-packages (0.5.0)\n",
            "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (1.21.6)\n",
            "Requirement already satisfied: scikit-learn>=1.0 in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (1.0.2)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (3.2.2)\n",
            "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (1.3.5)\n",
            "Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (1.7.3)\n",
            "Requirement already satisfied: fairlearn~=0.7 in /usr/local/lib/python3.7/dist-packages (from aif360[Reductions]) (0.7.0)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[Reductions]) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[Reductions]) (2022.2.1)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.24.0->aif360[Reductions]) (1.15.0)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[Reductions]) (3.1.0)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[Reductions]) (1.1.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[Reductions]) (3.0.9)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[Reductions]) (1.4.4)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[Reductions]) (0.11.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->aif360[Reductions]) (4.1.1)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "New8gRgee1_-"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "gUpSgWaAe1__"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.metrics import accuracy_score\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "from aif360.sklearn.inprocessing import GridSearchReduction\n",
        "\n",
        "from aif360.sklearn.datasets import fetch_adult\n",
        "from aif360.sklearn.metrics import average_odds_error"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dCde3Zgle2AA"
      },
      "source": [
        "### Loading data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzyxulcxe2AA"
      },
      "source": [
        "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
        "\n",
        "For example, we can easily load the Adult dataset from UCI with the following line:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 368
        },
        "id": "izBC2P4be2AB",
        "outputId": "43217f1c-0602-4a7f-f036-863de074b681"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "                 age  workclass     education  education-num  \\\n",
              "race      sex                                                  \n",
              "Non-white Male  25.0    Private          11th            7.0   \n",
              "White     Male  38.0    Private       HS-grad            9.0   \n",
              "          Male  28.0  Local-gov    Assoc-acdm           12.0   \n",
              "Non-white Male  44.0    Private  Some-college           10.0   \n",
              "White     Male  34.0    Private          10th            6.0   \n",
              "\n",
              "                    marital-status         occupation   relationship   race  \\\n",
              "race      sex                                                                 \n",
              "Non-white Male       Never-married  Machine-op-inspct      Own-child  Black   \n",
              "White     Male  Married-civ-spouse    Farming-fishing        Husband  White   \n",
              "          Male  Married-civ-spouse    Protective-serv        Husband  White   \n",
              "Non-white Male  Married-civ-spouse  Machine-op-inspct        Husband  Black   \n",
              "White     Male       Never-married      Other-service  Not-in-family  White   \n",
              "\n",
              "                 sex  capital-gain  capital-loss  hours-per-week  \\\n",
              "race      sex                                                      \n",
              "Non-white Male  Male           0.0           0.0            40.0   \n",
              "White     Male  Male           0.0           0.0            50.0   \n",
              "          Male  Male           0.0           0.0            40.0   \n",
              "Non-white Male  Male        7688.0           0.0            40.0   \n",
              "White     Male  Male           0.0           0.0            30.0   \n",
              "\n",
              "               native-country  \n",
              "race      sex                  \n",
              "Non-white Male  United-States  \n",
              "White     Male  United-States  \n",
              "          Male  United-States  \n",
              "Non-white Male  United-States  \n",
              "White     Male  United-States  "
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-5b05e4d7-2fc1-448e-82c7-08d68d6c62bb\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th>age</th>\n",
              "      <th>workclass</th>\n",
              "      <th>education</th>\n",
              "      <th>education-num</th>\n",
              "      <th>marital-status</th>\n",
              "      <th>occupation</th>\n",
              "      <th>relationship</th>\n",
              "      <th>race</th>\n",
              "      <th>sex</th>\n",
              "      <th>capital-gain</th>\n",
              "      <th>capital-loss</th>\n",
              "      <th>hours-per-week</th>\n",
              "      <th>native-country</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th>sex</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>Non-white</th>\n",
              "      <th>Male</th>\n",
              "      <td>25.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>11th</td>\n",
              "      <td>7.0</td>\n",
              "      <td>Never-married</td>\n",
              "      <td>Machine-op-inspct</td>\n",
              "      <td>Own-child</td>\n",
              "      <td>Black</td>\n",
              "      <td>Male</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th rowspan=\"2\" valign=\"top\">White</th>\n",
              "      <th>Male</th>\n",
              "      <td>38.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>HS-grad</td>\n",
              "      <td>9.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Farming-fishing</td>\n",
              "      <td>Husband</td>\n",
              "      <td>White</td>\n",
              "      <td>Male</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>50.0</td>\n",
              "      <td>United-States</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Male</th>\n",
              "      <td>28.0</td>\n",
              "      <td>Local-gov</td>\n",
              "      <td>Assoc-acdm</td>\n",
              "      <td>12.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Protective-serv</td>\n",
              "      <td>Husband</td>\n",
              "      <td>White</td>\n",
              "      <td>Male</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Non-white</th>\n",
              "      <th>Male</th>\n",
              "      <td>44.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>Some-college</td>\n",
              "      <td>10.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Machine-op-inspct</td>\n",
              "      <td>Husband</td>\n",
              "      <td>Black</td>\n",
              "      <td>Male</td>\n",
              "      <td>7688.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>White</th>\n",
              "      <th>Male</th>\n",
              "      <td>34.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>10th</td>\n",
              "      <td>6.0</td>\n",
              "      <td>Never-married</td>\n",
              "      <td>Other-service</td>\n",
              "      <td>Not-in-family</td>\n",
              "      <td>White</td>\n",
              "      <td>Male</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>30.0</td>\n",
              "      <td>United-States</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-5b05e4d7-2fc1-448e-82c7-08d68d6c62bb')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-5b05e4d7-2fc1-448e-82c7-08d68d6c62bb button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-5b05e4d7-2fc1-448e-82c7-08d68d6c62bb');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ],
      "source": [
        "X, y, sample_weight = fetch_adult()\n",
        "X.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tgHCkDc6e2AC"
      },
      "source": [
        "We can then map the protected attributes to integers,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "spqjtesUe2AC"
      },
      "outputs": [],
      "source": [
        "X.index = pd.MultiIndex.from_arrays(X.index.codes, names=X.index.names)\n",
        "y.index = pd.MultiIndex.from_arrays(y.index.codes, names=y.index.names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qwIA6rjde2AD"
      },
      "source": [
        "and the target classes to 0/1,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "Stuowp_fe2AE"
      },
      "outputs": [],
      "source": [
        "y = pd.Series(y.factorize(sort=True)[0], index=y.index)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sj-kKu7ce2AE"
      },
      "source": [
        "split the dataset,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "WhunxmLGe2AF"
      },
      "outputs": [],
      "source": [
        "(X_train, X_test,\n",
        " y_train, y_test) = train_test_split(X, y, train_size=0.7, random_state=1234567)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tvm6f_fSe2AF"
      },
      "source": [
        "We use Pandas for one-hot encoding for easy reference to columns associated with protected attributes, information necessary for grid search reduction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 366
        },
        "id": "ILuKRrQ2e2AF",
        "outputId": "65a3ce8b-af27-4168-9e0e-3d0922a39d28"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "           age  education-num  capital-gain  capital-loss  hours-per-week  \\\n",
              "race sex                                                                    \n",
              "1    1    58.0           11.0           0.0           0.0            42.0   \n",
              "     0    51.0           12.0           0.0           0.0            30.0   \n",
              "     1    26.0           14.0           0.0        1887.0            40.0   \n",
              "     1    44.0            3.0           0.0           0.0            40.0   \n",
              "     1    33.0            6.0           0.0           0.0            40.0   \n",
              "\n",
              "          workclass_Private  workclass_Self-emp-not-inc  \\\n",
              "race sex                                                  \n",
              "1    1                    0                           1   \n",
              "     0                    0                           1   \n",
              "     1                    1                           0   \n",
              "     1                    1                           0   \n",
              "     1                    1                           0   \n",
              "\n",
              "          workclass_Self-emp-inc  workclass_Federal-gov  workclass_Local-gov  \\\n",
              "race sex                                                                       \n",
              "1    1                         0                      0                    0   \n",
              "     0                         0                      0                    0   \n",
              "     1                         0                      0                    0   \n",
              "     1                         0                      0                    0   \n",
              "     1                         0                      0                    0   \n",
              "\n",
              "          ...  native-country_Guatemala  native-country_Nicaragua  \\\n",
              "race sex  ...                                                       \n",
              "1    1    ...                         0                         0   \n",
              "     0    ...                         0                         0   \n",
              "     1    ...                         0                         0   \n",
              "     1    ...                         0                         0   \n",
              "     1    ...                         0                         0   \n",
              "\n",
              "          native-country_Scotland  native-country_Thailand  \\\n",
              "race sex                                                     \n",
              "1    1                          0                        0   \n",
              "     0                          0                        0   \n",
              "     1                          0                        0   \n",
              "     1                          0                        0   \n",
              "     1                          0                        0   \n",
              "\n",
              "          native-country_Yugoslavia  native-country_El-Salvador  \\\n",
              "race sex                                                          \n",
              "1    1                            0                           0   \n",
              "     0                            0                           0   \n",
              "     1                            0                           0   \n",
              "     1                            0                           0   \n",
              "     1                            0                           0   \n",
              "\n",
              "          native-country_Trinadad&Tobago  native-country_Peru  \\\n",
              "race sex                                                        \n",
              "1    1                                 0                    0   \n",
              "     0                                 0                    0   \n",
              "     1                                 0                    0   \n",
              "     1                                 0                    0   \n",
              "     1                                 0                    0   \n",
              "\n",
              "          native-country_Hong  native-country_Holand-Netherlands  \n",
              "race sex                                                          \n",
              "1    1                      0                                  0  \n",
              "     0                      0                                  0  \n",
              "     1                      0                                  0  \n",
              "     1                      0                                  0  \n",
              "     1                      0                                  0  \n",
              "\n",
              "[5 rows x 102 columns]"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-89177a14-ec10-47b4-a8c2-c9135d3c59ef\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th>age</th>\n",
              "      <th>education-num</th>\n",
              "      <th>capital-gain</th>\n",
              "      <th>capital-loss</th>\n",
              "      <th>hours-per-week</th>\n",
              "      <th>workclass_Private</th>\n",
              "      <th>workclass_Self-emp-not-inc</th>\n",
              "      <th>workclass_Self-emp-inc</th>\n",
              "      <th>workclass_Federal-gov</th>\n",
              "      <th>workclass_Local-gov</th>\n",
              "      <th>...</th>\n",
              "      <th>native-country_Guatemala</th>\n",
              "      <th>native-country_Nicaragua</th>\n",
              "      <th>native-country_Scotland</th>\n",
              "      <th>native-country_Thailand</th>\n",
              "      <th>native-country_Yugoslavia</th>\n",
              "      <th>native-country_El-Salvador</th>\n",
              "      <th>native-country_Trinadad&amp;Tobago</th>\n",
              "      <th>native-country_Peru</th>\n",
              "      <th>native-country_Hong</th>\n",
              "      <th>native-country_Holand-Netherlands</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th>sex</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th rowspan=\"5\" valign=\"top\">1</th>\n",
              "      <th>1</th>\n",
              "      <td>58.0</td>\n",
              "      <td>11.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>42.0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>...</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>51.0</td>\n",
              "      <td>12.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>30.0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>...</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>26.0</td>\n",
              "      <td>14.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1887.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>...</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>44.0</td>\n",
              "      <td>3.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>...</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>33.0</td>\n",
              "      <td>6.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>...</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>5 rows × 102 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-89177a14-ec10-47b4-a8c2-c9135d3c59ef')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-89177a14-ec10-47b4-a8c2-c9135d3c59ef button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-89177a14-ec10-47b4-a8c2-c9135d3c59ef');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "X_train, X_test = pd.get_dummies(X_train), pd.get_dummies(X_test)\n",
        "X_train = X_train.drop(columns=['sex_Female'])\n",
        "X_test = X_test.drop(columns=['sex_Female'])\n",
        "X_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDhjpwBPe2AG"
      },
      "source": [
        "The protected attribute information is also replicated in the labels:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OrPSVH1fe2AG",
        "outputId": "2f82868b-cbef-4872-fb4b-44b700682c05"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "race  sex\n",
              "1     1      0\n",
              "      0      1\n",
              "      1      1\n",
              "      1      0\n",
              "      1      0\n",
              "dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "y_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SIvMIxDUe2AG"
      },
      "source": [
        "### Running metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eaTRgdT2e2AG"
      },
      "source": [
        "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "D1T_mQGve2AH",
        "outputId": "35934dd1-a77e-4bf5-9302-ca38b705f3c7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.8453600648632712\n"
          ]
        }
      ],
      "source": [
        "y_pred = LogisticRegression(solver='liblinear', random_state=1234).fit(X_train, y_train).predict(X_test)\n",
        "lr_acc = accuracy_score(y_test, y_pred)\n",
        "print(lr_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bj58DvyTe2AH"
      },
      "source": [
        "We can assess how close the predictions are to equality of odds.\n",
        "\n",
        "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n",
        "\n",
        "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0JnkfgpPe2AH",
        "outputId": "558416e0-c295-48d9-de8c-aa88071f64a9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.09356509680536546\n"
          ]
        }
      ],
      "source": [
        "lr_aoe = average_odds_error(y_test, y_pred, prot_attr='sex')\n",
        "print(lr_aoe)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZfvXUkFLe2AI"
      },
      "source": [
        "### Grid Search"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cIQBEvyqe2AI"
      },
      "source": [
        "Choose a base model for the candidate classifiers. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "QjucJdOne2AI"
      },
      "outputs": [],
      "source": [
        "estimator = LogisticRegression(solver='liblinear', random_state=1234)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PP_Hx3QYe2AI"
      },
      "source": [
        "Determine the columns associated with the protected attribute(s). Grid search can handle more than one attribute but it is computationally expensive. A similar method with less computational overhead is exponentiated gradient reduction, detailed at [examples/sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb](sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "3y0DSLaWe2AI"
      },
      "outputs": [],
      "source": [
        "prot_attr = 'sex_Male'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ATCopuWVe2AI"
      },
      "source": [
        "Search for the best classifier and observe test accuracy. Other options for `constraints` include \"DemographicParity\", \"TruePositiveRateParity\", \"FalsePositiveRateParity\", and \"ErrorRateParity\"."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FVlSIaI0e2AI",
        "outputId": "c7323edb-dcc5-43cd-d4af-98fbe6bbc33b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.8458760227021449\n"
          ]
        }
      ],
      "source": [
        "np.random.seed(0) #need for reproducibility\n",
        "grid_search_red = GridSearchReduction(prot_attr=prot_attr, \n",
        "                                      estimator=estimator, \n",
        "                                      constraints=\"EqualizedOdds\",\n",
        "                                      grid_size=20,\n",
        "                                      drop_prot_attr=False)\n",
        "grid_search_red.fit(X_train, y_train)\n",
        "gs_acc = grid_search_red.score(X_test, y_test)\n",
        "print(gs_acc)\n",
        "\n",
        "#Check if accuracy is comparable\n",
        "assert abs(lr_acc-gs_acc)<0.03"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7x-yQ8cee2AL",
        "outputId": "ab502622-bcf8-438e-cb6a-a61aad1ad122"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.05787745779072595\n"
          ]
        }
      ],
      "source": [
        "gs_aoe = average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')\n",
        "print(gs_aoe)\n",
        "\n",
        "#Check if average odds error improved\n",
        "assert gs_aoe<lr_aoe"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ybsQPrKNe2AL"
      },
      "source": [
        "Instead of passing in a string value for `constraints`, we can also pass a `fairlearn.reductions.moment` object. You could use a predefined moment as we do below or create a custom moment using the fairlearn library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "scrolled": true,
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Dg0Pezure2AL",
        "outputId": "af238e02-a62d-4e71-a664-d77798a5cf1f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.8458760227021449"
            ]
          },
          "metadata": {},
          "execution_count": 16
        }
      ],
      "source": [
        "import fairlearn.reductions as red\n",
        "\n",
        "\n",
        "np.random.seed(0) #need for reproducibility\n",
        "grid_search_red = GridSearchReduction(prot_attr=prot_attr, \n",
        "                                      estimator=estimator, \n",
        "                                      constraints=red.EqualizedOdds(),\n",
        "                                      grid_size=20,\n",
        "                                      drop_prot_attr=False)\n",
        "grid_search_red.fit(X_train, y_train)\n",
        "grid_search_red.score(X_test, y_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RZH7ZOFVe2AL",
        "outputId": "ea6ce7cc-3dcb-4dba-d238-b81092a91cf6"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.05787745779072595"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ],
      "source": [
        "average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.9.7 ('aif360')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    },
    "vscode": {
      "interpreter": {
        "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
      }
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}