{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import time\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current GPU: 0\n",
      "GPU Name: NVIDIA GeForce RTX 3090\n",
      "Number of GPUs: 2\n"
     ]
    }
   ],
   "source": [
    "print(f'Current GPU: {torch.cuda.current_device()}')\n",
    "print(f'GPU Name: {torch.cuda.get_device_name()}')\n",
    "print(f'Number of GPUs: {torch.cuda.device_count()}')\n",
    "torch.cuda.set_device(1) ## Setting cuda on GPU:0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args: \n",
    "    \n",
    "    num_users = 20\n",
    "    rounds = 30\n",
    "    frac = 0.1\n",
    "    local_bs = 10\n",
    "    local_ep = 10\n",
    "    lr = 0.01\n",
    "    momentum = 0.9\n",
    "    \n",
    "    \n",
    "    model = 'lenet5' ## options: lenet5\n",
    "    dataset = 'cifar10'  ## options: mnist, cifar10, cifar100\n",
    "    datadir='../data/'\n",
    "    \n",
    "    p_train=1.0\n",
    "    p_test = 1.0\n",
    "    partition='niid-labeldir'\n",
    "    niid_beta=0.1\n",
    "    iid_beta = 1.0\n",
    "\n",
    "    print_freq = 10\n",
    "    \n",
    "    load_initial = ''\n",
    "    seed = 0\n",
    "    gpu = 0\n",
    "    \n",
    "args = Args()\n",
    "\n",
    "torch.cuda.set_device(args.gpu) ## Setting cuda on GPU \n",
    "#torch.manual_seed(args.seed)\n",
    "#np.random.seed(args.seed)\n",
    "\n",
    "args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data import *\n",
    "from src.models import *\n",
    "from src.client import *\n",
    "from src.clustering import *\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Getting Clients Data\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "print('Getting Clients Data')\n",
    "\n",
    "train_ds_global, test_ds_global, train_dl_global, \\\n",
    "test_dl_global = get_dataset_global(args.dataset, args.datadir, batch_size=128,\n",
    "                                    p_train=1.0, p_test=1.0)\n",
    "\n",
    "train_ds_global1, test_ds_global1, train_dl_global1, \\\n",
    "test_dl_global1 = get_dataset_global(args.dataset, args.datadir, batch_size=128,\n",
    "                                     p_train=args.p_train, p_test=args.p_test)\n",
    "\n",
    "partitions_train, partitions_test, partitions_train_stat, \\\n",
    "partitions_test_stat = partition_data(args.dataset, args.datadir, args.partition,\n",
    "                                      args.num_users, niid_beta=args.niid_beta, iid_beta=args.iid_beta,\n",
    "                                      p_train=args.p_train, p_test=args.p_test)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building models for clients\n",
      "MODEL: lenet5, Dataset: cifar10\n",
      "----------------------------------------\n",
      "LeNet5(\n",
      "  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n",
      "\n",
      "total params 62006\n"
     ]
    }
   ],
   "source": [
    "print('Building models for clients')\n",
    "print(f'MODEL: {args.model}, Dataset: {args.dataset}')\n",
    "users_model, net_glob, initial_state_dict = get_models(args, dropout_p=0.5)\n",
    "#initial_state_dict = nn.DataParallel(initial_state_dict)\n",
    "#net_glob = nn.DataParallel(net_glob)\n",
    "print('-'*40)\n",
    "print(net_glob)\n",
    "print('')\n",
    "\n",
    "total = 0\n",
    "for name, param in net_glob.named_parameters():\n",
    "    #print(name, param.size())\n",
    "    total += np.prod(param.size())\n",
    "    #print(np.array(param.data.cpu().numpy().reshape([-1])))\n",
    "    #print(isinstance(param.data.cpu().numpy(), np.array))\n",
    "print(f'total params {total}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building Clients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing Clients\n",
      "-- Client 0, Train Stat {0: 123, 8: 25, 9: 10, 10: 6, 14: 1, 16: 359, 18: 42, 20: 1, 21: 114, 22: 7, 23: 51, 24: 94, 28: 5, 30: 1, 41: 121, 42: 94, 43: 5, 44: 2, 45: 5, 50: 23, 51: 3, 52: 6, 56: 17, 59: 53, 62: 228, 63: 116, 65: 104, 69: 1, 70: 104, 72: 1, 79: 4, 80: 323, 82: 48, 83: 2, 84: 13, 85: 221, 86: 3, 87: 314} Test Stat {0: 100, 8: 100, 9: 100, 10: 100, 14: 100, 16: 100, 18: 100, 20: 100, 21: 100, 22: 100, 23: 100, 24: 100, 28: 100, 30: 100, 41: 100, 42: 100, 43: 100, 44: 100, 45: 100, 50: 100, 51: 100, 52: 100, 56: 100, 59: 100, 62: 100, 63: 100, 65: 100, 69: 100, 70: 100, 72: 100, 79: 100, 80: 100, 82: 100, 83: 100, 84: 100, 85: 100, 86: 100, 87: 100}\n",
      "-- Client 1, Train Stat {0: 1, 8: 1, 9: 1, 11: 63, 12: 2, 13: 71, 15: 9, 16: 3, 18: 1, 20: 25, 21: 1, 24: 26, 26: 27, 32: 54, 33: 57, 34: 334, 37: 17, 39: 8, 40: 1, 44: 1, 46: 2, 47: 3, 48: 63, 52: 79, 55: 186, 56: 1, 59: 8, 60: 119, 61: 36, 63: 1, 67: 1, 68: 212, 70: 66, 72: 16, 74: 7, 76: 1, 78: 27, 79: 8, 81: 1, 83: 11, 84: 113, 85: 18, 86: 413, 87: 21, 88: 16, 92: 496} Test Stat {0: 100, 8: 100, 9: 100, 11: 100, 12: 100, 13: 100, 15: 100, 16: 100, 18: 100, 20: 100, 21: 100, 24: 100, 26: 100, 32: 100, 33: 100, 34: 100, 37: 100, 39: 100, 40: 100, 44: 100, 46: 100, 47: 100, 48: 100, 52: 100, 55: 100, 56: 100, 59: 100, 60: 100, 61: 100, 63: 100, 67: 100, 68: 100, 70: 100, 72: 100, 74: 100, 76: 100, 78: 100, 79: 100, 81: 100, 83: 100, 84: 100, 85: 100, 86: 100, 87: 100, 88: 100, 92: 100}\n",
      "-- Client 2, Train Stat {1: 34, 4: 25, 10: 1, 16: 71, 17: 6, 18: 294, 19: 14, 22: 3, 23: 6, 24: 22, 25: 339, 27: 248, 32: 24, 36: 53, 37: 63, 38: 71, 43: 1, 45: 55, 47: 5, 48: 3, 49: 1, 50: 24, 55: 26, 56: 52, 57: 9, 60: 1, 62: 21, 66: 7, 68: 4, 72: 44, 73: 27, 74: 17, 77: 39, 80: 12, 83: 193, 87: 25, 88: 35, 89: 292, 90: 52, 93: 239, 94: 1, 96: 79} Test Stat {1: 100, 4: 100, 10: 100, 16: 100, 17: 100, 18: 100, 19: 100, 22: 100, 23: 100, 24: 100, 25: 100, 27: 100, 32: 100, 36: 100, 37: 100, 38: 100, 43: 100, 45: 100, 47: 100, 48: 100, 49: 100, 50: 100, 55: 100, 56: 100, 57: 100, 60: 100, 62: 100, 66: 100, 68: 100, 72: 100, 73: 100, 74: 100, 77: 100, 80: 100, 83: 100, 87: 100, 88: 100, 89: 100, 90: 100, 93: 100, 94: 100, 96: 100}\n",
      "-- Client 3, Train Stat {2: 48, 4: 102, 11: 1, 13: 52, 14: 160, 15: 1, 21: 110, 22: 26, 23: 6, 24: 87, 27: 4, 31: 109, 32: 2, 34: 1, 35: 1, 36: 1, 37: 6, 38: 1, 39: 105, 41: 4, 42: 90, 48: 239, 53: 11, 54: 6, 61: 1, 62: 174, 69: 106, 70: 16, 78: 319, 80: 56, 81: 309, 84: 30, 85: 18, 86: 3, 87: 2, 88: 1, 89: 1, 90: 413} Test Stat {2: 100, 4: 100, 11: 100, 13: 100, 14: 100, 15: 100, 21: 100, 22: 100, 23: 100, 24: 100, 27: 100, 31: 100, 32: 100, 34: 100, 35: 100, 36: 100, 37: 100, 38: 100, 39: 100, 41: 100, 42: 100, 48: 100, 53: 100, 54: 100, 61: 100, 62: 100, 69: 100, 70: 100, 78: 100, 80: 100, 81: 100, 84: 100, 85: 100, 86: 100, 87: 100, 88: 100, 89: 100, 90: 100}\n",
      "-- Client 4, Train Stat {1: 297, 2: 4, 5: 106, 6: 120, 9: 2, 10: 1, 13: 85, 15: 62, 18: 15, 20: 1, 21: 49, 23: 1, 25: 8, 26: 125, 28: 54, 29: 15, 33: 96, 35: 1, 36: 9, 39: 146, 40: 85, 44: 8, 47: 5, 51: 1, 52: 23, 53: 43, 55: 179, 56: 21, 57: 2, 59: 33, 61: 87, 62: 26, 64: 16, 66: 1, 68: 38, 72: 102, 73: 4, 74: 2, 75: 8, 76: 24, 77: 53, 79: 117, 80: 24, 81: 54, 82: 8, 83: 3, 88: 6, 89: 1, 92: 1, 95: 248, 96: 398} Test Stat {1: 100, 2: 100, 5: 100, 6: 100, 9: 100, 10: 100, 13: 100, 15: 100, 18: 100, 20: 100, 21: 100, 23: 100, 25: 100, 26: 100, 28: 100, 29: 100, 33: 100, 35: 100, 36: 100, 39: 100, 40: 100, 44: 100, 47: 100, 51: 100, 52: 100, 53: 100, 55: 100, 56: 100, 57: 100, 59: 100, 61: 100, 62: 100, 64: 100, 66: 100, 68: 100, 72: 100, 73: 100, 74: 100, 75: 100, 76: 100, 77: 100, 79: 100, 80: 100, 81: 100, 82: 100, 83: 100, 88: 100, 89: 100, 92: 100, 95: 100, 96: 100}\n",
      "-- Client 5, Train Stat {0: 13, 4: 22, 5: 89, 9: 247, 10: 2, 11: 4, 16: 26, 17: 251, 19: 5, 21: 17, 25: 18, 28: 13, 29: 1, 30: 17, 31: 6, 32: 273, 33: 6, 36: 88, 38: 5, 39: 1, 41: 16, 42: 7, 46: 472, 49: 52, 51: 6, 53: 251, 56: 132, 58: 373, 60: 1, 66: 73, 68: 15} Test Stat {0: 100, 4: 100, 5: 100, 9: 100, 10: 100, 11: 100, 16: 100, 17: 100, 19: 100, 21: 100, 25: 100, 28: 100, 29: 100, 30: 100, 31: 100, 32: 100, 33: 100, 36: 100, 38: 100, 39: 100, 41: 100, 42: 100, 46: 100, 49: 100, 51: 100, 53: 100, 56: 100, 58: 100, 60: 100, 66: 100, 68: 100}\n",
      "-- Client 6, Train Stat {1: 6, 5: 23, 8: 16, 9: 1, 10: 4, 24: 1, 29: 9, 30: 126, 33: 85, 35: 24, 37: 76, 39: 11, 42: 27, 44: 103, 46: 5, 48: 41, 51: 2, 52: 5, 53: 91, 55: 6, 56: 38, 58: 1, 59: 46, 61: 2, 65: 9, 71: 286, 72: 25, 73: 1, 75: 2, 76: 35, 81: 4, 82: 228, 83: 70, 84: 86, 86: 55, 88: 182, 90: 9, 91: 302, 92: 1, 97: 6, 99: 33} Test Stat {1: 100, 5: 100, 8: 100, 9: 100, 10: 100, 24: 100, 29: 100, 30: 100, 33: 100, 35: 100, 37: 100, 39: 100, 42: 100, 44: 100, 46: 100, 48: 100, 51: 100, 52: 100, 53: 100, 55: 100, 56: 100, 58: 100, 59: 100, 61: 100, 65: 100, 71: 100, 72: 100, 73: 100, 75: 100, 76: 100, 81: 100, 82: 100, 83: 100, 84: 100, 86: 100, 88: 100, 90: 100, 91: 100, 92: 100, 97: 100, 99: 100}\n",
      "-- Client 7, Train Stat {1: 73, 3: 30, 5: 36, 7: 51, 15: 17, 18: 74, 23: 20, 24: 47, 26: 15, 27: 31, 29: 10, 35: 25, 36: 158, 39: 17, 40: 229, 41: 37, 43: 1, 44: 69, 45: 37, 47: 66, 48: 105, 54: 6, 58: 1, 61: 20, 62: 24, 65: 40, 67: 3, 68: 35, 69: 20, 73: 327, 74: 41, 75: 5, 80: 3, 83: 197, 84: 11, 85: 1, 86: 4, 88: 2, 94: 21, 95: 201, 98: 463} Test Stat {1: 100, 3: 100, 5: 100, 7: 100, 15: 100, 18: 100, 23: 100, 24: 100, 26: 100, 27: 100, 29: 100, 35: 100, 36: 100, 39: 100, 40: 100, 41: 100, 43: 100, 44: 100, 45: 100, 47: 100, 48: 100, 54: 100, 58: 100, 61: 100, 62: 100, 65: 100, 67: 100, 68: 100, 69: 100, 73: 100, 74: 100, 75: 100, 80: 100, 83: 100, 84: 100, 85: 100, 86: 100, 88: 100, 94: 100, 95: 100, 98: 100}\n",
      "-- Client 8, Train Stat {3: 1, 4: 3, 7: 23, 9: 14, 15: 2, 19: 40, 23: 18, 24: 7, 27: 14, 28: 12, 31: 4, 33: 8, 35: 4, 40: 11, 42: 1, 43: 102, 44: 21, 47: 186, 48: 1, 49: 422, 50: 1, 53: 2, 54: 9, 56: 12, 57: 48, 60: 171, 63: 1, 64: 359, 66: 35, 67: 28, 71: 141, 72: 6, 76: 6, 77: 288, 79: 358, 84: 75, 86: 11, 87: 2, 88: 21, 90: 25, 91: 11} Test Stat {3: 100, 4: 100, 7: 100, 9: 100, 15: 100, 19: 100, 23: 100, 24: 100, 27: 100, 28: 100, 31: 100, 33: 100, 35: 100, 40: 100, 42: 100, 43: 100, 44: 100, 47: 100, 48: 100, 49: 100, 50: 100, 53: 100, 54: 100, 56: 100, 57: 100, 60: 100, 63: 100, 64: 100, 66: 100, 67: 100, 71: 100, 72: 100, 76: 100, 77: 100, 79: 100, 84: 100, 86: 100, 87: 100, 88: 100, 90: 100, 91: 100}\n",
      "-- Client 9, Train Stat {3: 25, 4: 1, 5: 9, 7: 1, 8: 1, 9: 7, 11: 17, 16: 22, 17: 1, 19: 153, 20: 75, 23: 1, 24: 2, 25: 115, 26: 1, 27: 76, 29: 3, 31: 52, 33: 184, 34: 1, 35: 10, 36: 90, 43: 1, 44: 56, 45: 59, 50: 215, 51: 1, 55: 4, 58: 10, 59: 2, 61: 1, 62: 25, 63: 2, 66: 40, 67: 6, 72: 206, 73: 2, 78: 33, 79: 11, 84: 3, 86: 1, 87: 102, 89: 27, 91: 8, 93: 13, 94: 30, 98: 3} Test Stat {3: 100, 4: 100, 5: 100, 7: 100, 8: 100, 9: 100, 11: 100, 16: 100, 17: 100, 19: 100, 20: 100, 23: 100, 24: 100, 25: 100, 26: 100, 27: 100, 29: 100, 31: 100, 33: 100, 34: 100, 35: 100, 36: 100, 43: 100, 44: 100, 45: 100, 50: 100, 51: 100, 55: 100, 58: 100, 59: 100, 61: 100, 62: 100, 63: 100, 66: 100, 67: 100, 72: 100, 73: 100, 78: 100, 79: 100, 84: 100, 86: 100, 87: 100, 89: 100, 91: 100, 93: 100, 94: 100, 98: 100}\n",
      "-- Client 10, Train Stat {0: 7, 2: 8, 3: 1, 9: 172, 11: 5, 15: 51, 18: 4, 23: 66, 25: 1, 26: 213, 27: 34, 28: 2, 31: 1, 36: 87, 40: 114, 41: 42, 44: 36, 47: 24, 48: 17, 50: 20, 51: 1, 53: 3, 57: 5, 58: 6, 59: 2, 60: 4, 61: 115, 62: 1, 64: 2, 67: 46, 71: 72, 73: 43, 74: 47, 75: 121, 78: 86, 79: 1, 81: 3, 83: 15, 84: 6, 85: 2, 87: 3, 88: 1, 91: 102, 92: 1, 93: 208, 97: 394, 98: 25, 99: 465} Test Stat {0: 100, 2: 100, 3: 100, 9: 100, 11: 100, 15: 100, 18: 100, 23: 100, 25: 100, 26: 100, 27: 100, 28: 100, 31: 100, 36: 100, 40: 100, 41: 100, 44: 100, 47: 100, 48: 100, 50: 100, 51: 100, 53: 100, 57: 100, 58: 100, 59: 100, 60: 100, 61: 100, 62: 100, 64: 100, 67: 100, 71: 100, 73: 100, 74: 100, 75: 100, 78: 100, 79: 100, 81: 100, 83: 100, 84: 100, 85: 100, 87: 100, 88: 100, 91: 100, 92: 100, 93: 100, 97: 100, 98: 100, 99: 100}\n",
      "-- Client 11, Train Stat {2: 124, 6: 7, 13: 1, 17: 37, 18: 8, 19: 75, 20: 153, 21: 14, 26: 20, 27: 34, 28: 281, 30: 75, 31: 22, 34: 1, 36: 1, 37: 44, 40: 12, 41: 19, 42: 272, 43: 3, 45: 46, 46: 16, 47: 180, 48: 1, 50: 5, 51: 17, 54: 14, 58: 2, 60: 7, 65: 1, 66: 131, 67: 415, 72: 51, 76: 269, 80: 53, 81: 1, 82: 7, 84: 1, 86: 9, 88: 226} Test Stat {2: 100, 6: 100, 13: 100, 17: 100, 18: 100, 19: 100, 20: 100, 21: 100, 26: 100, 27: 100, 28: 100, 30: 100, 31: 100, 34: 100, 36: 100, 37: 100, 40: 100, 41: 100, 42: 100, 43: 100, 45: 100, 46: 100, 47: 100, 48: 100, 50: 100, 51: 100, 54: 100, 58: 100, 60: 100, 65: 100, 66: 100, 67: 100, 72: 100, 76: 100, 80: 100, 81: 100, 82: 100, 84: 100, 86: 100, 88: 100}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-- Client 12, Train Stat {0: 218, 1: 25, 6: 1, 7: 137, 9: 1, 10: 181, 12: 465, 15: 27, 16: 14, 17: 14, 19: 148, 21: 142, 22: 90, 23: 201, 24: 98, 25: 2, 26: 9, 28: 1, 30: 1, 31: 297, 33: 1, 34: 74, 35: 80, 37: 12, 38: 4, 46: 1, 49: 24, 50: 49, 51: 254} Test Stat {0: 100, 1: 100, 6: 100, 7: 100, 9: 100, 10: 100, 12: 100, 15: 100, 16: 100, 17: 100, 19: 100, 21: 100, 22: 100, 23: 100, 24: 100, 25: 100, 26: 100, 28: 100, 30: 100, 31: 100, 33: 100, 34: 100, 35: 100, 37: 100, 38: 100, 46: 100, 49: 100, 50: 100, 51: 100}\n",
      "-- Client 13, Train Stat {1: 37, 2: 257, 4: 346, 7: 5, 10: 9, 11: 83, 13: 11, 14: 70, 17: 1, 22: 34, 23: 125, 24: 53, 26: 7, 29: 3, 32: 115, 39: 5, 40: 5, 47: 10, 51: 97, 53: 29, 58: 29, 59: 42, 63: 39, 64: 4, 65: 251, 68: 24, 69: 82, 75: 67, 76: 105, 77: 1, 78: 1, 80: 1, 82: 6, 84: 83, 87: 30, 91: 53, 93: 38, 94: 395} Test Stat {1: 100, 2: 100, 4: 100, 7: 100, 10: 100, 11: 100, 13: 100, 14: 100, 17: 100, 22: 100, 23: 100, 24: 100, 26: 100, 29: 100, 32: 100, 39: 100, 40: 100, 47: 100, 51: 100, 53: 100, 58: 100, 59: 100, 63: 100, 64: 100, 65: 100, 68: 100, 69: 100, 75: 100, 76: 100, 77: 100, 78: 100, 80: 100, 82: 100, 84: 100, 87: 100, 91: 100, 93: 100, 94: 100}\n",
      "-- Client 14, Train Stat {0: 137, 2: 1, 5: 2, 10: 6, 13: 45, 15: 77, 16: 1, 17: 2, 18: 47, 19: 3, 20: 243, 21: 7, 23: 1, 25: 2, 28: 64, 29: 13, 31: 8, 32: 14, 33: 58, 34: 25, 37: 9, 43: 2, 44: 142, 45: 259, 46: 1, 54: 102, 56: 27, 58: 18, 60: 37, 63: 20, 64: 37, 65: 52, 68: 96, 69: 232, 70: 4, 72: 26, 75: 279, 77: 117, 78: 32, 80: 27, 82: 182, 83: 8, 88: 9, 89: 178} Test Stat {0: 100, 2: 100, 5: 100, 10: 100, 13: 100, 15: 100, 16: 100, 17: 100, 18: 100, 19: 100, 20: 100, 21: 100, 23: 100, 25: 100, 28: 100, 29: 100, 31: 100, 32: 100, 33: 100, 34: 100, 37: 100, 43: 100, 44: 100, 45: 100, 46: 100, 54: 100, 56: 100, 58: 100, 60: 100, 63: 100, 64: 100, 65: 100, 68: 100, 69: 100, 70: 100, 72: 100, 75: 100, 77: 100, 78: 100, 80: 100, 82: 100, 83: 100, 88: 100, 89: 100}\n",
      "-- Client 15, Train Stat {8: 456, 12: 19, 13: 219, 14: 1, 15: 234, 18: 1, 19: 61, 21: 45, 24: 2, 26: 79, 27: 3, 28: 3, 30: 7, 32: 16, 34: 63, 35: 88, 36: 2, 37: 156, 40: 19, 43: 176, 44: 52, 45: 11, 47: 15, 51: 26, 53: 10, 56: 2, 57: 369, 61: 158, 63: 54, 65: 42, 66: 102, 68: 24} Test Stat {8: 100, 12: 100, 13: 100, 14: 100, 15: 100, 18: 100, 19: 100, 21: 100, 24: 100, 26: 100, 27: 100, 28: 100, 30: 100, 32: 100, 34: 100, 35: 100, 36: 100, 37: 100, 40: 100, 43: 100, 44: 100, 45: 100, 47: 100, 51: 100, 53: 100, 56: 100, 57: 100, 61: 100, 63: 100, 65: 100, 66: 100, 68: 100}\n",
      "-- Client 16, Train Stat {1: 2, 3: 234, 5: 191, 6: 1, 9: 18, 11: 176, 12: 13, 13: 15, 14: 205, 17: 133, 22: 80, 23: 1, 24: 8, 25: 1, 26: 3, 28: 59, 35: 108, 37: 32, 40: 3, 43: 207, 44: 7, 45: 27, 52: 311, 54: 314, 55: 72, 56: 34, 57: 1, 58: 51, 59: 306} Test Stat {1: 100, 3: 100, 5: 100, 6: 100, 9: 100, 11: 100, 12: 100, 13: 100, 14: 100, 17: 100, 22: 100, 23: 100, 24: 100, 25: 100, 26: 100, 28: 100, 35: 100, 37: 100, 40: 100, 43: 100, 44: 100, 45: 100, 52: 100, 54: 100, 55: 100, 56: 100, 57: 100, 58: 100, 59: 100}\n",
      "-- Client 17, Train Stat {2: 9, 9: 14, 11: 1, 15: 19, 16: 1, 18: 13, 22: 256, 23: 2, 24: 52, 27: 9, 35: 78, 36: 2, 37: 84, 38: 4, 39: 186, 42: 8, 44: 2, 46: 2, 47: 5, 48: 19, 50: 45, 51: 1, 52: 9, 54: 48, 55: 26, 56: 163, 57: 65, 59: 2, 60: 4, 64: 65, 68: 46, 69: 1, 70: 298, 72: 6, 73: 93, 74: 1, 75: 15, 76: 4, 78: 1, 82: 20, 83: 1, 85: 7, 86: 1, 87: 1, 88: 1, 90: 1, 91: 24, 92: 1, 93: 2, 94: 53, 95: 50, 96: 22, 97: 100, 98: 9, 99: 2} Test Stat {2: 100, 9: 100, 11: 100, 15: 100, 16: 100, 18: 100, 22: 100, 23: 100, 24: 100, 27: 100, 35: 100, 36: 100, 37: 100, 38: 100, 39: 100, 42: 100, 44: 100, 46: 100, 47: 100, 48: 100, 50: 100, 51: 100, 52: 100, 54: 100, 55: 100, 56: 100, 57: 100, 59: 100, 60: 100, 64: 100, 68: 100, 69: 100, 70: 100, 72: 100, 73: 100, 74: 100, 75: 100, 76: 100, 78: 100, 82: 100, 83: 100, 85: 100, 86: 100, 87: 100, 88: 100, 90: 100, 91: 100, 92: 100, 93: 100, 94: 100, 95: 100, 96: 100, 97: 100, 98: 100, 99: 100}\n",
      "-- Client 18, Train Stat {1: 8, 5: 30, 6: 2, 7: 282, 9: 12, 10: 289, 11: 30, 14: 57, 16: 2, 17: 52, 20: 1, 22: 3, 27: 34, 28: 1, 30: 272, 32: 1, 39: 17, 40: 20, 41: 260, 48: 9, 51: 3, 52: 2, 58: 6, 60: 148, 61: 74, 63: 266, 64: 3, 66: 107, 70: 11, 72: 16, 73: 2, 75: 2, 76: 56, 77: 1, 78: 1, 80: 1, 81: 128, 82: 1, 84: 78, 85: 233} Test Stat {1: 100, 5: 100, 6: 100, 7: 100, 9: 100, 10: 100, 11: 100, 14: 100, 16: 100, 17: 100, 20: 100, 22: 100, 27: 100, 28: 100, 30: 100, 32: 100, 39: 100, 40: 100, 41: 100, 48: 100, 51: 100, 52: 100, 58: 100, 60: 100, 61: 100, 63: 100, 64: 100, 66: 100, 70: 100, 72: 100, 73: 100, 75: 100, 76: 100, 77: 100, 78: 100, 80: 100, 81: 100, 82: 100, 84: 100, 85: 100}\n",
      "-- Client 19, Train Stat {0: 1, 1: 18, 2: 49, 3: 209, 4: 1, 5: 14, 6: 369, 7: 1, 8: 1, 9: 1, 10: 1, 11: 120, 12: 1, 13: 1, 14: 6, 15: 1, 16: 1, 17: 3, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 14, 26: 1, 27: 13, 28: 5, 29: 446, 30: 1, 31: 1, 32: 1, 33: 5, 34: 1, 35: 81, 36: 9, 37: 1, 38: 415, 39: 4, 40: 1, 41: 1, 42: 1, 43: 2, 44: 1, 45: 1, 46: 1, 47: 1, 48: 2, 49: 1, 50: 118, 51: 88, 52: 65, 53: 60, 54: 1, 55: 1, 56: 1, 57: 1, 58: 3, 59: 6, 60: 8, 61: 6, 62: 1, 63: 1, 64: 14, 65: 1, 66: 4, 67: 1, 68: 6, 69: 58, 70: 1, 71: 1, 72: 1, 73: 1, 74: 385, 75: 1, 77: 1, 79: 1, 84: 1, 89: 1, 95: 1, 96: 1} Test Stat {0: 100, 1: 100, 2: 100, 3: 100, 4: 100, 5: 100, 6: 100, 7: 100, 8: 100, 9: 100, 10: 100, 11: 100, 12: 100, 13: 100, 14: 100, 15: 100, 16: 100, 17: 100, 18: 100, 19: 100, 20: 100, 21: 100, 22: 100, 23: 100, 24: 100, 25: 100, 26: 100, 27: 100, 28: 100, 29: 100, 30: 100, 31: 100, 32: 100, 33: 100, 34: 100, 35: 100, 36: 100, 37: 100, 38: 100, 39: 100, 40: 100, 41: 100, 42: 100, 43: 100, 44: 100, 45: 100, 46: 100, 47: 100, 48: 100, 49: 100, 50: 100, 51: 100, 52: 100, 53: 100, 54: 100, 55: 100, 56: 100, 57: 100, 58: 100, 59: 100, 60: 100, 61: 100, 62: 100, 63: 100, 64: 100, 65: 100, 66: 100, 67: 100, 68: 100, 69: 100, 70: 100, 71: 100, 72: 100, 73: 100, 74: 100, 75: 100, 77: 100, 79: 100, 84: 100, 89: 100, 95: 100, 96: 100}\n"
     ]
    }
   ],
   "source": [
    "print('Initializing Clients')\n",
    "clients = []\n",
    "for idx in range(args.num_users):\n",
    "    sys.stdout.flush()\n",
    "    print(f'-- Client {idx}, Train Stat {partitions_train_stat[idx]} Test Stat {partitions_test_stat[idx]}')\n",
    "\n",
    "    noise_level=0\n",
    "    dataidxs = partitions_train[idx]\n",
    "    dataidxs_test = partitions_test[idx]\n",
    "\n",
    "    train_ds_local = get_subset(train_ds_global, dataidxs)\n",
    "    test_ds_local  = get_subset(test_ds_global, dataidxs_test)\n",
    "\n",
    "    transform_train, transform_test = get_transforms(args.dataset, noise_level=0, net_id=None, total=0)\n",
    "\n",
    "    train_dl_local = DataLoader(dataset=train_ds_local, batch_size=args.local_bs, shuffle=True, drop_last=False,\n",
    "                               num_workers=4, pin_memory=False)\n",
    "    test_dl_local = DataLoader(dataset=test_ds_local, batch_size=64, shuffle=False, drop_last=False, num_workers=4,\n",
    "                              pin_memory=False)\n",
    "\n",
    "    clients.append(Client_FedAvg(idx, copy.deepcopy(users_model[idx]), args.local_bs, args.local_ep,\n",
    "               args.lr, args.momentum, args.device, train_dl_local, test_dl_local))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Federation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting FL\n",
      "----------------------------------------\n",
      "----- ROUND 1 -----\n",
      "-- Average Train loss 3.510\n",
      "-- Global Acc: 1.000, Global Best Acc: 1.000\n",
      "----- ROUND 2 -----\n",
      "-- Average Train loss 2.449\n",
      "-- Global Acc: 1.240, Global Best Acc: 1.240\n",
      "----- ROUND 3 -----\n",
      "-- Average Train loss 2.167\n",
      "-- Global Acc: 2.780, Global Best Acc: 2.780\n",
      "----- ROUND 4 -----\n",
      "-- Average Train loss 1.723\n",
      "-- Global Acc: 4.150, Global Best Acc: 4.150\n",
      "----- ROUND 5 -----\n",
      "-- Average Train loss 1.819\n",
      "-- Global Acc: 4.250, Global Best Acc: 4.250\n",
      "----- ROUND 6 -----\n",
      "-- Average Train loss 2.090\n",
      "-- Global Acc: 5.690, Global Best Acc: 5.690\n",
      "----- ROUND 7 -----\n",
      "-- Average Train loss 2.124\n",
      "-- Global Acc: 7.430, Global Best Acc: 7.430\n",
      "----- ROUND 8 -----\n",
      "-- Average Train loss 1.340\n",
      "-- Global Acc: 7.250, Global Best Acc: 7.430\n",
      "----- ROUND 9 -----\n",
      "-- Average Train loss 1.329\n",
      "-- Global Acc: 9.370, Global Best Acc: 9.370\n",
      "----- ROUND 10 -----\n",
      "-- Average Train loss 1.774\n",
      "-- Global Acc: 7.540, Global Best Acc: 9.370\n",
      "----- ROUND 11 -----\n",
      "-- Average Train loss 1.780\n",
      "-- Global Acc: 8.180, Global Best Acc: 9.370\n",
      "----- ROUND 12 -----\n",
      "-- Average Train loss 1.555\n",
      "-- Global Acc: 7.280, Global Best Acc: 9.370\n",
      "----- ROUND 13 -----\n",
      "-- Average Train loss 1.537\n",
      "-- Global Acc: 8.920, Global Best Acc: 9.370\n",
      "----- ROUND 14 -----\n",
      "-- Average Train loss 1.189\n",
      "-- Global Acc: 9.390, Global Best Acc: 9.390\n",
      "----- ROUND 15 -----\n",
      "-- Average Train loss 1.002\n",
      "-- Global Acc: 9.200, Global Best Acc: 9.390\n",
      "----- ROUND 16 -----\n",
      "-- Average Train loss 0.929\n",
      "-- Global Acc: 8.220, Global Best Acc: 9.390\n",
      "----- ROUND 17 -----\n",
      "-- Average Train loss 1.195\n",
      "-- Global Acc: 9.980, Global Best Acc: 9.980\n",
      "----- ROUND 18 -----\n",
      "-- Average Train loss 0.993\n",
      "-- Global Acc: 11.410, Global Best Acc: 11.410\n",
      "----- ROUND 19 -----\n",
      "-- Average Train loss 1.658\n",
      "-- Global Acc: 11.470, Global Best Acc: 11.470\n",
      "----- ROUND 20 -----\n",
      "-- Average Train loss 1.301\n",
      "-- Global Acc: 9.490, Global Best Acc: 11.470\n",
      "----- ROUND 21 -----\n",
      "-- Average Train loss 0.861\n",
      "-- Global Acc: 10.270, Global Best Acc: 11.470\n",
      "----- ROUND 22 -----\n",
      "-- Average Train loss 0.974\n",
      "-- Global Acc: 11.870, Global Best Acc: 11.870\n",
      "----- ROUND 23 -----\n",
      "-- Average Train loss 1.302\n",
      "-- Global Acc: 11.810, Global Best Acc: 11.870\n",
      "----- ROUND 24 -----\n",
      "-- Average Train loss 0.991\n",
      "-- Global Acc: 13.080, Global Best Acc: 13.080\n",
      "----- ROUND 25 -----\n",
      "-- Average Train loss 0.984\n",
      "-- Global Acc: 10.660, Global Best Acc: 13.080\n",
      "----- ROUND 26 -----\n",
      "-- Average Train loss 1.406\n",
      "-- Global Acc: 10.410, Global Best Acc: 13.080\n",
      "----- ROUND 27 -----\n",
      "-- Average Train loss 1.361\n",
      "-- Global Acc: 11.580, Global Best Acc: 13.080\n",
      "----- ROUND 28 -----\n",
      "-- Average Train loss 0.903\n",
      "-- Global Acc: 11.690, Global Best Acc: 13.080\n",
      "----- ROUND 29 -----\n",
      "-- Average Train loss 0.715\n",
      "-- Global Acc: 11.200, Global Best Acc: 13.080\n",
      "----- ROUND 30 -----\n",
      "-- Average Train loss 0.679\n",
      "-- Global Acc: 13.990, Global Best Acc: 13.990\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "print('Starting FL')\n",
    "print('-'*40)\n",
    "start = time.time()\n",
    "\n",
    "num_users_FL = args.num_users\n",
    "\n",
    "loss_train = []\n",
    "clients_local_acc = {i:[] for i in range(num_users_FL)}\n",
    "w_locals, loss_locals = [], []\n",
    "glob_acc = []\n",
    "\n",
    "w_glob = copy.deepcopy(initial_state_dict)\n",
    "\n",
    "m = max(int(args.frac * num_users_FL), 1)\n",
    "\n",
    "for iteration in range(args.rounds):\n",
    "\n",
    "    idxs_users = np.random.choice(range(num_users_FL), m, replace=False)\n",
    "    #idxs_users = comm_users[iteration]\n",
    "\n",
    "    print(f'----- ROUND {iteration+1} -----')\n",
    "    torch.cuda.synchronize()\n",
    "    sys.stdout.flush()\n",
    "    for idx in idxs_users:\n",
    "        clients[idx].set_state_dict(copy.deepcopy(w_glob))\n",
    "\n",
    "        loss = clients[idx].train(is_print=False)\n",
    "        loss_locals.append(copy.deepcopy(loss))\n",
    "\n",
    "    # print loss\n",
    "    loss_avg = sum(loss_locals) / len(loss_locals)\n",
    "    template = '-- Average Train loss {:.3f}'\n",
    "    print(template.format(loss_avg))\n",
    "\n",
    "    ####### FedAvg ####### START\n",
    "    total_data_points = sum([len(partitions_train[r]) for r in idxs_users])\n",
    "    fed_avg_freqs = [len(partitions_train[r]) / total_data_points for r in idxs_users]\n",
    "    w_locals = []\n",
    "    for idx in idxs_users:\n",
    "        w_locals.append(copy.deepcopy(clients[idx].get_state_dict()))\n",
    "\n",
    "    ww = AvgWeights(w_locals, weight_avg=fed_avg_freqs)\n",
    "    w_glob = copy.deepcopy(ww)\n",
    "    net_glob.load_state_dict(copy.deepcopy(ww))\n",
    "\n",
    "    ####### FedAvg ####### END\n",
    "    _, acc = eval_test(net_glob, args, test_dl_global1)\n",
    "\n",
    "    glob_acc.append(acc)\n",
    "    template = \"-- Global Acc: {:.3f}, Global Best Acc: {:.3f}\"\n",
    "    print(template.format(glob_acc[-1], np.max(glob_acc)))\n",
    "\n",
    "    loss_train.append(loss_avg)\n",
    "\n",
    "    ## clear the placeholders for the next round\n",
    "    loss_locals.clear()\n",
    "\n",
    "    ## calling garbage collector\n",
    "    gc.collect()\n",
    "\n",
    "end = time.time()\n",
    "duration = end-start\n",
    "print('-'*40)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*************************\n",
      "---- Testing Final Local Results ----\n",
      "Client   0, Final_acc 30.84, best_acc 30.84 \n",
      "\n",
      "Client   1, Final_acc 22.35, best_acc 22.35 \n",
      "\n",
      "Client   2, Final_acc 23.07, best_acc 23.07 \n",
      "\n",
      "Client   3, Final_acc 22.61, best_acc 22.61 \n",
      "\n",
      "Client   4, Final_acc 22.88, best_acc 22.88 \n",
      "\n",
      "Client   5, Final_acc 32.35, best_acc 32.35 \n",
      "\n",
      "Client   6, Final_acc 26.34, best_acc 26.34 \n",
      "\n",
      "Client   7, Final_acc 26.61, best_acc 26.61 \n",
      "\n",
      "Client   8, Final_acc 22.88, best_acc 22.88 \n",
      "\n",
      "Client   9, Final_acc 15.32, best_acc 15.32 \n",
      "\n",
      "Client  10, Final_acc 21.77, best_acc 21.77 \n",
      "\n",
      "Client  11, Final_acc 28.73, best_acc 28.73 \n",
      "\n",
      "Client  12, Final_acc 30.41, best_acc 30.41 \n",
      "\n",
      "Client  13, Final_acc 24.26, best_acc 24.26 \n",
      "\n",
      "Client  14, Final_acc 25.30, best_acc 25.30 \n",
      "\n",
      "Client  15, Final_acc 31.19, best_acc 31.19 \n",
      "\n",
      "Client  16, Final_acc 28.72, best_acc 28.72 \n",
      "\n",
      "Client  17, Final_acc 5.85, best_acc 5.85 \n",
      "\n",
      "Client  18, Final_acc 21.33, best_acc 21.33 \n",
      "\n",
      "Client  19, Final_acc 9.46, best_acc 9.46 \n",
      "\n",
      "-- Avg Local Acc: 23.61\n",
      "-- Avg Best Local Acc: 23.61\n",
      "*************************\n",
      "----------------------------------------\n",
      "FINAL RESULTS\n",
      "-- Global Acc Final: 13.99\n",
      "-- Global Acc Avg Final [N*C] Rounds: 12.59\n",
      "-- Global Best Acc: 13.99\n",
      "-- Avg Local Acc: 23.61\n",
      "-- Avg Best Local Acc: 23.61\n",
      "-- FL Time: 10.61 minutes\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "print('*'*25)\n",
    "print('---- Testing Final Local Results ----')\n",
    "temp_acc = []\n",
    "temp_best_acc = []\n",
    "\n",
    "for k in range(num_users_FL):\n",
    "    sys.stdout.flush()\n",
    "    loss, acc = clients[k].eval_test()\n",
    "    clients_local_acc[k].append(acc)\n",
    "    temp_acc.append(clients_local_acc[k][-1])\n",
    "    temp_best_acc.append(np.max(clients_local_acc[k]))\n",
    "\n",
    "    template = (\"Client {:3d}, Final_acc {:3.2f}, best_acc {:3.2f} \\n\")\n",
    "    print(template.format(k, clients_local_acc[k][-1], np.max(clients_local_acc[k])))\n",
    "\n",
    "template = (\"-- Avg Local Acc: {:3.2f}\")\n",
    "print(template.format(np.mean(temp_acc)))\n",
    "template = (\"-- Avg Best Local Acc: {:3.2f}\")\n",
    "print(template.format(np.mean(temp_best_acc)))\n",
    "print('*'*25)\n",
    "############################### FedAvg Final Results\n",
    "print('-'*40)\n",
    "print('FINAL RESULTS')\n",
    "template = \"-- Global Acc Final: {:.2f}\"\n",
    "print(template.format(glob_acc[-1]))\n",
    "\n",
    "template = \"-- Global Acc Avg Final [N*C] Rounds: {:.2f}\"\n",
    "print(template.format(np.mean(glob_acc[-m:])))\n",
    "\n",
    "template = \"-- Global Best Acc: {:.2f}\"\n",
    "print(template.format(np.max(glob_acc)))\n",
    "\n",
    "template = (\"-- Avg Local Acc: {:3.2f}\")\n",
    "print(template.format(np.mean(temp_acc)))\n",
    "\n",
    "template = (\"-- Avg Best Local Acc: {:3.2f}\")\n",
    "print(template.format(np.mean(temp_best_acc)))\n",
    "\n",
    "print(f'-- FL Time: {duration/60:.2f} minutes')\n",
    "print('-'*40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
