{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "86df96d2-0788-46ac-8d73-1e8874432119",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-22 14:06:38.369063: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-05-22 14:06:38.910967: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/home/leiding/pyotrch/lib/python3.8/site-packages/transformers/generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
      "  warnings.warn(\n",
      "/tmp/ipykernel_408467/685032672.py:28: DeprecationWarning: Please use `pearsonr` from the `scipy.stats` namespace, the `scipy.stats.stats` namespace is deprecated.\n",
      "  from scipy.stats.stats import pearsonr\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7eb8922a-7b5c-446d-b8c4-fb6a621b678e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def load_dataset(path):\n",
    "    \n",
    "    with open(path, \"rb\") as f:\n",
    "        \n",
    "        data = pickle.load(f)\n",
    "    return data\n",
    "\n",
    "def load_dictionary(path):\n",
    "    \n",
    "    with open(path, \"r\", encoding = \"utf-8\") as f:\n",
    "        \n",
    "        lines = f.readlines()\n",
    "        \n",
    "    k2v, v2k = {}, {}\n",
    "    for line in lines:\n",
    "        \n",
    "        k,v = line.strip().split(\"\\t\")\n",
    "        v = int(v)\n",
    "        k2v[k] = v\n",
    "        v2k[v] = k\n",
    "    \n",
    "    return k2v, v2k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "43d2e8be-5d97-4cdc-a3e2-466af9873af7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train = load_dataset(\"../data/train.pickle\")\n",
    "dev = load_dataset(\"../data/dev.pickle\")\n",
    "test = load_dataset(\"../data/test.pickle\")\n",
    "p2i, i2p = load_dictionary(\"../profession2index.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "346da441-d3eb-43b4-a0b8-17c285f0e881",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "path = \"../data/\"\n",
    "x_train = np.load(path + \"train_cls.npy\")\n",
    "x_dev = np.load(path + \"dev_cls.npy\")\n",
    "x_test = np.load(path + \"test_cls.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "58f0a779-dbbc-4396-9a40-c9d40af3db47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_train = np.array([p2i[entry[\"p\"]] for entry in train])\n",
    "y_dev = np.array([p2i[entry[\"p\"]] for entry in dev])\n",
    "y_test = np.array([p2i[entry[\"p\"]] for entry in test])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5d998cfb-82bf-4482-818d-7ae5ac0e4f81",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "550cb737-2e70-4bee-a04e-81a4eb908488",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SDR_multi_new(X_train, Y, K, h):\n",
    "\n",
    "    n = X_train.shape[0]\n",
    "    d_x = X_train.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "    \n",
    "    X = X_train\n",
    "\n",
    "    interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        #interval = np.linspace(np.min(Y[:,i])-0.1, np.max(Y[:,i])+0.1, h+1,endpoint=True)\n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + Gamma/d_y\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    b1 = v_PMS[:,:K] #+ #0.001*np.random.rand(d_x,K)\n",
    "    b2 = v_PMS[:,K:] #+ #0.001*np.random.rand(d_x,d_x-K)\n",
    "    #print(b1)\n",
    "    beta1 = get_rowspace_projection(b1.T) #vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    beta2 = get_rowspace_projection(b2.T)\n",
    "    return beta1, beta2  #np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) \n",
    "\n",
    "def debias_proj(X_t, Y1_train, Y2_train, Iter, q, h):\n",
    "    n = len(X_t)\n",
    "    d = X_t.shape[1]\n",
    "    X1 = X_t-np.mean(X_t,0).reshape(1,d)\n",
    "    Cx = np.dot(X1.T,X1)/n\n",
    "#     la, v = np.linalg.eig(Cx)\n",
    "#     la = np.abs(la)\n",
    "#     v = v.real\n",
    "#     Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    InvC = np.linalg.pinv(Cx,rcond=1e-8)\n",
    "    X_t = np.dot(X_t, InvC)\n",
    "\n",
    "    X_ = X_t\n",
    "    Y2_ = Y2_train\n",
    "    beta1_g, beta2_g = SDR_multi_new(X_, Y2_, q, h)\n",
    "\n",
    "    return beta1_g, beta2_g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "581334f2-b75f-48f5-9af5-9b5a862929c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "157d495b-de86-4e5f-8bb1-f2d76d3f80ae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_rowspace_projection(W: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    :param W: the matrix over its nullspace to project\n",
    "    :return: the projection matrix over the rowspace\n",
    "    \"\"\"\n",
    "    if np.allclose(W, 0):\n",
    "        w_basis = np.zeros_like(W.T)\n",
    "    else:\n",
    "        w_basis = scipy.linalg.orth(W.T) # orthogonal basis\n",
    "\n",
    "    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace\n",
    "\n",
    "    return P_W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7b69c403-ba37-4dac-8ec4-7e8f759a321d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):\n",
    "    \"\"\"\n",
    "    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),\n",
    "    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.\n",
    "    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:\n",
    "    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))\n",
    "    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections\n",
    "    :param dim: input dim\n",
    "    \"\"\"\n",
    "    I = np.eye(input_dim)\n",
    "    Q = np.sum(rowspace_projection_matrices, axis = 0)\n",
    "    P = I - get_rowspace_projection(Q)\n",
    "    return P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b7b61a0-e47d-4b49-8dfd-7d4ea369e5a3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7cc126b9-2553-48c5-bee8-c0423ae5deb5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_TPR(y_pred, y_true, p2i, i2p, gender):\n",
    "    \n",
    "    scores = defaultdict(Counter)\n",
    "    prof_count_total = defaultdict(Counter)\n",
    "    \n",
    "    for y_hat, y, g in zip(y_pred, y_true, gender):\n",
    "        \n",
    "        if y == y_hat:\n",
    "            \n",
    "            scores[i2p[y]][g] += 1\n",
    "        \n",
    "        prof_count_total[i2p[y]][g] += 1\n",
    "    \n",
    "    tprs = defaultdict(dict)\n",
    "    tprs_change = dict()\n",
    "    tprs_ratio = []\n",
    "    \n",
    "    for profession, scores_dict in scores.items():\n",
    "        \n",
    "        good_m, good_f = scores_dict[\"m\"], scores_dict[\"f\"]\n",
    "        prof_total_f = prof_count_total[profession][\"f\"]\n",
    "        prof_total_m = prof_count_total[profession][\"m\"]\n",
    "        tpr_m = (good_m) / prof_total_m\n",
    "        tpr_f = (good_f) / prof_total_f\n",
    "        \n",
    "        tprs[profession][\"m\"] = tpr_m\n",
    "        tprs[profession][\"f\"] = tpr_f\n",
    "        tprs_ratio.append(0)\n",
    "        tprs_change[profession] = tpr_f - tpr_m\n",
    "        \n",
    "    return tprs, tprs_change, np.mean(np.abs(tprs_ratio))\n",
    " \n",
    "\n",
    "def rms_diff(tpr_diff):\n",
    "    \n",
    "    return np.sqrt(np.mean(tpr_diff**2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "300fc018-c747-4cb4-83a4-f067dfde528f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def count_profs_and_gender(data: List[dict]):\n",
    "    \n",
    "    counter = defaultdict(Counter)\n",
    "    for entry in data:\n",
    "        gender, prof = entry[\"g\"], entry[\"p\"]\n",
    "        counter[prof][gender] += 1\n",
    "        \n",
    "    return counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "60626290-2c5a-4821-9928-3b7c4b223491",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "g2i, i2g = load_dictionary(\"../gender2index.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3a0113d5-2e20-4500-b0b4-2a7769254de2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_train_gender = np.array([g2i[d[\"g\"]] for d in train])\n",
    "y_test_gender = np.array([g2i[d[\"g\"]] for d in test])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ce0e947f-bb2f-487a-a0a3-1900a8afda32",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_gender = [d[\"g\"] for d in test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb74ae98-d072-439b-94d3-9c58a6ea314c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "12c7eb7c-ea52-4d9e-b17e-a99b168d842f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx = np.random.rand(x_train.shape[0]) < 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d8261892-51fb-4234-90a5-121bf968e202",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_train, y_train = x_train[idx], y_train[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d4429c18-0e8d-40fc-be05-bc001a41e2e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y_train_gender = y_train_gender[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6686cce8-868c-4fbc-9888-a294d5fddfc8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 145 seconds\n",
      "0.7979236150654844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:  2.4min finished\n"
     ]
    }
   ],
   "source": [
    "clf_prof = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "\n",
    "clf_prof.fit(x_train, y_train)\n",
    "print(clf_prof.score(x_test, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "46985cf7-9218-4140-8d6d-7e18b2b6ebc3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y1 = clf_prof.predict_proba(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "42887bf4-8531-4505-91bc-b2160d223fc6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2.66701027e-03, 2.42213386e-03, 2.21161209e-02, ...,\n",
       "        5.61156457e-06, 8.20607446e-01, 8.00269267e-03],\n",
       "       [3.23035395e-04, 5.96147301e-03, 8.60804502e-04, ...,\n",
       "        4.81365957e-05, 4.87210381e-03, 4.21413445e-06],\n",
       "       [3.16179321e-02, 3.97720962e-03, 7.08059029e-01, ...,\n",
       "        3.77870018e-03, 2.51013708e-03, 1.05503086e-05],\n",
       "       ...,\n",
       "       [1.03969786e-05, 1.60737245e-05, 7.17629886e-06, ...,\n",
       "        2.17740095e-03, 1.24836887e-05, 2.59702517e-06],\n",
       "       [3.40466352e-05, 2.24827877e-03, 5.35940797e-03, ...,\n",
       "        4.21395213e-04, 8.79322798e-04, 3.39213183e-06],\n",
       "       [1.93727960e-06, 3.39407021e-02, 2.37074288e-05, ...,\n",
       "        1.00574085e-07, 1.57339217e-05, 3.91961329e-08]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5dda65-009b-4721-a88f-8b794e26b5cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "039f0e3c-74ec-4749-9d66-47e9bc10a1d6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx = np.random.rand(x_train.shape[0]) < 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "31baf223-6c43-47a5-83bc-a4789ff913e5",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max_iter reached after 4 seconds\n",
      "0.9924753924997967\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:    4.1s finished\n"
     ]
    }
   ],
   "source": [
    "clf_gender = LogisticRegression(warm_start = False, penalty = 'l1',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 1)\n",
    "\n",
    "clf_gender.fit(x_train[idx], y_train_gender[idx])\n",
    "print(clf_gender.score(x_test, y_test_gender))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "9f94ec8e-f4a9-4dc5-808b-b5b68163ad1a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9924753924997967\n"
     ]
    }
   ],
   "source": [
    "print(clf_gender.score(x_test, y_test_gender))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c3e60f07-f7b2-4828-b1c2-fde0b0167746",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y2 = clf_gender.predict_proba(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "86ef40be-e4a5-47c0-b634-1d2a0c8217a3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[5.52186906e-06, 9.99994478e-01],\n",
       "       [9.99997030e-01, 2.96974721e-06],\n",
       "       [9.98927050e-01, 1.07294952e-03],\n",
       "       ...,\n",
       "       [9.99965694e-01, 3.43063547e-05],\n",
       "       [9.99999824e-01, 1.76045591e-07],\n",
       "       [9.99999997e-01, 2.58186567e-09]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e767d67-a437-45c9-992c-af33f5ed0bf8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "49f94558-da90-4792-a1d7-562fcfc4d005",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def SDR_PMS(X_train, Y, K, h):\n",
    "\n",
    "    n = X_train.shape[0]\n",
    "    d_x = X_train.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "    \n",
    "    X = X_train\n",
    "\n",
    "    #interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        interval = np.linspace(np.min(Y[:,i])-0.01, np.max(Y[:,i])+0.01, h+1,endpoint=True)\n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            if ph == 0:\n",
    "                mh = np.zeros(d_x)\n",
    "            else:\n",
    "                index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "                mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + Gamma/d_y\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    b1 = v_PMS[:,:K] \n",
    "    return b1\n",
    "\n",
    "def get_PMS(X_t, Y2_train, q, h):\n",
    "    n = len(X_t)\n",
    "    d = X_t.shape[1]\n",
    "    X1 = X_t-np.mean(X_t,0).reshape(1,d)\n",
    "    Cx = np.dot(X1.T,X1)/n\n",
    "    InvC = np.linalg.pinv(Cx,rcond=1e-8)\n",
    "    X_t = np.dot(X_t, InvC)\n",
    "\n",
    "    X_ = X_t\n",
    "    Y2_ = Y2_train\n",
    "    cc = SDR_PMS(X_, Y2_, 768, h)\n",
    "\n",
    "    return cc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "068733ea-12b4-4ef7-aac6-361586879997",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "4deb23b0-397a-4393-8cad-2f0cd99b5635",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "V = get_PMS(x_train, Y2, 0, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ecb42b3-262e-41c2-8b24-6c87da74ff43",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2d1ed045-1402-4fd3-bd89-52ad3a8d1054",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "h = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "8cd5d794-888f-4168-8383-0b3f63986709",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "q_vec = list(range(50,761,50))\n",
    "fout = open('./BIOS-cls-0516.txt','w') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70acc36b-a4f7-476a-bb2c-4630bcf3604e",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    }
   ],
   "source": [
    "for k in range(len(q_vec)):\n",
    "    q = q_vec[k]\n",
    "\n",
    "    print(\" dim of subspace is \"+ str(q), file=fout, flush=True)\n",
    "\n",
    "    pp = get_rowspace_projection(V[:,q:].T)\n",
    "    \n",
    "    lr2 = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "    lr2.fit(np.dot(x_train,pp), y_train)\n",
    "\n",
    "    accuracy = lr2.score(np.dot(x_test,pp), y_test)\n",
    "\n",
    "    y_pred_after = lr2.predict(np.dot(x_test,pp))\n",
    "    tprs, tprs_change_after, mean_ratio_after = get_TPR(y_pred_after, y_test, p2i, i2p, test_gender)\n",
    "    change_vals_after = np.array(list(tprs_change_after.values()))\n",
    "    rms_diff_vec = rms_diff(change_vals_after)\n",
    "\n",
    "    Acc = accuracy*100\n",
    "    \n",
    "    print(Acc, rms_diff_vec, file=fout, flush=True)\n",
    "fout.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f1e4544-77fe-48cf-bcd7-48949ff0d1e9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4fa83a3-e544-4233-badf-f67fded07d8f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d1050e8-40f3-48a6-82c3-a6507499eac9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
