{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2f757874-6a8c-4709-83ed-2582517905af",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-22 14:08:31.818308: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-05-22 14:08:32.352283: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/tmp/ipykernel_408790/685032672.py:28: DeprecationWarning: Please use `pearsonr` from the `scipy.stats` namespace, the `scipy.stats.stats` namespace is deprecated.\n",
      "  from scipy.stats.stats import pearsonr\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "47312e11-bc0e-429a-aec9-b3aef87648fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.utils import np_utils\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed69b5ef-67df-45d2-b123-e4774eadb506",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_dataset(path):\n",
    "    \n",
    "    with open(path, \"rb\") as f:\n",
    "        \n",
    "        data = pickle.load(f)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b1baa387-7d66-494a-a6ce-f187fca656fd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path = \"../data/deepmoji\"\n",
    "tags = ['pos_pos','pos_neg','neg_pos','neg_neg']\n",
    "clss = [[1,1],[1,0],[0,1],[0,0]]\n",
    "sets = ['test','dev','train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81b7d544-8f90-40e1-83ca-b36c2fca81b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ce2187a-cec7-4f6a-a7d6-bbf9c2276236",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def preprocess(path, subset, portions):\n",
    "    X = np.load(path +'/'+subset+'/'+tags[0]+\".npy\")\n",
    "    X = X[:portions[0]]\n",
    "    Y_sem = clss[0][0]*np.ones(X.shape[0])\n",
    "    Y_race = clss[0][1]*np.ones(X.shape[0])\n",
    "    \n",
    "    \n",
    "    for i in range(1,len(tags)):\n",
    "        X_new = np.load(path +'/'+subset+'/'+tags[i]+\".npy\")\n",
    "        X_new = X_new[:portions[i]]\n",
    "        X = np.r_[X, X_new]\n",
    "        Y_sem_new = clss[i][0]*np.ones(X_new.shape[0])\n",
    "        Y_sem = np.r_[Y_sem, Y_sem_new]\n",
    "        Y_race_new = clss[i][1]*np.ones(X_new.shape[0])\n",
    "        Y_race = np.r_[Y_race, Y_race_new]\n",
    "\n",
    "    print(X.shape)\n",
    "\n",
    "    np.random.seed(0)\n",
    "    np.random.shuffle(X)\n",
    "\n",
    "    np.random.seed(0)\n",
    "    np.random.shuffle(Y_sem)\n",
    "\n",
    "    np.random.seed(0)\n",
    "    np.random.shuffle(Y_race)\n",
    "    \n",
    "    return X, Y_sem, Y_race"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a9a4af16-06a1-4c55-8dd3-4b139eaa8896",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100000, 2304)\n"
     ]
    }
   ],
   "source": [
    "X_train, Y_sem_train, Y_race_train = preprocess(path, 'train', [40000,10000,10000,40000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eaa5194c-3f38-4e9b-845c-a451483dc3a6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7998, 2304)\n"
     ]
    }
   ],
   "source": [
    "X_test, Y_sem_test, Y_race_test = preprocess(path, 'test', [2000, 2000, 2000, 2000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8124f556-1337-43d2-ac81-b53f4b7da606",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def TPR(Y_p, Y_t, tag):\n",
    "    return np.mean(Y_p[Y_t == tag] == tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e96084bf-8ba5-47d8-a539-32ac762fe89c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_TPR_gap(Y_sem_pred, Y_sem_true, Y_race):\n",
    "    gap1 = TPR(Y_sem_pred[Y_race == 0], Y_sem_true[Y_race == 0], 0) - TPR(Y_sem_pred[Y_race == 1], Y_sem_true[Y_race == 1], 0)\n",
    "    gap2 = TPR(Y_sem_pred[Y_race == 0], Y_sem_true[Y_race == 0], 1) - TPR(Y_sem_pred[Y_race == 1], Y_sem_true[Y_race == 1], 1)\n",
    "    rms = np.sqrt(0.5*gap1**2 + 0.5*gap2**2)\n",
    "    return rms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbeb23f3-8ccb-474a-ae47-42e68bcec82b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0bff4dbc-1241-4557-bb81-7c7e83cd82f7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   16.4s finished\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(fit_intercept=False, max_iter=7, multi_class=&#x27;multinomial&#x27;,\n",
       "                   n_jobs=90, random_state=1, solver=&#x27;saga&#x27;, verbose=5,\n",
       "                   warm_start=True)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(fit_intercept=False, max_iter=7, multi_class=&#x27;multinomial&#x27;,\n",
       "                   n_jobs=90, random_state=1, solver=&#x27;saga&#x27;, verbose=5,\n",
       "                   warm_start=True)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "LogisticRegression(fit_intercept=False, max_iter=7, multi_class='multinomial',\n",
       "                   n_jobs=90, random_state=1, solver='saga', verbose=5,\n",
       "                   warm_start=True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "clf.fit(X_train, Y_sem_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d513f948-9890-4db6-983d-2b846e661680",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7119279819954989"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, Y_sem_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "932e6b3c-289d-4f06-8264-82c151deb223",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y_sem_pred_orig = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d0ac575b-3d7e-4708-bdc3-4031d2e5d433",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.38949996396878994"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_TPR_gap(Y_sem_pred_orig, Y_sem_test, Y_race_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bbb78893-a933-4227-b407-572e68db07d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_TPR(y_main, y_hat_main, y_protected):\n",
    "    \n",
    "    all_y = list(Counter(y_main).keys())\n",
    "    \n",
    "    protected_vals = defaultdict(dict)\n",
    "    for label in all_y:\n",
    "        for i in range(2):\n",
    "            used_vals = (y_main == label) & (y_protected == i)\n",
    "            y_label = y_main[used_vals]\n",
    "            y_hat_label = y_hat_main[used_vals]\n",
    "            protected_vals['y:{}'.format(label)]['p:{}'.format(i)] = (y_label == y_hat_label).mean()\n",
    "            \n",
    "    diffs = {}\n",
    "    for k, v in protected_vals.items():\n",
    "        vals = list(v.values())\n",
    "        diffs[k] = vals[0] - vals[1]\n",
    "    return protected_vals, diffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "dd07d851-41f6-4ef5-9ac6-e21442da1168",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def rms(arr):\n",
    "    return np.sqrt(np.mean(np.square(arr)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fb16421b-dd49-4060-b465-670420d61417",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_, diffs = get_TPR(Y_sem_test, Y_sem_pred_orig,  Y_race_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e7901c22-a7c2-4b38-9018-3b21ae27222e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.38949996396878994"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rms(list(diffs.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "794ff3c2-874a-4adf-b0d5-12aee9f9178b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b4dd1d4-ff56-40dd-b677-9decdc84ae03",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "87974c69-622f-4212-a018-a6650614858c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9f6ec0c2-2a95-436b-96a4-3e0fce7773df",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SDR_multi_new(X_train, Y, K, h):\n",
    "\n",
    "    n = X_train.shape[0]\n",
    "    d_x = X_train.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "    \n",
    "    X = X_train\n",
    "\n",
    "    interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        #interval = np.linspace(np.min(Y[:,i])-0.1, np.max(Y[:,i])+0.1, h+1,endpoint=True)\n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + la_g[0]*Gamma\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    b1 = v_PMS[:,:K] + 0.001*np.random.rand(d_x,K)\n",
    "    b2 = v_PMS[:,K:] + 0.001*np.random.rand(d_x,d_x-K)\n",
    "    #print(b1)\n",
    "    beta1 = get_rowspace_projection(b1.T) #vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    beta2 = get_rowspace_projection(b2.T)\n",
    "    return beta1, beta2  #np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) \n",
    "\n",
    "def debias_proj(X_t, Y1_train, Y2_train, Iter, q, h):\n",
    "    n = len(X_t)\n",
    "    d = X_t.shape[1]\n",
    "    X1 = X_t-np.mean(X_t,0).reshape(1,d)\n",
    "    Cx = np.dot(X1.T,X1)/n\n",
    "#     la, v = np.linalg.eig(Cx)\n",
    "#     la = np.abs(la)\n",
    "#     v = v.real\n",
    "#     Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    InvC = np.linalg.pinv(Cx,rcond=1e-8)\n",
    "    X_t = np.dot(X_t, InvC)\n",
    "\n",
    "    X_ = X_t\n",
    "    Y2_ = Y2_train\n",
    "    beta1_g, beta2_g = SDR_multi_new(X_, Y2_, q, h)\n",
    "\n",
    "    return beta1_g, beta2_g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "69a2407b-ae16-48d1-91ad-cb8e0e05d74b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_rowspace_projection(W: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    :param W: the matrix over its nullspace to project\n",
    "    :return: the projection matrix over the rowspace\n",
    "    \"\"\"\n",
    "    if np.allclose(W, 0):\n",
    "        w_basis = np.zeros_like(W.T)\n",
    "    else:\n",
    "        w_basis = scipy.linalg.orth(W.T) # orthogonal basis\n",
    "\n",
    "    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace\n",
    "\n",
    "    return P_W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8a2fd2bb-bf4d-4a7a-9b88-56b79c6d8102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):\n",
    "    \"\"\"\n",
    "    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),\n",
    "    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.\n",
    "    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:\n",
    "    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))\n",
    "    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections\n",
    "    :param dim: input dim\n",
    "    \"\"\"\n",
    "    I = np.eye(input_dim)\n",
    "    Q = np.sum(rowspace_projection_matrices, axis = 0)\n",
    "    P = I - get_rowspace_projection(Q)\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c967836f-5198-40ac-a9af-c5b996d3e2ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c7b80759-6f91-4060-a9b8-3fde5deed9b2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n",
      "0.7119279819954989\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   16.5s finished\n"
     ]
    }
   ],
   "source": [
    "clf_sem = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "clf_sem.fit(X_train, Y_sem_train)\n",
    "print(clf_sem.score(X_test, Y_sem_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "81e6b69f-fa81-4f83-b37a-a944c07d2744",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y = clf_sem.predict_proba(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e0159c-7066-4dca-97c1-1420a5f9c19b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e2a6d2f9-aa80-4c2a-8168-65dca49452f8",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n",
      "0.8823455863965991\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   16.4s finished\n"
     ]
    }
   ],
   "source": [
    "clf_race = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "clf_race.fit(X_train, Y_race_train)\n",
    "print(clf_race.score(X_test, Y_race_test))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "6f9884b1-cb93-4b2e-b80d-422f4410d86b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Z = clf_race.predict_proba(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48dae6fb-24f2-43bb-a264-7563bd50fcb6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0f5ba936-bf28-4a1a-8d11-b3853db2dc27",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[3.86421321e-04, 9.99613579e-01],\n",
       "       [8.88230709e-01, 1.11769291e-01],\n",
       "       [9.95507796e-03, 9.90044922e-01],\n",
       "       ...,\n",
       "       [9.10343071e-01, 8.96569293e-02],\n",
       "       [8.75272960e-01, 1.24727040e-01],\n",
       "       [9.73272006e-01, 2.67279943e-02]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4cf1e7-7cb2-4cec-8624-3740572a427b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "df6cb722-9fd7-4f3d-8b38-0706426eef41",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "h = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "4e4ba166-d76b-4795-99a9-b471de6d4b72",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "beta1, beta2 = debias_proj(X_train, 0, Z, 1, 1250, h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e0bff539-c351-4b36-9dc3-a4104d86807f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pp = beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "74dc6933-4c4f-4a7d-9937-ff3259262e18",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 16 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   16.0s finished\n"
     ]
    }
   ],
   "source": [
    "lr2 = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "lr2.fit(np.dot(X_train,pp), Y_sem_train)\n",
    "\n",
    "accuracy = lr2.score(np.dot(X_test,pp), Y_sem_test)\n",
    "\n",
    "Y_sem_pred = lr2.predict(np.dot(X_test,pp))\n",
    "\n",
    "rms_diff_vec = get_TPR_gap(Y_sem_pred, Y_sem_test, Y_race_test)\n",
    "\n",
    "Acc = accuracy*100\n",
    "GAP = (1-rms_diff_vec)*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "989a00d9-fe18-4977-bafb-6ebd39cb7e01",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "69.00475118779696"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "d5c6900e-00f3-4bb7-89ce-096f5cda444c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.10569311467661427"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rms_diff_vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd666ca2-8ffa-44e1-b5fc-c54bf553f91f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69586e2b-c910-4d95-b70e-04c6f98f812b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bbc7b65-dd80-42eb-94ff-0fc9600ba28e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3232093-486a-405a-b016-5f336d8d8ef6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "868fd6cc-53ba-46f8-8e57-7a2e4ef5dcf6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
