{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/Trusted-AI/AIF360/blob/master/examples/sklearn/demo_grid_search_reduction_regression_sklearn.ipynb)"
      ],
      "metadata": {
        "id": "fgfyb_2c9WL4"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_uAgvnjw71C_"
      },
      "source": [
        "# Sklearn compatible Grid Search for regression\n",
        "\n",
        "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n",
        "the candidates searched. For regression it uses the same priniciple to return a deterministic regressor with the lowest empirical error subject to the constraint of bounded group loss. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#Install aif360\n",
        "#Install Reductions from Fairlearn\n",
        "#Install Low School GPA dataset\n",
        "!pip install aif360[Reductions,LawSchoolGPA]"
      ],
      "metadata": {
        "id": "4Ad7nC129i8f",
        "outputId": "0ce37c29-d5a5-4682-968e-bb33cea9de19",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: aif360[LawSchoolGPA,Reductions] in /usr/local/lib/python3.7/dist-packages (0.5.0)\n",
            "Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (1.7.3)\n",
            "Requirement already satisfied: scikit-learn>=1.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (1.0.2)\n",
            "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (1.21.6)\n",
            "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (1.3.5)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (3.2.2)\n",
            "Requirement already satisfied: fairlearn~=0.7 in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (0.7.0)\n",
            "Requirement already satisfied: tempeh in /usr/local/lib/python3.7/dist-packages (from aif360[LawSchoolGPA,Reductions]) (0.1.12)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[LawSchoolGPA,Reductions]) (2022.2.1)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->aif360[LawSchoolGPA,Reductions]) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.24.0->aif360[LawSchoolGPA,Reductions]) (1.15.0)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[LawSchoolGPA,Reductions]) (1.1.0)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=1.0->aif360[LawSchoolGPA,Reductions]) (3.1.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA,Reductions]) (1.4.4)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA,Reductions]) (3.0.9)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->aif360[LawSchoolGPA,Reductions]) (0.11.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib->aif360[LawSchoolGPA,Reductions]) (4.1.1)\n",
            "Requirement already satisfied: pytest in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA,Reductions]) (3.6.4)\n",
            "Requirement already satisfied: memory-profiler in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA,Reductions]) (0.60.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA,Reductions]) (2.23.0)\n",
            "Requirement already satisfied: shap in /usr/local/lib/python3.7/dist-packages (from tempeh->aif360[LawSchoolGPA,Reductions]) (0.41.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from memory-profiler->tempeh->aif360[LawSchoolGPA,Reductions]) (5.4.8)\n",
            "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (1.11.0)\n",
            "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (8.14.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (57.4.0)\n",
            "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (22.1.0)\n",
            "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (0.7.1)\n",
            "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.7/dist-packages (from pytest->tempeh->aif360[LawSchoolGPA,Reductions]) (1.4.1)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA,Reductions]) (1.24.3)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA,Reductions]) (3.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA,Reductions]) (2022.6.15)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->tempeh->aif360[LawSchoolGPA,Reductions]) (2.10)\n",
            "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA,Reductions]) (0.56.2)\n",
            "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA,Reductions]) (1.5.0)\n",
            "Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA,Reductions]) (0.0.7)\n",
            "Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA,Reductions]) (4.64.1)\n",
            "Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap->tempeh->aif360[LawSchoolGPA,Reductions]) (21.3)\n",
            "Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap->tempeh->aif360[LawSchoolGPA,Reductions]) (0.39.1)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from numba->shap->tempeh->aif360[LawSchoolGPA,Reductions]) (4.12.0)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba->shap->tempeh->aif360[LawSchoolGPA,Reductions]) (3.8.1)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "sivw3vma71DE"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.compose import TransformedTargetRegressor\n",
        "from sklearn.linear_model import LinearRegression\n",
        "from sklearn.metrics import mean_absolute_error\n",
        "from sklearn.preprocessing import MinMaxScaler\n",
        "\n",
        "from aif360.sklearn.datasets import fetch_lawschool_gpa\n",
        "from aif360.sklearn.inprocessing import GridSearchReduction\n",
        "from aif360.sklearn.metrics import difference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMTJ62DZ71DF"
      },
      "source": [
        "### Loading data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KgqbgiIY71DG"
      },
      "source": [
        "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
        "\n",
        "For example, we can easily load the law school gpa dataset from tempeh with the following line:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "7Xqg6Yn771DG",
        "outputId": "00fda049-0c6a-4c86-9ee8-f49b9e2d5161",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "      lsat  ugpa  race\n",
              "race                  \n",
              "0     38.0   3.3     0\n",
              "1     34.0   4.0     1\n",
              "1     34.0   3.9     1\n",
              "1     45.0   3.3     1\n",
              "1     39.0   2.5     1"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-5dc8ae11-154f-4781-a43e-7b27b819c810\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>lsat</th>\n",
              "      <th>ugpa</th>\n",
              "      <th>race</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>38.0</td>\n",
              "      <td>3.3</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>34.0</td>\n",
              "      <td>4.0</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>34.0</td>\n",
              "      <td>3.9</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>45.0</td>\n",
              "      <td>3.3</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>39.0</td>\n",
              "      <td>2.5</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-5dc8ae11-154f-4781-a43e-7b27b819c810')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-5dc8ae11-154f-4781-a43e-7b27b819c810 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-5dc8ae11-154f-4781-a43e-7b27b819c810');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "X_train, y_train = fetch_lawschool_gpa(\"train\", numeric_only=True)\n",
        "X_test, y_test = fetch_lawschool_gpa(\"test\", numeric_only=True)\n",
        "X_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mgjzey6y71DI"
      },
      "source": [
        "We normalize the continuous values, making sure to propagate column names associated with protected attributes, information necessary for grid search reduction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "7CxY3K8i71DI",
        "outputId": "c14dec8a-8a45-4992-af71-fd6cc5761cc0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "          lsat   ugpa  race\n",
              "race                       \n",
              "0     0.729730  0.825   0.0\n",
              "1     0.621622  1.000   1.0\n",
              "1     0.621622  0.975   1.0\n",
              "1     0.918919  0.825   1.0\n",
              "1     0.756757  0.625   1.0"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-242eaf4b-880d-4104-86f9-c9eaa52f4b73\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>lsat</th>\n",
              "      <th>ugpa</th>\n",
              "      <th>race</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.729730</td>\n",
              "      <td>0.825</td>\n",
              "      <td>0.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.621622</td>\n",
              "      <td>1.000</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.621622</td>\n",
              "      <td>0.975</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.918919</td>\n",
              "      <td>0.825</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.756757</td>\n",
              "      <td>0.625</td>\n",
              "      <td>1.0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-242eaf4b-880d-4104-86f9-c9eaa52f4b73')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-242eaf4b-880d-4104-86f9-c9eaa52f4b73 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-242eaf4b-880d-4104-86f9-c9eaa52f4b73');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ],
      "source": [
        "scaler = MinMaxScaler()\n",
        "\n",
        "X_train  = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns, index=X_train.index)\n",
        "X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns, index=X_test.index)\n",
        "\n",
        "X_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lukGdCjv71DJ"
      },
      "source": [
        "### Running metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9LheHu-71DK"
      },
      "source": [
        "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data. We drop the protective attribule columns so that they are not used in the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "y3mrAFw471DL",
        "outputId": "d7bd400b-2fc1-426b-baab-93139fb48739",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.7400826321650612"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "tt = TransformedTargetRegressor(LinearRegression(), transformer=scaler)\n",
        "tt = tt.fit(X_train.drop([\"race\"], axis=1), y_train)\n",
        "y_pred = tt.predict(X_test.drop([\"race\"], axis=1))\n",
        "lr_mae = mean_absolute_error(y_test, y_pred)\n",
        "lr_mae"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M2cLBv8q71DL"
      },
      "source": [
        "We can assess how the mean absolute error differs across groups simply"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "F97LRCdf71DM",
        "outputId": "fe6f95a7-fdb9-4a47-bd39-284ef726f983",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.2039259052574467"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "lr_mae_diff = difference(mean_absolute_error, y_test, y_pred)\n",
        "lr_mae_diff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ADb1vbQ71DM"
      },
      "source": [
        "### Grid Search"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkoNE09071DM"
      },
      "source": [
        "Reuse the base model for the candidate regressors. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "TxejOHFK71DN"
      },
      "outputs": [],
      "source": [
        "estimator = TransformedTargetRegressor(LinearRegression(), transformer=scaler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WyqzfrL271DN"
      },
      "source": [
        "Search for the best regressor and observe mean absolute error. Grid search for regression uses \"BoundedGroupLoss\" to specify using bounded group loss for its constraints. Accordingly we need to specify a loss function, like \"Absolute.\" Other options include \"Square\" and \"ZeroOne.\" When the loss is \"Absolute\" or \"Square\" we also specify the expected range of the y values in min_val and max_val. For details on the implementation of these loss function see the fairlearn library here https://github.com/fairlearn/fairlearn/blob/master/fairlearn/reductions/_moments/bounded_group_loss.py."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "B63yaCCf71DN",
        "outputId": "e3e2c851-b3f0-4296-e868-a46b78899364",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.7622719376746615\n"
          ]
        }
      ],
      "source": [
        "np.random.seed(0) #need for reproducibility\n",
        "grid_search_red = GridSearchReduction(prot_attr=\"race\", \n",
        "                                      estimator=estimator, \n",
        "                                      constraints=\"BoundedGroupLoss\",\n",
        "                                      loss=\"Absolute\",\n",
        "                                      min_val=y_train.min(),\n",
        "                                      max_val=y_train.max(),\n",
        "                                      grid_size=10,\n",
        "                                      drop_prot_attr=True)\n",
        "grid_search_red.fit(X_train, y_train)\n",
        "gs_pred = grid_search_red.predict(X_test)\n",
        "gs_mae = mean_absolute_error(y_test, gs_pred)\n",
        "print(gs_mae)\n",
        "\n",
        "#Check if mean absolute error is comparable\n",
        "assert abs(gs_mae-lr_mae) < 0.08"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "WPd4i-Zj71DO",
        "outputId": "d40bcb93-fcb1-405b-ffc3-a72ef29157e1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.06122151904963524\n"
          ]
        }
      ],
      "source": [
        "gs_mae_diff = difference(mean_absolute_error, y_test, gs_pred)\n",
        "print(gs_mae_diff)\n",
        "\n",
        "#Check if difference decreased\n",
        "assert gs_mae_diff < lr_mae_diff"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.9.7 ('aif360')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    },
    "vscode": {
      "interpreter": {
        "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
      }
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}