{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aab52dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import itertools\n",
    "\n",
    "from copy import deepcopy\n",
    "from ucimlrepo import fetch_ucirepo \n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn import svm\n",
    "import sklearn.preprocessing as preprocessing\n",
    "from scipy.stats import beta\n",
    "from sklearn.utils import shuffle\n",
    "from sklearn import metrics\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0139e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "features = [\"Age\", \"Workclass\", \"fnlwgt\", \"Education\", \"Education-Num\", \"Martial Status\",\n",
    "        \"Occupation\", \"Relationship\", \"Race\", \"Sex\", \"Capital Gain\", \"Capital Loss\",\n",
    "        \"Hours per week\", \"Country\", \"Target\"] \n",
    "\n",
    "train_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'\n",
    "test_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'\n",
    "\n",
    "# Load train and test data\n",
    "original_train = pd.read_csv(train_url, names=features, sep=r'\\s*,\\s*', \n",
    "                             engine='python', na_values=\"?\")\n",
    "original_test = pd.read_csv(test_url, names=features, sep=r'\\s*,\\s*', \n",
    "                            engine='python', na_values=\"?\", skiprows=1)\n",
    "\n",
    "original = pd.concat([original_train, original_test])\n",
    "\n",
    "\n",
    "# Group within groups\n",
    "#1: Workclass\n",
    "work = original['Workclass']\n",
    "work = work.replace('Federal-gov', 'Government')\n",
    "work = work.replace('Local-gov', 'Government')\n",
    "work = work.replace('State-gov', 'Government')\n",
    "work = work.replace('Self-emp-inc', 'Self-Employed')\n",
    "work = work.replace('Self-emp-not-inc', 'Self-Employed')\n",
    "work = work.replace('Never-worked', 'Other/Unknown')\n",
    "work = work.replace('Without-pay', 'Other/Unknown')\n",
    "work = work.replace(np.nan, 'Other/Unknown')\n",
    "\n",
    "#2: Marital_status\n",
    "marital = original['Martial Status']\n",
    "marital = marital.replace('Married-AF-spouse', 'Married')\n",
    "marital = marital.replace('Married-civ-spouse', 'Married')\n",
    "marital = marital.replace('Married-spouse-absent', 'Married')\n",
    "marital = marital.replace('Never-married', 'Single')\n",
    "\n",
    "#3: Occupation\n",
    "occupation = original['Occupation']\n",
    "occupation = occupation.replace('Adm-clerical', 'White-Collar')\n",
    "occupation = occupation.replace('Craft-repair', 'Blue-Collar')\n",
    "occupation = occupation.replace('Exec-managerial', 'White-Collar')\n",
    "occupation = occupation.replace('Farming-fishing', 'Blue-Collar')\n",
    "occupation = occupation.replace('Handlers-cleaners', 'Blue-Collar')\n",
    "occupation = occupation.replace('Machine-op-inspct', 'Blue-Collar')\n",
    "occupation = occupation.replace('Transport-moving', 'Blue-Collar')\n",
    "occupation = occupation.replace('Other-service', 'Service')\n",
    "occupation = occupation.replace('Priv-house-serv', 'Service')\n",
    "occupation = occupation.replace('Protective-serv', 'Service')\n",
    "occupation = occupation.replace('Tech-support', 'Service')\n",
    "occupation = occupation.replace('Prof-specialty', 'Professional')\n",
    "occupation = occupation.replace('Unknown', 'Other/Unknown')\n",
    "occupation = occupation.replace('Armed-Forces', 'Other/Unknown')\n",
    "occupation = occupation.replace(np.nan, 'Other/Unknown')\n",
    "\n",
    "#4: Income labels\n",
    "labels = original['Target']\n",
    "labels = labels.replace('<=50K', 0).replace('>50K', 1)\n",
    "labels = labels.replace('<=50K.', 0).replace('>50K.', 1)\n",
    "\n",
    "#5: Sex\n",
    "sex = original['Sex']\n",
    "sex = sex.replace('Male', 1)\n",
    "sex = sex.replace('Female', 0)\n",
    "\n",
    "#6: Education\n",
    "education = original['Education']\n",
    "education = education.replace('10th', \"High-School\")\n",
    "education = education.replace('11th', \"High-School\")\n",
    "education = education.replace('12th', \"High-School\")\n",
    "education = education.replace('HS-grad', \"High-School\")\n",
    "education = education.replace('1st-4th', \"Elem-School\")\n",
    "education = education.replace('5th-6th', \"Elem-School\")\n",
    "education = education.replace('7th-8th', \"Mid-School\")\n",
    "education = education.replace('9th', \"Mid-School\")\n",
    "education = education.replace('Assoc-acdm', \"Assoc\")\n",
    "education = education.replace('Assoc-voc', \"Assoc\")\n",
    "education = education.replace('Some-college', \"Bachelors\")\n",
    "education = education.replace('Prof-school', \"Masters\")\n",
    "\n",
    "#7: Country\n",
    "country = original['Country']\n",
    "country = country.replace('Canada', \"North-America\")\n",
    "country = country.replace('United-States', \"North-America\")\n",
    "country = country.replace('Mexico', \"North-America\")\n",
    "country = country.replace('Honduras', \"North-America\")\n",
    "country = country.replace('Outlying-US(Guam-USVI-etc)', \"North-America\")\n",
    "country = country.replace('Puerto-Rico', \"North-America\")\n",
    "country = country.replace('Guatemala', \"Central-America\")\n",
    "country = country.replace('Honduras', \"Central-America\")\n",
    "country = country.replace('Nicaragua', \"Central-America\")\n",
    "country = country.replace('Columbia', \"South-America\")\n",
    "country = country.replace('Cuba', \"South-America\")\n",
    "country = country.replace('Dominican-Republic', \"South-America\")\n",
    "country = country.replace('Ecuador', \"South-America\")\n",
    "country = country.replace('El-Salvador', \"South-America\")\n",
    "country = country.replace('Peru', \"South-America\")\n",
    "country = country.replace('England', \"Europe\")\n",
    "country = country.replace('France', \"Europe\")\n",
    "country = country.replace('Germany', \"Europe\")\n",
    "country = country.replace('Greece', \"Europe\")\n",
    "country = country.replace('Ireland', \"Europe\")\n",
    "country = country.replace('Italy', \"Europe\")\n",
    "country = country.replace('Portugal', \"Europe\")\n",
    "country = country.replace('Poland', \"Europe\")\n",
    "country = country.replace('Scotland', \"Europe\")\n",
    "country = country.replace('Hungary', \"Europe\")\n",
    "country = country.replace('Yugoslavia', \"Europe\")\n",
    "country = country.replace('Holand-Netherlands', \"Europe\")\n",
    "country = country.replace('Yugoslavia', \"Europe\")\n",
    "country = country.replace('China', \"Asia\")\n",
    "country = country.replace('India', \"Asia\")\n",
    "country = country.replace('Iran', \"Asia\")\n",
    "country = country.replace('Hong', \"Asia\")\n",
    "country = country.replace('Hong Kong', \"Asia\")\n",
    "country = country.replace('Japan', \"Asia\")\n",
    "country = country.replace('Laos', \"Asia\")\n",
    "country = country.replace('Philippines', \"Asia\")\n",
    "country = country.replace('Cambodia', \"Asia\")\n",
    "country = country.replace('Taiwan', \"Asia\")\n",
    "country = country.replace('Thailand', \"Asia\")\n",
    "country = country.replace('South', \"Asia\")\n",
    "country = country.replace('Vietnam', \"Asia\")\n",
    "country = country.replace('Haiti', \"Caribbean\")\n",
    "country = country.replace('Jamaica', \"Caribbean\")\n",
    "country = country.replace('Trinadad&Tobago', \"Caribbean\")\n",
    "country = country.replace(np.nan, \"Others\")\n",
    "\n",
    "#8: Race\n",
    "race = original['Race']\n",
    "race = race.replace('Amer-Indian-Eskimo', 0)\n",
    "race = race.replace('Asian-Pac-Islander', 0)\n",
    "race = race.replace('Black', 0)\n",
    "race = race.replace('Other', 0)\n",
    "race = race.replace('White', 1)\n",
    "\n",
    "# Reassign columns\n",
    "original['Target'] = labels\n",
    "original['Workclass'] = work\n",
    "original['Martial Status'] = marital\n",
    "original['Occupation'] = occupation\n",
    "original['Sex'] = sex\n",
    "original['Race'] = race\n",
    "original['Country'] = country\n",
    "original = original.drop('fnlwgt', axis = 1)\n",
    "\n",
    "# Dataframe\n",
    "original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63920b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original data\n",
    "data = [\n",
    "    ['0', '300', '1000'],\n",
    "    ['1', '300', '1000'],\n",
    "    ['2', '300', '39362'],\n",
    "    ['3', '600', '200'],\n",
    "    ['4', '5580', '200']\n",
    "]\n",
    "\n",
    "# Updated column headers\n",
    "columns = ['Client ID', 'Non-White samples', 'White samples']\n",
    "\n",
    "# Create figure and axis\n",
    "fig, ax = plt.subplots(figsize=(6, 3))\n",
    "\n",
    "# Hide the axes\n",
    "ax.xaxis.set_visible(False)\n",
    "ax.yaxis.set_visible(False)\n",
    "ax.set_frame_on(False)\n",
    "\n",
    "# Create the table\n",
    "table = ax.table(cellText=data, colLabels=columns, cellLoc='center', loc='center')\n",
    "\n",
    "# Adjust table appearance\n",
    "table.scale(1, 2)\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(12)\n",
    "\n",
    "# Set font weight and color\n",
    "for key, cell in table.get_celld().items():\n",
    "    cell.set_fontsize(12)\n",
    "    cell.set_text_props(fontfamily='serif', fontweight='normal', color='black')\n",
    "\n",
    "# Display the table\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caa74986",
   "metadata": {},
   "outputs": [],
   "source": [
    "#   race = 0,     race = 1\n",
    "#0  300           1000\n",
    "#1  300           1000 \n",
    "#2  300           39362\n",
    "#3  600           200 \n",
    "#4  5580          200\n",
    "#\n",
    "#   7080          41762\n",
    "\n",
    "# Step 1: Split the dataset into two subsets based on race\n",
    "race_1 = original[original['Race'] == 1]  # race == 1 samples\n",
    "race_0 = original[original['Race'] == 0]  # race == 0 samples\n",
    "\n",
    "# Step 2: Select the exact number of samples for each client\n",
    "# For race == 1\n",
    "race_1_client1, race_1_remaining = train_test_split(race_1, train_size=1000, random_state=42)\n",
    "race_1_client2, race_1_remaining = train_test_split(race_1_remaining, train_size=1000, random_state=42)\n",
    "race_1_client3, race_1_remaining = train_test_split(race_1_remaining, train_size=39362, random_state=42)\n",
    "race_1_client4, race_1_client5 = train_test_split(race_1_remaining, train_size=200, random_state=42)\n",
    "\n",
    "# For race == 0\n",
    "race_0_client1, race_0_remaining = train_test_split(race_0, train_size=300, random_state=42)\n",
    "race_0_client2, race_0_remaining = train_test_split(race_0_remaining, train_size=300, random_state=42)\n",
    "race_0_client3, race_0_remaining = train_test_split(race_0_remaining, train_size=300, random_state=42)\n",
    "race_0_client4, race_0_client5 = train_test_split(race_0_remaining, train_size=600, random_state=42)\n",
    "\n",
    "# Step 3: Combine race == 1 and race == 0 samples for each client\n",
    "client1 = pd.concat([race_1_client1, race_0_client1])\n",
    "client2 = pd.concat([race_1_client2, race_0_client2])\n",
    "client3 = pd.concat([race_1_client3, race_0_client3])\n",
    "client4 = pd.concat([race_1_client4, race_0_client4])\n",
    "client5 = pd.concat([race_1_client5, race_0_client5])\n",
    "\n",
    "# Step 4: Verify the number of samples for each client\n",
    "print(f\"Client 1: Race==1: {len(race_1_client1)}, Race==0: {len(race_0_client1)}\")\n",
    "print(f\"Client 2: Race==1: {len(race_1_client2)}, Race==0: {len(race_0_client2)}\")\n",
    "print(f\"Client 3: Race==1: {len(race_1_client3)}, Race==0: {len(race_0_client3)}\")\n",
    "print(f\"Client 4: Race==1: {len(race_1_client4)}, Race==0: {len(race_0_client4)}\")\n",
    "print(f\"Client 5: Race==1: {len(race_1_client5)}, Race==0: {len(race_0_client5)}\")\n",
    "\n",
    "clients = {\n",
    "    'client1': client1,\n",
    "    'client2': client2,\n",
    "    'client3': client3,\n",
    "    'client4': client4,\n",
    "    'client5': client5\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45fb37a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evenly split\n",
    "shuffled_data = original.sample(frac=1, random_state=42).reset_index(drop=True)\n",
    "\n",
    "# Split into 5 equal parts\n",
    "split_data = np.array_split(shuffled_data, 5)\n",
    "\n",
    "# Assign each split to a different client\n",
    "clients = {}\n",
    "for i, data in enumerate(split_data):\n",
    "    clients[f'client{i+1}'] = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c5c0f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to split data, convert to tensor, and create DataLoader\n",
    "def create_dataloaders(i, client_data, batch_size=32, test_size=0.2):\n",
    "    # Features, labels, and sensitive attribute (Sex)\n",
    "    #X = client_data.drop(columns=['Target', 'Sex'])\n",
    "    X = client_data.drop(columns=['Target', 'Race'])\n",
    "    #S = client_data['Sex'].values  \n",
    "    S = client_data['Race'].values  \n",
    "    Y = client_data['Target'].values\n",
    "    \n",
    "    X = pd.get_dummies(X)\n",
    "    \n",
    "    X_train, X_test, S_train, S_test, Y_train, Y_test = train_test_split(X, S, Y, test_size=test_size, random_state=42)\n",
    "    \n",
    "    print (X_train.shape)\n",
    "    # Convert to torch tensors\n",
    "    X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)\n",
    "    X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)\n",
    "    S_train_tensor = torch.tensor(S_train, dtype=torch.float32).view(-1, 1)\n",
    "    S_test_tensor = torch.tensor(S_test, dtype=torch.float32).view(-1, 1)\n",
    "    Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).view(-1, 1)\n",
    "    Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32).view(-1, 1)\n",
    "\n",
    "    # Create TensorDataset\n",
    "    train_dataset = TensorDataset(X_train_tensor, S_train_tensor, Y_train_tensor)\n",
    "    test_dataset = TensorDataset(X_test_tensor, S_test_tensor, Y_test_tensor)\n",
    "\n",
    "    # Create DataLoader\n",
    "    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    return train_loader, test_loader\n",
    "\n",
    "\n",
    "iid_train = []\n",
    "iid_test = []\n",
    "for i in range(1, 6):\n",
    "    client_data = clients[f'client{i}']\n",
    "    train_loader, test_loader = create_dataloaders(i, client_data, batch_size=32, test_size=0.2)\n",
    "    iid_train.append(train_loader)\n",
    "    iid_test.append(test_loader)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ca5e809",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
