{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5909348f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "# matplotlib.use('Agg')\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import numpy as np\n",
    "from torchvision import datasets, transforms\n",
    "import torch\n",
    "\n",
    "from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, cifar_noniid_shared\n",
    "from utils.options import args_parser\n",
    "from models.Update import LocalUpdate\n",
    "from models.Nets import MLP, CNNMnist, CNNCifar, LeNet, CNNMnist2\n",
    "from models.Fed import FedAvg\n",
    "from models.Fed import FedQAvg, Quantization, Quantization_Finite, my_score, my_score_Finite\n",
    "from models.test import test_img\n",
    "\n",
    "\n",
    "import math\n",
    "\n",
    "\n",
    "# from sympy import * \n",
    "from utils.functions import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ee2d7891",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "class my_argument:    \n",
    "    epochs = 200    #\"rounds of training\"\n",
    "    num_users = 120  # \"number of users: K\"\n",
    "    frac = 0.5 #\"the fraction of clients: C\"\n",
    "    local_ep=1 #\"the number of local epochs: E\"\n",
    "    local_bs=50 #\"local batch size: B\"\n",
    "    bs=50 #\"test batch size\"\n",
    "    lr=0.001 #\"learning rate\"\n",
    "    momentum=0.9 # \"SGD momentum (default: 0.5)\"\n",
    "    split='user' # \"train-test split type, user or sample\"\n",
    "    weight_decay = 5e-4\n",
    "    opt = 'SGD' #'ADAM'\n",
    "    loss = 'Cross'\n",
    "\n",
    "    # model arguments\n",
    "    model = 'cnn'\n",
    "    kernel_num=9 #, help='number of each kind of kernel')\n",
    "    kernel_sizes='3,4,5' #  help='comma-separated kernel size to use for convolution')\n",
    "    norm='batch_norm' #, help=\"batch_norm, layer_norm, or None\")\n",
    "    num_filters=32 #, help=\"number of filters for conv nets\")\n",
    "    max_pool='True' #help=\"Whether use max pooling rather than strided convolutions\")\n",
    "\n",
    "    # other arguments\n",
    "    dataset='cifar' #, help=\"name of dataset\")\n",
    "    iid=1\n",
    "    num_classes=10#, help=\"number of classes\")\n",
    "    num_channels=1#, help=\"number of channels of imges\")\n",
    "    gpu=1#, help=\"GPU ID, -1 for CPU\")\n",
    "    stopping_rounds=10#, help='rounds of early stopping')\n",
    "    verbose='False'#, help='verbose print')\n",
    "    seed=1#, help='random seed (default: 1)')\n",
    "    \n",
    "args = my_argument()\n",
    "\n",
    "args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "print(use_cuda)\n",
    "args.device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "print(args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83cbc7aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])\n",
    "dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=trans_cifar)\n",
    "dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True, transform=trans_cifar)\n",
    "\n",
    "dict_users = cifar_iid(dataset_train, args.num_users)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5f9addf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.1 0.3 0.1 0.5 0.1 0.1 0.3 0.1 0.5 0.5 0.1 0.4 0.5 0.3 0.4 0.4 0.5 0.1\n",
      " 0.3 0.4 0.1 0.1 0.2 0.5 0.1 0.3 0.2 0.2 0.3 0.2 0.5 0.1 0.1 0.5 0.2 0.2\n",
      " 0.2 0.5 0.3 0.2 0.2 0.5 0.2 0.2 0.3 0.4 0.1 0.5 0.4 0.1 0.1 0.1 0.2 0.5\n",
      " 0.2 0.5 0.2 0.4 0.5 0.5 0.2 0.1 0.5 0.3 0.5 0.2 0.1 0.1 0.5 0.3 0.5 0.2\n",
      " 0.5 0.2 0.2 0.3 0.5 0.5 0.5 0.4 0.1 0.2 0.1 0.1 0.5 0.5 0.2 0.2 0.4 0.2\n",
      " 0.3 0.2 0.2 0.4 0.2 0.3 0.2 0.1 0.1 0.4 0.5 0.1 0.2 0.5 0.2 0.4 0.5 0.3\n",
      " 0.2 0.1 0.1 0.1 0.1 0.3 0.4 0.5 0.5 0.5 0.3 0.1]\n"
     ]
    }
   ],
   "source": [
    "N = args.num_users\n",
    "\n",
    "p_array = np.array([0.1, 0.2, 0.3, 0.4, 0.5])\n",
    "# p_array = np.array([0])\n",
    "\n",
    "# print(p_matrix)\n",
    "p_sel = np.random.randint(low=0, high=len(p_array), size=(N,))\n",
    "\n",
    "p_per_user = np.ones((N,))\n",
    "\n",
    "for i in range(N):\n",
    "    p_per_user[i] = p_array[p_sel[i]]\n",
    "\n",
    "print(p_per_user)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "710ad02a",
   "metadata": {},
   "source": [
    "# 1. Random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0c4d1e1b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Learning Rate = 0.001\n",
      "\n",
      "VGG_origin(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU(inplace=True)\n",
      "    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (6): ReLU(inplace=True)\n",
      "    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (10): ReLU(inplace=True)\n",
      "    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (13): ReLU(inplace=True)\n",
      "    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (17): ReLU(inplace=True)\n",
      "    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (20): ReLU(inplace=True)\n",
      "    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (24): ReLU(inplace=True)\n",
      "    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (27): ReLU(inplace=True)\n",
      "    (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (29): AvgPool2d(kernel_size=1, stride=1, padding=0)\n",
      "  )\n",
      "  (classifier): Linear(in_features=512, out_features=10, bias=True)\n",
      ")\n",
      "\n",
      "Test set: Average loss: 2.3012 \n",
      "Accuracy: 1097/10000 (10.97%)\n",
      "\n",
      "Round   0, Train average loss 2.393 Test accuracy 10.970\n",
      "\n",
      "Test set: Average loss: 2.3203 \n",
      "Accuracy: 986/10000 (9.86%)\n",
      "\n",
      "Round   1, Train average loss 2.523 Test accuracy 9.860\n",
      "\n",
      "Test set: Average loss: 3.5198 \n",
      "Accuracy: 1985/10000 (19.85%)\n",
      "\n",
      "Round   2, Train average loss 2.478 Test accuracy 19.850\n",
      "\n",
      "Test set: Average loss: 2.4684 \n",
      "Accuracy: 2091/10000 (20.91%)\n",
      "\n",
      "Round   3, Train average loss 2.534 Test accuracy 20.910\n",
      "\n",
      "Test set: Average loss: 3.4832 \n",
      "Accuracy: 1991/10000 (19.91%)\n",
      "\n",
      "Round   4, Train average loss 2.438 Test accuracy 19.910\n",
      "\n",
      "Test set: Average loss: 2.3842 \n",
      "Accuracy: 2576/10000 (25.76%)\n",
      "\n",
      "Round   5, Train average loss 2.279 Test accuracy 25.760\n",
      "\n",
      "Test set: Average loss: 3.8084 \n",
      "Accuracy: 1741/10000 (17.41%)\n",
      "\n",
      "Round   6, Train average loss 2.202 Test accuracy 17.410\n",
      "\n",
      "Test set: Average loss: 2.2540 \n",
      "Accuracy: 2811/10000 (28.11%)\n",
      "\n",
      "Round   7, Train average loss 2.079 Test accuracy 28.110\n",
      "\n",
      "Test set: Average loss: 1.7314 \n",
      "Accuracy: 3672/10000 (36.72%)\n",
      "\n",
      "Round   8, Train average loss 2.008 Test accuracy 36.720\n",
      "\n",
      "Test set: Average loss: 1.7380 \n",
      "Accuracy: 3485/10000 (34.85%)\n",
      "\n",
      "Round   9, Train average loss 1.954 Test accuracy 34.850\n",
      "\n",
      "Test set: Average loss: 1.6223 \n",
      "Accuracy: 3940/10000 (39.40%)\n",
      "\n",
      "Round  10, Train average loss 1.906 Test accuracy 39.400\n",
      "\n",
      "Test set: Average loss: 1.9947 \n",
      "Accuracy: 3221/10000 (32.21%)\n",
      "\n",
      "Round  11, Train average loss 1.882 Test accuracy 32.210\n",
      "\n",
      "Test set: Average loss: 1.5678 \n",
      "Accuracy: 3982/10000 (39.82%)\n",
      "\n",
      "Round  12, Train average loss 1.859 Test accuracy 39.820\n",
      "\n",
      "Test set: Average loss: 1.7032 \n",
      "Accuracy: 3753/10000 (37.53%)\n",
      "\n",
      "Round  13, Train average loss 1.801 Test accuracy 37.530\n",
      "\n",
      "Test set: Average loss: 1.6600 \n",
      "Accuracy: 3644/10000 (36.44%)\n",
      "\n",
      "Round  14, Train average loss 1.868 Test accuracy 36.440\n",
      "\n",
      "Test set: Average loss: 1.6955 \n",
      "Accuracy: 3590/10000 (35.90%)\n",
      "\n",
      "Round  15, Train average loss 1.774 Test accuracy 35.900\n",
      "\n",
      "Test set: Average loss: 1.6403 \n",
      "Accuracy: 3900/10000 (39.00%)\n",
      "\n",
      "Round  16, Train average loss 1.762 Test accuracy 39.000\n",
      "\n",
      "Test set: Average loss: 1.4862 \n",
      "Accuracy: 4457/10000 (44.57%)\n",
      "\n",
      "Round  17, Train average loss 1.734 Test accuracy 44.570\n",
      "\n",
      "Test set: Average loss: 1.5678 \n",
      "Accuracy: 3975/10000 (39.75%)\n",
      "\n",
      "Round  18, Train average loss 1.711 Test accuracy 39.750\n",
      "\n",
      "Test set: Average loss: 1.5287 \n",
      "Accuracy: 4201/10000 (42.01%)\n",
      "\n",
      "Round  19, Train average loss 1.679 Test accuracy 42.010\n",
      "\n",
      "Test set: Average loss: 1.4882 \n",
      "Accuracy: 4561/10000 (45.61%)\n",
      "\n",
      "Round  20, Train average loss 1.632 Test accuracy 45.610\n",
      "\n",
      "Test set: Average loss: 1.4800 \n",
      "Accuracy: 4519/10000 (45.19%)\n",
      "\n",
      "Round  21, Train average loss 1.582 Test accuracy 45.190\n",
      "\n",
      "Test set: Average loss: 1.3670 \n",
      "Accuracy: 4960/10000 (49.60%)\n",
      "\n",
      "Round  22, Train average loss 1.606 Test accuracy 49.600\n",
      "\n",
      "Test set: Average loss: 1.4036 \n",
      "Accuracy: 4726/10000 (47.26%)\n",
      "\n",
      "Round  23, Train average loss 1.556 Test accuracy 47.260\n",
      "\n",
      "Test set: Average loss: 1.3149 \n",
      "Accuracy: 5181/10000 (51.81%)\n",
      "\n",
      "Round  24, Train average loss 1.610 Test accuracy 51.810\n",
      "\n",
      "Test set: Average loss: 1.3479 \n",
      "Accuracy: 4920/10000 (49.20%)\n",
      "\n",
      "Round  25, Train average loss 1.507 Test accuracy 49.200\n",
      "\n",
      "Test set: Average loss: 1.3320 \n",
      "Accuracy: 4861/10000 (48.61%)\n",
      "\n",
      "Round  26, Train average loss 1.548 Test accuracy 48.610\n",
      "\n",
      "Test set: Average loss: 1.3153 \n",
      "Accuracy: 5079/10000 (50.79%)\n",
      "\n",
      "Round  27, Train average loss 1.510 Test accuracy 50.790\n",
      "\n",
      "Test set: Average loss: 1.3532 \n",
      "Accuracy: 4936/10000 (49.36%)\n",
      "\n",
      "Round  28, Train average loss 1.489 Test accuracy 49.360\n",
      "\n",
      "Test set: Average loss: 1.2222 \n",
      "Accuracy: 5434/10000 (54.34%)\n",
      "\n",
      "Round  29, Train average loss 1.488 Test accuracy 54.340\n",
      "\n",
      "Test set: Average loss: 1.3110 \n",
      "Accuracy: 5083/10000 (50.83%)\n",
      "\n",
      "Round  30, Train average loss 1.429 Test accuracy 50.830\n",
      "\n",
      "Test set: Average loss: 1.3022 \n",
      "Accuracy: 5246/10000 (52.46%)\n",
      "\n",
      "Round  31, Train average loss 1.442 Test accuracy 52.460\n",
      "\n",
      "Test set: Average loss: 1.2436 \n",
      "Accuracy: 5370/10000 (53.70%)\n",
      "\n",
      "Round  32, Train average loss 1.433 Test accuracy 53.700\n",
      "\n",
      "Test set: Average loss: 1.2583 \n",
      "Accuracy: 5401/10000 (54.01%)\n",
      "\n",
      "Round  33, Train average loss 1.394 Test accuracy 54.010\n",
      "\n",
      "Test set: Average loss: 1.2667 \n",
      "Accuracy: 5311/10000 (53.11%)\n",
      "\n",
      "Round  34, Train average loss 1.384 Test accuracy 53.110\n",
      "\n",
      "Test set: Average loss: 1.2561 \n",
      "Accuracy: 5319/10000 (53.19%)\n",
      "\n",
      "Round  35, Train average loss 1.392 Test accuracy 53.190\n",
      "\n",
      "Test set: Average loss: 1.1855 \n",
      "Accuracy: 5679/10000 (56.79%)\n",
      "\n",
      "Round  36, Train average loss 1.378 Test accuracy 56.790\n",
      "\n",
      "Test set: Average loss: 1.2776 \n",
      "Accuracy: 5348/10000 (53.48%)\n",
      "\n",
      "Round  37, Train average loss 1.343 Test accuracy 53.480\n",
      "\n",
      "Test set: Average loss: 1.2255 \n",
      "Accuracy: 5489/10000 (54.89%)\n",
      "\n",
      "Round  38, Train average loss 1.317 Test accuracy 54.890\n",
      "\n",
      "Test set: Average loss: 1.2647 \n",
      "Accuracy: 5646/10000 (56.46%)\n",
      "\n",
      "Round  39, Train average loss 1.301 Test accuracy 56.460\n",
      "\n",
      "Test set: Average loss: 1.1895 \n",
      "Accuracy: 5618/10000 (56.18%)\n",
      "\n",
      "Round  40, Train average loss 1.278 Test accuracy 56.180\n",
      "\n",
      "Test set: Average loss: 1.1888 \n",
      "Accuracy: 5833/10000 (58.33%)\n",
      "\n",
      "Round  41, Train average loss 1.259 Test accuracy 58.330\n",
      "\n",
      "Test set: Average loss: 1.2777 \n",
      "Accuracy: 5399/10000 (53.99%)\n",
      "\n",
      "Round  42, Train average loss 1.319 Test accuracy 53.990\n",
      "\n",
      "Test set: Average loss: 1.2635 \n",
      "Accuracy: 5533/10000 (55.33%)\n",
      "\n",
      "Round  43, Train average loss 1.286 Test accuracy 55.330\n",
      "\n",
      "Test set: Average loss: 1.1901 \n",
      "Accuracy: 5671/10000 (56.71%)\n",
      "\n",
      "Round  44, Train average loss 1.240 Test accuracy 56.710\n",
      "\n",
      "Test set: Average loss: 1.1149 \n",
      "Accuracy: 6003/10000 (60.03%)\n",
      "\n",
      "Round  45, Train average loss 1.203 Test accuracy 60.030\n",
      "\n",
      "Test set: Average loss: 1.1028 \n",
      "Accuracy: 5984/10000 (59.84%)\n",
      "\n",
      "Round  46, Train average loss 1.240 Test accuracy 59.840\n",
      "\n",
      "Test set: Average loss: 1.2846 \n",
      "Accuracy: 5427/10000 (54.27%)\n",
      "\n",
      "Round  47, Train average loss 1.215 Test accuracy 54.270\n",
      "\n",
      "Test set: Average loss: 1.0531 \n",
      "Accuracy: 6289/10000 (62.89%)\n",
      "\n",
      "Round  48, Train average loss 1.226 Test accuracy 62.890\n",
      "\n",
      "Test set: Average loss: 1.1148 \n",
      "Accuracy: 6010/10000 (60.10%)\n",
      "\n",
      "Round  49, Train average loss 1.173 Test accuracy 60.100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 1.1382 \n",
      "Accuracy: 5901/10000 (59.01%)\n",
      "\n",
      "Round  50, Train average loss 1.173 Test accuracy 59.010\n",
      "\n",
      "Test set: Average loss: 0.9933 \n",
      "Accuracy: 6430/10000 (64.30%)\n",
      "\n",
      "Round  51, Train average loss 1.161 Test accuracy 64.300\n",
      "\n",
      "Test set: Average loss: 1.0098 \n",
      "Accuracy: 6399/10000 (63.99%)\n",
      "\n",
      "Round  52, Train average loss 1.123 Test accuracy 63.990\n",
      "\n",
      "Test set: Average loss: 1.0756 \n",
      "Accuracy: 6133/10000 (61.33%)\n",
      "\n",
      "Round  53, Train average loss 1.186 Test accuracy 61.330\n",
      "\n",
      "Test set: Average loss: 1.0359 \n",
      "Accuracy: 6292/10000 (62.92%)\n",
      "\n",
      "Round  54, Train average loss 1.125 Test accuracy 62.920\n",
      "\n",
      "Test set: Average loss: 1.0521 \n",
      "Accuracy: 6313/10000 (63.13%)\n",
      "\n",
      "Round  55, Train average loss 1.133 Test accuracy 63.130\n",
      "\n",
      "Test set: Average loss: 1.0811 \n",
      "Accuracy: 6088/10000 (60.88%)\n",
      "\n",
      "Round  56, Train average loss 1.093 Test accuracy 60.880\n",
      "\n",
      "Test set: Average loss: 1.0955 \n",
      "Accuracy: 6143/10000 (61.43%)\n",
      "\n",
      "Round  57, Train average loss 1.099 Test accuracy 61.430\n",
      "\n",
      "Test set: Average loss: 1.1485 \n",
      "Accuracy: 5938/10000 (59.38%)\n",
      "\n",
      "Round  58, Train average loss 1.103 Test accuracy 59.380\n",
      "\n",
      "Test set: Average loss: 1.0134 \n",
      "Accuracy: 6371/10000 (63.71%)\n",
      "\n",
      "Round  59, Train average loss 1.052 Test accuracy 63.710\n",
      "\n",
      "Test set: Average loss: 1.0976 \n",
      "Accuracy: 6270/10000 (62.70%)\n",
      "\n",
      "Round  60, Train average loss 1.054 Test accuracy 62.700\n",
      "\n",
      "Test set: Average loss: 1.0142 \n",
      "Accuracy: 6430/10000 (64.30%)\n",
      "\n",
      "Round  61, Train average loss 1.098 Test accuracy 64.300\n",
      "\n",
      "Test set: Average loss: 1.0836 \n",
      "Accuracy: 6264/10000 (62.64%)\n",
      "\n",
      "Round  62, Train average loss 1.041 Test accuracy 62.640\n",
      "\n",
      "Test set: Average loss: 1.0273 \n",
      "Accuracy: 6364/10000 (63.64%)\n",
      "\n",
      "Round  63, Train average loss 1.060 Test accuracy 63.640\n",
      "\n",
      "Test set: Average loss: 1.0195 \n",
      "Accuracy: 6488/10000 (64.88%)\n",
      "\n",
      "Round  64, Train average loss 1.088 Test accuracy 64.880\n",
      "\n",
      "Test set: Average loss: 0.9767 \n",
      "Accuracy: 6569/10000 (65.69%)\n",
      "\n",
      "Round  65, Train average loss 0.961 Test accuracy 65.690\n",
      "\n",
      "Test set: Average loss: 1.0299 \n",
      "Accuracy: 6490/10000 (64.90%)\n",
      "\n",
      "Round  66, Train average loss 1.064 Test accuracy 64.900\n",
      "\n",
      "Test set: Average loss: 1.0390 \n",
      "Accuracy: 6469/10000 (64.69%)\n",
      "\n",
      "Round  67, Train average loss 1.017 Test accuracy 64.690\n",
      "\n",
      "Test set: Average loss: 1.0071 \n",
      "Accuracy: 6379/10000 (63.79%)\n",
      "\n",
      "Round  68, Train average loss 0.957 Test accuracy 63.790\n",
      "\n",
      "Test set: Average loss: 0.9726 \n",
      "Accuracy: 6597/10000 (65.97%)\n",
      "\n",
      "Round  69, Train average loss 0.998 Test accuracy 65.970\n",
      "\n",
      "Test set: Average loss: 0.9787 \n",
      "Accuracy: 6646/10000 (66.46%)\n",
      "\n",
      "Round  70, Train average loss 0.995 Test accuracy 66.460\n",
      "\n",
      "Test set: Average loss: 0.9432 \n",
      "Accuracy: 6760/10000 (67.60%)\n",
      "\n",
      "Round  71, Train average loss 0.941 Test accuracy 67.600\n",
      "\n",
      "Test set: Average loss: 0.9239 \n",
      "Accuracy: 6847/10000 (68.47%)\n",
      "\n",
      "Round  72, Train average loss 0.910 Test accuracy 68.470\n",
      "\n",
      "Test set: Average loss: 0.9349 \n",
      "Accuracy: 6791/10000 (67.91%)\n",
      "\n",
      "Round  73, Train average loss 0.981 Test accuracy 67.910\n",
      "\n",
      "Test set: Average loss: 0.9254 \n",
      "Accuracy: 6856/10000 (68.56%)\n",
      "\n",
      "Round  74, Train average loss 0.959 Test accuracy 68.560\n",
      "\n",
      "Test set: Average loss: 0.9655 \n",
      "Accuracy: 6696/10000 (66.96%)\n",
      "\n",
      "Round  75, Train average loss 0.907 Test accuracy 66.960\n",
      "\n",
      "Test set: Average loss: 1.0057 \n",
      "Accuracy: 6574/10000 (65.74%)\n",
      "\n",
      "Round  76, Train average loss 0.933 Test accuracy 65.740\n",
      "\n",
      "Test set: Average loss: 1.0361 \n",
      "Accuracy: 6558/10000 (65.58%)\n",
      "\n",
      "Round  77, Train average loss 0.907 Test accuracy 65.580\n",
      "\n",
      "Test set: Average loss: 0.9777 \n",
      "Accuracy: 6763/10000 (67.63%)\n",
      "\n",
      "Round  78, Train average loss 0.900 Test accuracy 67.630\n",
      "\n",
      "Test set: Average loss: 1.0526 \n",
      "Accuracy: 6483/10000 (64.83%)\n",
      "\n",
      "Round  79, Train average loss 0.931 Test accuracy 64.830\n",
      "\n",
      "Test set: Average loss: 1.0091 \n",
      "Accuracy: 6640/10000 (66.40%)\n",
      "\n",
      "Round  80, Train average loss 0.906 Test accuracy 66.400\n",
      "\n",
      "Test set: Average loss: 1.0297 \n",
      "Accuracy: 6590/10000 (65.90%)\n",
      "\n",
      "Round  81, Train average loss 0.897 Test accuracy 65.900\n",
      "\n",
      "Test set: Average loss: 0.9016 \n",
      "Accuracy: 6969/10000 (69.69%)\n",
      "\n",
      "Round  82, Train average loss 0.896 Test accuracy 69.690\n",
      "\n",
      "Test set: Average loss: 0.9770 \n",
      "Accuracy: 6689/10000 (66.89%)\n",
      "\n",
      "Round  83, Train average loss 0.885 Test accuracy 66.890\n",
      "\n",
      "Test set: Average loss: 0.9872 \n",
      "Accuracy: 6788/10000 (67.88%)\n",
      "\n",
      "Round  84, Train average loss 0.839 Test accuracy 67.880\n",
      "\n",
      "Test set: Average loss: 0.9908 \n",
      "Accuracy: 6776/10000 (67.76%)\n",
      "\n",
      "Round  85, Train average loss 0.870 Test accuracy 67.760\n",
      "\n",
      "Test set: Average loss: 1.0159 \n",
      "Accuracy: 6755/10000 (67.55%)\n",
      "\n",
      "Round  86, Train average loss 0.876 Test accuracy 67.550\n",
      "\n",
      "Test set: Average loss: 0.9937 \n",
      "Accuracy: 6793/10000 (67.93%)\n",
      "\n",
      "Round  87, Train average loss 0.802 Test accuracy 67.930\n",
      "\n",
      "Test set: Average loss: 0.9156 \n",
      "Accuracy: 7020/10000 (70.20%)\n",
      "\n",
      "Round  88, Train average loss 0.857 Test accuracy 70.200\n",
      "\n",
      "Test set: Average loss: 0.9222 \n",
      "Accuracy: 6988/10000 (69.88%)\n",
      "\n",
      "Round  89, Train average loss 0.840 Test accuracy 69.880\n",
      "\n",
      "Test set: Average loss: 0.9038 \n",
      "Accuracy: 7015/10000 (70.15%)\n",
      "\n",
      "Round  90, Train average loss 0.814 Test accuracy 70.150\n",
      "\n",
      "Test set: Average loss: 0.9675 \n",
      "Accuracy: 6847/10000 (68.47%)\n",
      "\n",
      "Round  91, Train average loss 0.810 Test accuracy 68.470\n",
      "\n",
      "Test set: Average loss: 1.0484 \n",
      "Accuracy: 6762/10000 (67.62%)\n",
      "\n",
      "Round  92, Train average loss 0.802 Test accuracy 67.620\n",
      "\n",
      "Test set: Average loss: 1.0982 \n",
      "Accuracy: 6683/10000 (66.83%)\n",
      "\n",
      "Round  93, Train average loss 0.846 Test accuracy 66.830\n",
      "\n",
      "Test set: Average loss: 1.0814 \n",
      "Accuracy: 6680/10000 (66.80%)\n",
      "\n",
      "Round  94, Train average loss 0.813 Test accuracy 66.800\n",
      "\n",
      "Test set: Average loss: 0.9299 \n",
      "Accuracy: 6984/10000 (69.84%)\n",
      "\n",
      "Round  95, Train average loss 0.750 Test accuracy 69.840\n",
      "\n",
      "Test set: Average loss: 1.0043 \n",
      "Accuracy: 6813/10000 (68.13%)\n",
      "\n",
      "Round  96, Train average loss 0.794 Test accuracy 68.130\n",
      "\n",
      "Test set: Average loss: 1.0673 \n",
      "Accuracy: 6675/10000 (66.75%)\n",
      "\n",
      "Round  97, Train average loss 0.800 Test accuracy 66.750\n",
      "\n",
      "Test set: Average loss: 1.2997 \n",
      "Accuracy: 6436/10000 (64.36%)\n",
      "\n",
      "Round  98, Train average loss 0.811 Test accuracy 64.360\n",
      "\n",
      "Test set: Average loss: 0.9624 \n",
      "Accuracy: 6949/10000 (69.49%)\n",
      "\n",
      "Round  99, Train average loss 0.722 Test accuracy 69.490\n",
      "\n",
      "Test set: Average loss: 0.9969 \n",
      "Accuracy: 6988/10000 (69.88%)\n",
      "\n",
      "Round 100, Train average loss 0.719 Test accuracy 69.880\n",
      "\n",
      "Test set: Average loss: 1.0392 \n",
      "Accuracy: 6869/10000 (68.69%)\n",
      "\n",
      "Round 101, Train average loss 0.731 Test accuracy 68.690\n",
      "\n",
      "Test set: Average loss: 1.2647 \n",
      "Accuracy: 6497/10000 (64.97%)\n",
      "\n",
      "Round 102, Train average loss 0.765 Test accuracy 64.970\n",
      "\n",
      "Test set: Average loss: 0.9578 \n",
      "Accuracy: 7030/10000 (70.30%)\n",
      "\n",
      "Round 103, Train average loss 0.784 Test accuracy 70.300\n",
      "\n",
      "Test set: Average loss: 1.4200 \n",
      "Accuracy: 6289/10000 (62.89%)\n",
      "\n",
      "Round 104, Train average loss 0.755 Test accuracy 62.890\n",
      "\n",
      "Test set: Average loss: 1.0184 \n",
      "Accuracy: 6864/10000 (68.64%)\n",
      "\n",
      "Round 105, Train average loss 0.735 Test accuracy 68.640\n",
      "\n",
      "Test set: Average loss: 1.0186 \n",
      "Accuracy: 6917/10000 (69.17%)\n",
      "\n",
      "Round 106, Train average loss 0.725 Test accuracy 69.170\n",
      "\n",
      "Test set: Average loss: 1.1652 \n",
      "Accuracy: 6636/10000 (66.36%)\n",
      "\n",
      "Round 107, Train average loss 0.697 Test accuracy 66.360\n",
      "\n",
      "Test set: Average loss: 1.1238 \n",
      "Accuracy: 6763/10000 (67.63%)\n",
      "\n",
      "Round 108, Train average loss 0.719 Test accuracy 67.630\n",
      "\n",
      "Test set: Average loss: 1.1623 \n",
      "Accuracy: 6603/10000 (66.03%)\n",
      "\n",
      "Round 109, Train average loss 0.689 Test accuracy 66.030\n",
      "\n",
      "Test set: Average loss: 1.0474 \n",
      "Accuracy: 6932/10000 (69.32%)\n",
      "\n",
      "Round 110, Train average loss 0.695 Test accuracy 69.320\n",
      "\n",
      "Test set: Average loss: 1.0777 \n",
      "Accuracy: 6883/10000 (68.83%)\n",
      "\n",
      "Round 111, Train average loss 0.691 Test accuracy 68.830\n",
      "\n",
      "Test set: Average loss: 1.2844 \n",
      "Accuracy: 6553/10000 (65.53%)\n",
      "\n",
      "Round 112, Train average loss 0.693 Test accuracy 65.530\n",
      "\n",
      "Test set: Average loss: 1.2487 \n",
      "Accuracy: 6729/10000 (67.29%)\n",
      "\n",
      "Round 113, Train average loss 0.687 Test accuracy 67.290\n",
      "\n",
      "Test set: Average loss: 1.0285 \n",
      "Accuracy: 7103/10000 (71.03%)\n",
      "\n",
      "Round 114, Train average loss 0.703 Test accuracy 71.030\n",
      "\n",
      "Test set: Average loss: 1.1589 \n",
      "Accuracy: 6694/10000 (66.94%)\n",
      "\n",
      "Round 115, Train average loss 0.624 Test accuracy 66.940\n",
      "\n",
      "Test set: Average loss: 1.1023 \n",
      "Accuracy: 6930/10000 (69.30%)\n",
      "\n",
      "Round 116, Train average loss 0.673 Test accuracy 69.300\n",
      "\n",
      "Test set: Average loss: 0.9551 \n",
      "Accuracy: 7148/10000 (71.48%)\n",
      "\n",
      "Round 117, Train average loss 0.679 Test accuracy 71.480\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 1.1079 \n",
      "Accuracy: 6933/10000 (69.33%)\n",
      "\n",
      "Round 118, Train average loss 0.707 Test accuracy 69.330\n",
      "\n",
      "Test set: Average loss: 1.1082 \n",
      "Accuracy: 7000/10000 (70.00%)\n",
      "\n",
      "Round 119, Train average loss 0.671 Test accuracy 70.000\n",
      "\n",
      "Test set: Average loss: 1.1881 \n",
      "Accuracy: 6765/10000 (67.65%)\n",
      "\n",
      "Round 120, Train average loss 0.669 Test accuracy 67.650\n",
      "\n",
      "Test set: Average loss: 1.0386 \n",
      "Accuracy: 7054/10000 (70.54%)\n",
      "\n",
      "Round 121, Train average loss 0.610 Test accuracy 70.540\n",
      "\n",
      "Test set: Average loss: 1.1572 \n",
      "Accuracy: 6883/10000 (68.83%)\n",
      "\n",
      "Round 122, Train average loss 0.656 Test accuracy 68.830\n",
      "\n",
      "Test set: Average loss: 0.9779 \n",
      "Accuracy: 7198/10000 (71.98%)\n",
      "\n",
      "Round 123, Train average loss 0.623 Test accuracy 71.980\n",
      "\n",
      "Test set: Average loss: 0.9962 \n",
      "Accuracy: 7200/10000 (72.00%)\n",
      "\n",
      "Round 124, Train average loss 0.591 Test accuracy 72.000\n",
      "\n",
      "Test set: Average loss: 1.0163 \n",
      "Accuracy: 7077/10000 (70.77%)\n",
      "\n",
      "Round 125, Train average loss 0.665 Test accuracy 70.770\n",
      "\n",
      "Test set: Average loss: 1.0206 \n",
      "Accuracy: 7089/10000 (70.89%)\n",
      "\n",
      "Round 126, Train average loss 0.606 Test accuracy 70.890\n",
      "\n",
      "Test set: Average loss: 1.1001 \n",
      "Accuracy: 6939/10000 (69.39%)\n",
      "\n",
      "Round 127, Train average loss 0.681 Test accuracy 69.390\n",
      "\n",
      "Test set: Average loss: 1.0129 \n",
      "Accuracy: 7157/10000 (71.57%)\n",
      "\n",
      "Round 128, Train average loss 0.655 Test accuracy 71.570\n",
      "\n",
      "Test set: Average loss: 1.0527 \n",
      "Accuracy: 7088/10000 (70.88%)\n",
      "\n",
      "Round 129, Train average loss 0.582 Test accuracy 70.880\n",
      "\n",
      "Test set: Average loss: 1.0673 \n",
      "Accuracy: 7129/10000 (71.29%)\n",
      "\n",
      "Round 130, Train average loss 0.597 Test accuracy 71.290\n",
      "\n",
      "Test set: Average loss: 1.1848 \n",
      "Accuracy: 6838/10000 (68.38%)\n",
      "\n",
      "Round 131, Train average loss 0.578 Test accuracy 68.380\n",
      "\n",
      "Test set: Average loss: 1.1263 \n",
      "Accuracy: 7124/10000 (71.24%)\n",
      "\n",
      "Round 132, Train average loss 0.594 Test accuracy 71.240\n",
      "\n",
      "Test set: Average loss: 1.0861 \n",
      "Accuracy: 7011/10000 (70.11%)\n",
      "\n",
      "Round 133, Train average loss 0.599 Test accuracy 70.110\n",
      "\n",
      "Test set: Average loss: 1.2132 \n",
      "Accuracy: 6978/10000 (69.78%)\n",
      "\n",
      "Round 134, Train average loss 0.597 Test accuracy 69.780\n",
      "\n",
      "Test set: Average loss: 1.0187 \n",
      "Accuracy: 7168/10000 (71.68%)\n",
      "\n",
      "Round 135, Train average loss 0.617 Test accuracy 71.680\n",
      "\n",
      "Test set: Average loss: 1.0507 \n",
      "Accuracy: 7120/10000 (71.20%)\n",
      "\n",
      "Round 136, Train average loss 0.556 Test accuracy 71.200\n",
      "\n",
      "Test set: Average loss: 0.9996 \n",
      "Accuracy: 7255/10000 (72.55%)\n",
      "\n",
      "Round 137, Train average loss 0.563 Test accuracy 72.550\n",
      "\n",
      "Test set: Average loss: 1.0929 \n",
      "Accuracy: 7040/10000 (70.40%)\n",
      "\n",
      "Round 138, Train average loss 0.530 Test accuracy 70.400\n",
      "\n",
      "Test set: Average loss: 0.9949 \n",
      "Accuracy: 7245/10000 (72.45%)\n",
      "\n",
      "Round 139, Train average loss 0.589 Test accuracy 72.450\n",
      "\n",
      "Test set: Average loss: 1.2674 \n",
      "Accuracy: 6866/10000 (68.66%)\n",
      "\n",
      "Round 140, Train average loss 0.537 Test accuracy 68.660\n",
      "\n",
      "Test set: Average loss: 1.1762 \n",
      "Accuracy: 6897/10000 (68.97%)\n",
      "\n",
      "Round 141, Train average loss 0.497 Test accuracy 68.970\n",
      "\n",
      "Test set: Average loss: 1.1754 \n",
      "Accuracy: 7073/10000 (70.73%)\n",
      "\n",
      "Round 142, Train average loss 0.544 Test accuracy 70.730\n",
      "\n",
      "Test set: Average loss: 1.1515 \n",
      "Accuracy: 7146/10000 (71.46%)\n",
      "\n",
      "Round 143, Train average loss 0.602 Test accuracy 71.460\n",
      "\n",
      "Test set: Average loss: 1.0264 \n",
      "Accuracy: 7354/10000 (73.54%)\n",
      "\n",
      "Round 144, Train average loss 0.512 Test accuracy 73.540\n",
      "\n",
      "Test set: Average loss: 1.2517 \n",
      "Accuracy: 6781/10000 (67.81%)\n",
      "\n",
      "Round 145, Train average loss 0.556 Test accuracy 67.810\n",
      "\n",
      "Test set: Average loss: 1.2835 \n",
      "Accuracy: 7022/10000 (70.22%)\n",
      "\n",
      "Round 146, Train average loss 0.550 Test accuracy 70.220\n",
      "\n",
      "Test set: Average loss: 1.3987 \n",
      "Accuracy: 6586/10000 (65.86%)\n",
      "\n",
      "Round 147, Train average loss 0.553 Test accuracy 65.860\n",
      "\n",
      "Test set: Average loss: 1.0719 \n",
      "Accuracy: 7201/10000 (72.01%)\n",
      "\n",
      "Round 148, Train average loss 0.554 Test accuracy 72.010\n",
      "\n",
      "Test set: Average loss: 1.2246 \n",
      "Accuracy: 6980/10000 (69.80%)\n",
      "\n",
      "Round 149, Train average loss 0.486 Test accuracy 69.800\n",
      "\n",
      "Test set: Average loss: 1.0981 \n",
      "Accuracy: 7167/10000 (71.67%)\n",
      "\n",
      "Round 150, Train average loss 0.513 Test accuracy 71.670\n",
      "\n",
      "Test set: Average loss: 1.2777 \n",
      "Accuracy: 6845/10000 (68.45%)\n",
      "\n",
      "Round 151, Train average loss 0.497 Test accuracy 68.450\n",
      "\n",
      "Test set: Average loss: 1.2570 \n",
      "Accuracy: 6966/10000 (69.66%)\n",
      "\n",
      "Round 152, Train average loss 0.453 Test accuracy 69.660\n",
      "\n",
      "Test set: Average loss: 1.2103 \n",
      "Accuracy: 7081/10000 (70.81%)\n",
      "\n",
      "Round 153, Train average loss 0.571 Test accuracy 70.810\n",
      "\n",
      "Test set: Average loss: 1.0442 \n",
      "Accuracy: 7248/10000 (72.48%)\n",
      "\n",
      "Round 154, Train average loss 0.489 Test accuracy 72.480\n",
      "\n",
      "Test set: Average loss: 1.0434 \n",
      "Accuracy: 7291/10000 (72.91%)\n",
      "\n",
      "Round 155, Train average loss 0.489 Test accuracy 72.910\n",
      "\n",
      "Test set: Average loss: 1.1477 \n",
      "Accuracy: 7161/10000 (71.61%)\n",
      "\n",
      "Round 156, Train average loss 0.522 Test accuracy 71.610\n",
      "\n",
      "Test set: Average loss: 1.3855 \n",
      "Accuracy: 6704/10000 (67.04%)\n",
      "\n",
      "Round 157, Train average loss 0.488 Test accuracy 67.040\n",
      "\n",
      "Test set: Average loss: 1.3268 \n",
      "Accuracy: 6947/10000 (69.47%)\n",
      "\n",
      "Round 158, Train average loss 0.519 Test accuracy 69.470\n",
      "\n",
      "Test set: Average loss: 1.0727 \n",
      "Accuracy: 7331/10000 (73.31%)\n",
      "\n",
      "Round 159, Train average loss 0.565 Test accuracy 73.310\n",
      "\n",
      "Test set: Average loss: 1.1045 \n",
      "Accuracy: 7137/10000 (71.37%)\n",
      "\n",
      "Round 160, Train average loss 0.528 Test accuracy 71.370\n",
      "\n",
      "Test set: Average loss: 1.0902 \n",
      "Accuracy: 7203/10000 (72.03%)\n",
      "\n",
      "Round 161, Train average loss 0.482 Test accuracy 72.030\n",
      "\n",
      "Test set: Average loss: 1.0222 \n",
      "Accuracy: 7310/10000 (73.10%)\n",
      "\n",
      "Round 162, Train average loss 0.438 Test accuracy 73.100\n",
      "\n",
      "Test set: Average loss: 1.0427 \n",
      "Accuracy: 7148/10000 (71.48%)\n",
      "\n",
      "Round 163, Train average loss 0.482 Test accuracy 71.480\n",
      "\n",
      "Test set: Average loss: 1.1728 \n",
      "Accuracy: 7025/10000 (70.25%)\n",
      "\n",
      "Round 164, Train average loss 0.490 Test accuracy 70.250\n",
      "\n",
      "Test set: Average loss: 1.0962 \n",
      "Accuracy: 7194/10000 (71.94%)\n",
      "\n",
      "Round 165, Train average loss 0.496 Test accuracy 71.940\n",
      "\n",
      "Test set: Average loss: 1.1114 \n",
      "Accuracy: 7212/10000 (72.12%)\n",
      "\n",
      "Round 166, Train average loss 0.491 Test accuracy 72.120\n",
      "\n",
      "Test set: Average loss: 1.2380 \n",
      "Accuracy: 7019/10000 (70.19%)\n",
      "\n",
      "Round 167, Train average loss 0.435 Test accuracy 70.190\n",
      "\n",
      "Test set: Average loss: 1.0755 \n",
      "Accuracy: 7288/10000 (72.88%)\n",
      "\n",
      "Round 168, Train average loss 0.486 Test accuracy 72.880\n",
      "\n",
      "Test set: Average loss: 1.3741 \n",
      "Accuracy: 6929/10000 (69.29%)\n",
      "\n",
      "Round 169, Train average loss 0.423 Test accuracy 69.290\n",
      "\n",
      "Test set: Average loss: 1.2926 \n",
      "Accuracy: 7034/10000 (70.34%)\n",
      "\n",
      "Round 170, Train average loss 0.508 Test accuracy 70.340\n",
      "\n",
      "Test set: Average loss: 1.0839 \n",
      "Accuracy: 7306/10000 (73.06%)\n",
      "\n",
      "Round 171, Train average loss 0.457 Test accuracy 73.060\n",
      "\n",
      "Test set: Average loss: 1.0967 \n",
      "Accuracy: 7318/10000 (73.18%)\n",
      "\n",
      "Round 172, Train average loss 0.438 Test accuracy 73.180\n",
      "\n",
      "Test set: Average loss: 1.1634 \n",
      "Accuracy: 7223/10000 (72.23%)\n",
      "\n",
      "Round 173, Train average loss 0.398 Test accuracy 72.230\n",
      "\n",
      "Test set: Average loss: 1.1023 \n",
      "Accuracy: 7328/10000 (73.28%)\n",
      "\n",
      "Round 174, Train average loss 0.523 Test accuracy 73.280\n",
      "\n",
      "Test set: Average loss: 1.1461 \n",
      "Accuracy: 7137/10000 (71.37%)\n",
      "\n",
      "Round 175, Train average loss 0.381 Test accuracy 71.370\n",
      "\n",
      "Test set: Average loss: 1.1369 \n",
      "Accuracy: 7252/10000 (72.52%)\n",
      "\n",
      "Round 176, Train average loss 0.421 Test accuracy 72.520\n",
      "\n",
      "Test set: Average loss: 1.1641 \n",
      "Accuracy: 7251/10000 (72.51%)\n",
      "\n",
      "Round 177, Train average loss 0.389 Test accuracy 72.510\n",
      "\n",
      "Test set: Average loss: 1.1488 \n",
      "Accuracy: 7197/10000 (71.97%)\n",
      "\n",
      "Round 178, Train average loss 0.450 Test accuracy 71.970\n",
      "\n",
      "Test set: Average loss: 1.2407 \n",
      "Accuracy: 7154/10000 (71.54%)\n",
      "\n",
      "Round 179, Train average loss 0.491 Test accuracy 71.540\n",
      "\n",
      "Test set: Average loss: 1.1391 \n",
      "Accuracy: 7134/10000 (71.34%)\n",
      "\n",
      "Round 180, Train average loss 0.387 Test accuracy 71.340\n",
      "\n",
      "Test set: Average loss: 1.1816 \n",
      "Accuracy: 7083/10000 (70.83%)\n",
      "\n",
      "Round 181, Train average loss 0.485 Test accuracy 70.830\n",
      "\n",
      "Test set: Average loss: 1.1470 \n",
      "Accuracy: 7172/10000 (71.72%)\n",
      "\n",
      "Round 182, Train average loss 0.422 Test accuracy 71.720\n",
      "\n",
      "Test set: Average loss: 1.0674 \n",
      "Accuracy: 7293/10000 (72.93%)\n",
      "\n",
      "Round 183, Train average loss 0.459 Test accuracy 72.930\n",
      "\n",
      "Test set: Average loss: 1.2886 \n",
      "Accuracy: 7119/10000 (71.19%)\n",
      "\n",
      "Round 184, Train average loss 0.428 Test accuracy 71.190\n",
      "\n",
      "Test set: Average loss: 1.2749 \n",
      "Accuracy: 7074/10000 (70.74%)\n",
      "\n",
      "Round 185, Train average loss 0.388 Test accuracy 70.740\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 1.4190 \n",
      "Accuracy: 6852/10000 (68.52%)\n",
      "\n",
      "Round 186, Train average loss 0.451 Test accuracy 68.520\n",
      "\n",
      "Test set: Average loss: 1.2811 \n",
      "Accuracy: 7048/10000 (70.48%)\n",
      "\n",
      "Round 187, Train average loss 0.371 Test accuracy 70.480\n",
      "\n",
      "Test set: Average loss: 1.2034 \n",
      "Accuracy: 7076/10000 (70.76%)\n",
      "\n",
      "Round 188, Train average loss 0.517 Test accuracy 70.760\n",
      "\n",
      "Test set: Average loss: 1.3410 \n",
      "Accuracy: 6981/10000 (69.81%)\n",
      "\n",
      "Round 189, Train average loss 0.406 Test accuracy 69.810\n",
      "\n",
      "Test set: Average loss: 1.3051 \n",
      "Accuracy: 6950/10000 (69.50%)\n",
      "\n",
      "Round 190, Train average loss 0.398 Test accuracy 69.500\n",
      "\n",
      "Test set: Average loss: 1.1477 \n",
      "Accuracy: 7337/10000 (73.37%)\n",
      "\n",
      "Round 191, Train average loss 0.397 Test accuracy 73.370\n",
      "\n",
      "Test set: Average loss: 1.2938 \n",
      "Accuracy: 7042/10000 (70.42%)\n",
      "\n",
      "Round 192, Train average loss 0.443 Test accuracy 70.420\n",
      "\n",
      "Test set: Average loss: 1.1590 \n",
      "Accuracy: 7246/10000 (72.46%)\n",
      "\n",
      "Round 193, Train average loss 0.394 Test accuracy 72.460\n",
      "\n",
      "Test set: Average loss: 1.1847 \n",
      "Accuracy: 7223/10000 (72.23%)\n",
      "\n",
      "Round 194, Train average loss 0.441 Test accuracy 72.230\n",
      "\n",
      "Test set: Average loss: 1.1474 \n",
      "Accuracy: 7230/10000 (72.30%)\n",
      "\n",
      "Round 195, Train average loss 0.358 Test accuracy 72.300\n",
      "\n",
      "Test set: Average loss: 1.3652 \n",
      "Accuracy: 7095/10000 (70.95%)\n",
      "\n",
      "Round 196, Train average loss 0.406 Test accuracy 70.950\n",
      "\n",
      "Test set: Average loss: 1.2211 \n",
      "Accuracy: 7126/10000 (71.26%)\n",
      "\n",
      "Round 197, Train average loss 0.415 Test accuracy 71.260\n",
      "\n",
      "Test set: Average loss: 1.0503 \n",
      "Accuracy: 7308/10000 (73.08%)\n",
      "\n",
      "Round 198, Train average loss 0.421 Test accuracy 73.080\n",
      "\n",
      "Test set: Average loss: 1.1269 \n",
      "Accuracy: 7257/10000 (72.57%)\n",
      "\n",
      "Round 199, Train average loss 0.408 Test accuracy 72.570\n"
     ]
    }
   ],
   "source": [
    "from models.Nets import NIN,CNN_moderate, CNNCifar3\n",
    "from models.vgg import *\n",
    "from models.vggmodel import *\n",
    "\n",
    "import math\n",
    "\n",
    "\n",
    "lr_array = [0.001]\n",
    "\n",
    "\n",
    "args.local_ep = 1\n",
    "args.local_bs = 50\n",
    "args.weight_decay = 3e-4\n",
    "\n",
    "N = 40\n",
    "K = 8\n",
    "\n",
    "N_trials = 1\n",
    "Max_iter = 200\n",
    "\n",
    "args.opt = 'ADAM'\n",
    "\n",
    "args.local_ep = 1\n",
    "args.local_bs = 128\n",
    "\n",
    "\n",
    "acc_test_arr_w_random_asymDrop  = np.zeros((len(lr_array), N_trials, Max_iter))\n",
    "loss_test_arr_w_random_asymDrop = np.zeros((len(lr_array), N_trials, Max_iter))\n",
    "\n",
    "\n",
    "for trial_idx in range(N_trials):\n",
    "    \n",
    "    for lr_idx in range(len(lr_array)):\n",
    "        \n",
    "        args.lr = lr_array[lr_idx]\n",
    "        \n",
    "        P_w_random_asymDrop = []\n",
    "        \n",
    "        print()\n",
    "        print('Learning Rate =',args.lr)\n",
    "        print()\n",
    "#         net_glob = vgg11()\n",
    "        net_glob = VGG_origin('VGG11')\n",
    "#         net_glob = VGG('VGG11')\n",
    "\n",
    "        net_glob = net_glob.cuda()\n",
    "        print(net_glob)\n",
    "\n",
    "        net_glob.train()\n",
    "\n",
    "        # copy weights\n",
    "        w_glob = net_glob.state_dict()\n",
    "        for iter in range(Max_iter): #args.epochs\n",
    "            w_locals, loss_locals = [], []\n",
    "            \n",
    "            if iter == 400 or iter == 800:\n",
    "                args.lr = args.lr * 0.4\n",
    "                \n",
    "                \n",
    "            ###############################\n",
    "            # 0. Dropout Realization\n",
    "            ###############################    \n",
    "            \n",
    "            u = np.ones((N,))\n",
    "            for u_idx in range(N):\n",
    "                p_sel = p_per_user[u_idx]\n",
    "                u[u_idx] = np.random.binomial(1, 1-p_sel, size=1)[0]\n",
    "            \n",
    "            result = np.where(u == 1)\n",
    "            drop_result = np.where(u == 0)\n",
    "\n",
    "            ###############################\n",
    "            # 1. Weighted Random Selection\n",
    "            ###############################\n",
    "\n",
    "            if iter == 0:\n",
    "                idxs_users = np.random.choice(result[0], K, replace=False)\n",
    "#                 print('select=',select)\n",
    "            else:\n",
    "                P = np.array(P_w_random_asymDrop)\n",
    "                P_sum = np.sum(P, axis=0).astype(int)\n",
    "        \n",
    "                for i in drop_result[0]:\n",
    "                    P_sum[i] = Max_iter + 1   \n",
    "\n",
    "                P_sum_sort = P_sum.argsort()\n",
    "\n",
    "                idxs_users = P_sum_sort[:K]\n",
    "            \n",
    "\n",
    "            p_tmp = np.zeros(N)\n",
    "            p_tmp[idxs_users] = 1\n",
    "\n",
    "            P_w_random_asymDrop.append(p_tmp)\n",
    "\n",
    "\n",
    "        #     idxs_users = np.random.choice(range(N), K, replace=False)\n",
    "            for idx in idxs_users:\n",
    "        #         print(idx)\n",
    "                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])\n",
    "                w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))\n",
    "                w_locals.append(copy.deepcopy(w))\n",
    "                loss_locals.append(copy.deepcopy(loss))\n",
    "            # update global weights\n",
    "            w_glob = FedAvg(w_locals)\n",
    "\n",
    "            # copy weight to net_glob\n",
    "            net_glob.load_state_dict(w_glob)\n",
    "\n",
    "            # print loss\n",
    "            loss_avg = sum(loss_locals) / len(loss_locals)\n",
    "\n",
    "    #         loss_train.append(loss_avg)\n",
    "\n",
    "            acc_test, loss_test = test_img(net_glob, dataset_test, args)\n",
    "            acc_test_arr_w_random_asymDrop[lr_idx][trial_idx][iter]  = acc_test\n",
    "            loss_test_arr_w_random_asymDrop[lr_idx][trial_idx][iter] = loss_test\n",
    "            if iter % 1 ==0:\n",
    "                print('Round {:3d}, Train average loss {:.3f} Test accuracy {:.3f}'.format(iter, loss_avg,acc_test))\n",
    "            #print(loss_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebe6bb96",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kernel_env_py38",
   "language": "python",
   "name": "env_py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
