{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4cf135",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from ucimlrepo import fetch_ucirepo \n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e6e75c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining True Parameters\n",
    "mu_1_group1_cluster1 = 12\n",
    "mu_0_group1_cluster1 = 8\n",
    "mu_1_group0_cluster1 = 11\n",
    "mu_0_group0_cluster1 = 7\n",
    "\n",
    "mu_1_group1_cluster11 = 12\n",
    "mu_0_group1_cluster11 = 9\n",
    "mu_1_group0_cluster11 = 12\n",
    "mu_0_group0_cluster11 = 9\n",
    "\n",
    "mu_1_group1_cluster111 = 11.5\n",
    "mu_0_group1_cluster111 = 8.5\n",
    "mu_1_group0_cluster111 = 11.5\n",
    "mu_0_group0_cluster111 = 8.5\n",
    "\n",
    "mu_1_group1_cluster1111 = 10.5\n",
    "mu_0_group1_cluster1111 = 7.5\n",
    "mu_1_group0_cluster1111 = 10.5\n",
    "mu_0_group0_cluster1111 = 7.5\n",
    "\n",
    "mu_1_group1_cluster0 = 8\n",
    "mu_0_group1_cluster0 = 6\n",
    "mu_1_group0_cluster0 = 8\n",
    "mu_0_group0_cluster0 = 6\n",
    "\n",
    "mu_1_group1_cluster00 = 7.5\n",
    "mu_0_group1_cluster00 = 5.5\n",
    "mu_1_group0_cluster00 = 7.5\n",
    "mu_0_group0_cluster00 = 5.5\n",
    "\n",
    "mu_1_group1_cluster000 = 12\n",
    "mu_0_group1_cluster000 = 8\n",
    "mu_1_group0_cluster000 = 11\n",
    "mu_0_group0_cluster000 = 7\n",
    "\n",
    "mu_1_group1_cluster0000 = 11\n",
    "mu_0_group1_cluster0000 = 8\n",
    "mu_1_group0_cluster0000 = 11\n",
    "mu_0_group0_cluster0000 = 8\n",
    "\n",
    "sd_1 = 1\n",
    "sd_0 = 1\n",
    "\n",
    "pop_prob = 0.1\n",
    "size = 1200 \n",
    "batch_size = 32  \n",
    "shuffle = True  \n",
    "\n",
    "iid_train = []\n",
    "iid_test = []\n",
    "\n",
    "    \n",
    "# Generate initial training data \n",
    "data_1_group1_cluster1 = np.random.normal(mu_1_group1_cluster1,sd_1,size=size)\n",
    "data_0_group1_cluster1 = np.random.normal(mu_0_group1_cluster1,sd_1,size=size)\n",
    "data_1_group0_cluster1 = np.random.normal(mu_1_group0_cluster1,sd_1,size=size)\n",
    "data_0_group0_cluster1 = np.random.normal(mu_0_group0_cluster1,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster11 = np.random.normal(mu_1_group1_cluster11,sd_1,size=size)\n",
    "data_0_group1_cluster11 = np.random.normal(mu_0_group1_cluster11,sd_1,size=size)\n",
    "data_1_group0_cluster11 = np.random.normal(mu_1_group0_cluster11,sd_1,size=size)\n",
    "data_0_group0_cluster11 = np.random.normal(mu_0_group0_cluster11,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster111 = np.random.normal(mu_1_group1_cluster111,sd_1,size=size)\n",
    "data_0_group1_cluster111 = np.random.normal(mu_0_group1_cluster111,sd_1,size=size)\n",
    "data_1_group0_cluster111 = np.random.normal(mu_1_group0_cluster111,sd_1,size=size)\n",
    "data_0_group0_cluster111 = np.random.normal(mu_0_group0_cluster111,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster1111 = np.random.normal(mu_1_group1_cluster1111,sd_1,size=size)\n",
    "data_0_group1_cluster1111 = np.random.normal(mu_0_group1_cluster1111,sd_1,size=size)\n",
    "data_1_group0_cluster1111 = np.random.normal(mu_1_group0_cluster1111,sd_1,size=size)\n",
    "data_0_group0_cluster1111 = np.random.normal(mu_0_group0_cluster1111,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster0 = np.random.normal(mu_1_group1_cluster0,sd_0,size=size)\n",
    "data_0_group1_cluster0 = np.random.normal(mu_0_group1_cluster0,sd_0,size=size)\n",
    "data_1_group0_cluster0 = np.random.normal(mu_1_group0_cluster0,sd_0,size=size)\n",
    "data_0_group0_cluster0 = np.random.normal(mu_0_group0_cluster0,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster00 = np.random.normal(mu_1_group1_cluster00,sd_0,size=size)\n",
    "data_0_group1_cluster00 = np.random.normal(mu_0_group1_cluster00,sd_0,size=size)\n",
    "data_1_group0_cluster00 = np.random.normal(mu_1_group0_cluster00,sd_0,size=size)\n",
    "data_0_group0_cluster00 = np.random.normal(mu_0_group0_cluster00,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster000 = np.random.normal(mu_1_group1_cluster000,sd_0,size=size)\n",
    "data_0_group1_cluster000 = np.random.normal(mu_0_group1_cluster000,sd_0,size=size)\n",
    "data_1_group0_cluster000 = np.random.normal(mu_1_group0_cluster000,sd_0,size=size)\n",
    "data_0_group0_cluster000 = np.random.normal(mu_0_group0_cluster000,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster0000 = np.random.normal(mu_1_group1_cluster0000,sd_0,size=size)\n",
    "data_0_group1_cluster0000 = np.random.normal(mu_0_group1_cluster0000,sd_0,size=size)\n",
    "data_1_group0_cluster0000 = np.random.normal(mu_1_group0_cluster0000,sd_0,size=size)\n",
    "data_0_group0_cluster0000 = np.random.normal(mu_0_group0_cluster0000,sd_0,size=size)\n",
    "\n",
    "def generate_data (data1, data2, group1, data3, data4, group2):\n",
    "    \n",
    "    X1 = data1.T  # Label 1\n",
    "    X2 = data2.T  # Label 0\n",
    "    X_1 = np.hstack((X1, X2))\n",
    "    if group1 == 1:\n",
    "        s_1 = np.ones((len(X_1), 1))\n",
    "    else:\n",
    "        s_1 = np.zeros((len(X_1), 1))\n",
    "    y_1 = np.vstack((np.ones((len(X1), 1)), np.zeros((len(X2), 1))))\n",
    "    \n",
    "    # Convert to torch tensors\n",
    "    X_1 = torch.tensor(X_1, dtype=torch.float32).view(-1, 1)\n",
    "    s_1 = torch.tensor(s_1, dtype=torch.float32).view(-1, 1)\n",
    "    y_1 = torch.tensor(y_1, dtype=torch.float32)\n",
    "    \n",
    "    X3 = data3.T  # Label 1\n",
    "    X4 = data4.T  # Label 0\n",
    "    X_2 = np.hstack((X3, X4))\n",
    "    if group2 == 1:\n",
    "        s_2 = np.ones((len(X_2), 1))\n",
    "    else:\n",
    "        s_2 = np.zeros((len(X_2), 1))\n",
    "    y_2 = np.vstack((np.ones((len(X3), 1)), np.zeros((len(X4), 1))))\n",
    "    \n",
    "    # Convert to torch tensors\n",
    "    X_2 = torch.tensor(X_2, dtype=torch.float32).view(-1, 1)\n",
    "    s_2 = torch.tensor(s_2, dtype=torch.float32).view(-1, 1)\n",
    "    y_2 = torch.tensor(y_2, dtype=torch.float32)\n",
    "    \n",
    "    (X_1_train, X_1_test, \n",
    "     s_1_train, s_1_test, \n",
    "     y_1_train, y_1_test) = train_test_split(X_1, s_1, y_1, test_size=400, random_state=42)\n",
    "\n",
    "    (X_2_train, X_2_test, \n",
    "     s_2_train, s_2_test, \n",
    "     y_2_train, y_2_test) = train_test_split(X_2, s_2, y_2, test_size=400, random_state=42)\n",
    "    \n",
    "    combined_X_train = torch.cat((X_1_train, X_2_train), dim=0)\n",
    "    combined_s_train = torch.cat((s_1_train, s_2_train), dim=0)\n",
    "    combined_y_train = torch.cat((y_1_train, y_2_train), dim=0)\n",
    "\n",
    "    combined_X_test = torch.cat((X_1_test, X_2_test), dim=0)\n",
    "    combined_s_test = torch.cat((s_1_test, s_2_test), dim=0)\n",
    "    combined_y_test = torch.cat((y_1_test, y_2_test), dim=0)\n",
    "\n",
    "\n",
    "    return combined_X_train, combined_X_test, combined_s_train, combined_s_test, combined_y_train, combined_y_test\n",
    "\n",
    "# generate client 1 and 2 in cluster 1\n",
    "(X1_train_cluster1, X1_test_cluster1, \n",
    " s1_train_cluster1, s1_test_cluster1, \n",
    " y1_train_cluster1, y1_test_cluster1) = generate_data(data_1_group1_cluster1, data_0_group1_cluster1, 1,\n",
    "                                                      data_1_group0_cluster1, data_0_group0_cluster1, 0)\n",
    "\n",
    "\n",
    "\n",
    "(X1_train_cluster11, X1_test_cluster11, \n",
    " s1_train_cluster11, s1_test_cluster11, \n",
    " y1_train_cluster11, y1_test_cluster11) = generate_data(data_1_group1_cluster11, data_0_group1_cluster11, 1,\n",
    "                                                        data_1_group0_cluster11, data_0_group0_cluster11, 0)\n",
    "\n",
    "(X1_train_cluster111, X1_test_cluster111, \n",
    " s1_train_cluster111, s1_test_cluster111, \n",
    " y1_train_cluster111, y1_test_cluster111) = generate_data(data_1_group1_cluster111, data_0_group1_cluster111, 1,\n",
    "                                                        data_1_group0_cluster111, data_0_group0_cluster111, 0)\n",
    "\n",
    "(X1_train_cluster1111, X1_test_cluster1111, \n",
    " s1_train_cluster1111, s1_test_cluster1111, \n",
    " y1_train_cluster1111, y1_test_cluster1111) = generate_data(data_1_group1_cluster1111, data_0_group1_cluster1111, 1,\n",
    "                                                        data_1_group0_cluster1111, data_0_group0_cluster1111, 0)\n",
    "\n",
    "(X1_train_cluster0, X1_test_cluster0, \n",
    " s1_train_cluster0, s1_test_cluster0, \n",
    " y1_train_cluster0, y1_test_cluster0) = generate_data(data_1_group1_cluster0, data_0_group1_cluster0, 1,\n",
    "                                                      data_1_group0_cluster0, data_0_group0_cluster0, 0)\n",
    "\n",
    "\n",
    "\n",
    "(X1_train_cluster00, X1_test_cluster00, \n",
    " s1_train_cluster00, s1_test_cluster00, \n",
    " y1_train_cluster00, y1_test_cluster00) = generate_data(data_1_group1_cluster00, data_0_group1_cluster00, 1,\n",
    "                                                        data_1_group0_cluster00, data_0_group0_cluster00, 0)\n",
    "\n",
    "(X1_train_cluster000, X1_test_cluster000, \n",
    " s1_train_cluster000, s1_test_cluster000, \n",
    " y1_train_cluster000, y1_test_cluster000) = generate_data(data_1_group1_cluster000, data_0_group1_cluster000, 1,\n",
    "                                                        data_1_group0_cluster000, data_0_group0_cluster000, 0)\n",
    "\n",
    "(X1_train_cluster0000, X1_test_cluster0000, \n",
    " s1_train_cluster0000, s1_test_cluster0000, \n",
    " y1_train_cluster0000, y1_test_cluster0000) = generate_data(data_1_group1_cluster0000, data_0_group1_cluster0000, 1,\n",
    "                                                        data_1_group0_cluster0000, data_0_group0_cluster0000, 0)\n",
    "\n",
    "# Create a dataset and a dataloader\n",
    "def data_loader (X,s,y,batch_size):\n",
    "    \n",
    "    train = TensorDataset(X,s,y)\n",
    "    train_loader = DataLoader(dataset=train, batch_size=batch_size, shuffle=True)\n",
    "    return train_loader\n",
    "\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "    \n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df70d0cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NN model\n",
    "\n",
    "class LogisticRegressionNN(nn.Module):\n",
    "    def __init__(self, input_size):\n",
    "        super(LogisticRegressionNN, self).__init__()\n",
    "        self.fc = nn.Linear(input_size, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x)\n",
    "        return torch.sigmoid(x)\n",
    "\n",
    "# Initialize the model\n",
    "input_size = 1  \n",
    "model_0 = LogisticRegressionNN(input_size)  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ac7f358",
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculating the loss under the classififer\n",
    "\n",
    "def loss_classifier(predictions, labels):\n",
    "    loss = nn.BCELoss()\n",
    "    labels = labels\n",
    "    \n",
    "    return loss(predictions, labels)\n",
    "\n",
    "\n",
    "# calculating the loss for the dataset\n",
    "def loss_dataset(model, dataset, loss_f):\n",
    "    \"\"\"Compute the loss of `model` on `dataset`\"\"\"\n",
    "    loss=0\n",
    "    \n",
    "    for idx,(features,_,labels) in enumerate(dataset):\n",
    "        \n",
    "        predictions= model(features)\n",
    "        loss+=loss_f(predictions,labels)\n",
    "    \n",
    "    loss/=idx+1\n",
    "    return loss\n",
    "\n",
    "# calculating the accuracy for the dataset\n",
    "def accuracy_dataset(model, dataset):\n",
    "    \"\"\"Compute the accuracy of `model` on `dataset`\"\"\"\n",
    "    \n",
    "    correct=0\n",
    "    \n",
    "    for features,_,labels in iter(dataset):\n",
    "        \n",
    "        predictions= model(features)\n",
    "        \n",
    "        predicted = predictions.round()\n",
    "        correct+=torch.sum(predicted.view(-1,1)==labels.view(-1, 1)).item()\n",
    "        \n",
    "    accuracy = 100*correct/len(dataset.dataset)\n",
    "        \n",
    "    return accuracy\n",
    "\n",
    "# calculating the fairness for the dataset\n",
    "def fairness_dataset(model, dataset):\n",
    "    \"\"\"Compute the fairness of `model` on `dataset`\"\"\"\n",
    "    \n",
    "    target_list = []\n",
    "    s_list = []\n",
    "    x_list = []\n",
    "    pred_list = []\n",
    "    \n",
    "    for features, s, labels in iter(dataset):\n",
    "        \n",
    "        predictions= model(features)\n",
    "        \n",
    "        predicted = predictions.round()\n",
    "        x_list.append(features)\n",
    "        s_list.append(s)\n",
    "        target_list.append(labels)\n",
    "        pred_list.append(predicted.detach().numpy())\n",
    "\n",
    "    ppr_list = []\n",
    "    pred_list = np.concatenate(pred_list).ravel()\n",
    "    s_list = np.concatenate(s_list).ravel()\n",
    "    \n",
    "    for s_value in np.unique(s_list):\n",
    "        if np.mean(s_list == s_value) > 0.01:\n",
    "            ppr_list.append(np.mean(pred_list[s_list==s_value]))\n",
    "            \n",
    "    \n",
    "    dp_gap = max(ppr_list) - min(ppr_list)\n",
    "        \n",
    "    return dp_gap\n",
    "\n",
    "\n",
    "\n",
    "# train the algorithm, output the avg. total loss\n",
    "def train_step(model, model_0, mu:int, optimizer, train_data, loss_f):\n",
    "    \"\"\"Train `model` on one epoch of `train_data`\"\"\"\n",
    "    \n",
    "    total_loss=0\n",
    "    target_list = []\n",
    "    s_list = []\n",
    "    x_list = []\n",
    "    pred_list = []\n",
    "    \n",
    "    for idx, (features, s, labels) in enumerate(train_data):\n",
    "\n",
    "        predictions= model(features)\n",
    "        \n",
    "        predicted = predictions.round()\n",
    "        x_list.append(features)\n",
    "        s_list.append(s)\n",
    "        target_list.append(labels)\n",
    "        pred_list.append(predicted.detach().numpy())\n",
    "\n",
    "        loss=loss_f(predictions,labels)\n",
    "        total_loss+=loss\n",
    "            \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    ppr_list = []\n",
    "    pred_list = np.concatenate(pred_list).ravel()\n",
    "    s_list = np.concatenate(s_list).ravel()\n",
    "    \n",
    "    for s_value in np.unique(s_list):\n",
    "        if np.mean(s_list == s_value) > 0.01:\n",
    "            ppr_list.append(np.mean(pred_list[s_list==s_value]))\n",
    "            \n",
    "    \n",
    "    dp_gap = max(ppr_list) - min(ppr_list)\n",
    "      \n",
    "        \n",
    "    return total_loss/(idx+1), dp_gap\n",
    "\n",
    "\n",
    "# local learn\n",
    "def local_learning(model, mu:float, optimizer, train_data, epochs:int, loss_f):\n",
    "    \n",
    "    model_0=deepcopy(model)\n",
    "    \n",
    "    for e in range(epochs):\n",
    "        local_loss, local_fairness =train_step(model,model_0,mu,optimizer,train_data,loss_f)\n",
    "        \n",
    "    return float(local_loss.detach().numpy()), local_fairness\n",
    "\n",
    "# reset model parameter\n",
    "def set_to_zero_model_weights(model):\n",
    "    \"\"\"Set all the parameters of a model to 0\"\"\"\n",
    "\n",
    "    for layer_weigths in model.parameters():\n",
    "        layer_weigths.data.sub_(layer_weigths.data)\n",
    "\n",
    "        \n",
    "# aggregation\n",
    "def average_models(model, clients_models_hist:list , weights:list):\n",
    "\n",
    "\n",
    "    \"\"\"Creates the new model of a given iteration with the models of the other\n",
    "    clients\"\"\"\n",
    "    \n",
    "    new_model=deepcopy(model)\n",
    "    set_to_zero_model_weights(new_model)\n",
    "\n",
    "    for k,client_hist in enumerate(clients_models_hist):\n",
    "        \n",
    "        for idx, layer_weights in enumerate(new_model.parameters()):\n",
    "\n",
    "            contribution=client_hist[idx].data*weights[k]\n",
    "            layer_weights.data.add_(contribution)\n",
    "            \n",
    "    return new_model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1782ae77",
   "metadata": {},
   "outputs": [],
   "source": [
    "def FedAvg(model, training_sets:list, n_iter:int, testing_sets:list, mu=0, \n",
    "    file_name=\"test\", epochs = 50, lr=10**-2, decay=1):\n",
    "    \"\"\" all the clients are considered in this implementation of FedProx\n",
    "    Parameters:\n",
    "        - `model`: common structure used by the clients and the server\n",
    "        - `training_sets`: list of the training sets. At each index is the \n",
    "            training set of client \"index\"\n",
    "        - `n_iter`: number of iterations the server will run\n",
    "        - `testing_set`: list of the testing sets. If [], then the testing\n",
    "            accuracy is not computed\n",
    "        - `mu`: regularization term for FedProx. mu=0 for FedAvg\n",
    "        - `epochs`: number of epochs each client is running\n",
    "        - `lr`: learning rate of the optimizer\n",
    "        - `decay`: to change the learning rate at each iteration\n",
    "    \n",
    "    returns :\n",
    "        - `model`: the final global model \n",
    "    \"\"\"\n",
    "        \n",
    "    loss_f=loss_classifier\n",
    "    \n",
    "    #Variables initialization\n",
    "    K = len(training_sets) #number of clients\n",
    "    n_samples=sum([len(db.dataset) for db in training_sets])\n",
    "    weights=([len(db.dataset)/n_samples for db in training_sets])\n",
    "    print(\"Clients' weights:\",weights)\n",
    "    \n",
    "    \n",
    "    loss_hist=[[float(loss_dataset(model, dl, loss_f).detach()) \n",
    "        for dl in training_sets]]\n",
    "    acc_hist=[[accuracy_dataset(model, dl) for dl in testing_sets]]\n",
    "    fairness_hist = [[fairness_dataset(model, dl) for dl in testing_sets]]\n",
    "    server_hist=[[tens_param.detach().numpy() \n",
    "        for tens_param in list(model.parameters())]]\n",
    "    models_hist = []\n",
    "    \n",
    "    \n",
    "    server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "    server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "    server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "\n",
    "    print(f'====> i: 0 Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}')\n",
    "    \n",
    "    for i in range(n_iter):\n",
    "        \n",
    "        clients_params=[]\n",
    "        clients_models=[]\n",
    "        clients_losses=[]\n",
    "       \n",
    "        for k in range(K):\n",
    "            \n",
    "        \n",
    "            local_model=deepcopy(model)\n",
    "            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            \n",
    "            local_loss, local_fairness=local_learning(local_model,mu,local_optimizer,\n",
    "                training_sets[k],epochs,loss_f)\n",
    "            \n",
    "            clients_losses.append(local_loss)\n",
    "                \n",
    "            #GET THE PARAMETER TENSORS OF THE MODEL\n",
    "            list_params=list(local_model.parameters())\n",
    "            list_params=[tens_param.detach() for tens_param in list_params]\n",
    "            clients_params.append(list_params)    \n",
    "            clients_models.append(deepcopy(local_model))\n",
    "            \n",
    "\n",
    "        #CREATE THE NEW GLOBAL MODEL\n",
    "        model = average_models(deepcopy(model), clients_params, \n",
    "            weights=weights)\n",
    "        models_hist.append(clients_models)\n",
    "        \n",
    "        #COMPUTE THE LOSS/ACCURACY OF THE DIFFERENT CLIENTS WITH THE NEW MODEL\n",
    "        loss_hist+=[[float(loss_dataset(model, dl, loss_f).detach()) \n",
    "            for dl in training_sets]]\n",
    "        acc_hist+=[[accuracy_dataset(model, dl) for dl in testing_sets]]\n",
    "        fairness_hist+=[[fairness_dataset(model, dl) for dl in testing_sets]]\n",
    "\n",
    "        server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "        server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "        server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "\n",
    "        print(f'====> i: {i+1} Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}' )\n",
    "\n",
    "        server_hist.append([tens_param.detach().cpu().numpy() \n",
    "            for tens_param in list(model.parameters())])\n",
    "        \n",
    "        #DECREASING THE LEARNING RATE AT EACH SERVER ITERATION\n",
    "        lr*=decay\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            learned_W = model.fc.weight.item()\n",
    "            learned_b = model.fc.bias.item()\n",
    "            print(\"Decision threshold:\", -learned_b/learned_W)\n",
    "            \n",
    "    return model, loss_hist, acc_hist, fairness_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b99ba79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ClusterFL (model, training_sets:list, n_iter:int, testing_sets:list, mu=0, \n",
    "    file_name=\"test\", epochs = 50, lr=10**-2, decay=1):\n",
    "    \"\"\" all the clients are considered in this implementation of FedProx\n",
    "    Parameters:\n",
    "        - `model`: common structure used by the clients and the server\n",
    "        - `training_sets`: list of the training sets. At each index is the \n",
    "            training set of client \"index\"\n",
    "        - `n_iter`: number of iterations the server will run\n",
    "        - `testing_set`: list of the testing sets. If [], then the testing\n",
    "            accuracy is not computed\n",
    "        - `mu`: regularization term for FedProx. mu=0 for FedAvg\n",
    "        - `epochs`: number of epochs each client is running\n",
    "        - `lr`: learning rate of the optimizer\n",
    "        - `decay`: to change the learning rate at each iteration\n",
    "    \n",
    "    returns :\n",
    "        - `model`: the final global model \n",
    "    \"\"\"\n",
    "        \n",
    "    loss_f=loss_classifier\n",
    "    \n",
    "    #Variables initialization\n",
    "    K = len(training_sets) #number of clients\n",
    "    \n",
    "    identity = np.empty([n_iter+1, len(training_sets)])\n",
    "    model0 = model\n",
    "    model1 = model\n",
    "    M = [model0, model1]\n",
    "    \n",
    "    \n",
    "\n",
    "    n_samples=sum([len(db.dataset) for db in training_sets])\n",
    "    weights=([len(db.dataset)/n_samples for db in training_sets])\n",
    "    print(\"Clients' weights:\",weights)\n",
    "    \n",
    "    loss_hist=[[float(loss_dataset(model, dl, loss_f).detach()) \n",
    "        for dl in training_sets]]\n",
    "    acc_hist=[[accuracy_dataset(model, dl) for dl in testing_sets]]\n",
    "    fairness_hist = [[fairness_dataset(model, dl) for dl in testing_sets]]\n",
    "    server_hist=[[tens_param.detach().numpy() \n",
    "        for tens_param in list(model.parameters())]]\n",
    "    models_hist0 = []\n",
    "    models_hist1 = []\n",
    "    \n",
    "    \n",
    "    server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "    server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "    server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "\n",
    "    print(f'====> i: 0 Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}')\n",
    "    \n",
    "    for i in range(n_iter):\n",
    "        \n",
    "        clients_params0=[]\n",
    "        clients_models0=[]\n",
    "        clients_losses0=[]\n",
    "        \n",
    "        clients_params1=[]\n",
    "        clients_models1=[]\n",
    "        clients_losses1=[]\n",
    "       \n",
    "        for k in range(K): # for client k \n",
    "            local_model=deepcopy(model0)\n",
    "            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            local_loss0, local_fairness0=local_learning(model0,mu,local_optimizer,\n",
    "                training_sets[k],epochs,loss_f)\n",
    "            \n",
    "            local_model=deepcopy(model1)\n",
    "            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            local_loss1, local_fairness1=local_learning(model1,mu,local_optimizer,\n",
    "                training_sets[k],epochs,loss_f)\n",
    "        \n",
    "            if local_loss0 > local_loss1:\n",
    "                identity[n_iter, k] = 1\n",
    "            elif local_loss0 < local_loss1:\n",
    "                identity[n_iter, k] = 0\n",
    "            else:\n",
    "                identity[n_iter, k] = np.random.randint(0, 1)\n",
    "            \n",
    "\n",
    "            if  identity[n_iter, k] == 0:      \n",
    "                local_model=deepcopy(model0)\n",
    "                local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            \n",
    "                local_loss=local_learning(local_model,mu,local_optimizer,\n",
    "                    training_sets[k],epochs,loss_f)\n",
    "            \n",
    "                clients_losses0.append(local_loss)\n",
    "                \n",
    "                #GET THE PARAMETER TENSORS OF THE MODEL\n",
    "                list_params=list(local_model.parameters())\n",
    "                list_params=[tens_param.detach() for tens_param in list_params]\n",
    "                clients_params0.append(list_params)    \n",
    "                clients_models0.append(deepcopy(local_model))\n",
    "            else: \n",
    "                local_model=deepcopy(model1)\n",
    "                local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            \n",
    "                local_loss=local_learning(local_model,mu,local_optimizer,\n",
    "                    training_sets[k],epochs,loss_f)\n",
    "            \n",
    "                clients_losses1.append(local_loss)\n",
    "                \n",
    "                #GET THE PARAMETER TENSORS OF THE MODEL\n",
    "                list_params=list(local_model.parameters())\n",
    "                list_params=[tens_param.detach() for tens_param in list_params]\n",
    "                clients_params1.append(list_params)    \n",
    "                clients_models1.append(deepcopy(local_model))\n",
    "        \n",
    "        \n",
    "        #CREATE THE NEW GLOBAL MODEL\n",
    "        weights0=([1/len(clients_params0) for i in clients_params0])\n",
    "        model0 = average_models(deepcopy(model0), clients_params0, \n",
    "            weights=weights0)\n",
    "        models_hist0.append(clients_models0)\n",
    "        \n",
    "        weights1=([1/len(clients_params1) for i in clients_params1])\n",
    "        model1 = average_models(deepcopy(model1), clients_params1, \n",
    "            weights=weights1)\n",
    "        models_hist1.append(clients_models1)\n",
    "    \n",
    "\n",
    "        #COMPUTE THE LOSS/ACCURACY OF THE DIFFERENT CLIENTS WITH THE NEW MODEL\n",
    "        \n",
    "        loss_hist += [[float(loss_dataset(model0 if iden == 0 else model1, dl, loss_f).detach()) \n",
    "              for dl, iden in zip(training_sets, identity[n_iter])]]\n",
    "        \n",
    "        acc_hist += [[accuracy_dataset(model0 if iden == 0 else model1, dl)\n",
    "              for dl, iden in zip(testing_sets, identity[n_iter])]]\n",
    "        \n",
    "        fairness_hist += [[fairness_dataset(model0 if iden == 0 else model1, dl)\n",
    "              for dl, iden in zip(testing_sets, identity[n_iter])]]\n",
    "\n",
    "        server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "        server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "        server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "\n",
    "        print(f'====> i: {i+1} Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}' )\n",
    "\n",
    "        server_hist.append([tens_param.detach().cpu().numpy() \n",
    "            for tens_param in list(model.parameters())])\n",
    "        \n",
    "        #DECREASING THE LEARNING RATE AT EACH SERVER ITERATION\n",
    "        lr*=decay\n",
    "        with torch.no_grad():\n",
    "            learned_W1 = model1.fc.weight.item()\n",
    "            learned_b1 = model1.fc.bias.item()\n",
    "            learned_W0 = model0.fc.weight.item()\n",
    "            learned_b0 = model0.fc.bias.item()\n",
    "            \n",
    "            print(\"Decision threshold for cluster 1:\", -learned_b1/learned_W1)\n",
    "            print(\"Decision threshold for cluster 0:\", -learned_b0/learned_W0)\n",
    "            \n",
    "    return model, loss_hist, acc_hist, fairness_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b78c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Fair_FCA (model, training_sets:list, n_iter:int, testing_sets:list, mu=0, \n",
    "    file_name=\"test\", epochs = 50, lr=10**-2, decay=1):\n",
    "    \"\"\" all the clients are considered in this implementation of FedProx\n",
    "    Parameters:\n",
    "        - `model`: common structure used by the clients and the server\n",
    "        - `training_sets`: list of the training sets. At each index is the \n",
    "            training set of client \"index\"\n",
    "        - `n_iter`: number of iterations the server will run\n",
    "        - `testing_set`: list of the testing sets. If [], then the testing\n",
    "            accuracy is not computed\n",
    "        - `mu`: regularization term for FedProx. mu=0 for FedAvg\n",
    "        - `epochs`: number of epochs each client is running\n",
    "        - `lr`: learning rate of the optimizer\n",
    "        - `decay`: to change the learning rate at each iteration\n",
    "    \n",
    "    returns :\n",
    "        - `model`: the final global model \n",
    "    \"\"\"\n",
    "        \n",
    "    loss_f=loss_classifier\n",
    "    \n",
    "    #Variables initialization\n",
    "    K = len(training_sets) #number of clients\n",
    "    \n",
    "    identity = np.empty([n_iter+1, len(training_sets)])\n",
    "    model0 = model\n",
    "    model1 = model\n",
    "    M = [model0, model1]\n",
    "    \n",
    "    \n",
    "\n",
    "    n_samples=sum([len(db.dataset) for db in training_sets])\n",
    "    weights=([len(db.dataset)/n_samples for db in training_sets])\n",
    "    print(\"Clients' weights:\",weights)\n",
    "    \n",
    "    loss_hist=[[float(loss_dataset(model, dl, loss_f).detach()) \n",
    "        for dl in training_sets]]\n",
    "    acc_hist=[[accuracy_dataset(model, dl) for dl in testing_sets]]\n",
    "    fairness_hist = [[fairness_dataset(model, dl) for dl in testing_sets]]\n",
    "    server_hist=[[tens_param.detach().numpy() \n",
    "        for tens_param in list(model.parameters())]]\n",
    "    models_hist0 = []\n",
    "    models_hist1 = []\n",
    "    \n",
    "    \n",
    "    server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "    server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "    server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "\n",
    "    print(f'====> i: 0 Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}')\n",
    "    \n",
    "    for i in range(n_iter):\n",
    "        \n",
    "        clients_params0=[]\n",
    "        clients_models0=[]\n",
    "        clients_losses0=[]\n",
    "        \n",
    "        clients_params1=[]\n",
    "        clients_models1=[]\n",
    "        clients_losses1=[]\n",
    "       \n",
    "        for k in range(K): # for client k \n",
    "            local_model=deepcopy(model0)\n",
    "            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            local_loss0, local_fairness0=local_learning(model0,mu,local_optimizer,\n",
    "                training_sets[k],epochs,loss_f)\n",
    "            \n",
    "            local_model=deepcopy(model1)\n",
    "            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            local_loss1, local_fairness1=local_learning(model1,mu,local_optimizer,\n",
    "                training_sets[k],epochs,loss_f)\n",
    "    \n",
    "            if i >= 1:\n",
    "                metric0 = 0.5*(local_loss0/(local_loss0+local_loss1)) + 0.5*(local_fairness0/(local_fairness0+local_fairness1))\n",
    "                metric1 = 0.5*(local_loss1/(local_loss0+local_loss1)) + 0.5*(local_fairness1/(local_fairness0+local_fairness1))\n",
    "                \n",
    "                if metric0 > metric1:\n",
    "                    identity[n_iter, k] = 1\n",
    "                elif metric0 < metric1:\n",
    "                    identity[n_iter, k] = 0\n",
    "                else:\n",
    "                    identity[n_iter, k] = np.random.randint(0, 1)\n",
    "            else:\n",
    "                if local_loss0 > local_loss1:\n",
    "                    identity[n_iter, k] = 1\n",
    "                elif local_loss0 < local_loss1:\n",
    "                    identity[n_iter, k] = 0\n",
    "                else:\n",
    "                    identity[n_iter, k] = np.random.randint(0, 1)\n",
    "            \n",
    "            if  identity[n_iter, k] == 0:      \n",
    "                local_model=deepcopy(model0)\n",
    "                local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            \n",
    "                local_loss=local_learning(local_model,mu,local_optimizer,\n",
    "                    training_sets[k],epochs,loss_f)\n",
    "            \n",
    "                clients_losses0.append(local_loss)\n",
    "                \n",
    "                #GET THE PARAMETER TENSORS OF THE MODEL\n",
    "                list_params=list(local_model.parameters())\n",
    "                list_params=[tens_param.detach() for tens_param in list_params]\n",
    "                clients_params0.append(list_params)    \n",
    "                clients_models0.append(deepcopy(local_model))\n",
    "            else: \n",
    "                local_model=deepcopy(model1)\n",
    "                local_optimizer=optim.Adam(local_model.parameters(),lr=lr)\n",
    "            \n",
    "                local_loss=local_learning(local_model,mu,local_optimizer,\n",
    "                    training_sets[k],epochs,loss_f)\n",
    "            \n",
    "                clients_losses1.append(local_loss)\n",
    "                \n",
    "                #GET THE PARAMETER TENSORS OF THE MODEL\n",
    "                list_params=list(local_model.parameters())\n",
    "                list_params=[tens_param.detach() for tens_param in list_params]\n",
    "                clients_params1.append(list_params)    \n",
    "                clients_models1.append(deepcopy(local_model))\n",
    "        \n",
    "        \n",
    "        #CREATE THE NEW GLOBAL MODEL\n",
    "        weights0=([1/len(clients_params0) for i in clients_params0])\n",
    "        model0 = average_models(deepcopy(model0), clients_params0, \n",
    "            weights=weights0)\n",
    "        models_hist0.append(clients_models0)\n",
    "        \n",
    "        weights1=([1/len(clients_params1) for i in clients_params1])\n",
    "        model1 = average_models(deepcopy(model1), clients_params1, \n",
    "            weights=weights1)\n",
    "        models_hist1.append(clients_models1)\n",
    "    \n",
    "\n",
    "        #COMPUTE THE LOSS/ACCURACY OF THE DIFFERENT CLIENTS WITH THE NEW MODEL\n",
    "        \n",
    "        loss_hist += [[float(loss_dataset(model0 if iden == 0 else model1, dl, loss_f).detach()) \n",
    "              for dl, iden in zip(testing_sets, identity[n_iter])]]\n",
    "        \n",
    "        acc_hist += [[accuracy_dataset(model0 if iden == 0 else model1, dl)\n",
    "              for dl, iden in zip(testing_sets, identity[n_iter])]]\n",
    "        \n",
    "        fairness_hist += [[fairness_dataset(model0 if iden == 0 else model1, dl)\n",
    "              for dl, iden in zip(testing_sets, identity[n_iter])]]\n",
    "\n",
    "        \n",
    "        server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])\n",
    "        server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])\n",
    "        server_fairness=sum([weights[i]*fairness_hist[-1][i] for i in range(len(weights))])\n",
    "        print(f'====> i: {i+1} Loss: {server_loss} Server Test Accuracy: {server_acc} Server Fairness: {server_fairness}' )\n",
    "\n",
    "        server_hist.append([tens_param.detach().cpu().numpy() \n",
    "            for tens_param in list(model.parameters())])\n",
    "        \n",
    "        #DECREASING THE LEARNING RATE AT EACH SERVER ITERATION\n",
    "        lr*=decay\n",
    "        with torch.no_grad():\n",
    "            if len(weights1)!= 0:\n",
    "                learned_W1 = model1.fc.weight.item()\n",
    "                learned_b1 = model1.fc.bias.item()\n",
    "                print(\"Decision threshold for cluster 1:\", -learned_b1/learned_W1)\n",
    "            if len(weights0)!= 0:\n",
    "                learned_W0 = model0.fc.weight.item()\n",
    "                learned_b0 = model0.fc.bias.item()\n",
    "                print(\"Decision threshold for cluster 0:\", -learned_b0/learned_W0)\n",
    "            \n",
    "    return model, loss_hist, acc_hist, fairness_hist"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
