{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Paramters :  Namespace(rounds=101, num_users=100, nclass=2, nsample_pc=250, frac=0.2, local_ep=10, local_bs=10, bs=128, lr=0.01, momentum=0.5, warmup_epoch=0, trial=1, mu=0.001, model='simple-cnn', ks=5, in_ch=3, dataset='cifar10', noniid=False, shard=False, label=False, split_test=False, savedir='../save_results/', datadir='../data/', logdir='../logs/', partition='non-iid3', alg='pacfl', beta=0.1, local_view=True, batch_size=64, noise=0, noise_type='level', cluster_alpha=1.37, n_basis=3, linkage='average', nclasses=10, nsamples_shared=2500, nclusters=3, num_incluster_layers=2, pruning_percent=10, pruning_target=30, dist_thresh=0.0001, acc_thresh=50, weight_decay=0.0001, gpu=-1, is_print=False, print_freq=10, seed=1, load_initial='', device=device(type='cpu'))\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext autoreload\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import copy\n",
    "import os \n",
    "import gc \n",
    "import pickle\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from src.data import *\n",
    "from src.models import *\n",
    "from src.fedavg import *\n",
    "from src.client import * \n",
    "from src.clustering import *\n",
    "from src.utils import * \n",
    "\n",
    "st=time.time()\n",
    "args = args_parser()\n",
    "\n",
    "args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "torch.cuda.set_device(args.gpu) ## Setting cuda on GPU \n",
    "\n",
    "def mkdirs(dirpath):\n",
    "    try:\n",
    "        os.makedirs(dirpath)\n",
    "    except Exception as _:\n",
    "        pass\n",
    "\n",
    "args.local_view=True\n",
    "args.model='simple-cnn'\n",
    "args.dataset='cifar10'\n",
    "args.partition='flag-non-iid'\n",
    "args.num_users=100\n",
    "args.rounds=201\n",
    "args.frac=.2\n",
    "#args.cluster_alpha\n",
    "print(\"Paramters : \",str(args))\n",
    "path = args.savedir + args.alg + '/' + args.partition + '/' + args.dataset + '/'\n",
    "mkdirs(path)\n",
    "\n",
    "template = \"Algorithm {}, Clients {}, Dataset {}, Model {}, Non-IID {}, Threshold {}, K {}, Linkage {}, LR {}, Ep {}, Rounds {}, bs {}, frac {}\"\n",
    "\n",
    "s = template.format(args.alg, args.num_users, args.dataset, args.model, args.partition, args.cluster_alpha, args.n_basis, args.linkage, args.lr, args.local_ep, args.rounds, args.local_ep, args.frac)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "hello  5\n",
      "partition: non-iid3\n",
      "Data statistics Train:\n",
      " {0: {1: 81, 3: 310, 5: 19}, 1: {1: 163, 3: 217, 5: 229}, 2: {3: 9, 5: 46, 7: 127}, 3: {0: 452, 2: 104, 6: 123}, 4: {4: 208, 8: 319, 9: 365}, 5: {3: 85, 5: 110, 7: 309}, 6: {0: 323, 2: 132, 6: 65}, 7: {1: 443, 4: 97, 6: 173}, 8: {0: 342, 2: 90, 6: 158}, 9: {1: 3, 4: 83, 6: 98}, 10: {0: 328, 2: 347, 6: 235}, 11: {3: 198, 5: 375, 7: 150}, 12: {4: 151, 8: 232, 9: 34}, 13: {0: 85, 2: 240, 6: 186}, 14: {0: 385, 2: 539, 6: 43}, 15: {4: 105, 8: 418, 9: 630}, 16: {3: 91, 5: 90, 7: 470}, 17: {3: 439, 5: 101, 7: 171}, 18: {1: 28, 3: 28, 5: 42}, 19: {1: 39, 4: 234, 6: 51}, 20: {3: 46, 5: 330, 7: 183}, 21: {1: 42, 4: 221, 6: 5}, 22: {1: 4, 4: 77, 6: 208}, 23: {3: 6, 5: 26, 7: 706}, 24: {0: 440, 2: 3, 6: 133}, 25: {1: 59, 3: 11, 5: 79}, 26: {1: 66, 3: 469, 5: 116}, 27: {4: 46, 8: 208, 9: 258}, 28: {3: 93, 5: 74, 7: 214}, 29: {4: 31, 8: 91, 9: 292}, 30: {3: 131, 5: 103, 7: 107}, 31: {4: 164, 8: 611, 9: 182}, 32: {3: 100, 5: 107, 7: 27}, 33: {3: 158, 5: 182, 7: 394}, 34: {1: 29, 3: 228, 5: 133}, 35: {4: 63, 8: 31, 9: 370}, 36: {3: 41, 5: 169, 7: 69}, 37: {1: 125, 4: 224, 6: 267}, 38: {3: 205, 5: 239, 7: 225}, 39: {1: 285, 4: 39, 6: 7}, 40: {1: 308, 3: 142, 5: 103}, 41: {3: 51, 5: 73, 7: 2}, 42: {3: 184, 5: 325, 7: 134}, 43: {0: 235, 2: 18, 6: 501}, 44: {4: 215, 8: 6, 9: 552}, 45: {0: 610, 2: 5, 6: 40}, 46: {0: 30, 2: 493, 6: 104}, 47: {3: 62, 5: 112, 7: 37}, 48: {1: 3, 3: 2, 5: 69}, 49: {4: 356, 8: 467, 9: 22}, 50: {1: 289, 4: 23, 6: 18}, 51: {0: 139, 2: 96, 6: 61}, 52: {4: 35, 8: 47, 9: 122}, 53: {1: 139, 3: 45, 5: 3}, 54: {1: 120, 4: 29, 6: 107}, 55: {1: 182, 3: 100, 5: 335}, 56: {0: 279, 2: 972, 6: 260}, 57: {1: 149, 4: 404, 6: 98}, 58: {1: 213, 3: 86, 5: 8}, 59: {0: 158, 2: 733, 6: 16}, 60: {4: 9, 8: 144, 9: 155}, 61: {3: 53, 5: 159, 7: 279}, 62: {3: 233, 5: 351, 7: 153}, 63: {3: 20, 5: 31, 7: 9}, 64: {1: 5, 3: 127, 5: 16}, 65: {0: 28, 2: 38, 6: 463}, 66: {1: 163, 3: 162, 5: 8}, 67: {1: 96, 4: 49, 6: 73}, 68: {4: 152, 8: 104, 9: 8}, 69: {1: 63, 4: 231, 6: 55}, 70: {1: 1, 3: 42, 5: 191}, 71: {3: 69, 5: 4, 7: 77}, 72: {1: 22, 3: 64, 5: 37}, 73: {1: 229, 4: 202, 6: 271}, 74: {1: 351, 4: 32, 6: 31}, 75: {1: 18, 4: 56, 6: 163}, 76: {1: 115, 4: 159, 6: 128}, 77: {4: 82, 8: 459, 9: 591}, 78: {1: 83, 4: 26, 6: 15}, 79: {3: 54, 5: 147, 7: 846}, 80: {1: 38, 4: 48, 6: 4}, 81: {0: 308, 2: 71, 6: 151}, 82: {1: 67, 3: 148, 5: 146}, 83: {3: 367, 5: 10, 7: 82}, 84: {1: 215, 4: 102, 6: 50}, 85: {4: 30, 8: 18, 9: 52}, 86: {1: 1, 4: 81, 6: 34}, 87: {0: 616, 2: 464, 6: 48}, 88: {1: 34, 4: 34, 6: 83}, 89: {4: 2, 8: 131, 9: 470}, 90: {4: 230, 8: 658, 9: 613}, 91: {1: 11, 4: 200, 6: 13}, 92: {3: 7, 5: 300, 7: 229}, 93: {0: 54, 2: 69, 6: 59}, 94: {1: 664, 4: 34, 6: 56}, 95: {1: 17, 4: 234, 6: 132}, 96: {0: 12, 2: 178, 6: 63}, 97: {1: 37, 3: 117, 5: 2}, 98: {4: 202, 8: 1056, 9: 284}, 99: {0: 176, 2: 408, 6: 151}} \n",
      "\n",
      "Data statistics Test:\n",
      " {0: {1: 1000, 3: 1000, 5: 1000}, 1: {1: 1000, 3: 1000, 5: 1000}, 2: {3: 1000, 5: 1000, 7: 1000}, 3: {0: 1000, 2: 1000, 6: 1000}, 4: {4: 1000, 8: 1000, 9: 1000}, 5: {3: 1000, 5: 1000, 7: 1000}, 6: {0: 1000, 2: 1000, 6: 1000}, 7: {1: 1000, 4: 1000, 6: 1000}, 8: {0: 1000, 2: 1000, 6: 1000}, 9: {1: 1000, 4: 1000, 6: 1000}, 10: {0: 1000, 2: 1000, 6: 1000}, 11: {3: 1000, 5: 1000, 7: 1000}, 12: {4: 1000, 8: 1000, 9: 1000}, 13: {0: 1000, 2: 1000, 6: 1000}, 14: {0: 1000, 2: 1000, 6: 1000}, 15: {4: 1000, 8: 1000, 9: 1000}, 16: {3: 1000, 5: 1000, 7: 1000}, 17: {3: 1000, 5: 1000, 7: 1000}, 18: {1: 1000, 3: 1000, 5: 1000}, 19: {1: 1000, 4: 1000, 6: 1000}, 20: {3: 1000, 5: 1000, 7: 1000}, 21: {1: 1000, 4: 1000, 6: 1000}, 22: {1: 1000, 4: 1000, 6: 1000}, 23: {3: 1000, 5: 1000, 7: 1000}, 24: {0: 1000, 2: 1000, 6: 1000}, 25: {1: 1000, 3: 1000, 5: 1000}, 26: {1: 1000, 3: 1000, 5: 1000}, 27: {4: 1000, 8: 1000, 9: 1000}, 28: {3: 1000, 5: 1000, 7: 1000}, 29: {4: 1000, 8: 1000, 9: 1000}, 30: {3: 1000, 5: 1000, 7: 1000}, 31: {4: 1000, 8: 1000, 9: 1000}, 32: {3: 1000, 5: 1000, 7: 1000}, 33: {3: 1000, 5: 1000, 7: 1000}, 34: {1: 1000, 3: 1000, 5: 1000}, 35: {4: 1000, 8: 1000, 9: 1000}, 36: {3: 1000, 5: 1000, 7: 1000}, 37: {1: 1000, 4: 1000, 6: 1000}, 38: {3: 1000, 5: 1000, 7: 1000}, 39: {1: 1000, 4: 1000, 6: 1000}, 40: {1: 1000, 3: 1000, 5: 1000}, 41: {3: 1000, 5: 1000, 7: 1000}, 42: {3: 1000, 5: 1000, 7: 1000}, 43: {0: 1000, 2: 1000, 6: 1000}, 44: {4: 1000, 8: 1000, 9: 1000}, 45: {0: 1000, 2: 1000, 6: 1000}, 46: {0: 1000, 2: 1000, 6: 1000}, 47: {3: 1000, 5: 1000, 7: 1000}, 48: {1: 1000, 3: 1000, 5: 1000}, 49: {4: 1000, 8: 1000, 9: 1000}, 50: {1: 1000, 4: 1000, 6: 1000}, 51: {0: 1000, 2: 1000, 6: 1000}, 52: {4: 1000, 8: 1000, 9: 1000}, 53: {1: 1000, 3: 1000, 5: 1000}, 54: {1: 1000, 4: 1000, 6: 1000}, 55: {1: 1000, 3: 1000, 5: 1000}, 56: {0: 1000, 2: 1000, 6: 1000}, 57: {1: 1000, 4: 1000, 6: 1000}, 58: {1: 1000, 3: 1000, 5: 1000}, 59: {0: 1000, 2: 1000, 6: 1000}, 60: {4: 1000, 8: 1000, 9: 1000}, 61: {3: 1000, 5: 1000, 7: 1000}, 62: {3: 1000, 5: 1000, 7: 1000}, 63: {3: 1000, 5: 1000, 7: 1000}, 64: {1: 1000, 3: 1000, 5: 1000}, 65: {0: 1000, 2: 1000, 6: 1000}, 66: {1: 1000, 3: 1000, 5: 1000}, 67: {1: 1000, 4: 1000, 6: 1000}, 68: {4: 1000, 8: 1000, 9: 1000}, 69: {1: 1000, 4: 1000, 6: 1000}, 70: {1: 1000, 3: 1000, 5: 1000}, 71: {3: 1000, 5: 1000, 7: 1000}, 72: {1: 1000, 3: 1000, 5: 1000}, 73: {1: 1000, 4: 1000, 6: 1000}, 74: {1: 1000, 4: 1000, 6: 1000}, 75: {1: 1000, 4: 1000, 6: 1000}, 76: {1: 1000, 4: 1000, 6: 1000}, 77: {4: 1000, 8: 1000, 9: 1000}, 78: {1: 1000, 4: 1000, 6: 1000}, 79: {3: 1000, 5: 1000, 7: 1000}, 80: {1: 1000, 4: 1000, 6: 1000}, 81: {0: 1000, 2: 1000, 6: 1000}, 82: {1: 1000, 3: 1000, 5: 1000}, 83: {3: 1000, 5: 1000, 7: 1000}, 84: {1: 1000, 4: 1000, 6: 1000}, 85: {4: 1000, 8: 1000, 9: 1000}, 86: {1: 1000, 4: 1000, 6: 1000}, 87: {0: 1000, 2: 1000, 6: 1000}, 88: {1: 1000, 4: 1000, 6: 1000}, 89: {4: 1000, 8: 1000, 9: 1000}, 90: {4: 1000, 8: 1000, 9: 1000}, 91: {1: 1000, 4: 1000, 6: 1000}, 92: {3: 1000, 5: 1000, 7: 1000}, 93: {0: 1000, 2: 1000, 6: 1000}, 94: {1: 1000, 4: 1000, 6: 1000}, 95: {1: 1000, 4: 1000, 6: 1000}, 96: {0: 1000, 2: 1000, 6: 1000}, 97: {1: 1000, 3: 1000, 5: 1000}, 98: {4: 1000, 8: 1000, 9: 1000}, 99: {0: 1000, 2: 1000, 6: 1000}} \n",
      "\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "len train_ds_global: 50000\n",
      "len test_ds_global: 10000\n",
      "MODEL: simple-cnn, Dataset: cifar10\n",
      "SimpleCNN(\n",
      "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (drop1): Dropout2d(p=0.5, inplace=False)\n",
      "  (fc1): Linear(in_features=1024, out_features=128, bias=True)\n",
      "  (drop2): Dropout2d(p=0.5, inplace=False)\n",
      "  (fc2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (drop3): Dropout2d(p=0.5, inplace=False)\n",
      "  (fc3): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n",
      "conv1.weight torch.Size([64, 3, 3, 3])\n",
      "conv1.bias torch.Size([64])\n",
      "conv2.weight torch.Size([128, 64, 3, 3])\n",
      "conv2.bias torch.Size([128])\n",
      "conv3.weight torch.Size([256, 128, 3, 3])\n",
      "conv3.bias torch.Size([256])\n",
      "fc1.weight torch.Size([128, 1024])\n",
      "fc1.bias torch.Size([128])\n",
      "fc2.weight torch.Size([256, 128])\n",
      "fc2.bias torch.Size([256])\n",
      "fc3.weight torch.Size([10, 256])\n",
      "fc3.bias torch.Size([10])\n",
      "537610\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "###### ROUND 1 ######\n",
      "Clients [87 93 52 85 20 65 92 80 48 29 62 45 12 58 99 78 94 46 22  6]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/anik/anaconda3/envs/PACFL/lib/python3.11/site-packages/torch/nn/functional.py:1374: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\n",
      "  warnings.warn(warn_msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "## END OF ROUND ##\n",
      "Average Train loss 0.896\n",
      "Global Model Test Acc: 18.460, Global Model Best Test Acc: 18.460\n",
      "--- PRINTING ALL CLIENTS STATUS ---\n",
      "Client   0, labels {1: 81, 3: 310, 5: 19}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client   1, labels {1: 163, 3: 217, 5: 229}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client   2, labels {3: 9, 5: 46, 7: 127}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client   3, labels {0: 452, 2: 104, 6: 123}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client   4, labels {4: 208, 8: 319, 9: 365}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client   5, labels {3: 85, 5: 110, 7: 309}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client   6, labels {0: 323, 2: 132, 6: 65}, count 0, best_acc 60.833, current_acc 60.833 \n",
      "\n",
      "Client   7, labels {1: 443, 4: 97, 6: 173}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client   8, labels {0: 342, 2: 90, 6: 158}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client   9, labels {1: 3, 4: 83, 6: 98}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  10, labels {0: 328, 2: 347, 6: 235}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  11, labels {3: 198, 5: 375, 7: 150}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  12, labels {4: 151, 8: 232, 9: 34}, count 0, best_acc 55.700, current_acc 55.700 \n",
      "\n",
      "Client  13, labels {0: 85, 2: 240, 6: 186}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  14, labels {0: 385, 2: 539, 6: 43}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  15, labels {4: 105, 8: 418, 9: 630}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  16, labels {3: 91, 5: 90, 7: 470}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  17, labels {3: 439, 5: 101, 7: 171}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  18, labels {1: 28, 3: 28, 5: 42}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  19, labels {1: 39, 4: 234, 6: 51}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  20, labels {3: 46, 5: 330, 7: 183}, count 0, best_acc 41.567, current_acc 41.567 \n",
      "\n",
      "Client  21, labels {1: 42, 4: 221, 6: 5}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  22, labels {1: 4, 4: 77, 6: 208}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  23, labels {3: 6, 5: 26, 7: 706}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  24, labels {0: 440, 2: 3, 6: 133}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  25, labels {1: 59, 3: 11, 5: 79}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  26, labels {1: 66, 3: 469, 5: 116}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  27, labels {4: 46, 8: 208, 9: 258}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  28, labels {3: 93, 5: 74, 7: 214}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  29, labels {4: 31, 8: 91, 9: 292}, count 0, best_acc 46.767, current_acc 46.767 \n",
      "\n",
      "Client  30, labels {3: 131, 5: 103, 7: 107}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  31, labels {4: 164, 8: 611, 9: 182}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  32, labels {3: 100, 5: 107, 7: 27}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  33, labels {3: 158, 5: 182, 7: 394}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  34, labels {1: 29, 3: 228, 5: 133}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  35, labels {4: 63, 8: 31, 9: 370}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  36, labels {3: 41, 5: 169, 7: 69}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  37, labels {1: 125, 4: 224, 6: 267}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  38, labels {3: 205, 5: 239, 7: 225}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  39, labels {1: 285, 4: 39, 6: 7}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  40, labels {1: 308, 3: 142, 5: 103}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  41, labels {3: 51, 5: 73, 7: 2}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  42, labels {3: 184, 5: 325, 7: 134}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  43, labels {0: 235, 2: 18, 6: 501}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  44, labels {4: 215, 8: 6, 9: 552}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  45, labels {0: 610, 2: 5, 6: 40}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  46, labels {0: 30, 2: 493, 6: 104}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  47, labels {3: 62, 5: 112, 7: 37}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  48, labels {1: 3, 3: 2, 5: 69}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  49, labels {4: 356, 8: 467, 9: 22}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  50, labels {1: 289, 4: 23, 6: 18}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  51, labels {0: 139, 2: 96, 6: 61}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  52, labels {4: 35, 8: 47, 9: 122}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  53, labels {1: 139, 3: 45, 5: 3}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  54, labels {1: 120, 4: 29, 6: 107}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  55, labels {1: 182, 3: 100, 5: 335}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  56, labels {0: 279, 2: 972, 6: 260}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  57, labels {1: 149, 4: 404, 6: 98}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  58, labels {1: 213, 3: 86, 5: 8}, count 0, best_acc 34.167, current_acc 34.167 \n",
      "\n",
      "Client  59, labels {0: 158, 2: 733, 6: 16}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  60, labels {4: 9, 8: 144, 9: 155}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  61, labels {3: 53, 5: 159, 7: 279}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  62, labels {3: 233, 5: 351, 7: 153}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  63, labels {3: 20, 5: 31, 7: 9}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  64, labels {1: 5, 3: 127, 5: 16}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  65, labels {0: 28, 2: 38, 6: 463}, count 0, best_acc 34.333, current_acc 34.333 \n",
      "\n",
      "Client  66, labels {1: 163, 3: 162, 5: 8}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  67, labels {1: 96, 4: 49, 6: 73}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  68, labels {4: 152, 8: 104, 9: 8}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  69, labels {1: 63, 4: 231, 6: 55}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  70, labels {1: 1, 3: 42, 5: 191}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  71, labels {3: 69, 5: 4, 7: 77}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  72, labels {1: 22, 3: 64, 5: 37}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  73, labels {1: 229, 4: 202, 6: 271}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  74, labels {1: 351, 4: 32, 6: 31}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  75, labels {1: 18, 4: 56, 6: 163}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  76, labels {1: 115, 4: 159, 6: 128}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  77, labels {4: 82, 8: 459, 9: 591}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  78, labels {1: 83, 4: 26, 6: 15}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  79, labels {3: 54, 5: 147, 7: 846}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  80, labels {1: 38, 4: 48, 6: 4}, count 0, best_acc 46.100, current_acc 46.100 \n",
      "\n",
      "Client  81, labels {0: 308, 2: 71, 6: 151}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  82, labels {1: 67, 3: 148, 5: 146}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  83, labels {3: 367, 5: 10, 7: 82}, count 0, best_acc 0.100, current_acc 0.100 \n",
      "\n",
      "Client  84, labels {1: 215, 4: 102, 6: 50}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  85, labels {4: 30, 8: 18, 9: 52}, count 0, best_acc 33.333, current_acc 33.333 \n",
      "\n",
      "Client  86, labels {1: 1, 4: 81, 6: 34}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  87, labels {0: 616, 2: 464, 6: 48}, count 0, best_acc 56.133, current_acc 56.133 \n",
      "\n",
      "Client  88, labels {1: 34, 4: 34, 6: 83}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  89, labels {4: 2, 8: 131, 9: 470}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  90, labels {4: 230, 8: 658, 9: 613}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  91, labels {1: 11, 4: 200, 6: 13}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  92, labels {3: 7, 5: 300, 7: 229}, count 0, best_acc 42.367, current_acc 42.367 \n",
      "\n",
      "Client  93, labels {0: 54, 2: 69, 6: 59}, count 0, best_acc 53.733, current_acc 53.733 \n",
      "\n",
      "Client  94, labels {1: 664, 4: 34, 6: 56}, count 0, best_acc 34.133, current_acc 34.133 \n",
      "\n",
      "Client  95, labels {1: 17, 4: 234, 6: 132}, count 0, best_acc 4.333, current_acc 4.333 \n",
      "\n",
      "Client  96, labels {0: 12, 2: 178, 6: 63}, count 0, best_acc 26.967, current_acc 26.967 \n",
      "\n",
      "Client  97, labels {1: 37, 3: 117, 5: 2}, count 0, best_acc 4.433, current_acc 4.433 \n",
      "\n",
      "Client  98, labels {4: 202, 8: 1056, 9: 284}, count 0, best_acc 0.467, current_acc 0.467 \n",
      "\n",
      "Client  99, labels {0: 176, 2: 408, 6: 151}, count 0, best_acc 53.867, current_acc 53.867 \n",
      "\n",
      "Round 1, Avg current_acc 13.112, Avg best_acc 13.112\n",
      "----- Analysis End of Round -------\n",
      "Client 87, Count: 0, Labels: {0: 616, 2: 464, 6: 48}\n",
      "Client 93, Count: 0, Labels: {0: 54, 2: 69, 6: 59}\n",
      "Client 52, Count: 0, Labels: {4: 35, 8: 47, 9: 122}\n",
      "Client 85, Count: 0, Labels: {4: 30, 8: 18, 9: 52}\n",
      "Client 20, Count: 0, Labels: {3: 46, 5: 330, 7: 183}\n",
      "Client 65, Count: 0, Labels: {0: 28, 2: 38, 6: 463}\n",
      "Client 92, Count: 0, Labels: {3: 7, 5: 300, 7: 229}\n",
      "Client 80, Count: 0, Labels: {1: 38, 4: 48, 6: 4}\n",
      "Client 48, Count: 0, Labels: {1: 3, 3: 2, 5: 69}\n",
      "Client 29, Count: 0, Labels: {4: 31, 8: 91, 9: 292}\n",
      "Client 62, Count: 0, Labels: {3: 233, 5: 351, 7: 153}\n",
      "Client 45, Count: 0, Labels: {0: 610, 2: 5, 6: 40}\n",
      "Client 12, Count: 0, Labels: {4: 151, 8: 232, 9: 34}\n",
      "Client 58, Count: 0, Labels: {1: 213, 3: 86, 5: 8}\n",
      "Client 99, Count: 0, Labels: {0: 176, 2: 408, 6: 151}\n",
      "Client 78, Count: 0, Labels: {1: 83, 4: 26, 6: 15}\n",
      "Client 94, Count: 0, Labels: {1: 664, 4: 34, 6: 56}\n",
      "Client 46, Count: 0, Labels: {0: 30, 2: 493, 6: 104}\n",
      "Client 22, Count: 0, Labels: {1: 4, 4: 77, 6: 208}\n",
      "Client 6, Count: 0, Labels: {0: 323, 2: 132, 6: 65}\n",
      "###### ROUND 2 ######\n",
      "Clients [33  5 50 35 27 10 82 92 99 19 62 46 48 17 40 31 70 96 76 75]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.837\n",
      "Global Model Test Acc: 16.130, Global Model Best Test Acc: 18.460\n",
      "----- Analysis End of Round -------\n",
      "Client 33, Count: 0, Labels: {3: 158, 5: 182, 7: 394}\n",
      "Client 5, Count: 0, Labels: {3: 85, 5: 110, 7: 309}\n",
      "Client 50, Count: 0, Labels: {1: 289, 4: 23, 6: 18}\n",
      "Client 35, Count: 0, Labels: {4: 63, 8: 31, 9: 370}\n",
      "Client 27, Count: 0, Labels: {4: 46, 8: 208, 9: 258}\n",
      "Client 10, Count: 0, Labels: {0: 328, 2: 347, 6: 235}\n",
      "Client 82, Count: 0, Labels: {1: 67, 3: 148, 5: 146}\n",
      "Client 92, Count: 0, Labels: {3: 7, 5: 300, 7: 229}\n",
      "Client 99, Count: 0, Labels: {0: 176, 2: 408, 6: 151}\n",
      "Client 19, Count: 0, Labels: {1: 39, 4: 234, 6: 51}\n",
      "Client 62, Count: 0, Labels: {3: 233, 5: 351, 7: 153}\n",
      "Client 46, Count: 0, Labels: {0: 30, 2: 493, 6: 104}\n",
      "Client 48, Count: 0, Labels: {1: 3, 3: 2, 5: 69}\n",
      "Client 17, Count: 0, Labels: {3: 439, 5: 101, 7: 171}\n",
      "Client 40, Count: 0, Labels: {1: 308, 3: 142, 5: 103}\n",
      "Client 31, Count: 0, Labels: {4: 164, 8: 611, 9: 182}\n",
      "Client 70, Count: 0, Labels: {1: 1, 3: 42, 5: 191}\n",
      "Client 96, Count: 0, Labels: {0: 12, 2: 178, 6: 63}\n",
      "Client 76, Count: 0, Labels: {1: 115, 4: 159, 6: 128}\n",
      "Client 75, Count: 0, Labels: {1: 18, 4: 56, 6: 163}\n",
      "###### ROUND 3 ######\n",
      "Clients [80 66 34 30 69  6 63 68 19 59 67 12 26  9 27 41 13  3 70 50]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.764\n",
      "Global Model Test Acc: 18.890, Global Model Best Test Acc: 18.890\n",
      "----- Analysis End of Round -------\n",
      "Client 80, Count: 0, Labels: {1: 38, 4: 48, 6: 4}\n",
      "Client 66, Count: 0, Labels: {1: 163, 3: 162, 5: 8}\n",
      "Client 34, Count: 0, Labels: {1: 29, 3: 228, 5: 133}\n",
      "Client 30, Count: 0, Labels: {3: 131, 5: 103, 7: 107}\n",
      "Client 69, Count: 0, Labels: {1: 63, 4: 231, 6: 55}\n",
      "Client 6, Count: 0, Labels: {0: 323, 2: 132, 6: 65}\n",
      "Client 63, Count: 0, Labels: {3: 20, 5: 31, 7: 9}\n",
      "Client 68, Count: 0, Labels: {4: 152, 8: 104, 9: 8}\n",
      "Client 19, Count: 0, Labels: {1: 39, 4: 234, 6: 51}\n",
      "Client 59, Count: 0, Labels: {0: 158, 2: 733, 6: 16}\n",
      "Client 67, Count: 0, Labels: {1: 96, 4: 49, 6: 73}\n",
      "Client 12, Count: 0, Labels: {4: 151, 8: 232, 9: 34}\n",
      "Client 26, Count: 0, Labels: {1: 66, 3: 469, 5: 116}\n",
      "Client 9, Count: 0, Labels: {1: 3, 4: 83, 6: 98}\n",
      "Client 27, Count: 0, Labels: {4: 46, 8: 208, 9: 258}\n",
      "Client 41, Count: 0, Labels: {3: 51, 5: 73, 7: 2}\n",
      "Client 13, Count: 0, Labels: {0: 85, 2: 240, 6: 186}\n",
      "Client 3, Count: 0, Labels: {0: 452, 2: 104, 6: 123}\n",
      "Client 70, Count: 0, Labels: {1: 1, 3: 42, 5: 191}\n",
      "Client 50, Count: 0, Labels: {1: 289, 4: 23, 6: 18}\n",
      "###### ROUND 4 ######\n",
      "Clients [17 61 24 66 73 79 69 34  6 62 76 16 10 30  3 48 55 27 35 25]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.688\n",
      "Global Model Test Acc: 15.390, Global Model Best Test Acc: 18.890\n",
      "----- Analysis End of Round -------\n",
      "Client 17, Count: 0, Labels: {3: 439, 5: 101, 7: 171}\n",
      "Client 61, Count: 0, Labels: {3: 53, 5: 159, 7: 279}\n",
      "Client 24, Count: 0, Labels: {0: 440, 2: 3, 6: 133}\n",
      "Client 66, Count: 0, Labels: {1: 163, 3: 162, 5: 8}\n",
      "Client 73, Count: 0, Labels: {1: 229, 4: 202, 6: 271}\n",
      "Client 79, Count: 0, Labels: {3: 54, 5: 147, 7: 846}\n",
      "Client 69, Count: 0, Labels: {1: 63, 4: 231, 6: 55}\n",
      "Client 34, Count: 0, Labels: {1: 29, 3: 228, 5: 133}\n",
      "Client 6, Count: 0, Labels: {0: 323, 2: 132, 6: 65}\n",
      "Client 62, Count: 0, Labels: {3: 233, 5: 351, 7: 153}\n",
      "Client 76, Count: 0, Labels: {1: 115, 4: 159, 6: 128}\n",
      "Client 16, Count: 0, Labels: {3: 91, 5: 90, 7: 470}\n",
      "Client 10, Count: 0, Labels: {0: 328, 2: 347, 6: 235}\n",
      "Client 30, Count: 0, Labels: {3: 131, 5: 103, 7: 107}\n",
      "Client 3, Count: 0, Labels: {0: 452, 2: 104, 6: 123}\n",
      "Client 48, Count: 0, Labels: {1: 3, 3: 2, 5: 69}\n",
      "Client 55, Count: 0, Labels: {1: 182, 3: 100, 5: 335}\n",
      "Client 27, Count: 0, Labels: {4: 46, 8: 208, 9: 258}\n",
      "Client 35, Count: 0, Labels: {4: 63, 8: 31, 9: 370}\n",
      "Client 25, Count: 0, Labels: {1: 59, 3: 11, 5: 79}\n",
      "###### ROUND 5 ######\n",
      "Clients [65 53 42 62 25 90  0 10 50 87 46 74  5 56 14  3 66 28 67 81]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.557\n",
      "Global Model Test Acc: 20.120, Global Model Best Test Acc: 20.120\n",
      "----- Analysis End of Round -------\n",
      "Client 65, Count: 0, Labels: {0: 28, 2: 38, 6: 463}\n",
      "Client 53, Count: 0, Labels: {1: 139, 3: 45, 5: 3}\n",
      "Client 42, Count: 0, Labels: {3: 184, 5: 325, 7: 134}\n",
      "Client 62, Count: 0, Labels: {3: 233, 5: 351, 7: 153}\n",
      "Client 25, Count: 0, Labels: {1: 59, 3: 11, 5: 79}\n",
      "Client 90, Count: 0, Labels: {4: 230, 8: 658, 9: 613}\n",
      "Client 0, Count: 0, Labels: {1: 81, 3: 310, 5: 19}\n",
      "Client 10, Count: 0, Labels: {0: 328, 2: 347, 6: 235}\n",
      "Client 50, Count: 0, Labels: {1: 289, 4: 23, 6: 18}\n",
      "Client 87, Count: 0, Labels: {0: 616, 2: 464, 6: 48}\n",
      "Client 46, Count: 0, Labels: {0: 30, 2: 493, 6: 104}\n",
      "Client 74, Count: 0, Labels: {1: 351, 4: 32, 6: 31}\n",
      "Client 5, Count: 0, Labels: {3: 85, 5: 110, 7: 309}\n",
      "Client 56, Count: 0, Labels: {0: 279, 2: 972, 6: 260}\n",
      "Client 14, Count: 0, Labels: {0: 385, 2: 539, 6: 43}\n",
      "Client 3, Count: 0, Labels: {0: 452, 2: 104, 6: 123}\n",
      "Client 66, Count: 0, Labels: {1: 163, 3: 162, 5: 8}\n",
      "Client 28, Count: 0, Labels: {3: 93, 5: 74, 7: 214}\n",
      "Client 67, Count: 0, Labels: {1: 96, 4: 49, 6: 73}\n",
      "Client 81, Count: 0, Labels: {0: 308, 2: 71, 6: 151}\n",
      "###### ROUND 6 ######\n",
      "Clients [51 35 50 48 58  8 64 55 96 24 29 71 75 22 74 60 69 94 19 15]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.504\n",
      "Global Model Test Acc: 19.730, Global Model Best Test Acc: 20.120\n",
      "----- Analysis End of Round -------\n",
      "Client 51, Count: 0, Labels: {0: 139, 2: 96, 6: 61}\n",
      "Client 35, Count: 0, Labels: {4: 63, 8: 31, 9: 370}\n",
      "Client 50, Count: 0, Labels: {1: 289, 4: 23, 6: 18}\n",
      "Client 48, Count: 0, Labels: {1: 3, 3: 2, 5: 69}\n",
      "Client 58, Count: 0, Labels: {1: 213, 3: 86, 5: 8}\n",
      "Client 8, Count: 0, Labels: {0: 342, 2: 90, 6: 158}\n",
      "Client 64, Count: 0, Labels: {1: 5, 3: 127, 5: 16}\n",
      "Client 55, Count: 0, Labels: {1: 182, 3: 100, 5: 335}\n",
      "Client 96, Count: 0, Labels: {0: 12, 2: 178, 6: 63}\n",
      "Client 24, Count: 0, Labels: {0: 440, 2: 3, 6: 133}\n",
      "Client 29, Count: 0, Labels: {4: 31, 8: 91, 9: 292}\n",
      "Client 71, Count: 0, Labels: {3: 69, 5: 4, 7: 77}\n",
      "Client 75, Count: 0, Labels: {1: 18, 4: 56, 6: 163}\n",
      "Client 22, Count: 0, Labels: {1: 4, 4: 77, 6: 208}\n",
      "Client 74, Count: 0, Labels: {1: 351, 4: 32, 6: 31}\n",
      "Client 60, Count: 0, Labels: {4: 9, 8: 144, 9: 155}\n",
      "Client 69, Count: 0, Labels: {1: 63, 4: 231, 6: 55}\n",
      "Client 94, Count: 0, Labels: {1: 664, 4: 34, 6: 56}\n",
      "Client 19, Count: 0, Labels: {1: 39, 4: 234, 6: 51}\n",
      "Client 15, Count: 0, Labels: {4: 105, 8: 418, 9: 630}\n",
      "###### ROUND 7 ######\n",
      "Clients [99 35 29 58 95 11 26  9 14 77 79 59  3 86 84 93 34  0 10 22]\n",
      "## END OF ROUND ##\n",
      "Average Train loss 0.535\n",
      "Global Model Test Acc: 22.550, Global Model Best Test Acc: 22.550\n",
      "----- Analysis End of Round -------\n",
      "Client 99, Count: 0, Labels: {0: 176, 2: 408, 6: 151}\n",
      "Client 35, Count: 0, Labels: {4: 63, 8: 31, 9: 370}\n",
      "Client 29, Count: 0, Labels: {4: 31, 8: 91, 9: 292}\n",
      "Client 58, Count: 0, Labels: {1: 213, 3: 86, 5: 8}\n",
      "Client 95, Count: 0, Labels: {1: 17, 4: 234, 6: 132}\n",
      "Client 11, Count: 0, Labels: {3: 198, 5: 375, 7: 150}\n",
      "Client 26, Count: 0, Labels: {1: 66, 3: 469, 5: 116}\n",
      "Client 9, Count: 0, Labels: {1: 3, 4: 83, 6: 98}\n",
      "Client 14, Count: 0, Labels: {0: 385, 2: 539, 6: 43}\n",
      "Client 77, Count: 0, Labels: {4: 82, 8: 459, 9: 591}\n",
      "Client 79, Count: 0, Labels: {3: 54, 5: 147, 7: 846}\n",
      "Client 59, Count: 0, Labels: {0: 158, 2: 733, 6: 16}\n",
      "Client 3, Count: 0, Labels: {0: 452, 2: 104, 6: 123}\n",
      "Client 86, Count: 0, Labels: {1: 1, 4: 81, 6: 34}\n",
      "Client 84, Count: 0, Labels: {1: 215, 4: 102, 6: 50}\n",
      "Client 93, Count: 0, Labels: {0: 54, 2: 69, 6: 59}\n",
      "Client 34, Count: 0, Labels: {1: 29, 3: 228, 5: 133}\n",
      "Client 0, Count: 0, Labels: {1: 81, 3: 310, 5: 19}\n",
      "Client 10, Count: 0, Labels: {0: 328, 2: 347, 6: 235}\n",
      "Client 22, Count: 0, Labels: {1: 4, 4: 77, 6: 208}\n",
      "###### ROUND 8 ######\n",
      "Clients [71 95 11 26 48 32 36 75 76 49 39 79 94 23 10 92 97 81 16 22]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 226\u001b[0m\n\u001b[1;32m    219\u001b[0m         clients[idx]\u001b[38;5;241m.\u001b[39mset_state_dict(copy\u001b[38;5;241m.\u001b[39mdeepcopy(w_glob)) \n\u001b[1;32m    221\u001b[0m \u001b[38;5;66;03m#         loss, acc = clients[idx].eval_test()        \u001b[39;00m\n\u001b[1;32m    222\u001b[0m             \n\u001b[1;32m    223\u001b[0m \u001b[38;5;66;03m#         init_local_tacc.append(acc)\u001b[39;00m\n\u001b[1;32m    224\u001b[0m \u001b[38;5;66;03m#         init_local_tloss.append(loss)\u001b[39;00m\n\u001b[0;32m--> 226\u001b[0m         loss \u001b[38;5;241m=\u001b[39m \u001b[43mclients\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mis_print\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    228\u001b[0m         loss_locals\u001b[38;5;241m.\u001b[39mappend(copy\u001b[38;5;241m.\u001b[39mdeepcopy(loss))\n\u001b[1;32m    230\u001b[0m \u001b[38;5;66;03m#         loss, acc = clients[idx].eval_test()\u001b[39;00m\n\u001b[1;32m    231\u001b[0m         \n\u001b[1;32m    232\u001b[0m \u001b[38;5;66;03m#         if acc > clients_best_acc[idx]:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    235\u001b[0m \u001b[38;5;66;03m#         final_local_tacc.append(acc)\u001b[39;00m\n\u001b[1;32m    236\u001b[0m \u001b[38;5;66;03m#         final_local_tloss.append(loss)           \u001b[39;00m\n",
      "File \u001b[0;32m~/Misc/PACFLComboNB/Flag/src/client/client_fedavg.py:35\u001b[0m, in \u001b[0;36mClient_FedAvg.train\u001b[0;34m(self, is_print)\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m iteration \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlocal_ep):\n\u001b[1;32m     34\u001b[0m     batch_loss \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m---> 35\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mimages\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mldr_train\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m     36\u001b[0m \u001b[43m        \u001b[49m\u001b[43mimages\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mimages\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     37\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzero_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/PACFL/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    629\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    630\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    635\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/anaconda3/envs/PACFL/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    674\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    676\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m    677\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[0;32m~/anaconda3/envs/PACFL/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[0;32m~/anaconda3/envs/PACFL/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[0;32m~/Misc/PACFLComboNB/Flag/src/utils/datasets.py:629\u001b[0m, in \u001b[0;36mCIFAR10_truncated.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m    625\u001b[0m \u001b[38;5;66;03m# print(\"cifar10 img:\", img)\u001b[39;00m\n\u001b[1;32m    626\u001b[0m \u001b[38;5;66;03m# print(\"cifar10 target:\", target)\u001b[39;00m\n\u001b[1;32m    628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 629\u001b[0m     img \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    631\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    632\u001b[0m     target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(target)\n",
      "File \u001b[0;32m~/anaconda3/envs/PACFL/lib/python3.11/site-packages/torchvision/transforms/transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m     93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[1;32m     94\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[0;32m---> 95\u001b[0m         img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     96\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
      "File \u001b[0;32m~/Misc/PACFLComboNB/Flag/src/utils/utils.py:832\u001b[0m, in \u001b[0;36m__call__\u001b[0;34m(self, tensor)\u001b[0m\n\u001b[1;32m    828\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mNormalNLLLoss\u001b[39;00m:\n\u001b[1;32m    829\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    830\u001b[0m \u001b[38;5;124;03m    Calculate the negative log likelihood\u001b[39;00m\n\u001b[1;32m    831\u001b[0m \u001b[38;5;124;03m    of normal distribution.\u001b[39;00m\n\u001b[0;32m--> 832\u001b[0m \u001b[38;5;124;03m    This needs to be minimised.\u001b[39;00m\n\u001b[1;32m    833\u001b[0m \n\u001b[1;32m    834\u001b[0m \u001b[38;5;124;03m    Treating Q(cj | x) as a factored Gaussian.\u001b[39;00m\n\u001b[1;32m    835\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m    836\u001b[0m     \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, mu, var):\n\u001b[1;32m    838\u001b[0m         logli \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.5\u001b[39m \u001b[38;5;241m*\u001b[39m (var\u001b[38;5;241m.\u001b[39mmul(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mpi) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1e-6\u001b[39m)\u001b[38;5;241m.\u001b[39mlog() \u001b[38;5;241m-\u001b[39m (x \u001b[38;5;241m-\u001b[39m mu)\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mdiv(var\u001b[38;5;241m.\u001b[39mmul(\u001b[38;5;241m2.0\u001b[39m) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1e-6\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "##################################### Data partitioning section \n",
    "args.local_view = True\n",
    "X_train, y_train, X_test, y_test, net_dataidx_map, net_dataidx_map_test, \\\n",
    "traindata_cls_counts, testdata_cls_counts = partition_data(args.dataset, \n",
    "args.datadir, args.logdir, args.partition, args.num_users, beta=args.beta, local_view=args.local_view)\n",
    "\n",
    "train_dl_global, test_dl_global, train_ds_global, test_ds_global = get_dataloader(args.dataset,\n",
    "                                                                                   args.datadir,\n",
    "                                                                                   args.batch_size,\n",
    "                                                                                   32)\n",
    "\n",
    "print(\"len train_ds_global:\", len(train_ds_global))\n",
    "print(\"len test_ds_global:\", len(test_ds_global))\n",
    "\n",
    "################################### build model\n",
    "def init_nets(args, dropout_p=0.5):\n",
    "\n",
    "    users_model = []\n",
    "\n",
    "    for net_i in range(-1, args.num_users):\n",
    "        if args.dataset == \"generated\":\n",
    "            net = PerceptronModel().to(args.device)\n",
    "        elif args.model == \"mlp\":\n",
    "            if args.dataset == 'covtype':\n",
    "                input_size = 54\n",
    "                output_size = 2\n",
    "                hidden_sizes = [32,16,8]\n",
    "            elif args.dataset == 'a9a':\n",
    "                input_size = 123\n",
    "                output_size = 2\n",
    "                hidden_sizes = [32,16,8]\n",
    "            elif args.dataset == 'rcv1':\n",
    "                input_size = 47236\n",
    "                output_size = 2\n",
    "                hidden_sizes = [32,16,8]\n",
    "            elif args.dataset == 'SUSY':\n",
    "                input_size = 18\n",
    "                output_size = 2\n",
    "                hidden_sizes = [16,8]\n",
    "            net = FcNet(input_size, hidden_sizes, output_size, dropout_p).to(args.device)\n",
    "        elif args.model == \"vgg\":\n",
    "            net = vgg11().to(args.device)\n",
    "        elif args.model == \"simple-cnn\":\n",
    "            if args.dataset in (\"cifar10\", \"cinic10\", \"svhn\"):\n",
    "                net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=10).to(args.device)\n",
    "            elif args.dataset in (\"mnist\", 'femnist', 'fmnist'):\n",
    "                net = SimpleCNNMNIST(input_dim=(16 * 4 * 4), hidden_dims=[120, 84], output_dim=10).to(args.device)\n",
    "            elif args.dataset == 'celeba':\n",
    "                net = SimpleCNN(input_dim=(16 * 5 * 5), hidden_dims=[120, 84], output_dim=2).to(args.device)\n",
    "        elif args.model ==\"simple-cnn-3\":\n",
    "            if args.dataset == 'cifar100': \n",
    "                net = SimpleCNN_3(input_dim=(16 * 3 * 5 * 5), hidden_dims=[120*3, 84*3], output_dim=100).to(args.device)\n",
    "            if args.dataset == 'tinyimagenet':\n",
    "                net = SimpleCNNTinyImagenet_3(input_dim=(16 * 3 * 13 * 13), hidden_dims=[120*3, 84*3], \n",
    "                                              output_dim=200).to(args.device)\n",
    "        elif args.model == \"vgg-9\":\n",
    "            if args.dataset in (\"mnist\", 'femnist'):\n",
    "                net = ModerateCNNMNIST().to(args.device)\n",
    "            elif args.dataset in (\"cifar10\", \"cinic10\", \"svhn\"):\n",
    "                # print(\"in moderate cnn\")\n",
    "                net = ModerateCNN().to(args.device)\n",
    "            elif args.dataset == 'celeba':\n",
    "                net = ModerateCNN(output_dim=2).to(args.device)\n",
    "        elif args.model == 'resnet9': \n",
    "            if args.dataset == 'cifar100': \n",
    "                net = ResNet9(in_channels=3, num_classes=100)\n",
    "            elif args.dataset == 'tinyimagenet': \n",
    "                net = ResNet9(in_channels=3, num_classes=200, dim=512*2*2)\n",
    "        elif args.model == \"resnet\":\n",
    "            net = ResNet50_cifar10().to(args.device)\n",
    "        elif args.model == \"vgg16\":\n",
    "            net = vgg16().to(args.device)\n",
    "        else:\n",
    "            print(\"not supported yet\")\n",
    "            exit(1)\n",
    "        if net_i == -1: \n",
    "            net_glob = copy.deepcopy(net)\n",
    "            initial_state_dict = copy.deepcopy(net_glob.state_dict())\n",
    "            server_state_dict = copy.deepcopy(net_glob.state_dict())\n",
    "            if args.load_initial:\n",
    "                initial_state_dict = torch.load(args.load_initial)\n",
    "                server_state_dict = torch.load(args.load_initial)\n",
    "                net_glob.load_state_dict(initial_state_dict)\n",
    "        else:\n",
    "            users_model.append(copy.deepcopy(net))\n",
    "            users_model[net_i].load_state_dict(initial_state_dict)\n",
    "\n",
    "#     model_meta_data = []\n",
    "#     layer_type = []\n",
    "#     for (k, v) in nets[0].state_dict().items():\n",
    "#         model_meta_data.append(v.shape)\n",
    "#         layer_type.append(k)\n",
    "\n",
    "    return users_model, net_glob, initial_state_dict, server_state_dict\n",
    "\n",
    "print(f'MODEL: {args.model}, Dataset: {args.dataset}')\n",
    "\n",
    "users_model, net_glob, initial_state_dict, server_state_dict = init_nets(args, dropout_p=0.5)\n",
    "\n",
    "print(net_glob)\n",
    "\n",
    "total = 0 \n",
    "for name, param in net_glob.named_parameters():\n",
    "    print(name, param.size())\n",
    "    total += np.prod(param.size())\n",
    "    #print(np.array(param.data.cpu().numpy().reshape([-1])))\n",
    "    #print(isinstance(param.data.cpu().numpy(), np.array))\n",
    "print(total)\n",
    "\n",
    "################################# Fixing all to the same Init and data partitioning and random users \n",
    "#print(os.getcwd())\n",
    "\n",
    "# tt = '../initialization/' + 'traindata_'+args.dataset+'_'+args.partition+'.pkl'\n",
    "# with open(tt, 'rb') as f:\n",
    "#     net_dataidx_map = pickle.load(f)\n",
    "    \n",
    "# tt = '../initialization/' + 'testdata_'+args.dataset+'_'+args.partition+'.pkl'\n",
    "# with open(tt, 'rb') as f:\n",
    "#     net_dataidx_map_test = pickle.load(f)\n",
    "    \n",
    "# tt = '../initialization/' + 'traindata_cls_counts_'+args.dataset+'_'+args.partition+'.pkl'\n",
    "# with open(tt, 'rb') as f:\n",
    "#     traindata_cls_counts = pickle.load(f)\n",
    "    \n",
    "# tt = '../initialization/' + 'testdata_cls_counts_'+args.dataset+'_'+args.partition+'.pkl'\n",
    "# with open(tt, 'rb') as f:\n",
    "#     testdata_cls_counts = pickle.load(f)\n",
    "\n",
    "#tt = '../initialization/' + 'init_'+args.model+'_'+args.dataset+'.pth'\n",
    "#initial_state_dict = torch.load(tt, map_location=args.device)\n",
    "\n",
    "#server_state_dict = copy.deepcopy(initial_state_dict)\n",
    "#for idx in range(args.num_users):\n",
    "#    users_model[idx].load_state_dict(initial_state_dict)\n",
    "    \n",
    "#net_glob.load_state_dict(initial_state_dict)\n",
    "\n",
    "# tt = '../initialization/' + 'comm_users.pkl'\n",
    "# with open(tt, 'rb') as f:\n",
    "#     comm_users = pickle.load(f)\n",
    "    \n",
    "################################# Initializing Clients   \n",
    "clients = []\n",
    "\n",
    "K = args.n_basis\n",
    "#K = 5\n",
    "for idx in range(args.num_users):\n",
    "    \n",
    "    dataidxs = net_dataidx_map[idx]\n",
    "    if net_dataidx_map_test is None:\n",
    "        dataidx_test = None \n",
    "    else:\n",
    "        dataidxs_test = net_dataidx_map_test[idx]\n",
    "\n",
    "    #print(f'Initializing Client {idx}')\n",
    "\n",
    "    noise_level = args.noise\n",
    "    if idx == args.num_users - 1:\n",
    "        noise_level = 0\n",
    "\n",
    "    if args.noise_type == 'space':\n",
    "        train_dl_local, test_dl_local, train_ds_local, test_ds_local,p,v = get_dataloader(args.dataset, \n",
    "                                                                       args.datadir, args.local_bs, 32, \n",
    "                                                                       dataidxs, noise_level, idx, \n",
    "                                                                       args.num_users-1, \n",
    "                                                                       dataidxs_test=dataidxs_test)\n",
    "    else:\n",
    "        noise_level = args.noise / (args.num_users - 1) * idx\n",
    "        train_dl_local, test_dl_local, train_ds_local, test_ds_local,p,v = get_dataloader(args.dataset, \n",
    "                                                                       args.datadir, args.local_bs, 32, \n",
    "                                                                       dataidxs, noise_level, \n",
    "                                                                       dataidxs_test=dataidxs_test)\n",
    "    \n",
    "    clients.append(Client_FedAvg(idx, copy.deepcopy(users_model[idx]), args.local_bs, args.local_ep, \n",
    "               args.lr, args.momentum, args.device, train_dl_local, test_dl_local))\n",
    "    \n",
    "###################################### Federation \n",
    "\n",
    "loss_train = []\n",
    "\n",
    "init_tracc_pr = []  # initial train accuracy for each round \n",
    "final_tracc_pr = [] # final train accuracy for each round \n",
    "\n",
    "init_tacc_pr = []  # initial test accuarcy for each round \n",
    "final_tacc_pr = [] # final test accuracy for each round\n",
    "\n",
    "init_tloss_pr = []  # initial test loss for each round \n",
    "final_tloss_pr = [] # final test loss for each round \n",
    "\n",
    "clients_best_acc = [0 for _ in range(args.num_users)]\n",
    "w_locals, loss_locals = [], []\n",
    "\n",
    "init_local_tacc = []       # initial local test accuracy at each round \n",
    "final_local_tacc = []      # final local test accuracy at each round \n",
    "\n",
    "init_local_tloss = []      # initial local test loss at each round \n",
    "final_local_tloss = []     # final local test loss at each round \n",
    "\n",
    "ckp_avg_tacc = []\n",
    "ckp_avg_best_tacc = []\n",
    "\n",
    "users_best_acc = [0 for _ in range(args.num_users)]\n",
    "best_glob_acc = 0\n",
    "\n",
    "w_glob = copy.deepcopy(initial_state_dict)\n",
    "print_flag = False\n",
    "for iteration in range(args.rounds):\n",
    "        \n",
    "    m = max(int(args.frac * args.num_users), 1)\n",
    "    idxs_users = np.random.choice(range(args.num_users), m, replace=False)\n",
    "    \n",
    "    #idxs_users = comm_users[iteration]\n",
    "    \n",
    "    print(f'###### ROUND {iteration+1} ######')\n",
    "    print(f'Clients {idxs_users}')\n",
    "        \n",
    "    for idx in idxs_users:\n",
    "        \n",
    "        clients[idx].set_state_dict(copy.deepcopy(w_glob)) \n",
    "            \n",
    "#         loss, acc = clients[idx].eval_test()        \n",
    "            \n",
    "#         init_local_tacc.append(acc)\n",
    "#         init_local_tloss.append(loss)\n",
    "            \n",
    "        loss = clients[idx].train(is_print=False)\n",
    "                        \n",
    "        loss_locals.append(copy.deepcopy(loss))\n",
    "                       \n",
    "#         loss, acc = clients[idx].eval_test()\n",
    "        \n",
    "#         if acc > clients_best_acc[idx]:\n",
    "#             clients_best_acc[idx] = acc\n",
    "        \n",
    "#         final_local_tacc.append(acc)\n",
    "#         final_local_tloss.append(loss)           \n",
    "    \n",
    "    total_data_points = sum([len(net_dataidx_map[r]) for r in idxs_users])\n",
    "    fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in idxs_users]\n",
    "    \n",
    "    w_locals = []\n",
    "    for idx in idxs_users:\n",
    "        w_locals.append(copy.deepcopy(clients[idx].get_state_dict()))\n",
    "\n",
    "    ww = FedAvg(w_locals, weight_avg=fed_avg_freqs)\n",
    "    w_glob = copy.deepcopy(ww)\n",
    "    net_glob.load_state_dict(copy.deepcopy(ww))\n",
    "    _, acc = eval_test(net_glob, args, test_dl_global)\n",
    "    if acc > best_glob_acc:\n",
    "        best_glob_acc = acc \n",
    "\n",
    "    # print loss\n",
    "    loss_avg = sum(loss_locals) / len(loss_locals)\n",
    "    #avg_init_tloss = sum(init_local_tloss) / len(init_local_tloss)\n",
    "    #avg_init_tacc = sum(init_local_tacc) / len(init_local_tacc)\n",
    "    #avg_final_tloss = sum(final_local_tloss) / len(final_local_tloss)\n",
    "    #avg_final_tacc = sum(final_local_tacc) / len(final_local_tacc)\n",
    "         \n",
    "    print('## END OF ROUND ##')\n",
    "    template = 'Average Train loss {:.3f}'\n",
    "    print(template.format(loss_avg))\n",
    "    \n",
    "#     template = \"AVG Init Test Loss: {:.3f}, AVG Init Test Acc: {:.3f}\"\n",
    "#     print(template.format(avg_init_tloss, avg_init_tacc))\n",
    "    \n",
    "#     template = \"AVG Final Test Loss: {:.3f}, AVG Final Test Acc: {:.3f}\"\n",
    "#     print(template.format(avg_final_tloss, avg_final_tacc))\n",
    "    \n",
    "    template = \"Global Model Test Acc: {:.3f}, Global Model Best Test Acc: {:.3f}\"\n",
    "    print(template.format(acc, best_glob_acc))\n",
    "    \n",
    "    print_flag = False\n",
    "#     if iteration < 60:\n",
    "#         print_flag = True\n",
    "    if iteration%args.print_freq == 0: \n",
    "        print_flag = True\n",
    "        \n",
    "    if print_flag:\n",
    "        print('--- PRINTING ALL CLIENTS STATUS ---')\n",
    "        current_acc = []\n",
    "        for k in range(args.num_users):\n",
    "            loss, acc = clients[k].eval_test() \n",
    "            current_acc.append(acc)\n",
    "            \n",
    "            if acc > clients_best_acc[k]:\n",
    "                clients_best_acc[k] = acc\n",
    "                \n",
    "            template = (\"Client {:3d}, labels {}, count {}, best_acc {:3.3f}, current_acc {:3.3f} \\n\")\n",
    "            print(template.format(k, traindata_cls_counts[k], clients[k].get_count(),\n",
    "                                  clients_best_acc[k], current_acc[-1]))\n",
    "            \n",
    "        template = (\"Round {:1d}, Avg current_acc {:3.3f}, Avg best_acc {:3.3f}\")\n",
    "        print(template.format(iteration+1, np.mean(current_acc), np.mean(clients_best_acc)))\n",
    "        \n",
    "        ckp_avg_tacc.append(np.mean(current_acc))\n",
    "        ckp_avg_best_tacc.append(np.mean(clients_best_acc))\n",
    "    \n",
    "    print('----- Analysis End of Round -------')\n",
    "    for idx in idxs_users:\n",
    "        print(f'Client {idx}, Count: {clients[idx].get_count()}, Labels: {traindata_cls_counts[idx]}')\n",
    "           \n",
    "    loss_train.append(loss_avg)\n",
    "    \n",
    "    #init_tacc_pr.append(avg_init_tacc)\n",
    "    #init_tloss_pr.append(avg_init_tloss)\n",
    "    \n",
    "    #final_tacc_pr.append(avg_final_tacc)\n",
    "    #final_tloss_pr.append(avg_final_tloss)\n",
    "    \n",
    "    #break;\n",
    "    ## clear the placeholders for the next round \n",
    "    loss_locals.clear()\n",
    "    init_local_tacc.clear()\n",
    "    init_local_tloss.clear()\n",
    "    final_local_tacc.clear()\n",
    "    final_local_tloss.clear()\n",
    "    \n",
    "    ## calling garbage collector \n",
    "    gc.collect()\n",
    "    \n",
    "############################### Saving Training Results \n",
    "# with open(path+str(args.trial)+'_loss_train.npy', 'wb') as fp:\n",
    "#     loss_train = np.array(loss_train)\n",
    "#     np.save(fp, loss_train)\n",
    "    \n",
    "# with open(path+str(args.trial)+'_init_tacc_pr.npy', 'wb') as fp:\n",
    "#     init_tacc_pr = np.array(init_tacc_pr)\n",
    "#     np.save(fp, init_tacc_pr)\n",
    "    \n",
    "# with open(path+str(args.trial)+'_init_tloss_pr.npy', 'wb') as fp:\n",
    "#     init_tloss_pr = np.array(init_tloss_pr)\n",
    "#     np.save(fp, init_tloss_pr)\n",
    "    \n",
    "# with open(path+str(args.trial)+'_final_tacc_pr.npy', 'wb') as fp:\n",
    "#     final_tacc_pr = np.array(final_tacc_pr)\n",
    "#     np.save(fp, final_tacc_pr)\n",
    "    \n",
    "# with open(path+str(args.trial)+'_final_tloss_pr.npy', 'wb') as fp:\n",
    "#     final_tloss_pr = np.array(final_tloss_pr)\n",
    "#     np.save(fp, final_tloss_pr)\n",
    "    \n",
    "# with open(path+str(args.trial)+'_best_glob_w.pt', 'wb') as fp:\n",
    "#     torch.save(best_glob_w, fp)\n",
    "############################### Printing Final Test and Train ACC / LOSS\n",
    "test_loss = []\n",
    "test_acc = []\n",
    "train_loss = []\n",
    "train_acc = []\n",
    "\n",
    "for idx in range(args.num_users):        \n",
    "    loss, acc = clients[idx].eval_test()\n",
    "        \n",
    "    test_loss.append(loss)\n",
    "    test_acc.append(acc)\n",
    "    \n",
    "    loss, acc = clients[idx].eval_train()\n",
    "    \n",
    "    train_loss.append(loss)\n",
    "    train_acc.append(acc)\n",
    "\n",
    "test_loss = sum(test_loss) / len(test_loss)\n",
    "test_acc = sum(test_acc) / len(test_acc)\n",
    "\n",
    "train_loss = sum(train_loss) / len(train_loss)\n",
    "train_acc = sum(train_acc) / len(train_acc)\n",
    "\n",
    "print(f'Train Loss: {train_loss}, Test_loss: {test_loss}')\n",
    "print(f'Train Acc: {train_acc}, Test Acc: {test_acc}')\n",
    "\n",
    "print(f'Best Clients AVG Acc: {np.mean(clients_best_acc)}')\n",
    "\n",
    "net_glob.load_state_dict(copy.deepcopy(w_glob))\n",
    "_, acc = eval_test(net_glob, args, test_dl_global)\n",
    "if acc > best_glob_acc:\n",
    "    best_glob_acc = acc \n",
    "\n",
    "template = \"Global Model Test Acc: {:.3f}, Global Model Best Test Acc: {:.3f}\"\n",
    "print(template.format(acc, best_glob_acc))\n",
    "############################# Saving Print Results \n",
    "with open(path+str(args.trial)+'_final_results.txt', 'a') as text_file:\n",
    "    print(f'Train Loss: {train_loss}, Test_loss: {test_loss}', file=text_file)\n",
    "    print(f'Train Acc: {train_acc}, Test Acc: {test_acc}', file=text_file)\n",
    "\n",
    "    print(f'Best Clients AVG Acc: {np.mean(clients_best_acc)}', file=text_file)\n",
    "    \n",
    "    template = \"Global Model Test Acc: {:.3f}, Global Model Best Test Acc: {:.3f}\"\n",
    "    print(template.format(acc, best_glob_acc), file=text_file)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
