{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46b19989",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cmath import log\n",
    "import copy\n",
    "import logging\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import wandb\n",
    "import os\n",
    "from .client import Client\n",
    "from .my_model_trainer_classification import MyModelTrainer as MyModelTrainerCLS\n",
    "from .my_model_trainer_nwp import MyModelTrainer as MyModelTrainerNWP\n",
    "from .my_model_trainer_tag_prediction import MyModelTrainer as MyModelTrainerTAG\n",
    "import logging\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "class FedAvgAPI(object):\n",
    "    def __init__(self, args, device, dataset, model,model_trainer=None):\n",
    "        self.device = device\n",
    "        self.args = args\n",
    "        [\n",
    "            train_data_num,\n",
    "            test_data_num,\n",
    "            train_data_global,\n",
    "            test_data_global,\n",
    "            train_data_local_num_dict,\n",
    "            train_data_local_dict,\n",
    "            test_data_local_dict,\n",
    "            val_data_local_dict,\n",
    "            class_num,\n",
    "        ] = dataset\n",
    "        self.train_global = train_data_global\n",
    "        self.test_global = test_data_global\n",
    "        self.val_global = None\n",
    "        self.train_data_num_in_total = train_data_num\n",
    "        self.test_data_num_in_total = test_data_num\n",
    "\n",
    "        self.client_list = []\n",
    "        self.train_data_local_num_dict = train_data_local_num_dict\n",
    "        self.train_data_local_dict = train_data_local_dict\n",
    "        self.test_data_local_dict = test_data_local_dict\n",
    "        self.val_data_local_dict = val_data_local_dict\n",
    "\n",
    "        logging.info(\"model = {}\".format(model))\n",
    "        if model_trainer is None:\n",
    "            if args.dataset == \"stackoverflow_lr\":\n",
    "                model_trainer = MyModelTrainerTAG(model)\n",
    "            elif args.dataset in [\"fed_shakespeare\", \"stackoverflow_nwp\"]:\n",
    "                model_trainer = MyModelTrainerNWP(model)\n",
    "            else:\n",
    "                # default model trainer is for classification problem\n",
    "                model_trainer = MyModelTrainerCLS(model)\n",
    "        self.model_trainer = model_trainer\n",
    "        logging.info(\"self.model_trainer = {}\".format(self.model_trainer))\n",
    "\n",
    "        self._setup_clients(\n",
    "            train_data_local_num_dict,\n",
    "            train_data_local_dict,\n",
    "            test_data_local_dict,\n",
    "            val_data_local_dict,\n",
    "            self.model_trainer,\n",
    "        )\n",
    "\n",
    "    def _setup_clients(\n",
    "        self,\n",
    "        train_data_local_num_dict,\n",
    "        train_data_local_dict,\n",
    "        test_data_local_dict,\n",
    "        val_data_local_dict,\n",
    "        model_trainer,\n",
    "    ):\n",
    "        logging.info(\"############setup_clients (START)#############\")\n",
    "        for client_idx in self.args.users: \n",
    "            c = Client(\n",
    "                client_idx,\n",
    "                train_data_local_dict[client_idx],\n",
    "                test_data_local_dict[client_idx],\n",
    "                val_data_local_dict[client_idx],\n",
    "                train_data_local_num_dict[client_idx],\n",
    "                self.args,\n",
    "                self.device,\n",
    "                model_trainer,\n",
    "            )\n",
    "            self.client_list.append(c)\n",
    "        logging.info(\"############setup_clients (END)#############\")\n",
    "\n",
    "    def train(self):\n",
    "        logging.info(\"self.model_trainer = {}\".format(self.model_trainer))\n",
    "        w_global = self.model_trainer.get_model_params()\n",
    "        \n",
    "        plot_gap_using_local = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        plot_acc_using_local = np.empty([self.args.comm_round, self.args.client_num_in_total])     \n",
    "        plot_gap_using_global = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        plot_acc_using_global = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        plot_aggregate = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        plot_global_fairness = np.empty([self.args.comm_round])\n",
    "        plot_global_acc = np.empty([self.args.comm_round])\n",
    "        plot_local_acc_using_global_model = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        identity = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        plot_gap_using_local[:] = np.nan\n",
    "        plot_acc_using_local[:] = np.nan\n",
    "        plot_gap_using_global[:] = np.nan\n",
    "        plot_acc_using_global[:] = np.nan\n",
    "        plot_aggregate[:] = np.nan\n",
    "        w_save = []\n",
    "        \n",
    "        \"\"\"\n",
    "        w_global_0 = w_global\n",
    "        w_global_1 = w_global\n",
    "        w_g = [w_global_0, w_global_1]\n",
    "        \"\"\"\n",
    "        \n",
    "        for round_idx in range(self.args.comm_round):\n",
    "            \n",
    "            logging.info(\"################Communication round : {}\".format(round_idx))\n",
    "\n",
    "            w_locals = []\n",
    "            w_locals_0 = []\n",
    "            w_locals_1 = []\n",
    "            \n",
    "            client_indexes = self._client_sampling(\n",
    "                round_idx, self.args.client_num_in_total, self.args.client_num_per_round\n",
    "            )\n",
    "            logging.info(\"client_indexes = \" + str(client_indexes))\n",
    "\n",
    "            #w_save = [] \n",
    "            \n",
    "            \"\"\"\n",
    "            for idx, client_idx in enumerate(client_indexes):\n",
    "                \n",
    "                client = self.client_list[idx]\n",
    "                \n",
    "                # check loss and assign identity\n",
    "                w = client.train(copy.deepcopy(w_global_0))\n",
    "                test_local_metrics, _, _, _ = client.local_test(True)\n",
    "                loss_0 = test_local_metrics[\"test_loss\"]\n",
    "                w = client.train(copy.deepcopy(w_global_1))\n",
    "                test_local_metrics, _, _, _ = client.local_test(True)\n",
    "                loss_1 = test_local_metrics[\"test_loss\"]\n",
    "                if loss_0 < loss_1:\n",
    "                    identity[round_idx, idx] = 0\n",
    "                elif loss_0 > loss_1:\n",
    "                    identity[round_idx, idx] = 1\n",
    "                else: \n",
    "                    identity[round_idx, idx] = random.randint(0, 1)\n",
    "                \n",
    "                w_global = w_g[int(identity[round_idx, idx])]\n",
    "                w = client.train(copy.deepcopy(w_global))\n",
    "                test_local_metrics, _, _, _ = client.local_test(True)\n",
    "                plot_gap_using_local[round_idx, idx] = test_local_metrics[\"dp_gap\"]\n",
    "                plot_acc_using_local[round_idx, idx] = test_local_metrics[\"test_correct\"]/test_local_metrics[\"test_total\"]\n",
    "                w_locals.append((client.get_sample_number(), copy.deepcopy(w)))\n",
    "                if identity[round_idx, idx] == 1:\n",
    "                    w_locals_1.append((client.get_sample_number(), copy.deepcopy(w)))\n",
    "                else :\n",
    "                    w_locals_0.append((client.get_sample_number(), copy.deepcopy(w)))\n",
    "                w_save.append(copy.deepcopy(w))\n",
    "\n",
    "            # find aggregated global model and assign it to model parameter\n",
    "            w_global_1 = self._aggregate(w_locals_1,round_idx) \n",
    "            w_global_0 = self._aggregate(w_locals_0,round_idx) \n",
    "            \"\"\"\n",
    "            \n",
    "            for idx, client_idx in enumerate(client_indexes):\n",
    "                client = self.client_list[idx]  \n",
    "                \n",
    "                if round_idx > 1:\n",
    "                    \n",
    "                    if client_idx in new_index:\n",
    "                \n",
    "                        w = client.train(copy.deepcopy(w_global_0))\n",
    "                        identity[round_idx, idx] = 0\n",
    "                        \n",
    "                    else :\n",
    "                        \n",
    "                        w = client.train(copy.deepcopy(w_global_1))\n",
    "                        identity[round_idx, idx] = 1\n",
    "                     \n",
    "                else:                     \n",
    "                    # round = 0\n",
    "                    w = client.train(copy.deepcopy(w_global))\n",
    "                    \n",
    "                w_locals.append((client.get_sample_number(), copy.deepcopy(w)))\n",
    "                w_save.append(copy.deepcopy(w))\n",
    "\n",
    "            #w_global = self._aggregate(w_locals,round_idx)\n",
    "            #self.model_trainer.set_model_params(w_global)\n",
    "\n",
    "            if round_idx < 1:\n",
    "                w_global = self._aggregate(w_locals,round_idx)\n",
    "                w_global_0 = w_global\n",
    "                w_global_1 = w_global\n",
    "            else :\n",
    "                w_p0_local = []\n",
    "                w_p1_local = []\n",
    "                new_index = [1, 10, 20, 30, 40, 50]\n",
    "                for i in new_index:\n",
    "                    w_p0_local.append(w_locals[i])\n",
    "                w_global_0 = self._aggregate(w_p0_local,round_idx)\n",
    "                new_index_1 = [0,2,3,4,5,6,7,8,9,\n",
    "                               11,12,13,14,15,16,17,18,19,\n",
    "                               21,22,23,24,25,26,27,28,29,\n",
    "                               31,32,33,34,35,36,37,38,39,\n",
    "                               41,42,43,44,45,46,47,48,49]\n",
    "                for j in new_index_1:\n",
    "                    w_p1_local.append(w_locals[j])\n",
    "                w_global_1 = self._aggregate(w_p1_local,round_idx)\n",
    "            \n",
    "            # save information\n",
    "            if round_idx % self.args.save_epoches == 0: \n",
    "                torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, \"%s_at_%s.pt\" %(self.args.save_model_name,round_idx))) # check the fedavg model name\n",
    "                with open(\"%s/%s_locals_at_%s.pt\" %(self.args.run_folder,self.args.save_model_name,round_idx),'wb') as f:\n",
    "                    pickle.dump(w_save, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "            if round_idx == self.args.comm_round - 1 or round_idx % self.args.frequency_of_the_test == 0:\n",
    "                # check dp_gap for all clients using global model\n",
    "                plot_gap_each_global, dp_gap_test_global, test_accuracy, loc_acc_global = self._local_test_on_all_clients(round_idx, w_global, w_global_0, w_global_1)\n",
    "                \n",
    "                dp_gap_test_global, test_accuracy = self._local_test_on_all_clients_global(round_idx,w_global, w_global_0, w_global_1)\n",
    "                \n",
    "                print (np.mean(loc_acc_global), test_accuracy)\n",
    "                print (np.mean(plot_gap_each_global), dp_gap_test_global)\n",
    "        \n",
    "                plot_aggregate[round_idx] = plot_gap_each_global\n",
    "                plot_global_fairness[round_idx] = dp_gap_test_global\n",
    "                plot_global_acc[round_idx] = test_accuracy\n",
    "                plot_local_acc_using_global_model[round_idx,:] = loc_acc_global\n",
    "                \n",
    "                \n",
    "        \"\"\"\n",
    "        np.savetxt('income_test_Cluster_FL_60_global_acc_seed0.csv', plot_global_acc, delimiter=',')\n",
    "        np.savetxt('income_test_Cluster_FL_60_global_fairness_seed0.csv', plot_global_fairness, delimiter=',')\n",
    "        np.savetxt('income_test_Cluster_FL_60_local_acc_using_global_model_seed0.csv', plot_local_acc_using_global_model, delimiter=',')\n",
    "        np.savetxt('income_test_Cluster_FL_60_aggregate_seed0.csv', plot_aggregate, delimiter=',')\n",
    "        np.savetxt('income_test_Cluster_FL_60_identity_seed0.csv', identity, delimiter=',')\n",
    "        \"\"\"\n",
    "\n",
    "      \n",
    "    def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round):\n",
    "        if client_num_in_total == client_num_per_round:\n",
    "            client_indexes = self.args.users\n",
    "        else:\n",
    "            num_clients = min(client_num_per_round, client_num_in_total)\n",
    "            np.random.seed(\n",
    "                round_idx\n",
    "            )  # make sure for each comparison, we are selecting the same clients each round\n",
    "            client_indexes = np.random.choice(\n",
    "                self.args.users, num_clients, replace=False\n",
    "            )\n",
    "            np.random.seed(self.args.random_seed)\n",
    "        logging.info(\"client_indexes = %s\" % str(client_indexes))\n",
    "        return client_indexes\n",
    "\n",
    "    def _generate_validation_set(self, num_samples=10000):\n",
    "        return False\n",
    "\n",
    "    def _aggregate(self, w_locals,round_idx):\n",
    "        \n",
    "        training_num = 0\n",
    "        for idx in range(len(w_locals)):\n",
    "            (sample_num, averaged_params) = w_locals[idx]\n",
    "            training_num += sample_num\n",
    "\n",
    "        (sample_num, averaged_params) = w_locals[0]\n",
    "        for k in averaged_params.keys():\n",
    "            for i in range(0, len(w_locals)):\n",
    "                local_sample_number, local_model_params = w_locals[i]\n",
    "                w = local_sample_number / training_num\n",
    "                if i == 0:\n",
    "                    averaged_params[k] = local_model_params[k] * w\n",
    "                else:\n",
    "                    averaged_params[k] += local_model_params[k] * w\n",
    "        \n",
    "        return averaged_params\n",
    "\n",
    "    def _aggregate_noniid_avg(self, w_locals):\n",
    "        # uniform aggregation\n",
    "        \"\"\"\n",
    "        The old aggregate method will impact the model performance when it comes to Non-IID setting\n",
    "        Args:\n",
    "            w_locals:\n",
    "        Returns:\n",
    "        \"\"\"\n",
    "        (_, averaged_params) = w_locals[0]\n",
    "        for k in averaged_params.keys():\n",
    "            temp_w = []\n",
    "            for (_, local_w) in w_locals:\n",
    "                temp_w.append(local_w[k])\n",
    "            averaged_params[k] = sum(temp_w) / len(temp_w)\n",
    "        return averaged_params\n",
    "\n",
    "    \n",
    "    def _local_test_on_all_clients_global(self, round_idx, w_global, w_global_0, w_global_1):\n",
    "\n",
    "        test_metrics = {\"num_samples\": [], \"num_correct\": [], \"losses\": [], \"eo_gap\":[],\"dp_gap\":[]}\n",
    "\n",
    "        test_target_list_global = []\n",
    "        test_pred_list_global = []\n",
    "        test_s_list_global = []\n",
    "        \n",
    "        for key in w_global: \n",
    "                w_global[key] =  6/51 * w_global_0[key] + 45/51 * w_global_1[key]\n",
    "            \n",
    "        self.model_trainer.set_model_params(w_global)\n",
    "        \n",
    "        \n",
    "        for idx,client_idx in enumerate(self.args.users):\n",
    "            \n",
    "            if self.test_data_local_dict[client_idx] is None:\n",
    "                continue\n",
    "\n",
    "            client = self.client_list[idx]\n",
    "\n",
    "            test_local_metrics, test_target_list, test_pred_list, test_s_list = client.local_test(True)\n",
    "            test_metrics[\"num_samples\"].append(copy.deepcopy(test_local_metrics[\"test_total\"]))\n",
    "            test_metrics[\"num_correct\"].append(copy.deepcopy(test_local_metrics[\"test_correct\"]))\n",
    "            test_metrics[\"losses\"].append(copy.deepcopy(test_local_metrics[\"test_loss\"]))\n",
    "            test_metrics[\"eo_gap\"].append(copy.deepcopy(test_local_metrics[\"eo_gap\"]))\n",
    "            test_metrics[\"dp_gap\"].append(copy.deepcopy(test_local_metrics[\"dp_gap\"]))\n",
    "            test_target_list_global.append(test_target_list.tolist())\n",
    "            test_pred_list_global.append(test_pred_list.tolist())\n",
    "            test_s_list_global.append(test_s_list.tolist())\n",
    "            \n",
    "            \n",
    "        test_target_list_global = np.array(sum(test_target_list_global,[]))\n",
    "        test_pred_list_global = np.array(sum(test_pred_list_global,[]))\n",
    "        test_s_list_global = np.array(sum(test_s_list_global,[]))\n",
    "\n",
    "        pred_test_acc = ( test_pred_list_global==test_target_list_global)\n",
    "        ppr_test_global = []\n",
    "        tnr_test_global = []\n",
    "        tpr_test_global = []\n",
    "        converted_test_s = test_s_list_global[:,1] # sex, 1 attribute\n",
    "        \n",
    "        for s_value in np.unique(converted_test_s):\n",
    "            if np.mean(converted_test_s == s_value) > 0.01:\n",
    "                indexs0  = np.logical_and(test_target_list_global==0, converted_test_s==s_value)\n",
    "                indexs1  = np.logical_and(test_target_list_global==1, converted_test_s==s_value)\n",
    "                ppr_test_global.append(np.mean(test_pred_list_global[converted_test_s==s_value]))\n",
    "                tnr_test_global.append(np.mean(pred_test_acc[indexs0]))\n",
    "                tpr_test_global.append(np.mean(pred_test_acc[indexs1]))\n",
    "\n",
    "        eo_gap_test_global = max(max(tnr_test_global)-min(tnr_test_global), max(tpr_test_global)-min(tpr_test_global))\n",
    "        dp_gap_test_global = max(ppr_test_global) - min(ppr_test_global)\n",
    "\n",
    "        test_global_acc = np.mean(pred_test_acc)\n",
    "\n",
    "        return  dp_gap_test_global, test_global_acc\n",
    "\n",
    "    \n",
    "    def _local_test_on_all_clients(self, round_idx, w_global, w_global_0, w_global_1):\n",
    "\n",
    "        logging.info(\"################local_test_on_all_clients : {}\".format(round_idx))\n",
    "\n",
    "        train_metrics = {\"num_samples\": [], \"num_correct\": [], \"losses\": [], \"eo_gap\":[],\"dp_gap\":[]}\n",
    "\n",
    "        test_metrics = {\"num_samples\": [], \"num_correct\": [], \"losses\": [], \"eo_gap\":[],\"dp_gap\":[]}\n",
    "\n",
    "        train_target_list_global = []\n",
    "        train_pred_list_global = []\n",
    "        train_s_list_global = []\n",
    "        test_target_list_global = []\n",
    "        test_pred_list_global = []\n",
    "        test_s_list_global = []\n",
    "        local_train_acc = []\n",
    "        local_test_acc = []\n",
    "        p_1_local = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "        p_0_local = np.empty([self.args.comm_round, self.args.client_num_in_total])\n",
    "\n",
    "        for idx,client_idx in enumerate(self.args.users):\n",
    "            \n",
    "            \n",
    "            if self.test_data_local_dict[client_idx] is None:\n",
    "                continue\n",
    "                \n",
    "            if round_idx >= 1:\n",
    "            \n",
    "                if client_idx in [1, 10, 20, 30, 40, 50]:\n",
    "                    self.model_trainer.set_model_params(w_global_0)\n",
    "                else: \n",
    "                    self.model_trainer.set_model_params(w_global_1)\n",
    "            else:\n",
    "                self.model_trainer.set_model_params(w_global)\n",
    "            \n",
    "            client = self.client_list[idx]\n",
    "            train_local_metrics, train_target_list, train_pred_list, train_s_list = client.local_test(False)\n",
    "            train_metrics[\"num_samples\"].append(copy.deepcopy(train_local_metrics[\"test_total\"]))\n",
    "            train_metrics[\"num_correct\"].append(copy.deepcopy(train_local_metrics[\"test_correct\"]))\n",
    "            train_metrics[\"losses\"].append(copy.deepcopy(train_local_metrics[\"test_loss\"]))\n",
    "            train_metrics[\"eo_gap\"].append(copy.deepcopy(train_local_metrics[\"eo_gap\"]))\n",
    "            train_metrics[\"dp_gap\"].append(copy.deepcopy(train_local_metrics[\"dp_gap\"]))\n",
    "            train_target_list_global.append(train_target_list.tolist())\n",
    "            train_pred_list_global.append(train_pred_list.tolist())\n",
    "            train_s_list_global.append(train_s_list.tolist())\n",
    "            local_train_acc.append(train_local_metrics[\"test_correct\"]/train_local_metrics[\"test_total\"])\n",
    "            \n",
    "            test_local_metrics, test_target_list, test_pred_list, test_s_list = client.local_test(True)\n",
    "            test_metrics[\"num_samples\"].append(copy.deepcopy(test_local_metrics[\"test_total\"]))\n",
    "            test_metrics[\"num_correct\"].append(copy.deepcopy(test_local_metrics[\"test_correct\"]))\n",
    "            test_metrics[\"losses\"].append(copy.deepcopy(test_local_metrics[\"test_loss\"]))\n",
    "            test_metrics[\"eo_gap\"].append(copy.deepcopy(test_local_metrics[\"eo_gap\"]))\n",
    "            test_metrics[\"dp_gap\"].append(copy.deepcopy(test_local_metrics[\"dp_gap\"]))\n",
    "            test_target_list_global.append(test_target_list.tolist())\n",
    "            test_pred_list_global.append(test_pred_list.tolist())\n",
    "            test_s_list_global.append(test_s_list.tolist())\n",
    "            local_test_acc.append(test_local_metrics[\"test_correct\"]/test_local_metrics[\"test_total\"])\n",
    "            p_1_local[round_idx, idx] = np.mean(np.logical_and(test_pred_list==1, test_s_list[:,1]==1))\n",
    "            p_0_local[round_idx, idx] = np.mean(np.logical_and(test_pred_list==1, test_s_list[:,1]==0))\n",
    "            \n",
    "\n",
    "        train_target_list_global = np.array(sum(train_target_list_global, []))\n",
    "        train_pred_list_global = np.array(sum(train_pred_list_global,[]))\n",
    "        train_s_list_global = np.array(sum(train_s_list_global,[]))\n",
    "        test_target_list_global = np.array(sum(test_target_list_global,[]))\n",
    "        test_pred_list_global = np.array(sum(test_pred_list_global,[]))\n",
    "        test_s_list_global = np.array(sum(test_s_list_global,[]))\n",
    "\n",
    "        pred_train_acc = ( train_pred_list_global==train_target_list_global)\n",
    "        pred_test_acc = ( test_pred_list_global==test_target_list_global)\n",
    "        ppr_train_global = []\n",
    "        tnr_train_global = []\n",
    "        tpr_train_global = []\n",
    "        ppr_test_global = []\n",
    "        tnr_test_global = []\n",
    "        tpr_test_global = []\n",
    "        converted_train_s = train_s_list_global[:,1] # sex, 1 attribute\n",
    "        converted_test_s = test_s_list_global[:,1] # sex, 1 attribute\n",
    "        \n",
    "        for s_value in np.unique(converted_train_s):\n",
    "            if np.mean(converted_train_s == s_value) > 0.01:\n",
    "                indexs0  = np.logical_and(train_target_list_global==0, converted_train_s==s_value)\n",
    "                indexs1  = np.logical_and(train_target_list_global==1, converted_train_s==s_value)\n",
    "                ppr_train_global.append(np.mean(train_pred_list_global[converted_train_s==s_value]))\n",
    "                tnr_train_global.append(np.mean(pred_train_acc[indexs0]))\n",
    "                tpr_train_global.append(np.mean(pred_train_acc[indexs1]))\n",
    "               \n",
    "        for s_value in np.unique(converted_test_s):\n",
    "            if np.mean(converted_test_s == s_value) > 0.01:\n",
    "                indexs0  = np.logical_and(test_target_list_global==0, converted_test_s==s_value)\n",
    "                indexs1  = np.logical_and(test_target_list_global==1, converted_test_s==s_value)\n",
    "                ppr_test_global.append(np.mean(test_pred_list_global[converted_test_s==s_value]))\n",
    "                tnr_test_global.append(np.mean(pred_test_acc[indexs0]))\n",
    "                tpr_test_global.append(np.mean(pred_test_acc[indexs1]))\n",
    "\n",
    "        eo_gap_train_global = max(max(tnr_train_global)-min(tnr_train_global), max(tpr_train_global)-min(tpr_train_global))\n",
    "        dp_gap_train_global = max(ppr_train_global) - min(ppr_train_global)\n",
    "        eo_gap_test_global = max(max(tnr_test_global)-min(tnr_test_global), max(tpr_test_global)-min(tpr_test_global))\n",
    "        dp_gap_test_global = max(ppr_test_global) - min(ppr_test_global)\n",
    "\n",
    "        train_global_acc = np.mean(pred_train_acc)\n",
    "        test_global_acc = np.mean(pred_test_acc)\n",
    "        \n",
    "        # test on training dataset\n",
    "        train_acc = sum(train_metrics[\"num_correct\"]) / sum(train_metrics[\"num_samples\"])\n",
    "        train_loss = sum(train_metrics[\"losses\"]) / sum(train_metrics[\"num_samples\"])\n",
    "        train_dp_gap = sum(train_metrics[\"dp_gap\"])/len(self.args.users)\n",
    "        train_eo_gap = sum(train_metrics[\"eo_gap\"])/len(self.args.users)\n",
    "\n",
    "        # test on test dataset\n",
    "        test_acc = sum(test_metrics[\"num_correct\"]) / sum(test_metrics[\"num_samples\"])\n",
    "        test_loss = sum(test_metrics[\"losses\"]) / sum(test_metrics[\"num_samples\"])\n",
    "        test_dp_gap = sum(test_metrics[\"dp_gap\"])/len(self.args.users)\n",
    "        test_eo_gap = sum(test_metrics[\"eo_gap\"])/len(self.args.users)\n",
    "        logging.info('dp_gap' + str(train_metrics[\"dp_gap\"]))\n",
    "        logging.info('Train acc: {} Train Loss: {}, Test acc: {} Test Loss: {}'.format(train_acc, train_loss, test_acc,test_loss))\n",
    "        logging.info('Train dp gap: {} Train eo gap: {}, Test dp gap: {} Test eo gap: {}'.format(train_dp_gap,train_eo_gap, test_dp_gap,test_eo_gap))\n",
    "\n",
    "        if self.args.enable_wandb:\n",
    "            wandb.log({\"Test/Acc\": test_acc, \"round\": round_idx})\n",
    "            wandb.log({\"Test/Loss\": test_loss, \"round\": round_idx})\n",
    "            wandb.log({\"Train/Acc\": train_acc, \"round\": round_idx})\n",
    "            wandb.log({\"Train/Loss\": train_loss, \"round\": round_idx})\n",
    "       \n",
    "        \n",
    "        return test_metrics[\"dp_gap\"], dp_gap_test_global, test_global_acc, local_test_acc\n",
    "\n",
    "    \n",
    "    \n",
    "    \n",
    "    def _local_test_on_validation_set(self, round_idx):\n",
    "\n",
    "        logging.info(\n",
    "            \"################local_test_on_validation_set : {}\".format(round_idx)\n",
    "        )\n",
    "\n",
    "        if self.val_global is None:\n",
    "            self._generate_validation_set()\n",
    "\n",
    "        client = self.client_list[0]\n",
    "        client.update_local_dataset(0, None, self.val_global, None)\n",
    "        # test data\n",
    "        test_metrics = client.local_test(True)\n",
    "\n",
    "        if self.args.dataset == \"stackoverflow_nwp\":\n",
    "            test_acc = test_metrics[\"test_correct\"] / test_metrics[\"test_total\"]\n",
    "            test_loss = test_metrics[\"test_loss\"] / test_metrics[\"test_total\"]\n",
    "            stats = {\"test_acc\": test_acc, \"test_loss\": test_loss}\n",
    "            if self.args.enable_wandb:\n",
    "                wandb.log({\"Test/Acc\": test_acc, \"round\": round_idx})\n",
    "                wandb.log({\"Test/Loss\": test_loss, \"round\": round_idx})\n",
    "        elif self.args.dataset == \"stackoverflow_lr\":\n",
    "            test_acc = test_metrics[\"test_correct\"] / test_metrics[\"test_total\"]\n",
    "            test_pre = test_metrics[\"test_precision\"] / test_metrics[\"test_total\"]\n",
    "            test_rec = test_metrics[\"test_recall\"] / test_metrics[\"test_total\"]\n",
    "            test_loss = test_metrics[\"test_loss\"] / test_metrics[\"test_total\"]\n",
    "            stats = {\n",
    "                \"test_acc\": test_acc,\n",
    "                \"test_pre\": test_pre,\n",
    "                \"test_rec\": test_rec,\n",
    "                \"test_loss\": test_loss,\n",
    "            }\n",
    "            if self.args.enable_wandb:\n",
    "                wandb.log({\"Test/Acc\": test_acc, \"round\": round_idx})\n",
    "                wandb.log({\"Test/Pre\": test_pre, \"round\": round_idx})\n",
    "                wandb.log({\"Test/Rec\": test_rec, \"round\": round_idx})\n",
    "                wandb.log({\"Test/Loss\": test_loss, \"round\": round_idx})\n",
    "        else:\n",
    "            raise Exception(\n",
    "                \"Unknown format to log metrics for dataset {}!\" % self.args.dataset\n",
    "            )\n",
    "\n",
    "        logging.info(stats)\n",
    "\n",
    "    def save(self):\n",
    "        torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, \"%s.pt\" %(self.args.save_model_name))) # check the fedavg model name\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
