{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d908ca32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bf7c24cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "acs_2010 = pd.read_csv('acs_2010_1yr.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "32ca85f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['STATE', 'RACPACIS', 'VET47X50', 'WORKEDYR', 'SPEAKENG', 'EMPSTAT',\n",
       "       'DIFFREM', 'DIFFEYE', 'SEX', 'HWSEI', 'FERTYR', 'RACOTHER', 'HINSCAID',\n",
       "       'NCHLT5', 'SCHLTYPE', 'BUILTYR2', 'EDUC', 'DIFFSENS', 'VETSTAT',\n",
       "       'RACBLK', 'DIFFPHYS', 'MARRNO', 'VET75X90', 'VETWWII', 'LABFORCE',\n",
       "       'HCOVANY', 'MORTGAGE', 'OCCSCORE', 'VACANCY', 'VET55X64', 'ELDCH',\n",
       "       'SCHOOL', 'CITIZEN', 'DEGFIELD', 'AGE', 'OWNERSHP', 'ACREHOUS',\n",
       "       'GRADEATT', 'RACAMIND', 'CLASSWKR', 'HISPAN', 'YNGCH', 'WIDINYR',\n",
       "       'MIGRATE1', 'POVERTY', 'RELATE', 'VETDISAB', 'RACWHT', 'BEDROOMS',\n",
       "       'MARST', 'HINSCARE', 'NSIBS', 'LOOKING', 'VETVIETN', 'DIFFCARE',\n",
       "       'MARRINYR', 'DIVINYR', 'FAMSIZE', 'AVAILBLE', 'NMOTHERS', 'NFATHERS',\n",
       "       'SEI', 'DIFFHEAR', 'NCHILD', 'RACASIAN', 'FOODSTMP', 'RACE', 'VEHICLES',\n",
       "       'NCOUPLES', 'HCOVPRIV', 'MIGTYPE1', 'VET90X01', 'DIFFMOB', 'NFAMS',\n",
       "       'METRO', 'LANGUAGE', 'VETKOREA', 'HINSVA', 'MULTGEN', 'ROOMS',\n",
       "       'VET01LTR', 'AGEORIG'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acs_2010.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "581234e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b478057a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'RACPACIS': 2,\n",
       " 'VET47X50': 3,\n",
       " 'WORKEDYR': 4,\n",
       " 'SPEAKENG': 6,\n",
       " 'EMPSTAT': 4,\n",
       " 'DIFFREM': 3,\n",
       " 'DIFFEYE': 2,\n",
       " 'SEX': 2,\n",
       " 'HWSEI': 305,\n",
       " 'FERTYR': 3,\n",
       " 'RACOTHER': 2,\n",
       " 'HINSCAID': 2,\n",
       " 'NCHLT5': 8,\n",
       " 'SCHLTYPE': 4,\n",
       " 'BUILTYR2': 23,\n",
       " 'EDUC': 11,\n",
       " 'DIFFSENS': 2,\n",
       " 'VETSTAT': 3,\n",
       " 'RACBLK': 2,\n",
       " 'DIFFPHYS': 3,\n",
       " 'MARRNO': 4,\n",
       " 'VET75X90': 3,\n",
       " 'VETWWII': 3,\n",
       " 'LABFORCE': 3,\n",
       " 'HCOVANY': 2,\n",
       " 'MORTGAGE': 4,\n",
       " 'OCCSCORE': 48,\n",
       " 'VACANCY': 1,\n",
       " 'VET55X64': 3,\n",
       " 'ELDCH': 92,\n",
       " 'SCHOOL': 3,\n",
       " 'CITIZEN': 4,\n",
       " 'DEGFIELD': 38,\n",
       " 'AGE': 10,\n",
       " 'OWNERSHP': 3,\n",
       " 'ACREHOUS': 3,\n",
       " 'GRADEATT': 8,\n",
       " 'RACAMIND': 2,\n",
       " 'CLASSWKR': 3,\n",
       " 'HISPAN': 5,\n",
       " 'YNGCH': 91,\n",
       " 'WIDINYR': 3,\n",
       " 'MIGRATE1': 5,\n",
       " 'POVERTY': 502,\n",
       " 'RELATE': 13,\n",
       " 'VETDISAB': 8,\n",
       " 'RACWHT': 2,\n",
       " 'BEDROOMS': 16,\n",
       " 'MARST': 6,\n",
       " 'HINSCARE': 2,\n",
       " 'NSIBS': 10,\n",
       " 'LOOKING': 4,\n",
       " 'VETVIETN': 3,\n",
       " 'DIFFCARE': 3,\n",
       " 'MARRINYR': 3,\n",
       " 'DIVINYR': 4,\n",
       " 'FAMSIZE': 20,\n",
       " 'AVAILBLE': 5,\n",
       " 'NMOTHERS': 7,\n",
       " 'NFATHERS': 7,\n",
       " 'SEI': 81,\n",
       " 'DIFFHEAR': 2,\n",
       " 'NCHILD': 10,\n",
       " 'RACASIAN': 2,\n",
       " 'FOODSTMP': 2,\n",
       " 'RACE': 9,\n",
       " 'VEHICLES': 8,\n",
       " 'NCOUPLES': 7,\n",
       " 'HCOVPRIV': 2,\n",
       " 'MIGTYPE1': 7,\n",
       " 'VET90X01': 3,\n",
       " 'DIFFMOB': 3,\n",
       " 'NFAMS': 20,\n",
       " 'METRO': 5,\n",
       " 'LANGUAGE': 67,\n",
       " 'VETKOREA': 3,\n",
       " 'HINSVA': 2,\n",
       " 'MULTGEN': 4,\n",
       " 'ROOMS': 29,\n",
       " 'VET01LTR': 3,\n",
       " 'AGEORIG': 97}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json.load(open('acs_2010_1yr-domain.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8de7a9d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>STATE</th>\n",
       "      <th>RACPACIS</th>\n",
       "      <th>VET47X50</th>\n",
       "      <th>WORKEDYR</th>\n",
       "      <th>SPEAKENG</th>\n",
       "      <th>EMPSTAT</th>\n",
       "      <th>DIFFREM</th>\n",
       "      <th>DIFFEYE</th>\n",
       "      <th>SEX</th>\n",
       "      <th>HWSEI</th>\n",
       "      <th>...</th>\n",
       "      <th>DIFFMOB</th>\n",
       "      <th>NFAMS</th>\n",
       "      <th>METRO</th>\n",
       "      <th>LANGUAGE</th>\n",
       "      <th>VETKOREA</th>\n",
       "      <th>HINSVA</th>\n",
       "      <th>MULTGEN</th>\n",
       "      <th>ROOMS</th>\n",
       "      <th>VET01LTR</th>\n",
       "      <th>AGEORIG</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>AL</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>AL</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>279</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>AL</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>269</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>AL</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>AL</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3061687</th>\n",
       "      <td>WY</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>69</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3061688</th>\n",
       "      <td>WY</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3061689</th>\n",
       "      <td>WY</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3061690</th>\n",
       "      <td>WY</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>33</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>55</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3061691</th>\n",
       "      <td>WY</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>170</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>52</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3061692 rows × 82 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        STATE  RACPACIS  VET47X50  WORKEDYR  SPEAKENG  EMPSTAT  DIFFREM  \\\n",
       "0          AL         0         0         1         2        3        1   \n",
       "1          AL         0         0         3         2        1        1   \n",
       "2          AL         0         0         3         2        1        1   \n",
       "3          AL         0         0         0         0        0        0   \n",
       "4          AL         0         0         1         2        3        2   \n",
       "...       ...       ...       ...       ...       ...      ...      ...   \n",
       "3061687    WY         0         0         2         2        3        1   \n",
       "3061688    WY         0         0         0         0        0        0   \n",
       "3061689    WY         0         0         0         0        0        0   \n",
       "3061690    WY         0         0         3         2        1        1   \n",
       "3061691    WY         0         0         3         2        1        1   \n",
       "\n",
       "         DIFFEYE  SEX  HWSEI  ...  DIFFMOB  NFAMS  METRO  LANGUAGE  VETKOREA  \\\n",
       "0              0    1      0  ...        1      0      1         1         0   \n",
       "1              0    0    279  ...        1      0      2         1         0   \n",
       "2              0    1    269  ...        1      0      2         1         0   \n",
       "3              0    0      0  ...        0      0      2         0         0   \n",
       "4              0    1      0  ...        2      0      4         1         0   \n",
       "...          ...  ...    ...  ...      ...    ...    ...       ...       ...   \n",
       "3061687        0    1     69  ...        1      0      1         1         0   \n",
       "3061688        0    0      0  ...        0      0      1         0         0   \n",
       "3061689        0    1      0  ...        0      0      1         0         0   \n",
       "3061690        0    1     33  ...        1      0      1         1         0   \n",
       "3061691        0    0    170  ...        1      0      1         1         0   \n",
       "\n",
       "         HINSVA  MULTGEN  ROOMS  VET01LTR  AGEORIG  \n",
       "0             0        1      5         0       75  \n",
       "1             0        2      6         0       25  \n",
       "2             0        2      6         0       26  \n",
       "3             0        2      6         0        3  \n",
       "4             0        1      5         0       87  \n",
       "...         ...      ...    ...       ...      ...  \n",
       "3061687       0        2      5         0       26  \n",
       "3061688       0        2      5         0        3  \n",
       "3061689       0        2      5         0        1  \n",
       "3061690       0        1      4         0       55  \n",
       "3061691       0        1      4         0       52  \n",
       "\n",
       "[3061692 rows x 82 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acs_2010"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a9c2477b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from folktables import ACSDataSource, ACSEmployment\n",
    "\n",
    "data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')\n",
    "acs_data = data_source.get_data(states=[\"CA\"], download=True)\n",
    "features, label, group = ACSEmployment.df_to_numpy(acs_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "248e5ea1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.046779134534380984"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "X_train, X_test, y_train, y_test, group_train, group_test = train_test_split(\n",
    "    features, label, group, test_size=0.2, random_state=0)\n",
    "\n",
    "###### Your favorite learning algorithm here #####\n",
    "# model = make_pipeline(StandardScaler(), LogisticRegression())\n",
    "model = LogisticRegression()\n",
    "model.fit(X_train, y_train)\n",
    "\n",
    "yhat = model.predict(X_test)\n",
    "\n",
    "white_tpr = np.mean(yhat[(y_test == 1) & (group_test == 1)])\n",
    "black_tpr = np.mean(yhat[(y_test == 1) & (group_test == 2)])\n",
    "\n",
    "# Equality of opportunity violation: 0.0455\n",
    "white_tpr - black_tpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d499cacf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3 4 5 6 7 8 9]\n"
     ]
    }
   ],
   "source": [
    "print(np.unique(group_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0fcbb55d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(303053, 16)\n"
     ]
    }
   ],
   "source": [
    "print(X_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f2b182c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_states = np.unique(acs_2010['STATE'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4bc91c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_clients = {}\n",
    "def partition_data(x,y,a,state):\n",
    "    X_train, X_test, y_train, y_test, group_train, group_test = train_test_split(\n",
    "    x,y,a, test_size=0.2, random_state=0)\n",
    "    data = {}\n",
    "    index_of_0_1_train = group_train <= 2\n",
    "    index_of_0_1_test = group_test <= 2\n",
    "    data['x'] = X_train[index_of_0_1_train]\n",
    "    data['a'] = group_train[index_of_0_1_train]\n",
    "    data['y'] = y_train[index_of_0_1_train]\n",
    "    data['x_test'] = X_test[index_of_0_1_test]\n",
    "    data['a_test'] = group_test[index_of_0_1_test]\n",
    "    data['y_test'] = y_test[index_of_0_1_test]\n",
    "    all_clients[state] = data\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "20b1c70c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for state in all_states:\n",
    "    if state != \"DC\":\n",
    "        acs_data = data_source.get_data(states=[state], download=True)\n",
    "        features, label, group = ACSEmployment.df_to_numpy(acs_data)\n",
    "        partition_data(features, label, group, state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ccfc3b62",
   "metadata": {},
   "outputs": [],
   "source": [
    "def repartition_data(all_clients):\n",
    "    data = {}\n",
    "    for i,client in enumerate(all_clients.values()):\n",
    "        if i == 0:\n",
    "            data['x'] = client['x']\n",
    "            data['a'] = client['a']\n",
    "            data['y'] = client['y']\n",
    "            data['x_test'] = client['x_test']\n",
    "            data['a_test'] = client['a_test']\n",
    "            data['y_test'] = client['y_test']\n",
    "        else:\n",
    "            data['x'] = np.concatenate((data['x'],client['x']),axis=0)\n",
    "            data['a'] = np.concatenate((data['a'],client['a']),axis=0)\n",
    "            data['y'] = np.concatenate((data['y'],client['y']),axis=0)\n",
    "            data['x_test'] = np.concatenate((data['x_test'],client['x_test']),axis=0)\n",
    "            data['a_test'] = np.concatenate((data['a_test'],client['a_test']),axis=0)\n",
    "            data['y_test'] = np.concatenate((data['y_test'],client['y_test']),axis=0)\n",
    "    new_client_data = {}\n",
    "    data_x_split = np.array_split(data['x'],50)\n",
    "    data_a_split = np.array_split(data['a'],50)\n",
    "    data_y_split = np.array_split(data['y'],50)\n",
    "    data_x_test_split = np.array_split(data['x_test'],50)\n",
    "    data_a_test_split = np.array_split(data['a_test'],50)\n",
    "    data_y_test_split = np.array_split(data['y_test'],50)\n",
    "    for i,client_name in enumerate(all_clients.keys()):\n",
    "        client_data = {}\n",
    "        client_data['x'] = data_x_split[i]\n",
    "        client_data['a'] = data_a_split[i]\n",
    "        client_data['y'] = data_y_split[i]\n",
    "        client_data['x_test'] = data_x_test_split[i]\n",
    "        client_data['a_test'] = data_a_test_split[i]\n",
    "        client_data['y_test'] = data_y_test_split[i]\n",
    "        new_client_data[client_name] = client_data\n",
    "    return new_client_data\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "4c8680b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_clients = repartition_data(all_clients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e22b3d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import functools\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from copy import deepcopy\n",
    "from sklearn.preprocessing import normalize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "641c6c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "for client in all_clients.values():\n",
    "    client['x'] = torch.tensor(client['x'])\n",
    "    client['y'] = torch.LongTensor(client['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b453d0e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for client in all_clients.values():\n",
    "    client['x_test'] = torch.tensor(client['x_test'])\n",
    "    client['y_test'] = torch.LongTensor(client['y_test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "61f7d626",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(LogisticRegression, self).__init__()\n",
    "        self.linear = torch.nn.Linear(input_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        outputs = self.linear(x)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a4988354",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sizes(lst):\n",
    "    sizes = []\n",
    "    for w in lst:\n",
    "        sizes.append(functools.reduce((lambda x, y: x*y), w.size()))\n",
    "    c = np.cumsum(sizes)\n",
    "    bounds = list(zip([0] + c[:-1].tolist(), c.tolist()))\n",
    "    return sizes, bounds\n",
    "\n",
    "def torch_to_numpy(lst, arr=None):\n",
    "    # lst: obtained either from list(net.parameters()) or from torch.autograd.grad\n",
    "    lst = list(lst)\n",
    "    sizes, bounds = get_sizes(lst)\n",
    "    if arr is None:\n",
    "        arr = np.zeros(sum(sizes))\n",
    "    else:\n",
    "        assert len(arr) == sum(sizes)\n",
    "    for bound, var in zip(bounds, lst):\n",
    "        arr[bound[0]: bound[1]] = var.data.cpu().numpy().reshape(-1)\n",
    "    return arr\n",
    "\n",
    "\n",
    "\n",
    "def numpy_to_torch(arr, net):\n",
    "    device = next(net.parameters()).device\n",
    "    arr = torch.from_numpy(arr).to(device)\n",
    "    sizes, bounds = get_sizes(net.parameters())\n",
    "    assert len(arr) == sum(sizes)\n",
    "    for bound, var in zip(bounds, net.parameters()):\n",
    "        vnp = var.data.view(-1)\n",
    "        vnp[:] = arr[bound[0] : bound[1]]\n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "46c4e015",
   "metadata": {},
   "outputs": [],
   "source": [
    "def FedAvg(epochs, client_data, model, lr, global_lr, local_iterations):\n",
    "    protected_attributes = [1,2]\n",
    "    num_instances = 0\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    print(num_instances)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    \n",
    "    \n",
    "    for i in range(epochs):\n",
    "        if i % 20 == 19:\n",
    "            lr /= 2\n",
    "        all_client_updates = 0\n",
    "        for client in client_data.values():\n",
    "            client_model_copy = deepcopy(model)\n",
    "            optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "            for j in range(local_iterations):\n",
    "                optimizer.zero_grad()\n",
    "                model_output = client_model_copy(client['x'].float())\n",
    "#                 print(model_output)\n",
    "                model_prediction = model_output.argmax(1)\n",
    "#                 import pdb;\n",
    "#                 pdb.set_trace()\n",
    "#                 loss = nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                loss = (nn.CrossEntropyLoss()(model_output[client['a'] == 2], client['y'][client['a'] == 2]) + nn.CrossEntropyLoss()(model_output[client['a'] == 1], client['y'][client['a'] == 1])) / 2\n",
    "                loss = 0.\n",
    "                for p in protected_attributes:\n",
    "                    loss += nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)])\n",
    "                    loss += nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)])\n",
    "                loss /= 4.\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "            all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "        all_client_updates /= len(client_data)\n",
    "        w = torch_to_numpy(model.parameters())\n",
    "        w += all_client_updates * global_lr\n",
    "        numpy_to_torch(w, model)\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        \n",
    "#         print('Printing train accuracy...')\n",
    "        for client in client_data.values():\n",
    "            model_output = client_model_copy(client['x'].float())\n",
    "            model_prediction = model_output.argmax(1)\n",
    "            correct += model_prediction.eq(client['y']).sum()\n",
    "            total += len(client['y'])\n",
    "#         print(correct, total)\n",
    "#         print(correct.item()*1.0/total)\n",
    "\n",
    "        if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                PP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "#                         PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                         total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                max_sps.append(running_theta_test)\n",
    "        \n",
    "#         if i % 200 == 199:\n",
    "#             print('Printing test accuracy')\n",
    "#             correct = 0\n",
    "#             PP = 0\n",
    "#             total = 0\n",
    "#             PP_by_group = [0]*9\n",
    "#             total_by_group = [0]*9\n",
    "#             running_theta = [0]*9\n",
    "            \n",
    "#             for client in client_data.values():\n",
    "#                 model_output = model(client['x_test'].float())\n",
    "#                 model_prediction = model_output.argmax(1)\n",
    "#                 PP += model_prediction.sum()\n",
    "\n",
    "#                 for idx,p in enumerate(protected_attributes):\n",
    "#                     if len(model_output[client['a_test'] == p]) < 1:\n",
    "#                         running_theta[idx] += torch.tensor(0)\n",
    "#                         continue\n",
    "#                     the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a_test'] == p], client['y_test'][client['a_test'] == p])\n",
    "#                     running_theta[idx] += the\n",
    "#                     PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                     total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "#                 correct += model_prediction.eq(client['y_test']).sum()\n",
    "#                 total += len(client['y_test'])\n",
    "#             print(correct.item()*1.0/total)\n",
    "\n",
    "#             SP_by_group = [0]*9\n",
    "#             for idx,p in enumerate(protected_attributes):\n",
    "#                 SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "#                 running_theta[idx] /= num_instances[idx]\n",
    "#             print(running_theta)\n",
    "#             print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "#             print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "        \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5a5690e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression(16,2)\n",
    "epochs = 200\n",
    "lr = 0.1\n",
    "threshold = 1\n",
    "pos_threshold = 1\n",
    "neg_threshold = 0.2\n",
    "global_lr = 1\n",
    "lr_theta = 0.1\n",
    "local_iterations = 3\n",
    "rounds = 10\n",
    "B = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "199aa06f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold):\n",
    "    theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_instances = 0\n",
    "    num_test_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a_test'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_instances += running_instance\n",
    "    print(num_instances, num_test_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(lmbda)\n",
    "        grad_theta = 0\n",
    "        initial_lr /= 5\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "            if i % 20 == 19:\n",
    "                lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "    #                 print(model_output)\n",
    "#                     model_prediction = model_output.argmax(1)\n",
    "#                     sum_positive_prediction = F.softmax(model_output[:,1]).sum() / N\n",
    "#                     sum_positive_prediction_per_group = 0\n",
    "                    loss = 0\n",
    "                    for p, l, ins in zip(protected_attributes, lmbda, num_instances):\n",
    "                        if len(model_output[client['a'] == p]) < 1:\n",
    "                            continue                        \n",
    "                        ################## Global Fairness ####################\n",
    "#                         loss += l * (nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins - threshold)  \n",
    "                        ################## Local Fairness ####################\n",
    "                        loss += l * (nn.CrossEntropyLoss(reduction='mean')(model_output[client['a'] == p], client['y'][client['a'] == p]) - threshold)\n",
    "#                         sum_positive_prediction_per_group += l * (sum_positive_prediction - torch.tensor([model_output[idx,1] if (client['a'][idx] == p) else 0 for idx in range(len(model_output))]).sum()/N_a[p])\n",
    "#                         sum_positive_prediction_per_group+=(torch.tensor([model_output[idx,1] if (client['a'][idx] == p) else 0 for idx in range(len(model_output))]).sum()/N_a[p])\n",
    "    #                 print(model_output, client['y'])\n",
    "#                     reg = torch.tensor(sum_positive_prediction - sum_positive_prediction_per_group)\n",
    "#                     reg = sum_positive_prediction_per_group\n",
    "#                     print(reg)\n",
    "#                     print(loss, nn.CrossEntropyLoss()(model_output, client['y']))\n",
    "                    loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "#                 if i == epochs-1:\n",
    "#                     running_theta = []\n",
    "#                     for p in protected_attributes:\n",
    "#                         if len(model_output[client['a'] == p]) < 1:\n",
    "#                             running_theta.append(torch.tensor(0))\n",
    "#                             continue\n",
    "#                         running_theta.append(nn.CrossEntropyLoss()(model_output[client['a'] == p], client['y'][client['a'] == p]))\n",
    "#                     running_theta = torch.tensor(running_theta)\n",
    "# #                     print(running_theta)\n",
    "#                     grad_theta = grad_theta + running_theta\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    for p, ins in zip(protected_attributes, num_instances):\n",
    "                        if len(model_output[client['a'] == p]) < 1:\n",
    "                            running_theta.append(torch.tensor(0))\n",
    "                            continue\n",
    "                        \n",
    "                        ############## Global Fairness ########\n",
    "#                         the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins - threshold\n",
    "                        \n",
    "                        ############## Local Fairness ###########\n",
    "                        the = nn.CrossEntropyLoss(reduction='mean')(model_output[client['a'] == p], client['y'][client['a'] == p]) - threshold\n",
    "                        running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "            \n",
    "#             print('Printing train accuracy..')\n",
    "#             for client in client_data.values():\n",
    "#                 model_output = model(client['x'].float())\n",
    "#                 model_prediction = model_output.argmax(1)\n",
    "#                 correct += model_prediction.eq(client['y']).sum()\n",
    "#                 total += len(client['y'])\n",
    "#     #         print(correct, total)\n",
    "#             print(correct.item()*1.0/total)\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                PP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "                    PP += model_prediction.sum()\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[client['a_test'] == p]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a_test'] == p], client['y_test'][client['a_test'] == p])\n",
    "                        running_theta_test[idx] += the\n",
    "                        PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "                        total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_instances[idx]\n",
    "                print(running_theta_test)\n",
    "                max_sps.append(running_theta_test)\n",
    "                print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "                print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "        for idx,p in enumerate(protected_attributes):\n",
    "            grad_theta[idx] = grad_theta[idx] / num_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "eab9f40e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg_TP_FP(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold):\n",
    "    TP = False\n",
    "    FP = True\n",
    "    weighted = False\n",
    "    if TP and FP:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)*2))\n",
    "    else:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(theta)\n",
    "        grad_theta = 0\n",
    "        initial_lr /= 5\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "            if i % 20 == 19:\n",
    "                lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "                    loss = 0\n",
    "                    if TP:\n",
    "                        for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold)  \n",
    "                    if FP:\n",
    "                        if TP:\n",
    "                            target_dual = lmbda[2:]\n",
    "                        else:\n",
    "                            target_dual = lmbda[:2]\n",
    "                        for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold)  \n",
    "                    if weighted:\n",
    "                        loss += (nn.CrossEntropyLoss()(model_output[client['y'] == 0], client['y'][client['y'] == 0]) + nn.CrossEntropyLoss()(model_output[client['y'] == 1], client['y'][client['y'] == 1]))/2\n",
    "                    else:\n",
    "                        loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    if TP:\n",
    "                        for p, ins in zip(protected_attributes, num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold\n",
    "                            running_theta.append(the)\n",
    "                    if FP:\n",
    "                        for p, ins in zip(protected_attributes, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold\n",
    "                            running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                PP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "#                         PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                         total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                max_sps.append(running_theta_test)\n",
    "#                 print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "#                 print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "#         for idx,p in enumerate(protected_attributes):\n",
    "#             if TP:\n",
    "#                 grad_theta[idx] = grad_theta[idx] / num_pos_instances[idx]\n",
    "#             if FP:\n",
    "#                 target = 2 if TP else 0\n",
    "#                 grad_theta[idx+target] = grad_theta[idx+target] / num_neg_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5c39bd13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FedAvg_TP_FP_local(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold):\n",
    "    TP = False\n",
    "    FP = True\n",
    "    weighted = False\n",
    "    if TP and FP:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)*2))\n",
    "    else:\n",
    "        theta = torch.tensor(np.zeros(len(protected_attributes)))\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_pos_instances = 0\n",
    "    num_neg_instances = 0\n",
    "    num_test_pos_instances = 0\n",
    "    num_test_neg_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a']) == p) & (client['y'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_neg_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_pos_instances += running_instance\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = ((torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_test_neg_instances += running_instance\n",
    "    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)\n",
    "    \n",
    "    for k in range(rounds):\n",
    "        lmbda = B*theta.exp()/(1+theta.exp().sum())\n",
    "        print(theta)\n",
    "        grad_theta = 0\n",
    "        initial_lr /= 5\n",
    "        lr = initial_lr\n",
    "        for i in range(epochs):\n",
    "            if i % 20 == 19:\n",
    "                lr /= 2\n",
    "            all_client_updates = 0\n",
    "            for client in client_data.values():\n",
    "                client_model_copy = deepcopy(model)\n",
    "                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "                for j in range(local_iterations):\n",
    "                    optimizer.zero_grad()\n",
    "                    model_output = client_model_copy(client['x'].float())\n",
    "                    loss = 0\n",
    "                    if TP:\n",
    "                        for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) - pos_threshold / 50)  \n",
    "                    if FP:\n",
    "                        if TP:\n",
    "                            target_dual = lmbda[2:]\n",
    "                        else:\n",
    "                            target_dual = lmbda[:2]\n",
    "                        for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                continue                        \n",
    "                            loss += l * (10 if ((p == 1) and weighted) else 1) * (nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) - neg_threshold / 50)  \n",
    "                    if weighted:\n",
    "                        loss += (nn.CrossEntropyLoss()(model_output[client['y'] == 0], client['y'][client['y'] == 0]) + nn.CrossEntropyLoss()(model_output[client['y'] == 1], client['y'][client['y'] == 1]))/2\n",
    "                    else:\n",
    "                        loss += nn.CrossEntropyLoss()(model_output, client['y'])\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "            all_client_updates /= len(client_data)\n",
    "            w = torch_to_numpy(model.parameters())\n",
    "            w += all_client_updates\n",
    "            numpy_to_torch(w, model)\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                for client in client_data.values():\n",
    "                    running_theta = []\n",
    "                    model_output = model(client['x'].float())\n",
    "                    if TP:\n",
    "                        for p, ins in zip(protected_attributes, num_pos_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 1)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 1)]) / ins - pos_threshold / 50\n",
    "                            running_theta.append(the)\n",
    "                    if FP:\n",
    "                        for p, ins in zip(protected_attributes, num_neg_instances):\n",
    "                            if len(model_output[client['a'] == p]) < 1:\n",
    "                                running_theta.append(torch.tensor(0))\n",
    "                                continue\n",
    "\n",
    "                            the = (10 if ((p == 1) and weighted) else 1) * nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(client['a']) == p) & (client['y'] == 0)], client['y'][(torch.tensor(client['a']) == p) & (client['y'] == 0)]) / ins - neg_threshold / 50\n",
    "                            running_theta.append(the)\n",
    "                    running_theta = torch.tensor(running_theta)\n",
    "                    grad_theta = grad_theta + running_theta\n",
    "                \n",
    "                iterates += 1\n",
    "                average_iterate += torch_to_numpy(model.parameters())\n",
    "                numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "            \n",
    "            if i == epochs-1:\n",
    "                print('Printing test accuracy')\n",
    "                correct = 0\n",
    "                PP = 0\n",
    "                total = 0\n",
    "                PP_by_group = [0]*9\n",
    "                total_by_group = [0]*9\n",
    "                running_theta_test = [0]*9\n",
    "                for client in client_data.values():\n",
    "                    model_output = model(client['x_test'].float())\n",
    "                    model_prediction = model_output.argmax(1)\n",
    "#                     PP += model_prediction.sum()\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)]) < 1:\n",
    "                            running_theta_test[idx] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 1)])\n",
    "                        running_theta_test[idx] += the\n",
    "\n",
    "                    for idx,p in enumerate(protected_attributes):\n",
    "                        if len(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)]) < 1:\n",
    "                            running_theta_test[idx+2] += torch.tensor(0)\n",
    "                            continue\n",
    "                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)], client['y_test'][(torch.tensor(client['a_test']) == p) & (client['y_test'] == 0)])\n",
    "                        running_theta_test[idx+2] += the\n",
    "#                         PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "#                         total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                    correct += model_prediction.eq(client['y_test']).sum()\n",
    "                    total += len(client['y_test'])\n",
    "                print(correct.item()*1.0/total)\n",
    "                test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "                SP_by_group = [0]*9\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx] /= num_test_pos_instances[idx]\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "#                     SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                    running_theta_test[idx+2] /= num_test_neg_instances[idx]\n",
    "                print(\"TP loss: \", running_theta_test[:2])\n",
    "                print(\"FP loss: \", running_theta_test[2:])\n",
    "                max_sps.append(running_theta_test)\n",
    "#                 print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "#                 print('SP std: ', np.array(SP_by_group).std())\n",
    "        \n",
    "#         for idx,p in enumerate(protected_attributes):\n",
    "#             if TP:\n",
    "#                 grad_theta[idx] = grad_theta[idx] / num_pos_instances[idx]\n",
    "#             if FP:\n",
    "#                 target = 2 if TP else 0\n",
    "#                 grad_theta[idx+target] = grad_theta[idx+target] / num_neg_instances[idx]\n",
    "#         print(grad_theta)\n",
    "        theta += lr_theta * grad_theta\n",
    "#         print(theta, grad_theta)\n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ec3daefa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def FedMinMax(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold):\n",
    "    theta = torch.tensor(np.ones(len(protected_attributes)))*1.0 / len(protected_attributes)\n",
    "    average_iterate = 0\n",
    "    iterates = 0\n",
    "    avg_model = deepcopy(model)\n",
    "    num_instances = 0\n",
    "    test_accs = []\n",
    "    max_sps = []\n",
    "    for client in client_data.values():\n",
    "        running_instance = []\n",
    "        for p in protected_attributes:\n",
    "            instance = (client['a'] == p).sum()\n",
    "            running_instance.append(instance)\n",
    "        running_instance = np.array(running_instance)\n",
    "        num_instances += running_instance\n",
    "    print(num_instances)\n",
    "    \n",
    "    for i in range(epochs):\n",
    "        lr = initial_lr\n",
    "        lmbda = theta\n",
    "#         if i % 20 == 19:\n",
    "#             lr /= 2\n",
    "        all_client_updates = 0\n",
    "        for client in client_data.values():\n",
    "            client_model_copy = deepcopy(model)\n",
    "            optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)\n",
    "            for j in range(local_iterations):\n",
    "                optimizer.zero_grad()\n",
    "                model_output = client_model_copy(client['x'].float())\n",
    "                loss = 0\n",
    "                for p, l, ins in zip(protected_attributes, lmbda, num_instances):\n",
    "                    if len(model_output[client['a'] == p]) < 1:\n",
    "                        continue                        \n",
    "                    loss += l * (nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins)  \n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "            all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())\n",
    "        all_client_updates /= len(client_data)\n",
    "        w = torch_to_numpy(model.parameters())\n",
    "        w += all_client_updates\n",
    "        numpy_to_torch(w, model)\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        grad_theta = 0\n",
    "\n",
    "        for client in client_data.values():\n",
    "            running_theta = []\n",
    "            model_output = model(client['x'].float())\n",
    "            for p, ins in zip(protected_attributes, num_instances):\n",
    "                if len(model_output[client['a'] == p]) < 1:\n",
    "                    running_theta.append(torch.tensor(0))\n",
    "                    continue\n",
    "                the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a'] == p], client['y'][client['a'] == p]) / ins\n",
    "                running_theta.append(the)\n",
    "            running_theta = torch.tensor(running_theta)\n",
    "            grad_theta = grad_theta + running_theta\n",
    "\n",
    "        iterates += 1\n",
    "        average_iterate += torch_to_numpy(model.parameters())\n",
    "        numpy_to_torch(average_iterate/iterates, avg_model)\n",
    "\n",
    "        if i == epochs-1:\n",
    "            print('Printing test accuracy')\n",
    "            correct = 0\n",
    "            PP = 0\n",
    "            total = 0\n",
    "            PP_by_group = [0]*9\n",
    "            total_by_group = [0]*9\n",
    "            running_theta_test = [0]*9\n",
    "            for client in client_data.values():\n",
    "                model_output = model(client['x_test'].float())\n",
    "                model_prediction = model_output.argmax(1)\n",
    "                PP += model_prediction.sum()\n",
    "\n",
    "                for idx,p in enumerate(protected_attributes):\n",
    "                    if len(model_output[client['a_test'] == p]) < 1:\n",
    "                        running_theta_test[idx] += torch.tensor(0)\n",
    "                        continue\n",
    "                    the = nn.CrossEntropyLoss(reduction='sum')(model_output[client['a_test'] == p], client['y_test'][client['a_test'] == p])\n",
    "                    running_theta_test[idx] += the\n",
    "                    PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()\n",
    "                    total_by_group[idx] += len(model_prediction[client['a_test'] == p])\n",
    "                correct += model_prediction.eq(client['y_test']).sum()\n",
    "                total += len(client['y_test'])\n",
    "            print(correct.item()*1.0/total)\n",
    "            test_accs.append(correct.item()*1.0/total)\n",
    "\n",
    "            SP_by_group = [0]*9\n",
    "            for idx,p in enumerate(protected_attributes):\n",
    "                SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()\n",
    "                running_theta_test[idx] /= num_instances[idx]\n",
    "            print(running_theta_test)\n",
    "            max_sps.append(running_theta_test)\n",
    "            print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))\n",
    "            print('SP std: ', np.array(SP_by_group).std())\n",
    "\n",
    "        for idx,p in enumerate(protected_attributes):\n",
    "            grad_theta[idx] = grad_theta[idx] / num_instances[idx]\n",
    "        theta += lr_theta * grad_theta\n",
    "#         theta /= theta.sum()  \n",
    "        \n",
    "        \n",
    "    return avg_model, test_accs, max_sps"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
